changeset 34:b4fab248d1eb

Added a bunch of files from http://hg.toolness.com/cosocket
author Atul Varma <varmaa@toolness.com>
date Thu, 24 Dec 2009 15:40:39 -0800
parents 97e681243579
children 45d84e588d14
files bzapi_server.py channels.py cosocket.py media/html/index.html test_cosocket.py
diffstat 5 files changed, 611 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/bzapi_server.py	Thu Dec 24 15:40:39 2009 -0800
@@ -0,0 +1,215 @@
+import os
+import sys
+import math
+import re
+import httplib
+import cStringIO
+import mimetools
+import weakref
+import cgi
+import logging
+
+from cosocket import *
+import channels
+
+try:
+    import json
+except ImportError:
+    import simplejson as json
+
+KEEP_ALIVE_MAX_REQUESTS = 99
+KEEP_ALIVE_TIMEOUT = int(DEFAULT_TIMEOUT)
+MAX_MESSAGE_SIZE = 8192
+
+ROBOTS_TXT = "User-agent: *\r\nDisallow: /"
+
+def _parse_qs(querystring):
+    querydict = {}
+    cgi_querydict = cgi.parse_qs(querystring)
+    for key, value in cgi_querydict.items():
+        querydict[key] = cgi_querydict[key][0]
+    return querydict
+
+class BugzillaApiServer(object):
+    QUERYSTRING_TEMPLATE = re.compile('([^\?]*)\?(.*)')
+    REDIRECT_TEMPLATE = re.compile('\/([A-Za-z0-9_]+)$')
+    URL_TEMPLATE = re.compile('\/([A-Za-z0-9_]+)/(.*)')
+
+    BOUNDARY = "'''"
+
+    BLOCK_SIZE = 8192
+
+    MIME_TYPES = {'html' : 'text/html',
+                  'js' : 'text/javascript',
+                  'css' : 'text/css'}
+
+    def __init__(self, addr, media_dir, index_filename,
+                 is_keep_alive = False):
+        self._is_keep_alive = is_keep_alive
+        self._num_connections = 0
+        self._media_dir = media_dir
+        self._index_filename = index_filename
+        AsyncChatCoroutine(self._server_coroutine(addr))
+
+    def _until_http_response_sent(self, msg = '', mimetype = 'text/plain',
+                                  length = None, code = 200,
+                                  additional_headers = None):
+        headers = {'Content-Type': mimetype}
+        if self._is_keep_alive:
+            headers.update({'Keep-Alive': 'timeout=%d, max=%d' %
+                            (KEEP_ALIVE_TIMEOUT,
+                             KEEP_ALIVE_MAX_REQUESTS),
+                            'Connection': 'Keep-Alive'})
+        if additional_headers:
+            headers.update(additional_headers)
+        if length is None:
+            length = len(msg)
+        headers['Content-Length'] = str(length)
+
+        header_lines = ['HTTP/1.1 %d %s' % (code,
+                                            httplib.responses[code])]
+        header_lines.extend(['%s: %s' % (key, value)
+                             for key, value in headers.items()])
+        header_lines.extend(['', msg])
+        content = '\r\n'.join(header_lines)
+        yield until_sent(content)
+
+    def _until_file_sent(self, filename):
+        mimetype = self.MIME_TYPES[filename.split('.')[-1]]
+
+        length = os.stat(filename).st_size
+        num_blocks = length / self.BLOCK_SIZE
+        if length % self.BLOCK_SIZE:
+            num_blocks += 1
+        infile = open(filename, 'r')
+
+        yield self._until_http_response_sent(mimetype = mimetype,
+                                             length = length)
+
+        for i in range(num_blocks):
+            # TODO: This could be bad since we're reading the file
+            # synchronously.
+            block = infile.read(self.BLOCK_SIZE)
+            yield until_sent(block)
+
+    def _server_coroutine(self, bind_addr):
+        yield until_listening(bind_addr)
+        while 1:
+            conn, addr = yield until_connection_accepted()
+            AsyncChatCoroutine(self._connection_coroutine(addr), conn)
+
+    def _connection_coroutine(self, addr):
+        self._num_connections += 1
+        try:
+            if self._is_keep_alive:
+                for i in range(KEEP_ALIVE_MAX_REQUESTS):
+                    yield self._until_one_request_processed(addr)
+            else:
+                yield self._until_one_request_processed(addr)
+        finally:
+            logging.info('Closing connection to %s' % repr(addr))
+            self._num_connections -= 1
+
+    def _until_conv_request_processed(self, addr, headers, method,
+                                      conv_name, page):
+        match = self.QUERYSTRING_TEMPLATE.match(page)
+        querydict = {}
+        if match:
+            querydict.update(_parse_qs(match.group(2)))
+            page = match.group(1)
+        if page == 'listen':
+            logging.info("Waiting for message on channel '%s' for %s" %
+                         (conv_name, addr))
+            msg = yield channels.until_message_received(conv_name)
+            yield self._until_http_response_sent(
+                json.dumps(msg),
+                mimetype = 'application/json'
+                )
+        elif page == 'send':
+            length = int(headers.getheader('Content-Length', 0))
+            if length == 0 or length > MAX_MESSAGE_SIZE:
+                yield self._until_http_response_sent('message too large',
+                                                     code = 413)
+            else:
+                msg = yield until_received(bytes = length)
+                json_msg = json.loads(msg)
+                yield channels.until_message_sent(conv_name, json_msg)
+                yield self._until_http_response_sent('sent.')
+        else:
+            yield self._until_http_response_sent('not found',
+                                                 code = 404)
+
+    def _until_one_request_processed(self, addr):
+        request = yield until_received(terminator = '\r\n\r\n')
+        request = request.splitlines()
+        request_line = request[0]
+        logging.info("Request from %s: %s" % (addr, request_line))
+        stringfile = cStringIO.StringIO('\n'.join(request[1:]))
+        headers = mimetools.Message(stringfile)
+        req_parts = request_line.split()
+        method = req_parts[0]
+        match = self.URL_TEMPLATE.match(req_parts[1])
+
+        if req_parts[1] == '/':
+            yield self._until_file_sent(self._index_filename)
+        elif not match:
+            match = self.REDIRECT_TEMPLATE.match(req_parts[1])
+            if match:
+                newpath = req_parts[1] + '/'
+                yield self._until_http_response_sent(
+                    newpath,
+                    code = 301,
+                    additional_headers = {'Location': newpath}
+                    )
+            elif req_parts[1] == '/robots.txt':
+                yield self._until_http_response_sent(ROBOTS_TXT)
+            else:
+                yield self._until_http_response_sent('not found',
+                                                     code = 404)
+        else:
+            conv_name = match.group(1)
+            page = match.group(2)
+            if conv_name == 'status':
+                # TODO: Return 404 if page is non-empty.
+                lines = ('open connections  : %d' % self._num_connections,
+                         'open timers       : %d' % len(time_map))
+                yield self._until_http_response_sent('\r\n'.join(lines))
+            elif conv_name == 'media':
+                path = os.path.join(self._media_dir, *page.split('/'))
+                path = os.path.normpath(path)
+                if (path.startswith(self._media_dir) and
+                    os.path.exists(path) and
+                    os.path.isfile(path)):
+                    yield self._until_file_sent(path)
+                else:
+                    yield self._until_http_response_sent('not found',
+                                                         code = 404)
+            else:
+                yield self._until_conv_request_processed(addr, headers,
+                                                         method, conv_name,
+                                                         page)
+
+if __name__ == '__main__':
+    args = dict(ip = '127.0.0.1',
+                port = 8071,
+                is_keep_alive = True,
+                logfile = '',
+                loglevel = 'info',
+                media_dir = os.path.abspath('media'))
+
+    args['index_filename'] = os.path.join(args['media_dir'], 'html',
+                                          'index.html')
+    args['loglevel'] = getattr(logging, args['loglevel'].upper())
+    if args['logfile']:
+        logging.basicConfig(filename = args['logfile'],
+                            level = args['loglevel'])
+    else:
+        logging.basicConfig(stream = sys.stdout,
+                            level = args['loglevel'])
+
+    server = BugzillaApiServer(addr = (args['ip'], args['port']),
+                               is_keep_alive = args['is_keep_alive'],
+                               media_dir = args['media_dir'],
+                               index_filename = args['index_filename'])
+    logging.info("Starting server with configuration: %s" % args)
+    loop()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/channels.py	Thu Dec 24 15:40:39 2009 -0800
@@ -0,0 +1,38 @@
+import cosocket
+
+_channels = {}
+
+class until_message_sent(cosocket.CoroutineInstruction):
+    def do_execute(self, channel_name, message):
+        if channel_name in _channels:
+            receivers = _channels[channel_name].values()
+            del _channels[channel_name]
+            for receiver in receivers:
+                receiver.trampoline.continue_from_yield(message)
+        self.dispatcher.trampoline.continue_from_yield()
+
+class _until_message_received(cosocket.CoroutineInstruction):
+    def do_execute(self, channel_name,
+                   timeout = cosocket.DEFAULT_TIMEOUT):
+        if channel_name not in _channels:
+            _channels[channel_name] = {}
+        fd = self.dispatcher.socket.fileno()
+        _channels[channel_name][fd] = self.dispatcher
+        self.dispatcher.set_timeout(timeout)
+        self.channel_name = channel_name
+        self._fd = fd
+
+    def finalize(self):
+        if (self.channel_name in _channels and
+            self._fd in _channels[self.channel_name]):
+            del _channels[self.channel_name][self._fd]
+            if not _channels[self.channel_name]:
+                del _channels[self.channel_name]
+
+def until_message_received(channel_name):
+    instruction = _until_message_received(channel_name)
+    try:
+        message = yield instruction
+        yield cosocket.return_value(message)
+    finally:
+        instruction.finalize()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/cosocket.py	Thu Dec 24 15:40:39 2009 -0800
@@ -0,0 +1,314 @@
+'''
+    >>> IP = '127.0.0.1'
+    >>> PORT = 38424
+
+    >>> def server_coroutine():
+    ...     print 'server now listening.'
+    ...     yield until_listening((IP, PORT))
+    ...     print 'server now spawning client.'
+    ...     AsyncChatCoroutine(client_coroutine())
+    ...     print 'server now accepting connections.'
+    ...     conn, addr = yield until_connection_accepted()
+    ...     print 'server now spawning connection.'
+    ...     AsyncChatCoroutine(connection_coroutine(addr), conn)
+
+    >>> def client_coroutine():
+    ...     print 'client now connecting to server.'
+    ...     yield until_connected((IP, PORT))
+    ...     print 'client now connected, sending text.'
+    ...     yield until_sent('hai2u\\r\\n')
+    ...     print 'client waiting for response.'
+    ...     data = yield until_received(terminator = '\\r\\n')
+    ...     print 'client received response: %s' % data
+
+    >>> def connection_coroutine(addr):
+    ...     print 'server connection waiting for request.'
+    ...     data = yield until_received('\\r\\n')
+    ...     print 'server connection sending back response: %s' % data
+    ...     yield until_sent(data + '\\r\\n')
+
+    >>> server = AsyncChatCoroutine(server_coroutine())
+    server now listening.
+    server now spawning client.
+    client now connecting to server.
+    server now accepting connections.
+
+    >>> loop()
+    server now spawning connection.
+    server connection waiting for request.
+    client now connected, sending text.
+    client waiting for response.
+    server connection sending back response: hai2u
+    client received response: hai2u
+
+'''
+
+import sys
+import socket
+import asyncore
+import asynchat
+import types
+import traceback
+import time
+import weakref
+import logging
+
+DEFAULT_LOOP_TIMEOUT = 1.0
+DEFAULT_TIMEOUT = 90.0
+DEFAULT_MAX_DATA = 65536
+
+time_map = {}
+
+def loop(timeout = DEFAULT_LOOP_TIMEOUT):
+    start_time = time.time()
+    while asyncore.socket_map or time_map:
+        asyncore.loop(timeout = timeout, count = 1)
+        curr_time = time.time()
+        time_elapsed = curr_time - start_time
+        if time_elapsed > timeout:
+            start_time = curr_time
+            funcs_to_call = []
+            for func in time_map:
+                time_map[func] -= time_elapsed
+                if time_map[func] <= 0:
+                    funcs_to_call.append(func)
+            for func in funcs_to_call:
+                del time_map[func]
+            for func in funcs_to_call:
+                try:
+                    func()
+                except:
+                    logging.error(traceback.format_exc())
+
+class _Trampoline(object):
+    def __init__(self, coroutine, handler):
+        self.__handler = handler
+        self.__coroutine = coroutine
+        self.__coroutine_stack = []
+
+    def __log_error(self):
+        logging.error(traceback.format_exc() +
+                      self.get_formatted_coroutine_traceback())
+
+    def __close_coroutine(self, coroutine):
+        try:
+            coroutine.close()
+        except Exception:
+            self.__log_error()
+
+    def get_formatted_coroutine_traceback(self):
+        if not self.__coroutine:
+            return ""
+        lines = []
+        frames = [coroutine.gi_frame
+                  for coroutine in self.__coroutine_stack]
+        if self.__coroutine.gi_frame:
+            frames.append(self.__coroutine.gi_frame)
+        for frame in frames:
+            name = frame.f_code.co_name
+            filename = frame.f_code.co_filename
+            lineno = frame.f_lineno
+            lines.append('File "%s", line %d, in coroutine %s' %
+                          (filename, lineno, name))
+        if not lines:
+            return 'No coroutine traceback available.'
+        lines.insert(0, 'Coroutine traceback (most recent call last):')
+        return '\n'.join(lines)
+
+    def close_coroutine_stack(self):
+        if self.__coroutine:
+            # Pass an exception back into the coroutine to kick
+            # it out of whatever yielding state it's in.
+            self.__close_coroutine(self.__coroutine)
+            self.__coroutine = None
+            while self.__coroutine_stack:
+                self.__close_coroutine(self.__coroutine_stack.pop())
+
+    def close_coroutine_and_return_to_caller(self, message):
+        self.__close_coroutine(self.__coroutine)
+        if self.__coroutine_stack:
+            self.__coroutine = self.__coroutine_stack.pop()
+            self.continue_from_yield(message)
+        else:
+            self.__coroutine = None
+
+    def continue_from_yield(self, message = None, exception_info = None):
+        try:
+            if exception_info:
+                instruction = self.__coroutine.throw(*exception_info)
+            else:
+                instruction = self.__coroutine.send(message)
+        except StopIteration:
+            if self.__coroutine_stack:
+                self.__coroutine = self.__coroutine_stack.pop()
+                self.continue_from_yield()
+            else:
+                self.__coroutine = None
+                self.__handler.handle_coroutine_complete(None)
+        except Exception, e:
+            if self.__coroutine_stack:
+                self.__coroutine = self.__coroutine_stack.pop()
+                self.continue_from_yield(exception_info = sys.exc_info())
+            else:
+                self.__log_error()
+                self.__handler.handle_coroutine_complete(e)
+        else:
+            if type(instruction) == types.GeneratorType:
+                self.__coroutine_stack.append(self.__coroutine)
+                self.__coroutine = instruction
+                self.continue_from_yield()
+            else:
+                self.__handler.handle_coroutine_instruction(instruction)
+
+class AsyncChatCoroutine(asynchat.async_chat):
+    def __init__(self, coroutine, conn = None):
+        asynchat.async_chat.__init__(self, conn)
+        self.trampoline = _Trampoline(coroutine, self)
+        self.set_terminator(None)
+        self.trampoline.continue_from_yield()
+
+    def handle_coroutine_instruction(self, instruction):
+        self.__instruction = instruction
+        instruction.execute(self)
+
+    def handle_coroutine_complete(self, exception):
+        self.__instruction = None
+        if not exception:
+            self.handle_close()
+
+    def handle_close(self):
+        self.trampoline.close_coroutine_stack()
+        self.clear_timeout()
+        self.close()
+
+    def log_info(self, message, type='info'):
+        try:
+            level = getattr(logging, type.upper())
+        except AttributeError:
+            level = logging.INFO
+        logging.log(level, message)
+
+    def handle_error(self):
+        self.log_info(traceback.format_exc() +
+                      self.trampoline.get_formatted_coroutine_traceback(),
+                      'error')
+
+    def handle_accept(self):
+        self.__instruction.handle_accept()
+
+    def handle_connect(self):
+        self.__instruction.handle_connect()
+
+    def initiate_send(self):
+        asynchat.async_chat.initiate_send(self)
+        self.__instruction.handle_initiate_send()
+
+    def __on_timeout(self):
+        self.__instruction.handle_timeout()
+
+    def clear_timeout(self):
+        if self.__on_timeout in time_map:
+            del time_map[self.__on_timeout]
+
+    def set_timeout(self, timeout):
+        time_map[self.__on_timeout] = timeout
+
+    def collect_incoming_data(self, data):
+        self.__instruction.collect_incoming_data(data)
+
+    def found_terminator(self):
+        self.__instruction.found_terminator()
+
+# Instructions that coroutines yield.
+
+class CoroutineInstruction(object):
+    def __init__(self, *args, **kwargs):
+        self.__args = args
+        self.__kwargs = kwargs
+
+    def execute(self, dispatcher):
+        self.dispatcher = dispatcher
+        self.do_execute(*self.__args, **self.__kwargs)
+
+    def handle_timeout(self):
+        self.dispatcher.handle_close()
+
+class until_listening(CoroutineInstruction):
+    def do_execute(self, bind_addr):
+        self.dispatcher.create_socket(socket.AF_INET,
+                                      socket.SOCK_STREAM)
+        self.dispatcher.set_reuse_addr()
+        self.dispatcher.bind(bind_addr)
+        self.dispatcher.listen(1)
+        self.dispatcher.trampoline.continue_from_yield()
+
+class until_connection_accepted(CoroutineInstruction):
+    def do_execute(self):
+        pass
+
+    def handle_accept(self):
+        data = self.dispatcher.accept()
+        self.dispatcher.trampoline.continue_from_yield(data)
+
+class until_connected(CoroutineInstruction):
+    def do_execute(self, addr):
+        self.dispatcher.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.dispatcher.connect(addr)
+
+    def handle_connect(self):
+        self.dispatcher.trampoline.continue_from_yield()
+
+class until_received(CoroutineInstruction):
+    def do_execute(self,
+                   terminator = None,
+                   bytes = None,
+                   timeout = DEFAULT_TIMEOUT,
+                   max_data = DEFAULT_MAX_DATA):
+        self.dispatcher.set_timeout(timeout)
+        if terminator:
+            max_data = 0
+            self.dispatcher.set_terminator(terminator)
+        elif bytes:
+            self.dispatcher.set_terminator(bytes)
+        else:
+            raise ValueError('Must specify terminator or bytes')
+        self.__max_data = max_data
+        self.__data = []
+        self.__data_len = 0
+
+    def collect_incoming_data(self, data):
+        self.__data.append(data)
+        self.__data_len += len(data)
+        if self.__max_data and self.__data_len > self.__max_data:
+            logging.error("Max data reached (%s bytes)." % self.__max_data)
+            self.dispatcher.handle_close()
+
+    def found_terminator(self):
+        if not (self.__max_data and self.__data_len > self.__max_data):
+            self.dispatcher.set_terminator(None)
+            data = ''.join(self.__data)
+            self.__data = []
+            self.__data_len = 0
+            self.dispatcher.clear_timeout()
+            self.dispatcher.trampoline.continue_from_yield(data)
+
+class until_sent(CoroutineInstruction):
+    def do_execute(self, content, timeout = DEFAULT_TIMEOUT):
+        self.dispatcher.set_timeout(timeout)
+        self.dispatcher.push(content)
+
+    def handle_initiate_send(self):
+        if ((not self.dispatcher.ac_out_buffer) and
+            (len(self.dispatcher.producer_fifo) == 0) and
+            self.dispatcher.connected):
+            self.dispatcher.clear_timeout()
+            self.dispatcher.trampoline.continue_from_yield()
+
+class return_value(CoroutineInstruction):
+    def do_execute(self, value):
+        self.dispatcher.trampoline.close_coroutine_and_return_to_caller(value)
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod(verbose = True)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/media/html/index.html	Thu Dec 24 15:40:39 2009 -0800
@@ -0,0 +1,9 @@
+<html>
+<head>
+  <meta http-equiv="Content-type" content="text/html; charset=utf-8" />
+  <title>bzapi</title>
+</head>
+<body>
+TODO: Put stuff here.
+</body>
+</html>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test_cosocket.py	Thu Dec 24 15:40:39 2009 -0800
@@ -0,0 +1,35 @@
+import cosocket
+import unittest
+
+class Tests(unittest.TestCase):
+    PORT = 38424
+    IP = '127.0.0.1'
+
+    def testSimple(self):
+        done = dict(client = 0, server = 0, connection = 0)
+
+        def server_coroutine():
+            yield cosocket.until_listening((self.IP, self.PORT))
+            cosocket.AsyncChatCoroutine(client_coroutine())
+            conn, addr = yield cosocket.until_connection_accepted()
+            cosocket.AsyncChatCoroutine(connection_coroutine(addr), conn)
+            done['server'] += 1
+
+        def client_coroutine():
+            yield cosocket.until_connected((self.IP, self.PORT))
+            yield cosocket.until_sent('hai2u\r\n')
+            data = yield cosocket.until_received(terminator = '\r\n')
+            self.assertEqual(data, 'hai2u')
+            done['client'] += 1
+
+        def connection_coroutine(addr):
+            data = yield cosocket.until_received('\r\n')
+            yield cosocket.until_sent(data + '\r\n')
+            done['connection'] += 1
+
+        cosocket.AsyncChatCoroutine(server_coroutine())
+        cosocket.loop()
+        self.assertEqual(done['server'], 1)
+
+if __name__ == '__main__':
+    unittest.main()