diff options
author | sergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-23 21:37:04 +0000 |
---|---|---|
committer | sergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-23 21:37:04 +0000 |
commit | eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c (patch) | |
tree | 62dde3058ac403dc7cff3d79810c4051027ec56f /remoting/protocol | |
parent | f1e0be7d135642015c9eedadc33d8878c35dd544 (diff) | |
download | chromium_src-eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c.zip chromium_src-eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c.tar.gz chromium_src-eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c.tar.bz2 |
Make MessageReader class not ref-counted.
Previously MessageReader class was ref-counted, which means in some cases
it could outlive the socket. Specifically the problem is when received messages
are processed asynchronously and the callback for one of the messages is called
after the socket is destroyed.
BUG=139257
Review URL: https://chromiumcodereview.appspot.com/10870021
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@153079 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting/protocol')
-rw-r--r-- | remoting/protocol/message_reader.cc | 32 | ||||
-rw-r--r-- | remoting/protocol/message_reader.h | 28 | ||||
-rw-r--r-- | remoting/protocol/message_reader_unittest.cc | 42 |
3 files changed, 33 insertions, 69 deletions
diff --git a/remoting/protocol/message_reader.cc b/remoting/protocol/message_reader.cc index 441d729..fbbe9be 100644 --- a/remoting/protocol/message_reader.cc +++ b/remoting/protocol/message_reader.cc @@ -6,6 +6,7 @@ #include "base/bind.h" #include "base/callback.h" +#include "base/compiler_specific.h" #include "base/location.h" #include "base/thread_task_runner_handle.h" #include "net/base/io_buffer.h" @@ -23,11 +24,13 @@ MessageReader::MessageReader() : socket_(NULL), read_pending_(false), pending_messages_(0), - closed_(false) { + closed_(false), + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { } void MessageReader::Init(net::Socket* socket, const MessageReceivedCallback& callback) { + DCHECK(CalledOnValidThread()); message_received_callback_ = callback; DCHECK(socket); socket_ = socket; @@ -39,18 +42,20 @@ MessageReader::~MessageReader() { } void MessageReader::DoRead() { + DCHECK(CalledOnValidThread()); // Don't try to read again if there is another read pending or we // have messages that we haven't finished processing yet. while (!closed_ && !read_pending_ && pending_messages_ == 0) { read_buffer_ = new net::IOBuffer(kReadBufferSize); int result = socket_->Read( - read_buffer_, kReadBufferSize, base::Bind(&MessageReader::OnRead, - base::Unretained(this))); + read_buffer_, kReadBufferSize, + base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr())); HandleReadResult(result); } } void MessageReader::OnRead(int result) { + DCHECK(CalledOnValidThread()); DCHECK(read_pending_); read_pending_ = false; @@ -61,6 +66,7 @@ void MessageReader::OnRead(int result) { } void MessageReader::HandleReadResult(int result) { + DCHECK(CalledOnValidThread()); if (closed_) return; @@ -78,6 +84,7 @@ void MessageReader::HandleReadResult(int result) { } void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) { + DCHECK(CalledOnValidThread()); message_decoder_.AddData(data, data_size); // Get list of all new messages first, and then call the callback @@ -96,27 +103,18 @@ void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) { it != new_messages.end(); ++it) { message_received_callback_.Run( scoped_ptr<CompoundBuffer>(*it), - base::Bind(&MessageReader::OnMessageDone, this, - base::ThreadTaskRunnerHandle::Get())); + base::Bind(&MessageReader::OnMessageDone, + weak_factory_.GetWeakPtr())); } } -void MessageReader::OnMessageDone( - scoped_refptr<base::SingleThreadTaskRunner> task_runner) { - if (task_runner->BelongsToCurrentThread()) { - ProcessDoneEvent(); - } else { - task_runner->PostTask( - FROM_HERE, base::Bind(&MessageReader::ProcessDoneEvent, this)); - } -} - -void MessageReader::ProcessDoneEvent() { +void MessageReader::OnMessageDone() { + DCHECK(CalledOnValidThread()); pending_messages_--; DCHECK_GE(pending_messages_, 0); if (!read_pending_) - DoRead(); // Start next read if neccessary. + DoRead(); // Start next read if necessary. } } // namespace protocol diff --git a/remoting/protocol/message_reader.h b/remoting/protocol/message_reader.h index b9a373b..efba5e1 100644 --- a/remoting/protocol/message_reader.h +++ b/remoting/protocol/message_reader.h @@ -7,15 +7,13 @@ #include "base/bind.h" #include "base/callback.h" -#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "base/message_loop_proxy.h" +#include "base/memory/weak_ptr.h" +#include "base/threading/non_thread_safe.h" #include "net/base/completion_callback.h" #include "remoting/base/compound_buffer.h" #include "remoting/protocol/message_decoder.h" -class MessageLoop; - namespace net { class IOBuffer; class Socket; @@ -35,27 +33,24 @@ namespace protocol { // It is still possible that the MessageReceivedCallback is called // twice (so that there is more than one outstanding message), // e.g. when we the sender sends multiple messages in one TCP packet. -class MessageReader : public base::RefCountedThreadSafe<MessageReader> { +class MessageReader : public base::NonThreadSafe { public: typedef base::Callback<void(scoped_ptr<CompoundBuffer>, const base::Closure&)> MessageReceivedCallback; MessageReader(); + virtual ~MessageReader(); // Initialize the MessageReader with a socket. If a message is received // |callback| is called. void Init(net::Socket* socket, const MessageReceivedCallback& callback); private: - friend class base::RefCountedThreadSafe<MessageReader>; - virtual ~MessageReader(); - void DoRead(); void OnRead(int result); void HandleReadResult(int result); void OnDataReceived(net::IOBuffer* data, int data_size); - void OnMessageDone(scoped_refptr<base::SingleThreadTaskRunner> task_runner); - void ProcessDoneEvent(); + void OnMessageDone(); net::Socket* socket_; @@ -75,6 +70,10 @@ class MessageReader : public base::RefCountedThreadSafe<MessageReader> { // Callback is called when a message is received. MessageReceivedCallback message_received_callback_; + + base::WeakPtrFactory<MessageReader> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(MessageReader); }; // Version of MessageReader for protocol buffer messages, that parses @@ -82,7 +81,10 @@ class MessageReader : public base::RefCountedThreadSafe<MessageReader> { template <class T> class ProtobufMessageReader { public: - typedef typename base::Callback<void(scoped_ptr<T>, const base::Closure&)> + // The callback that is called when a new message is received. |done_task| + // must be called by the callback when it's done processing the |message|. + typedef typename base::Callback<void(scoped_ptr<T> message, + const base::Closure& done_task)> MessageReceivedCallback; ProtobufMessageReader() { }; @@ -91,7 +93,7 @@ class ProtobufMessageReader { void Init(net::Socket* socket, const MessageReceivedCallback& callback) { DCHECK(!callback.is_null()); message_received_callback_ = callback; - message_reader_ = new MessageReader(); + message_reader_.reset(new MessageReader()); message_reader_->Init( socket, base::Bind(&ProtobufMessageReader<T>::OnNewData, base::Unretained(this))); @@ -111,7 +113,7 @@ class ProtobufMessageReader { } } - scoped_refptr<MessageReader> message_reader_; + scoped_ptr<MessageReader> message_reader_; MessageReceivedCallback message_received_callback_; }; diff --git a/remoting/protocol/message_reader_unittest.cc b/remoting/protocol/message_reader_unittest.cc index 7867343..c789b96 100644 --- a/remoting/protocol/message_reader_unittest.cc +++ b/remoting/protocol/message_reader_unittest.cc @@ -10,7 +10,6 @@ #include "base/message_loop.h" #include "base/stl_util.h" #include "base/synchronization/waitable_event.h" -#include "base/threading/thread.h" #include "net/base/net_errors.h" #include "net/socket/socket.h" #include "remoting/protocol/fake_session.h" @@ -44,20 +43,12 @@ class MockMessageReceivedCallback { class MessageReaderTest : public testing::Test { public: MessageReaderTest() - : other_thread_("SecondTestThread"), - run_task_finished_(false, false) { - } - - void RunDoneTaskOnOtherThread(const base::Closure& done_task) { - other_thread_.message_loop()->PostTask( - FROM_HERE, - base::Bind(&MessageReaderTest::RunClosure, - base::Unretained(this), done_task)); + : run_task_finished_(false, false) { } protected: virtual void SetUp() OVERRIDE { - reader_ = new MessageReader(); + reader_.reset(new MessageReader()); } virtual void TearDown() OVERRIDE { @@ -94,9 +85,8 @@ class MessageReaderTest : public testing::Test { } MessageLoop message_loop_; - base::Thread other_thread_; base::WaitableEvent run_task_finished_; - scoped_refptr<MessageReader> reader_; + scoped_ptr<MessageReader> reader_; FakeSocket socket_; MockMessageReceivedCallback callback_; std::vector<CompoundBuffer*> messages_; @@ -263,32 +253,6 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) { EXPECT_TRUE(socket_.read_pending()); } -// Verify that socket operations occur on same thread, even when the OnMessage() -// callback triggers |done_task| to run on a different thread. -TEST_F(MessageReaderTest, UseSocketOnCorrectThread) { - AddMessage(kTestMessage1); - other_thread_.Start(); - - EXPECT_CALL(callback_, OnMessage(_)) - .WillOnce(Invoke(this, &MessageReaderTest::RunDoneTaskOnOtherThread)); - - InitReader(); - - run_task_finished_.Wait(); - message_loop_.RunAllPending(); - - Mock::VerifyAndClearExpectations(&callback_); - - // Write another message and verify that we receive it. - base::Closure done_task; - EXPECT_CALL(callback_, OnMessage(_)) - .WillOnce(SaveArg<0>(&done_task)); - AddMessage(kTestMessage2); - EXPECT_TRUE(CompareResult(messages_[1], kTestMessage2)); - - done_task.Run(); -} - // Read() returns error. TEST_F(MessageReaderTest, ReadError) { socket_.set_next_read_error(net::ERR_FAILED); |