view cosocket.py @ 92:1414765e0929

Fix to client to deal with annoying no-content responses from the lighttpd proxy.
author Atul Varma <varmaa@toolness.com>
date Sat, 02 May 2009 01:04:30 +0000
parents 0b8c3a21335c
children 3be28af79baf
line wrap: on
line source

import sys
import socket
import asyncore
import asynchat
import types
import traceback
import time
import weakref
import logging

DEFAULT_TIMEOUT = 90.0
DEFAULT_MAX_DATA = 65536

time_map = {}

def loop(timeout = DEFAULT_TIMEOUT):
    start_time = time.time()
    while 1:
        asyncore.loop(timeout = timeout, count = 1)
        curr_time = time.time()
        time_elapsed = curr_time - start_time
        if time_elapsed > timeout:
            start_time = curr_time
            for func in time_map.values():
                func(time_elapsed)

class _AsyncChatCoroutineDispatcher(asynchat.async_chat):
    def __init__(self, coroutine, conn = None):
        asynchat.async_chat.__init__(self, conn)
        self.set_terminator(None)
        self.__max_data = DEFAULT_MAX_DATA
        self.__coroutine = coroutine
        self.__data = []
        self.__data_len = 0
        self.__coroutine_stack = []
        self.__timeout = 0
        self.__time_passed = 0
        if conn:
            self.continue_from_yield()

    def close_coroutine_and_return_to_caller(self, message):
        self.__close_coroutine(self.__coroutine)
        if self.__coroutine_stack:
            self.__coroutine = self.__coroutine_stack.pop()
            self.continue_from_yield(message)
        else:
            self.__coroutine = None

    def continue_from_yield(self, message = None,
                            exception_info = None):
        self.clear_timeout()
        try:
            if exception_info:
                instruction = self.__coroutine.throw(*exception_info)
            else:
                instruction = self.__coroutine.send(message)
        except StopIteration:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield()
            else:
                self.__coroutine = None
                self.handle_close()
        except Exception, e:
            if self.__coroutine_stack:
                self.__coroutine = self.__coroutine_stack.pop()
                self.continue_from_yield(exception_info = sys.exc_info())
            else:
                self.handle_error()
        else:
            if type(instruction) == types.GeneratorType:
                self.__coroutine_stack.append(self.__coroutine)
                self.__coroutine = instruction
                self.continue_from_yield()
            else:
                instruction.execute(self)

    def __close_coroutine(self, coroutine):
        try:
            coroutine.close()
        except Exception:
            self.log_info(traceback.format_exc(), 'error')

    def get_coroutine_stack_frames(self):
        return [coroutine.gi_frame
                for coroutine in self.__coroutine_stack]

    def get_formatted_coroutine_traceback(self):
        lines = ['Coroutine traceback (most recent call last):']
        for frame in self.get_coroutine_stack_frames():
            name = frame.f_code.co_name
            filename = frame.f_code.co_filename
            lineno = frame.f_lineno
            lines.append('File "%s", line %d, in coroutine %s' %
                          (filename, lineno, name))
        return '\n'.join(lines)

    def handle_close(self):
        if self.__coroutine:
            # Pass an exception back into the coroutine to kick
            # it out of whatever yielding state it's in.
            self.__close_coroutine(self.__coroutine)
            self.__coroutine = None
            while self.__coroutine_stack:
                self.__close_coroutine(self.__coroutine_stack.pop())
        self.clear_timeout()
        self.close()

    def log_info(self, message, type='info'):
        try:
            level = getattr(logging, type.upper())
        except AttributeError:
            level = logging.INFO
        logging.log(level, message)

    def handle_error(self):
        self.log_info(traceback.format_exc() +
                      self.get_formatted_coroutine_traceback(),
                      'error')

    def handle_connect(self):
        self.continue_from_yield()

    def initiate_send(self):
        asynchat.async_chat.initiate_send(self)
        if ((not self.ac_out_buffer) and
            (len(self.producer_fifo) == 0) and
            self.connected):
            self.continue_from_yield()

    def __on_tick(self, time_elapsed):
        self.__time_passed += time_elapsed
        if self.__time_passed > self.__timeout:
            self.log_info("Timeout expired (%ss)." % self.__timeout,
                          'error')
            self.handle_close()

    def clear_timeout(self):
        self.__timeout = 0
        self.__time_passed = 0
        if self.__on_tick in time_map:
            del time_map[self.__on_tick]

    def set_timeout(self, timeout):
        self.__timeout = timeout
        self.__time_passed = 0
        time_map[self.__on_tick] = self.__on_tick

    def collect_incoming_data(self, data):
        self.__data.append(data)
        self.__data_len += len(data)
        if self.__max_data and self.__data_len > self.__max_data:
            self.log_info("Max data reached (%s bytes)." % self.__max_data,
                          'error')
            self.handle_close()

    def set_max_data(self, amount):
        self.__max_data = amount

    def found_terminator(self):
        if not (self.__max_data and self.__data_len > self.__max_data):
            self.set_terminator(None)
            data = ''.join(self.__data)
            self.__data = []
            self.__data_len = 0
            self.continue_from_yield(data)

class CoroutineSocketServer(asyncore.dispatcher):
    def __init__(self, addr, coroutineFactory):
        asyncore.dispatcher.__init__(self)
        self.__coroutineFactory = coroutineFactory
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(addr)
        self.listen(1)

    def run(self, timeout = DEFAULT_TIMEOUT):
        loop(timeout)

    def handle_accept(self):
        conn, addr = self.accept()
        coroutine = self.__coroutineFactory(addr)
        _AsyncChatCoroutineDispatcher(coroutine, conn)

class CoroutineSocketClient(_AsyncChatCoroutineDispatcher):
    def __init__(self, addr, coroutineFactory):
        coroutine = coroutineFactory(addr)
        _AsyncChatCoroutineDispatcher.__init__(self, coroutine)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.connect(addr)

# Instructions that coroutines yield.

class until_received(object):
    def __init__(self, terminator = None, bytes = None,
                 timeout = DEFAULT_TIMEOUT, max_data = DEFAULT_MAX_DATA):
        self._timeout = timeout
        if terminator:
            self._terminator = terminator
            self._max_data = max_data
        elif bytes:
            self._terminator = bytes
            self._max_data = 0
        else:
            raise ValueError()

    def execute(self, dispatcher):
        dispatcher.set_timeout(self._timeout)
        dispatcher.set_max_data(self._max_data)
        dispatcher.set_terminator(self._terminator)

class until_sent(object):
    def __init__(self, content, timeout = DEFAULT_TIMEOUT):
        if not content:
            raise ValueError(content)
        self._timeout = timeout
        self.content = content

    def execute(self, dispatcher):
        dispatcher.set_timeout(self._timeout)
        dispatcher.push(self.content)

class return_value(object):
    def __init__(self, value):
        self.value = value

    def execute(self, dispatcher):
        dispatcher.close_coroutine_and_return_to_caller(self.value)