summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorsergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-01-22 02:34:56 +0000
committersergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-01-22 02:34:56 +0000
commit6852d7d96f3643277e8dab49e3bfa0e482aafffe (patch)
tree65b97cf15fb3cb6a60b914abac6310c9ff22e710
parenta4f4692c776153b2b61c0a63d8b9c80e56613881 (diff)
downloadchromium_src-6852d7d96f3643277e8dab49e3bfa0e482aafffe.zip
chromium_src-6852d7d96f3643277e8dab49e3bfa0e482aafffe.tar.gz
chromium_src-6852d7d96f3643277e8dab49e3bfa0e482aafffe.tar.bz2
Changed MessageReader so that it doesn't read from the socket if there are
other messages being processed. Added unittests for MessageReader. BUG=None TEST=Unittests Review URL: http://codereview.chromium.org/6271004 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@72262 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--remoting/base/compound_buffer.h2
-rw-r--r--remoting/proto/event.proto14
-rw-r--r--remoting/proto/internal.proto9
-rw-r--r--remoting/protocol/client_message_dispatcher.cc15
-rw-r--r--remoting/protocol/client_message_dispatcher.h2
-rw-r--r--remoting/protocol/fake_session.cc4
-rw-r--r--remoting/protocol/fake_session.h15
-rw-r--r--remoting/protocol/host_message_dispatcher.cc39
-rw-r--r--remoting/protocol/host_message_dispatcher.h8
-rw-r--r--remoting/protocol/input_sender.cc11
-rw-r--r--remoting/protocol/message_decoder.cc7
-rw-r--r--remoting/protocol/message_decoder.h9
-rw-r--r--remoting/protocol/message_decoder_unittest.cc38
-rw-r--r--remoting/protocol/message_reader.cc57
-rw-r--r--remoting/protocol/message_reader.h58
-rw-r--r--remoting/protocol/message_reader_unittest.cc242
-rw-r--r--remoting/protocol/protobuf_video_reader.cc4
-rw-r--r--remoting/protocol/protobuf_video_reader.h2
-rw-r--r--remoting/protocol/ref_counted_message.h45
-rw-r--r--remoting/remoting.gyp1
20 files changed, 430 insertions, 152 deletions
diff --git a/remoting/base/compound_buffer.h b/remoting/base/compound_buffer.h
index eef8d13..cf92242 100644
--- a/remoting/base/compound_buffer.h
+++ b/remoting/base/compound_buffer.h
@@ -105,6 +105,8 @@ class CompoundBufferInputStream
explicit CompoundBufferInputStream(const CompoundBuffer* buffer);
virtual ~CompoundBufferInputStream();
+ int position() const { return position_; }
+
// google::protobuf::io::ZeroCopyInputStream interface.
virtual bool Next(const void** data, int* size);
virtual void BackUp(int count);
diff --git a/remoting/proto/event.proto b/remoting/proto/event.proto
index 19c734d..945d0b0 100644
--- a/remoting/proto/event.proto
+++ b/remoting/proto/event.proto
@@ -40,17 +40,3 @@ message MouseEvent {
optional MouseButton button = 5;
optional bool button_down = 6;
}
-
-// Defines an event message on the event channel.
-message Event {
- required int32 timestamp = 1; // Client timestamp for event
- optional bool dummy = 2; // Is this a dummy event?
-
- optional KeyEvent key = 3;
- optional MouseEvent mouse = 4;
-}
-
-// Message sent in the event channel.
-message EventMessage {
- repeated Event event = 1;
-}
diff --git a/remoting/proto/internal.proto b/remoting/proto/internal.proto
index 339fd59..c52dd2d 100644
--- a/remoting/proto/internal.proto
+++ b/remoting/proto/internal.proto
@@ -22,3 +22,12 @@ message ControlMessage {
optional BeginSessionRequest begin_session_request = 3;
optional BeginSessionResponse begin_session_response = 4;
}
+
+// Defines an event message on the event channel.
+message EventMessage {
+ required int32 timestamp = 1; // Client timestamp for event
+ optional bool dummy = 2; // Is this a dummy event?
+
+ optional KeyEvent key_event = 3;
+ optional MouseEvent mouse_event = 4;
+}
diff --git a/remoting/protocol/client_message_dispatcher.cc b/remoting/protocol/client_message_dispatcher.cc
index 9568685..e7b6dd6 100644
--- a/remoting/protocol/client_message_dispatcher.cc
+++ b/remoting/protocol/client_message_dispatcher.cc
@@ -11,7 +11,6 @@
#include "remoting/protocol/client_stub.h"
#include "remoting/protocol/input_stub.h"
#include "remoting/protocol/message_reader.h"
-#include "remoting/protocol/ref_counted_message.h"
#include "remoting/protocol/session.h"
namespace remoting {
@@ -39,18 +38,18 @@ void ClientMessageDispatcher::Initialize(
}
void ClientMessageDispatcher::OnControlMessageReceived(
- ControlMessage* message) {
- scoped_refptr<RefCountedMessage<ControlMessage> > ref_msg =
- new RefCountedMessage<ControlMessage>(message);
+ ControlMessage* message, Task* done_task) {
+ // TODO(sergeyu): Add message validation.
if (message->has_notify_resolution()) {
client_stub_->NotifyResolution(
- &message->notify_resolution(), NewDeleteTask(ref_msg));
+ &message->notify_resolution(), done_task);
} else if (message->has_begin_session_response()) {
client_stub_->BeginSessionResponse(
- &message->begin_session_response().login_status(),
- NewDeleteTask(ref_msg));
+ &message->begin_session_response().login_status(), done_task);
} else {
- NOTREACHED() << "Invalid control message received";
+ LOG(WARNING) << "Invalid control message received.";
+ done_task->Run();
+ delete done_task;
}
}
diff --git a/remoting/protocol/client_message_dispatcher.h b/remoting/protocol/client_message_dispatcher.h
index 8f0f5a6..88c7c14 100644
--- a/remoting/protocol/client_message_dispatcher.h
+++ b/remoting/protocol/client_message_dispatcher.h
@@ -40,7 +40,7 @@ class ClientMessageDispatcher {
void Initialize(protocol::Session* session, ClientStub* client_stub);
private:
- void OnControlMessageReceived(ControlMessage* message);
+ void OnControlMessageReceived(ControlMessage* message, Task* done_task);
// MessageReader that runs on the control channel. It runs a loop
// that parses data on the channel and then calls the corresponding handler
diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc
index ac344e0..fe50f2b 100644
--- a/remoting/protocol/fake_session.cc
+++ b/remoting/protocol/fake_session.cc
@@ -21,7 +21,7 @@ FakeSocket::FakeSocket()
FakeSocket::~FakeSocket() {
}
-void FakeSocket::AppendInputData(char* data, int data_size) {
+void FakeSocket::AppendInputData(const char* data, int data_size) {
input_data_.insert(input_data_.end(), data, data + data_size);
// Complete pending read if any.
if (read_pending_) {
@@ -78,7 +78,7 @@ FakeUdpSocket::FakeUdpSocket()
FakeUdpSocket::~FakeUdpSocket() {
}
-void FakeUdpSocket::AppendInputPacket(char* data, int data_size) {
+void FakeUdpSocket::AppendInputPacket(const char* data, int data_size) {
input_packets_.push_back(std::string());
input_packets_.back().assign(data, data + data_size);
diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h
index 62818e4..3ffe20b 100644
--- a/remoting/protocol/fake_session.h
+++ b/remoting/protocol/fake_session.h
@@ -27,10 +27,11 @@ class FakeSocket : public net::Socket {
FakeSocket();
virtual ~FakeSocket();
- const std::string& written_data() { return written_data_; }
+ const std::string& written_data() const { return written_data_; }
- void AppendInputData(char* data, int data_size);
- int input_pos() { return input_pos_; }
+ void AppendInputData(const char* data, int data_size);
+ int input_pos() const { return input_pos_; }
+ bool read_pending() const { return read_pending_; }
// net::Socket interface.
virtual int Read(net::IOBuffer* buf, int buf_len,
@@ -60,12 +61,12 @@ class FakeUdpSocket : public net::Socket {
FakeUdpSocket();
virtual ~FakeUdpSocket();
- const std::vector<std::string>& written_packets() {
+ const std::vector<std::string>& written_packets() const {
return written_packets_;
}
- void AppendInputPacket(char* data, int data_size);
- int input_pos() { return input_pos_; }
+ void AppendInputPacket(const char* data, int data_size);
+ int input_pos() const { return input_pos_; }
// net::Socket interface.
virtual int Read(net::IOBuffer* buf, int buf_len,
@@ -100,7 +101,7 @@ class FakeSession : public Session {
message_loop_ = message_loop;
}
- bool is_closed() { return closed_; }
+ bool is_closed() const { return closed_; }
virtual void SetStateChangeCallback(StateChangeCallback* callback);
diff --git a/remoting/protocol/host_message_dispatcher.cc b/remoting/protocol/host_message_dispatcher.cc
index 2554c9b..1e1eea8 100644
--- a/remoting/protocol/host_message_dispatcher.cc
+++ b/remoting/protocol/host_message_dispatcher.cc
@@ -11,7 +11,6 @@
#include "remoting/protocol/host_stub.h"
#include "remoting/protocol/input_stub.h"
#include "remoting/protocol/message_reader.h"
-#include "remoting/protocol/ref_counted_message.h"
#include "remoting/protocol/session.h"
namespace remoting {
@@ -47,34 +46,32 @@ void HostMessageDispatcher::Initialize(
NewCallback(this, &HostMessageDispatcher::OnControlMessageReceived));
}
-void HostMessageDispatcher::OnControlMessageReceived(ControlMessage* message) {
- scoped_refptr<RefCountedMessage<ControlMessage> > ref_msg =
- new RefCountedMessage<ControlMessage>(message);
+void HostMessageDispatcher::OnControlMessageReceived(
+ ControlMessage* message, Task* done_task) {
+ // TODO(sergeyu): Add message validation.
if (message->has_suggest_resolution()) {
- host_stub_->SuggestResolution(
- &message->suggest_resolution(), NewDeleteTask(ref_msg));
+ host_stub_->SuggestResolution(&message->suggest_resolution(), done_task);
} else if (message->has_begin_session_request()) {
host_stub_->BeginSessionRequest(
- &message->begin_session_request().credentials(),
- NewDeleteTask(ref_msg));
+ &message->begin_session_request().credentials(), done_task);
} else {
- NOTREACHED() << "Invalid control message received";
+ LOG(WARNING) << "Invalid control message received.";
+ done_task->Run();
+ delete done_task;
}
}
void HostMessageDispatcher::OnEventMessageReceived(
- EventMessage* message) {
- scoped_refptr<RefCountedMessage<EventMessage> > ref_msg =
- new RefCountedMessage<EventMessage>(message);
- for (int i = 0; i < message->event_size(); ++i) {
- if (message->event(i).has_key()) {
- input_stub_->InjectKeyEvent(
- &message->event(i).key(), NewDeleteTask(ref_msg));
- }
- if (message->event(i).has_mouse()) {
- input_stub_->InjectMouseEvent(
- &message->event(i).mouse(), NewDeleteTask(ref_msg));
- }
+ EventMessage* message, Task* done_task) {
+ // TODO(sergeyu): Add message validation.
+ if (message->has_key_event()) {
+ input_stub_->InjectKeyEvent(&message->key_event(), done_task);
+ } else if (message->has_mouse_event()) {
+ input_stub_->InjectMouseEvent(&message->mouse_event(), done_task);
+ } else {
+ LOG(WARNING) << "Invalid event message received.";
+ done_task->Run();
+ delete done_task;
}
}
diff --git a/remoting/protocol/host_message_dispatcher.h b/remoting/protocol/host_message_dispatcher.h
index 8ebed01..60afb27 100644
--- a/remoting/protocol/host_message_dispatcher.h
+++ b/remoting/protocol/host_message_dispatcher.h
@@ -11,12 +11,10 @@
#include "remoting/protocol/message_reader.h"
namespace remoting {
-
-class EventMessage;
-
namespace protocol {
class ControlMessage;
+class EventMessage;
class HostStub;
class InputStub;
class Session;
@@ -45,11 +43,11 @@ class HostMessageDispatcher {
private:
// This method is called by |control_channel_reader_| when a control
// message is received.
- void OnControlMessageReceived(ControlMessage* message);
+ void OnControlMessageReceived(ControlMessage* message, Task* done_task);
// This method is called by |event_channel_reader_| when a event
// message is received.
- void OnEventMessageReceived(EventMessage* message);
+ void OnEventMessageReceived(EventMessage* message, Task* done_task);
// MessageReader that runs on the control channel. It runs a loop
// that parses data on the channel and then delegates the message to this
diff --git a/remoting/protocol/input_sender.cc b/remoting/protocol/input_sender.cc
index 8831649..8ba90ca 100644
--- a/remoting/protocol/input_sender.cc
+++ b/remoting/protocol/input_sender.cc
@@ -9,6 +9,7 @@
#include "base/task.h"
#include "remoting/proto/event.pb.h"
+#include "remoting/proto/internal.pb.h"
#include "remoting/protocol/buffered_socket_writer.h"
#include "remoting/protocol/util.h"
@@ -27,19 +28,17 @@ InputSender::~InputSender() {
void InputSender::InjectKeyEvent(const KeyEvent* event, Task* done) {
EventMessage message;
- Event* evt = message.add_event();
// TODO(hclam): Provide timestamp.
- evt->set_timestamp(0);
- evt->mutable_key()->CopyFrom(*event);
+ message.set_timestamp(0);
+ message.mutable_key_event()->CopyFrom(*event);
buffered_writer_->Write(SerializeAndFrameMessage(message), done);
}
void InputSender::InjectMouseEvent(const MouseEvent* event, Task* done) {
EventMessage message;
- Event* evt = message.add_event();
// TODO(hclam): Provide timestamp.
- evt->set_timestamp(0);
- evt->mutable_mouse()->CopyFrom(*event);
+ message.set_timestamp(0);
+ message.mutable_mouse_event()->CopyFrom(*event);
buffered_writer_->Write(SerializeAndFrameMessage(message), done);
}
diff --git a/remoting/protocol/message_decoder.cc b/remoting/protocol/message_decoder.cc
index b460b4d..2e33229 100644
--- a/remoting/protocol/message_decoder.cc
+++ b/remoting/protocol/message_decoder.cc
@@ -25,7 +25,7 @@ void MessageDecoder::AddData(scoped_refptr<net::IOBuffer> data,
buffer_.Append(data, data_size);
}
-bool MessageDecoder::GetNextMessage(CompoundBuffer* message_buffer) {
+CompoundBuffer* MessageDecoder::GetNextMessage() {
// Determine the payload size. If we already know it then skip this part.
// We may not have enough data to determine the payload size so use a
// utility function to find out.
@@ -39,14 +39,15 @@ bool MessageDecoder::GetNextMessage(CompoundBuffer* message_buffer) {
// If the next payload size is still not known or we don't have enough
// data for parsing then exit.
if (!next_payload_known_ || buffer_.total_bytes() < next_payload_)
- return false;
+ return NULL;
+ CompoundBuffer* message_buffer = new CompoundBuffer();
message_buffer->CopyFrom(buffer_, 0, next_payload_);
message_buffer->Lock();
buffer_.CropFront(next_payload_);
next_payload_known_ = false;
- return true;
+ return message_buffer;
}
bool MessageDecoder::GetPayloadSize(int* size) {
diff --git a/remoting/protocol/message_decoder.h b/remoting/protocol/message_decoder.h
index 3e0745f7..0ba8b78 100644
--- a/remoting/protocol/message_decoder.h
+++ b/remoting/protocol/message_decoder.h
@@ -35,10 +35,11 @@ class MessageDecoder {
// its bytes are consumed.
void AddData(scoped_refptr<net::IOBuffer> data, int data_size);
- // Get next message from the stream and puts it in
- // |message_buffer|. Returns false if there are no complete messages
- // yet.
- bool GetNextMessage(CompoundBuffer* message_buffer);
+ // Returns next message from the stream. Ownership of the result is
+ // passed to the caller. Returns NULL if there are no complete
+ // messages yet, otherwise returns a buffer that contains one
+ // message.
+ CompoundBuffer* GetNextMessage();
private:
// Retrieves the read payload size of the current protocol buffer via |size|.
diff --git a/remoting/protocol/message_decoder_unittest.cc b/remoting/protocol/message_decoder_unittest.cc
index 81bb699..c00d57d 100644
--- a/remoting/protocol/message_decoder_unittest.cc
+++ b/remoting/protocol/message_decoder_unittest.cc
@@ -6,7 +6,9 @@
#include "base/scoped_ptr.h"
#include "base/stl_util-inl.h"
+#include "base/string_number_conversions.h"
#include "remoting/proto/event.pb.h"
+#include "remoting/proto/internal.pb.h"
#include "remoting/protocol/message_decoder.h"
#include "remoting/protocol/util.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -29,16 +31,13 @@ static void PrepareData(uint8** buffer, int* size) {
// Contains all encoded messages.
std::string encoded_data;
- EventMessage msg;
-
// Then append 10 update sequences to the data.
for (int i = 0; i < 10; ++i) {
- Event* event = msg.add_event();
- event->set_timestamp(i);
- event->mutable_key()->set_keycode(kTestKey + i);
- event->mutable_key()->set_pressed((i % 2) != 0);
+ EventMessage msg;
+ msg.set_timestamp(i);
+ msg.mutable_key_event()->set_keycode(kTestKey + i);
+ msg.mutable_key_event()->set_pressed((i % 2) != 0);
AppendMessage(msg, &encoded_data);
- msg.Clear();
}
*size = encoded_data.length();
@@ -62,25 +61,27 @@ void SimulateReadSequence(const int read_sequence[], int sequence_size) {
// Then feed the protocol decoder using the above generated data and the
// read pattern.
std::list<EventMessage*> message_list;
- for (int i = 0; i < size;) {
+ for (int pos = 0; pos < size;) {
+ SCOPED_TRACE("Input position: " + base::IntToString(pos));
+
// First generate the amount to feed the decoder.
- int read = std::min(size - i, read_sequence[i % sequence_size]);
+ int read = std::min(size - pos, read_sequence[pos % sequence_size]);
// And then prepare an IOBuffer for feeding it.
scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(read));
- memcpy(buffer->data(), test_data + i, read);
+ memcpy(buffer->data(), test_data + pos, read);
decoder.AddData(buffer, read);
while (true) {
- CompoundBuffer message;
- if (!decoder.GetNextMessage(&message))
+ scoped_ptr<CompoundBuffer> message(decoder.GetNextMessage());
+ if (!message.get())
break;
EventMessage* event = new EventMessage();
- CompoundBufferInputStream stream(&message);
+ CompoundBufferInputStream stream(message.get());
ASSERT_TRUE(event->ParseFromZeroCopyStream(&stream));
message_list.push_back(event);
}
- i += read;
+ pos += read;
}
// Then verify the decoded messages.
@@ -90,15 +91,16 @@ void SimulateReadSequence(const int read_sequence[], int sequence_size) {
for (std::list<EventMessage*>::iterator it =
message_list.begin();
it != message_list.end(); ++it) {
+ SCOPED_TRACE("Message " + base::IntToString(index));
+
EventMessage* message = *it;
// Partial update stream.
- EXPECT_EQ(message->event_size(), 1);
- EXPECT_TRUE(message->event(0).has_key());
+ EXPECT_TRUE(message->has_key_event());
// TODO(sergeyu): Don't use index here. Instead store the expected values
// in an array.
- EXPECT_EQ(kTestKey + index, message->event(0).key().keycode());
- EXPECT_EQ((index % 2) != 0, message->event(0).key().pressed());
+ EXPECT_EQ(kTestKey + index, message->key_event().keycode());
+ EXPECT_EQ((index % 2) != 0, message->key_event().pressed());
++index;
}
STLDeleteElements(&message_list);
diff --git a/remoting/protocol/message_reader.cc b/remoting/protocol/message_reader.cc
index 7b818ee..3bbfd59 100644
--- a/remoting/protocol/message_reader.cc
+++ b/remoting/protocol/message_reader.cc
@@ -18,12 +18,16 @@ static const int kReadBufferSize = 4096;
MessageReader::MessageReader()
: socket_(NULL),
+ message_loop_(NULL),
+ read_pending_(false),
+ pending_messages_(0),
closed_(false),
ALLOW_THIS_IN_INITIALIZER_LIST(
read_callback_(this, &MessageReader::OnRead)) {
}
MessageReader::~MessageReader() {
+ CHECK_EQ(pending_messages_, 0);
}
void MessageReader::Init(net::Socket* socket,
@@ -31,21 +35,27 @@ void MessageReader::Init(net::Socket* socket,
message_received_callback_.reset(callback);
DCHECK(socket);
socket_ = socket;
+ message_loop_ = MessageLoop::current();
DoRead();
}
void MessageReader::DoRead() {
- while (!closed_) {
+ DCHECK(!read_pending_);
+
+ // Don't try to read again if there is another read pending or we
+ // have messages that we haven't finished processing yet.
+ while (!closed_ && !read_pending_ && pending_messages_ == 0) {
read_buffer_ = new net::IOBuffer(kReadBufferSize);
int result = socket_->Read(
read_buffer_, kReadBufferSize, &read_callback_);
HandleReadResult(result);
- if (result < 0)
- break;
}
}
void MessageReader::OnRead(int result) {
+ DCHECK(read_pending_);
+ read_pending_ = false;
+
if (!closed_) {
HandleReadResult(result);
DoRead();
@@ -53,12 +63,17 @@ void MessageReader::OnRead(int result) {
}
void MessageReader::HandleReadResult(int result) {
+ if (closed_)
+ return;
+
if (result > 0) {
OnDataReceived(read_buffer_, result);
} else {
if (result == net::ERR_CONNECTION_CLOSED) {
closed_ = true;
- } else if (result != net::ERR_IO_PENDING) {
+ } else if (result == net::ERR_IO_PENDING) {
+ read_pending_ = true;
+ } else {
LOG(ERROR) << "Read() returned error " << result;
}
}
@@ -67,14 +82,42 @@ void MessageReader::HandleReadResult(int result) {
void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) {
message_decoder_.AddData(data, data_size);
+ // Get list of all new messages first, and then call the callback
+ // for all of them.
+ std::vector<CompoundBuffer*> new_messages;
while (true) {
- CompoundBuffer buffer;
- if (!message_decoder_.GetNextMessage(&buffer))
+ CompoundBuffer* buffer = message_decoder_.GetNextMessage();
+ if (!buffer)
break;
+ new_messages.push_back(buffer);
+ }
+
+ pending_messages_ += new_messages.size();
- message_received_callback_->Run(&buffer);
+ for (std::vector<CompoundBuffer*>::iterator it = new_messages.begin();
+ it != new_messages.end(); ++it) {
+ message_received_callback_->Run(*it, NewRunnableMethod(
+ this, &MessageReader::OnMessageDone, *it));
}
}
+void MessageReader::OnMessageDone(CompoundBuffer* message) {
+ delete message;
+ ProcessDoneEvent();
+}
+
+void MessageReader::ProcessDoneEvent() {
+ if (MessageLoop::current() != message_loop_) {
+ message_loop_->PostTask(FROM_HERE, NewRunnableMethod(
+ this, &MessageReader::ProcessDoneEvent));
+ return;
+ }
+
+ pending_messages_--;
+ DCHECK_GE(pending_messages_, 0);
+
+ DoRead(); // Start next read if neccessary.
+}
+
} // namespace protocol
} // namespace remoting
diff --git a/remoting/protocol/message_reader.h b/remoting/protocol/message_reader.h
index d493778..d14cb90 100644
--- a/remoting/protocol/message_reader.h
+++ b/remoting/protocol/message_reader.h
@@ -13,6 +13,8 @@
#include "remoting/base/compound_buffer.h"
#include "remoting/protocol/message_decoder.h"
+class MessageLoop;
+
namespace net {
class IOBuffer;
class Socket;
@@ -22,10 +24,23 @@ namespace remoting {
namespace protocol {
// MessageReader reads data from the socket asynchronously and calls
-// callback for each message it receives
-class MessageReader {
+// callback for each message it receives. It stops calling the
+// callback as soon as the socket is closed, so the socket should
+// always be closed before the callback handler is destroyed.
+//
+// In order to throttle the stream, MessageReader doesn't try to read
+// new data from the socket until all previously received messages are
+// processed by the receiver (|done_task| is called for each message).
+// It is still possible that the MessageReceivedCallback is called
+// twice (so that there is more than one outstanding message),
+// e.g. when we the sender sends multiple messages in one TCP packet.
+class MessageReader : public base::RefCountedThreadSafe<MessageReader> {
public:
- typedef Callback1<CompoundBuffer*>::Type MessageReceivedCallback;
+ // The callback is given ownership of the second argument
+ // (|done_task|). The buffer (first argument) is owned by
+ // MessageReader and is freed when the task specified by the second
+ // argument is called.
+ typedef Callback2<CompoundBuffer*, Task*>::Type MessageReceivedCallback;
MessageReader();
virtual ~MessageReader();
@@ -39,9 +54,23 @@ class MessageReader {
void OnRead(int result);
void HandleReadResult(int result);
void OnDataReceived(net::IOBuffer* data, int data_size);
+ void OnMessageDone(CompoundBuffer* message);
+ void ProcessDoneEvent();
net::Socket* socket_;
+ // The network message loop this object runs on.
+ MessageLoop* message_loop_;
+
+ // Set to true, when we have a socket read pending, and expecting
+ // OnRead() to be called when new data is received.
+ bool read_pending_;
+
+ // Number of messages that we received, but haven't finished
+ // processing yet, i.e. |done_task| hasn't been called for these
+ // messages.
+ int pending_messages_;
+
bool closed_;
scoped_refptr<net::IOBuffer> read_buffer_;
net::CompletionCallbackImpl<MessageReader> read_callback_;
@@ -52,33 +81,46 @@ class MessageReader {
scoped_ptr<MessageReceivedCallback> message_received_callback_;
};
+// Version of MessageReader for protocol buffer messages, that parses
+// each incoming message.
template <class T>
class ProtobufMessageReader {
public:
- typedef typename Callback1<T*>::Type MessageReceivedCallback;
+ typedef typename Callback2<T*, Task*>::Type MessageReceivedCallback;
ProtobufMessageReader() { };
~ProtobufMessageReader() { };
void Init(net::Socket* socket, MessageReceivedCallback* callback) {
message_received_callback_.reset(callback);
- message_reader_.Init(
+ message_reader_ = new MessageReader();
+ message_reader_->Init(
socket, NewCallback(this, &ProtobufMessageReader<T>::OnNewData));
}
private:
- void OnNewData(CompoundBuffer* buffer) {
+ void OnNewData(CompoundBuffer* buffer, Task* done_task) {
T* message = new T();
CompoundBufferInputStream stream(buffer);
bool ret = message->ParseFromZeroCopyStream(&stream);
if (!ret) {
+ LOG(WARNING) << "Received message that is not a valid protocol buffer.";
delete message;
} else {
- message_received_callback_->Run(message);
+ DCHECK_EQ(stream.position(), buffer->total_bytes());
+ message_received_callback_->Run(
+ message, NewRunnableFunction(
+ &ProtobufMessageReader<T>::OnDone, message, done_task));
}
}
- MessageReader message_reader_;
+ static void OnDone(T* message, Task* done_task) {
+ delete message;
+ done_task->Run();
+ delete done_task;
+ }
+
+ scoped_refptr<MessageReader> message_reader_;
scoped_ptr<MessageReceivedCallback> message_received_callback_;
};
diff --git a/remoting/protocol/message_reader_unittest.cc b/remoting/protocol/message_reader_unittest.cc
new file mode 100644
index 0000000..5f17a1a
--- /dev/null
+++ b/remoting/protocol/message_reader_unittest.cc
@@ -0,0 +1,242 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <string>
+
+#include "base/message_loop.h"
+#include "net/socket/socket.h"
+#include "remoting/protocol/fake_session.h"
+#include "remoting/protocol/message_reader.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "third_party/libjingle/source/talk/base/byteorder.h"
+
+using testing::_;
+using testing::DoAll;
+using testing::Mock;
+using testing::SaveArg;
+
+namespace remoting {
+namespace protocol {
+
+namespace {
+const char kTestMessage1[] = "Message1";
+const char kTestMessage2[] = "Message2";
+
+ACTION(CallDoneTask) {
+ arg1->Run();
+ delete arg1;
+}
+}
+
+class MockMessageReceivedCallback {
+ public:
+ MOCK_METHOD2(OnMessage, void(CompoundBuffer*, Task*));
+};
+
+class MessageReaderTest : public testing::Test {
+ protected:
+ virtual void SetUp() {
+ reader_ = new MessageReader();
+ }
+
+ void InitReader() {
+ reader_->Init(&socket_, NewCallback(
+ &callback_, &MockMessageReceivedCallback::OnMessage));
+ }
+
+ void AddMessage(const std::string& message) {
+ std::string data = std::string(4, ' ') + message;
+ talk_base::SetBE32(const_cast<char*>(data.data()), message.size());
+
+ socket_.AppendInputData(data.data(), data.size());
+ }
+
+ bool CompareResult(CompoundBuffer* buffer, const std::string& expected) {
+ std::string result(buffer->total_bytes(), ' ');
+ buffer->CopyTo(const_cast<char*>(result.data()), result.size());
+ return result == expected;
+ }
+
+ // MessageLoop must be first here, so that is is destroyed the last.
+ MessageLoop message_loop_;
+
+ scoped_refptr<MessageReader> reader_;
+ FakeSocket socket_;
+ MockMessageReceivedCallback callback_;
+};
+
+// Receive one message and process it with delay
+TEST_F(MessageReaderTest, OneMessage_Delay) {
+ CompoundBuffer* buffer;
+ Task* done_task;
+
+ AddMessage(kTestMessage1);
+
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(1)
+ .WillOnce(DoAll(SaveArg<0>(&buffer),
+ SaveArg<1>(&done_task)));
+
+ InitReader();
+
+ Mock::VerifyAndClearExpectations(&callback_);
+ Mock::VerifyAndClearExpectations(&socket_);
+
+ EXPECT_TRUE(CompareResult(buffer, kTestMessage1));
+
+ // Verify that the reader starts reading again only after we've
+ // finished processing the previous message.
+ EXPECT_FALSE(socket_.read_pending());
+
+ done_task->Run();
+
+ EXPECT_TRUE(socket_.read_pending());
+}
+
+// Receive one message and process it instantly.
+TEST_F(MessageReaderTest, OneMessage_Instant) {
+ AddMessage(kTestMessage1);
+
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(1)
+ .WillOnce(CallDoneTask());
+
+ InitReader();
+
+ EXPECT_TRUE(socket_.read_pending());
+}
+
+// Receive two messages in one packet.
+TEST_F(MessageReaderTest, TwoMessages_Together) {
+ CompoundBuffer* buffer1;
+ Task* done_task1;
+ CompoundBuffer* buffer2;
+ Task* done_task2;
+
+ AddMessage(kTestMessage1);
+ AddMessage(kTestMessage2);
+
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(2)
+ .WillOnce(DoAll(SaveArg<0>(&buffer1),
+ SaveArg<1>(&done_task1)))
+ .WillOnce(DoAll(SaveArg<0>(&buffer2),
+ SaveArg<1>(&done_task2)));
+
+ InitReader();
+
+ Mock::VerifyAndClearExpectations(&callback_);
+ Mock::VerifyAndClearExpectations(&socket_);
+
+ EXPECT_TRUE(CompareResult(buffer1, kTestMessage1));
+ EXPECT_TRUE(CompareResult(buffer2, kTestMessage2));
+
+ // Verify that the reader starts reading again only after we've
+ // finished processing the previous message.
+ EXPECT_FALSE(socket_.read_pending());
+
+ done_task1->Run();
+
+ EXPECT_FALSE(socket_.read_pending());
+
+ done_task2->Run();
+
+ EXPECT_TRUE(socket_.read_pending());
+}
+
+// Receive two messages in one packet, and process the first one
+// instantly.
+TEST_F(MessageReaderTest, TwoMessages_Instant) {
+ CompoundBuffer* buffer2;
+ Task* done_task2;
+
+ AddMessage(kTestMessage1);
+ AddMessage(kTestMessage2);
+
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(2)
+ .WillOnce(CallDoneTask())
+ .WillOnce(DoAll(SaveArg<0>(&buffer2),
+ SaveArg<1>(&done_task2)));
+
+ InitReader();
+
+ Mock::VerifyAndClearExpectations(&callback_);
+ Mock::VerifyAndClearExpectations(&socket_);
+
+ EXPECT_TRUE(CompareResult(buffer2, kTestMessage2));
+
+ // Verify that the reader starts reading again only after we've
+ // finished processing the second message.
+ EXPECT_FALSE(socket_.read_pending());
+
+ done_task2->Run();
+
+ EXPECT_TRUE(socket_.read_pending());
+}
+
+// Receive two messages in one packet, and process both of them
+// instantly.
+TEST_F(MessageReaderTest, TwoMessages_Instant2) {
+ AddMessage(kTestMessage1);
+ AddMessage(kTestMessage2);
+
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(2)
+ .WillOnce(CallDoneTask())
+ .WillOnce(CallDoneTask());
+
+ InitReader();
+
+ EXPECT_TRUE(socket_.read_pending());
+}
+
+// Receive two messages in separate packets.
+TEST_F(MessageReaderTest, TwoMessages_Separately) {
+ CompoundBuffer* buffer;
+ Task* done_task;
+
+ AddMessage(kTestMessage1);
+
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(1)
+ .WillOnce(DoAll(SaveArg<0>(&buffer),
+ SaveArg<1>(&done_task)));
+
+ InitReader();
+
+ Mock::VerifyAndClearExpectations(&callback_);
+ Mock::VerifyAndClearExpectations(&socket_);
+
+ EXPECT_TRUE(CompareResult(buffer, kTestMessage1));
+
+ // Verify that the reader starts reading again only after we've
+ // finished processing the previous message.
+ EXPECT_FALSE(socket_.read_pending());
+
+ done_task->Run();
+
+ EXPECT_TRUE(socket_.read_pending());
+
+ // Write another message and verify that we receive it.
+ EXPECT_CALL(callback_, OnMessage(_, _))
+ .Times(1)
+ .WillOnce(DoAll(SaveArg<0>(&buffer),
+ SaveArg<1>(&done_task)));
+ AddMessage(kTestMessage2);
+
+ EXPECT_TRUE(CompareResult(buffer, kTestMessage2));
+
+ // Verify that the reader starts reading again only after we've
+ // finished processing the previous message.
+ EXPECT_FALSE(socket_.read_pending());
+
+ done_task->Run();
+
+ EXPECT_TRUE(socket_.read_pending());
+}
+
+} // namespace protocol
+} // namespace remoting
diff --git a/remoting/protocol/protobuf_video_reader.cc b/remoting/protocol/protobuf_video_reader.cc
index 02e06f8..a1dc2d2 100644
--- a/remoting/protocol/protobuf_video_reader.cc
+++ b/remoting/protocol/protobuf_video_reader.cc
@@ -25,8 +25,8 @@ void ProtobufVideoReader::Init(protocol::Session* session,
video_stub_ = video_stub;
}
-void ProtobufVideoReader::OnNewData(VideoPacket* packet) {
- video_stub_->ProcessVideoPacket(packet, new DeleteTask<VideoPacket>(packet));
+void ProtobufVideoReader::OnNewData(VideoPacket* packet, Task* done_task) {
+ video_stub_->ProcessVideoPacket(packet, done_task);
}
} // namespace protocol
diff --git a/remoting/protocol/protobuf_video_reader.h b/remoting/protocol/protobuf_video_reader.h
index 7305dd7..8e8ce42 100644
--- a/remoting/protocol/protobuf_video_reader.h
+++ b/remoting/protocol/protobuf_video_reader.h
@@ -23,7 +23,7 @@ class ProtobufVideoReader : public VideoReader {
virtual void Init(protocol::Session* session, VideoStub* video_stub);
private:
- void OnNewData(VideoPacket* packet);
+ void OnNewData(VideoPacket* packet, Task* done_task);
VideoPacketFormat::Encoding encoding_;
diff --git a/remoting/protocol/ref_counted_message.h b/remoting/protocol/ref_counted_message.h
deleted file mode 100644
index 24a8b52..0000000
--- a/remoting/protocol/ref_counted_message.h
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright (c) 2010 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-// This is a wrapper class to help ref-counting a protobuf message.
-// This file should only be inclued on host_message_dispatcher.cc and
-// client_message_dispatche.cc.
-
-// A single protobuf can contain multiple messages that will be handled by
-// different message handlers. We use this wrapper to ensure that the
-// protobuf is only deleted after all the handlers have finished executing.
-
-#ifndef REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_
-#define REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_
-
-#include "base/ref_counted.h"
-#include "base/task.h"
-
-namespace remoting {
-namespace protocol {
-
-template <typename T>
-class RefCountedMessage : public base::RefCounted<RefCountedMessage<T> > {
- public:
- RefCountedMessage(T* message) : message_(message) { }
-
- T* message() { return message_.get(); }
-
- private:
- scoped_ptr<T> message_;
-};
-
-// Dummy methods to destroy messages.
-template <class T>
-static void DeleteMessage(scoped_refptr<T> message) { }
-
-template <class T>
-static Task* NewDeleteTask(scoped_refptr<T> message) {
- return NewRunnableFunction(&DeleteMessage<T>, message);
-}
-
-} // namespace protocol
-} // namespace remoting
-
-#endif // REMOTING_PROTOCOL_REF_COUNTED_MESSAGE_H_
diff --git a/remoting/remoting.gyp b/remoting/remoting.gyp
index 8b6b5f1..70dc0fd 100644
--- a/remoting/remoting.gyp
+++ b/remoting/remoting.gyp
@@ -486,6 +486,7 @@
'protocol/fake_session.h',
'protocol/jingle_session_unittest.cc',
'protocol/message_decoder_unittest.cc',
+ 'protocol/message_reader_unittest.cc',
'protocol/mock_objects.h',
'protocol/rtp_video_reader_unittest.cc',
'protocol/rtp_video_writer_unittest.cc',