----------------------------------------------------------------------------- -- A hacked dispatcher module -- LuaSocket sample files -- Author: Diego Nehab -- RCS ID: $$ ----------------------------------------------------------------------------- local base = _G local socket = require("socket") local coroutine = require("coroutine") module("dispatch") -- if too much time goes by without any activity in one of our sockets, we -- just kill it TIMEOUT = 60 ----------------------------------------------------------------------------- -- We implement 3 types of dispatchers: -- sequential -- coroutine -- threaded -- The user can choose whatever one is needed ----------------------------------------------------------------------------- local handlert = {} -- default handler is coroutine function newhandler(mode) mode = mode or "coroutine" return handlert[mode]() end local function seqstart(self, func) return func() end -- sequential handler simply calls the functions and doesn't wrap I/O function handlert.sequential() return { tcp = socket.tcp, start = seqstart } end ----------------------------------------------------------------------------- -- Mega hack. Don't try to do this at home. ----------------------------------------------------------------------------- -- we can't yield across calls to protect, so we rewrite it with coxpcall -- make sure you don't require any module that uses socket.protect before -- loading our hack function socket.protect(f) return function(...) local co = coroutine.create(f) while true do local results = {coroutine.resume(co, unpack(arg))} local status = table.remove(results, 1) if not status then if type(results[1]) == 'table' then return nil, results[1][1] else error(results[1]) end end if coroutine.status(co) == "suspended" then arg = {coroutine.yield(unpack(results))} else return unpack(results) end end end end ----------------------------------------------------------------------------- -- Simple set data structure. O(1) everything. ----------------------------------------------------------------------------- local function newset() local reverse = {} local set = {} return setmetatable(set, {__index = { insert = function(set, value) if not reverse[value] then table.insert(set, value) reverse[value] = table.getn(set) end end, remove = function(set, value) local index = reverse[value] if index then reverse[value] = nil local top = table.remove(set) if top ~= value then reverse[top] = index set[index] = top end end end }}) end ----------------------------------------------------------------------------- -- socket.tcp() wrapper for the coroutine dispatcher ----------------------------------------------------------------------------- local function cowrap(dispatcher, tcp, error) if not tcp then return nil, error end -- put it in non-blocking mode right away tcp:settimeout(0) -- metatable for wrap produces new methods on demand for those that we -- don't override explicitly. local metat = { __index = function(table, key) table[key] = function(...) arg[1] = tcp return tcp[key](unpack(arg)) end return table[key] end} -- does our user want to do his own non-blocking I/O? local zero = false -- create a wrap object that will behave just like a real socket object local wrap = { } -- we ignore settimeout to preserve our 0 timeout, but record whether -- the user wants to do his own non-blocking I/O function wrap:settimeout(value, mode) if value == 0 then zero = true else zero = false end return 1 end -- send in non-blocking mode and yield on timeout function wrap:send(data, first, last) first = (first or 1) - 1 local result, error while true do -- return control to dispatcher and tell it we want to send -- if upon return the dispatcher tells us we timed out, -- return an error to whoever called us if coroutine.yield(dispatcher.sending, tcp) == "timeout" then return nil, "timeout" end -- try sending result, error, first = tcp:send(data, first+1, last) -- if we are done, or there was an unexpected error, -- break away from loop if error ~= "timeout" then return result, error, first end end end -- receive in non-blocking mode and yield on timeout -- or simply return partial read, if user requested timeout = 0 function wrap:receive(pattern, partial) local error = "timeout" local value while true do -- return control to dispatcher and tell it we want to receive -- if upon return the dispatcher tells us we timed out, -- return an error to whoever called us if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then return nil, "timeout" end -- try receiving value, error, partial = tcp:receive(pattern, partial) -- if we are done, or there was an unexpected error, -- break away from loop. also, if the user requested -- zero timeout, return all we got if (error ~= "timeout") or zero then return value, error, partial end end end -- connect in non-blocking mode and yield on timeout function wrap:connect(host, port) local result, error = tcp:connect(host, port) if error == "timeout" then -- return control to dispatcher. we will be writable when -- connection succeeds. -- if upon return the dispatcher tells us we have a -- timeout, just abort if coroutine.yield(dispatcher.sending, tcp) == "timeout" then return nil, "timeout" end -- when we come back, check if connection was successful result, error = tcp:connect(host, port) if result or error == "already connected" then return 1 else return nil, "non-blocking connect failed" end else return result, error end end -- accept in non-blocking mode and yield on timeout function wrap:accept() while 1 do -- return control to dispatcher. we will be readable when a -- connection arrives. -- if upon return the dispatcher tells us we have a -- timeout, just abort if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then return nil, "timeout" end local client, error = tcp:accept() if error ~= "timeout" then return cowrap(dispatcher, client, error) end end end -- remove cortn from context function wrap:close() dispatcher.stamp[tcp] = nil dispatcher.sending.set:remove(tcp) dispatcher.sending.cortn[tcp] = nil dispatcher.receiving.set:remove(tcp) dispatcher.receiving.cortn[tcp] = nil return tcp:close() end return setmetatable(wrap, metat) end ----------------------------------------------------------------------------- -- Our coroutine dispatcher ----------------------------------------------------------------------------- local cometat = { __index = {} } function schedule(cortn, status, operation, tcp) if status then if cortn and operation then operation.set:insert(tcp) operation.cortn[tcp] = cortn operation.stamp[tcp] = socket.gettime() end else error(operation) end end function kick(operation, tcp) operation.cortn[tcp] = nil operation.set:remove(tcp) end function wakeup(operation, tcp) local cortn = operation.cortn[tcp] -- if cortn is still valid, wake it up if cortn then kick(operation, tcp) return cortn, coroutine.resume(cortn) -- othrewise, just get scheduler not to do anything else return nil, true end end function abort(operation, tcp) local cortn = operation.cortn[tcp] if cortn then kick(operation, tcp) coroutine.resume(cortn, "timeout") end end -- step through all active cortns function cometat.__index:step() -- check which sockets are interesting and act on them local readable, writable = socket.select(self.receiving.set, self.sending.set, 1) -- for all readable connections, resume their cortns and reschedule -- when they yield back to us for _, tcp in ipairs(readable) do schedule(wakeup(self.receiving, tcp)) end -- for all writable connections, do the same for _, tcp in ipairs(writable) do schedule(wakeup(self.sending, tcp)) end -- politely ask replacement I/O functions in idle cortns to -- return reporting a timeout local now = socket.gettime() for tcp, stamp in pairs(self.stamp) do if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then abort(self.sending, tcp) abort(self.receiving, tcp) end end end function cometat.__index:start(func) local cortn = coroutine.create(func) schedule(cortn, coroutine.resume(cortn)) end function handlert.coroutine() local stamp = {} local dispatcher = { stamp = stamp, sending = { name = "sending", set = newset(), cortn = {}, stamp = stamp }, receiving = { name = "receiving", set = newset(), cortn = {}, stamp = stamp }, } function dispatcher.tcp() return cowrap(dispatcher, socket.tcp()) end return setmetatable(dispatcher, cometat) end