summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--remoting/host/chromoting_host.cc22
-rw-r--r--remoting/host/chromoting_host.h1
-rw-r--r--remoting/host/chromoting_host_unittest.cc84
-rw-r--r--remoting/host/client_session.cc5
-rw-r--r--remoting/host/client_session.h5
-rw-r--r--remoting/host/host_mock_objects.h1
-rw-r--r--remoting/host/pam_authorization_factory_posix.cc5
-rw-r--r--remoting/protocol/authenticator.h5
-rw-r--r--remoting/protocol/authenticator_test_base.cc32
-rw-r--r--remoting/protocol/authenticator_test_base.h4
-rw-r--r--remoting/protocol/connection_to_client.cc4
-rw-r--r--remoting/protocol/connection_to_client.h3
-rw-r--r--remoting/protocol/connection_to_host.cc1
-rw-r--r--remoting/protocol/fake_authenticator.cc23
-rw-r--r--remoting/protocol/fake_authenticator.h13
-rw-r--r--remoting/protocol/jingle_session.cc40
-rw-r--r--remoting/protocol/jingle_session.h9
-rw-r--r--remoting/protocol/jingle_session_unittest.cc74
-rw-r--r--remoting/protocol/me2me_host_authenticator_factory.cc4
-rw-r--r--remoting/protocol/negotiating_authenticator_base.cc7
-rw-r--r--remoting/protocol/negotiating_authenticator_base.h1
-rw-r--r--remoting/protocol/pairing_authenticator_base.cc7
-rw-r--r--remoting/protocol/pairing_authenticator_base.h1
-rw-r--r--remoting/protocol/protocol_mock_objects.h2
-rw-r--r--remoting/protocol/session.h7
-rw-r--r--remoting/protocol/third_party_authenticator_base.cc10
-rw-r--r--remoting/protocol/third_party_authenticator_base.h2
-rw-r--r--remoting/protocol/v2_authenticator.cc7
-rw-r--r--remoting/protocol/v2_authenticator.h2
29 files changed, 332 insertions, 49 deletions
diff --git a/remoting/host/chromoting_host.cc b/remoting/host/chromoting_host.cc
index e4bdae7..469c5b2 100644
--- a/remoting/host/chromoting_host.cc
+++ b/remoting/host/chromoting_host.cc
@@ -171,6 +171,20 @@ void ChromotingHost::SetMaximumSessionDuration(
////////////////////////////////////////////////////////////////////////////
// protocol::ClientSession::EventHandler implementation.
+void ChromotingHost::OnSessionAuthenticating(ClientSession* client) {
+ // We treat each incoming connection as a failure to authenticate,
+ // and clear the backoff when a connection successfully
+ // authenticates. This allows the backoff to protect from parallel
+ // connection attempts as well as sequential ones.
+ if (login_backoff_.ShouldRejectRequest()) {
+ LOG(WARNING) << "Disconnecting client " << client->client_jid() << " due to"
+ " an overload of failed login attempts.";
+ client->DisconnectSession();
+ return;
+ }
+ login_backoff_.InformOfRequest(false);
+}
+
bool ChromotingHost::OnSessionAuthenticated(ClientSession* client) {
DCHECK(CalledOnValidThread());
@@ -265,16 +279,12 @@ void ChromotingHost::OnIncomingSession(
}
if (login_backoff_.ShouldRejectRequest()) {
+ LOG(WARNING) << "Rejecting connection due to"
+ " an overload of failed login attempts.";
*response = protocol::SessionManager::OVERLOAD;
return;
}
- // We treat each incoming connection as a failure to authenticate,
- // and clear the backoff when a connection successfully
- // authenticates. This allows the backoff to protect from parallel
- // connection attempts as well as sequential ones.
- login_backoff_.InformOfRequest(false);
-
protocol::SessionConfig config;
if (!protocol_config_->Select(session->candidate_config(), &config)) {
LOG(WARNING) << "Rejecting connection from " << session->jid()
diff --git a/remoting/host/chromoting_host.h b/remoting/host/chromoting_host.h
index 3e1fa74..a6642d4 100644
--- a/remoting/host/chromoting_host.h
+++ b/remoting/host/chromoting_host.h
@@ -115,6 +115,7 @@ class ChromotingHost : public base::NonThreadSafe,
////////////////////////////////////////////////////////////////////////////
// ClientSession::EventHandler implementation.
+ virtual void OnSessionAuthenticating(ClientSession* client) OVERRIDE;
virtual bool OnSessionAuthenticated(ClientSession* client) OVERRIDE;
virtual void OnSessionChannelsConnected(ClientSession* client) OVERRIDE;
virtual void OnSessionAuthenticationFailed(ClientSession* client) OVERRIDE;
diff --git a/remoting/host/chromoting_host_unittest.cc b/remoting/host/chromoting_host_unittest.cc
index 1410371..5ebe051 100644
--- a/remoting/host/chromoting_host_unittest.cc
+++ b/remoting/host/chromoting_host_unittest.cc
@@ -28,12 +28,13 @@ using ::remoting::protocol::MockConnectionToClientEventHandler;
using ::remoting::protocol::MockHostStub;
using ::remoting::protocol::MockSession;
using ::remoting::protocol::MockVideoStub;
+using ::remoting::protocol::Session;
using ::remoting::protocol::SessionConfig;
using testing::_;
using testing::AnyNumber;
-using testing::AtMost;
using testing::AtLeast;
+using testing::AtMost;
using testing::CreateFunctor;
using testing::DeleteArg;
using testing::DoAll;
@@ -44,6 +45,7 @@ using testing::InvokeArgument;
using testing::InvokeWithoutArgs;
using testing::Return;
using testing::ReturnRef;
+using testing::SaveArg;
using testing::Sequence;
namespace remoting {
@@ -124,9 +126,10 @@ class ChromotingHostTest : public testing::Test {
.Times(AnyNumber());
EXPECT_CALL(*session_unowned1_, SetEventHandler(_))
.Times(AnyNumber())
- .WillRepeatedly(Invoke(this, &ChromotingHostTest::SetEventHandler));
+ .WillRepeatedly(SaveArg<0>(&session_unowned1_event_handler_));
EXPECT_CALL(*session_unowned2_, SetEventHandler(_))
- .Times(AnyNumber());
+ .Times(AnyNumber())
+ .WillRepeatedly(SaveArg<0>(&session_unowned2_event_handler_));
EXPECT_CALL(*session1_, config())
.WillRepeatedly(ReturnRef(session_config1_));
EXPECT_CALL(*session2_, config())
@@ -287,13 +290,15 @@ class ChromotingHostTest : public testing::Test {
get_connection(connection_index), protocol::OK);
}
- void SetEventHandler(protocol::Session::EventHandler* event_handler) {
- session_event_handler_ = event_handler;
+ void NotifyConnectionClosed1() {
+ if (session_unowned1_event_handler_) {
+ session_unowned1_event_handler_->OnSessionStateChange(Session::CLOSED);
+ }
}
- void NotifyConnectionClosed() {
- if (session_event_handler_) {
- session_event_handler_->OnSessionStateChange(protocol::Session::CLOSED);
+ void NotifyConnectionClosed2() {
+ if (session_unowned2_event_handler_) {
+ session_unowned2_event_handler_->OnSessionStateChange(Session::CLOSED);
}
}
@@ -424,7 +429,8 @@ class ChromotingHostTest : public testing::Test {
scoped_ptr<MockSession> session_unowned2_; // Not owned by a connection.
SessionConfig session_unowned_config2_;
std::string session_unowned_jid2_;
- protocol::Session::EventHandler* session_event_handler_;
+ protocol::Session::EventHandler* session_unowned1_event_handler_;
+ protocol::Session::EventHandler* session_unowned2_event_handler_;
scoped_ptr<protocol::CandidateSessionConfig> empty_candidate_config_;
scoped_ptr<protocol::CandidateSessionConfig> default_candidate_config_;
@@ -432,10 +438,16 @@ class ChromotingHostTest : public testing::Test {
return (connection_index == 0) ? connection1_ : connection2_;
}
+ // Returns the cached client pointers client1_ or client2_.
ClientSession*& get_client(int connection_index) {
return (connection_index == 0) ? client1_ : client2_;
}
+ // Returns the list of clients of the host_.
+ std::list<ClientSession*>& get_clients_from_host() {
+ return host_->clients_;
+ }
+
const std::string& get_session_jid(int connection_index) {
return (connection_index == 0) ? session_jid1_ : session_jid2_;
}
@@ -578,7 +590,7 @@ TEST_F(ChromotingHostTest, IncomingSessionAccepted) {
default_candidate_config_.get()));
EXPECT_CALL(*session_unowned1_, set_config(_));
EXPECT_CALL(*session_unowned1_, Close()).WillOnce(InvokeWithoutArgs(
- this, &ChromotingHostTest::NotifyConnectionClosed));
+ this, &ChromotingHostTest::NotifyConnectionClosed1));
EXPECT_CALL(host_status_observer_, OnAccessDenied(_));
EXPECT_CALL(host_status_observer_, OnShutdown());
@@ -593,13 +605,13 @@ TEST_F(ChromotingHostTest, IncomingSessionAccepted) {
message_loop_.Run();
}
-TEST_F(ChromotingHostTest, IncomingSessionOverload) {
+TEST_F(ChromotingHostTest, LoginBackOffUponConnection) {
ExpectHostAndSessionManagerStart();
- EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce(Return(
- default_candidate_config_.get()));
+ EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce(
+ Return(default_candidate_config_.get()));
EXPECT_CALL(*session_unowned1_, set_config(_));
- EXPECT_CALL(*session_unowned1_, Close()).WillOnce(InvokeWithoutArgs(
- this, &ChromotingHostTest::NotifyConnectionClosed));
+ EXPECT_CALL(*session_unowned1_, Close()).WillOnce(
+ InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed1));
EXPECT_CALL(host_status_observer_, OnAccessDenied(_));
EXPECT_CALL(host_status_observer_, OnShutdown());
@@ -607,9 +619,11 @@ TEST_F(ChromotingHostTest, IncomingSessionOverload) {
protocol::SessionManager::IncomingSessionResponse response =
protocol::SessionManager::DECLINE;
+
host_->OnIncomingSession(session_unowned1_.release(), &response);
EXPECT_EQ(protocol::SessionManager::ACCEPT, response);
+ host_->OnSessionAuthenticating(get_clients_from_host().front());
host_->OnIncomingSession(session_unowned2_.get(), &response);
EXPECT_EQ(protocol::SessionManager::OVERLOAD, response);
@@ -617,6 +631,46 @@ TEST_F(ChromotingHostTest, IncomingSessionOverload) {
message_loop_.Run();
}
+TEST_F(ChromotingHostTest, LoginBackOffUponAuthenticating) {
+ Expectation start = ExpectHostAndSessionManagerStart();
+ EXPECT_CALL(*session_unowned1_, candidate_config()).WillOnce(
+ Return(default_candidate_config_.get()));
+ EXPECT_CALL(*session_unowned1_, set_config(_));
+ EXPECT_CALL(*session_unowned1_, Close()).WillOnce(
+ InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed1));
+
+ EXPECT_CALL(*session_unowned2_, candidate_config()).WillOnce(
+ Return(default_candidate_config_.get()));
+ EXPECT_CALL(*session_unowned2_, set_config(_));
+ EXPECT_CALL(*session_unowned2_, Close()).WillOnce(
+ InvokeWithoutArgs(this, &ChromotingHostTest::NotifyConnectionClosed2));
+
+ EXPECT_CALL(host_status_observer_, OnShutdown());
+
+ host_->Start(xmpp_login_);
+
+ protocol::SessionManager::IncomingSessionResponse response =
+ protocol::SessionManager::DECLINE;
+
+ host_->OnIncomingSession(session_unowned1_.release(), &response);
+ EXPECT_EQ(protocol::SessionManager::ACCEPT, response);
+
+ host_->OnIncomingSession(session_unowned2_.release(), &response);
+ EXPECT_EQ(protocol::SessionManager::ACCEPT, response);
+
+ // This will set the backoff.
+ host_->OnSessionAuthenticating(get_clients_from_host().front());
+
+ // This should disconnect client2.
+ host_->OnSessionAuthenticating(get_clients_from_host().back());
+
+ // Verify that the host only has 1 client at this point.
+ EXPECT_EQ(get_clients_from_host().size(), 1U);
+
+ ShutdownHost();
+ message_loop_.Run();
+}
+
TEST_F(ChromotingHostTest, OnSessionRouteChange) {
std::string channel_name("ChannelName");
protocol::TransportRoute route;
diff --git a/remoting/host/client_session.cc b/remoting/host/client_session.cc
index 7a3c76c..2661e0e 100644
--- a/remoting/host/client_session.cc
+++ b/remoting/host/client_session.cc
@@ -210,6 +210,11 @@ void ClientSession::DeliverClientMessage(
<< message.type() << ": " << message.data();
}
+void ClientSession::OnConnectionAuthenticating(
+ protocol::ConnectionToClient* connection) {
+ event_handler_->OnSessionAuthenticating(this);
+}
+
void ClientSession::OnConnectionAuthenticated(
protocol::ConnectionToClient* connection) {
DCHECK(CalledOnValidThread());
diff --git a/remoting/host/client_session.h b/remoting/host/client_session.h
index ef75b25..f892329 100644
--- a/remoting/host/client_session.h
+++ b/remoting/host/client_session.h
@@ -54,6 +54,9 @@ class ClientSession
// Callback interface for passing events to the ChromotingHost.
class EventHandler {
public:
+ // Called after authentication has started.
+ virtual void OnSessionAuthenticating(ClientSession* client) = 0;
+
// Called after authentication has finished successfully. Returns true if
// the connection is allowed, or false otherwise.
virtual bool OnSessionAuthenticated(ClientSession* client) = 0;
@@ -115,6 +118,8 @@ class ClientSession
const protocol::ExtensionMessage& message) OVERRIDE;
// protocol::ConnectionToClient::EventHandler interface.
+ virtual void OnConnectionAuthenticating(
+ protocol::ConnectionToClient* connection) OVERRIDE;
virtual void OnConnectionAuthenticated(
protocol::ConnectionToClient* connection) OVERRIDE;
virtual void OnConnectionChannelsConnected(
diff --git a/remoting/host/host_mock_objects.h b/remoting/host/host_mock_objects.h
index d365695..d916f04 100644
--- a/remoting/host/host_mock_objects.h
+++ b/remoting/host/host_mock_objects.h
@@ -68,6 +68,7 @@ class MockClientSessionEventHandler : public ClientSession::EventHandler {
MockClientSessionEventHandler();
virtual ~MockClientSessionEventHandler();
+ MOCK_METHOD1(OnSessionAuthenticating, void(ClientSession* client));
MOCK_METHOD1(OnSessionAuthenticated, bool(ClientSession* client));
MOCK_METHOD1(OnSessionChannelsConnected, void(ClientSession* client));
MOCK_METHOD1(OnSessionAuthenticationFailed, void(ClientSession* client));
diff --git a/remoting/host/pam_authorization_factory_posix.cc b/remoting/host/pam_authorization_factory_posix.cc
index c89c71f..eef0c48 100644
--- a/remoting/host/pam_authorization_factory_posix.cc
+++ b/remoting/host/pam_authorization_factory_posix.cc
@@ -24,6 +24,7 @@ class PamAuthorizer : public protocol::Authenticator {
// protocol::Authenticator interface.
virtual State state() const OVERRIDE;
+ virtual bool started() const OVERRIDE;
virtual RejectionReason rejection_reason() const OVERRIDE;
virtual void ProcessMessage(const buzz::XmlElement* message,
const base::Closure& resume_callback) OVERRIDE;
@@ -62,6 +63,10 @@ protocol::Authenticator::State PamAuthorizer::state() const {
}
}
+bool PamAuthorizer::started() const {
+ return underlying_->started();
+}
+
protocol::Authenticator::RejectionReason
PamAuthorizer::rejection_reason() const {
if (local_login_status_ == DISALLOWED) {
diff --git a/remoting/protocol/authenticator.h b/remoting/protocol/authenticator.h
index 28288e5..1210989 100644
--- a/remoting/protocol/authenticator.h
+++ b/remoting/protocol/authenticator.h
@@ -89,6 +89,11 @@ class Authenticator {
// Returns current state of the authenticator.
virtual State state() const = 0;
+ // Returns whether authentication has started. The chromoting host uses this
+ // method to starts the back off process to prevent malicious clients from
+ // guessing the PIN by spamming the host with auth requests.
+ virtual bool started() const = 0;
+
// Returns rejection reason. Can be called only when in REJECTED state.
virtual RejectionReason rejection_reason() const = 0;
diff --git a/remoting/protocol/authenticator_test_base.cc b/remoting/protocol/authenticator_test_base.cc
index 25b0efe..1b65dbe 100644
--- a/remoting/protocol/authenticator_test_base.cc
+++ b/remoting/protocol/authenticator_test_base.cc
@@ -60,22 +60,43 @@ void AuthenticatorTestBase::SetUp() {
}
void AuthenticatorTestBase::RunAuthExchange() {
- ContinueAuthExchangeWith(client_.get(), host_.get());
+ ContinueAuthExchangeWith(client_.get(),
+ host_.get(),
+ client_->started(),
+ host_->started());
}
void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
- ContinueAuthExchangeWith(host_.get(), client_.get());
+ ContinueAuthExchangeWith(host_.get(),
+ client_.get(),
+ host_->started(),
+ client_->started());
}
// static
+// This function sends a message from the sender and receiver and recursively
+// calls itself to the send the next message from the receiver to the sender
+// untils the authentication completes.
void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
- Authenticator* receiver) {
+ Authenticator* receiver,
+ bool sender_started,
+ bool receiver_started) {
scoped_ptr<buzz::XmlElement> message;
ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state());
if (sender->state() == Authenticator::ACCEPTED ||
sender->state() == Authenticator::REJECTED)
return;
- // Pass message from client to host.
+
+ // Verify that once the started flag for either party is set to true,
+ // it should always stay true.
+ if (receiver_started) {
+ ASSERT_TRUE(receiver->started());
+ }
+
+ if (sender_started) {
+ ASSERT_TRUE(sender->started());
+ }
+
ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state());
message = sender->GetNextMessage();
ASSERT_TRUE(message.get());
@@ -84,7 +105,8 @@ void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state());
receiver->ProcessMessage(message.get(), base::Bind(
&AuthenticatorTestBase::ContinueAuthExchangeWith,
- base::Unretained(receiver), base::Unretained(sender)));
+ base::Unretained(receiver), base::Unretained(sender),
+ receiver->started(), sender->started()));
}
void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
diff --git a/remoting/protocol/authenticator_test_base.h b/remoting/protocol/authenticator_test_base.h
index 9b299db..e20774a 100644
--- a/remoting/protocol/authenticator_test_base.h
+++ b/remoting/protocol/authenticator_test_base.h
@@ -41,7 +41,9 @@ class AuthenticatorTestBase : public testing::Test {
};
static void ContinueAuthExchangeWith(Authenticator* sender,
- Authenticator* receiver);
+ Authenticator* receiver,
+ bool sender_started,
+ bool receiver_srated);
virtual void SetUp() OVERRIDE;
void RunAuthExchange();
void RunHostInitiatedAuthExchange();
diff --git a/remoting/protocol/connection_to_client.cc b/remoting/protocol/connection_to_client.cc
index dd32452..a6d3143 100644
--- a/remoting/protocol/connection_to_client.cc
+++ b/remoting/protocol/connection_to_client.cc
@@ -112,7 +112,9 @@ void ConnectionToClient::OnSessionStateChange(Session::State state) {
case Session::CONNECTED:
// Don't care about these events.
break;
-
+ case Session::AUTHENTICATING:
+ handler_->OnConnectionAuthenticating(this);
+ break;
case Session::AUTHENTICATED:
// Initialize channels.
control_dispatcher_.reset(new HostControlDispatcher());
diff --git a/remoting/protocol/connection_to_client.h b/remoting/protocol/connection_to_client.h
index 9865307..9a64dcd 100644
--- a/remoting/protocol/connection_to_client.h
+++ b/remoting/protocol/connection_to_client.h
@@ -38,6 +38,9 @@ class ConnectionToClient : public base::NonThreadSafe,
public:
class EventHandler {
public:
+ // Called when the network connection is authenticating
+ virtual void OnConnectionAuthenticating(ConnectionToClient* connection) = 0;
+
// Called when the network connection is authenticated.
virtual void OnConnectionAuthenticated(ConnectionToClient* connection) = 0;
diff --git a/remoting/protocol/connection_to_host.cc b/remoting/protocol/connection_to_host.cc
index cdaf4b6..978da56 100644
--- a/remoting/protocol/connection_to_host.cc
+++ b/remoting/protocol/connection_to_host.cc
@@ -152,6 +152,7 @@ void ConnectionToHost::OnSessionStateChange(
case Session::CONNECTING:
case Session::ACCEPTING:
case Session::CONNECTED:
+ case Session::AUTHENTICATING:
// Don't care about these events.
break;
diff --git a/remoting/protocol/fake_authenticator.cc b/remoting/protocol/fake_authenticator.cc
index 67f5d35..67c1cca 100644
--- a/remoting/protocol/fake_authenticator.cc
+++ b/remoting/protocol/fake_authenticator.cc
@@ -84,12 +84,17 @@ FakeAuthenticator::FakeAuthenticator(
round_trips_(round_trips),
action_(action),
async_(async),
- messages_(0) {
+ messages_(0),
+ messages_till_started_(0) {
}
FakeAuthenticator::~FakeAuthenticator() {
}
+void FakeAuthenticator::set_messages_till_started(int messages) {
+ messages_till_started_ = messages;
+}
+
Authenticator::State FakeAuthenticator::state() const {
EXPECT_LE(messages_, round_trips_ * 2);
if (messages_ >= round_trips_ * 2) {
@@ -116,6 +121,10 @@ Authenticator::State FakeAuthenticator::state() const {
}
}
+bool FakeAuthenticator::started() const {
+ return messages_ > messages_till_started_;
+}
+
Authenticator::RejectionReason FakeAuthenticator::rejection_reason() const {
EXPECT_EQ(REJECTED, state());
return INVALID_CREDENTIALS;
@@ -153,8 +162,10 @@ FakeAuthenticator::CreateChannelAuthenticator() const {
}
FakeHostAuthenticatorFactory::FakeHostAuthenticatorFactory(
- int round_trips, FakeAuthenticator::Action action, bool async)
+ int round_trips, int messages_till_started,
+ FakeAuthenticator::Action action, bool async)
: round_trips_(round_trips),
+ messages_till_started_(messages_till_started),
action_(action), async_(async) {
}
@@ -165,8 +176,12 @@ scoped_ptr<Authenticator> FakeHostAuthenticatorFactory::CreateAuthenticator(
const std::string& local_jid,
const std::string& remote_jid,
const buzz::XmlElement* first_message) {
- return scoped_ptr<Authenticator>(new FakeAuthenticator(
- FakeAuthenticator::HOST, round_trips_, action_, async_));
+ FakeAuthenticator* authenticator = new FakeAuthenticator(
+ FakeAuthenticator::HOST, round_trips_, action_, async_);
+ authenticator->set_messages_till_started(messages_till_started_);
+
+ scoped_ptr<Authenticator> result(authenticator);
+ return result.Pass();
}
} // namespace protocol
diff --git a/remoting/protocol/fake_authenticator.h b/remoting/protocol/fake_authenticator.h
index a5a2de8..b74654d0 100644
--- a/remoting/protocol/fake_authenticator.h
+++ b/remoting/protocol/fake_authenticator.h
@@ -59,10 +59,16 @@ class FakeAuthenticator : public Authenticator {
};
FakeAuthenticator(Type type, int round_trips, Action action, bool async);
+
virtual ~FakeAuthenticator();
+ // Set the number of messages that the authenticator needs to process before
+ // started() returns true. Default to 0.
+ void set_messages_till_started(int messages);
+
// Authenticator interface.
virtual State state() const OVERRIDE;
+ virtual bool started() const OVERRIDE;
virtual RejectionReason rejection_reason() const OVERRIDE;
virtual void ProcessMessage(const buzz::XmlElement* message,
const base::Closure& resume_callback) OVERRIDE;
@@ -78,6 +84,9 @@ class FakeAuthenticator : public Authenticator {
// Total number of messages that have been processed.
int messages_;
+ // Number of messages that the authenticator needs to process before started()
+ // returns true. Default to 0.
+ int messages_till_started_;
DISALLOW_COPY_AND_ASSIGN(FakeAuthenticator);
};
@@ -85,7 +94,8 @@ class FakeAuthenticator : public Authenticator {
class FakeHostAuthenticatorFactory : public AuthenticatorFactory {
public:
FakeHostAuthenticatorFactory(
- int round_trips, FakeAuthenticator::Action action, bool async);
+ int round_trips, int messages_till_start,
+ FakeAuthenticator::Action action, bool async);
virtual ~FakeHostAuthenticatorFactory();
// AuthenticatorFactory interface.
@@ -96,6 +106,7 @@ class FakeHostAuthenticatorFactory : public AuthenticatorFactory {
private:
int round_trips_;
+ int messages_till_started_;
FakeAuthenticator::Action action_;
bool async_;
diff --git a/remoting/protocol/jingle_session.cc b/remoting/protocol/jingle_session.cc
index 2aecd4c..57883cb 100644
--- a/remoting/protocol/jingle_session.cc
+++ b/remoting/protocol/jingle_session.cc
@@ -6,8 +6,10 @@
#include "base/bind.h"
#include "base/rand_util.h"
+#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
#include "base/strings/string_number_conversions.h"
+#include "base/thread_task_runner_handle.h"
#include "base/time/time.h"
#include "remoting/base/constants.h"
#include "remoting/jingle_glue/iq_sender.h"
@@ -63,7 +65,8 @@ JingleSession::JingleSession(JingleSessionManager* session_manager)
event_handler_(NULL),
state_(INITIALIZING),
error_(OK),
- config_is_set_(false) {
+ config_is_set_(false),
+ weak_factory_(this) {
}
JingleSession::~JingleSession() {
@@ -181,9 +184,10 @@ void JingleSession::ContinueAcceptIncomingConnection() {
SetState(AUTHENTICATED);
} else {
DCHECK_EQ(authenticator_->state(), Authenticator::WAITING_MESSAGE);
+ if (authenticator_->started()) {
+ SetState(AUTHENTICATING);
+ }
}
-
- return;
}
const std::string& JingleSession::jid() {
@@ -485,7 +489,7 @@ void JingleSession::OnSessionInfo(const JingleMessage& message,
return;
}
- if (state_ != CONNECTED ||
+ if ((state_ != CONNECTED && state_ != AUTHENTICATING) ||
authenticator_->state() != Authenticator::WAITING_MESSAGE) {
LOG(WARNING) << "Received unexpected authenticator message "
<< message.info->Str();
@@ -517,8 +521,7 @@ void JingleSession::ProcessTransportInfo(const JingleMessage& message) {
void JingleSession::OnTerminate(const JingleMessage& message,
const ReplyCallback& reply_callback) {
- if (state_ != CONNECTING && state_ != ACCEPTING && state_ != CONNECTED &&
- state_ != AUTHENTICATED) {
+ if (!is_session_active()) {
LOG(WARNING) << "Received unexpected session-terminate message.";
reply_callback.Run(JingleMessageReply::UNEXPECTED_REQUEST);
return;
@@ -577,7 +580,7 @@ void JingleSession::ProcessAuthenticationStep() {
DCHECK(CalledOnValidThread());
DCHECK_NE(authenticator_->state(), Authenticator::PROCESSING_MESSAGE);
- if (state_ != CONNECTED) {
+ if (state_ != CONNECTED && state_ != AUTHENTICATING) {
DCHECK(state_ == FAILED || state_ == CLOSED);
// The remote host closed the connection while the authentication was being
// processed asynchronously, nothing to do.
@@ -592,6 +595,21 @@ void JingleSession::ProcessAuthenticationStep() {
}
DCHECK_NE(authenticator_->state(), Authenticator::MESSAGE_READY);
+ // The current JingleSession object can be destroyed by event_handler of
+ // SetState(AUTHENTICATING) and cause subsequent dereferencing of the this
+ // pointer to crash. To protect against it, we run ContinueAuthenticationStep
+ // asychronously using a weak pointer.
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(&JingleSession::ContinueAuthenticationStep,
+ weak_factory_.GetWeakPtr()));
+
+ if (authenticator_->started()) {
+ SetState(AUTHENTICATING);
+ }
+}
+
+void JingleSession::ContinueAuthenticationStep() {
if (authenticator_->state() == Authenticator::ACCEPTED) {
SetState(AUTHENTICATED);
} else if (authenticator_->state() == Authenticator::REJECTED) {
@@ -603,8 +621,7 @@ void JingleSession::ProcessAuthenticationStep() {
void JingleSession::CloseInternal(ErrorCode error) {
DCHECK(CalledOnValidThread());
- if (state_ == CONNECTING || state_ == ACCEPTING || state_ == CONNECTED ||
- state_ == AUTHENTICATED) {
+ if (is_session_active()) {
// Send session-terminate message with the appropriate error code.
JingleMessage::Reason reason;
switch (error) {
@@ -655,5 +672,10 @@ void JingleSession::SetState(State new_state) {
}
}
+bool JingleSession::is_session_active() {
+ return state_ == CONNECTING || state_ == ACCEPTING || state_ == CONNECTED ||
+ state_ == AUTHENTICATING || state_ == AUTHENTICATED;
+}
+
} // namespace protocol
} // namespace remoting
diff --git a/remoting/protocol/jingle_session.h b/remoting/protocol/jingle_session.h
index 189cb53..3b704fb 100644
--- a/remoting/protocol/jingle_session.h
+++ b/remoting/protocol/jingle_session.h
@@ -132,9 +132,13 @@ class JingleSession : public Session,
// Called after the initial incoming authenticator message is processed.
void ContinueAcceptIncomingConnection();
+
// Called after subsequent authenticator messages are processed.
void ProcessAuthenticationStep();
+ // Called after the authenticating step is finished.
+ void ContinueAuthenticationStep();
+
// Terminates the session and sends session-terminate if it is
// necessary. |error| specifies the error code in case when the
// session is being closed due to an error.
@@ -143,6 +147,9 @@ class JingleSession : public Session,
// Sets |state_| to |new_state| and calls state change callback.
void SetState(State new_state);
+ // Returns true if the state of the session is not CLOSED or FAILED
+ bool is_session_active();
+
JingleSessionManager* session_manager_;
std::string peer_jid_;
scoped_ptr<CandidateSessionConfig> candidate_config_;
@@ -172,6 +179,8 @@ class JingleSession : public Session,
// Pending remote candidates, received before the local channels were created.
std::list<JingleMessage::NamedCandidate> pending_remote_candidates_;
+ base::WeakPtrFactory<JingleSession> weak_factory_;
+
DISALLOW_COPY_AND_ASSIGN(JingleSession);
};
diff --git a/remoting/protocol/jingle_session_unittest.cc b/remoting/protocol/jingle_session_unittest.cc
index dcca077..2f95100 100644
--- a/remoting/protocol/jingle_session_unittest.cc
+++ b/remoting/protocol/jingle_session_unittest.cc
@@ -105,6 +105,10 @@ class JingleSessionTest : public testing::Test {
session->set_config(SessionConfig::ForTest());
}
+ void DeleteSession() {
+ host_session_.reset();
+ }
+
void OnClientChannelCreated(scoped_ptr<net::StreamSocket> socket) {
client_channel_callback_.OnDone(socket.get());
client_socket_ = socket.Pass();
@@ -132,7 +136,7 @@ class JingleSessionTest : public testing::Test {
client_session_.reset();
}
- void CreateSessionManagers(int auth_round_trips,
+ void CreateSessionManagers(int auth_round_trips, int messages_till_start,
FakeAuthenticator::Action auth_action) {
host_signal_strategy_.reset(new FakeSignalStrategy(kHostJid));
client_signal_strategy_.reset(new FakeSignalStrategy(kClientJid));
@@ -153,7 +157,8 @@ class JingleSessionTest : public testing::Test {
host_server_->Init(host_signal_strategy_.get(), &host_server_listener_);
scoped_ptr<AuthenticatorFactory> factory(
- new FakeHostAuthenticatorFactory(auth_round_trips, auth_action, true));
+ new FakeHostAuthenticatorFactory(auth_round_trips,
+ messages_till_start, auth_action, true));
host_server_->set_authenticator_factory(factory.Pass());
EXPECT_CALL(client_server_listener_, OnSessionManagerReady())
@@ -169,6 +174,11 @@ class JingleSessionTest : public testing::Test {
&client_server_listener_);
}
+ void CreateSessionManagers(int auth_round_trips,
+ FakeAuthenticator::Action auth_action) {
+ CreateSessionManagers(auth_round_trips, 0, auth_action);
+ }
+
void CloseSessionManager() {
if (host_server_.get()) {
host_server_->Close();
@@ -196,6 +206,9 @@ class JingleSessionTest : public testing::Test {
EXPECT_CALL(host_session_event_handler_,
OnSessionStateChange(Session::CONNECTED))
.Times(AtMost(1));
+ EXPECT_CALL(host_session_event_handler_,
+ OnSessionStateChange(Session::AUTHENTICATING))
+ .Times(AtMost(1));
if (expect_fail) {
EXPECT_CALL(host_session_event_handler_,
OnSessionStateChange(Session::FAILED))
@@ -217,6 +230,9 @@ class JingleSessionTest : public testing::Test {
EXPECT_CALL(client_session_event_handler_,
OnSessionStateChange(Session::CONNECTED))
.Times(AtMost(1));
+ EXPECT_CALL(client_session_event_handler_,
+ OnSessionStateChange(Session::AUTHENTICATING))
+ .Times(AtMost(1));
if (expect_fail) {
EXPECT_CALL(client_session_event_handler_,
OnSessionStateChange(Session::FAILED))
@@ -375,6 +391,60 @@ TEST_F(JingleSessionTest, TestStreamChannel) {
tester.CheckResults();
}
+TEST_F(JingleSessionTest, DeleteSessionOnIncomingConnection) {
+ CreateSessionManagers(3, FakeAuthenticator::ACCEPT);
+
+ EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _))
+ .WillOnce(DoAll(
+ WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)),
+ SetArgumentPointee<1>(protocol::SessionManager::ACCEPT)));
+
+ EXPECT_CALL(host_session_event_handler_,
+ OnSessionStateChange(Session::CONNECTED))
+ .Times(AtMost(1));
+
+ EXPECT_CALL(host_session_event_handler_,
+ OnSessionStateChange(Session::AUTHENTICATING))
+ .WillOnce(InvokeWithoutArgs(this, &JingleSessionTest::DeleteSession));
+
+ scoped_ptr<Authenticator> authenticator(new FakeAuthenticator(
+ FakeAuthenticator::CLIENT, 3, FakeAuthenticator::ACCEPT, true));
+
+ client_session_ = client_server_->Connect(
+ kHostJid, authenticator.Pass(),
+ CandidateSessionConfig::CreateDefault());
+
+ base::RunLoop().RunUntilIdle();
+}
+
+TEST_F(JingleSessionTest, DeleteSessionOnAuth) {
+ // Same as the previous test, but set messages_till_started to 2 in
+ // CreateSessionManagers so that the session will goes into the
+ // AUTHENTICATING state after two message exchanges.
+ CreateSessionManagers(3, 2, FakeAuthenticator::ACCEPT);
+
+ EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _))
+ .WillOnce(DoAll(
+ WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)),
+ SetArgumentPointee<1>(protocol::SessionManager::ACCEPT)));
+
+ EXPECT_CALL(host_session_event_handler_,
+ OnSessionStateChange(Session::CONNECTED))
+ .Times(AtMost(1));
+
+ EXPECT_CALL(host_session_event_handler_,
+ OnSessionStateChange(Session::AUTHENTICATING))
+ .WillOnce(InvokeWithoutArgs(this, &JingleSessionTest::DeleteSession));
+
+ scoped_ptr<Authenticator> authenticator(new FakeAuthenticator(
+ FakeAuthenticator::CLIENT, 3, FakeAuthenticator::ACCEPT, true));
+
+ client_session_ = client_server_->Connect(
+ kHostJid, authenticator.Pass(),
+ CandidateSessionConfig::CreateDefault());
+ base::RunLoop().RunUntilIdle();
+}
+
// Verify that data can be sent over a multiplexed channel.
TEST_F(JingleSessionTest, TestMuxStreamChannel) {
CreateSessionManagers(1, FakeAuthenticator::ACCEPT);
diff --git a/remoting/protocol/me2me_host_authenticator_factory.cc b/remoting/protocol/me2me_host_authenticator_factory.cc
index 24c0ca4..7e407c8 100644
--- a/remoting/protocol/me2me_host_authenticator_factory.cc
+++ b/remoting/protocol/me2me_host_authenticator_factory.cc
@@ -30,6 +30,10 @@ class RejectingAuthenticator : public Authenticator {
return state_;
}
+ virtual bool started() const OVERRIDE {
+ return true;
+ }
+
virtual RejectionReason rejection_reason() const OVERRIDE {
DCHECK_EQ(state_, REJECTED);
return INVALID_CREDENTIALS;
diff --git a/remoting/protocol/negotiating_authenticator_base.cc b/remoting/protocol/negotiating_authenticator_base.cc
index 30bd8ba..5c61d30 100644
--- a/remoting/protocol/negotiating_authenticator_base.cc
+++ b/remoting/protocol/negotiating_authenticator_base.cc
@@ -40,6 +40,13 @@ Authenticator::State NegotiatingAuthenticatorBase::state() const {
return state_;
}
+bool NegotiatingAuthenticatorBase::started() const {
+ if (!current_authenticator_) {
+ return false;
+ }
+ return current_authenticator_->started();
+}
+
Authenticator::RejectionReason
NegotiatingAuthenticatorBase::rejection_reason() const {
return rejection_reason_;
diff --git a/remoting/protocol/negotiating_authenticator_base.h b/remoting/protocol/negotiating_authenticator_base.h
index 1f80967..bcb005f2 100644
--- a/remoting/protocol/negotiating_authenticator_base.h
+++ b/remoting/protocol/negotiating_authenticator_base.h
@@ -64,6 +64,7 @@ class NegotiatingAuthenticatorBase : public Authenticator {
// Authenticator interface.
virtual State state() const OVERRIDE;
+ virtual bool started() const OVERRIDE;
virtual RejectionReason rejection_reason() const OVERRIDE;
virtual scoped_ptr<ChannelAuthenticator>
CreateChannelAuthenticator() const OVERRIDE;
diff --git a/remoting/protocol/pairing_authenticator_base.cc b/remoting/protocol/pairing_authenticator_base.cc
index 47ca275..7435e55 100644
--- a/remoting/protocol/pairing_authenticator_base.cc
+++ b/remoting/protocol/pairing_authenticator_base.cc
@@ -38,6 +38,13 @@ Authenticator::State PairingAuthenticatorBase::state() const {
return v2_authenticator_->state();
}
+bool PairingAuthenticatorBase::started() const {
+ if (!v2_authenticator_) {
+ return false;
+ }
+ return v2_authenticator_->started();
+}
+
Authenticator::RejectionReason
PairingAuthenticatorBase::rejection_reason() const {
if (!v2_authenticator_) {
diff --git a/remoting/protocol/pairing_authenticator_base.h b/remoting/protocol/pairing_authenticator_base.h
index 2a4bc4e..4cb042c 100644
--- a/remoting/protocol/pairing_authenticator_base.h
+++ b/remoting/protocol/pairing_authenticator_base.h
@@ -43,6 +43,7 @@ class PairingAuthenticatorBase : public Authenticator {
// Authenticator interface.
virtual State state() const OVERRIDE;
+ virtual bool started() const OVERRIDE;
virtual RejectionReason rejection_reason() const OVERRIDE;
virtual void ProcessMessage(const buzz::XmlElement* message,
const base::Closure& resume_callback) OVERRIDE;
diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h
index fecf1d8..07e03306 100644
--- a/remoting/protocol/protocol_mock_objects.h
+++ b/remoting/protocol/protocol_mock_objects.h
@@ -52,6 +52,8 @@ class MockConnectionToClientEventHandler :
MockConnectionToClientEventHandler();
virtual ~MockConnectionToClientEventHandler();
+ MOCK_METHOD1(OnConnectionAuthenticating,
+ void(ConnectionToClient* connection));
MOCK_METHOD1(OnConnectionAuthenticated, void(ConnectionToClient* connection));
MOCK_METHOD1(OnConnectionChannelsConnected,
void(ConnectionToClient* connection));
diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h
index a542283..232a1d04 100644
--- a/remoting/protocol/session.h
+++ b/remoting/protocol/session.h
@@ -38,6 +38,9 @@ class Session {
// Session has been accepted and is pending authentication.
CONNECTED,
+ // Session has started authenticating.
+ AUTHENTICATING,
+
// Session has been connected and authenticated.
AUTHENTICATED,
@@ -54,8 +57,8 @@ class Session {
virtual ~EventHandler() {}
// Called after session state has changed. It is safe to destroy
- // the session from within the handler if |state| is CLOSED or
- // FAILED.
+ // the session from within the handler if |state| is AUTHENTICATING
+ // or CLOSED or FAILED.
virtual void OnSessionStateChange(State state) = 0;
// Called whenever route for the channel specified with
diff --git a/remoting/protocol/third_party_authenticator_base.cc b/remoting/protocol/third_party_authenticator_base.cc
index 9d6be72..019e788 100644
--- a/remoting/protocol/third_party_authenticator_base.cc
+++ b/remoting/protocol/third_party_authenticator_base.cc
@@ -28,12 +28,17 @@ const buzz::StaticQName ThirdPartyAuthenticatorBase::kTokenTag =
ThirdPartyAuthenticatorBase::ThirdPartyAuthenticatorBase(
Authenticator::State initial_state)
: token_state_(initial_state),
+ started_(false),
rejection_reason_(INVALID_CREDENTIALS) {
}
ThirdPartyAuthenticatorBase::~ThirdPartyAuthenticatorBase() {
}
+bool ThirdPartyAuthenticatorBase::started() const {
+ return started_;
+}
+
Authenticator::State ThirdPartyAuthenticatorBase::state() const {
if (token_state_ == ACCEPTED)
return underlying_->state();
@@ -74,9 +79,10 @@ scoped_ptr<buzz::XmlElement> ThirdPartyAuthenticatorBase::GetNextMessage() {
message = CreateEmptyAuthenticatorMessage();
}
- if (token_state_ == MESSAGE_READY)
+ if (token_state_ == MESSAGE_READY) {
AddTokenElements(message.get());
-
+ started_ = true;
+ }
return message.Pass();
}
diff --git a/remoting/protocol/third_party_authenticator_base.h b/remoting/protocol/third_party_authenticator_base.h
index 0db203d..7141be8 100644
--- a/remoting/protocol/third_party_authenticator_base.h
+++ b/remoting/protocol/third_party_authenticator_base.h
@@ -36,6 +36,7 @@ class ThirdPartyAuthenticatorBase : public Authenticator {
// Authenticator interface.
virtual State state() const OVERRIDE;
+ virtual bool started() const OVERRIDE;
virtual RejectionReason rejection_reason() const OVERRIDE;
virtual void ProcessMessage(const buzz::XmlElement* message,
const base::Closure& resume_callback) OVERRIDE;
@@ -66,6 +67,7 @@ class ThirdPartyAuthenticatorBase : public Authenticator {
scoped_ptr<Authenticator> underlying_;
State token_state_;
+ bool started_;
RejectionReason rejection_reason_;
private:
diff --git a/remoting/protocol/v2_authenticator.cc b/remoting/protocol/v2_authenticator.cc
index ee5c9d1..1b13c7c 100644
--- a/remoting/protocol/v2_authenticator.cc
+++ b/remoting/protocol/v2_authenticator.cc
@@ -62,6 +62,7 @@ V2Authenticator::V2Authenticator(
: certificate_sent_(false),
key_exchange_impl_(type, shared_secret),
state_(initial_state),
+ started_(false),
rejection_reason_(INVALID_CREDENTIALS) {
pending_messages_.push(key_exchange_impl_.GetMessage());
}
@@ -75,6 +76,10 @@ Authenticator::State V2Authenticator::state() const {
return state_;
}
+bool V2Authenticator::started() const {
+ return started_;
+}
+
Authenticator::RejectionReason V2Authenticator::rejection_reason() const {
DCHECK_EQ(state(), REJECTED);
return rejection_reason_;
@@ -127,6 +132,7 @@ void V2Authenticator::ProcessMessageInternal(const buzz::XmlElement* message) {
P224EncryptedKeyExchange::Result result =
key_exchange_impl_.ProcessMessage(spake_message);
+ started_ = true;
switch (result) {
case P224EncryptedKeyExchange::kResultPending:
pending_messages_.push(key_exchange_impl_.GetMessage());
@@ -143,7 +149,6 @@ void V2Authenticator::ProcessMessageInternal(const buzz::XmlElement* message) {
return;
}
}
-
state_ = MESSAGE_READY;
}
diff --git a/remoting/protocol/v2_authenticator.h b/remoting/protocol/v2_authenticator.h
index 0675cb2..b52a3e1 100644
--- a/remoting/protocol/v2_authenticator.h
+++ b/remoting/protocol/v2_authenticator.h
@@ -38,6 +38,7 @@ class V2Authenticator : public Authenticator {
// Authenticator interface.
virtual State state() const OVERRIDE;
+ virtual bool started() const OVERRIDE;
virtual RejectionReason rejection_reason() const OVERRIDE;
virtual void ProcessMessage(const buzz::XmlElement* message,
const base::Closure& resume_callback) OVERRIDE;
@@ -67,6 +68,7 @@ class V2Authenticator : public Authenticator {
// Used for both host and client authenticators.
crypto::P224EncryptedKeyExchange key_exchange_impl_;
State state_;
+ bool started_;
RejectionReason rejection_reason_;
std::queue<std::string> pending_messages_;
std::string auth_key_;