diff options
author | kelvinp@chromium.org <kelvinp@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-04-07 22:33:28 +0000 |
---|---|---|
committer | kelvinp@chromium.org <kelvinp@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-04-07 22:33:28 +0000 |
commit | 064128c1d8c7a3fb1d4ceaae996891130f2cf171 (patch) | |
tree | ef143838848f3d3d28c7c8d327cc21828c3e43a1 /remoting | |
parent | 57baec29d609c0e1ae53dea806e5fbcc70f83ed6 (diff) | |
download | chromium_src-064128c1d8c7a3fb1d4ceaae996891130f2cf171.zip chromium_src-064128c1d8c7a3fb1d4ceaae996891130f2cf171.tar.gz chromium_src-064128c1d8c7a3fb1d4ceaae996891130f2cf171.tar.bz2 |
Cause:
To prevent a malicious client from guessing the PIN by spamming the host with bogus logins, the chromoting host can throttle incoming requests after too many unsuccessful login attempts. In the current implementation, every time when there is an incoming request, we start incrementing the bad login counter, regardless of whether the host has actually starts authenticating.
Fix:
This change adds an extra flag on the authenticator to indicate whether authentication has started.
The JingleSession checks the flag and progagates the message back all the way up to the host through the callback Session::OnSessionAuthenticationBegin
BUG=350208
Review URL: https://codereview.chromium.org/205583011
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@262228 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting')
29 files changed, 332 insertions, 49 deletions
diff --git a/remoting/host/chromoting_host.cc b/remoting/host/chromoting_host.cc index e4bdae7..469c5b2 100644 --- a/remoting/host/chromoting_host.cc +++ b/remoting/host/chromoting_host.cc @@ -171,6 +171,20 @@ void ChromotingHost::SetMaximumSessionDuration( //////////////////////////////////////////////////////////////////////////// // protocol::ClientSession::EventHandler implementation. +void ChromotingHost::OnSessionAuthenticating(ClientSession* client) { + // We treat each incoming connection as a failure to authenticate, + // and clear the backoff when a connection successfully + // authenticates. This allows the backoff to protect from parallel + // connection attempts as well as sequential ones. + if (login_backoff_.ShouldRejectRequest()) { + LOG(WARNING) << "Disconnecting client " << client->client_jid() << " due to" + " an overload of failed login attempts."; + client->DisconnectSession(); + return; + } + login_backoff_.InformOfRequest(false); +} + bool ChromotingHost::OnSessionAuthenticated(ClientSession* client) { DCHECK(CalledOnValidThread()); @@ -265,16 +279,12 @@ void ChromotingHost::OnIncomingSession( } if (login_backoff_.ShouldRejectRequest()) { + LOG(WARNING) << "Rejecting connection due to" + " an overload of failed login attempts."; *response = protocol::SessionManager::OVERLOAD; return; } - // We treat each incoming connection as a failure to authenticate, - // and clear the backoff when a connection successfully - // authenticates. This allows the backoff to protect from parallel - // connection attempts as well as sequential ones. - login_backoff_.InformOfRequest(false); - protocol::SessionConfig config; if (!protocol_config_->Select(session->candidate_config(), &config)) { LOG(WARNING) << "Rejecting connection from " << session->jid() diff --git a/remoting/host/chromoting_host.h b/remoting/host/chromoting_host.h index 3e1fa74..a6642d4 100644 --- a/remoting/host/chromoting_host.h +++ b/remoting/host/chromoting_host.h @@ -115,6 +115,7 @@ class ChromotingHost : public base::NonThreadSafe, //////////////////////////////////////////////////////////////////////////// // ClientSession::EventHandler implementation. + virtual void OnSessionAuthenticating(ClientSession* client) OVERRIDE; virtual bool OnSessionAuthenticated(ClientSession* client) OVERRIDE; virtual void OnSessionChannelsConnected(ClientSession* client) OVERRIDE; virtual void OnSessionAuthenticationFailed(ClientSession* client) OVERRIDE; diff --git a/remoting/host/chromoting_host_unittest.cc b/remoting/host/chromoting_host_unittest.cc index 1410371..5ebe051 100644 --- a/remoting/host/chromoting_host_unittest.cc +++ b/remoting/host/chromoting_host_unittest.cc @@ -28,12 +28,13 @@ using ::remoting::protocol::MockConnectionToClientEventHandler; using ::remoting::protocol::MockHostStub; using ::remoting::protocol::MockSession; using ::remoting::protocol::MockVideoStub; +using ::remoting::protocol::Session; using ::remoting::protocol::SessionConfig; using testing::_; using testing::AnyNumber; -using testing::AtMost; using testing::AtLeast; +using testing::AtMost; using testing::CreateFunctor; using testing::DeleteArg; using testing::DoAll; @@ -44,6 +45,7 @@ using testing::InvokeArgument; using testing::InvokeWithoutArgs; using testing::Return; using testing::ReturnRef; +using testing::SaveArg; using testing::Sequence; namespace remoting { @@ -124,9 +126,10 @@ class ChromotingHostTest : public testing::Test { .Times(AnyNumber()); EXPECT_CALL(*session_unowned1_, SetEventHandler(_)) .Times(AnyNumber()) - .WillRepeatedly(Invoke(this, &ChromotingHostTest::SetEventHandler)); + .WillRepeatedly(SaveArg<0>(&session_unowned1_event_handler_)); EXPECT_CALL(*session_unowned2_, SetEventHandler(_)) - .Times(AnyNumber()); + .Times(AnyNumber()) + .WillRepeatedly(SaveArg<0>(&session_unowned2_event_handler_)); EXPECT_CALL(*session1_, config()) .WillRepeatedly(ReturnRef(session_config1_)); EXPECT_CALL(*session2_, config()) @@ -287,13 +290,15 @@ class ChromotingHostTest : public testing::Test { get_connection(connection_index), protocol::OK); } - void SetEventHandler(protocol::Session::EventHandler* event_handler) { - session_event_handler_ = event_handler; + void NotifyConnectionClosed1() { + if (session_unowned1_event_handler_) { + session_unowned1_event_handler_->OnSessionStateChange(Session::CLOSED); + } } - void NotifyConnectionClosed() { - if (session_event_handler_) { - session_event_handler_->OnSessionStateChange(protocol::Session::CLOSED); + void NotifyConnectionClosed2() { + if (session_unowned2_event_handler_) { + session_unowned2_event_handler_->OnSessionStateChange(Session::CLOSED); } } @@ -424,7 +429,8 @@ class ChromotingHostTest : public testing::Test { scoped_ptr<MockSession> session_unowned2_; // Not owned by a connection. SessionConfig session_unowned_config2_; std::string session_unowned_jid2_; - protocol::Session::EventHandler* session_event_handler_; + protocol::Session::EventHandler* session_unowned1_event_handler_; + protocol::Session::EventHandler* session_unowned2_event_handler_; scoped_ptr<protocol::CandidateSessionConfig> empty_candidate_config_; scoped_ptr<protocol::CandidateSessionConfig> default_candidate_config_; @@ -432,10 +438,16 @@ class ChromotingHostTest : public testing::Test { return (connection_index == 0) ? connection1_ : connection2_; } + // Returns the cached client pointers client1_ or client2_. ClientSession*& get_client(int connection_index) { return (connection_index == 0) ? client1_ : client2_; } + // Returns the list of clients of the host_. + std::list<ClientSession*>& get_clients_from_host() { + return host_->clients_; + } + const std::string& get_session_jid(int connection_index) { return (connection_index == 0) ? session_jid1_ : session_jid2_; } @@ -578,7 +590,7 @@ TEST_F(ChromotingHostTest, IncomingSessionAccepted) { default_candidate_config_.get())); EXPECT_CALL(*session_unowned1_, set_config(_)); EXPECT_CALL(*session_unowned1_, Close()).WillOnce(InvokeWithoutArgs( - this, &ChromotingHostTest::NotifyConnectionClosed)); + this, &ChromotingHostTest::NotifyConnectionClosed1)); EXPECT_CALL(host_status_observer_, OnAccessDenied(_)); EXPECT_CALL(host_status_observer_, OnShutdown()); @@ -593,13 +605,13 @@ TEST_F(ChromotingHostTest, IncomingSessionAccepted) { message_loop_.Run(); } -TEST_F(ChromotingHostTest, IncomingSessionOverload) { +TEST_F(ChromotingHostTest, LoginBackOffUponConnection) { ExpectHostAndSessionManagerStart(); - EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce(Return( - default_candidate_config_.get())); + EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce( + Return(default_candidate_config_.get())); EXPECT_CALL(*session_unowned1_, set_config(_)); - EXPECT_CALL(*session_unowned1_, Close()).WillOnce(InvokeWithoutArgs( - this, &ChromotingHostTest::NotifyConnectionClosed)); + EXPECT_CALL(*session_unowned1_, Close()).WillOnce( + InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed1)); EXPECT_CALL(host_status_observer_, OnAccessDenied(_)); EXPECT_CALL(host_status_observer_, OnShutdown()); @@ -607,9 +619,11 @@ TEST_F(ChromotingHostTest, IncomingSessionOverload) { protocol::SessionManager::IncomingSessionResponse response = protocol::SessionManager::DECLINE; + host_->OnIncomingSession(session_unowned1_.release(), &response); EXPECT_EQ(protocol::SessionManager::ACCEPT, response); + host_->OnSessionAuthenticating(get_clients_from_host().front()); host_->OnIncomingSession(session_unowned2_.get(), &response); EXPECT_EQ(protocol::SessionManager::OVERLOAD, response); @@ -617,6 +631,46 @@ TEST_F(ChromotingHostTest, IncomingSessionOverload) { message_loop_.Run(); } +TEST_F(ChromotingHostTest, LoginBackOffUponAuthenticating) { + Expectation start = ExpectHostAndSessionManagerStart(); + EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce( + Return(default_candidate_config_.get())); + EXPECT_CALL(*session_unowned1_, set_config(_)); + EXPECT_CALL(*session_unowned1_, Close()).WillOnce( + InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed1)); + + EXPECT_CALL(*session_unowned2_, candidate_config()).WillOnce( + Return(default_candidate_config_.get())); + EXPECT_CALL(*session_unowned2_, set_config(_)); + EXPECT_CALL(*session_unowned2_, Close()).WillOnce( + InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed2)); + + EXPECT_CALL(host_status_observer_, OnShutdown()); + + host_->Start(xmpp_login_); + + protocol::SessionManager::IncomingSessionResponse response = + protocol::SessionManager::DECLINE; + + host_->OnIncomingSession(session_unowned1_.release(), &response); + EXPECT_EQ(protocol::SessionManager::ACCEPT, response); + + host_->OnIncomingSession(session_unowned2_.release(), &response); + EXPECT_EQ(protocol::SessionManager::ACCEPT, response); + + // This will set the backoff. + host_->OnSessionAuthenticating(get_clients_from_host().front()); + + // This should disconnect client2. + host_->OnSessionAuthenticating(get_clients_from_host().back()); + + // Verify that the host only has 1 client at this point. + EXPECT_EQ(get_clients_from_host().size(), 1U); + + ShutdownHost(); + message_loop_.Run(); +} + TEST_F(ChromotingHostTest, OnSessionRouteChange) { std::string channel_name("ChannelName"); protocol::TransportRoute route; diff --git a/remoting/host/client_session.cc b/remoting/host/client_session.cc index 7a3c76c..2661e0e 100644 --- a/remoting/host/client_session.cc +++ b/remoting/host/client_session.cc @@ -210,6 +210,11 @@ void ClientSession::DeliverClientMessage( << message.type() << ": " << message.data(); } +void ClientSession::OnConnectionAuthenticating( + protocol::ConnectionToClient* connection) { + event_handler_->OnSessionAuthenticating(this); +} + void ClientSession::OnConnectionAuthenticated( protocol::ConnectionToClient* connection) { DCHECK(CalledOnValidThread()); diff --git a/remoting/host/client_session.h b/remoting/host/client_session.h index ef75b25..f892329 100644 --- a/remoting/host/client_session.h +++ b/remoting/host/client_session.h @@ -54,6 +54,9 @@ class ClientSession // Callback interface for passing events to the ChromotingHost. class EventHandler { public: + // Called after authentication has started. + virtual void OnSessionAuthenticating(ClientSession* client) = 0; + // Called after authentication has finished successfully. Returns true if // the connection is allowed, or false otherwise. virtual bool OnSessionAuthenticated(ClientSession* client) = 0; @@ -115,6 +118,8 @@ class ClientSession const protocol::ExtensionMessage& message) OVERRIDE; // protocol::ConnectionToClient::EventHandler interface. + virtual void OnConnectionAuthenticating( + protocol::ConnectionToClient* connection) OVERRIDE; virtual void OnConnectionAuthenticated( protocol::ConnectionToClient* connection) OVERRIDE; virtual void OnConnectionChannelsConnected( diff --git a/remoting/host/host_mock_objects.h b/remoting/host/host_mock_objects.h index d365695..d916f04 100644 --- a/remoting/host/host_mock_objects.h +++ b/remoting/host/host_mock_objects.h @@ -68,6 +68,7 @@ class MockClientSessionEventHandler : public ClientSession::EventHandler { MockClientSessionEventHandler(); virtual ~MockClientSessionEventHandler(); + MOCK_METHOD1(OnSessionAuthenticating, void(ClientSession* client)); MOCK_METHOD1(OnSessionAuthenticated, bool(ClientSession* client)); MOCK_METHOD1(OnSessionChannelsConnected, void(ClientSession* client)); MOCK_METHOD1(OnSessionAuthenticationFailed, void(ClientSession* client)); diff --git a/remoting/host/pam_authorization_factory_posix.cc b/remoting/host/pam_authorization_factory_posix.cc index c89c71f..eef0c48 100644 --- a/remoting/host/pam_authorization_factory_posix.cc +++ b/remoting/host/pam_authorization_factory_posix.cc @@ -24,6 +24,7 @@ class PamAuthorizer : public protocol::Authenticator { // protocol::Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -62,6 +63,10 @@ protocol::Authenticator::State PamAuthorizer::state() const { } } +bool PamAuthorizer::started() const { + return underlying_->started(); +} + protocol::Authenticator::RejectionReason PamAuthorizer::rejection_reason() const { if (local_login_status_ == DISALLOWED) { diff --git a/remoting/protocol/authenticator.h b/remoting/protocol/authenticator.h index 28288e5..1210989 100644 --- a/remoting/protocol/authenticator.h +++ b/remoting/protocol/authenticator.h @@ -89,6 +89,11 @@ class Authenticator { // Returns current state of the authenticator. virtual State state() const = 0; + // Returns whether authentication has started. The chromoting host uses this + // method to starts the back off process to prevent malicious clients from + // guessing the PIN by spamming the host with auth requests. + virtual bool started() const = 0; + // Returns rejection reason. Can be called only when in REJECTED state. virtual RejectionReason rejection_reason() const = 0; diff --git a/remoting/protocol/authenticator_test_base.cc b/remoting/protocol/authenticator_test_base.cc index 25b0efe..1b65dbe 100644 --- a/remoting/protocol/authenticator_test_base.cc +++ b/remoting/protocol/authenticator_test_base.cc @@ -60,22 +60,43 @@ void AuthenticatorTestBase::SetUp() { } void AuthenticatorTestBase::RunAuthExchange() { - ContinueAuthExchangeWith(client_.get(), host_.get()); + ContinueAuthExchangeWith(client_.get(), + host_.get(), + client_->started(), + host_->started()); } void AuthenticatorTestBase::RunHostInitiatedAuthExchange() { - ContinueAuthExchangeWith(host_.get(), client_.get()); + ContinueAuthExchangeWith(host_.get(), + client_.get(), + host_->started(), + client_->started()); } // static +// This function sends a message from the sender and receiver and recursively +// calls itself to the send the next message from the receiver to the sender +// untils the authentication completes. void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, - Authenticator* receiver) { + Authenticator* receiver, + bool sender_started, + bool receiver_started) { scoped_ptr<buzz::XmlElement> message; ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state()); if (sender->state() == Authenticator::ACCEPTED || sender->state() == Authenticator::REJECTED) return; - // Pass message from client to host. + + // Verify that once the started flag for either party is set to true, + // it should always stay true. + if (receiver_started) { + ASSERT_TRUE(receiver->started()); + } + + if (sender_started) { + ASSERT_TRUE(sender->started()); + } + ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state()); message = sender->GetNextMessage(); ASSERT_TRUE(message.get()); @@ -84,7 +105,8 @@ void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state()); receiver->ProcessMessage(message.get(), base::Bind( &AuthenticatorTestBase::ContinueAuthExchangeWith, - base::Unretained(receiver), base::Unretained(sender))); + base::Unretained(receiver), base::Unretained(sender), + receiver->started(), sender->started())); } void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { diff --git a/remoting/protocol/authenticator_test_base.h b/remoting/protocol/authenticator_test_base.h index 9b299db..e20774a 100644 --- a/remoting/protocol/authenticator_test_base.h +++ b/remoting/protocol/authenticator_test_base.h @@ -41,7 +41,9 @@ class AuthenticatorTestBase : public testing::Test { }; static void ContinueAuthExchangeWith(Authenticator* sender, - Authenticator* receiver); + Authenticator* receiver, + bool sender_started, + bool receiver_srated); virtual void SetUp() OVERRIDE; void RunAuthExchange(); void RunHostInitiatedAuthExchange(); diff --git a/remoting/protocol/connection_to_client.cc b/remoting/protocol/connection_to_client.cc index dd32452..a6d3143 100644 --- a/remoting/protocol/connection_to_client.cc +++ b/remoting/protocol/connection_to_client.cc @@ -112,7 +112,9 @@ void ConnectionToClient::OnSessionStateChange(Session::State state) { case Session::CONNECTED: // Don't care about these events. break; - + case Session::AUTHENTICATING: + handler_->OnConnectionAuthenticating(this); + break; case Session::AUTHENTICATED: // Initialize channels. control_dispatcher_.reset(new HostControlDispatcher()); diff --git a/remoting/protocol/connection_to_client.h b/remoting/protocol/connection_to_client.h index 9865307..9a64dcd 100644 --- a/remoting/protocol/connection_to_client.h +++ b/remoting/protocol/connection_to_client.h @@ -38,6 +38,9 @@ class ConnectionToClient : public base::NonThreadSafe, public: class EventHandler { public: + // Called when the network connection is authenticating + virtual void OnConnectionAuthenticating(ConnectionToClient* connection) = 0; + // Called when the network connection is authenticated. virtual void OnConnectionAuthenticated(ConnectionToClient* connection) = 0; diff --git a/remoting/protocol/connection_to_host.cc b/remoting/protocol/connection_to_host.cc index cdaf4b6..978da56 100644 --- a/remoting/protocol/connection_to_host.cc +++ b/remoting/protocol/connection_to_host.cc @@ -152,6 +152,7 @@ void ConnectionToHost::OnSessionStateChange( case Session::CONNECTING: case Session::ACCEPTING: case Session::CONNECTED: + case Session::AUTHENTICATING: // Don't care about these events. break; diff --git a/remoting/protocol/fake_authenticator.cc b/remoting/protocol/fake_authenticator.cc index 67f5d35..67c1cca 100644 --- a/remoting/protocol/fake_authenticator.cc +++ b/remoting/protocol/fake_authenticator.cc @@ -84,12 +84,17 @@ FakeAuthenticator::FakeAuthenticator( round_trips_(round_trips), action_(action), async_(async), - messages_(0) { + messages_(0), + messages_till_started_(0) { } FakeAuthenticator::~FakeAuthenticator() { } +void FakeAuthenticator::set_messages_till_started(int messages) { + messages_till_started_ = messages; +} + Authenticator::State FakeAuthenticator::state() const { EXPECT_LE(messages_, round_trips_ * 2); if (messages_ >= round_trips_ * 2) { @@ -116,6 +121,10 @@ Authenticator::State FakeAuthenticator::state() const { } } +bool FakeAuthenticator::started() const { + return messages_ > messages_till_started_; +} + Authenticator::RejectionReason FakeAuthenticator::rejection_reason() const { EXPECT_EQ(REJECTED, state()); return INVALID_CREDENTIALS; @@ -153,8 +162,10 @@ FakeAuthenticator::CreateChannelAuthenticator() const { } FakeHostAuthenticatorFactory::FakeHostAuthenticatorFactory( - int round_trips, FakeAuthenticator::Action action, bool async) + int round_trips, int messages_till_started, + FakeAuthenticator::Action action, bool async) : round_trips_(round_trips), + messages_till_started_(messages_till_started), action_(action), async_(async) { } @@ -165,8 +176,12 @@ scoped_ptr<Authenticator> FakeHostAuthenticatorFactory::CreateAuthenticator( const std::string& local_jid, const std::string& remote_jid, const buzz::XmlElement* first_message) { - return scoped_ptr<Authenticator>(new FakeAuthenticator( - FakeAuthenticator::HOST, round_trips_, action_, async_)); + FakeAuthenticator* authenticator = new FakeAuthenticator( + FakeAuthenticator::HOST, round_trips_, action_, async_); + authenticator->set_messages_till_started(messages_till_started_); + + scoped_ptr<Authenticator> result(authenticator); + return result.Pass(); } } // namespace protocol diff --git a/remoting/protocol/fake_authenticator.h b/remoting/protocol/fake_authenticator.h index a5a2de8..b74654d0 100644 --- a/remoting/protocol/fake_authenticator.h +++ b/remoting/protocol/fake_authenticator.h @@ -59,10 +59,16 @@ class FakeAuthenticator : public Authenticator { }; FakeAuthenticator(Type type, int round_trips, Action action, bool async); + virtual ~FakeAuthenticator(); + // Set the number of messages that the authenticator needs to process before + // started() returns true. Default to 0. + void set_messages_till_started(int messages); + // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -78,6 +84,9 @@ class FakeAuthenticator : public Authenticator { // Total number of messages that have been processed. int messages_; + // Number of messages that the authenticator needs to process before started() + // returns true. Default to 0. + int messages_till_started_; DISALLOW_COPY_AND_ASSIGN(FakeAuthenticator); }; @@ -85,7 +94,8 @@ class FakeAuthenticator : public Authenticator { class FakeHostAuthenticatorFactory : public AuthenticatorFactory { public: FakeHostAuthenticatorFactory( - int round_trips, FakeAuthenticator::Action action, bool async); + int round_trips, int messages_till_start, + FakeAuthenticator::Action action, bool async); virtual ~FakeHostAuthenticatorFactory(); // AuthenticatorFactory interface. @@ -96,6 +106,7 @@ class FakeHostAuthenticatorFactory : public AuthenticatorFactory { private: int round_trips_; + int messages_till_started_; FakeAuthenticator::Action action_; bool async_; diff --git a/remoting/protocol/jingle_session.cc b/remoting/protocol/jingle_session.cc index 2aecd4c..57883cb 100644 --- a/remoting/protocol/jingle_session.cc +++ b/remoting/protocol/jingle_session.cc @@ -6,8 +6,10 @@ #include "base/bind.h" #include "base/rand_util.h" +#include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/strings/string_number_conversions.h" +#include "base/thread_task_runner_handle.h" #include "base/time/time.h" #include "remoting/base/constants.h" #include "remoting/jingle_glue/iq_sender.h" @@ -63,7 +65,8 @@ JingleSession::JingleSession(JingleSessionManager* session_manager) event_handler_(NULL), state_(INITIALIZING), error_(OK), - config_is_set_(false) { + config_is_set_(false), + weak_factory_(this) { } JingleSession::~JingleSession() { @@ -181,9 +184,10 @@ void JingleSession::ContinueAcceptIncomingConnection() { SetState(AUTHENTICATED); } else { DCHECK_EQ(authenticator_->state(), Authenticator::WAITING_MESSAGE); + if (authenticator_->started()) { + SetState(AUTHENTICATING); + } } - - return; } const std::string& JingleSession::jid() { @@ -485,7 +489,7 @@ void JingleSession::OnSessionInfo(const JingleMessage& message, return; } - if (state_ != CONNECTED || + if ((state_ != CONNECTED && state_ != AUTHENTICATING) || authenticator_->state() != Authenticator::WAITING_MESSAGE) { LOG(WARNING) << "Received unexpected authenticator message " << message.info->Str(); @@ -517,8 +521,7 @@ void JingleSession::ProcessTransportInfo(const JingleMessage& message) { void JingleSession::OnTerminate(const JingleMessage& message, const ReplyCallback& reply_callback) { - if (state_ != CONNECTING && state_ != ACCEPTING && state_ != CONNECTED && - state_ != AUTHENTICATED) { + if (!is_session_active()) { LOG(WARNING) << "Received unexpected session-terminate message."; reply_callback.Run(JingleMessageReply::UNEXPECTED_REQUEST); return; @@ -577,7 +580,7 @@ void JingleSession::ProcessAuthenticationStep() { DCHECK(CalledOnValidThread()); DCHECK_NE(authenticator_->state(), Authenticator::PROCESSING_MESSAGE); - if (state_ != CONNECTED) { + if (state_ != CONNECTED && state_ != AUTHENTICATING) { DCHECK(state_ == FAILED || state_ == CLOSED); // The remote host closed the connection while the authentication was being // processed asynchronously, nothing to do. @@ -592,6 +595,21 @@ void JingleSession::ProcessAuthenticationStep() { } DCHECK_NE(authenticator_->state(), Authenticator::MESSAGE_READY); + // The current JingleSession object can be destroyed by event_handler of + // SetState(AUTHENTICATING) and cause subsequent dereferencing of the this + // pointer to crash. To protect against it, we run ContinueAuthenticationStep + // asychronously using a weak pointer. + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, + base::Bind(&JingleSession::ContinueAuthenticationStep, + weak_factory_.GetWeakPtr())); + + if (authenticator_->started()) { + SetState(AUTHENTICATING); + } +} + +void JingleSession::ContinueAuthenticationStep() { if (authenticator_->state() == Authenticator::ACCEPTED) { SetState(AUTHENTICATED); } else if (authenticator_->state() == Authenticator::REJECTED) { @@ -603,8 +621,7 @@ void JingleSession::ProcessAuthenticationStep() { void JingleSession::CloseInternal(ErrorCode error) { DCHECK(CalledOnValidThread()); - if (state_ == CONNECTING || state_ == ACCEPTING || state_ == CONNECTED || - state_ == AUTHENTICATED) { + if (is_session_active()) { // Send session-terminate message with the appropriate error code. JingleMessage::Reason reason; switch (error) { @@ -655,5 +672,10 @@ void JingleSession::SetState(State new_state) { } } +bool JingleSession::is_session_active() { + return state_ == CONNECTING || state_ == ACCEPTING || state_ == CONNECTED || + state_ == AUTHENTICATING || state_ == AUTHENTICATED; +} + } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/jingle_session.h b/remoting/protocol/jingle_session.h index 189cb53..3b704fb 100644 --- a/remoting/protocol/jingle_session.h +++ b/remoting/protocol/jingle_session.h @@ -132,9 +132,13 @@ class JingleSession : public Session, // Called after the initial incoming authenticator message is processed. void ContinueAcceptIncomingConnection(); + // Called after subsequent authenticator messages are processed. void ProcessAuthenticationStep(); + // Called after the authenticating step is finished. + void ContinueAuthenticationStep(); + // Terminates the session and sends session-terminate if it is // necessary. |error| specifies the error code in case when the // session is being closed due to an error. @@ -143,6 +147,9 @@ class JingleSession : public Session, // Sets |state_| to |new_state| and calls state change callback. void SetState(State new_state); + // Returns true if the state of the session is not CLOSED or FAILED + bool is_session_active(); + JingleSessionManager* session_manager_; std::string peer_jid_; scoped_ptr<CandidateSessionConfig> candidate_config_; @@ -172,6 +179,8 @@ class JingleSession : public Session, // Pending remote candidates, received before the local channels were created. std::list<JingleMessage::NamedCandidate> pending_remote_candidates_; + base::WeakPtrFactory<JingleSession> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(JingleSession); }; diff --git a/remoting/protocol/jingle_session_unittest.cc b/remoting/protocol/jingle_session_unittest.cc index dcca077..2f95100 100644 --- a/remoting/protocol/jingle_session_unittest.cc +++ b/remoting/protocol/jingle_session_unittest.cc @@ -105,6 +105,10 @@ class JingleSessionTest : public testing::Test { session->set_config(SessionConfig::ForTest()); } + void DeleteSession() { + host_session_.reset(); + } + void OnClientChannelCreated(scoped_ptr<net::StreamSocket> socket) { client_channel_callback_.OnDone(socket.get()); client_socket_ = socket.Pass(); @@ -132,7 +136,7 @@ class JingleSessionTest : public testing::Test { client_session_.reset(); } - void CreateSessionManagers(int auth_round_trips, + void CreateSessionManagers(int auth_round_trips, int messages_till_start, FakeAuthenticator::Action auth_action) { host_signal_strategy_.reset(new FakeSignalStrategy(kHostJid)); client_signal_strategy_.reset(new FakeSignalStrategy(kClientJid)); @@ -153,7 +157,8 @@ class JingleSessionTest : public testing::Test { host_server_->Init(host_signal_strategy_.get(), &host_server_listener_); scoped_ptr<AuthenticatorFactory> factory( - new FakeHostAuthenticatorFactory(auth_round_trips, auth_action, true)); + new FakeHostAuthenticatorFactory(auth_round_trips, + messages_till_start, auth_action, true)); host_server_->set_authenticator_factory(factory.Pass()); EXPECT_CALL(client_server_listener_, OnSessionManagerReady()) @@ -169,6 +174,11 @@ class JingleSessionTest : public testing::Test { &client_server_listener_); } + void CreateSessionManagers(int auth_round_trips, + FakeAuthenticator::Action auth_action) { + CreateSessionManagers(auth_round_trips, 0, auth_action); + } + void CloseSessionManager() { if (host_server_.get()) { host_server_->Close(); @@ -196,6 +206,9 @@ class JingleSessionTest : public testing::Test { EXPECT_CALL(host_session_event_handler_, OnSessionStateChange(Session::CONNECTED)) .Times(AtMost(1)); + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .Times(AtMost(1)); if (expect_fail) { EXPECT_CALL(host_session_event_handler_, OnSessionStateChange(Session::FAILED)) @@ -217,6 +230,9 @@ class JingleSessionTest : public testing::Test { EXPECT_CALL(client_session_event_handler_, OnSessionStateChange(Session::CONNECTED)) .Times(AtMost(1)); + EXPECT_CALL(client_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .Times(AtMost(1)); if (expect_fail) { EXPECT_CALL(client_session_event_handler_, OnSessionStateChange(Session::FAILED)) @@ -375,6 +391,60 @@ TEST_F(JingleSessionTest, TestStreamChannel) { tester.CheckResults(); } +TEST_F(JingleSessionTest, DeleteSessionOnIncomingConnection) { + CreateSessionManagers(3, FakeAuthenticator::ACCEPT); + + EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _)) + .WillOnce(DoAll( + WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)), + SetArgumentPointee<1>(protocol::SessionManager::ACCEPT))); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::CONNECTED)) + .Times(AtMost(1)); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .WillOnce(InvokeWithoutArgs(this, &JingleSessionTest::DeleteSession)); + + scoped_ptr<Authenticator> authenticator(new FakeAuthenticator( + FakeAuthenticator::CLIENT, 3, FakeAuthenticator::ACCEPT, true)); + + client_session_ = client_server_->Connect( + kHostJid, authenticator.Pass(), + CandidateSessionConfig::CreateDefault()); + + base::RunLoop().RunUntilIdle(); +} + +TEST_F(JingleSessionTest, DeleteSessionOnAuth) { + // Same as the previous test, but set messages_till_started to 2 in + // CreateSessionManagers so that the session will goes into the + // AUTHENTICATING state after two message exchanges. + CreateSessionManagers(3, 2, FakeAuthenticator::ACCEPT); + + EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _)) + .WillOnce(DoAll( + WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)), + SetArgumentPointee<1>(protocol::SessionManager::ACCEPT))); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::CONNECTED)) + .Times(AtMost(1)); + + EXPECT_CALL(host_session_event_handler_, + OnSessionStateChange(Session::AUTHENTICATING)) + .WillOnce(InvokeWithoutArgs(this, &JingleSessionTest::DeleteSession)); + + scoped_ptr<Authenticator> authenticator(new FakeAuthenticator( + FakeAuthenticator::CLIENT, 3, FakeAuthenticator::ACCEPT, true)); + + client_session_ = client_server_->Connect( + kHostJid, authenticator.Pass(), + CandidateSessionConfig::CreateDefault()); + base::RunLoop().RunUntilIdle(); +} + // Verify that data can be sent over a multiplexed channel. TEST_F(JingleSessionTest, TestMuxStreamChannel) { CreateSessionManagers(1, FakeAuthenticator::ACCEPT); diff --git a/remoting/protocol/me2me_host_authenticator_factory.cc b/remoting/protocol/me2me_host_authenticator_factory.cc index 24c0ca4..7e407c8 100644 --- a/remoting/protocol/me2me_host_authenticator_factory.cc +++ b/remoting/protocol/me2me_host_authenticator_factory.cc @@ -30,6 +30,10 @@ class RejectingAuthenticator : public Authenticator { return state_; } + virtual bool started() const OVERRIDE { + return true; + } + virtual RejectionReason rejection_reason() const OVERRIDE { DCHECK_EQ(state_, REJECTED); return INVALID_CREDENTIALS; diff --git a/remoting/protocol/negotiating_authenticator_base.cc b/remoting/protocol/negotiating_authenticator_base.cc index 30bd8ba..5c61d30 100644 --- a/remoting/protocol/negotiating_authenticator_base.cc +++ b/remoting/protocol/negotiating_authenticator_base.cc @@ -40,6 +40,13 @@ Authenticator::State NegotiatingAuthenticatorBase::state() const { return state_; } +bool NegotiatingAuthenticatorBase::started() const { + if (!current_authenticator_) { + return false; + } + return current_authenticator_->started(); +} + Authenticator::RejectionReason NegotiatingAuthenticatorBase::rejection_reason() const { return rejection_reason_; diff --git a/remoting/protocol/negotiating_authenticator_base.h b/remoting/protocol/negotiating_authenticator_base.h index 1f80967..bcb005f2 100644 --- a/remoting/protocol/negotiating_authenticator_base.h +++ b/remoting/protocol/negotiating_authenticator_base.h @@ -64,6 +64,7 @@ class NegotiatingAuthenticatorBase : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual scoped_ptr<ChannelAuthenticator> CreateChannelAuthenticator() const OVERRIDE; diff --git a/remoting/protocol/pairing_authenticator_base.cc b/remoting/protocol/pairing_authenticator_base.cc index 47ca275..7435e55 100644 --- a/remoting/protocol/pairing_authenticator_base.cc +++ b/remoting/protocol/pairing_authenticator_base.cc @@ -38,6 +38,13 @@ Authenticator::State PairingAuthenticatorBase::state() const { return v2_authenticator_->state(); } +bool PairingAuthenticatorBase::started() const { + if (!v2_authenticator_) { + return false; + } + return v2_authenticator_->started(); +} + Authenticator::RejectionReason PairingAuthenticatorBase::rejection_reason() const { if (!v2_authenticator_) { diff --git a/remoting/protocol/pairing_authenticator_base.h b/remoting/protocol/pairing_authenticator_base.h index 2a4bc4e..4cb042c 100644 --- a/remoting/protocol/pairing_authenticator_base.h +++ b/remoting/protocol/pairing_authenticator_base.h @@ -43,6 +43,7 @@ class PairingAuthenticatorBase : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index fecf1d8..07e03306 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -52,6 +52,8 @@ class MockConnectionToClientEventHandler : MockConnectionToClientEventHandler(); virtual ~MockConnectionToClientEventHandler(); + MOCK_METHOD1(OnConnectionAuthenticating, + void(ConnectionToClient* connection)); MOCK_METHOD1(OnConnectionAuthenticated, void(ConnectionToClient* connection)); MOCK_METHOD1(OnConnectionChannelsConnected, void(ConnectionToClient* connection)); diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h index a542283..232a1d04 100644 --- a/remoting/protocol/session.h +++ b/remoting/protocol/session.h @@ -38,6 +38,9 @@ class Session { // Session has been accepted and is pending authentication. CONNECTED, + // Session has started authenticating. + AUTHENTICATING, + // Session has been connected and authenticated. AUTHENTICATED, @@ -54,8 +57,8 @@ class Session { virtual ~EventHandler() {} // Called after session state has changed. It is safe to destroy - // the session from within the handler if |state| is CLOSED or - // FAILED. + // the session from within the handler if |state| is AUTHENTICATING + // or CLOSED or FAILED. virtual void OnSessionStateChange(State state) = 0; // Called whenever route for the channel specified with diff --git a/remoting/protocol/third_party_authenticator_base.cc b/remoting/protocol/third_party_authenticator_base.cc index 9d6be72..019e788 100644 --- a/remoting/protocol/third_party_authenticator_base.cc +++ b/remoting/protocol/third_party_authenticator_base.cc @@ -28,12 +28,17 @@ const buzz::StaticQName ThirdPartyAuthenticatorBase::kTokenTag = ThirdPartyAuthenticatorBase::ThirdPartyAuthenticatorBase( Authenticator::State initial_state) : token_state_(initial_state), + started_(false), rejection_reason_(INVALID_CREDENTIALS) { } ThirdPartyAuthenticatorBase::~ThirdPartyAuthenticatorBase() { } +bool ThirdPartyAuthenticatorBase::started() const { + return started_; +} + Authenticator::State ThirdPartyAuthenticatorBase::state() const { if (token_state_ == ACCEPTED) return underlying_->state(); @@ -74,9 +79,10 @@ scoped_ptr<buzz::XmlElement> ThirdPartyAuthenticatorBase::GetNextMessage() { message = CreateEmptyAuthenticatorMessage(); } - if (token_state_ == MESSAGE_READY) + if (token_state_ == MESSAGE_READY) { AddTokenElements(message.get()); - + started_ = true; + } return message.Pass(); } diff --git a/remoting/protocol/third_party_authenticator_base.h b/remoting/protocol/third_party_authenticator_base.h index 0db203d..7141be8 100644 --- a/remoting/protocol/third_party_authenticator_base.h +++ b/remoting/protocol/third_party_authenticator_base.h @@ -36,6 +36,7 @@ class ThirdPartyAuthenticatorBase : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -66,6 +67,7 @@ class ThirdPartyAuthenticatorBase : public Authenticator { scoped_ptr<Authenticator> underlying_; State token_state_; + bool started_; RejectionReason rejection_reason_; private: diff --git a/remoting/protocol/v2_authenticator.cc b/remoting/protocol/v2_authenticator.cc index ee5c9d1..1b13c7c 100644 --- a/remoting/protocol/v2_authenticator.cc +++ b/remoting/protocol/v2_authenticator.cc @@ -62,6 +62,7 @@ V2Authenticator::V2Authenticator( : certificate_sent_(false), key_exchange_impl_(type, shared_secret), state_(initial_state), + started_(false), rejection_reason_(INVALID_CREDENTIALS) { pending_messages_.push(key_exchange_impl_.GetMessage()); } @@ -75,6 +76,10 @@ Authenticator::State V2Authenticator::state() const { return state_; } +bool V2Authenticator::started() const { + return started_; +} + Authenticator::RejectionReason V2Authenticator::rejection_reason() const { DCHECK_EQ(state(), REJECTED); return rejection_reason_; @@ -127,6 +132,7 @@ void V2Authenticator::ProcessMessageInternal(const buzz::XmlElement* message) { P224EncryptedKeyExchange::Result result = key_exchange_impl_.ProcessMessage(spake_message); + started_ = true; switch (result) { case P224EncryptedKeyExchange::kResultPending: pending_messages_.push(key_exchange_impl_.GetMessage()); @@ -143,7 +149,6 @@ void V2Authenticator::ProcessMessageInternal(const buzz::XmlElement* message) { return; } } - state_ = MESSAGE_READY; } diff --git a/remoting/protocol/v2_authenticator.h b/remoting/protocol/v2_authenticator.h index 0675cb2..b52a3e1 100644 --- a/remoting/protocol/v2_authenticator.h +++ b/remoting/protocol/v2_authenticator.h @@ -38,6 +38,7 @@ class V2Authenticator : public Authenticator { // Authenticator interface. virtual State state() const OVERRIDE; + virtual bool started() const OVERRIDE; virtual RejectionReason rejection_reason() const OVERRIDE; virtual void ProcessMessage(const buzz::XmlElement* message, const base::Closure& resume_callback) OVERRIDE; @@ -67,6 +68,7 @@ class V2Authenticator : public Authenticator { // Used for both host and client authenticators. crypto::P224EncryptedKeyExchange key_exchange_impl_; State state_; + bool started_; RejectionReason rejection_reason_; std::queue<std::string> pending_messages_; std::string auth_key_; |