diff options
Diffstat (limited to 'net/tools/testserver/xmppserver.py')
-rw-r--r-- | net/tools/testserver/xmppserver.py | 48 |
1 files changed, 36 insertions, 12 deletions
diff --git a/net/tools/testserver/xmppserver.py b/net/tools/testserver/xmppserver.py index 6952a99..996da96 100644 --- a/net/tools/testserver/xmppserver.py +++ b/net/tools/testserver/xmppserver.py @@ -220,6 +220,10 @@ class HandshakeTask(object): _AUTH_SUCCESS_STANZA = ParseXml( '<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + # Used when in the _AUTH_NEEDED state. + _AUTH_FAILURE_STANZA = ParseXml( + '<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + # Used when in the _AUTH_STREAM_NEEDED state. _BIND_STANZA = ParseXml( '<stream:features xmlns:stream="http://etherx.jabber.org/streams">' @@ -242,12 +246,13 @@ class HandshakeTask(object): # The id attribute is filled in later. _IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>') - def __init__(self, connection, resource_prefix): + def __init__(self, connection, resource_prefix, authenticated): self._connection = connection self._id_generator = IdGenerator(resource_prefix) self._username = '' self._domain = '' self._jid = None + self._authenticated = authenticated self._resource_prefix = resource_prefix self._state = self._INITIAL_STREAM_NEEDED @@ -304,6 +309,10 @@ class HandshakeTask(object): domain = self._domain return (username, domain) + def Finish(): + self._state = self._FINISHED + self._connection.HandshakeDone(self._jid) + if self._state == self._INITIAL_STREAM_NEEDED: HandleStream(stanza) self._connection.SendStanza(self._AUTH_STANZA, False) @@ -312,8 +321,12 @@ class HandshakeTask(object): elif self._state == self._AUTH_NEEDED: ExpectStanza(stanza, 'auth') (self._username, self._domain) = GetUserDomain(stanza) - self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False) - self._state = self._AUTH_STREAM_NEEDED + if self._authenticated: + self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False) + self._state = self._AUTH_STREAM_NEEDED + else: + self._connection.SendStanza(self._AUTH_FAILURE_STANZA, False) + Finish() elif self._state == self._AUTH_STREAM_NEEDED: HandleStream(stanza) @@ -340,8 +353,7 @@ class HandshakeTask(object): xml = CloneXml(self._IQ_RESPONSE_STANZA) xml.setAttribute('id', stanza_id) self._connection.SendStanza(xml) - self._state = self._FINISHED - self._connection.HandshakeDone(self._jid) + Finish() def AddrString(addr): @@ -361,7 +373,7 @@ class XmppConnection(asynchat.async_chat): # The from and id attributes are filled in later. _IQ_RESPONSE_STANZA = ParseXml('<iq from="" id="" type="result"/>') - def __init__(self, sock, socket_map, delegate, addr): + def __init__(self, sock, socket_map, delegate, addr, authenticated): """Starts up the xmpp connection. Args: @@ -389,7 +401,7 @@ class XmppConnection(asynchat.async_chat): self._addr = addr addr_str = AddrString(self._addr) - self._handshake_task = HandshakeTask(self, addr_str) + self._handshake_task = HandshakeTask(self, addr_str, authenticated) print 'Starting connection to %s' % self def __str__(self): @@ -427,10 +439,14 @@ class XmppConnection(asynchat.async_chat): # Called by self._handshake_task. def HandshakeDone(self, jid): - self._jid = jid - self._handshake_task = None - self._delegate.OnXmppHandshakeDone(self) - print "Handshake done for %s" % self + if jid: + self._jid = jid + self._handshake_task = None + self._delegate.OnXmppHandshakeDone(self) + print "Handshake done for %s" % self + else: + print "Handshake failed for %s" % self + self.close() def _HandlePushCommand(self, stanza): if stanza.tagName == 'iq' and stanza.firstChild.tagName == 'subscribe': @@ -505,10 +521,12 @@ class XmppServer(asyncore.dispatcher): self._connections = set() self._handshake_done_connections = set() self._notifications_enabled = True + self._authenticated = True def handle_accept(self): (sock, addr) = self.accept() - xmpp_connection = XmppConnection(sock, self._socket_map, self, addr) + xmpp_connection = XmppConnection( + sock, self._socket_map, self, addr, self._authenticated) self._connections.add(xmpp_connection) # Return the new XmppConnection for testing. return xmpp_connection @@ -553,6 +571,12 @@ class XmppServer(asyncore.dispatcher): self.ForwardNotification(None, notification_stanza) notification_stanza.unlink() + def SetAuthenticated(self, auth_valid): + self._authenticated = auth_valid + + def GetAuthenticated(self): + return self._authenticated + # XmppConnection delegate methods. def OnXmppHandshakeDone(self, xmpp_connection): self._handshake_done_connections.add(xmpp_connection) |