diff options
Diffstat (limited to 'google_apis')
-rw-r--r-- | google_apis/gcm/base/mcs_message_unittest.cc | 92 | ||||
-rw-r--r-- | google_apis/gcm/base/mcs_util.cc | 19 | ||||
-rw-r--r-- | google_apis/gcm/base/mcs_util.h | 2 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory.h | 26 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory_impl.cc | 35 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory_impl.h | 10 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory_impl_unittest.cc | 36 | ||||
-rw-r--r-- | google_apis/gcm/engine/fake_connection_factory.cc | 46 | ||||
-rw-r--r-- | google_apis/gcm/engine/fake_connection_factory.h | 42 | ||||
-rw-r--r-- | google_apis/gcm/engine/fake_connection_handler.cc | 86 | ||||
-rw-r--r-- | google_apis/gcm/engine/fake_connection_handler.h | 74 | ||||
-rw-r--r-- | google_apis/gcm/engine/mcs_client.cc | 659 | ||||
-rw-r--r-- | google_apis/gcm/engine/mcs_client.h | 231 | ||||
-rw-r--r-- | google_apis/gcm/engine/mcs_client_unittest.cc | 540 | ||||
-rw-r--r-- | google_apis/gcm/gcm.gyp | 32 | ||||
-rw-r--r-- | google_apis/gcm/tools/mcs_probe.cc | 372 |
16 files changed, 2257 insertions, 45 deletions
diff --git a/google_apis/gcm/base/mcs_message_unittest.cc b/google_apis/gcm/base/mcs_message_unittest.cc new file mode 100644 index 0000000..4d4ef59 --- /dev/null +++ b/google_apis/gcm/base/mcs_message_unittest.cc @@ -0,0 +1,92 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/base/mcs_message.h" + +#include "base/logging.h" +#include "base/message_loop/message_loop.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/protocol/mcs.pb.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace gcm { + +const uint64 kAndroidId = 12345; +const uint64 kSecret = 54321; + +class MCSMessageTest : public testing::Test { + public: + MCSMessageTest(); + virtual ~MCSMessageTest(); + private: + base::MessageLoop message_loop_; +}; + +MCSMessageTest::MCSMessageTest() { +} + +MCSMessageTest::~MCSMessageTest() { +} + +TEST_F(MCSMessageTest, Invalid) { + MCSMessage message; + EXPECT_FALSE(message.IsValid()); +} + +TEST_F(MCSMessageTest, InitInferTag) { + scoped_ptr<mcs_proto::LoginRequest> login_request( + BuildLoginRequest(kAndroidId, kSecret)); + scoped_ptr<google::protobuf::MessageLite> login_copy( + new mcs_proto::LoginRequest(*login_request)); + MCSMessage message(*login_copy); + login_copy.reset(); + ASSERT_TRUE(message.IsValid()); + EXPECT_EQ(kLoginRequestTag, message.tag()); + EXPECT_EQ(login_request->ByteSize(), message.size()); + EXPECT_EQ(login_request->SerializeAsString(), message.SerializeAsString()); + EXPECT_EQ(login_request->SerializeAsString(), + message.GetProtobuf().SerializeAsString()); + login_copy = message.CloneProtobuf(); + EXPECT_EQ(login_request->SerializeAsString(), + login_copy->SerializeAsString()); +} + +TEST_F(MCSMessageTest, InitWithTag) { + scoped_ptr<mcs_proto::LoginRequest> login_request( + BuildLoginRequest(kAndroidId, kSecret)); + scoped_ptr<google::protobuf::MessageLite> login_copy( + new mcs_proto::LoginRequest(*login_request)); + MCSMessage message(kLoginRequestTag, *login_copy); + login_copy.reset(); + ASSERT_TRUE(message.IsValid()); + EXPECT_EQ(kLoginRequestTag, message.tag()); + EXPECT_EQ(login_request->ByteSize(), message.size()); + EXPECT_EQ(login_request->SerializeAsString(), message.SerializeAsString()); + EXPECT_EQ(login_request->SerializeAsString(), + message.GetProtobuf().SerializeAsString()); + login_copy = message.CloneProtobuf(); + EXPECT_EQ(login_request->SerializeAsString(), + login_copy->SerializeAsString()); +} + +TEST_F(MCSMessageTest, InitPassOwnership) { + scoped_ptr<mcs_proto::LoginRequest> login_request( + BuildLoginRequest(kAndroidId, kSecret)); + scoped_ptr<google::protobuf::MessageLite> login_copy( + new mcs_proto::LoginRequest(*login_request)); + MCSMessage message(kLoginRequestTag, + login_copy.PassAs<const google::protobuf::MessageLite>()); + EXPECT_FALSE(login_copy.get()); + ASSERT_TRUE(message.IsValid()); + EXPECT_EQ(kLoginRequestTag, message.tag()); + EXPECT_EQ(login_request->ByteSize(), message.size()); + EXPECT_EQ(login_request->SerializeAsString(), message.SerializeAsString()); + EXPECT_EQ(login_request->SerializeAsString(), + message.GetProtobuf().SerializeAsString()); + login_copy = message.CloneProtobuf(); + EXPECT_EQ(login_request->SerializeAsString(), + login_copy->SerializeAsString()); +} + +} // namespace gcm diff --git a/google_apis/gcm/base/mcs_util.cc b/google_apis/gcm/base/mcs_util.cc index b52d429..7365560 100644 --- a/google_apis/gcm/base/mcs_util.cc +++ b/google_apis/gcm/base/mcs_util.cc @@ -46,9 +46,8 @@ const char kLoginSettingValue[] = "1"; } // namespace -scoped_ptr<mcs_proto::LoginRequest> BuildLoginRequest( - uint64 auth_id, - uint64 auth_token) { +scoped_ptr<mcs_proto::LoginRequest> BuildLoginRequest(uint64 auth_id, + uint64 auth_token) { // Create a hex encoded auth id for the device id field. std::string auth_id_hex; auth_id_hex = base::StringPrintf("%" PRIx64, auth_id); @@ -87,6 +86,20 @@ scoped_ptr<mcs_proto::IqStanza> BuildStreamAck() { return stream_ack_iq.Pass(); } +scoped_ptr<mcs_proto::IqStanza> BuildSelectiveAck( + const std::vector<std::string>& acked_ids) { + scoped_ptr<mcs_proto::IqStanza> selective_ack_iq(new mcs_proto::IqStanza()); + selective_ack_iq->set_type(mcs_proto::IqStanza::SET); + selective_ack_iq->set_id(""); + selective_ack_iq->mutable_extension()->set_id(kSelectiveAck); + mcs_proto::SelectiveAck selective_ack; + for (size_t i = 0; i < acked_ids.size(); ++i) + selective_ack.add_id(acked_ids[i]); + selective_ack_iq->mutable_extension()->set_data( + selective_ack.SerializeAsString()); + return selective_ack_iq.Pass(); +} + // Utility method to build a google::protobuf::MessageLite object from a MCS // tag. scoped_ptr<google::protobuf::MessageLite> BuildProtobufFromTag(uint8 tag) { diff --git a/google_apis/gcm/base/mcs_util.h b/google_apis/gcm/base/mcs_util.h index d125af7..7f92564 100644 --- a/google_apis/gcm/base/mcs_util.h +++ b/google_apis/gcm/base/mcs_util.h @@ -56,6 +56,8 @@ GCM_EXPORT scoped_ptr<mcs_proto::LoginRequest> BuildLoginRequest( // Builds a StreamAck IqStanza message. GCM_EXPORT scoped_ptr<mcs_proto::IqStanza> BuildStreamAck(); +GCM_EXPORT scoped_ptr<mcs_proto::IqStanza> BuildSelectiveAck( + const std::vector<std::string>& acked_ids); // Utility methods for building and identifying MCS protobufs. GCM_EXPORT scoped_ptr<google::protobuf::MessageLite> diff --git a/google_apis/gcm/engine/connection_factory.h b/google_apis/gcm/engine/connection_factory.h index 598c211..3cff482 100644 --- a/google_apis/gcm/engine/connection_factory.h +++ b/google_apis/gcm/engine/connection_factory.h @@ -20,26 +20,34 @@ namespace gcm { // backoff policies when attempting connections. class GCM_EXPORT ConnectionFactory { public: + typedef base::Callback<void(mcs_proto::LoginRequest* login_request)> + BuildLoginRequestCallback; + ConnectionFactory(); virtual ~ConnectionFactory(); - // Create a new uninitialized connection handler. Should only be called once. - // The factory will retain ownership of the connection handler. + // Initialize the factory, creating a connection handler with a disconnected + // socket. Should only be called once. + // Upon connection: // |read_callback| will be invoked with the contents of any received protobuf // message. // |write_callback| will be invoked anytime a message has been successfully // sent. Note: this just means the data was sent to the wire, not that the // other end received it. - virtual ConnectionHandler* BuildConnectionHandler( + virtual void Initialize( + const BuildLoginRequestCallback& request_builder, const ConnectionHandler::ProtoReceivedCallback& read_callback, const ConnectionHandler::ProtoSentCallback& write_callback) = 0; - // Opens a new connection for use by the locally owned connection handler - // (created via BuildConnectionHandler), and initiates login handshake using - // |login_request|. Upon completion of the handshake, |read_callback| - // will be invoked with a valid mcs_proto::LoginResponse. - // Note: BuildConnectionHandler must have already been invoked. - virtual void Connect(const mcs_proto::LoginRequest& login_request) = 0; + // Get the connection handler for this factory. Initialize(..) must have + // been called. + virtual ConnectionHandler* GetConnectionHandler() const = 0; + + // Opens a new connection and initiates login handshake. Upon completion of + // the handshake, |read_callback| will be invoked with a valid + // mcs_proto::LoginResponse. + // Note: Initialize must have already been invoked. + virtual void Connect() = 0; // Whether or not the MCS endpoint is currently reachable with an active // connection. diff --git a/google_apis/gcm/engine/connection_factory_impl.cc b/google_apis/gcm/engine/connection_factory_impl.cc index 0a87acc..388b9dc 100644 --- a/google_apis/gcm/engine/connection_factory_impl.cc +++ b/google_apis/gcm/engine/connection_factory_impl.cc @@ -64,12 +64,14 @@ ConnectionFactoryImpl::ConnectionFactoryImpl( ConnectionFactoryImpl::~ConnectionFactoryImpl() { } -ConnectionHandler* ConnectionFactoryImpl::BuildConnectionHandler( - const ConnectionHandler::ProtoReceivedCallback& read_callback, - const ConnectionHandler::ProtoSentCallback& write_callback) { +void ConnectionFactoryImpl::Initialize( + const BuildLoginRequestCallback& request_builder, + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) { DCHECK(!connection_handler_); backoff_entry_ = CreateBackoffEntry(&kConnectionBackoffPolicy); + request_builder_ = request_builder; net::NetworkChangeNotifier::AddIPAddressObserver(this); net::NetworkChangeNotifier::AddConnectionTypeObserver(this); @@ -80,19 +82,16 @@ ConnectionHandler* ConnectionFactoryImpl::BuildConnectionHandler( write_callback, base::Bind(&ConnectionFactoryImpl::ConnectionHandlerCallback, weak_ptr_factory_.GetWeakPtr()))); +} + +ConnectionHandler* ConnectionFactoryImpl::GetConnectionHandler() const { return connection_handler_.get(); } -void ConnectionFactoryImpl::Connect( - const mcs_proto::LoginRequest& login_request) { +void ConnectionFactoryImpl::Connect() { DCHECK(connection_handler_); DCHECK(!IsEndpointReachable()); - if (login_request.IsInitialized()) { - DCHECK(!login_request_.IsInitialized()); - login_request_ = login_request; - } - if (backoff_entry_->ShouldRejectRequest()) { DVLOG(1) << "Delaying MCS endpoint connection for " << backoff_entry_->GetTimeUntilRelease().InMilliseconds() @@ -100,8 +99,7 @@ void ConnectionFactoryImpl::Connect( base::MessageLoop::current()->PostDelayedTask( FROM_HERE, base::Bind(&ConnectionFactoryImpl::Connect, - weak_ptr_factory_.GetWeakPtr(), - login_request_), + weak_ptr_factory_.GetWeakPtr()), NextRetryAttempt() - base::TimeTicks::Now()); return; } @@ -165,7 +163,14 @@ void ConnectionFactoryImpl::ConnectImpl() { } void ConnectionFactoryImpl::InitHandler() { - connection_handler_->Init(login_request_, socket_handle_.PassSocket()); + // May be null in tests. + mcs_proto::LoginRequest login_request; + if (!request_builder_.is_null()) { + request_builder_.Run(&login_request); + DCHECK(login_request.IsInitialized()); + } + + connection_handler_->Init(login_request, socket_handle_.PassSocket()); } scoped_ptr<net::BackoffEntry> ConnectionFactoryImpl::CreateBackoffEntry( @@ -177,7 +182,7 @@ void ConnectionFactoryImpl::OnConnectDone(int result) { if (result != net::OK) { LOG(ERROR) << "Failed to connect to MCS endpoint with error " << result; backoff_entry_->InformOfRequest(false); - Connect(mcs_proto::LoginRequest()); + Connect(); return; } @@ -194,7 +199,7 @@ void ConnectionFactoryImpl::ConnectionHandlerCallback(int result) { // user intervention (login page, etc.). LOG(ERROR) << "Connection reset with error " << result; backoff_entry_->InformOfRequest(false); - Connect(mcs_proto::LoginRequest()); + Connect(); } } // namespace gcm diff --git a/google_apis/gcm/engine/connection_factory_impl.h b/google_apis/gcm/engine/connection_factory_impl.h index 0e40521..d807270 100644 --- a/google_apis/gcm/engine/connection_factory_impl.h +++ b/google_apis/gcm/engine/connection_factory_impl.h @@ -35,10 +35,12 @@ class GCM_EXPORT ConnectionFactoryImpl : virtual ~ConnectionFactoryImpl(); // ConnectionFactory implementation. - virtual ConnectionHandler* BuildConnectionHandler( + virtual void Initialize( + const BuildLoginRequestCallback& request_builder, const ConnectionHandler::ProtoReceivedCallback& read_callback, const ConnectionHandler::ProtoSentCallback& write_callback) OVERRIDE; - virtual void Connect(const mcs_proto::LoginRequest& login_request) OVERRIDE; + virtual ConnectionHandler* GetConnectionHandler() const OVERRIDE; + virtual void Connect() OVERRIDE; virtual bool IsEndpointReachable() const OVERRIDE; virtual base::TimeTicks NextRetryAttempt() const OVERRIDE; @@ -86,8 +88,8 @@ class GCM_EXPORT ConnectionFactoryImpl : // The current connection handler, if one exists. scoped_ptr<ConnectionHandlerImpl> connection_handler_; - // The current login request if a connection attempt is in progress/pending. - mcs_proto::LoginRequest login_request_; + // Builder for generating new login requests. + BuildLoginRequestCallback request_builder_; base::WeakPtrFactory<ConnectionFactoryImpl> weak_ptr_factory_; diff --git a/google_apis/gcm/engine/connection_factory_impl_unittest.cc b/google_apis/gcm/engine/connection_factory_impl_unittest.cc index 40adcf2..1e0ccef 100644 --- a/google_apis/gcm/engine/connection_factory_impl_unittest.cc +++ b/google_apis/gcm/engine/connection_factory_impl_unittest.cc @@ -192,43 +192,48 @@ void ConnectionFactoryImplTest::ConnectionsComplete() { } // Verify building a connection handler works. -TEST_F(ConnectionFactoryImplTest, BuildConnectionHandler) { +TEST_F(ConnectionFactoryImplTest, Initialize) { EXPECT_FALSE(factory()->IsEndpointReachable()); - ConnectionHandler* handler = factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), base::Bind(&ReadContinuation), base::Bind(&WriteContinuation)); + ConnectionHandler* handler = factory()->GetConnectionHandler(); ASSERT_TRUE(handler); EXPECT_FALSE(factory()->IsEndpointReachable()); } // An initial successful connection should not result in backoff. TEST_F(ConnectionFactoryImplTest, ConnectSuccess) { - factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), ConnectionHandler::ProtoReceivedCallback(), ConnectionHandler::ProtoSentCallback()); factory()->SetConnectResult(net::OK); - factory()->Connect(mcs_proto::LoginRequest()); + factory()->Connect(); EXPECT_TRUE(factory()->NextRetryAttempt().is_null()); } // A connection failure should result in backoff. TEST_F(ConnectionFactoryImplTest, ConnectFail) { - factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), ConnectionHandler::ProtoReceivedCallback(), ConnectionHandler::ProtoSentCallback()); factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); - factory()->Connect(mcs_proto::LoginRequest()); + factory()->Connect(); EXPECT_FALSE(factory()->NextRetryAttempt().is_null()); } // A connection success after a failure should reset backoff. TEST_F(ConnectionFactoryImplTest, FailThenSucceed) { - factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), ConnectionHandler::ProtoReceivedCallback(), ConnectionHandler::ProtoSentCallback()); factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); base::TimeTicks connect_time = base::TimeTicks::Now(); - factory()->Connect(mcs_proto::LoginRequest()); + factory()->Connect(); WaitForConnections(); base::TimeTicks retry_time = factory()->NextRetryAttempt(); EXPECT_FALSE(retry_time.is_null()); @@ -241,7 +246,8 @@ TEST_F(ConnectionFactoryImplTest, FailThenSucceed) { // Multiple connection failures should retry with an exponentially increasing // backoff, then reset on success. TEST_F(ConnectionFactoryImplTest, MultipleFailuresThenSucceed) { - factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), ConnectionHandler::ProtoReceivedCallback(), ConnectionHandler::ProtoSentCallback()); @@ -250,7 +256,7 @@ TEST_F(ConnectionFactoryImplTest, MultipleFailuresThenSucceed) { kNumAttempts); base::TimeTicks connect_time = base::TimeTicks::Now(); - factory()->Connect(mcs_proto::LoginRequest()); + factory()->Connect(); WaitForConnections(); base::TimeTicks retry_time = factory()->NextRetryAttempt(); EXPECT_FALSE(retry_time.is_null()); @@ -264,11 +270,12 @@ TEST_F(ConnectionFactoryImplTest, MultipleFailuresThenSucceed) { // IP events should reset backoff. TEST_F(ConnectionFactoryImplTest, FailThenIPEvent) { - factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), ConnectionHandler::ProtoReceivedCallback(), ConnectionHandler::ProtoSentCallback()); factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); - factory()->Connect(mcs_proto::LoginRequest()); + factory()->Connect(); WaitForConnections(); EXPECT_FALSE(factory()->NextRetryAttempt().is_null()); @@ -278,11 +285,12 @@ TEST_F(ConnectionFactoryImplTest, FailThenIPEvent) { // Connection type events should reset backoff. TEST_F(ConnectionFactoryImplTest, FailThenConnectionTypeEvent) { - factory()->BuildConnectionHandler( + factory()->Initialize( + ConnectionFactory::BuildLoginRequestCallback(), ConnectionHandler::ProtoReceivedCallback(), ConnectionHandler::ProtoSentCallback()); factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); - factory()->Connect(mcs_proto::LoginRequest()); + factory()->Connect(); WaitForConnections(); EXPECT_FALSE(factory()->NextRetryAttempt().is_null()); diff --git a/google_apis/gcm/engine/fake_connection_factory.cc b/google_apis/gcm/engine/fake_connection_factory.cc new file mode 100644 index 0000000..54b3423 --- /dev/null +++ b/google_apis/gcm/engine/fake_connection_factory.cc @@ -0,0 +1,46 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/fake_connection_factory.h" + +#include "google_apis/gcm/engine/fake_connection_handler.h" +#include "google_apis/gcm/protocol/mcs.pb.h" +#include "net/socket/stream_socket.h" + +namespace gcm { + +FakeConnectionFactory::FakeConnectionFactory() { +} + +FakeConnectionFactory::~FakeConnectionFactory() { +} + +void FakeConnectionFactory::Initialize( + const BuildLoginRequestCallback& request_builder, + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) { + request_builder_ = request_builder; + connection_handler_.reset(new FakeConnectionHandler(read_callback, + write_callback)); +} + +ConnectionHandler* FakeConnectionFactory::GetConnectionHandler() const { + return connection_handler_.get(); +} + +void FakeConnectionFactory::Connect() { + mcs_proto::LoginRequest login_request; + request_builder_.Run(&login_request); + connection_handler_->Init(login_request, scoped_ptr<net::StreamSocket>()); +} + +bool FakeConnectionFactory::IsEndpointReachable() const { + return connection_handler_.get() && connection_handler_->CanSendMessage(); +} + +base::TimeTicks FakeConnectionFactory::NextRetryAttempt() const { + return base::TimeTicks(); +} + +} // namespace gcm diff --git a/google_apis/gcm/engine/fake_connection_factory.h b/google_apis/gcm/engine/fake_connection_factory.h new file mode 100644 index 0000000..60b10e1 --- /dev/null +++ b/google_apis/gcm/engine/fake_connection_factory.h @@ -0,0 +1,42 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_FAKE_CONNECTION_FACTORY_H_ +#define GOOGLE_APIS_GCM_ENGINE_FAKE_CONNECTION_FACTORY_H_ + +#include "base/memory/scoped_ptr.h" +#include "google_apis/gcm/engine/connection_factory.h" + +namespace gcm { + +class FakeConnectionHandler; + +// A connection factory that mocks out real connections, using a fake connection +// handler instead. +class FakeConnectionFactory : public ConnectionFactory { + public: + FakeConnectionFactory(); + virtual ~FakeConnectionFactory(); + + // ConnectionFactory implementation. + virtual void Initialize( + const BuildLoginRequestCallback& request_builder, + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) OVERRIDE; + virtual ConnectionHandler* GetConnectionHandler() const OVERRIDE; + virtual void Connect() OVERRIDE; + virtual bool IsEndpointReachable() const OVERRIDE; + virtual base::TimeTicks NextRetryAttempt() const OVERRIDE; + + private: + scoped_ptr<FakeConnectionHandler> connection_handler_; + + BuildLoginRequestCallback request_builder_; + + DISALLOW_COPY_AND_ASSIGN(FakeConnectionFactory); +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_FAKE_CONNECTION_FACTORY_H_ diff --git a/google_apis/gcm/engine/fake_connection_handler.cc b/google_apis/gcm/engine/fake_connection_handler.cc new file mode 100644 index 0000000..0663933 --- /dev/null +++ b/google_apis/gcm/engine/fake_connection_handler.cc @@ -0,0 +1,86 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/fake_connection_handler.h" + +#include "base/logging.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "net/socket/stream_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace gcm { + +namespace { + +// Build a basic login response. +scoped_ptr<google::protobuf::MessageLite> BuildLoginResponse(bool fail_login) { + scoped_ptr<mcs_proto::LoginResponse> login_response( + new mcs_proto::LoginResponse()); + login_response->set_id("id"); + if (fail_login) + login_response->mutable_error()->set_code(1); + return login_response.PassAs<google::protobuf::MessageLite>(); +} + +} // namespace + +FakeConnectionHandler::FakeConnectionHandler( + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) + : read_callback_(read_callback), + write_callback_(write_callback), + fail_login_(false), + fail_send_(false), + initialized_(false) { +} + +FakeConnectionHandler::~FakeConnectionHandler() { +} + +void FakeConnectionHandler::Init(const mcs_proto::LoginRequest& login_request, + scoped_ptr<net::StreamSocket> socket) { + EXPECT_EQ(expected_outgoing_messages_.front().SerializeAsString(), + login_request.SerializeAsString()); + expected_outgoing_messages_.pop_front(); + DVLOG(1) << "Received init call."; + read_callback_.Run(BuildLoginResponse(fail_login_)); + initialized_ = !fail_login_; +} + +bool FakeConnectionHandler::CanSendMessage() const { + return initialized_; +} + +void FakeConnectionHandler::SendMessage( + const google::protobuf::MessageLite& message) { + if (expected_outgoing_messages_.empty()) + FAIL() << "Unexpected message sent."; + EXPECT_EQ(expected_outgoing_messages_.front().SerializeAsString(), + message.SerializeAsString()); + expected_outgoing_messages_.pop_front(); + DVLOG(1) << "Received message, " + << (fail_send_ ? " failing send." : "calling back."); + if (!fail_send_) + write_callback_.Run(); + else + initialized_ = false; // Prevent future messages until reconnect. +} + +void FakeConnectionHandler::ExpectOutgoingMessage(const MCSMessage& message) { + expected_outgoing_messages_.push_back(message); +} + +void FakeConnectionHandler::ResetOutgoingMessageExpectations() { + expected_outgoing_messages_.clear(); +} + +bool FakeConnectionHandler::AllOutgoingMessagesReceived() const { + return expected_outgoing_messages_.empty(); +} + +void FakeConnectionHandler::ReceiveMessage(const MCSMessage& message) { + read_callback_.Run(message.CloneProtobuf()); +} + +} // namespace gcm diff --git a/google_apis/gcm/engine/fake_connection_handler.h b/google_apis/gcm/engine/fake_connection_handler.h new file mode 100644 index 0000000..5356b77 --- /dev/null +++ b/google_apis/gcm/engine/fake_connection_handler.h @@ -0,0 +1,74 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_FAKE_CONNECTION_HANDLER_H_ +#define GOOGLE_APIS_GCM_ENGINE_FAKE_CONNECTION_HANDLER_H_ + +#include <list> + +#include "google_apis/gcm/base/mcs_message.h" +#include "google_apis/gcm/engine/connection_handler.h" + +namespace gcm { + +// A fake implementation of a ConnectionHandler that can arbitrarily receive +// messages and verify expectations for outgoing messages. +class FakeConnectionHandler : public ConnectionHandler { + public: + FakeConnectionHandler( + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback); + virtual ~FakeConnectionHandler(); + + // ConnectionHandler implementation. + virtual void Init(const mcs_proto::LoginRequest& login_request, + scoped_ptr<net::StreamSocket> socket) OVERRIDE; + virtual bool CanSendMessage() const OVERRIDE; + virtual void SendMessage(const google::protobuf::MessageLite& message) + OVERRIDE; + + // EXPECT's receipt of |message| via SendMessage(..). + void ExpectOutgoingMessage(const MCSMessage& message); + + // Reset the expected outgoing messages. + void ResetOutgoingMessageExpectations(); + + // Whether all expected outgoing messages have been received; + bool AllOutgoingMessagesReceived() const; + + // Passes on |message| to |write_callback_|. + void ReceiveMessage(const MCSMessage& message); + + // Whether to return an error with the next login response. + void set_fail_login(bool fail_login) { + fail_login_ = fail_login; + } + + // Whether to invoke the write callback on the next send attempt or fake a + // connection error instead. + void set_fail_send(bool fail_send) { + fail_send_ = fail_send; + } + + private: + ConnectionHandler::ProtoReceivedCallback read_callback_; + ConnectionHandler::ProtoSentCallback write_callback_; + + std::list<MCSMessage> expected_outgoing_messages_; + + // Whether to fail the login or not. + bool fail_login_; + + // Whether to fail a SendMessage call or not. + bool fail_send_; + + // Whether a successful login has completed. + bool initialized_; + + DISALLOW_COPY_AND_ASSIGN(FakeConnectionHandler); +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_FAKE_CONNECTION_HANDLER_H_ diff --git a/google_apis/gcm/engine/mcs_client.cc b/google_apis/gcm/engine/mcs_client.cc new file mode 100644 index 0000000..f0af051 --- /dev/null +++ b/google_apis/gcm/engine/mcs_client.cc @@ -0,0 +1,659 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/mcs_client.h" + +#include "base/basictypes.h" +#include "base/message_loop/message_loop.h" +#include "base/strings/string_number_conversions.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/base/socket_stream.h" +#include "google_apis/gcm/engine/connection_factory.h" +#include "google_apis/gcm/engine/rmq_store.h" + +using namespace google::protobuf::io; + +namespace gcm { + +namespace { + +typedef scoped_ptr<google::protobuf::MessageLite> MCSProto; + +// TODO(zea): get these values from MCS settings. +const int64 kHeartbeatDefaultSeconds = 60 * 15; // 15 minutes. + +// The category of messages intended for the GCM client itself from MCS. +const char kMCSCategory[] = "com.google.android.gsf.gtalkservice"; + +// The from field for messages originating in the GCM client. +const char kGCMFromField[] = "gcm@android.com"; + +// MCS status message types. +const char kIdleNotification[] = "IdleNotification"; +// TODO(zea): consume the following message types: +// const char kAlwaysShowOnIdle[] = "ShowAwayOnIdle"; +// const char kPowerNotification[] = "PowerNotification"; +// const char kDataActiveNotification[] = "DataActiveNotification"; + +// The number of unacked messages to allow before sending a stream ack. +// Applies to both incoming and outgoing messages. +// TODO(zea): make this server configurable. +const int kUnackedMessageBeforeStreamAck = 10; + +// The global maximum number of pending messages to have in the send queue. +const size_t kMaxSendQueueSize = 10 * 1024; + +// The maximum message size that can be sent to the server. +const int kMaxMessageBytes = 4 * 1024; // 4KB, like the server. + +// Helper for converting a proto persistent id list to a vector of strings. +bool BuildPersistentIdListFromProto(const google::protobuf::string& bytes, + std::vector<std::string>* id_list) { + mcs_proto::SelectiveAck selective_ack; + if (!selective_ack.ParseFromString(bytes)) + return false; + std::vector<std::string> new_list; + for (int i = 0; i < selective_ack.id_size(); ++i) { + DCHECK(!selective_ack.id(i).empty()); + new_list.push_back(selective_ack.id(i)); + } + id_list->swap(new_list); + return true; +} + +} // namespace + +struct ReliablePacketInfo { + ReliablePacketInfo(); + ~ReliablePacketInfo(); + + // The stream id with which the message was sent. + uint32 stream_id; + + // If reliable delivery was requested, the persistent id of the message. + std::string persistent_id; + + // The type of message itself (for easier lookup). + uint8 tag; + + // The protobuf of the message itself. + MCSProto protobuf; +}; + +ReliablePacketInfo::ReliablePacketInfo() + : stream_id(0), tag(0) { +} +ReliablePacketInfo::~ReliablePacketInfo() {} + +MCSClient::MCSClient( + const base::FilePath& rmq_path, + ConnectionFactory* connection_factory, + scoped_refptr<base::SequencedTaskRunner> blocking_task_runner) + : state_(UNINITIALIZED), + android_id_(0), + security_token_(0), + connection_factory_(connection_factory), + connection_handler_(NULL), + last_device_to_server_stream_id_received_(0), + last_server_to_device_stream_id_received_(0), + stream_id_out_(0), + stream_id_in_(0), + rmq_store_(rmq_path, blocking_task_runner), + heartbeat_interval_( + base::TimeDelta::FromSeconds(kHeartbeatDefaultSeconds)), + heartbeat_timer_(true, true), + blocking_task_runner_(blocking_task_runner), + weak_ptr_factory_(this) { +} + +MCSClient::~MCSClient() { +} + +void MCSClient::Initialize( + const InitializationCompleteCallback& initialization_callback, + const OnMessageReceivedCallback& message_received_callback, + const OnMessageSentCallback& message_sent_callback) { + DCHECK_EQ(state_, UNINITIALIZED); + initialization_callback_ = initialization_callback; + message_received_callback_ = message_received_callback; + message_sent_callback_ = message_sent_callback; + + state_ = LOADING; + rmq_store_.Load(base::Bind(&MCSClient::OnRMQLoadFinished, + weak_ptr_factory_.GetWeakPtr())); + + connection_factory_->Initialize( + base::Bind(&MCSClient::ResetStateAndBuildLoginRequest, + weak_ptr_factory_.GetWeakPtr()), + base::Bind(&MCSClient::HandlePacketFromWire, + weak_ptr_factory_.GetWeakPtr()), + base::Bind(&MCSClient::MaybeSendMessage, + weak_ptr_factory_.GetWeakPtr())); + connection_handler_ = connection_factory_->GetConnectionHandler(); +} + +void MCSClient::Login(uint64 android_id, uint64 security_token) { + DCHECK_EQ(state_, LOADED); + if (android_id != android_id_ && security_token != security_token_) { + DCHECK(android_id); + DCHECK(security_token); + DCHECK(restored_unackeds_server_ids_.empty()); + android_id_ = android_id; + security_token_ = security_token; + rmq_store_.SetDeviceCredentials(android_id_, + security_token_, + base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); + } + + state_ = CONNECTING; + connection_factory_->Connect(); +} + +void MCSClient::SendMessage(const MCSMessage& message, bool use_rmq) { + DCHECK_EQ(state_, CONNECTED); + if (to_send_.size() > kMaxSendQueueSize) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(message_sent_callback_, "Message queue full.")); + return; + } + if (message.size() > kMaxMessageBytes) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(message_sent_callback_, "Message too large.")); + return; + } + + ReliablePacketInfo* packet_info = new ReliablePacketInfo(); + packet_info->protobuf = message.CloneProtobuf(); + + if (use_rmq) { + PersistentId persistent_id = GetNextPersistentId(); + DVLOG(1) << "Setting persistent id to " << persistent_id; + packet_info->persistent_id = persistent_id; + SetPersistentId(persistent_id, + packet_info->protobuf.get()); + rmq_store_.AddOutgoingMessage(persistent_id, + MCSMessage(message.tag(), + *(packet_info->protobuf)), + base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); + } else { + // Check that there is an active connection to the endpoint. + if (!connection_handler_->CanSendMessage()) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(message_sent_callback_, "Unable to reach endpoint")); + return; + } + } + to_send_.push_back(make_linked_ptr(packet_info)); + MaybeSendMessage(); +} + +void MCSClient::Destroy() { + rmq_store_.Destroy(base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); +} + +void MCSClient::ResetStateAndBuildLoginRequest( + mcs_proto::LoginRequest* request) { + DCHECK(android_id_); + DCHECK(security_token_); + stream_id_in_ = 0; + stream_id_out_ = 1; + last_device_to_server_stream_id_received_ = 0; + last_server_to_device_stream_id_received_ = 0; + + // TODO(zea): expire all messages older than their TTL. + + // Add any pending acknowledgments to the list of ids. + for (StreamIdToPersistentIdMap::const_iterator iter = + unacked_server_ids_.begin(); + iter != unacked_server_ids_.end(); ++iter) { + restored_unackeds_server_ids_.push_back(iter->second); + } + unacked_server_ids_.clear(); + + // Any acknowledged server ids which have not been confirmed by the server + // are treated like unacknowledged ids. + for (std::map<StreamId, PersistentIdList>::const_iterator iter = + acked_server_ids_.begin(); + iter != acked_server_ids_.end(); ++iter) { + restored_unackeds_server_ids_.insert(restored_unackeds_server_ids_.end(), + iter->second.begin(), + iter->second.end()); + } + acked_server_ids_.clear(); + + // Then build the request, consuming all pending acknowledgments. + request->Swap(BuildLoginRequest(android_id_, security_token_).get()); + for (PersistentIdList::const_iterator iter = + restored_unackeds_server_ids_.begin(); + iter != restored_unackeds_server_ids_.end(); ++iter) { + request->add_received_persistent_id(*iter); + } + acked_server_ids_[stream_id_out_] = restored_unackeds_server_ids_; + restored_unackeds_server_ids_.clear(); + + // Push all unacknowledged messages to front of send queue. No need to save + // to RMQ, as all messages that reach this point should already have been + // saved as necessary. + while (!to_resend_.empty()) { + to_send_.push_front(to_resend_.back()); + to_resend_.pop_back(); + } + DVLOG(1) << "Resetting state, with " << request->received_persistent_id_size() + << " incoming acks pending, and " << to_send_.size() + << " pending outgoing messages."; + + heartbeat_timer_.Stop(); + + state_ = CONNECTING; +} + +void MCSClient::SendHeartbeat() { + SendMessage(MCSMessage(kHeartbeatPingTag, mcs_proto::HeartbeatPing()), + false); +} + +void MCSClient::OnRMQLoadFinished(const RMQStore::LoadResult& result) { + if (!result.success) { + state_ = UNINITIALIZED; + LOG(ERROR) << "Failed to load/create RMQ state. Not connecting."; + initialization_callback_.Run(false, 0, 0); + return; + } + state_ = LOADED; + stream_id_out_ = 1; // Login request is hardcoded to id 1. + + if (result.device_android_id == 0 || result.device_security_token == 0) { + DVLOG(1) << "No device credentials found, assuming new client."; + initialization_callback_.Run(true, 0, 0); + return; + } + + android_id_ = result.device_android_id; + security_token_ = result.device_security_token; + + DVLOG(1) << "RMQ Load finished with " << result.incoming_messages.size() + << " incoming acks pending and " << result.outgoing_messages.size() + << " outgoing messages pending."; + + restored_unackeds_server_ids_ = result.incoming_messages; + + // First go through and order the outgoing messages by recency. + std::map<uint64, google::protobuf::MessageLite*> ordered_messages; + for (std::map<PersistentId, google::protobuf::MessageLite*>::const_iterator + iter = result.outgoing_messages.begin(); + iter != result.outgoing_messages.end(); ++iter) { + uint64 timestamp = 0; + if (!base::StringToUint64(iter->first, ×tamp)) { + LOG(ERROR) << "Invalid restored message."; + return; + } + ordered_messages[timestamp] = iter->second; + } + + // Now go through and add the outgoing messages to the send queue in their + // appropriate order (oldest at front, most recent at back). + for (std::map<uint64, google::protobuf::MessageLite*>::const_iterator + iter = ordered_messages.begin(); + iter != ordered_messages.end(); ++iter) { + ReliablePacketInfo* packet_info = new ReliablePacketInfo(); + packet_info->protobuf.reset(iter->second); + packet_info->persistent_id = base::Uint64ToString(iter->first); + to_send_.push_back(make_linked_ptr(packet_info)); + } + + initialization_callback_.Run(true, android_id_, security_token_); +} + +void MCSClient::OnRMQUpdateFinished(bool success) { + LOG_IF(ERROR, !success) << "RMQ Update failed!"; + // TODO(zea): Rebuild the store from scratch in case of persistence failure? +} + +void MCSClient::MaybeSendMessage() { + if (to_send_.empty()) + return; + + if (!connection_handler_->CanSendMessage()) + return; + + // TODO(zea): drop messages older than their TTL. + + DVLOG(1) << "Pending output message found, sending."; + MCSPacketInternal packet = to_send_.front(); + to_send_.pop_front(); + if (!packet->persistent_id.empty()) + to_resend_.push_back(packet); + SendPacketToWire(packet.get()); +} + +void MCSClient::SendPacketToWire(ReliablePacketInfo* packet_info) { + // Reset the heartbeat interval. + heartbeat_timer_.Reset(); + packet_info->stream_id = ++stream_id_out_; + DVLOG(1) << "Sending packet of type " << packet_info->protobuf->GetTypeName(); + + // Set the proper last received stream id to acknowledge received server + // packets. + DVLOG(1) << "Setting last stream id received to " + << stream_id_in_; + SetLastStreamIdReceived(stream_id_in_, + packet_info->protobuf.get()); + if (stream_id_in_ != last_server_to_device_stream_id_received_) { + last_server_to_device_stream_id_received_ = stream_id_in_; + // Mark all acknowledged server messages as such. Note: they're not dropped, + // as it may be that they'll need to be re-acked if this message doesn't + // make it. + PersistentIdList persistent_id_list; + for (StreamIdToPersistentIdMap::const_iterator iter = + unacked_server_ids_.begin(); + iter != unacked_server_ids_.end(); ++iter) { + DCHECK_LE(iter->first, last_server_to_device_stream_id_received_); + persistent_id_list.push_back(iter->second); + } + unacked_server_ids_.clear(); + acked_server_ids_[stream_id_out_] = persistent_id_list; + } + + connection_handler_->SendMessage(*packet_info->protobuf); +} + +void MCSClient::HandleMCSDataMesssage( + scoped_ptr<google::protobuf::MessageLite> protobuf) { + mcs_proto::DataMessageStanza* data_message = + reinterpret_cast<mcs_proto::DataMessageStanza*>(protobuf.get()); + // TODO(zea): implement a proper status manager rather than hardcoding these + // values. + scoped_ptr<mcs_proto::DataMessageStanza> response( + new mcs_proto::DataMessageStanza()); + response->set_from(kGCMFromField); + bool send = false; + for (int i = 0; i < data_message->app_data_size(); ++i) { + const mcs_proto::AppData& app_data = data_message->app_data(i); + if (app_data.key() == kIdleNotification) { + // Tell the MCS server the client is not idle. + send = true; + mcs_proto::AppData data; + data.set_key(kIdleNotification); + data.set_value("false"); + response->add_app_data()->CopyFrom(data); + response->set_category(kMCSCategory); + } + } + + if (send) { + SendMessage( + MCSMessage(kDataMessageStanzaTag, + response.PassAs<const google::protobuf::MessageLite>()), + false); + } +} + +void MCSClient::HandlePacketFromWire( + scoped_ptr<google::protobuf::MessageLite> protobuf) { + if (!protobuf.get()) + return; + uint8 tag = GetMCSProtoTag(*protobuf); + PersistentId persistent_id = GetPersistentId(*protobuf); + StreamId last_stream_id_received = GetLastStreamIdReceived(*protobuf); + + if (last_stream_id_received != 0) { + last_device_to_server_stream_id_received_ = last_stream_id_received; + + // Process device to server messages that have now been acknowledged by the + // server. Because messages are stored in order, just pop off all that have + // a stream id lower than server's last received stream id. + HandleStreamAck(last_stream_id_received); + + // Process server_to_device_messages that the server now knows were + // acknowledged. Again, they're in order, so just keep going until the + // stream id is reached. + StreamIdList acked_stream_ids_to_remove; + for (std::map<StreamId, PersistentIdList>::iterator iter = + acked_server_ids_.begin(); + iter != acked_server_ids_.end() && + iter->first <= last_stream_id_received; ++iter) { + acked_stream_ids_to_remove.push_back(iter->first); + } + for (StreamIdList::iterator iter = acked_stream_ids_to_remove.begin(); + iter != acked_stream_ids_to_remove.end(); ++iter) { + acked_server_ids_.erase(*iter); + } + } + + ++stream_id_in_; + if (!persistent_id.empty()) { + unacked_server_ids_[stream_id_in_] = persistent_id; + rmq_store_.AddIncomingMessage(persistent_id, + base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); + } + + DVLOG(1) << "Received message of type " << protobuf->GetTypeName() + << " with persistent id " + << (persistent_id.empty() ? "NULL" : persistent_id) + << ", stream id " << stream_id_in_ << " and last stream id received " + << last_stream_id_received; + + if (unacked_server_ids_.size() > 0 && + unacked_server_ids_.size() % kUnackedMessageBeforeStreamAck == 0) { + SendMessage(MCSMessage(kIqStanzaTag, + BuildStreamAck(). + PassAs<const google::protobuf::MessageLite>()), + false); + } + + switch (tag) { + case kLoginResponseTag: { + mcs_proto::LoginResponse* login_response = + reinterpret_cast<mcs_proto::LoginResponse*>(protobuf.get()); + DVLOG(1) << "Received login response:"; + DVLOG(1) << " Id: " << login_response->id(); + DVLOG(1) << " Timestamp: " << login_response->server_timestamp(); + if (login_response->has_error()) { + state_ = UNINITIALIZED; + DVLOG(1) << " Error code: " << login_response->error().code(); + DVLOG(1) << " Error message: " << login_response->error().message(); + initialization_callback_.Run(false, 0, 0); + return; + } + + state_ = CONNECTED; + stream_id_in_ = 1; // To account for the login response. + DCHECK_EQ(1U, stream_id_out_); + + // Pass the login response on up. + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(message_received_callback_, + MCSMessage(tag, + protobuf.PassAs< + const google::protobuf::MessageLite>()))); + + // If there are pending messages, attempt to send one. + if (!to_send_.empty()) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MCSClient::MaybeSendMessage, + weak_ptr_factory_.GetWeakPtr())); + } + + heartbeat_timer_.Start(FROM_HERE, + heartbeat_interval_, + base::Bind(&MCSClient::SendHeartbeat, + weak_ptr_factory_.GetWeakPtr())); + return; + } + case kHeartbeatPingTag: + DCHECK_GE(stream_id_in_, 1U); + DVLOG(1) << "Received heartbeat ping, sending ack."; + SendMessage( + MCSMessage(kHeartbeatAckTag, mcs_proto::HeartbeatAck()), false); + return; + case kHeartbeatAckTag: + DCHECK_GE(stream_id_in_, 1U); + DVLOG(1) << "Received heartbeat ack."; + // TODO(zea): add logic to reconnect if no ack received within a certain + // timeout (with backoff). + return; + case kCloseTag: + LOG(ERROR) << "Received close command, closing connection."; + state_ = UNINITIALIZED; + initialization_callback_.Run(false, 0, 0); + // TODO(zea): should this happen in non-error cases? Reconnect? + return; + case kIqStanzaTag: { + DCHECK_GE(stream_id_in_, 1U); + mcs_proto::IqStanza* iq_stanza = + reinterpret_cast<mcs_proto::IqStanza*>(protobuf.get()); + const mcs_proto::Extension& iq_extension = iq_stanza->extension(); + switch (iq_extension.id()) { + case kSelectiveAck: { + PersistentIdList acked_ids; + if (BuildPersistentIdListFromProto(iq_extension.data(), + &acked_ids)) { + HandleSelectiveAck(acked_ids); + } + return; + } + case kStreamAck: + // Do nothing. The last received stream id is always processed if it's + // present. + return; + default: + LOG(WARNING) << "Received invalid iq stanza extension " + << iq_extension.id(); + return; + } + } + case kDataMessageStanzaTag: { + DCHECK_GE(stream_id_in_, 1U); + mcs_proto::DataMessageStanza* data_message = + reinterpret_cast<mcs_proto::DataMessageStanza*>(protobuf.get()); + if (data_message->category() == kMCSCategory) { + HandleMCSDataMesssage(protobuf.Pass()); + return; + } + + DCHECK(protobuf.get()); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(message_received_callback_, + MCSMessage(tag, + protobuf.PassAs< + const google::protobuf::MessageLite>()))); + return; + } + default: + LOG(ERROR) << "Received unexpected message of type " + << static_cast<int>(tag); + return; + } +} + +void MCSClient::HandleStreamAck(StreamId last_stream_id_received) { + PersistentIdList acked_outgoing_persistent_ids; + StreamIdList acked_outgoing_stream_ids; + while (!to_resend_.empty() && + to_resend_.front()->stream_id <= last_stream_id_received) { + const MCSPacketInternal& outgoing_packet = to_resend_.front(); + acked_outgoing_persistent_ids.push_back(outgoing_packet->persistent_id); + acked_outgoing_stream_ids.push_back(outgoing_packet->stream_id); + to_resend_.pop_front(); + } + + DVLOG(1) << "Server acked " << acked_outgoing_persistent_ids.size() + << " outgoing messages, " << to_resend_.size() + << " remaining unacked"; + rmq_store_.RemoveOutgoingMessages(acked_outgoing_persistent_ids, + base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); + + HandleServerConfirmedReceipt(last_stream_id_received); +} + +void MCSClient::HandleSelectiveAck(const PersistentIdList& id_list) { + // First check the to_resend_ queue. Acknowledgments should always happen + // in the order they were sent, so if messages are present they should match + // the acknowledge list. + PersistentIdList::const_iterator iter = id_list.begin(); + for (; iter != id_list.end() && !to_resend_.empty(); ++iter) { + const MCSPacketInternal& outgoing_packet = to_resend_.front(); + DCHECK_EQ(outgoing_packet->persistent_id, *iter); + + // No need to re-acknowledge any server messages this message already + // acknowledged. + StreamId device_stream_id = outgoing_packet->stream_id; + HandleServerConfirmedReceipt(device_stream_id); + + to_resend_.pop_front(); + } + + // If the acknowledged ids aren't all there, they might be in the to_send_ + // queue (typically when a StreamAck confirms messages as part of a login + // response). + for (; iter != id_list.end() && !to_send_.empty(); ++iter) { + const MCSPacketInternal& outgoing_packet = to_send_.front(); + DCHECK_EQ(outgoing_packet->persistent_id, *iter); + + // No need to re-acknowledge any server messages this message already + // acknowledged. + StreamId device_stream_id = outgoing_packet->stream_id; + HandleServerConfirmedReceipt(device_stream_id); + + to_send_.pop_front(); + } + + DCHECK(iter == id_list.end()); + + DVLOG(1) << "Server acked " << id_list.size() + << " messages, " << to_resend_.size() << " remaining unacked."; + rmq_store_.RemoveOutgoingMessages(id_list, + base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); + + // Resend any remaining outgoing messages, as they were not received by the + // server. + DVLOG(1) << "Resending " << to_resend_.size() << " messages."; + while (!to_resend_.empty()) { + to_send_.push_front(to_resend_.back()); + to_resend_.pop_back(); + } +} + +void MCSClient::HandleServerConfirmedReceipt(StreamId device_stream_id) { + // TODO(zea): use a message id the sender understands. + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(message_sent_callback_, + "Message " + base::UintToString(device_stream_id) + " sent.")); + + PersistentIdList acked_incoming_ids; + for (std::map<StreamId, PersistentIdList>::iterator iter = + acked_server_ids_.begin(); + iter != acked_server_ids_.end() && + iter->first <= device_stream_id;) { + acked_incoming_ids.insert(acked_incoming_ids.end(), + iter->second.begin(), + iter->second.end()); + acked_server_ids_.erase(iter++); + } + + DVLOG(1) << "Server confirmed receipt of " << acked_incoming_ids.size() + << " acknowledged server messages."; + rmq_store_.RemoveIncomingMessages(acked_incoming_ids, + base::Bind(&MCSClient::OnRMQUpdateFinished, + weak_ptr_factory_.GetWeakPtr())); +} + +MCSClient::PersistentId MCSClient::GetNextPersistentId() { + return base::Uint64ToString(base::TimeTicks::Now().ToInternalValue()); +} + +} // namespace gcm diff --git a/google_apis/gcm/engine/mcs_client.h b/google_apis/gcm/engine/mcs_client.h new file mode 100644 index 0000000..4de62cb --- /dev/null +++ b/google_apis/gcm/engine/mcs_client.h @@ -0,0 +1,231 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_MCS_CLIENT_H_ +#define GOOGLE_APIS_GCM_ENGINE_MCS_CLIENT_H_ + +#include <deque> +#include <map> +#include <string> +#include <vector> + +#include "base/files/file_path.h" +#include "base/memory/linked_ptr.h" +#include "base/memory/weak_ptr.h" +#include "base/timer/timer.h" +#include "google_apis/gcm/base/gcm_export.h" +#include "google_apis/gcm/base/mcs_message.h" +#include "google_apis/gcm/engine/connection_handler.h" +#include "google_apis/gcm/engine/rmq_store.h" + +namespace google { +namespace protobuf { +class MessageLite; +} // namespace protobuf +} // namespace google + +namespace mcs_proto { +class LoginRequest; +} + +namespace gcm { + +class ConnectionFactory; +struct ReliablePacketInfo; + +// An MCS client. This client is in charge of all communications with an +// MCS endpoint, and is capable of reliably sending/receiving GCM messages. +// NOTE: Not thread safe. This class should live on the same thread as that +// network requests are performed on. +class GCM_EXPORT MCSClient { + public: + enum State { + UNINITIALIZED, // Uninitialized. + LOADING, // Waiting for RMQ load to finish. + LOADED, // RMQ Load finished, waiting to connect. + CONNECTING, // Connection in progress. + CONNECTED, // Connected and running. + }; + + // Callback for informing MCSClient status. It is valid for this to be + // invoked more than once if a permanent error is encountered after a + // successful login was initiated. + typedef base::Callback< + void(bool success, + uint64 restored_android_id, + uint64 restored_security_token)> InitializationCompleteCallback; + // Callback when a message is received. + typedef base::Callback<void(const MCSMessage& message)> + OnMessageReceivedCallback; + // Callback when a message is sent (and receipt has been acknowledged by + // the MCS endpoint). + // TODO(zea): pass some sort of structure containing more details about + // send failures. + typedef base::Callback<void(const std::string& message_id)> + OnMessageSentCallback; + + MCSClient(const base::FilePath& rmq_path, + ConnectionFactory* connection_factory, + scoped_refptr<base::SequencedTaskRunner> blocking_task_runner); + virtual ~MCSClient(); + + // Initialize the client. Will load any previous id/token information as well + // as unacknowledged message information from the RMQ storage, if it exists, + // passing the id/token information back via |initialization_callback| along + // with a |success == true| result. If no RMQ information is present (and + // this is therefore a fresh client), a clean RMQ store will be created and + // values of 0 will be returned via |initialization_callback| with + // |success == true|. + /// If an error loading the RMQ store is encountered, + // |initialization_callback| will be invoked with |success == false|. + void Initialize(const InitializationCompleteCallback& initialization_callback, + const OnMessageReceivedCallback& message_received_callback, + const OnMessageSentCallback& message_sent_callback); + + // Logs the client into the server. Client must be initialized. + // |android_id| and |security_token| are optional if this is not a new + // client, else they must be non-zero. + // Successful login will result in |message_received_callback| being invoked + // with a valid LoginResponse. + // Login failure (typically invalid id/token) will shut down the client, and + // |initialization_callback| to be invoked with |success = false|. + void Login(uint64 android_id, uint64 security_token); + + // Sends a message, with or without reliable message queueing (RMQ) support. + // Will asynchronously invoke the OnMessageSent callback regardless. + // TODO(zea): support TTL. + void SendMessage(const MCSMessage& message, bool use_rmq); + + // Disconnects the client and permanently destroys the persistent RMQ store. + // WARNING: This is permanent, and the client must be recreated with new + // credentials afterwards. + void Destroy(); + + // Returns the current state of the client. + State state() const { return state_; } + + private: + typedef uint32 StreamId; + typedef std::string PersistentId; + typedef std::vector<StreamId> StreamIdList; + typedef std::vector<PersistentId> PersistentIdList; + typedef std::map<StreamId, PersistentId> StreamIdToPersistentIdMap; + typedef linked_ptr<ReliablePacketInfo> MCSPacketInternal; + + // Resets the internal state and builds a new login request, acknowledging + // any pending server-to-device messages and rebuilding the send queue + // from all unacknowledged device-to-server messages. + // Should only be called when the connection has been reset. + void ResetStateAndBuildLoginRequest(mcs_proto::LoginRequest* request); + + // Send a heartbeat to the MCS server. + void SendHeartbeat(); + + // RMQ Store callbacks. + void OnRMQLoadFinished(const RMQStore::LoadResult& result); + void OnRMQUpdateFinished(bool success); + + // Attempt to send a message. + void MaybeSendMessage(); + + // Helper for sending a protobuf along with any unacknowledged ids to the + // wire. + void SendPacketToWire(ReliablePacketInfo* packet_info); + + // Handle a data message sent to the MCS client system from the MCS server. + void HandleMCSDataMesssage( + scoped_ptr<google::protobuf::MessageLite> protobuf); + + // Handle a packet received over the wire. + void HandlePacketFromWire(scoped_ptr<google::protobuf::MessageLite> protobuf); + + // ReliableMessageQueue acknowledgment helpers. + // Handle a StreamAck sent by the server confirming receipt of all + // messages up to the message with stream id |last_stream_id_received|. + void HandleStreamAck(StreamId last_stream_id_received_); + // Handle a SelectiveAck sent by the server confirming all messages + // in |id_list|. + void HandleSelectiveAck(const PersistentIdList& id_list); + // Handle server confirmation of a device message, including device's + // acknowledgment of receipt of messages. + void HandleServerConfirmedReceipt(StreamId device_stream_id); + + // Generates a new persistent id for messages. + // Virtual for testing. + virtual PersistentId GetNextPersistentId(); + + // Client state. + State state_; + + // Callbacks for owner. + InitializationCompleteCallback initialization_callback_; + OnMessageReceivedCallback message_received_callback_; + OnMessageSentCallback message_sent_callback_; + + // The android id and security token in use by this device. + uint64 android_id_; + uint64 security_token_; + + // Factory for creating new connections and connection handlers. + ConnectionFactory* connection_factory_; + + // Connection handler to handle all over-the-wire protocol communication + // with the mobile connection server. + ConnectionHandler* connection_handler_; + + // ----- Reliablie Message Queue section ----- + // Note: all queues/maps are ordered from oldest (front/begin) message to + // most recent (back/end). + + // Send/acknowledge queues. + std::deque<MCSPacketInternal> to_send_; + std::deque<MCSPacketInternal> to_resend_; + + // Last device_to_server stream id acknowledged by the server. + StreamId last_device_to_server_stream_id_received_; + // Last server_to_device stream id acknowledged by this device. + StreamId last_server_to_device_stream_id_received_; + // The stream id for the last sent message. A new message should consume + // stream_id_out_ + 1. + StreamId stream_id_out_; + // The stream id of the last received message. The LoginResponse will always + // have a stream id of 1, and stream ids increment by 1 for each received + // message. + StreamId stream_id_in_; + + // The server messages that have not been acked by the device yet. Keyed by + // server stream id. + StreamIdToPersistentIdMap unacked_server_ids_; + + // Those server messages that have been acked. They must remain tracked + // until the ack message is itself confirmed. The list of all message ids + // acknowledged are keyed off the device stream id of the message that + // acknowledged them. + std::map<StreamId, PersistentIdList> acked_server_ids_; + + // Those server messages from a previous connection that were not fully + // acknowledged. They do not have associated stream ids, and will be + // acknowledged on the next login attempt. + PersistentIdList restored_unackeds_server_ids_; + + // The reliable message queue persistent store. + RMQStore rmq_store_; + + // ----- Heartbeats ----- + // The current heartbeat interval. + base::TimeDelta heartbeat_interval_; + // Timer for triggering heartbeats. + base::Timer heartbeat_timer_; + + // The task runner for blocking tasks (i.e. persisting RMQ state to disk). + scoped_refptr<base::SequencedTaskRunner> blocking_task_runner_; + + base::WeakPtrFactory<MCSClient> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(MCSClient); +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_MCS_CLIENT_H_ diff --git a/google_apis/gcm/engine/mcs_client_unittest.cc b/google_apis/gcm/engine/mcs_client_unittest.cc new file mode 100644 index 0000000..6ef1405 --- /dev/null +++ b/google_apis/gcm/engine/mcs_client_unittest.cc @@ -0,0 +1,540 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/mcs_client.h" + +#include "base/files/scoped_temp_dir.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "components/webdata/encryptor/encryptor.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/engine/fake_connection_factory.h" +#include "google_apis/gcm/engine/fake_connection_handler.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace gcm { + +namespace { + +const uint64 kAndroidId = 54321; +const uint64 kSecurityToken = 12345; + +// Number of messages to send when testing batching. +// Note: must be even for tests that split batches in half. +const int kMessageBatchSize = 6; + +// The number of unacked messages the client will receive before sending a +// stream ack. +// TODO(zea): get this (and other constants) directly from the mcs client. +const int kAckLimitSize = 10; + +// Helper for building arbitrary data messages. +MCSMessage BuildDataMessage(const std::string& from, + const std::string& category, + int last_stream_id_received, + const std::string persistent_id) { + mcs_proto::DataMessageStanza data_message; + data_message.set_from(from); + data_message.set_category(category); + data_message.set_last_stream_id_received(last_stream_id_received); + if (!persistent_id.empty()) + data_message.set_persistent_id(persistent_id); + return MCSMessage(kDataMessageStanzaTag, data_message); +} + +// MCSClient with overriden exposed persistent id logic. +class TestMCSClient : public MCSClient { + public: + TestMCSClient(const base::FilePath& rmq_path, + ConnectionFactory* connection_factory, + scoped_refptr<base::SequencedTaskRunner> blocking_task_runner) + : MCSClient(rmq_path, connection_factory, blocking_task_runner), + next_id_(0) { + } + + virtual std::string GetNextPersistentId() OVERRIDE { + return base::UintToString(++next_id_); + } + + private: + uint32 next_id_; +}; + +class MCSClientTest : public testing::Test { + public: + MCSClientTest(); + virtual ~MCSClientTest(); + + void BuildMCSClient(); + void InitializeClient(); + void LoginClient(const std::vector<std::string>& acknowledged_ids); + + TestMCSClient* mcs_client() const { return mcs_client_.get(); } + FakeConnectionFactory* connection_factory() { + return &connection_factory_; + } + bool init_success() const { return init_success_; } + uint64 restored_android_id() const { return restored_android_id_; } + uint64 restored_security_token() const { return restored_security_token_; } + MCSMessage* received_message() const { return received_message_.get(); } + std::string sent_message_id() const { return sent_message_id_;} + + FakeConnectionHandler* GetFakeHandler() const; + + void WaitForMCSEvent(); + void PumpLoop(); + + private: + void InitializationCallback(bool success, + uint64 restored_android_id, + uint64 restored_security_token); + void MessageReceivedCallback(const MCSMessage& message); + void MessageSentCallback(const std::string& message_id); + + base::ScopedTempDir temp_directory_; + base::MessageLoop message_loop_; + scoped_ptr<base::RunLoop> run_loop_; + + FakeConnectionFactory connection_factory_; + scoped_ptr<TestMCSClient> mcs_client_; + bool init_success_; + uint64 restored_android_id_; + uint64 restored_security_token_; + scoped_ptr<MCSMessage> received_message_; + std::string sent_message_id_; +}; + +MCSClientTest::MCSClientTest() + : run_loop_(new base::RunLoop()), + init_success_(false), + restored_android_id_(0), + restored_security_token_(0) { + EXPECT_TRUE(temp_directory_.CreateUniqueTempDir()); + run_loop_.reset(new base::RunLoop()); + + // On OSX, prevent the Keychain permissions popup during unit tests. +#if defined(OS_MACOSX) + Encryptor::UseMockKeychain(true); +#endif +} + +MCSClientTest::~MCSClientTest() {} + +void MCSClientTest::BuildMCSClient() { + mcs_client_.reset( + new TestMCSClient(temp_directory_.path(), + &connection_factory_, + message_loop_.message_loop_proxy())); +} + +void MCSClientTest::InitializeClient() { + mcs_client_->Initialize(base::Bind(&MCSClientTest::InitializationCallback, + base::Unretained(this)), + base::Bind(&MCSClientTest::MessageReceivedCallback, + base::Unretained(this)), + base::Bind(&MCSClientTest::MessageSentCallback, + base::Unretained(this))); + run_loop_->Run(); + run_loop_.reset(new base::RunLoop()); +} + +void MCSClientTest::LoginClient( + const std::vector<std::string>& acknowledged_ids) { + scoped_ptr<mcs_proto::LoginRequest> login_request = + BuildLoginRequest(kAndroidId, kSecurityToken); + for (size_t i = 0; i < acknowledged_ids.size(); ++i) + login_request->add_received_persistent_id(acknowledged_ids[i]); + GetFakeHandler()->ExpectOutgoingMessage( + MCSMessage(kLoginRequestTag, + login_request.PassAs<const google::protobuf::MessageLite>())); + mcs_client_->Login(kAndroidId, kSecurityToken); + run_loop_->Run(); + run_loop_.reset(new base::RunLoop()); +} + +FakeConnectionHandler* MCSClientTest::GetFakeHandler() const { + return reinterpret_cast<FakeConnectionHandler*>( + connection_factory_.GetConnectionHandler()); +} + +void MCSClientTest::WaitForMCSEvent() { + run_loop_->Run(); + run_loop_.reset(new base::RunLoop()); +} + +void MCSClientTest::PumpLoop() { + run_loop_->RunUntilIdle(); + run_loop_.reset(new base::RunLoop()); +} + +void MCSClientTest::InitializationCallback(bool success, + uint64 restored_android_id, + uint64 restored_security_token) { + init_success_ = success; + restored_android_id_ = restored_android_id; + restored_security_token_ = restored_security_token; + DVLOG(1) << "Initialization callback invoked, killing loop."; + run_loop_->Quit(); +} + +void MCSClientTest::MessageReceivedCallback(const MCSMessage& message) { + received_message_.reset(new MCSMessage(message)); + DVLOG(1) << "Message received callback invoked, killing loop."; + run_loop_->Quit(); +} + +void MCSClientTest::MessageSentCallback(const std::string& message_id) { + DVLOG(1) << "Message sent callback invoked, killing loop."; + run_loop_->Quit(); +} + +// Initialize a new client. +TEST_F(MCSClientTest, InitializeNew) { + BuildMCSClient(); + InitializeClient(); + EXPECT_EQ(0U, restored_android_id()); + EXPECT_EQ(0U, restored_security_token()); + EXPECT_TRUE(init_success()); +} + +// Initialize a new client, shut it down, then restart the client. Should +// reload the existing device credentials. +TEST_F(MCSClientTest, InitializeExisting) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // Rebuild the client, to reload from the RMQ. + BuildMCSClient(); + InitializeClient(); + EXPECT_EQ(kAndroidId, restored_android_id()); + EXPECT_EQ(kSecurityToken, restored_security_token()); + EXPECT_TRUE(init_success()); +} + +// Log in successfully to the MCS endpoint. +TEST_F(MCSClientTest, LoginSuccess) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + EXPECT_TRUE(connection_factory()->IsEndpointReachable()); + EXPECT_TRUE(init_success()); + ASSERT_TRUE(received_message()); + EXPECT_EQ(kLoginResponseTag, received_message()->tag()); +} + +// Encounter a server error during the login attempt. +TEST_F(MCSClientTest, FailLogin) { + BuildMCSClient(); + InitializeClient(); + GetFakeHandler()->set_fail_login(true); + LoginClient(std::vector<std::string>()); + EXPECT_FALSE(connection_factory()->IsEndpointReachable()); + EXPECT_FALSE(init_success()); + EXPECT_FALSE(received_message()); +} + +// Send a message without RMQ support. +TEST_F(MCSClientTest, SendMessageNoRMQ) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + MCSMessage message(BuildDataMessage("from", "category", 1, "")); + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, false); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Send a message with RMQ support. +TEST_F(MCSClientTest, SendMessageRMQ) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + MCSMessage message(BuildDataMessage("from", "category", 1, "1")); + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, true); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Send a message with RMQ support while disconnected. On reconnect, the message +// should be resent. +TEST_F(MCSClientTest, SendMessageRMQWhileDisconnected) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + GetFakeHandler()->set_fail_send(true); + MCSMessage message(BuildDataMessage("from", "category", 1, "1")); + + // The initial (failed) send. + GetFakeHandler()->ExpectOutgoingMessage(message); + // The login request. + GetFakeHandler()->ExpectOutgoingMessage( + MCSMessage(kLoginRequestTag, + BuildLoginRequest(kAndroidId, kSecurityToken). + PassAs<const google::protobuf::MessageLite>())); + // The second (re)send. + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, true); + EXPECT_FALSE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + GetFakeHandler()->set_fail_send(false); + connection_factory()->Connect(); + WaitForMCSEvent(); // Wait for the login to finish. + PumpLoop(); // Wait for the send to happen. + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Send a message with RMQ support without receiving an acknowledgement. On +// restart the message should be resent. +TEST_F(MCSClientTest, SendMessageRMQOnRestart) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + GetFakeHandler()->set_fail_send(true); + MCSMessage message(BuildDataMessage("from", "category", 1, "1")); + + // The initial (failed) send. + GetFakeHandler()->ExpectOutgoingMessage(message); + GetFakeHandler()->set_fail_send(false); + mcs_client()->SendMessage(message, true); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + + // Rebuild the client, which should resend the old message. + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + GetFakeHandler()->ExpectOutgoingMessage(message); + PumpLoop(); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Send messages with RMQ support, followed by receiving a stream ack. On +// restart nothing should be recent. +TEST_F(MCSClientTest, SendMessageRMQWithStreamAck) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // Send some messages. + for (int i = 1; i <= kMessageBatchSize; ++i) { + MCSMessage message( + BuildDataMessage("from", "category", 1, base::IntToString(i))); + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, true); + } + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + + // Receive the ack. + scoped_ptr<mcs_proto::IqStanza> ack = BuildStreamAck(); + ack->set_last_stream_id_received(kMessageBatchSize + 1); + GetFakeHandler()->ReceiveMessage( + MCSMessage(kIqStanzaTag, + ack.PassAs<const google::protobuf::MessageLite>())); + WaitForMCSEvent(); + + // Reconnect and ensure no messages are resent. + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + PumpLoop(); +} + +// Send messages with RMQ support. On restart, receive a SelectiveAck with +// the login response. No messages should be resent. +TEST_F(MCSClientTest, SendMessageRMQAckOnReconnect) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // Send some messages. + std::vector<std::string> id_list; + for (int i = 1; i <= kMessageBatchSize; ++i) { + id_list.push_back(base::IntToString(i)); + MCSMessage message( + BuildDataMessage("from", "category", 1, id_list.back())); + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, true); + } + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + + // Rebuild the client, and receive an acknowledgment for the messages as + // part of the login response. + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + scoped_ptr<mcs_proto::IqStanza> ack(BuildSelectiveAck(id_list)); + GetFakeHandler()->ReceiveMessage( + MCSMessage(kIqStanzaTag, + ack.PassAs<const google::protobuf::MessageLite>())); + WaitForMCSEvent(); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Send messages with RMQ support. On restart, receive a SelectiveAck with +// the login response that only acks some messages. The unacked messages should +// be resent. +TEST_F(MCSClientTest, SendMessageRMQPartialAckOnReconnect) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // Send some messages. + std::vector<std::string> id_list; + for (int i = 1; i <= kMessageBatchSize; ++i) { + id_list.push_back(base::IntToString(i)); + MCSMessage message( + BuildDataMessage("from", "category", 1, id_list.back())); + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, true); + } + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + + // Rebuild the client, and receive an acknowledgment for the messages as + // part of the login response. + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + std::vector<std::string> acked_ids, remaining_ids; + acked_ids.insert(acked_ids.end(), + id_list.begin(), + id_list.begin() + kMessageBatchSize / 2); + remaining_ids.insert(remaining_ids.end(), + id_list.begin() + kMessageBatchSize / 2, + id_list.end()); + for (int i = 1; i <= kMessageBatchSize / 2; ++i) { + MCSMessage message( + BuildDataMessage("from", + "category", + 2, + remaining_ids[i - 1])); + GetFakeHandler()->ExpectOutgoingMessage(message); + } + scoped_ptr<mcs_proto::IqStanza> ack(BuildSelectiveAck(acked_ids)); + GetFakeHandler()->ReceiveMessage( + MCSMessage(kIqStanzaTag, + ack.PassAs<const google::protobuf::MessageLite>())); + WaitForMCSEvent(); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Receive some messages. On restart, the login request should contain the +// appropriate acknowledged ids. +TEST_F(MCSClientTest, AckOnLogin) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // Receive some messages. + std::vector<std::string> id_list; + for (int i = 1; i <= kMessageBatchSize; ++i) { + id_list.push_back(base::IntToString(i)); + MCSMessage message( + BuildDataMessage("from", "category", i, id_list.back())); + GetFakeHandler()->ReceiveMessage(message); + WaitForMCSEvent(); + PumpLoop(); + } + + // Restart the client. + BuildMCSClient(); + InitializeClient(); + LoginClient(id_list); +} + +// Receive some messages. On the next send, the outgoing message should contain +// the appropriate last stream id received field to ack the received messages. +TEST_F(MCSClientTest, AckOnSend) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // Receive some messages. + std::vector<std::string> id_list; + for (int i = 1; i <= kMessageBatchSize; ++i) { + id_list.push_back(base::IntToString(i)); + MCSMessage message( + BuildDataMessage("from", "category", i, id_list.back())); + GetFakeHandler()->ReceiveMessage(message); + WaitForMCSEvent(); + PumpLoop(); + } + + // Trigger a message send, which should acknowledge via stream ack. + MCSMessage message( + BuildDataMessage("from", "category", kMessageBatchSize + 1, "1")); + GetFakeHandler()->ExpectOutgoingMessage(message); + mcs_client()->SendMessage(message, true); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +// Receive the ack limit in messages, which should trigger an automatic +// stream ack. Receive a heartbeat to confirm the ack. +TEST_F(MCSClientTest, AckWhenLimitReachedWithHeartbeat) { + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + + // The stream ack. + scoped_ptr<mcs_proto::IqStanza> ack = BuildStreamAck(); + ack->set_last_stream_id_received(kAckLimitSize + 1); + GetFakeHandler()->ExpectOutgoingMessage( + MCSMessage(kIqStanzaTag, + ack.PassAs<const google::protobuf::MessageLite>())); + + // Receive some messages. + std::vector<std::string> id_list; + for (int i = 1; i <= kAckLimitSize; ++i) { + id_list.push_back(base::IntToString(i)); + MCSMessage message( + BuildDataMessage("from", "category", i, id_list.back())); + GetFakeHandler()->ReceiveMessage(message); + WaitForMCSEvent(); + PumpLoop(); + } + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + + // Receive a heartbeat confirming the ack (and receive the heartbeat ack). + scoped_ptr<mcs_proto::HeartbeatPing> heartbeat( + new mcs_proto::HeartbeatPing()); + heartbeat->set_last_stream_id_received(2); + + scoped_ptr<mcs_proto::HeartbeatAck> heartbeat_ack( + new mcs_proto::HeartbeatAck()); + heartbeat_ack->set_last_stream_id_received(kAckLimitSize + 2); + GetFakeHandler()->ExpectOutgoingMessage( + MCSMessage(kHeartbeatAckTag, + heartbeat_ack.PassAs<const google::protobuf::MessageLite>())); + + GetFakeHandler()->ReceiveMessage( + MCSMessage(kHeartbeatPingTag, + heartbeat.PassAs<const google::protobuf::MessageLite>())); + WaitForMCSEvent(); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); + + // Rebuild the client. Nothing should be sent on login. + BuildMCSClient(); + InitializeClient(); + LoginClient(std::vector<std::string>()); + EXPECT_TRUE(GetFakeHandler()-> + AllOutgoingMessagesReceived()); +} + +} // namespace + +} // namespace gcm diff --git a/google_apis/gcm/gcm.gyp b/google_apis/gcm/gcm.gyp index 833f2c8..f81c4ef 100644 --- a/google_apis/gcm/gcm.gyp +++ b/google_apis/gcm/gcm.gyp @@ -52,6 +52,8 @@ 'engine/connection_handler.cc', 'engine/connection_handler_impl.h', 'engine/connection_handler_impl.cc', + 'engine/mcs_client.h', + 'engine/mcs_client.cc', 'engine/rmq_store.h', 'engine/rmq_store.cc', 'gcm_client.cc', @@ -65,6 +67,26 @@ ], }, + # A standalone MCS (mobile connection server) client. + { + 'target_name': 'mcs_probe', + 'type': 'executable', + 'variables': { 'enable_wexit_time_destructors': 1, }, + 'include_dirs': [ + '../..', + ], + 'dependencies': [ + '../../base/base.gyp:base', + '../../net/net.gyp:net', + '../../net/net.gyp:net_test_support', + '../../third_party/protobuf/protobuf.gyp:protobuf_lite', + 'gcm' + ], + 'sources': [ + 'tools/mcs_probe.cc', + ], + }, + # The main GCM unit tests. { 'target_name': 'gcm_unit_tests', @@ -73,20 +95,30 @@ 'include_dirs': [ '../..', ], + 'export_dependent_settings': [ + '../../third_party/protobuf/protobuf.gyp:protobuf_lite' + ], 'dependencies': [ '../../base/base.gyp:run_all_unittests', '../../base/base.gyp:base', '../../components/components.gyp:encryptor', + '../../net/net.gyp:net', '../../net/net.gyp:net_test_support', '../../testing/gtest.gyp:gtest', '../../third_party/protobuf/protobuf.gyp:protobuf_lite', 'gcm' ], 'sources': [ + 'base/mcs_message_unittest.cc', 'base/mcs_util_unittest.cc', 'base/socket_stream_unittest.cc', 'engine/connection_factory_impl_unittest.cc', 'engine/connection_handler_impl_unittest.cc', + 'engine/fake_connection_factory.h', + 'engine/fake_connection_factory.cc', + 'engine/fake_connection_handler.h', + 'engine/fake_connection_handler.cc', + 'engine/mcs_client_unittest.cc', 'engine/rmq_store_unittest.cc', ] }, diff --git a/google_apis/gcm/tools/mcs_probe.cc b/google_apis/gcm/tools/mcs_probe.cc new file mode 100644 index 0000000..bc4ad7c --- /dev/null +++ b/google_apis/gcm/tools/mcs_probe.cc @@ -0,0 +1,372 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// A standalone tool for testing MCS connections and the MCS client on their +// own. + +#include <cstddef> +#include <cstdio> +#include <string> + +#include "base/at_exit.h" +#include "base/command_line.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "base/threading/thread.h" +#include "base/threading/worker_pool.h" +#include "base/values.h" +#include "google_apis/gcm/base/mcs_message.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/engine/connection_factory_impl.h" +#include "google_apis/gcm/engine/mcs_client.h" +#include "net/base/host_mapping_rules.h" +#include "net/base/net_log_logger.h" +#include "net/cert/cert_verifier.h" +#include "net/dns/host_resolver.h" +#include "net/http/http_auth_handler_factory.h" +#include "net/http/http_network_session.h" +#include "net/http/http_server_properties_impl.h" +#include "net/http/transport_security_state.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/ssl_client_socket.h" +#include "net/ssl/default_server_bound_cert_store.h" +#include "net/ssl/server_bound_cert_service.h" +#include "net/url_request/url_request_test_util.h" + +#if defined(OS_MACOSX) +#include "base/mac/scoped_nsautorelease_pool.h" +#endif + +// This is a simple utility that initializes an mcs client and +// prints out any events. +namespace gcm { +namespace { + +// The default server to communicate with. +const char kMCSServerHost[] = "mtalk.google.com"; +const uint16 kMCSServerPort = 5228; + +// Command line switches. +const char kRMQFileName[] = "rmq_file"; +const char kAndroidIdSwitch[] = "android_id"; +const char kSecretSwitch[] = "secret"; +const char kLogFileSwitch[] = "log-file"; +const char kIgnoreCertSwitch[] = "ignore-certs"; +const char kServerHostSwitch[] = "host"; +const char kServerPortSwitch[] = "port"; + +void MessageReceivedCallback(const MCSMessage& message) { + LOG(INFO) << "Received message with id " + << GetPersistentId(message.GetProtobuf()) << " and tag " + << static_cast<int>(message.tag()); + + if (message.tag() == kDataMessageStanzaTag) { + const mcs_proto::DataMessageStanza& data_message = + reinterpret_cast<const mcs_proto::DataMessageStanza&>( + message.GetProtobuf()); + DVLOG(1) << " to: " << data_message.to(); + DVLOG(1) << " from: " << data_message.from(); + DVLOG(1) << " category: " << data_message.category(); + DVLOG(1) << " sent: " << data_message.sent(); + for (int i = 0; i < data_message.app_data_size(); ++i) { + DVLOG(1) << " App data " << i << " " + << data_message.app_data(i).key() << " : " + << data_message.app_data(i).value(); + } + } +} + +void MessageSentCallback(const std::string& local_id) { + LOG(INFO) << "Message sent. Status: " << local_id; +} + +// Needed to use a real host resolver. +class MyTestURLRequestContext : public net::TestURLRequestContext { + public: + MyTestURLRequestContext() : TestURLRequestContext(true) { + context_storage_.set_host_resolver( + net::HostResolver::CreateDefaultResolver(NULL)); + context_storage_.set_transport_security_state( + new net::TransportSecurityState()); + Init(); + } + + virtual ~MyTestURLRequestContext() {} +}; + +class MyTestURLRequestContextGetter : public net::TestURLRequestContextGetter { + public: + explicit MyTestURLRequestContextGetter( + const scoped_refptr<base::MessageLoopProxy>& io_message_loop_proxy) + : TestURLRequestContextGetter(io_message_loop_proxy) {} + + virtual net::TestURLRequestContext* GetURLRequestContext() OVERRIDE { + // Construct |context_| lazily so it gets constructed on the right + // thread (the IO thread). + if (!context_) + context_.reset(new MyTestURLRequestContext()); + return context_.get(); + } + + private: + virtual ~MyTestURLRequestContextGetter() {} + + scoped_ptr<MyTestURLRequestContext> context_; +}; + +// A net log that logs all events by default. +class MyTestNetLog : public net::NetLog { + public: + MyTestNetLog() { + SetBaseLogLevel(LOG_ALL); + } + virtual ~MyTestNetLog() {} +}; + +// A cert verifier that access all certificates. +class MyTestCertVerifier : public net::CertVerifier { + public: + MyTestCertVerifier() {} + virtual ~MyTestCertVerifier() {} + + virtual int Verify(net::X509Certificate* cert, + const std::string& hostname, + int flags, + net::CRLSet* crl_set, + net::CertVerifyResult* verify_result, + const net::CompletionCallback& callback, + RequestHandle* out_req, + const net::BoundNetLog& net_log) OVERRIDE { + return net::OK; + } + + virtual void CancelRequest(RequestHandle req) OVERRIDE { + // Do nothing. + } +}; + +class MCSProbe { + public: + MCSProbe( + const CommandLine& command_line, + scoped_refptr<net::URLRequestContextGetter> url_request_context_getter); + ~MCSProbe(); + + void Start(); + + uint64 android_id() const { return android_id_; } + uint64 secret() const { return secret_; } + + private: + void InitializeNetworkState(); + void BuildNetworkSession(); + + void InitializationCallback(bool success, + uint64 restored_android_id, + uint64 restored_security_token); + + CommandLine command_line_; + + base::FilePath rmq_path_; + uint64 android_id_; + uint64 secret_; + std::string server_host_; + int server_port_; + + // Network state. + scoped_refptr<net::URLRequestContextGetter> url_request_context_getter_; + MyTestNetLog net_log_; + scoped_ptr<net::NetLogLogger> logger_; + scoped_ptr<base::Value> net_constants_; + scoped_ptr<net::HostResolver> host_resolver_; + scoped_ptr<net::CertVerifier> cert_verifier_; + scoped_ptr<net::ServerBoundCertService> system_server_bound_cert_service_; + scoped_ptr<net::TransportSecurityState> transport_security_state_; + scoped_ptr<net::URLSecurityManager> url_security_manager_; + scoped_ptr<net::HttpAuthHandlerFactory> http_auth_handler_factory_; + scoped_ptr<net::HttpServerPropertiesImpl> http_server_properties_; + scoped_ptr<net::HostMappingRules> host_mapping_rules_; + scoped_refptr<net::HttpNetworkSession> network_session_; + scoped_ptr<net::ProxyService> proxy_service_; + + scoped_ptr<MCSClient> mcs_client_; + + scoped_ptr<ConnectionFactoryImpl> connection_factory_; + + base::Thread file_thread_; + + scoped_ptr<base::RunLoop> run_loop_; +}; + +MCSProbe::MCSProbe( + const CommandLine& command_line, + scoped_refptr<net::URLRequestContextGetter> url_request_context_getter) + : command_line_(command_line), + rmq_path_(base::FilePath(FILE_PATH_LITERAL("gcm_rmq_store"))), + android_id_(0), + secret_(0), + server_port_(0), + url_request_context_getter_(url_request_context_getter), + file_thread_("FileThread") { + if (command_line.HasSwitch(kRMQFileName)) { + rmq_path_ = command_line.GetSwitchValuePath(kRMQFileName); + } + if (command_line.HasSwitch(kAndroidIdSwitch)) { + base::StringToUint64(command_line.GetSwitchValueASCII(kAndroidIdSwitch), + &android_id_); + } + if (command_line.HasSwitch(kSecretSwitch)) { + base::StringToUint64(command_line.GetSwitchValueASCII(kSecretSwitch), + &secret_); + } + server_host_ = kMCSServerHost; + if (command_line.HasSwitch(kServerHostSwitch)) { + server_host_ = command_line.GetSwitchValueASCII(kServerHostSwitch); + } + server_port_ = kMCSServerPort; + if (command_line.HasSwitch(kServerPortSwitch)) { + base::StringToInt(command_line.GetSwitchValueASCII(kServerPortSwitch), + &server_port_); + } +} + +MCSProbe::~MCSProbe() { + file_thread_.Stop(); +} + +void MCSProbe::Start() { + file_thread_.Start(); + InitializeNetworkState(); + BuildNetworkSession(); + connection_factory_.reset( + new ConnectionFactoryImpl(GURL("https://" + net::HostPortPair( + server_host_, server_port_).ToString()), + network_session_, + &net_log_)); + mcs_client_.reset(new MCSClient(rmq_path_, + connection_factory_.get(), + file_thread_.message_loop_proxy())); + run_loop_.reset(new base::RunLoop()); + mcs_client_->Initialize(base::Bind(&MCSProbe::InitializationCallback, + base::Unretained(this)), + base::Bind(&MessageReceivedCallback), + base::Bind(&MessageSentCallback)); + run_loop_->Run(); +} + +void MCSProbe::InitializeNetworkState() { + FILE* log_file = NULL; + if (command_line_.HasSwitch(kLogFileSwitch)) { + base::FilePath log_path = command_line_.GetSwitchValuePath(kLogFileSwitch); +#if defined(OS_WIN) + log_file = _wfopen(log_path.value().c_str(), L"w"); +#elif defined(OS_POSIX) + log_file = fopen(log_path.value().c_str(), "w"); +#endif + } + net_constants_.reset(net::NetLogLogger::GetConstants()); + if (log_file != NULL) { + logger_.reset(new net::NetLogLogger(log_file, *net_constants_)); + logger_->StartObserving(&net_log_); + } + + host_resolver_ = net::HostResolver::CreateDefaultResolver(&net_log_); + + if (command_line_.HasSwitch(kIgnoreCertSwitch)) { + cert_verifier_.reset(new MyTestCertVerifier()); + } else { + cert_verifier_.reset(net::CertVerifier::CreateDefault()); + } + system_server_bound_cert_service_.reset( + new net::ServerBoundCertService( + new net::DefaultServerBoundCertStore(NULL), + base::WorkerPool::GetTaskRunner(true))); + + transport_security_state_.reset(new net::TransportSecurityState()); + url_security_manager_.reset(net::URLSecurityManager::Create(NULL, NULL)); + http_auth_handler_factory_.reset( + net::HttpAuthHandlerRegistryFactory::Create( + std::vector<std::string>(1, "basic"), + url_security_manager_.get(), + host_resolver_.get(), + std::string(), + false, + false)); + http_server_properties_.reset(new net::HttpServerPropertiesImpl()); + host_mapping_rules_.reset(new net::HostMappingRules()); + proxy_service_.reset(net::ProxyService::CreateDirectWithNetLog(&net_log_)); +} + +void MCSProbe::BuildNetworkSession() { + net::HttpNetworkSession::Params session_params; + session_params.host_resolver = host_resolver_.get(); + session_params.cert_verifier = cert_verifier_.get(); + session_params.server_bound_cert_service = + system_server_bound_cert_service_.get(); + session_params.transport_security_state = transport_security_state_.get(); + session_params.ssl_config_service = new net::SSLConfigServiceDefaults(); + session_params.http_auth_handler_factory = http_auth_handler_factory_.get(); + session_params.http_server_properties = + http_server_properties_->GetWeakPtr(); + session_params.network_delegate = NULL; // TODO(zea): implement? + session_params.host_mapping_rules = host_mapping_rules_.get(); + session_params.ignore_certificate_errors = true; + session_params.http_pipelining_enabled = false; + session_params.testing_fixed_http_port = 0; + session_params.testing_fixed_https_port = 0; + session_params.net_log = &net_log_; + session_params.proxy_service = proxy_service_.get(); + + network_session_ = new net::HttpNetworkSession(session_params); +} + +void MCSProbe::InitializationCallback(bool success, + uint64 restored_android_id, + uint64 restored_security_token) { + LOG(INFO) << "Initialization " << (success ? "success!" : "failure!"); + if (restored_android_id && restored_security_token) { + android_id_ = restored_android_id; + secret_ = restored_security_token; + } + if (success) + mcs_client_->Login(android_id_, secret_); +} + +int MCSProbeMain(int argc, char* argv[]) { + base::AtExitManager exit_manager; + + CommandLine::Init(argc, argv); + logging::LoggingSettings settings; + settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG; + logging::InitLogging(settings); + + base::MessageLoopForIO message_loop; + + // For check-in and creating registration ids. + const scoped_refptr<MyTestURLRequestContextGetter> context_getter = + new MyTestURLRequestContextGetter( + base::MessageLoop::current()->message_loop_proxy()); + + const CommandLine& command_line = *CommandLine::ForCurrentProcess(); + + MCSProbe mcs_probe(command_line, context_getter); + mcs_probe.Start(); + + base::RunLoop run_loop; + run_loop.Run(); + + return 0; +} + +} // namespace +} // namespace gcm + +int main(int argc, char* argv[]) { + return gcm::MCSProbeMain(argc, argv); +} |