changeset 95:57681224a62a

Attempted to factor out the trampoline into a separate object, and modified the way coroutine instructions are done.
author Atul Varma <varmaa@toolness.com>
date Fri, 01 May 2009 23:41:52 -0700
parents abb3952c2209
children 68598f164855
files channels.py cosocket.py
diffstat 2 files changed, 104 insertions(+), 95 deletions(-) [+]
line wrap: on
line diff
--- a/channels.py	Fri May 01 21:33:57 2009 -0700
+++ b/channels.py	Fri May 01 23:41:52 2009 -0700
@@ -2,31 +2,25 @@
 
 _channels = {}
 
-class until_message_sent(object):
-    def __init__(self, channel_name, message):
-        self.channel_name = channel_name
-        self.message = message
-
-    def execute(self, dispatcher):
-        if self.channel_name in _channels:
-            receivers = _channels[self.channel_name].values()
-            del _channels[self.channel_name]
+class until_message_sent(cosocket.CoroutineInstruction):
+    def do_execute(self, channel_name, message):
+        if channel_name in _channels:
+            receivers = _channels[channel_name].values()
+            del _channels[channel_name]
             for receiver in receivers:
-                receiver.continue_from_yield(self.message)
-        dispatcher.continue_from_yield()
+                receiver.trampoline.continue_from_yield(message)
+        self.dispatcher.trampoline.continue_from_yield()
 
-class _until_message_received(object):
-    def __init__(self, channel_name, timeout = cosocket.DEFAULT_TIMEOUT):
+class _until_message_received(cosocket.CoroutineInstruction):
+    def do_execute(self, channel_name,
+                   timeout = cosocket.DEFAULT_TIMEOUT):
+        if channel_name not in _channels:
+            _channels[channel_name] = {}
+        fd = self.dispatcher.socket.fileno()
+        _channels[channel_name][fd] = self.dispatcher
+        self.dispatcher.set_timeout(timeout)
         self.channel_name = channel_name
-        self._timeout = timeout
-        self._fd = None
-
-    def execute(self, dispatcher):
-        if self.channel_name not in _channels:
-            _channels[self.channel_name] = {}
-        self._fd = dispatcher.socket.fileno()
-        _channels[self.channel_name][self._fd] = dispatcher
-        dispatcher.set_timeout(self._timeout)
+        self._fd = fd
 
     def finalize(self):
         if (self.channel_name in _channels and
--- a/cosocket.py	Fri May 01 21:33:57 2009 -0700
+++ b/cosocket.py	Fri May 01 23:41:52 2009 -0700
@@ -35,18 +35,45 @@
                 except:
                     logging.error(traceback.format_exc())
 
-class _AsyncChatCoroutineDispatcher(asynchat.async_chat):
-    def __init__(self, coroutine, conn = None):
-        asynchat.async_chat.__init__(self, conn)
-        self.set_terminator(None)
-        self.__max_data = DEFAULT_MAX_DATA
+class _Trampoline(object):
+    def __init__(self, coroutine, handler):
+        self.__handler = handler
         self.__coroutine = coroutine
-        self.__timeout = 0
-        self.__data = []
-        self.__data_len = 0
         self.__coroutine_stack = []
-        if conn:
-            self.continue_from_yield()
+
+    def __log_error(self):
+        logging.error(traceback.format_exc() +
+                      self.get_formatted_coroutine_traceback())
+
+    def __close_coroutine(self, coroutine):
+        try:
+            coroutine.close()
+        except Exception:
+            self.__log_error()
+
+    def get_formatted_coroutine_traceback(self):
+        if not self.__coroutine:
+            return ""
+        lines = ['Coroutine traceback (most recent call last):']
+        frames = [coroutine.gi_frame
+                  for coroutine in self.__coroutine_stack]
+        frames.append(self.__coroutine.gi_frame)
+        for frame in frames:
+            name = frame.f_code.co_name
+            filename = frame.f_code.co_filename
+            lineno = frame.f_lineno
+            lines.append('File "%s", line %d, in coroutine %s' %
+                          (filename, lineno, name))
+        return '\n'.join(lines)
+
+    def close_coroutine_stack(self):
+        if self.__coroutine:
+            # Pass an exception back into the coroutine to kick
+            # it out of whatever yielding state it's in.
+            self.__close_coroutine(self.__coroutine)
+            self.__coroutine = None
+            while self.__coroutine_stack:
+                self.__close_coroutine(self.__coroutine_stack.pop())
 
     def close_coroutine_and_return_to_caller(self, message):
         self.__close_coroutine(self.__coroutine)
@@ -56,9 +83,7 @@
         else:
             self.__coroutine = None
 
-    def continue_from_yield(self, message = None,
-                            exception_info = None):
-        self.clear_timeout()
+    def continue_from_yield(self, message = None, exception_info = None):
         try:
             if exception_info:
                 instruction = self.__coroutine.throw(*exception_info)
@@ -70,46 +95,38 @@
                 self.continue_from_yield()
             else:
                 self.__coroutine = None
-                self.handle_close()
+                self.__handler.handle_coroutine_complete()
         except Exception, e:
             if self.__coroutine_stack:
                 self.__coroutine = self.__coroutine_stack.pop()
                 self.continue_from_yield(exception_info = sys.exc_info())
             else:
-                self.handle_error()
+                self.__log_error()
         else:
             if type(instruction) == types.GeneratorType:
                 self.__coroutine_stack.append(self.__coroutine)
                 self.__coroutine = instruction
                 self.continue_from_yield()
             else:
-                instruction.execute(self)
-
-    def __close_coroutine(self, coroutine):
-        try:
-            coroutine.close()
-        except Exception:
-            self.log_info(traceback.format_exc(), 'error')
+                instruction.execute(self.__handler)
 
-    def get_formatted_coroutine_traceback(self):
-        lines = ['Coroutine traceback (most recent call last):']
-        for frame in [coroutine.gi_frame
-                      for coroutine in self.__coroutine_stack]:
-            name = frame.f_code.co_name
-            filename = frame.f_code.co_filename
-            lineno = frame.f_lineno
-            lines.append('File "%s", line %d, in coroutine %s' %
-                          (filename, lineno, name))
-        return '\n'.join(lines)
+class _AsyncChatCoroutineDispatcher(asynchat.async_chat):
+    def __init__(self, coroutine, conn = None):
+        asynchat.async_chat.__init__(self, conn)
+        self.trampoline = _Trampoline(coroutine, self)
+        self.set_terminator(None)
+        self.__max_data = DEFAULT_MAX_DATA
+        self.__timeout = 0
+        self.__data = []
+        self.__data_len = 0
+        if conn:
+            self.trampoline.continue_from_yield()
+
+    def handle_coroutine_complete(self):
+        self.handle_close()
 
     def handle_close(self):
-        if self.__coroutine:
-            # Pass an exception back into the coroutine to kick
-            # it out of whatever yielding state it's in.
-            self.__close_coroutine(self.__coroutine)
-            self.__coroutine = None
-            while self.__coroutine_stack:
-                self.__close_coroutine(self.__coroutine_stack.pop())
+        self.trampoline.close_coroutine_stack()
         self.clear_timeout()
         self.close()
 
@@ -122,18 +139,18 @@
 
     def handle_error(self):
         self.log_info(traceback.format_exc() +
-                      self.get_formatted_coroutine_traceback(),
+                      self.trampoline.get_formatted_coroutine_traceback(),
                       'error')
 
     def handle_connect(self):
-        self.continue_from_yield()
+        self.trampoline.continue_from_yield()
 
     def initiate_send(self):
         asynchat.async_chat.initiate_send(self)
         if ((not self.ac_out_buffer) and
             (len(self.producer_fifo) == 0) and
             self.connected):
-            self.continue_from_yield()
+            self.trampoline.continue_from_yield()
 
     def __on_timeout(self):
         self.log_info("Timeout expired (%ss)." % self.__timeout,
@@ -166,7 +183,7 @@
             data = ''.join(self.__data)
             self.__data = []
             self.__data_len = 0
-            self.continue_from_yield(data)
+            self.trampoline.continue_from_yield(data)
 
 class CoroutineSocketServer(asyncore.dispatcher):
     def __init__(self, addr, coroutineFactory):
@@ -194,38 +211,36 @@
 
 # Instructions that coroutines yield.
 
-class until_received(object):
-    def __init__(self, terminator = None, bytes = None,
-                 timeout = DEFAULT_TIMEOUT, max_data = DEFAULT_MAX_DATA):
-        self._timeout = timeout
-        if terminator:
-            self._terminator = terminator
-            self._max_data = max_data
-        elif bytes:
-            self._terminator = bytes
-            self._max_data = 0
-        else:
-            raise ValueError()
+class CoroutineInstruction(object):
+    def __init__(self, *args, **kwargs):
+        self.__args = args
+        self.__kwargs = kwargs
 
     def execute(self, dispatcher):
-        dispatcher.set_timeout(self._timeout)
-        dispatcher.set_max_data(self._max_data)
-        dispatcher.set_terminator(self._terminator)
-
-class until_sent(object):
-    def __init__(self, content, timeout = DEFAULT_TIMEOUT):
-        if not content:
-            raise ValueError(content)
-        self._timeout = timeout
-        self.content = content
+        self.dispatcher = dispatcher
+        self.do_execute(*self.__args, **self.__kwargs)
 
-    def execute(self, dispatcher):
-        dispatcher.set_timeout(self._timeout)
-        dispatcher.push(self.content)
+class until_received(CoroutineInstruction):
+    def do_execute(self,
+                   terminator = None,
+                   bytes = None,
+                   timeout = DEFAULT_TIMEOUT,
+                   max_data = DEFAULT_MAX_DATA):
+        self.dispatcher.set_timeout(timeout)
+        if terminator:
+            max_data = 0
+            self.dispatcher.set_terminator(terminator)
+        elif bytes:
+            self.dispatcher.set_terminator(bytes)
+        else:
+            raise ValueError('Must specify terminator or bytes')
+        self.dispatcher.set_max_data(max_data)
 
-class return_value(object):
-    def __init__(self, value):
-        self.value = value
+class until_sent(CoroutineInstruction):
+    def do_execute(self, content, timeout = DEFAULT_TIMEOUT):
+        self.dispatcher.set_timeout(timeout)
+        self.dispatcher.push(content)
 
-    def execute(self, dispatcher):
-        dispatcher.close_coroutine_and_return_to_caller(self.value)
+class return_value(CoroutineInstruction):
+    def do_execute(self, value):
+        self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)