view cosocket.py @ 86:408738c3cd5d

Added a handle_tick function to coroutine dispatchers.
author Atul Varma <varmaa@toolness.com>
date Fri, 01 May 2009 15:37:57 -0700
parents 7e3b3eb57ec2
children 43d37495e9d4
line wrap: on
line source

import sys
import socket
import asyncore
import asynchat
import types
import traceback
import time
import weakref

coroutine_dispatchers = []

def loop(timeout):
    start_time = time.time()
    while 1:
        asyncore.loop(timeout = timeout, count = 1)
        curr_time = time.time()
        time_elapsed = curr_time - start_time
        if time_elapsed > timeout:
            start_time = curr_time
            new_coroutine_dispatchers = []
            for ref in coroutine_dispatchers:
                dispatcher = ref()
                if dispatcher:
                    new_coroutine_dispatchers.append(ref)
                    try:
                        dispatcher.handle_tick(time_elapsed)
                    except:
                        dispatcher.handle_error()
            coroutine_dispatchers[:] = new_coroutine_dispatchers

class _AsyncChatCoroutineDispatcher(asynchat.async_chat):
    def __init__(self, coroutine, conn = None):
        coroutine_dispatchers.append(weakref.ref(self))
        asynchat.async_chat.__init__(self, conn)
        self.set_terminator(None)
        self.__coroutine = coroutine
        self.__data = []
        self.__coroutine_stack = []
        if conn:
            self.continue_from_yield()

    def close_coroutine_and_return_to_caller(self, message):
        self.__close_coroutine(self.__coroutine)
        if self.__coroutine_stack:
            self.__coroutine = self.__coroutine_stack.pop()
            self.continue_from_yield(message)
        else:
            self.__coroutine = None

    def continue_from_yield(self, message = None,
                            exception_info = None):
        try:
            if exception_info:
                instruction = self.__coroutine.throw(*exception_info)
            else:
                instruction = self.__coroutine.send(message)
        except StopIteration:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield()
            else:
                self.__coroutine = None
                self.handle_close()
        except Exception, e:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield(exception_info = sys.exc_info())
            else:
                self.handle_error()
        else:
            if type(instruction) == types.GeneratorType:
                self.__coroutine_stack.append(self.__coroutine)
                self.__coroutine = instruction
                self.continue_from_yield()
            else:
                instruction.execute(self)

    def __close_coroutine(self, coroutine):
        try:
            coroutine.close()
        except Exception:
            self.log_info(traceback.format_exc(), 'error')

    def get_coroutine_stack_frames(self):
        return [coroutine.gi_frame
                for coroutine in self.__coroutine_stack]

    def get_formatted_coroutine_traceback(self):
        lines = ['Coroutine traceback (most recent call last):']
        for frame in self.get_coroutine_stack_frames():
            name = frame.f_code.co_name
            filename = frame.f_code.co_filename
            lineno = frame.f_lineno
            lines.append('File "%s", line %d, in coroutine %s' %
                          (filename, lineno, name))
        return '\n'.join(lines)

    def handle_tick(self, time_elapsed):
        pass

    def handle_close(self):
        if self.__coroutine:
            # Pass an exception back into the coroutine to kick
            # it out of whatever yielding state it's in.
            self.__close_coroutine(self.__coroutine)
            self.__coroutine = None
            while self.__coroutine_stack:
                self.__close_coroutine(self.__coroutine_stack.pop())
        self.close()

    def log_info(self, message, type='info'):
        # TODO: Use the logging module here.
        print '%s: %s' % (type, message)

    def handle_error(self):
        self.log_info(traceback.format_exc() +
                      self.get_formatted_coroutine_traceback(),
                      'error')

    def handle_connect(self):
        self.continue_from_yield()

    def initiate_send(self):
        asynchat.async_chat.initiate_send(self)
        if ((not self.ac_out_buffer) and
            (len(self.producer_fifo) == 0) and
            self.connected):
            self.continue_from_yield()

    def collect_incoming_data(self, data):
        # TODO: Enforce some maximum data length.
        self.__data.append(data)

    def found_terminator(self):
        self.set_terminator(None)
        data = ''.join(self.__data)
        self.__data = []
        self.continue_from_yield(data)

class CoroutineSocketServer(asyncore.dispatcher):
    def __init__(self, addr, coroutineFactory):
        asyncore.dispatcher.__init__(self)
        self.__coroutineFactory = coroutineFactory
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(addr)
        self.listen(1)

    def run(self, timeout):
        loop(timeout)

    def handle_accept(self):
        conn, addr = self.accept()
        coroutine = self.__coroutineFactory(addr)
        _AsyncChatCoroutineDispatcher(coroutine, conn)

class CoroutineSocketClient(_AsyncChatCoroutineDispatcher):
    def __init__(self, addr, coroutineFactory):
        coroutine = coroutineFactory(addr)
        _AsyncChatCoroutineDispatcher.__init__(self, coroutine)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.connect(addr)

# Instructions that coroutines yield.

class until_received(object):
    def __init__(self, terminator = None, bytes = None):
        if terminator:
            self._terminator = terminator
        elif bytes:
            self._terminator = bytes
        else:
            raise ValueError()

    def execute(self, dispatcher):
        dispatcher.set_terminator(self._terminator)

class until_sent(object):
    def __init__(self, content):
        if not content:
            raise ValueError(content)
        self.content = content

    def execute(self, dispatcher):
        dispatcher.push(self.content)

class return_value(object):
    def __init__(self, value):
        self.value = value

    def execute(self, dispatcher):
        dispatcher.close_coroutine_and_return_to_caller(self.value)