summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rwxr-xr-xnet/tools/testserver/testserver.py25
-rw-r--r--net/tools/testserver/xmppserver.py48
-rwxr-xr-xnet/tools/testserver/xmppserver_test.py70
3 files changed, 127 insertions, 16 deletions
diff --git a/net/tools/testserver/testserver.py b/net/tools/testserver/testserver.py
index 4f97587..eefc448 100755
--- a/net/tools/testserver/testserver.py
+++ b/net/tools/testserver/testserver.py
@@ -1788,6 +1788,7 @@ class SyncPageHandler(BasePageHandler):
get_handlers = [self.ChromiumSyncTimeHandler,
self.ChromiumSyncMigrationOpHandler,
self.ChromiumSyncCredHandler,
+ self.ChromiumSyncXmppCredHandler,
self.ChromiumSyncDisableNotificationsOpHandler,
self.ChromiumSyncEnableNotificationsOpHandler,
self.ChromiumSyncSendNotificationOpHandler,
@@ -1894,6 +1895,30 @@ class SyncPageHandler(BasePageHandler):
self.wfile.write(raw_reply)
return True
+ def ChromiumSyncXmppCredHandler(self):
+ test_name = "/chromiumsync/xmppcred"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ xmpp_server = self.server.GetXmppServer()
+ try:
+ query = urlparse.urlparse(self.path)[4]
+ cred_valid = urlparse.parse_qs(query)['valid']
+ if cred_valid[0] == 'True':
+ xmpp_server.SetAuthenticated(True)
+ else:
+ xmpp_server.SetAuthenticated(False)
+ except:
+ xmpp_server.SetAuthenticated(False)
+
+ http_response = 200
+ raw_reply = 'XMPP Authenticated: %s ' % xmpp_server.GetAuthenticated()
+ self.send_response(http_response)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
def ChromiumSyncDisableNotificationsOpHandler(self):
test_name = "/chromiumsync/disablenotifications"
if not self._ShouldHandleRequest(test_name):
diff --git a/net/tools/testserver/xmppserver.py b/net/tools/testserver/xmppserver.py
index 6952a99..996da96 100644
--- a/net/tools/testserver/xmppserver.py
+++ b/net/tools/testserver/xmppserver.py
@@ -220,6 +220,10 @@ class HandshakeTask(object):
_AUTH_SUCCESS_STANZA = ParseXml(
'<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>')
+ # Used when in the _AUTH_NEEDED state.
+ _AUTH_FAILURE_STANZA = ParseXml(
+ '<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>')
+
# Used when in the _AUTH_STREAM_NEEDED state.
_BIND_STANZA = ParseXml(
'<stream:features xmlns:stream="http://etherx.jabber.org/streams">'
@@ -242,12 +246,13 @@ class HandshakeTask(object):
# The id attribute is filled in later.
_IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>')
- def __init__(self, connection, resource_prefix):
+ def __init__(self, connection, resource_prefix, authenticated):
self._connection = connection
self._id_generator = IdGenerator(resource_prefix)
self._username = ''
self._domain = ''
self._jid = None
+ self._authenticated = authenticated
self._resource_prefix = resource_prefix
self._state = self._INITIAL_STREAM_NEEDED
@@ -304,6 +309,10 @@ class HandshakeTask(object):
domain = self._domain
return (username, domain)
+ def Finish():
+ self._state = self._FINISHED
+ self._connection.HandshakeDone(self._jid)
+
if self._state == self._INITIAL_STREAM_NEEDED:
HandleStream(stanza)
self._connection.SendStanza(self._AUTH_STANZA, False)
@@ -312,8 +321,12 @@ class HandshakeTask(object):
elif self._state == self._AUTH_NEEDED:
ExpectStanza(stanza, 'auth')
(self._username, self._domain) = GetUserDomain(stanza)
- self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False)
- self._state = self._AUTH_STREAM_NEEDED
+ if self._authenticated:
+ self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False)
+ self._state = self._AUTH_STREAM_NEEDED
+ else:
+ self._connection.SendStanza(self._AUTH_FAILURE_STANZA, False)
+ Finish()
elif self._state == self._AUTH_STREAM_NEEDED:
HandleStream(stanza)
@@ -340,8 +353,7 @@ class HandshakeTask(object):
xml = CloneXml(self._IQ_RESPONSE_STANZA)
xml.setAttribute('id', stanza_id)
self._connection.SendStanza(xml)
- self._state = self._FINISHED
- self._connection.HandshakeDone(self._jid)
+ Finish()
def AddrString(addr):
@@ -361,7 +373,7 @@ class XmppConnection(asynchat.async_chat):
# The from and id attributes are filled in later.
_IQ_RESPONSE_STANZA = ParseXml('<iq from="" id="" type="result"/>')
- def __init__(self, sock, socket_map, delegate, addr):
+ def __init__(self, sock, socket_map, delegate, addr, authenticated):
"""Starts up the xmpp connection.
Args:
@@ -389,7 +401,7 @@ class XmppConnection(asynchat.async_chat):
self._addr = addr
addr_str = AddrString(self._addr)
- self._handshake_task = HandshakeTask(self, addr_str)
+ self._handshake_task = HandshakeTask(self, addr_str, authenticated)
print 'Starting connection to %s' % self
def __str__(self):
@@ -427,10 +439,14 @@ class XmppConnection(asynchat.async_chat):
# Called by self._handshake_task.
def HandshakeDone(self, jid):
- self._jid = jid
- self._handshake_task = None
- self._delegate.OnXmppHandshakeDone(self)
- print "Handshake done for %s" % self
+ if jid:
+ self._jid = jid
+ self._handshake_task = None
+ self._delegate.OnXmppHandshakeDone(self)
+ print "Handshake done for %s" % self
+ else:
+ print "Handshake failed for %s" % self
+ self.close()
def _HandlePushCommand(self, stanza):
if stanza.tagName == 'iq' and stanza.firstChild.tagName == 'subscribe':
@@ -505,10 +521,12 @@ class XmppServer(asyncore.dispatcher):
self._connections = set()
self._handshake_done_connections = set()
self._notifications_enabled = True
+ self._authenticated = True
def handle_accept(self):
(sock, addr) = self.accept()
- xmpp_connection = XmppConnection(sock, self._socket_map, self, addr)
+ xmpp_connection = XmppConnection(
+ sock, self._socket_map, self, addr, self._authenticated)
self._connections.add(xmpp_connection)
# Return the new XmppConnection for testing.
return xmpp_connection
@@ -553,6 +571,12 @@ class XmppServer(asyncore.dispatcher):
self.ForwardNotification(None, notification_stanza)
notification_stanza.unlink()
+ def SetAuthenticated(self, auth_valid):
+ self._authenticated = auth_valid
+
+ def GetAuthenticated(self):
+ return self._authenticated
+
# XmppConnection delegate methods.
def OnXmppHandshakeDone(self, xmpp_connection):
self._handshake_done_connections.add(xmpp_connection)
diff --git a/net/tools/testserver/xmppserver_test.py b/net/tools/testserver/xmppserver_test.py
index b00885a..7431c84 100755
--- a/net/tools/testserver/xmppserver_test.py
+++ b/net/tools/testserver/xmppserver_test.py
@@ -96,7 +96,12 @@ class IdGeneratorTest(unittest.TestCase):
class HandshakeTaskTest(unittest.TestCase):
def setUp(self):
+ self.Reset()
+
+ def Reset(self):
self.data_received = 0
+ self.handshake_done = False
+ self.jid = None
def SendData(self, _):
self.data_received += 1
@@ -105,13 +110,14 @@ class HandshakeTaskTest(unittest.TestCase):
self.data_received += 1
def HandshakeDone(self, jid):
+ self.handshake_done = True
self.jid = jid
def DoHandshake(self, resource_prefix, resource, username,
initial_stream_domain, auth_domain, auth_stream_domain):
- self.data_received = 0
+ self.Reset()
handshake_task = (
- xmppserver.HandshakeTask(self, resource_prefix))
+ xmppserver.HandshakeTask(self, resource_prefix, True))
stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
stream_xml.setAttribute('to', initial_stream_domain)
self.assertEqual(self.data_received, 0)
@@ -137,11 +143,15 @@ class HandshakeTaskTest(unittest.TestCase):
handshake_task.FeedStanza(bind_xml)
self.assertEqual(self.data_received, 6)
+ self.assertFalse(self.handshake_done)
+
session_xml = xmppserver.ParseXml(
'<iq type="set"><session></session></iq>')
handshake_task.FeedStanza(session_xml)
self.assertEqual(self.data_received, 7)
+ self.assertTrue(self.handshake_done)
+
self.assertEqual(self.jid.username, username)
self.assertEqual(self.jid.domain,
auth_stream_domain or auth_domain or
@@ -149,6 +159,34 @@ class HandshakeTaskTest(unittest.TestCase):
self.assertEqual(self.jid.resource,
'%s.%s' % (resource_prefix, resource))
+ handshake_task.FeedStanza('<ignored/>')
+ self.assertEqual(self.data_received, 7)
+
+ def DoHandshakeUnauthenticated(self, resource_prefix, resource, username,
+ initial_stream_domain):
+ self.Reset()
+ handshake_task = (
+ xmppserver.HandshakeTask(self, resource_prefix, False))
+ stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
+ stream_xml.setAttribute('to', initial_stream_domain)
+ self.assertEqual(self.data_received, 0)
+ handshake_task.FeedStanza(stream_xml)
+ self.assertEqual(self.data_received, 2)
+
+ self.assertFalse(self.handshake_done)
+
+ auth_string = base64.b64encode('\0%s\0bar' % username)
+ auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
+ handshake_task.FeedStanza(auth_xml)
+ self.assertEqual(self.data_received, 3)
+
+ self.assertTrue(self.handshake_done)
+
+ self.assertEqual(self.jid, None)
+
+ handshake_task.FeedStanza('<ignored/>')
+ self.assertEqual(self.data_received, 3)
+
def testBasic(self):
self.DoHandshake('resource_prefix', 'resource',
'foo', 'bar.com', 'baz.com', 'quux.com')
@@ -163,6 +201,10 @@ class HandshakeTaskTest(unittest.TestCase):
self.DoHandshake('resource_prefix', 'resource',
'foo', '', '', '')
+ def testBasicUnauthenticated(self):
+ self.DoHandshakeUnauthenticated('resource_prefix', 'resource',
+ 'foo', 'bar.com')
+
class FakeSocket(object):
"""A fake socket object used for testing.
@@ -212,7 +254,7 @@ class XmppConnectionTest(unittest.TestCase):
def testBasic(self):
socket_map = {}
xmpp_connection = xmppserver.XmppConnection(
- self.fake_socket, socket_map, self, ('', 0))
+ self.fake_socket, socket_map, self, ('', 0), True)
self.assertEqual(len(socket_map), 1)
self.assertEqual(len(self.connections), 0)
xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
@@ -248,7 +290,27 @@ class XmppConnectionTest(unittest.TestCase):
self.assertRaises(xmppserver.UnexpectedXml,
SendUnexpectedNotifierCommand)
- # Test close
+ # Test close.
+ xmpp_connection.close()
+ self.assertEqual(len(socket_map), 0)
+ self.assertEqual(len(self.connections), 0)
+
+ def testBasicUnauthenticated(self):
+ socket_map = {}
+ xmpp_connection = xmppserver.XmppConnection(
+ self.fake_socket, socket_map, self, ('', 0), False)
+ self.assertEqual(len(socket_map), 1)
+ self.assertEqual(len(self.connections), 0)
+ xmpp_connection.HandshakeDone(None)
+ self.assertEqual(len(socket_map), 0)
+ self.assertEqual(len(self.connections), 0)
+
+ # Test unexpected stanza.
+ def SendUnexpectedStanza():
+ xmpp_connection.collect_incoming_data('<foo/>')
+ self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
+
+ # Test redundant close.
xmpp_connection.close()
self.assertEqual(len(socket_map), 0)
self.assertEqual(len(self.connections), 0)