diff options
author | akalin@chromium.org <akalin@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-08-15 00:13:44 +0000 |
---|---|---|
committer | akalin@chromium.org <akalin@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-08-15 00:13:44 +0000 |
commit | 18ccfdb7574c4868e37f53386454277e3e63bbe8 (patch) | |
tree | f1e177773e0b1cdc80deb3d755a8d7baf1233df6 | |
parent | 582a8575e5259762d5cb7b517b928ed7fc75ca11 (diff) | |
download | chromium_src-18ccfdb7574c4868e37f53386454277e3e63bbe8.zip chromium_src-18ccfdb7574c4868e37f53386454277e3e63bbe8.tar.gz chromium_src-18ccfdb7574c4868e37f53386454277e3e63bbe8.tar.bz2 |
[net] Use scoped_ptr<> consistently in ClientSocketFactory and related code
This will make it easier to modify ClientSocketFactory et al. to support
reprioritization. This also fixes a few latent memory leaks in tests.
Make SocketStream use a ClientSocketHandle instead of
just a StreamSocket.
Rename {set,release}_socket() to {Set,Pass}Socket().
BUG=166689
TBR=eroman@chromium.org, rsleevi@chromium.org, sergeyu@chromium.org
Review URL: https://codereview.chromium.org/22995002
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@217707 0039d316-1c4b-4281-b951-d872f2087c98
69 files changed, 727 insertions, 605 deletions
diff --git a/chrome/browser/net/network_stats.cc b/chrome/browser/net/network_stats.cc index 23a8a6a..1bc3161e 100644 --- a/chrome/browser/net/network_stats.cc +++ b/chrome/browser/net/network_stats.cc @@ -189,28 +189,25 @@ bool NetworkStats::DoConnect(int result) { return false; } - net::DatagramClientSocket* udp_socket = + scoped_ptr<net::DatagramClientSocket> udp_socket = socket_factory_->CreateDatagramClientSocket( net::DatagramSocket::DEFAULT_BIND, net::RandIntCallback(), NULL, net::NetLog::Source()); - if (!udp_socket) { - TestPhaseComplete(SOCKET_CREATE_FAILED, net::ERR_INVALID_ARGUMENT); - return false; - } - DCHECK(!socket_.get()); - socket_.reset(udp_socket); + DCHECK(udp_socket); + DCHECK(!socket_); + socket_ = udp_socket.Pass(); const net::IPEndPoint& endpoint = addresses_.front(); - int rv = udp_socket->Connect(endpoint); + int rv = socket_->Connect(endpoint); if (rv < 0) { TestPhaseComplete(CONNECT_FAILED, rv); return false; } - udp_socket->SetSendBufferSize(kMaxUdpSendBufferSize); - udp_socket->SetReceiveBufferSize(kMaxUdpReceiveBufferSize); + socket_->SetSendBufferSize(kMaxUdpSendBufferSize); + socket_->SetReceiveBufferSize(kMaxUdpReceiveBufferSize); return ConnectComplete(rv); } diff --git a/content/browser/renderer_host/p2p/socket_host_tcp.cc b/content/browser/renderer_host/p2p/socket_host_tcp.cc index 8e13bf8..314026e 100644 --- a/content/browser/renderer_host/p2p/socket_host_tcp.cc +++ b/content/browser/renderer_host/p2p/socket_host_tcp.cc @@ -136,7 +136,9 @@ void P2PSocketHostTcpBase::OnConnected(int result) { StartTls(); } else { if (IsPseudoTlsClientSocket(type_)) { - socket_.reset(new jingle_glue::FakeSSLClientSocket(socket_.release())); + scoped_ptr<net::StreamSocket> transport_socket = socket_.Pass(); + socket_.reset( + new jingle_glue::FakeSSLClientSocket(transport_socket.Pass())); } // If we are not doing TLS, we are ready to send data now. @@ -155,7 +157,7 @@ void P2PSocketHostTcpBase::StartTls() { scoped_ptr<net::ClientSocketHandle> socket_handle( new net::ClientSocketHandle()); - socket_handle->set_socket(socket_.release()); + socket_handle->SetSocket(socket_.Pass()); net::SSLClientSocketContext context; context.cert_verifier = url_context_->GetURLRequestContext()->cert_verifier(); @@ -171,8 +173,8 @@ void P2PSocketHostTcpBase::StartTls() { net::ClientSocketFactory::GetDefaultFactory(); DCHECK(socket_factory); - socket_.reset(socket_factory->CreateSSLClientSocket( - socket_handle.release(), dest_host_port_pair, ssl_config, context)); + socket_ = socket_factory->CreateSSLClientSocket( + socket_handle.Pass(), dest_host_port_pair, ssl_config, context); int status = socket_->Connect( base::Bind(&P2PSocketHostTcpBase::ProcessTlsConnectDone, base::Unretained(this))); diff --git a/content/browser/renderer_host/pepper/pepper_tcp_socket.cc b/content/browser/renderer_host/pepper/pepper_tcp_socket.cc index f05be87..05b2e09 100644 --- a/content/browser/renderer_host/pepper/pepper_tcp_socket.cc +++ b/content/browser/renderer_host/pepper/pepper_tcp_socket.cc @@ -141,16 +141,16 @@ void PepperTCPSocket::SSLHandshake( connection_state_ = SSL_HANDSHAKE_IN_PROGRESS; // TODO(raymes,rsleevi): Use trusted/untrusted certificates when connecting. - net::ClientSocketHandle* handle = new net::ClientSocketHandle(); - handle->set_socket(socket_.release()); + scoped_ptr<net::ClientSocketHandle> handle(new net::ClientSocketHandle()); + handle->SetSocket(socket_.Pass()); net::ClientSocketFactory* factory = net::ClientSocketFactory::GetDefaultFactory(); net::HostPortPair host_port_pair(server_name, server_port); net::SSLClientSocketContext ssl_context; ssl_context.cert_verifier = manager_->GetCertVerifier(); ssl_context.transport_security_state = manager_->GetTransportSecurityState(); - socket_.reset(factory->CreateSSLClientSocket( - handle, host_port_pair, manager_->ssl_config(), ssl_context)); + socket_ = factory->CreateSSLClientSocket( + handle.Pass(), host_port_pair, manager_->ssl_config(), ssl_context); if (!socket_) { LOG(WARNING) << "Failed to create an SSL client socket."; OnSSLHandshakeCompleted(net::ERR_UNEXPECTED); diff --git a/jingle/glue/chrome_async_socket.cc b/jingle/glue/chrome_async_socket.cc index 39085e1..c14fb99 100644 --- a/jingle/glue/chrome_async_socket.cc +++ b/jingle/glue/chrome_async_socket.cc @@ -106,9 +106,9 @@ bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) { net::HostPortPair dest_host_port_pair(address.hostname(), address.port()); - transport_socket_.reset( + transport_socket_ = resolving_client_socket_factory_->CreateTransportClientSocket( - dest_host_port_pair)); + dest_host_port_pair); int status = transport_socket_->Connect( base::Bind(&ChromeAsyncSocket::ProcessConnectDone, weak_ptr_factory_.GetWeakPtr())); @@ -404,10 +404,10 @@ bool ChromeAsyncSocket::StartTls(const std::string& domain_name) { DCHECK(transport_socket_.get()); scoped_ptr<net::ClientSocketHandle> socket_handle( new net::ClientSocketHandle()); - socket_handle->set_socket(transport_socket_.release()); - transport_socket_.reset( + socket_handle->SetSocket(transport_socket_.Pass()); + transport_socket_ = resolving_client_socket_factory_->CreateSSLClientSocket( - socket_handle.release(), net::HostPortPair(domain_name, 443))); + socket_handle.Pass(), net::HostPortPair(domain_name, 443)); int status = transport_socket_->Connect( base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone, weak_ptr_factory_.GetWeakPtr())); diff --git a/jingle/glue/chrome_async_socket_unittest.cc b/jingle/glue/chrome_async_socket_unittest.cc index ebb69a2..db3d2b0 100644 --- a/jingle/glue/chrome_async_socket_unittest.cc +++ b/jingle/glue/chrome_async_socket_unittest.cc @@ -113,20 +113,20 @@ class MockXmppClientSocketFactory : public ResolvingClientSocketFactory { } // ResolvingClientSocketFactory implementation. - virtual net::StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket( const net::HostPortPair& host_and_port) OVERRIDE { return mock_client_socket_factory_->CreateTransportClientSocket( address_list_, NULL, net::NetLog::Source()); } - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocketHandle* transport_socket, + virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<net::ClientSocketHandle> transport_socket, const net::HostPortPair& host_and_port) OVERRIDE { net::SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); return mock_client_socket_factory_->CreateSSLClientSocket( - transport_socket, host_and_port, ssl_config_, context); + transport_socket.Pass(), host_and_port, ssl_config_, context); } private: diff --git a/jingle/glue/fake_ssl_client_socket.cc b/jingle/glue/fake_ssl_client_socket.cc index bf6d12a..9d722c7 100644 --- a/jingle/glue/fake_ssl_client_socket.cc +++ b/jingle/glue/fake_ssl_client_socket.cc @@ -77,8 +77,8 @@ base::StringPiece FakeSSLClientSocket::GetSslServerHello() { } FakeSSLClientSocket::FakeSSLClientSocket( - net::StreamSocket* transport_socket) - : transport_socket_(transport_socket), + scoped_ptr<net::StreamSocket> transport_socket) + : transport_socket_(transport_socket.Pass()), next_handshake_state_(STATE_NONE), handshake_completed_(false), write_buf_(NewDrainableIOBufferWithSize(arraysize(kSslClientHello))), diff --git a/jingle/glue/fake_ssl_client_socket.h b/jingle/glue/fake_ssl_client_socket.h index 5bc4547..54a9e2f 100644 --- a/jingle/glue/fake_ssl_client_socket.h +++ b/jingle/glue/fake_ssl_client_socket.h @@ -36,8 +36,7 @@ namespace jingle_glue { class FakeSSLClientSocket : public net::StreamSocket { public: - // Takes ownership of |transport_socket|. - explicit FakeSSLClientSocket(net::StreamSocket* transport_socket); + explicit FakeSSLClientSocket(scoped_ptr<net::StreamSocket> transport_socket); virtual ~FakeSSLClientSocket(); diff --git a/jingle/glue/fake_ssl_client_socket_unittest.cc b/jingle/glue/fake_ssl_client_socket_unittest.cc index 5c061f3..f6d8fea 100644 --- a/jingle/glue/fake_ssl_client_socket_unittest.cc +++ b/jingle/glue/fake_ssl_client_socket_unittest.cc @@ -91,7 +91,7 @@ class FakeSSLClientSocketTest : public testing::Test { virtual ~FakeSSLClientSocketTest() {} - net::StreamSocket* MakeClientSocket() { + scoped_ptr<net::StreamSocket> MakeClientSocket() { return mock_client_socket_factory_.CreateTransportClientSocket( net::AddressList(), NULL, net::NetLog::Source()); } @@ -269,7 +269,7 @@ class FakeSSLClientSocketTest : public testing::Test { }; TEST_F(FakeSSLClientSocketTest, PassThroughMethods) { - MockClientSocket* mock_client_socket = new MockClientSocket(); + scoped_ptr<MockClientSocket> mock_client_socket(new MockClientSocket()); const int kReceiveBufferSize = 10; const int kSendBufferSize = 20; net::IPEndPoint ip_endpoint(net::IPAddressNumber(net::kIPv4AddressSize), 80); @@ -284,7 +284,8 @@ TEST_F(FakeSSLClientSocketTest, PassThroughMethods) { EXPECT_CALL(*mock_client_socket, SetOmniboxSpeculation()); // Takes ownership of |mock_client_socket|. - FakeSSLClientSocket fake_ssl_client_socket(mock_client_socket); + FakeSSLClientSocket fake_ssl_client_socket( + mock_client_socket.PassAs<net::StreamSocket>()); fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize); fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize); EXPECT_EQ(kPeerAddress, diff --git a/jingle/glue/resolving_client_socket_factory.h b/jingle/glue/resolving_client_socket_factory.h index 5be8bc8..d1b9fc1 100644 --- a/jingle/glue/resolving_client_socket_factory.h +++ b/jingle/glue/resolving_client_socket_factory.h @@ -5,6 +5,7 @@ #ifndef JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_ #define JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_ +#include "base/memory/scoped_ptr.h" namespace net { class ClientSocketHandle; @@ -23,11 +24,11 @@ class ResolvingClientSocketFactory { public: virtual ~ResolvingClientSocketFactory() { } // Method to create a transport socket using a HostPortPair. - virtual net::StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket( const net::HostPortPair& host_and_port) = 0; - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocketHandle* transport_socket, + virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<net::ClientSocketHandle> transport_socket, const net::HostPortPair& host_and_port) = 0; }; diff --git a/jingle/glue/xmpp_client_socket_factory.cc b/jingle/glue/xmpp_client_socket_factory.cc index b9e040d..4823ee5 100644 --- a/jingle/glue/xmpp_client_socket_factory.cc +++ b/jingle/glue/xmpp_client_socket_factory.cc @@ -8,6 +8,7 @@ #include "jingle/glue/fake_ssl_client_socket.h" #include "jingle/glue/proxy_resolving_client_socket.h" #include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_context_getter.h" @@ -28,20 +29,25 @@ XmppClientSocketFactory::XmppClientSocketFactory( XmppClientSocketFactory::~XmppClientSocketFactory() {} -net::StreamSocket* XmppClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<net::StreamSocket> +XmppClientSocketFactory::CreateTransportClientSocket( const net::HostPortPair& host_and_port) { // TODO(akalin): Use socket pools. - net::StreamSocket* transport_socket = new ProxyResolvingClientSocket( - NULL, - request_context_getter_, - ssl_config_, - host_and_port); + scoped_ptr<net::StreamSocket> transport_socket( + new ProxyResolvingClientSocket( + NULL, + request_context_getter_, + ssl_config_, + host_and_port)); return (use_fake_ssl_client_socket_ ? - new FakeSSLClientSocket(transport_socket) : transport_socket); + scoped_ptr<net::StreamSocket>( + new FakeSSLClientSocket(transport_socket.Pass())) : + transport_socket.Pass()); } -net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket( - net::ClientSocketHandle* transport_socket, +scoped_ptr<net::SSLClientSocket> +XmppClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<net::ClientSocketHandle> transport_socket, const net::HostPortPair& host_and_port) { net::SSLClientSocketContext context; context.cert_verifier = @@ -52,7 +58,7 @@ net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket( // TODO(rkn): context.server_bound_cert_service is NULL because the // ServerBoundCertService class is not thread safe. return client_socket_factory_->CreateSSLClientSocket( - transport_socket, host_and_port, ssl_config_, context); + transport_socket.Pass(), host_and_port, ssl_config_, context); } diff --git a/jingle/glue/xmpp_client_socket_factory.h b/jingle/glue/xmpp_client_socket_factory.h index c2a0d6a..4204c19 100644 --- a/jingle/glue/xmpp_client_socket_factory.h +++ b/jingle/glue/xmpp_client_socket_factory.h @@ -35,11 +35,11 @@ class XmppClientSocketFactory : public ResolvingClientSocketFactory { virtual ~XmppClientSocketFactory(); // ResolvingClientSocketFactory implementation. - virtual net::StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket( const net::HostPortPair& host_and_port) OVERRIDE; - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocketHandle* transport_socket, + virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<net::ClientSocketHandle> transport_socket, const net::HostPortPair& host_and_port) OVERRIDE; private: diff --git a/net/dns/address_sorter_posix_unittest.cc b/net/dns/address_sorter_posix_unittest.cc index 96cbfc6..c451737 100644 --- a/net/dns/address_sorter_posix_unittest.cc +++ b/net/dns/address_sorter_posix_unittest.cc @@ -10,6 +10,8 @@ #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/socket/client_socket_factory.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" #include "net/udp/datagram_client_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -90,27 +92,27 @@ class TestSocketFactory : public ClientSocketFactory { TestSocketFactory() {} virtual ~TestSocketFactory() {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType, const RandIntCallback&, NetLog*, const NetLog::Source&) OVERRIDE { - return new TestUDPClientSocket(&mapping_); + return scoped_ptr<DatagramClientSocket>(new TestUDPClientSocket(&mapping_)); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList&, NetLog*, const NetLog::Source&) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle*, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle>, const HostPortPair&, const SSLConfig&, const SSLClientSocketContext&) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { NOTIMPLEMENTED(); diff --git a/net/dns/dns_session_unittest.cc b/net/dns/dns_session_unittest.cc index 4662706..ed726f2 100644 --- a/net/dns/dns_session_unittest.cc +++ b/net/dns/dns_session_unittest.cc @@ -14,6 +14,8 @@ #include "net/dns/dns_protocol.h" #include "net/dns/dns_socket_pool.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -24,26 +26,26 @@ class TestClientSocketFactory : public ClientSocketFactory { public: virtual ~TestClientSocketFactory(); - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog*, const NetLog::Source&) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -179,7 +181,8 @@ bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) { return true; } -DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket( +scoped_ptr<DatagramClientSocket> +TestClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, @@ -188,9 +191,10 @@ DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket( // simplest SocketDataProvider with no data supplied. SocketDataProvider* data_provider = new StaticSocketDataProvider(); data_providers_.push_back(data_provider); - MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } TestClientSocketFactory::~TestClientSocketFactory() { diff --git a/net/dns/dns_socket_pool.cc b/net/dns/dns_socket_pool.cc index 64570fc..7a7ecd6 100644 --- a/net/dns/dns_socket_pool.cc +++ b/net/dns/dns_socket_pool.cc @@ -76,8 +76,8 @@ scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket( scoped_ptr<DatagramClientSocket> socket; NetLog::Source no_source; - socket.reset(socket_factory_->CreateDatagramClientSocket( - kBindType, base::Bind(&base::RandInt), net_log_, no_source)); + socket = socket_factory_->CreateDatagramClientSocket( + kBindType, base::Bind(&base::RandInt), net_log_, no_source); if (socket.get()) { int rv = socket->Connect((*nameservers_)[server_index]); diff --git a/net/dns/dns_transaction_unittest.cc b/net/dns/dns_transaction_unittest.cc index f9667ee..7040e44 100644 --- a/net/dns/dns_transaction_unittest.cc +++ b/net/dns/dns_transaction_unittest.cc @@ -180,21 +180,21 @@ class TestSocketFactory : public MockClientSocketFactory { TestSocketFactory() : fail_next_socket_(false) {} virtual ~TestSocketFactory() {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) OVERRIDE { if (fail_next_socket_) { fail_next_socket_ = false; - return new FailingUDPClientSocket(&empty_data_, net_log); + return scoped_ptr<DatagramClientSocket>( + new FailingUDPClientSocket(&empty_data_, net_log)); } SocketDataProvider* data_provider = mock_data().GetNext(); - TestUDPClientSocket* socket = new TestUDPClientSocket(this, - data_provider, - net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<TestUDPClientSocket> socket( + new TestUDPClientSocket(this, data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } void OnConnect(const IPEndPoint& endpoint) { diff --git a/net/ftp/ftp_network_transaction.cc b/net/ftp/ftp_network_transaction.cc index ccd6e2e..f9f7b82 100644 --- a/net/ftp/ftp_network_transaction.cc +++ b/net/ftp/ftp_network_transaction.cc @@ -663,8 +663,8 @@ int FtpNetworkTransaction::DoCtrlResolveHostComplete(int result) { int FtpNetworkTransaction::DoCtrlConnect() { next_state_ = STATE_CTRL_CONNECT_COMPLETE; - ctrl_socket_.reset(socket_factory_->CreateTransportClientSocket( - addresses_, net_log_.net_log(), net_log_.source())); + ctrl_socket_ = socket_factory_->CreateTransportClientSocket( + addresses_, net_log_.net_log(), net_log_.source()); net_log_.AddEvent( NetLog::TYPE_FTP_CONTROL_CONNECTION, ctrl_socket_->NetLog().source().ToEventParametersCallback()); @@ -1249,8 +1249,8 @@ int FtpNetworkTransaction::DoDataConnect() { return Stop(rv); data_address = AddressList::CreateFromIPAddress( ip_endpoint.address(), data_connection_port_); - data_socket_.reset(socket_factory_->CreateTransportClientSocket( - data_address, net_log_.net_log(), net_log_.source())); + data_socket_ = socket_factory_->CreateTransportClientSocket( + data_address, net_log_.net_log(), net_log_.source()); net_log_.AddEvent( NetLog::TYPE_FTP_DATA_CONNECTION, data_socket_->NetLog().source().ToEventParametersCallback()); diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 0968b14..d89ab54 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -452,7 +452,7 @@ class CaptureGroupNameSocketPool : public ParentPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) {} virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) {} virtual void CloseIdleSockets() {} virtual int IdleSocketCount() const { diff --git a/net/http/http_pipelined_host_forced.cc b/net/http/http_pipelined_host_forced.cc index 8179e86..8059d84 100644 --- a/net/http/http_pipelined_host_forced.cc +++ b/net/http/http_pipelined_host_forced.cc @@ -36,10 +36,9 @@ HttpPipelinedStream* HttpPipelinedHostForced::CreateStreamOnNewPipeline( bool was_npn_negotiated, NextProto protocol_negotiated) { CHECK(!pipeline_.get()); - StreamSocket* wrapped_socket = connection->release_socket(); - BufferedWriteStreamSocket* buffered_socket = new BufferedWriteStreamSocket( - wrapped_socket); - connection->set_socket(buffered_socket); + scoped_ptr<BufferedWriteStreamSocket> buffered_socket( + new BufferedWriteStreamSocket(connection->PassSocket())); + connection->SetSocket(buffered_socket.PassAs<StreamSocket>()); pipeline_.reset(factory_->CreateNewPipeline( connection, this, key_.origin(), used_ssl_config, used_proxy_info, net_log, was_npn_negotiated, protocol_negotiated)); diff --git a/net/http/http_proxy_client_socket_pool.cc b/net/http/http_proxy_client_socket_pool.cc index b80df37..c75df6f 100644 --- a/net/http/http_proxy_client_socket_pool.cc +++ b/net/http/http_proxy_client_socket_pool.cc @@ -289,7 +289,7 @@ int HttpProxyConnectJob::DoHttpProxyConnect() { int HttpProxyConnectJob::DoHttpProxyConnectComplete(int result) { if (result == OK || result == ERR_PROXY_AUTH_REQUESTED || result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) { - set_socket(transport_socket_.release()); + SetSocket(transport_socket_.PassAs<StreamSocket>()); } return result; @@ -380,19 +380,19 @@ HttpProxyConnectJobFactory::HttpProxyConnectJobFactory( } -ConnectJob* +scoped_ptr<ConnectJob> HttpProxyClientSocketPool::HttpProxyConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new HttpProxyConnectJob(group_name, - request.params(), - ConnectionTimeout(), - transport_pool_, - ssl_pool_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>(new HttpProxyConnectJob(group_name, + request.params(), + ConnectionTimeout(), + transport_pool_, + ssl_pool_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -462,8 +462,9 @@ void HttpProxyClientSocketPool::CancelRequest( } void HttpProxyClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void HttpProxyClientSocketPool::FlushWithError(int error) { diff --git a/net/http/http_proxy_client_socket_pool.h b/net/http/http_proxy_client_socket_pool.h index a15b8ca..b77b5ae 100644 --- a/net/http/http_proxy_client_socket_pool.h +++ b/net/http/http_proxy_client_socket_pool.h @@ -204,7 +204,7 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; @@ -250,7 +250,7 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool NetLog* net_log); // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; diff --git a/net/http/http_stream_factory_impl_unittest.cc b/net/http/http_stream_factory_impl_unittest.cc index 14fbc03..f378c93 100644 --- a/net/http/http_stream_factory_impl_unittest.cc +++ b/net/http/http_stream_factory_impl_unittest.cc @@ -314,7 +314,7 @@ class CapturePreconnectsSocketPool : public ParentPool { ADD_FAILURE(); } virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE { ADD_FAILURE(); } diff --git a/net/http/http_stream_parser_unittest.cc b/net/http/http_stream_parser_unittest.cc index d530c2d..8477594 100644 --- a/net/http/http_stream_parser_unittest.cc +++ b/net/http/http_stream_parser_unittest.cc @@ -220,7 +220,7 @@ TEST(HttpStreamParser, AsyncChunkAndAsyncSocket) { ASSERT_EQ(OK, rv); scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); - socket_handle->set_socket(transport.release()); + socket_handle->SetSocket(transport.PassAs<StreamSocket>()); HttpRequestInfo request_info; request_info.method = "GET"; @@ -375,7 +375,7 @@ TEST(HttpStreamParser, TruncatedHeaders) { ASSERT_EQ(OK, rv); scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); - socket_handle->set_socket(transport.release()); + socket_handle->SetSocket(transport.PassAs<StreamSocket>()); HttpRequestInfo request_info; request_info.method = "GET"; diff --git a/net/quic/quic_client_session.cc b/net/quic/quic_client_session.cc index f6641d0..1ba03a3 100644 --- a/net/quic/quic_client_session.cc +++ b/net/quic/quic_client_session.cc @@ -81,7 +81,7 @@ void QuicClientSession::StreamRequest::OnRequestCompleteFailure(int rv) { QuicClientSession::QuicClientSession( QuicConnection* connection, - DatagramClientSocket* socket, + scoped_ptr<DatagramClientSocket> socket, QuicStreamFactory* stream_factory, QuicCryptoClientStreamFactory* crypto_client_stream_factory, const string& server_hostname, @@ -91,7 +91,7 @@ QuicClientSession::QuicClientSession( : QuicSession(connection, config, false), weak_factory_(this), stream_factory_(stream_factory), - socket_(socket), + socket_(socket.Pass()), read_buffer_(new IOBufferWithSize(kMaxPacketSize)), read_pending_(false), num_total_streams_(0), diff --git a/net/quic/quic_client_session.h b/net/quic/quic_client_session.h index d124fdb..555837f 100644 --- a/net/quic/quic_client_session.h +++ b/net/quic/quic_client_session.h @@ -13,6 +13,7 @@ #include <string> #include "base/containers/hash_tables.h" +#include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/quic/quic_connection_logger.h" #include "net/quic/quic_crypto_client_stream.h" @@ -74,7 +75,7 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { // not |stream_factory|, which must outlive this session. // TODO(rch): decouple the factory from the session via a Delegate interface. QuicClientSession(QuicConnection* connection, - DatagramClientSocket* socket, + scoped_ptr<DatagramClientSocket> socket, QuicStreamFactory* stream_factory, QuicCryptoClientStreamFactory* crypto_client_stream_factory, const std::string& server_hostname, diff --git a/net/quic/quic_client_session_test.cc b/net/quic/quic_client_session_test.cc index 6113f45..09e7d21 100644 --- a/net/quic/quic_client_session_test.cc +++ b/net/quic/quic_client_session_test.cc @@ -15,6 +15,7 @@ #include "net/quic/test_tools/crypto_test_utils.h" #include "net/quic/test_tools/quic_client_session_peer.h" #include "net/quic/test_tools/quic_test_utils.h" +#include "net/udp/datagram_client_socket.h" using testing::_; @@ -29,8 +30,9 @@ class QuicClientSessionTest : public ::testing::Test { QuicClientSessionTest() : guid_(1), connection_(new PacketSavingConnection(guid_, IPEndPoint(), false)), - session_(connection_, NULL, NULL, NULL, kServerHostname, - DefaultQuicConfig(), &crypto_config_, &net_log_) { + session_(connection_, scoped_ptr<DatagramClientSocket>(), NULL, + NULL, kServerHostname, DefaultQuicConfig(), &crypto_config_, + &net_log_) { session_.config()->SetDefaults(); crypto_config_.SetDefaults(); } diff --git a/net/quic/quic_http_stream_test.cc b/net/quic/quic_http_stream_test.cc index b378416..1e4ac91 100644 --- a/net/quic/quic_http_stream_test.cc +++ b/net/quic/quic_http_stream_test.cc @@ -179,10 +179,12 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { connection_->SetSendAlgorithm(send_algorithm_); connection_->SetReceiveAlgorithm(receive_algorithm_); crypto_config_.SetDefaults(); - session_.reset(new QuicClientSession(connection_, socket, NULL, - &crypto_client_stream_factory_, - "www.google.com", DefaultQuicConfig(), - &crypto_config_, NULL)); + session_.reset( + new QuicClientSession(connection_, + scoped_ptr<DatagramClientSocket>(socket), NULL, + &crypto_client_stream_factory_, + "www.google.com", DefaultQuicConfig(), + &crypto_config_, NULL)); session_->GetCryptoStream()->CryptoConnect(); EXPECT_TRUE(session_->IsCryptoHandshakeConfirmed()); stream_.reset(use_closing_stream_ ? diff --git a/net/quic/quic_stream_factory.cc b/net/quic/quic_stream_factory.cc index fba7f0b..86bd8a1 100644 --- a/net/quic/quic_stream_factory.cc +++ b/net/quic/quic_stream_factory.cc @@ -408,10 +408,10 @@ QuicClientSession* QuicStreamFactory::CreateSession( const BoundNetLog& net_log) { QuicGuid guid = random_generator_->RandUint64(); IPEndPoint addr = *address_list.begin(); - DatagramClientSocket* socket = + scoped_ptr<DatagramClientSocket> socket( client_socket_factory_->CreateDatagramClientSocket( DatagramSocket::DEFAULT_BIND, base::Bind(&base::RandInt), - net_log.net_log(), net_log.source()); + net_log.net_log(), net_log.source())); socket->Connect(addr); // We should adaptively set this buffer size, but for now, we'll use a size @@ -437,7 +437,7 @@ QuicClientSession* QuicStreamFactory::CreateSession( base::MessageLoop::current()->message_loop_proxy().get(), clock_.get(), random_generator_, - socket); + socket.get()); QuicConnection* connection = new QuicConnection(guid, addr, helper, false, QuicVersionMax()); @@ -447,7 +447,7 @@ QuicClientSession* QuicStreamFactory::CreateSession( DCHECK(crypto_config); QuicClientSession* session = - new QuicClientSession(connection, socket, this, + new QuicClientSession(connection, socket.Pass(), this, quic_crypto_client_stream_factory_, host_port_proxy_pair.first.host(), config_, crypto_config, net_log.net_log()); diff --git a/net/socket/buffered_write_stream_socket.cc b/net/socket/buffered_write_stream_socket.cc index 36b9df7..cf13c5e 100644 --- a/net/socket/buffered_write_stream_socket.cc +++ b/net/socket/buffered_write_stream_socket.cc @@ -23,8 +23,8 @@ void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) { } // anonymous namespace BufferedWriteStreamSocket::BufferedWriteStreamSocket( - StreamSocket* socket_to_wrap) - : wrapped_socket_(socket_to_wrap), + scoped_ptr<StreamSocket> socket_to_wrap) + : wrapped_socket_(socket_to_wrap.Pass()), io_buffer_(new GrowableIOBuffer()), backup_buffer_(new GrowableIOBuffer()), weak_factory_(this), diff --git a/net/socket/buffered_write_stream_socket.h b/net/socket/buffered_write_stream_socket.h index fcb33a8..aad5736 100644 --- a/net/socket/buffered_write_stream_socket.h +++ b/net/socket/buffered_write_stream_socket.h @@ -5,6 +5,8 @@ #ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ #define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" #include "net/base/net_log.h" #include "net/socket/stream_socket.h" @@ -33,7 +35,7 @@ class IPEndPoint; // There are no bounds on the local buffer size. Use carefully. class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { public: - BufferedWriteStreamSocket(StreamSocket* socket_to_wrap); + explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap); virtual ~BufferedWriteStreamSocket(); // Socket interface @@ -71,6 +73,8 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { bool callback_pending_; bool wrapped_write_in_progress_; int error_; + + DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket); }; } // namespace net diff --git a/net/socket/buffered_write_stream_socket_unittest.cc b/net/socket/buffered_write_stream_socket_unittest.cc index e579a7f..485295f 100644 --- a/net/socket/buffered_write_stream_socket_unittest.cc +++ b/net/socket/buffered_write_stream_socket_unittest.cc @@ -30,10 +30,11 @@ class BufferedWriteStreamSocketTest : public testing::Test { if (writes_count) { data_->StopAfter(writes_count); } - DeterministicMockTCPClientSocket* wrapped_socket = - new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get()); + scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket( + new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get())); data_->set_delegate(wrapped_socket->AsWeakPtr()); - socket_.reset(new BufferedWriteStreamSocket(wrapped_socket)); + socket_.reset(new BufferedWriteStreamSocket( + wrapped_socket.PassAs<StreamSocket>())); socket_->Connect(callback_.callback()); } diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc index 6d93034..a86688e 100644 --- a/net/socket/client_socket_factory.cc +++ b/net/socket/client_socket_factory.cc @@ -67,23 +67,25 @@ class DefaultClientSocketFactory : public ClientSocketFactory, ClearSSLSessionCache(); } - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { - return new UDPClientSocket(bind_type, rand_int_cb, net_log, source); + return scoped_ptr<DatagramClientSocket>( + new UDPClientSocket(bind_type, rand_int_cb, net_log, source)); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE { - return new TCPClientSocket(addresses, net_log, source); + return scoped_ptr<StreamSocket>( + new TCPClientSocket(addresses, net_log, source)); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { @@ -102,17 +104,19 @@ class DefaultClientSocketFactory : public ClientSocketFactory, nss_task_runner = base::ThreadTaskRunnerHandle::Get(); #if defined(USE_OPENSSL) - return new SSLClientSocketOpenSSL(transport_socket, host_and_port, - ssl_config, context); + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port, + ssl_config, context)); #elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) - return new SSLClientSocketNSS(nss_task_runner.get(), - transport_socket, - host_and_port, - ssl_config, - context); + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketNSS(nss_task_runner.get(), + transport_socket.Pass(), + host_and_port, + ssl_config, + context)); #else NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); #endif } diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h index a78fc48..6cb5949 100644 --- a/net/socket/client_socket_factory.h +++ b/net/socket/client_socket_factory.h @@ -8,6 +8,7 @@ #include <string> #include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" #include "net/base/net_log.h" #include "net/base/rand_callback.h" @@ -32,13 +33,13 @@ class NET_EXPORT ClientSocketFactory { // |source| is the NetLog::Source for the entity trying to create the socket, // if it has one. - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) = 0; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) = 0; @@ -46,8 +47,8 @@ class NET_EXPORT ClientSocketFactory { // It is allowed to pass in a |transport_socket| that is not obtained from a // socket pool. The caller could create a ClientSocketHandle directly and call // set_socket() on it to set a valid StreamSocket instance. - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) = 0; diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc index 3894fa7..acb896b 100644 --- a/net/socket/client_socket_handle.cc +++ b/net/socket/client_socket_handle.cc @@ -43,7 +43,7 @@ void ClientSocketHandle::ResetInternal(bool cancel) { if (pool_) // If we've still got a socket, release it back to the ClientSocketPool so // it can be deleted or reused. - pool_->ReleaseSocket(group_name_, release_socket(), pool_id_); + pool_->ReleaseSocket(group_name_, PassSocket(), pool_id_); } else if (cancel) { // If we did not get initialized yet, we've got a socket request pending. // Cancel it. @@ -121,6 +121,10 @@ bool ClientSocketHandle::GetLoadTimingInfo( return true; } +void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) { + socket_ = s.Pass(); +} + void ClientSocketHandle::OnIOComplete(int result) { CompletionCallback callback = user_callback_; user_callback_.Reset(); @@ -128,6 +132,10 @@ void ClientSocketHandle::OnIOComplete(int result) { callback.Run(result); } +scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() { + return socket_.Pass(); +} + void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(ERR_IO_PENDING, result); ClientSocketPoolHistograms* histograms = pool_->histograms(); diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h index 7d5588a..9651f08 100644 --- a/net/socket/client_socket_handle.h +++ b/net/socket/client_socket_handle.h @@ -116,8 +116,8 @@ class NET_EXPORT ClientSocketHandle { LoadTimingInfo* load_timing_info) const; // Used by ClientSocketPool to initialize the ClientSocketHandle. + void SetSocket(scoped_ptr<StreamSocket> s); void set_is_reused(bool is_reused) { is_reused_ = is_reused; } - void set_socket(StreamSocket* s) { socket_.reset(s); } void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; } void set_pool_id(int id) { pool_id_ = id; } void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; } @@ -144,10 +144,10 @@ class NET_EXPORT ClientSocketHandle { } // These may only be used if is_initialized() is true. + scoped_ptr<StreamSocket> PassSocket(); + StreamSocket* socket() { return socket_.get(); } const std::string& group_name() const { return group_name_; } int id() const { return pool_id_; } - StreamSocket* socket() { return socket_.get(); } - StreamSocket* release_socket() { return socket_.release(); } bool is_reused() const { return is_reused_; } base::TimeDelta idle_time() const { return idle_time_; } SocketReuseType reuse_type() const { diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h index 7cb9a7e..af18454 100644 --- a/net/socket/client_socket_pool.h +++ b/net/socket/client_socket_pool.h @@ -10,6 +10,7 @@ #include "base/basictypes.h" #include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" #include "base/template_util.h" #include "base/time/time.h" #include "net/base/completion_callback.h" @@ -111,7 +112,7 @@ class NET_EXPORT ClientSocketPool { // change when it flushes, so it can use this |id| to discard sockets with // mismatched ids. virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) = 0; // This flushes all state from the ClientSocketPool. This means that all diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index 4a5b118..b1ddd40 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -82,6 +82,10 @@ ConnectJob::~ConnectJob() { net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB); } +scoped_ptr<StreamSocket> ConnectJob::PassSocket() { + return socket_.Pass(); +} + int ConnectJob::Connect() { if (timeout_duration_ != base::TimeDelta()) timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout); @@ -100,16 +104,16 @@ int ConnectJob::Connect() { return rv; } -void ConnectJob::set_socket(StreamSocket* socket) { +void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) { if (socket) { net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET, socket->NetLog().source().ToEventParametersCallback()); } - socket_.reset(socket); + socket_ = socket.Pass(); } void ConnectJob::NotifyDelegateOfCompletion(int rv) { - // The delegate will delete |this|. + // The delegate will own |this|. Delegate* delegate = delegate_; delegate_ = NULL; @@ -135,7 +139,7 @@ void ConnectJob::LogConnectCompletion(int net_error) { void ConnectJob::OnTimeout() { // Make sure the socket is NULL before calling into |delegate|. - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT); @@ -392,11 +396,11 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( if (rv == OK) { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); if (!preconnecting) { - HandOutSocket(connect_job->ReleaseSocket(), false /* not reused */, + HandOutSocket(connect_job->PassSocket(), false /* not reused */, connect_job->connect_timing(), handle, base::TimeDelta(), group, request->net_log()); } else { - AddIdleSocket(connect_job->ReleaseSocket(), group); + AddIdleSocket(connect_job->PassSocket(), group); } } else if (rv == ERR_IO_PENDING) { // If we don't have any sockets in this group, set a timer for potentially @@ -409,17 +413,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( connecting_socket_count_++; - group->AddJob(connect_job.release(), preconnecting); + group->AddJob(connect_job.Pass(), preconnecting); } else { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); - StreamSocket* error_socket = NULL; + scoped_ptr<StreamSocket> error_socket; if (!preconnecting) { DCHECK(handle); connect_job->GetAdditionalErrorState(handle); - error_socket = connect_job->ReleaseSocket(); + error_socket = connect_job->PassSocket(); } if (error_socket) { - HandOutSocket(error_socket, false /* not reused */, + HandOutSocket(error_socket.Pass(), false /* not reused */, connect_job->connect_timing(), handle, base::TimeDelta(), group, request->net_log()); } else if (group->IsEmpty()) { @@ -469,7 +473,7 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( IdleSocket idle_socket = *idle_socket_it; idle_sockets->erase(idle_socket_it); HandOutSocket( - idle_socket.socket, + scoped_ptr<StreamSocket>(idle_socket.socket), idle_socket.socket->WasEverUsed(), LoadTimingInfo::ConnectTiming(), request->handle(), @@ -495,11 +499,11 @@ void ClientSocketPoolBaseHelper::CancelRequest( if (callback_it != pending_callback_map_.end()) { int result = callback_it->second.result; pending_callback_map_.erase(callback_it); - StreamSocket* socket = handle->release_socket(); + scoped_ptr<StreamSocket> socket = handle->PassSocket(); if (socket) { if (result != OK) socket->Disconnect(); - ReleaseSocket(handle->group_name(), socket, handle->id()); + ReleaseSocket(handle->group_name(), socket.Pass(), handle->id()); } return; } @@ -756,7 +760,7 @@ void ClientSocketPoolBaseHelper::StartIdleSocketTimer() { } void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) { GroupMap::iterator i = group_map_.find(group_name); CHECK(i != group_map_.end()); @@ -773,10 +777,10 @@ void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, id == pool_generation_number_; if (can_reuse) { // Add it to the idle list. - AddIdleSocket(socket, group); + AddIdleSocket(socket.Pass(), group); OnAvailableSocketSlot(group_name, group); } else { - delete socket; + socket.reset(); } CheckForStalledSocketGroups(); @@ -854,13 +858,16 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( CHECK(group_it != group_map_.end()); Group* group = group_it->second; - scoped_ptr<StreamSocket> socket(job->ReleaseSocket()); + scoped_ptr<StreamSocket> socket = job->PassSocket(); // Copies of these are needed because |job| may be deleted before they are // accessed. BoundNetLog job_log = job->net_log(); LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing(); + // RemoveConnectJob(job, _) must be called by all branches below; + // otherwise, |job| will be leaked. + if (result == OK) { DCHECK(socket.get()); RemoveConnectJob(job, group); @@ -869,12 +876,12 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( group->mutable_pending_requests()->begin(), group)); LogBoundConnectJobToRequest(job_log.source(), r.get()); HandOutSocket( - socket.release(), false /* unused socket */, connect_timing, + socket.Pass(), false /* unused socket */, connect_timing, r->handle(), base::TimeDelta(), group, r->net_log()); r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); InvokeUserCallbackLater(r->handle(), r->callback(), result); } else { - AddIdleSocket(socket.release(), group); + AddIdleSocket(socket.Pass(), group); OnAvailableSocketSlot(group_name, group); CheckForStalledSocketGroups(); } @@ -890,7 +897,7 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( RemoveConnectJob(job, group); if (socket.get()) { handed_out_socket = true; - HandOutSocket(socket.release(), false /* unused socket */, + HandOutSocket(socket.Pass(), false /* unused socket */, connect_timing, r->handle(), base::TimeDelta(), group, r->net_log()); } @@ -975,7 +982,7 @@ void ClientSocketPoolBaseHelper::ProcessPendingRequest( } void ClientSocketPoolBaseHelper::HandOutSocket( - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, bool reused, const LoadTimingInfo::ConnectTiming& connect_timing, ClientSocketHandle* handle, @@ -983,7 +990,7 @@ void ClientSocketPoolBaseHelper::HandOutSocket( Group* group, const BoundNetLog& net_log) { DCHECK(socket); - handle->set_socket(socket); + handle->SetSocket(socket.Pass()); handle->set_is_reused(reused); handle->set_idle_time(idle_time); handle->set_pool_id(pool_generation_number_); @@ -996,18 +1003,20 @@ void ClientSocketPoolBaseHelper::HandOutSocket( "idle_ms", static_cast<int>(idle_time.InMilliseconds()))); } - net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, - socket->NetLog().source().ToEventParametersCallback()); + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + handle->socket()->NetLog().source().ToEventParametersCallback()); handed_out_socket_count_++; group->IncrementActiveSocketCount(); } void ClientSocketPoolBaseHelper::AddIdleSocket( - StreamSocket* socket, Group* group) { + scoped_ptr<StreamSocket> socket, + Group* group) { DCHECK(socket); IdleSocket idle_socket; - idle_socket.socket = socket; + idle_socket.socket = socket.release(); idle_socket.start_time = base::TimeTicks::Now(); group->mutable_idle_sockets()->push_back(idle_socket); @@ -1178,13 +1187,13 @@ bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() { return true; } -void ClientSocketPoolBaseHelper::Group::AddJob(ConnectJob* job, +void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect) { SanityCheck(); if (is_preconnect) ++unassigned_job_count_; - jobs_.insert(job); + jobs_.insert(job.release()); } void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) { @@ -1224,15 +1233,17 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( if (pending_requests_.empty()) return; - ConnectJob* backup_job = pool->connect_job_factory_->NewConnectJob( - group_name, **pending_requests_.begin(), pool); + scoped_ptr<ConnectJob> backup_job = + pool->connect_job_factory_->NewConnectJob( + group_name, **pending_requests_.begin(), pool); backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED); SIMPLE_STATS_COUNTER("socket.backup_created"); int rv = backup_job->Connect(); pool->connecting_socket_count_++; - AddJob(backup_job, false); + ConnectJob* raw_backup_job = backup_job.get(); + AddJob(backup_job.Pass(), false); if (rv != ERR_IO_PENDING) - pool->OnConnectJobComplete(rv, backup_job); + pool->OnConnectJobComplete(rv, raw_backup_job); } void ClientSocketPoolBaseHelper::Group::SanityCheck() { diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h index ae1331b..eb642ed 100644 --- a/net/socket/client_socket_pool_base.h +++ b/net/socket/client_socket_pool_base.h @@ -61,8 +61,11 @@ class NET_EXPORT_PRIVATE ConnectJob { Delegate() {} virtual ~Delegate() {} - // Alerts the delegate that the connection completed. - virtual void OnConnectJobComplete(int result, ConnectJob* job) = 0; + // Alerts the delegate that the connection completed. |job| must + // be destroyed by the delegate. A scoped_ptr<> isn't used because + // the caller of this function doesn't own |job|. + virtual void OnConnectJobComplete(int result, + ConnectJob* job) = 0; private: DISALLOW_COPY_AND_ASSIGN(Delegate); @@ -79,9 +82,10 @@ class NET_EXPORT_PRIVATE ConnectJob { const std::string& group_name() const { return group_name_; } const BoundNetLog& net_log() { return net_log_; } - // Releases |socket_| to the client. On connection error, this should return - // NULL. - StreamSocket* ReleaseSocket() { return socket_.release(); } + // Releases ownership of the underlying socket to the caller. + // Returns the released socket, or NULL if there was a connection + // error. + scoped_ptr<StreamSocket> PassSocket(); // Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it // cannot complete synchronously without blocking, or another net error code @@ -105,7 +109,7 @@ class NET_EXPORT_PRIVATE ConnectJob { const BoundNetLog& net_log() const { return net_log_; } protected: - void set_socket(StreamSocket* socket); + void SetSocket(scoped_ptr<StreamSocket> socket); StreamSocket* socket() { return socket_.get(); } void NotifyDelegateOfCompletion(int rv); void ResetTimer(base::TimeDelta remainingTime); @@ -188,7 +192,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper ConnectJobFactory() {} virtual ~ConnectJobFactory() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const Request& request, ConnectJob::Delegate* delegate) const = 0; @@ -229,7 +233,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // See ClientSocketPool::ReleaseSocket for documentation on this function. void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id); // See ClientSocketPool::FlushWithError for documentation on this function. @@ -386,7 +390,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // Otherwise, returns false. bool TryToUseUnassignedConnectJob(); - void AddJob(ConnectJob* job, bool is_preconnect); + void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect); // Remove |job| from this group, which must already own |job|. void RemoveJob(ConnectJob* job); void RemoveAllJobs(); @@ -476,7 +480,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper CleanupIdleSockets(false); } - // Removes |job| from |connect_job_set_|. Also updates |group| if non-NULL. + // Removes |job| from |group|, which must already own |job|. void RemoveConnectJob(ConnectJob* job, Group* group); // Tries to see if we can handle any more requests for |group|. @@ -486,7 +490,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper void ProcessPendingRequest(const std::string& group_name, Group* group); // Assigns |socket| to |handle| and updates |group|'s counters appropriately. - void HandOutSocket(StreamSocket* socket, + void HandOutSocket(scoped_ptr<StreamSocket> socket, bool reused, const LoadTimingInfo::ConnectTiming& connect_timing, ClientSocketHandle* handle, @@ -495,7 +499,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper const BoundNetLog& net_log); // Adds |socket| to the list of idle sockets for |group|. - void AddIdleSocket(StreamSocket* socket, Group* group); + void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group); // Iterates through |group_map_|, canceling all ConnectJobs and deleting // groups if they are no longer needed. @@ -625,7 +629,7 @@ class ClientSocketPoolBase { ConnectJobFactory() {} virtual ~ConnectJobFactory() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const Request& request, ConnectJob::Delegate* delegate) const = 0; @@ -703,9 +707,10 @@ class ClientSocketPoolBase { return helper_.CancelRequest(group_name, handle); } - void ReleaseSocket(const std::string& group_name, StreamSocket* socket, + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, int id) { - return helper_.ReleaseSocket(group_name, socket, id); + return helper_.ReleaseSocket(group_name, socket.Pass(), id); } void FlushWithError(int error) { helper_.FlushWithError(error); } @@ -786,13 +791,13 @@ class ClientSocketPoolBase { : connect_job_factory_(connect_job_factory) {} virtual ~ConnectJobFactoryAdaptor() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const internal::ClientSocketPoolBaseHelper::Request& request, - ConnectJob::Delegate* delegate) const { - const Request* casted_request = static_cast<const Request*>(&request); + ConnectJob::Delegate* delegate) const OVERRIDE { + const Request& casted_request = static_cast<const Request&>(request); return connect_job_factory_->NewConnectJob( - group_name, *casted_request, delegate); + group_name, casted_request, delegate); } virtual base::TimeDelta ConnectionTimeout() const { diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 5eeda97..6688e01 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -30,7 +30,9 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -189,30 +191,30 @@ class MockClientSocketFactory : public ClientSocketFactory { public: MockClientSocketFactory() : allocation_count_(0) {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { NOTREACHED(); - return NULL; + return scoped_ptr<DatagramClientSocket>(); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, const NetLog::Source& /*source*/) OVERRIDE { allocation_count_++; - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -294,7 +296,8 @@ class TestConnectJob : public ConnectJob { AddressList ignored; client_socket_factory_->CreateTransportClientSocket( ignored, NULL, net::NetLog::Source()); - set_socket(new MockClientSocket(net_log().net_log())); + SetSocket( + scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log()))); switch (job_type_) { case kMockJob: return DoConnect(true /* successful */, false /* sync */, @@ -373,7 +376,7 @@ class TestConnectJob : public ConnectJob { return ERR_IO_PENDING; default: NOTREACHED(); - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); return ERR_FAILED; } } @@ -386,7 +389,7 @@ class TestConnectJob : public ConnectJob { result = ERR_PROXY_AUTH_REQUESTED; } else { result = ERR_CONNECTION_FAILED; - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); } if (was_async) @@ -430,7 +433,7 @@ class TestConnectJobFactory // ConnectJobFactory implementation. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const TestClientSocketPoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE { @@ -440,13 +443,13 @@ class TestConnectJobFactory job_type = job_types_->front(); job_types_->pop_front(); } - return new TestConnectJob(job_type, - group_name, - request, - timeout_duration_, - delegate, - client_socket_factory_, - net_log_); + return scoped_ptr<ConnectJob>(new TestConnectJob(job_type, + group_name, + request, + timeout_duration_, + delegate, + client_socket_factory_, + net_log_)); } virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { @@ -509,9 +512,9 @@ class TestClientSocketPool : public ClientSocketPool { virtual void ReleaseSocket( const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE { - base_.ReleaseSocket(group_name, socket, id); + base_.ReleaseSocket(group_name, socket.Pass(), id); } virtual void FlushWithError(int error) OVERRIDE { @@ -630,10 +633,10 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE { result_ = result; - scoped_ptr<StreamSocket> socket(job->ReleaseSocket()); + scoped_ptr<ConnectJob> owned_job(job); + scoped_ptr<StreamSocket> socket = owned_job->PassSocket(); // socket.get() should be NULL iff result != OK - EXPECT_EQ(socket.get() == NULL, result != OK); - delete job; + EXPECT_EQ(socket == NULL, result != OK); have_result_ = true; if (waiting_for_result_) base::MessageLoop::current()->Quit(); diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 8b2bdfc..159f62e 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -657,37 +657,39 @@ void MockClientSocketFactory::ResetNextMockIndexes() { mock_ssl_data_.ResetNextIndex(); } -DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( +scoped_ptr<DatagramClientSocket> +MockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); - MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } -StreamSocket* MockClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); - MockTCPClientSocket* socket = - new MockTCPClientSocket(addresses, net_log, data_provider); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockTCPClientSocket> socket( + new MockTCPClientSocket(addresses, net_log, data_provider)); + data_provider->set_socket(socket.get()); + return socket.PassAs<StreamSocket>(); } -SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, +scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - mock_ssl_data_.GetNext()); - return socket; + return scoped_ptr<SSLClientSocket>( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); } void MockClientSocketFactory::ClearSSLSessionCache() { @@ -1278,7 +1280,7 @@ void DeterministicMockTCPClientSocket::OnConnectComplete( // static void MockSSLClientSocket::ConnectCallback( - MockSSLClientSocket *ssl_client_socket, + MockSSLClientSocket* ssl_client_socket, const CompletionCallback& callback, int rv) { if (rv == OK) @@ -1287,7 +1289,7 @@ void MockSSLClientSocket::ConnectCallback( } MockSSLClientSocket::MockSSLClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_port_pair, const SSLConfig& ssl_config, SSLSocketDataProvider* data) @@ -1295,7 +1297,7 @@ MockSSLClientSocket::MockSSLClientSocket( // Have to use the right BoundNetLog for LoadTimingInfo regression // tests. transport_socket->socket()->NetLog()), - transport_(transport_socket), + transport_(transport_socket.Pass()), data_(data), is_npn_state_set_(false), new_npn_value_(false), @@ -1664,10 +1666,10 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { } MockTransportClientSocketPool::MockConnectJob::MockConnectJob( - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, const CompletionCallback& callback) - : socket_(socket), + : socket_(socket.Pass()), handle_(handle), user_callback_(callback) { } @@ -1698,7 +1700,7 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { if (!socket_.get()) return; if (rv == OK) { - handle_->set_socket(socket_.release()); + handle_->SetSocket(socket_.Pass()); // Needed for socket pool tests that layer other sockets on top of mock // sockets. @@ -1740,9 +1742,10 @@ int MockTransportClientSocketPool::RequestSocket( const std::string& group_name, const void* socket_params, RequestPriority priority, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { - StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket( - AddressList(), net_log.net_log(), net::NetLog::Source()); - MockConnectJob* job = new MockConnectJob(socket, handle, callback); + scoped_ptr<StreamSocket> socket = + client_socket_factory_->CreateTransportClientSocket( + AddressList(), net_log.net_log(), net::NetLog::Source()); + MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback); job_list_.push_back(job); handle->set_pool_id(1); return job->Connect(); @@ -1759,11 +1762,12 @@ void MockTransportClientSocketPool::CancelRequest(const std::string& group_name, } } -void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { +void MockTransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { EXPECT_EQ(1, id); release_count_++; - delete socket; } DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} @@ -1791,42 +1795,45 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory:: return ssl_client_sockets_[index]; } -DatagramClientSocket* +scoped_ptr<DatagramClientSocket> DeterministicMockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockUDPClientSocket* socket = - new DeterministicMockUDPClientSocket(net_log, data_provider); + scoped_ptr<DeterministicMockUDPClientSocket> socket( + new DeterministicMockUDPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); - udp_client_sockets().push_back(socket); - return socket; + udp_client_sockets().push_back(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } -StreamSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<StreamSocket> +DeterministicMockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockTCPClientSocket* socket = - new DeterministicMockTCPClientSocket(net_log, data_provider); + scoped_ptr<DeterministicMockTCPClientSocket> socket( + new DeterministicMockTCPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); - tcp_client_sockets().push_back(socket); - return socket; + tcp_client_sockets().push_back(socket.get()); + return socket.PassAs<StreamSocket>(); } -SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, +scoped_ptr<SSLClientSocket> +DeterministicMockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - mock_ssl_data_.GetNext()); - ssl_client_sockets_.push_back(socket); - return socket; + scoped_ptr<MockSSLClientSocket> socket( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); + ssl_client_sockets_.push_back(socket.get()); + return socket.PassAs<SSLClientSocket>(); } void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { @@ -1859,8 +1866,9 @@ void MockSOCKSClientSocketPool::CancelRequest( } void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - return transport_pool_->ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id); } const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 6afe170..a888249 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -592,17 +592,17 @@ class MockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE; @@ -857,7 +857,7 @@ class DeterministicMockTCPClientSocket class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { public: MockSSLClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLSocketDataProvider* socket); @@ -1049,7 +1049,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { public: class MockConnectJob { public: - MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle, + MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, const CompletionCallback& callback); ~MockConnectJob(); @@ -1088,7 +1088,8 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) OVERRIDE; + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; private: ClientSocketFactory* client_socket_factory_; @@ -1123,17 +1124,17 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE; @@ -1170,7 +1171,8 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) OVERRIDE; + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; private: TransportClientSocketPool* const transport_pool_; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index 8e329d1..537b584 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -28,18 +28,18 @@ COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4); COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6); SOCKS5ClientSocket::SOCKS5ClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info) : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, base::Unretained(this))), - transport_(transport_socket), + transport_(transport_socket.Pass()), next_state_(STATE_NONE), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), read_header_size(kReadHeaderSize), host_request_info_(req_info), - net_log_(transport_socket->socket()->NetLog()) { + net_log_(transport_->socket()->NetLog()) { } SOCKS5ClientSocket::~SOCKS5ClientSocket() { diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index 28e829e..4521624 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -28,16 +28,13 @@ class BoundNetLog; // Currently no SOCKSv5 authentication is supported. class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { public: - // Takes ownership of the |transport_socket|, which should already be - // connected by the time Connect() is called. - // // |req_info| contains the hostname and port to which the socket above will // communicate to via the SOCKS layer. // // Although SOCKS 5 supports 3 different modes of addressing, we will // always pass it a hostname. This means the DNS resolving is done // proxy side. - SOCKS5ClientSocket(ClientSocketHandle* transport_socket, + SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info); // On destruction Disconnect() is called. diff --git a/net/socket/socks5_client_socket_unittest.cc b/net/socket/socks5_client_socket_unittest.cc index 630884f..4c9240f 100644 --- a/net/socket/socks5_client_socket_unittest.cc +++ b/net/socket/socks5_client_socket_unittest.cc @@ -32,13 +32,13 @@ class SOCKS5ClientSocketTest : public PlatformTest { public: SOCKS5ClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. - SOCKS5ClientSocket* BuildMockSocket(MockRead reads[], - size_t reads_count, - MockWrite writes[], - size_t writes_count, - const std::string& hostname, - int port, - NetLog* net_log); + scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count, + const std::string& hostname, + int port, + NetLog* net_log); virtual void SetUp(); @@ -77,7 +77,7 @@ void SOCKS5ClientSocketTest::SetUp() { ASSERT_EQ(OK, rv); } -SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( +scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket( MockRead reads[], size_t reads_count, MockWrite writes[], @@ -99,10 +99,10 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); // |connection| takes ownership of |tcp_sock_|, but keep a // non-owning pointer to it. - connection->set_socket(tcp_sock_); - return new SOCKS5ClientSocket( - connection.release(), - HostResolver::RequestInfo(HostPortPair(hostname, port))); + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket( + connection.Pass(), + HostResolver::RequestInfo(HostPortPair(hostname, port)))); } // Tests a complete SOCKS5 handshake and the disconnection. @@ -130,9 +130,9 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength), MockRead(ASYNC, payload_read.data(), payload_read.size()) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - "localhost", 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + "localhost", 80, &net_log_); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -202,9 +202,9 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, NULL); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(OK, rv); @@ -224,9 +224,9 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { // Create a SOCKS socket, with mock transport socket. MockWrite data_writes[] = {MockWrite()}; MockRead data_reads[] = {MockRead()}; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - large_host_name, 80, NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + large_host_name, 80, NULL); // Try to connect -- should fail (without having read/written anything to // the transport socket first) because the hostname is too long. @@ -260,9 +260,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -291,9 +291,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead(ASYNC, partial1, arraysize(partial1)), MockRead(ASYNC, partial2, arraysize(partial2)), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -321,9 +321,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); CapturingNetLog::CapturedEntryList net_log_entries; @@ -352,9 +352,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { kSOCKS5OkResponseLength - kSplitPoint) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); CapturingNetLog::CapturedEntryList net_log_entries; diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index cefcd3f..1941fdb 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -55,17 +55,18 @@ struct SOCKS4ServerResponse { COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, socks4_server_response_struct_wrong_size); -SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver) - : transport_(transport_socket), +SOCKSClientSocket::SOCKSClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info, + HostResolver* host_resolver) + : transport_(transport_socket.Pass()), next_state_(STATE_NONE), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), host_resolver_(host_resolver), host_request_info_(req_info), - net_log_(transport_socket->socket()->NetLog()) { + net_log_(transport_->socket()->NetLog()) { } SOCKSClientSocket::~SOCKSClientSocket() { diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index a8b3b71..285c75e 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -27,12 +27,9 @@ class BoundNetLog; // The SOCKS client socket implementation class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { public: - // Takes ownership of the |transport_socket|, which should already be - // connected by the time Connect() is called. - // // |req_info| contains the hostname and port to which the socket above will // communicate to via the socks layer. For testing the referrer is optional. - SOCKSClientSocket(ClientSocketHandle* transport_socket, + SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info, HostResolver* host_resolver); diff --git a/net/socket/socks_client_socket_pool.cc b/net/socket/socks_client_socket_pool.cc index d740e5b..e49eaba 100644 --- a/net/socket/socks_client_socket_pool.cc +++ b/net/socket/socks_client_socket_pool.cc @@ -140,10 +140,10 @@ int SOCKSConnectJob::DoSOCKSConnect() { // Add a SOCKS connection on top of the tcp socket. if (socks_params_->is_socks_v5()) { - socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(), + socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), socks_params_->destination())); } else { - socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(), + socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), socks_params_->destination(), resolver_)); } @@ -157,7 +157,7 @@ int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { return result; } - set_socket(socket_.release()); + SetSocket(socket_.Pass()); return result; } @@ -166,17 +166,18 @@ int SOCKSConnectJob::ConnectInternal() { return DoLoop(OK); } -ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> +SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new SOCKSConnectJob(group_name, - request.params(), - ConnectionTimeout(), - transport_pool_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, + request.params(), + ConnectionTimeout(), + transport_pool_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -238,8 +239,9 @@ void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, } void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void SOCKSClientSocketPool::FlushWithError(int error) { diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h index 86609a1..fe69a78 100644 --- a/net/socket/socks_client_socket_pool.h +++ b/net/socket/socks_client_socket_pool.h @@ -134,7 +134,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; @@ -183,7 +183,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool virtual ~SOCKSConnectJobFactory() {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc index 640c4f1..8c30838 100644 --- a/net/socket/socks_client_socket_unittest.cc +++ b/net/socket/socks_client_socket_unittest.cc @@ -4,6 +4,7 @@ #include "net/socket/socks_client_socket.h" +#include "base/memory/scoped_ptr.h" #include "net/base/address_list.h" #include "net/base/net_log.h" #include "net/base/net_log_unittest.h" @@ -27,11 +28,12 @@ class SOCKSClientSocketTest : public PlatformTest { public: SOCKSClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. - SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count, - MockWrite writes[], size_t writes_count, - HostResolver* host_resolver, - const std::string& hostname, int port, - NetLog* net_log); + scoped_ptr<SOCKSClientSocket> BuildMockSocket( + MockRead reads[], size_t reads_count, + MockWrite writes[], size_t writes_count, + HostResolver* host_resolver, + const std::string& hostname, int port, + NetLog* net_log); virtual void SetUp(); protected: @@ -54,7 +56,7 @@ void SOCKSClientSocketTest::SetUp() { PlatformTest::SetUp(); } -SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( +scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( MockRead reads[], size_t reads_count, MockWrite writes[], @@ -78,11 +80,11 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); // |connection| takes ownership of |tcp_sock_|, but keep a // non-owning pointer to it. - connection->set_socket(tcp_sock_); - return new SOCKSClientSocket( - connection.release(), + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket( + connection.Pass(), HostResolver::RequestInfo(HostPortPair(hostname, port)), - host_resolver); + host_resolver)); } // Implementation of HostResolver that never completes its resolve request. @@ -141,11 +143,11 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { MockRead(ASYNC, payload_read.data(), payload_read.size()) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -217,11 +219,11 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { arraysize(tests[i].fail_reply)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -254,11 +256,11 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -292,11 +294,11 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -324,11 +326,11 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { MockRead(SYNCHRONOUS, 0) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -354,11 +356,11 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) { CapturingNetLog log; - user_sock_.reset(BuildMockSocket(NULL, 0, - NULL, 0, - host_resolver_.get(), - hostname, 80, - &log)); + user_sock_ = BuildMockSocket(NULL, 0, + NULL, 0, + host_resolver_.get(), + hostname, 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -385,11 +387,11 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) }; MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hanging_resolver.get(), - "foo", 80, - NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hanging_resolver.get(), + "foo", 80, + NULL); // Start connecting (will get stuck waiting for the host to resolve). int rv = user_sock_->Connect(callback_.callback()); diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 6274e64..acc1b0de 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -2751,12 +2751,12 @@ void SSLClientSocketNSS::Core::SetChannelIDProvided() { SSLClientSocketNSS::SSLClientSocketNSS( base::SequencedTaskRunner* nss_task_runner, - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) : nss_task_runner_(nss_task_runner), - transport_(transport_socket), + transport_(transport_socket.Pass()), host_and_port_(host_and_port), ssl_config_(ssl_config), cert_verifier_(context.cert_verifier), @@ -2765,7 +2765,7 @@ SSLClientSocketNSS::SSLClientSocketNSS( completed_handshake_(false), next_handshake_state_(STATE_NONE), nss_fd_(NULL), - net_log_(transport_socket->socket()->NetLog()), + net_log_(transport_->socket()->NetLog()), transport_security_state_(context.transport_security_state), valid_thread_id_(base::kInvalidThreadId) { EnterFunction(""); diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index fed8ef7..b41d28d 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -59,7 +59,7 @@ class SSLClientSocketNSS : public SSLClientSocket { // behaviour is desired, for performance or compatibility, the current task // runner should be supplied instead. SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner, - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index 1431bc6..4591cec 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -425,7 +425,7 @@ void SSLClientSocket::ClearSessionCache() { } SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) @@ -439,14 +439,14 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( cert_verifier_(context.cert_verifier), ssl_(NULL), transport_bio_(NULL), - transport_(transport_socket), + transport_(transport_socket.Pass()), host_and_port_(host_and_port), ssl_config_(ssl_config), ssl_session_cache_shard_(context.ssl_session_cache_shard), trying_cached_session_(false), next_handshake_state_(STATE_NONE), npn_status_(kNextProtoUnsupported), - net_log_(transport_socket->socket()->NetLog()) { + net_log_(transport_->socket()->NetLog()) { } SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 520f432..f66d95c 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -41,7 +41,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // The given hostname will be compared with the name(s) in the server's // certificate during the SSL handshake. ssl_config specifies the SSL // settings. - SSLClientSocketOpenSSL(ClientSocketHandle* transport_socket, + SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); diff --git a/net/socket/ssl_client_socket_openssl_unittest.cc b/net/socket/ssl_client_socket_openssl_unittest.cc index 7da8625..04f8999 100644 --- a/net/socket/ssl_client_socket_openssl_unittest.cc +++ b/net/socket/ssl_client_socket_openssl_unittest.cc @@ -107,13 +107,13 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { } protected: - SSLClientSocket* CreateSSLClientSocket( - StreamSocket* transport_socket, + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config) { scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->set_socket(transport_socket); - return socket_factory_->CreateSSLClientSocket(connection.release(), + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket(connection.Pass(), host_and_port, ssl_config, context_); @@ -166,9 +166,9 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { // itself was a success. bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config, int* result) { - sock_.reset(CreateSSLClientSocket(transport_.release(), - test_server_->host_port_pair(), - ssl_config)); + sock_ = CreateSSLClientSocket(transport_.Pass(), + test_server_->host_port_pair(), + ssl_config); if (sock_->IsConnected()) { LOG(ERROR) << "SSL Socket prematurely connected"; diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc index fed268d..d07c76f 100644 --- a/net/socket/ssl_client_socket_pool.cc +++ b/net/socket/ssl_client_socket_pool.cc @@ -287,11 +287,11 @@ int SSLConnectJob::DoSSLConnect() { connect_timing_.ssl_start = base::TimeTicks::Now(); - ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( - transport_socket_handle_.release(), + ssl_socket_ = client_socket_factory_->CreateSSLClientSocket( + transport_socket_handle_.Pass(), params_->host_and_port(), params_->ssl_config(), - context_)); + context_); return ssl_socket_->Connect(callback_); } @@ -410,7 +410,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { } if (result == OK || IsCertificateError(result)) { - set_socket(ssl_socket_.release()); + SetSocket(ssl_socket_.PassAs<StreamSocket>()); } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { error_response_info_.cert_request_info = new SSLCertRequestInfo; ssl_socket_->GetSSLCertRequestInfo( @@ -527,14 +527,16 @@ SSLClientSocketPool::~SSLClientSocketPool() { ssl_config_service_->RemoveObserver(this); } -ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> +SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), - transport_pool_, socks_pool_, http_proxy_pool_, - client_socket_factory_, host_resolver_, - context_, delegate, net_log_); + return scoped_ptr<ConnectJob>( + new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), + transport_pool_, socks_pool_, http_proxy_pool_, + client_socket_factory_, host_resolver_, + context_, delegate, net_log_)); } base::TimeDelta @@ -572,8 +574,9 @@ void SSLClientSocketPool::CancelRequest(const std::string& group_name, } void SSLClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void SSLClientSocketPool::FlushWithError(int error) { diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h index bc54bc9..431a1b7c 100644 --- a/net/socket/ssl_client_socket_pool.h +++ b/net/socket/ssl_client_socket_pool.h @@ -204,7 +204,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; @@ -261,7 +261,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual ~SSLConnectJobFactory() {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index af7f9b9..f791928 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -508,13 +508,14 @@ class SSLClientSocketTest : public PlatformTest { } protected: - SSLClientSocket* CreateSSLClientSocket(StreamSocket* transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config) { + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config) { scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->set_socket(transport_socket); + connection->SetSocket(transport_socket.Pass()); return socket_factory_->CreateSSLClientSocket( - connection.release(), host_and_port, ssl_config, context_); + connection.Pass(), host_and_port, ssl_config, context_); } ClientSocketFactory* socket_factory_; @@ -552,14 +553,15 @@ TEST_F(SSLClientSocketTest, Connect) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -593,14 +595,15 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -636,14 +639,15 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -679,14 +683,15 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -737,7 +742,8 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); @@ -748,7 +754,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { ssl_config.client_cert = NULL; scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config)); + transport.Pass(), test_server.host_port_pair(), ssl_config)); EXPECT_FALSE(sock->IsConnected()); @@ -793,14 +799,15 @@ TEST_F(SSLClientSocketTest, Read) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -851,8 +858,8 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { TestCompletionCallback callback; scoped_ptr<StreamSocket> real_transport( new TCPClientSocket(addr, NULL, NetLog::Source())); - SynchronousErrorStreamSocket* transport = - new SynchronousErrorStreamSocket(real_transport.Pass()); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -860,8 +867,11 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config)); + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -878,7 +888,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { EXPECT_EQ(kRequestTextSize, rv); // Simulate an unclean/forcible shutdown. - transport->SetNextReadError(ERR_CONNECTION_RESET); + raw_transport->SetNextReadError(ERR_CONNECTION_RESET); scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); @@ -912,12 +922,14 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { TestCompletionCallback callback; scoped_ptr<StreamSocket> real_transport( new TCPClientSocket(addr, NULL, NetLog::Source())); - // Note: |error_socket|'s ownership is handed to |transport|, but the pointer + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer // is retained in order to configure additional errors. - SynchronousErrorStreamSocket* error_socket = - new SynchronousErrorStreamSocket(real_transport.Pass()); - FakeBlockingStreamSocket* transport = - new FakeBlockingStreamSocket(scoped_ptr<StreamSocket>(error_socket)); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -925,8 +937,10 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config)); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -940,8 +954,8 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { // Simulate an unclean/forcible shutdown on the underlying socket. // However, simulate this error asynchronously. - error_socket->SetNextWriteError(ERR_CONNECTION_RESET); - transport->SetNextWriteShouldBlock(); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + raw_transport->SetNextWriteShouldBlock(); // This write should complete synchronously, because the TLS ciphertext // can be created and placed into the outgoing buffers independent of the @@ -957,7 +971,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { // Now unblock the outgoing request, having it fail with the connection // being reset. - transport->UnblockWrite(); + raw_transport->UnblockWrite(); // Note: This will cause an inifite loop if this bug has regressed. Simply // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING @@ -986,14 +1000,15 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { TestCompletionCallback callback; // Used for everything except Write. - StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1049,12 +1064,14 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { TestCompletionCallback callback; scoped_ptr<StreamSocket> real_transport( new TCPClientSocket(addr, NULL, NetLog::Source())); - // Note: |error_socket|'s ownership is handed to |transport|, but the pointer + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer // is retained in order to configure additional errors. - SynchronousErrorStreamSocket* error_socket = - new SynchronousErrorStreamSocket(real_transport.Pass()); - FakeBlockingStreamSocket* transport = - new FakeBlockingStreamSocket(scoped_ptr<StreamSocket>(error_socket)); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1063,8 +1080,10 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - SSLClientSocket* sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config)); + scoped_ptr<SSLClientSocket> sock = + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1077,18 +1096,19 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { new StringIOBuffer(request_text), request_text.size())); // Simulate errors being returned from the underlying Read() and Write() ... - error_socket->SetNextReadError(ERR_CONNECTION_RESET); - error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); // ... but have those errors returned asynchronously. Because the Write() will // return first, this will trigger the error. - transport->SetNextReadShouldBlock(); - transport->SetNextWriteShouldBlock(); + raw_transport->SetNextReadShouldBlock(); + raw_transport->SetNextWriteShouldBlock(); // Enqueue a Read() before calling Write(), which should "hang" due to // the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return. - DeleteSocketCallback read_callback(sock); + SSLClientSocket* raw_sock = sock.get(); + DeleteSocketCallback read_callback(sock.release()); scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096)); - rv = sock->Read(read_buf.get(), 4096, read_callback.callback()); + rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback()); // Ensure things didn't complete synchronously, otherwise |sock| is invalid. ASSERT_EQ(ERR_IO_PENDING, rv); @@ -1111,9 +1131,9 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { // SSLClientSocketOpenSSL::Write() will not return until all of // |request_buffer| has been written to the underlying BIO (although not // necessarily the underlying transport). - rv = callback.GetResult(sock->Write(request_buffer.get(), - request_buffer->BytesRemaining(), - callback.callback())); + rv = callback.GetResult(raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback())); ASSERT_LT(0, rv); request_buffer->DidConsume(rv); @@ -1126,16 +1146,16 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { // Attempt to write the remaining data. NSS will not be able to consume the // application data because the internal buffers are full, while OpenSSL will // return that its blocked because the underlying transport is blocked. - rv = sock->Write(request_buffer.get(), - request_buffer->BytesRemaining(), - callback.callback()); + rv = raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback()); ASSERT_EQ(ERR_IO_PENDING, rv); ASSERT_FALSE(callback.have_result()); // Now unblock Write(), which will invoke OnSendComplete and (eventually) // call the Read() callback, deleting the socket and thus aborting calling // the Write() callback. - transport->UnblockWrite(); + raw_transport->UnblockWrite(); rv = read_callback.WaitForResult(); @@ -1161,14 +1181,15 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1215,13 +1236,16 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { scoped_ptr<StreamSocket> real_transport( new TCPClientSocket(addr, NULL, NetLog::Source())); - ReadBufferingStreamSocket* transport = - new ReadBufferingStreamSocket(real_transport.Pass()); + scoped_ptr<ReadBufferingStreamSocket> transport( + new ReadBufferingStreamSocket(real_transport.Pass())); + ReadBufferingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); ASSERT_EQ(OK, rv); - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + kDefaultSSLConfig)); rv = callback.GetResult(sock->Connect(callback.callback())); ASSERT_EQ(OK, rv); @@ -1246,7 +1270,7 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { // 15K was chosen because 15K is smaller than the 17K (max) read issued by // the SSLClientSocket implementation, and larger than the minimum amount // of ciphertext necessary to contain the 8K of plaintext requested below. - transport->SetBufferSize(15000); + raw_transport->SetBufferSize(15000); scoped_refptr<IOBuffer> buffer(new IOBuffer(8192)); rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback())); @@ -1263,14 +1287,15 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1313,14 +1338,15 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { TestCompletionCallback callback; CapturingNetLog log; log.SetLogLevel(NetLog::LOG_ALL); - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1398,14 +1424,15 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0); - StreamSocket* transport = new MockTCPClientSocket(addr, NULL, &data); + scoped_ptr<StreamSocket> transport( + new MockTCPClientSocket(addr, NULL, &data)); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1433,7 +1460,8 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); @@ -1444,7 +1472,7 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config)); + transport.Pass(), test_server.host_port_pair(), ssl_config)); EXPECT_FALSE(sock->IsConnected()); @@ -1499,17 +1527,18 @@ TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); - ClientSocketHandle* socket_handle = new ClientSocketHandle(); - socket_handle->set_socket(transport); + scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle()); + socket_handle->SetSocket(transport.Pass()); scoped_ptr<SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket(socket_handle, + socket_factory_->CreateSSLClientSocket(socket_handle.Pass(), test_server.host_port_pair(), kDefaultSSLConfig, context_)); @@ -1534,14 +1563,15 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { TestCompletionCallback callback; - StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1629,14 +1659,15 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); @@ -1688,14 +1719,15 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { TestCompletionCallback callback; CapturingNetLog log; - StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); diff --git a/net/socket/ssl_server_socket.h b/net/socket/ssl_server_socket.h index 52d53cb..8b607bf 100644 --- a/net/socket/ssl_server_socket.h +++ b/net/socket/ssl_server_socket.h @@ -6,6 +6,7 @@ #define NET_SOCKET_SSL_SERVER_SOCKET_H_ #include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" #include "net/socket/ssl_socket.h" @@ -52,8 +53,8 @@ NET_EXPORT void EnableSSLServerSockets(); // // The caller starts the SSL server handshake by calling Handshake on the // returned socket. -NET_EXPORT SSLServerSocket* CreateSSLServerSocket( - StreamSocket* socket, +NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, X509Certificate* certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index c2681d3..7e5d701 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -78,19 +78,20 @@ void EnableSSLServerSockets() { g_nss_ssl_server_init_singleton.Get(); } -SSLServerSocket* CreateSSLServerSocket( - StreamSocket* socket, +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, X509Certificate* cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) { DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" << "called yet!"; - return new SSLServerSocketNSS(socket, cert, key, ssl_config); + return scoped_ptr<SSLServerSocket>( + new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config)); } SSLServerSocketNSS::SSLServerSocketNSS( - StreamSocket* transport_socket, + scoped_ptr<StreamSocket> transport_socket, scoped_refptr<X509Certificate> cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) @@ -100,7 +101,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( user_write_buf_len_(0), nss_fd_(NULL), nss_bufs_(NULL), - transport_socket_(transport_socket), + transport_socket_(transport_socket.Pass()), ssl_config_(ssl_config), cert_(cert), next_handshake_state_(STATE_NONE), diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 17a1fc3..8bbb0e3 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -24,7 +24,7 @@ class SSLServerSocketNSS : public SSLServerSocket { public: // See comments on CreateSSLServerSocket for details of how these // parameters are used. - SSLServerSocketNSS(StreamSocket* socket, + SSLServerSocketNSS(scoped_ptr<StreamSocket> socket, scoped_refptr<X509Certificate> certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); diff --git a/net/socket/ssl_server_socket_openssl.cc b/net/socket/ssl_server_socket_openssl.cc index e0cf8bc..c327f2c 100644 --- a/net/socket/ssl_server_socket_openssl.cc +++ b/net/socket/ssl_server_socket_openssl.cc @@ -16,13 +16,13 @@ void EnableSSLServerSockets() { NOTIMPLEMENTED(); } -SSLServerSocket* CreateSSLServerSocket(StreamSocket* socket, - X509Certificate* certificate, - crypto::RSAPrivateKey* key, - const SSLConfig& ssl_config) { +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, + X509Certificate* certificate, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config) { NOTIMPLEMENTED(); - delete socket; - return NULL; + return scoped_ptr<SSLServerSocket>(); } } // namespace net diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index 63bf037..64c8549 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -304,8 +304,11 @@ class SSLServerSocketTest : public PlatformTest { protected: void Initialize() { - FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); - FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); + scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); + client_connection->SetSocket( + scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); + scoped_ptr<StreamSocket> server_socket( + new FakeSocket(&channel_2_, &channel_1_)); base::FilePath certs_dir(GetTestCertsDirectory()); @@ -344,13 +347,12 @@ class SSLServerSocketTest : public PlatformTest { net::SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); - scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->set_socket(fake_client_socket); - client_socket_.reset( + client_socket_ = socket_factory_->CreateSSLClientSocket( - connection.release(), host_and_pair, ssl_config, context)); - server_socket_.reset(net::CreateSSLServerSocket( - fake_server_socket, cert.get(), private_key.get(), net::SSLConfig())); + client_connection.Pass(), host_and_pair, ssl_config, context); + server_socket_ = net::CreateSSLServerSocket( + server_socket.Pass(), + cert.get(), private_key.get(), net::SSLConfig()); } FakeDataChannel channel_1_; diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc index 8255e98..6d0afac 100644 --- a/net/socket/transport_client_socket_pool.cc +++ b/net/socket/transport_client_socket_pool.cc @@ -190,8 +190,8 @@ int TransportConnectJob::DoResolveHostComplete(int result) { int TransportConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; - transport_socket_.reset(client_socket_factory_->CreateTransportClientSocket( - addresses_, net_log().net_log(), net_log().source())); + transport_socket_ = client_socket_factory_->CreateTransportClientSocket( + addresses_, net_log().net_log(), net_log().source()); int rv = transport_socket_->Connect( base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING && @@ -246,7 +246,7 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { 100); } } - set_socket(transport_socket_.release()); + SetSocket(transport_socket_.Pass()); fallback_timer_.Stop(); } else { // Be a bit paranoid and kill off the fallback members to prevent reuse. @@ -270,9 +270,9 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { fallback_addresses_.reset(new AddressList(addresses_)); MakeAddressListStartWithIPv4(fallback_addresses_.get()); - fallback_transport_socket_.reset( + fallback_transport_socket_ = client_socket_factory_->CreateTransportClientSocket( - *fallback_addresses_, net_log().net_log(), net_log().source())); + *fallback_addresses_, net_log().net_log(), net_log().source()); fallback_connect_start_time_ = base::TimeTicks::Now(); int rv = fallback_transport_socket_->Connect( base::Bind( @@ -317,7 +317,7 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), 100); - set_socket(fallback_transport_socket_.release()); + SetSocket(fallback_transport_socket_.Pass()); next_state_ = STATE_NONE; transport_socket_.reset(); } else { @@ -333,18 +333,19 @@ int TransportConnectJob::ConnectInternal() { return DoLoop(OK); } -ConnectJob* +scoped_ptr<ConnectJob> TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new TransportConnectJob(group_name, - request.params(), - ConnectionTimeout(), - client_socket_factory_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>( + new TransportConnectJob(group_name, + request.params(), + ConnectionTimeout(), + client_socket_factory_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -419,9 +420,9 @@ void TransportClientSocketPool::CancelRequest( void TransportClientSocketPool::ReleaseSocket( const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + base_.ReleaseSocket(group_name, socket.Pass(), id); } void TransportClientSocketPool::FlushWithError(int error) { diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h index bb53b3d..f07dc1f 100644 --- a/net/socket/transport_client_socket_pool.h +++ b/net/socket/transport_client_socket_pool.h @@ -156,7 +156,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; virtual bool IsStalled() const OVERRIDE; @@ -193,7 +193,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index dfa1151..c607a38 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -23,6 +23,7 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -340,16 +341,16 @@ class MockClientSocketFactory : public ClientSocketFactory { delay_(base::TimeDelta::FromMilliseconds( ClientSocketPool::kMaxConnectRetryIntervalMs)) {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { NOTREACHED(); - return NULL; + return scoped_ptr<DatagramClientSocket>(); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, const NetLog::Source& /* source */) OVERRIDE { @@ -363,34 +364,41 @@ class MockClientSocketFactory : public ClientSocketFactory { switch (type) { case MOCK_CLIENT_SOCKET: - return new MockClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); case MOCK_FAILING_CLIENT_SOCKET: - return new MockFailingClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockFailingClientSocket(addresses, net_log_)); case MOCK_PENDING_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, false, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, base::TimeDelta(), net_log_)); case MOCK_PENDING_FAILING_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, false, false, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, false, false, base::TimeDelta(), net_log_)); case MOCK_DELAYED_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, false, delay_, net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, delay_, net_log_)); case MOCK_STALLED_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, true, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, true, base::TimeDelta(), net_log_)); default: NOTREACHED(); - return new MockClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); } } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { diff --git a/net/socket/transport_client_socket_unittest.cc b/net/socket/transport_client_socket_unittest.cc index 2f75e74..5c5a303 100644 --- a/net/socket/transport_client_socket_unittest.cc +++ b/net/socket/transport_client_socket_unittest.cc @@ -130,10 +130,10 @@ void TransportClientSocketTest::SetUp() { CHECK_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); CHECK_EQ(rv, OK); - sock_.reset( + sock_ = socket_factory_->CreateTransportClientSocket(addr, &net_log_, - NetLog::Source())); + NetLog::Source()); } int TransportClientSocketTest::DrainClientSocket( diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index ab4c3fd..c549fcb 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -97,6 +97,7 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate) proxy_mode_(kDirectConnection), proxy_url_(url), pac_request_(NULL), + connection_(new ClientSocketHandle), privacy_mode_(kPrivacyModeDisabled), // Unretained() is required; without it, Bind() creates a circular // dependency and the SocketStream object will not be freed. @@ -206,8 +207,10 @@ bool SocketStream::SendData(const char* data, int len) { << "The current base::MessageLoop must be TYPE_IO"; DCHECK_GT(len, 0); - if (!socket_.get() || !socket_->IsConnected() || next_state_ == STATE_NONE) + if (!connection_->socket() || + !connection_->socket()->IsConnected() || next_state_ == STATE_NONE) { return false; + } int total_buffered_bytes = len; if (current_write_buf_.get()) { @@ -265,7 +268,7 @@ void SocketStream::RestartWithAuth(const AuthCredentials& credentials) { DCHECK_EQ(base::MessageLoop::TYPE_IO, base::MessageLoop::current()->type()) << "The current base::MessageLoop must be TYPE_IO"; DCHECK(proxy_auth_controller_.get()); - if (!socket_.get()) { + if (!connection_->socket()) { DVLOG(1) << "Socket is closed before restarting with auth."; return; } @@ -370,7 +373,7 @@ void SocketStream::Finish(int result) { } int SocketStream::DidEstablishConnection() { - if (!socket_.get() || !socket_->IsConnected()) { + if (!connection_->socket() || !connection_->socket()->IsConnected()) { next_state_ = STATE_CLOSE; return ERR_CONNECTION_FAILED; } @@ -731,11 +734,12 @@ int SocketStream::DoTcpConnect(int result) { } next_state_ = STATE_TCP_CONNECT_COMPLETE; DCHECK(factory_); - socket_.reset(factory_->CreateTransportClientSocket(addresses_, - net_log_.net_log(), - net_log_.source())); + connection_->SetSocket( + factory_->CreateTransportClientSocket(addresses_, + net_log_.net_log(), + net_log_.source())); metrics_->OnStartConnection(); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoTcpConnectComplete(int result) { @@ -820,7 +824,8 @@ int SocketStream::DoWriteTunnelHeaders() { int buf_len = static_cast<int>(tunnel_request_headers_->headers_.size() - tunnel_request_headers_bytes_sent_); DCHECK_GT(buf_len, 0); - return socket_->Write(tunnel_request_headers_.get(), buf_len, io_callback_); + return connection_->socket()->Write( + tunnel_request_headers_.get(), buf_len, io_callback_); } int SocketStream::DoWriteTunnelHeadersComplete(int result) { @@ -863,7 +868,8 @@ int SocketStream::DoReadTunnelHeaders() { tunnel_response_headers_->SetDataOffset(tunnel_response_headers_len_); CHECK(tunnel_response_headers_->data()); - return socket_->Read(tunnel_response_headers_.get(), buf_len, io_callback_); + return connection_->socket()->Read( + tunnel_response_headers_.get(), buf_len, io_callback_); } int SocketStream::DoReadTunnelHeadersComplete(int result) { @@ -957,16 +963,17 @@ int SocketStream::DoSOCKSConnect() { HostResolver::RequestInfo req_info(HostPortPair::FromURL(url_)); DCHECK(!proxy_info_.is_empty()); - scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->set_socket(socket_.release()); + scoped_ptr<StreamSocket> s; if (proxy_info_.proxy_server().scheme() == ProxyServer::SCHEME_SOCKS5) { - socket_.reset(new SOCKS5ClientSocket(connection.release(), req_info)); + s.reset(new SOCKS5ClientSocket(connection_.Pass(), req_info)); } else { - socket_.reset(new SOCKSClientSocket( - connection.release(), req_info, context_->host_resolver())); + s.reset(new SOCKSClientSocket( + connection_.Pass(), req_info, context_->host_resolver())); } + connection_.reset(new ClientSocketHandle); + connection_->SetSocket(s.Pass()); metrics_->OnCountConnectionType(SocketStreamMetrics::SOCKS_CONNECTION); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoSOCKSConnectComplete(int result) { @@ -989,16 +996,16 @@ int SocketStream::DoSecureProxyConnect() { ssl_context.cert_verifier = context_->cert_verifier(); ssl_context.transport_security_state = context_->transport_security_state(); ssl_context.server_bound_cert_service = context_->server_bound_cert_service(); - scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->set_socket(socket_.release()); - socket_.reset(factory_->CreateSSLClientSocket( - connection.release(), + scoped_ptr<StreamSocket> socket(factory_->CreateSSLClientSocket( + connection_.Pass(), proxy_info_.proxy_server().host_port_pair(), proxy_ssl_config_, ssl_context)); + connection_.reset(new ClientSocketHandle); + connection_->SetSocket(socket.Pass()); next_state_ = STATE_SECURE_PROXY_CONNECT_COMPLETE; metrics_->OnCountConnectionType(SocketStreamMetrics::SECURE_PROXY_CONNECTION); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoSecureProxyConnectComplete(int result) { @@ -1030,7 +1037,7 @@ int SocketStream::DoSecureProxyHandleCertError(int result) { int SocketStream::DoSecureProxyHandleCertErrorComplete(int result) { DCHECK_EQ(STATE_NONE, next_state_); if (result == OK) { - if (!socket_->IsConnectedAndIdle()) + if (!connection_->socket()->IsConnectedAndIdle()) return AllowCertErrorForReconnection(&proxy_ssl_config_); next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; } else { @@ -1045,15 +1052,16 @@ int SocketStream::DoSSLConnect() { ssl_context.cert_verifier = context_->cert_verifier(); ssl_context.transport_security_state = context_->transport_security_state(); ssl_context.server_bound_cert_service = context_->server_bound_cert_service(); - scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->set_socket(socket_.release()); - socket_.reset(factory_->CreateSSLClientSocket(connection.release(), - HostPortPair::FromURL(url_), - server_ssl_config_, - ssl_context)); + scoped_ptr<StreamSocket> socket( + factory_->CreateSSLClientSocket(connection_.Pass(), + HostPortPair::FromURL(url_), + server_ssl_config_, + ssl_context)); + connection_.reset(new ClientSocketHandle); + connection_->SetSocket(socket.Pass()); next_state_ = STATE_SSL_CONNECT_COMPLETE; metrics_->OnCountConnectionType(SocketStreamMetrics::SSL_CONNECTION); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoSSLConnectComplete(int result) { @@ -1089,7 +1097,7 @@ int SocketStream::DoSSLHandleCertErrorComplete(int result) { // we should take care of TLS NPN extension here. if (result == OK) { - if (!socket_->IsConnectedAndIdle()) + if (!connection_->socket()->IsConnectedAndIdle()) return AllowCertErrorForReconnection(&server_ssl_config_); result = DidEstablishConnection(); } else { @@ -1103,7 +1111,7 @@ int SocketStream::DoReadWrite(int result) { next_state_ = STATE_CLOSE; return result; } - if (!socket_.get() || !socket_->IsConnected()) { + if (!connection_->socket() || !connection_->socket()->IsConnected()) { next_state_ = STATE_CLOSE; return ERR_CONNECTION_CLOSED; } @@ -1112,7 +1120,7 @@ int SocketStream::DoReadWrite(int result) { // let's close the socket. // We don't care about receiving data after the socket is closed. if (closing_ && !current_write_buf_.get() && pending_write_bufs_.empty()) { - socket_->Disconnect(); + connection_->socket()->Disconnect(); next_state_ = STATE_CLOSE; return OK; } @@ -1124,7 +1132,7 @@ int SocketStream::DoReadWrite(int result) { if (!read_buf_.get()) { // No read pending and server didn't close the socket. read_buf_ = new IOBuffer(kReadBufferSize); - result = socket_->Read( + result = connection_->socket()->Read( read_buf_.get(), kReadBufferSize, base::Bind(&SocketStream::OnReadCompleted, base::Unretained(this))); @@ -1163,7 +1171,7 @@ int SocketStream::DoReadWrite(int result) { pending_write_bufs_.pop_front(); } - result = socket_->Write( + result = connection_->socket()->Write( current_write_buf_.get(), current_write_buf_->BytesRemaining(), base::Bind(&SocketStream::OnWriteCompleted, base::Unretained(this))); @@ -1195,10 +1203,10 @@ int SocketStream::HandleCertificateRequest(int result, SSLConfig* ssl_config) { return result; } - DCHECK(socket_.get()); + DCHECK(connection_->socket()); scoped_refptr<SSLCertRequestInfo> cert_request_info = new SSLCertRequestInfo; SSLClientSocket* ssl_socket = - static_cast<SSLClientSocket*>(socket_.get()); + static_cast<SSLClientSocket*>(connection_->socket()); ssl_socket->GetSSLCertRequestInfo(cert_request_info.get()); HttpTransactionFactory* factory = context_->http_transaction_factory(); @@ -1244,7 +1252,8 @@ int SocketStream::AllowCertErrorForReconnection(SSLConfig* ssl_config) { // allowed bad certificates in |ssl_config|. // See also net/http/http_network_transaction.cc HandleCertificateError() and // RestartIgnoringLastError(). - SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(socket_.get()); + SSLClientSocket* ssl_socket = + static_cast<SSLClientSocket*>(connection_->socket()); SSLInfo ssl_info; ssl_socket->GetSSLInfo(&ssl_info); if (ssl_info.cert.get() == NULL || @@ -1266,8 +1275,8 @@ int SocketStream::AllowCertErrorForReconnection(SSLConfig* ssl_config) { bad_cert.cert_status = ssl_info.cert_status; ssl_config->allowed_bad_certs.push_back(bad_cert); // Restart connection ignoring the bad certificate. - socket_->Disconnect(); - socket_.reset(); + connection_->socket()->Disconnect(); + connection_->SetSocket(scoped_ptr<StreamSocket>()); next_state_ = STATE_TCP_CONNECT; return OK; } @@ -1293,7 +1302,8 @@ void SocketStream::DoRestartWithAuth() { int SocketStream::HandleCertificateError(int result) { DCHECK(IsCertificateError(result)); - SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(socket_.get()); + SSLClientSocket* ssl_socket = + static_cast<SSLClientSocket*>(connection_->socket()); DCHECK(ssl_socket); if (!context_) diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index 5004060..90aeb8c 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -28,13 +28,13 @@ namespace net { class AuthChallengeInfo; class CertVerifier; class ClientSocketFactory; +class ClientSocketHandle; class CookieOptions; class HostResolver; class HttpAuthController; class SSLInfo; class ServerBoundCertService; class SingleRequestHostResolver; -class StreamSocket; class SocketStreamMetrics; class TransportSecurityState; class URLRequestContext; @@ -364,7 +364,7 @@ class NET_EXPORT SocketStream scoped_ptr<SingleRequestHostResolver> resolver_; AddressList addresses_; - scoped_ptr<StreamSocket> socket_; + scoped_ptr<ClientSocketHandle> connection_; SSLConfig server_ssl_config_; SSLConfig proxy_ssl_config_; diff --git a/net/spdy/spdy_test_util_common.cc b/net/spdy/spdy_test_util_common.cc index 4383db0..9546684 100644 --- a/net/spdy/spdy_test_util_common.cc +++ b/net/spdy/spdy_test_util_common.cc @@ -649,8 +649,8 @@ base::WeakPtr<SpdySession> CreateFakeSpdySessionHelper( EXPECT_FALSE(HasSpdySession(pool, key)); base::WeakPtr<SpdySession> spdy_session; scoped_ptr<ClientSocketHandle> handle(new ClientSocketHandle()); - handle->set_socket(new FakeSpdySessionClientSocket( - expected_status == OK ? ERR_IO_PENDING : expected_status)); + handle->SetSocket(scoped_ptr<StreamSocket>(new FakeSpdySessionClientSocket( + expected_status == OK ? ERR_IO_PENDING : expected_status))); EXPECT_EQ( expected_status, pool->CreateAvailableSessionFromSocket( diff --git a/remoting/protocol/ssl_hmac_channel_authenticator.cc b/remoting/protocol/ssl_hmac_channel_authenticator.cc index d5db72f..20bfc53b 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator.cc +++ b/remoting/protocol/ssl_hmac_channel_authenticator.cc @@ -73,16 +73,16 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate( return; } - net::SSLConfig ssl_config; - net::SSLServerSocket* server_socket = - net::CreateSSLServerSocket(socket.release(), + scoped_ptr<net::SSLServerSocket> server_socket = + net::CreateSSLServerSocket(socket.Pass(), cert.get(), local_key_pair_->private_key(), - ssl_config); - socket_.reset(server_socket); - - result = server_socket->Handshake(base::Bind( - &SslHmacChannelAuthenticator::OnConnected, base::Unretained(this))); + net::SSLConfig()); + net::SSLServerSocket* raw_server_socket = server_socket.get(); + socket_ = server_socket.Pass(); + result = raw_server_socket->Handshake( + base::Bind(&SslHmacChannelAuthenticator::OnConnected, + base::Unretained(this))); } else { cert_verifier_.reset(net::CertVerifier::CreateDefault()); transport_security_state_.reset(new net::TransportSecurityState); @@ -105,10 +105,10 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate( context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle); - connection->set_socket(socket.release()); - socket_.reset( + connection->SetSocket(socket.Pass()); + socket_ = net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( - connection.release(), host_and_port, ssl_config, context)); + connection.Pass(), host_and_port, ssl_config, context); result = socket_->Connect( base::Bind(&SslHmacChannelAuthenticator::OnConnected, |