summaryrefslogtreecommitdiffstats
path: root/remoting/protocol
diff options
context:
space:
mode:
authorsergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2012-08-23 21:37:04 +0000
committersergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2012-08-23 21:37:04 +0000
commiteeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c (patch)
tree62dde3058ac403dc7cff3d79810c4051027ec56f /remoting/protocol
parentf1e0be7d135642015c9eedadc33d8878c35dd544 (diff)
downloadchromium_src-eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c.zip
chromium_src-eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c.tar.gz
chromium_src-eeda031ebf73dab9fe55ea1c51bd6ac4f4e1f35c.tar.bz2
Make MessageReader class not ref-counted.
Previously MessageReader class was ref-counted, which means in some cases it could outlive the socket. Specifically the problem is when received messages are processed asynchronously and the callback for one of the messages is called after the socket is destroyed. BUG=139257 Review URL: https://chromiumcodereview.appspot.com/10870021 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@153079 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting/protocol')
-rw-r--r--remoting/protocol/message_reader.cc32
-rw-r--r--remoting/protocol/message_reader.h28
-rw-r--r--remoting/protocol/message_reader_unittest.cc42
3 files changed, 33 insertions, 69 deletions
diff --git a/remoting/protocol/message_reader.cc b/remoting/protocol/message_reader.cc
index 441d729..fbbe9be 100644
--- a/remoting/protocol/message_reader.cc
+++ b/remoting/protocol/message_reader.cc
@@ -6,6 +6,7 @@
#include "base/bind.h"
#include "base/callback.h"
+#include "base/compiler_specific.h"
#include "base/location.h"
#include "base/thread_task_runner_handle.h"
#include "net/base/io_buffer.h"
@@ -23,11 +24,13 @@ MessageReader::MessageReader()
: socket_(NULL),
read_pending_(false),
pending_messages_(0),
- closed_(false) {
+ closed_(false),
+ ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
}
void MessageReader::Init(net::Socket* socket,
const MessageReceivedCallback& callback) {
+ DCHECK(CalledOnValidThread());
message_received_callback_ = callback;
DCHECK(socket);
socket_ = socket;
@@ -39,18 +42,20 @@ MessageReader::~MessageReader() {
}
void MessageReader::DoRead() {
+ DCHECK(CalledOnValidThread());
// Don't try to read again if there is another read pending or we
// have messages that we haven't finished processing yet.
while (!closed_ && !read_pending_ && pending_messages_ == 0) {
read_buffer_ = new net::IOBuffer(kReadBufferSize);
int result = socket_->Read(
- read_buffer_, kReadBufferSize, base::Bind(&MessageReader::OnRead,
- base::Unretained(this)));
+ read_buffer_, kReadBufferSize,
+ base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr()));
HandleReadResult(result);
}
}
void MessageReader::OnRead(int result) {
+ DCHECK(CalledOnValidThread());
DCHECK(read_pending_);
read_pending_ = false;
@@ -61,6 +66,7 @@ void MessageReader::OnRead(int result) {
}
void MessageReader::HandleReadResult(int result) {
+ DCHECK(CalledOnValidThread());
if (closed_)
return;
@@ -78,6 +84,7 @@ void MessageReader::HandleReadResult(int result) {
}
void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) {
+ DCHECK(CalledOnValidThread());
message_decoder_.AddData(data, data_size);
// Get list of all new messages first, and then call the callback
@@ -96,27 +103,18 @@ void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) {
it != new_messages.end(); ++it) {
message_received_callback_.Run(
scoped_ptr<CompoundBuffer>(*it),
- base::Bind(&MessageReader::OnMessageDone, this,
- base::ThreadTaskRunnerHandle::Get()));
+ base::Bind(&MessageReader::OnMessageDone,
+ weak_factory_.GetWeakPtr()));
}
}
-void MessageReader::OnMessageDone(
- scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
- if (task_runner->BelongsToCurrentThread()) {
- ProcessDoneEvent();
- } else {
- task_runner->PostTask(
- FROM_HERE, base::Bind(&MessageReader::ProcessDoneEvent, this));
- }
-}
-
-void MessageReader::ProcessDoneEvent() {
+void MessageReader::OnMessageDone() {
+ DCHECK(CalledOnValidThread());
pending_messages_--;
DCHECK_GE(pending_messages_, 0);
if (!read_pending_)
- DoRead(); // Start next read if neccessary.
+ DoRead(); // Start next read if necessary.
}
} // namespace protocol
diff --git a/remoting/protocol/message_reader.h b/remoting/protocol/message_reader.h
index b9a373b..efba5e1 100644
--- a/remoting/protocol/message_reader.h
+++ b/remoting/protocol/message_reader.h
@@ -7,15 +7,13 @@
#include "base/bind.h"
#include "base/callback.h"
-#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
-#include "base/message_loop_proxy.h"
+#include "base/memory/weak_ptr.h"
+#include "base/threading/non_thread_safe.h"
#include "net/base/completion_callback.h"
#include "remoting/base/compound_buffer.h"
#include "remoting/protocol/message_decoder.h"
-class MessageLoop;
-
namespace net {
class IOBuffer;
class Socket;
@@ -35,27 +33,24 @@ namespace protocol {
// It is still possible that the MessageReceivedCallback is called
// twice (so that there is more than one outstanding message),
// e.g. when we the sender sends multiple messages in one TCP packet.
-class MessageReader : public base::RefCountedThreadSafe<MessageReader> {
+class MessageReader : public base::NonThreadSafe {
public:
typedef base::Callback<void(scoped_ptr<CompoundBuffer>, const base::Closure&)>
MessageReceivedCallback;
MessageReader();
+ virtual ~MessageReader();
// Initialize the MessageReader with a socket. If a message is received
// |callback| is called.
void Init(net::Socket* socket, const MessageReceivedCallback& callback);
private:
- friend class base::RefCountedThreadSafe<MessageReader>;
- virtual ~MessageReader();
-
void DoRead();
void OnRead(int result);
void HandleReadResult(int result);
void OnDataReceived(net::IOBuffer* data, int data_size);
- void OnMessageDone(scoped_refptr<base::SingleThreadTaskRunner> task_runner);
- void ProcessDoneEvent();
+ void OnMessageDone();
net::Socket* socket_;
@@ -75,6 +70,10 @@ class MessageReader : public base::RefCountedThreadSafe<MessageReader> {
// Callback is called when a message is received.
MessageReceivedCallback message_received_callback_;
+
+ base::WeakPtrFactory<MessageReader> weak_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(MessageReader);
};
// Version of MessageReader for protocol buffer messages, that parses
@@ -82,7 +81,10 @@ class MessageReader : public base::RefCountedThreadSafe<MessageReader> {
template <class T>
class ProtobufMessageReader {
public:
- typedef typename base::Callback<void(scoped_ptr<T>, const base::Closure&)>
+ // The callback that is called when a new message is received. |done_task|
+ // must be called by the callback when it's done processing the |message|.
+ typedef typename base::Callback<void(scoped_ptr<T> message,
+ const base::Closure& done_task)>
MessageReceivedCallback;
ProtobufMessageReader() { };
@@ -91,7 +93,7 @@ class ProtobufMessageReader {
void Init(net::Socket* socket, const MessageReceivedCallback& callback) {
DCHECK(!callback.is_null());
message_received_callback_ = callback;
- message_reader_ = new MessageReader();
+ message_reader_.reset(new MessageReader());
message_reader_->Init(
socket, base::Bind(&ProtobufMessageReader<T>::OnNewData,
base::Unretained(this)));
@@ -111,7 +113,7 @@ class ProtobufMessageReader {
}
}
- scoped_refptr<MessageReader> message_reader_;
+ scoped_ptr<MessageReader> message_reader_;
MessageReceivedCallback message_received_callback_;
};
diff --git a/remoting/protocol/message_reader_unittest.cc b/remoting/protocol/message_reader_unittest.cc
index 7867343..c789b96 100644
--- a/remoting/protocol/message_reader_unittest.cc
+++ b/remoting/protocol/message_reader_unittest.cc
@@ -10,7 +10,6 @@
#include "base/message_loop.h"
#include "base/stl_util.h"
#include "base/synchronization/waitable_event.h"
-#include "base/threading/thread.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
#include "remoting/protocol/fake_session.h"
@@ -44,20 +43,12 @@ class MockMessageReceivedCallback {
class MessageReaderTest : public testing::Test {
public:
MessageReaderTest()
- : other_thread_("SecondTestThread"),
- run_task_finished_(false, false) {
- }
-
- void RunDoneTaskOnOtherThread(const base::Closure& done_task) {
- other_thread_.message_loop()->PostTask(
- FROM_HERE,
- base::Bind(&MessageReaderTest::RunClosure,
- base::Unretained(this), done_task));
+ : run_task_finished_(false, false) {
}
protected:
virtual void SetUp() OVERRIDE {
- reader_ = new MessageReader();
+ reader_.reset(new MessageReader());
}
virtual void TearDown() OVERRIDE {
@@ -94,9 +85,8 @@ class MessageReaderTest : public testing::Test {
}
MessageLoop message_loop_;
- base::Thread other_thread_;
base::WaitableEvent run_task_finished_;
- scoped_refptr<MessageReader> reader_;
+ scoped_ptr<MessageReader> reader_;
FakeSocket socket_;
MockMessageReceivedCallback callback_;
std::vector<CompoundBuffer*> messages_;
@@ -263,32 +253,6 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) {
EXPECT_TRUE(socket_.read_pending());
}
-// Verify that socket operations occur on same thread, even when the OnMessage()
-// callback triggers |done_task| to run on a different thread.
-TEST_F(MessageReaderTest, UseSocketOnCorrectThread) {
- AddMessage(kTestMessage1);
- other_thread_.Start();
-
- EXPECT_CALL(callback_, OnMessage(_))
- .WillOnce(Invoke(this, &MessageReaderTest::RunDoneTaskOnOtherThread));
-
- InitReader();
-
- run_task_finished_.Wait();
- message_loop_.RunAllPending();
-
- Mock::VerifyAndClearExpectations(&callback_);
-
- // Write another message and verify that we receive it.
- base::Closure done_task;
- EXPECT_CALL(callback_, OnMessage(_))
- .WillOnce(SaveArg<0>(&done_task));
- AddMessage(kTestMessage2);
- EXPECT_TRUE(CompareResult(messages_[1], kTestMessage2));
-
- done_task.Run();
-}
-
// Read() returns error.
TEST_F(MessageReaderTest, ReadError) {
socket_.set_next_read_error(net::ERR_FAILED);