diff options
Diffstat (limited to 'remoting/protocol')
-rw-r--r-- | remoting/protocol/client_message_dispatcher.cc | 15 | ||||
-rw-r--r-- | remoting/protocol/client_message_dispatcher.h | 2 | ||||
-rw-r--r-- | remoting/protocol/fake_session.cc | 4 | ||||
-rw-r--r-- | remoting/protocol/fake_session.h | 15 | ||||
-rw-r--r-- | remoting/protocol/host_message_dispatcher.cc | 39 | ||||
-rw-r--r-- | remoting/protocol/host_message_dispatcher.h | 8 | ||||
-rw-r--r-- | remoting/protocol/input_sender.cc | 11 | ||||
-rw-r--r-- | remoting/protocol/message_decoder.cc | 7 | ||||
-rw-r--r-- | remoting/protocol/message_decoder.h | 9 | ||||
-rw-r--r-- | remoting/protocol/message_decoder_unittest.cc | 38 | ||||
-rw-r--r-- | remoting/protocol/message_reader.cc | 57 | ||||
-rw-r--r-- | remoting/protocol/message_reader.h | 58 | ||||
-rw-r--r-- | remoting/protocol/message_reader_unittest.cc | 242 | ||||
-rw-r--r-- | remoting/protocol/protobuf_video_reader.cc | 4 | ||||
-rw-r--r-- | remoting/protocol/protobuf_video_reader.h | 2 | ||||
-rw-r--r-- | remoting/protocol/ref_counted_message.h | 45 |
16 files changed, 418 insertions, 138 deletions
diff --git a/remoting/protocol/client_message_dispatcher.cc b/remoting/protocol/client_message_dispatcher.cc index 9568685..e7b6dd6 100644 --- a/remoting/protocol/client_message_dispatcher.cc +++ b/remoting/protocol/client_message_dispatcher.cc @@ -11,7 +11,6 @@ #include "remoting/protocol/client_stub.h" #include "remoting/protocol/input_stub.h" #include "remoting/protocol/message_reader.h" -#include "remoting/protocol/ref_counted_message.h" #include "remoting/protocol/session.h" namespace remoting { @@ -39,18 +38,18 @@ void ClientMessageDispatcher::Initialize( } void ClientMessageDispatcher::OnControlMessageReceived( - ControlMessage* message) { - scoped_refptr<RefCountedMessage<ControlMessage> > ref_msg = - new RefCountedMessage<ControlMessage>(message); + ControlMessage* message, Task* done_task) { + // TODO(sergeyu): Add message validation. if (message->has_notify_resolution()) { client_stub_->NotifyResolution( - &message->notify_resolution(), NewDeleteTask(ref_msg)); + &message->notify_resolution(), done_task); } else if (message->has_begin_session_response()) { client_stub_->BeginSessionResponse( - &message->begin_session_response().login_status(), - NewDeleteTask(ref_msg)); + &message->begin_session_response().login_status(), done_task); } else { - NOTREACHED() << "Invalid control message received"; + LOG(WARNING) << "Invalid control message received."; + done_task->Run(); + delete done_task; } } diff --git a/remoting/protocol/client_message_dispatcher.h b/remoting/protocol/client_message_dispatcher.h index 8f0f5a6..88c7c14 100644 --- a/remoting/protocol/client_message_dispatcher.h +++ b/remoting/protocol/client_message_dispatcher.h @@ -40,7 +40,7 @@ class ClientMessageDispatcher { void Initialize(protocol::Session* session, ClientStub* client_stub); private: - void OnControlMessageReceived(ControlMessage* message); + void OnControlMessageReceived(ControlMessage* message, Task* done_task); // MessageReader that runs on the control channel. It runs a loop // that parses data on the channel and then calls the corresponding handler diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index ac344e0..fe50f2b 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -21,7 +21,7 @@ FakeSocket::FakeSocket() FakeSocket::~FakeSocket() { } -void FakeSocket::AppendInputData(char* data, int data_size) { +void FakeSocket::AppendInputData(const char* data, int data_size) { input_data_.insert(input_data_.end(), data, data + data_size); // Complete pending read if any. if (read_pending_) { @@ -78,7 +78,7 @@ FakeUdpSocket::FakeUdpSocket() FakeUdpSocket::~FakeUdpSocket() { } -void FakeUdpSocket::AppendInputPacket(char* data, int data_size) { +void FakeUdpSocket::AppendInputPacket(const char* data, int data_size) { input_packets_.push_back(std::string()); input_packets_.back().assign(data, data + data_size); diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index 62818e4..3ffe20b 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -27,10 +27,11 @@ class FakeSocket : public net::Socket { FakeSocket(); virtual ~FakeSocket(); - const std::string& written_data() { return written_data_; } + const std::string& written_data() const { return written_data_; } - void AppendInputData(char* data, int data_size); - int input_pos() { return input_pos_; } + void AppendInputData(const char* data, int data_size); + int input_pos() const { return input_pos_; } + bool read_pending() const { return read_pending_; } // net::Socket interface. virtual int Read(net::IOBuffer* buf, int buf_len, @@ -60,12 +61,12 @@ class FakeUdpSocket : public net::Socket { FakeUdpSocket(); virtual ~FakeUdpSocket(); - const std::vector<std::string>& written_packets() { + const std::vector<std::string>& written_packets() const { return written_packets_; } - void AppendInputPacket(char* data, int data_size); - int input_pos() { return input_pos_; } + void AppendInputPacket(const char* data, int data_size); + int input_pos() const { return input_pos_; } // net::Socket interface. virtual int Read(net::IOBuffer* buf, int buf_len, @@ -100,7 +101,7 @@ class FakeSession : public Session { message_loop_ = message_loop; } - bool is_closed() { return closed_; } + bool is_closed() const { return closed_; } virtual void SetStateChangeCallback(StateChangeCallback* callback); diff --git a/remoting/protocol/host_message_dispatcher.cc b/remoting/protocol/host_message_dispatcher.cc index 2554c9b..1e1eea8 100644 --- a/remoting/protocol/host_message_dispatcher.cc +++ b/remoting/protocol/host_message_dispatcher.cc @@ -11,7 +11,6 @@ #include "remoting/protocol/host_stub.h" #include "remoting/protocol/input_stub.h" #include "remoting/protocol/message_reader.h" -#include "remoting/protocol/ref_counted_message.h" #include "remoting/protocol/session.h" namespace remoting { @@ -47,34 +46,32 @@ void HostMessageDispatcher::Initialize( NewCallback(this, &HostMessageDispatcher::OnControlMessageReceived)); } -void HostMessageDispatcher::OnControlMessageReceived(ControlMessage* message) { - scoped_refptr<RefCountedMessage<ControlMessage> > ref_msg = - new RefCountedMessage<ControlMessage>(message); +void HostMessageDispatcher::OnControlMessageReceived( + ControlMessage* message, Task* done_task) { + // TODO(sergeyu): Add message validation. if (message->has_suggest_resolution()) { - host_stub_->SuggestResolution( - &message->suggest_resolution(), NewDeleteTask(ref_msg)); + host_stub_->SuggestResolution(&message->suggest_resolution(), done_task); } else if (message->has_begin_session_request()) { host_stub_->BeginSessionRequest( - &message->begin_session_request().credentials(), - NewDeleteTask(ref_msg)); + &message->begin_session_request().credentials(), done_task); } else { - NOTREACHED() << "Invalid control message received"; + LOG(WARNING) << "Invalid control message received."; + done_task->Run(); + delete done_task; } } void HostMessageDispatcher::OnEventMessageReceived( - EventMessage* message) { - scoped_refptr<RefCountedMessage<EventMessage> > ref_msg = - new RefCountedMessage<EventMessage>(message); - for (int i = 0; i < message->event_size(); ++i) { - if (message->event(i).has_key()) { - input_stub_->InjectKeyEvent( - &message->event(i).key(), NewDeleteTask(ref_msg)); - } - if (message->event(i).has_mouse()) { - input_stub_->InjectMouseEvent( - &message->event(i).mouse(), NewDeleteTask(ref_msg)); - } + EventMessage* message, Task* done_task) { + // TODO(sergeyu): Add message validation. + if (message->has_key_event()) { + input_stub_->InjectKeyEvent(&message->key_event(), done_task); + } else if (message->has_mouse_event()) { + input_stub_->InjectMouseEvent(&message->mouse_event(), done_task); + } else { + LOG(WARNING) << "Invalid event message received."; + done_task->Run(); + delete done_task; } } diff --git a/remoting/protocol/host_message_dispatcher.h b/remoting/protocol/host_message_dispatcher.h index 8ebed01..60afb27 100644 --- a/remoting/protocol/host_message_dispatcher.h +++ b/remoting/protocol/host_message_dispatcher.h @@ -11,12 +11,10 @@ #include "remoting/protocol/message_reader.h" namespace remoting { - -class EventMessage; - namespace protocol { class ControlMessage; +class EventMessage; class HostStub; class InputStub; class Session; @@ -45,11 +43,11 @@ class HostMessageDispatcher { private: // This method is called by |control_channel_reader_| when a control // message is received. - void OnControlMessageReceived(ControlMessage* message); + void OnControlMessageReceived(ControlMessage* message, Task* done_task); // This method is called by |event_channel_reader_| when a event // message is received. - void OnEventMessageReceived(EventMessage* message); + void OnEventMessageReceived(EventMessage* message, Task* done_task); // MessageReader that runs on the control channel. It runs a loop // that parses data on the channel and then delegates the message to this diff --git a/remoting/protocol/input_sender.cc b/remoting/protocol/input_sender.cc index 8831649..8ba90ca 100644 --- a/remoting/protocol/input_sender.cc +++ b/remoting/protocol/input_sender.cc @@ -9,6 +9,7 @@ #include "base/task.h" #include "remoting/proto/event.pb.h" +#include "remoting/proto/internal.pb.h" #include "remoting/protocol/buffered_socket_writer.h" #include "remoting/protocol/util.h" @@ -27,19 +28,17 @@ InputSender::~InputSender() { void InputSender::InjectKeyEvent(const KeyEvent* event, Task* done) { EventMessage message; - Event* evt = message.add_event(); // TODO(hclam): Provide timestamp. - evt->set_timestamp(0); - evt->mutable_key()->CopyFrom(*event); + message.set_timestamp(0); + message.mutable_key_event()->CopyFrom(*event); buffered_writer_->Write(SerializeAndFrameMessage(message), done); } void InputSender::InjectMouseEvent(const MouseEvent* event, Task* done) { EventMessage message; - Event* evt = message.add_event(); // TODO(hclam): Provide timestamp. - evt->set_timestamp(0); - evt->mutable_mouse()->CopyFrom(*event); + message.set_timestamp(0); + message.mutable_mouse_event()->CopyFrom(*event); buffered_writer_->Write(SerializeAndFrameMessage(message), done); } diff --git a/remoting/protocol/message_decoder.cc b/remoting/protocol/message_decoder.cc index b460b4d..2e33229 100644 --- a/remoting/protocol/message_decoder.cc +++ b/remoting/protocol/message_decoder.cc @@ -25,7 +25,7 @@ void MessageDecoder::AddData(scoped_refptr<net::IOBuffer> data, buffer_.Append(data, data_size); } -bool MessageDecoder::GetNextMessage(CompoundBuffer* message_buffer) { +CompoundBuffer* MessageDecoder::GetNextMessage() { // Determine the payload size. If we already know it then skip this part. // We may not have enough data to determine the payload size so use a // utility function to find out. @@ -39,14 +39,15 @@ bool MessageDecoder::GetNextMessage(CompoundBuffer* message_buffer) { // If the next payload size is still not known or we don't have enough // data for parsing then exit. if (!next_payload_known_ || buffer_.total_bytes() < next_payload_) - return false; + return NULL; + CompoundBuffer* message_buffer = new CompoundBuffer(); message_buffer->CopyFrom(buffer_, 0, next_payload_); message_buffer->Lock(); buffer_.CropFront(next_payload_); next_payload_known_ = false; - return true; + return message_buffer; } bool MessageDecoder::GetPayloadSize(int* size) { diff --git a/remoting/protocol/message_decoder.h b/remoting/protocol/message_decoder.h index 3e0745f7..0ba8b78 100644 --- a/remoting/protocol/message_decoder.h +++ b/remoting/protocol/message_decoder.h @@ -35,10 +35,11 @@ class MessageDecoder { // its bytes are consumed. void AddData(scoped_refptr<net::IOBuffer> data, int data_size); - // Get next message from the stream and puts it in - // |message_buffer|. Returns false if there are no complete messages - // yet. - bool GetNextMessage(CompoundBuffer* message_buffer); + // Returns next message from the stream. Ownership of the result is + // passed to the caller. Returns NULL if there are no complete + // messages yet, otherwise returns a buffer that contains one + // message. + CompoundBuffer* GetNextMessage(); private: // Retrieves the read payload size of the current protocol buffer via |size|. diff --git a/remoting/protocol/message_decoder_unittest.cc b/remoting/protocol/message_decoder_unittest.cc index 81bb699..c00d57d 100644 --- a/remoting/protocol/message_decoder_unittest.cc +++ b/remoting/protocol/message_decoder_unittest.cc @@ -6,7 +6,9 @@ #include "base/scoped_ptr.h" #include "base/stl_util-inl.h" +#include "base/string_number_conversions.h" #include "remoting/proto/event.pb.h" +#include "remoting/proto/internal.pb.h" #include "remoting/protocol/message_decoder.h" #include "remoting/protocol/util.h" #include "testing/gtest/include/gtest/gtest.h" @@ -29,16 +31,13 @@ static void PrepareData(uint8** buffer, int* size) { // Contains all encoded messages. std::string encoded_data; - EventMessage msg; - // Then append 10 update sequences to the data. for (int i = 0; i < 10; ++i) { - Event* event = msg.add_event(); - event->set_timestamp(i); - event->mutable_key()->set_keycode(kTestKey + i); - event->mutable_key()->set_pressed((i % 2) != 0); + EventMessage msg; + msg.set_timestamp(i); + msg.mutable_key_event()->set_keycode(kTestKey + i); + msg.mutable_key_event()->set_pressed((i % 2) != 0); AppendMessage(msg, &encoded_data); - msg.Clear(); } *size = encoded_data.length(); @@ -62,25 +61,27 @@ void SimulateReadSequence(const int read_sequence[], int sequence_size) { // Then feed the protocol decoder using the above generated data and the // read pattern. std::list<EventMessage*> message_list; - for (int i = 0; i < size;) { + for (int pos = 0; pos < size;) { + SCOPED_TRACE("Input position: " + base::IntToString(pos)); + // First generate the amount to feed the decoder. - int read = std::min(size - i, read_sequence[i % sequence_size]); + int read = std::min(size - pos, read_sequence[pos % sequence_size]); // And then prepare an IOBuffer for feeding it. scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(read)); - memcpy(buffer->data(), test_data + i, read); + memcpy(buffer->data(), test_data + pos, read); decoder.AddData(buffer, read); while (true) { - CompoundBuffer message; - if (!decoder.GetNextMessage(&message)) + scoped_ptr<CompoundBuffer> message(decoder.GetNextMessage()); + if (!message.get()) break; EventMessage* event = new EventMessage(); - CompoundBufferInputStream stream(&message); + CompoundBufferInputStream stream(message.get()); ASSERT_TRUE(event->ParseFromZeroCopyStream(&stream)); message_list.push_back(event); } - i += read; + pos += read; } // Then verify the decoded messages. @@ -90,15 +91,16 @@ void SimulateReadSequence(const int read_sequence[], int sequence_size) { for (std::list<EventMessage*>::iterator it = message_list.begin(); it != message_list.end(); ++it) { + SCOPED_TRACE("Message " + base::IntToString(index)); + EventMessage* message = *it; // Partial update stream. - EXPECT_EQ(message->event_size(), 1); - EXPECT_TRUE(message->event(0).has_key()); + EXPECT_TRUE(message->has_key_event()); // TODO(sergeyu): Don't use index here. Instead store the expected values // in an array. - EXPECT_EQ(kTestKey + index, message->event(0).key().keycode()); - EXPECT_EQ((index % 2) != 0, message->event(0).key().pressed()); + EXPECT_EQ(kTestKey + index, message->key_event().keycode()); + EXPECT_EQ((index % 2) != 0, message->key_event().pressed()); ++index; } STLDeleteElements(&message_list); diff --git a/remoting/protocol/message_reader.cc b/remoting/protocol/message_reader.cc index 7b818ee..3bbfd59 100644 --- a/remoting/protocol/message_reader.cc +++ b/remoting/protocol/message_reader.cc @@ -18,12 +18,16 @@ static const int kReadBufferSize = 4096; MessageReader::MessageReader() : socket_(NULL), + message_loop_(NULL), + read_pending_(false), + pending_messages_(0), closed_(false), ALLOW_THIS_IN_INITIALIZER_LIST( read_callback_(this, &MessageReader::OnRead)) { } MessageReader::~MessageReader() { + CHECK_EQ(pending_messages_, 0); } void MessageReader::Init(net::Socket* socket, @@ -31,21 +35,27 @@ void MessageReader::Init(net::Socket* socket, message_received_callback_.reset(callback); DCHECK(socket); socket_ = socket; + message_loop_ = MessageLoop::current(); DoRead(); } void MessageReader::DoRead() { - while (!closed_) { + DCHECK(!read_pending_); + + // Don't try to read again if there is another read pending or we + // have messages that we haven't finished processing yet. + while (!closed_ && !read_pending_ && pending_messages_ == 0) { read_buffer_ = new net::IOBuffer(kReadBufferSize); int result = socket_->Read( read_buffer_, kReadBufferSize, &read_callback_); HandleReadResult(result); - if (result < 0) - break; } } void MessageReader::OnRead(int result) { + DCHECK(read_pending_); + read_pending_ = false; + if (!closed_) { HandleReadResult(result); DoRead(); @@ -53,12 +63,17 @@ void MessageReader::OnRead(int result) { } void MessageReader::HandleReadResult(int result) { + if (closed_) + return; + if (result > 0) { OnDataReceived(read_buffer_, result); } else { if (result == net::ERR_CONNECTION_CLOSED) { closed_ = true; - } else if (result != net::ERR_IO_PENDING) { + } else if (result == net::ERR_IO_PENDING) { + read_pending_ = true; + } else { LOG(ERROR) << "Read() returned error " << result; } } @@ -67,14 +82,42 @@ void MessageReader::HandleReadResult(int result) { void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) { message_decoder_.AddData(data, data_size); + // Get list of all new messages first, and then call the callback + // for all of them. + std::vector<CompoundBuffer*> new_messages; while (true) { - CompoundBuffer buffer; - if (!message_decoder_.GetNextMessage(&buffer)) + CompoundBuffer* buffer = message_decoder_.GetNextMessage(); + if (!buffer) break; + new_messages.push_back(buffer); + } + + pending_messages_ += new_messages.size(); - message_received_callback_->Run(&buffer); + for (std::vector<CompoundBuffer*>::iterator it = new_messages.begin(); + it != new_messages.end(); ++it) { + message_received_callback_->Run(*it, NewRunnableMethod( + this, &MessageReader::OnMessageDone, *it)); } } +void MessageReader::OnMessageDone(CompoundBuffer* message) { + delete message; + ProcessDoneEvent(); +} + +void MessageReader::ProcessDoneEvent() { + if (MessageLoop::current() != message_loop_) { + message_loop_->PostTask(FROM_HERE, NewRunnableMethod( + this, &MessageReader::ProcessDoneEvent)); + return; + } + + pending_messages_--; + DCHECK_GE(pending_messages_, 0); + + DoRead(); // Start next read if neccessary. +} + } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/message_reader.h b/remoting/protocol/message_reader.h index d493778..d14cb90 100644 --- a/remoting/protocol/message_reader.h +++ b/remoting/protocol/message_reader.h @@ -13,6 +13,8 @@ #include "remoting/base/compound_buffer.h" #include "remoting/protocol/message_decoder.h" +class MessageLoop; + namespace net { class IOBuffer; class Socket; @@ -22,10 +24,23 @@ namespace remoting { namespace protocol { // MessageReader reads data from the socket asynchronously and calls -// callback for each message it receives -class MessageReader { +// callback for each message it receives. It stops calling the +// callback as soon as the socket is closed, so the socket should +// always be closed before the callback handler is destroyed. +// +// In order to throttle the stream, MessageReader doesn't try to read +// new data from the socket until all previously received messages are +// processed by the receiver (|done_task| is called for each message). +// It is still possible that the MessageReceivedCallback is called +// twice (so that there is more than one outstanding message), +// e.g. when we the sender sends multiple messages in one TCP packet. +class MessageReader : public base::RefCountedThreadSafe<MessageReader> { public: - typedef Callback1<CompoundBuffer*>::Type MessageReceivedCallback; + // The callback is given ownership of the second argument + // (|done_task|). The buffer (first argument) is owned by + // MessageReader and is freed when the task specified by the second + // argument is called. + typedef Callback2<CompoundBuffer*, Task*>::Type MessageReceivedCallback; MessageReader(); virtual ~MessageReader(); @@ -39,9 +54,23 @@ class MessageReader { void OnRead(int result); void HandleReadResult(int result); void OnDataReceived(net::IOBuffer* data, int data_size); + void OnMessageDone(CompoundBuffer* message); + void ProcessDoneEvent(); net::Socket* socket_; + // The network message loop this object runs on. + MessageLoop* message_loop_; + + // Set to true, when we have a socket read pending, and expecting + // OnRead() to be called when new data is received. + bool read_pending_; + + // Number of messages that we received, but haven't finished + // processing yet, i.e. |done_task| hasn't been called for these + // messages. + int pending_messages_; + bool closed_; scoped_refptr<net::IOBuffer> read_buffer_; net::CompletionCallbackImpl<MessageReader> read_callback_; @@ -52,33 +81,46 @@ class MessageReader { scoped_ptr<MessageReceivedCallback> message_received_callback_; }; +// Version of MessageReader for protocol buffer messages, that parses +// each incoming message. template <class T> class ProtobufMessageReader { public: - typedef typename Callback1<T*>::Type MessageReceivedCallback; + typedef typename Callback2<T*, Task*>::Type MessageReceivedCallback; ProtobufMessageReader() { }; ~ProtobufMessageReader() { }; void Init(net::Socket* socket, MessageReceivedCallback* callback) { message_received_callback_.reset(callback); - message_reader_.Init( + message_reader_ = new MessageReader(); + message_reader_->Init( socket, NewCallback(this, &ProtobufMessageReader<T>::OnNewData)); } private: - void OnNewData(CompoundBuffer* buffer) { + void OnNewData(CompoundBuffer* buffer, Task* done_task) { T* message = new T(); CompoundBufferInputStream stream(buffer); bool ret = message->ParseFromZeroCopyStream(&stream); if (!ret) { + LOG(WARNING) << "Received message that is not a valid protocol buffer."; delete message; } else { - message_received_callback_->Run(message); + DCHECK_EQ(stream.position(), buffer->total_bytes()); + message_received_callback_->Run( + message, NewRunnableFunction( + &ProtobufMessageReader<T>::OnDone, message, done_task)); } } - MessageReader message_reader_; + static void OnDone(T* message, Task* done_task) { + delete message; + done_task->Run(); + delete done_task; + } + + scoped_refptr<MessageReader> message_reader_; scoped_ptr<MessageReceivedCallback> message_received_callback_; }; diff --git a/remoting/protocol/message_reader_unittest.cc b/remoting/protocol/message_reader_unittest.cc new file mode 100644 index 0000000..5f17a1a --- /dev/null +++ b/remoting/protocol/message_reader_unittest.cc @@ -0,0 +1,242 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> + +#include "base/message_loop.h" +#include "net/socket/socket.h" +#include "remoting/protocol/fake_session.h" +#include "remoting/protocol/message_reader.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "third_party/libjingle/source/talk/base/byteorder.h" + +using testing::_; +using testing::DoAll; +using testing::Mock; +using testing::SaveArg; + +namespace remoting { +namespace protocol { + +namespace { +const char kTestMessage1[] = "Message1"; +const char kTestMessage2[] = "Message2"; + +ACTION(CallDoneTask) { + arg1->Run(); + delete arg1; +} +} + +class MockMessageReceivedCallback { + public: + MOCK_METHOD2(OnMessage, void(CompoundBuffer*, Task*)); +}; + +class MessageReaderTest : public testing::Test { + protected: + virtual void SetUp() { + reader_ = new MessageReader(); + } + + void InitReader() { + reader_->Init(&socket_, NewCallback( + &callback_, &MockMessageReceivedCallback::OnMessage)); + } + + void AddMessage(const std::string& message) { + std::string data = std::string(4, ' ') + message; + talk_base::SetBE32(const_cast<char*>(data.data()), message.size()); + + socket_.AppendInputData(data.data(), data.size()); + } + + bool CompareResult(CompoundBuffer* buffer, const std::string& expected) { + std::string result(buffer->total_bytes(), ' '); + buffer->CopyTo(const_cast<char*>(result.data()), result.size()); + return result == expected; + } + + // MessageLoop must be first here, so that is is destroyed the last. + MessageLoop message_loop_; + + scoped_refptr<MessageReader> reader_; + FakeSocket socket_; + MockMessageReceivedCallback callback_; +}; + +// Receive one message and process it with delay +TEST_F(MessageReaderTest, OneMessage_Delay) { + CompoundBuffer* buffer; + Task* done_task; + + AddMessage(kTestMessage1); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(1) + .WillOnce(DoAll(SaveArg<0>(&buffer), + SaveArg<1>(&done_task))); + + InitReader(); + + Mock::VerifyAndClearExpectations(&callback_); + Mock::VerifyAndClearExpectations(&socket_); + + EXPECT_TRUE(CompareResult(buffer, kTestMessage1)); + + // Verify that the reader starts reading again only after we've + // finished processing the previous message. + EXPECT_FALSE(socket_.read_pending()); + + done_task->Run(); + + EXPECT_TRUE(socket_.read_pending()); +} + +// Receive one message and process it instantly. +TEST_F(MessageReaderTest, OneMessage_Instant) { + AddMessage(kTestMessage1); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(1) + .WillOnce(CallDoneTask()); + + InitReader(); + + EXPECT_TRUE(socket_.read_pending()); +} + +// Receive two messages in one packet. +TEST_F(MessageReaderTest, TwoMessages_Together) { + CompoundBuffer* buffer1; + Task* done_task1; + CompoundBuffer* buffer2; + Task* done_task2; + + AddMessage(kTestMessage1); + AddMessage(kTestMessage2); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(2) + .WillOnce(DoAll(SaveArg<0>(&buffer1), + SaveArg<1>(&done_task1))) + .WillOnce(DoAll(SaveArg<0>(&buffer2), + SaveArg<1>(&done_task2))); + + InitReader(); + + Mock::VerifyAndClearExpectations(&callback_); + Mock::VerifyAndClearExpectations(&socket_); + + EXPECT_TRUE(CompareResult(buffer1, kTestMessage1)); + EXPECT_TRUE(CompareResult(buffer2, kTestMessage2)); + + // Verify that the reader starts reading again only after we've + // finished processing the previous message. + EXPECT_FALSE(socket_.read_pending()); + + done_task1->Run(); + + EXPECT_FALSE(socket_.read_pending()); + + done_task2->Run(); + + EXPECT_TRUE(socket_.read_pending()); +} + +// Receive two messages in one packet, and process the first one +// instantly. +TEST_F(MessageReaderTest, TwoMessages_Instant) { + CompoundBuffer* buffer2; + Task* done_task2; + + AddMessage(kTestMessage1); + AddMessage(kTestMessage2); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(2) + .WillOnce(CallDoneTask()) + .WillOnce(DoAll(SaveArg<0>(&buffer2), + SaveArg<1>(&done_task2))); + + InitReader(); + + Mock::VerifyAndClearExpectations(&callback_); + Mock::VerifyAndClearExpectations(&socket_); + + EXPECT_TRUE(CompareResult(buffer2, kTestMessage2)); + + // Verify that the reader starts reading again only after we've + // finished processing the second message. + EXPECT_FALSE(socket_.read_pending()); + + done_task2->Run(); + + EXPECT_TRUE(socket_.read_pending()); +} + +// Receive two messages in one packet, and process both of them +// instantly. +TEST_F(MessageReaderTest, TwoMessages_Instant2) { + AddMessage(kTestMessage1); + AddMessage(kTestMessage2); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(2) + .WillOnce(CallDoneTask()) + .WillOnce(CallDoneTask()); + + InitReader(); + + EXPECT_TRUE(socket_.read_pending()); +} + +// Receive two messages in separate packets. +TEST_F(MessageReaderTest, TwoMessages_Separately) { + CompoundBuffer* buffer; + Task* done_task; + + AddMessage(kTestMessage1); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(1) + .WillOnce(DoAll(SaveArg<0>(&buffer), + SaveArg<1>(&done_task))); + + InitReader(); + + Mock::VerifyAndClearExpectations(&callback_); + Mock::VerifyAndClearExpectations(&socket_); + + EXPECT_TRUE(CompareResult(buffer, kTestMessage1)); + + // Verify that the reader starts reading again only after we've + // finished processing the previous message. + EXPECT_FALSE(socket_.read_pending()); + + done_task->Run(); + + EXPECT_TRUE(socket_.read_pending()); + + // Write another message and verify that we receive it. + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(1) + .WillOnce(DoAll(SaveArg<0>(&buffer), + SaveArg<1>(&done_task))); + AddMessage(kTestMessage2); + + EXPECT_TRUE(CompareResult(buffer, kTestMessage2)); + + // Verify that the reader starts reading again only after we've + // finished processing the previous message. + EXPECT_FALSE(socket_.read_pending()); + + done_task->Run(); + + EXPECT_TRUE(socket_.read_pending()); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/protobuf_video_reader.cc b/remoting/protocol/protobuf_video_reader.cc index 02e06f8..a1dc2d2 100644 --- a/remoting/protocol/protobuf_video_reader.cc +++ b/remoting/protocol/protobuf_video_reader.cc @@ -25,8 +25,8 @@ void ProtobufVideoReader::Init(protocol::Session* session, video_stub_ = video_stub; } -void ProtobufVideoReader::OnNewData(VideoPacket* packet) { - video_stub_->ProcessVideoPacket(packet, new DeleteTask<VideoPacket>(packet)); +void ProtobufVideoReader::OnNewData(VideoPacket* packet, Task* done_task) { + video_stub_->ProcessVideoPacket(packet, done_task); } } // namespace protocol diff --git a/remoting/protocol/protobuf_video_reader.h b/remoting/protocol/protobuf_video_reader.h index 7305dd7..8e8ce42 100644 --- a/remoting/protocol/protobuf_video_reader.h +++ b/remoting/protocol/protobuf_video_reader.h @@ -23,7 +23,7 @@ class ProtobufVideoReader : public VideoReader { virtual void Init(protocol::Session* session, VideoStub* video_stub); private: - void OnNewData(VideoPacket* packet); + void OnNewData(VideoPacket* packet, Task* done_task); VideoPacketFormat::Encoding encoding_; diff --git a/remoting/protocol/ref_counted_message.h b/remoting/protocol/ref_counted_message.h deleted file mode 100644 index 24a8b52..0000000 --- a/remoting/protocol/ref_counted_message.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -// This is a wrapper class to help ref-counting a protobuf message. -// This file should only be inclued on host_message_dispatcher.cc and -// client_message_dispatche.cc. - -// A single protobuf can contain multiple messages that will be handled by -// different message handlers. We use this wrapper to ensure that the -// protobuf is only deleted after all the handlers have finished executing. - -#ifndef REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_ -#define REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_ - -#include "base/ref_counted.h" -#include "base/task.h" - -namespace remoting { -namespace protocol { - -template <typename T> -class RefCountedMessage : public base::RefCounted<RefCountedMessage<T> > { - public: - RefCountedMessage(T* message) : message_(message) { } - - T* message() { return message_.get(); } - - private: - scoped_ptr<T> message_; -}; - -// Dummy methods to destroy messages. -template <class T> -static void DeleteMessage(scoped_refptr<T> message) { } - -template <class T> -static Task* NewDeleteTask(scoped_refptr<T> message) { - return NewRunnableFunction(&DeleteMessage<T>, message); -} - -} // namespace protocol -} // namespace remoting - -#endif // REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_ |