diff options
author | jhawkins@chromium.org <jhawkins@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-12-06 23:33:24 +0000 |
---|---|---|
committer | jhawkins@chromium.org <jhawkins@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-12-06 23:33:24 +0000 |
commit | dbf036fcb743cfdcd5be421364c8b89b10ee3f55 (patch) | |
tree | 034c3028c8b523d2bb6c2703416b26a81bc31663 | |
parent | ad24b1827fe58c4a22c0cddb5791a95f2ab1b21b (diff) | |
download | chromium_src-dbf036fcb743cfdcd5be421364c8b89b10ee3f55.zip chromium_src-dbf036fcb743cfdcd5be421364c8b89b10ee3f55.tar.gz chromium_src-dbf036fcb743cfdcd5be421364c8b89b10ee3f55.tar.bz2 |
base::Bind: Convert StreamSocket::Connect.
BUG=none
TEST=none
R=csilv
Review URL: http://codereview.chromium.org/8801004
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@113283 0039d316-1c4b-4281-b951-d872f2087c98
47 files changed, 903 insertions, 228 deletions
diff --git a/chrome/browser/chromeos/web_socket_proxy.cc b/chrome/browser/chromeos/web_socket_proxy.cc index 7a04efe..10dbf48 100644 --- a/chrome/browser/chromeos/web_socket_proxy.cc +++ b/chrome/browser/chromeos/web_socket_proxy.cc @@ -26,6 +26,7 @@ #include "base/base64.h" #include "base/basictypes.h" #include "base/bind.h" +#include "base/bind_helpers.h" #include "base/lazy_instance.h" #include "base/logging.h" #include "base/memory/ref_counted.h" @@ -554,9 +555,11 @@ class SSLChan : public MessageLoopForIO::Watcher { read_pipe_(read_pipe), write_pipe_(write_pipe), method_factory_(this), - socket_connect_callback_(NewCallback(this, &SSLChan::OnSocketConnect)), + socket_connect_callback_( + base::Bind(&SSLChan::OnSocketConnect, base::Unretained(this))), ssl_handshake_callback_( - NewCallback(this, &SSLChan::OnSSLHandshakeCompleted)), + base::Bind(&SSLChan::OnSSLHandshakeCompleted, + base::Unretained(this))), socket_read_callback_(NewCallback(this, &SSLChan::OnSocketRead)), socket_write_callback_(NewCallback(this, &SSLChan::OnSocketWrite)) { if (!SetNonBlock(read_pipe_) || !SetNonBlock(write_pipe_)) { @@ -571,7 +574,7 @@ class SSLChan : public MessageLoopForIO::Watcher { Shut(net::ERR_FAILED); return; } - int result = socket_->Connect(socket_connect_callback_.get()); + int result = socket_->Connect(socket_connect_callback_); if (result != net::ERR_IO_PENDING) OnSocketConnect(result); } @@ -631,7 +634,7 @@ class SSLChan : public MessageLoopForIO::Watcher { OnSSLHandshakeCompleted(net::ERR_UNEXPECTED); return; } - result = socket_->Connect(ssl_handshake_callback_.get()); + result = socket_->Connect(ssl_handshake_callback_); if (result != net::ERR_IO_PENDING) OnSSLHandshakeCompleted(result); } @@ -792,8 +795,8 @@ class SSLChan : public MessageLoopForIO::Watcher { bool is_read_pipe_blocked_; bool is_write_pipe_blocked_; ScopedRunnableMethodFactory<SSLChan> method_factory_; - scoped_ptr<net::OldCompletionCallback> socket_connect_callback_; - scoped_ptr<net::OldCompletionCallback> ssl_handshake_callback_; + net::CompletionCallback socket_connect_callback_; + net::CompletionCallback ssl_handshake_callback_; scoped_ptr<net::OldCompletionCallback> socket_read_callback_; scoped_ptr<net::OldCompletionCallback> socket_write_callback_; MessageLoopForIO::FileDescriptorWatcher read_pipe_controller_; diff --git a/content/browser/renderer_host/p2p/socket_host_test_utils.h b/content/browser/renderer_host/p2p/socket_host_test_utils.h index e7002f4..bb05a30 100644 --- a/content/browser/renderer_host/p2p/socket_host_test_utils.h +++ b/content/browser/renderer_host/p2p/socket_host_test_utils.h @@ -63,6 +63,7 @@ class FakeSocket : public net::StreamSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -167,6 +168,10 @@ int FakeSocket::Connect(net::OldCompletionCallback* callback) { return 0; } +int FakeSocket::Connect(const net::CompletionCallback& callback) { + return 0; +} + void FakeSocket::Disconnect() { NOTREACHED(); } diff --git a/jingle/glue/pseudotcp_adapter.cc b/jingle/glue/pseudotcp_adapter.cc index 9f91834..6f3ddd1 100644 --- a/jingle/glue/pseudotcp_adapter.cc +++ b/jingle/glue/pseudotcp_adapter.cc @@ -34,6 +34,7 @@ class PseudoTcpAdapter::Core : public cricket::IPseudoTcpNotify, int Write(net::IOBuffer* buffer, int buffer_size, net::OldCompletionCallback* callback); int Connect(net::OldCompletionCallback* callback); + int Connect(const net::CompletionCallback& callback); void Disconnect(); bool IsConnected() const; @@ -68,7 +69,8 @@ class PseudoTcpAdapter::Core : public cricket::IPseudoTcpNotify, // This re-sets |timer| without triggering callbacks. void AdjustClock(); - net::OldCompletionCallback* connect_callback_; + net::OldCompletionCallback* old_connect_callback_; + net::CompletionCallback connect_callback_; net::OldCompletionCallback* read_callback_; net::OldCompletionCallback* write_callback_; @@ -93,7 +95,7 @@ class PseudoTcpAdapter::Core : public cricket::IPseudoTcpNotify, PseudoTcpAdapter::Core::Core(net::Socket* socket) - : connect_callback_(NULL), + : old_connect_callback_(NULL), read_callback_(NULL), write_callback_(NULL), ALLOW_THIS_IN_INITIALIZER_LIST(pseudo_tcp_(this, 0)), @@ -171,6 +173,27 @@ int PseudoTcpAdapter::Core::Connect(net::OldCompletionCallback* callback) { AdjustClock(); + old_connect_callback_ = callback; + connect_callback_.Reset(); + DoReadFromSocket(); + + return net::ERR_IO_PENDING; +} + +int PseudoTcpAdapter::Core::Connect(const net::CompletionCallback& callback) { + DCHECK_EQ(pseudo_tcp_.State(), cricket::PseudoTcp::TCP_LISTEN); + + // Reference the Core in case a callback deletes the adapter. + scoped_refptr<Core> core(this); + + // Start the connection attempt. + int result = pseudo_tcp_.Connect(); + if (result < 0) + return net::ERR_FAILED; + + AdjustClock(); + + old_connect_callback_ = NULL; connect_callback_ = callback; DoReadFromSocket(); @@ -183,7 +206,8 @@ void PseudoTcpAdapter::Core::Disconnect() { read_buffer_ = NULL; write_callback_ = NULL; write_buffer_ = NULL; - connect_callback_ = NULL; + old_connect_callback_ = NULL; + connect_callback_.Reset(); // TODO(wez): Connect should succeed if called after Disconnect, which // PseudoTcp doesn't support, so we need to teardown the internal PseudoTcp @@ -202,10 +226,14 @@ bool PseudoTcpAdapter::Core::IsConnected() const { void PseudoTcpAdapter::Core::OnTcpOpen(PseudoTcp* tcp) { DCHECK(tcp == &pseudo_tcp_); - if (connect_callback_) { - net::OldCompletionCallback* callback = connect_callback_; - connect_callback_ = NULL; + if (old_connect_callback_) { + net::OldCompletionCallback* callback = old_connect_callback_; + old_connect_callback_ = NULL; callback->Run(net::OK); + } else if (!connect_callback_.is_null()) { + net::CompletionCallback callback = connect_callback_; + connect_callback_.Reset(); + callback.Run(net::OK); } OnTcpReadable(tcp); @@ -257,10 +285,14 @@ void PseudoTcpAdapter::Core::OnTcpWriteable(PseudoTcp* tcp) { void PseudoTcpAdapter::Core::OnTcpClosed(PseudoTcp* tcp, uint32 error) { DCHECK_EQ(tcp, &pseudo_tcp_); - if (connect_callback_) { - net::OldCompletionCallback* callback = connect_callback_; - connect_callback_ = NULL; + if (old_connect_callback_) { + net::OldCompletionCallback* callback = old_connect_callback_; + old_connect_callback_ = NULL; callback->Run(net::MapSystemError(error)); + } else if (!connect_callback_.is_null()) { + net::CompletionCallback callback = connect_callback_; + connect_callback_.Reset(); + callback.Run(net::MapSystemError(error)); } if (read_callback_) { @@ -430,6 +462,16 @@ int PseudoTcpAdapter::Connect(net::OldCompletionCallback* callback) { return core_->Connect(callback); } +int PseudoTcpAdapter::Connect(const net::CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + + // net::StreamSocket requires that Connect return OK if already connected. + if (IsConnected()) + return net::OK; + + return core_->Connect(callback); +} + void PseudoTcpAdapter::Disconnect() { DCHECK(CalledOnValidThread()); core_->Disconnect(); diff --git a/jingle/glue/pseudotcp_adapter.h b/jingle/glue/pseudotcp_adapter.h index 5edb3e0..fa48b15 100644 --- a/jingle/glue/pseudotcp_adapter.h +++ b/jingle/glue/pseudotcp_adapter.h @@ -39,6 +39,7 @@ class PseudoTcpAdapter : public net::StreamSocket, base::NonThreadSafe { // net::StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; diff --git a/jingle/notifier/base/fake_ssl_client_socket.cc b/jingle/notifier/base/fake_ssl_client_socket.cc index 2ab7ef5..e247038 100644 --- a/jingle/notifier/base/fake_ssl_client_socket.cc +++ b/jingle/notifier/base/fake_ssl_client_socket.cc @@ -89,7 +89,7 @@ FakeSSLClientSocket::FakeSSLClientSocket( transport_socket_(transport_socket), next_handshake_state_(STATE_NONE), handshake_completed_(false), - user_connect_callback_(NULL), + old_user_connect_callback_(NULL), write_buf_(NewDrainableIOBufferWithSize(arraysize(kSslClientHello))), read_buf_(NewDrainableIOBufferWithSize(arraysize(kSslServerHello))) { CHECK(transport_socket_.get()); @@ -126,17 +126,34 @@ int FakeSSLClientSocket::Connect(net::OldCompletionCallback* callback) { DCHECK(callback); DCHECK_EQ(next_handshake_state_, STATE_NONE); DCHECK(!handshake_completed_); - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_); DCHECK_EQ(write_buf_->BytesConsumed(), 0); DCHECK_EQ(read_buf_->BytesConsumed(), 0); next_handshake_state_ = STATE_CONNECT; int status = DoHandshakeLoop(); if (status == net::ERR_IO_PENDING) { - user_connect_callback_ = callback; + old_user_connect_callback_ = callback; } return status; } +int FakeSSLClientSocket::Connect(const net::CompletionCallback& callback) { + // We don't support synchronous operation, even if + // |transport_socket_| does. + DCHECK(!callback.is_null()); + DCHECK_EQ(next_handshake_state_, STATE_NONE); + DCHECK(!handshake_completed_); + DCHECK(user_connect_callback_.is_null()); + DCHECK_EQ(write_buf_->BytesConsumed(), 0); + DCHECK_EQ(read_buf_->BytesConsumed(), 0); + + next_handshake_state_ = STATE_CONNECT; + int status = DoHandshakeLoop(); + if (status == net::ERR_IO_PENDING) + user_connect_callback_ = callback; + + return status; +} int FakeSSLClientSocket::DoHandshakeLoop() { DCHECK_NE(next_handshake_state_, STATE_NONE); @@ -167,9 +184,16 @@ int FakeSSLClientSocket::DoHandshakeLoop() { void FakeSSLClientSocket::RunUserConnectCallback(int status) { DCHECK_LE(status, net::OK); next_handshake_state_ = STATE_NONE; - net::OldCompletionCallback* user_connect_callback = user_connect_callback_; - user_connect_callback_ = NULL; - user_connect_callback->Run(status); + if (old_user_connect_callback_) { + net::OldCompletionCallback* user_connect_callback = + old_user_connect_callback_; + old_user_connect_callback_ = NULL; + user_connect_callback->Run(status); + } else { + net::CompletionCallback user_connect_callback = user_connect_callback_; + user_connect_callback_.Reset(); + user_connect_callback.Run(status); + } } void FakeSSLClientSocket::DoHandshakeLoopWithUserConnectCallback() { @@ -191,7 +215,7 @@ int FakeSSLClientSocket::DoConnect() { void FakeSSLClientSocket::OnConnectDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); DCHECK_LE(status, net::OK); - DCHECK(user_connect_callback_); + DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); if (status != net::OK) { RunUserConnectCallback(status); return; @@ -219,7 +243,7 @@ int FakeSSLClientSocket::DoSendClientHello() { void FakeSSLClientSocket::OnSendClientHelloDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); - DCHECK(user_connect_callback_); + DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); if (status < net::OK) { RunUserConnectCallback(status); return; @@ -252,7 +276,7 @@ int FakeSSLClientSocket::DoVerifyServerHello() { void FakeSSLClientSocket::OnVerifyServerHelloDone(int status) { DCHECK_NE(status, net::ERR_IO_PENDING); - DCHECK(user_connect_callback_); + DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); if (status < net::OK) { RunUserConnectCallback(status); return; @@ -295,7 +319,8 @@ void FakeSSLClientSocket::Disconnect() { transport_socket_->Disconnect(); next_handshake_state_ = STATE_NONE; handshake_completed_ = false; - user_connect_callback_ = NULL; + old_user_connect_callback_ = NULL; + user_connect_callback_.Reset(); write_buf_->SetOffset(0); read_buf_->SetOffset(0); } diff --git a/jingle/notifier/base/fake_ssl_client_socket.h b/jingle/notifier/base/fake_ssl_client_socket.h index 9a5af54..9061abe 100644 --- a/jingle/notifier/base/fake_ssl_client_socket.h +++ b/jingle/notifier/base/fake_ssl_client_socket.h @@ -53,6 +53,7 @@ class FakeSSLClientSocket : public net::StreamSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -107,7 +108,8 @@ class FakeSSLClientSocket : public net::StreamSocket { bool handshake_completed_; // The callback passed to Connect(). - net::OldCompletionCallback* user_connect_callback_; + net::OldCompletionCallback* old_user_connect_callback_; + net::CompletionCallback user_connect_callback_; scoped_refptr<net::DrainableIOBuffer> write_buf_; scoped_refptr<net::DrainableIOBuffer> read_buf_; diff --git a/jingle/notifier/base/fake_ssl_client_socket_unittest.cc b/jingle/notifier/base/fake_ssl_client_socket_unittest.cc index 86f5fb1..6c73af5 100644 --- a/jingle/notifier/base/fake_ssl_client_socket_unittest.cc +++ b/jingle/notifier/base/fake_ssl_client_socket_unittest.cc @@ -52,6 +52,7 @@ class MockClientSocket : public net::StreamSocket { MOCK_METHOD1(SetReceiveBufferSize, bool(int32)); MOCK_METHOD1(SetSendBufferSize, bool(int32)); MOCK_METHOD1(Connect, int(net::OldCompletionCallback*)); + MOCK_METHOD1(Connect, int(const net::CompletionCallback&)); MOCK_METHOD0(Disconnect, void()); MOCK_CONST_METHOD0(IsConnected, bool()); MOCK_CONST_METHOD0(IsConnectedAndIdle, bool()); diff --git a/jingle/notifier/base/proxy_resolving_client_socket.cc b/jingle/notifier/base/proxy_resolving_client_socket.cc index 4c31ba3..b3a95f9 100644 --- a/jingle/notifier/base/proxy_resolving_client_socket.cc +++ b/jingle/notifier/base/proxy_resolving_client_socket.cc @@ -37,7 +37,7 @@ ProxyResolvingClientSocket::ProxyResolvingClientSocket( request_context_getter->GetURLRequestContext()->net_log(), net::NetLog::SOURCE_SOCKET)), ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)), - user_connect_callback_(NULL) { + old_user_connect_callback_(NULL) { DCHECK(request_context_getter); net::URLRequestContext* request_context = request_context_getter->GetURLRequestContext(); @@ -97,7 +97,35 @@ bool ProxyResolvingClientSocket::SetSendBufferSize(int32 size) { } int ProxyResolvingClientSocket::Connect(net::OldCompletionCallback* callback) { - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + + tried_direct_connect_fallback_ = false; + + // First we try and resolve the proxy. + GURL url("http://" + dest_host_port_pair_.ToString()); + int status = network_session_->proxy_service()->ResolveProxy( + url, + &proxy_info_, + &proxy_resolve_callback_, + &pac_request_, + bound_net_log_); + if (status != net::ERR_IO_PENDING) { + // We defer execution of ProcessProxyResolveDone instead of calling it + // directly here for simplicity. From the caller's point of view, + // the connect always happens asynchronously. + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ProxyResolvingClientSocket::ProcessProxyResolveDone, + weak_factory_.GetWeakPtr(), status)); + } + old_user_connect_callback_ = callback; + return net::ERR_IO_PENDING; +} +int ProxyResolvingClientSocket::Connect( + const net::CompletionCallback& callback) { + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); tried_direct_connect_fallback_ = false; @@ -126,9 +154,16 @@ int ProxyResolvingClientSocket::Connect(net::OldCompletionCallback* callback) { void ProxyResolvingClientSocket::RunUserConnectCallback(int status) { DCHECK_LE(status, net::OK); - net::OldCompletionCallback* user_connect_callback = user_connect_callback_; - user_connect_callback_ = NULL; - user_connect_callback->Run(status); + if (old_user_connect_callback_) { + net::OldCompletionCallback* user_connect_callback = + old_user_connect_callback_; + old_user_connect_callback_ = NULL; + user_connect_callback->Run(status); + } else { + net::CompletionCallback user_connect_callback = user_connect_callback_; + user_connect_callback_.Reset(); + user_connect_callback.Run(status); + } } // Always runs asynchronously. @@ -287,7 +322,8 @@ void ProxyResolvingClientSocket::Disconnect() { CloseTransportSocket(); if (pac_request_) network_session_->proxy_service()->CancelPacRequest(pac_request_); - user_connect_callback_ = NULL; + old_user_connect_callback_ = NULL; + user_connect_callback_.Reset(); } bool ProxyResolvingClientSocket::IsConnected() const { diff --git a/jingle/notifier/base/proxy_resolving_client_socket.h b/jingle/notifier/base/proxy_resolving_client_socket.h index 0c40a22..0f27bc1 100644 --- a/jingle/notifier/base/proxy_resolving_client_socket.h +++ b/jingle/notifier/base/proxy_resolving_client_socket.h @@ -53,6 +53,7 @@ class ProxyResolvingClientSocket : public net::StreamSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -95,7 +96,8 @@ class ProxyResolvingClientSocket : public net::StreamSocket { base::WeakPtrFactory<ProxyResolvingClientSocket> weak_factory_; // The callback passed to Connect(). - net::OldCompletionCallback* user_connect_callback_; + net::OldCompletionCallback* old_user_connect_callback_; + net::CompletionCallback user_connect_callback_; }; } // namespace notifier diff --git a/net/curvecp/client_packetizer.cc b/net/curvecp/client_packetizer.cc index 72bf4a0..77ddfda 100644 --- a/net/curvecp/client_packetizer.cc +++ b/net/curvecp/client_packetizer.cc @@ -33,7 +33,7 @@ ClientPacketizer::ClientPacketizer() : Packetizer(), next_state_(NONE), listener_(NULL), - user_callback_(NULL), + old_user_callback_(NULL), current_address_(NULL), hello_attempts_(0), initiate_sent_(false), @@ -52,7 +52,23 @@ ClientPacketizer::~ClientPacketizer() { int ClientPacketizer::Connect(const AddressList& server, Packetizer::Listener* listener, OldCompletionCallback* callback) { - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); + DCHECK(!socket_.get()); + DCHECK(!listener_); + + listener_ = listener; + + addresses_ = server; + + old_user_callback_ = callback; + next_state_ = LOOKUP_COOKIE; + + return DoLoop(OK); +} +int ClientPacketizer::Connect(const AddressList& server, + Packetizer::Listener* listener, + const net::CompletionCallback& callback) { + DCHECK(user_callback_.is_null()); DCHECK(!socket_.get()); DCHECK(!listener_); @@ -279,11 +295,17 @@ int ClientPacketizer::DoConnected(int rv) { void ClientPacketizer::DoCallback(int result) { DCHECK_NE(result, ERR_IO_PENDING); - DCHECK(user_callback_); - - OldCompletionCallback* callback = user_callback_; - user_callback_ = NULL; - callback->Run(result); + DCHECK(old_user_callback_ || !user_callback_.is_null()); + + if (old_user_callback_) { + OldCompletionCallback* callback = old_user_callback_; + old_user_callback_ = NULL; + callback->Run(result); + } else { + CompletionCallback callback = user_callback_; + user_callback_.Reset(); + callback.Run(result); + } } int ClientPacketizer::ConnectNextAddress() { diff --git a/net/curvecp/client_packetizer.h b/net/curvecp/client_packetizer.h index 4915bbe..f8edeaf 100644 --- a/net/curvecp/client_packetizer.h +++ b/net/curvecp/client_packetizer.h @@ -30,8 +30,11 @@ class ClientPacketizer : public Packetizer { int Connect(const AddressList& server, Packetizer::Listener* listener, OldCompletionCallback* callback); + int Connect(const AddressList& server, + Packetizer::Listener* listener, + const CompletionCallback& callback); - // Packetizer methods + // Packetizer implementation. virtual int SendMessage(ConnectionKey key, const char* data, size_t length, @@ -83,13 +86,14 @@ class ClientPacketizer : public Packetizer { StateType next_state_; scoped_ptr<UDPClientSocket> socket_; Packetizer::Listener* listener_; - OldCompletionCallback* user_callback_; + OldCompletionCallback* old_user_callback_; + CompletionCallback user_callback_; AddressList addresses_; const struct addrinfo* current_address_; int hello_attempts_; // Number of attempts to send a Hello Packet. - bool initiate_sent_; // Indicates whether the Initialte Packet was sent. + bool initiate_sent_; // Indicates whether the Initiate Packet was sent. - scoped_refptr<IOBuffer> read_buffer_; // Buffer for interal reads. + scoped_refptr<IOBuffer> read_buffer_; // Buffer for internal reads. uchar shortterm_public_key_[32]; diff --git a/net/curvecp/curvecp_client_socket.cc b/net/curvecp/curvecp_client_socket.cc index cba3d0e..8074016 100644 --- a/net/curvecp/curvecp_client_socket.cc +++ b/net/curvecp/curvecp_client_socket.cc @@ -25,6 +25,10 @@ int CurveCPClientSocket::Connect(OldCompletionCallback* callback) { return packetizer_.Connect(addresses_, &messenger_, callback); } +int CurveCPClientSocket::Connect(const net::CompletionCallback& callback) { + return packetizer_.Connect(addresses_, &messenger_, callback); +} + void CurveCPClientSocket::Disconnect() { // TODO(mbelshe): DCHECK that we're connected. // Record the ConnectionKey so that we can disconnect it properly. diff --git a/net/curvecp/curvecp_client_socket.h b/net/curvecp/curvecp_client_socket.h index fc56440..85f8817 100644 --- a/net/curvecp/curvecp_client_socket.h +++ b/net/curvecp/curvecp_client_socket.h @@ -25,8 +25,9 @@ class CurveCPClientSocket : public StreamSocket { const net::NetLog::Source& source); virtual ~CurveCPClientSocket(); - // ClientSocket methods: + // ClientSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; diff --git a/net/http/http_proxy_client_socket.cc b/net/http/http_proxy_client_socket.cc index 1434c655..15c6e5f 100644 --- a/net/http/http_proxy_client_socket.cc +++ b/net/http/http_proxy_client_socket.cc @@ -37,7 +37,7 @@ HttpProxyClientSocket::HttpProxyClientSocket( : ALLOW_THIS_IN_INITIALIZER_LIST( io_callback_(this, &HttpProxyClientSocket::OnIOComplete)), next_state_(STATE_NONE), - user_callback_(NULL), + old_user_callback_(NULL), transport_(transport_socket), endpoint_(endpoint), auth_(tunnel ? @@ -65,7 +65,7 @@ HttpProxyClientSocket::~HttpProxyClientSocket() { int HttpProxyClientSocket::RestartWithAuth(OldCompletionCallback* callback) { DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); int rv = PrepareForAuthRestart(); if (rv != OK || next_state_ == STATE_NONE) @@ -73,7 +73,7 @@ int HttpProxyClientSocket::RestartWithAuth(OldCompletionCallback* callback) { rv = DoLoop(OK); if (rv == ERR_IO_PENDING) - user_callback_ = callback; + old_user_callback_ = callback; return rv; } @@ -95,7 +95,31 @@ HttpStream* HttpProxyClientSocket::CreateConnectResponseStream() { int HttpProxyClientSocket::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(transport_->socket()); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + // TODO(rch): figure out the right way to set up a tunnel with SPDY. + // This approach sends the complete HTTPS request to the proxy + // which allows the proxy to see "private" data. Instead, we should + // create an SSL tunnel to the origin server using the CONNECT method + // inside a single SPDY stream. + if (using_spdy_ || !tunnel_) + next_state_ = STATE_DONE; + if (next_state_ == STATE_DONE) + return OK; + + DCHECK_EQ(STATE_NONE, next_state_); + next_state_ = STATE_GENERATE_AUTH_TOKEN; + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + old_user_callback_ = callback; + return rv; +} + +int HttpProxyClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(transport_->socket()); + DCHECK(!old_user_callback_ && user_callback_.is_null()); // TODO(rch): figure out the right way to set up a tunnel with SPDY. // This approach sends the complete HTTPS request to the proxy @@ -123,7 +147,8 @@ void HttpProxyClientSocket::Disconnect() { // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). next_state_ = STATE_NONE; - user_callback_ = NULL; + old_user_callback_ = NULL; + user_callback_.Reset(); } bool HttpProxyClientSocket::IsConnected() const { @@ -189,7 +214,7 @@ base::TimeDelta HttpProxyClientSocket::GetConnectTimeMicros() const { int HttpProxyClientSocket::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); if (next_state_ != STATE_DONE) { // We're trying to read the body of the response but we're still trying // to establish an SSL tunnel through the proxy. We can't read these @@ -210,7 +235,7 @@ int HttpProxyClientSocket::Read(IOBuffer* buf, int buf_len, int HttpProxyClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK_EQ(STATE_DONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); return transport_->socket()->Write(buf, buf_len, callback); } @@ -277,13 +302,19 @@ void HttpProxyClientSocket::LogBlockedTunnelResponse(int response_code) const { void HttpProxyClientSocket::DoCallback(int result) { DCHECK_NE(ERR_IO_PENDING, result); - DCHECK(user_callback_); + DCHECK(old_user_callback_ || !user_callback_.is_null()); // Since Run() may result in Read being called, - // clear user_callback_ up front. - OldCompletionCallback* c = user_callback_; - user_callback_ = NULL; - c->Run(result); + // clear old_user_callback_ up front. + if (old_user_callback_) { + OldCompletionCallback* c = old_user_callback_; + old_user_callback_ = NULL; + c->Run(result); + } else { + CompletionCallback c = user_callback_; + user_callback_.Reset(); + c.Run(result); + } } void HttpProxyClientSocket::OnIOComplete(int result) { diff --git a/net/http/http_proxy_client_socket.h b/net/http/http_proxy_client_socket.h index 70db3707..15baf7c 100644 --- a/net/http/http_proxy_client_socket.h +++ b/net/http/http_proxy_client_socket.h @@ -60,8 +60,9 @@ class HttpProxyClientSocket : public ProxyClientSocket { virtual int RestartWithAuth(OldCompletionCallback* callback) OVERRIDE; virtual const scoped_refptr<HttpAuthController>& auth_controller() OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -126,7 +127,8 @@ class HttpProxyClientSocket : public ProxyClientSocket { State next_state_; // Stores the callback to the layer above, called on completing Connect(). - OldCompletionCallback* user_callback_; + OldCompletionCallback* old_user_callback_; + CompletionCallback user_callback_; HttpRequestInfo request_; HttpResponseInfo response_; diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index f8a21fb..d318dd6 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -64,12 +64,15 @@ class MockClientSocket : public StreamSocket { virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } - // StreamSocket methods: - + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) { connected_ = true; return OK; } + virtual int Connect(const net::CompletionCallback& callback) { + connected_ = true; + return OK; + } virtual void Disconnect() { connected_ = false; } virtual bool IsConnected() const { return connected_; } diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 9685697..eb1599f 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -7,8 +7,9 @@ #include <algorithm> #include <vector> - #include "base/basictypes.h" +#include "base/bind.h" +#include "base/bind_helpers.h" #include "base/compiler_specific.h" #include "base/message_loop.h" #include "base/time.h" @@ -630,7 +631,7 @@ void MockClientSocketFactory::ClearSSLSessionCache() { } MockClientSocket::MockClientSocket(net::NetLog* net_log) - : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), + : ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)), connected_(false), net_log_(NetLog::Source(), net_log) { } @@ -703,15 +704,26 @@ MockClientSocket::~MockClientSocket() {} void MockClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, int result) { MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockClientSocket::RunCallback, callback, result)); + base::Bind(&MockClientSocket::RunOldCallback, weak_factory_.GetWeakPtr(), + callback, result)); +} +void MockClientSocket::RunCallbackAsync(const net::CompletionCallback& callback, + int result) { + MessageLoop::current()->PostTask(FROM_HERE, + base::Bind(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(), + callback, result)); } -void MockClientSocket::RunCallback(net::OldCompletionCallback* callback, - int result) { +void MockClientSocket::RunOldCallback(net::OldCompletionCallback* callback, + int result) { if (callback) callback->Run(result); } +void MockClientSocket::RunCallback(const net::CompletionCallback& callback, + int result) { + if (!callback.is_null()) + callback.Run(result); +} MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log, @@ -797,6 +809,19 @@ int MockTCPClientSocket::Connect(net::OldCompletionCallback* callback) { } return data_->connect_data().result; } +int MockTCPClientSocket::Connect(const net::CompletionCallback& callback) { + if (connected_) + return net::OK; + + connected_ = true; + peer_closed_connection_ = false; + if (data_->connect_data().async) { + RunCallbackAsync(callback, data_->connect_data().result); + return net::ERR_IO_PENDING; + } + + return data_->connect_data().result; +} void MockTCPClientSocket::Disconnect() { MockClientSocket::Disconnect(); @@ -853,7 +878,7 @@ void MockTCPClientSocket::OnReadComplete(const MockRead& data) { net::OldCompletionCallback* callback = pending_callback_; int rv = CompleteRead(); - RunCallback(callback, rv); + RunOldCallback(callback, rv); } int MockTCPClientSocket::CompleteRead() { @@ -1006,6 +1031,19 @@ int DeterministicMockTCPClientSocket::Connect( } return data_->connect_data().result; } +int DeterministicMockTCPClientSocket::Connect( + const net::CompletionCallback& callback) { + if (connected_) + return net::OK; + + connected_ = true; + if (data_->connect_data().async) { + RunCallbackAsync(callback, data_->connect_data().result); + return net::ERR_IO_PENDING; + } + + return data_->connect_data().result; +} void DeterministicMockTCPClientSocket::Disconnect() { MockClientSocket::Disconnect(); @@ -1037,15 +1075,17 @@ base::TimeDelta DeterministicMockTCPClientSocket::GetConnectTimeMicros() const { void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} -class MockSSLClientSocket::ConnectCallback - : public net::OldCompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { +class MockSSLClientSocket::OldConnectCallback + : public net::OldCompletionCallbackImpl< + MockSSLClientSocket::OldConnectCallback> { public: - ConnectCallback(MockSSLClientSocket *ssl_client_socket, - net::OldCompletionCallback* user_callback, - int rv) + OldConnectCallback(MockSSLClientSocket *ssl_client_socket, + net::OldCompletionCallback* user_callback, + int rv) : ALLOW_THIS_IN_INITIALIZER_LIST( - net::OldCompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( - this, &ConnectCallback::Wrapper)), + net::OldCompletionCallbackImpl< + MockSSLClientSocket::OldConnectCallback>( + this, &OldConnectCallback::Wrapper)), ssl_client_socket_(ssl_client_socket), user_callback_(user_callback), rv_(rv) { @@ -1063,6 +1103,32 @@ class MockSSLClientSocket::ConnectCallback net::OldCompletionCallback* user_callback_; int rv_; }; +class MockSSLClientSocket::ConnectCallback { + public: + ConnectCallback(MockSSLClientSocket *ssl_client_socket, + const CompletionCallback& user_callback, + int rv) + : ALLOW_THIS_IN_INITIALIZER_LIST(callback_( + base::Bind(&ConnectCallback::Wrapper, base::Unretained(this)))), + ssl_client_socket_(ssl_client_socket), + user_callback_(user_callback), + rv_(rv) { + } + + const CompletionCallback& callback() const { return callback_; } + + private: + void Wrapper(int rv) { + if (rv_ == net::OK) + ssl_client_socket_->connected_ = true; + user_callback_.Run(rv_); + } + + CompletionCallback callback_; + MockSSLClientSocket* ssl_client_socket_; + CompletionCallback user_callback_; + int rv_; +}; MockSSLClientSocket::MockSSLClientSocket( net::ClientSocketHandle* transport_socket, @@ -1094,7 +1160,7 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, } int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { - ConnectCallback* connect_callback = new ConnectCallback( + OldConnectCallback* connect_callback = new OldConnectCallback( this, callback, data_->connect.result); int rv = transport_->socket()->Connect(connect_callback); if (rv == net::OK) { @@ -1109,6 +1175,20 @@ int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { } return rv; } +int MockSSLClientSocket::Connect(const net::CompletionCallback& callback) { + ConnectCallback connect_callback(this, callback, data_->connect.result); + int rv = transport_->socket()->Connect(connect_callback.callback()); + if (rv == net::OK) { + if (data_->connect.result == net::OK) + connected_ = true; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return net::ERR_IO_PENDING; + } + return data_->connect.result; + } + return rv; +} void MockSSLClientSocket::Disconnect() { MockClientSocket::Disconnect(); @@ -1187,7 +1267,7 @@ MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, pending_buf_len_(0), pending_callback_(NULL), net_log_(NetLog::Source(), net_log), - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)) { + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { DCHECK(data_); data_->Reset(); } @@ -1330,8 +1410,8 @@ int MockUDPClientSocket::CompleteRead() { void MockUDPClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, int result) { MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockUDPClientSocket::RunCallback, callback, result)); + base::Bind(&MockUDPClientSocket::RunCallback, weak_factory_.GetWeakPtr(), + callback, result)); } void MockUDPClientSocket::RunCallback(net::OldCompletionCallback* callback, diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index c7082ca..fa47d26 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -198,7 +198,7 @@ class StaticSocketDataProvider : public SocketDataProvider { virtual void CompleteRead() {} - // SocketDataProvider methods: + // SocketDataProvider implementation. virtual MockRead GetNextRead() OVERRIDE; virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; virtual void Reset() OVERRIDE; @@ -228,7 +228,7 @@ class DynamicSocketDataProvider : public SocketDataProvider { void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } - // SocketDataProvider methods: + // SocketDataProvider implementation. virtual MockRead GetNextRead() OVERRIDE; virtual MockWriteResult OnWrite(const std::string& data) = 0; virtual void Reset() OVERRIDE; @@ -584,7 +584,7 @@ class MockClientSocket : public net::SSLClientSocket { public: explicit MockClientSocket(net::NetLog* net_log); - // Socket methods: + // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) = 0; virtual int Write(net::IOBuffer* buf, int buf_len, @@ -592,8 +592,9 @@ class MockClientSocket : public net::SSLClientSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) = 0; + virtual int Connect(const net::CompletionCallback& callback) = 0; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -603,7 +604,7 @@ class MockClientSocket : public net::SSLClientSocket { virtual void SetSubresourceSpeculation() OVERRIDE {} virtual void SetOmniboxSpeculation() OVERRIDE {} - // SSLClientSocket methods: + // SSLClientSocket implementation. virtual void GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( net::SSLCertRequestInfo* cert_request_info) OVERRIDE; @@ -617,9 +618,11 @@ class MockClientSocket : public net::SSLClientSocket { protected: virtual ~MockClientSocket(); void RunCallbackAsync(net::OldCompletionCallback* callback, int result); - void RunCallback(net::OldCompletionCallback*, int result); + void RunCallbackAsync(const net::CompletionCallback& callback, int result); + void RunOldCallback(net::OldCompletionCallback*, int result); + void RunCallback(const net::CompletionCallback&, int result); - ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + base::WeakPtrFactory<MockClientSocket> weak_factory_; // True if Connect completed successfully and Disconnect hasn't been called. bool connected_; @@ -634,14 +637,15 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { net::AddressList addresses() const { return addresses_; } - // Socket methods: + // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -697,8 +701,9 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; - // StreamSocket: + // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -735,14 +740,15 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { net::SSLSocketDataProvider* socket); virtual ~MockSSLClientSocket(); - // Socket methods: + // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool WasEverUsed() const OVERRIDE; @@ -750,7 +756,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - // SSLClientSocket methods: + // SSLClientSocket implementation. virtual void GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( net::SSLCertRequestInfo* cert_request_info) OVERRIDE; @@ -763,6 +769,7 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual void OnReadComplete(const MockRead& data) OVERRIDE; private: + class OldConnectCallback; class ConnectCallback; scoped_ptr<ClientSocketHandle> transport_; @@ -817,7 +824,7 @@ class MockUDPClientSocket : public DatagramClientSocket, BoundNetLog net_log_; - ScopedRunnableMethodFactory<MockUDPClientSocket> method_factory_; + base::WeakPtrFactory<MockUDPClientSocket> weak_factory_; DISALLOW_COPY_AND_ASSIGN(MockUDPClientSocket); }; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index f106f70..c9d2825 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -34,7 +34,7 @@ SOCKS5ClientSocket::SOCKS5ClientSocket( io_callback_(this, &SOCKS5ClientSocket::OnIOComplete)), transport_(transport_socket), next_state_(STATE_NONE), - user_callback_(NULL), + old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -50,7 +50,7 @@ SOCKS5ClientSocket::SOCKS5ClientSocket( io_callback_(this, &SOCKS5ClientSocket::OnIOComplete)), transport_(new ClientSocketHandle()), next_state_(STATE_NONE), - user_callback_(NULL), + old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -68,7 +68,7 @@ int SOCKS5ClientSocket::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(transport_->socket()); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); // If already connected, then just return OK. if (completed_handshake_) @@ -81,12 +81,35 @@ int SOCKS5ClientSocket::Connect(OldCompletionCallback* callback) { int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) { - user_callback_ = callback; + old_user_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv); } return rv; } +int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(transport_->socket()); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT, NULL); + + next_state_ = STATE_GREET_WRITE; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + user_callback_ = callback; + else + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv); + + return rv; +} void SOCKS5ClientSocket::Disconnect() { completed_handshake_ = false; @@ -95,7 +118,8 @@ void SOCKS5ClientSocket::Disconnect() { // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). next_state_ = STATE_NONE; - user_callback_ = NULL; + old_user_callback_ = NULL; + user_callback_.Reset(); } bool SOCKS5ClientSocket::IsConnected() const { @@ -164,7 +188,7 @@ int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); return transport_->socket()->Read(buf, buf_len, callback); } @@ -175,7 +199,7 @@ int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); return transport_->socket()->Write(buf, buf_len, callback); } @@ -190,13 +214,19 @@ bool SOCKS5ClientSocket::SetSendBufferSize(int32 size) { void SOCKS5ClientSocket::DoCallback(int result) { DCHECK_NE(ERR_IO_PENDING, result); - DCHECK(user_callback_); + DCHECK(old_user_callback_ || !user_callback_.is_null()); // Since Run() may result in Read being called, // clear user_callback_ up front. - OldCompletionCallback* c = user_callback_; - user_callback_ = NULL; - c->Run(result); + if (old_user_callback_) { + OldCompletionCallback* c = old_user_callback_; + old_user_callback_ = NULL; + c->Run(result); + } else { + CompletionCallback c = user_callback_; + user_callback_.Reset(); + c.Run(result); + } } void SOCKS5ClientSocket::OnIOComplete(int result) { diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index aeb1864..748b55a 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -52,6 +52,7 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { // Does the SOCKS handshake and completes the protocol. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -129,7 +130,8 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { State next_state_; // Stores the callback to the layer above, called on completing Connect(). - OldCompletionCallback* user_callback_; + OldCompletionCallback* old_user_callback_; + CompletionCallback user_callback_; // This IOBuffer is used by the class to read and write // SOCKS handshake data. The length contains the expected size to diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 3ce723e..a4c4b47 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -62,7 +62,7 @@ SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, io_callback_(this, &SOCKSClientSocket::OnIOComplete)), transport_(transport_socket), next_state_(STATE_NONE), - user_callback_(NULL), + old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -78,7 +78,7 @@ SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket, io_callback_(this, &SOCKSClientSocket::OnIOComplete)), transport_(new ClientSocketHandle()), next_state_(STATE_NONE), - user_callback_(NULL), + old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -96,7 +96,7 @@ int SOCKSClientSocket::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(transport_->socket()); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); // If already connected, then just return OK. if (completed_handshake_) @@ -108,12 +108,34 @@ int SOCKSClientSocket::Connect(OldCompletionCallback* callback) { int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) { - user_callback_ = callback; + old_user_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); } return rv; } +int SOCKSClientSocket::Connect(const net::CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(transport_->socket()); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + next_state_ = STATE_RESOLVE_HOST; + + net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT, NULL); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + user_callback_ = callback; + else + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); + + return rv; +} void SOCKSClientSocket::Disconnect() { completed_handshake_ = false; @@ -123,7 +145,8 @@ void SOCKSClientSocket::Disconnect() { // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). next_state_ = STATE_NONE; - user_callback_ = NULL; + old_user_callback_ = NULL; + user_callback_.Reset(); } bool SOCKSClientSocket::IsConnected() const { @@ -193,7 +216,7 @@ int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); return transport_->socket()->Read(buf, buf_len, callback); } @@ -204,7 +227,7 @@ int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!user_callback_); + DCHECK(!old_user_callback_); return transport_->socket()->Write(buf, buf_len, callback); } @@ -219,14 +242,21 @@ bool SOCKSClientSocket::SetSendBufferSize(int32 size) { void SOCKSClientSocket::DoCallback(int result) { DCHECK_NE(ERR_IO_PENDING, result); - DCHECK(user_callback_); + DCHECK(old_user_callback_ || !user_callback_.is_null()); // Since Run() may result in Read being called, // clear user_callback_ up front. - OldCompletionCallback* c = user_callback_; - user_callback_ = NULL; - DVLOG(1) << "Finished setting up SOCKS handshake"; - c->Run(result); + if (old_user_callback_) { + OldCompletionCallback* c = old_user_callback_; + old_user_callback_ = NULL; + DVLOG(1) << "Finished setting up SOCKS handshake"; + c->Run(result); + } else { + CompletionCallback c = user_callback_; + user_callback_.Reset(); + DVLOG(1) << "Finished setting up SOCKS handshake"; + c.Run(result); + } } void SOCKSClientSocket::OnIOComplete(int result) { diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index eb74a5e..c7089af 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -49,6 +49,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { // Does the SOCKS handshake and completes the protocol. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -110,7 +111,8 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { State next_state_; // Stores the callback to the layer above, called on completing Connect(). - OldCompletionCallback* user_callback_; + OldCompletionCallback* old_user_callback_; + CompletionCallback user_callback_; // This IOBuffer is used by the class to read and write // SOCKS handshake data. The length contains the expected size to diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index c719946..817b3be 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -531,7 +531,7 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - user_connect_callback_(NULL), + old_user_connect_callback_(NULL), user_read_callback_(NULL), user_write_callback_(NULL), user_read_buf_len_(0), @@ -558,7 +558,29 @@ SSLClientSocketMac::~SSLClientSocketMac() { int SSLClientSocketMac::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); + + int rv = InitializeSSLContext(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + next_handshake_state_ = STATE_HANDSHAKE; + rv = DoHandshakeLoop(OK); + if (rv == ERR_IO_PENDING) { + old_user_connect_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + } + return rv; +} +int SSLClientSocketMac::Connect(const CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(next_handshake_state_ == STATE_NONE); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); @@ -896,11 +918,17 @@ int SSLClientSocketMac::InitializeSSLContext() { void SSLClientSocketMac::DoConnectCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_connect_callback_); + DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); - OldCompletionCallback* c = user_connect_callback_; - user_connect_callback_ = NULL; - c->Run(rv > OK ? OK : rv); + if (old_user_connect_callback_) { + OldCompletionCallback* c = old_user_connect_callback_; + old_user_connect_callback_ = NULL; + c->Run(rv > OK ? OK : rv); + } else { + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); + } } void SSLClientSocketMac::DoReadCallback(int rv) { @@ -936,7 +964,7 @@ void SSLClientSocketMac::OnHandshakeIOComplete(int result) { // renegotiating (which occurs because we are in the middle of a Read // when the renegotiation process starts). So we complete the Read // here. - if (!user_connect_callback_) { + if (!old_user_connect_callback_ && user_connect_callback_.is_null()) { DoReadCallback(rv); return; } @@ -1274,7 +1302,7 @@ int SSLClientSocketMac::DoCompletedRenegotiation(int result) { } void SSLClientSocketMac::DidCompleteRenegotiation() { - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); renegotiating_ = false; next_handshake_state_ = STATE_COMPLETED_RENEGOTIATION; } diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index 7fa95c4..b9dccc0 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -40,7 +40,7 @@ class SSLClientSocketMac : public SSLClientSocket { const SSLClientSocketContext& context); virtual ~SSLClientSocketMac(); - // SSLClientSocket methods: + // SSLClientSocket implementation. virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; @@ -51,8 +51,9 @@ class SSLClientSocketMac : public SSLClientSocket { virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos) OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -66,7 +67,7 @@ class SSLClientSocketMac : public SSLClientSocket { virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; @@ -118,7 +119,8 @@ class SSLClientSocketMac : public SSLClientSocket { HostPortPair host_and_port_; SSLConfig ssl_config_; - OldCompletionCallback* user_connect_callback_; + OldCompletionCallback* old_user_connect_callback_; + CompletionCallback user_connect_callback_; OldCompletionCallback* user_read_callback_; OldCompletionCallback* user_write_callback_; diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 1e95fc8..e1ac396 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -446,7 +446,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - user_connect_callback_(NULL), + old_user_connect_callback_(NULL), user_read_callback_(NULL), user_write_callback_(NULL), user_read_buf_len_(0), @@ -578,7 +578,55 @@ int SSLClientSocketNSS::Connect(OldCompletionCallback* callback) { DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!user_read_callback_); DCHECK(!user_write_callback_); - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(!user_read_buf_); + DCHECK(!user_write_buf_); + + EnsureThreadIdAssigned(); + + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); + + int rv = Init(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + rv = InitializeSSLOptions(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + rv = InitializeSSLPeerName(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + if (ssl_config_.cached_info_enabled && ssl_host_info_.get()) { + GotoState(STATE_LOAD_SSL_HOST_INFO); + } else { + GotoState(STATE_HANDSHAKE); + } + + rv = DoHandshakeLoop(OK); + if (rv == ERR_IO_PENDING) { + old_user_connect_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + } + + LeaveFunction(""); + return rv > OK ? OK : rv; +} +int SSLClientSocketNSS::Connect(const CompletionCallback& callback) { + EnterFunction(""); + DCHECK(transport_.get()); + DCHECK(next_handshake_state_ == STATE_NONE); + DCHECK(!user_read_callback_); + DCHECK(!user_write_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(!user_write_buf_); @@ -645,7 +693,8 @@ void SSLClientSocketNSS::Disconnect() { // Reset object state transport_send_busy_ = false; transport_recv_busy_ = false; - user_connect_callback_ = NULL; + old_user_connect_callback_ = NULL; + user_connect_callback_.Reset(); user_read_callback_ = NULL; user_write_callback_ = NULL; user_read_buf_ = NULL; @@ -767,7 +816,7 @@ int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!user_read_callback_); - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_); DCHECK(!user_read_buf_); DCHECK(nss_bufs_); @@ -792,7 +841,7 @@ int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!user_write_callback_); - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -1188,11 +1237,17 @@ void SSLClientSocketNSS::DoWriteCallback(int rv) { void SSLClientSocketNSS::DoConnectCallback(int rv) { EnterFunction(rv); DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(user_connect_callback_); + DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); - OldCompletionCallback* c = user_connect_callback_; - user_connect_callback_ = NULL; - c->Run(rv > OK ? OK : rv); + if (old_user_connect_callback_) { + OldCompletionCallback* c = old_user_connect_callback_; + old_user_connect_callback_ = NULL; + c->Run(rv > OK ? OK : rv); + } else { + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); + } LeaveFunction(""); } diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 019412a..78e222b 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -59,7 +59,7 @@ class SSLClientSocketNSS : public SSLClientSocket { NET_EXPORT_PRIVATE static void ClearSessionCache(); - // SSLClientSocket methods: + // SSLClientSocket implementation. virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; @@ -70,8 +70,9 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos) OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -85,7 +86,7 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; @@ -225,7 +226,8 @@ class SSLClientSocketNSS : public SSLClientSocket { HostPortPair host_and_port_; SSLConfig ssl_config_; - OldCompletionCallback* user_connect_callback_; + OldCompletionCallback* old_user_connect_callback_; + CompletionCallback user_connect_callback_; OldCompletionCallback* user_read_callback_; OldCompletionCallback* user_write_callback_; diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index e2f3a78..e6e9aac 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -390,7 +390,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( this, &SSLClientSocketOpenSSL::BufferRecvComplete)), transport_send_busy_(false), transport_recv_busy_(false), - user_connect_callback_(NULL), + old_user_connect_callback_(NULL), user_read_callback_(NULL), user_write_callback_(NULL), completed_handshake_(false), @@ -649,6 +649,29 @@ int SSLClientSocketOpenSSL::Connect(OldCompletionCallback* callback) { GotoState(STATE_HANDSHAKE); int rv = DoHandshakeLoop(net::OK); if (rv == ERR_IO_PENDING) { + old_user_connect_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + } + + return rv > OK ? OK : rv; +} +int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); + + // Set up new ssl object. + if (!Init()) { + int result = ERR_UNEXPECTED; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, result); + return result; + } + + // Set SSL to client mode. Handshake happens in the loop below. + SSL_set_connect_state(ssl_); + + GotoState(STATE_HANDSHAKE); + int rv = DoHandshakeLoop(net::OK); + if (rv == ERR_IO_PENDING) { user_connect_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); @@ -677,7 +700,8 @@ void SSLClientSocketOpenSSL::Disconnect() { transport_recv_busy_ = false; recv_buffer_ = NULL; - user_connect_callback_ = NULL; + old_user_connect_callback_ = NULL; + user_connect_callback_.Reset(); user_read_callback_ = NULL; user_write_callback_ = NULL; user_read_buf_ = NULL; @@ -1019,9 +1043,15 @@ void SSLClientSocketOpenSSL::TransportReadComplete(int result) { } void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { - OldCompletionCallback* c = user_connect_callback_; - user_connect_callback_ = NULL; - c->Run(rv > OK ? OK : rv); + if (old_user_connect_callback_) { + OldCompletionCallback* c = old_user_connect_callback_; + old_user_connect_callback_ = NULL; + c->Run(rv > OK ? OK : rv); + } else { + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); + } } void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) { diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 281bb1c..010930a 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -52,7 +52,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen); - // SSLClientSocket methods: + // SSLClientSocket implementation. virtual void GetSSLInfo(SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info); virtual int ExportKeyingMaterial(const base::StringPiece& label, @@ -62,8 +62,9 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos); - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback); + virtual int Connect(const CompletionCallback& callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; @@ -77,7 +78,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual int64 NumBytesRead() const; virtual base::TimeDelta GetConnectTimeMicros() const; - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual bool SetReceiveBufferSize(int32 size); @@ -119,7 +120,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { bool transport_recv_busy_; scoped_refptr<IOBuffer> recv_buffer_; - OldCompletionCallback* user_connect_callback_; + OldCompletionCallback* old_user_connect_callback_; + CompletionCallback user_connect_callback_; OldCompletionCallback* user_read_callback_; OldCompletionCallback* user_write_callback_; diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 1e42414..d1ed130 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -397,7 +397,7 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - user_connect_callback_(NULL), + old_user_connect_callback_(NULL), user_read_callback_(NULL), user_read_buf_len_(0), user_write_callback_(NULL), @@ -565,7 +565,30 @@ SSLClientSocketWin::GetNextProto(std::string* proto, int SSLClientSocketWin::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(next_state_ == STATE_NONE); - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); + + int rv = InitializeSSLContext(); + if (rv != OK) { + net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); + return rv; + } + + writing_first_token_ = true; + next_state_ = STATE_HANDSHAKE_WRITE; + rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + old_user_connect_callback_ = callback; + } else { + net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); + } + return rv; +} +int SSLClientSocketWin::Connect(const CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(next_state_ == STATE_NONE); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); @@ -842,7 +865,7 @@ void SSLClientSocketWin::OnHandshakeIOComplete(int result) { // If there is no connect callback available to call, we are renegotiating // (which occurs because we are in the middle of a Read when the // renegotiation process starts). So we complete the Read here. - if (!user_connect_callback_) { + if (!old_user_connect_callback_ && user_connect_callback_.is_null()) { OldCompletionCallback* c = user_read_callback_; user_read_callback_ = NULL; user_read_buf_ = NULL; @@ -851,9 +874,15 @@ void SSLClientSocketWin::OnHandshakeIOComplete(int result) { return; } net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); - OldCompletionCallback* c = user_connect_callback_; - user_connect_callback_ = NULL; - c->Run(rv); + if (old_user_connect_callback_) { + OldCompletionCallback* c = old_user_connect_callback_; + old_user_connect_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv); + } } } @@ -1549,7 +1578,7 @@ int SSLClientSocketWin::DidCompleteHandshake() { // Called when a renegotiation is completed. |result| is the verification // result of the server certificate received during renegotiation. void SSLClientSocketWin::DidCompleteRenegotiation() { - DCHECK(!user_connect_callback_); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(user_read_callback_); renegotiating_ = false; next_state_ = STATE_COMPLETED_RENEGOTIATION; diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index adff167..01a5509b 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -45,7 +45,7 @@ class SSLClientSocketWin : public SSLClientSocket { const SSLClientSocketContext& context); ~SSLClientSocketWin(); - // SSLClientSocket methods: + // SSLClientSocket implementation. virtual void GetSSLInfo(SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info); virtual int ExportKeyingMaterial(const base::StringPiece& label, @@ -55,8 +55,9 @@ class SSLClientSocketWin : public SSLClientSocket { virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos); - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback); + virtual int Connect(const CompletionCallback& callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; @@ -70,7 +71,7 @@ class SSLClientSocketWin : public SSLClientSocket { virtual int64 NumBytesRead() const; virtual base::TimeDelta GetConnectTimeMicros() const; - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); @@ -121,7 +122,8 @@ class SSLClientSocketWin : public SSLClientSocket { SSLConfig ssl_config_; // User function to callback when the Connect() completes. - OldCompletionCallback* user_connect_callback_; + OldCompletionCallback* old_user_connect_callback_; + CompletionCallback user_connect_callback_; // User function to callback when a Read() completes. OldCompletionCallback* user_read_callback_; diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 8f1b43b..8ead679 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -147,6 +147,10 @@ int SSLServerSocketNSS::Connect(OldCompletionCallback* callback) { NOTIMPLEMENTED(); return ERR_NOT_IMPLEMENTED; } +int SSLServerSocketNSS::Connect(const CompletionCallback& callback) { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 0004c4a..7967ffa 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -46,8 +46,9 @@ class SSLServerSocketNSS : public SSLServerSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; - // StreamSocket interface. + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index 2340967..5af50f8 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -147,6 +147,9 @@ class FakeSocket : public StreamSocket { virtual int Connect(OldCompletionCallback* callback) { return net::OK; } + virtual int Connect(const CompletionCallback& callback) { + return net::OK; + } virtual void Disconnect() {} diff --git a/net/socket/stream_socket.h b/net/socket/stream_socket.h index 0a68442..3ba5b42 100644 --- a/net/socket/stream_socket.h +++ b/net/socket/stream_socket.h @@ -34,6 +34,7 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { // Connect may also be called again after a call to the Disconnect method. // virtual int Connect(OldCompletionCallback* callback) = 0; + virtual int Connect(const CompletionCallback& callback) = 0; // Called to disconnect a socket. Does nothing if the socket is already // disconnected. After calling Disconnect it is possible to call Connect diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index 2a490e1..7c8af80 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -131,7 +131,7 @@ TCPClientSocketLibevent::TCPClientSocketLibevent( read_watcher_(this), write_watcher_(this), read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), @@ -227,6 +227,38 @@ int TCPClientSocketLibevent::Connect(OldCompletionCallback* callback) { if (rv == ERR_IO_PENDING) { // Synchronous operation not supported. DCHECK(callback); + old_write_callback_ = callback; + } else { + LogConnectCompletion(rv); + } + + return rv; +} +int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + + // If already connected, then just return OK. + if (socket_ != kInvalidSocket) + return OK; + + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + DCHECK(!waiting_connect()); + + net_log_.BeginEvent( + NetLog::TYPE_TCP_CONNECT, + make_scoped_refptr(new AddressListNetLogParam(addresses_))); + + // We will try to connect to each address in addresses_. Start with the + // first one in the list. + next_connect_state_ = CONNECT_STATE_CONNECT; + current_ai_ = addresses_.head(); + + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); write_callback_ = callback; } else { LogConnectCompletion(rv); @@ -473,7 +505,7 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); DCHECK(!waiting_connect()); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); // Synchronous operation not supported DCHECK(callback); DCHECK_GT(buf_len, 0); @@ -500,7 +532,7 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, write_buf_ = buf; write_buf_len_ = buf_len; - write_callback_ = callback; + old_write_callback_ = callback; return ERR_IO_PENDING; } @@ -596,12 +628,18 @@ void TCPClientSocketLibevent::DoReadCallback(int rv) { void TCPClientSocketLibevent::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(write_callback_); + DCHECK(old_write_callback_ || !write_callback_.is_null()); // since Run may result in Write being called, clear write_callback_ up front. - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + if (old_write_callback_) { + OldCompletionCallback* c = old_write_callback_; + old_write_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); + } } void TCPClientSocketLibevent::DidCompleteConnect() { diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index 2a6b041..ac73f2c 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -42,8 +42,9 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, // Binds the socket to a local IP address and port. int Bind(const IPEndPoint& address); - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -57,7 +58,7 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - // Socket methods: + // Socket implementation. // Multiple outstanding requests are not supported. // Full duplex mode (reading and writing at the same time) is supported virtual int Read(IOBuffer* buf, @@ -100,14 +101,13 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, public: explicit WriteWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} - // MessageLoopForIO::Watcher methods - + // MessageLoopForIO::Watcher implementation. virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {} - virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE { if (socket_->waiting_connect()) { socket_->DidCompleteConnect(); - } else if (socket_->write_callback_) { + } else if (socket_->old_write_callback_ || + !socket_->write_callback_.is_null()) { socket_->DidCompleteWrite(); } } @@ -179,7 +179,8 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, OldCompletionCallback* read_callback_; // External callback; called when write is complete. - OldCompletionCallback* write_callback_; + OldCompletionCallback* old_write_callback_; + CompletionCallback write_callback_; // The next state for the Connect() state machine. ConnectState next_connect_state_; diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index ca5ad63..1267d06 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -317,7 +317,7 @@ TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, current_ai_(NULL), waiting_read_(false), waiting_write_(false), - read_callback_(NULL), + old_read_callback_(NULL), write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), @@ -404,6 +404,35 @@ int TCPClientSocketWin::Connect(OldCompletionCallback* callback) { if (rv == ERR_IO_PENDING) { // Synchronous operation not supported. DCHECK(callback); + old_read_callback_ = callback; + } else { + LogConnectCompletion(rv); + } + + return rv; +} +int TCPClientSocketWin::Connect(const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + + // If already connected, then just return OK. + if (socket_ != INVALID_SOCKET) + return OK; + + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + new AddressListNetLogParam(addresses_)); + + // We will try to connect to each address in addresses_. Start with the + // first one in the list. + next_connect_state_ = CONNECT_STATE_CONNECT; + current_ai_ = addresses_.head(); + + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); read_callback_ = callback; } else { LogConnectCompletion(rv); @@ -680,7 +709,7 @@ int TCPClientSocketWin::Read(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_read_); - DCHECK(!read_callback_); + DCHECK(!old_read_callback_ && read_callback_.is_null()); DCHECK(!core_->read_iobuffer_); buf_len = core_->ThrottleReadSize(buf_len); @@ -711,7 +740,7 @@ int TCPClientSocketWin::Read(IOBuffer* buf, } core_->WatchForRead(); waiting_read_ = true; - read_callback_ = callback; + old_read_callback_ = callback; core_->read_iobuffer_ = buf; return ERR_IO_PENDING; } @@ -811,12 +840,18 @@ void TCPClientSocketWin::LogConnectCompletion(int net_error) { void TCPClientSocketWin::DoReadCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(read_callback_); + DCHECK(old_read_callback_ || !read_callback_.is_null()); // since Run may result in Read being called, clear read_callback_ up front. - OldCompletionCallback* c = read_callback_; - read_callback_ = NULL; - c->Run(rv); + if (old_read_callback_) { + OldCompletionCallback* c = old_read_callback_; + old_read_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); + } } void TCPClientSocketWin::DoWriteCallback(int rv) { diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index 59ab3b4..bda2585 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -41,8 +41,9 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, // Binds the socket to a local IP address and port. int Bind(const IPEndPoint& address); - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback); + virtual int Connect(const CompletionCallback& callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; @@ -56,7 +57,7 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, virtual int64 NumBytesRead() const; virtual base::TimeDelta GetConnectTimeMicros() const; - // Socket methods: + // Socket implementation. // Multiple outstanding requests are not supported. // Full duplex mode (reading and writing at the same time) is supported virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); @@ -123,7 +124,8 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, scoped_refptr<Core> core_; // External callback; called when connect or read is complete. - OldCompletionCallback* read_callback_; + OldCompletionCallback* old_read_callback_; + CompletionCallback read_callback_; // External callback; called when write is complete. OldCompletionCallback* write_callback_; diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 5a61d00..6604727 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -51,11 +51,15 @@ class MockClientSocket : public StreamSocket { : connected_(false), addrlist_(addrlist) {} - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) { connected_ = true; return OK; } + virtual int Connect(const CompletionCallback& callback) { + connected_ = true; + return OK; + } virtual void Disconnect() { connected_ = false; } @@ -112,10 +116,13 @@ class MockFailingClientSocket : public StreamSocket { public: MockFailingClientSocket(const AddressList& addrlist) : addrlist_(addrlist) {} - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) { return ERR_CONNECTION_FAILED; } + virtual int Connect(const net::CompletionCallback& callback) { + return ERR_CONNECTION_FAILED; + } virtual void Disconnect() {} @@ -173,19 +180,28 @@ class MockPendingClientSocket : public StreamSocket { bool should_connect, bool should_stall, int delay_ms) - : method_factory_(ALLOW_THIS_IN_INITIALIZER_LIST(this)), + : ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)), should_connect_(should_connect), should_stall_(should_stall), delay_ms_(delay_ms), is_connected_(false), addrlist_(addrlist) {} - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) { MessageLoop::current()->PostDelayedTask( FROM_HERE, - method_factory_.NewRunnableMethod( - &MockPendingClientSocket::DoCallback, callback), delay_ms_); + base::Bind(&MockPendingClientSocket::DoOldCallback, + weak_factory_.GetWeakPtr(), callback), + delay_ms_); + return ERR_IO_PENDING; + } + virtual int Connect(const CompletionCallback& callback) { + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&MockPendingClientSocket::DoCallback, + weak_factory_.GetWeakPtr(), callback), + delay_ms_); return ERR_IO_PENDING; } @@ -236,7 +252,7 @@ class MockPendingClientSocket : public StreamSocket { virtual bool SetSendBufferSize(int32 size) { return true; } private: - void DoCallback(OldCompletionCallback* callback) { + void DoOldCallback(OldCompletionCallback* callback) { if (should_stall_) return; @@ -248,8 +264,20 @@ class MockPendingClientSocket : public StreamSocket { callback->Run(ERR_CONNECTION_FAILED); } } + void DoCallback(const CompletionCallback& callback) { + if (should_stall_) + return; + + if (should_connect_) { + is_connected_ = true; + callback.Run(OK); + } else { + is_connected_ = false; + callback.Run(ERR_CONNECTION_FAILED); + } + } - ScopedRunnableMethodFactory<MockPendingClientSocket> method_factory_; + base::WeakPtrFactory<MockPendingClientSocket> weak_factory_; bool should_connect_; bool should_stall_; int delay_ms_; diff --git a/net/spdy/spdy_proxy_client_socket.cc b/net/spdy/spdy_proxy_client_socket.cc index 7dbf83e..c9f0d8f 100644 --- a/net/spdy/spdy_proxy_client_socket.cc +++ b/net/spdy/spdy_proxy_client_socket.cc @@ -33,7 +33,7 @@ SpdyProxyClientSocket::SpdyProxyClientSocket( io_callback_(this, &SpdyProxyClientSocket::OnIOComplete)), next_state_(STATE_DISCONNECTED), spdy_stream_(spdy_stream), - read_callback_(NULL), + old_read_callback_(NULL), write_callback_(NULL), endpoint_(endpoint), auth_( @@ -92,7 +92,20 @@ HttpStream* SpdyProxyClientSocket::CreateConnectResponseStream() { // TODO(rch): create a more appropriate error code to disambiguate // the HTTPS Proxy tunnel failure from an HTTP Proxy tunnel failure. int SpdyProxyClientSocket::Connect(OldCompletionCallback* callback) { - DCHECK(!read_callback_); + DCHECK(!old_read_callback_ && read_callback_.is_null()); + if (next_state_ == STATE_OPEN) + return OK; + + DCHECK_EQ(STATE_DISCONNECTED, next_state_); + next_state_ = STATE_GENERATE_AUTH_TOKEN; + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + old_read_callback_ = callback; + return rv; +} +int SpdyProxyClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(!old_read_callback_ && read_callback_.is_null()); if (next_state_ == STATE_OPEN) return OK; @@ -108,7 +121,8 @@ int SpdyProxyClientSocket::Connect(OldCompletionCallback* callback) { void SpdyProxyClientSocket::Disconnect() { read_buffer_.clear(); user_buffer_ = NULL; - read_callback_ = NULL; + old_read_callback_ = NULL; + read_callback_.Reset(); write_buffer_len_ = 0; write_bytes_outstanding_ = 0; @@ -160,7 +174,7 @@ base::TimeDelta SpdyProxyClientSocket::GetConnectTimeMicros() const { int SpdyProxyClientSocket::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { - DCHECK(!read_callback_); + DCHECK(!old_read_callback_ && read_callback_.is_null()); DCHECK(!user_buffer_); if (next_state_ == STATE_DISCONNECTED) @@ -176,7 +190,7 @@ int SpdyProxyClientSocket::Read(IOBuffer* buf, int buf_len, int result = PopulateUserReadBuffer(); if (result == 0) { DCHECK(callback); - read_callback_ = callback; + old_read_callback_ = callback; return ERR_IO_PENDING; } user_buffer_ = NULL; @@ -271,9 +285,15 @@ void SpdyProxyClientSocket::OnIOComplete(int result) { DCHECK_NE(STATE_DISCONNECTED, next_state_); int rv = DoLoop(result); if (rv != ERR_IO_PENDING) { - OldCompletionCallback* c = read_callback_; - read_callback_ = NULL; - c->Run(rv); + if (old_read_callback_) { + OldCompletionCallback* c = old_read_callback_; + old_read_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); + } } } @@ -472,12 +492,18 @@ void SpdyProxyClientSocket::OnDataReceived(const char* data, int length) { make_scoped_refptr(new DrainableIOBuffer(io_buffer, length))); } - if (read_callback_) { + if (old_read_callback_) { int rv = PopulateUserReadBuffer(); - OldCompletionCallback* c = read_callback_; - read_callback_ = NULL; + OldCompletionCallback* c = old_read_callback_; + old_read_callback_ = NULL; user_buffer_ = NULL; c->Run(rv); + } else if (!read_callback_.is_null()) { + int rv = PopulateUserReadBuffer(); + CompletionCallback c = read_callback_; + read_callback_.Reset(); + user_buffer_ = NULL; + c.Run(rv); } } @@ -519,11 +545,17 @@ void SpdyProxyClientSocket::OnClose(int status) { // If we're in the middle of connecting, we need to make sure // we invoke the connect callback. if (connecting) { - DCHECK(read_callback_); - OldCompletionCallback* read_callback = read_callback_; - read_callback_ = NULL; - read_callback->Run(status); - } else if (read_callback_) { + DCHECK(old_read_callback_ || !read_callback_.is_null()); + if (old_read_callback_) { + OldCompletionCallback* read_callback = old_read_callback_; + old_read_callback_ = NULL; + read_callback->Run(status); + } else { + CompletionCallback read_callback = read_callback_; + read_callback_.Reset(); + read_callback.Run(status); + } + } else if (old_read_callback_ || !read_callback_.is_null()) { // If we have a read_callback_, the we need to make sure we call it back. OnDataReceived(NULL, 0); } diff --git a/net/spdy/spdy_proxy_client_socket.h b/net/spdy/spdy_proxy_client_socket.h index b7e48f8..8527896 100644 --- a/net/spdy/spdy_proxy_client_socket.h +++ b/net/spdy/spdy_proxy_client_socket.h @@ -60,8 +60,9 @@ class NET_EXPORT_PRIVATE SpdyProxyClientSocket : public ProxyClientSocket, virtual int RestartWithAuth(OldCompletionCallback* callback) OVERRIDE; virtual const scoped_refptr<HttpAuthController>& auth_controller() OVERRIDE; - // StreamSocket methods: + // StreamSocket implementation. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -130,7 +131,8 @@ class NET_EXPORT_PRIVATE SpdyProxyClientSocket : public ProxyClientSocket, // Stores the callback to the layer above, called on completing Read() or // Connect(). - OldCompletionCallback* read_callback_; + OldCompletionCallback* old_read_callback_; + CompletionCallback read_callback_; // Stores the callback to the layer above, called on completing Write(). OldCompletionCallback* write_callback_; diff --git a/remoting/jingle_glue/ssl_socket_adapter.cc b/remoting/jingle_glue/ssl_socket_adapter.cc index 625b8c1..1948c68 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.cc +++ b/remoting/jingle_glue/ssl_socket_adapter.cc @@ -206,6 +206,12 @@ int TransportSocket::Connect(net::OldCompletionCallback* callback) { NOTREACHED(); return false; } +int TransportSocket::Connect(const net::CompletionCallback& callback) { + // Connect is never called by SSLClientSocket, instead SSLSocketAdapter + // calls Connect() on socket_ directly. + NOTREACHED(); + return false; +} void TransportSocket::Disconnect() { socket_->Close(); diff --git a/remoting/jingle_glue/ssl_socket_adapter.h b/remoting/jingle_glue/ssl_socket_adapter.h index 2e6ebeb..ffa593d 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.h +++ b/remoting/jingle_glue/ssl_socket_adapter.h @@ -39,9 +39,9 @@ class TransportSocket : public net::StreamSocket, public sigslot::has_slots<> { addr_ = addr; } - // net::StreamSocket implementation - + // net::StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index c09a62a5..3b5499a 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -80,6 +80,10 @@ int FakeSocket::Connect(net::OldCompletionCallback* callback) { EXPECT_EQ(message_loop_, MessageLoop::current()); return net::OK; } +int FakeSocket::Connect(const net::CompletionCallback& callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); + return net::OK; +} void FakeSocket::Disconnect() { NOTIMPLEMENTED(); diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index 7eaab6b..2faa554 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -37,7 +37,7 @@ class FakeSocket : public net::StreamSocket { int input_pos() const { return input_pos_; } bool read_pending() const { return read_pending_; } - // net::Socket interface. + // net::Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, @@ -46,8 +46,9 @@ class FakeSocket : public net::StreamSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; - // net::StreamSocket interface. + // net::StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -93,7 +94,7 @@ class FakeUdpSocket : public net::Socket { void AppendInputPacket(const char* data, int data_size); int input_pos() const { return input_pos_; } - // net::Socket interface. + // net::Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, @@ -137,7 +138,7 @@ class FakeSession : public Session { FakeSocket* GetStreamChannel(const std::string& name); FakeUdpSocket* GetDatagramChannel(const std::string& name); - // Session interface. + // Session implementation. virtual void SetStateChangeCallback( const StateChangeCallback& callback) OVERRIDE; diff --git a/remoting/protocol/pepper_transport_socket_adapter.cc b/remoting/protocol/pepper_transport_socket_adapter.cc index 6acb8e3..a48afce 100644 --- a/remoting/protocol/pepper_transport_socket_adapter.cc +++ b/remoting/protocol/pepper_transport_socket_adapter.cc @@ -123,7 +123,31 @@ bool PepperTransportSocketAdapter::SetSendBufferSize(int32 size) { return false; } -int PepperTransportSocketAdapter::Connect(net::OldCompletionCallback* callback) { +int PepperTransportSocketAdapter::Connect( + net::OldCompletionCallback* callback) { + DCHECK(CalledOnValidThread()); + + if (!transport_.get()) + return net::ERR_UNEXPECTED; + + old_connect_callback_ = callback; + + // This will return false when GetNextAddress() returns an + // error. This helps to detect when the P2P Transport API is not + // supported. + int result = ProcessCandidates(); + if (result != net::OK) + return result; + + result = transport_->Connect( + callback_factory_.NewRequiredCallback( + &PepperTransportSocketAdapter::OnConnect)); + DCHECK_EQ(result, PP_OK_COMPLETIONPENDING); + + return net::ERR_IO_PENDING; +} +int PepperTransportSocketAdapter::Connect( + const net::CompletionCallback& callback) { DCHECK(CalledOnValidThread()); if (!transport_.get()) @@ -254,14 +278,20 @@ void PepperTransportSocketAdapter::OnNextAddress(int32_t result) { void PepperTransportSocketAdapter::OnConnect(int result) { DCHECK(CalledOnValidThread()); - DCHECK(connect_callback_); + DCHECK(old_connect_callback_ || !connect_callback_.is_null()); if (result == PP_OK) connected_ = true; - net::OldCompletionCallback* callback = connect_callback_; - connect_callback_ = NULL; - callback->Run(PPErrorToNetError(result)); + if (old_connect_callback_) { + net::OldCompletionCallback* callback = old_connect_callback_; + old_connect_callback_ = NULL; + callback->Run(PPErrorToNetError(result)); + } else { + net::CompletionCallback callback = connect_callback_; + connect_callback_.Reset(); + callback.Run(PPErrorToNetError(result)); + } } void PepperTransportSocketAdapter::OnRead(int32_t result) { diff --git a/remoting/protocol/pepper_transport_socket_adapter.h b/remoting/protocol/pepper_transport_socket_adapter.h index 881d612..c25506d 100644 --- a/remoting/protocol/pepper_transport_socket_adapter.h +++ b/remoting/protocol/pepper_transport_socket_adapter.h @@ -47,7 +47,7 @@ class PepperTransportSocketAdapter : public base::NonThreadSafe, // Adds candidate received from the peer. void AddRemoteCandidate(const std::string& candidate); - // net::Socket interface. + // net::Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, @@ -55,8 +55,9 @@ class PepperTransportSocketAdapter : public base::NonThreadSafe, virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; - // net::StreamSocket interface. + // net::StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -84,7 +85,8 @@ class PepperTransportSocketAdapter : public base::NonThreadSafe, scoped_ptr<pp::Transport_Dev> transport_; - net::OldCompletionCallback* connect_callback_; + net::OldCompletionCallback* old_connect_callback_; + net::CompletionCallback connect_callback_; bool connected_; bool get_address_pending_; |