diff options
-rw-r--r-- | remoting/protocol/fake_session.cc | 18 | ||||
-rw-r--r-- | remoting/protocol/fake_session.h | 4 | ||||
-rw-r--r-- | remoting/protocol/message_reader.cc | 19 | ||||
-rw-r--r-- | remoting/protocol/message_reader.h | 4 | ||||
-rw-r--r-- | remoting/protocol/message_reader_unittest.cc | 54 |
5 files changed, 92 insertions, 7 deletions
diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index 6579208..822cfea 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -7,6 +7,7 @@ #include "base/message_loop.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "testing/gtest/include/gtest/gtest.h" namespace remoting { namespace protocol { @@ -15,13 +16,16 @@ const char kTestJid[] = "host1@gmail.com/chromoting123"; FakeSocket::FakeSocket() : read_pending_(false), - input_pos_(0) { + input_pos_(0), + message_loop_(MessageLoop::current()) { } FakeSocket::~FakeSocket() { + EXPECT_EQ(message_loop_, MessageLoop::current()); } void FakeSocket::AppendInputData(const char* data, int data_size) { + EXPECT_EQ(message_loop_, MessageLoop::current()); input_data_.insert(input_data_.end(), data, data + data_size); // Complete pending read if any. if (read_pending_) { @@ -39,6 +43,7 @@ void FakeSocket::AppendInputData(const char* data, int data_size) { int FakeSocket::Read(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); if (input_pos_ < static_cast<int>(input_data_.size())) { int result = std::min(buf_len, static_cast<int>(input_data_.size()) - input_pos_); @@ -56,6 +61,7 @@ int FakeSocket::Read(net::IOBuffer* buf, int buf_len, int FakeSocket::Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); written_data_.insert(written_data_.end(), buf->data(), buf->data() + buf_len); return buf_len; @@ -71,6 +77,7 @@ bool FakeSocket::SetSendBufferSize(int32 size) { } int FakeSocket::Connect(net::CompletionCallback* callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); return net::OK; } @@ -79,6 +86,7 @@ void FakeSocket::Disconnect() { } bool FakeSocket::IsConnected() const { + EXPECT_EQ(message_loop_, MessageLoop::current()); return true; } @@ -100,6 +108,7 @@ int FakeSocket::GetLocalAddress( } const net::BoundNetLog& FakeSocket::NetLog() const { + EXPECT_EQ(message_loop_, MessageLoop::current()); return net_log_; } @@ -133,13 +142,16 @@ base::TimeDelta FakeSocket::GetConnectTimeMicros() const { FakeUdpSocket::FakeUdpSocket() : read_pending_(false), - input_pos_(0) { + input_pos_(0), + message_loop_(MessageLoop::current()) { } FakeUdpSocket::~FakeUdpSocket() { + EXPECT_EQ(message_loop_, MessageLoop::current()); } void FakeUdpSocket::AppendInputPacket(const char* data, int data_size) { + EXPECT_EQ(message_loop_, MessageLoop::current()); input_packets_.push_back(std::string()); input_packets_.back().assign(data, data + data_size); @@ -156,6 +168,7 @@ void FakeUdpSocket::AppendInputPacket(const char* data, int data_size) { int FakeUdpSocket::Read(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); if (input_pos_ < static_cast<int>(input_packets_.size())) { int result = std::min( buf_len, static_cast<int>(input_packets_[input_pos_].size())); @@ -173,6 +186,7 @@ int FakeUdpSocket::Read(net::IOBuffer* buf, int buf_len, int FakeUdpSocket::Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); written_packets_.push_back(std::string()); written_packets_.back().assign(buf->data(), buf->data() + buf_len); return buf_len; diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index e89a208..13409d0 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -73,6 +73,8 @@ class FakeSocket : public net::StreamSocket { net::BoundNetLog net_log_; + MessageLoop* message_loop_; + DISALLOW_COPY_AND_ASSIGN(FakeSocket); }; @@ -110,6 +112,8 @@ class FakeUdpSocket : public net::Socket { std::vector<std::string> input_packets_; int input_pos_; + MessageLoop* message_loop_; + DISALLOW_COPY_AND_ASSIGN(FakeUdpSocket); }; diff --git a/remoting/protocol/message_reader.cc b/remoting/protocol/message_reader.cc index e5825b1..be9c455 100644 --- a/remoting/protocol/message_reader.cc +++ b/remoting/protocol/message_reader.cc @@ -1,9 +1,11 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "remoting/protocol/message_reader.h" +#include "base/bind.h" +#include "base/callback.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/socket/socket.h" @@ -89,14 +91,25 @@ void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) { pending_messages_ += new_messages.size(); + // TODO(lambroslambrou): MessageLoopProxy::current() will not work from the + // plugin thread if this code is compiled into a separate binary. Fix this. for (std::vector<CompoundBuffer*>::iterator it = new_messages.begin(); it != new_messages.end(); ++it) { message_received_callback_->Run(*it, NewRunnableMethod( - this, &MessageReader::OnMessageDone, *it)); + this, &MessageReader::OnMessageDone, *it, + base::MessageLoopProxy::current())); } } -void MessageReader::OnMessageDone(CompoundBuffer* message) { +void MessageReader::OnMessageDone( + CompoundBuffer* message, + scoped_refptr<base::MessageLoopProxy> message_loop) { + if (!message_loop->BelongsToCurrentThread()) { + message_loop->PostTask( + FROM_HERE, + base::Bind(&MessageReader::OnMessageDone, this, message, message_loop)); + return; + } delete message; ProcessDoneEvent(); } diff --git a/remoting/protocol/message_reader.h b/remoting/protocol/message_reader.h index bf8aae4..c8d214c 100644 --- a/remoting/protocol/message_reader.h +++ b/remoting/protocol/message_reader.h @@ -8,6 +8,7 @@ #include "base/callback.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" +#include "base/message_loop_proxy.h" #include "base/task.h" #include "net/base/completion_callback.h" #include "remoting/base/compound_buffer.h" @@ -54,7 +55,8 @@ class MessageReader : public base::RefCountedThreadSafe<MessageReader> { void OnRead(int result); void HandleReadResult(int result); void OnDataReceived(net::IOBuffer* data, int data_size); - void OnMessageDone(CompoundBuffer* message); + void OnMessageDone(CompoundBuffer* message, + scoped_refptr<base::MessageLoopProxy> message_loop); void ProcessDoneEvent(); net::Socket* socket_; diff --git a/remoting/protocol/message_reader_unittest.cc b/remoting/protocol/message_reader_unittest.cc index 84e2e31..9aba90a 100644 --- a/remoting/protocol/message_reader_unittest.cc +++ b/remoting/protocol/message_reader_unittest.cc @@ -1,9 +1,14 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include <string> +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/message_loop.h" +#include "base/synchronization/waitable_event.h" +#include "base/threading/thread.h" #include "net/socket/socket.h" #include "remoting/protocol/fake_session.h" #include "remoting/protocol/message_reader.h" @@ -35,6 +40,20 @@ class MockMessageReceivedCallback { }; class MessageReaderTest : public testing::Test { + public: + MessageReaderTest() + : other_thread_("SecondTestThread"), + run_task_finished_(false, false) { + } + + void RunDoneTaskOnOtherThread(CompoundBuffer* buffer, Task* done_task) { + other_thread_.message_loop()->PostTask( + FROM_HERE, + base::Bind(&MessageReaderTest::RunAndDeleteTask, + base::Unretained(this), + done_task)); + } + protected: virtual void SetUp() { reader_ = new MessageReader(); @@ -61,8 +80,12 @@ class MessageReaderTest : public testing::Test { void RunAndDeleteTask(Task* task) { task->Run(); delete task; + run_task_finished_.Signal(); } + MessageLoop message_loop_; + base::Thread other_thread_; + base::WaitableEvent run_task_finished_; scoped_refptr<MessageReader> reader_; FakeSocket socket_; MockMessageReceivedCallback callback_; @@ -239,5 +262,34 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) { EXPECT_TRUE(socket_.read_pending()); } +// Verify that socket operations occur on same thread, even when the OnMessage() +// callback triggers |done_task| to run on a different thread. +TEST_F(MessageReaderTest, UseSocketOnCorrectThread) { + + AddMessage(kTestMessage1); + other_thread_.Start(); + + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(1) + .WillOnce(Invoke(this, &MessageReaderTest::RunDoneTaskOnOtherThread)); + + InitReader(); + + run_task_finished_.Wait(); + message_loop_.RunAllPending(); + + // Write another message and verify that we receive it. + CompoundBuffer* buffer; + Task* done_task; + EXPECT_CALL(callback_, OnMessage(_, _)) + .Times(1) + .WillOnce(DoAll(SaveArg<0>(&buffer), + SaveArg<1>(&done_task))); + AddMessage(kTestMessage2); + EXPECT_TRUE(CompareResult(buffer, kTestMessage2)); + + RunAndDeleteTask(done_task); +} + } // namespace protocol } // namespace remoting |