diff options
Diffstat (limited to 'google_apis')
-rw-r--r-- | google_apis/gcm/engine/connection_handler_impl.cc | 40 | ||||
-rw-r--r-- | google_apis/gcm/engine/connection_handler_impl_unittest.cc | 94 |
2 files changed, 118 insertions, 16 deletions
diff --git a/google_apis/gcm/engine/connection_handler_impl.cc b/google_apis/gcm/engine/connection_handler_impl.cc index 4d3fc82..95c0286 100644 --- a/google_apis/gcm/engine/connection_handler_impl.cc +++ b/google_apis/gcm/engine/connection_handler_impl.cc @@ -339,6 +339,8 @@ void ConnectionHandlerImpl::OnGotMessageSize() { } int prev_byte_count = input_stream_->UnreadByteCount(); + int result = net::OK; + bool incomplete_size_packet = false; { CodedInputStream coded_input_stream(input_stream_.get()); if (!coded_input_stream.ReadVarint32(&message_size_)) { @@ -346,18 +348,25 @@ void ConnectionHandlerImpl::OnGotMessageSize() { if (prev_byte_count >= kSizePacketLenMax) { // Already had enough bytes, something else went wrong. LOG(ERROR) << "Failed to process message size"; - connection_callback_.Run(net::ERR_FILE_TOO_BIG); - return; + result = net::ERR_FILE_TOO_BIG; + } else { + // Back up by the amount read. + int bytes_read = prev_byte_count - input_stream_->UnreadByteCount(); + input_stream_->BackUp(bytes_read); + size_packet_so_far_ = bytes_read; + incomplete_size_packet = true; } - // Back up by the amount read. - int bytes_read = prev_byte_count - input_stream_->UnreadByteCount(); - input_stream_->BackUp(bytes_read); - size_packet_so_far_ = bytes_read; - WaitForData(MCS_SIZE); - return; } } + if (result != net::OK) { + connection_callback_.Run(result); + return; + } else if (incomplete_size_packet) { + WaitForData(MCS_SIZE); + return; + } + DVLOG(1) << "Proto size: " << message_size_; size_packet_so_far_ = 0; payload_input_buffer_.clear(); @@ -398,14 +407,13 @@ void ConnectionHandlerImpl::OnGotMessageBytes() { return; } + int result = net::OK; if (message_size_ < kDefaultDataPacketLimit) { CodedInputStream coded_input_stream(input_stream_.get()); if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { LOG(ERROR) << "Unable to parse GCM message of type " << static_cast<unsigned int>(message_tag_); - // Reset the connection. - connection_callback_.Run(net::ERR_FAILED); - return; + result = net::ERR_FAILED; } } else { // Copy any data in the input stream onto the end of the buffer. @@ -424,9 +432,7 @@ void ConnectionHandlerImpl::OnGotMessageBytes() { if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { LOG(ERROR) << "Unable to parse GCM message of type " << static_cast<unsigned int>(message_tag_); - // Reset the connection. - connection_callback_.Run(net::ERR_FAILED); - return; + result = net::ERR_FAILED; } } else { // Continue reading data. @@ -444,6 +450,12 @@ void ConnectionHandlerImpl::OnGotMessageBytes() { } } + if (result != net::OK) { + // Reset the connection. + connection_callback_.Run(result); + return; + } + input_stream_->RebuildBuffer(); base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, diff --git a/google_apis/gcm/engine/connection_handler_impl_unittest.cc b/google_apis/gcm/engine/connection_handler_impl_unittest.cc index f15c841..ab5fdf2 100644 --- a/google_apis/gcm/engine/connection_handler_impl_unittest.cc +++ b/google_apis/gcm/engine/connection_handler_impl_unittest.cc @@ -4,6 +4,8 @@ #include "google_apis/gcm/engine/connection_handler_impl.h" +#include <string> + #include "base/bind.h" #include "base/memory/scoped_ptr.h" #include "base/run_loop.h" @@ -11,6 +13,7 @@ #include "base/test/test_timeouts.h" #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/wire_format_lite.h" #include "google_apis/gcm/base/mcs_util.h" #include "google_apis/gcm/base/socket_stream.h" #include "google_apis/gcm/protocol/mcs.pb.h" @@ -96,13 +99,40 @@ std::string EncodeHandshakeResponse() { // Build a serialized data message stanza protobuf. std::string BuildDataMessage(const std::string& from, const std::string& category) { - std::string result; mcs_proto::DataMessageStanza data_message; data_message.set_from(from); data_message.set_category(category); return data_message.SerializeAsString(); } +// Build a corrupt data message that will force the protobuf parser to backup +// after completion (useful in testing memory corruption cases due to a +// CodedInputStream going out of scope). +std::string BuildCorruptDataMessage() { + // Manually construct the message with invalid data. We set field 2 (id) to + // be an invalid string. + const int kMsgTag = + (2 << google::protobuf::internal::WireFormatLite::kTagTypeBits) | + google::protobuf::internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED; + const int kStringLength = -1; // Corrupted length. + const char kStringData[] = "id"; + std::string data_message_proto; + google::protobuf::io::StringOutputStream string_output_stream( + &data_message_proto); + { + google::protobuf::io::CodedOutputStream coded_output_stream( + &string_output_stream); + coded_output_stream.WriteVarint32(kMsgTag); + coded_output_stream.WriteVarint32( + static_cast<google::protobuf::uint32>(kStringLength)); + coded_output_stream.WriteRaw(&kStringData, sizeof(kStringData)); + // ~CodedOutputStream must run before the move constructor at the + // return statement. http://crbug.com/338962 + } + + return data_message_proto; +} + class GCMConnectionHandlerImplTest : public testing::Test { public: GCMConnectionHandlerImplTest(); @@ -117,7 +147,7 @@ class GCMConnectionHandlerImplTest : public testing::Test { ConnectionHandlerImpl* connection_handler() { return connection_handler_.get(); } - base::MessageLoop* message_loop() { return &message_loop_; }; + base::MessageLoop* message_loop() { return &message_loop_; } net::StaticSocketDataProvider* data_provider() { return data_provider_.get(); } @@ -227,6 +257,8 @@ void GCMConnectionHandlerImplTest::WriteContinuation() { void GCMConnectionHandlerImplTest::ConnectionContinuation(int error) { last_error_ = error; + if (error != net::OK) + connection_handler_->Reset(); run_loop_->Quit(); } @@ -815,5 +847,63 @@ TEST_F(GCMConnectionHandlerImplTest, RecvMsgSplitSize) { EXPECT_EQ(net::OK, last_error()); } +// Make sure a message with invalid data is handled gracefully and resets +// the connection with a FAILED error. +TEST_F(GCMConnectionHandlerImplTest, InvalidData) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + std::string data_message_proto = BuildCorruptDataMessage(); + std::string invalid_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::ASYNC, invalid_message_pkt.c_str(), + invalid_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + received_message.reset(); + WaitForMessage(); // The invalid message. + EXPECT_FALSE(received_message.get()); + EXPECT_EQ(net::ERR_FAILED, last_error()); +} + +// Make sure a long message with invalid data is handled gracefully and resets +// the connection with a FAILED error. +TEST_F(GCMConnectionHandlerImplTest, InvalidDataLong) { + std::string handshake_request = EncodeHandshakeRequest(); + WriteList write_list(1, net::MockWrite(net::ASYNC, handshake_request.c_str(), + handshake_request.size())); + std::string handshake_response = EncodeHandshakeResponse(); + std::string data_message_proto = BuildCorruptDataMessage(); + // Pad the corrupt data so it's beyond the normal single packet length. + data_message_proto.resize(1 << 12); + std::string invalid_message_pkt = + EncodePacket(kDataMessageStanzaTag, data_message_proto); + + ReadList read_list; + read_list.push_back(net::MockRead(net::ASYNC, handshake_response.c_str(), + handshake_response.size())); + read_list.push_back(net::MockRead(net::ASYNC, invalid_message_pkt.c_str(), + invalid_message_pkt.size())); + BuildSocket(read_list, write_list); + + ScopedMessage received_message; + Connect(&received_message); + WaitForMessage(); // The login send. + WaitForMessage(); // The login response. + received_message.reset(); + WaitForMessage(); // The invalid message. + EXPECT_FALSE(received_message.get()); + EXPECT_EQ(net::ERR_FAILED, last_error()); +} + } // namespace } // namespace gcm |