Mercurial > cosocket
changeset 95:57681224a62a
Attempted to factor out the trampoline into a separate object, and modified the way coroutine instructions are done.
author | Atul Varma <varmaa@toolness.com> |
---|---|
date | Fri, 01 May 2009 23:41:52 -0700 |
parents | abb3952c2209 |
children | 68598f164855 |
files | channels.py cosocket.py |
diffstat | 2 files changed, 104 insertions(+), 95 deletions(-) [+] |
line wrap: on
line diff
--- a/channels.py Fri May 01 21:33:57 2009 -0700 +++ b/channels.py Fri May 01 23:41:52 2009 -0700 @@ -2,31 +2,25 @@ _channels = {} -class until_message_sent(object): - def __init__(self, channel_name, message): - self.channel_name = channel_name - self.message = message - - def execute(self, dispatcher): - if self.channel_name in _channels: - receivers = _channels[self.channel_name].values() - del _channels[self.channel_name] +class until_message_sent(cosocket.CoroutineInstruction): + def do_execute(self, channel_name, message): + if channel_name in _channels: + receivers = _channels[channel_name].values() + del _channels[channel_name] for receiver in receivers: - receiver.continue_from_yield(self.message) - dispatcher.continue_from_yield() + receiver.trampoline.continue_from_yield(message) + self.dispatcher.trampoline.continue_from_yield() -class _until_message_received(object): - def __init__(self, channel_name, timeout = cosocket.DEFAULT_TIMEOUT): +class _until_message_received(cosocket.CoroutineInstruction): + def do_execute(self, channel_name, + timeout = cosocket.DEFAULT_TIMEOUT): + if channel_name not in _channels: + _channels[channel_name] = {} + fd = self.dispatcher.socket.fileno() + _channels[channel_name][fd] = self.dispatcher + self.dispatcher.set_timeout(timeout) self.channel_name = channel_name - self._timeout = timeout - self._fd = None - - def execute(self, dispatcher): - if self.channel_name not in _channels: - _channels[self.channel_name] = {} - self._fd = dispatcher.socket.fileno() - _channels[self.channel_name][self._fd] = dispatcher - dispatcher.set_timeout(self._timeout) + self._fd = fd def finalize(self): if (self.channel_name in _channels and
--- a/cosocket.py Fri May 01 21:33:57 2009 -0700 +++ b/cosocket.py Fri May 01 23:41:52 2009 -0700 @@ -35,18 +35,45 @@ except: logging.error(traceback.format_exc()) -class _AsyncChatCoroutineDispatcher(asynchat.async_chat): - def __init__(self, coroutine, conn = None): - asynchat.async_chat.__init__(self, conn) - self.set_terminator(None) - self.__max_data = DEFAULT_MAX_DATA +class _Trampoline(object): + def __init__(self, coroutine, handler): + self.__handler = handler self.__coroutine = coroutine - self.__timeout = 0 - self.__data = [] - self.__data_len = 0 self.__coroutine_stack = [] - if conn: - self.continue_from_yield() + + def __log_error(self): + logging.error(traceback.format_exc() + + self.get_formatted_coroutine_traceback()) + + def __close_coroutine(self, coroutine): + try: + coroutine.close() + except Exception: + self.__log_error() + + def get_formatted_coroutine_traceback(self): + if not self.__coroutine: + return "" + lines = ['Coroutine traceback (most recent call last):'] + frames = [coroutine.gi_frame + for coroutine in self.__coroutine_stack] + frames.append(self.__coroutine.gi_frame) + for frame in frames: + name = frame.f_code.co_name + filename = frame.f_code.co_filename + lineno = frame.f_lineno + lines.append('File "%s", line %d, in coroutine %s' % + (filename, lineno, name)) + return '\n'.join(lines) + + def close_coroutine_stack(self): + if self.__coroutine: + # Pass an exception back into the coroutine to kick + # it out of whatever yielding state it's in. + self.__close_coroutine(self.__coroutine) + self.__coroutine = None + while self.__coroutine_stack: + self.__close_coroutine(self.__coroutine_stack.pop()) def close_coroutine_and_return_to_caller(self, message): self.__close_coroutine(self.__coroutine) @@ -56,9 +83,7 @@ else: self.__coroutine = None - def continue_from_yield(self, message = None, - exception_info = None): - self.clear_timeout() + def continue_from_yield(self, message = None, exception_info = None): try: if exception_info: instruction = self.__coroutine.throw(*exception_info) @@ -70,46 +95,38 @@ self.continue_from_yield() else: self.__coroutine = None - self.handle_close() + self.__handler.handle_coroutine_complete() except Exception, e: if self.__coroutine_stack: self.__coroutine = self.__coroutine_stack.pop() self.continue_from_yield(exception_info = sys.exc_info()) else: - self.handle_error() + self.__log_error() else: if type(instruction) == types.GeneratorType: self.__coroutine_stack.append(self.__coroutine) self.__coroutine = instruction self.continue_from_yield() else: - instruction.execute(self) - - def __close_coroutine(self, coroutine): - try: - coroutine.close() - except Exception: - self.log_info(traceback.format_exc(), 'error') + instruction.execute(self.__handler) - def get_formatted_coroutine_traceback(self): - lines = ['Coroutine traceback (most recent call last):'] - for frame in [coroutine.gi_frame - for coroutine in self.__coroutine_stack]: - name = frame.f_code.co_name - filename = frame.f_code.co_filename - lineno = frame.f_lineno - lines.append('File "%s", line %d, in coroutine %s' % - (filename, lineno, name)) - return '\n'.join(lines) +class _AsyncChatCoroutineDispatcher(asynchat.async_chat): + def __init__(self, coroutine, conn = None): + asynchat.async_chat.__init__(self, conn) + self.trampoline = _Trampoline(coroutine, self) + self.set_terminator(None) + self.__max_data = DEFAULT_MAX_DATA + self.__timeout = 0 + self.__data = [] + self.__data_len = 0 + if conn: + self.trampoline.continue_from_yield() + + def handle_coroutine_complete(self): + self.handle_close() def handle_close(self): - if self.__coroutine: - # Pass an exception back into the coroutine to kick - # it out of whatever yielding state it's in. - self.__close_coroutine(self.__coroutine) - self.__coroutine = None - while self.__coroutine_stack: - self.__close_coroutine(self.__coroutine_stack.pop()) + self.trampoline.close_coroutine_stack() self.clear_timeout() self.close() @@ -122,18 +139,18 @@ def handle_error(self): self.log_info(traceback.format_exc() + - self.get_formatted_coroutine_traceback(), + self.trampoline.get_formatted_coroutine_traceback(), 'error') def handle_connect(self): - self.continue_from_yield() + self.trampoline.continue_from_yield() def initiate_send(self): asynchat.async_chat.initiate_send(self) if ((not self.ac_out_buffer) and (len(self.producer_fifo) == 0) and self.connected): - self.continue_from_yield() + self.trampoline.continue_from_yield() def __on_timeout(self): self.log_info("Timeout expired (%ss)." % self.__timeout, @@ -166,7 +183,7 @@ data = ''.join(self.__data) self.__data = [] self.__data_len = 0 - self.continue_from_yield(data) + self.trampoline.continue_from_yield(data) class CoroutineSocketServer(asyncore.dispatcher): def __init__(self, addr, coroutineFactory): @@ -194,38 +211,36 @@ # Instructions that coroutines yield. -class until_received(object): - def __init__(self, terminator = None, bytes = None, - timeout = DEFAULT_TIMEOUT, max_data = DEFAULT_MAX_DATA): - self._timeout = timeout - if terminator: - self._terminator = terminator - self._max_data = max_data - elif bytes: - self._terminator = bytes - self._max_data = 0 - else: - raise ValueError() +class CoroutineInstruction(object): + def __init__(self, *args, **kwargs): + self.__args = args + self.__kwargs = kwargs def execute(self, dispatcher): - dispatcher.set_timeout(self._timeout) - dispatcher.set_max_data(self._max_data) - dispatcher.set_terminator(self._terminator) - -class until_sent(object): - def __init__(self, content, timeout = DEFAULT_TIMEOUT): - if not content: - raise ValueError(content) - self._timeout = timeout - self.content = content + self.dispatcher = dispatcher + self.do_execute(*self.__args, **self.__kwargs) - def execute(self, dispatcher): - dispatcher.set_timeout(self._timeout) - dispatcher.push(self.content) +class until_received(CoroutineInstruction): + def do_execute(self, + terminator = None, + bytes = None, + timeout = DEFAULT_TIMEOUT, + max_data = DEFAULT_MAX_DATA): + self.dispatcher.set_timeout(timeout) + if terminator: + max_data = 0 + self.dispatcher.set_terminator(terminator) + elif bytes: + self.dispatcher.set_terminator(bytes) + else: + raise ValueError('Must specify terminator or bytes') + self.dispatcher.set_max_data(max_data) -class return_value(object): - def __init__(self, value): - self.value = value +class until_sent(CoroutineInstruction): + def do_execute(self, content, timeout = DEFAULT_TIMEOUT): + self.dispatcher.set_timeout(timeout) + self.dispatcher.push(content) - def execute(self, dispatcher): - dispatcher.close_coroutine_and_return_to_caller(self.value) +class return_value(CoroutineInstruction): + def do_execute(self, value): + self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)