diff options
Diffstat (limited to 'net/tools')
| -rw-r--r-- | net/tools/fetch/fetch_client.cc | 2 | ||||
| -rw-r--r-- | net/tools/fetch/http_listen_socket.cc | 4 | ||||
| -rwxr-xr-x[-rw-r--r--] | net/tools/testserver/testserver.py | 73 | ||||
| -rw-r--r-- | net/tools/testserver/xmppserver.py | 527 | ||||
| -rw-r--r-- | net/tools/testserver/xmppserver_test.py | 250 |
5 files changed, 832 insertions, 24 deletions
diff --git a/net/tools/fetch/fetch_client.cc b/net/tools/fetch/fetch_client.cc index 42949c8..3bdbcbf 100644 --- a/net/tools/fetch/fetch_client.cc +++ b/net/tools/fetch/fetch_client.cc @@ -137,7 +137,7 @@ int main(int argc, char**argv) { scoped_ptr<net::HostResolver> host_resolver( net::CreateSystemHostResolver(net::HostResolver::kDefaultParallelism, - NULL)); + NULL, NULL)); scoped_refptr<net::ProxyService> proxy_service( net::ProxyService::CreateDirect()); diff --git a/net/tools/fetch/http_listen_socket.cc b/net/tools/fetch/http_listen_socket.cc index fd788c8..0db714f 100644 --- a/net/tools/fetch/http_listen_socket.cc +++ b/net/tools/fetch/http_listen_socket.cc @@ -30,8 +30,8 @@ void HttpListenSocket::Accept() { if (conn == ListenSocket::kInvalidSocket) { // TODO } else { - scoped_refptr<HttpListenSocket> sock = - new HttpListenSocket(conn, delegate_); + scoped_refptr<HttpListenSocket> sock( + new HttpListenSocket(conn, delegate_)); // it's up to the delegate to AddRef if it wants to keep it around DidAccept(this, sock); } diff --git a/net/tools/testserver/testserver.py b/net/tools/testserver/testserver.py index c3fe86b..55aa6a9 100644..100755 --- a/net/tools/testserver/testserver.py +++ b/net/tools/testserver/testserver.py @@ -21,7 +21,7 @@ import shutil import SocketServer import sys import time -import urllib2 +import urlparse import warnings # Ignore deprecation warnings, they make our output more cluttered. @@ -64,7 +64,7 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer): """This is a specialization of StoppableHTTPerver that add https support.""" def __init__(self, server_address, request_hander_class, cert_path, - ssl_client_auth, ssl_client_cas): + ssl_client_auth, ssl_client_cas, ssl_bulk_ciphers): s = open(cert_path).read() x509 = tlslite.api.X509() x509.parse(s) @@ -78,6 +78,9 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer): x509 = tlslite.api.X509() x509.parse(s) self.ssl_client_cas.append(x509.subject) + self.ssl_handshake_settings = tlslite.api.HandshakeSettings() + if ssl_bulk_ciphers is not None: + self.ssl_handshake_settings.cipherNames = ssl_bulk_ciphers self.session_cache = tlslite.api.SessionCache() StoppableHTTPServer.__init__(self, server_address, request_hander_class) @@ -89,6 +92,7 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer): privateKey=self.private_key, sessionCache=self.session_cache, reqCert=self.ssl_client_auth, + settings=self.ssl_handshake_settings, reqCAs=self.ssl_client_cas) tlsConnection.ignoreAbruptClose = True return True @@ -569,6 +573,24 @@ class TestPageHandler(BaseHTTPServer.BaseHTTPRequestHandler): self.end_headers() return True + def _ReplaceFileData(self, data, query_parameters): + """Replaces matching substrings in a file. + + If the 'replace_orig' and 'replace_new' URL query parameters are present, + a new string is returned with all occasions of the 'replace_orig' value + replaced by the 'replace_new' value. + + If the parameters are not present, |data| is returned. + """ + query_dict = cgi.parse_qs(query_parameters) + orig_values = query_dict.get('replace_orig', []) + new_values = query_dict.get('replace_new', []) + if not orig_values or not new_values: + return data + orig_value = orig_values[0] + new_value = new_values[0] + return data.replace(orig_value, new_value) + def FileHandler(self): """This handler sends the contents of the requested file. Wow, it's like a real webserver!""" @@ -581,29 +603,27 @@ class TestPageHandler(BaseHTTPServer.BaseHTTPRequestHandler): if self.command == 'POST' or self.command == 'PUT' : self.rfile.read(int(self.headers.getheader('content-length'))) - file = self.path[len(prefix):] - if file.find('?') > -1: - # Ignore the query parameters entirely. - url, querystring = file.split('?') - else: - url = file - entries = url.split('/') - path = os.path.join(self.server.data_dir, *entries) - if os.path.isdir(path): - path = os.path.join(path, 'index.html') - - if not os.path.isfile(path): - print "File not found " + file + " full path:" + path + _, _, url_path, _, query, _ = urlparse.urlparse(self.path) + sub_path = url_path[len(prefix):] + entries = sub_path.split('/') + file_path = os.path.join(self.server.data_dir, *entries) + if os.path.isdir(file_path): + file_path = os.path.join(file_path, 'index.html') + + if not os.path.isfile(file_path): + print "File not found " + sub_path + " full path:" + file_path self.send_error(404) return True - f = open(path, "rb") + f = open(file_path, "rb") data = f.read() f.close() + data = self._ReplaceFileData(data, query) + # If file.mock-http-headers exists, it contains the headers we # should send. Read them in and parse them. - headers_path = path + '.mock-http-headers' + headers_path = file_path + '.mock-http-headers' if os.path.isfile(headers_path): f = open(headers_path, "r") @@ -623,7 +643,7 @@ class TestPageHandler(BaseHTTPServer.BaseHTTPRequestHandler): # Could be more generic once we support mime-type sniffing, but for # now we need to set it explicitly. self.send_response(200) - self.send_header('Content-type', self.GetMIMETypeFromName(file)) + self.send_header('Content-type', self.GetMIMETypeFromName(file_path)) self.send_header('Content-Length', len(data)) self.end_headers() @@ -1169,7 +1189,8 @@ def main(options, args): ' exiting...' return server = HTTPSServer(('127.0.0.1', port), TestPageHandler, options.cert, - options.ssl_client_auth, options.ssl_client_ca) + options.ssl_client_auth, options.ssl_client_ca, + options.ssl_bulk_cipher) print 'HTTPS server started on port %d...' % port else: server = StoppableHTTPServer(('127.0.0.1', port), TestPageHandler) @@ -1240,8 +1261,18 @@ if __name__ == '__main__': help='Require SSL client auth on every connection.') option_parser.add_option('', '--ssl-client-ca', action='append', default=[], help='Specify that the client certificate request ' - 'should indicate that it supports the CA contained ' - 'in the specified certificate file') + 'should include the CA named in the subject of ' + 'the DER-encoded certificate contained in the ' + 'specified file. This option may appear multiple ' + 'times, indicating multiple CA names should be ' + 'sent in the request.') + option_parser.add_option('', '--ssl-bulk-cipher', action='append', + help='Specify the bulk encryption algorithm(s)' + 'that will be accepted by the SSL server. Valid ' + 'values are "aes256", "aes128", "3des", "rc4". If ' + 'omitted, all algorithms will be used. This ' + 'option may appear multiple times, indicating ' + 'multiple algorithms should be enabled.'); option_parser.add_option('', '--file-root-url', default='/files/', help='Specify a root URL for files served.') option_parser.add_option('', '--startup-pipe', type='int', diff --git a/net/tools/testserver/xmppserver.py b/net/tools/testserver/xmppserver.py new file mode 100644 index 0000000..ad99571 --- /dev/null +++ b/net/tools/testserver/xmppserver.py @@ -0,0 +1,527 @@ +#!/usr/bin/python2.4 +# Copyright (c) 2010 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""A bare-bones and non-compliant XMPP server. + +Just enough of the protocol is implemented to get it to work with +Chrome's sync notification system. +""" + +import asynchat +import asyncore +import base64 +import re +import socket +from xml.dom import minidom + +# pychecker complains about the use of fileno(), which is implemented +# by asyncore by forwarding to an internal object via __getattr__. +__pychecker__ = 'no-classattr' + + +class Error(Exception): + """Error class for this module.""" + pass + + +class UnexpectedXml(Error): + """Raised when an unexpected XML element has been encountered.""" + + def __init__(self, xml_element): + xml_text = xml_element.toxml() + Error.__init__(self, 'Unexpected XML element', xml_text) + + +def ParseXml(xml_string): + """Parses the given string as XML and returns a minidom element + object. + """ + dom = minidom.parseString(xml_string) + + # minidom handles xmlns specially, but there's a bug where it sets + # the attribute value to None, which causes toxml() or toprettyxml() + # to break. + def FixMinidomXmlnsBug(xml_element): + if xml_element.getAttribute('xmlns') is None: + xml_element.setAttribute('xmlns', '') + + def ApplyToAllDescendantElements(xml_element, fn): + fn(xml_element) + for node in xml_element.childNodes: + if node.nodeType == node.ELEMENT_NODE: + ApplyToAllDescendantElements(node, fn) + + root = dom.documentElement + ApplyToAllDescendantElements(root, FixMinidomXmlnsBug) + return root + + +def CloneXml(xml): + """Returns a deep copy of the given XML element. + + Args: + xml: The XML element, which should be something returned from + ParseXml() (i.e., a root element). + """ + return xml.ownerDocument.cloneNode(True).documentElement + + +class StanzaParser(object): + """A hacky incremental XML parser. + + StanzaParser consumes data incrementally via FeedString() and feeds + its delegate complete parsed stanzas (i.e., XML documents) via + FeedStanza(). Any stanzas passed to FeedStanza() are unlinked after + the callback is done. + + Use like so: + + class MyClass(object): + ... + def __init__(self, ...): + ... + self._parser = StanzaParser(self) + ... + + def SomeFunction(self, ...): + ... + self._parser.FeedString(some_data) + ... + + def FeedStanza(self, stanza): + ... + print stanza.toprettyxml() + ... + """ + + # NOTE(akalin): The following regexps are naive, but necessary since + # none of the existing Python 2.4/2.5 XML libraries support + # incremental parsing. This works well enough for our purposes. + # + # The regexps below assume that any present XML element starts at + # the beginning of the string, but there may be trailing whitespace. + + # Matches an opening stream tag (e.g., '<stream:stream foo="bar">') + # (assumes that the stream XML namespace is defined in the tag). + _stream_re = re.compile(r'^(<stream:stream [^>]*>)\s*') + + # Matches an empty element tag (e.g., '<foo bar="baz"/>'). + _empty_element_re = re.compile(r'^(<[^>]*/>)\s*') + + # Matches a non-empty element (e.g., '<foo bar="baz">quux</foo>'). + # Does *not* handle nested elements. + _non_empty_element_re = re.compile(r'^(<([^ >]*)[^>]*>.*?</\2>)\s*') + + # The closing tag for a stream tag. We have to insert this + # ourselves since all XML stanzas are children of the stream tag, + # which is never closed until the connection is closed. + _stream_suffix = '</stream:stream>' + + def __init__(self, delegate): + self._buffer = '' + self._delegate = delegate + + def FeedString(self, data): + """Consumes the given string data, possibly feeding one or more + stanzas to the delegate. + """ + self._buffer += data + while (self._ProcessBuffer(self._stream_re, self._stream_suffix) or + self._ProcessBuffer(self._empty_element_re) or + self._ProcessBuffer(self._non_empty_element_re)): + pass + + def _ProcessBuffer(self, regexp, xml_suffix=''): + """If the buffer matches the given regexp, removes the match from + the buffer, appends the given suffix, parses it, and feeds it to + the delegate. + + Returns: + Whether or not the buffer matched the given regexp. + """ + results = regexp.match(self._buffer) + if not results: + return False + xml_text = self._buffer[:results.end()] + xml_suffix + self._buffer = self._buffer[results.end():] + stanza = ParseXml(xml_text) + self._delegate.FeedStanza(stanza) + # Needed because stanza may have cycles. + stanza.unlink() + return True + + +class Jid(object): + """Simple struct for an XMPP jid (essentially an e-mail address with + an optional resource string). + """ + + def __init__(self, username, domain, resource=''): + self.username = username + self.domain = domain + self.resource = resource + + def __str__(self): + jid_str = "%s@%s" % (self.username, self.domain) + if self.resource: + jid_str += '/' + self.resource + return jid_str + + def GetBareJid(self): + return Jid(self.username, self.domain) + + +class IdGenerator(object): + """Simple class to generate unique IDs for XMPP messages.""" + + def __init__(self, prefix): + self._prefix = prefix + self._id = 0 + + def GetNextId(self): + next_id = "%s.%s" % (self._prefix, self._id) + self._id += 1 + return next_id + + +class HandshakeTask(object): + """Class to handle the initial handshake with a connected XMPP + client. + """ + + # The handshake states in order. + (_INITIAL_STREAM_NEEDED, + _AUTH_NEEDED, + _AUTH_STREAM_NEEDED, + _BIND_NEEDED, + _SESSION_NEEDED, + _FINISHED) = range(6) + + # Used when in the _INITIAL_STREAM_NEEDED and _AUTH_STREAM_NEEDED + # states. Not an XML object as it's only the opening tag. + # + # The from and id attributes are filled in later. + _STREAM_DATA = ( + '<stream:stream from="%s" id="%s" ' + 'version="1.0" xmlns:stream="http://etherx.jabber.org/streams" ' + 'xmlns="jabber:client">') + + # Used when in the _INITIAL_STREAM_NEEDED state. + _AUTH_STANZA = ParseXml( + '<stream:features xmlns:stream="http://etherx.jabber.org/streams">' + ' <mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">' + ' <mechanism>PLAIN</mechanism>' + ' <mechanism>X-GOOGLE-TOKEN</mechanism>' + ' </mechanisms>' + '</stream:features>') + + # Used when in the _AUTH_NEEDED state. + _AUTH_SUCCESS_STANZA = ParseXml( + '<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + + # Used when in the _AUTH_STREAM_NEEDED state. + _BIND_STANZA = ParseXml( + '<stream:features xmlns:stream="http://etherx.jabber.org/streams">' + ' <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"/>' + ' <session xmlns="urn:ietf:params:xml:ns:xmpp-session"/>' + '</stream:features>') + + # Used when in the _BIND_NEEDED state. + # + # The id and jid attributes are filled in later. + _BIND_RESULT_STANZA = ParseXml( + '<iq id="" type="result">' + ' <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind">' + ' <jid/>' + ' </bind>' + '</iq>') + + # Used when in the _SESSION_NEEDED state. + # + # The id attribute is filled in later. + _IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>') + + def __init__(self, connection, id_generator, resource_prefix): + self._connection = connection + self._id_generator = id_generator + self._username = '' + self._domain = '' + self._jid = None + self._resource_prefix = resource_prefix + self._state = self._INITIAL_STREAM_NEEDED + + def FeedStanza(self, stanza): + """Inspects the given stanza and changes the handshake state if needed. + + Called when a stanza is received from the client. Inspects the + stanza to make sure it has the expected attributes given the + current state, advances the state if needed, and sends a reply to + the client if needed. + """ + def ExpectStanza(stanza, name): + if stanza.tagName != name: + raise UnexpectedXml(stanza) + + def ExpectIq(stanza, type, name): + ExpectStanza(stanza, 'iq') + if (stanza.getAttribute('type') != type or + stanza.firstChild.tagName != name): + raise UnexpectedXml(stanza) + + def GetStanzaId(stanza): + return stanza.getAttribute('id') + + def HandleStream(stanza): + ExpectStanza(stanza, 'stream:stream') + domain = stanza.getAttribute('to') + if domain: + self._domain = domain + SendStreamData() + + def SendStreamData(): + next_id = self._id_generator.GetNextId() + stream_data = self._STREAM_DATA % (self._domain, next_id) + self._connection.SendData(stream_data) + + def GetUserDomain(stanza): + encoded_username_password = stanza.firstChild.data + username_password = base64.b64decode(encoded_username_password) + (_, username_domain, _) = username_password.split('\0') + # The domain may be omitted. + # + # If we were using python 2.5, we'd be able to do: + # + # username, _, domain = username_domain.partition('@') + # if not domain: + # domain = self._domain + at_pos = username_domain.find('@') + if at_pos != -1: + username = username_domain[:at_pos] + domain = username_domain[at_pos+1:] + else: + username = username_domain + domain = self._domain + return (username, domain) + + if self._state == self._INITIAL_STREAM_NEEDED: + HandleStream(stanza) + self._connection.SendStanza(self._AUTH_STANZA, False) + self._state = self._AUTH_NEEDED + + elif self._state == self._AUTH_NEEDED: + ExpectStanza(stanza, 'auth') + (self._username, self._domain) = GetUserDomain(stanza) + self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False) + self._state = self._AUTH_STREAM_NEEDED + + elif self._state == self._AUTH_STREAM_NEEDED: + HandleStream(stanza) + self._connection.SendStanza(self._BIND_STANZA, False) + self._state = self._BIND_NEEDED + + elif self._state == self._BIND_NEEDED: + ExpectIq(stanza, 'set', 'bind') + stanza_id = GetStanzaId(stanza) + resource_element = stanza.getElementsByTagName('resource')[0] + resource = resource_element.firstChild.data + full_resource = '%s.%s' % (self._resource_prefix, resource) + response = CloneXml(self._BIND_RESULT_STANZA) + response.setAttribute('id', stanza_id) + self._jid = Jid(self._username, self._domain, full_resource) + jid_text = response.parentNode.createTextNode(str(self._jid)) + response.getElementsByTagName('jid')[0].appendChild(jid_text) + self._connection.SendStanza(response) + self._state = self._SESSION_NEEDED + + elif self._state == self._SESSION_NEEDED: + ExpectIq(stanza, 'set', 'session') + stanza_id = GetStanzaId(stanza) + xml = CloneXml(self._IQ_RESPONSE_STANZA) + xml.setAttribute('id', stanza_id) + self._connection.SendStanza(xml) + self._state = self._FINISHED + self._connection.HandshakeDone(self._jid) + + +def AddrString(addr): + return '%s:%d' % addr + + +class XmppConnection(asynchat.async_chat): + """A single XMPP client connection. + + This class handles the connection to a single XMPP client (via a + socket). It does the XMPP handshake and also implements the (old) + Google notification protocol. + """ + + # We use this XML template for subscription responses as well as + # notifications (conveniently enough, the same template works + # for both). + # + # The from, to, id, and type attributes are filled in later. + _NOTIFIER_STANZA = ParseXml( + """<iq from="" to="" id="" type=""> + <not:getAll xmlns:not="google:notifier"> + <Result xmlns=""/> + </not:getAll> + </iq> + """) + + def __init__(self, sock, socket_map, connections, addr): + """Starts up the xmpp connection. + + Args: + sock: The socket to the client. + socket_map: A map from sockets to their owning objects. + connections: The set of handshake-completed connections. + addr: The host/port of the client. + """ + asynchat.async_chat.__init__(self, sock) + self.set_terminator(None) + # async_chat in Python 2.4 has a bug where it ignores a + # socket_map argument. So we handle that ourselves. + self._socket_map = socket_map + self._socket_map[self.fileno()] = self + + self._connections = connections + self._parser = StanzaParser(self) + self._jid = None + + self._addr = addr + addr_str = AddrString(self._addr) + self._id_generator = IdGenerator(addr_str) + self._handshake_task = ( + HandshakeTask(self, self._id_generator, addr_str)) + print 'Starting connection to %s' % self + + def __str__(self): + if self._jid: + return str(self._jid) + else: + return AddrString(self._addr) + + # async_chat implementation. + + def collect_incoming_data(self, data): + self._parser.FeedString(data) + + # This is only here to make pychecker happy. + def found_terminator(self): + asynchat.async_chat.found_terminator(self) + + def handle_close(self): + print "Closing connection to %s" % self + # Remove ourselves from anywhere we possibly installed ourselves. + self._connections.discard(self) + del self._socket_map[self.fileno()] + + # Called by self._parser.FeedString(). + def FeedStanza(self, stanza): + if self._handshake_task: + self._handshake_task.FeedStanza(stanza) + elif stanza.tagName == 'iq': + self._HandleIq(stanza) + else: + raise UnexpectedXml(stanza) + + # Called by self._handshake_task. + def HandshakeDone(self, jid): + self._jid = jid + self._handshake_task = None + self._connections.add(self) + print "Handshake done for %s" % self + + def _HandleIq(self, iq): + if (iq.firstChild and + iq.firstChild.namespaceURI == 'google:notifier'): + iq_id = iq.getAttribute('id') + self._HandleNotifierCommand(iq_id, iq.firstChild) + elif iq.getAttribute('type') == 'result': + # Ignore all client acks. + pass + else: + raise UnexpectedXml(iq) + + def _HandleNotifierCommand(self, id, command_xml): + command = command_xml.tagName + if command == 'getAll': + # Subscription request. + if not command_xml.getElementsByTagName('SubscribedServiceUrl'): + raise UnexpectedXml(command_xml) + self._SendNotifierStanza(id, 'result') + elif command == 'set': + # Send notification request. + SendNotification(self._connections) + else: + raise UnexpectedXml(command_xml) + + def _SendNotifierStanza(self, id, type): + stanza = CloneXml(self._NOTIFIER_STANZA) + stanza.setAttribute('from', str(self._jid.GetBareJid())) + stanza.setAttribute('to', str(self._jid)) + stanza.setAttribute('id', id) + stanza.setAttribute('type', type) + self.SendStanza(stanza) + + def SendStanza(self, stanza, unlink=True): + """Sends a stanza to the client. + + Args: + stanza: The stanza to send. + unlink: Whether to unlink stanza after sending it. (Pass in + False if stanza is a constant.) + """ + self.SendData(stanza.toxml()) + if unlink: + stanza.unlink() + + def SendData(self, data): + """Sends raw data to the client. + """ + # We explicitly encode to ascii as that is what the client expects + # (some minidom library functions return unicode strings). + self.push(data.encode('ascii')) + + def SendNotification(self): + """Sends a notification to the client.""" + next_id = self._id_generator.GetNextId() + self._SendNotifierStanza(next_id, 'set') + + +def SendNotification(connections): + """Sends a notification to all connections in the given sequence.""" + for connection in connections: + print 'Sending notification to %s' % connection + connection.SendNotification() + + +class XmppServer(asyncore.dispatcher): + """The main XMPP server class. + + The XMPP server starts accepting connections on the given address + and spawns off XmppConnection objects for each one. + + Use like so: + + socket_map = {} + xmpp_server = xmppserver.XmppServer(socket_map, ('127.0.0.1', 5222)) + asyncore.loop(30.0, False, socket_map) + """ + + def __init__(self, socket_map, addr): + asyncore.dispatcher.__init__(self, None, socket_map) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.set_reuse_addr() + self.bind(addr) + self.listen(5) + self._socket_map = socket_map + self._socket_map[self.fileno()] = self + self._connections = set() + print 'XMPP server running at %s' % AddrString(addr) + + def handle_accept(self): + (sock, addr) = self.accept() + XmppConnection(sock, self._socket_map, self._connections, addr) diff --git a/net/tools/testserver/xmppserver_test.py b/net/tools/testserver/xmppserver_test.py new file mode 100644 index 0000000..e033a69 --- /dev/null +++ b/net/tools/testserver/xmppserver_test.py @@ -0,0 +1,250 @@ +#!/usr/bin/python2.4 +# Copyright (c) 2010 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""Tests exercising the various classes in xmppserver.py.""" + +import unittest + +import base64 +import xmppserver + +class XmlUtilsTest(unittest.TestCase): + + def testParseXml(self): + xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>""" + xml = xmppserver.ParseXml(xml_text) + self.assertEqual(xml.toxml(), xml_text) + + def testCloneXml(self): + xml = xmppserver.ParseXml('<foo/>') + xml_clone = xmppserver.CloneXml(xml) + xml_clone.setAttribute('bar', 'baz') + self.assertEqual(xml, xml) + self.assertEqual(xml_clone, xml_clone) + self.assertNotEqual(xml, xml_clone) + + def testCloneXmlUnlink(self): + xml_text = '<foo/>' + xml = xmppserver.ParseXml(xml_text) + xml_clone = xmppserver.CloneXml(xml) + xml.unlink() + self.assertEqual(xml.parentNode, None) + self.assertNotEqual(xml_clone.parentNode, None) + self.assertEqual(xml_clone.toxml(), xml_text) + +class StanzaParserTest(unittest.TestCase): + + def setUp(self): + self.stanzas = [] + + def FeedStanza(self, stanza): + # We can't append stanza directly because it is unlinked after + # this callback. + self.stanzas.append(stanza.toxml()) + + def testBasic(self): + parser = xmppserver.StanzaParser(self) + parser.FeedString('<foo') + self.assertEqual(len(self.stanzas), 0) + parser.FeedString('/><bar></bar>') + self.assertEqual(self.stanzas[0], '<foo/>') + self.assertEqual(self.stanzas[1], '<bar/>') + + def testStream(self): + parser = xmppserver.StanzaParser(self) + parser.FeedString('<stream') + self.assertEqual(len(self.stanzas), 0) + parser.FeedString(':stream foo="bar" xmlns:stream="baz">') + self.assertEqual(self.stanzas[0], + '<stream:stream foo="bar" xmlns:stream="baz"/>') + + def testNested(self): + parser = xmppserver.StanzaParser(self) + parser.FeedString('<foo') + self.assertEqual(len(self.stanzas), 0) + parser.FeedString(' bar="baz"') + parser.FeedString('><baz/><blah>meh</blah></foo>') + self.assertEqual(self.stanzas[0], + '<foo bar="baz"><baz/><blah>meh</blah></foo>') + + +class JidTest(unittest.TestCase): + + def testBasic(self): + jid = xmppserver.Jid('foo', 'bar.com') + self.assertEqual(str(jid), 'foo@bar.com') + + def testResource(self): + jid = xmppserver.Jid('foo', 'bar.com', 'resource') + self.assertEqual(str(jid), 'foo@bar.com/resource') + + def testGetBareJid(self): + jid = xmppserver.Jid('foo', 'bar.com', 'resource') + self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com') + + +class IdGeneratorTest(unittest.TestCase): + + def testBasic(self): + id_generator = xmppserver.IdGenerator('foo') + for i in xrange(0, 100): + self.assertEqual('foo.%d' % i, id_generator.GetNextId()) + + +class HandshakeTaskTest(unittest.TestCase): + + def setUp(self): + self.data_received = 0 + + def SendData(self, _): + self.data_received += 1 + + def SendStanza(self, _, unused=True): + self.data_received += 1 + + def HandshakeDone(self, jid): + self.jid = jid + + def DoHandshake(self, resource_prefix, resource, username, + initial_stream_domain, auth_domain, auth_stream_domain): + self.data_received = 0 + id_generator = xmppserver.IdGenerator('foo') + handshake_task = ( + xmppserver.HandshakeTask(self, id_generator, resource_prefix)) + stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') + stream_xml.setAttribute('to', initial_stream_domain) + self.assertEqual(self.data_received, 0) + handshake_task.FeedStanza(stream_xml) + self.assertEqual(self.data_received, 2) + + if auth_domain: + username_domain = '%s@%s' % (username, auth_domain) + else: + username_domain = username + auth_string = base64.b64encode('\0%s\0bar' % username_domain) + auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string) + handshake_task.FeedStanza(auth_xml) + self.assertEqual(self.data_received, 3) + + stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') + stream_xml.setAttribute('to', auth_stream_domain) + handshake_task.FeedStanza(stream_xml) + self.assertEqual(self.data_received, 5) + + bind_xml = xmppserver.ParseXml( + '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource) + handshake_task.FeedStanza(bind_xml) + self.assertEqual(self.data_received, 6) + + session_xml = xmppserver.ParseXml( + '<iq type="set"><session></session></iq>') + handshake_task.FeedStanza(session_xml) + self.assertEqual(self.data_received, 7) + + self.assertEqual(self.jid.username, username) + self.assertEqual(self.jid.domain, + auth_stream_domain or auth_domain or + initial_stream_domain) + self.assertEqual(self.jid.resource, + '%s.%s' % (resource_prefix, resource)) + + def testBasic(self): + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', 'baz.com', 'quux.com') + + def testDomainBehavior(self): + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', 'baz.com', 'quux.com') + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', 'baz.com', '') + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', '', '') + self.DoHandshake('resource_prefix', 'resource', + 'foo', '', '', '') + + +class XmppConnectionTest(unittest.TestCase): + + def setUp(self): + self.data = [] + + # socket-like methods. + def fileno(self): + return 0 + + def setblocking(self, int): + pass + + def getpeername(self): + return ('', 0) + + def send(self, data): + self.data.append(data) + pass + + def testBasic(self): + connections = set() + xmpp_connection = xmppserver.XmppConnection( + self, {}, connections, ('', 0)) + self.assertEqual(len(connections), 0) + xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar')) + self.assertEqual(len(connections), 1) + + # Test subscription request. + self.assertEqual(len(self.data), 0) + xmpp_connection.collect_incoming_data( + '<iq><getAll xmlns="google:notifier">' + '<SubscribedServiceUrl/></getAll></iq>') + self.assertEqual(len(self.data), 1) + + # Test acks. + xmpp_connection.collect_incoming_data('<iq type="result"/>') + self.assertEqual(len(self.data), 1) + + # Test notification. + xmpp_connection.collect_incoming_data( + '<iq><set xmlns="google:notifier"/></iq>') + self.assertEqual(len(self.data), 2) + + # Test unexpected stanza. + def SendUnexpectedStanza(): + xmpp_connection.collect_incoming_data('<foo/>') + self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza) + + # Test unexpected notifier command. + def SendUnexpectedNotifierCommand(): + xmpp_connection.collect_incoming_data( + '<iq><foo xmlns="google:notifier"/></iq>') + self.assertRaises(xmppserver.UnexpectedXml, + SendUnexpectedNotifierCommand) + + +class XmppServerTest(unittest.TestCase): + + # socket-like methods. + def fileno(self): + return 0 + + def setblocking(self, int): + pass + + def getpeername(self): + return ('', 0) + + def testBasic(self): + class FakeXmppServer(xmppserver.XmppServer): + def accept(self2): + return (self, ('', 0)) + + socket_map = {} + self.assertEqual(len(socket_map), 0) + xmpp_server = FakeXmppServer(socket_map, ('', 0)) + self.assertEqual(len(socket_map), 1) + xmpp_server.handle_accept() + self.assertEqual(len(socket_map), 2) + + +if __name__ == '__main__': + unittest.main() |
