Mercurial > cosocket
view cosocket.py @ 95:57681224a62a
Attempted to factor out the trampoline into a separate object, and modified the way coroutine instructions are done.
author | Atul Varma <varmaa@toolness.com> |
---|---|
date | Fri, 01 May 2009 23:41:52 -0700 |
parents | abb3952c2209 |
children | 68598f164855 |
line wrap: on
line source
import sys import socket import asyncore import asynchat import types import traceback import time import weakref import logging DEFAULT_LOOP_TIMEOUT = 1.0 DEFAULT_TIMEOUT = 90.0 DEFAULT_MAX_DATA = 65536 time_map = {} def loop(timeout = DEFAULT_LOOP_TIMEOUT): start_time = time.time() while 1: asyncore.loop(timeout = timeout, count = 1) curr_time = time.time() time_elapsed = curr_time - start_time if time_elapsed > timeout: start_time = curr_time funcs_to_call = [] for func in time_map: time_map[func] -= time_elapsed if time_map[func] <= 0: funcs_to_call.append(func) for func in funcs_to_call: del time_map[func] for func in funcs_to_call: try: func() except: logging.error(traceback.format_exc()) class _Trampoline(object): def __init__(self, coroutine, handler): self.__handler = handler self.__coroutine = coroutine self.__coroutine_stack = [] def __log_error(self): logging.error(traceback.format_exc() + self.get_formatted_coroutine_traceback()) def __close_coroutine(self, coroutine): try: coroutine.close() except Exception: self.__log_error() def get_formatted_coroutine_traceback(self): if not self.__coroutine: return "" lines = ['Coroutine traceback (most recent call last):'] frames = [coroutine.gi_frame for coroutine in self.__coroutine_stack] frames.append(self.__coroutine.gi_frame) for frame in frames: name = frame.f_code.co_name filename = frame.f_code.co_filename lineno = frame.f_lineno lines.append('File "%s", line %d, in coroutine %s' % (filename, lineno, name)) return '\n'.join(lines) def close_coroutine_stack(self): if self.__coroutine: # Pass an exception back into the coroutine to kick # it out of whatever yielding state it's in. self.__close_coroutine(self.__coroutine) self.__coroutine = None while self.__coroutine_stack: self.__close_coroutine(self.__coroutine_stack.pop()) def close_coroutine_and_return_to_caller(self, message): self.__close_coroutine(self.__coroutine) if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(message) else: self.__coroutine = None def continue_from_yield(self, message = None, exception_info = None): try: if exception_info: instruction = self.__coroutine.throw(*exception_info) else: instruction = self.__coroutine.send(message) except StopIteration: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield() else: self.__coroutine = None self.__handler.handle_coroutine_complete() except Exception, e: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(exception_info = sys.exc_info()) else: self.__log_error() else: if type(instruction) == types.GeneratorType: self.__coroutine_stack.append(self.__coroutine) self.__coroutine = instruction self.continue_from_yield() else: instruction.execute(self.__handler) class _AsyncChatCoroutineDispatcher(asynchat.async_chat): def __init__(self, coroutine, conn = None): asynchat.async_chat.__init__(self, conn) self.trampoline = _Trampoline(coroutine, self) self.set_terminator(None) self.__max_data = DEFAULT_MAX_DATA self.__timeout = 0 self.__data = [] self.__data_len = 0 if conn: self.trampoline.continue_from_yield() def handle_coroutine_complete(self): self.handle_close() def handle_close(self): self.trampoline.close_coroutine_stack() self.clear_timeout() self.close() def log_info(self, message, type='info'): try: level = getattr(logging, type.upper()) except AttributeError: level = logging.INFO logging.log(level, message) def handle_error(self): self.log_info(traceback.format_exc() + self.trampoline.get_formatted_coroutine_traceback(), 'error') def handle_connect(self): self.trampoline.continue_from_yield() def initiate_send(self): asynchat.async_chat.initiate_send(self) if ((not self.ac_out_buffer) and (len(self.producer_fifo) == 0) and self.connected): self.trampoline.continue_from_yield() def __on_timeout(self): self.log_info("Timeout expired (%ss)." % self.__timeout, 'error') self.handle_close() def clear_timeout(self): self.__timeout = 0 if self.__on_timeout in time_map: del time_map[self.__on_timeout] def set_timeout(self, timeout): self.__timeout = timeout time_map[self.__on_timeout] = timeout def collect_incoming_data(self, data): self.__data.append(data) self.__data_len += len(data) if self.__max_data and self.__data_len > self.__max_data: self.log_info("Max data reached (%s bytes)." % self.__max_data, 'error') self.handle_close() def set_max_data(self, amount): self.__max_data = amount def found_terminator(self): if not (self.__max_data and self.__data_len > self.__max_data): self.set_terminator(None) data = ''.join(self.__data) self.__data = [] self.__data_len = 0 self.trampoline.continue_from_yield(data) class CoroutineSocketServer(asyncore.dispatcher): def __init__(self, addr, coroutineFactory): asyncore.dispatcher.__init__(self) self.__coroutineFactory = coroutineFactory self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.set_reuse_addr() self.bind(addr) self.listen(1) def run(self, timeout = DEFAULT_LOOP_TIMEOUT): loop(timeout) def handle_accept(self): conn, addr = self.accept() coroutine = self.__coroutineFactory(addr) _AsyncChatCoroutineDispatcher(coroutine, conn) class CoroutineSocketClient(_AsyncChatCoroutineDispatcher): def __init__(self, addr, coroutineFactory): coroutine = coroutineFactory(addr) _AsyncChatCoroutineDispatcher.__init__(self, coroutine) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.connect(addr) # Instructions that coroutines yield. class CoroutineInstruction(object): def __init__(self, *args, **kwargs): self.__args = args self.__kwargs = kwargs def execute(self, dispatcher): self.dispatcher = dispatcher self.do_execute(*self.__args, **self.__kwargs) class until_received(CoroutineInstruction): def do_execute(self, terminator = None, bytes = None, timeout = DEFAULT_TIMEOUT, max_data = DEFAULT_MAX_DATA): self.dispatcher.set_timeout(timeout) if terminator: max_data = 0 self.dispatcher.set_terminator(terminator) elif bytes: self.dispatcher.set_terminator(bytes) else: raise ValueError('Must specify terminator or bytes') self.dispatcher.set_max_data(max_data) class until_sent(CoroutineInstruction): def do_execute(self, content, timeout = DEFAULT_TIMEOUT): self.dispatcher.set_timeout(timeout) self.dispatcher.push(content) class return_value(CoroutineInstruction): def do_execute(self, value): self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)