diff options
-rw-r--r-- | chrome/browser/extensions/api/socket/combined_socket_unittest.cc | 27 | ||||
-rw-r--r-- | extensions/browser/api/socket/tcp_socket.cc | 18 |
2 files changed, 36 insertions, 9 deletions
diff --git a/chrome/browser/extensions/api/socket/combined_socket_unittest.cc b/chrome/browser/extensions/api/socket/combined_socket_unittest.cc index d6ae864..22cdebb 100644 --- a/chrome/browser/extensions/api/socket/combined_socket_unittest.cc +++ b/chrome/browser/extensions/api/socket/combined_socket_unittest.cc @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "base/memory/scoped_ptr.h" #include "chrome/browser/extensions/api/socket/mock_tcp_client_socket.h" #include "extensions/browser/api/socket/socket.h" #include "extensions/browser/api/socket/tcp_socket.h" @@ -14,6 +15,20 @@ namespace extensions { const int kBufferLength = 10; +template <typename T> +scoped_ptr<T> CreateTestSocket(scoped_ptr<MockTCPClientSocket> stream); + +template <> +scoped_ptr<TCPSocket> CreateTestSocket(scoped_ptr<MockTCPClientSocket> stream) { + return make_scoped_ptr(new TCPSocket(std::move(stream), "fake id", + true /* is_connected */)); +} + +template <> +scoped_ptr<TLSSocket> CreateTestSocket(scoped_ptr<MockTCPClientSocket> stream) { + return make_scoped_ptr(new TLSSocket(std::move(stream), "fake id")); +} + class CombinedSocketTest : public testing::Test { public: CombinedSocketTest() : count_(0), io_buffer_(nullptr) {} @@ -30,10 +45,10 @@ class CombinedSocketTest : public testing::Test { testing::Return(kBufferLength))); EXPECT_CALL(*stream, Disconnect()); - T socket(std::move(stream), "fake id"); + scoped_ptr<T> socket = CreateTestSocket<T>(std::move(stream)); ReadCompletionCallback read_callback = base::Bind(&CombinedSocketTest::OnRead, base::Unretained(this)); - socket.Read(kBufferLength, read_callback); + socket->Read(kBufferLength, read_callback); EXPECT_EQ(kBufferLength, count_); EXPECT_NE(nullptr, buffer); EXPECT_EQ(buffer, io_buffer_); @@ -53,10 +68,10 @@ class CombinedSocketTest : public testing::Test { testing::Return(net::ERR_IO_PENDING))); EXPECT_CALL(*stream, Disconnect()); - T socket(std::move(stream), "fake id"); + scoped_ptr<T> socket = CreateTestSocket<T>(std::move(stream)); ReadCompletionCallback read_callback = base::Bind(&CombinedSocketTest::OnRead, base::Unretained(this)); - socket.Read(kBufferLength, read_callback); + socket->Read(kBufferLength, read_callback); EXPECT_EQ(0, count_); EXPECT_EQ(nullptr, io_buffer_); socket_cb.Run(kBufferLength); @@ -78,10 +93,10 @@ class CombinedSocketTest : public testing::Test { ON_CALL(*stream, IsConnected()).WillByDefault(testing::Return(false)); EXPECT_CALL(*stream, Disconnect()); - T socket(std::move(stream), "fake id"); + scoped_ptr<T> socket = CreateTestSocket<T>(std::move(stream)); ReadCompletionCallback read_callback = base::Bind(&CombinedSocketTest::OnRead, base::Unretained(this)); - socket.Read(kBufferLength, read_callback); + socket->Read(kBufferLength, read_callback); EXPECT_EQ(kBufferLength, count_); EXPECT_NE(nullptr, buffer); EXPECT_EQ(buffer, io_buffer_); diff --git a/extensions/browser/api/socket/tcp_socket.cc b/extensions/browser/api/socket/tcp_socket.cc index 2a8e718..a69ed02 100644 --- a/extensions/browser/api/socket/tcp_socket.cc +++ b/extensions/browser/api/socket/tcp_socket.cc @@ -86,6 +86,12 @@ void TCPSocket::Connect(const net::AddressList& address, callback.Run(net::ERR_CONNECTION_FAILED); return; } + + if (is_connected_) { + callback.Run(net::ERR_SOCKET_IS_CONNECTED); + return; + } + DCHECK(!server_socket_.get()); socket_mode_ = CLIENT; connect_callback_ = callback; @@ -130,7 +136,9 @@ void TCPSocket::Read(int count, const ReadCompletionCallback& callback) { return; } - if (!read_callback_.is_null()) { + if (!read_callback_.is_null() || !connect_callback_.is_null()) { + // It's illegal to read a net::TCPSocket while a pending Connect or Read is + // already in progress. callback.Run(net::ERR_IO_PENDING, NULL); return; } @@ -140,7 +148,7 @@ void TCPSocket::Read(int count, const ReadCompletionCallback& callback) { return; } - if (!socket_.get()) { + if (!socket_.get() || !is_connected_) { callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL); return; } @@ -277,8 +285,12 @@ void TCPSocket::OnConnectComplete(int result) { DCHECK(!connect_callback_.is_null()); DCHECK(!is_connected_); is_connected_ = result == net::OK; - connect_callback_.Run(result); + + // The completion callback may re-enter TCPSocket, e.g. to Read(); therefore + // we reset |connect_callback_| before calling it. + CompletionCallback connect_callback = connect_callback_; connect_callback_.Reset(); + connect_callback.Run(result); } void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer, |