diff options
-rw-r--r-- | google_apis/gcm/engine/connection_factory.cc | 12 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory.h | 56 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory_impl.cc | 200 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory_impl.h | 99 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_factory_impl_unittest.cc | 295 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler.cc | 388 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler.h | 118 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler_impl.cc | 404 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler_impl.h | 122 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler_impl_unittest.cc (renamed from google_apis/gcm/engine/connection_handler_unittest.cc) | 84 | ||||
-rw-r--r-- | google_apis/gcm/gcm.gyp | 12 | ||||
-rw-r--r-- | net/socket/client_socket_pool_manager.cc | 25 | ||||
-rw-r--r-- | net/socket/client_socket_pool_manager.h | 15 |
13 files changed, 1300 insertions, 530 deletions
diff --git a/google_apis/gcm/engine/connection_factory.cc b/google_apis/gcm/engine/connection_factory.cc new file mode 100644 index 0000000..016e1e2 --- /dev/null +++ b/google_apis/gcm/engine/connection_factory.cc @@ -0,0 +1,12 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/connection_factory.h" + +namespace gcm { + +ConnectionFactory::ConnectionFactory() {} +ConnectionFactory::~ConnectionFactory() {} + +} // namespace gcm diff --git a/google_apis/gcm/engine/connection_factory.h b/google_apis/gcm/engine/connection_factory.h new file mode 100644 index 0000000..598c211 --- /dev/null +++ b/google_apis/gcm/engine/connection_factory.h @@ -0,0 +1,56 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_CONNECTION_FACTORY_H_ +#define GOOGLE_APIS_GCM_ENGINE_CONNECTION_FACTORY_H_ + +#include "base/time/time.h" +#include "google_apis/gcm/base/gcm_export.h" +#include "google_apis/gcm/engine/connection_handler.h" + +namespace mcs_proto { +class LoginRequest; +} + +namespace gcm { + +// Factory for creating a ConnectionHandler and maintaining its connection. +// The factory retains ownership of the ConnectionHandler and will enforce +// backoff policies when attempting connections. +class GCM_EXPORT ConnectionFactory { + public: + ConnectionFactory(); + virtual ~ConnectionFactory(); + + // Create a new uninitialized connection handler. Should only be called once. + // The factory will retain ownership of the connection handler. + // |read_callback| will be invoked with the contents of any received protobuf + // message. + // |write_callback| will be invoked anytime a message has been successfully + // sent. Note: this just means the data was sent to the wire, not that the + // other end received it. + virtual ConnectionHandler* BuildConnectionHandler( + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) = 0; + + // Opens a new connection for use by the locally owned connection handler + // (created via BuildConnectionHandler), and initiates login handshake using + // |login_request|. Upon completion of the handshake, |read_callback| + // will be invoked with a valid mcs_proto::LoginResponse. + // Note: BuildConnectionHandler must have already been invoked. + virtual void Connect(const mcs_proto::LoginRequest& login_request) = 0; + + // Whether or not the MCS endpoint is currently reachable with an active + // connection. + virtual bool IsEndpointReachable() const = 0; + + // If in backoff, the time at which the next retry will be made. Otherwise, + // a null time, indicating either no attempt to connect has been made or no + // backoff is in progress. + virtual base::TimeTicks NextRetryAttempt() const = 0; +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_CONNECTION_FACTORY_H_ diff --git a/google_apis/gcm/engine/connection_factory_impl.cc b/google_apis/gcm/engine/connection_factory_impl.cc new file mode 100644 index 0000000..0a87acc --- /dev/null +++ b/google_apis/gcm/engine/connection_factory_impl.cc @@ -0,0 +1,200 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/connection_factory_impl.h" + +#include "base/message_loop/message_loop.h" +#include "google_apis/gcm/engine/connection_handler_impl.h" +#include "google_apis/gcm/protocol/mcs.pb.h" +#include "net/base/net_errors.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/proxy/proxy_info.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/ssl/ssl_config_service.h" + +namespace gcm { + +namespace { + +// The amount of time a Socket read should wait before timing out. +const int kReadTimeoutMs = 30000; // 30 seconds. + +// Backoff policy. +const net::BackoffEntry::Policy kConnectionBackoffPolicy = { + // Number of initial errors (in sequence) to ignore before applying + // exponential back-off rules. + 0, + + // Initial delay for exponential back-off in ms. + 10000, // 10 seconds. + + // Factor by which the waiting time will be multiplied. + 2, + + // Fuzzing percentage. ex: 10% will spread requests randomly + // between 90%-100% of the calculated time. + 0.2, // 20%. + + // Maximum amount of time we are willing to delay our request in ms. + 1000 * 3600 * 4, // 4 hours. + + // Time to keep an entry from being discarded even when it + // has no significant state, -1 to never discard. + -1, + + // Don't use initial delay unless the last request was an error. + false, +}; + +} // namespace + +ConnectionFactoryImpl::ConnectionFactoryImpl( + const GURL& mcs_endpoint, + scoped_refptr<net::HttpNetworkSession> network_session, + net::NetLog* net_log) + : mcs_endpoint_(mcs_endpoint), + network_session_(network_session), + net_log_(net_log), + weak_ptr_factory_(this) { +} + +ConnectionFactoryImpl::~ConnectionFactoryImpl() { +} + +ConnectionHandler* ConnectionFactoryImpl::BuildConnectionHandler( + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) { + DCHECK(!connection_handler_); + + backoff_entry_ = CreateBackoffEntry(&kConnectionBackoffPolicy); + + net::NetworkChangeNotifier::AddIPAddressObserver(this); + net::NetworkChangeNotifier::AddConnectionTypeObserver(this); + connection_handler_.reset( + new ConnectionHandlerImpl( + base::TimeDelta::FromMilliseconds(kReadTimeoutMs), + read_callback, + write_callback, + base::Bind(&ConnectionFactoryImpl::ConnectionHandlerCallback, + weak_ptr_factory_.GetWeakPtr()))); + return connection_handler_.get(); +} + +void ConnectionFactoryImpl::Connect( + const mcs_proto::LoginRequest& login_request) { + DCHECK(connection_handler_); + DCHECK(!IsEndpointReachable()); + + if (login_request.IsInitialized()) { + DCHECK(!login_request_.IsInitialized()); + login_request_ = login_request; + } + + if (backoff_entry_->ShouldRejectRequest()) { + DVLOG(1) << "Delaying MCS endpoint connection for " + << backoff_entry_->GetTimeUntilRelease().InMilliseconds() + << " milliseconds."; + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&ConnectionFactoryImpl::Connect, + weak_ptr_factory_.GetWeakPtr(), + login_request_), + NextRetryAttempt() - base::TimeTicks::Now()); + return; + } + + DVLOG(1) << "Attempting connection to MCS endpoint."; + ConnectImpl(); +} + +bool ConnectionFactoryImpl::IsEndpointReachable() const { + return connection_handler_ && connection_handler_->CanSendMessage(); +} + +base::TimeTicks ConnectionFactoryImpl::NextRetryAttempt() const { + if (!backoff_entry_) + return base::TimeTicks(); + return backoff_entry_->GetReleaseTime(); +} + +void ConnectionFactoryImpl::OnConnectionTypeChanged( + net::NetworkChangeNotifier::ConnectionType type) { + if (type == net::NetworkChangeNotifier::CONNECTION_NONE) + return; + + // TODO(zea): implement different backoff/retry policies based on connection + // type. + DVLOG(1) << "Connection type changed to " << type << ", resetting backoff."; + backoff_entry_->Reset(); + // Connect(..) should be retrying with backoff already if a connection is + // necessary, so no need to call again. +} + +void ConnectionFactoryImpl::OnIPAddressChanged() { + DVLOG(1) << "IP Address changed, resetting backoff."; + backoff_entry_->Reset(); + // Connect(..) should be retrying with backoff already if a connection is + // necessary, so no need to call again. +} + +void ConnectionFactoryImpl::ConnectImpl() { + DCHECK(!IsEndpointReachable()); + + // TODO(zea): resolve proxies. + net::ProxyInfo proxy_info; + proxy_info.UseDirect(); + net::SSLConfig ssl_config; + network_session_->ssl_config_service()->GetSSLConfig(&ssl_config); + + int status = net::InitSocketHandleForTlsConnect( + net::HostPortPair::FromURL(mcs_endpoint_), + network_session_.get(), + proxy_info, + ssl_config, + ssl_config, + net::kPrivacyModeDisabled, + net::BoundNetLog::Make(net_log_, net::NetLog::SOURCE_SOCKET), + &socket_handle_, + base::Bind(&ConnectionFactoryImpl::OnConnectDone, + weak_ptr_factory_.GetWeakPtr())); + if (status != net::ERR_IO_PENDING) + OnConnectDone(status); +} + +void ConnectionFactoryImpl::InitHandler() { + connection_handler_->Init(login_request_, socket_handle_.PassSocket()); +} + +scoped_ptr<net::BackoffEntry> ConnectionFactoryImpl::CreateBackoffEntry( + const net::BackoffEntry::Policy* const policy) { + return scoped_ptr<net::BackoffEntry>(new net::BackoffEntry(policy)); +} + +void ConnectionFactoryImpl::OnConnectDone(int result) { + if (result != net::OK) { + LOG(ERROR) << "Failed to connect to MCS endpoint with error " << result; + backoff_entry_->InformOfRequest(false); + Connect(mcs_proto::LoginRequest()); + return; + } + + DVLOG(1) << "MCS endpoint connection success."; + + // Reset the backoff. + backoff_entry_->Reset(); + + InitHandler(); +} + +void ConnectionFactoryImpl::ConnectionHandlerCallback(int result) { + // TODO(zea): Consider how to handle errors that may require some sort of + // user intervention (login page, etc.). + LOG(ERROR) << "Connection reset with error " << result; + backoff_entry_->InformOfRequest(false); + Connect(mcs_proto::LoginRequest()); +} + +} // namespace gcm diff --git a/google_apis/gcm/engine/connection_factory_impl.h b/google_apis/gcm/engine/connection_factory_impl.h new file mode 100644 index 0000000..0e40521 --- /dev/null +++ b/google_apis/gcm/engine/connection_factory_impl.h @@ -0,0 +1,99 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_CONNECTION_FACTORY_IMPL_H_ +#define GOOGLE_APIS_GCM_ENGINE_CONNECTION_FACTORY_IMPL_H_ + +#include "google_apis/gcm/engine/connection_factory.h" + +#include "base/memory/weak_ptr.h" +#include "google_apis/gcm/protocol/mcs.pb.h" +#include "net/base/backoff_entry.h" +#include "net/base/network_change_notifier.h" +#include "net/socket/client_socket_handle.h" +#include "url/gurl.h" + +namespace net { +class HttpNetworkSession; +class NetLog; +} + +namespace gcm { + +class ConnectionHandlerImpl; + +class GCM_EXPORT ConnectionFactoryImpl : + public ConnectionFactory, + public net::NetworkChangeNotifier::ConnectionTypeObserver, + public net::NetworkChangeNotifier::IPAddressObserver { + public: + ConnectionFactoryImpl( + const GURL& mcs_endpoint, + scoped_refptr<net::HttpNetworkSession> network_session, + net::NetLog* net_log); + virtual ~ConnectionFactoryImpl(); + + // ConnectionFactory implementation. + virtual ConnectionHandler* BuildConnectionHandler( + const ConnectionHandler::ProtoReceivedCallback& read_callback, + const ConnectionHandler::ProtoSentCallback& write_callback) OVERRIDE; + virtual void Connect(const mcs_proto::LoginRequest& login_request) OVERRIDE; + virtual bool IsEndpointReachable() const OVERRIDE; + virtual base::TimeTicks NextRetryAttempt() const OVERRIDE; + + // NetworkChangeNotifier observer implementations. + virtual void OnConnectionTypeChanged( + net::NetworkChangeNotifier::ConnectionType type) OVERRIDE; + virtual void OnIPAddressChanged() OVERRIDE; + + protected: + // Implementation of Connect(..). If not in backoff, uses |login_request_| + // in attempting a connection/handshake. On connection/handshake failure, goes + // into backoff. + // Virtual for testing. + virtual void ConnectImpl(); + + // Helper method for initalizing the connection hander. + // Virtual for testing. + virtual void InitHandler(); + + // Helper method for creating a backoff entry. + // Virtual for testing. + virtual scoped_ptr<net::BackoffEntry> CreateBackoffEntry( + const net::BackoffEntry::Policy* const policy); + + // Callback for Socket connection completion. + void OnConnectDone(int result); + + private: + // ConnectionHandler callback for connection issues. + void ConnectionHandlerCallback(int result); + + // The MCS endpoint to make connections to. + const GURL mcs_endpoint_; + + // ---- net:: components for establishing connections. ---- + // Network session for creating new connections. + const scoped_refptr<net::HttpNetworkSession> network_session_; + // Net log to use in connection attempts. + net::NetLog* const net_log_; + // The handle to the socket for the current connection, if one exists. + net::ClientSocketHandle socket_handle_; + // Connection attempt backoff policy. + scoped_ptr<net::BackoffEntry> backoff_entry_; + + // The current connection handler, if one exists. + scoped_ptr<ConnectionHandlerImpl> connection_handler_; + + // The current login request if a connection attempt is in progress/pending. + mcs_proto::LoginRequest login_request_; + + base::WeakPtrFactory<ConnectionFactoryImpl> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(ConnectionFactoryImpl); +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_CONNECTION_FACTORY_IMPL_H_ diff --git a/google_apis/gcm/engine/connection_factory_impl_unittest.cc b/google_apis/gcm/engine/connection_factory_impl_unittest.cc new file mode 100644 index 0000000..40adcf2 --- /dev/null +++ b/google_apis/gcm/engine/connection_factory_impl_unittest.cc @@ -0,0 +1,295 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/connection_factory_impl.h" + +#include <cmath> + +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/time/time.h" +#include "net/base/backoff_entry.h" +#include "net/http/http_network_session.h" +#include "testing/gtest/include/gtest/gtest.h" + +class Policy; + +namespace gcm { +namespace { + +const char kMCSEndpoint[] = "http://my.server"; + +const int kBackoffDelayMs = 1; +const int kBackoffMultiplier = 2; + +// A backoff policy with small enough delays that tests aren't burdened. +const net::BackoffEntry::Policy kTestBackoffPolicy = { + // Number of initial errors (in sequence) to ignore before applying + // exponential back-off rules. + 0, + + // Initial delay for exponential back-off in ms. + kBackoffDelayMs, + + // Factor by which the waiting time will be multiplied. + kBackoffMultiplier, + + // Fuzzing percentage. ex: 10% will spread requests randomly + // between 90%-100% of the calculated time. + 0, + + // Maximum amount of time we are willing to delay our request in ms. + 10, + + // Time to keep an entry from being discarded even when it + // has no significant state, -1 to never discard. + -1, + + // Don't use initial delay unless the last request was an error. + false, +}; + +// Helper for calculating total expected exponential backoff delay given an +// arbitrary number of failed attempts. See BackoffEntry::CalculateReleaseTime. +double CalculateBackoff(int num_attempts) { + double delay = kBackoffDelayMs; + for (int i = 1; i < num_attempts; ++i) { + delay += kBackoffDelayMs * pow(static_cast<double>(kBackoffMultiplier), + i - 1); + } + DVLOG(1) << "Expected backoff " << delay << " milliseconds."; + return delay; +} + +// Helper methods that should never actually be called due to real connections +// being stubbed out. +void ReadContinuation( + scoped_ptr<google::protobuf::MessageLite> message) { + ADD_FAILURE(); +} + +void WriteContinuation() { + ADD_FAILURE(); +} + +// A connection factory that stubs out network requests and overrides the +// backoff policy. +class TestConnectionFactoryImpl : public ConnectionFactoryImpl { + public: + TestConnectionFactoryImpl(const base::Closure& finished_callback); + virtual ~TestConnectionFactoryImpl(); + + // Overridden stubs. + virtual void ConnectImpl() OVERRIDE; + virtual void InitHandler() OVERRIDE; + virtual scoped_ptr<net::BackoffEntry> CreateBackoffEntry( + const net::BackoffEntry::Policy* const policy) OVERRIDE; + + // Helpers for verifying connection attempts are made. Connection results + // must be consumed. + void SetConnectResult(int connect_result); + void SetMultipleConnectResults(int connect_result, int num_expected_attempts); + + private: + // The result to return on the next connect attempt. + int connect_result_; + // The number of expected connection attempts; + int num_expected_attempts_; + // Whether all expected connection attempts have been fulfilled since an + // expectation was last set. + bool connections_fulfilled_; + // Callback to invoke when all connection attempts have been made. + base::Closure finished_callback_; +}; + +TestConnectionFactoryImpl::TestConnectionFactoryImpl( + const base::Closure& finished_callback) + : ConnectionFactoryImpl(GURL(kMCSEndpoint), NULL, NULL), + connect_result_(net::ERR_UNEXPECTED), + num_expected_attempts_(0), + connections_fulfilled_(true), + finished_callback_(finished_callback) { +} + +TestConnectionFactoryImpl::~TestConnectionFactoryImpl() { + EXPECT_EQ(0, num_expected_attempts_); +} + +void TestConnectionFactoryImpl::ConnectImpl() { + ASSERT_GT(num_expected_attempts_, 0); + + OnConnectDone(connect_result_); + --num_expected_attempts_; + if (num_expected_attempts_ == 0) { + connect_result_ = net::ERR_UNEXPECTED; + connections_fulfilled_ = true; + finished_callback_.Run(); + } +} + +void TestConnectionFactoryImpl::InitHandler() { + EXPECT_NE(connect_result_, net::ERR_UNEXPECTED); +} + +scoped_ptr<net::BackoffEntry> TestConnectionFactoryImpl::CreateBackoffEntry( + const net::BackoffEntry::Policy* const policy) { + return scoped_ptr<net::BackoffEntry>( + new net::BackoffEntry(&kTestBackoffPolicy)); +} + +void TestConnectionFactoryImpl::SetConnectResult(int connect_result) { + DCHECK_NE(connect_result, net::ERR_UNEXPECTED); + ASSERT_EQ(0, num_expected_attempts_); + connections_fulfilled_ = false; + connect_result_ = connect_result; + num_expected_attempts_ = 1; +} + +void TestConnectionFactoryImpl::SetMultipleConnectResults( + int connect_result, + int num_expected_attempts) { + DCHECK_NE(connect_result, net::ERR_UNEXPECTED); + DCHECK_GT(num_expected_attempts, 0); + ASSERT_EQ(0, num_expected_attempts_); + connections_fulfilled_ = false; + connect_result_ = connect_result; + num_expected_attempts_ = num_expected_attempts; +} + +class ConnectionFactoryImplTest : public testing::Test { + public: + ConnectionFactoryImplTest(); + virtual ~ConnectionFactoryImplTest(); + + TestConnectionFactoryImpl* factory() { return &factory_; } + + void WaitForConnections(); + + private: + void ConnectionsComplete(); + + TestConnectionFactoryImpl factory_; + base::MessageLoop message_loop_; + scoped_ptr<base::RunLoop> run_loop_; +}; + +ConnectionFactoryImplTest::ConnectionFactoryImplTest() + : factory_(base::Bind(&ConnectionFactoryImplTest::ConnectionsComplete, + base::Unretained(this))), + run_loop_(new base::RunLoop()) {} +ConnectionFactoryImplTest::~ConnectionFactoryImplTest() {} + +void ConnectionFactoryImplTest::WaitForConnections() { + run_loop_->Run(); + run_loop_.reset(new base::RunLoop()); +} + +void ConnectionFactoryImplTest::ConnectionsComplete() { + if (!run_loop_) + return; + run_loop_->Quit(); +} + +// Verify building a connection handler works. +TEST_F(ConnectionFactoryImplTest, BuildConnectionHandler) { + EXPECT_FALSE(factory()->IsEndpointReachable()); + ConnectionHandler* handler = factory()->BuildConnectionHandler( + base::Bind(&ReadContinuation), + base::Bind(&WriteContinuation)); + ASSERT_TRUE(handler); + EXPECT_FALSE(factory()->IsEndpointReachable()); +} + +// An initial successful connection should not result in backoff. +TEST_F(ConnectionFactoryImplTest, ConnectSuccess) { + factory()->BuildConnectionHandler( + ConnectionHandler::ProtoReceivedCallback(), + ConnectionHandler::ProtoSentCallback()); + factory()->SetConnectResult(net::OK); + factory()->Connect(mcs_proto::LoginRequest()); + EXPECT_TRUE(factory()->NextRetryAttempt().is_null()); +} + +// A connection failure should result in backoff. +TEST_F(ConnectionFactoryImplTest, ConnectFail) { + factory()->BuildConnectionHandler( + ConnectionHandler::ProtoReceivedCallback(), + ConnectionHandler::ProtoSentCallback()); + factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); + factory()->Connect(mcs_proto::LoginRequest()); + EXPECT_FALSE(factory()->NextRetryAttempt().is_null()); +} + +// A connection success after a failure should reset backoff. +TEST_F(ConnectionFactoryImplTest, FailThenSucceed) { + factory()->BuildConnectionHandler( + ConnectionHandler::ProtoReceivedCallback(), + ConnectionHandler::ProtoSentCallback()); + factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); + base::TimeTicks connect_time = base::TimeTicks::Now(); + factory()->Connect(mcs_proto::LoginRequest()); + WaitForConnections(); + base::TimeTicks retry_time = factory()->NextRetryAttempt(); + EXPECT_FALSE(retry_time.is_null()); + EXPECT_GE((retry_time - connect_time).InMilliseconds(), CalculateBackoff(1)); + factory()->SetConnectResult(net::OK); + WaitForConnections(); + EXPECT_TRUE(factory()->NextRetryAttempt().is_null()); +} + +// Multiple connection failures should retry with an exponentially increasing +// backoff, then reset on success. +TEST_F(ConnectionFactoryImplTest, MultipleFailuresThenSucceed) { + factory()->BuildConnectionHandler( + ConnectionHandler::ProtoReceivedCallback(), + ConnectionHandler::ProtoSentCallback()); + + const int kNumAttempts = 5; + factory()->SetMultipleConnectResults(net::ERR_CONNECTION_FAILED, + kNumAttempts); + + base::TimeTicks connect_time = base::TimeTicks::Now(); + factory()->Connect(mcs_proto::LoginRequest()); + WaitForConnections(); + base::TimeTicks retry_time = factory()->NextRetryAttempt(); + EXPECT_FALSE(retry_time.is_null()); + EXPECT_GE((retry_time - connect_time).InMilliseconds(), + CalculateBackoff(kNumAttempts)); + + factory()->SetConnectResult(net::OK); + WaitForConnections(); + EXPECT_TRUE(factory()->NextRetryAttempt().is_null()); +} + +// IP events should reset backoff. +TEST_F(ConnectionFactoryImplTest, FailThenIPEvent) { + factory()->BuildConnectionHandler( + ConnectionHandler::ProtoReceivedCallback(), + ConnectionHandler::ProtoSentCallback()); + factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); + factory()->Connect(mcs_proto::LoginRequest()); + WaitForConnections(); + EXPECT_FALSE(factory()->NextRetryAttempt().is_null()); + + factory()->OnIPAddressChanged(); + EXPECT_TRUE(factory()->NextRetryAttempt().is_null()); +} + +// Connection type events should reset backoff. +TEST_F(ConnectionFactoryImplTest, FailThenConnectionTypeEvent) { + factory()->BuildConnectionHandler( + ConnectionHandler::ProtoReceivedCallback(), + ConnectionHandler::ProtoSentCallback()); + factory()->SetConnectResult(net::ERR_CONNECTION_FAILED); + factory()->Connect(mcs_proto::LoginRequest()); + WaitForConnections(); + EXPECT_FALSE(factory()->NextRetryAttempt().is_null()); + + factory()->OnConnectionTypeChanged( + net::NetworkChangeNotifier::CONNECTION_WIFI); + EXPECT_TRUE(factory()->NextRetryAttempt().is_null()); +} + +} // namespace +} // namespace gcm diff --git a/google_apis/gcm/engine/connection_handler.cc b/google_apis/gcm/engine/connection_handler.cc index b4eb602..bc9b658 100644 --- a/google_apis/gcm/engine/connection_handler.cc +++ b/google_apis/gcm/engine/connection_handler.cc @@ -4,398 +4,12 @@ #include "google_apis/gcm/engine/connection_handler.h" -#include "base/message_loop/message_loop.h" -#include "google/protobuf/io/coded_stream.h" -#include "google_apis/gcm/base/mcs_util.h" -#include "google_apis/gcm/base/socket_stream.h" -#include "net/base/net_errors.h" -#include "net/socket/stream_socket.h" - -using namespace google::protobuf::io; - namespace gcm { -namespace { - -// # of bytes a MCS version packet consumes. -const int kVersionPacketLen = 1; -// # of bytes a tag packet consumes. -const int kTagPacketLen = 1; -// Max # of bytes a length packet consumes. -const int kSizePacketLenMin = 1; -const int kSizePacketLenMax = 2; - -// The current MCS protocol version. -const int kMCSVersion = 38; - -} // namespace - -ConnectionHandler::ConnectionHandler(base::TimeDelta read_timeout) - : read_timeout_(read_timeout), - handshake_complete_(false), - message_tag_(0), - message_size_(0), - weak_ptr_factory_(this) { +ConnectionHandler::ConnectionHandler() { } ConnectionHandler::~ConnectionHandler() { } -void ConnectionHandler::Init( - scoped_ptr<net::StreamSocket> socket, - const google::protobuf::MessageLite& login_request, - const ProtoReceivedCallback& read_callback, - const ProtoSentCallback& write_callback, - const ConnectionChangedCallback& connection_callback) { - DCHECK(!read_callback.is_null()); - DCHECK(!write_callback.is_null()); - DCHECK(!connection_callback.is_null()); - - // Invalidate any previously outstanding reads. - weak_ptr_factory_.InvalidateWeakPtrs(); - - handshake_complete_ = false; - message_tag_ = 0; - message_size_ = 0; - socket_ = socket.Pass(); - input_stream_.reset(new SocketInputStream(socket_.get())); - output_stream_.reset(new SocketOutputStream(socket_.get())); - read_callback_ = read_callback; - write_callback_ = write_callback; - connection_callback_ = connection_callback; - - Login(login_request); -} - -bool ConnectionHandler::CanSendMessage() const { - return handshake_complete_ && output_stream_.get() && - output_stream_->GetState() == SocketOutputStream::EMPTY; -} - -void ConnectionHandler::SendMessage( - const google::protobuf::MessageLite& message) { - DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); - DCHECK(handshake_complete_); - - { - CodedOutputStream coded_output_stream(output_stream_.get()); - DVLOG(1) << "Writing proto of size " << message.ByteSize(); - int tag = GetMCSProtoTag(message); - DCHECK_NE(tag, -1); - coded_output_stream.WriteRaw(&tag, 1); - coded_output_stream.WriteVarint32(message.ByteSize()); - message.SerializeToCodedStream(&coded_output_stream); - } - - if (output_stream_->Flush( - base::Bind(&ConnectionHandler::OnMessageSent, - weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { - OnMessageSent(); - } -} - -void ConnectionHandler::Login( - const google::protobuf::MessageLite& login_request) { - DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); - - const char version_byte[1] = {kMCSVersion}; - const char login_request_tag[1] = {kLoginRequestTag}; - { - CodedOutputStream coded_output_stream(output_stream_.get()); - coded_output_stream.WriteRaw(version_byte, 1); - coded_output_stream.WriteRaw(login_request_tag, 1); - coded_output_stream.WriteVarint32(login_request.ByteSize()); - login_request.SerializeToCodedStream(&coded_output_stream); - } - - if (output_stream_->Flush( - base::Bind(&ConnectionHandler::OnMessageSent, - weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&ConnectionHandler::OnMessageSent, - weak_ptr_factory_.GetWeakPtr())); - } - - read_timeout_timer_.Start(FROM_HERE, - read_timeout_, - base::Bind(&ConnectionHandler::OnTimeout, - weak_ptr_factory_.GetWeakPtr())); - WaitForData(MCS_VERSION_TAG_AND_SIZE); -} - -void ConnectionHandler::OnMessageSent() { - if (!output_stream_.get()) { - // The connection has already been closed. Just return. - DCHECK(!input_stream_.get()); - DCHECK(!read_timeout_timer_.IsRunning()); - return; - } - - if (output_stream_->GetState() != SocketOutputStream::EMPTY) { - int last_error = output_stream_->last_error(); - CloseConnection(); - // If the socket stream had an error, plumb it up, else plumb up FAILED. - if (last_error == net::OK) - last_error = net::ERR_FAILED; - connection_callback_.Run(last_error); - return; - } - - write_callback_.Run(); -} - -void ConnectionHandler::GetNextMessage() { - DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || - SocketInputStream::READY == input_stream_->GetState()); - message_tag_ = 0; - message_size_ = 0; - - WaitForData(MCS_TAG_AND_SIZE); -} - -void ConnectionHandler::WaitForData(ProcessingState state) { - DVLOG(1) << "Waiting for MCS data: state == " << state; - - if (!input_stream_) { - // The connection has already been closed. Just return. - DCHECK(!output_stream_.get()); - DCHECK(!read_timeout_timer_.IsRunning()); - return; - } - - if (input_stream_->GetState() != SocketInputStream::EMPTY && - input_stream_->GetState() != SocketInputStream::READY) { - // An error occurred. - int last_error = output_stream_->last_error(); - CloseConnection(); - // If the socket stream had an error, plumb it up, else plumb up FAILED. - if (last_error == net::OK) - last_error = net::ERR_FAILED; - connection_callback_.Run(last_error); - return; - } - - // Used to determine whether a Socket::Read is necessary. - int min_bytes_needed = 0; - // Used to limit the size of the Socket::Read. - int max_bytes_needed = 0; - - switch(state) { - case MCS_VERSION_TAG_AND_SIZE: - min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; - max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; - break; - case MCS_TAG_AND_SIZE: - min_bytes_needed = kTagPacketLen + kSizePacketLenMin; - max_bytes_needed = kTagPacketLen + kSizePacketLenMax; - break; - case MCS_FULL_SIZE: - // If in this state, the minimum size packet length must already have been - // insufficient, so set both to the max length. - min_bytes_needed = kSizePacketLenMax; - max_bytes_needed = kSizePacketLenMax; - break; - case MCS_PROTO_BYTES: - read_timeout_timer_.Reset(); - // No variability in the message size, set both to the same. - min_bytes_needed = message_size_; - max_bytes_needed = message_size_; - break; - default: - NOTREACHED(); - } - DCHECK_GE(max_bytes_needed, min_bytes_needed); - - int byte_count = input_stream_->UnreadByteCount(); - if (min_bytes_needed - byte_count > 0 && - input_stream_->Refresh( - base::Bind(&ConnectionHandler::WaitForData, - weak_ptr_factory_.GetWeakPtr(), - state), - max_bytes_needed - byte_count) == net::ERR_IO_PENDING) { - return; - } - - // Check for refresh errors. - if (input_stream_->GetState() != SocketInputStream::READY) { - // An error occurred. - int last_error = output_stream_->last_error(); - CloseConnection(); - // If the socket stream had an error, plumb it up, else plumb up FAILED. - if (last_error == net::OK) - last_error = net::ERR_FAILED; - connection_callback_.Run(last_error); - return; - } - - // Received enough bytes, process them. - DVLOG(1) << "Processing MCS data: state == " << state; - switch(state) { - case MCS_VERSION_TAG_AND_SIZE: - OnGotVersion(); - break; - case MCS_TAG_AND_SIZE: - OnGotMessageTag(); - break; - case MCS_FULL_SIZE: - OnGotMessageSize(); - break; - case MCS_PROTO_BYTES: - OnGotMessageBytes(); - break; - default: - NOTREACHED(); - } -} - -void ConnectionHandler::OnGotVersion() { - uint8 version = 0; - { - CodedInputStream coded_input_stream(input_stream_.get()); - coded_input_stream.ReadRaw(&version, 1); - } - if (version < kMCSVersion) { - LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); - connection_callback_.Run(net::ERR_FAILED); - return; - } - - input_stream_->RebuildBuffer(); - - // Process the LoginResponse message tag. - OnGotMessageTag(); -} - -void ConnectionHandler::OnGotMessageTag() { - if (input_stream_->GetState() != SocketInputStream::READY) { - LOG(ERROR) << "Failed to receive protobuf tag."; - read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); - return; - } - - { - CodedInputStream coded_input_stream(input_stream_.get()); - coded_input_stream.ReadRaw(&message_tag_, 1); - } - - DVLOG(1) << "Received proto of type " - << static_cast<unsigned int>(message_tag_); - - if (!read_timeout_timer_.IsRunning()) { - read_timeout_timer_.Start(FROM_HERE, - read_timeout_, - base::Bind(&ConnectionHandler::OnTimeout, - weak_ptr_factory_.GetWeakPtr())); - } - OnGotMessageSize(); -} - -void ConnectionHandler::OnGotMessageSize() { - if (input_stream_->GetState() != SocketInputStream::READY) { - LOG(ERROR) << "Failed to receive message size."; - read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); - return; - } - - bool need_another_byte = false; - int prev_byte_count = input_stream_->ByteCount(); - { - CodedInputStream coded_input_stream(input_stream_.get()); - if (!coded_input_stream.ReadVarint32(&message_size_)) - need_another_byte = true; - } - - if (need_another_byte) { - DVLOG(1) << "Expecting another message size byte."; - if (prev_byte_count >= kSizePacketLenMax) { - // Already had enough bytes, something else went wrong. - LOG(ERROR) << "Failed to process message size."; - read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); - return; - } - // Back up by the amount read (should always be 1 byte). - int bytes_read = prev_byte_count - input_stream_->ByteCount(); - DCHECK_EQ(bytes_read, 1); - input_stream_->BackUp(bytes_read); - WaitForData(MCS_FULL_SIZE); - return; - } - - DVLOG(1) << "Proto size: " << message_size_; - - if (message_size_ > 0) - WaitForData(MCS_PROTO_BYTES); - else - OnGotMessageBytes(); -} - -void ConnectionHandler::OnGotMessageBytes() { - read_timeout_timer_.Stop(); - scoped_ptr<google::protobuf::MessageLite> protobuf( - BuildProtobufFromTag(message_tag_)); - // Messages with no content are valid; just use the default protobuf for - // that tag. - if (protobuf.get() && message_size_ == 0) { - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&ConnectionHandler::GetNextMessage, - weak_ptr_factory_.GetWeakPtr())); - read_callback_.Run(protobuf.Pass()); - return; - } - - if (!protobuf.get() || - input_stream_->GetState() != SocketInputStream::READY) { - LOG(ERROR) << "Failed to extract protobuf bytes of type " - << static_cast<unsigned int>(message_tag_); - protobuf.reset(); // Return a null pointer to denote an error. - read_callback_.Run(protobuf.Pass()); - return; - } - - { - CodedInputStream coded_input_stream(input_stream_.get()); - if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { - NOTREACHED() << "Unable to parse GCM message of type " - << static_cast<unsigned int>(message_tag_); - protobuf.reset(); // Return a null pointer to denote an error. - read_callback_.Run(protobuf.Pass()); - return; - } - } - - input_stream_->RebuildBuffer(); - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&ConnectionHandler::GetNextMessage, - weak_ptr_factory_.GetWeakPtr())); - if (message_tag_ == kLoginResponseTag) { - if (handshake_complete_) { - LOG(ERROR) << "Unexpected login response."; - } else { - handshake_complete_ = true; - DVLOG(1) << "GCM Handshake complete."; - } - } - read_callback_.Run(protobuf.Pass()); -} - -void ConnectionHandler::OnTimeout() { - LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; - CloseConnection(); - connection_callback_.Run(net::ERR_TIMED_OUT); -} - -void ConnectionHandler::CloseConnection() { - DVLOG(1) << "Closing connection."; - read_callback_.Reset(); - write_callback_.Reset(); - read_timeout_timer_.Stop(); - socket_->Disconnect(); - input_stream_.reset(); - output_stream_.reset(); - weak_ptr_factory_.InvalidateWeakPtrs(); -} - } // namespace gcm diff --git a/google_apis/gcm/engine/connection_handler.h b/google_apis/gcm/engine/connection_handler.h index 6dd838c..5b9ea71 100644 --- a/google_apis/gcm/engine/connection_handler.h +++ b/google_apis/gcm/engine/connection_handler.h @@ -5,13 +5,21 @@ #ifndef GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_H_ #define GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_H_ -#include "base/memory/weak_ptr.h" -#include "base/timer/timer.h" +#include "base/callback.h" #include "google_apis/gcm/base/gcm_export.h" -#include "google_apis/gcm/protocol/mcs.pb.h" namespace net{ class StreamSocket; +} // namespace net + +namespace google { +namespace protobuf { +class MessageLite; +} // namespace protobuf +} // namepsace google + +namespace mcs_proto { +class LoginRequest; } namespace gcm { @@ -31,113 +39,23 @@ class GCM_EXPORT ConnectionHandler { typedef base::Closure ProtoSentCallback; typedef base::Callback<void(int)> ConnectionChangedCallback; - explicit ConnectionHandler(base::TimeDelta read_timeout); - ~ConnectionHandler(); + ConnectionHandler(); + virtual ~ConnectionHandler(); // Starts a new MCS connection handshake (using |login_request|) and, upon - // success, begins listening for incoming/outgoing messages. A successful - // handshake is when a mcs_proto::LoginResponse is received, and is signaled - // via the |read_callback|. - // Outputs: - // |read_callback| will be invoked with the contents of any received protobuf - // message. - // |write_callback| will be invoked anytime a message has been successfully - // sent. Note: this just means the data was sent to the wire, not that the - // other end received it. - // |connection_callback| will be invoked with any fatal read/write errors - // encountered. + // success, begins listening for incoming/outgoing messages. // // Note: It is correct and expected to call Init more than once, as connection // issues are encountered and new connections must be made. - void Init(scoped_ptr<net::StreamSocket> socket, - const google::protobuf::MessageLite& login_request, - const ProtoReceivedCallback& read_callback, - const ProtoSentCallback& write_callback, - const ConnectionChangedCallback& connection_callback); + virtual void Init(const mcs_proto::LoginRequest& login_request, + scoped_ptr<net::StreamSocket> socket) = 0; // Checks that a handshake has been completed and a message is not already // in flight. - bool CanSendMessage() const; + virtual bool CanSendMessage() const = 0; // Send an MCS protobuf message. CanSendMessage() must be true. - void SendMessage(const google::protobuf::MessageLite& message); - - private: - // State machine for handling incoming data. See WaitForData(..) for usage. - enum ProcessingState { - // Processing the version, tag, and size packets (assuming minimum length - // size packet). Only used during the login handshake. - MCS_VERSION_TAG_AND_SIZE = 0, - // Processing the tag and size packets (assuming minimum length size - // packet). Used for normal messages. - MCS_TAG_AND_SIZE, - // Processing a maximum length size packet (for messages with length > 128). - // Used when a normal size packet was not sufficient to read the message - // size. - MCS_FULL_SIZE, - // Processing the protocol buffer bytes (for those messages with non-zero - // sizes). - MCS_PROTO_BYTES - }; - - // Sends the protocol version and login request. First step in the MCS - // connection handshake. - void Login(const google::protobuf::MessageLite& login_request); - - // SendMessage continuation. Invoked when Socket::Write completes. - void OnMessageSent(); - - // Starts the message processing process, which is comprised of the tag, - // message size, and bytes packet types. - void GetNextMessage(); - - // Performs any necessary SocketInputStream refreshing until the data - // associated with |packet_type| is fully ready, then calls the appropriate - // OnGot* message to process the packet data. If the read times out, - // will close the stream and invoke the connection callback. - void WaitForData(ProcessingState state); - - // Incoming data helper methods. - void OnGotVersion(); - void OnGotMessageTag(); - void OnGotMessageSize(); - void OnGotMessageBytes(); - - // Timeout handler. - void OnTimeout(); - - // Closes the current connection. - void CloseConnection(); - - // Timeout policy: the timeout is only enforced while waiting on the - // handshake (version and/or LoginResponse) or once at least a tag packet has - // been received. It is reset every time new data is received, and is - // only stopped when a full message is processed. - // TODO(zea): consider enforcing a separate timeout when waiting for - // a message to send. - const base::TimeDelta read_timeout_; - base::OneShotTimer<ConnectionHandler> read_timeout_timer_; - - // This connection's socket and the input/output streams attached to it. - scoped_ptr<net::StreamSocket> socket_; - scoped_ptr<SocketInputStream> input_stream_; - scoped_ptr<SocketOutputStream> output_stream_; - - // Whether the MCS login handshake has successfully completed. See Init(..) - // description for more info on what the handshake involves. - bool handshake_complete_; - - // State for the message currently being processed, if there is one. - uint8 message_tag_; - uint32 message_size_; - - ProtoReceivedCallback read_callback_; - ProtoSentCallback write_callback_; - ConnectionChangedCallback connection_callback_; - - base::WeakPtrFactory<ConnectionHandler> weak_ptr_factory_; - - DISALLOW_COPY_AND_ASSIGN(ConnectionHandler); + virtual void SendMessage(const google::protobuf::MessageLite& message) = 0; }; } // namespace gcm diff --git a/google_apis/gcm/engine/connection_handler_impl.cc b/google_apis/gcm/engine/connection_handler_impl.cc new file mode 100644 index 0000000..aff0dfd --- /dev/null +++ b/google_apis/gcm/engine/connection_handler_impl.cc @@ -0,0 +1,404 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/connection_handler_impl.h" + +#include "base/message_loop/message_loop.h" +#include "google/protobuf/io/coded_stream.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/base/socket_stream.h" +#include "google_apis/gcm/protocol/mcs.pb.h" +#include "net/base/net_errors.h" +#include "net/socket/stream_socket.h" + +using namespace google::protobuf::io; + +namespace gcm { + +namespace { + +// # of bytes a MCS version packet consumes. +const int kVersionPacketLen = 1; +// # of bytes a tag packet consumes. +const int kTagPacketLen = 1; +// Max # of bytes a length packet consumes. +const int kSizePacketLenMin = 1; +const int kSizePacketLenMax = 2; + +// The current MCS protocol version. +// TODO(zea): bump to 41 once the server supports it. +const int kMCSVersion = 38; + +} // namespace + +ConnectionHandlerImpl::ConnectionHandlerImpl( + base::TimeDelta read_timeout, + const ProtoReceivedCallback& read_callback, + const ProtoSentCallback& write_callback, + const ConnectionChangedCallback& connection_callback) + : read_timeout_(read_timeout), + handshake_complete_(false), + message_tag_(0), + message_size_(0), + read_callback_(read_callback), + write_callback_(write_callback), + connection_callback_(connection_callback), + weak_ptr_factory_(this) { +} + +ConnectionHandlerImpl::~ConnectionHandlerImpl() { +} + +void ConnectionHandlerImpl::Init( + const mcs_proto::LoginRequest& login_request, + scoped_ptr<net::StreamSocket> socket) { + DCHECK(!read_callback_.is_null()); + DCHECK(!write_callback_.is_null()); + DCHECK(!connection_callback_.is_null()); + + // Invalidate any previously outstanding reads. + weak_ptr_factory_.InvalidateWeakPtrs(); + + handshake_complete_ = false; + message_tag_ = 0; + message_size_ = 0; + socket_ = socket.Pass(); + input_stream_.reset(new SocketInputStream(socket_.get())); + output_stream_.reset(new SocketOutputStream(socket_.get())); + + Login(login_request); +} + +bool ConnectionHandlerImpl::CanSendMessage() const { + return handshake_complete_ && output_stream_.get() && + output_stream_->GetState() == SocketOutputStream::EMPTY; +} + +void ConnectionHandlerImpl::SendMessage( + const google::protobuf::MessageLite& message) { + DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); + DCHECK(handshake_complete_); + + { + CodedOutputStream coded_output_stream(output_stream_.get()); + DVLOG(1) << "Writing proto of size " << message.ByteSize(); + int tag = GetMCSProtoTag(message); + DCHECK_NE(tag, -1); + coded_output_stream.WriteRaw(&tag, 1); + coded_output_stream.WriteVarint32(message.ByteSize()); + message.SerializeToCodedStream(&coded_output_stream); + } + + if (output_stream_->Flush( + base::Bind(&ConnectionHandlerImpl::OnMessageSent, + weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { + OnMessageSent(); + } +} + +void ConnectionHandlerImpl::Login( + const google::protobuf::MessageLite& login_request) { + DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); + + const char version_byte[1] = {kMCSVersion}; + const char login_request_tag[1] = {kLoginRequestTag}; + { + CodedOutputStream coded_output_stream(output_stream_.get()); + coded_output_stream.WriteRaw(version_byte, 1); + coded_output_stream.WriteRaw(login_request_tag, 1); + coded_output_stream.WriteVarint32(login_request.ByteSize()); + login_request.SerializeToCodedStream(&coded_output_stream); + } + + if (output_stream_->Flush( + base::Bind(&ConnectionHandlerImpl::OnMessageSent, + weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ConnectionHandlerImpl::OnMessageSent, + weak_ptr_factory_.GetWeakPtr())); + } + + read_timeout_timer_.Start(FROM_HERE, + read_timeout_, + base::Bind(&ConnectionHandlerImpl::OnTimeout, + weak_ptr_factory_.GetWeakPtr())); + WaitForData(MCS_VERSION_TAG_AND_SIZE); +} + +void ConnectionHandlerImpl::OnMessageSent() { + if (!output_stream_.get()) { + // The connection has already been closed. Just return. + DCHECK(!input_stream_.get()); + DCHECK(!read_timeout_timer_.IsRunning()); + return; + } + + if (output_stream_->GetState() != SocketOutputStream::EMPTY) { + int last_error = output_stream_->last_error(); + CloseConnection(); + // If the socket stream had an error, plumb it up, else plumb up FAILED. + if (last_error == net::OK) + last_error = net::ERR_FAILED; + connection_callback_.Run(last_error); + return; + } + + write_callback_.Run(); +} + +void ConnectionHandlerImpl::GetNextMessage() { + DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || + SocketInputStream::READY == input_stream_->GetState()); + message_tag_ = 0; + message_size_ = 0; + + WaitForData(MCS_TAG_AND_SIZE); +} + +void ConnectionHandlerImpl::WaitForData(ProcessingState state) { + DVLOG(1) << "Waiting for MCS data: state == " << state; + + if (!input_stream_) { + // The connection has already been closed. Just return. + DCHECK(!output_stream_.get()); + DCHECK(!read_timeout_timer_.IsRunning()); + return; + } + + if (input_stream_->GetState() != SocketInputStream::EMPTY && + input_stream_->GetState() != SocketInputStream::READY) { + // An error occurred. + int last_error = output_stream_->last_error(); + CloseConnection(); + // If the socket stream had an error, plumb it up, else plumb up FAILED. + if (last_error == net::OK) + last_error = net::ERR_FAILED; + connection_callback_.Run(last_error); + return; + } + + // Used to determine whether a Socket::Read is necessary. + int min_bytes_needed = 0; + // Used to limit the size of the Socket::Read. + int max_bytes_needed = 0; + + switch(state) { + case MCS_VERSION_TAG_AND_SIZE: + min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; + max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; + break; + case MCS_TAG_AND_SIZE: + min_bytes_needed = kTagPacketLen + kSizePacketLenMin; + max_bytes_needed = kTagPacketLen + kSizePacketLenMax; + break; + case MCS_FULL_SIZE: + // If in this state, the minimum size packet length must already have been + // insufficient, so set both to the max length. + min_bytes_needed = kSizePacketLenMax; + max_bytes_needed = kSizePacketLenMax; + break; + case MCS_PROTO_BYTES: + read_timeout_timer_.Reset(); + // No variability in the message size, set both to the same. + min_bytes_needed = message_size_; + max_bytes_needed = message_size_; + break; + default: + NOTREACHED(); + } + DCHECK_GE(max_bytes_needed, min_bytes_needed); + + int byte_count = input_stream_->UnreadByteCount(); + if (min_bytes_needed - byte_count > 0 && + input_stream_->Refresh( + base::Bind(&ConnectionHandlerImpl::WaitForData, + weak_ptr_factory_.GetWeakPtr(), + state), + max_bytes_needed - byte_count) == net::ERR_IO_PENDING) { + return; + } + + // Check for refresh errors. + if (input_stream_->GetState() != SocketInputStream::READY) { + // An error occurred. + int last_error = output_stream_->last_error(); + CloseConnection(); + // If the socket stream had an error, plumb it up, else plumb up FAILED. + if (last_error == net::OK) + last_error = net::ERR_FAILED; + connection_callback_.Run(last_error); + return; + } + + // Received enough bytes, process them. + DVLOG(1) << "Processing MCS data: state == " << state; + switch(state) { + case MCS_VERSION_TAG_AND_SIZE: + OnGotVersion(); + break; + case MCS_TAG_AND_SIZE: + OnGotMessageTag(); + break; + case MCS_FULL_SIZE: + OnGotMessageSize(); + break; + case MCS_PROTO_BYTES: + OnGotMessageBytes(); + break; + default: + NOTREACHED(); + } +} + +void ConnectionHandlerImpl::OnGotVersion() { + uint8 version = 0; + { + CodedInputStream coded_input_stream(input_stream_.get()); + coded_input_stream.ReadRaw(&version, 1); + } + if (version < kMCSVersion) { + LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); + connection_callback_.Run(net::ERR_FAILED); + return; + } + + input_stream_->RebuildBuffer(); + + // Process the LoginResponse message tag. + OnGotMessageTag(); +} + +void ConnectionHandlerImpl::OnGotMessageTag() { + if (input_stream_->GetState() != SocketInputStream::READY) { + LOG(ERROR) << "Failed to receive protobuf tag."; + read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); + return; + } + + { + CodedInputStream coded_input_stream(input_stream_.get()); + coded_input_stream.ReadRaw(&message_tag_, 1); + } + + DVLOG(1) << "Received proto of type " + << static_cast<unsigned int>(message_tag_); + + if (!read_timeout_timer_.IsRunning()) { + read_timeout_timer_.Start(FROM_HERE, + read_timeout_, + base::Bind(&ConnectionHandlerImpl::OnTimeout, + weak_ptr_factory_.GetWeakPtr())); + } + OnGotMessageSize(); +} + +void ConnectionHandlerImpl::OnGotMessageSize() { + if (input_stream_->GetState() != SocketInputStream::READY) { + LOG(ERROR) << "Failed to receive message size."; + read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); + return; + } + + bool need_another_byte = false; + int prev_byte_count = input_stream_->ByteCount(); + { + CodedInputStream coded_input_stream(input_stream_.get()); + if (!coded_input_stream.ReadVarint32(&message_size_)) + need_another_byte = true; + } + + if (need_another_byte) { + DVLOG(1) << "Expecting another message size byte."; + if (prev_byte_count >= kSizePacketLenMax) { + // Already had enough bytes, something else went wrong. + LOG(ERROR) << "Failed to process message size."; + read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); + return; + } + // Back up by the amount read (should always be 1 byte). + int bytes_read = prev_byte_count - input_stream_->ByteCount(); + DCHECK_EQ(bytes_read, 1); + input_stream_->BackUp(bytes_read); + WaitForData(MCS_FULL_SIZE); + return; + } + + DVLOG(1) << "Proto size: " << message_size_; + + if (message_size_ > 0) + WaitForData(MCS_PROTO_BYTES); + else + OnGotMessageBytes(); +} + +void ConnectionHandlerImpl::OnGotMessageBytes() { + read_timeout_timer_.Stop(); + scoped_ptr<google::protobuf::MessageLite> protobuf( + BuildProtobufFromTag(message_tag_)); + // Messages with no content are valid; just use the default protobuf for + // that tag. + if (protobuf.get() && message_size_ == 0) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ConnectionHandlerImpl::GetNextMessage, + weak_ptr_factory_.GetWeakPtr())); + read_callback_.Run(protobuf.Pass()); + return; + } + + if (!protobuf.get() || + input_stream_->GetState() != SocketInputStream::READY) { + LOG(ERROR) << "Failed to extract protobuf bytes of type " + << static_cast<unsigned int>(message_tag_); + protobuf.reset(); // Return a null pointer to denote an error. + read_callback_.Run(protobuf.Pass()); + return; + } + + { + CodedInputStream coded_input_stream(input_stream_.get()); + if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { + NOTREACHED() << "Unable to parse GCM message of type " + << static_cast<unsigned int>(message_tag_); + protobuf.reset(); // Return a null pointer to denote an error. + read_callback_.Run(protobuf.Pass()); + return; + } + } + + input_stream_->RebuildBuffer(); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ConnectionHandlerImpl::GetNextMessage, + weak_ptr_factory_.GetWeakPtr())); + if (message_tag_ == kLoginResponseTag) { + if (handshake_complete_) { + LOG(ERROR) << "Unexpected login response."; + } else { + handshake_complete_ = true; + DVLOG(1) << "GCM Handshake complete."; + } + } + read_callback_.Run(protobuf.Pass()); +} + +void ConnectionHandlerImpl::OnTimeout() { + LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; + CloseConnection(); + connection_callback_.Run(net::ERR_TIMED_OUT); +} + +void ConnectionHandlerImpl::CloseConnection() { + DVLOG(1) << "Closing connection."; + read_callback_.Reset(); + write_callback_.Reset(); + read_timeout_timer_.Stop(); + socket_->Disconnect(); + input_stream_.reset(); + output_stream_.reset(); + weak_ptr_factory_.InvalidateWeakPtrs(); +} + +} // namespace gcm diff --git a/google_apis/gcm/engine/connection_handler_impl.h b/google_apis/gcm/engine/connection_handler_impl.h new file mode 100644 index 0000000..110cdcd --- /dev/null +++ b/google_apis/gcm/engine/connection_handler_impl.h @@ -0,0 +1,122 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_IMPL_H_ +#define GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_IMPL_H_ + +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +#include "google_apis/gcm/engine/connection_handler.h" + +namespace mcs_proto { +class LoginRequest; +} // namespace mcs_proto + +namespace gcm { + +class GCM_EXPORT ConnectionHandlerImpl : public ConnectionHandler { + public: + // |read_callback| will be invoked with the contents of any received protobuf + // message. + // |write_callback| will be invoked anytime a message has been successfully + // sent. Note: this just means the data was sent to the wire, not that the + // other end received it. + // |connection_callback| will be invoked with any fatal read/write errors + // encountered. + ConnectionHandlerImpl( + base::TimeDelta read_timeout, + const ProtoReceivedCallback& read_callback, + const ProtoSentCallback& write_callback, + const ConnectionChangedCallback& connection_callback); + virtual ~ConnectionHandlerImpl(); + + // ConnectionHandler implementation. + virtual void Init(const mcs_proto::LoginRequest& login_request, + scoped_ptr<net::StreamSocket> socket) OVERRIDE; + virtual bool CanSendMessage() const OVERRIDE; + virtual void SendMessage(const google::protobuf::MessageLite& message) + OVERRIDE; + + private: + // State machine for handling incoming data. See WaitForData(..) for usage. + enum ProcessingState { + // Processing the version, tag, and size packets (assuming minimum length + // size packet). Only used during the login handshake. + MCS_VERSION_TAG_AND_SIZE = 0, + // Processing the tag and size packets (assuming minimum length size + // packet). Used for normal messages. + MCS_TAG_AND_SIZE, + // Processing a maximum length size packet (for messages with length > 128). + // Used when a normal size packet was not sufficient to read the message + // size. + MCS_FULL_SIZE, + // Processing the protocol buffer bytes (for those messages with non-zero + // sizes). + MCS_PROTO_BYTES + }; + + // Sends the protocol version and login request. First step in the MCS + // connection handshake. + void Login(const google::protobuf::MessageLite& login_request); + + // SendMessage continuation. Invoked when Socket::Write completes. + void OnMessageSent(); + + // Starts the message processing process, which is comprised of the tag, + // message size, and bytes packet types. + void GetNextMessage(); + + // Performs any necessary SocketInputStream refreshing until the data + // associated with |packet_type| is fully ready, then calls the appropriate + // OnGot* message to process the packet data. If the read times out, + // will close the stream and invoke the connection callback. + void WaitForData(ProcessingState state); + + // Incoming data helper methods. + void OnGotVersion(); + void OnGotMessageTag(); + void OnGotMessageSize(); + void OnGotMessageBytes(); + + // Timeout handler. + void OnTimeout(); + + // Closes the current connection. + void CloseConnection(); + + // Timeout policy: the timeout is only enforced while waiting on the + // handshake (version and/or LoginResponse) or once at least a tag packet has + // been received. It is reset every time new data is received, and is + // only stopped when a full message is processed. + // TODO(zea): consider enforcing a separate timeout when waiting for + // a message to send. + const base::TimeDelta read_timeout_; + base::OneShotTimer<ConnectionHandlerImpl> read_timeout_timer_; + + // This connection's socket and the input/output streams attached to it. + scoped_ptr<net::StreamSocket> socket_; + scoped_ptr<SocketInputStream> input_stream_; + scoped_ptr<SocketOutputStream> output_stream_; + + // Whether the MCS login handshake has successfully completed. See Init(..) + // description for more info on what the handshake involves. + bool handshake_complete_; + + // State for the message currently being processed, if there is one. + uint8 message_tag_; + uint32 message_size_; + + ProtoReceivedCallback read_callback_; + ProtoSentCallback write_callback_; + ConnectionChangedCallback connection_callback_; + + base::WeakPtrFactory<ConnectionHandlerImpl> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(ConnectionHandlerImpl); +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_IMPL_H_ diff --git a/google_apis/gcm/engine/connection_handler_unittest.cc b/google_apis/gcm/engine/connection_handler_impl_unittest.cc index d46c068..0cdcdc6 100644 --- a/google_apis/gcm/engine/connection_handler_unittest.cc +++ b/google_apis/gcm/engine/connection_handler_impl_unittest.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "google_apis/gcm/engine/connection_handler.h" +#include "google_apis/gcm/engine/connection_handler_impl.h" #include "base/bind.h" #include "base/memory/scoped_ptr.h" @@ -97,10 +97,10 @@ std::string BuildDataMessage(const std::string& from, return data_message.SerializeAsString(); } -class GCMConnectionHandlerTest : public testing::Test { +class GCMConnectionHandlerImplTest : public testing::Test { public: - GCMConnectionHandlerTest(); - virtual ~GCMConnectionHandlerTest(); + GCMConnectionHandlerImplTest(); + virtual ~GCMConnectionHandlerImplTest(); net::StreamSocket* BuildSocket(const ReadList& read_list, const WriteList& write_list); @@ -108,7 +108,9 @@ class GCMConnectionHandlerTest : public testing::Test { // Pump |message_loop_|, resetting |run_loop_| after completion. void PumpLoop(); - ConnectionHandler* connection_handler() { return &connection_handler_; } + ConnectionHandlerImpl* connection_handler() { + return connection_handler_.get(); + } base::MessageLoop* message_loop() { return &message_loop_; }; net::DelayedSocketData* data_provider() { return data_provider_.get(); } int last_error() const { return last_error_; } @@ -133,7 +135,7 @@ class GCMConnectionHandlerTest : public testing::Test { scoped_ptr<SocketOutputStream> socket_output_stream_; // The connection handler being tested. - ConnectionHandler connection_handler_; + scoped_ptr<ConnectionHandlerImpl> connection_handler_; // The last connection error received. int last_error_; @@ -147,18 +149,17 @@ class GCMConnectionHandlerTest : public testing::Test { scoped_ptr<base::RunLoop> run_loop_; }; -GCMConnectionHandlerTest::GCMConnectionHandlerTest() - : connection_handler_(TestTimeouts::tiny_timeout()), - last_error_(0) { +GCMConnectionHandlerImplTest::GCMConnectionHandlerImplTest() + : last_error_(0) { net::IPAddressNumber ip_number; net::ParseIPLiteralToNumber("127.0.0.1", &ip_number); address_list_ = net::AddressList::CreateFromIPAddress(ip_number, kMCSPort); } -GCMConnectionHandlerTest::~GCMConnectionHandlerTest() { +GCMConnectionHandlerImplTest::~GCMConnectionHandlerImplTest() { } -net::StreamSocket* GCMConnectionHandlerTest::BuildSocket( +net::StreamSocket* GCMConnectionHandlerImplTest::BuildSocket( const ReadList& read_list, const WriteList& write_list) { mock_reads_ = read_list; @@ -180,49 +181,51 @@ net::StreamSocket* GCMConnectionHandlerTest::BuildSocket( return socket_.get(); } -void GCMConnectionHandlerTest::PumpLoop() { +void GCMConnectionHandlerImplTest::PumpLoop() { run_loop_->RunUntilIdle(); run_loop_.reset(new base::RunLoop()); } -void GCMConnectionHandlerTest::Connect( +void GCMConnectionHandlerImplTest::Connect( ScopedMessage* dst_proto) { - connection_handler_.Init( - socket_.Pass(), - *BuildLoginRequest(kAuthId, kAuthToken), - base::Bind(&GCMConnectionHandlerTest::ReadContinuation, - base::Unretained(this), - dst_proto), - base::Bind(&GCMConnectionHandlerTest::WriteContinuation, - base::Unretained(this)), - base::Bind(&GCMConnectionHandlerTest::ConnectionContinuation, - base::Unretained(this))); + connection_handler_.reset(new ConnectionHandlerImpl( + TestTimeouts::tiny_timeout(), + base::Bind(&GCMConnectionHandlerImplTest::ReadContinuation, + base::Unretained(this), + dst_proto), + base::Bind(&GCMConnectionHandlerImplTest::WriteContinuation, + base::Unretained(this)), + base::Bind(&GCMConnectionHandlerImplTest::ConnectionContinuation, + base::Unretained(this)))); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + connection_handler_->Init(*BuildLoginRequest(kAuthId, kAuthToken), + socket_.Pass()); } -void GCMConnectionHandlerTest::ReadContinuation( +void GCMConnectionHandlerImplTest::ReadContinuation( ScopedMessage* dst_proto, ScopedMessage new_proto) { *dst_proto = new_proto.Pass(); run_loop_->Quit(); } -void GCMConnectionHandlerTest::WaitForMessage() { +void GCMConnectionHandlerImplTest::WaitForMessage() { run_loop_->Run(); run_loop_.reset(new base::RunLoop()); } -void GCMConnectionHandlerTest::WriteContinuation() { +void GCMConnectionHandlerImplTest::WriteContinuation() { run_loop_->Quit(); } -void GCMConnectionHandlerTest::ConnectionContinuation(int error) { +void GCMConnectionHandlerImplTest::ConnectionContinuation(int error) { last_error_ = error; run_loop_->Quit(); } // Initialize the connection handler and ensure the handshake completes // successfully. -TEST_F(GCMConnectionHandlerTest, Init) { +TEST_F(GCMConnectionHandlerImplTest, Init) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -234,7 +237,6 @@ TEST_F(GCMConnectionHandlerTest, Init) { BuildSocket(read_list, write_list); ScopedMessage received_message; - EXPECT_FALSE(connection_handler()->CanSendMessage()); Connect(&received_message); EXPECT_FALSE(connection_handler()->CanSendMessage()); WaitForMessage(); // The login send. @@ -246,7 +248,7 @@ TEST_F(GCMConnectionHandlerTest, Init) { // Simulate the handshake response returning an older version. Initialization // should fail. -TEST_F(GCMConnectionHandlerTest, InitFailedVersionCheck) { +TEST_F(GCMConnectionHandlerImplTest, InitFailedVersionCheck) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -270,7 +272,7 @@ TEST_F(GCMConnectionHandlerTest, InitFailedVersionCheck) { // Attempt to initialize, but receive no server response, resulting in a time // out. -TEST_F(GCMConnectionHandlerTest, InitTimeout) { +TEST_F(GCMConnectionHandlerImplTest, InitTimeout) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -290,7 +292,7 @@ TEST_F(GCMConnectionHandlerTest, InitTimeout) { // Attempt to initialize, but receive an incomplete server response, resulting // in a time out. -TEST_F(GCMConnectionHandlerTest, InitIncompleteTimeout) { +TEST_F(GCMConnectionHandlerImplTest, InitIncompleteTimeout) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -314,7 +316,7 @@ TEST_F(GCMConnectionHandlerTest, InitIncompleteTimeout) { } // Reinitialize the connection handler after failing to initialize. -TEST_F(GCMConnectionHandlerTest, ReInit) { +TEST_F(GCMConnectionHandlerImplTest, ReInit) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -347,7 +349,7 @@ TEST_F(GCMConnectionHandlerTest, ReInit) { } // Verify that messages can be received after initialization. -TEST_F(GCMConnectionHandlerTest, RecvMsg) { +TEST_F(GCMConnectionHandlerImplTest, RecvMsg) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -377,7 +379,7 @@ TEST_F(GCMConnectionHandlerTest, RecvMsg) { } // Verify that if two messages arrive at once, they're treated appropriately. -TEST_F(GCMConnectionHandlerTest, Recv2Msgs) { +TEST_F(GCMConnectionHandlerImplTest, Recv2Msgs) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -414,7 +416,7 @@ TEST_F(GCMConnectionHandlerTest, Recv2Msgs) { } // Receive a long (>128 bytes) message. -TEST_F(GCMConnectionHandlerTest, RecvLongMsg) { +TEST_F(GCMConnectionHandlerImplTest, RecvLongMsg) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -445,7 +447,7 @@ TEST_F(GCMConnectionHandlerTest, RecvLongMsg) { } // Receive two long (>128 bytes) message. -TEST_F(GCMConnectionHandlerTest, Recv2LongMsgs) { +TEST_F(GCMConnectionHandlerImplTest, Recv2LongMsgs) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -484,7 +486,7 @@ TEST_F(GCMConnectionHandlerTest, Recv2LongMsgs) { // Simulate a message where the end of the data does not arrive in time and the // read times out. -TEST_F(GCMConnectionHandlerTest, ReadTimeout) { +TEST_F(GCMConnectionHandlerImplTest, ReadTimeout) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -527,7 +529,7 @@ TEST_F(GCMConnectionHandlerTest, ReadTimeout) { } // Receive a message with zero data bytes. -TEST_F(GCMConnectionHandlerTest, RecvMsgNoData) { +TEST_F(GCMConnectionHandlerImplTest, RecvMsgNoData) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), @@ -558,7 +560,7 @@ TEST_F(GCMConnectionHandlerTest, RecvMsgNoData) { } // Send a message after performing the handshake. -TEST_F(GCMConnectionHandlerTest, SendMsg) { +TEST_F(GCMConnectionHandlerImplTest, SendMsg) { mcs_proto::DataMessageStanza data_message; data_message.set_from(kDataMsgFrom); data_message.set_category(kDataMsgCategory); @@ -592,7 +594,7 @@ TEST_F(GCMConnectionHandlerTest, SendMsg) { } // Attempt to send a message after the socket is disconnected due to a timeout. -TEST_F(GCMConnectionHandlerTest, SendMsgSocketDisconnected) { +TEST_F(GCMConnectionHandlerImplTest, SendMsgSocketDisconnected) { std::string handshake_request = EncodeHandshakeRequest(); WriteList write_list; write_list.push_back(net::MockWrite(net::ASYNC, diff --git a/google_apis/gcm/gcm.gyp b/google_apis/gcm/gcm.gyp index 52abd22..833f2c8 100644 --- a/google_apis/gcm/gcm.gyp +++ b/google_apis/gcm/gcm.gyp @@ -34,7 +34,8 @@ '../../components/components.gyp:encryptor', '../../net/net.gyp:net', '../../third_party/leveldatabase/leveldatabase.gyp:leveldatabase', - '../../third_party/protobuf/protobuf.gyp:protobuf_lite' + '../../third_party/protobuf/protobuf.gyp:protobuf_lite', + '../../url/url.gyp:url_lib', ], 'sources': [ 'base/mcs_message.h', @@ -43,8 +44,14 @@ 'base/mcs_util.cc', 'base/socket_stream.h', 'base/socket_stream.cc', + 'engine/connection_factory.h', + 'engine/connection_factory.cc', + 'engine/connection_factory_impl.h', + 'engine/connection_factory_impl.cc', 'engine/connection_handler.h', 'engine/connection_handler.cc', + 'engine/connection_handler_impl.h', + 'engine/connection_handler_impl.cc', 'engine/rmq_store.h', 'engine/rmq_store.cc', 'gcm_client.cc', @@ -78,7 +85,8 @@ 'sources': [ 'base/mcs_util_unittest.cc', 'base/socket_stream_unittest.cc', - 'engine/connection_handler_unittest.cc', + 'engine/connection_factory_impl_unittest.cc', + 'engine/connection_handler_impl_unittest.cc', 'engine/rmq_store_unittest.cc', ] }, diff --git a/net/socket/client_socket_pool_manager.cc b/net/socket/client_socket_pool_manager.cc index b37d2d1..24d6b70 100644 --- a/net/socket/client_socket_pool_manager.cc +++ b/net/socket/client_socket_pool_manager.cc @@ -437,6 +437,31 @@ int InitSocketHandleForRawConnect( callback); } +int InitSocketHandleForTlsConnect( + const HostPortPair& host_port_pair, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback) { + DCHECK(socket_handle); + // Synthesize an HttpRequestInfo. + GURL request_url = GURL("https://" + host_port_pair.ToString()); + HttpRequestHeaders request_extra_headers; + int request_load_flags = 0; + RequestPriority request_priority = MEDIUM; + + return InitSocketPoolHelper( + request_url, request_extra_headers, request_load_flags, request_priority, + session, proxy_info, false, false, ssl_config_for_origin, + ssl_config_for_proxy, true, privacy_mode, net_log, 0, socket_handle, + HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(), + callback); +} + int PreconnectSocketsForHttpRequest( const GURL& request_url, const HttpRequestHeaders& request_extra_headers, diff --git a/net/socket/client_socket_pool_manager.h b/net/socket/client_socket_pool_manager.h index 1b78324..1215480 100644 --- a/net/socket/client_socket_pool_manager.h +++ b/net/socket/client_socket_pool_manager.h @@ -147,6 +147,21 @@ NET_EXPORT int InitSocketHandleForRawConnect( ClientSocketHandle* socket_handle, const CompletionCallback& callback); +// A helper method that uses the passed in proxy information to initialize a +// ClientSocketHandle with the relevant socket pool. Use this method for +// a raw socket connection with TLS negotiation to a host-port pair (that needs +// to tunnel through the proxies). +NET_EXPORT int InitSocketHandleForTlsConnect( + const HostPortPair& host_port_pair, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback); + // Similar to InitSocketHandleForHttpRequest except that it initiates the // desired number of preconnect streams from the relevant socket pool. int PreconnectSocketsForHttpRequest( |