diff options
author | zea@chromium.org <zea@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-10-29 03:35:58 +0000 |
---|---|---|
committer | zea@chromium.org <zea@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-10-29 03:35:58 +0000 |
commit | eb9fa2a810020d6cb4e1b8252c04822055dfef4c (patch) | |
tree | bef0c93bc79471bfd8fa3eca380ef2b99d81777b /google_apis | |
parent | 28165927b8cf46ef9eb58b3c51ae1ae58df32114 (diff) | |
download | chromium_src-eb9fa2a810020d6cb4e1b8252c04822055dfef4c.zip chromium_src-eb9fa2a810020d6cb4e1b8252c04822055dfef4c.tar.gz chromium_src-eb9fa2a810020d6cb4e1b8252c04822055dfef4c.tar.bz2 |
[GCM] Add basic MCS connection logic
Introduce state machine for handling connections, as well as protobuf
definitions and util methods for dealing with protobufs.
BUG=284553
Review URL: https://codereview.chromium.org/27375002
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@231507 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'google_apis')
-rw-r--r-- | google_apis/gcm/base/mcs_util.cc | 220 | ||||
-rw-r--r-- | google_apis/gcm/base/mcs_util.h | 79 | ||||
-rw-r--r-- | google_apis/gcm/base/mcs_util_unittest.cc | 82 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler.cc | 401 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler.h | 145 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler_unittest.cc | 626 | ||||
-rw-r--r-- | google_apis/gcm/gcm.gyp | 19 | ||||
-rw-r--r-- | google_apis/gcm/protocol/mcs.proto | 269 |
8 files changed, 1840 insertions, 1 deletions
diff --git a/google_apis/gcm/base/mcs_util.cc b/google_apis/gcm/base/mcs_util.cc new file mode 100644 index 0000000..b52d429 --- /dev/null +++ b/google_apis/gcm/base/mcs_util.cc @@ -0,0 +1,220 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/base/mcs_util.h" + +#include "base/format_macros.h" +#include "base/logging.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/stringprintf.h" + +namespace gcm { + +namespace { + +// Type names corresponding to MCSProtoTags. Useful for identifying what type +// of MCS protobuf is contained within a google::protobuf::MessageLite object. +// WARNING: must match the order in MCSProtoTag. +const char* kProtoNames[] = { + "mcs_proto.HeartbeatPing", + "mcs_proto.HeartbeatAck", + "mcs_proto.LoginRequest", + "mcs_proto.LoginResponse", + "mcs_proto.Close", + "mcs_proto.MessageStanza", + "mcs_proto.PresenceStanza", + "mcs_proto.IqStanza", + "mcs_proto.DataMessageStanza", + "mcs_proto.BatchPresenceStanza", + "mcs_proto.StreamErrorStanza", + "mcs_proto.HttpRequest", + "mcs_proto.HttpResponse", + "mcs_proto.BindAccountRequest", + "mcs_proto.BindAccountResponse", + "mcs_proto.TalkMetadata" +}; +COMPILE_ASSERT(arraysize(kProtoNames) == kNumProtoTypes, + ProtoNamesMustIncludeAllTags); + +// TODO(zea): replace these with proper values. +const char kLoginId[] = "login-1"; +const char kLoginDomain[] = "mcs.android.com"; +const char kLoginDeviceIdPrefix[] = "android-"; +const char kLoginSettingName[] = "new_vc"; +const char kLoginSettingValue[] = "1"; + +} // namespace + +scoped_ptr<mcs_proto::LoginRequest> BuildLoginRequest( + uint64 auth_id, + uint64 auth_token) { + // Create a hex encoded auth id for the device id field. + std::string auth_id_hex; + auth_id_hex = base::StringPrintf("%" PRIx64, auth_id); + + std::string auth_id_str = base::Uint64ToString(auth_id); + std::string auth_token_str = base::Uint64ToString(auth_token); + + scoped_ptr<mcs_proto::LoginRequest> login_request( + new mcs_proto::LoginRequest()); + + // TODO(zea): set better values. + login_request->set_account_id(1000000); + login_request->set_adaptive_heartbeat(false); + login_request->set_auth_service(mcs_proto::LoginRequest::ANDROID_ID); + login_request->set_auth_token(auth_token_str); + login_request->set_id(kLoginId); + login_request->set_domain(kLoginDomain); + login_request->set_device_id(kLoginDeviceIdPrefix + auth_id_hex); + login_request->set_network_type(1); + login_request->set_resource(auth_id_str); + login_request->set_user(auth_id_str); + login_request->set_use_rmq2(true); + + login_request->add_setting(); + login_request->mutable_setting(0)->set_name(kLoginSettingName); + login_request->mutable_setting(0)->set_value(kLoginSettingValue); + return login_request.Pass(); +} + +scoped_ptr<mcs_proto::IqStanza> BuildStreamAck() { + scoped_ptr<mcs_proto::IqStanza> stream_ack_iq(new mcs_proto::IqStanza()); + stream_ack_iq->set_type(mcs_proto::IqStanza::SET); + stream_ack_iq->set_id(""); + stream_ack_iq->mutable_extension()->set_id(kStreamAck); + stream_ack_iq->mutable_extension()->set_data(""); + return stream_ack_iq.Pass(); +} + +// Utility method to build a google::protobuf::MessageLite object from a MCS +// tag. +scoped_ptr<google::protobuf::MessageLite> BuildProtobufFromTag(uint8 tag) { + switch(tag) { + case kHeartbeatPingTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::HeartbeatPing()); + case kHeartbeatAckTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::HeartbeatAck()); + case kLoginRequestTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::LoginRequest()); + case kLoginResponseTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::LoginResponse()); + case kCloseTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::Close()); + case kIqStanzaTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::IqStanza()); + case kDataMessageStanzaTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::DataMessageStanza()); + case kStreamErrorStanzaTag: + return scoped_ptr<google::protobuf::MessageLite>( + new mcs_proto::StreamErrorStanza()); + default: + return scoped_ptr<google::protobuf::MessageLite>(); + } +} + +// Utility method to extract a MCS tag from a google::protobuf::MessageLite +// object. +int GetMCSProtoTag(const google::protobuf::MessageLite& message) { + const std::string& type_name = message.GetTypeName(); + if (type_name == kProtoNames[kHeartbeatPingTag]) { + return kHeartbeatPingTag; + } else if (type_name == kProtoNames[kHeartbeatAckTag]) { + return kHeartbeatAckTag; + } else if (type_name == kProtoNames[kLoginRequestTag]) { + return kLoginRequestTag; + } else if (type_name == kProtoNames[kLoginResponseTag]) { + return kLoginResponseTag; + } else if (type_name == kProtoNames[kCloseTag]) { + return kCloseTag; + } else if (type_name == kProtoNames[kIqStanzaTag]) { + return kIqStanzaTag; + } else if (type_name == kProtoNames[kDataMessageStanzaTag]) { + return kDataMessageStanzaTag; + } else if (type_name == kProtoNames[kStreamErrorStanzaTag]) { + return kStreamErrorStanzaTag; + } + return -1; +} + +std::string GetPersistentId(const google::protobuf::MessageLite& protobuf) { + if (protobuf.GetTypeName() == kProtoNames[kIqStanzaTag]) { + return reinterpret_cast<const mcs_proto::IqStanza*>(&protobuf)-> + persistent_id(); + } else if (protobuf.GetTypeName() == kProtoNames[kDataMessageStanzaTag]) { + return reinterpret_cast<const mcs_proto::DataMessageStanza*>(&protobuf)-> + persistent_id(); + } + // Not all message types have persistent ids. Just return empty string; + return ""; +} + +void SetPersistentId(const std::string& persistent_id, + google::protobuf::MessageLite* protobuf) { + if (protobuf->GetTypeName() == kProtoNames[kIqStanzaTag]) { + reinterpret_cast<mcs_proto::IqStanza*>(protobuf)-> + set_persistent_id(persistent_id); + return; + } else if (protobuf->GetTypeName() == kProtoNames[kDataMessageStanzaTag]) { + reinterpret_cast<mcs_proto::DataMessageStanza*>(protobuf)-> + set_persistent_id(persistent_id); + return; + } + NOTREACHED(); +} + +uint32 GetLastStreamIdReceived(const google::protobuf::MessageLite& protobuf) { + if (protobuf.GetTypeName() == kProtoNames[kIqStanzaTag]) { + return reinterpret_cast<const mcs_proto::IqStanza*>(&protobuf)-> + last_stream_id_received(); + } else if (protobuf.GetTypeName() == kProtoNames[kDataMessageStanzaTag]) { + return reinterpret_cast<const mcs_proto::DataMessageStanza*>(&protobuf)-> + last_stream_id_received(); + } else if (protobuf.GetTypeName() == kProtoNames[kHeartbeatPingTag]) { + return reinterpret_cast<const mcs_proto::HeartbeatPing*>(&protobuf)-> + last_stream_id_received(); + } else if (protobuf.GetTypeName() == kProtoNames[kHeartbeatAckTag]) { + return reinterpret_cast<const mcs_proto::HeartbeatAck*>(&protobuf)-> + last_stream_id_received(); + } else if (protobuf.GetTypeName() == kProtoNames[kLoginResponseTag]) { + return reinterpret_cast<const mcs_proto::LoginResponse*>(&protobuf)-> + last_stream_id_received(); + } + // Not all message types have last stream ids. Just return 0. + return 0; +} + +void SetLastStreamIdReceived(uint32 val, + google::protobuf::MessageLite* protobuf) { + if (protobuf->GetTypeName() == kProtoNames[kIqStanzaTag]) { + reinterpret_cast<mcs_proto::IqStanza*>(protobuf)-> + set_last_stream_id_received(val); + return; + } else if (protobuf->GetTypeName() == kProtoNames[kHeartbeatPingTag]) { + reinterpret_cast<mcs_proto::HeartbeatPing*>(protobuf)-> + set_last_stream_id_received(val); + return; + } else if (protobuf->GetTypeName() == kProtoNames[kHeartbeatAckTag]) { + reinterpret_cast<mcs_proto::HeartbeatAck*>(protobuf)-> + set_last_stream_id_received(val); + return; + } else if (protobuf->GetTypeName() == kProtoNames[kDataMessageStanzaTag]) { + reinterpret_cast<mcs_proto::DataMessageStanza*>(protobuf)-> + set_last_stream_id_received(val); + return; + } else if (protobuf->GetTypeName() == kProtoNames[kLoginResponseTag]) { + reinterpret_cast<mcs_proto::LoginResponse*>(protobuf)-> + set_last_stream_id_received(val); + return; + } + NOTREACHED(); +} + +} // namespace gcm diff --git a/google_apis/gcm/base/mcs_util.h b/google_apis/gcm/base/mcs_util.h new file mode 100644 index 0000000..d125af7 --- /dev/null +++ b/google_apis/gcm/base/mcs_util.h @@ -0,0 +1,79 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Utility methods for MCS interactions. + +#ifndef GOOGLE_APIS_GCM_BASE_MCS_UTIL_H_ +#define GOOGLE_APIS_GCM_BASE_MCS_UTIL_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "google_apis/gcm/base/gcm_export.h" +#include "google_apis/gcm/protocol/mcs.pb.h" + +namespace net { +class StreamSocket; +} + +namespace gcm { + +// MCS Message tags. +// WARNING: the order of these tags must remain the same, as the tag values +// must be consistent with those used on the server. +enum MCSProtoTag { + kHeartbeatPingTag = 0, + kHeartbeatAckTag, + kLoginRequestTag, + kLoginResponseTag, + kCloseTag, + kMessageStanzaTag, + kPresenceStanzaTag, + kIqStanzaTag, + kDataMessageStanzaTag, + kBatchPresenceStanzaTag, + kStreamErrorStanzaTag, + kHttpRequestTag, + kHttpResponseTag, + kBindAccountRequestTag, + kBindAccountResponseTag, + kTalkMetadataTag, + kNumProtoTypes, +}; + +enum MCSIqStanzaExtension { + kSelectiveAck = 12, + kStreamAck = 13, +}; + +// Builds a LoginRequest with the hardcoded local data. +GCM_EXPORT scoped_ptr<mcs_proto::LoginRequest> BuildLoginRequest( + uint64 auth_id, + uint64 auth_token); + +// Builds a StreamAck IqStanza message. +GCM_EXPORT scoped_ptr<mcs_proto::IqStanza> BuildStreamAck(); + +// Utility methods for building and identifying MCS protobufs. +GCM_EXPORT scoped_ptr<google::protobuf::MessageLite> + BuildProtobufFromTag(uint8 tag); +GCM_EXPORT int GetMCSProtoTag(const google::protobuf::MessageLite& message); + +// RMQ utility methods for extracting/setting common data from/to protobufs. +GCM_EXPORT std::string GetPersistentId( + const google::protobuf::MessageLite& message); +GCM_EXPORT void SetPersistentId( + const std::string& persistent_id, + google::protobuf::MessageLite* message); +GCM_EXPORT uint32 GetLastStreamIdReceived( + const google::protobuf::MessageLite& protobuf); +GCM_EXPORT void SetLastStreamIdReceived( + uint32 last_stream_id_received, + google::protobuf::MessageLite* protobuf); + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_BASE_MCS_UTIL_H_ diff --git a/google_apis/gcm/base/mcs_util_unittest.cc b/google_apis/gcm/base/mcs_util_unittest.cc new file mode 100644 index 0000000..d259145 --- /dev/null +++ b/google_apis/gcm/base/mcs_util_unittest.cc @@ -0,0 +1,82 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/base/mcs_util.h" + +#include "base/bind.h" +#include "base/memory/scoped_ptr.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace gcm { +namespace { + +const uint64 kAuthId = 4421448356646222460; +const uint64 kAuthToken = 12345; + +// Build a login request protobuf. +TEST(MCSUtilTest, BuildLoginRequest) { + scoped_ptr<mcs_proto::LoginRequest> login_request = + BuildLoginRequest(kAuthId, kAuthToken); + ASSERT_EQ("login-1", login_request->id()); + ASSERT_EQ(base::Uint64ToString(kAuthToken), login_request->auth_token()); + ASSERT_EQ(base::Uint64ToString(kAuthId), login_request->user()); + ASSERT_EQ("android-3d5c23dac2a1fa7c", login_request->device_id()); + // TODO(zea): test the other fields once they have valid values. +} + +// Test building a protobuf and extracting the tag from a protobuf. +TEST(MCSUtilTest, ProtobufToTag) { + for (size_t i = 0; i < kNumProtoTypes; ++i) { + scoped_ptr<google::protobuf::MessageLite> protobuf = + BuildProtobufFromTag(i); + if (!protobuf.get()) // Not all tags have protobuf definitions. + continue; + ASSERT_EQ((int)i, GetMCSProtoTag(*protobuf)) << "Type " << i; + } +} + +// Test getting and setting persistent ids. +TEST(MCSUtilTest, PersistentIds) { + COMPILE_ASSERT(kNumProtoTypes == 16U, UpdatePersistentIds); + const int kTagsWithPersistentIds[] = { + kIqStanzaTag, + kDataMessageStanzaTag + }; + for (size_t i = 0; i < arraysize(kTagsWithPersistentIds); ++i) { + int tag = kTagsWithPersistentIds[i]; + scoped_ptr<google::protobuf::MessageLite> protobuf = + BuildProtobufFromTag(tag); + ASSERT_TRUE(protobuf.get()); + SetPersistentId(base::IntToString(tag), protobuf.get()); + int get_val = 0; + base::StringToInt(GetPersistentId(*protobuf), &get_val); + ASSERT_EQ(tag, get_val); + } +} + +// Test getting and setting stream ids. +TEST(MCSUtilTest, StreamIds) { + COMPILE_ASSERT(kNumProtoTypes == 16U, UpdateStreamIds); + const int kTagsWithStreamIds[] = { + kIqStanzaTag, + kDataMessageStanzaTag, + kHeartbeatPingTag, + kHeartbeatAckTag, + kLoginResponseTag, + }; + for (size_t i = 0; i < arraysize(kTagsWithStreamIds); ++i) { + int tag = kTagsWithStreamIds[i]; + scoped_ptr<google::protobuf::MessageLite> protobuf = + BuildProtobufFromTag(tag); + ASSERT_TRUE(protobuf.get()); + SetLastStreamIdReceived(tag, protobuf.get()); + int get_id = GetLastStreamIdReceived(*protobuf); + ASSERT_EQ(tag, get_id); + } +} + +} // namespace +} // namespace gcm diff --git a/google_apis/gcm/engine/connection_handler.cc b/google_apis/gcm/engine/connection_handler.cc new file mode 100644 index 0000000..b4eb602 --- /dev/null +++ b/google_apis/gcm/engine/connection_handler.cc @@ -0,0 +1,401 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/connection_handler.h" + +#include "base/message_loop/message_loop.h" +#include "google/protobuf/io/coded_stream.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/base/socket_stream.h" +#include "net/base/net_errors.h" +#include "net/socket/stream_socket.h" + +using namespace google::protobuf::io; + +namespace gcm { + +namespace { + +// # of bytes a MCS version packet consumes. +const int kVersionPacketLen = 1; +// # of bytes a tag packet consumes. +const int kTagPacketLen = 1; +// Max # of bytes a length packet consumes. +const int kSizePacketLenMin = 1; +const int kSizePacketLenMax = 2; + +// The current MCS protocol version. +const int kMCSVersion = 38; + +} // namespace + +ConnectionHandler::ConnectionHandler(base::TimeDelta read_timeout) + : read_timeout_(read_timeout), + handshake_complete_(false), + message_tag_(0), + message_size_(0), + weak_ptr_factory_(this) { +} + +ConnectionHandler::~ConnectionHandler() { +} + +void ConnectionHandler::Init( + scoped_ptr<net::StreamSocket> socket, + const google::protobuf::MessageLite& login_request, + const ProtoReceivedCallback& read_callback, + const ProtoSentCallback& write_callback, + const ConnectionChangedCallback& connection_callback) { + DCHECK(!read_callback.is_null()); + DCHECK(!write_callback.is_null()); + DCHECK(!connection_callback.is_null()); + + // Invalidate any previously outstanding reads. + weak_ptr_factory_.InvalidateWeakPtrs(); + + handshake_complete_ = false; + message_tag_ = 0; + message_size_ = 0; + socket_ = socket.Pass(); + input_stream_.reset(new SocketInputStream(socket_.get())); + output_stream_.reset(new SocketOutputStream(socket_.get())); + read_callback_ = read_callback; + write_callback_ = write_callback; + connection_callback_ = connection_callback; + + Login(login_request); +} + +bool ConnectionHandler::CanSendMessage() const { + return handshake_complete_ && output_stream_.get() && + output_stream_->GetState() == SocketOutputStream::EMPTY; +} + +void ConnectionHandler::SendMessage( + const google::protobuf::MessageLite& message) { + DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); + DCHECK(handshake_complete_); + + { + CodedOutputStream coded_output_stream(output_stream_.get()); + DVLOG(1) << "Writing proto of size " << message.ByteSize(); + int tag = GetMCSProtoTag(message); + DCHECK_NE(tag, -1); + coded_output_stream.WriteRaw(&tag, 1); + coded_output_stream.WriteVarint32(message.ByteSize()); + message.SerializeToCodedStream(&coded_output_stream); + } + + if (output_stream_->Flush( + base::Bind(&ConnectionHandler::OnMessageSent, + weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { + OnMessageSent(); + } +} + +void ConnectionHandler::Login( + const google::protobuf::MessageLite& login_request) { + DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); + + const char version_byte[1] = {kMCSVersion}; + const char login_request_tag[1] = {kLoginRequestTag}; + { + CodedOutputStream coded_output_stream(output_stream_.get()); + coded_output_stream.WriteRaw(version_byte, 1); + coded_output_stream.WriteRaw(login_request_tag, 1); + coded_output_stream.WriteVarint32(login_request.ByteSize()); + login_request.SerializeToCodedStream(&coded_output_stream); + } + + if (output_stream_->Flush( + base::Bind(&ConnectionHandler::OnMessageSent, + weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ConnectionHandler::OnMessageSent, + weak_ptr_factory_.GetWeakPtr())); + } + + read_timeout_timer_.Start(FROM_HERE, + read_timeout_, + base::Bind(&ConnectionHandler::OnTimeout, + weak_ptr_factory_.GetWeakPtr())); + WaitForData(MCS_VERSION_TAG_AND_SIZE); +} + +void ConnectionHandler::OnMessageSent() { + if (!output_stream_.get()) { + // The connection has already been closed. Just return. + DCHECK(!input_stream_.get()); + DCHECK(!read_timeout_timer_.IsRunning()); + return; + } + + if (output_stream_->GetState() != SocketOutputStream::EMPTY) { + int last_error = output_stream_->last_error(); + CloseConnection(); + // If the socket stream had an error, plumb it up, else plumb up FAILED. + if (last_error == net::OK) + last_error = net::ERR_FAILED; + connection_callback_.Run(last_error); + return; + } + + write_callback_.Run(); +} + +void ConnectionHandler::GetNextMessage() { + DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || + SocketInputStream::READY == input_stream_->GetState()); + message_tag_ = 0; + message_size_ = 0; + + WaitForData(MCS_TAG_AND_SIZE); +} + +void ConnectionHandler::WaitForData(ProcessingState state) { + DVLOG(1) << "Waiting for MCS data: state == " << state; + + if (!input_stream_) { + // The connection has already been closed. Just return. + DCHECK(!output_stream_.get()); + DCHECK(!read_timeout_timer_.IsRunning()); + return; + } + + if (input_stream_->GetState() != SocketInputStream::EMPTY && + input_stream_->GetState() != SocketInputStream::READY) { + // An error occurred. + int last_error = output_stream_->last_error(); + CloseConnection(); + // If the socket stream had an error, plumb it up, else plumb up FAILED. + if (last_error == net::OK) + last_error = net::ERR_FAILED; + connection_callback_.Run(last_error); + return; + } + + // Used to determine whether a Socket::Read is necessary. + int min_bytes_needed = 0; + // Used to limit the size of the Socket::Read. + int max_bytes_needed = 0; + + switch(state) { + case MCS_VERSION_TAG_AND_SIZE: + min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; + max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; + break; + case MCS_TAG_AND_SIZE: + min_bytes_needed = kTagPacketLen + kSizePacketLenMin; + max_bytes_needed = kTagPacketLen + kSizePacketLenMax; + break; + case MCS_FULL_SIZE: + // If in this state, the minimum size packet length must already have been + // insufficient, so set both to the max length. + min_bytes_needed = kSizePacketLenMax; + max_bytes_needed = kSizePacketLenMax; + break; + case MCS_PROTO_BYTES: + read_timeout_timer_.Reset(); + // No variability in the message size, set both to the same. + min_bytes_needed = message_size_; + max_bytes_needed = message_size_; + break; + default: + NOTREACHED(); + } + DCHECK_GE(max_bytes_needed, min_bytes_needed); + + int byte_count = input_stream_->UnreadByteCount(); + if (min_bytes_needed - byte_count > 0 && + input_stream_->Refresh( + base::Bind(&ConnectionHandler::WaitForData, + weak_ptr_factory_.GetWeakPtr(), + state), + max_bytes_needed - byte_count) == net::ERR_IO_PENDING) { + return; + } + + // Check for refresh errors. + if (input_stream_->GetState() != SocketInputStream::READY) { + // An error occurred. + int last_error = output_stream_->last_error(); + CloseConnection(); + // If the socket stream had an error, plumb it up, else plumb up FAILED. + if (last_error == net::OK) + last_error = net::ERR_FAILED; + connection_callback_.Run(last_error); + return; + } + + // Received enough bytes, process them. + DVLOG(1) << "Processing MCS data: state == " << state; + switch(state) { + case MCS_VERSION_TAG_AND_SIZE: + OnGotVersion(); + break; + case MCS_TAG_AND_SIZE: + OnGotMessageTag(); + break; + case MCS_FULL_SIZE: + OnGotMessageSize(); + break; + case MCS_PROTO_BYTES: + OnGotMessageBytes(); + break; + default: + NOTREACHED(); + } +} + +void ConnectionHandler::OnGotVersion() { + uint8 version = 0; + { + CodedInputStream coded_input_stream(input_stream_.get()); + coded_input_stream.ReadRaw(&version, 1); + } + if (version < kMCSVersion) { + LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); + connection_callback_.Run(net::ERR_FAILED); + return; + } + + input_stream_->RebuildBuffer(); + + // Process the LoginResponse message tag. + OnGotMessageTag(); +} + +void ConnectionHandler::OnGotMessageTag() { + if (input_stream_->GetState() != SocketInputStream::READY) { + LOG(ERROR) << "Failed to receive protobuf tag."; + read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); + return; + } + + { + CodedInputStream coded_input_stream(input_stream_.get()); + coded_input_stream.ReadRaw(&message_tag_, 1); + } + + DVLOG(1) << "Received proto of type " + << static_cast<unsigned int>(message_tag_); + + if (!read_timeout_timer_.IsRunning()) { + read_timeout_timer_.Start(FROM_HERE, + read_timeout_, + base::Bind(&ConnectionHandler::OnTimeout, + weak_ptr_factory_.GetWeakPtr())); + } + OnGotMessageSize(); +} + +void ConnectionHandler::OnGotMessageSize() { + if (input_stream_->GetState() != SocketInputStream::READY) { + LOG(ERROR) << "Failed to receive message size."; + read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); + return; + } + + bool need_another_byte = false; + int prev_byte_count = input_stream_->ByteCount(); + { + CodedInputStream coded_input_stream(input_stream_.get()); + if (!coded_input_stream.ReadVarint32(&message_size_)) + need_another_byte = true; + } + + if (need_another_byte) { + DVLOG(1) << "Expecting another message size byte."; + if (prev_byte_count >= kSizePacketLenMax) { + // Already had enough bytes, something else went wrong. + LOG(ERROR) << "Failed to process message size."; + read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); + return; + } + // Back up by the amount read (should always be 1 byte). + int bytes_read = prev_byte_count - input_stream_->ByteCount(); + DCHECK_EQ(bytes_read, 1); + input_stream_->BackUp(bytes_read); + WaitForData(MCS_FULL_SIZE); + return; + } + + DVLOG(1) << "Proto size: " << message_size_; + + if (message_size_ > 0) + WaitForData(MCS_PROTO_BYTES); + else + OnGotMessageBytes(); +} + +void ConnectionHandler::OnGotMessageBytes() { + read_timeout_timer_.Stop(); + scoped_ptr<google::protobuf::MessageLite> protobuf( + BuildProtobufFromTag(message_tag_)); + // Messages with no content are valid; just use the default protobuf for + // that tag. + if (protobuf.get() && message_size_ == 0) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ConnectionHandler::GetNextMessage, + weak_ptr_factory_.GetWeakPtr())); + read_callback_.Run(protobuf.Pass()); + return; + } + + if (!protobuf.get() || + input_stream_->GetState() != SocketInputStream::READY) { + LOG(ERROR) << "Failed to extract protobuf bytes of type " + << static_cast<unsigned int>(message_tag_); + protobuf.reset(); // Return a null pointer to denote an error. + read_callback_.Run(protobuf.Pass()); + return; + } + + { + CodedInputStream coded_input_stream(input_stream_.get()); + if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { + NOTREACHED() << "Unable to parse GCM message of type " + << static_cast<unsigned int>(message_tag_); + protobuf.reset(); // Return a null pointer to denote an error. + read_callback_.Run(protobuf.Pass()); + return; + } + } + + input_stream_->RebuildBuffer(); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ConnectionHandler::GetNextMessage, + weak_ptr_factory_.GetWeakPtr())); + if (message_tag_ == kLoginResponseTag) { + if (handshake_complete_) { + LOG(ERROR) << "Unexpected login response."; + } else { + handshake_complete_ = true; + DVLOG(1) << "GCM Handshake complete."; + } + } + read_callback_.Run(protobuf.Pass()); +} + +void ConnectionHandler::OnTimeout() { + LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; + CloseConnection(); + connection_callback_.Run(net::ERR_TIMED_OUT); +} + +void ConnectionHandler::CloseConnection() { + DVLOG(1) << "Closing connection."; + read_callback_.Reset(); + write_callback_.Reset(); + read_timeout_timer_.Stop(); + socket_->Disconnect(); + input_stream_.reset(); + output_stream_.reset(); + weak_ptr_factory_.InvalidateWeakPtrs(); +} + +} // namespace gcm diff --git a/google_apis/gcm/engine/connection_handler.h b/google_apis/gcm/engine/connection_handler.h new file mode 100644 index 0000000..6dd838c --- /dev/null +++ b/google_apis/gcm/engine/connection_handler.h @@ -0,0 +1,145 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_H_ +#define GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_H_ + +#include "base/memory/weak_ptr.h" +#include "base/timer/timer.h" +#include "google_apis/gcm/base/gcm_export.h" +#include "google_apis/gcm/protocol/mcs.pb.h" + +namespace net{ +class StreamSocket; +} + +namespace gcm { + +class SocketInputStream; +class SocketOutputStream; + +// Handles performing the protocol handshake and sending/receiving protobuf +// messages. Note that no retrying or queueing is enforced at this layer. +// Once a connection error is encountered, the ConnectionHandler will disconnect +// the socket and must be reinitialized with a new StreamSocket before +// messages can be sent/received again. +class GCM_EXPORT ConnectionHandler { + public: + typedef base::Callback<void(scoped_ptr<google::protobuf::MessageLite>)> + ProtoReceivedCallback; + typedef base::Closure ProtoSentCallback; + typedef base::Callback<void(int)> ConnectionChangedCallback; + + explicit ConnectionHandler(base::TimeDelta read_timeout); + ~ConnectionHandler(); + + // Starts a new MCS connection handshake (using |login_request|) and, upon + // success, begins listening for incoming/outgoing messages. A successful + // handshake is when a mcs_proto::LoginResponse is received, and is signaled + // via the |read_callback|. + // Outputs: + // |read_callback| will be invoked with the contents of any received protobuf + // message. + // |write_callback| will be invoked anytime a message has been successfully + // sent. Note: this just means the data was sent to the wire, not that the + // other end received it. + // |connection_callback| will be invoked with any fatal read/write errors + // encountered. + // + // Note: It is correct and expected to call Init more than once, as connection + // issues are encountered and new connections must be made. + void Init(scoped_ptr<net::StreamSocket> socket, + const google::protobuf::MessageLite& login_request, + const ProtoReceivedCallback& read_callback, + const ProtoSentCallback& write_callback, + const ConnectionChangedCallback& connection_callback); + + // Checks that a handshake has been completed and a message is not already + // in flight. + bool CanSendMessage() const; + + // Send an MCS protobuf message. CanSendMessage() must be true. + void SendMessage(const google::protobuf::MessageLite& message); + + private: + // State machine for handling incoming data. See WaitForData(..) for usage. + enum ProcessingState { + // Processing the version, tag, and size packets (assuming minimum length + // size packet). Only used during the login handshake. + MCS_VERSION_TAG_AND_SIZE = 0, + // Processing the tag and size packets (assuming minimum length size + // packet). Used for normal messages. + MCS_TAG_AND_SIZE, + // Processing a maximum length size packet (for messages with length > 128). + // Used when a normal size packet was not sufficient to read the message + // size. + MCS_FULL_SIZE, + // Processing the protocol buffer bytes (for those messages with non-zero + // sizes). + MCS_PROTO_BYTES + }; + + // Sends the protocol version and login request. First step in the MCS + // connection handshake. + void Login(const google::protobuf::MessageLite& login_request); + + // SendMessage continuation. Invoked when Socket::Write completes. + void OnMessageSent(); + + // Starts the message processing process, which is comprised of the tag, + // message size, and bytes packet types. + void GetNextMessage(); + + // Performs any necessary SocketInputStream refreshing until the data + // associated with |packet_type| is fully ready, then calls the appropriate + // OnGot* message to process the packet data. If the read times out, + // will close the stream and invoke the connection callback. + void WaitForData(ProcessingState state); + + // Incoming data helper methods. + void OnGotVersion(); + void OnGotMessageTag(); + void OnGotMessageSize(); + void OnGotMessageBytes(); + + // Timeout handler. + void OnTimeout(); + + // Closes the current connection. + void CloseConnection(); + + // Timeout policy: the timeout is only enforced while waiting on the + // handshake (version and/or LoginResponse) or once at least a tag packet has + // been received. It is reset every time new data is received, and is + // only stopped when a full message is processed. + // TODO(zea): consider enforcing a separate timeout when waiting for + // a message to send. + const base::TimeDelta read_timeout_; + base::OneShotTimer<ConnectionHandler> read_timeout_timer_; + + // This connection's socket and the input/output streams attached to it. + scoped_ptr<net::StreamSocket> socket_; + scoped_ptr<SocketInputStream> input_stream_; + scoped_ptr<SocketOutputStream> output_stream_; + + // Whether the MCS login handshake has successfully completed. See Init(..) + // description for more info on what the handshake involves. + bool handshake_complete_; + + // State for the message currently being processed, if there is one. + uint8 message_tag_; + uint32 message_size_; + + ProtoReceivedCallback read_callback_; + ProtoSentCallback write_callback_; + ConnectionChangedCallback connection_callback_; + + base::WeakPtrFactory<ConnectionHandler> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(ConnectionHandler); +}; + +} // namespace gcm + +#endif // GOOGLE_APIS_GCM_ENGINE_CONNECTION_HANDLER_H_ diff --git a/google_apis/gcm/engine/connection_handler_unittest.cc b/google_apis/gcm/engine/connection_handler_unittest.cc new file mode 100644 index 0000000..19deaf2 --- /dev/null +++ b/google_apis/gcm/engine/connection_handler_unittest.cc @@ -0,0 +1,626 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "google_apis/gcm/engine/connection_handler.h" + +#include "base/bind.h" +#include "base/memory/scoped_ptr.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "base/test/test_timeouts.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google_apis/gcm/base/mcs_util.h" +#include "google_apis/gcm/base/socket_stream.h" +#include "google_apis/gcm/protocol/mcs.pb.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/stream_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace gcm { +namespace { + +typedef scoped_ptr<google::protobuf::MessageLite> ScopedMessage; +typedef std::vector<net::MockRead> ReadList; +typedef std::vector<net::MockWrite> WriteList; + +const uint64 kAuthId = 54321; +const uint64 kAuthToken = 12345; +const char kMCSVersion = 38; // The protocol version. +const int kMCSPort = 5228; // The server port. +const char kDataMsgFrom[] = "data_from"; +const char kDataMsgCategory[] = "data_category"; +const char kDataMsgFrom2[] = "data_from2"; +const char kDataMsgCategory2[] = "data_category2"; +const char kDataMsgFromLong[] = + "this is a long from that will result in a message > 128 bytes"; +const char kDataMsgCategoryLong[] = + "this is a long category that will result in a message > 128 bytes"; +const char kDataMsgFromLong2[] = + "this is a second long from that will result in a message > 128 bytes"; +const char kDataMsgCategoryLong2[] = + "this is a second long category that will result in a message > 128 bytes"; + +// ---- Helpers for building messages. ---- + +// Encode a protobuf packet with protobuf type |tag| and serialized protobuf +// bytes |proto| into the MCS message form (tag + varint size + bytes). +std::string EncodePacket(uint8 tag, const std::string& proto) { + std::string result; + google::protobuf::io::StringOutputStream string_output_stream(&result); + google::protobuf::io::CodedOutputStream coded_output_stream( + &string_output_stream); + const char tag_byte[1] = {tag}; + coded_output_stream.WriteRaw(tag_byte, 1); + coded_output_stream.WriteVarint32(proto.size()); + coded_output_stream.WriteRaw(proto.c_str(), proto.size()); + return result; +} + +// Encode a handshake request into the MCS message form. +std::string EncodeHandshakeRequest() { + std::string result; + const char version_byte[1] = {kMCSVersion}; + result.append(version_byte, 1); + ScopedMessage login_request(BuildLoginRequest(kAuthId, kAuthToken)); + result.append(EncodePacket(kLoginRequestTag, + login_request->SerializeAsString())); + return result; +} + +// Build a serialized login response protobuf. +std::string BuildLoginResponse() { + std::string result; + mcs_proto::LoginResponse login_response; + login_response.set_id("id"); + result.append(login_response.SerializeAsString()); + return result; +} + +// Encoode a handshake response into the MCS message form. +std::string EncodeHandshakeResponse() { + std::string result; + const char version_byte[1] = {kMCSVersion}; + result.append(version_byte, 1); + result.append(EncodePacket(kLoginResponseTag, BuildLoginResponse())); + return result; +} + +// Build a serialized data message stanza protobuf. +std::string BuildDataMessage(const std::string& from, + const std::string& category) { + std::string result; + mcs_proto::DataMessageStanza data_message; + data_message.set_from(from); + data_message.set_category(category); + return data_message.SerializeAsString(); +} + +class GCMConnectionHandlerTest : public testing::Test { + public: + GCMConnectionHandlerTest(); + virtual ~GCMConnectionHandlerTest(); + + net::StreamSocket* BuildSocket(const ReadList& read_list, + const WriteList& write_list); + + // Pump |message_loop_|, resetting |run_loop_| after completion. + void PumpLoop(); + + ConnectionHandler* connection_handler() { return &connection_handler_; } + base::MessageLoop* message_loop() { return &message_loop_; }; + net::DelayedSocketData* data_provider() { return data_provider_.get(); } + int last_error() const { return last_error_; } + + // Initialize the connection handler, setting |dst_proto| as the destination + // for any received messages. + void Connect(ScopedMessage* dst_proto); + + // Runs the message loop until a message is received. + void WaitForMessage(); + + private: + void ReadContinuation(ScopedMessage* dst_proto, ScopedMessage new_proto); + void WriteContinuation(); + void ConnectionContinuation(int error); + + // SocketStreams and their data provider. + ReadList mock_reads_; + WriteList mock_writes_; + scoped_ptr<net::DelayedSocketData> data_provider_; + scoped_ptr<SocketInputStream> socket_input_stream_; + scoped_ptr<SocketOutputStream> socket_output_stream_; + + // The connection handler being tested. + ConnectionHandler connection_handler_; + + // The last connection error received. + int last_error_; + + // net:: components. + scoped_ptr<net::StreamSocket> socket_; + net::MockClientSocketFactory socket_factory_; + net::AddressList address_list_; + + base::MessageLoopForIO message_loop_; + scoped_ptr<base::RunLoop> run_loop_; +}; + +GCMConnectionHandlerTest::GCMConnectionHandlerTest() + : connection_handler_(TestTimeouts::tiny_timeout()), + last_error_(0) { + net::IPAddressNumber ip_number; + net::ParseIPLiteralToNumber("127.0.0.1", &ip_number); + address_list_ = net::AddressList::CreateFromIPAddress(ip_number, kMCSPort); +} + +GCMConnectionHandlerTest::~GCMConnectionHandlerTest() { +} + +net::StreamSocket* GCMConnectionHandlerTest::BuildSocket( + const ReadList& read_list, + const WriteList& write_list) { + mock_reads_ = read_list; + mock_writes_ = write_list; + data_provider_.reset( + new net::DelayedSocketData(0, + &(mock_reads_[0]), mock_reads_.size(), + &(mock_writes_[0]), mock_writes_.size())); + socket_factory_.AddSocketDataProvider(data_provider_.get()); + + socket_ = socket_factory_.CreateTransportClientSocket( + address_list_, NULL, net::NetLog::Source()); + socket_->Connect(net::CompletionCallback()); + + run_loop_.reset(new base::RunLoop()); + PumpLoop(); + + DCHECK(socket_->IsConnected()); + return socket_.get(); +} + +void GCMConnectionHandlerTest::PumpLoop() { + run_loop_->RunUntilIdle(); + run_loop_.reset(new base::RunLoop()); +} + +void GCMConnectionHandlerTest::Connect( + ScopedMessage* dst_proto) { + connection_handler_.Init( + socket_.Pass(), + *BuildLoginRequest(kAuthId, kAuthToken), + base::Bind(&GCMConnectionHandlerTest::ReadContinuation, + base::Unretained(this), + dst_proto), + base::Bind(&GCMConnectionHandlerTest::WriteContinuation, + base::Unretained(this)), + base::Bind(&GCMConnectionHandlerTest::ConnectionContinuation, + base::Unretained(this))); +} + +void GCMConnectionHandlerTest::ReadContinuation( + ScopedMessage* dst_proto, + ScopedMessage new_proto) { + *dst_proto = new_proto.Pass(); + run_loop_->Quit(); +} + +void GCMConnectionHandlerTest::WaitForMessage() { + run_loop_->Run(); + run_loop_.reset(new base::RunLoop()); +} + +void GCMConnectionHandlerTest::WriteContinuation() { + run_loop_->Quit(); +} + +void GCMConnectionHandlerTest::ConnectionContinuation(int error) { + last_error_ = error; + run_loop_->Quit(); +} + +// Initialize the connection handler and ensure the handshake completes +// successfully. +TEST_F(GCMConnectionHandlerTest, Init) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + ReadList read_list(1, net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + EXPECT_FALSE(connection_handler()->CanSendMessage()); + Connect(&received_message); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(BuildLoginResponse(), received_message->SerializeAsString()); + EXPECT_TRUE(connection_handler()->CanSendMessage()); +} + +// Simulate the handshake response returning an older version. Initialization +// should fail. +TEST_F(GCMConnectionHandlerTest, InitFailedVersionCheck) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + // Overwrite the version byte. + handshake_response[0] = 37; + ReadList read_list(1, net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. Should result in a connection error. + EXPECT_FALSE(received_message.get()); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + EXPECT_EQ(net::ERR_FAILED, last_error()); +} + +// Attempt to initialize, but receive no server response, resulting in a time +// out. +TEST_F(GCMConnectionHandlerTest, InitTimeout) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + ReadList read_list(1, net::MockRead(net::SYNCHRONOUS, + net::ERR_IO_PENDING)); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. Should result in a connection error. + EXPECT_FALSE(received_message.get()); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + EXPECT_EQ(net::ERR_TIMED_OUT, last_error()); +} + +// Attempt to initialize, but receive an incomplete server response, resulting +// in a time out. +TEST_F(GCMConnectionHandlerTest, InitIncompleteTimeout) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size() / 2)); + read_list.push_back(net::MockRead(net::SYNCHRONOUS, + net::ERR_IO_PENDING)); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. Should result in a connection error. + EXPECT_FALSE(received_message.get()); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + EXPECT_EQ(net::ERR_TIMED_OUT, last_error()); +} + +// Reinitialize the connection handler after failing to initialize. +TEST_F(GCMConnectionHandlerTest, ReInit) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + ReadList read_list(1, net::MockRead(net::SYNCHRONOUS, + net::ERR_IO_PENDING)); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. Should result in a connection error. + EXPECT_FALSE(received_message.get()); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + EXPECT_EQ(net::ERR_TIMED_OUT, last_error()); + + // Build a new socket and reconnect, successfully this time. + std::string handshake_response = EncodeHandshakeResponse(); + read_list[0] = net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size()); + BuildSocket(read_list, write_list); + Connect(&received_message); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(BuildLoginResponse(), received_message->SerializeAsString()); + EXPECT_TRUE(connection_handler()->CanSendMessage()); +} + +// Verify that messages can be received after initialization. +TEST_F(GCMConnectionHandlerTest, RecvMsg) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + + std::string data_message_proto = BuildDataMessage(kDataMsgFrom, + kDataMsgCategory); + std::string data_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::ASYNC, + data_message_pkt.c_str(), + data_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + WaitForMessage(); // The data message. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(data_message_proto, received_message->SerializeAsString()); +} + +// Verify that if two messages arrive at once, they're treated appropriately. +TEST_F(GCMConnectionHandlerTest, Recv2Msgs) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + + std::string data_message_proto = BuildDataMessage(kDataMsgFrom, + kDataMsgCategory); + std::string data_message_proto2 = BuildDataMessage(kDataMsgFrom2, + kDataMsgCategory2); + std::string data_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + data_message_pkt += EncodePacket(kDataMessageStanzaTag, data_message_proto2); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::SYNCHRONOUS, + data_message_pkt.c_str(), + data_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + WaitForMessage(); // The first data message. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(data_message_proto, received_message->SerializeAsString()); + received_message.reset(); + WaitForMessage(); // The second data message. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(data_message_proto2, received_message->SerializeAsString()); +} + +// Receive a long (>128 bytes) message. +TEST_F(GCMConnectionHandlerTest, RecvLongMsg) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + + std::string data_message_proto = + BuildDataMessage(kDataMsgFromLong, kDataMsgCategoryLong); + std::string data_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + DCHECK_GT(data_message_pkt.size(), 128U); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::ASYNC, + data_message_pkt.c_str(), + data_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + WaitForMessage(); // The data message. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(data_message_proto, received_message->SerializeAsString()); +} + +// Receive two long (>128 bytes) message. +TEST_F(GCMConnectionHandlerTest, Recv2LongMsgs) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + + std::string data_message_proto = + BuildDataMessage(kDataMsgFromLong, kDataMsgCategoryLong); + std::string data_message_proto2 = + BuildDataMessage(kDataMsgFromLong2, kDataMsgCategoryLong2); + std::string data_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + data_message_pkt += EncodePacket(kDataMessageStanzaTag, data_message_proto2); + DCHECK_GT(data_message_pkt.size(), 256U); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::SYNCHRONOUS, + data_message_pkt.c_str(), + data_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + WaitForMessage(); // The first data message. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(data_message_proto, received_message->SerializeAsString()); + received_message.reset(); + WaitForMessage(); // The second data message. + ASSERT_TRUE(received_message.get()); + EXPECT_EQ(data_message_proto2, received_message->SerializeAsString()); +} + +// Simulate a message where the end of the data does not arrive in time and the +// read times out. +TEST_F(GCMConnectionHandlerTest, ReadTimeout) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + + std::string data_message_proto = BuildDataMessage(kDataMsgFrom, + kDataMsgCategory); + std::string data_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + int bytes_in_first_message = data_message_pkt.size() / 2; + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::ASYNC, + data_message_pkt.c_str(), + bytes_in_first_message)); + read_list.push_back(net::MockRead(net::SYNCHRONOUS, + net::ERR_IO_PENDING)); + read_list.push_back(net::MockRead(net::ASYNC, + data_message_pkt.c_str() + + bytes_in_first_message, + data_message_pkt.size() - + bytes_in_first_message)); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + received_message.reset(); + WaitForMessage(); // Should time out. + EXPECT_FALSE(received_message.get()); + EXPECT_EQ(net::ERR_TIMED_OUT, last_error()); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + + // Finish the socket read. Should have no effect. + data_provider()->ForceNextRead(); +} + +// Receive a message with zero data bytes. +TEST_F(GCMConnectionHandlerTest, RecvMsgNoData) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + + std::string data_message_pkt = EncodePacket(kHeartbeatPingTag, ""); + ASSERT_EQ(data_message_pkt.size(), 2U); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::ASYNC, + data_message_pkt.c_str(), + data_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + received_message.reset(); + WaitForMessage(); // The heartbeat ping. + EXPECT_TRUE(received_message.get()); + EXPECT_EQ(GetMCSProtoTag(*received_message), kHeartbeatPingTag); + EXPECT_EQ(net::OK, last_error()); + EXPECT_TRUE(connection_handler()->CanSendMessage()); +} + +// Send a message after performing the handshake. +TEST_F(GCMConnectionHandlerTest, SendMsg) { + mcs_proto::DataMessageStanza data_message; + data_message.set_from(kDataMsgFrom); + data_message.set_category(kDataMsgCategory); + std::string handshake_request = EncodeHandshakeRequest(); + std::string data_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message.SerializeAsString()); + WriteList write_list; + write_list.push_back(net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + write_list.push_back(net::MockWrite(net::ASYNC, + data_message_pkt.c_str(), + data_message_pkt.size())); + std::string handshake_response = EncodeHandshakeResponse(); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + EXPECT_TRUE(connection_handler()->CanSendMessage()); + connection_handler()->SendMessage(data_message); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + WaitForMessage(); // The message send. + EXPECT_TRUE(connection_handler()->CanSendMessage()); +} + +// Attempt to send a message after the socket is disconnected due to a timeout. +TEST_F(GCMConnectionHandlerTest, SendMsgSocketDisconnected) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list; + write_list.push_back(net::MockWrite(net::ASYNC, + handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, + handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING)); + net::StreamSocket* socket = BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + EXPECT_TRUE(connection_handler()->CanSendMessage()); + socket->Disconnect(); + mcs_proto::DataMessageStanza data_message; + data_message.set_from(kDataMsgFrom); + data_message.set_category(kDataMsgCategory); + connection_handler()->SendMessage(data_message); + EXPECT_FALSE(connection_handler()->CanSendMessage()); + WaitForMessage(); // The message send. Should result in an error + EXPECT_FALSE(connection_handler()->CanSendMessage()); + EXPECT_EQ(net::ERR_CONNECTION_CLOSED, last_error()); +} + +} // namespace +} // namespace gcm diff --git a/google_apis/gcm/gcm.gyp b/google_apis/gcm/gcm.gyp index c1ab33b..f04fbee 100644 --- a/google_apis/gcm/gcm.gyp +++ b/google_apis/gcm/gcm.gyp @@ -12,7 +12,13 @@ { 'target_name': 'gcm', 'type': '<(component)', - 'variables': { 'enable_wexit_time_destructors': 1, }, + 'variables': { + 'enable_wexit_time_destructors': 1, + 'proto_in_dir': './protocol', + 'proto_out_dir': 'google_apis/gcm/protocol', + 'cc_generator_options': 'dllexport_decl=GCM_EXPORT:', + 'cc_include': 'google_apis/gcm/base/gcm_export.h', + }, 'include_dirs': [ '../..', ], @@ -28,8 +34,16 @@ '../../third_party/protobuf/protobuf.gyp:protobuf_lite' ], 'sources': [ + 'base/mcs_util.h', + 'base/mcs_util.cc', 'base/socket_stream.h', 'base/socket_stream.cc', + 'engine/connection_handler.h', + 'engine/connection_handler.cc', + 'protocol/mcs.proto', + ], + 'includes': [ + '../../build/protoc.gypi' ], }, @@ -46,10 +60,13 @@ '../../base/base.gyp:base', '../../net/net.gyp:net_test_support', '../../testing/gtest.gyp:gtest', + '../../third_party/protobuf/protobuf.gyp:protobuf_lite', 'gcm' ], 'sources': [ + 'base/mcs_util_unittest.cc', 'base/socket_stream_unittest.cc', + 'engine/connection_handler_unittest.cc', ] }, ], diff --git a/google_apis/gcm/protocol/mcs.proto b/google_apis/gcm/protocol/mcs.proto new file mode 100644 index 0000000..2926d10 --- /dev/null +++ b/google_apis/gcm/protocol/mcs.proto @@ -0,0 +1,269 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// MCS protocol for communication between Chrome client and Mobile Connection +// Server . + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; +option retain_unknown_fields = true; + +package mcs_proto; + +/* + Common fields/comments: + + stream_id: no longer sent by server, each side keeps a counter + last_stream_id_received: sent only if a packet was received since last time + a last_stream was sent + status: new bitmask including the 'idle' as bit 0. + + */ + +/** + TAG: 0 + */ +message HeartbeatPing { + optional int32 stream_id = 1; + optional int32 last_stream_id_received = 2; + optional int64 status = 3; +} + +/** + TAG: 1 + */ +message HeartbeatAck { + optional int32 stream_id = 1; + optional int32 last_stream_id_received = 2; + optional int64 status = 3; +} + +message ErrorInfo { + required int32 code = 1; + optional string message = 2; + optional string type = 3; + optional Extension extension = 4; +} + +// MobileSettings class. +// "u:f", "u:b", "u:s" - multi user devices reporting foreground, background +// and stopped users. +// hbping: heatbeat ping interval +// rmq2v: include explicit stream IDs + +message Setting { + required string name = 1; + required string value = 2; +} + +message HeartbeatStat { + required string ip = 1; + required bool timeout = 2; + required int32 interval_ms = 3; +} + +message HeartbeatConfig { + optional bool upload_stat = 1; + optional string ip = 2; + optional int32 interval_ms = 3; +} + +/** + TAG: 2 + */ +message LoginRequest { + enum AuthService { + ANDROID_ID = 2; + } + required string id = 1; // Must be present ( proto required ), may be empty + // string. + // mcs.android.com. + required string domain = 2; + // Decimal android ID + required string user = 3; + + required string resource = 4; + + // Secret + required string auth_token = 5; + + // Format is: android-HEX_DEVICE_ID + // The user is the decimal value. + optional string device_id = 6; + + // RMQ1 - no longer used + optional int64 last_rmq_id = 7; + + repeated Setting setting = 8; + //optional int32 compress = 9; + repeated string received_persistent_id = 10; + + // Replaced by "rmq2v" setting + // optional bool include_stream_ids = 11; + + optional bool adaptive_heartbeat = 12; + optional HeartbeatStat heartbeat_stat = 13; + // Must be true. + optional bool use_rmq2 = 14; + optional int64 account_id = 15; + + // ANDROID_ID = 2 + optional AuthService auth_service = 16; + + optional int32 network_type = 17; + optional int64 status = 18; +} + +/** + * TAG: 3 + */ +message LoginResponse { + required string id = 1; + // Not used. + optional string jid = 2; + // Null if login was ok. + optional ErrorInfo error = 3; + repeated Setting setting = 4; + optional int32 stream_id = 5; + // Should be "1" + optional int32 last_stream_id_received = 6; + optional HeartbeatConfig heartbeat_config = 7; + // used by the client to synchronize with the server timestamp. + optional int64 server_timestamp = 8; +} + +message StreamErrorStanza { + required string type = 1; + optional string text = 2; +} + +/** + * TAG: 4 + */ +message Close { +} + +message Extension { + // 12: SelectiveAck + // 13: StreamAck + required int32 id = 1; + required bytes data = 2; +} + +/** + * TAG: 7 + * IqRequest must contain a single extension. IqResponse may contain 0 or 1 + * extensions. + */ +message IqStanza { + enum IqType { + GET = 0; + SET = 1; + RESULT = 2; + IQ_ERROR = 3; + } + + optional int64 rmq_id = 1; + required IqType type = 2; + required string id = 3; + optional string from = 4; + optional string to = 5; + optional ErrorInfo error = 6; + + // Only field used in the 38+ protocol (besides common last_stream_id_received, status, rmq_id) + optional Extension extension = 7; + + optional string persistent_id = 8; + optional int32 stream_id = 9; + optional int32 last_stream_id_received = 10; + optional int64 account_id = 11; + optional int64 status = 12; +} + +message AppData { + required string key = 1; + required string value = 2; +} + +/** + * TAG: 8 + */ +message DataMessageStanza { + // Not used. + // optional int64 rmq_id = 1; + + // This is the message ID, set by client, DMP.9 (message_id) + optional string id = 2; + + // Project ID of the sender, DMP.1 + required string from = 3; + + // Part of DMRequest - also the key in DataMessageProto. + optional string to = 4; + + // Package name. DMP.2 + required string category = 5; + + // The collapsed key, DMP.3 + optional string token = 6; + + // User data + GOOGLE. prefixed special entries, DMP.4 + repeated AppData app_data = 7; + + // Not used. + optional bool from_trusted_server = 8; + + // Part of the ACK protocol, returned in DataMessageResponse on server side. + // It's part of the key of DMP. + optional string persistent_id = 9; + + // In-stream ack. Increments on each message sent - a bit redundant + // Not used in DMP/DMR. + optional int32 stream_id = 10; + optional int32 last_stream_id_received = 11; + + // Not used. + // optional string permission = 12; + + // Sent by the device shortly after registration. + optional string reg_id = 13; + + // Not used. + // optional string pkg_signature = 14; + // Not used. + // optional string client_id = 15; + + // serial number of the target user, DMP.8 + // It is the 'serial number' according to user manager. + optional int64 device_user_id = 16; + + // Time to live, in seconds. + optional int32 ttl = 17; + // Timestamp ( according to client ) when message was sent by app, in seconds + optional int64 sent = 18; + + // How long has the message been queued before the flush, in seconds. + // This is needed to account for the time difference between server and + // client: server should adjust 'sent' based on his 'receive' time. + optional int32 queued = 19; + + optional int64 status = 20; +} + +/** + Included in IQ with ID 13, sent from client or server after 10 unconfirmed + messages. + */ +message StreamAck { + // No last_streamid_received required. This is included within an IqStanza, + // which includes the last_stream_id_received. +} + +/** + Included in IQ sent after LoginResponse from server with ID 12. +*/ +message SelectiveAck { + repeated string id = 1; +} |