diff options
Diffstat (limited to 'net')
-rwxr-xr-x | net/tools/testserver/testserver.py | 25 | ||||
-rw-r--r-- | net/tools/testserver/xmppserver.py | 48 | ||||
-rwxr-xr-x | net/tools/testserver/xmppserver_test.py | 70 |
3 files changed, 127 insertions, 16 deletions
diff --git a/net/tools/testserver/testserver.py b/net/tools/testserver/testserver.py index 4f97587..eefc448 100755 --- a/net/tools/testserver/testserver.py +++ b/net/tools/testserver/testserver.py @@ -1788,6 +1788,7 @@ class SyncPageHandler(BasePageHandler): get_handlers = [self.ChromiumSyncTimeHandler, self.ChromiumSyncMigrationOpHandler, self.ChromiumSyncCredHandler, + self.ChromiumSyncXmppCredHandler, self.ChromiumSyncDisableNotificationsOpHandler, self.ChromiumSyncEnableNotificationsOpHandler, self.ChromiumSyncSendNotificationOpHandler, @@ -1894,6 +1895,30 @@ class SyncPageHandler(BasePageHandler): self.wfile.write(raw_reply) return True + def ChromiumSyncXmppCredHandler(self): + test_name = "/chromiumsync/xmppcred" + if not self._ShouldHandleRequest(test_name): + return False + xmpp_server = self.server.GetXmppServer() + try: + query = urlparse.urlparse(self.path)[4] + cred_valid = urlparse.parse_qs(query)['valid'] + if cred_valid[0] == 'True': + xmpp_server.SetAuthenticated(True) + else: + xmpp_server.SetAuthenticated(False) + except: + xmpp_server.SetAuthenticated(False) + + http_response = 200 + raw_reply = 'XMPP Authenticated: %s ' % xmpp_server.GetAuthenticated() + self.send_response(http_response) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + def ChromiumSyncDisableNotificationsOpHandler(self): test_name = "/chromiumsync/disablenotifications" if not self._ShouldHandleRequest(test_name): diff --git a/net/tools/testserver/xmppserver.py b/net/tools/testserver/xmppserver.py index 6952a99..996da96 100644 --- a/net/tools/testserver/xmppserver.py +++ b/net/tools/testserver/xmppserver.py @@ -220,6 +220,10 @@ class HandshakeTask(object): _AUTH_SUCCESS_STANZA = ParseXml( '<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + # Used when in the _AUTH_NEEDED state. + _AUTH_FAILURE_STANZA = ParseXml( + '<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + # Used when in the _AUTH_STREAM_NEEDED state. _BIND_STANZA = ParseXml( '<stream:features xmlns:stream="http://etherx.jabber.org/streams">' @@ -242,12 +246,13 @@ class HandshakeTask(object): # The id attribute is filled in later. _IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>') - def __init__(self, connection, resource_prefix): + def __init__(self, connection, resource_prefix, authenticated): self._connection = connection self._id_generator = IdGenerator(resource_prefix) self._username = '' self._domain = '' self._jid = None + self._authenticated = authenticated self._resource_prefix = resource_prefix self._state = self._INITIAL_STREAM_NEEDED @@ -304,6 +309,10 @@ class HandshakeTask(object): domain = self._domain return (username, domain) + def Finish(): + self._state = self._FINISHED + self._connection.HandshakeDone(self._jid) + if self._state == self._INITIAL_STREAM_NEEDED: HandleStream(stanza) self._connection.SendStanza(self._AUTH_STANZA, False) @@ -312,8 +321,12 @@ class HandshakeTask(object): elif self._state == self._AUTH_NEEDED: ExpectStanza(stanza, 'auth') (self._username, self._domain) = GetUserDomain(stanza) - self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False) - self._state = self._AUTH_STREAM_NEEDED + if self._authenticated: + self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False) + self._state = self._AUTH_STREAM_NEEDED + else: + self._connection.SendStanza(self._AUTH_FAILURE_STANZA, False) + Finish() elif self._state == self._AUTH_STREAM_NEEDED: HandleStream(stanza) @@ -340,8 +353,7 @@ class HandshakeTask(object): xml = CloneXml(self._IQ_RESPONSE_STANZA) xml.setAttribute('id', stanza_id) self._connection.SendStanza(xml) - self._state = self._FINISHED - self._connection.HandshakeDone(self._jid) + Finish() def AddrString(addr): @@ -361,7 +373,7 @@ class XmppConnection(asynchat.async_chat): # The from and id attributes are filled in later. _IQ_RESPONSE_STANZA = ParseXml('<iq from="" id="" type="result"/>') - def __init__(self, sock, socket_map, delegate, addr): + def __init__(self, sock, socket_map, delegate, addr, authenticated): """Starts up the xmpp connection. Args: @@ -389,7 +401,7 @@ class XmppConnection(asynchat.async_chat): self._addr = addr addr_str = AddrString(self._addr) - self._handshake_task = HandshakeTask(self, addr_str) + self._handshake_task = HandshakeTask(self, addr_str, authenticated) print 'Starting connection to %s' % self def __str__(self): @@ -427,10 +439,14 @@ class XmppConnection(asynchat.async_chat): # Called by self._handshake_task. def HandshakeDone(self, jid): - self._jid = jid - self._handshake_task = None - self._delegate.OnXmppHandshakeDone(self) - print "Handshake done for %s" % self + if jid: + self._jid = jid + self._handshake_task = None + self._delegate.OnXmppHandshakeDone(self) + print "Handshake done for %s" % self + else: + print "Handshake failed for %s" % self + self.close() def _HandlePushCommand(self, stanza): if stanza.tagName == 'iq' and stanza.firstChild.tagName == 'subscribe': @@ -505,10 +521,12 @@ class XmppServer(asyncore.dispatcher): self._connections = set() self._handshake_done_connections = set() self._notifications_enabled = True + self._authenticated = True def handle_accept(self): (sock, addr) = self.accept() - xmpp_connection = XmppConnection(sock, self._socket_map, self, addr) + xmpp_connection = XmppConnection( + sock, self._socket_map, self, addr, self._authenticated) self._connections.add(xmpp_connection) # Return the new XmppConnection for testing. return xmpp_connection @@ -553,6 +571,12 @@ class XmppServer(asyncore.dispatcher): self.ForwardNotification(None, notification_stanza) notification_stanza.unlink() + def SetAuthenticated(self, auth_valid): + self._authenticated = auth_valid + + def GetAuthenticated(self): + return self._authenticated + # XmppConnection delegate methods. def OnXmppHandshakeDone(self, xmpp_connection): self._handshake_done_connections.add(xmpp_connection) diff --git a/net/tools/testserver/xmppserver_test.py b/net/tools/testserver/xmppserver_test.py index b00885a..7431c84 100755 --- a/net/tools/testserver/xmppserver_test.py +++ b/net/tools/testserver/xmppserver_test.py @@ -96,7 +96,12 @@ class IdGeneratorTest(unittest.TestCase): class HandshakeTaskTest(unittest.TestCase): def setUp(self): + self.Reset() + + def Reset(self): self.data_received = 0 + self.handshake_done = False + self.jid = None def SendData(self, _): self.data_received += 1 @@ -105,13 +110,14 @@ class HandshakeTaskTest(unittest.TestCase): self.data_received += 1 def HandshakeDone(self, jid): + self.handshake_done = True self.jid = jid def DoHandshake(self, resource_prefix, resource, username, initial_stream_domain, auth_domain, auth_stream_domain): - self.data_received = 0 + self.Reset() handshake_task = ( - xmppserver.HandshakeTask(self, resource_prefix)) + xmppserver.HandshakeTask(self, resource_prefix, True)) stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') stream_xml.setAttribute('to', initial_stream_domain) self.assertEqual(self.data_received, 0) @@ -137,11 +143,15 @@ class HandshakeTaskTest(unittest.TestCase): handshake_task.FeedStanza(bind_xml) self.assertEqual(self.data_received, 6) + self.assertFalse(self.handshake_done) + session_xml = xmppserver.ParseXml( '<iq type="set"><session></session></iq>') handshake_task.FeedStanza(session_xml) self.assertEqual(self.data_received, 7) + self.assertTrue(self.handshake_done) + self.assertEqual(self.jid.username, username) self.assertEqual(self.jid.domain, auth_stream_domain or auth_domain or @@ -149,6 +159,34 @@ class HandshakeTaskTest(unittest.TestCase): self.assertEqual(self.jid.resource, '%s.%s' % (resource_prefix, resource)) + handshake_task.FeedStanza('<ignored/>') + self.assertEqual(self.data_received, 7) + + def DoHandshakeUnauthenticated(self, resource_prefix, resource, username, + initial_stream_domain): + self.Reset() + handshake_task = ( + xmppserver.HandshakeTask(self, resource_prefix, False)) + stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') + stream_xml.setAttribute('to', initial_stream_domain) + self.assertEqual(self.data_received, 0) + handshake_task.FeedStanza(stream_xml) + self.assertEqual(self.data_received, 2) + + self.assertFalse(self.handshake_done) + + auth_string = base64.b64encode('\0%s\0bar' % username) + auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string) + handshake_task.FeedStanza(auth_xml) + self.assertEqual(self.data_received, 3) + + self.assertTrue(self.handshake_done) + + self.assertEqual(self.jid, None) + + handshake_task.FeedStanza('<ignored/>') + self.assertEqual(self.data_received, 3) + def testBasic(self): self.DoHandshake('resource_prefix', 'resource', 'foo', 'bar.com', 'baz.com', 'quux.com') @@ -163,6 +201,10 @@ class HandshakeTaskTest(unittest.TestCase): self.DoHandshake('resource_prefix', 'resource', 'foo', '', '', '') + def testBasicUnauthenticated(self): + self.DoHandshakeUnauthenticated('resource_prefix', 'resource', + 'foo', 'bar.com') + class FakeSocket(object): """A fake socket object used for testing. @@ -212,7 +254,7 @@ class XmppConnectionTest(unittest.TestCase): def testBasic(self): socket_map = {} xmpp_connection = xmppserver.XmppConnection( - self.fake_socket, socket_map, self, ('', 0)) + self.fake_socket, socket_map, self, ('', 0), True) self.assertEqual(len(socket_map), 1) self.assertEqual(len(self.connections), 0) xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar')) @@ -248,7 +290,27 @@ class XmppConnectionTest(unittest.TestCase): self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedNotifierCommand) - # Test close + # Test close. + xmpp_connection.close() + self.assertEqual(len(socket_map), 0) + self.assertEqual(len(self.connections), 0) + + def testBasicUnauthenticated(self): + socket_map = {} + xmpp_connection = xmppserver.XmppConnection( + self.fake_socket, socket_map, self, ('', 0), False) + self.assertEqual(len(socket_map), 1) + self.assertEqual(len(self.connections), 0) + xmpp_connection.HandshakeDone(None) + self.assertEqual(len(socket_map), 0) + self.assertEqual(len(self.connections), 0) + + # Test unexpected stanza. + def SendUnexpectedStanza(): + xmpp_connection.collect_incoming_data('<foo/>') + self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza) + + # Test redundant close. xmpp_connection.close() self.assertEqual(len(socket_map), 0) self.assertEqual(len(self.connections), 0) |