Mercurial > cosocket
view cosocket.py @ 99:b9cdccd5fbe4 default tip
Added unit test (doctest).
author | Atul Varma <varmaa@toolness.com> |
---|---|
date | Sat, 02 May 2009 08:13:45 -0700 |
parents | 06aa973a54c3 |
children |
line wrap: on
line source
''' >>> IP = '127.0.0.1' >>> PORT = 38424 >>> def server_coroutine(): ... print 'server now listening.' ... yield until_listening((IP, PORT)) ... print 'server now spawning client.' ... AsyncChatCoroutine(client_coroutine()) ... print 'server now accepting connections.' ... conn, addr = yield until_connection_accepted() ... print 'server now spawning connection.' ... AsyncChatCoroutine(connection_coroutine(addr), conn) >>> def client_coroutine(): ... print 'client now connecting to server.' ... yield until_connected((IP, PORT)) ... print 'client now connected, sending text.' ... yield until_sent('hai2u\\r\\n') ... print 'client waiting for response.' ... data = yield until_received(terminator = '\\r\\n') ... print 'client received response: %s' % data >>> def connection_coroutine(addr): ... print 'server connection waiting for request.' ... data = yield until_received('\\r\\n') ... print 'server connection sending back response: %s' % data ... yield until_sent(data + '\\r\\n') >>> server = AsyncChatCoroutine(server_coroutine()) server now listening. server now spawning client. client now connecting to server. server now accepting connections. >>> loop() server now spawning connection. server connection waiting for request. client now connected, sending text. client waiting for response. server connection sending back response: hai2u client received response: hai2u ''' import sys import socket import asyncore import asynchat import types import traceback import time import weakref import logging DEFAULT_LOOP_TIMEOUT = 1.0 DEFAULT_TIMEOUT = 90.0 DEFAULT_MAX_DATA = 65536 time_map = {} def loop(timeout = DEFAULT_LOOP_TIMEOUT): start_time = time.time() while asyncore.socket_map or time_map: asyncore.loop(timeout = timeout, count = 1) curr_time = time.time() time_elapsed = curr_time - start_time if time_elapsed > timeout: start_time = curr_time funcs_to_call = [] for func in time_map: time_map[func] -= time_elapsed if time_map[func] <= 0: funcs_to_call.append(func) for func in funcs_to_call: del time_map[func] for func in funcs_to_call: try: func() except: logging.error(traceback.format_exc()) class _Trampoline(object): def __init__(self, coroutine, handler): self.__handler = handler self.__coroutine = coroutine self.__coroutine_stack = [] def __log_error(self): logging.error(traceback.format_exc() + self.get_formatted_coroutine_traceback()) def __close_coroutine(self, coroutine): try: coroutine.close() except Exception: self.__log_error() def get_formatted_coroutine_traceback(self): if not self.__coroutine: return "" lines = [] frames = [coroutine.gi_frame for coroutine in self.__coroutine_stack] if self.__coroutine.gi_frame: frames.append(self.__coroutine.gi_frame) for frame in frames: name = frame.f_code.co_name filename = frame.f_code.co_filename lineno = frame.f_lineno lines.append('File "%s", line %d, in coroutine %s' % (filename, lineno, name)) if not lines: return 'No coroutine traceback available.' lines.insert(0, 'Coroutine traceback (most recent call last):') return '\n'.join(lines) def close_coroutine_stack(self): if self.__coroutine: # Pass an exception back into the coroutine to kick # it out of whatever yielding state it's in. self.__close_coroutine(self.__coroutine) self.__coroutine = None while self.__coroutine_stack: self.__close_coroutine(self.__coroutine_stack.pop()) def close_coroutine_and_return_to_caller(self, message): self.__close_coroutine(self.__coroutine) if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(message) else: self.__coroutine = None def continue_from_yield(self, message = None, exception_info = None): try: if exception_info: instruction = self.__coroutine.throw(*exception_info) else: instruction = self.__coroutine.send(message) except StopIteration: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield() else: self.__coroutine = None self.__handler.handle_coroutine_complete(None) except Exception, e: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(exception_info = sys.exc_info()) else: self.__log_error() self.__handler.handle_coroutine_complete(e) else: if type(instruction) == types.GeneratorType: self.__coroutine_stack.append(self.__coroutine) self.__coroutine = instruction self.continue_from_yield() else: self.__handler.handle_coroutine_instruction(instruction) class AsyncChatCoroutine(asynchat.async_chat): def __init__(self, coroutine, conn = None): asynchat.async_chat.__init__(self, conn) self.trampoline = _Trampoline(coroutine, self) self.set_terminator(None) self.trampoline.continue_from_yield() def handle_coroutine_instruction(self, instruction): self.__instruction = instruction instruction.execute(self) def handle_coroutine_complete(self, exception): self.__instruction = None if not exception: self.handle_close() def handle_close(self): self.trampoline.close_coroutine_stack() self.clear_timeout() self.close() def log_info(self, message, type='info'): try: level = getattr(logging, type.upper()) except AttributeError: level = logging.INFO logging.log(level, message) def handle_error(self): self.log_info(traceback.format_exc() + self.trampoline.get_formatted_coroutine_traceback(), 'error') def handle_accept(self): self.__instruction.handle_accept() def handle_connect(self): self.__instruction.handle_connect() def initiate_send(self): asynchat.async_chat.initiate_send(self) self.__instruction.handle_initiate_send() def __on_timeout(self): self.__instruction.handle_timeout() def clear_timeout(self): if self.__on_timeout in time_map: del time_map[self.__on_timeout] def set_timeout(self, timeout): time_map[self.__on_timeout] = timeout def collect_incoming_data(self, data): self.__instruction.collect_incoming_data(data) def found_terminator(self): self.__instruction.found_terminator() # Instructions that coroutines yield. class CoroutineInstruction(object): def __init__(self, *args, **kwargs): self.__args = args self.__kwargs = kwargs def execute(self, dispatcher): self.dispatcher = dispatcher self.do_execute(*self.__args, **self.__kwargs) def handle_timeout(self): self.dispatcher.handle_close() class until_listening(CoroutineInstruction): def do_execute(self, bind_addr): self.dispatcher.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.dispatcher.set_reuse_addr() self.dispatcher.bind(bind_addr) self.dispatcher.listen(1) self.dispatcher.trampoline.continue_from_yield() class until_connection_accepted(CoroutineInstruction): def do_execute(self): pass def handle_accept(self): data = self.dispatcher.accept() self.dispatcher.trampoline.continue_from_yield(data) class until_connected(CoroutineInstruction): def do_execute(self, addr): self.dispatcher.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.dispatcher.connect(addr) def handle_connect(self): self.dispatcher.trampoline.continue_from_yield() class until_received(CoroutineInstruction): def do_execute(self, terminator = None, bytes = None, timeout = DEFAULT_TIMEOUT, max_data = DEFAULT_MAX_DATA): self.dispatcher.set_timeout(timeout) if terminator: max_data = 0 self.dispatcher.set_terminator(terminator) elif bytes: self.dispatcher.set_terminator(bytes) else: raise ValueError('Must specify terminator or bytes') self.__max_data = max_data self.__data = [] self.__data_len = 0 def collect_incoming_data(self, data): self.__data.append(data) self.__data_len += len(data) if self.__max_data and self.__data_len > self.__max_data: logging.error("Max data reached (%s bytes)." % self.__max_data) self.dispatcher.handle_close() def found_terminator(self): if not (self.__max_data and self.__data_len > self.__max_data): self.dispatcher.set_terminator(None) data = ''.join(self.__data) self.__data = [] self.__data_len = 0 self.dispatcher.clear_timeout() self.dispatcher.trampoline.continue_from_yield(data) class until_sent(CoroutineInstruction): def do_execute(self, content, timeout = DEFAULT_TIMEOUT): self.dispatcher.set_timeout(timeout) self.dispatcher.push(content) def handle_initiate_send(self): if ((not self.dispatcher.ac_out_buffer) and (len(self.dispatcher.producer_fifo) == 0) and self.dispatcher.connected): self.dispatcher.clear_timeout() self.dispatcher.trampoline.continue_from_yield() class return_value(CoroutineInstruction): def do_execute(self, value): self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value) if __name__ == '__main__': import doctest doctest.testmod(verbose = True)