// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/socket/ssl_client_socket_pool.h" #include "base/metrics/histogram.h" #include "base/values.h" #include "net/base/net_errors.h" #include "net/base/host_port_pair.h" #include "net/base/ssl_cert_request_info.h" #include "net/http/http_proxy_client_socket.h" #include "net/http/http_proxy_client_socket_pool.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/ssl_host_info.h" #include "net/socket/tcp_client_socket_pool.h" namespace net { SSLSocketParams::SSLSocketParams( const scoped_refptr& tcp_params, const scoped_refptr& socks_params, const scoped_refptr& http_proxy_params, ProxyServer::Scheme proxy, const HostPortPair& host_and_port, const SSLConfig& ssl_config, int load_flags, bool force_spdy_over_ssl, bool want_spdy_over_npn) : tcp_params_(tcp_params), http_proxy_params_(http_proxy_params), socks_params_(socks_params), proxy_(proxy), host_and_port_(host_and_port), ssl_config_(ssl_config), load_flags_(load_flags), force_spdy_over_ssl_(force_spdy_over_ssl), want_spdy_over_npn_(want_spdy_over_npn) { switch (proxy_) { case ProxyServer::SCHEME_DIRECT: DCHECK(tcp_params_.get() != NULL); DCHECK(http_proxy_params_.get() == NULL); DCHECK(socks_params_.get() == NULL); break; case ProxyServer::SCHEME_HTTP: case ProxyServer::SCHEME_HTTPS: DCHECK(tcp_params_.get() == NULL); DCHECK(http_proxy_params_.get() != NULL); DCHECK(socks_params_.get() == NULL); break; case ProxyServer::SCHEME_SOCKS4: case ProxyServer::SCHEME_SOCKS5: DCHECK(tcp_params_.get() == NULL); DCHECK(http_proxy_params_.get() == NULL); DCHECK(socks_params_.get() != NULL); break; default: LOG(DFATAL) << "unknown proxy type"; break; } } SSLSocketParams::~SSLSocketParams() {} // Timeout for the SSL handshake portion of the connect. static const int kSSLHandshakeTimeoutInSeconds = 30; SSLConnectJob::SSLConnectJob( const std::string& group_name, const scoped_refptr& params, const base::TimeDelta& timeout_duration, TCPClientSocketPool* tcp_pool, SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, Delegate* delegate, NetLog* net_log) : ConnectJob(group_name, timeout_duration, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), tcp_pool_(tcp_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), client_socket_factory_(client_socket_factory), resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), ALLOW_THIS_IN_INITIALIZER_LIST( callback_(this, &SSLConnectJob::OnIOComplete)) {} SSLConnectJob::~SSLConnectJob() {} LoadState SSLConnectJob::GetLoadState() const { switch (next_state_) { case STATE_TUNNEL_CONNECT_COMPLETE: if (transport_socket_handle_->socket()) return LOAD_STATE_ESTABLISHING_PROXY_TUNNEL; // else, fall through. case STATE_TCP_CONNECT: case STATE_TCP_CONNECT_COMPLETE: case STATE_SOCKS_CONNECT: case STATE_SOCKS_CONNECT_COMPLETE: case STATE_TUNNEL_CONNECT: return transport_socket_handle_->GetLoadState(); case STATE_SSL_CONNECT: case STATE_SSL_CONNECT_COMPLETE: return LOAD_STATE_SSL_HANDSHAKE; default: NOTREACHED(); return LOAD_STATE_IDLE; } } int SSLConnectJob::ConnectInternal() { switch (params_->proxy()) { case ProxyServer::SCHEME_DIRECT: next_state_ = STATE_TCP_CONNECT; break; case ProxyServer::SCHEME_HTTP: case ProxyServer::SCHEME_HTTPS: next_state_ = STATE_TUNNEL_CONNECT; break; case ProxyServer::SCHEME_SOCKS4: case ProxyServer::SCHEME_SOCKS5: next_state_ = STATE_SOCKS_CONNECT; break; default: NOTREACHED() << "unknown proxy type"; break; } return DoLoop(OK); } void SSLConnectJob::OnIOComplete(int result) { int rv = DoLoop(result); if (rv != ERR_IO_PENDING) NotifyDelegateOfCompletion(rv); // Deletes |this|. } int SSLConnectJob::DoLoop(int result) { DCHECK_NE(next_state_, STATE_NONE); int rv = result; do { State state = next_state_; next_state_ = STATE_NONE; switch (state) { case STATE_TCP_CONNECT: DCHECK_EQ(OK, rv); rv = DoTCPConnect(); break; case STATE_TCP_CONNECT_COMPLETE: rv = DoTCPConnectComplete(rv); break; case STATE_SOCKS_CONNECT: DCHECK_EQ(OK, rv); rv = DoSOCKSConnect(); break; case STATE_SOCKS_CONNECT_COMPLETE: rv = DoSOCKSConnectComplete(rv); break; case STATE_TUNNEL_CONNECT: DCHECK_EQ(OK, rv); rv = DoTunnelConnect(); break; case STATE_TUNNEL_CONNECT_COMPLETE: rv = DoTunnelConnectComplete(rv); break; case STATE_SSL_CONNECT: DCHECK_EQ(OK, rv); rv = DoSSLConnect(); break; case STATE_SSL_CONNECT_COMPLETE: rv = DoSSLConnectComplete(rv); break; default: NOTREACHED() << "bad state"; rv = ERR_FAILED; break; } } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); return rv; } int SSLConnectJob::DoTCPConnect() { DCHECK(tcp_pool_); if (ssl_host_info_factory_ && SSLConfigService::snap_start_enabled()) { ssl_host_info_.reset( ssl_host_info_factory_->GetForHost(params_->host_and_port().host(), params_->ssl_config())); } if (ssl_host_info_.get()) { // This starts fetching the SSL host info from the disk cache for Snap // Start. ssl_host_info_->Start(); } next_state_ = STATE_TCP_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr tcp_params = params_->tcp_params(); return transport_socket_handle_->Init(group_name(), tcp_params, tcp_params->destination().priority(), &callback_, tcp_pool_, net_log()); } int SSLConnectJob::DoTCPConnectComplete(int result) { if (result == OK) next_state_ = STATE_SSL_CONNECT; return result; } int SSLConnectJob::DoSOCKSConnect() { DCHECK(socks_pool_); next_state_ = STATE_SOCKS_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr socks_params = params_->socks_params(); return transport_socket_handle_->Init(group_name(), socks_params, socks_params->destination().priority(), &callback_, socks_pool_, net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { if (result == OK) next_state_ = STATE_SSL_CONNECT; return result; } int SSLConnectJob::DoTunnelConnect() { DCHECK(http_proxy_pool_); next_state_ = STATE_TUNNEL_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr http_proxy_params = params_->http_proxy_params(); return transport_socket_handle_->Init( group_name(), http_proxy_params, http_proxy_params->destination().priority(), &callback_, http_proxy_pool_, net_log()); } int SSLConnectJob::DoTunnelConnectComplete(int result) { // Extract the information needed to prompt for appropriate proxy // authentication so that when ClientSocketPoolBaseHelper calls // |GetAdditionalErrorState|, we can easily set the state. if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { error_response_info_ = transport_socket_handle_->ssl_error_response_info(); } else if (result == ERR_PROXY_AUTH_REQUESTED) { ClientSocket* socket = transport_socket_handle_->socket(); HttpProxyClientSocket* tunnel_socket = static_cast(socket); error_response_info_ = *tunnel_socket->GetResponseInfo(); } if (result < 0) return result; next_state_ = STATE_SSL_CONNECT; return result; } void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { // Headers in |error_response_info_| indicate a proxy tunnel setup // problem. See DoTunnelConnectComplete. if (error_response_info_.headers) { handle->set_pending_http_proxy_connection( transport_socket_handle_.release()); } handle->set_ssl_error_response_info(error_response_info_); if (!ssl_connect_start_time_.is_null()) handle->set_is_ssl_error(true); } int SSLConnectJob::DoSSLConnect() { next_state_ = STATE_SSL_CONNECT_COMPLETE; // Reset the timeout to just the time allowed for the SSL handshake. ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds)); ssl_connect_start_time_ = base::TimeTicks::Now(); ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( transport_socket_handle_.release(), params_->host_and_port(), params_->ssl_config(), ssl_host_info_.release(), dns_cert_checker_)); return ssl_socket_->Connect(&callback_); } int SSLConnectJob::DoSSLConnectComplete(int result) { SSLClientSocket::NextProtoStatus status = SSLClientSocket::kNextProtoUnsupported; std::string proto; // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket // that hasn't had SSL_ImportFD called on it. If we get a certificate error // here, then we know that we called SSL_ImportFD. if (result == OK || IsCertificateError(result)) status = ssl_socket_->GetNextProto(&proto); // If we want spdy over npn, make sure it succeeded. if (status == SSLClientSocket::kNextProtoNegotiated) { ssl_socket_->set_was_npn_negotiated(true); SSLClientSocket::NextProto next_protocol = SSLClientSocket::NextProtoFromString(proto); // If we negotiated either version of SPDY, we must have // advertised it, so allow it. // TODO(mbelshe): verify it was a protocol we advertised? if (next_protocol == SSLClientSocket::kProtoSPDY1 || next_protocol == SSLClientSocket::kProtoSPDY2) { ssl_socket_->set_was_spdy_negotiated(true); } } if (params_->want_spdy_over_npn() && !ssl_socket_->was_spdy_negotiated()) return ERR_NPN_NEGOTIATION_FAILED; // Spdy might be turned on by default, or it might be over npn. bool using_spdy = params_->force_spdy_over_ssl() || params_->want_spdy_over_npn(); if (result == OK || ssl_socket_->IgnoreCertError(result, params_->load_flags())) { DCHECK(ssl_connect_start_time_ != base::TimeTicks()); base::TimeDelta connect_duration = base::TimeTicks::Now() - ssl_connect_start_time_; if (using_spdy) { UMA_HISTOGRAM_CUSTOM_TIMES("Net.SpdyConnectionLatency", connect_duration, base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), 100); } else { UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency", connect_duration, base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), 100); } } if (result == OK || IsCertificateError(result)) { set_socket(ssl_socket_.release()); } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { error_response_info_.cert_request_info = new SSLCertRequestInfo; ssl_socket_->GetSSLCertRequestInfo(error_response_info_.cert_request_info); } return result; } ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), tcp_pool_, socks_pool_, http_proxy_pool_, client_socket_factory_, host_resolver_, dnsrr_resolver_, dns_cert_checker_, ssl_host_info_factory_, delegate, net_log_); } SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( TCPClientSocketPool* tcp_pool, SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, NetLog* net_log) : tcp_pool_(tcp_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), client_socket_factory_(client_socket_factory), host_resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), net_log_(net_log) { base::TimeDelta max_transport_timeout = base::TimeDelta(); base::TimeDelta pool_timeout; if (tcp_pool_) max_transport_timeout = tcp_pool_->ConnectionTimeout(); if (socks_pool_) { pool_timeout = socks_pool_->ConnectionTimeout(); if (pool_timeout > max_transport_timeout) max_transport_timeout = pool_timeout; } if (http_proxy_pool_) { pool_timeout = http_proxy_pool_->ConnectionTimeout(); if (pool_timeout > max_transport_timeout) max_transport_timeout = pool_timeout; } timeout_ = max_transport_timeout + base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds); } SSLClientSocketPool::SSLClientSocketPool( int max_sockets, int max_sockets_per_group, ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ClientSocketFactory* client_socket_factory, TCPClientSocketPool* tcp_pool, SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, SSLConfigService* ssl_config_service, NetLog* net_log) : tcp_pool_(tcp_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), base_(max_sockets, max_sockets_per_group, histograms, base::TimeDelta::FromSeconds( ClientSocketPool::unused_idle_socket_timeout()), base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), new SSLConnectJobFactory(tcp_pool, socks_pool, http_proxy_pool, client_socket_factory, host_resolver, dnsrr_resolver, dns_cert_checker, ssl_host_info_factory, net_log)), ssl_config_service_(ssl_config_service) { if (ssl_config_service_) ssl_config_service_->AddObserver(this); } SSLClientSocketPool::~SSLClientSocketPool() { if (ssl_config_service_) ssl_config_service_->RemoveObserver(this); } int SSLClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, ClientSocketHandle* handle, CompletionCallback* callback, const BoundNetLog& net_log) { const scoped_refptr* casted_socket_params = static_cast*>(socket_params); return base_.RequestSocket(group_name, *casted_socket_params, priority, handle, callback, net_log); } void SSLClientSocketPool::RequestSockets( const std::string& group_name, const void* params, int num_sockets, const BoundNetLog& net_log) { const scoped_refptr* casted_params = static_cast*>(params); base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); } void SSLClientSocketPool::CancelRequest(const std::string& group_name, ClientSocketHandle* handle) { base_.CancelRequest(group_name, handle); } void SSLClientSocketPool::ReleaseSocket(const std::string& group_name, ClientSocket* socket, int id) { base_.ReleaseSocket(group_name, socket, id); } void SSLClientSocketPool::Flush() { base_.Flush(); } void SSLClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } int SSLClientSocketPool::IdleSocketCount() const { return base_.idle_socket_count(); } int SSLClientSocketPool::IdleSocketCountInGroup( const std::string& group_name) const { return base_.IdleSocketCountInGroup(group_name); } LoadState SSLClientSocketPool::GetLoadState( const std::string& group_name, const ClientSocketHandle* handle) const { return base_.GetLoadState(group_name, handle); } void SSLClientSocketPool::OnSSLConfigChanged() { Flush(); } DictionaryValue* SSLClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, bool include_nested_pools) const { DictionaryValue* dict = base_.GetInfoAsValue(name, type); if (include_nested_pools) { ListValue* list = new ListValue(); if (tcp_pool_) { list->Append(tcp_pool_->GetInfoAsValue("tcp_socket_pool", "tcp_socket_pool", false)); } if (socks_pool_) { list->Append(socks_pool_->GetInfoAsValue("socks_pool", "socks_pool", true)); } if (http_proxy_pool_) { list->Append(http_proxy_pool_->GetInfoAsValue("http_proxy_pool", "http_proxy_pool", true)); } dict->Set("nested_pools", list); } return dict; } base::TimeDelta SSLClientSocketPool::ConnectionTimeout() const { return base_.ConnectionTimeout(); } ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const { return base_.histograms(); } } // namespace net