diff options
40 files changed, 2284 insertions, 474 deletions
diff --git a/chrome/browser/net/preconnect.cc b/chrome/browser/net/preconnect.cc index 10aa8b0..7930a68 100644 --- a/chrome/browser/net/preconnect.cc +++ b/chrome/browser/net/preconnect.cc @@ -54,6 +54,7 @@ static void HistogramPreconnectStatus(ProxyStatus status) { // static void Preconnect::PreconnectOnIOThread(const GURL& url) { + // TODO(jar): This does not handle proxies currently. URLRequestContextGetter* getter = Profile::GetDefaultRequestContext(); if (!getter) return; @@ -85,24 +86,45 @@ void Preconnect::PreconnectOnIOThread(const GURL& url) { net::HttpTransactionFactory* factory = context->http_transaction_factory(); net::HttpNetworkSession* session = factory->GetSession(); - scoped_refptr<net::TCPClientSocketPool> pool = session->tcp_socket_pool(); - - scoped_refptr<net::TCPSocketParams> params = - new net::TCPSocketParams(url.host(), url.EffectiveIntPort(), net::LOW, - GURL(), false); net::ClientSocketHandle handle; if (!callback_instance_) callback_instance_ = new Preconnect; + scoped_refptr<net::TCPSocketParams> tcp_params = + new net::TCPSocketParams(url.host(), url.EffectiveIntPort(), net::LOW, + GURL(), false); + net::HostPortPair endpoint(url.host(), url.EffectiveIntPort()); std::string group_name = endpoint.ToString(); - if (url.SchemeIs("https")) + + if (url.SchemeIs("https")) { group_name = StringPrintf("ssl/%s", group_name.c_str()); - // TODO(jar): This does not handle proxies currently. - handle.Init(group_name, params, net::LOWEST, - callback_instance_, pool, net::BoundNetLog()); + net::SSLConfig ssl_config; + session->ssl_config_service()->GetSSLConfig(&ssl_config); + // All preconnects should be for main pages. + ssl_config.verify_ev_cert = true; + + scoped_refptr<net::SSLSocketParams> ssl_params = + new net::SSLSocketParams(tcp_params, NULL, NULL, + net::ProxyServer::SCHEME_DIRECT, + url.HostNoBrackets(), ssl_config, + 0, false); + + const scoped_refptr<net::SSLClientSocketPool>& pool = + session->ssl_socket_pool(); + + handle.Init(group_name, ssl_params, net::LOWEST, callback_instance_, pool, + net::BoundNetLog()); + handle.Reset(); + return; + } + + const scoped_refptr<net::TCPClientSocketPool>& pool = + session->tcp_socket_pool(); + handle.Init(group_name, tcp_params, net::LOWEST, callback_instance_, pool, + net::BoundNetLog()); handle.Reset(); } diff --git a/net/base/load_flags_list.h b/net/base/load_flags_list.h index f7be1ff..2bcb23e 100644 --- a/net/base/load_flags_list.h +++ b/net/base/load_flags_list.h @@ -77,3 +77,5 @@ LOAD_FLAG(DO_NOT_SEND_COOKIES, 1 << 17) // to the server (as opposed to the proxy). LOAD_FLAG(DO_NOT_SEND_AUTH_DATA, 1 << 18) +// This should only be used for testing (set by HttpNetworkTransaction). +LOAD_FLAG(IGNORE_ALL_CERT_ERRORS, 1 << 19) diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h index d834fa5..06dca41 100644 --- a/net/base/net_error_list.h +++ b/net/base/net_error_list.h @@ -139,7 +139,8 @@ NET_ERROR(SOCKS_CONNECTION_FAILED, -120) // because that host is unreachable. NET_ERROR(SOCKS_CONNECTION_HOST_UNREACHABLE, -121) -// Error number -122 is available for use. +// The request to negotiate an alternate protocol failed. +NET_ERROR(NPN_NEGOTIATION_FAILED, -122) // The peer sent an SSL no_renegotiation alert message. NET_ERROR(SSL_NO_RENEGOTIATION, -123) @@ -163,7 +164,9 @@ NET_ERROR(PROXY_AUTH_REQUESTED, -127) // A known TLS strict server didn't offer the renegotiation extension. NET_ERROR(SSL_UNSAFE_NEGOTIATION, -128) -// The socket needs a fresh connection in order to proceed. +// The socket is reporting that we tried to provide new credentials after a +// a failed attempt on a connection without keep alive. We need to +// reestablish the transport socket in order to retry the authentication. NET_ERROR(RETRY_CONNECTION, -129) // Certificate error codes diff --git a/net/http/http_network_session.cc b/net/http/http_network_session.cc index e49674b..971786a 100644 --- a/net/http/http_network_session.cc +++ b/net/http/http_network_session.cc @@ -44,18 +44,21 @@ HttpNetworkSession::HttpNetworkSession( HttpAuthHandlerFactory* http_auth_handler_factory, HttpNetworkDelegate* network_delegate, NetLog* net_log) - // TODO(vandebo) when we've completely converted to pools, the base TCP - // pool name should get changed to TCP instead of Transport. - : tcp_pool_histograms_(new ClientSocketPoolHistograms("Transport")), + : tcp_pool_histograms_(new ClientSocketPoolHistograms("TCP")), tcp_for_http_proxy_pool_histograms_( new ClientSocketPoolHistograms("TCPforHTTPProxy")), http_proxy_pool_histograms_(new ClientSocketPoolHistograms("HTTPProxy")), tcp_for_socks_pool_histograms_( new ClientSocketPoolHistograms("TCPforSOCKS")), socks_pool_histograms_(new ClientSocketPoolHistograms("SOCK")), - tcp_socket_pool_(new TCPClientSocketPool(g_max_sockets, - g_max_sockets_per_group, tcp_pool_histograms_, host_resolver, - client_socket_factory, net_log)), + ssl_pool_histograms_(new ClientSocketPoolHistograms("SSL")), + tcp_socket_pool_(new TCPClientSocketPool( + g_max_sockets, g_max_sockets_per_group, tcp_pool_histograms_, + host_resolver, client_socket_factory, net_log)), + ssl_socket_pool_(new SSLClientSocketPool( + g_max_sockets, g_max_sockets_per_group, ssl_pool_histograms_, + host_resolver, client_socket_factory, tcp_socket_pool_, NULL, + NULL, net_log)), socket_factory_(client_socket_factory), host_resolver_(host_resolver), proxy_service_(proxy_service), @@ -74,12 +77,12 @@ HttpNetworkSession::~HttpNetworkSession() { const scoped_refptr<HttpProxyClientSocketPool>& HttpNetworkSession::GetSocketPoolForHTTPProxy(const HostPortPair& http_proxy) { HTTPProxySocketPoolMap::const_iterator it = - http_proxy_socket_pool_.find(http_proxy); - if (it != http_proxy_socket_pool_.end()) + http_proxy_socket_pools_.find(http_proxy); + if (it != http_proxy_socket_pools_.end()) return it->second; std::pair<HTTPProxySocketPoolMap::iterator, bool> ret = - http_proxy_socket_pool_.insert( + http_proxy_socket_pools_.insert( std::make_pair( http_proxy, new HttpProxyClientSocketPool( @@ -97,18 +100,42 @@ HttpNetworkSession::GetSocketPoolForHTTPProxy(const HostPortPair& http_proxy) { const scoped_refptr<SOCKSClientSocketPool>& HttpNetworkSession::GetSocketPoolForSOCKSProxy( const HostPortPair& socks_proxy) { - SOCKSSocketPoolMap::const_iterator it = socks_socket_pool_.find(socks_proxy); - if (it != socks_socket_pool_.end()) + SOCKSSocketPoolMap::const_iterator it = socks_socket_pools_.find(socks_proxy); + if (it != socks_socket_pools_.end()) return it->second; - std::pair<SOCKSSocketPoolMap::iterator, bool> ret = socks_socket_pool_.insert( - std::make_pair(socks_proxy, new SOCKSClientSocketPool( - g_max_sockets_per_proxy_server, g_max_sockets_per_group, - socks_pool_histograms_, host_resolver_, - new TCPClientSocketPool(g_max_sockets_per_proxy_server, - g_max_sockets_per_group, tcp_for_socks_pool_histograms_, - host_resolver_, socket_factory_, net_log_), - net_log_))); + std::pair<SOCKSSocketPoolMap::iterator, bool> ret = + socks_socket_pools_.insert( + std::make_pair(socks_proxy, new SOCKSClientSocketPool( + g_max_sockets_per_proxy_server, g_max_sockets_per_group, + socks_pool_histograms_, host_resolver_, + new TCPClientSocketPool(g_max_sockets_per_proxy_server, + g_max_sockets_per_group, tcp_for_socks_pool_histograms_, + host_resolver_, socket_factory_, net_log_), + net_log_))); + + return ret.first->second; +} + +const scoped_refptr<SSLClientSocketPool>& +HttpNetworkSession::GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) { + SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_proxies_.find(proxy_server); + if (it != ssl_socket_pools_for_proxies_.end()) + return it->second; + + SSLClientSocketPool* new_pool = new SSLClientSocketPool( + g_max_sockets_per_proxy_server, g_max_sockets_per_group, + ssl_pool_histograms_, host_resolver_, socket_factory_, + NULL, + GetSocketPoolForHTTPProxy(proxy_server), + GetSocketPoolForSOCKSProxy(proxy_server), + net_log_); + + std::pair<SSLSocketPoolMap::iterator, bool> ret = + ssl_socket_pools_for_proxies_.insert(std::make_pair(proxy_server, + new_pool)); return ret.first->second; } diff --git a/net/http/http_network_session.h b/net/http/http_network_session.h index 5f1b869..de57eba 100644 --- a/net/http/http_network_session.h +++ b/net/http/http_network_session.h @@ -21,6 +21,7 @@ #include "net/proxy/proxy_service.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket_pool.h" #include "net/socket/tcp_client_socket_pool.h" #include "net/spdy/spdy_settings_storage.h" @@ -71,12 +72,19 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, return tcp_socket_pool_; } + const scoped_refptr<SSLClientSocketPool>& ssl_socket_pool() { + return ssl_socket_pool_; + } + const scoped_refptr<SOCKSClientSocketPool>& GetSocketPoolForSOCKSProxy( const HostPortPair& socks_proxy); const scoped_refptr<HttpProxyClientSocketPool>& GetSocketPoolForHTTPProxy( const HostPortPair& http_proxy); + const scoped_refptr<SSLClientSocketPool>& GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server); + // SSL sockets come from the socket_factory(). ClientSocketFactory* socket_factory() { return socket_factory_; } HostResolver* host_resolver() { return host_resolver_; } @@ -100,11 +108,40 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, static uint16 fixed_https_port(); static void set_fixed_https_port(uint16 port); +#ifdef UNIT_TEST + void FlushSocketPools() { + if (ssl_socket_pool_.get()) + ssl_socket_pool_->Flush(); + if (tcp_socket_pool_.get()) + tcp_socket_pool_->Flush(); + + for (SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_proxies_.begin(); + it != ssl_socket_pools_for_proxies_.end(); + it++) + it->second->Flush(); + + for (SOCKSSocketPoolMap::const_iterator it = + socks_socket_pools_.begin(); + it != socks_socket_pools_.end(); + it++) + it->second->Flush(); + + for (HTTPProxySocketPoolMap::const_iterator it = + http_proxy_socket_pools_.begin(); + it != http_proxy_socket_pools_.end(); + it++) + it->second->Flush(); + } +#endif + private: typedef std::map<HostPortPair, scoped_refptr<HttpProxyClientSocketPool> > HTTPProxySocketPoolMap; typedef std::map<HostPortPair, scoped_refptr<SOCKSClientSocketPool> > SOCKSSocketPoolMap; + typedef std::map<HostPortPair, scoped_refptr<SSLClientSocketPool> > + SSLSocketPoolMap; friend class base::RefCounted<HttpNetworkSession>; friend class HttpNetworkSessionPeer; @@ -119,9 +156,12 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, scoped_refptr<ClientSocketPoolHistograms> http_proxy_pool_histograms_; scoped_refptr<ClientSocketPoolHistograms> tcp_for_socks_pool_histograms_; scoped_refptr<ClientSocketPoolHistograms> socks_pool_histograms_; + scoped_refptr<ClientSocketPoolHistograms> ssl_pool_histograms_; scoped_refptr<TCPClientSocketPool> tcp_socket_pool_; - HTTPProxySocketPoolMap http_proxy_socket_pool_; - SOCKSSocketPoolMap socks_socket_pool_; + scoped_refptr<SSLClientSocketPool> ssl_socket_pool_; + HTTPProxySocketPoolMap http_proxy_socket_pools_; + SOCKSSocketPoolMap socks_socket_pools_; + SSLSocketPoolMap ssl_socket_pools_for_proxies_; ClientSocketFactory* socket_factory_; scoped_refptr<HostResolver> host_resolver_; scoped_refptr<ProxyService> proxy_service_; diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index 32f8530..7f9b696 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -41,6 +41,7 @@ #include "net/socket/client_socket_factory.h" #include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket.h" +#include "net/socket/ssl_client_socket_pool.h" #include "net/socket/tcp_client_socket_pool.h" #include "net/spdy/spdy_http_stream.h" #include "net/spdy/spdy_session.h" @@ -260,7 +261,7 @@ int HttpNetworkTransaction::Start(const HttpRequestInfo* request_info, int HttpNetworkTransaction::RestartIgnoringLastError( CompletionCallback* callback) { - if (connection_->socket()->IsConnectedAndIdle()) { + if (connection_->socket() && connection_->socket()->IsConnectedAndIdle()) { // TODO(wtc): Should we update any of the connection histograms that we // update in DoSSLConnectComplete if |result| is OK? if (using_spdy_) { @@ -270,7 +271,8 @@ int HttpNetworkTransaction::RestartIgnoringLastError( next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; } } else { - connection_->socket()->Disconnect(); + if (connection_->socket()) + connection_->socket()->Disconnect(); connection_->Reset(); next_state_ = STATE_INIT_CONNECTION; } @@ -314,8 +316,8 @@ int HttpNetworkTransaction::RestartWithAuth( if (target == HttpAuth::AUTH_PROXY && using_ssl_ && proxy_info_.is_http()) { DCHECK(establishing_tunnel_); + next_state_ = STATE_INIT_CONNECTION; ResetStateForRestart(); - next_state_ = STATE_TUNNEL_RESTART_WITH_AUTH; } else { PrepareForAuthRestart(target); } @@ -377,22 +379,8 @@ int HttpNetworkTransaction::Read(IOBuffer* buf, int buf_len, State next_state = STATE_NONE; - // Are we using SPDY or HTTP? - if (using_spdy_) { - DCHECK(!http_stream_.get()); - DCHECK(spdy_http_stream_->GetResponseInfo()->headers); - next_state = STATE_SPDY_READ_BODY; - } else { - DCHECK(!spdy_http_stream_.get()); - next_state = STATE_READ_BODY; - - if (!connection_->is_initialized()) - return 0; // connection_->has been reset. Treat like EOF. - } - scoped_refptr<HttpResponseHeaders> headers = GetResponseHeaders(); - DCHECK(headers.get()); - if (establishing_tunnel_) { + if (headers_valid_ && headers.get() && establishing_tunnel_) { // We're trying to read the body of the response but we're still trying // to establish an SSL tunnel through the proxy. We can't read these // bytes when establishing a tunnel because they might be controlled by @@ -408,6 +396,19 @@ int HttpNetworkTransaction::Read(IOBuffer* buf, int buf_len, return ERR_TUNNEL_CONNECTION_FAILED; } + // Are we using SPDY or HTTP? + if (using_spdy_) { + DCHECK(!http_stream_.get()); + DCHECK(spdy_http_stream_->GetResponseInfo()->headers); + next_state = STATE_SPDY_READ_BODY; + } else { + DCHECK(!spdy_http_stream_.get()); + next_state = STATE_READ_BODY; + + if (!connection_->is_initialized()) + return 0; // |*connection_| has been reset. Treat like EOF. + } + read_buf_ = buf; read_buf_len_ = buf_len; @@ -517,17 +518,6 @@ int HttpNetworkTransaction::DoLoop(int result) { case STATE_INIT_CONNECTION_COMPLETE: rv = DoInitConnectionComplete(rv); break; - case STATE_TUNNEL_RESTART_WITH_AUTH: - DCHECK_EQ(OK, rv); - rv = DoTunnelRestartWithAuth(); - break; - case STATE_SSL_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoSSLConnect(); - break; - case STATE_SSL_CONNECT_COMPLETE: - rv = DoSSLConnectComplete(rv); - break; case STATE_GENERATE_PROXY_AUTH_TOKEN: DCHECK_EQ(OK, rv); rv = DoGenerateProxyAuthToken(); @@ -702,6 +692,7 @@ int HttpNetworkTransaction::DoResolveProxyComplete(int result) { int HttpNetworkTransaction::DoInitConnection() { DCHECK(!connection_->is_initialized()); DCHECK(proxy_info_.proxy_server().is_valid()); + next_state_ = STATE_INIT_CONNECTION_COMPLETE; // Now that the proxy server has been resolved, create the auth_controllers_. for (int i = 0; i < HttpAuth::AUTH_NUM_TARGETS; i++) { @@ -712,17 +703,11 @@ int HttpNetworkTransaction::DoInitConnection() { session_); } - next_state_ = STATE_INIT_CONNECTION_COMPLETE; - - using_ssl_ = request_->url.SchemeIs("https") || - (alternate_protocol_mode_ == kUsingAlternateProtocol && - alternate_protocol_ == HttpAlternateProtocols::NPN_SPDY_1); - + bool want_spdy = alternate_protocol_mode_ == kUsingAlternateProtocol + && alternate_protocol_ == HttpAlternateProtocols::NPN_SPDY_1; + using_ssl_ = request_->url.SchemeIs("https") || want_spdy; using_spdy_ = false; - - // Build the string used to uniquely identify connections of this type. - // Determine the host and port to connect to. - std::string connection_group; + response_.was_fetched_via_proxy = !proxy_info_.is_direct(); // Use the fixed testing ports if they've been provided. if (using_ssl_) { @@ -732,17 +717,18 @@ int HttpNetworkTransaction::DoInitConnection() { endpoint_.port = session_->fixed_http_port(); } - response_.was_fetched_via_proxy = !proxy_info_.is_direct(); - // Check first if we have a spdy session for this group. If so, then go // straight to using that. if (session_->spdy_session_pool()->HasSession(endpoint_)) { using_spdy_ = true; reused_socket_ = true; + next_state_ = STATE_SPDY_SEND_REQUEST; return OK; } - connection_group = endpoint_.ToString(); + // Build the string used to uniquely identify connections of this type. + // Determine the host and port to connect to. + std::string connection_group = endpoint_.ToString(); DCHECK(!connection_group.empty()); if (using_ssl_) @@ -753,206 +739,175 @@ int HttpNetworkTransaction::DoInitConnection() { request_->load_flags & LOAD_VALIDATE_CACHE || request_->load_flags & LOAD_DISABLE_CACHE; - int rv; - if (!proxy_info_.is_direct()) { - ProxyServer proxy_server = proxy_info_.proxy_server(); - HostPortPair proxy_host_port_pair(proxy_server.HostNoBrackets(), - proxy_server.port()); + // Build up the connection parameters. + scoped_refptr<TCPSocketParams> tcp_params; + scoped_refptr<HttpProxySocketParams> http_proxy_params; + scoped_refptr<SOCKSSocketParams> socks_params; + scoped_ptr<HostPortPair> proxy_host_port; - scoped_refptr<TCPSocketParams> tcp_params = - new TCPSocketParams(proxy_host_port_pair, request_->priority, + if (proxy_info_.is_direct()) { + tcp_params = new TCPSocketParams(endpoint_, request_->priority, + request_->referrer, + disable_resolver_cache); + } else { + ProxyServer proxy_server = proxy_info_.proxy_server(); + proxy_host_port.reset(new HostPortPair(proxy_server.HostNoBrackets(), + proxy_server.port())); + scoped_refptr<TCPSocketParams> proxy_tcp_params = + new TCPSocketParams(*proxy_host_port, request_->priority, request_->referrer, disable_resolver_cache); - if (proxy_info_.is_socks()) { - const char* socks_version; - bool socks_v5; - if (proxy_info_.proxy_server().scheme() == ProxyServer::SCHEME_SOCKS5) { - socks_version = "5"; - socks_v5 = true; - } else { - socks_version = "4"; - socks_v5 = false; - } - - connection_group = - StringPrintf("socks%s/%s", socks_version, connection_group.c_str()); - - scoped_refptr<SOCKSSocketParams> socks_params = - new SOCKSSocketParams(tcp_params, socks_v5, endpoint_, - request_->priority, request_->referrer); - - rv = connection_->Init( - connection_group, socks_params, request_->priority, - &io_callback_, - session_->GetSocketPoolForSOCKSProxy(proxy_host_port_pair), net_log_); - } else { - DCHECK(proxy_info_.is_http()); + if (proxy_info_.is_http()) { scoped_refptr<HttpAuthController> http_proxy_auth; if (using_ssl_) { http_proxy_auth = auth_controllers_[HttpAuth::AUTH_PROXY]; establishing_tunnel_ = true; } + http_proxy_params = new HttpProxySocketParams(proxy_tcp_params, + request_->url, endpoint_, + http_proxy_auth, + using_ssl_); + } else { + DCHECK(proxy_info_.is_socks()); + char socks_version; + if (proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5) + socks_version = '5'; + else + socks_version = '4'; + connection_group = + StringPrintf("socks%c/%s", socks_version, connection_group.c_str()); - scoped_refptr<HttpProxySocketParams> http_proxy_params = - new HttpProxySocketParams(tcp_params, request_->url, endpoint_, - http_proxy_auth, using_ssl_); - - rv = connection_->Init(connection_group, http_proxy_params, - request_->priority, &io_callback_, - session_->GetSocketPoolForHTTPProxy( - proxy_host_port_pair), - net_log_); + socks_params = new SOCKSSocketParams(proxy_tcp_params, + socks_version == '5', + endpoint_, + request_->priority, + request_->referrer); } - } else { - scoped_refptr<TCPSocketParams> tcp_params = - new TCPSocketParams(endpoint_, request_->priority, request_->referrer, - disable_resolver_cache); - rv = connection_->Init(connection_group, tcp_params, request_->priority, - &io_callback_, session_->tcp_socket_pool(), - net_log_); } - return rv; -} - -int HttpNetworkTransaction::DoInitConnectionComplete(int result) { - if (result < 0) { - if (result == ERR_RETRY_CONNECTION) { - DCHECK(establishing_tunnel_); - next_state_ = STATE_INIT_CONNECTION; - connection_->socket()->Disconnect(); - connection_->Reset(); - return OK; + // Deal with SSL - which layers on top of any given proxy. + if (using_ssl_) { + if (ContainsKey(*g_tls_intolerant_servers, GetHostAndPort(request_->url))) { + LOG(WARNING) << "Falling back to SSLv3 because host is TLS intolerant: " + << GetHostAndPort(request_->url); + ssl_config_.ssl3_fallback = true; + ssl_config_.tls1_enabled = false; } - if (result == ERR_PROXY_AUTH_REQUESTED) { - DCHECK(establishing_tunnel_); - HttpProxyClientSocket* tunnel_socket = - static_cast<HttpProxyClientSocket*>(connection_->socket()); - DCHECK(tunnel_socket); - DCHECK(!tunnel_socket->IsConnected()); - const HttpResponseInfo* auth_response = tunnel_socket->GetResponseInfo(); - - response_.headers = auth_response->headers; - headers_valid_ = true; - response_.auth_challenge = auth_response->auth_challenge; - pending_auth_target_ = HttpAuth::AUTH_PROXY; - return OK; - } + UMA_HISTOGRAM_ENUMERATION("Net.ConnectionUsedSSLv3Fallback", + (int) ssl_config_.ssl3_fallback, 2); - if (alternate_protocol_mode_ == kUsingAlternateProtocol) { - // Mark the alternate protocol as broken and fallback. - MarkBrokenAlternateProtocolAndFallback(); - return OK; - } + int load_flags = request_->load_flags; + if (g_ignore_certificate_errors) + load_flags |= LOAD_IGNORE_ALL_CERT_ERRORS; + if (request_->load_flags & LOAD_VERIFY_EV_CERT) + ssl_config_.verify_ev_cert = true; - return ReconsiderProxyAfterError(result); - } + scoped_refptr<SSLSocketParams> ssl_params = + new SSLSocketParams(tcp_params, http_proxy_params, socks_params, + proxy_info_.proxy_server().scheme(), + request_->url.HostNoBrackets(), ssl_config_, + load_flags, want_spdy); - DCHECK_EQ(OK, result); - if (establishing_tunnel_) { - DCHECK(connection_->socket()->IsConnected()); - establishing_tunnel_ = false; - } + scoped_refptr<SSLClientSocketPool> ssl_pool; + if (proxy_info_.is_direct()) + ssl_pool = session_->ssl_socket_pool(); + else + ssl_pool = session_->GetSocketPoolForSSLWithProxy(*proxy_host_port); - if (using_spdy_) { - DCHECK(!connection_->is_initialized()); - // TODO(cbentzel): Add auth support to spdy. See http://crbug.com/46620 - next_state_ = STATE_SPDY_SEND_REQUEST; - return OK; + return connection_->Init(connection_group, ssl_params, request_->priority, + &io_callback_, ssl_pool, net_log_); } - LogHttpConnectedMetrics(*connection_); + // Finally, get the connection started. + if (proxy_info_.is_http()) { + return connection_->Init( + connection_group, http_proxy_params, request_->priority, &io_callback_, + session_->GetSocketPoolForHTTPProxy(*proxy_host_port), net_log_); + } - // Set the reused_socket_ flag to indicate that we are using a keep-alive - // connection. This flag is used to handle errors that occur while we are - // trying to reuse a keep-alive connection. - reused_socket_ = connection_->is_reused(); - if (reused_socket_) { - if (using_ssl_) { - SSLClientSocket* ssl_socket = - reinterpret_cast<SSLClientSocket*>(connection_->socket()); - response_.was_npn_negotiated = ssl_socket->wasNpnNegotiated(); - } - next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; - } else { - // Now we have a TCP connected socket. Perform other connection setup as - // needed. - UpdateConnectionTypeHistograms(CONNECTION_HTTP); - if (using_ssl_) - next_state_ = STATE_SSL_CONNECT; - else - next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; + if (proxy_info_.is_socks()) { + return connection_->Init( + connection_group, socks_params, request_->priority, &io_callback_, + session_->GetSocketPoolForSOCKSProxy(*proxy_host_port), net_log_); } - return OK; + DCHECK(proxy_info_.is_direct()); + return connection_->Init(connection_group, tcp_params, request_->priority, + &io_callback_, session_->tcp_socket_pool(), + net_log_); } -int HttpNetworkTransaction::DoTunnelRestartWithAuth() { - next_state_ = STATE_INIT_CONNECTION_COMPLETE; - HttpProxyClientSocket* tunnel_socket = - reinterpret_cast<HttpProxyClientSocket*>(connection_->socket()); - - return tunnel_socket->RestartWithAuth(&io_callback_); -} +int HttpNetworkTransaction::DoInitConnectionComplete(int result) { + // |result| may be the result of any of the stacked pools. The following + // logic is used when determining how to interpret an error. + // If |result| < 0: + // and connection_->socket() != NULL, then the SSL handshake ran and it + // is a potentially recoverable error. + // and connection_->socket == NULL and connection_->is_ssl_error() is true, + // then the SSL handshake ran with an unrecoverable error. + // otherwise, the error came from one of the other pools. + bool ssl_started = using_ssl_ && (result == OK || connection_->socket() || + connection_->is_ssl_error()); + + if (ssl_started && (result == OK || IsCertificateError(result))) { + SSLClientSocket* ssl_socket = + static_cast<SSLClientSocket*>(connection_->socket()); + if (ssl_socket->wasNpnNegotiated()) { + response_.was_npn_negotiated = true; + std::string proto; + ssl_socket->GetNextProto(&proto); + if (SSLClientSocket::NextProtoFromString(proto) == + SSLClientSocket::kProtoSPDY1) + using_spdy_ = true; + } + } -int HttpNetworkTransaction::DoSSLConnect() { - next_state_ = STATE_SSL_CONNECT_COMPLETE; + if (result == ERR_PROXY_AUTH_REQUESTED) { + DCHECK(!ssl_started); + const HttpResponseInfo& tunnel_auth_response = + connection_->tunnel_auth_response_info(); - if (ContainsKey(*g_tls_intolerant_servers, GetHostAndPort(request_->url))) { - LOG(WARNING) << "Falling back to SSLv3 because host is TLS intolerant: " - << GetHostAndPort(request_->url); - ssl_config_.ssl3_fallback = true; - ssl_config_.tls1_enabled = false; + response_.headers = tunnel_auth_response.headers; + response_.auth_challenge = tunnel_auth_response.auth_challenge; + headers_valid_ = true; + pending_auth_target_ = HttpAuth::AUTH_PROXY; + return OK; } - UMA_HISTOGRAM_ENUMERATION("Net.ConnectionUsedSSLv3Fallback", - (int) ssl_config_.ssl3_fallback, 2); + if ((!ssl_started && result < 0 && + alternate_protocol_mode_ == kUsingAlternateProtocol) || + result == ERR_NPN_NEGOTIATION_FAILED) { + // Mark the alternate protocol as broken and fallback. + MarkBrokenAlternateProtocolAndFallback(); + return OK; + } - if (request_->load_flags & LOAD_VERIFY_EV_CERT) - ssl_config_.verify_ev_cert = true; + if (result < 0 && !ssl_started) + return ReconsiderProxyAfterError(result); + establishing_tunnel_ = false; - ssl_connect_start_time_ = base::TimeTicks::Now(); + if (connection_->socket()) { + LogHttpConnectedMetrics(*connection_); - // Add a SSL socket on top of our existing transport socket. - ClientSocket* s = connection_->release_socket(); - s = session_->socket_factory()->CreateSSLClientSocket( - s, request_->url.HostNoBrackets(), ssl_config_); - connection_->set_socket(s); - return connection_->socket()->Connect(&io_callback_); -} + // Set the reused_socket_ flag to indicate that we are using a keep-alive + // connection. This flag is used to handle errors that occur while we are + // trying to reuse a keep-alive connection. + reused_socket_ = connection_->is_reused(); + // TODO(vandebo) should we exclude SPDY in the following if? + if (!reused_socket_) + UpdateConnectionTypeHistograms(CONNECTION_HTTP); -int HttpNetworkTransaction::DoSSLConnectComplete(int result) { - SSLClientSocket* ssl_socket = - reinterpret_cast<SSLClientSocket*>(connection_->socket()); - - SSLClientSocket::NextProtoStatus status = - SSLClientSocket::kNextProtoUnsupported; - std::string proto; - // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket - // that hasn't had SSL_ImportFD called on it. If we get a certificate error - // here, then we know that we called SSL_ImportFD. - if (result == OK || IsCertificateError(result)) - status = ssl_socket->GetNextProto(&proto); - - if (status == SSLClientSocket::kNextProtoNegotiated) { - ssl_socket->setWasNpnNegotiated(true); - response_.was_npn_negotiated = true; - if (SSLClientSocket::NextProtoFromString(proto) == - SSLClientSocket::kProtoSPDY1) { - using_spdy_ = true; + if (!using_ssl_) { + DCHECK_EQ(OK, result); + next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; + return result; } } - if (alternate_protocol_mode_ == kUsingAlternateProtocol && - alternate_protocol_ == HttpAlternateProtocols::NPN_SPDY_1 && - !using_spdy_) { - // We tried using the NPN_SPDY_1 alternate protocol, but failed, so we - // fallback. - MarkBrokenAlternateProtocolAndFallback(); - return OK; - } - + // Handle SSL errors below. + DCHECK(using_ssl_); + DCHECK(ssl_started); if (IsCertificateError(result)) { if (using_spdy_ && request_->url.SchemeIs("http")) { // We ignore certificate errors for http over spdy. @@ -969,36 +924,19 @@ int HttpNetworkTransaction::DoSSLConnectComplete(int result) { } } - if (result == OK) { - DCHECK(ssl_connect_start_time_ != base::TimeTicks()); - base::TimeDelta connect_duration = - base::TimeTicks::Now() - ssl_connect_start_time_; - - if (using_spdy_) { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.SpdyConnectionLatency", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - UpdateConnectionTypeHistograms(CONNECTION_SPDY); - // TODO(cbentzel): Add auth support to spdy. See http://crbug.com/46620 - next_state_ = STATE_SPDY_SEND_REQUEST; - } else { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); + if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) + return HandleCertificateRequest(result); + if (result < 0) + return HandleSSLHandshakeError(result); - next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; - } - } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { - result = HandleCertificateRequest(result); + if (using_spdy_) { + UpdateConnectionTypeHistograms(CONNECTION_SPDY); + // TODO(cbentzel): Add auth support to spdy. See http://crbug.com/46620 + next_state_ = STATE_SPDY_SEND_REQUEST; } else { - result = HandleSSLHandshakeError(result); + next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; } - return result; + return OK; } int HttpNetworkTransaction::DoGenerateProxyAuthToken() { @@ -1196,7 +1134,7 @@ int HttpNetworkTransaction::DoReadHeadersComplete(int result) { if (using_ssl_) { SSLClientSocket* ssl_socket = - reinterpret_cast<SSLClientSocket*>(connection_->socket()); + static_cast<SSLClientSocket*>(connection_->socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); } @@ -1539,7 +1477,7 @@ int HttpNetworkTransaction::HandleCertificateError(int error) { DCHECK(IsCertificateError(error)); SSLClientSocket* ssl_socket = - reinterpret_cast<SSLClientSocket*>(connection_->socket()); + static_cast<SSLClientSocket*>(connection_->socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); // Add the bad certificate to the set of allowed certificates in the @@ -1551,29 +1489,11 @@ int HttpNetworkTransaction::HandleCertificateError(int error) { bad_cert.cert_status = response_.ssl_info.cert_status; ssl_config_.allowed_bad_certs.push_back(bad_cert); + int load_flags = request_->load_flags; if (g_ignore_certificate_errors) + load_flags |= LOAD_IGNORE_ALL_CERT_ERRORS; + if (ssl_socket->IgnoreCertError(error, load_flags)) return OK; - - const int kCertFlags = LOAD_IGNORE_CERT_COMMON_NAME_INVALID | - LOAD_IGNORE_CERT_DATE_INVALID | - LOAD_IGNORE_CERT_AUTHORITY_INVALID | - LOAD_IGNORE_CERT_WRONG_USAGE; - if (request_->load_flags & kCertFlags) { - switch (error) { - case ERR_CERT_COMMON_NAME_INVALID: - if (request_->load_flags & LOAD_IGNORE_CERT_COMMON_NAME_INVALID) - error = OK; - break; - case ERR_CERT_DATE_INVALID: - if (request_->load_flags & LOAD_IGNORE_CERT_DATE_INVALID) - error = OK; - break; - case ERR_CERT_AUTHORITY_INVALID: - if (request_->load_flags & LOAD_IGNORE_CERT_AUTHORITY_INVALID) - error = OK; - break; - } - } return error; } @@ -1590,7 +1510,7 @@ int HttpNetworkTransaction::HandleCertificateRequest(int error) { response_.cert_request_info = new SSLCertRequestInfo; SSLClientSocket* ssl_socket = - reinterpret_cast<SSLClientSocket*>(connection_->socket()); + static_cast<SSLClientSocket*>(connection_->socket()); ssl_socket->GetSSLCertRequestInfo(response_.cert_request_info); // Close the connection while the user is selecting a certificate to send @@ -1695,11 +1615,13 @@ bool HttpNetworkTransaction::ShouldResendRequest(int error) const { } void HttpNetworkTransaction::ResetConnectionAndRequestForResend() { - connection_->socket()->Disconnect(); + if (connection_->socket()) + connection_->socket()->Disconnect(); connection_->Reset(); // We need to clear request_headers_ because it contains the real request // headers, but we may need to resend the CONNECT request first to recreate // the SSL tunnel. + request_headers_.clear(); next_state_ = STATE_INIT_CONNECTION; // Resend the request. } diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index 373aba8..a5207da 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -127,7 +127,6 @@ class HttpNetworkTransaction : public HttpTransaction { int DoResolveProxyComplete(int result); int DoInitConnection(); int DoInitConnectionComplete(int result); - int DoTunnelRestartWithAuth(); int DoSSLConnect(); int DoSSLConnectComplete(int result); int DoGenerateProxyAuthToken(); @@ -305,9 +304,6 @@ class HttpNetworkTransaction : public HttpTransaction { // The time the Start method was called. base::Time start_time_; - // The time the DoSSLConnect() method was called (if it got called). - base::TimeTicks ssl_connect_start_time_; - // The next state in the state machine. State next_state_; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 976452d..48945db 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -59,13 +59,23 @@ class HttpNetworkSessionPeer { void SetSocketPoolForSOCKSProxy( const HostPortPair& socks_proxy, const scoped_refptr<SOCKSClientSocketPool>& pool) { - session_->socks_socket_pool_[socks_proxy] = pool; + session_->socks_socket_pools_[socks_proxy] = pool; } void SetSocketPoolForHTTPProxy( const HostPortPair& http_proxy, const scoped_refptr<HttpProxyClientSocketPool>& pool) { - session_->http_proxy_socket_pool_[http_proxy] = pool; + session_->http_proxy_socket_pools_[http_proxy] = pool; + } + + void SetSSLSocketPool(const scoped_refptr<SSLClientSocketPool>& pool) { + session_->ssl_socket_pool_ = pool; + } + + void SetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_host, + const scoped_refptr<SSLClientSocketPool>& pool) { + session_->ssl_socket_pools_for_proxies_[proxy_host] = pool; } private: @@ -238,13 +248,11 @@ std::string MockGetHostName() { return "WTC-WIN7"; } -template<typename EmulatedClientSocketPool> -class CaptureGroupNameSocketPool : public EmulatedClientSocketPool { +template<typename ParentPool> +class CaptureGroupNameSocketPool : public ParentPool { public: - explicit CaptureGroupNameSocketPool(HttpNetworkSession* session) - : EmulatedClientSocketPool(0, 0, NULL, session->host_resolver(), NULL, - NULL) { - } + explicit CaptureGroupNameSocketPool(HttpNetworkSession* session); + const std::string last_group_name_received() const { return last_group_name_; } @@ -290,6 +298,19 @@ typedef CaptureGroupNameSocketPool<HttpProxyClientSocketPool> CaptureGroupNameHttpProxySocketPool; typedef CaptureGroupNameSocketPool<SOCKSClientSocketPool> CaptureGroupNameSOCKSSocketPool; +typedef CaptureGroupNameSocketPool<SSLClientSocketPool> +CaptureGroupNameSSLSocketPool; + +template<typename ParentPool> +CaptureGroupNameSocketPool<ParentPool>::CaptureGroupNameSocketPool( + HttpNetworkSession* session) + : ParentPool(0, 0, NULL, session->host_resolver(), NULL, NULL) {} + +template<> +CaptureGroupNameSSLSocketPool::CaptureGroupNameSocketPool( + HttpNetworkSession* session) + : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), NULL, NULL, + NULL, NULL, NULL) {} //----------------------------------------------------------------------------- @@ -1404,12 +1425,9 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyKeepAlive) { EXPECT_EQ(L"MyRealm1", response->auth_challenge->realm); EXPECT_EQ(L"basic", response->auth_challenge->scheme); - // Cleanup the transaction so that the sockets are destroyed before the - // net log goes out of scope. - trans.reset(); - - // We also need to run the message queue for the socket releases to complete. - MessageLoop::current()->RunAllPending(); + // Flush the idle socket before the NetLog and HttpNetworkTransaction go + // out of scope. + session->FlushSocketPools(); } // Test that we don't read the response body when we fail to establish a tunnel, @@ -1465,6 +1483,9 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyCancelTunnel) { std::string response_data; rv = ReadTransaction(trans.get(), &response_data); EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, rv); + + // Flush the idle socket before the HttpNetworkTransaction goes out of scope. + session->FlushSocketPools(); } // Test when a server (non-proxy) returns a 407 (proxy-authenticate). @@ -3977,6 +3998,7 @@ struct GroupNameTest { std::string proxy_server; std::string url; std::string expected_group_name; + bool ssl; }; scoped_refptr<HttpNetworkSession> SetupSessionForGroupNameTests( @@ -4015,11 +4037,13 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForDirectConnections) { "", // unused "http://www.google.com/direct", "www.google.com:80", + false, }, { "", // unused "http://[2001:1418:13:1::25]/direct", "[2001:1418:13:1::25]:80", + false, }, // SSL Tests @@ -4027,16 +4051,19 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForDirectConnections) { "", // unused "https://www.google.com/direct_ssl", "ssl/www.google.com:443", + true, }, { "", // unused "https://[2001:1418:13:1::25]/direct", "ssl/[2001:1418:13:1::25]:443", + true, }, { "", // unused "http://host.with.alternate/direct", "ssl/host.with.alternate:443", + true, }, }; @@ -4050,11 +4077,18 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForDirectConnections) { scoped_refptr<CaptureGroupNameTCPSocketPool> tcp_conn_pool( new CaptureGroupNameTCPSocketPool(session.get())); peer.SetTCPSocketPool(tcp_conn_pool); + scoped_refptr<CaptureGroupNameSSLSocketPool> ssl_conn_pool( + new CaptureGroupNameSSLSocketPool(session.get())); + peer.SetSSLSocketPool(ssl_conn_pool); EXPECT_EQ(ERR_IO_PENDING, GroupNameTransactionHelper(tests[i].url, session)); - EXPECT_EQ(tests[i].expected_group_name, - tcp_conn_pool->last_group_name_received()); + if (tests[i].ssl) + EXPECT_EQ(tests[i].expected_group_name, + ssl_conn_pool->last_group_name_received()); + else + EXPECT_EQ(tests[i].expected_group_name, + tcp_conn_pool->last_group_name_received()); } HttpNetworkTransaction::SetUseAlternateProtocols(false); @@ -4066,6 +4100,7 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForHTTPProxyConnections) { "http_proxy", "http://www.google.com/http_proxy_normal", "www.google.com:80", + false, }, // SSL Tests @@ -4073,12 +4108,14 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForHTTPProxyConnections) { "http_proxy", "https://www.google.com/http_connect_ssl", "ssl/www.google.com:443", + true, }, { "http_proxy", "http://host.with.alternate/direct", "ssl/host.with.alternate:443", + true, }, }; @@ -4090,15 +4127,22 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForHTTPProxyConnections) { HttpNetworkSessionPeer peer(session); + HostPortPair proxy_host("http_proxy", 80); scoped_refptr<CaptureGroupNameHttpProxySocketPool> http_proxy_pool( new CaptureGroupNameHttpProxySocketPool(session.get())); - peer.SetSocketPoolForHTTPProxy( - HostPortPair("http_proxy", 80), http_proxy_pool); + peer.SetSocketPoolForHTTPProxy(proxy_host, http_proxy_pool); + scoped_refptr<CaptureGroupNameSSLSocketPool> ssl_conn_pool( + new CaptureGroupNameSSLSocketPool(session.get())); + peer.SetSocketPoolForSSLWithProxy(proxy_host, ssl_conn_pool); EXPECT_EQ(ERR_IO_PENDING, GroupNameTransactionHelper(tests[i].url, session)); - EXPECT_EQ(tests[i].expected_group_name, - http_proxy_pool->last_group_name_received()); + if (tests[i].ssl) + EXPECT_EQ(tests[i].expected_group_name, + ssl_conn_pool->last_group_name_received()); + else + EXPECT_EQ(tests[i].expected_group_name, + http_proxy_pool->last_group_name_received()); } HttpNetworkTransaction::SetUseAlternateProtocols(false); @@ -4110,11 +4154,13 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForSOCKSConnections) { "socks4://socks_proxy:1080", "http://www.google.com/socks4_direct", "socks4/www.google.com:80", + false, }, { "socks5://socks_proxy:1080", "http://www.google.com/socks5_direct", "socks5/www.google.com:80", + false, }, // SSL Tests @@ -4122,17 +4168,20 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForSOCKSConnections) { "socks4://socks_proxy:1080", "https://www.google.com/socks4_ssl", "socks4/ssl/www.google.com:443", + true, }, { "socks5://socks_proxy:1080", "https://www.google.com/socks5_ssl", "socks5/ssl/www.google.com:443", + true, }, { "socks4://socks_proxy:1080", "http://host.with.alternate/direct", "socks4/ssl/host.with.alternate:443", + true, }, }; @@ -4143,17 +4192,24 @@ TEST_F(HttpNetworkTransactionTest, GroupNameForSOCKSConnections) { SetupSessionForGroupNameTests(tests[i].proxy_server)); HttpNetworkSessionPeer peer(session); + HostPortPair proxy_host("socks_proxy", 1080); scoped_refptr<CaptureGroupNameSOCKSSocketPool> socks_conn_pool( new CaptureGroupNameSOCKSSocketPool(session.get())); - peer.SetSocketPoolForSOCKSProxy( - HostPortPair("socks_proxy", 1080), socks_conn_pool); + peer.SetSocketPoolForSOCKSProxy(proxy_host, socks_conn_pool); + scoped_refptr<CaptureGroupNameSSLSocketPool> ssl_conn_pool( + new CaptureGroupNameSSLSocketPool(session.get())); + peer.SetSocketPoolForSSLWithProxy(proxy_host, ssl_conn_pool); scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction(session)); EXPECT_EQ(ERR_IO_PENDING, GroupNameTransactionHelper(tests[i].url, session)); - EXPECT_EQ(tests[i].expected_group_name, - socks_conn_pool->last_group_name_received()); + if (tests[i].ssl) + EXPECT_EQ(tests[i].expected_group_name, + ssl_conn_pool->last_group_name_received()); + else + EXPECT_EQ(tests[i].expected_group_name, + socks_conn_pool->last_group_name_received()); } HttpNetworkTransaction::SetUseAlternateProtocols(false); @@ -5869,6 +5925,9 @@ TEST_F(HttpNetworkTransactionTest, GenerateAuthToken) { } } } + + // Flush the idle socket before the HttpNetworkTransaction goes out of scope. + session->FlushSocketPools(); } class TLSDecompressionFailureSocketDataProvider : public SocketDataProvider { @@ -5919,6 +5978,11 @@ TEST_F(HttpNetworkTransactionTest, RestartAfterTLSDecompressionFailure) { session_deps.socket_factory.AddSSLSocketDataProvider( &ssl_socket_data_provider2); + // Work around http://crbug.com/37454 + StaticSocketDataProvider bug37454_connection; + bug37454_connection.set_connect_data(MockConnect(true, ERR_UNEXPECTED)); + session_deps.socket_factory.AddSocketDataProvider(&bug37454_connection); + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps)); scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction(session)); TestCompletionCallback callback; diff --git a/net/http/http_proxy_client_socket.cc b/net/http/http_proxy_client_socket.cc index df64589..ec4b753 100644 --- a/net/http/http_proxy_client_socket.cc +++ b/net/http/http_proxy_client_socket.cc @@ -159,12 +159,15 @@ void HttpProxyClientSocket::Disconnect() { } bool HttpProxyClientSocket::IsConnected() const { - return next_state_ == STATE_DONE && transport_->socket()->IsConnected(); + return transport_->socket()->IsConnected(); } bool HttpProxyClientSocket::IsConnectedAndIdle() const { - return next_state_ == STATE_DONE - && transport_->socket()->IsConnectedAndIdle(); + return transport_->socket()->IsConnectedAndIdle(); +} + +bool HttpProxyClientSocket::NeedsRestartWithAuth() const { + return next_state_ != STATE_DONE; } int HttpProxyClientSocket::Read(IOBuffer* buf, int buf_len, @@ -336,11 +339,8 @@ int HttpProxyClientSocket::DoReadHeaders() { } int HttpProxyClientSocket::DoReadHeadersComplete(int result) { - if (result < 0) { - if (result == ERR_CONNECTION_CLOSED) - result = ERR_TUNNEL_CONNECTION_FAILED; + if (result < 0) return result; - } // Require the "HTTP/1.x" status line for SSL CONNECT. if (response_.headers->GetParsedHttpVersion() < HttpVersion(1, 0)) diff --git a/net/http/http_proxy_client_socket.h b/net/http/http_proxy_client_socket.h index 692cc19e..61e8158 100644 --- a/net/http/http_proxy_client_socket.h +++ b/net/http/http_proxy_client_socket.h @@ -45,6 +45,10 @@ class HttpProxyClientSocket : public ClientSocket { // RestartWithAuth. int RestartWithAuth(CompletionCallback* callback); + // Indicates if RestartWithAuth needs to be called. i.e. if Connect + // returned PROXY_AUTH_REQUESTED. Only valid after Connect has been called. + bool NeedsRestartWithAuth() const; + const HttpResponseInfo* GetResponseInfo() const { return response_.headers ? &response_ : NULL; } diff --git a/net/http/http_proxy_client_socket_pool.h b/net/http/http_proxy_client_socket_pool.h index 7a40424..a23318f 100644 --- a/net/http/http_proxy_client_socket_pool.h +++ b/net/http/http_proxy_client_socket_pool.h @@ -13,6 +13,7 @@ #include "base/time.h" #include "net/base/host_port_pair.h" #include "net/base/host_resolver.h" +#include "net/http/http_auth.h" #include "net/proxy/proxy_server.h" #include "net/socket/client_socket_pool_base.h" #include "net/socket/client_socket_pool_histograms.h" @@ -37,7 +38,7 @@ class HttpProxySocketParams : public base::RefCounted<HttpProxySocketParams> { } const GURL& request_url() const { return request_url_; } const HostPortPair& endpoint() const { return endpoint_; } - const scoped_refptr<HttpAuthController>& auth_controller() const { + const scoped_refptr<HttpAuthController>& auth_controller() { return auth_controller_; } bool tunnel() const { return tunnel_; } @@ -51,6 +52,8 @@ class HttpProxySocketParams : public base::RefCounted<HttpProxySocketParams> { const HostPortPair endpoint_; const scoped_refptr<HttpAuthController> auth_controller_; const bool tunnel_; + + DISALLOW_COPY_AND_ASSIGN(HttpProxySocketParams); }; // HttpProxyConnectJob optionally establishes a tunnel through the proxy diff --git a/net/http/http_proxy_client_socket_pool_unittest.cc b/net/http/http_proxy_client_socket_pool_unittest.cc index 228b6a8..9400dd8 100644 --- a/net/http/http_proxy_client_socket_pool_unittest.cc +++ b/net/http/http_proxy_client_socket_pool_unittest.cc @@ -6,16 +6,9 @@ #include "base/callback.h" #include "base/compiler_specific.h" -#include "base/time.h" -#include "net/base/auth.h" -#include "net/base/mock_host_resolver.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" -#include "net/http/http_auth_controller.h" -#include "net/http/http_network_session.h" -#include "net/http/http_request_headers.h" -#include "net/http/http_response_headers.h" -#include "net/socket/client_socket_factory.h" +#include "net/http/http_proxy_client_socket.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" @@ -28,60 +21,6 @@ namespace { const int kMaxSockets = 32; const int kMaxSocketsPerGroup = 6; -struct MockHttpAuthControllerData { - MockHttpAuthControllerData(std::string header) : auth_header(header) {} - - std::string auth_header; -}; - -class MockHttpAuthController : public HttpAuthController { - public: - MockHttpAuthController() - : HttpAuthController(HttpAuth::AUTH_PROXY, GURL(), - scoped_refptr<HttpNetworkSession>(NULL)), - data_(NULL), - data_index_(0), - data_count_(0) { - } - - void SetMockAuthControllerData(struct MockHttpAuthControllerData* data, - size_t data_length) { - data_ = data; - data_count_ = data_length; - } - - // HttpAuthController methods. - virtual int MaybeGenerateAuthToken(const HttpRequestInfo* request, - CompletionCallback* callback, - const BoundNetLog& net_log) { - return OK; - } - virtual void AddAuthorizationHeader( - HttpRequestHeaders* authorization_headers) { - authorization_headers->AddHeadersFromString(CurrentData().auth_header); - } - virtual int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers, - bool do_not_send_server_auth, - bool establishing_tunnel, - const BoundNetLog& net_log) { - return OK; - } - virtual bool HaveAuthHandler() const { return HaveAuth(); } - virtual bool HaveAuth() const { - return CurrentData().auth_header.size() != 0; } - - private: - virtual ~MockHttpAuthController() {} - const struct MockHttpAuthControllerData& CurrentData() const { - DCHECK(data_index_ < data_count_); - return data_[data_index_]; - } - - MockHttpAuthControllerData* data_; - size_t data_index_; - size_t data_count_; -}; - class HttpProxyClientSocketPoolTest : public ClientSocketPoolTest { protected: HttpProxyClientSocketPoolTest() @@ -131,6 +70,9 @@ TEST_F(HttpProxyClientSocketPoolTest, NoTunnel) { EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(handle.socket()); + EXPECT_FALSE(tunnel_socket->NeedsRestartWithAuth()); } TEST_F(HttpProxyClientSocketPoolTest, NeedAuth) { @@ -166,6 +108,9 @@ TEST_F(HttpProxyClientSocketPoolTest, NeedAuth) { EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(handle.socket()); + EXPECT_TRUE(tunnel_socket->NeedsRestartWithAuth()); } TEST_F(HttpProxyClientSocketPoolTest, HaveAuth) { @@ -196,6 +141,9 @@ TEST_F(HttpProxyClientSocketPoolTest, HaveAuth) { EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(handle.socket()); + EXPECT_FALSE(tunnel_socket->NeedsRestartWithAuth()); } TEST_F(HttpProxyClientSocketPoolTest, AsyncHaveAuth) { @@ -228,6 +176,9 @@ TEST_F(HttpProxyClientSocketPoolTest, AsyncHaveAuth) { EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(handle.socket()); + EXPECT_FALSE(tunnel_socket->NeedsRestartWithAuth()); } TEST_F(HttpProxyClientSocketPoolTest, TCPError) { @@ -277,7 +228,7 @@ TEST_F(HttpProxyClientSocketPoolTest, TunnelUnexpectedClose) { EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); - EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_EQ(ERR_CONNECTION_CLOSED, callback.WaitForResult()); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); } diff --git a/net/net.gyp b/net/net.gyp index dd9ffea..8e859e9 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -473,6 +473,8 @@ 'socket/ssl_client_socket_nss.h', 'socket/ssl_client_socket_nss_factory.cc', 'socket/ssl_client_socket_nss_factory.h', + 'socket/ssl_client_socket_pool.cc', + 'socket/ssl_client_socket_pool.h', 'socket/ssl_client_socket_win.cc', 'socket/ssl_client_socket_win.h', 'socket/tcp_client_socket.h', @@ -753,6 +755,7 @@ 'socket/socks_client_socket_pool_unittest.cc', 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', + 'socket/ssl_client_socket_pool_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', 'socket/tcp_client_socket_unittest.cc', 'socket/tcp_pinger_unittest.cc', diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc index db819db..fbccfcb 100644 --- a/net/socket/client_socket_factory.cc +++ b/net/socket/client_socket_factory.cc @@ -6,6 +6,7 @@ #include "base/singleton.h" #include "build/build_config.h" +#include "net/socket/client_socket_handle.h" #if defined(OS_WIN) #include "net/socket/ssl_client_socket_win.h" #elif defined(USE_NSS) @@ -21,7 +22,7 @@ namespace net { namespace { SSLClientSocket* DefaultSSLClientSocketFactory( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { #if defined(OS_WIN) @@ -52,7 +53,7 @@ class DefaultClientSocketFactory : public ClientSocketFactory { } virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { return g_ssl_factory(transport_socket, hostname, ssl_config); @@ -72,4 +73,14 @@ void ClientSocketFactory::SetSSLClientSocketFactory( g_ssl_factory = factory; } +// Deprecated function (http://crbug.com/37810) that takes a ClientSocket. +SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( + ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) { + ClientSocketHandle* socket_handle = new ClientSocketHandle(); + socket_handle->set_socket(transport_socket); + return CreateSSLClientSocket(socket_handle, hostname, ssl_config); +} + } // namespace net diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h index b519b32..dddf1de 100644 --- a/net/socket/client_socket_factory.h +++ b/net/socket/client_socket_factory.h @@ -11,13 +11,14 @@ namespace net { class AddressList; class ClientSocket; +class ClientSocketHandle; class NetLog; class SSLClientSocket; struct SSLConfig; // Callback function to create new SSLClientSocket objects. typedef SSLClientSocket* (*SSLClientSocketFactory)( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); @@ -31,10 +32,16 @@ class ClientSocketFactory { const AddressList& addresses, NetLog* net_log) = 0; virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) = 0; + + // Deprecated function (http://crbug.com/37810) that takes a ClientSocket. + virtual SSLClientSocket* CreateSSLClientSocket(ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config); + // Returns the default ClientSocketFactory. static ClientSocketFactory* GetDefaultFactory(); diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc index fab8d3e..73142bf 100644 --- a/net/socket/client_socket_handle.cc +++ b/net/socket/client_socket_handle.cc @@ -17,7 +17,8 @@ ClientSocketHandle::ClientSocketHandle() : socket_(NULL), is_reused_(false), ALLOW_THIS_IN_INITIALIZER_LIST( - callback_(this, &ClientSocketHandle::OnIOComplete)) {} + callback_(this, &ClientSocketHandle::OnIOComplete)), + is_ssl_error_(false) {} ClientSocketHandle::~ClientSocketHandle() { Reset(); @@ -25,6 +26,7 @@ ClientSocketHandle::~ClientSocketHandle() { void ClientSocketHandle::Reset() { ResetInternal(true); + ResetErrorState(); } void ClientSocketHandle::ResetInternal(bool cancel) { @@ -53,6 +55,11 @@ void ClientSocketHandle::ResetInternal(bool cancel) { pool_id_ = -1; } +void ClientSocketHandle::ResetErrorState() { + is_ssl_error_ = false; + tunnel_auth_response_info_ = HttpResponseInfo(); +} + LoadState ClientSocketHandle::GetLoadState() const { CHECK(!is_initialized()); CHECK(!group_name_.empty()); diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h index 8ec4eb5..b0cb574 100644 --- a/net/socket/client_socket_handle.h +++ b/net/socket/client_socket_handle.h @@ -16,6 +16,7 @@ #include "net/base/net_errors.h" #include "net/base/net_log.h" #include "net/base/request_priority.h" +#include "net/http/http_response_info.h" #include "net/socket/client_socket.h" #include "net/socket/client_socket_pool.h" @@ -60,6 +61,9 @@ class ClientSocketHandle { // that the error is not recoverable, the Disconnect method should be used // on the socket, so that it does not get reused. // + // A non-recoverable error may set additional state in the ClientSocketHandle + // to allow the caller to determine what went wrong. + // // Init may be called multiple times. // // Profiling information for the request is saved to |net_log| if non-NULL. @@ -100,11 +104,26 @@ class ClientSocketHandle { void set_socket(ClientSocket* s) { socket_.reset(s); } void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; } void set_pool_id(int id) { pool_id_ = id; } + void set_tunnel_auth_response_info( + const scoped_refptr<HttpResponseHeaders>& headers, + const scoped_refptr<AuthChallengeInfo>& auth_challenge) { + tunnel_auth_response_info_.headers = headers; + tunnel_auth_response_info_.auth_challenge = auth_challenge; + } + void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; } // These may only be used if is_initialized() is true. const std::string& group_name() const { return group_name_; } ClientSocket* socket() { return socket_.get(); } ClientSocket* release_socket() { return socket_.release(); } + const HttpResponseInfo& tunnel_auth_response_info() const { + return tunnel_auth_response_info_; + } + // Only valid if there is no |socket_|. + bool is_ssl_error() const { + DCHECK(socket_.get() == NULL); + return is_ssl_error_; + } bool is_reused() const { return is_reused_; } base::TimeDelta idle_time() const { return idle_time_; } SocketReuseType reuse_type() const { @@ -139,9 +158,13 @@ class ClientSocketHandle { void HandleInitCompletion(int result); // Resets the state of the ClientSocketHandle. |cancel| indicates whether or - // not to try to cancel the request with the ClientSocketPool. + // not to try to cancel the request with the ClientSocketPool. Does not + // reset the supplemental error state. void ResetInternal(bool cancel); + // Resets the supplemental error state. + void ResetErrorState(); + scoped_refptr<ClientSocketPool> pool_; scoped_ptr<ClientSocket> socket_; std::string group_name_; @@ -150,6 +173,8 @@ class ClientSocketHandle { CompletionCallback* user_callback_; base::TimeDelta idle_time_; int pool_id_; // See ClientSocketPool::ReleaseSocket() for an explanation. + bool is_ssl_error_; + HttpResponseInfo tunnel_auth_response_info_; base::TimeTicks init_time_; base::TimeDelta setup_time_; @@ -174,6 +199,7 @@ int ClientSocketHandle::Init(const std::string& group_name, // (defined in client_socket_pool.h). CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); ResetInternal(true); + ResetErrorState(); pool_ = pool; group_name_ = group_name; init_time_ = base::TimeTicks::Now(); diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h index 493fff1..b22da31 100644 --- a/net/socket/client_socket_pool.h +++ b/net/socket/client_socket_pool.h @@ -42,8 +42,9 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { // code is returned, but the |handle| is initialized with the new socket. // The caller must recover from the error before using the connection, or // Disconnect the socket before releasing or resetting the |handle|. - // The current recoverable errors are: PROXY_AUTH_REQUESTED and the errors - // accepted by IsCertificateError(err). + // The current recoverable errors are: the errors accepted by + // IsCertificateError(err) and PROXY_AUTH_REQUESTED when reported by + // HttpProxyClientSocketPool. // // If this function returns OK, then |handle| is initialized upon return. // The |handle|'s is_initialized method will return true in this case. If a diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index 7e5c24fc..5b116e6 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -257,6 +257,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( group.jobs.insert(job); } else { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); + connect_job->GetAdditionalErrorState(handle); ClientSocket* error_socket = connect_job->ReleaseSocket(); if (error_socket) { HandOutSocket(error_socket, false /* not reused */, handle, @@ -592,10 +593,10 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( scoped_ptr<ClientSocket> socket(job->ReleaseSocket()); BoundNetLog job_log = job->net_log(); - RemoveConnectJob(job, &group); if (result == OK) { DCHECK(socket.get()); + RemoveConnectJob(job, &group); if (!group.pending_requests.empty()) { scoped_ptr<const Request> r(RemoveRequestFromQueue( group.pending_requests.begin(), &group.pending_requests)); @@ -617,6 +618,8 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( scoped_ptr<const Request> r(RemoveRequestFromQueue( group.pending_requests.begin(), &group.pending_requests)); LogBoundConnectJobToRequest(job_log.source(), r.get()); + job->GetAdditionalErrorState(r->handle()); + RemoveConnectJob(job, &group); if (socket.get()) { handed_out_socket = true; HandOutSocket(socket.release(), false /* unused socket */, r->handle(), @@ -624,7 +627,15 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( } r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL, new NetLogIntegerParameter("net_error", result)); + if (socket.get()) { + handed_out_socket = true; + HandOutSocket( + socket.release(), false /* unused socket */, r->handle(), + base::TimeDelta(), &group, r->net_log()); + } r->callback()->Run(result); + } else { + RemoveConnectJob(job, &group); } if (!handed_out_socket) OnAvailableSocketSlot(group_name, MayHaveStalledGroups()); diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h index d515ca5..77698ce 100644 --- a/net/socket/client_socket_pool_base.h +++ b/net/socket/client_socket_pool_base.h @@ -90,6 +90,11 @@ class ConnectJob { virtual LoadState GetLoadState() const = 0; + // If Connect returns an error (or OnConnectJobComplete reports an error + // result) this method will be called, allowing the pool to add + // additional error state to the ClientSocketHandle (post late-binding). + virtual void GetAdditionalErrorState(ClientSocketHandle* handle) {} + protected: void set_socket(ClientSocket* socket); ClientSocket* socket() { return socket_.get(); } diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 70cc604..53063ba 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -94,7 +94,7 @@ class MockClientSocketFactory : public ClientSocketFactory { } virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { NOTIMPLEMENTED(); @@ -122,6 +122,8 @@ class TestConnectJob : public ConnectJob { kMockAdvancingLoadStateJob, kMockRecoverableJob, kMockPendingRecoverableJob, + kMockAdditionalErrorStateJob, + kMockPendingAdditionalErrorStateJob, }; // The kMockPendingJob uses a slight delay before allowing the connect @@ -140,7 +142,8 @@ class TestConnectJob : public ConnectJob { job_type_(job_type), client_socket_factory_(client_socket_factory), method_factory_(ALLOW_THIS_IN_INITIALIZER_LIST(this)), - load_state_(LOAD_STATE_IDLE) {} + load_state_(LOAD_STATE_IDLE), + store_additional_error_state_(false) {} void Signal() { DoConnect(waiting_success_, true /* async */, false /* recoverable */); @@ -148,6 +151,15 @@ class TestConnectJob : public ConnectJob { virtual LoadState GetLoadState() const { return load_state_; } + virtual void GetAdditionalErrorState(ClientSocketHandle* handle) { + if (store_additional_error_state_) { + // Set all of the additional error state fields in some way. + handle->set_is_ssl_error(true); + scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders("")); + handle->set_tunnel_auth_response_info(headers, NULL); + } + } + private: // ConnectJob methods: @@ -220,6 +232,22 @@ class TestConnectJob : public ConnectJob { true /* recoverable */), 2); return ERR_IO_PENDING; + case kMockAdditionalErrorStateJob: + store_additional_error_state_ = true; + return DoConnect(false /* error */, false /* sync */, + false /* recoverable */); + case kMockPendingAdditionalErrorStateJob: + set_load_state(LOAD_STATE_CONNECTING); + store_additional_error_state_ = true; + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + method_factory_.NewRunnableMethod( + &TestConnectJob::DoConnect, + false /* error */, + true /* async */, + false /* recoverable */), + 2); + return ERR_IO_PENDING; default: NOTREACHED(); set_socket(NULL); @@ -267,6 +295,7 @@ class TestConnectJob : public ConnectJob { MockClientSocketFactory* const client_socket_factory_; ScopedRunnableMethodFactory<TestConnectJob> method_factory_; LoadState load_state_; + bool store_additional_error_state_; DISALLOW_COPY_AND_ASSIGN(TestConnectJob); }; @@ -593,10 +622,16 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { CapturingBoundNetLog log(CapturingNetLog::kUnbounded); TestSocketRequest req(&request_order_, &completion_count_); + // Set the additional error state members to ensure that they get cleared. + req.handle()->set_is_ssl_error(true); + scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders("")); + req.handle()->set_tunnel_auth_response_info(headers, NULL); EXPECT_EQ(ERR_CONNECTION_FAILED, req.handle()->Init("a", params_, kDefaultPriority, &req, pool_, log.bound())); EXPECT_FALSE(req.handle()->socket()); + EXPECT_FALSE(req.handle()->is_ssl_error()); + EXPECT_TRUE(req.handle()->tunnel_auth_response_info().headers.get() == NULL); EXPECT_EQ(3u, log.entries().size()); EXPECT_TRUE(LogContainsBeginEvent( @@ -1352,10 +1387,16 @@ TEST_F(ClientSocketPoolBaseTest, connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); TestSocketRequest req(&request_order_, &completion_count_); CapturingBoundNetLog log(CapturingNetLog::kUnbounded); + // Set the additional error state members to ensure that they get cleared. + req.handle()->set_is_ssl_error(true); + scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders("")); + req.handle()->set_tunnel_auth_response_info(headers, NULL); EXPECT_EQ(ERR_IO_PENDING, req.handle()->Init("a", params_, kDefaultPriority, &req, pool_, log.bound())); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req.handle())); EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); + EXPECT_FALSE(req.handle()->is_ssl_error()); + EXPECT_TRUE(req.handle()->tunnel_auth_response_info().headers.get() == NULL); EXPECT_EQ(3u, log.entries().size()); EXPECT_TRUE(LogContainsBeginEvent( @@ -1548,6 +1589,39 @@ TEST_F(ClientSocketPoolBaseTest, AsyncRecoverable) { req.handle()->Reset(); } +TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateSynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type( + TestConnectJob::kMockAdditionalErrorStateJob); + + TestSocketRequest req(&request_order_, &completion_count_); + EXPECT_EQ(ERR_CONNECTION_FAILED, req.handle()->Init("a", params_, + kDefaultPriority, &req, + pool_, BoundNetLog())); + EXPECT_FALSE(req.handle()->is_initialized()); + EXPECT_FALSE(req.handle()->socket()); + EXPECT_TRUE(req.handle()->is_ssl_error()); + EXPECT_FALSE(req.handle()->tunnel_auth_response_info().headers.get() == NULL); + req.handle()->Reset(); +} + +TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateAsynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockPendingAdditionalErrorStateJob); + TestSocketRequest req(&request_order_, &completion_count_); + EXPECT_EQ(ERR_IO_PENDING, req.handle()->Init("a", params_, kDefaultPriority, + &req, pool_, BoundNetLog())); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", req.handle())); + EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); + EXPECT_FALSE(req.handle()->is_initialized()); + EXPECT_FALSE(req.handle()->socket()); + EXPECT_TRUE(req.handle()->is_ssl_error()); + EXPECT_FALSE(req.handle()->tunnel_auth_response_info().headers.get() == NULL); + req.handle()->Reset(); +} + TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { CreatePoolWithIdleTimeouts( kDefaultMaxSockets, kDefaultMaxSocketsPerGroup, @@ -1735,8 +1809,11 @@ TEST_F(ClientSocketPoolBaseTest, class TestReleasingSocketRequest : public CallbackRunner< Tuple1<int> > { public: - explicit TestReleasingSocketRequest(TestClientSocketPool* pool) - : pool_(pool) {} + TestReleasingSocketRequest(TestClientSocketPool* pool, int expected_result, + bool reset_releasing_handle) + : pool_(pool), + expected_result_(expected_result), + reset_releasing_handle_(reset_releasing_handle) {} ClientSocketHandle* handle() { return &handle_; } @@ -1746,20 +1823,50 @@ class TestReleasingSocketRequest : public CallbackRunner< Tuple1<int> > { virtual void RunWithParams(const Tuple1<int>& params) { callback_.RunWithParams(params); - handle_.Reset(); + if (reset_releasing_handle_) + handle_.Reset(); scoped_refptr<TestSocketParams> con_params = new TestSocketParams(); - EXPECT_EQ(ERR_IO_PENDING, handle2_.Init("a", con_params, kDefaultPriority, - &callback2_, pool_, BoundNetLog())); + EXPECT_EQ(expected_result_, handle2_.Init("a", con_params, kDefaultPriority, + &callback2_, pool_, + BoundNetLog())); } private: scoped_refptr<TestClientSocketPool> pool_; + int expected_result_; + bool reset_releasing_handle_; ClientSocketHandle handle_; ClientSocketHandle handle2_; TestCompletionCallback callback_; TestCompletionCallback callback2_; }; + +TEST_F(ClientSocketPoolBaseTest, AdditionalErrorSocketsDontUseSlot) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + + EXPECT_EQ(static_cast<int>(requests_.size()), + client_socket_factory_.allocation_count()); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockPendingAdditionalErrorStateJob); + TestReleasingSocketRequest req(pool_.get(), OK, false); + EXPECT_EQ(ERR_IO_PENDING, req.handle()->Init("a", params_, kDefaultPriority, + &req, pool_, BoundNetLog())); + // The next job should complete synchronously + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); + EXPECT_FALSE(req.handle()->is_initialized()); + EXPECT_FALSE(req.handle()->socket()); + EXPECT_TRUE(req.handle()->is_ssl_error()); + EXPECT_FALSE(req.handle()->tunnel_auth_response_info().headers.get() == NULL); +} + // http://crbug.com/44724 regression test. // We start releasing the pool when we flush on network change. When that // happens, the only active references are in the ClientSocketHandles. When a diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index c708657..bb40354 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -7,13 +7,18 @@ #include <algorithm> #include <vector> + #include "base/basictypes.h" #include "base/compiler_specific.h" #include "base/message_loop.h" #include "base/time.h" #include "net/base/address_family.h" +#include "net/base/auth.h" #include "net/base/host_resolver_proc.h" #include "net/base/ssl_info.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -334,13 +339,14 @@ class MockSSLClientSocket::ConnectCallback : }; MockSSLClientSocket::MockSSLClientSocket( - net::ClientSocket* transport_socket, + net::ClientSocketHandle* transport_socket, const std::string& hostname, const net::SSLConfig& ssl_config, net::SSLSocketDataProvider* data) - : MockClientSocket(transport_socket->NetLog().net_log()), + : MockClientSocket(transport_socket->socket()->NetLog().net_log()), transport_(transport_socket), - data_(data) { + data_(data), + is_npn_state_set_(false) { DCHECK(data_); } @@ -351,7 +357,7 @@ MockSSLClientSocket::~MockSSLClientSocket() { int MockSSLClientSocket::Connect(net::CompletionCallback* callback) { ConnectCallback* connect_callback = new ConnectCallback( this, callback, data_->connect.result); - int rv = transport_->Connect(connect_callback); + int rv = transport_->socket()->Connect(connect_callback); if (rv == net::OK) { delete connect_callback; if (data_->connect.async) { @@ -367,18 +373,18 @@ int MockSSLClientSocket::Connect(net::CompletionCallback* callback) { void MockSSLClientSocket::Disconnect() { MockClientSocket::Disconnect(); - if (transport_ != NULL) - transport_->Disconnect(); + if (transport_->socket() != NULL) + transport_->socket()->Disconnect(); } int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { - return transport_->Read(buf, buf_len, callback); + return transport_->socket()->Read(buf, buf_len, callback); } int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { - return transport_->Write(buf, buf_len, callback); + return transport_->socket()->Write(buf, buf_len, callback); } void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { @@ -392,9 +398,16 @@ SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( } bool MockSSLClientSocket::wasNpnNegotiated() const { + if (is_npn_state_set_) + return new_npn_value_; return data_->was_npn_negotiated; } +bool MockSSLClientSocket::setWasNpnNegotiated(bool negotiated) { + is_npn_state_set_ = true; + return new_npn_value_ = negotiated; +} + MockRead StaticSocketDataProvider::GetNextRead() { DCHECK(!at_read_eof()); reads_[read_index_].time_stamp = base::Time::Now(); @@ -668,7 +681,7 @@ ClientSocket* MockClientSocketFactory::CreateTCPClientSocket( } SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { MockSSLClientSocket* socket = @@ -839,6 +852,86 @@ void MockTCPClientSocketPool::ReleaseSocket(const std::string& group_name, MockTCPClientSocketPool::~MockTCPClientSocketPool() {} +MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + const scoped_refptr<TCPClientSocketPool>& tcp_pool) + : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms, + NULL, tcp_pool, NULL), + tcp_pool_(tcp_pool) { +} + +int MockSOCKSClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log) { + return tcp_pool_->RequestSocket(group_name, socket_params, priority, handle, + callback, net_log); +} + +void MockSOCKSClientSocketPool::CancelRequest( + const std::string& group_name, + const ClientSocketHandle* handle) { + return tcp_pool_->CancelRequest(group_name, handle); +} + +void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id) { + return tcp_pool_->ReleaseSocket(group_name, socket, id); +} + +MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} + +MockHttpAuthController::MockHttpAuthController() + : HttpAuthController(HttpAuth::AUTH_PROXY, GURL(), + scoped_refptr<HttpNetworkSession>(NULL)), + data_(NULL), + data_index_(0), + data_count_(0) { +} + +void MockHttpAuthController::SetMockAuthControllerData( + struct MockHttpAuthControllerData* data, size_t count) { + data_ = data; + data_count_ = count; +} + +int MockHttpAuthController::MaybeGenerateAuthToken( + const HttpRequestInfo* request, + CompletionCallback* callback, + const BoundNetLog& net_log) { + return OK; +} + +void MockHttpAuthController::AddAuthorizationHeader( + HttpRequestHeaders* authorization_headers) { + authorization_headers->AddHeadersFromString(CurrentData().auth_header); +} + +int MockHttpAuthController::HandleAuthChallenge( + scoped_refptr<HttpResponseHeaders> headers, + bool do_not_send_server_auth, + bool establishing_tunnel, + const BoundNetLog& net_log) { + return OK; +} + +void MockHttpAuthController::ResetAuth(const std::wstring& username, + const std::wstring& password) { + data_index_++; +} + +bool MockHttpAuthController::HaveAuth() const { + return CurrentData().auth_header.size() != 0; +} + +bool MockHttpAuthController::HaveAuthHandler() const { + return HaveAuth(); +} + const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 1087dbd..b9c0f0c 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -21,8 +21,11 @@ #include "net/base/net_log.h" #include "net/base/ssl_config_service.h" #include "net/base/test_completion_callback.h" +#include "net/http/http_auth_controller.h" +#include "net/http/http_proxy_client_socket_pool.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" +#include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/tcp_client_socket_pool.h" #include "testing/gtest/include/gtest/gtest.h" @@ -39,6 +42,8 @@ enum { }; class ClientSocket; +class HttpRequestHeaders; +class HttpResponseHeaders; class MockClientSocket; class SSLClientSocket; @@ -393,7 +398,7 @@ class MockClientSocketFactory : public ClientSocketFactory { virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses, NetLog* net_log); virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); @@ -496,7 +501,7 @@ class MockTCPClientSocket : public MockClientSocket { class MockSSLClientSocket : public MockClientSocket { public: MockSSLClientSocket( - net::ClientSocket* transport_socket, + net::ClientSocketHandle* transport_socket, const std::string& hostname, const net::SSLConfig& ssl_config, net::SSLSocketDataProvider* socket); @@ -516,6 +521,7 @@ class MockSSLClientSocket : public MockClientSocket { virtual void GetSSLInfo(net::SSLInfo* ssl_info); virtual NextProtoStatus GetNextProto(std::string* proto); virtual bool wasNpnNegotiated() const; + virtual bool setWasNpnNegotiated(bool negotiated); // This MockSocket does not implement the manual async IO feature. virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } @@ -523,8 +529,10 @@ class MockSSLClientSocket : public MockClientSocket { private: class ConnectCallback; - scoped_ptr<ClientSocket> transport_; + scoped_ptr<ClientSocketHandle> transport_; net::SSLSocketDataProvider* data_; + bool is_npn_state_set_; + bool new_npn_value_; }; class TestSocketRequest : public CallbackRunner< Tuple1<int> > { @@ -655,6 +663,75 @@ class MockTCPClientSocketPool : public TCPClientSocketPool { DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool); }; +class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { + public: + MockSOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + const scoped_refptr<TCPClientSocketPool>& tcp_pool); + + // SOCKSClientSocketPool methods. + virtual int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log); + + virtual void CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle); + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id); + + protected: + virtual ~MockSOCKSClientSocketPool(); + + private: + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + + DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool); +}; + +struct MockHttpAuthControllerData { + MockHttpAuthControllerData(std::string header) : auth_header(header) {} + + std::string auth_header; +}; + +class MockHttpAuthController : public HttpAuthController { + public: + MockHttpAuthController(); + void SetMockAuthControllerData(struct MockHttpAuthControllerData* data, + size_t data_length); + + // HttpAuthController methods. + virtual int MaybeGenerateAuthToken(const HttpRequestInfo* request, + CompletionCallback* callback, + const BoundNetLog& net_log); + virtual void AddAuthorizationHeader( + HttpRequestHeaders* authorization_headers); + virtual int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers, + bool do_not_send_server_auth, + bool establishing_tunnel, + const BoundNetLog& net_log); + virtual void ResetAuth(const std::wstring& username, + const std::wstring& password); + virtual bool HaveAuthHandler() const; + virtual bool HaveAuth() const; + + private: + virtual ~MockHttpAuthController() {} + const struct MockHttpAuthControllerData& CurrentData() const { + DCHECK(data_index_ < data_count_); + return data_[data_index_]; + } + + MockHttpAuthControllerData* data_; + size_t data_index_; + size_t data_count_; +}; + // Constants for a successful SOCKS v5 handshake. extern const char kSOCKS5GreetRequest[]; extern const int kSOCKS5GreetRequestLength; diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h index 6a2ed80..795c13c 100644 --- a/net/socket/socks_client_socket_pool.h +++ b/net/socket/socks_client_socket_pool.h @@ -46,6 +46,8 @@ class SOCKSSocketParams : public base::RefCounted<SOCKSSocketParams> { // This is the HTTP destination. HostResolver::RequestInfo destination_; const bool socks_v5_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSSocketParams); }; // SOCKSConnectJob handles the handshake to a socks server after setting up diff --git a/net/socket/ssl_client_socket.h b/net/socket/ssl_client_socket.h index 961447b..9c34282 100644 --- a/net/socket/ssl_client_socket.h +++ b/net/socket/ssl_client_socket.h @@ -7,6 +7,8 @@ #include <string> +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" #include "net/socket/client_socket.h" namespace net { @@ -71,6 +73,23 @@ class SSLClientSocket : public ClientSocket { } } + static bool IgnoreCertError(int error, int load_flags) { + if (error == OK || load_flags & LOAD_IGNORE_ALL_CERT_ERRORS) + return true; + + if (error == ERR_CERT_COMMON_NAME_INVALID && + (load_flags & LOAD_IGNORE_CERT_COMMON_NAME_INVALID)) + return true; + if(error == ERR_CERT_DATE_INVALID && + (load_flags & LOAD_IGNORE_CERT_DATE_INVALID)) + return true; + if(error == ERR_CERT_AUTHORITY_INVALID && + (load_flags & LOAD_IGNORE_CERT_AUTHORITY_INVALID)) + return true; + + return false; + } + virtual bool wasNpnNegotiated() const { return was_npn_negotiated_; } diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index 325df61..c3c7d7a 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -20,6 +20,7 @@ #include "net/base/ssl_cert_request_info.h" #include "net/base/ssl_connection_status_flags.h" #include "net/base/ssl_info.h" +#include "net/socket/client_socket_handle.h" // Welcome to Mac SSL. We've been waiting for you. // @@ -497,7 +498,7 @@ EnabledCipherSuites::EnabledCipherSuites() { //----------------------------------------------------------------------------- -SSLClientSocketMac::SSLClientSocketMac(ClientSocket* transport_socket, +SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) : handshake_io_callback_(this, &SSLClientSocketMac::OnHandshakeIOComplete), @@ -519,7 +520,7 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocket* transport_socket, client_cert_requested_(false), ssl_context_(NULL), pending_send_error_(OK), - net_log_(transport_socket->NetLog()) { + net_log_(transport_socket->socket()->NetLog()) { } SSLClientSocketMac::~SSLClientSocketMac() { @@ -561,7 +562,7 @@ void SSLClientSocketMac::Disconnect() { // Shut down anything that may call us back. verifier_.reset(); - transport_->Disconnect(); + transport_->socket()->Disconnect(); } bool SSLClientSocketMac::IsConnected() const { @@ -571,7 +572,7 @@ bool SSLClientSocketMac::IsConnected() const { // layer (HttpNetworkTransaction) needs to handle a persistent connection // closed by the server when we send a request anyway, a false positive in // exchange for simpler code is a good trade-off. - return completed_handshake_ && transport_->IsConnected(); + return completed_handshake_ && transport_->socket()->IsConnected(); } bool SSLClientSocketMac::IsConnectedAndIdle() const { @@ -580,13 +581,14 @@ bool SSLClientSocketMac::IsConnectedAndIdle() const { // Strictly speaking, we should check if we have received the close_notify // alert message from the server, and return false in that case. Although // the close_notify alert message means EOF in the SSL layer, it is just - // bytes to the transport layer below, so transport_->IsConnectedAndIdle() - // returns the desired false when we receive close_notify. - return completed_handshake_ && transport_->IsConnectedAndIdle(); + // bytes to the transport layer below, so + // transport_->socket()->IsConnectedAndIdle() returns the desired false + // when we receive close_notify. + return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); } int SSLClientSocketMac::GetPeerAddress(AddressList* address) const { - return transport_->GetPeerAddress(address); + return transport_->socket()->GetPeerAddress(address); } int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, @@ -628,11 +630,11 @@ int SSLClientSocketMac::Write(IOBuffer* buf, int buf_len, } bool SSLClientSocketMac::SetReceiveBufferSize(int32 size) { - return transport_->SetReceiveBufferSize(size); + return transport_->socket()->SetReceiveBufferSize(size); } bool SSLClientSocketMac::SetSendBufferSize(int32 size) { - return transport_->SetSendBufferSize(size); + return transport_->socket()->SetSendBufferSize(size); } void SSLClientSocketMac::GetSSLInfo(SSLInfo* ssl_info) { @@ -809,7 +811,7 @@ int SSLClientSocketMac::InitializeSSLContext() { // different peers, which puts us through certificate validation again // and catches hostname/certificate name mismatches. AddressList address; - int rv = transport_->GetPeerAddress(&address); + int rv = transport_->socket()->GetPeerAddress(&address); if (rv != OK) return rv; const struct addrinfo* ai = address.head(); @@ -1221,9 +1223,9 @@ OSStatus SSLClientSocketMac::SSLReadCallback(SSLConnectionRef connection, int rv = 1; // any old value to spin the loop below while (rv > 0 && total_read < *data_length) { us->read_io_buf_ = new IOBuffer(*data_length - total_read); - rv = us->transport_->Read(us->read_io_buf_, - *data_length - total_read, - &us->transport_read_callback_); + rv = us->transport_->socket()->Read(us->read_io_buf_, + *data_length - total_read, + &us->transport_read_callback_); if (rv >= 0) { us->recv_buffer_.insert(us->recv_buffer_.end(), @@ -1283,9 +1285,9 @@ OSStatus SSLClientSocketMac::SSLWriteCallback(SSLConnectionRef connection, us->write_io_buf_ = new IOBuffer(us->send_buffer_.size()); memcpy(us->write_io_buf_->data(), &us->send_buffer_[0], us->send_buffer_.size()); - rv = us->transport_->Write(us->write_io_buf_, - us->send_buffer_.size(), - &us->transport_write_callback_); + rv = us->transport_->socket()->Write(us->write_io_buf_, + us->send_buffer_.size(), + &us->transport_write_callback_); if (rv > 0) { us->send_buffer_.erase(us->send_buffer_.begin(), us->send_buffer_.begin() + rv); diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index bb25bda..dc2ed65 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -20,6 +20,7 @@ namespace net { class CertVerifier; +class ClientSocketHandle; // An SSL client socket implemented with Secure Transport. class SSLClientSocketMac : public SSLClientSocket { @@ -28,7 +29,7 @@ class SSLClientSocketMac : public SSLClientSocket { // The given hostname will be compared with the name(s) in the server's // certificate during the SSL handshake. ssl_config specifies the SSL // settings. - SSLClientSocketMac(ClientSocket* transport_socket, + SSLClientSocketMac(ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); ~SSLClientSocketMac(); @@ -88,7 +89,7 @@ class SSLClientSocketMac : public SSLClientSocket { CompletionCallbackImpl<SSLClientSocketMac> transport_read_callback_; CompletionCallbackImpl<SSLClientSocketMac> transport_write_callback_; - scoped_ptr<ClientSocket> transport_; + scoped_ptr<ClientSocketHandle> transport_; std::string hostname_; SSLConfig ssl_config_; diff --git a/net/socket/ssl_client_socket_mac_factory.cc b/net/socket/ssl_client_socket_mac_factory.cc index f2884e9..ec41345 100644 --- a/net/socket/ssl_client_socket_mac_factory.cc +++ b/net/socket/ssl_client_socket_mac_factory.cc @@ -9,7 +9,7 @@ namespace net { SSLClientSocket* SSLClientSocketMacFactory( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { return new SSLClientSocketMac(transport_socket, hostname, ssl_config); diff --git a/net/socket/ssl_client_socket_mac_factory.h b/net/socket/ssl_client_socket_mac_factory.h index 8a0fe0c..dafc40f 100644 --- a/net/socket/ssl_client_socket_mac_factory.h +++ b/net/socket/ssl_client_socket_mac_factory.h @@ -11,7 +11,7 @@ namespace net { // Creates SSLClientSocketMac objects. SSLClientSocket* SSLClientSocketMacFactory( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 5226c56..44c731d 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -75,6 +75,7 @@ #include "net/base/ssl_info.h" #include "net/base/sys_addrinfo.h" #include "net/ocsp/nss_ocsp.h" +#include "net/socket/client_socket_handle.h" static const int kRecvBufferSize = 4096; @@ -279,7 +280,7 @@ bool IsProblematicComodoEVCACert(const CERTCertificate& cert) { HCERTSTORE SSLClientSocketNSS::cert_store_ = NULL; #endif -SSLClientSocketNSS::SSLClientSocketNSS(ClientSocket* transport_socket, +SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( @@ -305,7 +306,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocket* transport_socket, next_handshake_state_(STATE_NONE), nss_fd_(NULL), nss_bufs_(NULL), - net_log_(transport_socket->NetLog()) { + net_log_(transport_socket->socket()->NetLog()) { EnterFunction(""); } @@ -379,7 +380,7 @@ int SSLClientSocketNSS::InitializeSSLOptions() { // Tell NSS who we're connected to AddressList peer_address; - int err = transport_->GetPeerAddress(&peer_address); + int err = transport_->socket()->GetPeerAddress(&peer_address); if (err != OK) return err; @@ -550,7 +551,7 @@ void SSLClientSocketNSS::Disconnect() { // Shut down anything that may call us back (through buffer_send_callback_, // buffer_recv_callback, or handshake_io_callback_). verifier_.reset(); - transport_->Disconnect(); + transport_->socket()->Disconnect(); // Reset object state transport_send_busy_ = false; @@ -584,7 +585,7 @@ bool SSLClientSocketNSS::IsConnected() const { // closed by the server when we send a request anyway, a false positive in // exchange for simpler code is a good trade-off. EnterFunction(""); - bool ret = completed_handshake_ && transport_->IsConnected(); + bool ret = completed_handshake_ && transport_->socket()->IsConnected(); LeaveFunction(""); return ret; } @@ -595,16 +596,17 @@ bool SSLClientSocketNSS::IsConnectedAndIdle() const { // Strictly speaking, we should check if we have received the close_notify // alert message from the server, and return false in that case. Although // the close_notify alert message means EOF in the SSL layer, it is just - // bytes to the transport layer below, so transport_->IsConnectedAndIdle() - // returns the desired false when we receive close_notify. + // bytes to the transport layer below, so + // transport_->socket()->IsConnectedAndIdle() returns the desired false + // when we receive close_notify. EnterFunction(""); - bool ret = completed_handshake_ && transport_->IsConnectedAndIdle(); + bool ret = completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); LeaveFunction(""); return ret; } int SSLClientSocketNSS::GetPeerAddress(AddressList* address) const { - return transport_->GetPeerAddress(address); + return transport_->socket()->GetPeerAddress(address); } int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, @@ -658,11 +660,11 @@ int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, } bool SSLClientSocketNSS::SetReceiveBufferSize(int32 size) { - return transport_->SetReceiveBufferSize(size); + return transport_->socket()->SetReceiveBufferSize(size); } bool SSLClientSocketNSS::SetSendBufferSize(int32 size) { - return transport_->SetSendBufferSize(size); + return transport_->socket()->SetSendBufferSize(size); } #if defined(OS_WIN) @@ -1032,7 +1034,8 @@ int SSLClientSocketNSS::BufferSend(void) { scoped_refptr<IOBuffer> send_buffer = new IOBuffer(nb); memcpy(send_buffer->data(), buf, nb); - int rv = transport_->Write(send_buffer, nb, &buffer_send_callback_); + int rv = transport_->socket()->Write(send_buffer, nb, + &buffer_send_callback_); if (rv == ERR_IO_PENDING) { transport_send_busy_ = true; break; @@ -1072,7 +1075,7 @@ int SSLClientSocketNSS::BufferRecv(void) { rv = ERR_IO_PENDING; } else { recv_buffer_ = new IOBuffer(nb); - rv = transport_->Read(recv_buffer_, nb, &buffer_recv_callback_); + rv = transport_->socket()->Read(recv_buffer_, nb, &buffer_recv_callback_); if (rv == ERR_IO_PENDING) { transport_recv_busy_ = true; } else { diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index cf3b478..60544ea 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -26,6 +26,8 @@ namespace net { class BoundNetLog; class CertVerifier; +class ClientSocketHandle; +class X509Certificate; // An SSL client socket implemented with Mozilla NSS. class SSLClientSocketNSS : public SSLClientSocket { @@ -34,7 +36,7 @@ class SSLClientSocketNSS : public SSLClientSocket { // The given hostname will be compared with the name(s) in the server's // certificate during the SSL handshake. ssl_config specifies the SSL // settings. - SSLClientSocketNSS(ClientSocket* transport_socket, + SSLClientSocketNSS(ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); ~SSLClientSocketNSS(); @@ -116,7 +118,7 @@ class SSLClientSocketNSS : public SSLClientSocket { scoped_refptr<IOBuffer> recv_buffer_; CompletionCallbackImpl<SSLClientSocketNSS> handshake_io_callback_; - scoped_ptr<ClientSocket> transport_; + scoped_ptr<ClientSocketHandle> transport_; std::string hostname_; SSLConfig ssl_config_; diff --git a/net/socket/ssl_client_socket_nss_factory.cc b/net/socket/ssl_client_socket_nss_factory.cc index 7cf73e78..99fb632 100644 --- a/net/socket/ssl_client_socket_nss_factory.cc +++ b/net/socket/ssl_client_socket_nss_factory.cc @@ -18,7 +18,7 @@ namespace net { SSLClientSocket* SSLClientSocketNSSFactory( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { // TODO(wtc): SSLClientSocketNSS can't do SSL client authentication using diff --git a/net/socket/ssl_client_socket_nss_factory.h b/net/socket/ssl_client_socket_nss_factory.h index a835516..b3b99b9 100644 --- a/net/socket/ssl_client_socket_nss_factory.h +++ b/net/socket/ssl_client_socket_nss_factory.h @@ -11,7 +11,7 @@ namespace net { // Creates SSLClientSocketNSS objects. SSLClientSocket* SSLClientSocketNSSFactory( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc new file mode 100644 index 0000000..fedb3ca --- /dev/null +++ b/net/socket/ssl_client_socket_pool.cc @@ -0,0 +1,424 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_client_socket_pool.h" + +#include "net/base/net_errors.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" + +namespace net { + +SSLSocketParams::SSLSocketParams( + const scoped_refptr<TCPSocketParams>& tcp_params, + const scoped_refptr<HttpProxySocketParams>& http_proxy_params, + const scoped_refptr<SOCKSSocketParams>& socks_params, + ProxyServer::Scheme proxy, + const std::string& hostname, + const SSLConfig& ssl_config, + int load_flags, + bool want_spdy) + : tcp_params_(tcp_params), + http_proxy_params_(http_proxy_params), + socks_params_(socks_params), + proxy_(proxy), + hostname_(hostname), + ssl_config_(ssl_config), + load_flags_(load_flags), + want_spdy_(want_spdy) { + switch (proxy_) { + case ProxyServer::SCHEME_DIRECT: + DCHECK(tcp_params_.get() != NULL); + DCHECK(http_proxy_params_.get() == NULL); + DCHECK(socks_params_.get() == NULL); + break; + case ProxyServer::SCHEME_HTTP: + DCHECK(tcp_params_.get() == NULL); + DCHECK(http_proxy_params_.get() != NULL); + DCHECK(socks_params_.get() == NULL); + break; + case ProxyServer::SCHEME_SOCKS4: + case ProxyServer::SCHEME_SOCKS5: + DCHECK(tcp_params_.get() == NULL); + DCHECK(http_proxy_params_.get() == NULL); + DCHECK(socks_params_.get() != NULL); + break; + default: + LOG(DFATAL) << "unknown proxy type"; + break; + } +} + +SSLSocketParams::~SSLSocketParams() {} + +// Timeout for the SSL handshake portion of the connect. +static const int kSSLHandshakeTimeoutInSeconds = 30; + +SSLConnectJob::SSLConnectJob( + const std::string& group_name, + const scoped_refptr<SSLSocketParams>& params, + const base::TimeDelta& timeout_duration, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HttpProxyClientSocketPool>& http_proxy_pool, + const scoped_refptr<SOCKSClientSocketPool>& socks_pool, + ClientSocketFactory* client_socket_factory, + const scoped_refptr<HostResolver>& host_resolver, + Delegate* delegate, + NetLog* net_log) + : ConnectJob(group_name, timeout_duration, delegate, + BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), + params_(params), + tcp_pool_(tcp_pool), + http_proxy_pool_(http_proxy_pool), + socks_pool_(socks_pool), + client_socket_factory_(client_socket_factory), + resolver_(host_resolver), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &SSLConnectJob::OnIOComplete)) {} + +SSLConnectJob::~SSLConnectJob() {} + +LoadState SSLConnectJob::GetLoadState() const { + switch (next_state_) { + case STATE_TCP_CONNECT: + case STATE_TCP_CONNECT_COMPLETE: + case STATE_SOCKS_CONNECT: + case STATE_SOCKS_CONNECT_COMPLETE: + case STATE_TUNNEL_CONNECT: + case STATE_TUNNEL_CONNECT_COMPLETE: + return transport_socket_handle_->GetLoadState(); + case STATE_SSL_CONNECT: + case STATE_SSL_CONNECT_COMPLETE: + return LOAD_STATE_SSL_HANDSHAKE; + default: + NOTREACHED(); + return LOAD_STATE_IDLE; + } +} + +int SSLConnectJob::ConnectInternal() { + DetermineFirstState(); + return DoLoop(OK); +} + +void SSLConnectJob::DetermineFirstState() { + switch (params_->proxy()) { + case ProxyServer::SCHEME_DIRECT: + next_state_ = STATE_TCP_CONNECT; + break; + case ProxyServer::SCHEME_HTTP: + next_state_ = STATE_TUNNEL_CONNECT; + break; + case ProxyServer::SCHEME_SOCKS4: + case ProxyServer::SCHEME_SOCKS5: + next_state_ = STATE_SOCKS_CONNECT; + break; + default: + NOTREACHED() << "unknown proxy type"; + break; + } +} + +void SSLConnectJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + NotifyDelegateOfCompletion(rv); // Deletes |this|. +} + +int SSLConnectJob::DoLoop(int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_TCP_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTCPConnect(); + break; + case STATE_TCP_CONNECT_COMPLETE: + rv = DoTCPConnectComplete(rv); + break; + case STATE_SOCKS_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoSOCKSConnect(); + break; + case STATE_SOCKS_CONNECT_COMPLETE: + rv = DoSOCKSConnectComplete(rv); + break; + case STATE_TUNNEL_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTunnelConnect(); + break; + case STATE_TUNNEL_CONNECT_COMPLETE: + rv = DoTunnelConnectComplete(rv); + break; + case STATE_SSL_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoSSLConnect(); + break; + case STATE_SSL_CONNECT_COMPLETE: + rv = DoSSLConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + return rv; +} + +int SSLConnectJob::DoTCPConnect() { + DCHECK(tcp_pool_.get()); + next_state_ = STATE_TCP_CONNECT_COMPLETE; + transport_socket_handle_.reset(new ClientSocketHandle()); + scoped_refptr<TCPSocketParams> tcp_params = params_->tcp_params(); + return transport_socket_handle_->Init(group_name(), tcp_params, + tcp_params->destination().priority(), + &callback_, tcp_pool_, net_log()); +} + +int SSLConnectJob::DoTCPConnectComplete(int result) { + if (result == OK) + next_state_ = STATE_SSL_CONNECT; + + return result; +} + +int SSLConnectJob::DoSOCKSConnect() { + DCHECK(socks_pool_.get()); + next_state_ = STATE_SOCKS_CONNECT_COMPLETE; + transport_socket_handle_.reset(new ClientSocketHandle()); + scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params(); + return transport_socket_handle_->Init(group_name(), socks_params, + socks_params->destination().priority(), + &callback_, socks_pool_, net_log()); +} + +int SSLConnectJob::DoSOCKSConnectComplete(int result) { + if (result == OK) + next_state_ = STATE_SSL_CONNECT; + + return result; +} + +int SSLConnectJob::DoTunnelConnect() { + DCHECK(http_proxy_pool_.get()); + next_state_ = STATE_TUNNEL_CONNECT_COMPLETE; + transport_socket_handle_.reset(new ClientSocketHandle()); + scoped_refptr<HttpProxySocketParams> http_proxy_params = + params_->http_proxy_params(); + return transport_socket_handle_->Init( + group_name(), http_proxy_params, + http_proxy_params->tcp_params()->destination().priority(), &callback_, + http_proxy_pool_, net_log()); +} + +int SSLConnectJob::DoTunnelConnectComplete(int result) { + ClientSocket* socket = transport_socket_handle_->socket(); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(socket); + + if (result == ERR_RETRY_CONNECTION) { + DetermineFirstState(); + transport_socket_handle_->socket()->Disconnect(); + return OK; + } + + if (result == ERR_PROXY_AUTH_REQUESTED) { + // Extract the information needed to prompt for the proxy authentication. + // so that when ClientSocketPoolBaseHelper calls |GetAdditionalErrorState|, + // we can easily set the state. + const HttpResponseInfo* tunnel_response = tunnel_socket->GetResponseInfo(); + + http_auth_response_headers_ = tunnel_response->headers; + http_auth_auth_challenge_ = tunnel_response->auth_challenge; + } + + if (result < 0) + return result; + + if (tunnel_socket->NeedsRestartWithAuth()) { + // We must have gotten an 'idle' tunnel socket that is waiting for auth. + // The HttpAuthController should have new credentials, we just need + // to retry. + next_state_ = STATE_TUNNEL_CONNECT_COMPLETE; + return tunnel_socket->RestartWithAuth(&callback_); + } + + next_state_ = STATE_SSL_CONNECT; + return result; +} + +void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { + if (http_auth_response_headers_.get() != NULL) + handle->set_tunnel_auth_response_info(http_auth_response_headers_, + http_auth_auth_challenge_); + if (!ssl_connect_start_time_.is_null()) + handle->set_is_ssl_error(true); +} + +int SSLConnectJob::DoSSLConnect() { + next_state_ = STATE_SSL_CONNECT_COMPLETE; + // Reset the timeout to just the time allowed for the SSL handshake. + ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds)); + ssl_connect_start_time_ = base::TimeTicks::Now(); + + ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( + transport_socket_handle_.release(), params_->hostname(), + params_->ssl_config())); + return ssl_socket_->Connect(&callback_); +} + +int SSLConnectJob::DoSSLConnectComplete(int result) { + SSLClientSocket::NextProtoStatus status = + SSLClientSocket::kNextProtoUnsupported; + std::string proto; + // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket + // that hasn't had SSL_ImportFD called on it. If we get a certificate error + // here, then we know that we called SSL_ImportFD. + if (result == OK || IsCertificateError(result)) + status = ssl_socket_->GetNextProto(&proto); + + bool using_spdy = false; + if (status == SSLClientSocket::kNextProtoNegotiated) { + ssl_socket_->setWasNpnNegotiated(true); + if (SSLClientSocket::NextProtoFromString(proto) == + SSLClientSocket::kProtoSPDY1) { + using_spdy = true; + } + } + if (params_->want_spdy() && !using_spdy) + return ERR_NPN_NEGOTIATION_FAILED; + + if (result == OK || + ssl_socket_->IgnoreCertError(result, params_->load_flags())) { + DCHECK(ssl_connect_start_time_ != base::TimeTicks()); + base::TimeDelta connect_duration = + base::TimeTicks::Now() - ssl_connect_start_time_; + if (using_spdy) + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SpdyConnectionLatency", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + else + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + } + if (result == OK || IsCertificateError(result)) + set_socket(ssl_socket_.release()); + + return result; +} + +ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), + tcp_pool_, http_proxy_pool_, socks_pool_, + client_socket_factory_, host_resolver_, delegate, + net_log_); +} + +SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HttpProxyClientSocketPool>& http_proxy_pool, + const scoped_refptr<SOCKSClientSocketPool>& socks_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + NetLog* net_log) + : tcp_pool_(tcp_pool), + http_proxy_pool_(http_proxy_pool), + socks_pool_(socks_pool), + client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver), + net_log_(net_log) { + base::TimeDelta max_transport_timeout = base::TimeDelta(); + base::TimeDelta pool_timeout; + if (tcp_pool_) + max_transport_timeout = tcp_pool_->ConnectionTimeout(); + if (socks_pool_) { + pool_timeout = socks_pool_->ConnectionTimeout(); + if (pool_timeout > max_transport_timeout) + max_transport_timeout = pool_timeout; + } + if (http_proxy_pool_) { + pool_timeout = http_proxy_pool_->ConnectionTimeout(); + if (pool_timeout > max_transport_timeout) + max_transport_timeout = pool_timeout; + } + timeout_ = max_transport_timeout + + base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds); +} + +SSLClientSocketPool::SSLClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + const scoped_refptr<HostResolver>& host_resolver, + ClientSocketFactory* client_socket_factory, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HttpProxyClientSocketPool>& http_proxy_pool, + const scoped_refptr<SOCKSClientSocketPool>& socks_pool, + NetLog* net_log) + : base_(max_sockets, max_sockets_per_group, histograms, + base::TimeDelta::FromSeconds( + ClientSocketPool::unused_idle_socket_timeout()), + base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), + new SSLConnectJobFactory(tcp_pool, http_proxy_pool, socks_pool, + client_socket_factory, host_resolver, + net_log)) {} + +SSLClientSocketPool::~SSLClientSocketPool() {} + +int SSLClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log) { + const scoped_refptr<SSLSocketParams>* casted_socket_params = + static_cast<const scoped_refptr<SSLSocketParams>*>(socket_params); + + return base_.RequestSocket(group_name, *casted_socket_params, priority, + handle, callback, net_log); +} + +void SSLClientSocketPool::CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle) { + base_.CancelRequest(group_name, handle); +} + +void SSLClientSocketPool::ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id) { + base_.ReleaseSocket(group_name, socket, id); +} + +void SSLClientSocketPool::Flush() { + base_.Flush(); +} + +void SSLClientSocketPool::CloseIdleSockets() { + base_.CloseIdleSockets(); +} + +int SSLClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return base_.IdleSocketCountInGroup(group_name); +} + +LoadState SSLClientSocketPool::GetLoadState( + const std::string& group_name, const ClientSocketHandle* handle) const { + return base_.GetLoadState(group_name, handle); +} + +} // namespace net diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h new file mode 100644 index 0000000..cba012c --- /dev/null +++ b/net/socket/ssl_client_socket_pool.h @@ -0,0 +1,248 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ + +#include <string> + +#include "base/ref_counted.h" +#include "base/scoped_ptr.h" +#include "base/time.h" +#include "net/base/host_resolver.h" +#include "net/base/ssl_config_service.h" +#include "net/http/http_proxy_client_socket.h" +#include "net/http/http_proxy_client_socket_pool.h" +#include "net/proxy/proxy_server.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_client_socket_pool.h" + +namespace net { + +class ClientSocketFactory; +class ConnectJobFactory; + +// SSLSocketParams only needs the socket params for the transport socket +// that will be used (denoted by |proxy|). +class SSLSocketParams : public base::RefCounted<SSLSocketParams> { + public: + SSLSocketParams(const scoped_refptr<TCPSocketParams>& tcp_params, + const scoped_refptr<HttpProxySocketParams>& http_proxy_params, + const scoped_refptr<SOCKSSocketParams>& socks_params, + ProxyServer::Scheme proxy, + const std::string& hostname, + const SSLConfig& ssl_config, + int load_flags, + bool want_spdy); + + const scoped_refptr<TCPSocketParams>& tcp_params() { return tcp_params_; } + const scoped_refptr<HttpProxySocketParams>& http_proxy_params () { + return http_proxy_params_; + } + const scoped_refptr<SOCKSSocketParams>& socks_params() { + return socks_params_; + } + ProxyServer::Scheme proxy() const { return proxy_; } + const std::string& hostname() const { return hostname_; } + const SSLConfig& ssl_config() const { return ssl_config_; } + int load_flags() const { return load_flags_; } + bool want_spdy() const { return want_spdy_; } + + private: + friend class base::RefCounted<SSLSocketParams>; + ~SSLSocketParams(); + + const scoped_refptr<TCPSocketParams> tcp_params_; + const scoped_refptr<HttpProxySocketParams> http_proxy_params_; + const scoped_refptr<SOCKSSocketParams> socks_params_; + const ProxyServer::Scheme proxy_; + const std::string hostname_; + const SSLConfig ssl_config_; + const int load_flags_; + const bool want_spdy_; + + DISALLOW_COPY_AND_ASSIGN(SSLSocketParams); +}; + +// SSLConnectJob handles the SSL handshake after setting up the underlying +// connection as specified in the params. +class SSLConnectJob : public ConnectJob { + public: + SSLConnectJob( + const std::string& group_name, + const scoped_refptr<SSLSocketParams>& params, + const base::TimeDelta& timeout_duration, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HttpProxyClientSocketPool>& http_proxy_pool, + const scoped_refptr<SOCKSClientSocketPool>& socks_pool, + ClientSocketFactory* client_socket_factory, + const scoped_refptr<HostResolver>& host_resolver, + Delegate* delegate, + NetLog* net_log); + virtual ~SSLConnectJob(); + + // ConnectJob methods. + virtual LoadState GetLoadState() const; + + virtual void GetAdditionalErrorState(ClientSocketHandle * handle); + + private: + enum State { + STATE_TCP_CONNECT, + STATE_TCP_CONNECT_COMPLETE, + STATE_SOCKS_CONNECT, + STATE_SOCKS_CONNECT_COMPLETE, + STATE_TUNNEL_CONNECT, + STATE_TUNNEL_CONNECT_COMPLETE, + STATE_SSL_CONNECT, + STATE_SSL_CONNECT_COMPLETE, + STATE_NONE, + }; + + // Starts the SSL connection process. Returns OK on success and + // ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal(); + + void DetermineFirstState(); + + void OnIOComplete(int result); + + // Runs the state transition loop. + int DoLoop(int result); + + int DoTCPConnect(); + int DoTCPConnectComplete(int result); + int DoSOCKSConnect(); + int DoSOCKSConnectComplete(int result); + int DoTunnelConnect(); + int DoTunnelConnectComplete(int result); + int DoSSLConnect(); + int DoSSLConnectComplete(int result); + + scoped_refptr<SSLSocketParams> params_; + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + const scoped_refptr<HttpProxyClientSocketPool> http_proxy_pool_; + const scoped_refptr<SOCKSClientSocketPool> socks_pool_; + ClientSocketFactory* const client_socket_factory_; + const scoped_refptr<HostResolver> resolver_; + + State next_state_; + CompletionCallbackImpl<SSLConnectJob> callback_; + scoped_ptr<ClientSocketHandle> transport_socket_handle_; + scoped_ptr<SSLClientSocket> ssl_socket_; + + // The time the DoSSLConnect() method was called. + base::TimeTicks ssl_connect_start_time_; + + scoped_refptr<HttpResponseHeaders> http_auth_response_headers_; + scoped_refptr<AuthChallengeInfo> http_auth_auth_challenge_; + + DISALLOW_COPY_AND_ASSIGN(SSLConnectJob); +}; + +class SSLClientSocketPool : public ClientSocketPool { + public: + // Only the pools that will be used are required. i.e. if you never + // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. + SSLClientSocketPool( + int max_sockets, + int max_sockets_per_group, + const scoped_refptr<ClientSocketPoolHistograms>& histograms, + const scoped_refptr<HostResolver>& host_resolver, + ClientSocketFactory* client_socket_factory, + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HttpProxyClientSocketPool>& http_proxy_pool, + const scoped_refptr<SOCKSClientSocketPool>& socks_pool, + NetLog* net_log); + + // ClientSocketPool methods: + virtual int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log); + + virtual void CancelRequest(const std::string& group_name, + const ClientSocketHandle* handle); + + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket, + int id); + + virtual void Flush(); + + virtual void CloseIdleSockets(); + + virtual int IdleSocketCount() const { + return base_.idle_socket_count(); + } + + virtual int IdleSocketCountInGroup(const std::string& group_name) const; + + virtual LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const; + + virtual base::TimeDelta ConnectionTimeout() const { + return base_.ConnectionTimeout(); + } + + virtual scoped_refptr<ClientSocketPoolHistograms> histograms() const { + return base_.histograms(); + }; + + protected: + virtual ~SSLClientSocketPool(); + + private: + typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; + + class SSLConnectJobFactory : public PoolBase::ConnectJobFactory { + public: + SSLConnectJobFactory( + const scoped_refptr<TCPClientSocketPool>& tcp_pool, + const scoped_refptr<HttpProxyClientSocketPool>& http_proxy_pool, + const scoped_refptr<SOCKSClientSocketPool>& socks_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + NetLog* net_log); + + virtual ~SSLConnectJobFactory() {} + + // ClientSocketPoolBase::ConnectJobFactory methods. + virtual ConnectJob* NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const; + + virtual base::TimeDelta ConnectionTimeout() const { return timeout_; } + + private: + const scoped_refptr<TCPClientSocketPool> tcp_pool_; + const scoped_refptr<HttpProxyClientSocketPool> http_proxy_pool_; + const scoped_refptr<SOCKSClientSocketPool> socks_pool_; + ClientSocketFactory* const client_socket_factory_; + const scoped_refptr<HostResolver> host_resolver_; + base::TimeDelta timeout_; + NetLog* net_log_; + + DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory); + }; + + PoolBase base_; + + DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); +}; + +REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams); + +} // namespace net + +#endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ diff --git a/net/socket/ssl_client_socket_pool_unittest.cc b/net/socket/ssl_client_socket_pool_unittest.cc new file mode 100644 index 0000000..30ff5da --- /dev/null +++ b/net/socket/ssl_client_socket_pool_unittest.cc @@ -0,0 +1,720 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/http/http_proxy_client_socket_pool.h" + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/time.h" +#include "net/base/auth.h" +#include "net/base/mock_host_resolver.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/base/ssl_config_service_defaults.h" +#include "net/http/http_auth_controller.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; + +class SSLClientSocketPoolTest : public ClientSocketPoolTest { + protected: + SSLClientSocketPoolTest() + : direct_tcp_socket_params_(new TCPSocketParams( + HostPortPair("host", 443), MEDIUM, GURL(), false)), + tcp_socket_pool_(new MockTCPClientSocketPool( + kMaxSockets, + kMaxSocketsPerGroup, + make_scoped_refptr(new ClientSocketPoolHistograms("MockTCP")), + &socket_factory_)), + proxy_tcp_socket_params_(new TCPSocketParams( + HostPortPair("proxy", 443), MEDIUM, GURL(), false)), + http_proxy_socket_pool_(new HttpProxyClientSocketPool( + kMaxSockets, + kMaxSocketsPerGroup, + make_scoped_refptr(new ClientSocketPoolHistograms("MockHttpProxy")), + new MockHostResolver, + tcp_socket_pool_, + NULL)), + socks_socket_params_(new SOCKSSocketParams( + proxy_tcp_socket_params_, true, HostPortPair("sockshost", 443), + MEDIUM, GURL())), + socks_socket_pool_(new MockSOCKSClientSocketPool( + kMaxSockets, + kMaxSocketsPerGroup, + make_scoped_refptr(new ClientSocketPoolHistograms("MockSOCKS")), + tcp_socket_pool_)) { + scoped_refptr<SSLConfigService> ssl_config_service( + new SSLConfigServiceDefaults); + ssl_config_service->GetSSLConfig(&ssl_config_); + } + + void CreatePool(bool tcp_pool, bool http_proxy_pool, bool socks_pool) { + pool_ = new SSLClientSocketPool( + kMaxSockets, + kMaxSocketsPerGroup, + make_scoped_refptr(new ClientSocketPoolHistograms("SSLUnitTest")), + NULL, + &socket_factory_, + tcp_pool ? tcp_socket_pool_ : NULL, + http_proxy_pool ? http_proxy_socket_pool_ : NULL, + socks_pool ? socks_socket_pool_ : NULL, + NULL); + } + + scoped_refptr<SSLSocketParams> SSLParams( + ProxyServer::Scheme proxy, struct MockHttpAuthControllerData* auth_data, + size_t auth_data_len, bool want_spdy) { + + scoped_refptr<HttpProxySocketParams> http_proxy_params; + if (proxy == ProxyServer::SCHEME_HTTP) { + scoped_refptr<MockHttpAuthController> auth_controller = + new MockHttpAuthController(); + auth_controller->SetMockAuthControllerData(auth_data, auth_data_len); + http_proxy_params = new HttpProxySocketParams(proxy_tcp_socket_params_, + GURL("http://host"), + HostPortPair("host", 80), + auth_controller, true); + } + + return make_scoped_refptr(new SSLSocketParams( + proxy == ProxyServer::SCHEME_DIRECT ? direct_tcp_socket_params_ : NULL, + http_proxy_params, + proxy == ProxyServer::SCHEME_SOCKS5 ? socks_socket_params_ : NULL, + proxy, + "host", + ssl_config_, + 0, + want_spdy)); + } + + MockClientSocketFactory socket_factory_; + + scoped_refptr<TCPSocketParams> direct_tcp_socket_params_; + scoped_refptr<MockTCPClientSocketPool> tcp_socket_pool_; + + scoped_refptr<TCPSocketParams> proxy_tcp_socket_params_; + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_; + scoped_refptr<HttpProxyClientSocketPool> http_proxy_socket_pool_; + + scoped_refptr<SOCKSSocketParams> socks_socket_params_; + scoped_refptr<MockSOCKSClientSocketPool> socks_socket_pool_; + + SSLConfig ssl_config_; + scoped_refptr<SSLClientSocketPool> pool_; +}; + +TEST_F(SSLClientSocketPoolTest, TCPFail) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(false, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + int rv = handle.Init("a", params, MEDIUM, NULL, pool_, BoundNetLog()); + EXPECT_EQ(ERR_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, TCPFailAsync) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(true, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, BasicDirect) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(false, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(false, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, BasicDirectAsync) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, DirectCertError) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, ERR_CERT_COMMON_NAME_INVALID); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, DirectSSLError) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, ERR_SSL_PROTOCOL_ERROR); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_TRUE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, DirectWithNPN) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, OK); + ssl.next_proto_status = SSLClientSocket::kNextProtoNegotiated; + ssl.next_proto = "http/1.1"; + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); + EXPECT_TRUE(ssl_socket->wasNpnNegotiated()); +} + +TEST_F(SSLClientSocketPoolTest, DirectNoSPDY) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, OK); + ssl.next_proto_status = SSLClientSocket::kNextProtoNegotiated; + ssl.next_proto = "http/1.1"; + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, true); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_NPN_NEGOTIATION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_TRUE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, DirectGotSPDY) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, OK); + ssl.next_proto_status = SSLClientSocket::kNextProtoNegotiated; + ssl.next_proto = "spdy/1"; + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, true); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + + SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); + EXPECT_TRUE(ssl_socket->wasNpnNegotiated()); + std::string proto; + ssl_socket->GetNextProto(&proto); + EXPECT_EQ(SSLClientSocket::NextProtoFromString(proto), + SSLClientSocket::kProtoSPDY1); +} + +TEST_F(SSLClientSocketPoolTest, DirectGotBonusSPDY) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, OK); + ssl.next_proto_status = SSLClientSocket::kNextProtoNegotiated; + ssl.next_proto = "spdy/1"; + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + + SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); + EXPECT_TRUE(ssl_socket->wasNpnNegotiated()); + std::string proto; + ssl_socket->GetNextProto(&proto); + EXPECT_EQ(SSLClientSocket::NextProtoFromString(proto), + SSLClientSocket::kProtoSPDY1); +} + +TEST_F(SSLClientSocketPoolTest, SOCKSFail) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(false, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, SOCKSFailAsync) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(true, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, SOCKSBasic) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(false, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(false, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, SOCKSBasicAsync) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(true, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, HttpProxyFail) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(false, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, HttpProxyFailAsync) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(true, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + NULL, 0, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_F(SSLClientSocketPoolTest, HttpProxyBasic) { + MockWrite writes[] = { + MockWrite(false, + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead(false, "HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + data.set_connect_data(MockConnect(false, OK)); + socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + SSLSocketDataProvider ssl(false, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + auth_data, + arraysize(auth_data), + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, HttpProxyBasicAsync) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + SSLSocketDataProvider ssl(true, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + auth_data, + arraysize(auth_data), + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, NeedProxyAuth) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), + MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), + MockRead("Content-Length: 10\r\n\r\n"), + MockRead("0123456789"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData(""), + }; + SSLSocketDataProvider ssl(true, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + auth_data, + arraysize(auth_data), + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); + const HttpResponseInfo& tunnel_info = handle.tunnel_auth_response_info(); + EXPECT_EQ(tunnel_info.headers->response_code(), 407); +} + +TEST_F(SSLClientSocketPoolTest, DoProxyAuth) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), + MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), + MockRead("Content-Length: 10\r\n\r\n"), + MockRead("0123456789"), + MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + socket_factory_.AddSocketDataProvider(&data); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData(""), + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + SSLSocketDataProvider ssl(true, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + auth_data, + arraysize(auth_data), + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); + const HttpResponseInfo& tunnel_info = handle.tunnel_auth_response_info(); + EXPECT_EQ(tunnel_info.headers->response_code(), 407); + + params->http_proxy_params()->auth_controller()->ResetAuth(std::wstring(), + std::wstring()); + rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(SSLClientSocketPoolTest, DoProxyAuthNoKeepAlive) { + MockWrite writes1[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + MockWrite writes2[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJheg==\r\n\r\n"), + }; + MockRead reads1[] = { + MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), + MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n\r\n"), + MockRead("Content0123456789"), + }; + MockRead reads2[] = { + MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data1(reads1, arraysize(reads1), writes1, + arraysize(writes1)); + socket_factory_.AddSocketDataProvider(&data1); + StaticSocketDataProvider data2(reads2, arraysize(reads2), writes2, + arraysize(writes2)); + socket_factory_.AddSocketDataProvider(&data2); + MockHttpAuthControllerData auth_data[] = { + MockHttpAuthControllerData(""), + MockHttpAuthControllerData("Proxy-Authorization: Basic Zm9vOmJheg=="), + }; + SSLSocketDataProvider ssl(true, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + auth_data, + arraysize(auth_data), + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); + const HttpResponseInfo& tunnel_info = handle.tunnel_auth_response_info(); + EXPECT_EQ(tunnel_info.headers->response_code(), 407); + + params->http_proxy_params()->auth_controller()->ResetAuth(std::wstring(), + std::wstring()); + rv = handle.Init("a", params, MEDIUM, &callback, pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +// It would be nice to also test the timeouts in SSLClientSocketPool. + +} // namespace + +} // namespace net diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 9a4be48..0484ebd 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -19,6 +19,7 @@ #include "net/base/ssl_cert_request_info.h" #include "net/base/ssl_connection_status_flags.h" #include "net/base/ssl_info.h" +#include "net/socket/client_socket_handle.h" #pragma comment(lib, "secur32.lib") @@ -293,7 +294,7 @@ class ClientCertStore { // 64: >= SSL record trailer (16 or 20 have been observed) static const int kRecvBufferSize = (5 + 16*1024 + 64); -SSLClientSocketWin::SSLClientSocketWin(ClientSocket* transport_socket, +SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) : ALLOW_THIS_IN_INITIALIZER_LIST( @@ -324,7 +325,7 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocket* transport_socket, ignore_ok_result_(false), renegotiating_(false), need_more_data_(false), - net_log_(transport_socket->NetLog()) { + net_log_(transport_socket->socket()->NetLog()) { memset(&stream_sizes_, 0, sizeof(stream_sizes_)); memset(in_buffers_, 0, sizeof(in_buffers_)); memset(&send_buffer_, 0, sizeof(send_buffer_)); @@ -529,7 +530,7 @@ void SSLClientSocketWin::Disconnect() { // Shut down anything that may call us back. verifier_.reset(); - transport_->Disconnect(); + transport_->socket()->Disconnect(); if (send_buffer_.pvBuffer) FreeSendBuffer(); @@ -555,7 +556,7 @@ bool SSLClientSocketWin::IsConnected() const { // layer (HttpNetworkTransaction) needs to handle a persistent connection // closed by the server when we send a request anyway, a false positive in // exchange for simpler code is a good trade-off. - return completed_handshake() && transport_->IsConnected(); + return completed_handshake() && transport_->socket()->IsConnected(); } bool SSLClientSocketWin::IsConnectedAndIdle() const { @@ -564,13 +565,14 @@ bool SSLClientSocketWin::IsConnectedAndIdle() const { // Strictly speaking, we should check if we have received the close_notify // alert message from the server, and return false in that case. Although // the close_notify alert message means EOF in the SSL layer, it is just - // bytes to the transport layer below, so transport_->IsConnectedAndIdle() - // returns the desired false when we receive close_notify. - return completed_handshake() && transport_->IsConnectedAndIdle(); + // bytes to the transport layer below, so + // transport_->socket()->IsConnectedAndIdle() returns the desired false + // when we receive close_notify. + return completed_handshake() && transport_->socket()->IsConnectedAndIdle(); } int SSLClientSocketWin::GetPeerAddress(AddressList* address) const { - return transport_->GetPeerAddress(address); + return transport_->socket()->GetPeerAddress(address); } int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, @@ -637,11 +639,11 @@ int SSLClientSocketWin::Write(IOBuffer* buf, int buf_len, } bool SSLClientSocketWin::SetReceiveBufferSize(int32 size) { - return transport_->SetReceiveBufferSize(size); + return transport_->socket()->SetReceiveBufferSize(size); } bool SSLClientSocketWin::SetSendBufferSize(int32 size) { - return transport_->SetSendBufferSize(size); + return transport_->socket()->SetSendBufferSize(size); } void SSLClientSocketWin::OnHandshakeIOComplete(int result) { @@ -756,8 +758,8 @@ int SSLClientSocketWin::DoHandshakeRead() { DCHECK(!transport_read_buf_); transport_read_buf_ = new IOBuffer(buf_len); - return transport_->Read(transport_read_buf_, buf_len, - &handshake_io_callback_); + return transport_->socket()->Read(transport_read_buf_, buf_len, + &handshake_io_callback_); } int SSLClientSocketWin::DoHandshakeReadComplete(int result) { @@ -923,8 +925,8 @@ int SSLClientSocketWin::DoHandshakeWrite() { transport_write_buf_ = new IOBuffer(buf_len); memcpy(transport_write_buf_->data(), buf, buf_len); - return transport_->Write(transport_write_buf_, buf_len, - &handshake_io_callback_); + return transport_->socket()->Write(transport_write_buf_, buf_len, + &handshake_io_callback_); } int SSLClientSocketWin::DoHandshakeWriteComplete(int result) { @@ -1018,7 +1020,8 @@ int SSLClientSocketWin::DoPayloadRead() { DCHECK(!transport_read_buf_); transport_read_buf_ = new IOBuffer(buf_len); - rv = transport_->Read(transport_read_buf_, buf_len, &read_callback_); + rv = transport_->socket()->Read(transport_read_buf_, buf_len, + &read_callback_); if (rv != ERR_IO_PENDING) rv = DoPayloadReadComplete(rv); if (rv <= 0) @@ -1253,7 +1256,8 @@ int SSLClientSocketWin::DoPayloadWrite() { transport_write_buf_ = new IOBuffer(buf_len); memcpy(transport_write_buf_->data(), buf, buf_len); - int rv = transport_->Write(transport_write_buf_, buf_len, &write_callback_); + int rv = transport_->socket()->Write(transport_write_buf_, buf_len, + &write_callback_); if (rv != ERR_IO_PENDING) rv = DoPayloadWriteComplete(rv); return rv; diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 3a273fe0..b4a0bad 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -23,6 +23,7 @@ namespace net { class CertVerifier; +class ClientSocketHandle; class BoundNetLog; // An SSL client socket implemented with the Windows Schannel. @@ -32,7 +33,7 @@ class SSLClientSocketWin : public SSLClientSocket { // The given hostname will be compared with the name(s) in the server's // certificate during the SSL handshake. ssl_config specifies the SSL // settings. - SSLClientSocketWin(ClientSocket* transport_socket, + SSLClientSocketWin(ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); ~SSLClientSocketWin(); @@ -96,7 +97,7 @@ class SSLClientSocketWin : public SSLClientSocket { CompletionCallbackImpl<SSLClientSocketWin> read_callback_; CompletionCallbackImpl<SSLClientSocketWin> write_callback_; - scoped_ptr<ClientSocket> transport_; + scoped_ptr<ClientSocketHandle> transport_; std::string hostname_; SSLConfig ssl_config_; diff --git a/net/socket/tcp_client_socket_pool.h b/net/socket/tcp_client_socket_pool.h index 49edfed..99513fd 100644 --- a/net/socket/tcp_client_socket_pool.h +++ b/net/socket/tcp_client_socket_pool.h @@ -49,6 +49,8 @@ class TCPSocketParams : public base::RefCounted<TCPSocketParams> { } HostResolver::RequestInfo destination_; + + DISALLOW_COPY_AND_ASSIGN(TCPSocketParams); }; // TCPConnectJob handles the host resolution necessary for socket creation diff --git a/net/socket/tcp_client_socket_pool_unittest.cc b/net/socket/tcp_client_socket_pool_unittest.cc index 2b3408c..9516f9f 100644 --- a/net/socket/tcp_client_socket_pool_unittest.cc +++ b/net/socket/tcp_client_socket_pool_unittest.cc @@ -229,7 +229,7 @@ class MockClientSocketFactory : public ClientSocketFactory { } virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocket* transport_socket, + ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { NOTIMPLEMENTED(); |