diff options
author | willchan@chromium.org <willchan@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-30 18:49:05 +0000 |
---|---|---|
committer | willchan@chromium.org <willchan@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-30 18:49:05 +0000 |
commit | ab83889ea0a2351c800e9b4697c04749bdc2e255 (patch) | |
tree | 3ac2dd1f2b79f15f9f7474b3c448a44da7a28cb1 /net | |
parent | b8e787f0951866c565b80e9af124c825af9a4231 (diff) | |
download | chromium_src-ab83889ea0a2351c800e9b4697c04749bdc2e255.zip chromium_src-ab83889ea0a2351c800e9b4697c04749bdc2e255.tar.gz chromium_src-ab83889ea0a2351c800e9b4697c04749bdc2e255.tar.bz2 |
Refactor ClientSocketPoolBase to be testable without host resolution / tcp connections.
BUG=http://crbug.com/13289
TEST=none
Review URL: http://codereview.chromium.org/147252
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@19620 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/net.gyp | 1 | ||||
-rw-r--r-- | net/socket/client_socket_pool.h | 3 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base.cc | 54 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base.h | 81 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base_unittest.cc | 499 | ||||
-rw-r--r-- | net/socket/tcp_client_socket_pool.cc | 52 | ||||
-rw-r--r-- | net/socket/tcp_client_socket_pool.h | 20 |
7 files changed, 340 insertions, 370 deletions
diff --git a/net/net.gyp b/net/net.gyp index fe05eaa..a5b3b12d 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -482,6 +482,7 @@ 'proxy/proxy_script_fetcher_unittest.cc', 'proxy/proxy_server_unittest.cc', 'proxy/proxy_service_unittest.cc', + 'socket/client_socket_pool_base_unittest.cc', 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h index b3a7bba..1c8b30e 100644 --- a/net/socket/client_socket_pool.h +++ b/net/socket/client_socket_pool.h @@ -69,9 +69,6 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { // Called to close any idle connections held by the connection manager. virtual void CloseIdleSockets() = 0; - // Returns the HostResolver that will be used for host lookups. - virtual HostResolver* GetHostResolver() const = 0; - // The total number of idle sockets in the pool. virtual int IdleSocketCount() const = 0; diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index 4a2770f..123c3cf 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -32,11 +32,9 @@ namespace net { ClientSocketPoolBase::ClientSocketPoolBase( int max_sockets_per_group, - HostResolver* host_resolver, ConnectJobFactory* connect_job_factory) : idle_socket_count_(0), max_sockets_per_group_(max_sockets_per_group), - host_resolver_(host_resolver), connect_job_factory_(connect_job_factory) {} ClientSocketPoolBase::~ClientSocketPoolBase() { @@ -69,6 +67,7 @@ int ClientSocketPoolBase::RequestSocket( CompletionCallback* callback) { DCHECK(!resolve_info.hostname().empty()); DCHECK_GE(priority, 0); + DCHECK(callback); Group& group = group_map_[group_name]; CheckSocketCounts(group); @@ -76,7 +75,7 @@ int ClientSocketPoolBase::RequestSocket( // Can we make another active socket now? if (group.active_socket_count == max_sockets_per_group_) { CHECK(callback); - Request r(handle, callback, priority, resolve_info, LOAD_STATE_IDLE); + Request r(handle, callback, priority, resolve_info); InsertRequestIntoQueue(r, &group.pending_requests); return ERR_IO_PENDING; } @@ -102,8 +101,7 @@ int ClientSocketPoolBase::RequestSocket( // We couldn't find a socket to reuse, so allocate and connect a new one. CHECK(callback); - Request r(handle, callback, priority, resolve_info, - LOAD_STATE_RESOLVING_HOST); + Request r(handle, callback, priority, resolve_info); group.connecting_requests[handle] = r; CHECK(!ContainsKey(connect_job_map_, handle)); @@ -178,17 +176,18 @@ LoadState ClientSocketPoolBase::GetLoadState( // Search connecting_requests for matching handle. RequestMap::const_iterator map_it = group.connecting_requests.find(handle); if (map_it != group.connecting_requests.end()) { - const LoadState load_state = map_it->second.load_state; - CHECK(load_state == LOAD_STATE_RESOLVING_HOST || - load_state == LOAD_STATE_CONNECTING); - return load_state; + ConnectJobMap::const_iterator job_it = connect_job_map_.find(handle); + if (job_it == connect_job_map_.end()) { + NOTREACHED(); + return LOAD_STATE_IDLE; + } + return job_it->second->load_state(); } // Search pending_requests for matching handle. RequestQueue::const_iterator it = group.pending_requests.begin(); for (; it != group.pending_requests.end(); ++it) { if (it->handle == handle) { - CHECK(LOAD_STATE_IDLE == it->load_state); // TODO(wtc): Add a state for being on the wait list. // See http://www.crbug.com/5077. return LOAD_STATE_IDLE; @@ -278,28 +277,12 @@ void ClientSocketPoolBase::DoReleaseSocket(const std::string& group_name, RemoveActiveSocket(group_name, &group); } -ClientSocketPoolBase::Request* ClientSocketPoolBase::GetConnectingRequest( - const std::string& group_name, const ClientSocketHandle* handle) { - GroupMap::iterator group_it = group_map_.find(group_name); - if (group_it == group_map_.end()) - return NULL; - - Group& group = group_it->second; - - RequestMap* request_map = &group.connecting_requests; - RequestMap::iterator it = request_map->find(handle); - if (it == request_map->end()) - return NULL; - - return &it->second; -} - -CompletionCallback* ClientSocketPoolBase::OnConnectingRequestComplete( +void ClientSocketPoolBase::OnConnectJobComplete( const std::string& group_name, - const ClientSocketHandle* handle, - bool deactivate, - ClientSocket* socket) { - CHECK((deactivate && !socket) || (!deactivate && socket)); + const ClientSocketHandle* key_handle, + ClientSocket* socket, + int result, + bool was_async) { GroupMap::iterator group_it = group_map_.find(group_name); CHECK(group_it != group_map_.end()); Group& group = group_it->second; @@ -308,13 +291,13 @@ CompletionCallback* ClientSocketPoolBase::OnConnectingRequestComplete( RequestMap* request_map = &group.connecting_requests; - RequestMap::iterator it = request_map->find(handle); + RequestMap::iterator it = request_map->find(key_handle); CHECK(it != request_map->end()); Request request = it->second; request_map->erase(it); - DCHECK_EQ(request.handle, handle); + DCHECK_EQ(request.handle, key_handle); - if (deactivate) { + if (!socket) { RemoveActiveSocket(group_name, &group); } else { request.handle->set_socket(socket); @@ -326,7 +309,8 @@ CompletionCallback* ClientSocketPoolBase::OnConnectingRequestComplete( RemoveConnectJob(request.handle); - return request.callback; + if (was_async) + request.callback->Run(result); } // static diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h index c642b48..6d05b41 100644 --- a/net/socket/client_socket_pool_base.h +++ b/net/socket/client_socket_pool_base.h @@ -30,22 +30,55 @@ class ClientSocketPoolBase; // etc. class ConnectJob { public: + class Delegate { + public: + Delegate() {} + virtual ~Delegate() {} + + // Alerts the delegate that the connection completed (though not necessarily + // successfully). |group_name| indicates the connection group this + // ConnectJob corresponds to. |key_handle| uniquely identifies the + // ClientSocketHandle that this job is coupled to. |socket| is non-NULL if + // the connection completed successfully, and ownership is transferred to + // the delegate. |was_async| indicates whether or not the connect job + // completed asynchronously. + virtual void OnConnectJobComplete( + const std::string& group_name, + const ClientSocketHandle* key_handle, + ClientSocket* socket, + int result, + bool was_async) = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(Delegate); + }; + ConnectJob() {} virtual ~ConnectJob() {} + // Returns the LoadState of this ConnectJob. + LoadState load_state() const { return load_state_; } + // Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it // cannot complete synchronously without blocking, or another net error code // on error. virtual int Connect() = 0; + protected: + void set_load_state(LoadState load_state) { load_state_ = load_state; } + private: + LoadState load_state_; + DISALLOW_COPY_AND_ASSIGN(ConnectJob); }; // A ClientSocketPoolBase is used to restrict the number of sockets open at // a time. It also maintains a list of idle persistent sockets. // -class ClientSocketPoolBase : public base::RefCounted<ClientSocketPoolBase> { +class ClientSocketPoolBase + : public base::RefCounted<ClientSocketPoolBase>, + public ConnectJob::Delegate { public: // A Request is allocated per call to RequestSocket that results in // ERR_IO_PENDING. @@ -56,17 +89,15 @@ class ClientSocketPoolBase : public base::RefCounted<ClientSocketPoolBase> { Request(ClientSocketHandle* handle, CompletionCallback* callback, int priority, - const HostResolver::RequestInfo& resolve_info, - LoadState load_state) + const HostResolver::RequestInfo& resolve_info) : handle(handle), callback(callback), priority(priority), - resolve_info(resolve_info), load_state(load_state) { + resolve_info(resolve_info) { } ClientSocketHandle* handle; CompletionCallback* callback; int priority; HostResolver::RequestInfo resolve_info; - LoadState load_state; }; class ConnectJobFactory { @@ -77,14 +108,13 @@ class ClientSocketPoolBase : public base::RefCounted<ClientSocketPoolBase> { virtual ConnectJob* NewConnectJob( const std::string& group_name, const Request& request, - ClientSocketPoolBase* pool) const = 0; + ConnectJob::Delegate* delegate) const = 0; private: DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory); }; ClientSocketPoolBase(int max_sockets_per_group, - HostResolver* host_resolver, ConnectJobFactory* connect_job_factory); ~ClientSocketPoolBase(); @@ -103,10 +133,6 @@ class ClientSocketPoolBase : public base::RefCounted<ClientSocketPoolBase> { void CloseIdleSockets(); - HostResolver* GetHostResolver() const { - return host_resolver_; - } - int idle_socket_count() const { return idle_socket_count_; } @@ -116,27 +142,14 @@ class ClientSocketPoolBase : public base::RefCounted<ClientSocketPoolBase> { LoadState GetLoadState(const std::string& group_name, const ClientSocketHandle* handle) const; - // Used by ConnectJob until we remove the coupling between a specific - // ConnectJob and a ClientSocketHandle: - - // Returns NULL if not found. Otherwise it returns the Request* - // corresponding to the ConnectJob (keyed by |group_name| and |handle|. - // Note that this pointer may be invalidated after any call that might mutate - // the RequestMap or GroupMap, so the user should not hold onto the pointer - // for long. - Request* GetConnectingRequest(const std::string& group_name, - const ClientSocketHandle* handle); - - // Handles the completed Request corresponding to the ConnectJob (keyed - // by |group_name| and |handle|. |deactivate| indicates whether or not to - // deactivate the socket, making the socket slot available for a new socket - // connection. If |deactivate| is false, then set |socket| into |handle|. - // Returns the callback to run. - CompletionCallback* OnConnectingRequestComplete( + // If |was_async| is true, then ClientSocketPoolBase will pick a callback to + // run from a request associated with |group_name|. + virtual void OnConnectJobComplete( const std::string& group_name, - const ClientSocketHandle* handle, - bool deactivate, - ClientSocket* socket); + const ClientSocketHandle* key_handle, + ClientSocket* socket, + int result, + bool was_async); private: // Entry for a persistent socket which became idle at time |start_time|. @@ -218,11 +231,7 @@ class ClientSocketPoolBase : public base::RefCounted<ClientSocketPoolBase> { // The maximum number of sockets kept per group. const int max_sockets_per_group_; - // The host resolver that will be used to do host lookups for connecting - // sockets. - scoped_refptr<HostResolver> host_resolver_; - - scoped_ptr<ConnectJobFactory> connect_job_factory_; + const scoped_ptr<ConnectJobFactory> connect_job_factory_; DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBase); }; diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 9cd62e8..ef9f15f8 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/socket/tcp_client_socket_pool.h" +#include "net/socket/client_socket_pool_base.h" #include "base/compiler_specific.h" #include "base/message_loop.h" @@ -18,11 +18,11 @@ namespace net { namespace { -const int kMaxSocketsPerGroup = 6; +const int kMaxSocketsPerGroup = 2; // Note that the first and the last are the same, the first should be handled // before the last, since it was inserted first. -const int kPriorities[10] = { 1, 7, 9, 5, 6, 2, 8, 3, 4, 1 }; +const int kPriorities[] = { 1, 3, 4, 2, 1 }; // This is the number of extra requests beyond the first few that use up all // available sockets in the socket group. @@ -34,144 +34,48 @@ class MockClientSocket : public ClientSocket { public: MockClientSocket() : connected_(false) {} - // ClientSocket methods: - virtual int Connect(CompletionCallback* callback) { - connected_ = true; - return OK; - } - virtual void Disconnect() { - connected_ = false; - } - virtual bool IsConnected() const { - return connected_; - } - virtual bool IsConnectedAndIdle() const { - return connected_; - } - - // Socket methods: - virtual int Read(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - return ERR_FAILED; - } - virtual int Write(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - return ERR_FAILED; - } - - private: - bool connected_; -}; - -class MockFailingClientSocket : public ClientSocket { - public: - MockFailingClientSocket() {} - - // ClientSocket methods: - virtual int Connect(CompletionCallback* callback) { - return ERR_CONNECTION_FAILED; - } - - virtual void Disconnect() {} - - virtual bool IsConnected() const { - return false; - } - virtual bool IsConnectedAndIdle() const { - return false; - } - // Socket methods: - virtual int Read(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - return ERR_FAILED; + virtual int Read( + IOBuffer* /* buf */, int /* len */, CompletionCallback* /* callback */) { + return ERR_UNEXPECTED; } - virtual int Write(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - return ERR_FAILED; + virtual int Write( + IOBuffer* /* buf */, int /* len */, CompletionCallback* /* callback */) { + return ERR_UNEXPECTED; } -}; - -class MockPendingClientSocket : public ClientSocket { - public: - MockPendingClientSocket(bool should_connect) - : method_factory_(ALLOW_THIS_IN_INITIALIZER_LIST(this)), - should_connect_(should_connect), - is_connected_(false) {} // ClientSocket methods: - virtual int Connect(CompletionCallback* callback) { - MessageLoop::current()->PostTask( - FROM_HERE, - method_factory_.NewRunnableMethod( - &MockPendingClientSocket::DoCallback, callback)); - return ERR_IO_PENDING; - } - virtual void Disconnect() {} - - virtual bool IsConnected() const { - return is_connected_; - } - virtual bool IsConnectedAndIdle() const { - return is_connected_; - } - - // Socket methods: - virtual int Read(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - return ERR_FAILED; + virtual int Connect(CompletionCallback* callback) { + connected_ = true; + return OK; } - virtual int Write(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - return ERR_FAILED; + virtual void Disconnect() { connected_ = false; } + virtual bool IsConnected() const { return connected_; } + virtual bool IsConnectedAndIdle() const { return connected_; } + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr* /* name */, + socklen_t* /* namelen */) { + return 0; } +#endif private: - void DoCallback(CompletionCallback* callback) { - if (should_connect_) { - is_connected_ = true; - callback->Run(OK); - } else { - is_connected_ = false; - callback->Run(ERR_CONNECTION_FAILED); - } - } + bool connected_; - ScopedRunnableMethodFactory<MockPendingClientSocket> method_factory_; - bool should_connect_; - bool is_connected_; + DISALLOW_COPY_AND_ASSIGN(MockClientSocket); }; class MockClientSocketFactory : public ClientSocketFactory { public: - enum ClientSocketType { - MOCK_CLIENT_SOCKET, - MOCK_FAILING_CLIENT_SOCKET, - MOCK_PENDING_CLIENT_SOCKET, - MOCK_PENDING_FAILING_CLIENT_SOCKET, - }; - - MockClientSocketFactory() - : allocation_count_(0), client_socket_type_(MOCK_CLIENT_SOCKET) {} + MockClientSocketFactory() : allocation_count_(0) {} virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses) { allocation_count_++; - switch (client_socket_type_) { - case MOCK_CLIENT_SOCKET: - return new MockClientSocket(); - case MOCK_FAILING_CLIENT_SOCKET: - return new MockFailingClientSocket(); - case MOCK_PENDING_CLIENT_SOCKET: - return new MockPendingClientSocket(true); - case MOCK_PENDING_FAILING_CLIENT_SOCKET: - return new MockPendingClientSocket(false); - default: - NOTREACHED(); - return new MockClientSocket(); - } + return NULL; } virtual SSLClientSocket* CreateSSLClientSocket( @@ -184,13 +88,8 @@ class MockClientSocketFactory : public ClientSocketFactory { int allocation_count() const { return allocation_count_; } - void set_client_socket_type(ClientSocketType type) { - client_socket_type_ = type; - } - private: int allocation_count_; - ClientSocketType client_socket_type_; }; class TestSocketRequest : public CallbackRunner< Tuple1<int> > { @@ -221,17 +120,174 @@ class TestSocketRequest : public CallbackRunner< Tuple1<int> > { int TestSocketRequest::completion_count = 0; -class TCPClientSocketPoolTest : public testing::Test { +class TestConnectJob : public ConnectJob { + public: + enum JobType { + kMockJob, + kMockFailingJob, + kMockPendingJob, + kMockPendingFailingJob, + }; + + TestConnectJob(JobType job_type, + const std::string& group_name, + const ClientSocketPoolBase::Request& request, + ConnectJob::Delegate* delegate, + ClientSocketFactory* client_socket_factory) + : job_type_(job_type), + group_name_(group_name), + handle_(request.handle), + client_socket_factory_(client_socket_factory), + delegate_(delegate), + method_factory_(ALLOW_THIS_IN_INITIALIZER_LIST(this)) {} + + // ConnectJob methods: + + virtual int Connect() { + AddressList ignored; + client_socket_factory_->CreateTCPClientSocket(ignored); + switch (job_type_) { + case kMockJob: + return DoConnect(true /* successful */, false /* sync */); + case kMockFailingJob: + return DoConnect(false /* error */, false /* sync */); + case kMockPendingJob: + MessageLoop::current()->PostTask( + FROM_HERE, + method_factory_.NewRunnableMethod( + &TestConnectJob::DoConnect, + true /* successful */, + true /* async */)); + return ERR_IO_PENDING; + case kMockPendingFailingJob: + MessageLoop::current()->PostTask( + FROM_HERE, + method_factory_.NewRunnableMethod( + &TestConnectJob::DoConnect, + false /* error */, + true /* async */)); + return ERR_IO_PENDING; + default: + NOTREACHED(); + return ERR_FAILED; + } + } + + private: + int DoConnect(bool succeed, bool was_async) { + int result = ERR_CONNECTION_FAILED; + ClientSocket* socket = NULL; + if (succeed) { + result = OK; + socket = new MockClientSocket(); + socket->Connect(NULL); + } + delegate_->OnConnectJobComplete( + group_name_, handle_, socket, result, was_async); + return result; + } + + const JobType job_type_; + const std::string group_name_; + const ClientSocketHandle* handle_; + ClientSocketFactory* const client_socket_factory_; + Delegate* const delegate_; + ScopedRunnableMethodFactory<TestConnectJob> method_factory_; + + DISALLOW_COPY_AND_ASSIGN(TestConnectJob); +}; + +class TestConnectJobFactory : public ClientSocketPoolBase::ConnectJobFactory { + public: + explicit TestConnectJobFactory(ClientSocketFactory* client_socket_factory) + : job_type_(TestConnectJob::kMockJob), + client_socket_factory_(client_socket_factory) {} + + virtual ~TestConnectJobFactory() {} + + void set_job_type(TestConnectJob::JobType job_type) { job_type_ = job_type; } + + // ConnectJobFactory methods: + + virtual ConnectJob* NewConnectJob( + const std::string& group_name, + const ClientSocketPoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return new TestConnectJob(job_type_, + group_name, + request, + delegate, + client_socket_factory_); + } + + private: + TestConnectJob::JobType job_type_; + ClientSocketFactory* const client_socket_factory_; + + DISALLOW_COPY_AND_ASSIGN(TestConnectJobFactory); +}; + +class TestClientSocketPool : public ClientSocketPool { + public: + TestClientSocketPool( + int max_sockets_per_group, + ClientSocketPoolBase::ConnectJobFactory* connect_job_factory) + : base_(new ClientSocketPoolBase( + kMaxSocketsPerGroup, connect_job_factory)) {} + + virtual int RequestSocket( + const std::string& group_name, + const HostResolver::RequestInfo& resolve_info, + int priority, + ClientSocketHandle* handle, + CompletionCallback* callback) { + return base_->RequestSocket( + group_name, resolve_info, priority, handle, callback); + } + + virtual void CancelRequest( + const std::string& group_name, + const ClientSocketHandle* handle) { + base_->CancelRequest(group_name, handle); + } + + virtual void ReleaseSocket( + const std::string& group_name, + ClientSocket* socket) { + base_->ReleaseSocket(group_name, socket); + } + + virtual void CloseIdleSockets() { + base_->CloseIdleSockets(); + } + + virtual int IdleSocketCount() const { return base_->idle_socket_count(); } + + virtual int IdleSocketCountInGroup(const std::string& group_name) const { + return base_->IdleSocketCountInGroup(group_name); + } + + virtual LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const { + return base_->GetLoadState(group_name, handle); + } + + private: + const scoped_refptr<ClientSocketPoolBase> base_; + + DISALLOW_COPY_AND_ASSIGN(TestClientSocketPool); +}; + +class ClientSocketPoolBaseTest : public testing::Test { protected: - TCPClientSocketPoolTest() - : pool_(new TCPClientSocketPool(kMaxSocketsPerGroup, - &host_resolver_, - &client_socket_factory_)) {} + ClientSocketPoolBaseTest() + : ignored_request_info_("ignored", 80), + connect_job_factory_( + new TestConnectJobFactory(&client_socket_factory_)), + pool_(new TestClientSocketPool(kMaxSocketsPerGroup, + connect_job_factory_)) {} virtual void SetUp() { - RuleBasedHostMapper *host_mapper = new RuleBasedHostMapper(); - host_mapper->AddRule("*", "127.0.0.1"); - scoped_host_mapper_.Init(host_mapper); TestSocketRequest::completion_count = 0; } @@ -241,54 +297,31 @@ class TCPClientSocketPoolTest : public testing::Test { MessageLoop::current()->RunAllPending(); } - ScopedHostMapper scoped_host_mapper_; - HostResolver host_resolver_; + HostResolver::RequestInfo ignored_request_info_; MockClientSocketFactory client_socket_factory_; + TestConnectJobFactory* const connect_job_factory_; scoped_refptr<ClientSocketPool> pool_; std::vector<TestSocketRequest*> request_order_; }; -TEST_F(TCPClientSocketPoolTest, Basic) { +TEST_F(ClientSocketPoolBaseTest, Basic) { TestCompletionCallback callback; ClientSocketHandle handle(pool_.get()); - HostResolver::RequestInfo info("www.google.com", 80); - int rv = handle.Init("a", info, 0, &callback); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_FALSE(handle.is_initialized()); - EXPECT_FALSE(handle.socket()); - - EXPECT_EQ(OK, callback.WaitForResult()); + int rv = handle.Init("a", ignored_request_info_, 0, &callback); + EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); - handle.Reset(); } -TEST_F(TCPClientSocketPoolTest, InitHostResolutionFailure) { - RuleBasedHostMapper* host_mapper = new RuleBasedHostMapper; - host_mapper->AddSimulatedFailure("unresolvable.host.name"); - ScopedHostMapper scoped_host_mapper(host_mapper); +TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { + connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); TestSocketRequest req(pool_.get(), &request_order_); - HostResolver::RequestInfo info("unresolvable.host.name", 80); - EXPECT_EQ(ERR_IO_PENDING, req.handle.Init("a", info, 5, &req)); - EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req.WaitForResult()); -} - -TEST_F(TCPClientSocketPoolTest, InitConnectionFailure) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); - TestSocketRequest req(pool_.get(), &request_order_); - HostResolver::RequestInfo info("a", 80); - EXPECT_EQ(ERR_IO_PENDING, - req.handle.Init("a", info, 5, &req)); - EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); - // HostCache caches it, so MockFailingClientSocket will cause Init() to - // synchronously fail. EXPECT_EQ(ERR_CONNECTION_FAILED, - req.handle.Init("a", info, 5, &req)); + req.handle.Init("a", ignored_request_info_, 5, &req)); } -TEST_F(TCPClientSocketPoolTest, PendingRequests) { +TEST_F(ClientSocketPoolBaseTest, PendingRequests) { scoped_ptr<TestSocketRequest> reqs[kNumRequests]; for (size_t i = 0; i < arraysize(reqs); ++i) @@ -296,23 +329,17 @@ TEST_F(TCPClientSocketPoolTest, PendingRequests) { // Create connections or queue up requests. - // First request finishes asynchronously. - HostResolver::RequestInfo info("www.google.com", 80); - int rv = reqs[0]->handle.Init("a", info, 5, reqs[0].get()); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_EQ(OK, reqs[0]->WaitForResult()); - - // Rest of them finish synchronously, since they're in the HostCache. - for (int i = 1; i < kMaxSocketsPerGroup; ++i) { - rv = reqs[i]->handle.Init("a", info, 5, reqs[i].get()); + for (int i = 0; i < kMaxSocketsPerGroup; ++i) { + int rv = reqs[i]->handle.Init("a", ignored_request_info_, 5, reqs[i].get()); EXPECT_EQ(OK, rv); request_order_.push_back(reqs[i].get()); } // The rest are pending since we've used all active sockets. for (int i = 0; i < kNumPendingRequests; ++i) { - rv = reqs[kMaxSocketsPerGroup + i]->handle.Init( - "a", info, kPriorities[i], reqs[kMaxSocketsPerGroup + i].get()); + int rv = reqs[kMaxSocketsPerGroup + i]->handle.Init( + "a", ignored_request_info_, kPriorities[i], + reqs[kMaxSocketsPerGroup + i].get()); EXPECT_EQ(ERR_IO_PENDING, rv); } @@ -330,7 +357,7 @@ TEST_F(TCPClientSocketPoolTest, PendingRequests) { } while (released_one); EXPECT_EQ(kMaxSocketsPerGroup, client_socket_factory_.allocation_count()); - EXPECT_EQ(kNumPendingRequests + 1, TestSocketRequest::completion_count); + EXPECT_EQ(kNumPendingRequests, TestSocketRequest::completion_count); for (int i = 0; i < kMaxSocketsPerGroup; ++i) { EXPECT_EQ(request_order_[i], reqs[i].get()) << @@ -350,30 +377,24 @@ TEST_F(TCPClientSocketPoolTest, PendingRequests) { "earlier into the queue."; } -TEST_F(TCPClientSocketPoolTest, PendingRequests_NoKeepAlive) { +TEST_F(ClientSocketPoolBaseTest, PendingRequests_NoKeepAlive) { scoped_ptr<TestSocketRequest> reqs[kNumRequests]; for (size_t i = 0; i < arraysize(reqs); ++i) reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); // Create connections or queue up requests. - - // First request finishes asynchronously. - HostResolver::RequestInfo info("www.google.com", 80); - int rv = reqs[0]->handle.Init("a", info, 5, reqs[0].get()); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_EQ(OK, reqs[0]->WaitForResult()); - - // Rest of them finish synchronously, since they're in the HostCache. - for (int i = 1; i < kMaxSocketsPerGroup; ++i) { - rv = reqs[i]->handle.Init("a", info, 5, reqs[i].get()); + for (int i = 0; i < kMaxSocketsPerGroup; ++i) { + int rv = reqs[i]->handle.Init("a", ignored_request_info_, 5, reqs[i].get()); EXPECT_EQ(OK, rv); request_order_.push_back(reqs[i].get()); } // The rest are pending since we've used all active sockets. for (int i = 0; i < kNumPendingRequests; ++i) { - EXPECT_EQ(ERR_IO_PENDING, reqs[kMaxSocketsPerGroup + i]->handle.Init( - "a", info, 0, reqs[kMaxSocketsPerGroup + i].get())); + int rv = reqs[kMaxSocketsPerGroup + i]->handle.Init( + "a", ignored_request_info_, kPriorities[i], + reqs[kMaxSocketsPerGroup + i].get()); + EXPECT_EQ(ERR_IO_PENDING, rv); } // Release any connections until we have no connections. @@ -394,35 +415,29 @@ TEST_F(TCPClientSocketPoolTest, PendingRequests_NoKeepAlive) { EXPECT_EQ(OK, reqs[i]->WaitForResult()); EXPECT_EQ(kNumRequests, client_socket_factory_.allocation_count()); - EXPECT_EQ(kNumPendingRequests + 1, TestSocketRequest::completion_count); + EXPECT_EQ(kNumPendingRequests, TestSocketRequest::completion_count); } // This test will start up a RequestSocket() and then immediately Cancel() it. -// The pending host resolution will eventually complete, and destroy the -// ClientSocketPool which will crash if the group was not cleared properly. -TEST_F(TCPClientSocketPoolTest, CancelRequestClearGroup) { +// The pending connect job will be cancelled and should not call back into +// ClientSocketPoolBase. +TEST_F(ClientSocketPoolBaseTest, CancelRequestClearGroup) { + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); TestSocketRequest req(pool_.get(), &request_order_); - HostResolver::RequestInfo info("www.google.com", 80); - EXPECT_EQ(ERR_IO_PENDING, req.handle.Init("a", info, 5, &req)); + EXPECT_EQ(ERR_IO_PENDING, + req.handle.Init("a", ignored_request_info_, 5, &req)); req.handle.Reset(); - - PlatformThread::Sleep(100); - - // There is a race condition here. If the worker pool doesn't post the task - // before we get here, then this might not run ConnectingSocket::OnIOComplete - // and therefore leak the canceled ConnectingSocket. However, other tests - // after this will call MessageLoop::RunAllPending() which should prevent a - // leak, unless the worker thread takes longer than all of them. - MessageLoop::current()->RunAllPending(); } -TEST_F(TCPClientSocketPoolTest, TwoRequestsCancelOne) { +TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); TestSocketRequest req(pool_.get(), &request_order_); TestSocketRequest req2(pool_.get(), &request_order_); - HostResolver::RequestInfo info("www.google.com", 80); - EXPECT_EQ(ERR_IO_PENDING, req.handle.Init("a", info, 5, &req)); - EXPECT_EQ(ERR_IO_PENDING, req2.handle.Init("a", info, 5, &req2)); + EXPECT_EQ(ERR_IO_PENDING, + req.handle.Init("a", ignored_request_info_, 5, &req)); + EXPECT_EQ(ERR_IO_PENDING, + req2.handle.Init("a", ignored_request_info_, 5, &req2)); req.handle.Reset(); @@ -430,28 +445,20 @@ TEST_F(TCPClientSocketPoolTest, TwoRequestsCancelOne) { req2.handle.Reset(); } -TEST_F(TCPClientSocketPoolTest, ConnectCancelConnect) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); +TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle handle(pool_.get()); TestCompletionCallback callback; TestSocketRequest req(pool_.get(), &request_order_); - HostResolver::RequestInfo info("www.google.com", 80); - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", info, 5, &callback)); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", ignored_request_info_, 5, &callback)); handle.Reset(); TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", info, 5, &callback2)); - - // At this point, handle has two ConnectingSockets out for it. Due to the - // host cache, the host resolution for both will return in the same loop of - // the MessageLoop. The client socket is a pending socket, so the Connect() - // will asynchronously complete on the next loop of the MessageLoop. That - // means that the first ConnectingSocket will enter OnIOComplete, and then the - // second one will. If the first one is not cancelled, it will advance the - // load state, and then the second one will crash. + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", ignored_request_info_, 5, &callback2)); EXPECT_EQ(OK, callback2.WaitForResult()); EXPECT_FALSE(callback.have_result()); @@ -459,23 +466,15 @@ TEST_F(TCPClientSocketPoolTest, ConnectCancelConnect) { handle.Reset(); } -TEST_F(TCPClientSocketPoolTest, CancelRequest) { +TEST_F(ClientSocketPoolBaseTest, CancelRequest) { scoped_ptr<TestSocketRequest> reqs[kNumRequests]; for (size_t i = 0; i < arraysize(reqs); ++i) reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); // Create connections or queue up requests. - HostResolver::RequestInfo info("www.google.com", 80); - - // First request finishes asynchronously. - int rv = reqs[0]->handle.Init("a", info, 5, reqs[0].get()); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_EQ(OK, reqs[0]->WaitForResult()); - - // Rest of them finish synchronously, since they're in the HostCache. - for (int i = 1; i < kMaxSocketsPerGroup; ++i) { - rv = reqs[i]->handle.Init("a", info, 5, reqs[i].get()); + for (int i = 0; i < kMaxSocketsPerGroup; ++i) { + int rv = reqs[i]->handle.Init("a", ignored_request_info_, 5, reqs[i].get()); EXPECT_EQ(OK, rv); request_order_.push_back(reqs[i].get()); } @@ -483,7 +482,8 @@ TEST_F(TCPClientSocketPoolTest, CancelRequest) { // The rest are pending since we've used all active sockets. for (int i = 0; i < kNumPendingRequests; ++i) { EXPECT_EQ(ERR_IO_PENDING, reqs[kMaxSocketsPerGroup + i]->handle.Init( - "a", info, kPriorities[i], reqs[kMaxSocketsPerGroup + i].get())); + "a", ignored_request_info_, kPriorities[i], + reqs[kMaxSocketsPerGroup + i].get())); } // Cancel a request. @@ -505,7 +505,7 @@ TEST_F(TCPClientSocketPoolTest, CancelRequest) { } while (released_one); EXPECT_EQ(kMaxSocketsPerGroup, client_socket_factory_.allocation_count()); - EXPECT_EQ(kNumPendingRequests, TestSocketRequest::completion_count); + EXPECT_EQ(kNumPendingRequests - 1, TestSocketRequest::completion_count); for (int i = 0; i < kMaxSocketsPerGroup; ++i) { EXPECT_EQ(request_order_[i], reqs[i].get()) << @@ -543,7 +543,7 @@ class RequestSocketCallback : public CallbackRunner< Tuple1<int> > { within_callback_ = true; int rv = handle_->Init( "a", HostResolver::RequestInfo("www.google.com", 80), 0, this); - EXPECT_EQ(OK, rv); + EXPECT_EQ(ERR_IO_PENDING, rv); } } @@ -557,11 +557,12 @@ class RequestSocketCallback : public CallbackRunner< Tuple1<int> > { TestCompletionCallback callback_; }; -TEST_F(TCPClientSocketPoolTest, RequestTwice) { +TEST_F(ClientSocketPoolBaseTest, RequestTwice) { + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle handle(pool_.get()); RequestSocketCallback callback(&handle); int rv = handle.Init( - "a", HostResolver::RequestInfo("www.google.com", 80), 0, &callback); + "a", ignored_request_info_, 0, &callback); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback.WaitForResult()); @@ -571,18 +572,15 @@ TEST_F(TCPClientSocketPoolTest, RequestTwice) { // Make sure that pending requests get serviced after active requests get // cancelled. -TEST_F(TCPClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); +TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestWithPendingRequests) { + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); scoped_ptr<TestSocketRequest> reqs[kNumRequests]; // Queue up all the requests - - HostResolver::RequestInfo info("www.google.com", 80); for (size_t i = 0; i < arraysize(reqs); ++i) { reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); - int rv = reqs[i]->handle.Init("a", info, 5, reqs[i].get()); + int rv = reqs[i]->handle.Init("a", ignored_request_info_, 5, reqs[i].get()); EXPECT_EQ(ERR_IO_PENDING, rv); } @@ -601,18 +599,15 @@ TEST_F(TCPClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { } // Make sure that pending requests get serviced after active requests fail. -TEST_F(TCPClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); +TEST_F(ClientSocketPoolBaseTest, FailingActiveRequestWithPendingRequests) { + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); scoped_ptr<TestSocketRequest> reqs[kMaxSocketsPerGroup * 2 + 1]; // Queue up all the requests - - HostResolver::RequestInfo info("www.google.com", 80); for (size_t i = 0; i < arraysize(reqs); ++i) { reqs[i].reset(new TestSocketRequest(pool_.get(), &request_order_)); - int rv = reqs[i]->handle.Init("a", info, 5, reqs[i].get()); + int rv = reqs[i]->handle.Init("a", ignored_request_info_, 5, reqs[i].get()); EXPECT_EQ(ERR_IO_PENDING, rv); } diff --git a/net/socket/tcp_client_socket_pool.cc b/net/socket/tcp_client_socket_pool.cc index 9892605..c9af808 100644 --- a/net/socket/tcp_client_socket_pool.cc +++ b/net/socket/tcp_client_socket_pool.cc @@ -23,7 +23,8 @@ TCPConnectJob::TCPConnectJob( const HostResolver::RequestInfo& resolve_info, const ClientSocketHandle* handle, ClientSocketFactory* client_socket_factory, - ClientSocketPoolBase* pool) + HostResolver* host_resolver, + Delegate* delegate) : group_name_(group_name), resolve_info_(resolve_info), handle_(handle), @@ -31,8 +32,8 @@ TCPConnectJob::TCPConnectJob( ALLOW_THIS_IN_INITIALIZER_LIST( callback_(this, &TCPConnectJob::OnIOComplete)), - pool_(pool), - resolver_(pool->GetHostResolver()) {} + delegate_(delegate), + resolver_(host_resolver) {} TCPConnectJob::~TCPConnectJob() { // We don't worry about cancelling the host resolution and TCP connect, since @@ -40,6 +41,7 @@ TCPConnectJob::~TCPConnectJob() { } int TCPConnectJob::Connect() { + set_load_state(LOAD_STATE_RESOLVING_HOST); int rv = resolver_.Resolve(resolve_info_, &addresses_, &callback_); if (rv != ERR_IO_PENDING) rv = OnIOCompleteInternal(rv, true /* synchronous */); @@ -54,12 +56,8 @@ int TCPConnectJob::OnIOCompleteInternal( int result, bool synchronous) { CHECK(result != ERR_IO_PENDING); - ClientSocketPoolBase::Request* request = pool_->GetConnectingRequest( - group_name_, handle_); - CHECK(request); - - if (result == OK && request->load_state == LOAD_STATE_RESOLVING_HOST) { - request->load_state = LOAD_STATE_CONNECTING; + if (result == OK && load_state() == LOAD_STATE_RESOLVING_HOST) { + set_load_state(LOAD_STATE_CONNECTING); socket_.reset(client_socket_factory_->CreateTCPClientSocket(addresses_)); connect_start_time_ = base::TimeTicks::Now(); result = socket_->Connect(&callback_); @@ -68,7 +66,7 @@ int TCPConnectJob::OnIOCompleteInternal( } if (result == OK) { - CHECK(request->load_state == LOAD_STATE_CONNECTING); + DCHECK_EQ(load_state(), LOAD_STATE_CONNECTING); CHECK(connect_start_time_ != base::TimeTicks()); base::TimeDelta connect_duration = base::TimeTicks::Now() - connect_start_time_; @@ -83,38 +81,24 @@ int TCPConnectJob::OnIOCompleteInternal( // Now, we either succeeded at Connect()'ing, or we failed at host resolution // or Connect()'ing. Either way, we'll run the callback to alert the client. - CompletionCallback* callback = NULL; - - if (result == OK) { - callback = pool_->OnConnectingRequestComplete( - group_name_, - handle_, - false /* don't deactivate socket */, - socket_.release()); - } else { - callback = pool_->OnConnectingRequestComplete( - group_name_, - handle_, - true /* deactivate socket */, - NULL /* no connected socket to give */); - } + delegate_->OnConnectJobComplete( + group_name_, + handle_, + result == OK ? socket_.release() : NULL, + result, + !synchronous); // |this| is deleted after this point. - - CHECK(callback); - - if (!synchronous) - callback->Run(result); return result; } ConnectJob* TCPClientSocketPool::TCPConnectJobFactory::NewConnectJob( const std::string& group_name, const ClientSocketPoolBase::Request& request, - ClientSocketPoolBase* pool) const { + ConnectJob::Delegate* delegate) const { return new TCPConnectJob( group_name, request.resolve_info, request.handle, - client_socket_factory_, pool); + client_socket_factory_, host_resolver_, delegate); } TCPClientSocketPool::TCPClientSocketPool( @@ -122,8 +106,8 @@ TCPClientSocketPool::TCPClientSocketPool( HostResolver* host_resolver, ClientSocketFactory* client_socket_factory) : base_(new ClientSocketPoolBase( - max_sockets_per_group, host_resolver, - new TCPConnectJobFactory(client_socket_factory))) {} + max_sockets_per_group, + new TCPConnectJobFactory(client_socket_factory, host_resolver))) {} TCPClientSocketPool::~TCPClientSocketPool() {} diff --git a/net/socket/tcp_client_socket_pool.h b/net/socket/tcp_client_socket_pool.h index 11982b6..4bbb9ba 100644 --- a/net/socket/tcp_client_socket_pool.h +++ b/net/socket/tcp_client_socket_pool.h @@ -25,8 +25,9 @@ class TCPConnectJob : public ConnectJob { const HostResolver::RequestInfo& resolve_info, const ClientSocketHandle* handle, ClientSocketFactory* client_socket_factory, - ClientSocketPoolBase* pool); - ~TCPConnectJob(); + HostResolver* host_resolver, + Delegate* delegate); + virtual ~TCPConnectJob(); // ConnectJob methods. @@ -53,7 +54,7 @@ class TCPConnectJob : public ConnectJob { ClientSocketFactory* const client_socket_factory_; CompletionCallbackImpl<TCPConnectJob> callback_; scoped_ptr<ClientSocket> socket_; - ClientSocketPoolBase* const pool_; + Delegate* const delegate_; SingleRequestHostResolver resolver_; AddressList addresses_; @@ -85,10 +86,6 @@ class TCPClientSocketPool : public ClientSocketPool { virtual void CloseIdleSockets(); - virtual HostResolver* GetHostResolver() const { - return base_->GetHostResolver(); - } - virtual int IdleSocketCount() const { return base_->idle_socket_count(); } @@ -104,8 +101,10 @@ class TCPClientSocketPool : public ClientSocketPool { class TCPConnectJobFactory : public ClientSocketPoolBase::ConnectJobFactory { public: - explicit TCPConnectJobFactory(ClientSocketFactory* client_socket_factory) - : client_socket_factory_(client_socket_factory) {} + TCPConnectJobFactory(ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver) + : client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver) {} virtual ~TCPConnectJobFactory() {} @@ -114,10 +113,11 @@ class TCPClientSocketPool : public ClientSocketPool { virtual ConnectJob* NewConnectJob( const std::string& group_name, const ClientSocketPoolBase::Request& request, - ClientSocketPoolBase* pool) const; + ConnectJob::Delegate* delegate) const; private: ClientSocketFactory* const client_socket_factory_; + const scoped_refptr<HostResolver> host_resolver_; DISALLOW_COPY_AND_ASSIGN(TCPConnectJobFactory); }; |