diff options
29 files changed, 332 insertions, 49 deletions
diff --git a/remoting/host/chromoting_host.cc b/remoting/host/chromoting_host.cc index e4bdae7..469c5b2 100644 --- a/remoting/host/chromoting_host.cc +++ b/remoting/host/chromoting_host.cc @@ -171,6 +171,20 @@ void ChromotingHost::SetMaximumSessionDuration( //////////////////////////////////////////////////////////////////////////// // protocol::ClientSession::EventHandler implementation. +void ChromotingHost::OnSessionAuthenticating(ClientSession* client) { + // We treat each incoming connection as a failure to authenticate, + // and clear the backoff when a connection successfully + // authenticates. This allows the backoff to protect from parallel + // connection attempts as well as sequential ones. + if (login_backoff_.ShouldRejectRequest()) { + LOG(WARNING) << "Disconnecting client " << client->client_jid() << " due to" + " an overload of failed login attempts."; + client->DisconnectSession(); + return; + } + login_backoff_.InformOfRequest(false); +} + bool ChromotingHost::OnSessionAuthenticated(ClientSession* client) { DCHECK(CalledOnValidThread()); @@ -265,16 +279,12 @@ void ChromotingHost::OnIncomingSession( } if (login_backoff_.ShouldRejectRequest()) { + LOG(WARNING) << "Rejecting connection due to" + " an overload of failed login attempts."; *response = protocol::SessionManager::OVERLOAD; return; } - // We treat each incoming connection as a failure to authenticate, - // and clear the backoff when a connection successfully - // authenticates. This allows the backoff to protect from parallel - // connection attempts as well as sequential ones. - login_backoff_.InformOfRequest(false); - protocol::SessionConfig config; if (!protocol_config_->Select(session->candidate_config(), &config)) { LOG(WARNING) << "Rejecting connection from " << session->jid() diff --git a/remoting/host/chromoting_host.h b/remoting/host/chromoting_host.h index 3e1fa74..a6642d4 100644 --- a/remoting/host/chromoting_host.h +++ b/remoting/host/chromoting_host.h @@ -115,6 +115,7 @@ class ChromotingHost : public base::NonThreadSafe, //////////////////////////////////////////////////////////////////////////// // ClientSession::EventHandler implementation. + virtual void OnSessionAuthenticating(ClientSession* client) OVERRIDE; virtual bool OnSessionAuthenticated(ClientSession* client) OVERRIDE; virtual void OnSessionChannelsConnected(ClientSession* client) OVERRIDE; virtual void OnSessionAuthenticationFailed(ClientSession* client) OVERRIDE; diff --git a/remoting/host/chromoting_host_unittest.cc b/remoting/host/chromoting_host_unittest.cc index 1410371..5ebe051 100644 --- a/remoting/host/chromoting_host_unittest.cc +++ b/remoting/host/chromoting_host_unittest.cc @@ -28,12 +28,13 @@ using ::remoting::protocol::MockConnectionToClientEventHandler; using ::remoting::protocol::MockHostStub; using ::remoting::protocol::MockSession; using ::remoting::protocol::MockVideoStub; +using ::remoting::protocol::Session; using ::remoting::protocol::SessionConfig; using testing::_; using testing::AnyNumber; -using testing::AtMost; using testing::AtLeast; +using testing::AtMost; using testing::CreateFunctor; using testing::DeleteArg; using testing::DoAll; @@ -44,6 +45,7 @@ using testing::InvokeArgument; using testing::InvokeWithoutArgs; using testing::Return; using testing::ReturnRef; +using testing::SaveArg; using testing::Sequence; namespace remoting { @@ -124,9 +126,10 @@ class ChromotingHostTest : public testing::Test { .Times(AnyNumber()); EXPECT_CALL(*session_unowned1_, SetEventHandler(_)) .Times(AnyNumber()) - .WillRepeatedly(Invoke(this, &ChromotingHostTest::SetEventHandler)); + .WillRepeatedly(SaveArg<0>(&session_unowned1_event_handler_)); EXPECT_CALL(*session_unowned2_, SetEventHandler(_)) - .Times(AnyNumber()); + .Times(AnyNumber()) + .WillRepeatedly(SaveArg<0>(&session_unowned2_event_handler_)); EXPECT_CALL(*session1_, config()) .WillRepeatedly(ReturnRef(session_config1_)); EXPECT_CALL(*session2_, config()) @@ -287,13 +290,15 @@ class ChromotingHostTest : public testing::Test { get_connection(connection_index), protocol::OK); } - void SetEventHandler(protocol::Session::EventHandler* event_handler) { - session_event_handler_ = event_handler; + void NotifyConnectionClosed1() { + if (session_unowned1_event_handler_) { + session_unowned1_event_handler_->OnSessionStateChange(Session::CLOSED); + } } - void NotifyConnectionClosed() { - if (session_event_handler_) { - session_event_handler_->OnSessionStateChange(protocol::Session::CLOSED); + void NotifyConnectionClosed2() { + if (session_unowned2_event_handler_) { + session_unowned2_event_handler_->OnSessionStateChange(Session::CLOSED); } } @@ -424,7 +429,8 @@ class ChromotingHostTest : public testing::Test { scoped_ptr<MockSession> session_unowned2_; // Not owned by a connection. SessionConfig session_unowned_config2_; std::string session_unowned_jid2_; - protocol::Session::EventHandler* session_event_handler_; + protocol::Session::EventHandler* session_unowned1_event_handler_; + protocol::Session::EventHandler* session_unowned2_event_handler_; scoped_ptr<protocol::CandidateSessionConfig> empty_candidate_config_; scoped_ptr<protocol::CandidateSessionConfig> default_candidate_config_; @@ -432,10 +438,16 @@ class ChromotingHostTest : public testing::Test { return (connection_index == 0) ? connection1_ : connection2_; } + // Returns the cached client pointers client1_ or client2_. ClientSession*& get_client(int connection_index) { return (connection_index == 0) ? client1_ : client2_; } + // Returns the list of clients of the host_. + std::list<ClientSession*>& get_clients_from_host() { + return host_->clients_; + } + const std::string& get_session_jid(int connection_index) { return (connection_index == 0) ? session_jid1_ : session_jid2_; } @@ -578,7 +590,7 @@ TEST_F(ChromotingHostTest, IncomingSessionAccepted) { default_candidate_config_.get())); EXPECT_CALL(*session_unowned1_, set_config(_)); EXPECT_CALL(*session_unowned1_, Close()).WillOnce(InvokeWithoutArgs( - this, &ChromotingHostTest::NotifyConnectionClosed)); + this, &ChromotingHostTest::NotifyConnectionClosed1)); EXPECT_CALL(host_status_observer_, OnAccessDenied(_)); EXPECT_CALL(host_status_observer_, OnShutdown()); @@ -593,13 +605,13 @@ TEST_F(ChromotingHostTest, IncomingSessionAccepted) { message_loop_.Run(); } -TEST_F(ChromotingHostTest, IncomingSessionOverload) { +TEST_F(ChromotingHostTest, LoginBackOffUponConnection) { ExpectHostAndSessionManagerStart(); - EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce(Return( - default_candidate_config_.get())); + EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce( + Return(default_candidate_config_.get())); EXPECT_CALL(*session_unowned1_, set_config(_)); - EXPECT_CALL(*session_unowned1_, Close()).WillOnce(InvokeWithoutArgs( - this, &ChromotingHostTest::NotifyConnectionClosed)); + EXPECT_CALL(*session_unowned1_, Close()).WillOnce( + InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed1)); EXPECT_CALL(host_status_observer_, OnAccessDenied(_)); EXPECT_CALL(host_status_observer_, OnShutdown()); @@ -607,9 +619,11 @@ TEST_F(ChromotingHostTest, IncomingSessionOverload) { protocol::SessionManager::IncomingSessionResponse response = protocol::SessionManager::DECLINE; + host_->OnIncomingSession(session_unowned1_.release(), &response); EXPECT_EQ(protocol::SessionManager::ACCEPT, response); + host_->OnSessionAuthenticating(get_clients_from_host().front()); host_->OnIncomingSession(session_unowned2_.get(), &response); EXPECT_EQ(protocol::SessionManager::OVERLOAD, response); @@ -617,6 +631,46 @@ TEST_F(ChromotingHostTest, IncomingSessionOverload) { message_loop_.Run(); } +TEST_F(ChromotingHostTest, LoginBackOffUponAuthenticating) { + Expectation start = ExpectHostAndSessionManagerStart(); + EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce( + Return(default_candidate_config_.get())); + EXPECT_CALL(*session_unowned1_, set_config(_)); + EXPECT_CALL(*session_unowned1_, Close()).WillOnce( + InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed1)); + + EXPECT_CALL(*session_unowned2_, candidate_config()).WillOnce( + Return(default_candidate_config_.get())); + EXPECT_CALL(*session_unowned2_, set_config(_)); + EXPECT_CALL(*session_unowned2_, Close()).WillOnce( + InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed2)); + + EXPECT_CALL(host_status_observer_, OnShutdown()); + + host_->Start(xmpp_login_); + + protocol::SessionManager::IncomingSessionResponse response = + protocol::SessionManager::DECLINE; + + host_->OnIncomingSession(session_unowned1_.release(), &response); + EXPECT_EQ(protocol::SessionManager::ACCEPT, response); + + host_->OnIncomingSession(session_unowned2_.release(), &response); + EXPECT_EQ(protocol::SessionManager::ACCEPT, response); + + // This will set the backoff. + host_->OnSessionAuthenticating(get_clients_from_host().front()); + + // This should disconnect client2. + host_->OnSessionAuthenticating(get_clients_from_host().back()); + + // Verify that the host only has 1 client at this point. + EXPECT_EQ(get_clients_from_host().size(), 1U); + + ShutdownHost(); + message_loop_.Run(); +} + TEST_F(ChromotingHostTest, OnSessionRouteChange) { std::string channel_name("ChannelName"); protocol::TransportRoute route; diff --git a/remoting/host/client_session.cc b/remoting/host/client_session.cc index 7a3c76c..2661e0e 100644 --- a/remoting/host/client_session.cc +++ b/remoting/host/client_session.cc @@ -210,6 +210,11 @@ void ClientSession::DeliverClientMessage( << message.type() << ": " << message.data(); } +void ClientSession::OnConnectionAuthenticating( + protocol::ConnectionToClient* connection) { + event_handler_->OnSessionAuthenticating(this); +} + void ClientSession::OnConnectionAuthenticated( protocol::ConnectionToClient* connection) { DCHECK(CalledOnValidThread()); diff --git a/remoting/host/client_session.h b/remoting/host/client_session.h index ef75b25..f892329 100644 --- a/remoting/host/client_session.h +++ b/remoting/host/client_session.h @@ -54,6 +54,9 @@ class ClientSession // Callback interface for passing events to the ChromotingHost. class EventHandler { public: + // Called after authentication has started. + virtual void OnSessionAuthenticating(ClientSession* client) = 0; + // Called after authentication has finished successfully. Returns true if // the connection is allowed, or false otherwise. virtual bool OnSessionAuthenticated(ClientSession* client) = 0; @@ -115,6 +118,8 @@ class ClientSession const protocol::ExtensionMessage& message) OVERRIDE; // protocol::ConnectionToClient::EventHandler interface. + virtual void OnConnectionAuthenticating( + protocol::ConnectionToClient* connection) OVERRIDE; virtual void OnConnectionAuthenticated( protocol::ConnectionToClient* connection) OVERRIDE; virtual void OnConnectionChannelsConnected( diff --git a/remoting/host/host_mock_objects.h b/remoting/host/host_mock_objects.h index d365695..d916f04 100644 --- a/remoting/host/host_mock_objects.h +++ b/remoting/host/host_mock_objects.h @@ -68,6 +68,7 @@ class MockClientSessionEventHandler : public ClientSession::EventHandler { MockClientSessionEventHandler(); virtual ~MockClientSessionEventHandler(); + MOCK_METHOD1(OnSessionAuthenticating, void(ClientSession* client)); MOCK_METHOD1(OnSessionAuthenticated, bool(ClientSession* client)); MOCK_METHOD1(OnSessionChannelsConnected, void(ClientSession* client)); MOCK_METHOD1(OnSessionAuthenticationFailed, void(ClientSession* client)); diff --git a/remoting/host/pam_authorization_factory_posix.cc b/remoting/host/pam_authorization_factory_posix.cc index c89c71f..eef0c48 100644 --- a/remoting/host/pam_authorization_factory_posix.cc +++ b/remoting/host/pam_authorization_factory_posix.cc @@ -24,6 +24,7 @@ class PamAuthorizer : public protocol::Authenticator { // protocol::Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -62,6 +63,10 @@ protocol::Authenticator::State PamAuthorizer::state() const { } } +bool PamAuthorizer::started() const { + return underlying_->started(); +} + protocol::Authenticator::RejectionReason PamAuthorizer::rejection_reason() const { if (local_login_status_ == DISALLOWED) { diff --git a/remoting/protocol/authenticator.h b/remoting/protocol/authenticator.h index 28288e5..1210989 100644 --- a/remoting/protocol/authenticator.h +++ b/remoting/protocol/authenticator.h @@ -89,6 +89,11 @@ class Authenticator { // Returns current state of the authenticator. virtual State state() const = 0; + // Returns whether authentication has started. The chromoting host uses this + // method to starts the back off process to prevent malicious clients from + // guessing the PIN by spamming the host with auth requests. + virtual bool started() const = 0; + // Returns rejection reason. Can be called only when in REJECTED state. virtual RejectionReason rejection_reason() const = 0; diff --git a/remoting/protocol/authenticator_test_base.cc b/remoting/protocol/authenticator_test_base.cc index 25b0efe..1b65dbe 100644 --- a/remoting/protocol/authenticator_test_base.cc +++ b/remoting/protocol/authenticator_test_base.cc @@ -60,22 +60,43 @@ void AuthenticatorTestBase::SetUp() { } void AuthenticatorTestBase::RunAuthExchange() { - ContinueAuthExchangeWith(client_.get(), host_.get()); + ContinueAuthExchangeWith(client_.get(), + host_.get(), + client_->started(), + host_->started()); } void AuthenticatorTestBase::RunHostInitiatedAuthExchange() { - ContinueAuthExchangeWith(host_.get(), client_.get()); + ContinueAuthExchangeWith(host_.get(), + client_.get(), + host_->started(), + client_->started()); } // static +// This function sends a message from the sender and receiver and recursively +// calls itself to the send the next message from the receiver to the sender +// untils the authentication completes. void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, - Authenticator* receiver) { + Authenticator* receiver, + bool sender_started, + bool receiver_started) { scoped_ptr<buzz::XmlElement> message; ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state()); if (sender->state() == Authenticator::ACCEPTED || sender->state() == Authenticator::REJECTED) return; - // Pass message from client to host. + + // Verify that once the started flag for either party is set to true, + // it should always stay true. + if (receiver_started) { + ASSERT_TRUE(receiver->started()); + } + + if (sender_started) { + ASSERT_TRUE(sender->started()); + } + ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state()); message = sender->GetNextMessage(); ASSERT_TRUE(message.get()); @@ -84,7 +105,8 @@ void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state()); receiver->ProcessMessage(message.get(), base::Bind( &AuthenticatorTestBase::ContinueAuthExchangeWith, - base::Unretained(receiver), base::Unretained(sender))); + base::Unretained(receiver), base::Unretained(sender), + receiver->started(), sender->started())); } void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { diff --git a/remoting/protocol/authenticator_test_base.h b/remoting/protocol/authenticator_test_base.h index 9b299db..e20774a 100644 --- a/remoting/protocol/authenticator_test_base.h +++ b/remoting/protocol/authenticator_test_base.h @@ -41,7 +41,9 @@ class AuthenticatorTestBase : public testing::Test { }; static void ContinueAuthExchangeWith(Authenticator* sender, - Authenticator* receiver); + Authenticator* receiver, + bool sender_started, + bool receiver_srated); virtual void SetUp() OVERRIDE; void RunAuthExchange(); void RunHostInitiatedAuthExchange(); diff --git a/remoting/protocol/connection_to_client.cc b/remoting/protocol/connection_to_client.cc index dd32452..a6d3143 100644 --- a/remoting/protocol/connection_to_client.cc +++ b/remoting/protocol/connection_to_client.cc @@ -112,7 +112,9 @@ void ConnectionToClient::OnSessionStateChange(Session::State state) { case Session::CONNECTED: // Don't care about these events. break; - + case Session::AUTHENTICATING: + handler_->OnConnectionAuthenticating(this); + break; case Session::AUTHENTICATED: // Initialize channels. control_dispatcher_.reset(new HostControlDispatcher()); diff --git a/remoting/protocol/connection_to_client.h b/remoting/protocol/connection_to_client.h index 9865307..9a64dcd 100644 --- a/remoting/protocol/connection_to_client.h +++ b/remoting/protocol/connection_to_client.h @@ -38,6 +38,9 @@ class ConnectionToClient : public base::NonThreadSafe, public: class EventHandler { public: + // Called when the network connection is authenticating + virtual void OnConnectionAuthenticating(ConnectionToClient* connection) = 0; + // Called when the network connection is authenticated. virtual void OnConnectionAuthenticated(ConnectionToClient* connection) = 0; diff --git a/remoting/protocol/connection_to_host.cc b/remoting/protocol/connection_to_host.cc index cdaf4b6..978da56 100644 --- a/remoting/protocol/connection_to_host.cc +++ b/remoting/protocol/connection_to_host.cc @@ -152,6 +152,7 @@ void ConnectionToHost::OnSessionStateChange( case Session::CONNECTING: case Session::ACCEPTING: case Session::CONNECTED: + case Session::AUTHENTICATING: // Don't care about these events. break; diff --git a/remoting/protocol/fake_authenticator.cc b/remoting/protocol/fake_authenticator.cc index 67f5d35..67c1cca 100644 --- a/remoting/protocol/fake_authenticator.cc +++ b/remoting/protocol/fake_authenticator.cc @@ -84,12 +84,17 @@ FakeAuthenticator::FakeAuthenticator( round_trips_(round_trips), action_(action), async_(async), - messages_(0) { + messages_(0), + messages_till_started_(0) { } FakeAuthenticator::~FakeAuthenticator() { } +void FakeAuthenticator::set_messages_till_started(int messages) { + messages_till_started_ = messages; +} + Authenticator::State FakeAuthenticator::state() const { EXPECT_LE(messages_, round_trips_ * 2); if (messages_ >= round_trips_ * 2) { @@ -116,6 +121,10 @@ Authenticator::State FakeAuthenticator::state() const { } } +bool FakeAuthenticator::started() const { + return messages_ > messages_till_started_; +} + Authenticator::RejectionReason FakeAuthenticator::rejection_reason() const { EXPECT_EQ(REJECTED, state()); return INVALID_CREDENTIALS; @@ -153,8 +162,10 @@ FakeAuthenticator::CreateChannelAuthenticator() const { } FakeHostAuthenticatorFactory::FakeHostAuthenticatorFactory( - int round_trips, FakeAuthenticator::Action action, bool async) + int round_trips, int messages_till_started, + FakeAuthenticator::Action action, bool async) : round_trips_(round_trips), + messages_till_started_(messages_till_started), action_(action), async_(async) { } @@ -165,8 +176,12 @@ scoped_ptr<Authenticator> FakeHostAuthenticatorFactory::CreateAuthenticator( const std::string& local_jid, const std::string& remote_jid, const buzz::XmlElement* first_message) { - return scoped_ptr<Authenticator>(new FakeAuthenticator( - FakeAuthenticator::HOST, round_trips_, action_, async_)); + FakeAuthenticator* authenticator = new FakeAuthenticator( + FakeAuthenticator::HOST, round_trips_, action_, async_); + authenticator->set_messages_till_started(messages_till_started_); + + scoped_ptr<Authenticator> result(authenticator); + return result.Pass(); } } // namespace protocol diff --git a/remoting/protocol/fake_authenticator.h b/remoting/protocol/fake_authenticator.h index a5a2de8..b74654d0 100644 --- a/remoting/protocol/fake_authenticator.h +++ b/remoting/protocol/fake_authenticator.h @@ -59,10 +59,16 @@ class FakeAuthenticator : public Authenticator { }; FakeAuthenticator(Type type, int round_trips, Action action, bool async); + virtual ~FakeAuthenticator(); + // Set the number of messages that the authenticator needs to process before + // started() returns true. Default to 0. + void set_messages_till_started(int messages); + // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -78,6 +84,9 @@ class FakeAuthenticator : public Authenticator { // Total number of messages that have been processed. int messages_; + // Number of messages that the authenticator needs to process before started() + // returns true. Default to 0. + int messages_till_started_; DISALLOW_COPY_AND_ASSIGN(FakeAuthenticator); }; @@ -85,7 +94,8 @@ class FakeAuthenticator : public Authenticator { class FakeHostAuthenticatorFactory : public AuthenticatorFactory { public: FakeHostAuthenticatorFactory( - int round_trips, FakeAuthenticator::Action action, bool async); + int round_trips, int messages_till_start, + FakeAuthenticator::Action action, bool async); virtual ~FakeHostAuthenticatorFactory(); // AuthenticatorFactory interface. @@ -96,6 +106,7 @@ class FakeHostAuthenticatorFactory : public AuthenticatorFactory { private: int round_trips_; + int messages_till_started_; FakeAuthenticator::Action action_; bool async_; diff --git a/remoting/protocol/jingle_session.cc b/remoting/protocol/jingle_session.cc index 2aecd4c..57883cb 100644 --- a/remoting/protocol/jingle_session.cc +++ b/remoting/protocol/jingle_session.cc @@ -6,8 +6,10 @@ #include "base/bind.h" #include "base/rand_util.h" +#include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/strings/string_number_conversions.h" +#include "base/thread_task_runner_handle.h" #include "base/time/time.h" #include "remoting/base/constants.h" #include "remoting/jingle_glue/iq_sender.h" @@ -63,7 +65,8 @@ JingleSession::JingleSession(JingleSessionManager* session_manager) event_handler_(NULL), state_(INITIALIZING), error_(OK), - config_is_set_(false) { + config_is_set_(false), + weak_factory_(this) { } JingleSession::~JingleSession() { @@ -181,9 +184,10 @@ void JingleSession::ContinueAcceptIncomingConnection() { SetState(AUTHENTICATED); } else { DCHECK_EQ(authenticator_->state(), Authenticator::WAITING_MESSAGE); + if (authenticator_->started()) { + SetState(AUTHENTICATING); + } } - - return; } const std::string& JingleSession::jid() { @@ -485,7 +489,7 @@ void JingleSession::OnSessionInfo(const JingleMessage& message, return; } - if (state_ != CONNECTED || + if ((state_ != CONNECTED && state_ != AUTHENTICATING) || authenticator_->state() != Authenticator::WAITING_MESSAGE) { LOG(WARNING) << "Received unexpected authenticator message " << message.info->Str(); @@ -517,8 +521,7 @@ void JingleSession::ProcessTransportInfo(const JingleMessage& message) { void JingleSession::OnTerminate(const JingleMessage& message, const ReplyCallback& reply_callback) { - if (state_ != CONNECTING && state_ != ACCEPTING && state_ != CONNECTED && - state_ != AUTHENTICATED) { + if (!is_session_active()) { LOG(WARNING) << "Received unexpected session-terminate message."; reply_callback.Run(JingleMessageReply::UNEXPECTED_REQUEST); return; @@ -577,7 +580,7 @@ void JingleSession::ProcessAuthenticationStep() { DCHECK(CalledOnValidThread()); DCHECK_NE(authenticator_->state(), Authenticator::PROCESSING_MESSAGE); - if (state_ != CONNECTED) { + if (state_ != CONNECTED && state_ != AUTHENTICATING) { DCHECK(state_ == FAILED || state_ == CLOSED); // The remote host closed the connection while the authentication was being // processed asynchronously, nothing to do. @@ -592,6 +595,21 @@ void JingleSession::ProcessAuthenticationStep() { } DCHECK_NE(authenticator_->state(), Authenticator::MESSAGE_READY); + // The current JingleSession object can be destroyed by event_handler of + // SetState(AUTHENTICATING) and cause subsequent dereferencing of the this + // pointer to crash. To protect against it, we run ContinueAuthenticationStep + // asychronously using a weak pointer. + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, + base::Bind(&JingleSession::ContinueAuthenticationStep, + weak_factory_.GetWeakPtr())); + + if (authenticator_->started()) { + SetState(AUTHENTICATING); + } +} + +void JingleSession::ContinueAuthenticationStep() { if (authenticator_->state() == Authenticator::ACCEPTED) { SetState(AUTHENTICATED); } else if (authenticator_->state() == Authenticator::REJECTED) { @@ -603,8 +621,7 @@ void JingleSession::ProcessAuthenticationStep() { void JingleSession::CloseInternal(ErrorCode error) { DCHECK(CalledOnValidThread()); - if (state_ == CONNECTING || state_ == ACCEPTING || state_ == CONNECTED || - state_ == AUTHENTICATED) { + if (is_session_active()) { // Send session-terminate message with the appropriate error code. JingleMessage::Reason reason; switch (error) { @@ -655,5 +672,10 @@ void JingleSession::SetState(State new_state) { } } +bool JingleSession::is_session_active() { + return state_ == CONNECTING || state_ == ACCEPTING || state_ == CONNECTED || + state_ == AUTHENTICATING || state_ == AUTHENTICATED; +} + } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/jingle_session.h b/remoting/protocol/jingle_session.h index 189cb53..3b704fb 100644 --- a/remoting/protocol/jingle_session.h +++ b/remoting/protocol/jingle_session.h @@ -132,9 +132,13 @@ class JingleSession : public Session, // Called after the initial incoming authenticator message is processed. void ContinueAcceptIncomingConnection(); + // Called after subsequent authenticator messages are processed. void ProcessAuthenticationStep(); + // Called after the authenticating step is finished. + void ContinueAuthenticationStep(); + // Terminates the session and sends session-terminate if it is // necessary. |error| specifies the error code in case when the // session is being closed due to an error. @@ -143,6 +147,9 @@ class JingleSession : public Session, // Sets |state_| to |new_state| and calls state change callback. void SetState(State new_state); + // Returns true if the state of the session is not CLOSED or FAILED + bool is_session_active(); + JingleSessionManager* session_manager_; std::string peer_jid_; scoped_ptr<CandidateSessionConfig> candidate_config_; @@ -172,6 +179,8 @@ class JingleSession : public Session, // Pending remote candidates, received before the local channels were created. std::list<JingleMessage::NamedCandidate> pending_remote_candidates_; + base::WeakPtrFactory<JingleSession> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(JingleSession); }; diff --git a/remoting/protocol/jingle_session_unittest.cc b/remoting/protocol/jingle_session_unittest.cc index dcca077..2f95100 100644 --- a/remoting/protocol/jingle_session_unittest.cc +++ b/remoting/protocol/jingle_session_unittest.cc @@ -105,6 +105,10 @@ class JingleSessionTest : public testing::Test { session->set_config(SessionConfig::ForTest()); } + void DeleteSession() { + host_session_.reset(); + } + void OnClientChannelCreated(scoped_ptr<net::StreamSocket> socket) { client_channel_callback_.OnDone(socket.get()); client_socket_ = socket.Pass(); @@ -132,7 +136,7 @@ class JingleSessionTest : public testing::Test { client_session_.reset(); } - void CreateSessionManagers(int auth_round_trips, + void CreateSessionManagers(int auth_round_trips, int messages_till_start, FakeAuthenticator::Action auth_action) { host_signal_strategy_.reset(new FakeSignalStrategy(kHostJid)); client_signal_strategy_.reset(new FakeSignalStrategy(kClientJid)); @@ -153,7 +157,8 @@ class JingleSessionTest : public testing::Test { host_server_->Init(host_signal_strategy_.get(), &host_server_listener_); scoped_ptr<AuthenticatorFactory> factory( - new FakeHostAuthenticatorFactory(auth_round_trips, auth_action, true)); + new FakeHostAuthenticatorFactory(auth_round_trips, + messages_till_start, auth_action, true)); host_server_->set_authenticator_factory(factory.Pass()); EXPECT_CALL(client_server_listener_, OnSessionManagerReady()) @@ -169,6 +174,11 @@ class JingleSessionTest : public testing::Test { &client_server_listener_); } + void CreateSessionManagers(int auth_round_trips, + FakeAuthenticator::Action auth_action) { + CreateSessionManagers(auth_round_trips, 0, auth_action); + } + void CloseSessionManager() { if (host_server_.get()) { host_server_->Close(); @@ -196,6 +206,9 @@ class JingleSessionTest : public testing::Test { EXPECT_CALL(host_session_event_handler_, OnSessionStateChange(Session::CONNECTED)) .Times(AtMost(1)); + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .Times(AtMost(1)); if (expect_fail) { EXPECT_CALL(host_session_event_handler_, OnSessionStateChange(Session::FAILED)) @@ -217,6 +230,9 @@ class JingleSessionTest : public testing::Test { EXPECT_CALL(client_session_event_handler_, OnSessionStateChange(Session::CONNECTED)) .Times(AtMost(1)); + EXPECT_CALL(client_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .Times(AtMost(1)); if (expect_fail) { EXPECT_CALL(client_session_event_handler_, OnSessionStateChange(Session::FAILED)) @@ -375,6 +391,60 @@ TEST_F(JingleSessionTest, TestStreamChannel) { tester.CheckResults(); } +TEST_F(JingleSessionTest, DeleteSessionOnIncomingConnection) { + CreateSessionManagers(3, FakeAuthenticator::ACCEPT); + + EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _)) + .WillOnce(DoAll( + WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)), + SetArgumentPointee<1>(protocol::SessionManager::ACCEPT))); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::CONNECTED)) + .Times(AtMost(1)); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .WillOnce(InvokeWithoutArgs(this, &JingleSessionTest::DeleteSession)); + + scoped_ptr<Authenticator> authenticator(new FakeAuthenticator( + FakeAuthenticator::CLIENT, 3, FakeAuthenticator::ACCEPT, true)); + + client_session_ = client_server_->Connect( + kHostJid, authenticator.Pass(), + CandidateSessionConfig::CreateDefault()); + + base::RunLoop().RunUntilIdle(); +} + +TEST_F(JingleSessionTest, DeleteSessionOnAuth) { + // Same as the previous test, but set messages_till_started to 2 in + // CreateSessionManagers so that the session will goes into the + // AUTHENTICATING state after two message exchanges. + CreateSessionManagers(3, 2, FakeAuthenticator::ACCEPT); + + EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _)) + .WillOnce(DoAll( + WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)), + SetArgumentPointee<1>(protocol::SessionManager::ACCEPT))); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::CONNECTED)) + .Times(AtMost(1)); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .WillOnce(InvokeWithoutArgs(this, &JingleSessionTest::DeleteSession)); + + scoped_ptr<Authenticator> authenticator(new FakeAuthenticator( + FakeAuthenticator::CLIENT, 3, FakeAuthenticator::ACCEPT, true)); + + client_session_ = client_server_->Connect( + kHostJid, authenticator.Pass(), + CandidateSessionConfig::CreateDefault()); + base::RunLoop().RunUntilIdle(); +} + // Verify that data can be sent over a multiplexed channel. TEST_F(JingleSessionTest, TestMuxStreamChannel) { CreateSessionManagers(1, FakeAuthenticator::ACCEPT); diff --git a/remoting/protocol/me2me_host_authenticator_factory.cc b/remoting/protocol/me2me_host_authenticator_factory.cc index 24c0ca4..7e407c8 100644 --- a/remoting/protocol/me2me_host_authenticator_factory.cc +++ b/remoting/protocol/me2me_host_authenticator_factory.cc @@ -30,6 +30,10 @@ class RejectingAuthenticator : public Authenticator { return state_; } + virtual bool started() const OVERRIDE { + return true; + } + virtual RejectionReason rejection_reason() const OVERRIDE { DCHECK_EQ(state_, REJECTED); return INVALID_CREDENTIALS; diff --git a/remoting/protocol/negotiating_authenticator_base.cc b/remoting/protocol/negotiating_authenticator_base.cc index 30bd8ba..5c61d30 100644 --- a/remoting/protocol/negotiating_authenticator_base.cc +++ b/remoting/protocol/negotiating_authenticator_base.cc @@ -40,6 +40,13 @@ Authenticator::State NegotiatingAuthenticatorBase::state() const { return state_; } +bool NegotiatingAuthenticatorBase::started() const { + if (!current_authenticator_) { + return false; + } + return current_authenticator_->started(); +} + Authenticator::RejectionReason NegotiatingAuthenticatorBase::rejection_reason() const { return rejection_reason_; diff --git a/remoting/protocol/negotiating_authenticator_base.h b/remoting/protocol/negotiating_authenticator_base.h index 1f80967..bcb005f2 100644 --- a/remoting/protocol/negotiating_authenticator_base.h +++ b/remoting/protocol/negotiating_authenticator_base.h @@ -64,6 +64,7 @@ class NegotiatingAuthenticatorBase : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual scoped_ptr<ChannelAuthenticator> CreateChannelAuthenticator() const OVERRIDE; diff --git a/remoting/protocol/pairing_authenticator_base.cc b/remoting/protocol/pairing_authenticator_base.cc index 47ca275..7435e55 100644 --- a/remoting/protocol/pairing_authenticator_base.cc +++ b/remoting/protocol/pairing_authenticator_base.cc @@ -38,6 +38,13 @@ Authenticator::State PairingAuthenticatorBase::state() const { return v2_authenticator_->state(); } +bool PairingAuthenticatorBase::started() const { + if (!v2_authenticator_) { + return false; + } + return v2_authenticator_->started(); +} + Authenticator::RejectionReason PairingAuthenticatorBase::rejection_reason() const { if (!v2_authenticator_) { diff --git a/remoting/protocol/pairing_authenticator_base.h b/remoting/protocol/pairing_authenticator_base.h index 2a4bc4e..4cb042c 100644 --- a/remoting/protocol/pairing_authenticator_base.h +++ b/remoting/protocol/pairing_authenticator_base.h @@ -43,6 +43,7 @@ class PairingAuthenticatorBase : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index fecf1d8..07e03306 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -52,6 +52,8 @@ class MockConnectionToClientEventHandler : MockConnectionToClientEventHandler(); virtual ~MockConnectionToClientEventHandler(); + MOCK_METHOD1(OnConnectionAuthenticating, + void(ConnectionToClient* connection)); MOCK_METHOD1(OnConnectionAuthenticated, void(ConnectionToClient* connection)); MOCK_METHOD1(OnConnectionChannelsConnected, void(ConnectionToClient* connection)); diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h index a542283..232a1d04 100644 --- a/remoting/protocol/session.h +++ b/remoting/protocol/session.h @@ -38,6 +38,9 @@ class Session { // Session has been accepted and is pending authentication. CONNECTED, + // Session has started authenticating. + AUTHENTICATING, + // Session has been connected and authenticated. AUTHENTICATED, @@ -54,8 +57,8 @@ class Session { virtual ~EventHandler() {} // Called after session state has changed. It is safe to destroy - // the session from within the handler if |state| is CLOSED or - // FAILED. + // the session from within the handler if |state| is AUTHENTICATING + // or CLOSED or FAILED. virtual void OnSessionStateChange(State state) = 0; // Called whenever route for the channel specified with diff --git a/remoting/protocol/third_party_authenticator_base.cc b/remoting/protocol/third_party_authenticator_base.cc index 9d6be72..019e788 100644 --- a/remoting/protocol/third_party_authenticator_base.cc +++ b/remoting/protocol/third_party_authenticator_base.cc @@ -28,12 +28,17 @@ const buzz::StaticQName ThirdPartyAuthenticatorBase::kTokenTag = ThirdPartyAuthenticatorBase::ThirdPartyAuthenticatorBase( Authenticator::State initial_state) : token_state_(initial_state), + started_(false), rejection_reason_(INVALID_CREDENTIALS) { } ThirdPartyAuthenticatorBase::~ThirdPartyAuthenticatorBase() { } +bool ThirdPartyAuthenticatorBase::started() const { + return started_; +} + Authenticator::State ThirdPartyAuthenticatorBase::state() const { if (token_state_ == ACCEPTED) return underlying_->state(); @@ -74,9 +79,10 @@ scoped_ptr<buzz::XmlElement> ThirdPartyAuthenticatorBase::GetNextMessage() { message = CreateEmptyAuthenticatorMessage(); } - if (token_state_ == MESSAGE_READY) + if (token_state_ == MESSAGE_READY) { AddTokenElements(message.get()); - + started_ = true; + } return message.Pass(); } diff --git a/remoting/protocol/third_party_authenticator_base.h b/remoting/protocol/third_party_authenticator_base.h index 0db203d..7141be8 100644 --- a/remoting/protocol/third_party_authenticator_base.h +++ b/remoting/protocol/third_party_authenticator_base.h @@ -36,6 +36,7 @@ class ThirdPartyAuthenticatorBase : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -66,6 +67,7 @@ class ThirdPartyAuthenticatorBase : public Authenticator { scoped_ptr<Authenticator> underlying_; State token_state_; + bool started_; RejectionReason rejection_reason_; private: diff --git a/remoting/protocol/v2_authenticator.cc b/remoting/protocol/v2_authenticator.cc index ee5c9d1..1b13c7c 100644 --- a/remoting/protocol/v2_authenticator.cc +++ b/remoting/protocol/v2_authenticator.cc @@ -62,6 +62,7 @@ V2Authenticator::V2Authenticator( : certificate_sent_(false), key_exchange_impl_(type, shared_secret), state_(initial_state), + started_(false), rejection_reason_(INVALID_CREDENTIALS) { pending_messages_.push(key_exchange_impl_.GetMessage()); } @@ -75,6 +76,10 @@ Authenticator::State V2Authenticator::state() const { return state_; } +bool V2Authenticator::started() const { + return started_; +} + Authenticator::RejectionReason V2Authenticator::rejection_reason() const { DCHECK_EQ(state(), REJECTED); return rejection_reason_; @@ -127,6 +132,7 @@ void V2Authenticator::ProcessMessageInternal(const buzz::XmlElement* message) { P224EncryptedKeyExchange::Result result = key_exchange_impl_.ProcessMessage(spake_message); + started_ = true; switch (result) { case P224EncryptedKeyExchange::kResultPending: pending_messages_.push(key_exchange_impl_.GetMessage()); @@ -143,7 +149,6 @@ void V2Authenticator::ProcessMessageInternal(const buzz::XmlElement* message) { return; } } - state_ = MESSAGE_READY; } diff --git a/remoting/protocol/v2_authenticator.h b/remoting/protocol/v2_authenticator.h index 0675cb2..b52a3e1 100644 --- a/remoting/protocol/v2_authenticator.h +++ b/remoting/protocol/v2_authenticator.h @@ -38,6 +38,7 @@ class V2Authenticator : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -67,6 +68,7 @@ class V2Authenticator : public Authenticator { // Used for both host and client authenticators. crypto::P224EncryptedKeyExchange key_exchange_impl_; State state_; + bool started_; RejectionReason rejection_reason_; std::queue<std::string> pending_messages_; std::string auth_key_; |