view cosocket.py @ 97:0d3dd2ab36cd

Factored out all instruction-specific logic from the dispatcher into individual instructions.
author Atul Varma <varmaa@toolness.com>
date Sat, 02 May 2009 00:14:59 -0700
parents 68598f164855
children 06aa973a54c3
line wrap: on
line source

import sys
import socket
import asyncore
import asynchat
import types
import traceback
import time
import weakref
import logging

DEFAULT_LOOP_TIMEOUT = 1.0
DEFAULT_TIMEOUT = 90.0
DEFAULT_MAX_DATA = 65536

time_map = {}

def loop(timeout = DEFAULT_LOOP_TIMEOUT):
    start_time = time.time()
    while 1:
        asyncore.loop(timeout = timeout, count = 1)
        curr_time = time.time()
        time_elapsed = curr_time - start_time
        if time_elapsed > timeout:
            start_time = curr_time
            funcs_to_call = []
            for func in time_map:
                time_map[func] -= time_elapsed
                if time_map[func] <= 0:
                    funcs_to_call.append(func)
            for func in funcs_to_call:
                del time_map[func]
            for func in funcs_to_call:
                try:
                    func()
                except:
                    logging.error(traceback.format_exc())

class _Trampoline(object):
    def __init__(self, coroutine, handler):
        self.__handler = handler
        self.__coroutine = coroutine
        self.__coroutine_stack = []

    def __log_error(self):
        logging.error(traceback.format_exc() +
                      self.get_formatted_coroutine_traceback())

    def __close_coroutine(self, coroutine):
        try:
            coroutine.close()
        except Exception:
            self.__log_error()

    def get_formatted_coroutine_traceback(self):
        if not self.__coroutine:
            return ""
        lines = ['Coroutine traceback (most recent call last):']
        frames = [coroutine.gi_frame
                  for coroutine in self.__coroutine_stack]
        frames.append(self.__coroutine.gi_frame)
        for frame in frames:
            name = frame.f_code.co_name
            filename = frame.f_code.co_filename
            lineno = frame.f_lineno
            lines.append('File "%s", line %d, in coroutine %s' %
                          (filename, lineno, name))
        return '\n'.join(lines)

    def close_coroutine_stack(self):
        if self.__coroutine:
            # Pass an exception back into the coroutine to kick
            # it out of whatever yielding state it's in.
            self.__close_coroutine(self.__coroutine)
            self.__coroutine = None
            while self.__coroutine_stack:
                self.__close_coroutine(self.__coroutine_stack.pop())

    def close_coroutine_and_return_to_caller(self, message):
        self.__close_coroutine(self.__coroutine)
        if self.__coroutine_stack:
            self.__coroutine = self.__coroutine_stack.pop()
            self.continue_from_yield(message)
        else:
            self.__coroutine = None

    def continue_from_yield(self, message = None, exception_info = None):
        try:
            if exception_info:
                instruction = self.__coroutine.throw(*exception_info)
            else:
                instruction = self.__coroutine.send(message)
        except StopIteration:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield()
            else:
                self.__coroutine = None
                self.__handler.handle_coroutine_complete(None)
        except Exception, e:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield(exception_info = sys.exc_info())
            else:
                self.__log_error()
                self.__handler.handle_coroutine_complete(e)
        else:
            if type(instruction) == types.GeneratorType:
                self.__coroutine_stack.append(self.__coroutine)
                self.__coroutine = instruction
                self.continue_from_yield()
            else:
                self.__handler.handle_coroutine_instruction(instruction)

class _AsyncChatCoroutineDispatcher(asynchat.async_chat):
    def __init__(self, coroutine, conn = None):
        asynchat.async_chat.__init__(self, conn)
        self.trampoline = _Trampoline(coroutine, self)
        self.set_terminator(None)
        if conn:
            self.trampoline.continue_from_yield()

    def handle_coroutine_instruction(self, instruction):
        self.__instruction = instruction
        instruction.execute(self)

    def handle_coroutine_complete(self, exception):
        self.__instruction = None
        if not exception:
            self.handle_close()

    def handle_close(self):
        self.trampoline.close_coroutine_stack()
        self.clear_timeout()
        self.close()

    def log_info(self, message, type='info'):
        try:
            level = getattr(logging, type.upper())
        except AttributeError:
            level = logging.INFO
        logging.log(level, message)

    def handle_error(self):
        self.log_info(traceback.format_exc() +
                      self.trampoline.get_formatted_coroutine_traceback(),
                      'error')

    def handle_connect(self):
        self.__instruction.handle_connect()

    def initiate_send(self):
        asynchat.async_chat.initiate_send(self)
        self.__instruction.handle_initiate_send()

    def __on_timeout(self):
        self.__instruction.handle_timeout()

    def clear_timeout(self):
        if self.__on_timeout in time_map:
            del time_map[self.__on_timeout]

    def set_timeout(self, timeout):
        time_map[self.__on_timeout] = timeout

    def collect_incoming_data(self, data):
        self.__instruction.collect_incoming_data(data)

    def found_terminator(self):
        self.__instruction.found_terminator()

class CoroutineSocketServer(asyncore.dispatcher):
    def __init__(self, addr, coroutineFactory):
        asyncore.dispatcher.__init__(self)
        self.__coroutineFactory = coroutineFactory
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(addr)
        self.listen(1)

    def run(self, timeout = DEFAULT_LOOP_TIMEOUT):
        loop(timeout)

    def handle_accept(self):
        conn, addr = self.accept()
        coroutine = self.__coroutineFactory(addr)
        _AsyncChatCoroutineDispatcher(coroutine, conn)

class CoroutineSocketClient(_AsyncChatCoroutineDispatcher):
    def __init__(self, addr, coroutineFactory):
        coroutine = coroutineFactory(addr)
        _AsyncChatCoroutineDispatcher.__init__(self, coroutine)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.connect(addr)

# Instructions that coroutines yield.

class CoroutineInstruction(object):
    def __init__(self, *args, **kwargs):
        self.__args = args
        self.__kwargs = kwargs

    def execute(self, dispatcher):
        self.dispatcher = dispatcher
        self.do_execute(*self.__args, **self.__kwargs)

    def handle_timeout(self):
        self.dispatcher.handle_close()

class until_received(CoroutineInstruction):
    def do_execute(self,
                   terminator = None,
                   bytes = None,
                   timeout = DEFAULT_TIMEOUT,
                   max_data = DEFAULT_MAX_DATA):
        self.dispatcher.set_timeout(timeout)
        if terminator:
            max_data = 0
            self.dispatcher.set_terminator(terminator)
        elif bytes:
            self.dispatcher.set_terminator(bytes)
        else:
            raise ValueError('Must specify terminator or bytes')
        self.__max_data = max_data
        self.__data = []
        self.__data_len = 0

    def collect_incoming_data(self, data):
        self.__data.append(data)
        self.__data_len += len(data)
        if self.__max_data and self.__data_len > self.__max_data:
            logging.error("Max data reached (%s bytes)." % self.__max_data)
            self.dispatcher.handle_close()

    def found_terminator(self):
        if not (self.__max_data and self.__data_len > self.__max_data):
            self.dispatcher.set_terminator(None)
            data = ''.join(self.__data)
            self.__data = []
            self.__data_len = 0
            self.dispatcher.clear_timeout()
            self.dispatcher.trampoline.continue_from_yield(data)

class until_sent(CoroutineInstruction):
    def do_execute(self, content, timeout = DEFAULT_TIMEOUT):
        self.dispatcher.set_timeout(timeout)
        self.dispatcher.push(content)

    def handle_initiate_send(self):
        if ((not self.dispatcher.ac_out_buffer) and
            (len(self.dispatcher.producer_fifo) == 0) and
            self.dispatcher.connected):
            self.dispatcher.clear_timeout()
            self.dispatcher.trampoline.continue_from_yield()

class return_value(CoroutineInstruction):
    def do_execute(self, value):
        self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)