Mercurial > cosocket
view cosocket.py @ 97:0d3dd2ab36cd
Factored out all instruction-specific logic from the dispatcher into individual instructions.
author | Atul Varma <varmaa@toolness.com> |
---|---|
date | Sat, 02 May 2009 00:14:59 -0700 |
parents | 68598f164855 |
children | 06aa973a54c3 |
line wrap: on
line source
import sys import socket import asyncore import asynchat import types import traceback import time import weakref import logging DEFAULT_LOOP_TIMEOUT = 1.0 DEFAULT_TIMEOUT = 90.0 DEFAULT_MAX_DATA = 65536 time_map = {} def loop(timeout = DEFAULT_LOOP_TIMEOUT): start_time = time.time() while 1: asyncore.loop(timeout = timeout, count = 1) curr_time = time.time() time_elapsed = curr_time - start_time if time_elapsed > timeout: start_time = curr_time funcs_to_call = [] for func in time_map: time_map[func] -= time_elapsed if time_map[func] <= 0: funcs_to_call.append(func) for func in funcs_to_call: del time_map[func] for func in funcs_to_call: try: func() except: logging.error(traceback.format_exc()) class _Trampoline(object): def __init__(self, coroutine, handler): self.__handler = handler self.__coroutine = coroutine self.__coroutine_stack = [] def __log_error(self): logging.error(traceback.format_exc() + self.get_formatted_coroutine_traceback()) def __close_coroutine(self, coroutine): try: coroutine.close() except Exception: self.__log_error() def get_formatted_coroutine_traceback(self): if not self.__coroutine: return "" lines = ['Coroutine traceback (most recent call last):'] frames = [coroutine.gi_frame for coroutine in self.__coroutine_stack] frames.append(self.__coroutine.gi_frame) for frame in frames: name = frame.f_code.co_name filename = frame.f_code.co_filename lineno = frame.f_lineno lines.append('File "%s", line %d, in coroutine %s' % (filename, lineno, name)) return '\n'.join(lines) def close_coroutine_stack(self): if self.__coroutine: # Pass an exception back into the coroutine to kick # it out of whatever yielding state it's in. self.__close_coroutine(self.__coroutine) self.__coroutine = None while self.__coroutine_stack: self.__close_coroutine(self.__coroutine_stack.pop()) def close_coroutine_and_return_to_caller(self, message): self.__close_coroutine(self.__coroutine) if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(message) else: self.__coroutine = None def continue_from_yield(self, message = None, exception_info = None): try: if exception_info: instruction = self.__coroutine.throw(*exception_info) else: instruction = self.__coroutine.send(message) except StopIteration: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield() else: self.__coroutine = None self.__handler.handle_coroutine_complete(None) except Exception, e: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(exception_info = sys.exc_info()) else: self.__log_error() self.__handler.handle_coroutine_complete(e) else: if type(instruction) == types.GeneratorType: self.__coroutine_stack.append(self.__coroutine) self.__coroutine = instruction self.continue_from_yield() else: self.__handler.handle_coroutine_instruction(instruction) class _AsyncChatCoroutineDispatcher(asynchat.async_chat): def __init__(self, coroutine, conn = None): asynchat.async_chat.__init__(self, conn) self.trampoline = _Trampoline(coroutine, self) self.set_terminator(None) if conn: self.trampoline.continue_from_yield() def handle_coroutine_instruction(self, instruction): self.__instruction = instruction instruction.execute(self) def handle_coroutine_complete(self, exception): self.__instruction = None if not exception: self.handle_close() def handle_close(self): self.trampoline.close_coroutine_stack() self.clear_timeout() self.close() def log_info(self, message, type='info'): try: level = getattr(logging, type.upper()) except AttributeError: level = logging.INFO logging.log(level, message) def handle_error(self): self.log_info(traceback.format_exc() + self.trampoline.get_formatted_coroutine_traceback(), 'error') def handle_connect(self): self.__instruction.handle_connect() def initiate_send(self): asynchat.async_chat.initiate_send(self) self.__instruction.handle_initiate_send() def __on_timeout(self): self.__instruction.handle_timeout() def clear_timeout(self): if self.__on_timeout in time_map: del time_map[self.__on_timeout] def set_timeout(self, timeout): time_map[self.__on_timeout] = timeout def collect_incoming_data(self, data): self.__instruction.collect_incoming_data(data) def found_terminator(self): self.__instruction.found_terminator() class CoroutineSocketServer(asyncore.dispatcher): def __init__(self, addr, coroutineFactory): asyncore.dispatcher.__init__(self) self.__coroutineFactory = coroutineFactory self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.set_reuse_addr() self.bind(addr) self.listen(1) def run(self, timeout = DEFAULT_LOOP_TIMEOUT): loop(timeout) def handle_accept(self): conn, addr = self.accept() coroutine = self.__coroutineFactory(addr) _AsyncChatCoroutineDispatcher(coroutine, conn) class CoroutineSocketClient(_AsyncChatCoroutineDispatcher): def __init__(self, addr, coroutineFactory): coroutine = coroutineFactory(addr) _AsyncChatCoroutineDispatcher.__init__(self, coroutine) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.connect(addr) # Instructions that coroutines yield. class CoroutineInstruction(object): def __init__(self, *args, **kwargs): self.__args = args self.__kwargs = kwargs def execute(self, dispatcher): self.dispatcher = dispatcher self.do_execute(*self.__args, **self.__kwargs) def handle_timeout(self): self.dispatcher.handle_close() class until_received(CoroutineInstruction): def do_execute(self, terminator = None, bytes = None, timeout = DEFAULT_TIMEOUT, max_data = DEFAULT_MAX_DATA): self.dispatcher.set_timeout(timeout) if terminator: max_data = 0 self.dispatcher.set_terminator(terminator) elif bytes: self.dispatcher.set_terminator(bytes) else: raise ValueError('Must specify terminator or bytes') self.__max_data = max_data self.__data = [] self.__data_len = 0 def collect_incoming_data(self, data): self.__data.append(data) self.__data_len += len(data) if self.__max_data and self.__data_len > self.__max_data: logging.error("Max data reached (%s bytes)." % self.__max_data) self.dispatcher.handle_close() def found_terminator(self): if not (self.__max_data and self.__data_len > self.__max_data): self.dispatcher.set_terminator(None) data = ''.join(self.__data) self.__data = [] self.__data_len = 0 self.dispatcher.clear_timeout() self.dispatcher.trampoline.continue_from_yield(data) class until_sent(CoroutineInstruction): def do_execute(self, content, timeout = DEFAULT_TIMEOUT): self.dispatcher.set_timeout(timeout) self.dispatcher.push(content) def handle_initiate_send(self): if ((not self.dispatcher.ac_out_buffer) and (len(self.dispatcher.producer_fifo) == 0) and self.dispatcher.connected): self.dispatcher.clear_timeout() self.dispatcher.trampoline.continue_from_yield() class return_value(CoroutineInstruction): def do_execute(self, value): self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)