diff options
21 files changed, 1105 insertions, 270 deletions
diff --git a/net/http/http_auth_controller.cc b/net/http/http_auth_controller.cc index cb899cb..22b4f20 100644 --- a/net/http/http_auth_controller.cc +++ b/net/http/http_auth_controller.cc @@ -62,6 +62,8 @@ HttpAuthController::HttpAuthController( net_log_(net_log) { } +HttpAuthController::~HttpAuthController() {} + int HttpAuthController::MaybeGenerateAuthToken(const HttpRequestInfo* request, CompletionCallback* callback) { bool needs_auth = HaveAuth() || SelectPreemptiveAuth(); diff --git a/net/http/http_auth_controller.h b/net/http/http_auth_controller.h index 016f88e..655e46d 100644 --- a/net/http/http_auth_controller.h +++ b/net/http/http_auth_controller.h @@ -23,7 +23,7 @@ class HttpNetworkSession; class HttpRequestHeaders; struct HttpRequestInfo; -class HttpAuthController { +class HttpAuthController : public base::RefCounted<HttpAuthController> { public: // The arguments are self explanatory except possibly for |auth_url|, which // should be both the auth target and auth path in a single url argument. @@ -35,32 +35,34 @@ class HttpAuthController { // value is a net error code. |OK| will be returned both in the case that // a token is correctly generated synchronously, as well as when no tokens // were necessary. - int MaybeGenerateAuthToken(const HttpRequestInfo* request, - CompletionCallback* callback); + virtual int MaybeGenerateAuthToken(const HttpRequestInfo* request, + CompletionCallback* callback); // Adds either the proxy auth header, or the origin server auth header, // as specified by |target_|. - void AddAuthorizationHeader(HttpRequestHeaders* authorization_headers); + virtual void AddAuthorizationHeader( + HttpRequestHeaders* authorization_headers); // Checks for and handles HTTP status code 401 or 407. // |HandleAuthChallenge()| returns OK on success, or a network error code // otherwise. It may also populate |auth_info_|. - int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers, - bool do_not_send_server_auth, - bool establishing_tunnel); + virtual int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers, + bool do_not_send_server_auth, + bool establishing_tunnel); // Store the supplied credentials and prepare to restart the auth. - void ResetAuth(const std::wstring& username, const std::wstring& password); + virtual void ResetAuth(const std::wstring& username, + const std::wstring& password); - bool HaveAuthHandler() const { + virtual bool HaveAuthHandler() const { return handler_.get() != NULL; } - bool HaveAuth() const { + virtual bool HaveAuth() const { return handler_.get() && !identity_.invalid; } - scoped_refptr<AuthChallengeInfo> auth_info() { + virtual scoped_refptr<AuthChallengeInfo> auth_info() { return auth_info_; } @@ -68,6 +70,10 @@ class HttpAuthController { net_log_ = net_log; } + protected: // So that we can mock this object. + friend class base::RefCounted<HttpAuthController>; + virtual ~HttpAuthController(); + private: // Searches the auth cache for an entry that encompasses the request's path. // If such an entry is found, updates |identity_| and |handler_| with the diff --git a/net/http/http_network_session.cc b/net/http/http_network_session.cc index ea5a5cf..e49674b 100644 --- a/net/http/http_network_session.cc +++ b/net/http/http_network_session.cc @@ -47,6 +47,8 @@ HttpNetworkSession::HttpNetworkSession( // TODO(vandebo) when we've completely converted to pools, the base TCP // pool name should get changed to TCP instead of Transport. : tcp_pool_histograms_(new ClientSocketPoolHistograms("Transport")), + tcp_for_http_proxy_pool_histograms_( + new ClientSocketPoolHistograms("TCPforHTTPProxy")), http_proxy_pool_histograms_(new ClientSocketPoolHistograms("HTTPProxy")), tcp_for_socks_pool_histograms_( new ClientSocketPoolHistograms("TCPforSOCKS")), @@ -69,7 +71,7 @@ HttpNetworkSession::HttpNetworkSession( HttpNetworkSession::~HttpNetworkSession() { } -const scoped_refptr<TCPClientSocketPool>& +const scoped_refptr<HttpProxyClientSocketPool>& HttpNetworkSession::GetSocketPoolForHTTPProxy(const HostPortPair& http_proxy) { HTTPProxySocketPoolMap::const_iterator it = http_proxy_socket_pool_.find(http_proxy); @@ -77,10 +79,17 @@ HttpNetworkSession::GetSocketPoolForHTTPProxy(const HostPortPair& http_proxy) { return it->second; std::pair<HTTPProxySocketPoolMap::iterator, bool> ret = - http_proxy_socket_pool_.insert(std::make_pair(http_proxy, - new TCPClientSocketPool(g_max_sockets_per_proxy_server, - g_max_sockets_per_group, http_proxy_pool_histograms_, - host_resolver_, socket_factory_, net_log_))); + http_proxy_socket_pool_.insert( + std::make_pair( + http_proxy, + new HttpProxyClientSocketPool( + g_max_sockets_per_proxy_server, g_max_sockets_per_group, + http_proxy_pool_histograms_, host_resolver_, + new TCPClientSocketPool( + g_max_sockets_per_proxy_server, g_max_sockets_per_group, + tcp_for_http_proxy_pool_histograms_, host_resolver_, + socket_factory_, net_log_), + net_log_))); return ret.first->second; } diff --git a/net/http/http_network_session.h b/net/http/http_network_session.h index 319753c..96cc7ba 100644 --- a/net/http/http_network_session.h +++ b/net/http/http_network_session.h @@ -16,6 +16,7 @@ #include "net/http/http_auth_cache.h" #include "net/http/http_network_delegate.h" #include "net/http/http_network_transaction.h" +#include "net/http/http_proxy_client_socket_pool.h" #include "net/proxy/proxy_service.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socks_client_socket_pool.h" @@ -71,7 +72,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { const scoped_refptr<SOCKSClientSocketPool>& GetSocketPoolForSOCKSProxy( const HostPortPair& socks_proxy); - const scoped_refptr<TCPClientSocketPool>& GetSocketPoolForHTTPProxy( + const scoped_refptr<HttpProxyClientSocketPool>& GetSocketPoolForHTTPProxy( const HostPortPair& http_proxy); // SSL sockets come from the socket_factory(). @@ -98,7 +99,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { static void set_fixed_https_port(uint16 port); private: - typedef std::map<HostPortPair, scoped_refptr<TCPClientSocketPool> > + typedef std::map<HostPortPair, scoped_refptr<HttpProxyClientSocketPool> > HTTPProxySocketPoolMap; typedef std::map<HostPortPair, scoped_refptr<SOCKSClientSocketPool> > SOCKSSocketPoolMap; @@ -112,6 +113,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { SSLClientAuthCache ssl_client_auth_cache_; HttpAlternateProtocols alternate_protocols_; scoped_refptr<ClientSocketPoolHistograms> tcp_pool_histograms_; + scoped_refptr<ClientSocketPoolHistograms> tcp_for_http_proxy_pool_histograms_; scoped_refptr<ClientSocketPoolHistograms> http_proxy_pool_histograms_; scoped_refptr<ClientSocketPoolHistograms> tcp_for_socks_pool_histograms_; scoped_refptr<ClientSocketPoolHistograms> socks_pool_histograms_; diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index c0fe07b..38e85f6 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -31,6 +31,7 @@ #include "net/http/http_net_log_params.h" #include "net/http/http_network_session.h" #include "net/http/http_proxy_client_socket.h" +#include "net/http/http_proxy_client_socket_pool.h" #include "net/http/http_request_headers.h" #include "net/http/http_request_info.h" #include "net/http/http_response_headers.h" @@ -309,15 +310,13 @@ int HttpNetworkTransaction::RestartWithAuth( } pending_auth_target_ = HttpAuth::AUTH_NONE; + auth_controllers_[target]->ResetAuth(username, password); + if (target == HttpAuth::AUTH_PROXY && using_ssl_ && proxy_info_.is_http()) { DCHECK(establishing_tunnel_); ResetStateForRestart(); - tunnel_credentials_.username = username; - tunnel_credentials_.password = password; - tunnel_credentials_.invalid = false; next_state_ = STATE_TUNNEL_RESTART_WITH_AUTH; } else { - auth_controllers_[target]->ResetAuth(username, password); PrepareForAuthRestart(target); } @@ -432,8 +431,6 @@ LoadState HttpNetworkTransaction::GetLoadState() const { return LOAD_STATE_RESOLVING_PROXY_FOR_URL; case STATE_INIT_CONNECTION_COMPLETE: return connection_->GetLoadState(); - case STATE_TUNNEL_CONNECT_COMPLETE: - return LOAD_STATE_ESTABLISHING_PROXY_TUNNEL; case STATE_SSL_CONNECT_COMPLETE: return LOAD_STATE_SSL_HANDSHAKE; case STATE_GENERATE_PROXY_AUTH_TOKEN_COMPLETE: @@ -520,13 +517,6 @@ int HttpNetworkTransaction::DoLoop(int result) { case STATE_INIT_CONNECTION_COMPLETE: rv = DoInitConnectionComplete(rv); break; - case STATE_TUNNEL_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoTunnelConnect(); - break; - case STATE_TUNNEL_CONNECT_COMPLETE: - rv = DoTunnelConnectComplete(rv); - break; case STATE_TUNNEL_RESTART_WITH_AUTH: DCHECK_EQ(OK, rv); rv = DoTunnelRestartWithAuth(); @@ -717,10 +707,9 @@ int HttpNetworkTransaction::DoInitConnection() { for (int i = 0; i < HttpAuth::AUTH_NUM_TARGETS; i++) { HttpAuth::Target target = static_cast<HttpAuth::Target>(i); if (!auth_controllers_[target].get()) - auth_controllers_[target].reset(new HttpAuthController(target, - AuthURL(target), - session_, - net_log_)); + auth_controllers_[target] = new HttpAuthController(target, + AuthURL(target), + session_, net_log_); } next_state_ = STATE_INIT_CONNECTION_COMPLETE; @@ -795,10 +784,22 @@ int HttpNetworkTransaction::DoInitConnection() { &io_callback_, session_->GetSocketPoolForSOCKSProxy(proxy_host_port_pair), net_log_); } else { - rv = connection_->Init( - connection_group, tcp_params, request_->priority, - &io_callback_, - session_->GetSocketPoolForHTTPProxy(proxy_host_port_pair), net_log_); + DCHECK(proxy_info_.is_http()); + scoped_refptr<HttpAuthController> http_proxy_auth; + if (using_ssl_) { + http_proxy_auth = auth_controllers_[HttpAuth::AUTH_PROXY]; + establishing_tunnel_ = true; + } + + HttpProxySocketParams http_proxy_params(tcp_params, request_->url, + endpoint_, http_proxy_auth, + using_ssl_); + + rv = connection_->Init(connection_group, http_proxy_params, + request_->priority, &io_callback_, + session_->GetSocketPoolForHTTPProxy( + proxy_host_port_pair), + net_log_); } } else { TCPSocketParams tcp_params(endpoint_, request_->priority, @@ -813,6 +814,29 @@ int HttpNetworkTransaction::DoInitConnection() { int HttpNetworkTransaction::DoInitConnectionComplete(int result) { if (result < 0) { + if (result == ERR_RETRY_CONNECTION) { + DCHECK(establishing_tunnel_); + next_state_ = STATE_INIT_CONNECTION; + connection_->socket()->Disconnect(); + connection_->Reset(); + return OK; + } + + if (result == ERR_PROXY_AUTH_REQUESTED) { + DCHECK(establishing_tunnel_); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(connection_->socket()); + DCHECK(tunnel_socket); + DCHECK(!tunnel_socket->IsConnected()); + const HttpResponseInfo* auth_response = tunnel_socket->GetResponseInfo(); + + response_.headers = auth_response->headers; + headers_valid_ = true; + response_.auth_challenge = auth_response->auth_challenge; + pending_auth_target_ = HttpAuth::AUTH_PROXY; + return OK; + } + if (alternate_protocol_mode_ == kUsingAlternateProtocol) { // Mark the alternate protocol as broken and fallback. MarkBrokenAlternateProtocolAndFallback(); @@ -823,6 +847,10 @@ int HttpNetworkTransaction::DoInitConnectionComplete(int result) { } DCHECK_EQ(OK, result); + if (establishing_tunnel_) { + DCHECK(connection_->socket()->IsConnected()); + establishing_tunnel_ = false; + } if (using_spdy_) { DCHECK(!connection_->is_initialized()); @@ -848,74 +876,21 @@ int HttpNetworkTransaction::DoInitConnectionComplete(int result) { // Now we have a TCP connected socket. Perform other connection setup as // needed. UpdateConnectionTypeHistograms(CONNECTION_HTTP); - if (using_ssl_) { - if (proxy_info_.is_direct() || proxy_info_.is_socks()) - next_state_ = STATE_SSL_CONNECT; - else - next_state_ = STATE_TUNNEL_CONNECT; - } else { + if (using_ssl_) + next_state_ = STATE_SSL_CONNECT; + else next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; - } } return OK; } -int HttpNetworkTransaction::DoTunnelConnect() { - next_state_ = STATE_TUNNEL_CONNECT_COMPLETE; - establishing_tunnel_ = true; - - // Add a tunnel socket on top of our existing transport socket. - ClientSocket* socket = connection_->release_socket(); - ClientSocketHandle* transport_socket_handle = new ClientSocketHandle(); - transport_socket_handle->set_socket(socket); - socket = new HttpProxyClientSocket(transport_socket_handle, request_->url, - endpoint_, auth_controllers_[HttpAuth::AUTH_PROXY].release(), true); - connection_->set_socket(socket); - return connection_->socket()->Connect(&io_callback_); -} - -int HttpNetworkTransaction::DoTunnelConnectComplete(int result) { - if (result == OK) { - next_state_ = STATE_SSL_CONNECT; - establishing_tunnel_ = false; - } - - if (result == ERR_RETRY_CONNECTION) { - HttpProxyClientSocket* tunnel_socket = - reinterpret_cast<HttpProxyClientSocket*>(connection_->socket()); - auth_controllers_[HttpAuth::AUTH_PROXY].reset( - tunnel_socket->TakeAuthController()); - next_state_ = STATE_INIT_CONNECTION; - connection_->socket()->Disconnect(); - connection_->Reset(); - result = OK; - } else if (result == ERR_PROXY_AUTH_REQUESTED) { - HttpProxyClientSocket* tunnel_socket = - reinterpret_cast<HttpProxyClientSocket*>(connection_->socket()); - const HttpResponseInfo* auth_response = tunnel_socket->GetResponseInfo(); - - response_.headers = auth_response->headers; - headers_valid_ = true; - response_.auth_challenge = auth_response->auth_challenge; - pending_auth_target_ = HttpAuth::AUTH_PROXY; - result = OK; - } - return result; -} - int HttpNetworkTransaction::DoTunnelRestartWithAuth() { - DCHECK(establishing_tunnel_); - DCHECK(!tunnel_credentials_.invalid); - next_state_ = STATE_TUNNEL_CONNECT_COMPLETE; - + next_state_ = STATE_INIT_CONNECTION_COMPLETE; HttpProxyClientSocket* tunnel_socket = reinterpret_cast<HttpProxyClientSocket*>(connection_->socket()); - tunnel_credentials_.invalid = true; - return tunnel_socket->RestartWithAuth(tunnel_credentials_.username, - tunnel_credentials_.password, - &io_callback_); + return tunnel_socket->RestartWithAuth(&io_callback_); } int HttpNetworkTransaction::DoSSLConnect() { @@ -1212,7 +1187,7 @@ int HttpNetworkTransaction::DoReadHeadersComplete(int result) { endpoint_, session_->mutable_alternate_protocols()); - int rv = HandleAuthChallenge(false); + int rv = HandleAuthChallenge(); if (rv != OK) return rv; @@ -1795,7 +1770,7 @@ bool HttpNetworkTransaction::ShouldApplyServerAuth() const { return !(request_->load_flags & LOAD_DO_NOT_SEND_AUTH_DATA); } -int HttpNetworkTransaction::HandleAuthChallenge(bool establishing_tunnel) { +int HttpNetworkTransaction::HandleAuthChallenge() { scoped_refptr<HttpResponseHeaders> headers = GetResponseHeaders(); DCHECK(headers); @@ -1808,8 +1783,7 @@ int HttpNetworkTransaction::HandleAuthChallenge(bool establishing_tunnel) { return ERR_UNEXPECTED_PROXY_AUTH; int rv = auth_controllers_[target]->HandleAuthChallenge( - headers, (request_->load_flags & LOAD_DO_NOT_SEND_AUTH_DATA) != 0, - establishing_tunnel); + headers, (request_->load_flags & LOAD_DO_NOT_SEND_AUTH_DATA) != 0, false); if (auth_controllers_[target]->HaveAuthHandler()) pending_auth_target_ = target; @@ -1868,8 +1842,6 @@ std::string HttpNetworkTransaction::DescribeState(State state) { STATE_CASE(STATE_RESOLVE_PROXY_COMPLETE); STATE_CASE(STATE_INIT_CONNECTION); STATE_CASE(STATE_INIT_CONNECTION_COMPLETE); - STATE_CASE(STATE_TUNNEL_CONNECT); - STATE_CASE(STATE_TUNNEL_CONNECT_COMPLETE); STATE_CASE(STATE_TUNNEL_RESTART_WITH_AUTH); STATE_CASE(STATE_SSL_CONNECT); STATE_CASE(STATE_SSL_CONNECT_COMPLETE); diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index 0c8d029..373aba8 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -83,8 +83,6 @@ class HttpNetworkTransaction : public HttpTransaction { STATE_RESOLVE_PROXY_COMPLETE, STATE_INIT_CONNECTION, STATE_INIT_CONNECTION_COMPLETE, - STATE_TUNNEL_CONNECT, - STATE_TUNNEL_CONNECT_COMPLETE, STATE_TUNNEL_RESTART_WITH_AUTH, STATE_SSL_CONNECT, STATE_SSL_CONNECT_COMPLETE, @@ -129,8 +127,6 @@ class HttpNetworkTransaction : public HttpTransaction { int DoResolveProxyComplete(int result); int DoInitConnection(); int DoInitConnectionComplete(int result); - int DoTunnelConnect(); - int DoTunnelConnectComplete(int result); int DoTunnelRestartWithAuth(); int DoSSLConnect(); int DoSSLConnectComplete(int result); @@ -230,7 +226,7 @@ class HttpNetworkTransaction : public HttpTransaction { // Handles HTTP status code 401 or 407. // HandleAuthChallenge() returns a network error code, or OK on success. // May update |pending_auth_target_| or |response_.auth_challenge|. - int HandleAuthChallenge(bool establishing_tunnel); + int HandleAuthChallenge(); bool HaveAuth(HttpAuth::Target target) const { return auth_controllers_[target].get() && @@ -247,7 +243,8 @@ class HttpNetworkTransaction : public HttpTransaction { static bool g_ignore_certificate_errors; - scoped_ptr<HttpAuthController> auth_controllers_[HttpAuth::AUTH_NUM_TARGETS]; + scoped_refptr<HttpAuthController> + auth_controllers_[HttpAuth::AUTH_NUM_TARGETS]; // Whether this transaction is waiting for proxy auth, server auth, or is // not waiting for any auth at all. |pending_auth_target_| is read and @@ -318,10 +315,6 @@ class HttpNetworkTransaction : public HttpTransaction { // specified by the URL, due to Alternate-Protocol or fixed testing ports. HostPortPair endpoint_; - // Stores login and password between |RestartWithAuth| - // and |DoTunnelRestartWithAuth|. - HttpAuth::Identity tunnel_credentials_; - // True when the tunnel is in the process of being established - we can't // read from the socket until the tunnel is done. bool establishing_tunnel_; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 36b3ce2..0a2fa31 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -64,7 +64,7 @@ class HttpNetworkSessionPeer { void SetSocketPoolForHTTPProxy( const HostPortPair& http_proxy, - const scoped_refptr<TCPClientSocketPool>& pool) { + const scoped_refptr<HttpProxyClientSocketPool>& pool) { session_->http_proxy_socket_pool_[http_proxy] = pool; } @@ -286,6 +286,8 @@ class CaptureGroupNameSocketPool : public EmulatedClientSocketPool { typedef CaptureGroupNameSocketPool<TCPClientSocketPool> CaptureGroupNameTCPSocketPool; +typedef CaptureGroupNameSocketPool<HttpProxyClientSocketPool> +CaptureGroupNameHttpProxySocketPool; typedef CaptureGroupNameSocketPool<SOCKSClientSocketPool> CaptureGroupNameSOCKSSocketPool; @@ -1401,6 +1403,13 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyKeepAlive) { EXPECT_EQ(L"myproxy:70", response->auth_challenge->host_and_port); EXPECT_EQ(L"MyRealm1", response->auth_challenge->realm); EXPECT_EQ(L"basic", response->auth_challenge->scheme); + + // Cleanup the transaction so that the sockets are destroyed before the + // net log goes out of scope. + trans.reset(); + + // We also need to run the message queue for the socket releases to complete. + MessageLoop::current()->RunAllPending(); } // Test that we don't read the response body when we fail to establish a tunnel, @@ -4081,8 +4090,8 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForHTTPProxyConnections) { HttpNetworkSessionPeer peer(session); - scoped_refptr<CaptureGroupNameTCPSocketPool> http_proxy_pool( - new CaptureGroupNameTCPSocketPool(session.get())); + scoped_refptr<CaptureGroupNameHttpProxySocketPool> http_proxy_pool( + new CaptureGroupNameHttpProxySocketPool(session.get())); peer.SetSocketPoolForHTTPProxy( HostPortPair("http_proxy", 80), http_proxy_pool); diff --git a/net/http/http_proxy_client_socket.cc b/net/http/http_proxy_client_socket.cc index cdfa183..321b5ea 100644 --- a/net/http/http_proxy_client_socket.cc +++ b/net/http/http_proxy_client_socket.cc @@ -50,7 +50,8 @@ void BuildTunnelRequest(const HttpRequestInfo* request_info, HttpProxyClientSocket::HttpProxyClientSocket( ClientSocketHandle* transport_socket, const GURL& request_url, - const HostPortPair& endpoint, HttpAuthController* auth, bool tunnel) + const HostPortPair& endpoint, const scoped_refptr<HttpAuthController>& auth, + bool tunnel) : ALLOW_THIS_IN_INITIALIZER_LIST( io_callback_(this, &HttpProxyClientSocket::OnIOComplete)), next_state_(STATE_NONE), @@ -92,14 +93,10 @@ int HttpProxyClientSocket::Connect(CompletionCallback* callback) { return rv; } -int HttpProxyClientSocket::RestartWithAuth(const std::wstring& username, - const std::wstring& password, - CompletionCallback* callback) { +int HttpProxyClientSocket::RestartWithAuth(CompletionCallback* callback) { DCHECK_EQ(STATE_NONE, next_state_); DCHECK(!user_callback_); - auth_->ResetAuth(username, password); - int rv = PrepareForAuthRestart(); if (rv != OK) return rv; @@ -111,6 +108,9 @@ int HttpProxyClientSocket::RestartWithAuth(const std::wstring& username, } int HttpProxyClientSocket::PrepareForAuthRestart() { + if (!response_.headers.get()) + return ERR_CONNECTION_RESET; + bool keep_alive = false; if (response_.headers->IsKeepAlive() && http_stream_->CanFindEndOfResponse()) { @@ -128,12 +128,13 @@ int HttpProxyClientSocket::PrepareForAuthRestart() { } int HttpProxyClientSocket::DidDrainBodyForAuthRestart(bool keep_alive) { + int rc = OK; if (keep_alive && transport_->socket()->IsConnectedAndIdle()) { next_state_ = STATE_GENERATE_AUTH_TOKEN; transport_->set_is_reused(true); } else { transport_->socket()->Disconnect(); - return ERR_RETRY_CONNECTION; + rc = ERR_RETRY_CONNECTION; } // Reset the other member variables. @@ -141,7 +142,7 @@ int HttpProxyClientSocket::DidDrainBodyForAuthRestart(bool keep_alive) { http_stream_.reset(); request_headers_.clear(); response_ = HttpResponseInfo(); - return OK; + return rc; } void HttpProxyClientSocket::LogBlockedTunnelResponse(int response_code) const { diff --git a/net/http/http_proxy_client_socket.h b/net/http/http_proxy_client_socket.h index 6ced031..4870f0c 100644 --- a/net/http/http_proxy_client_socket.h +++ b/net/http/http_proxy_client_socket.h @@ -29,30 +29,26 @@ class IOBuffer;; class HttpProxyClientSocket : public ClientSocket { public: - // Takes ownership of |auth| and the |transport_socket|, which should - // already be connected by the time Connect() is called. If tunnel is true - // then on Connect() this socket will establish an Http tunnel. + // Takes ownership of |transport_socket|, which should already be connected + // by the time Connect() is called. If tunnel is true then on Connect() + // this socket will establish an Http tunnel. HttpProxyClientSocket(ClientSocketHandle* transport_socket, const GURL& request_url, const HostPortPair& endpoint, - HttpAuthController* auth, bool tunnel); + const scoped_refptr<HttpAuthController>& auth, + bool tunnel); // On destruction Disconnect() is called. virtual ~HttpProxyClientSocket(); // If Connect (or its callback) returns PROXY_AUTH_REQUESTED, then - // credentials can be provided by calling RestartWithAuth. - int RestartWithAuth(const std::wstring& username, - const std::wstring& password, - CompletionCallback* callback); + // credentials should be added to the HttpAuthController before calling + // RestartWithAuth. + int RestartWithAuth(CompletionCallback* callback); const HttpResponseInfo* GetResponseInfo() const { return response_.headers ? &response_ : NULL; } - HttpAuthController* TakeAuthController() { - return auth_.release(); - } - // ClientSocket methods: // Authenticates to the Http Proxy and then passes data freely. @@ -128,7 +124,7 @@ class HttpProxyClientSocket : public ClientSocket { scoped_ptr<HttpStream> http_stream_; HttpRequestInfo request_; HttpResponseInfo response_; - scoped_ptr<HttpAuthController> auth_; + const scoped_refptr<HttpAuthController> auth_; // The hostname and port of the endpoint. This is not necessarily the one // specified by the URL, due to Alternate-Protocol or fixed testing ports. diff --git a/net/http/http_proxy_client_socket_pool.cc b/net/http/http_proxy_client_socket_pool.cc new file mode 100644 index 0000000..ce11bc3 --- /dev/null +++ b/net/http/http_proxy_client_socket_pool.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/http/http_proxy_client_socket_pool.h" + +#include "base/time.h" +#include "googleurl/src/gurl.h" +#include "net/base/net_errors.h" +#include "net/http/http_proxy_client_socket.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_base.h" + +namespace net { + +// HttpProxyConnectJobs will time out after this many seconds. Note this is on +// top of the timeout for the transport socket. +static const int kHttpProxyConnectJobTimeoutInSeconds = 30; + +HttpProxyConnectJob::HttpProxyConnectJob( + const std::string& group_name, + const HttpProxySocketParams& params, + const base::TimeDelta& timeout_duration, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HostResolver>& host_resolver, + Delegate* delegate, + NetLog* net_log) + : ConnectJob(group_name, timeout_duration, delegate, + BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), + params_(params), + tcp_pool_(tcp_pool), + resolver_(host_resolver), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &HttpProxyConnectJob::OnIOComplete)) { +} + +HttpProxyConnectJob::~HttpProxyConnectJob() {} + +LoadState HttpProxyConnectJob::GetLoadState() const { + switch (next_state_) { + case kStateTCPConnect: + case kStateTCPConnectComplete: + return tcp_socket_handle_->GetLoadState(); + case kStateHttpProxyConnect: + case kStateHttpProxyConnectComplete: + return LOAD_STATE_ESTABLISHING_PROXY_TUNNEL; + default: + NOTREACHED(); + return LOAD_STATE_IDLE; + } +} + +int HttpProxyConnectJob::ConnectInternal() { + next_state_ = kStateTCPConnect; + return DoLoop(OK); +} + +void HttpProxyConnectJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + NotifyDelegateOfCompletion(rv); // Deletes |this| +} + +int HttpProxyConnectJob::DoLoop(int result) { + DCHECK_NE(next_state_, kStateNone); + + int rv = result; + do { + State state = next_state_; + next_state_ = kStateNone; + switch (state) { + case kStateTCPConnect: + DCHECK_EQ(OK, rv); + rv = DoTCPConnect(); + break; + case kStateTCPConnectComplete: + rv = DoTCPConnectComplete(rv); + break; + case kStateHttpProxyConnect: + DCHECK_EQ(OK, rv); + rv = DoHttpProxyConnect(); + break; + case kStateHttpProxyConnectComplete: + rv = DoHttpProxyConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != kStateNone); + + return rv; +} + +int HttpProxyConnectJob::DoTCPConnect() { + next_state_ = kStateTCPConnectComplete; + tcp_socket_handle_.reset(new ClientSocketHandle()); + return tcp_socket_handle_->Init(group_name(), params_.tcp_params(), + params_.tcp_params().destination().priority(), + &callback_, tcp_pool_, net_log()); +} + +int HttpProxyConnectJob::DoTCPConnectComplete(int result) { + if (result != OK) + return result; + + // Reset the timer to just the length of time allowed for HttpProxy handshake + // so that a fast TCP connection plus a slow HttpProxy failure doesn't take + // longer to timeout than it should. + ResetTimer(base::TimeDelta::FromSeconds( + kHttpProxyConnectJobTimeoutInSeconds)); + next_state_ = kStateHttpProxyConnect; + return result; +} + +int HttpProxyConnectJob::DoHttpProxyConnect() { + next_state_ = kStateHttpProxyConnectComplete; + + // Add a HttpProxy connection on top of the tcp socket. + socket_.reset(new HttpProxyClientSocket(tcp_socket_handle_.release(), + params_.request_url(), + params_.endpoint(), + params_.auth_controller(), + params_.tunnel())); + return socket_->Connect(&callback_); +} + +int HttpProxyConnectJob::DoHttpProxyConnectComplete(int result) { + DCHECK_NE(result, ERR_RETRY_CONNECTION); + + if (result == OK || result == ERR_PROXY_AUTH_REQUESTED) + set_socket(socket_.release()); + + return result; +} + +ConnectJob* +HttpProxyClientSocketPool::HttpProxyConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return new HttpProxyConnectJob(group_name, request.params(), + ConnectionTimeout(), tcp_pool_, host_resolver_, + delegate, net_log_); +} + +base::TimeDelta +HttpProxyClientSocketPool::HttpProxyConnectJobFactory::ConnectionTimeout() +const { + return tcp_pool_->ConnectionTimeout() + + base::TimeDelta::FromSeconds(kHttpProxyConnectJobTimeoutInSeconds); +} + +HttpProxyClientSocketPool::HttpProxyClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + const scoped_refptr<HostResolver>& host_resolver, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + NetLog* net_log) + : base_(max_sockets, max_sockets_per_group, histograms, + base::TimeDelta::FromSeconds( + ClientSocketPool::unused_idle_socket_timeout()), + base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), + new HttpProxyConnectJobFactory(tcp_pool, host_resolver, net_log)) {} + +HttpProxyClientSocketPool::~HttpProxyClientSocketPool() {} + +int HttpProxyClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log) { + const HttpProxySocketParams* casted_socket_params = + static_cast<const HttpProxySocketParams*>(socket_params); + + return base_.RequestSocket(group_name, *casted_socket_params, priority, + handle, callback, net_log); +} + +void HttpProxyClientSocketPool::CancelRequest( + const std::string& group_name, + const ClientSocketHandle* handle) { + base_.CancelRequest(group_name, handle); +} + +void HttpProxyClientSocketPool::ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id) { + base_.ReleaseSocket(group_name, socket, id); +} + +void HttpProxyClientSocketPool::Flush() { + base_.Flush(); +} + +void HttpProxyClientSocketPool::CloseIdleSockets() { + base_.CloseIdleSockets(); +} + +int HttpProxyClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return base_.IdleSocketCountInGroup(group_name); +} + +LoadState HttpProxyClientSocketPool::GetLoadState( + const std::string& group_name, const ClientSocketHandle* handle) const { + return base_.GetLoadState(group_name, handle); +} + +} // namespace net diff --git a/net/http/http_proxy_client_socket_pool.h b/net/http/http_proxy_client_socket_pool.h new file mode 100644 index 0000000..dc85c32 --- /dev/null +++ b/net/http/http_proxy_client_socket_pool.h @@ -0,0 +1,201 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_HTTP_HTTP_PROXY_CLIENT_SOCKET_POOL_H_ +#define NET_HTTP_HTTP_PROXY_CLIENT_SOCKET_POOL_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/scoped_ptr.h" +#include "base/time.h" +#include "net/base/host_port_pair.h" +#include "net/base/host_resolver.h" +#include "net/proxy/proxy_server.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/tcp_client_socket_pool.h" + +namespace net { + +class ClientSocketFactory; +class ConnectJobFactory; +class HttpAuthController; + +class HttpProxySocketParams { + public: + HttpProxySocketParams(const TCPSocketParams& proxy_server, + const GURL& request_url, HostPortPair endpoint, + scoped_refptr<HttpAuthController> auth_controller, + bool tunnel) + : tcp_params_(proxy_server), + request_url_(request_url), + endpoint_(endpoint), + auth_controller_(auth_controller), + tunnel_(tunnel) { + } + + const TCPSocketParams& tcp_params() const { return tcp_params_; } + const GURL& request_url() const { return request_url_; } + const HostPortPair& endpoint() const { return endpoint_; } + const scoped_refptr<HttpAuthController>& auth_controller() const { + return auth_controller_; + } + bool tunnel() const { return tunnel_; } + + private: + const TCPSocketParams tcp_params_; + const GURL request_url_; + const HostPortPair endpoint_; + const scoped_refptr<HttpAuthController> auth_controller_; + const bool tunnel_; +}; + +// HttpProxyConnectJob optionally establishes a tunnel through the proxy +// server after connecting the underlying transport socket. +class HttpProxyConnectJob : public ConnectJob { + public: + HttpProxyConnectJob(const std::string& group_name, + const HttpProxySocketParams& params, + const base::TimeDelta& timeout_duration, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HostResolver> &host_resolver, + Delegate* delegate, + NetLog* net_log); + virtual ~HttpProxyConnectJob(); + + // ConnectJob methods. + virtual LoadState GetLoadState() const; + + private: + enum State { + kStateTCPConnect, + kStateTCPConnectComplete, + kStateHttpProxyConnect, + kStateHttpProxyConnectComplete, + kStateNone, + }; + + // Begins the tcp connection and the optional Http proxy tunnel. If the + // request is not immediately servicable (likely), the request will return + // ERR_IO_PENDING. An OK return from this function or the callback means + // that the connection is established; ERR_PROXY_AUTH_REQUESTED means + // that the tunnel needs authentication credentials, the socket will be + // returned in this case, and must be release back to the pool; or + // a standard net error code will be returned. + virtual int ConnectInternal(); + + void OnIOComplete(int result); + + // Runs the state transition loop. + int DoLoop(int result); + + int DoTCPConnect(); + int DoTCPConnectComplete(int result); + int DoHttpProxyConnect(); + int DoHttpProxyConnectComplete(int result); + + HttpProxySocketParams params_; + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + const scoped_refptr<HostResolver> resolver_; + + State next_state_; + CompletionCallbackImpl<HttpProxyConnectJob> callback_; + scoped_ptr<ClientSocketHandle> tcp_socket_handle_; + scoped_ptr<ClientSocket> socket_; + + DISALLOW_COPY_AND_ASSIGN(HttpProxyConnectJob); +}; + +class HttpProxyClientSocketPool : public ClientSocketPool { + public: + HttpProxyClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + const scoped_refptr<HostResolver>& host_resolver, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + NetLog* net_log); + + // ClientSocketPool methods: + virtual int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log); + + virtual void CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle); + + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket, + int id); + + virtual void Flush(); + + virtual void CloseIdleSockets(); + + virtual int IdleSocketCount() const { + return base_.idle_socket_count(); + } + + virtual int IdleSocketCountInGroup(const std::string& group_name) const; + + virtual LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const; + + virtual base::TimeDelta ConnectionTimeout() const { + return base_.ConnectionTimeout(); + } + + virtual scoped_refptr<ClientSocketPoolHistograms> histograms() const { + return base_.histograms(); + }; + + protected: + virtual ~HttpProxyClientSocketPool(); + + private: + typedef ClientSocketPoolBase<HttpProxySocketParams> PoolBase; + + class HttpProxyConnectJobFactory : public PoolBase::ConnectJobFactory { + public: + HttpProxyConnectJobFactory( + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + HostResolver* host_resolver, + NetLog* net_log) + : tcp_pool_(tcp_pool), + host_resolver_(host_resolver), + net_log_(net_log) {} + + virtual ~HttpProxyConnectJobFactory() {} + + // ClientSocketPoolBase::ConnectJobFactory methods. + virtual ConnectJob* NewConnectJob(const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const; + + virtual base::TimeDelta ConnectionTimeout() const; + + private: + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + const scoped_refptr<HostResolver> host_resolver_; + NetLog* net_log_; + + DISALLOW_COPY_AND_ASSIGN(HttpProxyConnectJobFactory); + }; + + PoolBase base_; + + DISALLOW_COPY_AND_ASSIGN(HttpProxyClientSocketPool); +}; + +REGISTER_SOCKET_PARAMS_FOR_POOL(HttpProxyClientSocketPool, + HttpProxySocketParams); + +} // namespace net + +#endif // NET_HTTP_HTTP_PROXY_CLIENT_SOCKET_POOL_H_ diff --git a/net/http/http_proxy_client_socket_pool_unittest.cc b/net/http/http_proxy_client_socket_pool_unittest.cc new file mode 100644 index 0000000..a3b6cca --- /dev/null +++ b/net/http/http_proxy_client_socket_pool_unittest.cc @@ -0,0 +1,318 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/http/http_proxy_client_socket_pool.h" + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/time.h" +#include "net/base/auth.h" +#include "net/base/mock_host_resolver.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/http/http_auth_controller.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; + +struct MockHttpAuthControllerData { + MockHttpAuthControllerData(std::string header) : auth_header(header) {} + + std::string auth_header; +}; + +class MockHttpAuthController : public HttpAuthController { + public: + MockHttpAuthController() + : HttpAuthController(HttpAuth::AUTH_PROXY, GURL(), + scoped_refptr<HttpNetworkSession>(NULL), + BoundNetLog()), + data_(NULL), + data_index_(0), + data_count_(0) { + } + + void SetMockAuthControllerData(struct MockHttpAuthControllerData* data, + size_t data_length) { + data_ = data; + data_count_ = data_length; + } + + // HttpAuthController methods. + virtual int MaybeGenerateAuthToken(const HttpRequestInfo* request, + CompletionCallback* callback) { + return OK; + } + virtual void AddAuthorizationHeader( + HttpRequestHeaders* authorization_headers) { + authorization_headers->AddHeadersFromString(CurrentData().auth_header); + } + virtual int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers, + bool do_not_send_server_auth, + bool establishing_tunnel) { + return OK; + } + virtual bool HaveAuthHandler() const { return HaveAuth(); } + virtual bool HaveAuth() const { + return CurrentData().auth_header.size() != 0; } + + private: + virtual ~MockHttpAuthController() {} + const struct MockHttpAuthControllerData& CurrentData() const { + DCHECK(data_index_ < data_count_); + return data_[data_index_]; + } + + MockHttpAuthControllerData* data_; + size_t data_index_; + size_t data_count_; +}; + +class HttpProxyClientSocketPoolTest : public ClientSocketPoolTest { + protected: + HttpProxyClientSocketPoolTest() + : ignored_tcp_socket_params_( + HostPortPair("proxy", 80), MEDIUM, GURL(), false), + tcp_histograms_(new ClientSocketPoolHistograms("MockTCP")), + tcp_socket_pool_(new MockTCPClientSocketPool(kMaxSockets, + kMaxSocketsPerGroup, tcp_histograms_, &tcp_client_socket_factory_)), + notunnel_socket_params_(ignored_tcp_socket_params_, GURL("http://host"), + HostPortPair("host", 80), NULL, false), + auth_controller_(new MockHttpAuthController), + tunnel_socket_params_(ignored_tcp_socket_params_, GURL("http://host"), + HostPortPair("host", 80), auth_controller_, true), + http_proxy_histograms_( + new ClientSocketPoolHistograms("HttpProxyUnitTest")), + pool_(new HttpProxyClientSocketPool(kMaxSockets, kMaxSocketsPerGroup, + http_proxy_histograms_, NULL, tcp_socket_pool_, NULL)) { + } + + int StartRequest(const std::string& group_name, RequestPriority priority) { + return StartRequestUsingPool( + pool_, group_name, priority, tunnel_socket_params_); + } + + TCPSocketParams ignored_tcp_socket_params_; + scoped_refptr<ClientSocketPoolHistograms> tcp_histograms_; + MockClientSocketFactory tcp_client_socket_factory_; + scoped_refptr<MockTCPClientSocketPool> tcp_socket_pool_; + + HttpProxySocketParams notunnel_socket_params_; + scoped_refptr<MockHttpAuthController> auth_controller_; + HttpProxySocketParams tunnel_socket_params_; + scoped_refptr<ClientSocketPoolHistograms> http_proxy_histograms_; + scoped_refptr<HttpProxyClientSocketPool> pool_; +}; + +TEST_F(HttpProxyClientSocketPoolTest, NoTunnel) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(false, 0)); + tcp_client_socket_factory_.AddSocketDataProvider(&data); + + ClientSocketHandle handle; + int rv = handle.Init("a", notunnel_socket_params_, LOW, NULL, pool_, + BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(HttpProxyClientSocketPoolTest, NeedAuth) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + MockRead reads[] = { + // No credentials. + MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), + MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), + MockRead("Content-Length: 10\r\n\r\n"), + MockRead("0123456789"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + + tcp_client_socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData(""), + }; + auth_controller_->SetMockAuthControllerData(auth_data, arraysize(auth_data)); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", tunnel_socket_params_, LOW, &callback, pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(HttpProxyClientSocketPoolTest, HaveAuth) { + MockWrite writes[] = { + MockWrite(false, + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead(false, "HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + data.set_connect_data(MockConnect(false, 0)); + + tcp_client_socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + auth_controller_->SetMockAuthControllerData(auth_data, arraysize(auth_data)); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", tunnel_socket_params_, LOW, &callback, pool_, + BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(HttpProxyClientSocketPoolTest, AsyncHaveAuth) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + + tcp_client_socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + auth_controller_->SetMockAuthControllerData(auth_data, arraysize(auth_data)); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", tunnel_socket_params_, LOW, &callback, pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(HttpProxyClientSocketPoolTest, TCPError) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(true, ERR_CONNECTION_CLOSED)); + + tcp_client_socket_factory_.AddSocketDataProvider(&data); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", tunnel_socket_params_, LOW, &callback, pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_CLOSED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +TEST_F(HttpProxyClientSocketPoolTest, TunnelUnexpectedClose) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 200 Conn"), + MockRead(true, ERR_CONNECTION_CLOSED), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + + tcp_client_socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + auth_controller_->SetMockAuthControllerData(auth_data, arraysize(auth_data)); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", tunnel_socket_params_, LOW, &callback, pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +TEST_F(HttpProxyClientSocketPoolTest, TunnelSetupError) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 304 Not Modified\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + + tcp_client_socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + auth_controller_->SetMockAuthControllerData(auth_data, arraysize(auth_data)); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", tunnel_socket_params_, LOW, &callback, pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +// It would be nice to also test the timeouts in HttpProxyClientSocketPool. + +} // namespace + +} // namespace net diff --git a/net/net.gyp b/net/net.gyp index 3fc102a..908e17d 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -385,6 +385,8 @@ 'http/url_security_manager_win.cc', 'http/http_proxy_client_socket.cc', 'http/http_proxy_client_socket.h', + 'http/http_proxy_client_socket_pool.cc', + 'http/http_proxy_client_socket_pool.h', 'http/http_util.cc', 'http/http_util_icu.cc', 'http/http_util.h', @@ -715,6 +717,7 @@ 'http/http_chunked_decoder_unittest.cc', 'http/http_network_layer_unittest.cc', 'http/http_network_transaction_unittest.cc', + 'http/http_proxy_client_socket_pool_unittest.cc', 'http/http_request_headers_unittest.cc', 'http/http_response_headers_unittest.cc', 'http/http_transaction_unittest.cc', diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc index 4adae02..fab8d3e 100644 --- a/net/socket/client_socket_handle.cc +++ b/net/socket/client_socket_handle.cc @@ -73,7 +73,8 @@ void ClientSocketHandle::OnIOComplete(int result) { void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(ERR_IO_PENDING, result); if (result != OK) { - ResetInternal(false); // The request failed, so there's nothing to cancel. + if (!socket_.get()) + ResetInternal(false); // Nothing to cancel since the request failed. return; } CHECK_NE(-1, pool_id_) << "Pool should have set |pool_id_| to a valid value."; diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h index 13095ab..cc6de9d 100644 --- a/net/socket/client_socket_handle.h +++ b/net/socket/client_socket_handle.h @@ -54,6 +54,12 @@ class ClientSocketHandle { // This method returns ERR_IO_PENDING if it cannot complete synchronously, in // which case the consumer will be notified of completion via |callback|. // + // If the pool was not able to reuse an existing socket, the new socket + // may report a recoverable error. In this case, the return value will + // indicate an error and the socket member will be set. If it is determined + // that the error is not recoverable, the Disconnect method should be used + // on the socket, so that it does not get reused. + // // Init may be called multiple times. // // Profiling information for the request is saved to |net_log| if non-NULL. diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h index 807104e..1c3784f 100644 --- a/net/socket/client_socket_pool.h +++ b/net/socket/client_socket_pool.h @@ -31,22 +31,27 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { public: // Requests a connected socket for a group_name. // - // There are four possible results from calling this function: + // There are five possible results from calling this function: // 1) RequestSocket returns OK and initializes |handle| with a reused socket. // 2) RequestSocket returns OK with a newly connected socket. // 3) RequestSocket returns ERR_IO_PENDING. The handle will be added to a // wait list until a socket is available to reuse or a new socket finishes // connecting. |priority| will determine the placement into the wait list. // 4) An error occurred early on, so RequestSocket returns an error code. + // 5) A recoverable error occurred while setting up the socket. An error + // code is returned, but the |handle| is initialized with the new socket. + // The caller must recover from the error before using the connection, or + // Disconnect the socket before releasing or resetting the |handle|. + // The current recoverable errors are: PROXY_AUTH_REQUESTED and the errors + // accepted by IsCertificateError(err). // // If this function returns OK, then |handle| is initialized upon return. // The |handle|'s is_initialized method will return true in this case. If a // ClientSocket was reused, then ClientSocketPool will call // |handle|->set_reused(true). In either case, the socket will have been // allocated and will be connected. A client might want to know whether or - // not the socket is reused in order to know whether or not he needs to - // perform SSL connection or tunnel setup or to request a new socket if he - // encounters an error with the reused socket. + // not the socket is reused in order to request a new socket if he encounters + // an error with the reused socket. // // If ERR_IO_PENDING is returned, then the callback will be used to notify the // client of completion. diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index 0b26b75..7e5c24fc 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -257,6 +257,11 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( group.jobs.insert(job); } else { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); + ClientSocket* error_socket = connect_job->ReleaseSocket(); + if (error_socket) { + HandOutSocket(error_socket, false /* not reused */, handle, + base::TimeDelta(), &group, request->net_log()); + } if (group.IsEmpty()) group_map_.erase(group_name); } @@ -605,16 +610,24 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( OnAvailableSocketSlot(group_name, MayHaveStalledGroups()); } } else { - DCHECK(!socket.get()); + // If we got a socket, it must contain error information so pass that + // up so that the caller can retrieve it. + bool handed_out_socket = false; if (!group.pending_requests.empty()) { scoped_ptr<const Request> r(RemoveRequestFromQueue( group.pending_requests.begin(), &group.pending_requests)); LogBoundConnectJobToRequest(job_log.source(), r.get()); + if (socket.get()) { + handed_out_socket = true; + HandOutSocket(socket.release(), false /* unused socket */, r->handle(), + base::TimeDelta(), &group, r->net_log()); + } r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL, new NetLogIntegerParameter("net_error", result)); r->callback()->Run(result); } - OnAvailableSocketSlot(group_name, MayHaveStalledGroups()); + if (!handed_out_socket) + OnAvailableSocketSlot(group_name, MayHaveStalledGroups()); } } diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 03c7879..3e1ec28 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -115,6 +115,8 @@ class TestConnectJob : public ConnectJob { kMockPendingFailingJob, kMockWaitingJob, kMockAdvancingLoadStateJob, + kMockRecoverableJob, + kMockPendingRecoverableJob, }; // The kMockPendingJob uses a slight delay before allowing the connect @@ -136,7 +138,7 @@ class TestConnectJob : public ConnectJob { load_state_(LOAD_STATE_IDLE) {} void Signal() { - DoConnect(waiting_success_, true /* async */); + DoConnect(waiting_success_, true /* async */, false /* recoverable */); } virtual LoadState GetLoadState() const { return load_state_; } @@ -150,9 +152,11 @@ class TestConnectJob : public ConnectJob { set_socket(new MockClientSocket()); switch (job_type_) { case kMockJob: - return DoConnect(true /* successful */, false /* sync */); + return DoConnect(true /* successful */, false /* sync */, + false /* recoverable */); case kMockFailingJob: - return DoConnect(false /* error */, false /* sync */); + return DoConnect(false /* error */, false /* sync */, + false /* recoverable */); case kMockPendingJob: set_load_state(LOAD_STATE_CONNECTING); @@ -172,7 +176,8 @@ class TestConnectJob : public ConnectJob { method_factory_.NewRunnableMethod( &TestConnectJob::DoConnect, true /* successful */, - true /* async */), + true /* async */, + false /* recoverable */), kPendingConnectDelay); return ERR_IO_PENDING; case kMockPendingFailingJob: @@ -182,7 +187,8 @@ class TestConnectJob : public ConnectJob { method_factory_.NewRunnableMethod( &TestConnectJob::DoConnect, false /* error */, - true /* async */), + true /* async */, + false /* recoverable */), 2); return ERR_IO_PENDING; case kMockWaitingJob: @@ -195,6 +201,20 @@ class TestConnectJob : public ConnectJob { method_factory_.NewRunnableMethod( &TestConnectJob::AdvanceLoadState, load_state_)); return ERR_IO_PENDING; + case kMockRecoverableJob: + return DoConnect(false /* error */, false /* sync */, + true /* recoverable */); + case kMockPendingRecoverableJob: + set_load_state(LOAD_STATE_CONNECTING); + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + method_factory_.NewRunnableMethod( + &TestConnectJob::DoConnect, + false /* error */, + true /* async */, + true /* recoverable */), + 2); + return ERR_IO_PENDING; default: NOTREACHED(); set_socket(NULL); @@ -204,12 +224,14 @@ class TestConnectJob : public ConnectJob { void set_load_state(LoadState load_state) { load_state_ = load_state; } - int DoConnect(bool succeed, bool was_async) { - int result = ERR_CONNECTION_FAILED; + int DoConnect(bool succeed, bool was_async, bool recoverable) { + int result = OK; if (succeed) { - result = OK; socket()->Connect(NULL); + } else if (recoverable) { + result = ERR_PROXY_AUTH_REQUESTED; } else { + result = ERR_CONNECTION_FAILED; set_socket(NULL); } @@ -583,6 +605,7 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { EXPECT_EQ(ERR_CONNECTION_FAILED, InitHandle(req.handle(), "a", kDefaultPriority, &req, pool_, log.bound())); + EXPECT_FALSE(req.handle()->socket()); EXPECT_EQ(3u, log.entries().size()); EXPECT_TRUE(LogContainsBeginEvent( @@ -1515,6 +1538,35 @@ TEST_F(ClientSocketPoolBaseTest, LoadState) { EXPECT_NE(LOAD_STATE_IDLE, req2.handle()->GetLoadState()); } +TEST_F(ClientSocketPoolBaseTest, Recoverable) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockRecoverableJob); + + TestSocketRequest req(&request_order_, &completion_count_); + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, InitHandle(req.handle(), "a", + kDefaultPriority, &req, pool_, + BoundNetLog())); + EXPECT_TRUE(req.handle()->is_initialized()); + EXPECT_TRUE(req.handle()->socket()); + req.handle()->Reset(); +} + +TEST_F(ClientSocketPoolBaseTest, AsyncRecoverable) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockPendingRecoverableJob); + TestSocketRequest req(&request_order_, &completion_count_); + int rv = InitHandle(req.handle(), "a", LOWEST, &req, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req.handle())); + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, req.WaitForResult()); + EXPECT_TRUE(req.handle()->is_initialized()); + EXPECT_TRUE(req.handle()->socket()); + req.handle()->Reset(); +} + + TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { CreatePoolWithIdleTimeouts( kDefaultMaxSockets, kDefaultMaxSocketsPerGroup, diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 8c64e5f..c708657 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -5,6 +5,7 @@ #include "net/socket/socket_test_util.h" #include <algorithm> +#include <vector> #include "base/basictypes.h" #include "base/compiler_specific.h" @@ -13,6 +14,7 @@ #include "net/base/address_family.h" #include "net/base/host_resolver_proc.h" #include "net/base/ssl_info.h" +#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -744,6 +746,99 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { } while (released_one); } +MockTCPClientSocketPool::MockConnectJob::MockConnectJob( + ClientSocket* socket, + ClientSocketHandle* handle, + CompletionCallback* callback) + : socket_(socket), + handle_(handle), + user_callback_(callback), + ALLOW_THIS_IN_INITIALIZER_LIST( + connect_callback_(this, &MockConnectJob::OnConnect)) { +} + +int MockTCPClientSocketPool::MockConnectJob::Connect() { + int rv = socket_->Connect(&connect_callback_); + if (rv == OK) { + user_callback_ = NULL; + OnConnect(OK); + } + return rv; +} + +bool MockTCPClientSocketPool::MockConnectJob::CancelHandle( + const ClientSocketHandle* handle) { + if (handle != handle_) + return false; + socket_.reset(); + handle_ = NULL; + user_callback_ = NULL; + return true; +} + +void MockTCPClientSocketPool::MockConnectJob::OnConnect(int rv) { + if (!socket_.get()) + return; + if (rv == OK) + handle_->set_socket(socket_.release()); + else + socket_.reset(); + + handle_ = NULL; + + if (user_callback_) { + CompletionCallback* callback = user_callback_; + user_callback_ = NULL; + callback->Run(rv); + } +} + +MockTCPClientSocketPool::MockTCPClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + ClientSocketFactory* socket_factory) + : TCPClientSocketPool(max_sockets, max_sockets_per_group, histograms, + NULL, NULL, NULL), + client_socket_factory_(socket_factory), + release_count_(0), + cancel_count_(0) { +} + +int MockTCPClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log) { + ClientSocket* socket = client_socket_factory_->CreateTCPClientSocket( + AddressList(), net_log.net_log()); + MockConnectJob* job = new MockConnectJob(socket, handle, callback); + job_list_.push_back(job); + handle->set_pool_id(1); + return job->Connect(); +} + +void MockTCPClientSocketPool::CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle) { + std::vector<MockConnectJob*>::iterator i; + for (i = job_list_.begin(); i != job_list_.end(); ++i) { + if ((*i)->CancelHandle(handle)) { + cancel_count_++; + break; + } + } +} + +void MockTCPClientSocketPool::ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id) { + EXPECT_EQ(1, id); + release_count_++; + delete socket; +} + +MockTCPClientSocketPool::~MockTCPClientSocketPool() {} + const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index dca6307..0824d39 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -24,6 +24,7 @@ #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_client_socket_pool.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -599,6 +600,61 @@ class ClientSocketPoolTest : public testing::Test { size_t completion_count_; }; +class MockTCPClientSocketPool : public TCPClientSocketPool { + public: + class MockConnectJob { + public: + MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle, + CompletionCallback* callback); + + int Connect(); + bool CancelHandle(const ClientSocketHandle* handle); + + private: + void OnConnect(int rv); + + scoped_ptr<ClientSocket> socket_; + ClientSocketHandle* handle_; + CompletionCallback* user_callback_; + CompletionCallbackImpl<MockConnectJob> connect_callback_; + + DISALLOW_COPY_AND_ASSIGN(MockConnectJob); + }; + + MockTCPClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + ClientSocketFactory* socket_factory); + + int release_count() { return release_count_; }; + int cancel_count() { return cancel_count_; }; + + // TCPClientSocketPool methods. + virtual int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log); + + virtual void CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle); + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id); + + protected: + virtual ~MockTCPClientSocketPool(); + + private: + ClientSocketFactory* client_socket_factory_; + int release_count_; + int cancel_count_; + ScopedVector<MockConnectJob> job_list_; + + DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool); +}; + // Constants for a successful SOCKS v5 handshake. extern const char kSOCKS5GreetRequest[]; extern const int kSOCKS5GreetRequestLength; diff --git a/net/socket/socks_client_socket_pool_unittest.cc b/net/socket/socks_client_socket_pool_unittest.cc index 7bfce55..b46789e 100644 --- a/net/socket/socks_client_socket_pool_unittest.cc +++ b/net/socket/socks_client_socket_pool_unittest.cc @@ -4,8 +4,6 @@ #include "net/socket/socks_client_socket_pool.h" -#include <vector> - #include "base/callback.h" #include "base/compiler_specific.h" #include "base/time.h" @@ -25,122 +23,6 @@ namespace { const int kMaxSockets = 32; const int kMaxSocketsPerGroup = 6; -class MockTCPClientSocketPool : public TCPClientSocketPool { - public: - class MockConnectJob { - public: - MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle, - CompletionCallback* callback) - : socket_(socket), - handle_(handle), - user_callback_(callback), - ALLOW_THIS_IN_INITIALIZER_LIST( - connect_callback_(this, &MockConnectJob::OnConnect)) {} - - int Connect() { - int rv = socket_->Connect(&connect_callback_); - if (rv == OK) { - user_callback_ = NULL; - OnConnect(OK); - } - return rv; - } - - bool CancelHandle(const ClientSocketHandle* handle) { - if (handle != handle_) - return false; - socket_.reset(); - handle_ = NULL; - user_callback_ = NULL; - return true; - } - - private: - void OnConnect(int rv) { - if (!socket_.get()) - return; - if (rv == OK) - handle_->set_socket(socket_.release()); - else - socket_.reset(); - - handle_ = NULL; - - if (user_callback_) { - CompletionCallback* callback = user_callback_; - user_callback_ = NULL; - callback->Run(rv); - } - } - - scoped_ptr<ClientSocket> socket_; - ClientSocketHandle* handle_; - CompletionCallback* user_callback_; - CompletionCallbackImpl<MockConnectJob> connect_callback_; - - DISALLOW_COPY_AND_ASSIGN(MockConnectJob); - }; - - MockTCPClientSocketPool( - int max_sockets, - int max_sockets_per_group, - const scoped_refptr<ClientSocketPoolHistograms>& histograms, - ClientSocketFactory* socket_factory) - : TCPClientSocketPool(max_sockets, max_sockets_per_group, histograms, - NULL, NULL, NULL), - client_socket_factory_(socket_factory), - release_count_(0), - cancel_count_(0) {} - - int release_count() { return release_count_; }; - int cancel_count() { return cancel_count_; }; - - // TCPClientSocketPool methods. - virtual int RequestSocket(const std::string& group_name, - const void* socket_params, - RequestPriority priority, - ClientSocketHandle* handle, - CompletionCallback* callback, - const BoundNetLog& net_log) { - ClientSocket* socket = client_socket_factory_->CreateTCPClientSocket( - AddressList(), net_log.net_log()); - MockConnectJob* job = new MockConnectJob(socket, handle, callback); - job_list_.push_back(job); - handle->set_pool_id(1); - return job->Connect(); - } - - virtual void CancelRequest(const std::string& group_name, - const ClientSocketHandle* handle) { - std::vector<MockConnectJob*>::iterator i; - for (i = job_list_.begin(); i != job_list_.end(); ++i) { - if ((*i)->CancelHandle(handle)) { - cancel_count_++; - break; - } - } - } - - virtual void ReleaseSocket(const std::string& group_name, - ClientSocket* socket, - int id) { - EXPECT_EQ(1, id); - release_count_++; - delete socket; - } - - protected: - virtual ~MockTCPClientSocketPool() {} - - private: - ClientSocketFactory* client_socket_factory_; - int release_count_; - int cancel_count_; - ScopedVector<MockConnectJob> job_list_; - - DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool); -}; - class SOCKSClientSocketPoolTest : public ClientSocketPoolTest { protected: class SOCKS5MockData { |