summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorakalin@chromium.org <akalin@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-08-15 00:13:44 +0000
committerakalin@chromium.org <akalin@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-08-15 00:13:44 +0000
commit18ccfdb7574c4868e37f53386454277e3e63bbe8 (patch)
treef1e177773e0b1cdc80deb3d755a8d7baf1233df6
parent582a8575e5259762d5cb7b517b928ed7fc75ca11 (diff)
downloadchromium_src-18ccfdb7574c4868e37f53386454277e3e63bbe8.zip
chromium_src-18ccfdb7574c4868e37f53386454277e3e63bbe8.tar.gz
chromium_src-18ccfdb7574c4868e37f53386454277e3e63bbe8.tar.bz2
[net] Use scoped_ptr<> consistently in ClientSocketFactory and related code
This will make it easier to modify ClientSocketFactory et al. to support reprioritization. This also fixes a few latent memory leaks in tests. Make SocketStream use a ClientSocketHandle instead of just a StreamSocket. Rename {set,release}_socket() to {Set,Pass}Socket(). BUG=166689 TBR=eroman@chromium.org, rsleevi@chromium.org, sergeyu@chromium.org Review URL: https://codereview.chromium.org/22995002 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@217707 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--chrome/browser/net/network_stats.cc17
-rw-r--r--content/browser/renderer_host/p2p/socket_host_tcp.cc10
-rw-r--r--content/browser/renderer_host/pepper/pepper_tcp_socket.cc8
-rw-r--r--jingle/glue/chrome_async_socket.cc10
-rw-r--r--jingle/glue/chrome_async_socket_unittest.cc8
-rw-r--r--jingle/glue/fake_ssl_client_socket.cc4
-rw-r--r--jingle/glue/fake_ssl_client_socket.h3
-rw-r--r--jingle/glue/fake_ssl_client_socket_unittest.cc7
-rw-r--r--jingle/glue/resolving_client_socket_factory.h7
-rw-r--r--jingle/glue/xmpp_client_socket_factory.cc26
-rw-r--r--jingle/glue/xmpp_client_socket_factory.h6
-rw-r--r--net/dns/address_sorter_posix_unittest.cc16
-rw-r--r--net/dns/dns_session_unittest.cc24
-rw-r--r--net/dns/dns_socket_pool.cc4
-rw-r--r--net/dns/dns_transaction_unittest.cc14
-rw-r--r--net/ftp/ftp_network_transaction.cc8
-rw-r--r--net/http/http_network_transaction_unittest.cc2
-rw-r--r--net/http/http_pipelined_host_forced.cc7
-rw-r--r--net/http/http_proxy_client_socket_pool.cc25
-rw-r--r--net/http/http_proxy_client_socket_pool.h4
-rw-r--r--net/http/http_stream_factory_impl_unittest.cc2
-rw-r--r--net/http/http_stream_parser_unittest.cc4
-rw-r--r--net/quic/quic_client_session.cc4
-rw-r--r--net/quic/quic_client_session.h3
-rw-r--r--net/quic/quic_client_session_test.cc6
-rw-r--r--net/quic/quic_http_stream_test.cc10
-rw-r--r--net/quic/quic_stream_factory.cc8
-rw-r--r--net/socket/buffered_write_stream_socket.cc4
-rw-r--r--net/socket/buffered_write_stream_socket.h6
-rw-r--r--net/socket/buffered_write_stream_socket_unittest.cc7
-rw-r--r--net/socket/client_socket_factory.cc32
-rw-r--r--net/socket/client_socket_factory.h9
-rw-r--r--net/socket/client_socket_handle.cc10
-rw-r--r--net/socket/client_socket_handle.h6
-rw-r--r--net/socket/client_socket_pool.h3
-rw-r--r--net/socket/client_socket_pool_base.cc75
-rw-r--r--net/socket/client_socket_pool_base.h43
-rw-r--r--net/socket/client_socket_pool_base_unittest.cc49
-rw-r--r--net/socket/socket_test_util.cc100
-rw-r--r--net/socket/socket_test_util.h26
-rw-r--r--net/socket/socks5_client_socket.cc6
-rw-r--r--net/socket/socks5_client_socket.h5
-rw-r--r--net/socket/socks5_client_socket_unittest.cc66
-rw-r--r--net/socket/socks_client_socket.cc11
-rw-r--r--net/socket/socks_client_socket.h5
-rw-r--r--net/socket/socks_client_socket_pool.cc28
-rw-r--r--net/socket/socks_client_socket_pool.h4
-rw-r--r--net/socket/socks_client_socket_unittest.cc92
-rw-r--r--net/socket/ssl_client_socket_nss.cc6
-rw-r--r--net/socket/ssl_client_socket_nss.h2
-rw-r--r--net/socket/ssl_client_socket_openssl.cc6
-rw-r--r--net/socket/ssl_client_socket_openssl.h2
-rw-r--r--net/socket/ssl_client_socket_openssl_unittest.cc14
-rw-r--r--net/socket/ssl_client_socket_pool.cc25
-rw-r--r--net/socket/ssl_client_socket_pool.h4
-rw-r--r--net/socket/ssl_client_socket_unittest.cc190
-rw-r--r--net/socket/ssl_server_socket.h5
-rw-r--r--net/socket/ssl_server_socket_nss.cc11
-rw-r--r--net/socket/ssl_server_socket_nss.h2
-rw-r--r--net/socket/ssl_server_socket_openssl.cc12
-rw-r--r--net/socket/ssl_server_socket_unittest.cc18
-rw-r--r--net/socket/transport_client_socket_pool.cc33
-rw-r--r--net/socket/transport_client_socket_pool.h4
-rw-r--r--net/socket/transport_client_socket_pool_unittest.cc42
-rw-r--r--net/socket/transport_client_socket_unittest.cc4
-rw-r--r--net/socket_stream/socket_stream.cc88
-rw-r--r--net/socket_stream/socket_stream.h4
-rw-r--r--net/spdy/spdy_test_util_common.cc4
-rw-r--r--remoting/protocol/ssl_hmac_channel_authenticator.cc22
69 files changed, 727 insertions, 605 deletions
diff --git a/chrome/browser/net/network_stats.cc b/chrome/browser/net/network_stats.cc
index 23a8a6a..1bc3161e 100644
--- a/chrome/browser/net/network_stats.cc
+++ b/chrome/browser/net/network_stats.cc
@@ -189,28 +189,25 @@ bool NetworkStats::DoConnect(int result) {
return false;
}
- net::DatagramClientSocket* udp_socket =
+ scoped_ptr<net::DatagramClientSocket> udp_socket =
socket_factory_->CreateDatagramClientSocket(
net::DatagramSocket::DEFAULT_BIND,
net::RandIntCallback(),
NULL,
net::NetLog::Source());
- if (!udp_socket) {
- TestPhaseComplete(SOCKET_CREATE_FAILED, net::ERR_INVALID_ARGUMENT);
- return false;
- }
- DCHECK(!socket_.get());
- socket_.reset(udp_socket);
+ DCHECK(udp_socket);
+ DCHECK(!socket_);
+ socket_ = udp_socket.Pass();
const net::IPEndPoint& endpoint = addresses_.front();
- int rv = udp_socket->Connect(endpoint);
+ int rv = socket_->Connect(endpoint);
if (rv < 0) {
TestPhaseComplete(CONNECT_FAILED, rv);
return false;
}
- udp_socket->SetSendBufferSize(kMaxUdpSendBufferSize);
- udp_socket->SetReceiveBufferSize(kMaxUdpReceiveBufferSize);
+ socket_->SetSendBufferSize(kMaxUdpSendBufferSize);
+ socket_->SetReceiveBufferSize(kMaxUdpReceiveBufferSize);
return ConnectComplete(rv);
}
diff --git a/content/browser/renderer_host/p2p/socket_host_tcp.cc b/content/browser/renderer_host/p2p/socket_host_tcp.cc
index 8e13bf8..314026e 100644
--- a/content/browser/renderer_host/p2p/socket_host_tcp.cc
+++ b/content/browser/renderer_host/p2p/socket_host_tcp.cc
@@ -136,7 +136,9 @@ void P2PSocketHostTcpBase::OnConnected(int result) {
StartTls();
} else {
if (IsPseudoTlsClientSocket(type_)) {
- socket_.reset(new jingle_glue::FakeSSLClientSocket(socket_.release()));
+ scoped_ptr<net::StreamSocket> transport_socket = socket_.Pass();
+ socket_.reset(
+ new jingle_glue::FakeSSLClientSocket(transport_socket.Pass()));
}
// If we are not doing TLS, we are ready to send data now.
@@ -155,7 +157,7 @@ void P2PSocketHostTcpBase::StartTls() {
scoped_ptr<net::ClientSocketHandle> socket_handle(
new net::ClientSocketHandle());
- socket_handle->set_socket(socket_.release());
+ socket_handle->SetSocket(socket_.Pass());
net::SSLClientSocketContext context;
context.cert_verifier = url_context_->GetURLRequestContext()->cert_verifier();
@@ -171,8 +173,8 @@ void P2PSocketHostTcpBase::StartTls() {
net::ClientSocketFactory::GetDefaultFactory();
DCHECK(socket_factory);
- socket_.reset(socket_factory->CreateSSLClientSocket(
- socket_handle.release(), dest_host_port_pair, ssl_config, context));
+ socket_ = socket_factory->CreateSSLClientSocket(
+ socket_handle.Pass(), dest_host_port_pair, ssl_config, context);
int status = socket_->Connect(
base::Bind(&P2PSocketHostTcpBase::ProcessTlsConnectDone,
base::Unretained(this)));
diff --git a/content/browser/renderer_host/pepper/pepper_tcp_socket.cc b/content/browser/renderer_host/pepper/pepper_tcp_socket.cc
index f05be87..05b2e09 100644
--- a/content/browser/renderer_host/pepper/pepper_tcp_socket.cc
+++ b/content/browser/renderer_host/pepper/pepper_tcp_socket.cc
@@ -141,16 +141,16 @@ void PepperTCPSocket::SSLHandshake(
connection_state_ = SSL_HANDSHAKE_IN_PROGRESS;
// TODO(raymes,rsleevi): Use trusted/untrusted certificates when connecting.
- net::ClientSocketHandle* handle = new net::ClientSocketHandle();
- handle->set_socket(socket_.release());
+ scoped_ptr<net::ClientSocketHandle> handle(new net::ClientSocketHandle());
+ handle->SetSocket(socket_.Pass());
net::ClientSocketFactory* factory =
net::ClientSocketFactory::GetDefaultFactory();
net::HostPortPair host_port_pair(server_name, server_port);
net::SSLClientSocketContext ssl_context;
ssl_context.cert_verifier = manager_->GetCertVerifier();
ssl_context.transport_security_state = manager_->GetTransportSecurityState();
- socket_.reset(factory->CreateSSLClientSocket(
- handle, host_port_pair, manager_->ssl_config(), ssl_context));
+ socket_ = factory->CreateSSLClientSocket(
+ handle.Pass(), host_port_pair, manager_->ssl_config(), ssl_context);
if (!socket_) {
LOG(WARNING) << "Failed to create an SSL client socket.";
OnSSLHandshakeCompleted(net::ERR_UNEXPECTED);
diff --git a/jingle/glue/chrome_async_socket.cc b/jingle/glue/chrome_async_socket.cc
index 39085e1..c14fb99 100644
--- a/jingle/glue/chrome_async_socket.cc
+++ b/jingle/glue/chrome_async_socket.cc
@@ -106,9 +106,9 @@ bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) {
net::HostPortPair dest_host_port_pair(address.hostname(), address.port());
- transport_socket_.reset(
+ transport_socket_ =
resolving_client_socket_factory_->CreateTransportClientSocket(
- dest_host_port_pair));
+ dest_host_port_pair);
int status = transport_socket_->Connect(
base::Bind(&ChromeAsyncSocket::ProcessConnectDone,
weak_ptr_factory_.GetWeakPtr()));
@@ -404,10 +404,10 @@ bool ChromeAsyncSocket::StartTls(const std::string& domain_name) {
DCHECK(transport_socket_.get());
scoped_ptr<net::ClientSocketHandle> socket_handle(
new net::ClientSocketHandle());
- socket_handle->set_socket(transport_socket_.release());
- transport_socket_.reset(
+ socket_handle->SetSocket(transport_socket_.Pass());
+ transport_socket_ =
resolving_client_socket_factory_->CreateSSLClientSocket(
- socket_handle.release(), net::HostPortPair(domain_name, 443)));
+ socket_handle.Pass(), net::HostPortPair(domain_name, 443));
int status = transport_socket_->Connect(
base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone,
weak_ptr_factory_.GetWeakPtr()));
diff --git a/jingle/glue/chrome_async_socket_unittest.cc b/jingle/glue/chrome_async_socket_unittest.cc
index ebb69a2..db3d2b0 100644
--- a/jingle/glue/chrome_async_socket_unittest.cc
+++ b/jingle/glue/chrome_async_socket_unittest.cc
@@ -113,20 +113,20 @@ class MockXmppClientSocketFactory : public ResolvingClientSocketFactory {
}
// ResolvingClientSocketFactory implementation.
- virtual net::StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket(
const net::HostPortPair& host_and_port) OVERRIDE {
return mock_client_socket_factory_->CreateTransportClientSocket(
address_list_, NULL, net::NetLog::Source());
}
- virtual net::SSLClientSocket* CreateSSLClientSocket(
- net::ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) OVERRIDE {
net::SSLClientSocketContext context;
context.cert_verifier = cert_verifier_.get();
context.transport_security_state = transport_security_state_.get();
return mock_client_socket_factory_->CreateSSLClientSocket(
- transport_socket, host_and_port, ssl_config_, context);
+ transport_socket.Pass(), host_and_port, ssl_config_, context);
}
private:
diff --git a/jingle/glue/fake_ssl_client_socket.cc b/jingle/glue/fake_ssl_client_socket.cc
index bf6d12a..9d722c7 100644
--- a/jingle/glue/fake_ssl_client_socket.cc
+++ b/jingle/glue/fake_ssl_client_socket.cc
@@ -77,8 +77,8 @@ base::StringPiece FakeSSLClientSocket::GetSslServerHello() {
}
FakeSSLClientSocket::FakeSSLClientSocket(
- net::StreamSocket* transport_socket)
- : transport_socket_(transport_socket),
+ scoped_ptr<net::StreamSocket> transport_socket)
+ : transport_socket_(transport_socket.Pass()),
next_handshake_state_(STATE_NONE),
handshake_completed_(false),
write_buf_(NewDrainableIOBufferWithSize(arraysize(kSslClientHello))),
diff --git a/jingle/glue/fake_ssl_client_socket.h b/jingle/glue/fake_ssl_client_socket.h
index 5bc4547..54a9e2f 100644
--- a/jingle/glue/fake_ssl_client_socket.h
+++ b/jingle/glue/fake_ssl_client_socket.h
@@ -36,8 +36,7 @@ namespace jingle_glue {
class FakeSSLClientSocket : public net::StreamSocket {
public:
- // Takes ownership of |transport_socket|.
- explicit FakeSSLClientSocket(net::StreamSocket* transport_socket);
+ explicit FakeSSLClientSocket(scoped_ptr<net::StreamSocket> transport_socket);
virtual ~FakeSSLClientSocket();
diff --git a/jingle/glue/fake_ssl_client_socket_unittest.cc b/jingle/glue/fake_ssl_client_socket_unittest.cc
index 5c061f3..f6d8fea 100644
--- a/jingle/glue/fake_ssl_client_socket_unittest.cc
+++ b/jingle/glue/fake_ssl_client_socket_unittest.cc
@@ -91,7 +91,7 @@ class FakeSSLClientSocketTest : public testing::Test {
virtual ~FakeSSLClientSocketTest() {}
- net::StreamSocket* MakeClientSocket() {
+ scoped_ptr<net::StreamSocket> MakeClientSocket() {
return mock_client_socket_factory_.CreateTransportClientSocket(
net::AddressList(), NULL, net::NetLog::Source());
}
@@ -269,7 +269,7 @@ class FakeSSLClientSocketTest : public testing::Test {
};
TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
- MockClientSocket* mock_client_socket = new MockClientSocket();
+ scoped_ptr<MockClientSocket> mock_client_socket(new MockClientSocket());
const int kReceiveBufferSize = 10;
const int kSendBufferSize = 20;
net::IPEndPoint ip_endpoint(net::IPAddressNumber(net::kIPv4AddressSize), 80);
@@ -284,7 +284,8 @@ TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
EXPECT_CALL(*mock_client_socket, SetOmniboxSpeculation());
// Takes ownership of |mock_client_socket|.
- FakeSSLClientSocket fake_ssl_client_socket(mock_client_socket);
+ FakeSSLClientSocket fake_ssl_client_socket(
+ mock_client_socket.PassAs<net::StreamSocket>());
fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize);
fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize);
EXPECT_EQ(kPeerAddress,
diff --git a/jingle/glue/resolving_client_socket_factory.h b/jingle/glue/resolving_client_socket_factory.h
index 5be8bc8..d1b9fc1 100644
--- a/jingle/glue/resolving_client_socket_factory.h
+++ b/jingle/glue/resolving_client_socket_factory.h
@@ -5,6 +5,7 @@
#ifndef JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_
#define JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_
+#include "base/memory/scoped_ptr.h"
namespace net {
class ClientSocketHandle;
@@ -23,11 +24,11 @@ class ResolvingClientSocketFactory {
public:
virtual ~ResolvingClientSocketFactory() { }
// Method to create a transport socket using a HostPortPair.
- virtual net::StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket(
const net::HostPortPair& host_and_port) = 0;
- virtual net::SSLClientSocket* CreateSSLClientSocket(
- net::ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) = 0;
};
diff --git a/jingle/glue/xmpp_client_socket_factory.cc b/jingle/glue/xmpp_client_socket_factory.cc
index b9e040d..4823ee5 100644
--- a/jingle/glue/xmpp_client_socket_factory.cc
+++ b/jingle/glue/xmpp_client_socket_factory.cc
@@ -8,6 +8,7 @@
#include "jingle/glue/fake_ssl_client_socket.h"
#include "jingle/glue/proxy_resolving_client_socket.h"
#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
#include "net/socket/ssl_client_socket.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_getter.h"
@@ -28,20 +29,25 @@ XmppClientSocketFactory::XmppClientSocketFactory(
XmppClientSocketFactory::~XmppClientSocketFactory() {}
-net::StreamSocket* XmppClientSocketFactory::CreateTransportClientSocket(
+scoped_ptr<net::StreamSocket>
+XmppClientSocketFactory::CreateTransportClientSocket(
const net::HostPortPair& host_and_port) {
// TODO(akalin): Use socket pools.
- net::StreamSocket* transport_socket = new ProxyResolvingClientSocket(
- NULL,
- request_context_getter_,
- ssl_config_,
- host_and_port);
+ scoped_ptr<net::StreamSocket> transport_socket(
+ new ProxyResolvingClientSocket(
+ NULL,
+ request_context_getter_,
+ ssl_config_,
+ host_and_port));
return (use_fake_ssl_client_socket_ ?
- new FakeSSLClientSocket(transport_socket) : transport_socket);
+ scoped_ptr<net::StreamSocket>(
+ new FakeSSLClientSocket(transport_socket.Pass())) :
+ transport_socket.Pass());
}
-net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket(
- net::ClientSocketHandle* transport_socket,
+scoped_ptr<net::SSLClientSocket>
+XmppClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) {
net::SSLClientSocketContext context;
context.cert_verifier =
@@ -52,7 +58,7 @@ net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket(
// TODO(rkn): context.server_bound_cert_service is NULL because the
// ServerBoundCertService class is not thread safe.
return client_socket_factory_->CreateSSLClientSocket(
- transport_socket, host_and_port, ssl_config_, context);
+ transport_socket.Pass(), host_and_port, ssl_config_, context);
}
diff --git a/jingle/glue/xmpp_client_socket_factory.h b/jingle/glue/xmpp_client_socket_factory.h
index c2a0d6a..4204c19 100644
--- a/jingle/glue/xmpp_client_socket_factory.h
+++ b/jingle/glue/xmpp_client_socket_factory.h
@@ -35,11 +35,11 @@ class XmppClientSocketFactory : public ResolvingClientSocketFactory {
virtual ~XmppClientSocketFactory();
// ResolvingClientSocketFactory implementation.
- virtual net::StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<net::StreamSocket> CreateTransportClientSocket(
const net::HostPortPair& host_and_port) OVERRIDE;
- virtual net::SSLClientSocket* CreateSSLClientSocket(
- net::ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<net::SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<net::ClientSocketHandle> transport_socket,
const net::HostPortPair& host_and_port) OVERRIDE;
private:
diff --git a/net/dns/address_sorter_posix_unittest.cc b/net/dns/address_sorter_posix_unittest.cc
index 96cbfc6..c451737 100644
--- a/net/dns/address_sorter_posix_unittest.cc
+++ b/net/dns/address_sorter_posix_unittest.cc
@@ -10,6 +10,8 @@
#include "net/base/net_util.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/client_socket_factory.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
#include "net/udp/datagram_client_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -90,27 +92,27 @@ class TestSocketFactory : public ClientSocketFactory {
TestSocketFactory() {}
virtual ~TestSocketFactory() {}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType,
const RandIntCallback&,
NetLog*,
const NetLog::Source&) OVERRIDE {
- return new TestUDPClientSocket(&mapping_);
+ return scoped_ptr<DatagramClientSocket>(new TestUDPClientSocket(&mapping_));
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList&,
NetLog*,
const NetLog::Source&) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<StreamSocket>();
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle*,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle>,
const HostPortPair&,
const SSLConfig&,
const SSLClientSocketContext&) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
NOTIMPLEMENTED();
diff --git a/net/dns/dns_session_unittest.cc b/net/dns/dns_session_unittest.cc
index 4662706..ed726f2 100644
--- a/net/dns/dns_session_unittest.cc
+++ b/net/dns/dns_session_unittest.cc
@@ -14,6 +14,8 @@
#include "net/dns/dns_protocol.h"
#include "net/dns/dns_socket_pool.h"
#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
@@ -24,26 +26,26 @@ class TestClientSocketFactory : public ClientSocketFactory {
public:
virtual ~TestClientSocketFactory();
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const net::NetLog::Source& source) OVERRIDE;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog*, const NetLog::Source&) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<StreamSocket>();
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
@@ -179,7 +181,8 @@ bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) {
return true;
}
-DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket(
+scoped_ptr<DatagramClientSocket>
+TestClientSocketFactory::CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
@@ -188,9 +191,10 @@ DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket(
// simplest SocketDataProvider with no data supplied.
SocketDataProvider* data_provider = new StaticSocketDataProvider();
data_providers_.push_back(data_provider);
- MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log);
- data_provider->set_socket(socket);
- return socket;
+ scoped_ptr<MockUDPClientSocket> socket(
+ new MockUDPClientSocket(data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
}
TestClientSocketFactory::~TestClientSocketFactory() {
diff --git a/net/dns/dns_socket_pool.cc b/net/dns/dns_socket_pool.cc
index 64570fc..7a7ecd6 100644
--- a/net/dns/dns_socket_pool.cc
+++ b/net/dns/dns_socket_pool.cc
@@ -76,8 +76,8 @@ scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
scoped_ptr<DatagramClientSocket> socket;
NetLog::Source no_source;
- socket.reset(socket_factory_->CreateDatagramClientSocket(
- kBindType, base::Bind(&base::RandInt), net_log_, no_source));
+ socket = socket_factory_->CreateDatagramClientSocket(
+ kBindType, base::Bind(&base::RandInt), net_log_, no_source);
if (socket.get()) {
int rv = socket->Connect((*nameservers_)[server_index]);
diff --git a/net/dns/dns_transaction_unittest.cc b/net/dns/dns_transaction_unittest.cc
index f9667ee..7040e44 100644
--- a/net/dns/dns_transaction_unittest.cc
+++ b/net/dns/dns_transaction_unittest.cc
@@ -180,21 +180,21 @@ class TestSocketFactory : public MockClientSocketFactory {
TestSocketFactory() : fail_next_socket_(false) {}
virtual ~TestSocketFactory() {}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const net::NetLog::Source& source) OVERRIDE {
if (fail_next_socket_) {
fail_next_socket_ = false;
- return new FailingUDPClientSocket(&empty_data_, net_log);
+ return scoped_ptr<DatagramClientSocket>(
+ new FailingUDPClientSocket(&empty_data_, net_log));
}
SocketDataProvider* data_provider = mock_data().GetNext();
- TestUDPClientSocket* socket = new TestUDPClientSocket(this,
- data_provider,
- net_log);
- data_provider->set_socket(socket);
- return socket;
+ scoped_ptr<TestUDPClientSocket> socket(
+ new TestUDPClientSocket(this, data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
}
void OnConnect(const IPEndPoint& endpoint) {
diff --git a/net/ftp/ftp_network_transaction.cc b/net/ftp/ftp_network_transaction.cc
index ccd6e2e..f9f7b82 100644
--- a/net/ftp/ftp_network_transaction.cc
+++ b/net/ftp/ftp_network_transaction.cc
@@ -663,8 +663,8 @@ int FtpNetworkTransaction::DoCtrlResolveHostComplete(int result) {
int FtpNetworkTransaction::DoCtrlConnect() {
next_state_ = STATE_CTRL_CONNECT_COMPLETE;
- ctrl_socket_.reset(socket_factory_->CreateTransportClientSocket(
- addresses_, net_log_.net_log(), net_log_.source()));
+ ctrl_socket_ = socket_factory_->CreateTransportClientSocket(
+ addresses_, net_log_.net_log(), net_log_.source());
net_log_.AddEvent(
NetLog::TYPE_FTP_CONTROL_CONNECTION,
ctrl_socket_->NetLog().source().ToEventParametersCallback());
@@ -1249,8 +1249,8 @@ int FtpNetworkTransaction::DoDataConnect() {
return Stop(rv);
data_address = AddressList::CreateFromIPAddress(
ip_endpoint.address(), data_connection_port_);
- data_socket_.reset(socket_factory_->CreateTransportClientSocket(
- data_address, net_log_.net_log(), net_log_.source()));
+ data_socket_ = socket_factory_->CreateTransportClientSocket(
+ data_address, net_log_.net_log(), net_log_.source());
net_log_.AddEvent(
NetLog::TYPE_FTP_DATA_CONNECTION,
data_socket_->NetLog().source().ToEventParametersCallback());
diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc
index 0968b14..d89ab54 100644
--- a/net/http/http_network_transaction_unittest.cc
+++ b/net/http/http_network_transaction_unittest.cc
@@ -452,7 +452,7 @@ class CaptureGroupNameSocketPool : public ParentPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) {}
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) {}
virtual void CloseIdleSockets() {}
virtual int IdleSocketCount() const {
diff --git a/net/http/http_pipelined_host_forced.cc b/net/http/http_pipelined_host_forced.cc
index 8179e86..8059d84 100644
--- a/net/http/http_pipelined_host_forced.cc
+++ b/net/http/http_pipelined_host_forced.cc
@@ -36,10 +36,9 @@ HttpPipelinedStream* HttpPipelinedHostForced::CreateStreamOnNewPipeline(
bool was_npn_negotiated,
NextProto protocol_negotiated) {
CHECK(!pipeline_.get());
- StreamSocket* wrapped_socket = connection->release_socket();
- BufferedWriteStreamSocket* buffered_socket = new BufferedWriteStreamSocket(
- wrapped_socket);
- connection->set_socket(buffered_socket);
+ scoped_ptr<BufferedWriteStreamSocket> buffered_socket(
+ new BufferedWriteStreamSocket(connection->PassSocket()));
+ connection->SetSocket(buffered_socket.PassAs<StreamSocket>());
pipeline_.reset(factory_->CreateNewPipeline(
connection, this, key_.origin(), used_ssl_config, used_proxy_info,
net_log, was_npn_negotiated, protocol_negotiated));
diff --git a/net/http/http_proxy_client_socket_pool.cc b/net/http/http_proxy_client_socket_pool.cc
index b80df37..c75df6f 100644
--- a/net/http/http_proxy_client_socket_pool.cc
+++ b/net/http/http_proxy_client_socket_pool.cc
@@ -289,7 +289,7 @@ int HttpProxyConnectJob::DoHttpProxyConnect() {
int HttpProxyConnectJob::DoHttpProxyConnectComplete(int result) {
if (result == OK || result == ERR_PROXY_AUTH_REQUESTED ||
result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) {
- set_socket(transport_socket_.release());
+ SetSocket(transport_socket_.PassAs<StreamSocket>());
}
return result;
@@ -380,19 +380,19 @@ HttpProxyConnectJobFactory::HttpProxyConnectJobFactory(
}
-ConnectJob*
+scoped_ptr<ConnectJob>
HttpProxyClientSocketPool::HttpProxyConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new HttpProxyConnectJob(group_name,
- request.params(),
- ConnectionTimeout(),
- transport_pool_,
- ssl_pool_,
- host_resolver_,
- delegate,
- net_log_);
+ return scoped_ptr<ConnectJob>(new HttpProxyConnectJob(group_name,
+ request.params(),
+ ConnectionTimeout(),
+ transport_pool_,
+ ssl_pool_,
+ host_resolver_,
+ delegate,
+ net_log_));
}
base::TimeDelta
@@ -462,8 +462,9 @@ void HttpProxyClientSocketPool::CancelRequest(
}
void HttpProxyClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void HttpProxyClientSocketPool::FlushWithError(int error) {
diff --git a/net/http/http_proxy_client_socket_pool.h b/net/http/http_proxy_client_socket_pool.h
index a15b8ca..b77b5ae 100644
--- a/net/http/http_proxy_client_socket_pool.h
+++ b/net/http/http_proxy_client_socket_pool.h
@@ -204,7 +204,7 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
@@ -250,7 +250,7 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool
NetLog* net_log);
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
diff --git a/net/http/http_stream_factory_impl_unittest.cc b/net/http/http_stream_factory_impl_unittest.cc
index 14fbc03..f378c93 100644
--- a/net/http/http_stream_factory_impl_unittest.cc
+++ b/net/http/http_stream_factory_impl_unittest.cc
@@ -314,7 +314,7 @@ class CapturePreconnectsSocketPool : public ParentPool {
ADD_FAILURE();
}
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE {
ADD_FAILURE();
}
diff --git a/net/http/http_stream_parser_unittest.cc b/net/http/http_stream_parser_unittest.cc
index d530c2d..8477594 100644
--- a/net/http/http_stream_parser_unittest.cc
+++ b/net/http/http_stream_parser_unittest.cc
@@ -220,7 +220,7 @@ TEST(HttpStreamParser, AsyncChunkAndAsyncSocket) {
ASSERT_EQ(OK, rv);
scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle);
- socket_handle->set_socket(transport.release());
+ socket_handle->SetSocket(transport.PassAs<StreamSocket>());
HttpRequestInfo request_info;
request_info.method = "GET";
@@ -375,7 +375,7 @@ TEST(HttpStreamParser, TruncatedHeaders) {
ASSERT_EQ(OK, rv);
scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle);
- socket_handle->set_socket(transport.release());
+ socket_handle->SetSocket(transport.PassAs<StreamSocket>());
HttpRequestInfo request_info;
request_info.method = "GET";
diff --git a/net/quic/quic_client_session.cc b/net/quic/quic_client_session.cc
index f6641d0..1ba03a3 100644
--- a/net/quic/quic_client_session.cc
+++ b/net/quic/quic_client_session.cc
@@ -81,7 +81,7 @@ void QuicClientSession::StreamRequest::OnRequestCompleteFailure(int rv) {
QuicClientSession::QuicClientSession(
QuicConnection* connection,
- DatagramClientSocket* socket,
+ scoped_ptr<DatagramClientSocket> socket,
QuicStreamFactory* stream_factory,
QuicCryptoClientStreamFactory* crypto_client_stream_factory,
const string& server_hostname,
@@ -91,7 +91,7 @@ QuicClientSession::QuicClientSession(
: QuicSession(connection, config, false),
weak_factory_(this),
stream_factory_(stream_factory),
- socket_(socket),
+ socket_(socket.Pass()),
read_buffer_(new IOBufferWithSize(kMaxPacketSize)),
read_pending_(false),
num_total_streams_(0),
diff --git a/net/quic/quic_client_session.h b/net/quic/quic_client_session.h
index d124fdb..555837f 100644
--- a/net/quic/quic_client_session.h
+++ b/net/quic/quic_client_session.h
@@ -13,6 +13,7 @@
#include <string>
#include "base/containers/hash_tables.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/completion_callback.h"
#include "net/quic/quic_connection_logger.h"
#include "net/quic/quic_crypto_client_stream.h"
@@ -74,7 +75,7 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession {
// not |stream_factory|, which must outlive this session.
// TODO(rch): decouple the factory from the session via a Delegate interface.
QuicClientSession(QuicConnection* connection,
- DatagramClientSocket* socket,
+ scoped_ptr<DatagramClientSocket> socket,
QuicStreamFactory* stream_factory,
QuicCryptoClientStreamFactory* crypto_client_stream_factory,
const std::string& server_hostname,
diff --git a/net/quic/quic_client_session_test.cc b/net/quic/quic_client_session_test.cc
index 6113f45..09e7d21 100644
--- a/net/quic/quic_client_session_test.cc
+++ b/net/quic/quic_client_session_test.cc
@@ -15,6 +15,7 @@
#include "net/quic/test_tools/crypto_test_utils.h"
#include "net/quic/test_tools/quic_client_session_peer.h"
#include "net/quic/test_tools/quic_test_utils.h"
+#include "net/udp/datagram_client_socket.h"
using testing::_;
@@ -29,8 +30,9 @@ class QuicClientSessionTest : public ::testing::Test {
QuicClientSessionTest()
: guid_(1),
connection_(new PacketSavingConnection(guid_, IPEndPoint(), false)),
- session_(connection_, NULL, NULL, NULL, kServerHostname,
- DefaultQuicConfig(), &crypto_config_, &net_log_) {
+ session_(connection_, scoped_ptr<DatagramClientSocket>(), NULL,
+ NULL, kServerHostname, DefaultQuicConfig(), &crypto_config_,
+ &net_log_) {
session_.config()->SetDefaults();
crypto_config_.SetDefaults();
}
diff --git a/net/quic/quic_http_stream_test.cc b/net/quic/quic_http_stream_test.cc
index b378416..1e4ac91 100644
--- a/net/quic/quic_http_stream_test.cc
+++ b/net/quic/quic_http_stream_test.cc
@@ -179,10 +179,12 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> {
connection_->SetSendAlgorithm(send_algorithm_);
connection_->SetReceiveAlgorithm(receive_algorithm_);
crypto_config_.SetDefaults();
- session_.reset(new QuicClientSession(connection_, socket, NULL,
- &crypto_client_stream_factory_,
- "www.google.com", DefaultQuicConfig(),
- &crypto_config_, NULL));
+ session_.reset(
+ new QuicClientSession(connection_,
+ scoped_ptr<DatagramClientSocket>(socket), NULL,
+ &crypto_client_stream_factory_,
+ "www.google.com", DefaultQuicConfig(),
+ &crypto_config_, NULL));
session_->GetCryptoStream()->CryptoConnect();
EXPECT_TRUE(session_->IsCryptoHandshakeConfirmed());
stream_.reset(use_closing_stream_ ?
diff --git a/net/quic/quic_stream_factory.cc b/net/quic/quic_stream_factory.cc
index fba7f0b..86bd8a1 100644
--- a/net/quic/quic_stream_factory.cc
+++ b/net/quic/quic_stream_factory.cc
@@ -408,10 +408,10 @@ QuicClientSession* QuicStreamFactory::CreateSession(
const BoundNetLog& net_log) {
QuicGuid guid = random_generator_->RandUint64();
IPEndPoint addr = *address_list.begin();
- DatagramClientSocket* socket =
+ scoped_ptr<DatagramClientSocket> socket(
client_socket_factory_->CreateDatagramClientSocket(
DatagramSocket::DEFAULT_BIND, base::Bind(&base::RandInt),
- net_log.net_log(), net_log.source());
+ net_log.net_log(), net_log.source()));
socket->Connect(addr);
// We should adaptively set this buffer size, but for now, we'll use a size
@@ -437,7 +437,7 @@ QuicClientSession* QuicStreamFactory::CreateSession(
base::MessageLoop::current()->message_loop_proxy().get(),
clock_.get(),
random_generator_,
- socket);
+ socket.get());
QuicConnection* connection = new QuicConnection(guid, addr, helper, false,
QuicVersionMax());
@@ -447,7 +447,7 @@ QuicClientSession* QuicStreamFactory::CreateSession(
DCHECK(crypto_config);
QuicClientSession* session =
- new QuicClientSession(connection, socket, this,
+ new QuicClientSession(connection, socket.Pass(), this,
quic_crypto_client_stream_factory_,
host_port_proxy_pair.first.host(), config_,
crypto_config, net_log.net_log());
diff --git a/net/socket/buffered_write_stream_socket.cc b/net/socket/buffered_write_stream_socket.cc
index 36b9df7..cf13c5e 100644
--- a/net/socket/buffered_write_stream_socket.cc
+++ b/net/socket/buffered_write_stream_socket.cc
@@ -23,8 +23,8 @@ void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) {
} // anonymous namespace
BufferedWriteStreamSocket::BufferedWriteStreamSocket(
- StreamSocket* socket_to_wrap)
- : wrapped_socket_(socket_to_wrap),
+ scoped_ptr<StreamSocket> socket_to_wrap)
+ : wrapped_socket_(socket_to_wrap.Pass()),
io_buffer_(new GrowableIOBuffer()),
backup_buffer_(new GrowableIOBuffer()),
weak_factory_(this),
diff --git a/net/socket/buffered_write_stream_socket.h b/net/socket/buffered_write_stream_socket.h
index fcb33a8..aad5736 100644
--- a/net/socket/buffered_write_stream_socket.h
+++ b/net/socket/buffered_write_stream_socket.h
@@ -5,6 +5,8 @@
#ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_
#define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_
+#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
#include "base/memory/weak_ptr.h"
#include "net/base/net_log.h"
#include "net/socket/stream_socket.h"
@@ -33,7 +35,7 @@ class IPEndPoint;
// There are no bounds on the local buffer size. Use carefully.
class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket {
public:
- BufferedWriteStreamSocket(StreamSocket* socket_to_wrap);
+ explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap);
virtual ~BufferedWriteStreamSocket();
// Socket interface
@@ -71,6 +73,8 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket {
bool callback_pending_;
bool wrapped_write_in_progress_;
int error_;
+
+ DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket);
};
} // namespace net
diff --git a/net/socket/buffered_write_stream_socket_unittest.cc b/net/socket/buffered_write_stream_socket_unittest.cc
index e579a7f..485295f 100644
--- a/net/socket/buffered_write_stream_socket_unittest.cc
+++ b/net/socket/buffered_write_stream_socket_unittest.cc
@@ -30,10 +30,11 @@ class BufferedWriteStreamSocketTest : public testing::Test {
if (writes_count) {
data_->StopAfter(writes_count);
}
- DeterministicMockTCPClientSocket* wrapped_socket =
- new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get());
+ scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket(
+ new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get()));
data_->set_delegate(wrapped_socket->AsWeakPtr());
- socket_.reset(new BufferedWriteStreamSocket(wrapped_socket));
+ socket_.reset(new BufferedWriteStreamSocket(
+ wrapped_socket.PassAs<StreamSocket>()));
socket_->Connect(callback_.callback());
}
diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc
index 6d93034..a86688e 100644
--- a/net/socket/client_socket_factory.cc
+++ b/net/socket/client_socket_factory.cc
@@ -67,23 +67,25 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
ClearSSLSessionCache();
}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
- return new UDPClientSocket(bind_type, rand_int_cb, net_log, source);
+ return scoped_ptr<DatagramClientSocket>(
+ new UDPClientSocket(bind_type, rand_int_cb, net_log, source));
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
- return new TCPClientSocket(addresses, net_log, source);
+ return scoped_ptr<StreamSocket>(
+ new TCPClientSocket(addresses, net_log, source));
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
@@ -102,17 +104,19 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
nss_task_runner = base::ThreadTaskRunnerHandle::Get();
#if defined(USE_OPENSSL)
- return new SSLClientSocketOpenSSL(transport_socket, host_and_port,
- ssl_config, context);
+ return scoped_ptr<SSLClientSocket>(
+ new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port,
+ ssl_config, context));
#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN)
- return new SSLClientSocketNSS(nss_task_runner.get(),
- transport_socket,
- host_and_port,
- ssl_config,
- context);
+ return scoped_ptr<SSLClientSocket>(
+ new SSLClientSocketNSS(nss_task_runner.get(),
+ transport_socket.Pass(),
+ host_and_port,
+ ssl_config,
+ context));
#else
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
#endif
}
diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h
index a78fc48..6cb5949 100644
--- a/net/socket/client_socket_factory.h
+++ b/net/socket/client_socket_factory.h
@@ -8,6 +8,7 @@
#include <string>
#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/net_export.h"
#include "net/base/net_log.h"
#include "net/base/rand_callback.h"
@@ -32,13 +33,13 @@ class NET_EXPORT ClientSocketFactory {
// |source| is the NetLog::Source for the entity trying to create the socket,
// if it has one.
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) = 0;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) = 0;
@@ -46,8 +47,8 @@ class NET_EXPORT ClientSocketFactory {
// It is allowed to pass in a |transport_socket| that is not obtained from a
// socket pool. The caller could create a ClientSocketHandle directly and call
// set_socket() on it to set a valid StreamSocket instance.
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) = 0;
diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc
index 3894fa7..acb896b 100644
--- a/net/socket/client_socket_handle.cc
+++ b/net/socket/client_socket_handle.cc
@@ -43,7 +43,7 @@ void ClientSocketHandle::ResetInternal(bool cancel) {
if (pool_)
// If we've still got a socket, release it back to the ClientSocketPool so
// it can be deleted or reused.
- pool_->ReleaseSocket(group_name_, release_socket(), pool_id_);
+ pool_->ReleaseSocket(group_name_, PassSocket(), pool_id_);
} else if (cancel) {
// If we did not get initialized yet, we've got a socket request pending.
// Cancel it.
@@ -121,6 +121,10 @@ bool ClientSocketHandle::GetLoadTimingInfo(
return true;
}
+void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) {
+ socket_ = s.Pass();
+}
+
void ClientSocketHandle::OnIOComplete(int result) {
CompletionCallback callback = user_callback_;
user_callback_.Reset();
@@ -128,6 +132,10 @@ void ClientSocketHandle::OnIOComplete(int result) {
callback.Run(result);
}
+scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() {
+ return socket_.Pass();
+}
+
void ClientSocketHandle::HandleInitCompletion(int result) {
CHECK_NE(ERR_IO_PENDING, result);
ClientSocketPoolHistograms* histograms = pool_->histograms();
diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h
index 7d5588a..9651f08 100644
--- a/net/socket/client_socket_handle.h
+++ b/net/socket/client_socket_handle.h
@@ -116,8 +116,8 @@ class NET_EXPORT ClientSocketHandle {
LoadTimingInfo* load_timing_info) const;
// Used by ClientSocketPool to initialize the ClientSocketHandle.
+ void SetSocket(scoped_ptr<StreamSocket> s);
void set_is_reused(bool is_reused) { is_reused_ = is_reused; }
- void set_socket(StreamSocket* s) { socket_.reset(s); }
void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; }
void set_pool_id(int id) { pool_id_ = id; }
void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; }
@@ -144,10 +144,10 @@ class NET_EXPORT ClientSocketHandle {
}
// These may only be used if is_initialized() is true.
+ scoped_ptr<StreamSocket> PassSocket();
+ StreamSocket* socket() { return socket_.get(); }
const std::string& group_name() const { return group_name_; }
int id() const { return pool_id_; }
- StreamSocket* socket() { return socket_.get(); }
- StreamSocket* release_socket() { return socket_.release(); }
bool is_reused() const { return is_reused_; }
base::TimeDelta idle_time() const { return idle_time_; }
SocketReuseType reuse_type() const {
diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h
index 7cb9a7e..af18454 100644
--- a/net/socket/client_socket_pool.h
+++ b/net/socket/client_socket_pool.h
@@ -10,6 +10,7 @@
#include "base/basictypes.h"
#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
#include "base/template_util.h"
#include "base/time/time.h"
#include "net/base/completion_callback.h"
@@ -111,7 +112,7 @@ class NET_EXPORT ClientSocketPool {
// change when it flushes, so it can use this |id| to discard sockets with
// mismatched ids.
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) = 0;
// This flushes all state from the ClientSocketPool. This means that all
diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc
index 4a5b118..b1ddd40 100644
--- a/net/socket/client_socket_pool_base.cc
+++ b/net/socket/client_socket_pool_base.cc
@@ -82,6 +82,10 @@ ConnectJob::~ConnectJob() {
net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB);
}
+scoped_ptr<StreamSocket> ConnectJob::PassSocket() {
+ return socket_.Pass();
+}
+
int ConnectJob::Connect() {
if (timeout_duration_ != base::TimeDelta())
timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout);
@@ -100,16 +104,16 @@ int ConnectJob::Connect() {
return rv;
}
-void ConnectJob::set_socket(StreamSocket* socket) {
+void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) {
if (socket) {
net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET,
socket->NetLog().source().ToEventParametersCallback());
}
- socket_.reset(socket);
+ socket_ = socket.Pass();
}
void ConnectJob::NotifyDelegateOfCompletion(int rv) {
- // The delegate will delete |this|.
+ // The delegate will own |this|.
Delegate* delegate = delegate_;
delegate_ = NULL;
@@ -135,7 +139,7 @@ void ConnectJob::LogConnectCompletion(int net_error) {
void ConnectJob::OnTimeout() {
// Make sure the socket is NULL before calling into |delegate|.
- set_socket(NULL);
+ SetSocket(scoped_ptr<StreamSocket>());
net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT);
@@ -392,11 +396,11 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
if (rv == OK) {
LogBoundConnectJobToRequest(connect_job->net_log().source(), request);
if (!preconnecting) {
- HandOutSocket(connect_job->ReleaseSocket(), false /* not reused */,
+ HandOutSocket(connect_job->PassSocket(), false /* not reused */,
connect_job->connect_timing(), handle, base::TimeDelta(),
group, request->net_log());
} else {
- AddIdleSocket(connect_job->ReleaseSocket(), group);
+ AddIdleSocket(connect_job->PassSocket(), group);
}
} else if (rv == ERR_IO_PENDING) {
// If we don't have any sockets in this group, set a timer for potentially
@@ -409,17 +413,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal(
connecting_socket_count_++;
- group->AddJob(connect_job.release(), preconnecting);
+ group->AddJob(connect_job.Pass(), preconnecting);
} else {
LogBoundConnectJobToRequest(connect_job->net_log().source(), request);
- StreamSocket* error_socket = NULL;
+ scoped_ptr<StreamSocket> error_socket;
if (!preconnecting) {
DCHECK(handle);
connect_job->GetAdditionalErrorState(handle);
- error_socket = connect_job->ReleaseSocket();
+ error_socket = connect_job->PassSocket();
}
if (error_socket) {
- HandOutSocket(error_socket, false /* not reused */,
+ HandOutSocket(error_socket.Pass(), false /* not reused */,
connect_job->connect_timing(), handle, base::TimeDelta(),
group, request->net_log());
} else if (group->IsEmpty()) {
@@ -469,7 +473,7 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest(
IdleSocket idle_socket = *idle_socket_it;
idle_sockets->erase(idle_socket_it);
HandOutSocket(
- idle_socket.socket,
+ scoped_ptr<StreamSocket>(idle_socket.socket),
idle_socket.socket->WasEverUsed(),
LoadTimingInfo::ConnectTiming(),
request->handle(),
@@ -495,11 +499,11 @@ void ClientSocketPoolBaseHelper::CancelRequest(
if (callback_it != pending_callback_map_.end()) {
int result = callback_it->second.result;
pending_callback_map_.erase(callback_it);
- StreamSocket* socket = handle->release_socket();
+ scoped_ptr<StreamSocket> socket = handle->PassSocket();
if (socket) {
if (result != OK)
socket->Disconnect();
- ReleaseSocket(handle->group_name(), socket, handle->id());
+ ReleaseSocket(handle->group_name(), socket.Pass(), handle->id());
}
return;
}
@@ -756,7 +760,7 @@ void ClientSocketPoolBaseHelper::StartIdleSocketTimer() {
}
void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) {
GroupMap::iterator i = group_map_.find(group_name);
CHECK(i != group_map_.end());
@@ -773,10 +777,10 @@ void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name,
id == pool_generation_number_;
if (can_reuse) {
// Add it to the idle list.
- AddIdleSocket(socket, group);
+ AddIdleSocket(socket.Pass(), group);
OnAvailableSocketSlot(group_name, group);
} else {
- delete socket;
+ socket.reset();
}
CheckForStalledSocketGroups();
@@ -854,13 +858,16 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete(
CHECK(group_it != group_map_.end());
Group* group = group_it->second;
- scoped_ptr<StreamSocket> socket(job->ReleaseSocket());
+ scoped_ptr<StreamSocket> socket = job->PassSocket();
// Copies of these are needed because |job| may be deleted before they are
// accessed.
BoundNetLog job_log = job->net_log();
LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing();
+ // RemoveConnectJob(job, _) must be called by all branches below;
+ // otherwise, |job| will be leaked.
+
if (result == OK) {
DCHECK(socket.get());
RemoveConnectJob(job, group);
@@ -869,12 +876,12 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete(
group->mutable_pending_requests()->begin(), group));
LogBoundConnectJobToRequest(job_log.source(), r.get());
HandOutSocket(
- socket.release(), false /* unused socket */, connect_timing,
+ socket.Pass(), false /* unused socket */, connect_timing,
r->handle(), base::TimeDelta(), group, r->net_log());
r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
InvokeUserCallbackLater(r->handle(), r->callback(), result);
} else {
- AddIdleSocket(socket.release(), group);
+ AddIdleSocket(socket.Pass(), group);
OnAvailableSocketSlot(group_name, group);
CheckForStalledSocketGroups();
}
@@ -890,7 +897,7 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete(
RemoveConnectJob(job, group);
if (socket.get()) {
handed_out_socket = true;
- HandOutSocket(socket.release(), false /* unused socket */,
+ HandOutSocket(socket.Pass(), false /* unused socket */,
connect_timing, r->handle(), base::TimeDelta(), group,
r->net_log());
}
@@ -975,7 +982,7 @@ void ClientSocketPoolBaseHelper::ProcessPendingRequest(
}
void ClientSocketPoolBaseHelper::HandOutSocket(
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
bool reused,
const LoadTimingInfo::ConnectTiming& connect_timing,
ClientSocketHandle* handle,
@@ -983,7 +990,7 @@ void ClientSocketPoolBaseHelper::HandOutSocket(
Group* group,
const BoundNetLog& net_log) {
DCHECK(socket);
- handle->set_socket(socket);
+ handle->SetSocket(socket.Pass());
handle->set_is_reused(reused);
handle->set_idle_time(idle_time);
handle->set_pool_id(pool_generation_number_);
@@ -996,18 +1003,20 @@ void ClientSocketPoolBaseHelper::HandOutSocket(
"idle_ms", static_cast<int>(idle_time.InMilliseconds())));
}
- net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
- socket->NetLog().source().ToEventParametersCallback());
+ net_log.AddEvent(
+ NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
+ handle->socket()->NetLog().source().ToEventParametersCallback());
handed_out_socket_count_++;
group->IncrementActiveSocketCount();
}
void ClientSocketPoolBaseHelper::AddIdleSocket(
- StreamSocket* socket, Group* group) {
+ scoped_ptr<StreamSocket> socket,
+ Group* group) {
DCHECK(socket);
IdleSocket idle_socket;
- idle_socket.socket = socket;
+ idle_socket.socket = socket.release();
idle_socket.start_time = base::TimeTicks::Now();
group->mutable_idle_sockets()->push_back(idle_socket);
@@ -1178,13 +1187,13 @@ bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() {
return true;
}
-void ClientSocketPoolBaseHelper::Group::AddJob(ConnectJob* job,
+void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job,
bool is_preconnect) {
SanityCheck();
if (is_preconnect)
++unassigned_job_count_;
- jobs_.insert(job);
+ jobs_.insert(job.release());
}
void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) {
@@ -1224,15 +1233,17 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired(
if (pending_requests_.empty())
return;
- ConnectJob* backup_job = pool->connect_job_factory_->NewConnectJob(
- group_name, **pending_requests_.begin(), pool);
+ scoped_ptr<ConnectJob> backup_job =
+ pool->connect_job_factory_->NewConnectJob(
+ group_name, **pending_requests_.begin(), pool);
backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED);
SIMPLE_STATS_COUNTER("socket.backup_created");
int rv = backup_job->Connect();
pool->connecting_socket_count_++;
- AddJob(backup_job, false);
+ ConnectJob* raw_backup_job = backup_job.get();
+ AddJob(backup_job.Pass(), false);
if (rv != ERR_IO_PENDING)
- pool->OnConnectJobComplete(rv, backup_job);
+ pool->OnConnectJobComplete(rv, raw_backup_job);
}
void ClientSocketPoolBaseHelper::Group::SanityCheck() {
diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h
index ae1331b..eb642ed 100644
--- a/net/socket/client_socket_pool_base.h
+++ b/net/socket/client_socket_pool_base.h
@@ -61,8 +61,11 @@ class NET_EXPORT_PRIVATE ConnectJob {
Delegate() {}
virtual ~Delegate() {}
- // Alerts the delegate that the connection completed.
- virtual void OnConnectJobComplete(int result, ConnectJob* job) = 0;
+ // Alerts the delegate that the connection completed. |job| must
+ // be destroyed by the delegate. A scoped_ptr<> isn't used because
+ // the caller of this function doesn't own |job|.
+ virtual void OnConnectJobComplete(int result,
+ ConnectJob* job) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(Delegate);
@@ -79,9 +82,10 @@ class NET_EXPORT_PRIVATE ConnectJob {
const std::string& group_name() const { return group_name_; }
const BoundNetLog& net_log() { return net_log_; }
- // Releases |socket_| to the client. On connection error, this should return
- // NULL.
- StreamSocket* ReleaseSocket() { return socket_.release(); }
+ // Releases ownership of the underlying socket to the caller.
+ // Returns the released socket, or NULL if there was a connection
+ // error.
+ scoped_ptr<StreamSocket> PassSocket();
// Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it
// cannot complete synchronously without blocking, or another net error code
@@ -105,7 +109,7 @@ class NET_EXPORT_PRIVATE ConnectJob {
const BoundNetLog& net_log() const { return net_log_; }
protected:
- void set_socket(StreamSocket* socket);
+ void SetSocket(scoped_ptr<StreamSocket> socket);
StreamSocket* socket() { return socket_.get(); }
void NotifyDelegateOfCompletion(int rv);
void ResetTimer(base::TimeDelta remainingTime);
@@ -188,7 +192,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
ConnectJobFactory() {}
virtual ~ConnectJobFactory() {}
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const Request& request,
ConnectJob::Delegate* delegate) const = 0;
@@ -229,7 +233,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// See ClientSocketPool::ReleaseSocket for documentation on this function.
void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id);
// See ClientSocketPool::FlushWithError for documentation on this function.
@@ -386,7 +390,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
// Otherwise, returns false.
bool TryToUseUnassignedConnectJob();
- void AddJob(ConnectJob* job, bool is_preconnect);
+ void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect);
// Remove |job| from this group, which must already own |job|.
void RemoveJob(ConnectJob* job);
void RemoveAllJobs();
@@ -476,7 +480,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
CleanupIdleSockets(false);
}
- // Removes |job| from |connect_job_set_|. Also updates |group| if non-NULL.
+ // Removes |job| from |group|, which must already own |job|.
void RemoveConnectJob(ConnectJob* job, Group* group);
// Tries to see if we can handle any more requests for |group|.
@@ -486,7 +490,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
void ProcessPendingRequest(const std::string& group_name, Group* group);
// Assigns |socket| to |handle| and updates |group|'s counters appropriately.
- void HandOutSocket(StreamSocket* socket,
+ void HandOutSocket(scoped_ptr<StreamSocket> socket,
bool reused,
const LoadTimingInfo::ConnectTiming& connect_timing,
ClientSocketHandle* handle,
@@ -495,7 +499,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
const BoundNetLog& net_log);
// Adds |socket| to the list of idle sockets for |group|.
- void AddIdleSocket(StreamSocket* socket, Group* group);
+ void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group);
// Iterates through |group_map_|, canceling all ConnectJobs and deleting
// groups if they are no longer needed.
@@ -625,7 +629,7 @@ class ClientSocketPoolBase {
ConnectJobFactory() {}
virtual ~ConnectJobFactory() {}
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const Request& request,
ConnectJob::Delegate* delegate) const = 0;
@@ -703,9 +707,10 @@ class ClientSocketPoolBase {
return helper_.CancelRequest(group_name, handle);
}
- void ReleaseSocket(const std::string& group_name, StreamSocket* socket,
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
int id) {
- return helper_.ReleaseSocket(group_name, socket, id);
+ return helper_.ReleaseSocket(group_name, socket.Pass(), id);
}
void FlushWithError(int error) { helper_.FlushWithError(error); }
@@ -786,13 +791,13 @@ class ClientSocketPoolBase {
: connect_job_factory_(connect_job_factory) {}
virtual ~ConnectJobFactoryAdaptor() {}
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const internal::ClientSocketPoolBaseHelper::Request& request,
- ConnectJob::Delegate* delegate) const {
- const Request* casted_request = static_cast<const Request*>(&request);
+ ConnectJob::Delegate* delegate) const OVERRIDE {
+ const Request& casted_request = static_cast<const Request&>(request);
return connect_job_factory_->NewConnectJob(
- group_name, *casted_request, delegate);
+ group_name, casted_request, delegate);
}
virtual base::TimeDelta ConnectionTimeout() const {
diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc
index 5eeda97..6688e01 100644
--- a/net/socket/client_socket_pool_base_unittest.cc
+++ b/net/socket/client_socket_pool_base_unittest.cc
@@ -30,7 +30,9 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_histograms.h"
#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -189,30 +191,30 @@ class MockClientSocketFactory : public ClientSocketFactory {
public:
MockClientSocketFactory() : allocation_count_(0) {}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
NOTREACHED();
- return NULL;
+ return scoped_ptr<DatagramClientSocket>();
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* /* net_log */,
const NetLog::Source& /*source*/) OVERRIDE {
allocation_count_++;
- return NULL;
+ return scoped_ptr<StreamSocket>();
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
@@ -294,7 +296,8 @@ class TestConnectJob : public ConnectJob {
AddressList ignored;
client_socket_factory_->CreateTransportClientSocket(
ignored, NULL, net::NetLog::Source());
- set_socket(new MockClientSocket(net_log().net_log()));
+ SetSocket(
+ scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log())));
switch (job_type_) {
case kMockJob:
return DoConnect(true /* successful */, false /* sync */,
@@ -373,7 +376,7 @@ class TestConnectJob : public ConnectJob {
return ERR_IO_PENDING;
default:
NOTREACHED();
- set_socket(NULL);
+ SetSocket(scoped_ptr<StreamSocket>());
return ERR_FAILED;
}
}
@@ -386,7 +389,7 @@ class TestConnectJob : public ConnectJob {
result = ERR_PROXY_AUTH_REQUESTED;
} else {
result = ERR_CONNECTION_FAILED;
- set_socket(NULL);
+ SetSocket(scoped_ptr<StreamSocket>());
}
if (was_async)
@@ -430,7 +433,7 @@ class TestConnectJobFactory
// ConnectJobFactory implementation.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const TestClientSocketPoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE {
@@ -440,13 +443,13 @@ class TestConnectJobFactory
job_type = job_types_->front();
job_types_->pop_front();
}
- return new TestConnectJob(job_type,
- group_name,
- request,
- timeout_duration_,
- delegate,
- client_socket_factory_,
- net_log_);
+ return scoped_ptr<ConnectJob>(new TestConnectJob(job_type,
+ group_name,
+ request,
+ timeout_duration_,
+ delegate,
+ client_socket_factory_,
+ net_log_));
}
virtual base::TimeDelta ConnectionTimeout() const OVERRIDE {
@@ -509,9 +512,9 @@ class TestClientSocketPool : public ClientSocketPool {
virtual void ReleaseSocket(
const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE {
- base_.ReleaseSocket(group_name, socket, id);
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
virtual void FlushWithError(int error) OVERRIDE {
@@ -630,10 +633,10 @@ class TestConnectJobDelegate : public ConnectJob::Delegate {
virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE {
result_ = result;
- scoped_ptr<StreamSocket> socket(job->ReleaseSocket());
+ scoped_ptr<ConnectJob> owned_job(job);
+ scoped_ptr<StreamSocket> socket = owned_job->PassSocket();
// socket.get() should be NULL iff result != OK
- EXPECT_EQ(socket.get() == NULL, result != OK);
- delete job;
+ EXPECT_EQ(socket == NULL, result != OK);
have_result_ = true;
if (waiting_for_result_)
base::MessageLoop::current()->Quit();
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc
index 8b2bdfc..159f62e 100644
--- a/net/socket/socket_test_util.cc
+++ b/net/socket/socket_test_util.cc
@@ -657,37 +657,39 @@ void MockClientSocketFactory::ResetNextMockIndexes() {
mock_ssl_data_.ResetNextIndex();
}
-DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket(
+scoped_ptr<DatagramClientSocket>
+MockClientSocketFactory::CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const net::NetLog::Source& source) {
SocketDataProvider* data_provider = mock_data_.GetNext();
- MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log);
- data_provider->set_socket(socket);
- return socket;
+ scoped_ptr<MockUDPClientSocket> socket(
+ new MockUDPClientSocket(data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
}
-StreamSocket* MockClientSocketFactory::CreateTransportClientSocket(
+scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
const AddressList& addresses,
net::NetLog* net_log,
const net::NetLog::Source& source) {
SocketDataProvider* data_provider = mock_data_.GetNext();
- MockTCPClientSocket* socket =
- new MockTCPClientSocket(addresses, net_log, data_provider);
- data_provider->set_socket(socket);
- return socket;
+ scoped_ptr<MockTCPClientSocket> socket(
+ new MockTCPClientSocket(addresses, net_log, data_provider));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<StreamSocket>();
}
-SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) {
- MockSSLClientSocket* socket =
- new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
- mock_ssl_data_.GetNext());
- return socket;
+ return scoped_ptr<SSLClientSocket>(
+ new MockSSLClientSocket(transport_socket.Pass(),
+ host_and_port, ssl_config,
+ mock_ssl_data_.GetNext()));
}
void MockClientSocketFactory::ClearSSLSessionCache() {
@@ -1278,7 +1280,7 @@ void DeterministicMockTCPClientSocket::OnConnectComplete(
// static
void MockSSLClientSocket::ConnectCallback(
- MockSSLClientSocket *ssl_client_socket,
+ MockSSLClientSocket* ssl_client_socket,
const CompletionCallback& callback,
int rv) {
if (rv == OK)
@@ -1287,7 +1289,7 @@ void MockSSLClientSocket::ConnectCallback(
}
MockSSLClientSocket::MockSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_port_pair,
const SSLConfig& ssl_config,
SSLSocketDataProvider* data)
@@ -1295,7 +1297,7 @@ MockSSLClientSocket::MockSSLClientSocket(
// Have to use the right BoundNetLog for LoadTimingInfo regression
// tests.
transport_socket->socket()->NetLog()),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
data_(data),
is_npn_state_set_(false),
new_npn_value_(false),
@@ -1664,10 +1666,10 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
}
MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
ClientSocketHandle* handle,
const CompletionCallback& callback)
- : socket_(socket),
+ : socket_(socket.Pass()),
handle_(handle),
user_callback_(callback) {
}
@@ -1698,7 +1700,7 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
if (!socket_.get())
return;
if (rv == OK) {
- handle_->set_socket(socket_.release());
+ handle_->SetSocket(socket_.Pass());
// Needed for socket pool tests that layer other sockets on top of mock
// sockets.
@@ -1740,9 +1742,10 @@ int MockTransportClientSocketPool::RequestSocket(
const std::string& group_name, const void* socket_params,
RequestPriority priority, ClientSocketHandle* handle,
const CompletionCallback& callback, const BoundNetLog& net_log) {
- StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket(
- AddressList(), net_log.net_log(), net::NetLog::Source());
- MockConnectJob* job = new MockConnectJob(socket, handle, callback);
+ scoped_ptr<StreamSocket> socket =
+ client_socket_factory_->CreateTransportClientSocket(
+ AddressList(), net_log.net_log(), net::NetLog::Source());
+ MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback);
job_list_.push_back(job);
handle->set_pool_id(1);
return job->Connect();
@@ -1759,11 +1762,12 @@ void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
}
}
-void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
+void MockTransportClientSocketPool::ReleaseSocket(
+ const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
EXPECT_EQ(1, id);
release_count_++;
- delete socket;
}
DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
@@ -1791,42 +1795,45 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory::
return ssl_client_sockets_[index];
}
-DatagramClientSocket*
+scoped_ptr<DatagramClientSocket>
DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
net::NetLog* net_log,
const NetLog::Source& source) {
DeterministicSocketData* data_provider = mock_data().GetNext();
- DeterministicMockUDPClientSocket* socket =
- new DeterministicMockUDPClientSocket(net_log, data_provider);
+ scoped_ptr<DeterministicMockUDPClientSocket> socket(
+ new DeterministicMockUDPClientSocket(net_log, data_provider));
data_provider->set_delegate(socket->AsWeakPtr());
- udp_client_sockets().push_back(socket);
- return socket;
+ udp_client_sockets().push_back(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
}
-StreamSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket(
+scoped_ptr<StreamSocket>
+DeterministicMockClientSocketFactory::CreateTransportClientSocket(
const AddressList& addresses,
net::NetLog* net_log,
const net::NetLog::Source& source) {
DeterministicSocketData* data_provider = mock_data().GetNext();
- DeterministicMockTCPClientSocket* socket =
- new DeterministicMockTCPClientSocket(net_log, data_provider);
+ scoped_ptr<DeterministicMockTCPClientSocket> socket(
+ new DeterministicMockTCPClientSocket(net_log, data_provider));
data_provider->set_delegate(socket->AsWeakPtr());
- tcp_client_sockets().push_back(socket);
- return socket;
+ tcp_client_sockets().push_back(socket.get());
+ return socket.PassAs<StreamSocket>();
}
-SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+scoped_ptr<SSLClientSocket>
+DeterministicMockClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) {
- MockSSLClientSocket* socket =
- new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
- mock_ssl_data_.GetNext());
- ssl_client_sockets_.push_back(socket);
- return socket;
+ scoped_ptr<MockSSLClientSocket> socket(
+ new MockSSLClientSocket(transport_socket.Pass(),
+ host_and_port, ssl_config,
+ mock_ssl_data_.GetNext()));
+ ssl_client_sockets_.push_back(socket.get());
+ return socket.PassAs<SSLClientSocket>();
}
void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
@@ -1859,8 +1866,9 @@ void MockSOCKSClientSocketPool::CancelRequest(
}
void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- return transport_pool_->ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id);
}
const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h
index 6afe170..a888249 100644
--- a/net/socket/socket_test_util.h
+++ b/net/socket/socket_test_util.h
@@ -592,17 +592,17 @@ class MockClientSocketFactory : public ClientSocketFactory {
}
// ClientSocketFactory
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE;
@@ -857,7 +857,7 @@ class DeterministicMockTCPClientSocket
class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
public:
MockSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
SSLSocketDataProvider* socket);
@@ -1049,7 +1049,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
public:
class MockConnectJob {
public:
- MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle,
+ MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle,
const CompletionCallback& callback);
~MockConnectJob();
@@ -1088,7 +1088,8 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) OVERRIDE;
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
private:
ClientSocketFactory* client_socket_factory_;
@@ -1123,17 +1124,17 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory {
}
// ClientSocketFactory
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE;
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE;
@@ -1170,7 +1171,8 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) OVERRIDE;
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
private:
TransportClientSocketPool* const transport_pool_;
diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc
index 8e329d1..537b584 100644
--- a/net/socket/socks5_client_socket.cc
+++ b/net/socket/socks5_client_socket.cc
@@ -28,18 +28,18 @@ COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4);
COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6);
SOCKS5ClientSocket::SOCKS5ClientSocket(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostResolver::RequestInfo& req_info)
: io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
base::Unretained(this))),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
next_state_(STATE_NONE),
completed_handshake_(false),
bytes_sent_(0),
bytes_received_(0),
read_header_size(kReadHeaderSize),
host_request_info_(req_info),
- net_log_(transport_socket->socket()->NetLog()) {
+ net_log_(transport_->socket()->NetLog()) {
}
SOCKS5ClientSocket::~SOCKS5ClientSocket() {
diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h
index 28e829e..4521624 100644
--- a/net/socket/socks5_client_socket.h
+++ b/net/socket/socks5_client_socket.h
@@ -28,16 +28,13 @@ class BoundNetLog;
// Currently no SOCKSv5 authentication is supported.
class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket {
public:
- // Takes ownership of the |transport_socket|, which should already be
- // connected by the time Connect() is called.
- //
// |req_info| contains the hostname and port to which the socket above will
// communicate to via the SOCKS layer.
//
// Although SOCKS 5 supports 3 different modes of addressing, we will
// always pass it a hostname. This means the DNS resolving is done
// proxy side.
- SOCKS5ClientSocket(ClientSocketHandle* transport_socket,
+ SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,
const HostResolver::RequestInfo& req_info);
// On destruction Disconnect() is called.
diff --git a/net/socket/socks5_client_socket_unittest.cc b/net/socket/socks5_client_socket_unittest.cc
index 630884f..4c9240f 100644
--- a/net/socket/socks5_client_socket_unittest.cc
+++ b/net/socket/socks5_client_socket_unittest.cc
@@ -32,13 +32,13 @@ class SOCKS5ClientSocketTest : public PlatformTest {
public:
SOCKS5ClientSocketTest();
// Create a SOCKSClientSocket on top of a MockSocket.
- SOCKS5ClientSocket* BuildMockSocket(MockRead reads[],
- size_t reads_count,
- MockWrite writes[],
- size_t writes_count,
- const std::string& hostname,
- int port,
- NetLog* net_log);
+ scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[],
+ size_t reads_count,
+ MockWrite writes[],
+ size_t writes_count,
+ const std::string& hostname,
+ int port,
+ NetLog* net_log);
virtual void SetUp();
@@ -77,7 +77,7 @@ void SOCKS5ClientSocketTest::SetUp() {
ASSERT_EQ(OK, rv);
}
-SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
+scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
MockRead reads[],
size_t reads_count,
MockWrite writes[],
@@ -99,10 +99,10 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
// |connection| takes ownership of |tcp_sock_|, but keep a
// non-owning pointer to it.
- connection->set_socket(tcp_sock_);
- return new SOCKS5ClientSocket(
- connection.release(),
- HostResolver::RequestInfo(HostPortPair(hostname, port)));
+ connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
+ return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket(
+ connection.Pass(),
+ HostResolver::RequestInfo(HostPortPair(hostname, port))));
}
// Tests a complete SOCKS5 handshake and the disconnection.
@@ -130,9 +130,9 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
MockRead(ASYNC, payload_read.data(), payload_read.size()) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- "localhost", 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ "localhost", 80, &net_log_);
// At this state the TCP connection is completed but not the SOCKS handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
@@ -202,9 +202,9 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
};
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, NULL));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, NULL);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(OK, rv);
@@ -224,9 +224,9 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
// Create a SOCKS socket, with mock transport socket.
MockWrite data_writes[] = {MockWrite()};
MockRead data_reads[] = {MockRead()};
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- large_host_name, 80, NULL));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ large_host_name, 80, NULL);
// Try to connect -- should fail (without having read/written anything to
// the transport socket first) because the hostname is too long.
@@ -260,9 +260,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -291,9 +291,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
MockRead(ASYNC, partial1, arraysize(partial1)),
MockRead(ASYNC, partial2, arraysize(partial2)),
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -321,9 +321,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
MockRead data_reads[] = {
MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
CapturingNetLog::CapturedEntryList net_log_entries;
@@ -352,9 +352,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
kSOCKS5OkResponseLength - kSplitPoint)
};
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hostname, 80, &net_log_));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
CapturingNetLog::CapturedEntryList net_log_entries;
diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc
index cefcd3f..1941fdb 100644
--- a/net/socket/socks_client_socket.cc
+++ b/net/socket/socks_client_socket.cc
@@ -55,17 +55,18 @@ struct SOCKS4ServerResponse {
COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
socks4_server_response_struct_wrong_size);
-SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket,
- const HostResolver::RequestInfo& req_info,
- HostResolver* host_resolver)
- : transport_(transport_socket),
+SOCKSClientSocket::SOCKSClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostResolver::RequestInfo& req_info,
+ HostResolver* host_resolver)
+ : transport_(transport_socket.Pass()),
next_state_(STATE_NONE),
completed_handshake_(false),
bytes_sent_(0),
bytes_received_(0),
host_resolver_(host_resolver),
host_request_info_(req_info),
- net_log_(transport_socket->socket()->NetLog()) {
+ net_log_(transport_->socket()->NetLog()) {
}
SOCKSClientSocket::~SOCKSClientSocket() {
diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h
index a8b3b71..285c75e 100644
--- a/net/socket/socks_client_socket.h
+++ b/net/socket/socks_client_socket.h
@@ -27,12 +27,9 @@ class BoundNetLog;
// The SOCKS client socket implementation
class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
public:
- // Takes ownership of the |transport_socket|, which should already be
- // connected by the time Connect() is called.
- //
// |req_info| contains the hostname and port to which the socket above will
// communicate to via the socks layer. For testing the referrer is optional.
- SOCKSClientSocket(ClientSocketHandle* transport_socket,
+ SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,
const HostResolver::RequestInfo& req_info,
HostResolver* host_resolver);
diff --git a/net/socket/socks_client_socket_pool.cc b/net/socket/socks_client_socket_pool.cc
index d740e5b..e49eaba 100644
--- a/net/socket/socks_client_socket_pool.cc
+++ b/net/socket/socks_client_socket_pool.cc
@@ -140,10 +140,10 @@ int SOCKSConnectJob::DoSOCKSConnect() {
// Add a SOCKS connection on top of the tcp socket.
if (socks_params_->is_socks_v5()) {
- socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(),
+ socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(),
socks_params_->destination()));
} else {
- socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(),
+ socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(),
socks_params_->destination(),
resolver_));
}
@@ -157,7 +157,7 @@ int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
return result;
}
- set_socket(socket_.release());
+ SetSocket(socket_.Pass());
return result;
}
@@ -166,17 +166,18 @@ int SOCKSConnectJob::ConnectInternal() {
return DoLoop(OK);
}
-ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
+scoped_ptr<ConnectJob>
+SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new SOCKSConnectJob(group_name,
- request.params(),
- ConnectionTimeout(),
- transport_pool_,
- host_resolver_,
- delegate,
- net_log_);
+ return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name,
+ request.params(),
+ ConnectionTimeout(),
+ transport_pool_,
+ host_resolver_,
+ delegate,
+ net_log_));
}
base::TimeDelta
@@ -238,8 +239,9 @@ void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
}
void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void SOCKSClientSocketPool::FlushWithError(int error) {
diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h
index 86609a1..fe69a78 100644
--- a/net/socket/socks_client_socket_pool.h
+++ b/net/socket/socks_client_socket_pool.h
@@ -134,7 +134,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
@@ -183,7 +183,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
virtual ~SOCKSConnectJobFactory() {}
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc
index 640c4f1..8c30838 100644
--- a/net/socket/socks_client_socket_unittest.cc
+++ b/net/socket/socks_client_socket_unittest.cc
@@ -4,6 +4,7 @@
#include "net/socket/socks_client_socket.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/address_list.h"
#include "net/base/net_log.h"
#include "net/base/net_log_unittest.h"
@@ -27,11 +28,12 @@ class SOCKSClientSocketTest : public PlatformTest {
public:
SOCKSClientSocketTest();
// Create a SOCKSClientSocket on top of a MockSocket.
- SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count,
- MockWrite writes[], size_t writes_count,
- HostResolver* host_resolver,
- const std::string& hostname, int port,
- NetLog* net_log);
+ scoped_ptr<SOCKSClientSocket> BuildMockSocket(
+ MockRead reads[], size_t reads_count,
+ MockWrite writes[], size_t writes_count,
+ HostResolver* host_resolver,
+ const std::string& hostname, int port,
+ NetLog* net_log);
virtual void SetUp();
protected:
@@ -54,7 +56,7 @@ void SOCKSClientSocketTest::SetUp() {
PlatformTest::SetUp();
}
-SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
+scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
MockRead reads[],
size_t reads_count,
MockWrite writes[],
@@ -78,11 +80,11 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
// |connection| takes ownership of |tcp_sock_|, but keep a
// non-owning pointer to it.
- connection->set_socket(tcp_sock_);
- return new SOCKSClientSocket(
- connection.release(),
+ connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
+ return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
+ connection.Pass(),
HostResolver::RequestInfo(HostPortPair(hostname, port)),
- host_resolver);
+ host_resolver));
}
// Implementation of HostResolver that never completes its resolve request.
@@ -141,11 +143,11 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
MockRead(ASYNC, payload_read.data(), payload_read.size()) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
// At this state the TCP connection is completed but not the SOCKS handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
@@ -217,11 +219,11 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
arraysize(tests[i].fail_reply)) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -254,11 +256,11 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) {
MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -292,11 +294,11 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -324,11 +326,11 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
MockRead(SYNCHRONOUS, 0) };
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- host_resolver_.get(),
- "localhost", 80,
- &log));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -354,11 +356,11 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) {
CapturingNetLog log;
- user_sock_.reset(BuildMockSocket(NULL, 0,
- NULL, 0,
- host_resolver_.get(),
- hostname, 80,
- &log));
+ user_sock_ = BuildMockSocket(NULL, 0,
+ NULL, 0,
+ host_resolver_.get(),
+ hostname, 80,
+ &log);
int rv = user_sock_->Connect(callback_.callback());
EXPECT_EQ(ERR_IO_PENDING, rv);
@@ -385,11 +387,11 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
- user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
- data_writes, arraysize(data_writes),
- hanging_resolver.get(),
- "foo", 80,
- NULL));
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hanging_resolver.get(),
+ "foo", 80,
+ NULL);
// Start connecting (will get stuck waiting for the host to resolve).
int rv = user_sock_->Connect(callback_.callback());
diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc
index 6274e64..acc1b0de 100644
--- a/net/socket/ssl_client_socket_nss.cc
+++ b/net/socket/ssl_client_socket_nss.cc
@@ -2751,12 +2751,12 @@ void SSLClientSocketNSS::Core::SetChannelIDProvided() {
SSLClientSocketNSS::SSLClientSocketNSS(
base::SequencedTaskRunner* nss_task_runner,
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context)
: nss_task_runner_(nss_task_runner),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
host_and_port_(host_and_port),
ssl_config_(ssl_config),
cert_verifier_(context.cert_verifier),
@@ -2765,7 +2765,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(
completed_handshake_(false),
next_handshake_state_(STATE_NONE),
nss_fd_(NULL),
- net_log_(transport_socket->socket()->NetLog()),
+ net_log_(transport_->socket()->NetLog()),
transport_security_state_(context.transport_security_state),
valid_thread_id_(base::kInvalidThreadId) {
EnterFunction("");
diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h
index fed8ef7..b41d28d 100644
--- a/net/socket/ssl_client_socket_nss.h
+++ b/net/socket/ssl_client_socket_nss.h
@@ -59,7 +59,7 @@ class SSLClientSocketNSS : public SSLClientSocket {
// behaviour is desired, for performance or compatibility, the current task
// runner should be supplied instead.
SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner,
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context);
diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc
index 1431bc6..4591cec 100644
--- a/net/socket/ssl_client_socket_openssl.cc
+++ b/net/socket/ssl_client_socket_openssl.cc
@@ -425,7 +425,7 @@ void SSLClientSocket::ClearSessionCache() {
}
SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
- ClientSocketHandle* transport_socket,
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context)
@@ -439,14 +439,14 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
cert_verifier_(context.cert_verifier),
ssl_(NULL),
transport_bio_(NULL),
- transport_(transport_socket),
+ transport_(transport_socket.Pass()),
host_and_port_(host_and_port),
ssl_config_(ssl_config),
ssl_session_cache_shard_(context.ssl_session_cache_shard),
trying_cached_session_(false),
next_handshake_state_(STATE_NONE),
npn_status_(kNextProtoUnsupported),
- net_log_(transport_socket->socket()->NetLog()) {
+ net_log_(transport_->socket()->NetLog()) {
}
SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() {
diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h
index 520f432..f66d95c 100644
--- a/net/socket/ssl_client_socket_openssl.h
+++ b/net/socket/ssl_client_socket_openssl.h
@@ -41,7 +41,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
// The given hostname will be compared with the name(s) in the server's
// certificate during the SSL handshake. ssl_config specifies the SSL
// settings.
- SSLClientSocketOpenSSL(ClientSocketHandle* transport_socket,
+ SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context);
diff --git a/net/socket/ssl_client_socket_openssl_unittest.cc b/net/socket/ssl_client_socket_openssl_unittest.cc
index 7da8625..04f8999 100644
--- a/net/socket/ssl_client_socket_openssl_unittest.cc
+++ b/net/socket/ssl_client_socket_openssl_unittest.cc
@@ -107,13 +107,13 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest {
}
protected:
- SSLClientSocket* CreateSSLClientSocket(
- StreamSocket* transport_socket,
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<StreamSocket> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config) {
scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
- connection->set_socket(transport_socket);
- return socket_factory_->CreateSSLClientSocket(connection.release(),
+ connection->SetSocket(transport_socket.Pass());
+ return socket_factory_->CreateSSLClientSocket(connection.Pass(),
host_and_port,
ssl_config,
context_);
@@ -166,9 +166,9 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest {
// itself was a success.
bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config,
int* result) {
- sock_.reset(CreateSSLClientSocket(transport_.release(),
- test_server_->host_port_pair(),
- ssl_config));
+ sock_ = CreateSSLClientSocket(transport_.Pass(),
+ test_server_->host_port_pair(),
+ ssl_config);
if (sock_->IsConnected()) {
LOG(ERROR) << "SSL Socket prematurely connected";
diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc
index fed268d..d07c76f 100644
--- a/net/socket/ssl_client_socket_pool.cc
+++ b/net/socket/ssl_client_socket_pool.cc
@@ -287,11 +287,11 @@ int SSLConnectJob::DoSSLConnect() {
connect_timing_.ssl_start = base::TimeTicks::Now();
- ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket(
- transport_socket_handle_.release(),
+ ssl_socket_ = client_socket_factory_->CreateSSLClientSocket(
+ transport_socket_handle_.Pass(),
params_->host_and_port(),
params_->ssl_config(),
- context_));
+ context_);
return ssl_socket_->Connect(callback_);
}
@@ -410,7 +410,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
}
if (result == OK || IsCertificateError(result)) {
- set_socket(ssl_socket_.release());
+ SetSocket(ssl_socket_.PassAs<StreamSocket>());
} else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) {
error_response_info_.cert_request_info = new SSLCertRequestInfo;
ssl_socket_->GetSSLCertRequestInfo(
@@ -527,14 +527,16 @@ SSLClientSocketPool::~SSLClientSocketPool() {
ssl_config_service_->RemoveObserver(this);
}
-ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
+scoped_ptr<ConnectJob>
+SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(),
- transport_pool_, socks_pool_, http_proxy_pool_,
- client_socket_factory_, host_resolver_,
- context_, delegate, net_log_);
+ return scoped_ptr<ConnectJob>(
+ new SSLConnectJob(group_name, request.params(), ConnectionTimeout(),
+ transport_pool_, socks_pool_, http_proxy_pool_,
+ client_socket_factory_, host_resolver_,
+ context_, delegate, net_log_));
}
base::TimeDelta
@@ -572,8 +574,9 @@ void SSLClientSocketPool::CancelRequest(const std::string& group_name,
}
void SSLClientSocketPool::ReleaseSocket(const std::string& group_name,
- StreamSocket* socket, int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void SSLClientSocketPool::FlushWithError(int error) {
diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h
index bc54bc9..431a1b7c 100644
--- a/net/socket/ssl_client_socket_pool.h
+++ b/net/socket/ssl_client_socket_pool.h
@@ -204,7 +204,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
@@ -261,7 +261,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
virtual ~SSLConnectJobFactory() {}
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc
index af7f9b9..f791928 100644
--- a/net/socket/ssl_client_socket_unittest.cc
+++ b/net/socket/ssl_client_socket_unittest.cc
@@ -508,13 +508,14 @@ class SSLClientSocketTest : public PlatformTest {
}
protected:
- SSLClientSocket* CreateSSLClientSocket(StreamSocket* transport_socket,
- const HostPortPair& host_and_port,
- const SSLConfig& ssl_config) {
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<StreamSocket> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config) {
scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
- connection->set_socket(transport_socket);
+ connection->SetSocket(transport_socket.Pass());
return socket_factory_->CreateSSLClientSocket(
- connection.release(), host_and_port, ssl_config, context_);
+ connection.Pass(), host_and_port, ssl_config, context_);
}
ClientSocketFactory* socket_factory_;
@@ -552,14 +553,15 @@ TEST_F(SSLClientSocketTest, Connect) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
@@ -593,14 +595,15 @@ TEST_F(SSLClientSocketTest, ConnectExpired) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
@@ -636,14 +639,15 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
@@ -679,14 +683,15 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
@@ -737,7 +742,8 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
@@ -748,7 +754,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
ssl_config.client_cert = NULL;
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), ssl_config));
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
@@ -793,14 +799,15 @@ TEST_F(SSLClientSocketTest, Read) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -851,8 +858,8 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
- SynchronousErrorStreamSocket* transport =
- new SynchronousErrorStreamSocket(real_transport.Pass());
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -860,8 +867,11 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), ssl_config));
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -878,7 +888,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
EXPECT_EQ(kRequestTextSize, rv);
// Simulate an unclean/forcible shutdown.
- transport->SetNextReadError(ERR_CONNECTION_RESET);
+ raw_transport->SetNextReadError(ERR_CONNECTION_RESET);
scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
@@ -912,12 +922,14 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
- // Note: |error_socket|'s ownership is handed to |transport|, but the pointer
+ // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
// is retained in order to configure additional errors.
- SynchronousErrorStreamSocket* error_socket =
- new SynchronousErrorStreamSocket(real_transport.Pass());
- FakeBlockingStreamSocket* transport =
- new FakeBlockingStreamSocket(scoped_ptr<StreamSocket>(error_socket));
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -925,8 +937,10 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), ssl_config));
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -940,8 +954,8 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
// Simulate an unclean/forcible shutdown on the underlying socket.
// However, simulate this error asynchronously.
- error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
- transport->SetNextWriteShouldBlock();
+ raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
+ raw_transport->SetNextWriteShouldBlock();
// This write should complete synchronously, because the TLS ciphertext
// can be created and placed into the outgoing buffers independent of the
@@ -957,7 +971,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
// Now unblock the outgoing request, having it fail with the connection
// being reset.
- transport->UnblockWrite();
+ raw_transport->UnblockWrite();
// Note: This will cause an inifite loop if this bug has regressed. Simply
// checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING
@@ -986,14 +1000,15 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) {
TestCompletionCallback callback; // Used for everything except Write.
- StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -1049,12 +1064,14 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
TestCompletionCallback callback;
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
- // Note: |error_socket|'s ownership is handed to |transport|, but the pointer
+ // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
// is retained in order to configure additional errors.
- SynchronousErrorStreamSocket* error_socket =
- new SynchronousErrorStreamSocket(real_transport.Pass());
- FakeBlockingStreamSocket* transport =
- new FakeBlockingStreamSocket(scoped_ptr<StreamSocket>(error_socket));
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1063,8 +1080,10 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- SSLClientSocket* sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), ssl_config));
+ scoped_ptr<SSLClientSocket> sock =
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config);
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1077,18 +1096,19 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
new StringIOBuffer(request_text), request_text.size()));
// Simulate errors being returned from the underlying Read() and Write() ...
- error_socket->SetNextReadError(ERR_CONNECTION_RESET);
- error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
+ raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET);
+ raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
// ... but have those errors returned asynchronously. Because the Write() will
// return first, this will trigger the error.
- transport->SetNextReadShouldBlock();
- transport->SetNextWriteShouldBlock();
+ raw_transport->SetNextReadShouldBlock();
+ raw_transport->SetNextWriteShouldBlock();
// Enqueue a Read() before calling Write(), which should "hang" due to
// the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return.
- DeleteSocketCallback read_callback(sock);
+ SSLClientSocket* raw_sock = sock.get();
+ DeleteSocketCallback read_callback(sock.release());
scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096));
- rv = sock->Read(read_buf.get(), 4096, read_callback.callback());
+ rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback());
// Ensure things didn't complete synchronously, otherwise |sock| is invalid.
ASSERT_EQ(ERR_IO_PENDING, rv);
@@ -1111,9 +1131,9 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
// SSLClientSocketOpenSSL::Write() will not return until all of
// |request_buffer| has been written to the underlying BIO (although not
// necessarily the underlying transport).
- rv = callback.GetResult(sock->Write(request_buffer.get(),
- request_buffer->BytesRemaining(),
- callback.callback()));
+ rv = callback.GetResult(raw_sock->Write(request_buffer.get(),
+ request_buffer->BytesRemaining(),
+ callback.callback()));
ASSERT_LT(0, rv);
request_buffer->DidConsume(rv);
@@ -1126,16 +1146,16 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
// Attempt to write the remaining data. NSS will not be able to consume the
// application data because the internal buffers are full, while OpenSSL will
// return that its blocked because the underlying transport is blocked.
- rv = sock->Write(request_buffer.get(),
- request_buffer->BytesRemaining(),
- callback.callback());
+ rv = raw_sock->Write(request_buffer.get(),
+ request_buffer->BytesRemaining(),
+ callback.callback());
ASSERT_EQ(ERR_IO_PENDING, rv);
ASSERT_FALSE(callback.have_result());
// Now unblock Write(), which will invoke OnSendComplete and (eventually)
// call the Read() callback, deleting the socket and thus aborting calling
// the Write() callback.
- transport->UnblockWrite();
+ raw_transport->UnblockWrite();
rv = read_callback.WaitForResult();
@@ -1161,14 +1181,15 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -1215,13 +1236,16 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
scoped_ptr<StreamSocket> real_transport(
new TCPClientSocket(addr, NULL, NetLog::Source()));
- ReadBufferingStreamSocket* transport =
- new ReadBufferingStreamSocket(real_transport.Pass());
+ scoped_ptr<ReadBufferingStreamSocket> transport(
+ new ReadBufferingStreamSocket(real_transport.Pass()));
+ ReadBufferingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
- scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ kDefaultSSLConfig));
rv = callback.GetResult(sock->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
@@ -1246,7 +1270,7 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
// 15K was chosen because 15K is smaller than the 17K (max) read issued by
// the SSLClientSocket implementation, and larger than the minimum amount
// of ciphertext necessary to contain the 8K of plaintext requested below.
- transport->SetBufferSize(15000);
+ raw_transport->SetBufferSize(15000);
scoped_refptr<IOBuffer> buffer(new IOBuffer(8192));
rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback()));
@@ -1263,14 +1287,15 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -1313,14 +1338,15 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) {
TestCompletionCallback callback;
CapturingNetLog log;
log.SetLogLevel(NetLog::LOG_ALL);
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -1398,14 +1424,15 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) {
StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0);
- StreamSocket* transport = new MockTCPClientSocket(addr, NULL, &data);
+ scoped_ptr<StreamSocket> transport(
+ new MockTCPClientSocket(addr, NULL, &data));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -1433,7 +1460,8 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
@@ -1444,7 +1472,7 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), ssl_config));
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
EXPECT_FALSE(sock->IsConnected());
@@ -1499,17 +1527,18 @@ TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
- ClientSocketHandle* socket_handle = new ClientSocketHandle();
- socket_handle->set_socket(transport);
+ scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle());
+ socket_handle->SetSocket(transport.Pass());
scoped_ptr<SSLClientSocket> sock(
- socket_factory_->CreateSSLClientSocket(socket_handle,
+ socket_factory_->CreateSSLClientSocket(socket_handle.Pass(),
test_server.host_port_pair(),
kDefaultSSLConfig,
context_));
@@ -1534,14 +1563,15 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) {
TestCompletionCallback callback;
- StreamSocket* transport = new TCPClientSocket(addr, NULL, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = sock->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
@@ -1629,14 +1659,15 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
@@ -1688,14 +1719,15 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
TestCompletionCallback callback;
CapturingNetLog log;
- StreamSocket* transport = new TCPClientSocket(addr, &log, NetLog::Source());
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
int rv = transport->Connect(callback.callback());
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
EXPECT_EQ(OK, rv);
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport, test_server.host_port_pair(), kDefaultSSLConfig));
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
EXPECT_FALSE(sock->IsConnected());
rv = sock->Connect(callback.callback());
diff --git a/net/socket/ssl_server_socket.h b/net/socket/ssl_server_socket.h
index 52d53cb..8b607bf 100644
--- a/net/socket/ssl_server_socket.h
+++ b/net/socket/ssl_server_socket.h
@@ -6,6 +6,7 @@
#define NET_SOCKET_SSL_SERVER_SOCKET_H_
#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
#include "net/socket/ssl_socket.h"
@@ -52,8 +53,8 @@ NET_EXPORT void EnableSSLServerSockets();
//
// The caller starts the SSL server handshake by calling Handshake on the
// returned socket.
-NET_EXPORT SSLServerSocket* CreateSSLServerSocket(
- StreamSocket* socket,
+NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
X509Certificate* certificate,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config);
diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc
index c2681d3..7e5d701 100644
--- a/net/socket/ssl_server_socket_nss.cc
+++ b/net/socket/ssl_server_socket_nss.cc
@@ -78,19 +78,20 @@ void EnableSSLServerSockets() {
g_nss_ssl_server_init_singleton.Get();
}
-SSLServerSocket* CreateSSLServerSocket(
- StreamSocket* socket,
+scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
X509Certificate* cert,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config) {
DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been"
<< "called yet!";
- return new SSLServerSocketNSS(socket, cert, key, ssl_config);
+ return scoped_ptr<SSLServerSocket>(
+ new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config));
}
SSLServerSocketNSS::SSLServerSocketNSS(
- StreamSocket* transport_socket,
+ scoped_ptr<StreamSocket> transport_socket,
scoped_refptr<X509Certificate> cert,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config)
@@ -100,7 +101,7 @@ SSLServerSocketNSS::SSLServerSocketNSS(
user_write_buf_len_(0),
nss_fd_(NULL),
nss_bufs_(NULL),
- transport_socket_(transport_socket),
+ transport_socket_(transport_socket.Pass()),
ssl_config_(ssl_config),
cert_(cert),
next_handshake_state_(STATE_NONE),
diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h
index 17a1fc3..8bbb0e3 100644
--- a/net/socket/ssl_server_socket_nss.h
+++ b/net/socket/ssl_server_socket_nss.h
@@ -24,7 +24,7 @@ class SSLServerSocketNSS : public SSLServerSocket {
public:
// See comments on CreateSSLServerSocket for details of how these
// parameters are used.
- SSLServerSocketNSS(StreamSocket* socket,
+ SSLServerSocketNSS(scoped_ptr<StreamSocket> socket,
scoped_refptr<X509Certificate> certificate,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config);
diff --git a/net/socket/ssl_server_socket_openssl.cc b/net/socket/ssl_server_socket_openssl.cc
index e0cf8bc..c327f2c 100644
--- a/net/socket/ssl_server_socket_openssl.cc
+++ b/net/socket/ssl_server_socket_openssl.cc
@@ -16,13 +16,13 @@ void EnableSSLServerSockets() {
NOTIMPLEMENTED();
}
-SSLServerSocket* CreateSSLServerSocket(StreamSocket* socket,
- X509Certificate* certificate,
- crypto::RSAPrivateKey* key,
- const SSLConfig& ssl_config) {
+scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
+ X509Certificate* certificate,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config) {
NOTIMPLEMENTED();
- delete socket;
- return NULL;
+ return scoped_ptr<SSLServerSocket>();
}
} // namespace net
diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc
index 63bf037..64c8549 100644
--- a/net/socket/ssl_server_socket_unittest.cc
+++ b/net/socket/ssl_server_socket_unittest.cc
@@ -304,8 +304,11 @@ class SSLServerSocketTest : public PlatformTest {
protected:
void Initialize() {
- FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_);
- FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_);
+ scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
+ client_connection->SetSocket(
+ scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
+ scoped_ptr<StreamSocket> server_socket(
+ new FakeSocket(&channel_2_, &channel_1_));
base::FilePath certs_dir(GetTestCertsDirectory());
@@ -344,13 +347,12 @@ class SSLServerSocketTest : public PlatformTest {
net::SSLClientSocketContext context;
context.cert_verifier = cert_verifier_.get();
context.transport_security_state = transport_security_state_.get();
- scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
- connection->set_socket(fake_client_socket);
- client_socket_.reset(
+ client_socket_ =
socket_factory_->CreateSSLClientSocket(
- connection.release(), host_and_pair, ssl_config, context));
- server_socket_.reset(net::CreateSSLServerSocket(
- fake_server_socket, cert.get(), private_key.get(), net::SSLConfig()));
+ client_connection.Pass(), host_and_pair, ssl_config, context);
+ server_socket_ = net::CreateSSLServerSocket(
+ server_socket.Pass(),
+ cert.get(), private_key.get(), net::SSLConfig());
}
FakeDataChannel channel_1_;
diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc
index 8255e98..6d0afac 100644
--- a/net/socket/transport_client_socket_pool.cc
+++ b/net/socket/transport_client_socket_pool.cc
@@ -190,8 +190,8 @@ int TransportConnectJob::DoResolveHostComplete(int result) {
int TransportConnectJob::DoTransportConnect() {
next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
- transport_socket_.reset(client_socket_factory_->CreateTransportClientSocket(
- addresses_, net_log().net_log(), net_log().source()));
+ transport_socket_ = client_socket_factory_->CreateTransportClientSocket(
+ addresses_, net_log().net_log(), net_log().source());
int rv = transport_socket_->Connect(
base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)));
if (rv == ERR_IO_PENDING &&
@@ -246,7 +246,7 @@ int TransportConnectJob::DoTransportConnectComplete(int result) {
100);
}
}
- set_socket(transport_socket_.release());
+ SetSocket(transport_socket_.Pass());
fallback_timer_.Stop();
} else {
// Be a bit paranoid and kill off the fallback members to prevent reuse.
@@ -270,9 +270,9 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() {
fallback_addresses_.reset(new AddressList(addresses_));
MakeAddressListStartWithIPv4(fallback_addresses_.get());
- fallback_transport_socket_.reset(
+ fallback_transport_socket_ =
client_socket_factory_->CreateTransportClientSocket(
- *fallback_addresses_, net_log().net_log(), net_log().source()));
+ *fallback_addresses_, net_log().net_log(), net_log().source());
fallback_connect_start_time_ = base::TimeTicks::Now();
int rv = fallback_transport_socket_->Connect(
base::Bind(
@@ -317,7 +317,7 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) {
base::TimeDelta::FromMilliseconds(1),
base::TimeDelta::FromMinutes(10),
100);
- set_socket(fallback_transport_socket_.release());
+ SetSocket(fallback_transport_socket_.Pass());
next_state_ = STATE_NONE;
transport_socket_.reset();
} else {
@@ -333,18 +333,19 @@ int TransportConnectJob::ConnectInternal() {
return DoLoop(OK);
}
-ConnectJob*
+scoped_ptr<ConnectJob>
TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return new TransportConnectJob(group_name,
- request.params(),
- ConnectionTimeout(),
- client_socket_factory_,
- host_resolver_,
- delegate,
- net_log_);
+ return scoped_ptr<ConnectJob>(
+ new TransportConnectJob(group_name,
+ request.params(),
+ ConnectionTimeout(),
+ client_socket_factory_,
+ host_resolver_,
+ delegate,
+ net_log_));
}
base::TimeDelta
@@ -419,9 +420,9 @@ void TransportClientSocketPool::CancelRequest(
void TransportClientSocketPool::ReleaseSocket(
const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) {
- base_.ReleaseSocket(group_name, socket, id);
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
}
void TransportClientSocketPool::FlushWithError(int error) {
diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h
index bb53b3d..f07dc1f 100644
--- a/net/socket/transport_client_socket_pool.h
+++ b/net/socket/transport_client_socket_pool.h
@@ -156,7 +156,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
virtual void CancelRequest(const std::string& group_name,
ClientSocketHandle* handle) OVERRIDE;
virtual void ReleaseSocket(const std::string& group_name,
- StreamSocket* socket,
+ scoped_ptr<StreamSocket> socket,
int id) OVERRIDE;
virtual void FlushWithError(int error) OVERRIDE;
virtual bool IsStalled() const OVERRIDE;
@@ -193,7 +193,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual ConnectJob* NewConnectJob(
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const OVERRIDE;
diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc
index dfa1151..c607a38 100644
--- a/net/socket/transport_client_socket_pool_unittest.cc
+++ b/net/socket/transport_client_socket_pool_unittest.cc
@@ -23,6 +23,7 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_histograms.h"
#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -340,16 +341,16 @@ class MockClientSocketFactory : public ClientSocketFactory {
delay_(base::TimeDelta::FromMilliseconds(
ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
- virtual DatagramClientSocket* CreateDatagramClientSocket(
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) OVERRIDE {
NOTREACHED();
- return NULL;
+ return scoped_ptr<DatagramClientSocket>();
}
- virtual StreamSocket* CreateTransportClientSocket(
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* /* net_log */,
const NetLog::Source& /* source */) OVERRIDE {
@@ -363,34 +364,41 @@ class MockClientSocketFactory : public ClientSocketFactory {
switch (type) {
case MOCK_CLIENT_SOCKET:
- return new MockClientSocket(addresses, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockClientSocket(addresses, net_log_));
case MOCK_FAILING_CLIENT_SOCKET:
- return new MockFailingClientSocket(addresses, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockFailingClientSocket(addresses, net_log_));
case MOCK_PENDING_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, true, false, base::TimeDelta(), net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, false, base::TimeDelta(), net_log_));
case MOCK_PENDING_FAILING_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, false, false, base::TimeDelta(), net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, false, false, base::TimeDelta(), net_log_));
case MOCK_DELAYED_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, true, false, delay_, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, false, delay_, net_log_));
case MOCK_STALLED_CLIENT_SOCKET:
- return new MockPendingClientSocket(
- addresses, true, true, base::TimeDelta(), net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, true, base::TimeDelta(), net_log_));
default:
NOTREACHED();
- return new MockClientSocket(addresses, net_log_);
+ return scoped_ptr<StreamSocket>(
+ new MockClientSocket(addresses, net_log_));
}
}
- virtual SSLClientSocket* CreateSSLClientSocket(
- ClientSocketHandle* transport_socket,
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) OVERRIDE {
NOTIMPLEMENTED();
- return NULL;
+ return scoped_ptr<SSLClientSocket>();
}
virtual void ClearSSLSessionCache() OVERRIDE {
diff --git a/net/socket/transport_client_socket_unittest.cc b/net/socket/transport_client_socket_unittest.cc
index 2f75e74..5c5a303 100644
--- a/net/socket/transport_client_socket_unittest.cc
+++ b/net/socket/transport_client_socket_unittest.cc
@@ -130,10 +130,10 @@ void TransportClientSocketTest::SetUp() {
CHECK_EQ(ERR_IO_PENDING, rv);
rv = callback.WaitForResult();
CHECK_EQ(rv, OK);
- sock_.reset(
+ sock_ =
socket_factory_->CreateTransportClientSocket(addr,
&net_log_,
- NetLog::Source()));
+ NetLog::Source());
}
int TransportClientSocketTest::DrainClientSocket(
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc
index ab4c3fd..c549fcb 100644
--- a/net/socket_stream/socket_stream.cc
+++ b/net/socket_stream/socket_stream.cc
@@ -97,6 +97,7 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate)
proxy_mode_(kDirectConnection),
proxy_url_(url),
pac_request_(NULL),
+ connection_(new ClientSocketHandle),
privacy_mode_(kPrivacyModeDisabled),
// Unretained() is required; without it, Bind() creates a circular
// dependency and the SocketStream object will not be freed.
@@ -206,8 +207,10 @@ bool SocketStream::SendData(const char* data, int len) {
<< "The current base::MessageLoop must be TYPE_IO";
DCHECK_GT(len, 0);
- if (!socket_.get() || !socket_->IsConnected() || next_state_ == STATE_NONE)
+ if (!connection_->socket() ||
+ !connection_->socket()->IsConnected() || next_state_ == STATE_NONE) {
return false;
+ }
int total_buffered_bytes = len;
if (current_write_buf_.get()) {
@@ -265,7 +268,7 @@ void SocketStream::RestartWithAuth(const AuthCredentials& credentials) {
DCHECK_EQ(base::MessageLoop::TYPE_IO, base::MessageLoop::current()->type())
<< "The current base::MessageLoop must be TYPE_IO";
DCHECK(proxy_auth_controller_.get());
- if (!socket_.get()) {
+ if (!connection_->socket()) {
DVLOG(1) << "Socket is closed before restarting with auth.";
return;
}
@@ -370,7 +373,7 @@ void SocketStream::Finish(int result) {
}
int SocketStream::DidEstablishConnection() {
- if (!socket_.get() || !socket_->IsConnected()) {
+ if (!connection_->socket() || !connection_->socket()->IsConnected()) {
next_state_ = STATE_CLOSE;
return ERR_CONNECTION_FAILED;
}
@@ -731,11 +734,12 @@ int SocketStream::DoTcpConnect(int result) {
}
next_state_ = STATE_TCP_CONNECT_COMPLETE;
DCHECK(factory_);
- socket_.reset(factory_->CreateTransportClientSocket(addresses_,
- net_log_.net_log(),
- net_log_.source()));
+ connection_->SetSocket(
+ factory_->CreateTransportClientSocket(addresses_,
+ net_log_.net_log(),
+ net_log_.source()));
metrics_->OnStartConnection();
- return socket_->Connect(io_callback_);
+ return connection_->socket()->Connect(io_callback_);
}
int SocketStream::DoTcpConnectComplete(int result) {
@@ -820,7 +824,8 @@ int SocketStream::DoWriteTunnelHeaders() {
int buf_len = static_cast<int>(tunnel_request_headers_->headers_.size() -
tunnel_request_headers_bytes_sent_);
DCHECK_GT(buf_len, 0);
- return socket_->Write(tunnel_request_headers_.get(), buf_len, io_callback_);
+ return connection_->socket()->Write(
+ tunnel_request_headers_.get(), buf_len, io_callback_);
}
int SocketStream::DoWriteTunnelHeadersComplete(int result) {
@@ -863,7 +868,8 @@ int SocketStream::DoReadTunnelHeaders() {
tunnel_response_headers_->SetDataOffset(tunnel_response_headers_len_);
CHECK(tunnel_response_headers_->data());
- return socket_->Read(tunnel_response_headers_.get(), buf_len, io_callback_);
+ return connection_->socket()->Read(
+ tunnel_response_headers_.get(), buf_len, io_callback_);
}
int SocketStream::DoReadTunnelHeadersComplete(int result) {
@@ -957,16 +963,17 @@ int SocketStream::DoSOCKSConnect() {
HostResolver::RequestInfo req_info(HostPortPair::FromURL(url_));
DCHECK(!proxy_info_.is_empty());
- scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
- connection->set_socket(socket_.release());
+ scoped_ptr<StreamSocket> s;
if (proxy_info_.proxy_server().scheme() == ProxyServer::SCHEME_SOCKS5) {
- socket_.reset(new SOCKS5ClientSocket(connection.release(), req_info));
+ s.reset(new SOCKS5ClientSocket(connection_.Pass(), req_info));
} else {
- socket_.reset(new SOCKSClientSocket(
- connection.release(), req_info, context_->host_resolver()));
+ s.reset(new SOCKSClientSocket(
+ connection_.Pass(), req_info, context_->host_resolver()));
}
+ connection_.reset(new ClientSocketHandle);
+ connection_->SetSocket(s.Pass());
metrics_->OnCountConnectionType(SocketStreamMetrics::SOCKS_CONNECTION);
- return socket_->Connect(io_callback_);
+ return connection_->socket()->Connect(io_callback_);
}
int SocketStream::DoSOCKSConnectComplete(int result) {
@@ -989,16 +996,16 @@ int SocketStream::DoSecureProxyConnect() {
ssl_context.cert_verifier = context_->cert_verifier();
ssl_context.transport_security_state = context_->transport_security_state();
ssl_context.server_bound_cert_service = context_->server_bound_cert_service();
- scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
- connection->set_socket(socket_.release());
- socket_.reset(factory_->CreateSSLClientSocket(
- connection.release(),
+ scoped_ptr<StreamSocket> socket(factory_->CreateSSLClientSocket(
+ connection_.Pass(),
proxy_info_.proxy_server().host_port_pair(),
proxy_ssl_config_,
ssl_context));
+ connection_.reset(new ClientSocketHandle);
+ connection_->SetSocket(socket.Pass());
next_state_ = STATE_SECURE_PROXY_CONNECT_COMPLETE;
metrics_->OnCountConnectionType(SocketStreamMetrics::SECURE_PROXY_CONNECTION);
- return socket_->Connect(io_callback_);
+ return connection_->socket()->Connect(io_callback_);
}
int SocketStream::DoSecureProxyConnectComplete(int result) {
@@ -1030,7 +1037,7 @@ int SocketStream::DoSecureProxyHandleCertError(int result) {
int SocketStream::DoSecureProxyHandleCertErrorComplete(int result) {
DCHECK_EQ(STATE_NONE, next_state_);
if (result == OK) {
- if (!socket_->IsConnectedAndIdle())
+ if (!connection_->socket()->IsConnectedAndIdle())
return AllowCertErrorForReconnection(&proxy_ssl_config_);
next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN;
} else {
@@ -1045,15 +1052,16 @@ int SocketStream::DoSSLConnect() {
ssl_context.cert_verifier = context_->cert_verifier();
ssl_context.transport_security_state = context_->transport_security_state();
ssl_context.server_bound_cert_service = context_->server_bound_cert_service();
- scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
- connection->set_socket(socket_.release());
- socket_.reset(factory_->CreateSSLClientSocket(connection.release(),
- HostPortPair::FromURL(url_),
- server_ssl_config_,
- ssl_context));
+ scoped_ptr<StreamSocket> socket(
+ factory_->CreateSSLClientSocket(connection_.Pass(),
+ HostPortPair::FromURL(url_),
+ server_ssl_config_,
+ ssl_context));
+ connection_.reset(new ClientSocketHandle);
+ connection_->SetSocket(socket.Pass());
next_state_ = STATE_SSL_CONNECT_COMPLETE;
metrics_->OnCountConnectionType(SocketStreamMetrics::SSL_CONNECTION);
- return socket_->Connect(io_callback_);
+ return connection_->socket()->Connect(io_callback_);
}
int SocketStream::DoSSLConnectComplete(int result) {
@@ -1089,7 +1097,7 @@ int SocketStream::DoSSLHandleCertErrorComplete(int result) {
// we should take care of TLS NPN extension here.
if (result == OK) {
- if (!socket_->IsConnectedAndIdle())
+ if (!connection_->socket()->IsConnectedAndIdle())
return AllowCertErrorForReconnection(&server_ssl_config_);
result = DidEstablishConnection();
} else {
@@ -1103,7 +1111,7 @@ int SocketStream::DoReadWrite(int result) {
next_state_ = STATE_CLOSE;
return result;
}
- if (!socket_.get() || !socket_->IsConnected()) {
+ if (!connection_->socket() || !connection_->socket()->IsConnected()) {
next_state_ = STATE_CLOSE;
return ERR_CONNECTION_CLOSED;
}
@@ -1112,7 +1120,7 @@ int SocketStream::DoReadWrite(int result) {
// let's close the socket.
// We don't care about receiving data after the socket is closed.
if (closing_ && !current_write_buf_.get() && pending_write_bufs_.empty()) {
- socket_->Disconnect();
+ connection_->socket()->Disconnect();
next_state_ = STATE_CLOSE;
return OK;
}
@@ -1124,7 +1132,7 @@ int SocketStream::DoReadWrite(int result) {
if (!read_buf_.get()) {
// No read pending and server didn't close the socket.
read_buf_ = new IOBuffer(kReadBufferSize);
- result = socket_->Read(
+ result = connection_->socket()->Read(
read_buf_.get(),
kReadBufferSize,
base::Bind(&SocketStream::OnReadCompleted, base::Unretained(this)));
@@ -1163,7 +1171,7 @@ int SocketStream::DoReadWrite(int result) {
pending_write_bufs_.pop_front();
}
- result = socket_->Write(
+ result = connection_->socket()->Write(
current_write_buf_.get(),
current_write_buf_->BytesRemaining(),
base::Bind(&SocketStream::OnWriteCompleted, base::Unretained(this)));
@@ -1195,10 +1203,10 @@ int SocketStream::HandleCertificateRequest(int result, SSLConfig* ssl_config) {
return result;
}
- DCHECK(socket_.get());
+ DCHECK(connection_->socket());
scoped_refptr<SSLCertRequestInfo> cert_request_info = new SSLCertRequestInfo;
SSLClientSocket* ssl_socket =
- static_cast<SSLClientSocket*>(socket_.get());
+ static_cast<SSLClientSocket*>(connection_->socket());
ssl_socket->GetSSLCertRequestInfo(cert_request_info.get());
HttpTransactionFactory* factory = context_->http_transaction_factory();
@@ -1244,7 +1252,8 @@ int SocketStream::AllowCertErrorForReconnection(SSLConfig* ssl_config) {
// allowed bad certificates in |ssl_config|.
// See also net/http/http_network_transaction.cc HandleCertificateError() and
// RestartIgnoringLastError().
- SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(socket_.get());
+ SSLClientSocket* ssl_socket =
+ static_cast<SSLClientSocket*>(connection_->socket());
SSLInfo ssl_info;
ssl_socket->GetSSLInfo(&ssl_info);
if (ssl_info.cert.get() == NULL ||
@@ -1266,8 +1275,8 @@ int SocketStream::AllowCertErrorForReconnection(SSLConfig* ssl_config) {
bad_cert.cert_status = ssl_info.cert_status;
ssl_config->allowed_bad_certs.push_back(bad_cert);
// Restart connection ignoring the bad certificate.
- socket_->Disconnect();
- socket_.reset();
+ connection_->socket()->Disconnect();
+ connection_->SetSocket(scoped_ptr<StreamSocket>());
next_state_ = STATE_TCP_CONNECT;
return OK;
}
@@ -1293,7 +1302,8 @@ void SocketStream::DoRestartWithAuth() {
int SocketStream::HandleCertificateError(int result) {
DCHECK(IsCertificateError(result));
- SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(socket_.get());
+ SSLClientSocket* ssl_socket =
+ static_cast<SSLClientSocket*>(connection_->socket());
DCHECK(ssl_socket);
if (!context_)
diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h
index 5004060..90aeb8c 100644
--- a/net/socket_stream/socket_stream.h
+++ b/net/socket_stream/socket_stream.h
@@ -28,13 +28,13 @@ namespace net {
class AuthChallengeInfo;
class CertVerifier;
class ClientSocketFactory;
+class ClientSocketHandle;
class CookieOptions;
class HostResolver;
class HttpAuthController;
class SSLInfo;
class ServerBoundCertService;
class SingleRequestHostResolver;
-class StreamSocket;
class SocketStreamMetrics;
class TransportSecurityState;
class URLRequestContext;
@@ -364,7 +364,7 @@ class NET_EXPORT SocketStream
scoped_ptr<SingleRequestHostResolver> resolver_;
AddressList addresses_;
- scoped_ptr<StreamSocket> socket_;
+ scoped_ptr<ClientSocketHandle> connection_;
SSLConfig server_ssl_config_;
SSLConfig proxy_ssl_config_;
diff --git a/net/spdy/spdy_test_util_common.cc b/net/spdy/spdy_test_util_common.cc
index 4383db0..9546684 100644
--- a/net/spdy/spdy_test_util_common.cc
+++ b/net/spdy/spdy_test_util_common.cc
@@ -649,8 +649,8 @@ base::WeakPtr<SpdySession> CreateFakeSpdySessionHelper(
EXPECT_FALSE(HasSpdySession(pool, key));
base::WeakPtr<SpdySession> spdy_session;
scoped_ptr<ClientSocketHandle> handle(new ClientSocketHandle());
- handle->set_socket(new FakeSpdySessionClientSocket(
- expected_status == OK ? ERR_IO_PENDING : expected_status));
+ handle->SetSocket(scoped_ptr<StreamSocket>(new FakeSpdySessionClientSocket(
+ expected_status == OK ? ERR_IO_PENDING : expected_status)));
EXPECT_EQ(
expected_status,
pool->CreateAvailableSessionFromSocket(
diff --git a/remoting/protocol/ssl_hmac_channel_authenticator.cc b/remoting/protocol/ssl_hmac_channel_authenticator.cc
index d5db72f..20bfc53b 100644
--- a/remoting/protocol/ssl_hmac_channel_authenticator.cc
+++ b/remoting/protocol/ssl_hmac_channel_authenticator.cc
@@ -73,16 +73,16 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate(
return;
}
- net::SSLConfig ssl_config;
- net::SSLServerSocket* server_socket =
- net::CreateSSLServerSocket(socket.release(),
+ scoped_ptr<net::SSLServerSocket> server_socket =
+ net::CreateSSLServerSocket(socket.Pass(),
cert.get(),
local_key_pair_->private_key(),
- ssl_config);
- socket_.reset(server_socket);
-
- result = server_socket->Handshake(base::Bind(
- &SslHmacChannelAuthenticator::OnConnected, base::Unretained(this)));
+ net::SSLConfig());
+ net::SSLServerSocket* raw_server_socket = server_socket.get();
+ socket_ = server_socket.Pass();
+ result = raw_server_socket->Handshake(
+ base::Bind(&SslHmacChannelAuthenticator::OnConnected,
+ base::Unretained(this)));
} else {
cert_verifier_.reset(net::CertVerifier::CreateDefault());
transport_security_state_.reset(new net::TransportSecurityState);
@@ -105,10 +105,10 @@ void SslHmacChannelAuthenticator::SecureAndAuthenticate(
context.cert_verifier = cert_verifier_.get();
context.transport_security_state = transport_security_state_.get();
scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
- connection->set_socket(socket.release());
- socket_.reset(
+ connection->SetSocket(socket.Pass());
+ socket_ =
net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
- connection.release(), host_and_port, ssl_config, context));
+ connection.Pass(), host_and_port, ssl_config, context);
result = socket_->Connect(
base::Bind(&SslHmacChannelAuthenticator::OnConnected,