diff options
Diffstat (limited to 'net/base/client_socket_pool_unittest.cc')
-rw-r--r-- | net/base/client_socket_pool_unittest.cc | 274 |
1 files changed, 186 insertions, 88 deletions
diff --git a/net/base/client_socket_pool_unittest.cc b/net/base/client_socket_pool_unittest.cc index a889d9a..73fc0e3 100644 --- a/net/base/client_socket_pool_unittest.cc +++ b/net/base/client_socket_pool_unittest.cc @@ -4,11 +4,16 @@ #include "base/message_loop.h" #include "net/base/client_socket.h" +#include "net/base/client_socket_factory.h" #include "net/base/client_socket_handle.h" #include "net/base/client_socket_pool.h" +#include "net/base/host_resolver_unittest.h" #include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" #include "testing/gtest/include/gtest/gtest.h" +namespace net { + namespace { const int kMaxSocketsPerGroup = 6; @@ -21,16 +26,16 @@ const int kPriorities[10] = { 1, 7, 9, 5, 6, 2, 8, 3, 4, 1 }; // available sockets in the socket group. const int kNumPendingRequests = arraysize(kPriorities); -class MockClientSocket : public net::ClientSocket { +const int kNumRequests = kMaxSocketsPerGroup + kNumPendingRequests; + +class MockClientSocket : public ClientSocket { public: - MockClientSocket() : connected_(false) { - allocation_count++; - } + MockClientSocket() : connected_(false) {} // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) { + virtual int Connect(CompletionCallback* callback) { connected_ = true; - return net::OK; + return OK; } virtual void Disconnect() { connected_ = false; @@ -43,51 +48,115 @@ class MockClientSocket : public net::ClientSocket { } // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - return net::ERR_FAILED; + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return ERR_FAILED; } - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - return net::ERR_FAILED; + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return ERR_FAILED; } - static int allocation_count; - private: bool connected_; }; -int MockClientSocket::allocation_count = 0; +class MockFailingClientSocket : public ClientSocket { + public: + MockFailingClientSocket() {} + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback) { + return ERR_CONNECTION_FAILED; + } + + virtual void Disconnect() {} + + virtual bool IsConnected() const { + return false; + } + virtual bool IsConnectedAndIdle() const { + return false; + } + + // Socket methods: + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return ERR_FAILED; + } + + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return ERR_FAILED; + } +}; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + enum ClientSocketType { + MOCK_CLIENT_SOCKET, + MOCK_FAILING_CLIENT_SOCKET, + }; + + MockClientSocketFactory() + : allocation_count_(0), client_socket_type_(MOCK_CLIENT_SOCKET) {} + + virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses) { + allocation_count_++; + switch (client_socket_type_) { + case MOCK_CLIENT_SOCKET: + return new MockClientSocket(); + case MOCK_FAILING_CLIENT_SOCKET: + return new MockFailingClientSocket(); + default: + NOTREACHED(); + return new MockClientSocket(); + } + } + + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) { + NOTIMPLEMENTED(); + return NULL; + } + + int allocation_count() const { return allocation_count_; } + + void set_client_socket_type(ClientSocketType type) { + client_socket_type_ = type; + } + + private: + int allocation_count_; + ClientSocketType client_socket_type_; +}; class TestSocketRequest : public CallbackRunner< Tuple1<int> > { public: TestSocketRequest( - net::ClientSocketPool* pool, + ClientSocketPool* pool, std::vector<TestSocketRequest*>* request_order) : handle(pool), request_order_(request_order) {} - net::ClientSocketHandle handle; + ClientSocketHandle handle; - void EnsureSocket() { - DCHECK(handle.is_initialized()); - request_order_->push_back(this); - if (!handle.socket()) { - handle.set_socket(new MockClientSocket()); - handle.socket()->Connect(NULL); - } + int WaitForResult() { + return callback_.WaitForResult(); } virtual void RunWithParams(const Tuple1<int>& params) { - DCHECK(params.a == net::OK); + callback_.RunWithParams(params); completion_count++; - EnsureSocket(); + request_order_->push_back(this); } static int completion_count; private: std::vector<TestSocketRequest*>* request_order_; + TestCompletionCallback callback_; }; int TestSocketRequest::completion_count = 0; @@ -95,68 +164,76 @@ int TestSocketRequest::completion_count = 0; class ClientSocketPoolTest : public testing::Test { protected: ClientSocketPoolTest() - : pool_(new net::ClientSocketPool(kMaxSocketsPerGroup)) {} + : pool_(new ClientSocketPool(kMaxSocketsPerGroup, + &client_socket_factory_)) {} virtual void SetUp() { - MockClientSocket::allocation_count = 0; TestSocketRequest::completion_count = 0; } - scoped_refptr<net::ClientSocketPool> pool_; + MockClientSocketFactory client_socket_factory_; + scoped_refptr<ClientSocketPool> pool_; std::vector<TestSocketRequest*> request_order_; }; TEST_F(ClientSocketPoolTest, Basic) { - TestSocketRequest r(pool_.get(), &request_order_); - int rv; + TestCompletionCallback callback; + ClientSocketHandle handle(pool_.get()); + int rv = handle.Init("a", "www.google.com", 80, 0, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); - rv = r.handle.Init("a", 0, &r); - EXPECT_EQ(net::OK, rv); - EXPECT_TRUE(r.handle.is_initialized()); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); - r.handle.Reset(); + handle.Reset(); // The handle's Reset method may have posted a task. MessageLoop::current()->RunAllPending(); } -TEST_F(ClientSocketPoolTest, WithIdleConnection) { - TestSocketRequest r(pool_.get(), &request_order_); - int rv; - - rv = r.handle.Init("a", 0, &r); - EXPECT_EQ(net::OK, rv); - EXPECT_TRUE(r.handle.is_initialized()); - - // Create a socket. - r.EnsureSocket(); - - // Release the socket. It should find its way into the idle list. We're - // testing that this does not trigger a crash. - r.handle.Reset(); +TEST_F(ClientSocketPoolTest, InitHostResolutionFailure) { + RuleBasedHostMapper* host_mapper = new RuleBasedHostMapper; + host_mapper->AddSimulatedFailure("unresolvable.host.name"); + ScopedHostMapper scoped_host_mapper(host_mapper); + TestSocketRequest req(pool_.get(), &request_order_); + EXPECT_EQ(ERR_IO_PENDING, + req.handle.Init("a", "unresolvable.host.name", 80, 5, &req)); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req.WaitForResult()); +} - // The handle's Reset method may have posted a task. - MessageLoop::current()->RunAllPending(); +TEST_F(ClientSocketPoolTest, InitConnectionFailure) { + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); + TestSocketRequest req(pool_.get(), &request_order_); + EXPECT_EQ(ERR_IO_PENDING, + req.handle.Init("a", "unresolvable.host.name", 80, 5, &req)); + EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); } TEST_F(ClientSocketPoolTest, PendingRequests) { int rv; - scoped_ptr<TestSocketRequest> reqs[kMaxSocketsPerGroup + kNumPendingRequests]; + scoped_ptr<TestSocketRequest> reqs[kNumRequests]; for (size_t i = 0; i < arraysize(reqs); ++i) reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); // Create connections or queue up requests. for (int i = 0; i < kMaxSocketsPerGroup; ++i) { - rv = reqs[i]->handle.Init("a", 5, reqs[i].get()); - EXPECT_EQ(net::OK, rv); - reqs[i]->EnsureSocket(); + EXPECT_EQ( + ERR_IO_PENDING, + reqs[i]->handle.Init("a", "www.google.com", 80, 5, reqs[i].get())); + EXPECT_EQ(OK, reqs[i]->WaitForResult()); } + for (int i = 0; i < kNumPendingRequests; ++i) { rv = reqs[kMaxSocketsPerGroup + i]->handle.Init( - "a", kPriorities[i], reqs[kMaxSocketsPerGroup + i].get()); - EXPECT_EQ(net::ERR_IO_PENDING, rv); + "a", "www.google.com", 80, kPriorities[i], + reqs[kMaxSocketsPerGroup + i].get()); + EXPECT_EQ(ERR_IO_PENDING, rv); } // Release any connections until we have no connections. @@ -172,8 +249,8 @@ TEST_F(ClientSocketPoolTest, PendingRequests) { } } while (released_one); - EXPECT_EQ(kMaxSocketsPerGroup, MockClientSocket::allocation_count); - EXPECT_EQ(kNumPendingRequests, TestSocketRequest::completion_count); + EXPECT_EQ(kMaxSocketsPerGroup, client_socket_factory_.allocation_count()); + EXPECT_EQ(kNumRequests, TestSocketRequest::completion_count); for (int i = 0; i < kMaxSocketsPerGroup; ++i) { EXPECT_EQ(request_order_[i], reqs[i].get()) << @@ -194,58 +271,77 @@ TEST_F(ClientSocketPoolTest, PendingRequests) { } TEST_F(ClientSocketPoolTest, PendingRequests_NoKeepAlive) { - int rv; - - scoped_ptr<TestSocketRequest> reqs[kMaxSocketsPerGroup + kNumPendingRequests]; + scoped_ptr<TestSocketRequest> reqs[kNumRequests]; for (size_t i = 0; i < arraysize(reqs); ++i) reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); // Create connections or queue up requests. - for (size_t i = 0; i < arraysize(reqs); ++i) { - rv = reqs[i]->handle.Init("a", 0, reqs[i].get()); - if (rv != net::ERR_IO_PENDING) { - EXPECT_EQ(net::OK, rv); - reqs[i]->EnsureSocket(); - } + for (int i = 0; i < kMaxSocketsPerGroup; ++i) { + EXPECT_EQ( + ERR_IO_PENDING, + reqs[i]->handle.Init("a", "www.google.com", 80, 0, reqs[i].get())); + EXPECT_EQ(OK, reqs[i]->WaitForResult()); + } + + for (int i = 0; i < kNumPendingRequests; ++i) { + EXPECT_EQ(ERR_IO_PENDING, reqs[kMaxSocketsPerGroup + i]->handle.Init( + "a", "www.google.com", 80, 0, reqs[kMaxSocketsPerGroup + i].get())); } // Release any connections until we have no connections. - bool released_one; - do { - released_one = false; + + while (TestSocketRequest::completion_count < kNumRequests) { + int num_released = 0; for (size_t i = 0; i < arraysize(reqs); ++i) { if (reqs[i]->handle.is_initialized()) { reqs[i]->handle.socket()->Disconnect(); reqs[i]->handle.Reset(); - MessageLoop::current()->RunAllPending(); - released_one = true; + num_released++; } } - } while (released_one); + int curr_num_completed = TestSocketRequest::completion_count; + for (int i = 0; + (i < num_released) && (i + curr_num_completed < kNumRequests); ++i) { + EXPECT_EQ(OK, reqs[i + curr_num_completed]->WaitForResult()); + } + } - EXPECT_EQ(kMaxSocketsPerGroup + kNumPendingRequests, - MockClientSocket::allocation_count); - EXPECT_EQ(kNumPendingRequests, TestSocketRequest::completion_count); + EXPECT_EQ(kNumRequests, client_socket_factory_.allocation_count()); + EXPECT_EQ(kNumRequests, TestSocketRequest::completion_count); } -TEST_F(ClientSocketPoolTest, CancelRequest) { - int rv; +// This test will start up a RequestSocket() and then immediately Cancel() it. +// The pending host resolution will eventually complete, and destroy the +// ClientSocketPool which will crash if the group was not cleared properly. +TEST_F(ClientSocketPoolTest, CancelRequestClearGroup) { + TestSocketRequest req(pool_.get(), &request_order_); + EXPECT_EQ(ERR_IO_PENDING, + req.handle.Init("a", "www.google.com", 80, 5, &req)); + req.handle.Reset(); + // There is a race condition here. If the worker pool doesn't post the task + // before we get here, then this might not run ConnectingSocket::IOComplete + // and therefore leak the canceled ConnectingSocket. + MessageLoop::current()->RunAllPending(); +} - scoped_ptr<TestSocketRequest> reqs[kMaxSocketsPerGroup + kNumPendingRequests]; +TEST_F(ClientSocketPoolTest, CancelRequest) { + scoped_ptr<TestSocketRequest> reqs[kNumRequests]; for (size_t i = 0; i < arraysize(reqs); ++i) reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); // Create connections or queue up requests. for (int i = 0; i < kMaxSocketsPerGroup; ++i) { - rv = reqs[i]->handle.Init("a", 5, reqs[i].get()); - EXPECT_EQ(net::OK, rv); - reqs[i]->EnsureSocket(); + EXPECT_EQ( + ERR_IO_PENDING, + reqs[i]->handle.Init("a", "www.google.com", 80, 5, reqs[i].get())); + EXPECT_EQ(OK, reqs[i]->WaitForResult()); } + for (int i = 0; i < kNumPendingRequests; ++i) { - rv = reqs[kMaxSocketsPerGroup + i]->handle.Init( - "a", kPriorities[i], reqs[kMaxSocketsPerGroup + i].get()); - EXPECT_EQ(net::ERR_IO_PENDING, rv); + EXPECT_EQ(ERR_IO_PENDING, reqs[kMaxSocketsPerGroup + i]->handle.Init( + "a", "www.google.com", 80, kPriorities[i], + reqs[kMaxSocketsPerGroup + i].get())); } // Cancel a request. @@ -266,8 +362,8 @@ TEST_F(ClientSocketPoolTest, CancelRequest) { } } while (released_one); - EXPECT_EQ(kMaxSocketsPerGroup, MockClientSocket::allocation_count); - EXPECT_EQ(kNumPendingRequests - 1, TestSocketRequest::completion_count); + EXPECT_EQ(kMaxSocketsPerGroup, client_socket_factory_.allocation_count()); + EXPECT_EQ(kNumRequests - 1, TestSocketRequest::completion_count); for (int i = 0; i < kMaxSocketsPerGroup; ++i) { EXPECT_EQ(request_order_[i], reqs[i].get()) << "Request " << i << " was not in order."; @@ -290,3 +386,5 @@ TEST_F(ClientSocketPoolTest, CancelRequest) { } } // namespace + +} // namespace net |