view cosocket.py @ 99:b9cdccd5fbe4 default tip

Added unit test (doctest).
author Atul Varma <varmaa@toolness.com>
date Sat, 02 May 2009 08:13:45 -0700
parents 06aa973a54c3
children
line wrap: on
line source

'''
    >>> IP = '127.0.0.1'
    >>> PORT = 38424

    >>> def server_coroutine():
    ...     print 'server now listening.'
    ...     yield until_listening((IP, PORT))
    ...     print 'server now spawning client.'
    ...     AsyncChatCoroutine(client_coroutine())
    ...     print 'server now accepting connections.'
    ...     conn, addr = yield until_connection_accepted()
    ...     print 'server now spawning connection.'
    ...     AsyncChatCoroutine(connection_coroutine(addr), conn)

    >>> def client_coroutine():
    ...     print 'client now connecting to server.'
    ...     yield until_connected((IP, PORT))
    ...     print 'client now connected, sending text.'
    ...     yield until_sent('hai2u\\r\\n')
    ...     print 'client waiting for response.'
    ...     data = yield until_received(terminator = '\\r\\n')
    ...     print 'client received response: %s' % data

    >>> def connection_coroutine(addr):
    ...     print 'server connection waiting for request.'
    ...     data = yield until_received('\\r\\n')
    ...     print 'server connection sending back response: %s' % data
    ...     yield until_sent(data + '\\r\\n')

    >>> server = AsyncChatCoroutine(server_coroutine())
    server now listening.
    server now spawning client.
    client now connecting to server.
    server now accepting connections.

    >>> loop()
    server now spawning connection.
    server connection waiting for request.
    client now connected, sending text.
    client waiting for response.
    server connection sending back response: hai2u
    client received response: hai2u

'''

import sys
import socket
import asyncore
import asynchat
import types
import traceback
import time
import weakref
import logging

DEFAULT_LOOP_TIMEOUT = 1.0
DEFAULT_TIMEOUT = 90.0
DEFAULT_MAX_DATA = 65536

time_map = {}

def loop(timeout = DEFAULT_LOOP_TIMEOUT):
    start_time = time.time()
    while asyncore.socket_map or time_map:
        asyncore.loop(timeout = timeout, count = 1)
        curr_time = time.time()
        time_elapsed = curr_time - start_time
        if time_elapsed > timeout:
            start_time = curr_time
            funcs_to_call = []
            for func in time_map:
                time_map[func] -= time_elapsed
                if time_map[func] <= 0:
                    funcs_to_call.append(func)
            for func in funcs_to_call:
                del time_map[func]
            for func in funcs_to_call:
                try:
                    func()
                except:
                    logging.error(traceback.format_exc())

class _Trampoline(object):
    def __init__(self, coroutine, handler):
        self.__handler = handler
        self.__coroutine = coroutine
        self.__coroutine_stack = []

    def __log_error(self):
        logging.error(traceback.format_exc() +
                      self.get_formatted_coroutine_traceback())

    def __close_coroutine(self, coroutine):
        try:
            coroutine.close()
        except Exception:
            self.__log_error()

    def get_formatted_coroutine_traceback(self):
        if not self.__coroutine:
            return ""
        lines = []
        frames = [coroutine.gi_frame
                  for coroutine in self.__coroutine_stack]
        if self.__coroutine.gi_frame:
            frames.append(self.__coroutine.gi_frame)
        for frame in frames:
            name = frame.f_code.co_name
            filename = frame.f_code.co_filename
            lineno = frame.f_lineno
            lines.append('File "%s", line %d, in coroutine %s' %
                          (filename, lineno, name))
        if not lines:
            return 'No coroutine traceback available.'
        lines.insert(0, 'Coroutine traceback (most recent call last):')
        return '\n'.join(lines)

    def close_coroutine_stack(self):
        if self.__coroutine:
            # Pass an exception back into the coroutine to kick
            # it out of whatever yielding state it's in.
            self.__close_coroutine(self.__coroutine)
            self.__coroutine = None
            while self.__coroutine_stack:
                self.__close_coroutine(self.__coroutine_stack.pop())

    def close_coroutine_and_return_to_caller(self, message):
        self.__close_coroutine(self.__coroutine)
        if self.__coroutine_stack:
            self.__coroutine = self.__coroutine_stack.pop()
            self.continue_from_yield(message)
        else:
            self.__coroutine = None

    def continue_from_yield(self, message = None, exception_info = None):
        try:
            if exception_info:
                instruction = self.__coroutine.throw(*exception_info)
            else:
                instruction = self.__coroutine.send(message)
        except StopIteration:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield()
            else:
                self.__coroutine = None
                self.__handler.handle_coroutine_complete(None)
        except Exception, e:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield(exception_info = sys.exc_info())
            else:
                self.__log_error()
                self.__handler.handle_coroutine_complete(e)
        else:
            if type(instruction) == types.GeneratorType:
                self.__coroutine_stack.append(self.__coroutine)
                self.__coroutine = instruction
                self.continue_from_yield()
            else:
                self.__handler.handle_coroutine_instruction(instruction)

class AsyncChatCoroutine(asynchat.async_chat):
    def __init__(self, coroutine, conn = None):
        asynchat.async_chat.__init__(self, conn)
        self.trampoline = _Trampoline(coroutine, self)
        self.set_terminator(None)
        self.trampoline.continue_from_yield()

    def handle_coroutine_instruction(self, instruction):
        self.__instruction = instruction
        instruction.execute(self)

    def handle_coroutine_complete(self, exception):
        self.__instruction = None
        if not exception:
            self.handle_close()

    def handle_close(self):
        self.trampoline.close_coroutine_stack()
        self.clear_timeout()
        self.close()

    def log_info(self, message, type='info'):
        try:
            level = getattr(logging, type.upper())
        except AttributeError:
            level = logging.INFO
        logging.log(level, message)

    def handle_error(self):
        self.log_info(traceback.format_exc() +
                      self.trampoline.get_formatted_coroutine_traceback(),
                      'error')

    def handle_accept(self):
        self.__instruction.handle_accept()

    def handle_connect(self):
        self.__instruction.handle_connect()

    def initiate_send(self):
        asynchat.async_chat.initiate_send(self)
        self.__instruction.handle_initiate_send()

    def __on_timeout(self):
        self.__instruction.handle_timeout()

    def clear_timeout(self):
        if self.__on_timeout in time_map:
            del time_map[self.__on_timeout]

    def set_timeout(self, timeout):
        time_map[self.__on_timeout] = timeout

    def collect_incoming_data(self, data):
        self.__instruction.collect_incoming_data(data)

    def found_terminator(self):
        self.__instruction.found_terminator()

# Instructions that coroutines yield.

class CoroutineInstruction(object):
    def __init__(self, *args, **kwargs):
        self.__args = args
        self.__kwargs = kwargs

    def execute(self, dispatcher):
        self.dispatcher = dispatcher
        self.do_execute(*self.__args, **self.__kwargs)

    def handle_timeout(self):
        self.dispatcher.handle_close()

class until_listening(CoroutineInstruction):
    def do_execute(self, bind_addr):
        self.dispatcher.create_socket(socket.AF_INET,
                                      socket.SOCK_STREAM)
        self.dispatcher.set_reuse_addr()
        self.dispatcher.bind(bind_addr)
        self.dispatcher.listen(1)
        self.dispatcher.trampoline.continue_from_yield()

class until_connection_accepted(CoroutineInstruction):
    def do_execute(self):
        pass

    def handle_accept(self):
        data = self.dispatcher.accept()
        self.dispatcher.trampoline.continue_from_yield(data)

class until_connected(CoroutineInstruction):
    def do_execute(self, addr):
        self.dispatcher.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.dispatcher.connect(addr)

    def handle_connect(self):
        self.dispatcher.trampoline.continue_from_yield()

class until_received(CoroutineInstruction):
    def do_execute(self,
                   terminator = None,
                   bytes = None,
                   timeout = DEFAULT_TIMEOUT,
                   max_data = DEFAULT_MAX_DATA):
        self.dispatcher.set_timeout(timeout)
        if terminator:
            max_data = 0
            self.dispatcher.set_terminator(terminator)
        elif bytes:
            self.dispatcher.set_terminator(bytes)
        else:
            raise ValueError('Must specify terminator or bytes')
        self.__max_data = max_data
        self.__data = []
        self.__data_len = 0

    def collect_incoming_data(self, data):
        self.__data.append(data)
        self.__data_len += len(data)
        if self.__max_data and self.__data_len > self.__max_data:
            logging.error("Max data reached (%s bytes)." % self.__max_data)
            self.dispatcher.handle_close()

    def found_terminator(self):
        if not (self.__max_data and self.__data_len > self.__max_data):
            self.dispatcher.set_terminator(None)
            data = ''.join(self.__data)
            self.__data = []
            self.__data_len = 0
            self.dispatcher.clear_timeout()
            self.dispatcher.trampoline.continue_from_yield(data)

class until_sent(CoroutineInstruction):
    def do_execute(self, content, timeout = DEFAULT_TIMEOUT):
        self.dispatcher.set_timeout(timeout)
        self.dispatcher.push(content)

    def handle_initiate_send(self):
        if ((not self.dispatcher.ac_out_buffer) and
            (len(self.dispatcher.producer_fifo) == 0) and
            self.dispatcher.connected):
            self.dispatcher.clear_timeout()
            self.dispatcher.trampoline.continue_from_yield()

class return_value(CoroutineInstruction):
    def do_execute(self, value):
        self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)

if __name__ == '__main__':
    import doctest

    doctest.testmod(verbose = True)