diff options
author | vandebo@chromium.org <vandebo@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-03-22 17:17:26 +0000 |
---|---|---|
committer | vandebo@chromium.org <vandebo@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-03-22 17:17:26 +0000 |
commit | a796bcec176ca3875a55346800b3a60a83e2dd89 (patch) | |
tree | 2533c17673ff50f4f101e803c2dff3bf8f5cbf7b | |
parent | 35818452760c23c570b7947e00a3b38e733ce58e (diff) | |
download | chromium_src-a796bcec176ca3875a55346800b3a60a83e2dd89.zip chromium_src-a796bcec176ca3875a55346800b3a60a83e2dd89.tar.gz chromium_src-a796bcec176ca3875a55346800b3a60a83e2dd89.tar.bz2 |
Implement SOCKSClientSocketPool
This is the first layered pool, so there are several infrastructure changes in this change as well.
Add a ConnectionTimeout method to pools so that layered pools can timeout each phase.
Add a name method to pools to support per pool UMA histograms.
Change SOCKS sockets to take a ClientSocketHandle instead of a ClientSocket
BUG=30357 (blocks an SSL Pool)
TEST=existing unit tests
Review URL: http://codereview.chromium.org/668097
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@42231 0039d316-1c4b-4281-b951-d872f2087c98
25 files changed, 1183 insertions, 262 deletions
diff --git a/net/http/http_network_session.cc b/net/http/http_network_session.cc index 79af40f..d7dbe94 100644 --- a/net/http/http_network_session.cc +++ b/net/http/http_network_session.cc @@ -20,6 +20,8 @@ int HttpNetworkSession::max_sockets_per_group_ = 6; uint16 HttpNetworkSession::g_fixed_http_port = 0; uint16 HttpNetworkSession::g_fixed_https_port = 0; +// TODO(vandebo) when we've completely converted to pools, the base TCP +// pool name should get changed to TCP instead of Transport. HttpNetworkSession::HttpNetworkSession( NetworkChangeNotifier* network_change_notifier, HostResolver* host_resolver, @@ -30,8 +32,15 @@ HttpNetworkSession::HttpNetworkSession( HttpAuthHandlerFactory* http_auth_handler_factory) : network_change_notifier_(network_change_notifier), tcp_socket_pool_(new TCPClientSocketPool( - max_sockets_, max_sockets_per_group_, + max_sockets_, max_sockets_per_group_, "Transport", host_resolver, client_socket_factory, network_change_notifier_)), + socks_socket_pool_(new SOCKSClientSocketPool( + max_sockets_, max_sockets_per_group_, "SOCKS", host_resolver, + new TCPClientSocketPool(max_sockets_, max_sockets_per_group_, + "TCPForSOCKS", host_resolver, + client_socket_factory, + network_change_notifier_), + network_change_notifier_)), socket_factory_(client_socket_factory), host_resolver_(host_resolver), proxy_service_(proxy_service), @@ -53,9 +62,12 @@ void HttpNetworkSession::set_max_sockets_per_group(int socket_count) { max_sockets_per_group_ = socket_count; } +// TODO(vandebo) when we've completely converted to pools, the base TCP +// pool name should get changed to TCP instead of Transport. void HttpNetworkSession::ReplaceTCPSocketPool() { tcp_socket_pool_ = new TCPClientSocketPool(max_sockets_, max_sockets_per_group_, + "Transport", host_resolver_, socket_factory_, network_change_notifier_); diff --git a/net/http/http_network_session.h b/net/http/http_network_session.h index 633798ec..5b2cb97 100644 --- a/net/http/http_network_session.h +++ b/net/http/http_network_session.h @@ -12,6 +12,7 @@ #include "net/http/http_alternate_protocols.h" #include "net/http/http_auth_cache.h" #include "net/proxy/proxy_service.h" +#include "net/socket/socks_client_socket_pool.h" #include "net/socket/tcp_client_socket_pool.h" namespace net { @@ -46,8 +47,13 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { } // TCP sockets come from the tcp_socket_pool(). - TCPClientSocketPool* tcp_socket_pool() { return tcp_socket_pool_; } - // SSL sockets come frmo the socket_factory(). + const scoped_refptr<TCPClientSocketPool>& tcp_socket_pool() { + return tcp_socket_pool_; + } + const scoped_refptr<SOCKSClientSocketPool>& socks_socket_pool() { + return socks_socket_pool_; + } + // SSL sockets come from the socket_factory(). ClientSocketFactory* socket_factory() { return socket_factory_; } HostResolver* host_resolver() { return host_resolver_; } ProxyService* proxy_service() { return proxy_service_; } @@ -93,6 +99,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { HttpAlternateProtocols alternate_protocols_; NetworkChangeNotifier* const network_change_notifier_; scoped_refptr<TCPClientSocketPool> tcp_socket_pool_; + scoped_refptr<SOCKSClientSocketPool> socks_socket_pool_; ClientSocketFactory* socket_factory_; scoped_refptr<HostResolver> host_resolver_; scoped_refptr<ProxyService> proxy_service_; diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index 8c81725..1a2c9a0 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -30,8 +30,7 @@ #include "net/http/http_response_info.h" #include "net/http/http_util.h" #include "net/socket/client_socket_factory.h" -#include "net/socket/socks5_client_socket.h" -#include "net/socket/socks_client_socket.h" +#include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/tcp_client_socket_pool.h" #include "net/spdy/spdy_session.h" @@ -521,15 +520,6 @@ int HttpNetworkTransaction::DoLoop(int result) { rv = DoInitConnectionComplete(rv); TRACE_EVENT_END("http.init_conn", request_, request_->url.spec()); break; - case STATE_SOCKS_CONNECT: - DCHECK_EQ(OK, rv); - TRACE_EVENT_BEGIN("http.socks_connect", request_, request_->url.spec()); - rv = DoSOCKSConnect(); - break; - case STATE_SOCKS_CONNECT_COMPLETE: - rv = DoSOCKSConnectComplete(rv); - TRACE_EVENT_END("http.socks_connect", request_, request_->url.spec()); - break; case STATE_SSL_CONNECT: DCHECK_EQ(OK, rv); TRACE_EVENT_BEGIN("http.ssl_connect", request_, request_->url.spec()); @@ -674,6 +664,7 @@ int HttpNetworkTransaction::DoInitConnection() { using_ssl_ = request_->url.SchemeIs("https"); using_spdy_ = false; + // TODO(vandebo) get rid of proxy_mode_, it's redundant if (proxy_info_.is_direct()) proxy_mode_ = kDirectConnection; else if (proxy_info_.proxy_server().is_socks()) @@ -745,9 +736,24 @@ int HttpNetworkTransaction::DoInitConnection() { TCPSocketParams tcp_params(host, port, request_->priority, request_->referrer, disable_resolver_cache); - int rv = connection_->Init(connection_group, tcp_params, request_->priority, - &io_callback_, session_->tcp_socket_pool(), - net_log_); + int rv; + if (proxy_mode_ != kSOCKSProxy) { + rv = connection_->Init(connection_group, tcp_params, request_->priority, + &io_callback_, session_->tcp_socket_pool(), + net_log_); + } else { + bool socks_v5 = proxy_info_.proxy_server().scheme() == + ProxyServer::SCHEME_SOCKS5; + SOCKSSocketParams socks_params(tcp_params, socks_v5, + request_->url.HostNoBrackets(), + request_->url.EffectiveIntPort(), + request_->priority, request_->referrer); + + rv = connection_->Init(connection_group, socks_params, request_->priority, + &io_callback_, session_->socks_socket_pool(), + net_log_); + } + return rv; } @@ -770,7 +776,7 @@ int HttpNetworkTransaction::DoInitConnectionComplete(int result) { return OK; } - LogTCPConnectedMetrics(*connection_); + LogHttpConnectedMetrics(*connection_); // Set the reused_socket_ flag to indicate that we are using a keep-alive // connection. This flag is used to handle errors that occur while we are @@ -782,9 +788,8 @@ int HttpNetworkTransaction::DoInitConnectionComplete(int result) { // Now we have a TCP connected socket. Perform other connection setup as // needed. UpdateConnectionTypeHistograms(CONNECTION_HTTP); - if (proxy_mode_ == kSOCKSProxy) - next_state_ = STATE_SOCKS_CONNECT; - else if (using_ssl_ && proxy_mode_ == kDirectConnection) { + if (using_ssl_ && (proxy_mode_ == kDirectConnection || + proxy_mode_ == kSOCKSProxy)) { next_state_ = STATE_SSL_CONNECT; } else { next_state_ = STATE_SEND_REQUEST; @@ -796,41 +801,6 @@ int HttpNetworkTransaction::DoInitConnectionComplete(int result) { return OK; } -int HttpNetworkTransaction::DoSOCKSConnect() { - DCHECK_EQ(kSOCKSProxy, proxy_mode_); - - next_state_ = STATE_SOCKS_CONNECT_COMPLETE; - - // Add a SOCKS connection on top of our existing transport socket. - ClientSocket* s = connection_->release_socket(); - HostResolver::RequestInfo req_info(request_->url.HostNoBrackets(), - request_->url.EffectiveIntPort()); - req_info.set_referrer(request_->referrer); - req_info.set_priority(request_->priority); - - if (proxy_info_.proxy_server().scheme() == ProxyServer::SCHEME_SOCKS5) - s = new SOCKS5ClientSocket(s, req_info); - else - s = new SOCKSClientSocket(s, req_info, session_->host_resolver()); - connection_->set_socket(s); - return connection_->socket()->Connect(&io_callback_, net_log_); -} - -int HttpNetworkTransaction::DoSOCKSConnectComplete(int result) { - DCHECK_EQ(kSOCKSProxy, proxy_mode_); - - if (result == OK) { - if (using_ssl_) { - next_state_ = STATE_SSL_CONNECT; - } else { - next_state_ = STATE_SEND_REQUEST; - } - } else { - result = ReconsiderProxyAfterError(result); - } - return result; -} - int HttpNetworkTransaction::DoSSLConnect() { next_state_ = STATE_SSL_CONNECT_COMPLETE; @@ -1262,42 +1232,32 @@ int HttpNetworkTransaction::DoSpdyReadBodyComplete(int result) { return result; } -void HttpNetworkTransaction::LogTCPConnectedMetrics( +void HttpNetworkTransaction::LogHttpConnectedMetrics( const ClientSocketHandle& handle) { - const base::TimeDelta time_to_obtain_connected_socket = - base::TimeTicks::Now() - handle.init_time(); - - if (handle.reuse_type() == ClientSocketHandle::UNUSED) { - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.HttpConnectionLatency", - time_to_obtain_connected_socket, - base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), - 100); - } - - UMA_HISTOGRAM_ENUMERATION("Net.TCPSocketType", handle.reuse_type(), + UMA_HISTOGRAM_ENUMERATION("Net.HttpSocketType", handle.reuse_type(), ClientSocketHandle::NUM_TYPES); - UMA_HISTOGRAM_CLIPPED_TIMES( - "Net.TransportSocketRequestTime", - time_to_obtain_connected_socket, - base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), - 100); - switch (handle.reuse_type()) { case ClientSocketHandle::UNUSED: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.HttpConnectionLatency", + handle.setup_time(), + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); break; case ClientSocketHandle::UNUSED_IDLE: - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.SocketIdleTimeBeforeNextUse_UnusedSocket", - handle.idle_time(), base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(6), 100); + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SocketIdleTimeBeforeNextUse_UnusedSocket", + handle.idle_time(), + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(6), + 100); break; case ClientSocketHandle::REUSED_IDLE: - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.SocketIdleTimeBeforeNextUse_ReusedSocket", - handle.idle_time(), base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(6), 100); + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SocketIdleTimeBeforeNextUse_ReusedSocket", + handle.idle_time(), + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(6), + 100); break; default: NOTREACHED(); diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index 5e51e90..ca4e882 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2009 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -77,8 +77,6 @@ class HttpNetworkTransaction : public HttpTransaction { STATE_RESOLVE_PROXY_COMPLETE, STATE_INIT_CONNECTION, STATE_INIT_CONNECTION_COMPLETE, - STATE_SOCKS_CONNECT, - STATE_SOCKS_CONNECT_COMPLETE, STATE_SSL_CONNECT, STATE_SSL_CONNECT_COMPLETE, STATE_SEND_REQUEST, @@ -125,8 +123,6 @@ class HttpNetworkTransaction : public HttpTransaction { int DoResolveProxyComplete(int result); int DoInitConnection(); int DoInitConnectionComplete(int result); - int DoSOCKSConnect(); - int DoSOCKSConnectComplete(int result); int DoSSLConnect(); int DoSSLConnectComplete(int result); int DoSendRequest(); @@ -145,7 +141,7 @@ class HttpNetworkTransaction : public HttpTransaction { int DoSpdyReadBodyComplete(int result); // Record histograms of latency until Connect() completes. - static void LogTCPConnectedMetrics(const ClientSocketHandle& handle); + static void LogHttpConnectedMetrics(const ClientSocketHandle& handle); // Record histogram of time until first byte of header is received. void LogTransactionConnectedMetrics() const; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index caf03fd..b25dcfb 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -182,9 +182,14 @@ std::string MockGetHostName() { return "WTC-WIN7"; } -class CaptureGroupNameSocketPool : public TCPClientSocketPool { +template<typename EmulatedClientSocketPool, typename SocketSourceType> +class CaptureGroupNameSocketPool : public EmulatedClientSocketPool { public: - CaptureGroupNameSocketPool() : TCPClientSocketPool(0, 0, NULL, NULL, NULL) {} + CaptureGroupNameSocketPool(HttpNetworkSession* session, + SocketSourceType* socket_source) + : EmulatedClientSocketPool(0, 0, "CaptureGroupNameTestPool", + session->host_resolver(), socket_source, + NULL) {} const std::string last_group_name_received() const { return last_group_name_; } @@ -216,11 +221,18 @@ class CaptureGroupNameSocketPool : public TCPClientSocketPool { const ClientSocketHandle* handle) const { return LOAD_STATE_IDLE; } + virtual base::TimeDelta ConnectionTimeout() const { + return base::TimeDelta(); + } private: std::string last_group_name_; }; +typedef CaptureGroupNameSocketPool<TCPClientSocketPool, ClientSocketFactory> + CaptureGroupNameTCPSocketPool; +typedef CaptureGroupNameSocketPool<SOCKSClientSocketPool, TCPClientSocketPool> + CaptureGroupNameSOCKSSocketPool; //----------------------------------------------------------------------------- TEST_F(HttpNetworkTransactionTest, Basic) { @@ -3654,11 +3666,16 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForProxyConnections) { SessionDependencies session_deps( CreateFixedProxyService(tests[i].proxy_server)); - scoped_refptr<CaptureGroupNameSocketPool> conn_pool( - new CaptureGroupNameSocketPool()); - scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps)); - session->tcp_socket_pool_ = conn_pool.get(); + + scoped_refptr<CaptureGroupNameTCPSocketPool> tcp_conn_pool( + new CaptureGroupNameTCPSocketPool(session.get(), + session->socket_factory())); + session->tcp_socket_pool_ = tcp_conn_pool.get(); + scoped_refptr<CaptureGroupNameSOCKSSocketPool> socks_conn_pool( + new CaptureGroupNameSOCKSSocketPool(session.get(), + tcp_conn_pool.get())); + session->socks_socket_pool_ = socks_conn_pool.get(); scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction(session)); @@ -3671,8 +3688,9 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForProxyConnections) { // We do not complete this request, the dtor will clean the transaction up. EXPECT_EQ(ERR_IO_PENDING, trans->Start(&request, &callback, NULL)); - EXPECT_EQ(tests[i].expected_group_name, - conn_pool->last_group_name_received()); + std::string allgroups = tcp_conn_pool->last_group_name_received() + + socks_conn_pool->last_group_name_received(); + EXPECT_EQ(tests[i].expected_group_name, allgroups); } } diff --git a/net/net.gyp b/net/net.gyp index 7e22af8..1b66ab8 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -424,6 +424,8 @@ 'socket/socks5_client_socket.h', 'socket/socks_client_socket.cc', 'socket/socks_client_socket.h', + 'socket/socks_client_socket_pool.cc', + 'socket/socks_client_socket_pool.h', 'socket/ssl_client_socket.h', 'socket/ssl_client_socket_mac.cc', 'socket/ssl_client_socket_mac.h', @@ -681,6 +683,7 @@ 'proxy/single_threaded_proxy_resolver_unittest.cc', 'socket/client_socket_pool_base_unittest.cc', 'socket/socks5_client_socket_unittest.cc', + 'socket/socks_client_socket_pool_unittest.cc', 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc index d8b0f8f..d99c7ad 100644 --- a/net/socket/client_socket_handle.cc +++ b/net/socket/client_socket_handle.cc @@ -1,10 +1,11 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/socket/client_socket_handle.h" #include "base/compiler_specific.h" +#include "base/histogram.h" #include "base/logging.h" #include "net/base/net_errors.h" #include "net/socket/client_socket_pool.h" @@ -29,9 +30,12 @@ void ClientSocketHandle::ResetInternal(bool cancel) { if (group_name_.empty()) // Was Init called? return; if (socket_.get()) { - // If we've still got a socket, release it back to the ClientSocketPool so - // it can be deleted or reused. - pool_->ReleaseSocket(group_name_, release_socket()); + // Because of http://crbug.com/37810 we may not have a pool, but have + // just a raw socket. + if (pool_) + // If we've still got a socket, release it back to the ClientSocketPool so + // it can be deleted or reused. + pool_->ReleaseSocket(group_name_, release_socket()); } else if (cancel) { // If we did not get initialized yet, so we've got a socket request pending. // Cancel it. @@ -43,11 +47,16 @@ void ClientSocketHandle::ResetInternal(bool cancel) { pool_ = NULL; idle_time_ = base::TimeDelta(); init_time_ = base::TimeTicks(); + setup_time_ = base::TimeDelta(); } LoadState ClientSocketHandle::GetLoadState() const { CHECK(!is_initialized()); CHECK(!group_name_.empty()); + // Because of http://crbug.com/37810 we may not have a pool, but have + // just a raw socket. + if (!pool_) + return LOAD_STATE_IDLE; return pool_->GetLoadState(group_name_, this); } @@ -62,8 +71,39 @@ void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(ERR_IO_PENDING, result); // TODO(vandebo) remove when bug 31096 is resolved CHECK(socket_.get() || result != OK); - if (result != OK) + if (result != OK) { ResetInternal(false); // The request failed, so there's nothing to cancel. + return; + } + setup_time_ = base::TimeTicks::Now() - init_time_; + + std::string metric = "Net." + pool_->name() + "SocketType"; + UMA_HISTOGRAM_ENUMERATION(metric, reuse_type(), NUM_TYPES); + switch (reuse_type()) { + case ClientSocketHandle::UNUSED: + metric = "Net." + pool_->name() + "SocketRequestTime"; + UMA_HISTOGRAM_CLIPPED_TIMES(metric, setup_time(), + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), 100); + break; + case ClientSocketHandle::UNUSED_IDLE: + metric = "Net." + pool_->name() + + "SocketIdleTimeBeforeNextUse_UnusedSocket"; + UMA_HISTOGRAM_CUSTOM_TIMES(metric, idle_time(), + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(6), 100); + break; + case ClientSocketHandle::REUSED_IDLE: + metric = "Net." + pool_->name() + + "SocketIdleTimeBeforeNextUse_ReusedSocket"; + UMA_HISTOGRAM_CUSTOM_TIMES(metric, idle_time(), + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(6), 100); + break; + default: + NOTREACHED(); + break; + } } } // namespace net diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h index 782ad48..f750c7b 100644 --- a/net/socket/client_socket_handle.h +++ b/net/socket/client_socket_handle.h @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -62,7 +62,7 @@ class ClientSocketHandle { const SocketParams& socket_params, RequestPriority priority, CompletionCallback* callback, - PoolType* pool, + const scoped_refptr<PoolType>& pool, const BoundNetLog& net_log); // An initialized handle can be reset, which causes it to return to the @@ -85,6 +85,9 @@ class ClientSocketHandle { // Returns the time tick when Init() was called. base::TimeTicks init_time() const { return init_time_; } + // Returns the time between Init() and when is_initialized() becomes true. + base::TimeDelta setup_time() const { return setup_time_; } + // Used by ClientSocketPool to initialize the ClientSocketHandle. void set_is_reused(bool is_reused) { is_reused_ = is_reused; } void set_socket(ClientSocket* s) { socket_.reset(s); } @@ -139,6 +142,7 @@ class ClientSocketHandle { CompletionCallback* user_callback_; base::TimeDelta idle_time_; base::TimeTicks init_time_; + base::TimeDelta setup_time_; DISALLOW_COPY_AND_ASSIGN(ClientSocketHandle); }; @@ -149,7 +153,7 @@ int ClientSocketHandle::Init(const std::string& group_name, const SocketParams& socket_params, RequestPriority priority, CompletionCallback* callback, - PoolType* pool, + const scoped_refptr<PoolType>& pool, const BoundNetLog& net_log) { CHECK(!group_name.empty()); // Note that this will result in a link error if the SocketParams has not been diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h index 16f408b..fb77ac4 100644 --- a/net/socket/client_socket_pool.h +++ b/net/socket/client_socket_pool.h @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -58,7 +58,8 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { // Called to cancel a RequestSocket call that returned ERR_IO_PENDING. The // same handle parameter must be passed to this method as was passed to the // RequestSocket call being cancelled. The associated CompletionCallback is - // not run. + // not run. However, for performance, we will let one ConnectJob complete + // and go idle. virtual void CancelRequest(const std::string& group_name, const ClientSocketHandle* handle) = 0; @@ -85,10 +86,16 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { // Returns the maximum amount of time to wait before retrying a connect. static const int kMaxConnectRetryIntervalMs = 250; + // The name of this pool, i.e. TCP, SOCKS. + virtual const std::string& name() const = 0; + protected: ClientSocketPool() {} virtual ~ClientSocketPool() {} + // Return the connection timeout for this pool. + virtual base::TimeDelta ConnectionTimeout() const = 0; + private: friend class base::RefCounted<ClientSocketPool>; diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index d567743..75196f5 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -78,6 +78,11 @@ void ConnectJob::NotifyDelegateOfCompletion(int rv) { delegate->OnConnectJobComplete(rv, this); } +void ConnectJob::ResetTimer(base::TimeDelta remaining_time) { + timer_.Stop(); + timer_.Start(remaining_time, this, &ConnectJob::OnTimeout); +} + void ConnectJob::OnTimeout() { // Make sure the socket is NULL before calling into |delegate|. set_socket(NULL); @@ -324,6 +329,7 @@ void ClientSocketPoolBaseHelper::CancelRequest( req->net_log().AddEvent(NetLog::TYPE_CANCELLED); req->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); delete req; + // Let one connect job connect and become idle for potential future use. if (group.jobs.size() > group.pending_requests.size() + 1) { // TODO(willchan): Cancel the job in the earliest LoadState. RemoveConnectJob(*group.jobs.begin(), &group); diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h index 08fb0f0..d1abd58 100644 --- a/net/socket/client_socket_pool_base.h +++ b/net/socket/client_socket_pool_base.h @@ -1,4 +1,4 @@ -// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // @@ -94,6 +94,7 @@ class ConnectJob { void set_socket(ClientSocket* socket) { socket_.reset(socket); } ClientSocket* socket() { return socket_.get(); } void NotifyDelegateOfCompletion(int rv); + void ResetTimer(base::TimeDelta remainingTime); private: virtual int ConnectInternal() = 0; @@ -158,6 +159,8 @@ class ClientSocketPoolBaseHelper ConnectJob::Delegate* delegate, const BoundNetLog& net_log) const = 0; + virtual base::TimeDelta ConnectionTimeout() const = 0; + private: DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory); }; @@ -222,6 +225,10 @@ class ClientSocketPoolBaseHelper // sockets that timed out or can't be reused. Made public for testing. void CleanupIdleSockets(bool force); + base::TimeDelta ConnectionTimeout() const { + return connect_job_factory_->ConnectionTimeout(); + } + void enable_backup_jobs() { backup_jobs_enabled_ = true; }; private: @@ -477,6 +484,8 @@ class ClientSocketPoolBase { ConnectJob::Delegate* delegate, const BoundNetLog& net_log) const = 0; + virtual base::TimeDelta ConnectionTimeout() const = 0; + private: DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory); }; @@ -490,11 +499,13 @@ class ClientSocketPoolBase { ClientSocketPoolBase( int max_sockets, int max_sockets_per_group, + const std::string& name, base::TimeDelta unused_idle_socket_timeout, base::TimeDelta used_idle_socket_timeout, ConnectJobFactory* connect_job_factory, NetworkChangeNotifier* network_change_notifier) - : helper_(new internal::ClientSocketPoolBaseHelper( + : name_(name), + helper_(new internal::ClientSocketPoolBaseHelper( max_sockets, max_sockets_per_group, unused_idle_socket_timeout, used_idle_socket_timeout, new ConnectJobFactoryAdaptor(connect_job_factory), @@ -560,6 +571,12 @@ class ClientSocketPoolBase { return helper_->CleanupIdleSockets(force); } + base::TimeDelta ConnectionTimeout() const { + return helper_->ConnectionTimeout(); + } + + const std::string& name() const { return name_; } + void enable_backup_jobs() { helper_->enable_backup_jobs(); }; private: @@ -589,9 +606,16 @@ class ClientSocketPoolBase { group_name, *casted_request, delegate, net_log); } + virtual base::TimeDelta ConnectionTimeout() const { + return connect_job_factory_->ConnectionTimeout(); + } + const scoped_ptr<ConnectJobFactory> connect_job_factory_; }; + // Name of this pool. + const std::string name_; + // One might ask why ClientSocketPoolBaseHelper is also refcounted if its // containing ClientSocketPool is already refcounted. The reason is because // DoReleaseSocket() posts a task. If ClientSocketPool gets deleted between diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index cffa75a..65feffd 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -259,6 +259,10 @@ class TestConnectJobFactory net_log); } + virtual base::TimeDelta ConnectionTimeout() const { + return timeout_duration_; + } + private: TestConnectJob::JobType job_type_; base::TimeDelta timeout_duration_; @@ -272,10 +276,11 @@ class TestClientSocketPool : public ClientSocketPool { TestClientSocketPool( int max_sockets, int max_sockets_per_group, + const std::string& name, base::TimeDelta unused_idle_socket_timeout, base::TimeDelta used_idle_socket_timeout, TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory) - : base_(max_sockets, max_sockets_per_group, + : base_(max_sockets, max_sockets_per_group, name, unused_idle_socket_timeout, used_idle_socket_timeout, connect_job_factory, NULL) {} @@ -317,6 +322,12 @@ class TestClientSocketPool : public ClientSocketPool { return base_.GetLoadState(group_name, handle); } + virtual base::TimeDelta ConnectionTimeout() const { + return base_.ConnectionTimeout(); + } + + virtual const std::string& name() const { return base_.name(); } + const TestClientSocketPoolBase* base() const { return &base_; } int NumConnectJobsInGroup(const std::string& group_name) const { @@ -401,6 +412,7 @@ class ClientSocketPoolBaseTest : public ClientSocketPoolTest { connect_job_factory_ = new TestConnectJobFactory(&client_socket_factory_); pool_ = new TestClientSocketPool(max_sockets, max_sockets_per_group, + "IdleTimeoutTestPool", unused_idle_socket_timeout, used_idle_socket_timeout, connect_job_factory_); @@ -409,7 +421,7 @@ class ClientSocketPoolBaseTest : public ClientSocketPoolTest { int StartRequest(const std::string& group_name, net::RequestPriority priority) { return StartRequestUsingPool<TestClientSocketPool, TestSocketParams>( - pool_.get(), group_name, priority, NULL); + pool_, group_name, priority, NULL); } virtual void TearDown() { @@ -441,7 +453,7 @@ int InitHandle(ClientSocketHandle* handle, const std::string& group_name, net::RequestPriority priority, CompletionCallback* callback, - TestClientSocketPool* pool, + const scoped_refptr<TestClientSocketPool>& pool, const BoundNetLog& net_log) { return handle->Init<TestSocketParams, TestClientSocketPool>( group_name, NULL, priority, callback, pool, net_log); @@ -502,8 +514,8 @@ TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) { ClientSocketHandle handle; CapturingBoundNetLog log(CapturingNetLog::kUnbounded); - EXPECT_EQ(OK, InitHandle(&handle, "a", kDefaultPriority, - &callback, pool_.get(), log.bound())); + EXPECT_EQ(OK, InitHandle(&handle, "a", kDefaultPriority, &callback, pool_, + log.bound())); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); handle.Reset(); @@ -530,8 +542,8 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { TestSocketRequest req(&request_order_, &completion_count_); EXPECT_EQ(ERR_CONNECTION_FAILED, - InitHandle(req.handle(), "a", kDefaultPriority, &req, - pool_.get(), log.bound())); + InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, + log.bound())); EXPECT_EQ(5u, log.entries().size()); EXPECT_TRUE(LogContainsBeginEvent(log.entries(), 0, NetLog::TYPE_SOCKET_POOL)); @@ -839,8 +851,7 @@ TEST_F(ClientSocketPoolBaseTest, CancelRequestClearGroup) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); TestSocketRequest req(&request_order_, &completion_count_); EXPECT_EQ(ERR_IO_PENDING, - InitHandle(req.handle(), "a", kDefaultPriority, &req, - pool_.get(), NULL)); + InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, NULL)); req.handle()->Reset(); } @@ -853,15 +864,13 @@ TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { TestSocketRequest req(&request_order_, &completion_count_); EXPECT_EQ(ERR_IO_PENDING, - InitHandle(&handle, "a", kDefaultPriority, &callback, - pool_.get(), NULL)); + InitHandle(&handle, "a", kDefaultPriority, &callback, pool_, NULL)); handle.Reset(); TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, - InitHandle(&handle, "a", kDefaultPriority, &callback2, - pool_.get(), NULL)); + EXPECT_EQ(ERR_IO_PENDING, InitHandle(&handle, "a", kDefaultPriority, + &callback2, pool_, NULL)); EXPECT_EQ(OK, callback2.WaitForResult()); EXPECT_FALSE(callback.have_result()); @@ -934,9 +943,8 @@ class RequestSocketCallback : public CallbackRunner< Tuple1<int> > { } within_callback_ = true; TestCompletionCallback next_job_callback; - int rv = InitHandle( - handle_, "a", kDefaultPriority, &next_job_callback, pool_.get(), - NULL); + int rv = InitHandle(handle_, "a", kDefaultPriority, &next_job_callback, + pool_, NULL); switch (next_job_type_) { case TestConnectJob::kMockJob: EXPECT_EQ(OK, rv); @@ -985,8 +993,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestPendingJobTwice) { RequestSocketCallback callback( &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockPendingJob); - int rv = InitHandle(&handle, "a", kDefaultPriority, &callback, - pool_.get(), NULL); + int rv = InitHandle(&handle, "a", kDefaultPriority, &callback, pool_, NULL); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback.WaitForResult()); @@ -999,8 +1006,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestPendingJobThenSynchronous) { ClientSocketHandle handle; RequestSocketCallback callback( &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockJob); - int rv = InitHandle(&handle, "a", kDefaultPriority, &callback, - pool_.get(), NULL); + int rv = InitHandle(&handle, "a", kDefaultPriority, &callback, pool_, NULL); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback.WaitForResult()); @@ -1061,15 +1067,13 @@ TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestThenRequestSocket) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); TestSocketRequest req(&request_order_, &completion_count_); - int rv = InitHandle(req.handle(), "a", kDefaultPriority, &req, - pool_.get(), NULL); + int rv = InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); // Cancel the active request. req.handle()->Reset(); - rv = InitHandle(req.handle(), "a", kDefaultPriority, &req, - pool_.get(), NULL); + rv = InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, req.WaitForResult()); @@ -1119,7 +1123,7 @@ TEST_F(ClientSocketPoolBaseTest, BasicAsynchronous) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); TestSocketRequest req(&request_order_, &completion_count_); CapturingBoundNetLog log(CapturingNetLog::kUnbounded); - int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_.get(), log.bound()); + int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_, log.bound()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req.handle())); EXPECT_EQ(OK, req.WaitForResult()); @@ -1146,8 +1150,8 @@ TEST_F(ClientSocketPoolBaseTest, TestSocketRequest req(&request_order_, &completion_count_); CapturingBoundNetLog log(CapturingNetLog::kUnbounded); EXPECT_EQ(ERR_IO_PENDING, - InitHandle(req.handle(), "a", kDefaultPriority, &req, - pool_.get(), log.bound())); + InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, + log.bound())); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req.handle())); EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); @@ -1172,12 +1176,12 @@ TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { TestSocketRequest req2(&request_order_, &completion_count_); EXPECT_EQ(ERR_IO_PENDING, - InitHandle(req.handle(), "a", kDefaultPriority, &req, - pool_.get(), BoundNetLog())); + InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, + BoundNetLog())); CapturingBoundNetLog log2(CapturingNetLog::kUnbounded); EXPECT_EQ(ERR_IO_PENDING, - InitHandle(req2.handle(), "a", kDefaultPriority, &req2, - pool_.get(), BoundNetLog())); + InitHandle(req2.handle(), "a", kDefaultPriority, &req2, pool_, + BoundNetLog())); req.handle()->Reset(); @@ -1222,8 +1226,7 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); TestSocketRequest req1(&request_order_, &completion_count_); - int rv = InitHandle(req1.handle(), "a", kDefaultPriority, - &req1, pool_.get(), NULL); + int rv = InitHandle(req1.handle(), "a", kDefaultPriority, &req1, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, req1.WaitForResult()); @@ -1232,12 +1235,10 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); TestSocketRequest req2(&request_order_, &completion_count_); - rv = InitHandle(req2.handle(), "a", kDefaultPriority, &req2, - pool_.get(), NULL); + rv = InitHandle(req2.handle(), "a", kDefaultPriority, &req2, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); TestSocketRequest req3(&request_order_, &completion_count_); - rv = InitHandle( - req3.handle(), "a", kDefaultPriority, &req3, pool_.get(), NULL); + rv = InitHandle(req3.handle(), "a", kDefaultPriority, &req3, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); // Both Requests 2 and 3 are pending. We release socket 1 which should @@ -1268,21 +1269,18 @@ TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); TestSocketRequest req1(&request_order_, &completion_count_); - int rv = InitHandle( - req1.handle(), "a", kDefaultPriority, &req1, pool_.get(), NULL); + int rv = InitHandle(req1.handle(), "a", kDefaultPriority, &req1, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); TestSocketRequest req2(&request_order_, &completion_count_); - rv = InitHandle(req2.handle(), "a", kDefaultPriority, &req2, - pool_.get(), NULL); + rv = InitHandle(req2.handle(), "a", kDefaultPriority, &req2, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); // The pending job is sync. connect_job_factory_->set_job_type(TestConnectJob::kMockJob); TestSocketRequest req3(&request_order_, &completion_count_); - rv = InitHandle( - req3.handle(), "a", kDefaultPriority, &req3, pool_.get(), NULL); + rv = InitHandle(req3.handle(), "a", kDefaultPriority, &req3, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(ERR_CONNECTION_FAILED, req1.WaitForResult()); @@ -1301,16 +1299,14 @@ TEST_F(ClientSocketPoolBaseTest, DISABLED_LoadState) { TestConnectJob::kMockAdvancingLoadStateJob); TestSocketRequest req1(&request_order_, &completion_count_); - int rv = InitHandle( - req1.handle(), "a", kDefaultPriority, &req1, pool_.get(), NULL); + int rv = InitHandle(req1.handle(), "a", kDefaultPriority, &req1, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_IDLE, req1.handle()->GetLoadState()); MessageLoop::current()->RunAllPending(); TestSocketRequest req2(&request_order_, &completion_count_); - rv = InitHandle(req2.handle(), "a", kDefaultPriority, &req2, - pool_.get(), NULL); + rv = InitHandle(req2.handle(), "a", kDefaultPriority, &req2, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_WAITING_FOR_CACHE, req1.handle()->GetLoadState()); EXPECT_EQ(LOAD_STATE_WAITING_FOR_CACHE, req2.handle()->GetLoadState()); @@ -1327,12 +1323,12 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { // Startup two mock pending connect jobs, which will sit in the MessageLoop. TestSocketRequest req(&request_order_, &completion_count_); - int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_.get(), NULL); + int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req.handle())); TestSocketRequest req2(&request_order_, &completion_count_); - rv = InitHandle(req2.handle(), "a", LOWEST, &req2, pool_.get(), NULL); + rv = InitHandle(req2.handle(), "a", LOWEST, &req2, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req2.handle())); @@ -1359,7 +1355,7 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { pool_->CleanupTimedOutIdleSockets(); CapturingBoundNetLog log(CapturingNetLog::kUnbounded); - rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_.get(), log.bound()); + rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_, log.bound()); EXPECT_EQ(OK, rv); EXPECT_TRUE(req.handle()->is_reused()); EXPECT_TRUE(LogContainsEntryWithType( @@ -1379,19 +1375,19 @@ TEST_F(ClientSocketPoolBaseTest, MultipleReleasingDisconnectedSockets) { // Startup 4 connect jobs. Two of them will be pending. TestSocketRequest req(&request_order_, &completion_count_); - int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_.get(), NULL); + int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_, NULL); EXPECT_EQ(OK, rv); TestSocketRequest req2(&request_order_, &completion_count_); - rv = InitHandle(req2.handle(), "a", LOWEST, &req2, pool_.get(), NULL); + rv = InitHandle(req2.handle(), "a", LOWEST, &req2, pool_, NULL); EXPECT_EQ(OK, rv); TestSocketRequest req3(&request_order_, &completion_count_); - rv = InitHandle(req3.handle(), "a", LOWEST, &req3, pool_.get(), NULL); + rv = InitHandle(req3.handle(), "a", LOWEST, &req3, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); TestSocketRequest req4(&request_order_, &completion_count_); - rv = InitHandle(req4.handle(), "a", LOWEST, &req4, pool_.get(), NULL); + rv = InitHandle(req4.handle(), "a", LOWEST, &req4, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); // Release two disconnected sockets. @@ -1453,9 +1449,8 @@ class TestReleasingSocketRequest : public CallbackRunner< Tuple1<int> > { virtual void RunWithParams(const Tuple1<int>& params) { callback_.RunWithParams(params); handle_.Reset(); - EXPECT_EQ(ERR_IO_PENDING, - InitHandle(&handle2_, "a", kDefaultPriority, - &callback2_, pool_, NULL)); + EXPECT_EQ(ERR_IO_PENDING, InitHandle(&handle2_, "a", kDefaultPriority, + &callback2_, pool_, NULL)); } private: @@ -1477,17 +1472,16 @@ TEST_F(ClientSocketPoolBaseTest, ReleasedSocketReleasesToo) { // Complete one request and release the socket. ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, InitHandle( - &handle, "a", kDefaultPriority, &callback, pool_.get(), NULL)); + EXPECT_EQ(OK, InitHandle(&handle, "a", kDefaultPriority, &callback, pool_, + NULL)); handle.Reset(); // Before the DoReleaseSocket() task has run, start up a // TestReleasingSocketRequest. This one will be ERR_IO_PENDING since // num_releasing_sockets > 0 and there was no idle socket to use yet. TestReleasingSocketRequest request(pool_.get()); - EXPECT_EQ(ERR_IO_PENDING, - InitHandle(request.handle(), "a", kDefaultPriority, &request, - pool_.get(), NULL)); + EXPECT_EQ(ERR_IO_PENDING, InitHandle(request.handle(), "a", kDefaultPriority, + &request, pool_, NULL)); EXPECT_EQ(OK, request.WaitForResult()); } diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index ba0c8dc..8b8153e 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -472,4 +472,18 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { } while (released_one); } +const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; +const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); + +const char kSOCKS5GreetResponse[] = { 0x05, 0x00 }; +const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse); + +const char kSOCKS5OkRequest[] = + { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 }; +const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest); + +const char kSOCKS5OkResponse[] = + { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 }; +const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse); + } // namespace net diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 69b78da..2daa901 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -422,11 +422,11 @@ class ClientSocketPoolTest : public testing::Test { virtual void TearDown(); template <typename PoolType, typename SocketParams> - int StartRequestUsingPool(PoolType* socket_pool, + int StartRequestUsingPool(const scoped_refptr<PoolType>& socket_pool, const std::string& group_name, RequestPriority priority, const SocketParams& socket_params) { - DCHECK(socket_pool); + DCHECK(socket_pool.get()); TestSocketRequest* request = new TestSocketRequest(&request_order_, &completion_count_); requests_.push_back(request); @@ -456,6 +456,19 @@ class ClientSocketPoolTest : public testing::Test { size_t completion_count_; }; +// Constants for a successful SOCKS v5 handshake. +extern const char kSOCKS5GreetRequest[]; +extern const int kSOCKS5GreetRequestLength; + +extern const char kSOCKS5GreetResponse[]; +extern const int kSOCKS5GreetResponseLength; + +extern const char kSOCKS5OkRequest[]; +extern const int kSOCKS5OkRequestLength; + +extern const char kSOCKS5OkResponse[]; +extern const int kSOCKS5OkResponseLength; + } // namespace net #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index 0e29e2c..eef3403 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -13,6 +13,7 @@ #include "net/base/net_log.h" #include "net/base/net_util.h" #include "net/base/sys_addrinfo.h" +#include "net/socket/client_socket_handle.h" namespace net { @@ -46,7 +47,8 @@ const uint8 SOCKS5ClientSocket::kNullByte = 0x00; COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4); COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6); -SOCKS5ClientSocket::SOCKS5ClientSocket(ClientSocket* transport_socket, +SOCKS5ClientSocket::SOCKS5ClientSocket( + ClientSocketHandle* transport_socket, const HostResolver::RequestInfo& req_info) : ALLOW_THIS_IN_INITIALIZER_LIST( io_callback_(this, &SOCKS5ClientSocket::OnIOComplete)), @@ -60,6 +62,22 @@ SOCKS5ClientSocket::SOCKS5ClientSocket(ClientSocket* transport_socket, host_request_info_(req_info) { } +SOCKS5ClientSocket::SOCKS5ClientSocket( + ClientSocket* transport_socket, + const HostResolver::RequestInfo& req_info) + : ALLOW_THIS_IN_INITIALIZER_LIST( + io_callback_(this, &SOCKS5ClientSocket::OnIOComplete)), + transport_(new ClientSocketHandle()), + next_state_(STATE_NONE), + user_callback_(NULL), + completed_handshake_(false), + bytes_sent_(0), + bytes_received_(0), + read_header_size(kReadHeaderSize), + host_request_info_(req_info) { + transport_->set_socket(transport_socket); +} + SOCKS5ClientSocket::~SOCKS5ClientSocket() { Disconnect(); } @@ -67,7 +85,8 @@ SOCKS5ClientSocket::~SOCKS5ClientSocket() { int SOCKS5ClientSocket::Connect(CompletionCallback* callback, const BoundNetLog& net_log) { DCHECK(transport_.get()); - DCHECK(transport_->IsConnected()); + DCHECK(transport_->socket()); + DCHECK(transport_->socket()->IsConnected()); DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); @@ -93,7 +112,7 @@ int SOCKS5ClientSocket::Connect(CompletionCallback* callback, void SOCKS5ClientSocket::Disconnect() { completed_handshake_ = false; - transport_->Disconnect(); + transport_->socket()->Disconnect(); // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). @@ -103,11 +122,11 @@ void SOCKS5ClientSocket::Disconnect() { } bool SOCKS5ClientSocket::IsConnected() const { - return completed_handshake_ && transport_->IsConnected(); + return completed_handshake_ && transport_->socket()->IsConnected(); } bool SOCKS5ClientSocket::IsConnectedAndIdle() const { - return completed_handshake_ && transport_->IsConnectedAndIdle(); + return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); } // Read is called by the transport layer above to read. This can only be done @@ -118,7 +137,7 @@ int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); - return transport_->Read(buf, buf_len, callback); + return transport_->socket()->Read(buf, buf_len, callback); } // Write is called by the transport layer. This can only be done if the @@ -129,15 +148,15 @@ int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); - return transport_->Write(buf, buf_len, callback); + return transport_->socket()->Write(buf, buf_len, callback); } bool SOCKS5ClientSocket::SetReceiveBufferSize(int32 size) { - return transport_->SetReceiveBufferSize(size); + return transport_->socket()->SetReceiveBufferSize(size); } bool SOCKS5ClientSocket::SetSendBufferSize(int32 size) { - return transport_->SetSendBufferSize(size); + return transport_->socket()->SetSendBufferSize(size); } void SOCKS5ClientSocket::DoCallback(int result) { @@ -236,7 +255,8 @@ int SOCKS5ClientSocket::DoGreetWrite() { handshake_buf_ = new IOBuffer(handshake_buf_len); memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], handshake_buf_len); - return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_); + return transport_->socket()->Write(handshake_buf_, handshake_buf_len, + &io_callback_); } int SOCKS5ClientSocket::DoGreetWriteComplete(int result) { @@ -258,7 +278,8 @@ int SOCKS5ClientSocket::DoGreetRead() { next_state_ = STATE_GREET_READ_COMPLETE; size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_; handshake_buf_ = new IOBuffer(handshake_buf_len); - return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_); + return transport_->socket()->Read(handshake_buf_, handshake_buf_len, + &io_callback_); } int SOCKS5ClientSocket::DoGreetReadComplete(int result) { @@ -335,7 +356,8 @@ int SOCKS5ClientSocket::DoHandshakeWrite() { handshake_buf_ = new IOBuffer(handshake_buf_len); memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); - return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_); + return transport_->socket()->Write(handshake_buf_, handshake_buf_len, + &io_callback_); } int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) { @@ -368,7 +390,8 @@ int SOCKS5ClientSocket::DoHandshakeRead() { int handshake_buf_len = read_header_size - bytes_received_; handshake_buf_ = new IOBuffer(handshake_buf_len); - return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_); + return transport_->socket()->Read(handshake_buf_, handshake_buf_len, + &io_callback_); } int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) { @@ -448,7 +471,7 @@ int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) { } int SOCKS5ClientSocket::GetPeerAddress(AddressList* address) const { - return transport_->GetPeerAddress(address); + return transport_->socket()->GetPeerAddress(address); } } // namespace net diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index ae3ef76..3e30c19 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -21,6 +21,7 @@ namespace net { +class ClientSocketHandle; class BoundNetLog; // This ClientSocket is used to setup a SOCKSv5 handshake with a socks proxy. @@ -36,6 +37,10 @@ class SOCKS5ClientSocket : public ClientSocket { // Although SOCKS 5 supports 3 different modes of addressing, we will // always pass it a hostname. This means the DNS resolving is done // proxy side. + SOCKS5ClientSocket(ClientSocketHandle* transport_socket, + const HostResolver::RequestInfo& req_info); + + // Deprecated constructor (http://crbug.com/37810) that takes a ClientSocket. SOCKS5ClientSocket(ClientSocket* transport_socket, const HostResolver::RequestInfo& req_info); @@ -106,7 +111,7 @@ class SOCKS5ClientSocket : public ClientSocket { CompletionCallbackImpl<SOCKS5ClientSocket> io_callback_; // Stores the underlying socket. - scoped_ptr<ClientSocket> transport_; + scoped_ptr<ClientSocketHandle> transport_; State next_state_; diff --git a/net/socket/socks5_client_socket_unittest.cc b/net/socket/socks5_client_socket_unittest.cc index d15676c..d9d1012 100644 --- a/net/socket/socks5_client_socket_unittest.cc +++ b/net/socket/socks5_client_socket_unittest.cc @@ -89,18 +89,12 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( HostResolver::RequestInfo(hostname, port)); } -const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; -const char kSOCKS5GreetResponse[] = { 0x05, 0x00 }; -const char kSOCKS5OkResponse[] = - { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 }; - - // Tests a complete SOCKS5 handshake and the disconnection. TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { const std::string payload_write = "random data"; const std::string payload_read = "moar random data"; - const char kSOCKS5OkRequest[] = { + const char kOkRequest[] = { 0x05, // Version 0x01, // Command (CONNECT) 0x00, // Reserved. @@ -112,12 +106,12 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { }; MockWrite data_writes[] = { - MockWrite(true, kSOCKS5GreetRequest, arraysize(kSOCKS5GreetRequest)), - MockWrite(true, kSOCKS5OkRequest, arraysize(kSOCKS5OkRequest)), + MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(true, kOkRequest, arraysize(kOkRequest)), MockWrite(true, payload_write.data(), payload_write.size()) }; MockRead data_reads[] = { - MockRead(true, kSOCKS5GreetResponse, arraysize(kSOCKS5GreetResponse)), - MockRead(true, kSOCKS5OkResponse, arraysize(kSOCKS5OkResponse)), + MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength), MockRead(true, payload_read.data(), payload_read.size()) }; user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), @@ -176,12 +170,12 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { for (int i = 0; i < 2; ++i) { MockWrite data_writes[] = { - MockWrite(false, kSOCKS5GreetRequest, arraysize(kSOCKS5GreetRequest)), + MockWrite(false, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), MockWrite(false, request.data(), request.size()) }; MockRead data_reads[] = { - MockRead(false, kSOCKS5GreetResponse, arraysize(kSOCKS5GreetResponse)), - MockRead(false, kSOCKS5OkResponse, arraysize(kSOCKS5OkResponse)) + MockRead(false, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(false, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), @@ -220,7 +214,7 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { const std::string hostname = "www.google.com"; - const char kSOCKS5OkRequest[] = { + const char kOkRequest[] = { 0x05, // Version 0x01, // Command (CONNECT) 0x00, // Reserved. @@ -238,10 +232,10 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockWrite data_writes[] = { MockWrite(true, arraysize(partial1)), MockWrite(true, partial2, arraysize(partial2)), - MockWrite(true, kSOCKS5OkRequest, arraysize(kSOCKS5OkRequest)) }; + MockWrite(true, kOkRequest, arraysize(kOkRequest)) }; MockRead data_reads[] = { - MockRead(true, kSOCKS5GreetResponse, arraysize(kSOCKS5GreetResponse)), - MockRead(true, kSOCKS5OkResponse, arraysize(kSOCKS5OkResponse)) }; + MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80)); @@ -260,12 +254,12 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { const char partial1[] = { 0x05 }; const char partial2[] = { 0x00 }; MockWrite data_writes[] = { - MockWrite(true, kSOCKS5GreetRequest, arraysize(kSOCKS5GreetRequest)), - MockWrite(true, kSOCKS5OkRequest, arraysize(kSOCKS5OkRequest)) }; + MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(true, kOkRequest, arraysize(kOkRequest)) }; MockRead data_reads[] = { MockRead(true, partial1, arraysize(partial1)), MockRead(true, partial2, arraysize(partial2)), - MockRead(true, kSOCKS5OkResponse, arraysize(kSOCKS5OkResponse)) }; + MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80)); @@ -283,14 +277,14 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { { const int kSplitPoint = 3; // Break handshake write into two parts. MockWrite data_writes[] = { - MockWrite(true, kSOCKS5GreetRequest, arraysize(kSOCKS5GreetRequest)), - MockWrite(true, kSOCKS5OkRequest, kSplitPoint), - MockWrite(true, kSOCKS5OkRequest + kSplitPoint, - arraysize(kSOCKS5OkRequest) - kSplitPoint) + MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(true, kOkRequest, kSplitPoint), + MockWrite(true, kOkRequest + kSplitPoint, + arraysize(kOkRequest) - kSplitPoint) }; MockRead data_reads[] = { - MockRead(true, kSOCKS5GreetResponse, arraysize(kSOCKS5GreetResponse)), - MockRead(true, kSOCKS5OkResponse, arraysize(kSOCKS5OkResponse)) }; + MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80)); @@ -308,14 +302,14 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { { const int kSplitPoint = 6; // Break the handshake read into two parts. MockWrite data_writes[] = { - MockWrite(true, kSOCKS5GreetRequest, arraysize(kSOCKS5GreetRequest)), - MockWrite(true, kSOCKS5OkRequest, arraysize(kSOCKS5OkRequest)) + MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(true, kOkRequest, arraysize(kOkRequest)) }; MockRead data_reads[] = { - MockRead(true, kSOCKS5GreetResponse, arraysize(kSOCKS5GreetResponse)), + MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(true, kSOCKS5OkResponse, kSplitPoint), MockRead(true, kSOCKS5OkResponse + kSplitPoint, - arraysize(kSOCKS5OkResponse) - kSplitPoint) + kSOCKS5OkResponseLength - kSplitPoint) }; user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index c92189c..e5049ff 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -11,6 +11,7 @@ #include "net/base/net_log.h" #include "net/base/net_util.h" #include "net/base/sys_addrinfo.h" +#include "net/socket/client_socket_handle.h" namespace net { @@ -59,7 +60,7 @@ struct SOCKS4ServerResponse { COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, socks4_server_response_struct_wrong_size); -SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket, +SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, const HostResolver::RequestInfo& req_info, HostResolver* host_resolver) : ALLOW_THIS_IN_INITIALIZER_LIST( @@ -75,6 +76,23 @@ SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket, host_request_info_(req_info) { } +SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket, + const HostResolver::RequestInfo& req_info, + HostResolver* host_resolver) + : ALLOW_THIS_IN_INITIALIZER_LIST( + io_callback_(this, &SOCKSClientSocket::OnIOComplete)), + transport_(new ClientSocketHandle()), + next_state_(STATE_NONE), + socks_version_(kSOCKS4Unresolved), + user_callback_(NULL), + completed_handshake_(false), + bytes_sent_(0), + bytes_received_(0), + host_resolver_(host_resolver), + host_request_info_(req_info) { + transport_->set_socket(transport_socket); +} + SOCKSClientSocket::~SOCKSClientSocket() { Disconnect(); } @@ -82,7 +100,8 @@ SOCKSClientSocket::~SOCKSClientSocket() { int SOCKSClientSocket::Connect(CompletionCallback* callback, const BoundNetLog& net_log) { DCHECK(transport_.get()); - DCHECK(transport_->IsConnected()); + DCHECK(transport_->socket()); + DCHECK(transport_->socket()->IsConnected()); DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); @@ -108,7 +127,7 @@ int SOCKSClientSocket::Connect(CompletionCallback* callback, void SOCKSClientSocket::Disconnect() { completed_handshake_ = false; host_resolver_.Cancel(); - transport_->Disconnect(); + transport_->socket()->Disconnect(); // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). @@ -118,11 +137,11 @@ void SOCKSClientSocket::Disconnect() { } bool SOCKSClientSocket::IsConnected() const { - return completed_handshake_ && transport_->IsConnected(); + return completed_handshake_ && transport_->socket()->IsConnected(); } bool SOCKSClientSocket::IsConnectedAndIdle() const { - return completed_handshake_ && transport_->IsConnectedAndIdle(); + return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); } // Read is called by the transport layer above to read. This can only be done @@ -133,7 +152,7 @@ int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); - return transport_->Read(buf, buf_len, callback); + return transport_->socket()->Read(buf, buf_len, callback); } // Write is called by the transport layer. This can only be done if the @@ -144,15 +163,15 @@ int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); - return transport_->Write(buf, buf_len, callback); + return transport_->socket()->Write(buf, buf_len, callback); } bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) { - return transport_->SetReceiveBufferSize(size); + return transport_->socket()->SetReceiveBufferSize(size); } bool SOCKSClientSocket::SetSendBufferSize(int32 size) { - return transport_->SetSendBufferSize(size); + return transport_->socket()->SetSendBufferSize(size); } void SOCKSClientSocket::DoCallback(int result) { @@ -306,7 +325,8 @@ int SOCKSClientSocket::DoHandshakeWrite() { handshake_buf_ = new IOBuffer(handshake_buf_len); memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); - return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_); + return transport_->socket()->Write(handshake_buf_, handshake_buf_len, + &io_callback_); } int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { @@ -342,7 +362,8 @@ int SOCKSClientSocket::DoHandshakeRead() { int handshake_buf_len = kReadHeaderSize - bytes_received_; handshake_buf_ = new IOBuffer(handshake_buf_len); - return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_); + return transport_->socket()->Read(handshake_buf_, handshake_buf_len, + &io_callback_); } int SOCKSClientSocket::DoHandshakeReadComplete(int result) { @@ -399,7 +420,7 @@ int SOCKSClientSocket::DoHandshakeReadComplete(int result) { } int SOCKSClientSocket::GetPeerAddress(AddressList* address) const { - return transport_->GetPeerAddress(address); + return transport_->socket()->GetPeerAddress(address); } } // namespace net diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 943c6c0..7aabf5d 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -21,6 +21,7 @@ namespace net { +class ClientSocketHandle; class BoundNetLog; // The SOCKS client socket implementation @@ -31,6 +32,11 @@ class SOCKSClientSocket : public ClientSocket { // // |req_info| contains the hostname and port to which the socket above will // communicate to via the socks layer. For testing the referrer is optional. + SOCKSClientSocket(ClientSocketHandle* transport_socket, + const HostResolver::RequestInfo& req_info, + HostResolver* host_resolver); + + // Deprecated constructor (http://crbug.com/37810) that takes a ClientSocket. SOCKSClientSocket(ClientSocket* transport_socket, const HostResolver::RequestInfo& req_info, HostResolver* host_resolver); @@ -96,7 +102,7 @@ class SOCKSClientSocket : public ClientSocket { CompletionCallbackImpl<SOCKSClientSocket> io_callback_; // Stores the underlying socket. - scoped_ptr<ClientSocket> transport_; + scoped_ptr<ClientSocketHandle> transport_; State next_state_; SocksVersion socks_version_; diff --git a/net/socket/socks_client_socket_pool.cc b/net/socket/socks_client_socket_pool.cc new file mode 100644 index 0000000..246cc20 --- /dev/null +++ b/net/socket/socks_client_socket_pool.cc @@ -0,0 +1,214 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket_pool.h" + +#include "base/time.h" +#include "googleurl/src/gurl.h" +#include "net/base/net_errors.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/socks5_client_socket.h" +#include "net/socket/socks_client_socket.h" + +namespace net { + +// SOCKSConnectJobs will time out after this many seconds. Note this is on +// top of the timeout for the transport socket. +static const int kSOCKSConnectJobTimeoutInSeconds = 30; + +SOCKSConnectJob::SOCKSConnectJob( + const std::string& group_name, + const SOCKSSocketParams& socks_params, + const base::TimeDelta& timeout_duration, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HostResolver>& host_resolver, + Delegate* delegate, + const BoundNetLog& net_log) + : ConnectJob(group_name, timeout_duration, delegate, net_log), + socks_params_(socks_params), + tcp_pool_(tcp_pool), + resolver_(host_resolver), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &SOCKSConnectJob::OnIOComplete)) {} + +SOCKSConnectJob::~SOCKSConnectJob() { + // We don't worry about cancelling the tcp socket since the destructor in + // scoped_ptr<ClientSocketHandle> tcp_socket_handle_ will take care of it. +} + +LoadState SOCKSConnectJob::GetLoadState() const { + switch (next_state_) { + case kStateTCPConnect: + case kStateTCPConnectComplete: + return tcp_socket_handle_->GetLoadState(); + case kStateSOCKSConnect: + case kStateSOCKSConnectComplete: + return LOAD_STATE_CONNECTING; + default: + NOTREACHED(); + return LOAD_STATE_IDLE; + } +} + +int SOCKSConnectJob::ConnectInternal() { + next_state_ = kStateTCPConnect; + return DoLoop(OK); +} + +void SOCKSConnectJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + NotifyDelegateOfCompletion(rv); // Deletes |this| +} + +int SOCKSConnectJob::DoLoop(int result) { + DCHECK_NE(next_state_, kStateNone); + + int rv = result; + do { + State state = next_state_; + next_state_ = kStateNone; + switch (state) { + case kStateTCPConnect: + DCHECK_EQ(OK, rv); + rv = DoTCPConnect(); + break; + case kStateTCPConnectComplete: + rv = DoTCPConnectComplete(rv); + break; + case kStateSOCKSConnect: + DCHECK_EQ(OK, rv); + rv = DoSOCKSConnect(); + break; + case kStateSOCKSConnectComplete: + rv = DoSOCKSConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != kStateNone); + + return rv; +} + +int SOCKSConnectJob::DoTCPConnect() { + next_state_ = kStateTCPConnectComplete; + tcp_socket_handle_.reset(new ClientSocketHandle()); + return tcp_socket_handle_->Init(group_name(), socks_params_.tcp_params(), + socks_params_.destination().priority(), + &callback_, tcp_pool_, net_log()); +} + +int SOCKSConnectJob::DoTCPConnectComplete(int result) { + if (result != OK) + return result; + + // Reset the timer to just the length of time allowed for SOCKS handshake + // so that a fast TCP connection plus a slow SOCKS failure doesn't take + // longer to timeout than it should. + ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); + next_state_ = kStateSOCKSConnect; + return result; +} + +int SOCKSConnectJob::DoSOCKSConnect() { + next_state_ = kStateSOCKSConnectComplete; + + // Add a SOCKS connection on top of the tcp socket. + if (socks_params_.is_socks_v5()) { + socket_.reset(new SOCKS5ClientSocket(tcp_socket_handle_.release(), + socks_params_.destination())); + } else { + socket_.reset(new SOCKSClientSocket(tcp_socket_handle_.release(), + socks_params_.destination(), + resolver_)); + } + return socket_->Connect(&callback_, net_log()); +} + +int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { + if (result != OK) { + socket_->Disconnect(); + return result; + } + + set_socket(socket_.release()); + return result; +} + +ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate, + const BoundNetLog& net_log) const { + return new SOCKSConnectJob(group_name, request.params(), ConnectionTimeout(), + tcp_pool_, host_resolver_, delegate, net_log); +} + +base::TimeDelta +SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { + return tcp_pool_->ConnectionTimeout() + + base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); +} + +SOCKSClientSocketPool::SOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const std::string& name, + const scoped_refptr<HostResolver>& host_resolver, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + NetworkChangeNotifier* network_change_notifier) + : base_(max_sockets, max_sockets_per_group, name, + base::TimeDelta::FromSeconds(kUnusedIdleSocketTimeout), + base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), + new SOCKSConnectJobFactory(tcp_pool, host_resolver), + network_change_notifier) {} + +SOCKSClientSocketPool::~SOCKSClientSocketPool() {} + +int SOCKSClientSocketPool::RequestSocket( + const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log) { + const SOCKSSocketParams* casted_socket_params = + static_cast<const SOCKSSocketParams*>(socket_params); + + return base_.RequestSocket(group_name, *casted_socket_params, priority, + handle, callback, net_log); +} + +void SOCKSClientSocketPool::CancelRequest( + const std::string& group_name, + const ClientSocketHandle* handle) { + base_.CancelRequest(group_name, handle); +} + +void SOCKSClientSocketPool::ReleaseSocket( + const std::string& group_name, + ClientSocket* socket) { + base_.ReleaseSocket(group_name, socket); +} + +void SOCKSClientSocketPool::CloseIdleSockets() { + base_.CloseIdleSockets(); +} + +int SOCKSClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return base_.IdleSocketCountInGroup(group_name); +} + +LoadState SOCKSClientSocketPool::GetLoadState( + const std::string& group_name, const ClientSocketHandle* handle) const { + return base_.GetLoadState(group_name, handle); +} + +} // namespace net diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h new file mode 100644 index 0000000..2c30600 --- /dev/null +++ b/net/socket/socks_client_socket_pool.h @@ -0,0 +1,185 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/scoped_ptr.h" +#include "base/time.h" +#include "net/base/host_resolver.h" +#include "net/proxy/proxy_server.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/tcp_client_socket_pool.h" + +namespace net { + +class ClientSocketFactory; +class ConnectJobFactory; + +class SOCKSSocketParams { + public: + SOCKSSocketParams(const TCPSocketParams& proxy_server, bool socks_v5, + const std::string& destination_host, int destination_port, + RequestPriority priority, const GURL& referrer) + : tcp_params_(proxy_server), + destination_(destination_host, destination_port), + socks_v5_(socks_v5) { + // The referrer is used by the DNS prefetch system to correlate resolutions + // with the page that triggered them. It doesn't impact the actual addresses + // that we resolve to. + destination_.set_referrer(referrer); + destination_.set_priority(priority); + } + + const TCPSocketParams& tcp_params() const { return tcp_params_; } + const HostResolver::RequestInfo& destination() const { return destination_; } + bool is_socks_v5() const { return socks_v5_; }; + + private: + // The tcp connection must point toward the proxy server. + const TCPSocketParams tcp_params_; + // This is the HTTP destination. + HostResolver::RequestInfo destination_; + const bool socks_v5_; +}; + +// SOCKSConnectJob handles the handshake to a socks server after setting up +// an underlying transport socket. +class SOCKSConnectJob : public ConnectJob { + public: + SOCKSConnectJob(const std::string& group_name, + const SOCKSSocketParams& params, + const base::TimeDelta& timeout_duration, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HostResolver> &host_resolver, + Delegate* delegate, + const BoundNetLog& net_log); + virtual ~SOCKSConnectJob(); + + // ConnectJob methods. + virtual LoadState GetLoadState() const; + + private: + enum State { + kStateTCPConnect, + kStateTCPConnectComplete, + kStateSOCKSConnect, + kStateSOCKSConnectComplete, + kStateNone, + }; + + // Begins the tcp connection and the SOCKS handshake. Returns OK on success + // and ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal(); + + void OnIOComplete(int result); + + // Runs the state transition loop. + int DoLoop(int result); + + int DoTCPConnect(); + int DoTCPConnectComplete(int result); + int DoSOCKSConnect(); + int DoSOCKSConnectComplete(int result); + + SOCKSSocketParams socks_params_; + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + const scoped_refptr<HostResolver> resolver_; + + State next_state_; + CompletionCallbackImpl<SOCKSConnectJob> callback_; + scoped_ptr<ClientSocketHandle> tcp_socket_handle_; + scoped_ptr<ClientSocket> socket_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJob); +}; + +class SOCKSClientSocketPool : public ClientSocketPool { + public: + SOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const std::string& name, + const scoped_refptr<HostResolver>& host_resolver, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + NetworkChangeNotifier* network_change_notifier); + + // ClientSocketPool methods: + virtual int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log); + + virtual void CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle); + + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket); + + virtual void CloseIdleSockets(); + + virtual int IdleSocketCount() const { + return base_.idle_socket_count(); + } + + virtual int IdleSocketCountInGroup(const std::string& group_name) const; + + virtual LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const; + + virtual base::TimeDelta ConnectionTimeout() const { + return base_.ConnectionTimeout(); + } + + virtual const std::string& name() const { return base_.name(); }; + + protected: + virtual ~SOCKSClientSocketPool(); + + private: + typedef ClientSocketPoolBase<SOCKSSocketParams> PoolBase; + + class SOCKSConnectJobFactory : public PoolBase::ConnectJobFactory { + public: + SOCKSConnectJobFactory(const scoped_refptr<TCPClientSocketPool>& tcp_pool, + HostResolver* host_resolver) + : tcp_pool_(tcp_pool), + host_resolver_(host_resolver) {} + + virtual ~SOCKSConnectJobFactory() {} + + // ClientSocketPoolBase::ConnectJobFactory methods. + virtual ConnectJob* NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate, + const BoundNetLog& net_log) const; + + virtual base::TimeDelta ConnectionTimeout() const; + + private: + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + const scoped_refptr<HostResolver> host_resolver_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJobFactory); + }; + + PoolBase base_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketPool); +}; + +REGISTER_SOCKET_PARAMS_FOR_POOL(SOCKSClientSocketPool, SOCKSSocketParams) + +} // namespace net + +#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ diff --git a/net/socket/socks_client_socket_pool_unittest.cc b/net/socket/socks_client_socket_pool_unittest.cc new file mode 100644 index 0000000..9f8ff7f --- /dev/null +++ b/net/socket/socks_client_socket_pool_unittest.cc @@ -0,0 +1,368 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket_pool.h" + +#include <vector> + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/time.h" +#include "net/base/mock_host_resolver.h" +#include "net/base/mock_network_change_notifier.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; + +class MockTCPClientSocketPool : public TCPClientSocketPool { + public: + class MockConnectJob { + public: + MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle, + CompletionCallback* callback) + : socket_(socket), + handle_(handle), + user_callback_(callback), + ALLOW_THIS_IN_INITIALIZER_LIST( + connect_callback_(this, &MockConnectJob::OnConnect)) {} + + int Connect(const BoundNetLog& net_log) { + int rv = socket_->Connect(&connect_callback_, net_log); + if (rv == OK) { + user_callback_ = NULL; + OnConnect(OK); + } + return rv; + } + + bool CancelHandle(const ClientSocketHandle* handle) { + if (handle != handle_) + return false; + socket_.reset(NULL); + handle_ = NULL; + user_callback_ = NULL; + return true; + } + + private: + void OnConnect(int rv) { + if (!socket_.get()) + return; + if (rv == OK) + handle_->set_socket(socket_.get()); + + socket_.release(); + handle_ = NULL; + + if (user_callback_) { + CompletionCallback* callback = user_callback_; + user_callback_ = NULL; + callback->Run(rv); + } + } + + scoped_ptr<ClientSocket> socket_; + ClientSocketHandle* handle_; + CompletionCallback* user_callback_; + CompletionCallbackImpl<MockConnectJob> connect_callback_; + + DISALLOW_COPY_AND_ASSIGN(MockConnectJob); + }; + + MockTCPClientSocketPool(int max_sockets, int max_sockets_per_group, + const std::string& name, + ClientSocketFactory* socket_factory, + NetworkChangeNotifier* network_change_notifier) + : TCPClientSocketPool(max_sockets, max_sockets_per_group, name, + NULL, NULL, network_change_notifier), + client_socket_factory_(socket_factory), + release_count_(0), + cancel_count_(0) {} + + int release_count() { return release_count_; }; + int cancel_count() { return cancel_count_; }; + + // TCPClientSocketPool methods. + virtual int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log) { + ClientSocket* socket = client_socket_factory_->CreateTCPClientSocket( + AddressList()); + MockConnectJob* job = new MockConnectJob(socket, handle, callback); + job_list_.push_back(job); + return job->Connect(net_log); + } + + virtual void CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle) { + std::vector<MockConnectJob*>::iterator i; + for (i = job_list_.begin(); i != job_list_.end(); ++i) { + if ((*i)->CancelHandle(handle)) { + cancel_count_++; + break; + } + } + } + + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket) { + release_count_++; + delete socket; + } + + protected: + virtual ~MockTCPClientSocketPool() {} + + private: + ClientSocketFactory* client_socket_factory_; + int release_count_; + int cancel_count_; + std::vector<MockConnectJob*> job_list_; + + DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool); +}; + +class SOCKSClientSocketPoolTest : public ClientSocketPoolTest { + protected: + class SOCKS5MockData { + public: + explicit SOCKS5MockData(bool async) { + writes_.reset(new MockWrite[3]); + writes_[0] = MockWrite(async, kSOCKS5GreetRequest, + kSOCKS5GreetRequestLength); + writes_[1] = MockWrite(async, kSOCKS5OkRequest, kSOCKS5OkRequestLength); + writes_[2] = MockWrite(async, 0); + + reads_.reset(new MockRead[3]); + reads_[0] = MockRead(async, kSOCKS5GreetResponse, + kSOCKS5GreetResponseLength); + reads_[1] = MockRead(async, kSOCKS5OkResponse, kSOCKS5OkResponseLength); + reads_[2] = MockRead(async, 0); + + data_.reset(new StaticSocketDataProvider(reads_.get(), 3, + writes_.get(), 3)); + } + + SocketDataProvider* data_provider() { return data_.get(); } + + private: + scoped_ptr<StaticSocketDataProvider> data_; + scoped_array<MockWrite> writes_; + scoped_array<MockWrite> reads_; + }; + + SOCKSClientSocketPoolTest() + : ignored_tcp_socket_params_("proxy", 80, MEDIUM, GURL(), false), + tcp_socket_pool_(new MockTCPClientSocketPool( + kMaxSockets, kMaxSocketsPerGroup, "MockTCP", + &tcp_client_socket_factory_, &tcp_notifier_)), + ignored_socket_params_(ignored_tcp_socket_params_, true, "host", 80, + MEDIUM, GURL()), + pool_(new SOCKSClientSocketPool(kMaxSockets, kMaxSocketsPerGroup, + "SOCKSUnitTest", NULL, tcp_socket_pool_.get(), &socks_notifier_)) { + } + + int StartRequest(const std::string& group_name, RequestPriority priority) { + return StartRequestUsingPool( + pool_, group_name, priority, ignored_socket_params_); + } + + TCPSocketParams ignored_tcp_socket_params_; + MockClientSocketFactory tcp_client_socket_factory_; + MockNetworkChangeNotifier tcp_notifier_; + scoped_refptr<MockTCPClientSocketPool> tcp_socket_pool_; + + SOCKSSocketParams ignored_socket_params_; + MockNetworkChangeNotifier socks_notifier_; + scoped_refptr<SOCKSClientSocketPool> pool_; +}; + +TEST_F(SOCKSClientSocketPoolTest, Simple) { + SOCKS5MockData data(false); + data.data_provider()->set_connect_data(MockConnect(false, 0)); + tcp_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, NULL, pool_, NULL); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SOCKSClientSocketPoolTest, Async) { + SOCKS5MockData data(true); + tcp_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, &callback, pool_, + NULL); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SOCKSClientSocketPoolTest, TCPConnectError) { + SocketDataProvider* socket_data = new StaticSocketDataProvider(); + socket_data->set_connect_data(MockConnect(false, ERR_CONNECTION_REFUSED)); + tcp_client_socket_factory_.AddSocketDataProvider(socket_data); + + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, NULL, pool_, NULL); + EXPECT_EQ(ERR_CONNECTION_REFUSED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +TEST_F(SOCKSClientSocketPoolTest, AsyncTCPConnectError) { + SocketDataProvider* socket_data = new StaticSocketDataProvider(); + socket_data->set_connect_data(MockConnect(true, ERR_CONNECTION_REFUSED)); + tcp_client_socket_factory_.AddSocketDataProvider(socket_data); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, &callback, pool_, + NULL); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_REFUSED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) { + MockRead failed_read[] = { + MockRead(false, 0), + }; + SocketDataProvider* socket_data = + new StaticSocketDataProvider(failed_read, arraysize(failed_read), + NULL, 0); + socket_data->set_connect_data(MockConnect(false, 0)); + tcp_client_socket_factory_.AddSocketDataProvider(socket_data); + + ClientSocketHandle handle; + EXPECT_EQ(0, tcp_socket_pool_->release_count()); + int rv = handle.Init("a", ignored_socket_params_, LOW, NULL, pool_, NULL); + EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_EQ(1, tcp_socket_pool_->release_count()); +} + +TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) { + MockRead failed_read[] = { + MockRead(true, 0), + }; + SocketDataProvider* socket_data = + new StaticSocketDataProvider(failed_read, arraysize(failed_read), + NULL, 0); + socket_data->set_connect_data(MockConnect(false, 0)); + tcp_client_socket_factory_.AddSocketDataProvider(socket_data); + + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(0, tcp_socket_pool_->release_count()); + int rv = handle.Init("a", ignored_socket_params_, LOW, &callback, pool_, + NULL); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_EQ(1, tcp_socket_pool_->release_count()); +} + +TEST_F(SOCKSClientSocketPoolTest, CancelDuringTCPConnect) { + SOCKS5MockData data(false); + tcp_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + // We need two connections because the pool base lets one cancelled + // connect job proceed for potential future use. + SOCKS5MockData data2(false); + tcp_client_socket_factory_.AddSocketDataProvider(data2.data_provider()); + + EXPECT_EQ(0, tcp_socket_pool_->cancel_count()); + int rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + pool_->CancelRequest("a", requests_[0]->handle()); + pool_->CancelRequest("a", requests_[1]->handle()); + EXPECT_EQ(1, tcp_socket_pool_->cancel_count()); + + // Now wait for the TCP sockets to connect. + MessageLoop::current()->RunAllPending(); + + EXPECT_EQ(kRequestNotFound, GetOrderOfRequest(1)); + EXPECT_EQ(kRequestNotFound, GetOrderOfRequest(2)); + EXPECT_EQ(1, tcp_socket_pool_->cancel_count()); + EXPECT_EQ(1, pool_->IdleSocketCount()); + + requests_[0]->handle()->Reset(); + requests_[1]->handle()->Reset(); +} + +TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) { + SOCKS5MockData data(true); + data.data_provider()->set_connect_data(MockConnect(false, 0)); + tcp_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + // We need two connections because the pool base lets one cancelled + // connect job proceed for potential future use. + SOCKS5MockData data2(true); + data2.data_provider()->set_connect_data(MockConnect(false, 0)); + tcp_client_socket_factory_.AddSocketDataProvider(data2.data_provider()); + + EXPECT_EQ(0, tcp_socket_pool_->cancel_count()); + EXPECT_EQ(0, tcp_socket_pool_->release_count()); + int rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + pool_->CancelRequest("a", requests_[0]->handle()); + pool_->CancelRequest("a", requests_[1]->handle()); + EXPECT_EQ(0, tcp_socket_pool_->cancel_count()); + EXPECT_EQ(1, tcp_socket_pool_->release_count()); + + // Now wait for the async data to reach the SOCKS connect jobs. + MessageLoop::current()->RunAllPending(); + + EXPECT_EQ(kRequestNotFound, GetOrderOfRequest(1)); + EXPECT_EQ(kRequestNotFound, GetOrderOfRequest(2)); + EXPECT_EQ(0, tcp_socket_pool_->cancel_count()); + EXPECT_EQ(1, tcp_socket_pool_->release_count()); + EXPECT_EQ(1, pool_->IdleSocketCount()); + + requests_[0]->handle()->Reset(); + requests_[1]->handle()->Reset(); +} + +// It would be nice to also test the timeouts in SOCKSClientSocketPool. + +} // namespace + +} // namespace net diff --git a/net/socket/tcp_client_socket_pool.cc b/net/socket/tcp_client_socket_pool.cc index 2c5d63f..a833e01 100644 --- a/net/socket/tcp_client_socket_pool.cc +++ b/net/socket/tcp_client_socket_pool.cc @@ -161,19 +161,24 @@ ConnectJob* TCPClientSocketPool::TCPConnectJobFactory::NewConnectJob( const PoolBase::Request& request, ConnectJob::Delegate* delegate, const BoundNetLog& net_log) const { - return new TCPConnectJob( - group_name, request.params(), - base::TimeDelta::FromSeconds(kTCPConnectJobTimeoutInSeconds), - client_socket_factory_, host_resolver_, delegate, net_log); + return new TCPConnectJob(group_name, request.params(), ConnectionTimeout(), + client_socket_factory_, host_resolver_, delegate, + net_log); +} + +base::TimeDelta + TCPClientSocketPool::TCPConnectJobFactory::ConnectionTimeout() const { + return base::TimeDelta::FromSeconds(kTCPConnectJobTimeoutInSeconds); } TCPClientSocketPool::TCPClientSocketPool( int max_sockets, int max_sockets_per_group, + const std::string& name, HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetworkChangeNotifier* network_change_notifier) - : base_(max_sockets, max_sockets_per_group, + : base_(max_sockets, max_sockets_per_group, name, base::TimeDelta::FromSeconds(kUnusedIdleSocketTimeout), base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), new TCPConnectJobFactory(client_socket_factory, host_resolver), diff --git a/net/socket/tcp_client_socket_pool.h b/net/socket/tcp_client_socket_pool.h index 50b1237..76950c3 100644 --- a/net/socket/tcp_client_socket_pool.h +++ b/net/socket/tcp_client_socket_pool.h @@ -101,6 +101,7 @@ class TCPClientSocketPool : public ClientSocketPool { TCPClientSocketPool( int max_sockets, int max_sockets_per_group, + const std::string& name, HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetworkChangeNotifier* network_change_notifier); @@ -131,6 +132,12 @@ class TCPClientSocketPool : public ClientSocketPool { virtual LoadState GetLoadState(const std::string& group_name, const ClientSocketHandle* handle) const; + virtual base::TimeDelta ConnectionTimeout() const { + return base_.ConnectionTimeout(); + } + + virtual const std::string& name() const { return base_.name(); } + protected: virtual ~TCPClientSocketPool(); @@ -155,6 +162,8 @@ class TCPClientSocketPool : public ClientSocketPool { ConnectJob::Delegate* delegate, const BoundNetLog& net_log) const; + virtual base::TimeDelta ConnectionTimeout() const; + private: ClientSocketFactory* const client_socket_factory_; const scoped_refptr<HostResolver> host_resolver_; diff --git a/net/socket/tcp_client_socket_pool_unittest.cc b/net/socket/tcp_client_socket_pool_unittest.cc index a275a63..4ef6b98 100644 --- a/net/socket/tcp_client_socket_pool_unittest.cc +++ b/net/socket/tcp_client_socket_pool_unittest.cc @@ -244,6 +244,7 @@ class TCPClientSocketPoolTest : public ClientSocketPoolTest { host_resolver_(new MockHostResolver), pool_(new TCPClientSocketPool(kMaxSockets, kMaxSocketsPerGroup, + "TCPUnitTest", host_resolver_, &client_socket_factory_, ¬ifier_)) { @@ -251,7 +252,7 @@ class TCPClientSocketPoolTest : public ClientSocketPoolTest { int StartRequest(const std::string& group_name, RequestPriority priority) { return StartRequestUsingPool( - pool_.get(), group_name, priority, ignored_socket_params_); + pool_, group_name, priority, ignored_socket_params_); } TCPSocketParams ignored_socket_params_; @@ -265,7 +266,7 @@ TEST_F(TCPClientSocketPoolTest, Basic) { TestCompletionCallback callback; ClientSocketHandle handle; TCPSocketParams dest("www.google.com", 80, LOW, GURL(), false); - int rv = handle.Init("a", dest, LOW, &callback, pool_.get(), NULL); + int rv = handle.Init("a", dest, LOW, &callback, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -283,8 +284,7 @@ TEST_F(TCPClientSocketPoolTest, InitHostResolutionFailure) { TCPSocketParams dest("unresolvable.host.name", 80, kDefaultPriority, GURL(), false); EXPECT_EQ(ERR_IO_PENDING, - req.handle()->Init( - "a", dest, kDefaultPriority, &req, pool_.get(), NULL)); + req.handle()->Init("a", dest, kDefaultPriority, &req, pool_, NULL)); EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req.WaitForResult()); } @@ -294,15 +294,13 @@ TEST_F(TCPClientSocketPoolTest, InitConnectionFailure) { TestSocketRequest req(&request_order_, &completion_count_); TCPSocketParams dest("a", 80, kDefaultPriority, GURL(), false); EXPECT_EQ(ERR_IO_PENDING, - req.handle()->Init( - "a", dest, kDefaultPriority, &req, pool_.get(), NULL)); + req.handle()->Init("a", dest, kDefaultPriority, &req, pool_, NULL)); EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); // Make the host resolutions complete synchronously this time. host_resolver_->set_synchronous_mode(true); EXPECT_EQ(ERR_CONNECTION_FAILED, - req.handle()->Init( - "a", dest, kDefaultPriority, &req, pool_.get(), NULL)); + req.handle()->Init("a", dest, kDefaultPriority, &req, pool_, NULL)); } TEST_F(TCPClientSocketPoolTest, PendingRequests) { @@ -408,8 +406,7 @@ TEST_F(TCPClientSocketPoolTest, CancelRequestClearGroup) { TestSocketRequest req(&request_order_, &completion_count_); TCPSocketParams dest("www.google.com", 80, kDefaultPriority, GURL(), false); EXPECT_EQ(ERR_IO_PENDING, - req.handle()->Init( - "a", dest, kDefaultPriority, &req, pool_.get(), NULL)); + req.handle()->Init("a", dest, kDefaultPriority, &req, pool_, NULL)); req.handle()->Reset(); // There is a race condition here. If the worker pool doesn't post the task @@ -427,11 +424,9 @@ TEST_F(TCPClientSocketPoolTest, TwoRequestsCancelOne) { TCPSocketParams dest("www.google.com", 80, kDefaultPriority, GURL(), false); EXPECT_EQ(ERR_IO_PENDING, - req.handle()->Init( - "a", dest, kDefaultPriority, &req, pool_.get(), NULL)); + req.handle()->Init("a", dest, kDefaultPriority, &req, pool_, NULL)); EXPECT_EQ(ERR_IO_PENDING, - req2.handle()->Init( - "a", dest, kDefaultPriority, &req2, pool_.get(), NULL)); + req2.handle()->Init("a", dest, kDefaultPriority, &req2, pool_, NULL)); req.handle()->Reset(); @@ -448,15 +443,13 @@ TEST_F(TCPClientSocketPoolTest, ConnectCancelConnect) { TCPSocketParams dest("www.google.com", 80, kDefaultPriority, GURL(), false); EXPECT_EQ(ERR_IO_PENDING, - handle.Init( - "a", dest, kDefaultPriority, &callback, pool_.get(), NULL)); + handle.Init("a", dest, kDefaultPriority, &callback, pool_, NULL)); handle.Reset(); TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init( - "a", dest, kDefaultPriority, &callback2, pool_.get(), NULL)); + handle.Init("a", dest, kDefaultPriority, &callback2, pool_, NULL)); host_resolver_->set_synchronous_mode(true); // At this point, handle has two ConnectingSockets out for it. Due to the @@ -555,7 +548,7 @@ class RequestSocketCallback : public CallbackRunner< Tuple1<int> > { } within_callback_ = true; TCPSocketParams dest("www.google.com", 80, LOWEST, GURL(), false); - int rv = handle_->Init("a", dest, LOWEST, this, pool_.get(), NULL); + int rv = handle_->Init("a", dest, LOWEST, this, pool_, NULL); EXPECT_EQ(OK, rv); } } @@ -575,7 +568,7 @@ TEST_F(TCPClientSocketPoolTest, RequestTwice) { ClientSocketHandle handle; RequestSocketCallback callback(&handle, pool_.get()); TCPSocketParams dest("www.google.com", 80, LOWEST, GURL(), false); - int rv = handle.Init("a", dest, LOWEST, &callback, pool_.get(), NULL); + int rv = handle.Init("a", dest, LOWEST, &callback, pool_, NULL); ASSERT_EQ(ERR_IO_PENDING, rv); // The callback is going to request "www.google.com". We want it to complete @@ -638,7 +631,7 @@ TEST_F(TCPClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { TestCompletionCallback callback; ClientSocketHandle handle; TCPSocketParams dest("www.google.com", 80, LOW, GURL(), false); - int rv = handle.Init("a", dest, LOW, &callback, pool_.get(), NULL); + int rv = handle.Init("a", dest, LOW, &callback, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -692,7 +685,7 @@ TEST_F(TCPClientSocketPoolTest, BackupSocketConnect) { TestCompletionCallback callback; ClientSocketHandle handle; TCPSocketParams dest("www.google.com", 80, LOW, GURL(), false); - int rv = handle.Init("b", dest, LOW, &callback, pool_.get(), NULL); + int rv = handle.Init("b", dest, LOW, &callback, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -730,7 +723,7 @@ TEST_F(TCPClientSocketPoolTest, BackupSocketCancel) { TestCompletionCallback callback; ClientSocketHandle handle; TCPSocketParams dest("www.google.com", 80, LOW, GURL(), false); - int rv = handle.Init("c", dest, LOW, &callback, pool_.get(), NULL); + int rv = handle.Init("c", dest, LOW, &callback, pool_, NULL); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); |