diff options
Diffstat (limited to 'net/socket')
26 files changed, 560 insertions, 92 deletions
diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 5ef38a7..6405362 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -78,6 +78,11 @@ class MockClientSocket : public StreamSocket { was_used_to_convey_data_ = true; return len; } + virtual int Write( + IOBuffer* /* buf */, int len, const CompletionCallback& /* callback */) { + was_used_to_convey_data_ = true; + return len; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } diff --git a/net/socket/socket.h b/net/socket/socket.h index c185c44..b02de0a 100644 --- a/net/socket/socket.h +++ b/net/socket/socket.h @@ -48,6 +48,8 @@ class NET_EXPORT Socket { // Disconnected before the write completes, the callback will not be invoked. virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) = 0; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; // Set the receive buffer size (in bytes) for the socket. // Note: changing this value can affect the TCP window size on some platforms. diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 376d542..7730a5b 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -832,6 +832,26 @@ int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, return write_result.result; } +int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + was_used_to_convey_data_ = true; + + if (write_result.async) { + RunCallbackAsync(callback, write_result.result); + return net::ERR_IO_PENDING; + } + + return write_result.result; +} int MockTCPClientSocket::Connect(net::OldCompletionCallback* callback) { if (connected_) @@ -972,7 +992,7 @@ DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( net::NetLog* net_log, net::DeterministicSocketData* data) : MockClientSocket(net_log), write_pending_(false), - write_callback_(NULL), + old_write_callback_(NULL), write_result_(0), read_data_(), read_buf_(NULL), @@ -987,7 +1007,10 @@ DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} void DeterministicMockTCPClientSocket::CompleteWrite() { was_used_to_convey_data_ = true; write_pending_ = false; - write_callback_->Run(write_result_); + if (old_write_callback_) + old_write_callback_->Run(write_result_); + else + write_callback_.Run(write_result_); } int DeterministicMockTCPClientSocket::CompleteRead() { @@ -1021,30 +1044,6 @@ int DeterministicMockTCPClientSocket::CompleteRead() { return result; } -int DeterministicMockTCPClientSocket::Write( - net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { - DCHECK(buf); - DCHECK_GT(buf_len, 0); - - if (!connected_) - return net::ERR_UNEXPECTED; - - std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); - - if (write_result.async) { - write_callback_ = callback; - write_result_ = write_result.result; - DCHECK(write_callback_ != NULL); - write_pending_ = true; - return net::ERR_IO_PENDING; - } - - was_used_to_convey_data_ = true; - write_pending_ = false; - return write_result.result; -} - int DeterministicMockTCPClientSocket::Read( net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { if (!connected_) @@ -1092,6 +1091,53 @@ int DeterministicMockTCPClientSocket::Read( return CompleteRead(); } +int DeterministicMockTCPClientSocket::Write( + net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + old_write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(old_write_callback_ != NULL); + write_pending_ = true; + return net::ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + write_pending_ = false; + return write_result.result; +} +int DeterministicMockTCPClientSocket::Write( + net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(!write_callback_.is_null()); + write_pending_ = true; + return net::ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + write_pending_ = false; + return write_result.result; +} + // TODO(erikchen): Support connect sequencing. int DeterministicMockTCPClientSocket::Connect( net::OldCompletionCallback* callback) { @@ -1235,6 +1281,10 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { return transport_->socket()->Write(buf, buf_len, callback); } +int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + return transport_->socket()->Write(buf, buf_len, callback); +} int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { OldConnectCallback* connect_callback = new OldConnectCallback( @@ -1422,6 +1472,23 @@ int MockUDPClientSocket::Write(net::IOBuffer* buf, int buf_len, } return write_result.result; } +int MockUDPClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + RunCallbackAsync(callback, write_result.result); + return ERR_IO_PENDING; + } + return write_result.result; +} bool MockUDPClientSocket::SetReceiveBufferSize(int32 size) { return true; diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 73ffc3d..1840bfc 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -592,6 +592,8 @@ class MockClientSocket : public net::SSLClientSocket { const net::CompletionCallback& callback) = 0; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) = 0; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) = 0; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -648,6 +650,8 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -703,12 +707,14 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, int CompleteRead(); // Socket implementation. - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Read(net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -726,7 +732,8 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, private: bool write_pending_; - net::OldCompletionCallback* write_callback_; + net::OldCompletionCallback* old_write_callback_; + net::CompletionCallback write_callback_; int write_result_; net::MockRead read_data_; @@ -757,6 +764,8 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -804,6 +813,8 @@ class MockUDPClientSocket : public DatagramClientSocket, const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index ea5fc7a..3497e24 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -204,10 +204,18 @@ int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, // Write is called by the transport layer. This can only be done if the // SOCKS handshake is complete. int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} +int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index b83a347..3c0255d 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -74,6 +74,9 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 623f202..40f8453 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -235,7 +235,15 @@ int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} +int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 1a4a75c..0f28e10 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -71,6 +71,9 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index f58d340..32f0a08 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -533,7 +533,7 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, ssl_config_(ssl_config), old_user_connect_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), next_handshake_state_(STATE_NONE), @@ -737,7 +737,25 @@ int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, int SSLClientSocketMac::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!user_write_buf_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoPayloadWrite(); + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} +int SSLClientSocketMac::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake()); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!user_write_buf_); user_write_buf_ = buf; @@ -972,15 +990,23 @@ void SSLClientSocketMac::DoReadCallback(int rv) { void SSLClientSocketMac::DoWriteCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(old_user_write_callback_ && !user_write_callback_.is_null()); // Since Run may result in Write being called, clear user_write_callback_ up // front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } void SSLClientSocketMac::OnHandshakeIOComplete(int result) { diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index 7792cb3..661432b 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -77,6 +77,9 @@ class SSLClientSocketMac : public SSLClientSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -126,7 +129,8 @@ class SSLClientSocketMac : public SSLClientSocket { CompletionCallback user_connect_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 6c16c4a..cd40140 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -447,7 +447,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, ssl_config_(ssl_config), old_user_connect_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), server_cert_nss_(NULL), @@ -576,7 +576,7 @@ int SSLClientSocketNSS::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(!user_write_buf_); @@ -624,7 +624,7 @@ int SSLClientSocketNSS::Connect(const CompletionCallback& callback) { DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(!user_write_buf_); @@ -696,7 +696,8 @@ void SSLClientSocketNSS::Disconnect() { user_connect_callback_.Reset(); old_user_read_callback_ = NULL; user_read_callback_.Reset(); - user_write_callback_ = NULL; + old_user_write_callback_ = NULL; + user_write_callback_.Reset(); user_read_buf_ = NULL; user_read_buf_len_ = 0; user_write_buf_ = NULL; @@ -864,8 +865,36 @@ int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, EnterFunction(buf_len); DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_write_callback_); - DCHECK(!old_user_connect_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + if (corked_) { + corked_ = false; + uncork_timer_.Reset(); + } + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + LeaveFunction(rv); + return rv; +} +int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + EnterFunction(buf_len); + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -1247,15 +1276,23 @@ void SSLClientSocketNSS::DoReadCallback(int rv) { void SSLClientSocketNSS::DoWriteCallback(int rv) { EnterFunction(rv); DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(old_user_write_callback_ || !user_write_callback_.is_null()); // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } LeaveFunction(""); } diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 7e75cd5..22cb4a5 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -96,6 +96,9 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -233,7 +236,8 @@ class SSLClientSocketNSS : public SSLClientSocket { CompletionCallback user_connect_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index 1237348..ea3b629 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -392,7 +392,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( transport_recv_busy_(false), old_user_connect_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), completed_handshake_(false), client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), @@ -632,11 +632,19 @@ void SSLClientSocketOpenSSL::DoReadCallback(int rv) { void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } // StreamSocket methods @@ -712,7 +720,8 @@ void SSLClientSocketOpenSSL::Disconnect() { user_connect_callback_.Reset(); old_user_read_callback_ = NULL; user_read_callback_.Reset(); - user_write_callback_ = NULL; + old_user_write_callback_ = NULL; + user_write_callback_.Reset(); user_read_buf_ = NULL; user_read_buf_len_ = 0; user_write_buf_ = NULL; @@ -1246,6 +1255,23 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf, int rv = DoWriteLoop(OK); if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + + return rv; +} +int SSLClientSocketOpenSSL::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { user_write_callback_ = callback; } else { user_write_buf_ = NULL; diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index a15c0e3..74ee3bb 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -82,7 +82,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -126,7 +129,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { CompletionCallback user_connect_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 30f599d..3c855fb 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -400,7 +400,7 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, old_user_connect_callback_(NULL), old_user_read_callback_(NULL), user_read_buf_len_(0), - user_write_callback_(NULL), + old_user_write_callback_(NULL), user_write_buf_len_(0), next_state_(STATE_NONE), cert_verifier_(context.cert_verifier), @@ -869,7 +869,29 @@ int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, int SSLClientSocketWin::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + + DCHECK(!user_write_buf_); + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoPayloadEncrypt(); + if (rv != OK) + return rv; + + rv = DoPayloadWrite(); + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} +int SSLClientSocketWin::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake()); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!user_write_buf_); user_write_buf_ = buf; @@ -964,12 +986,20 @@ void SSLClientSocketWin::OnWriteComplete(int result) { int rv = DoPayloadWriteComplete(result); if (rv != ERR_IO_PENDING) { - DCHECK(user_write_callback_); - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + DCHECK(old_user_write_callback_ || !user_write_callback_.is_null()); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } } diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 27ce300..8fc326d 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -75,7 +75,10 @@ class SSLClientSocketWin : public SSLClientSocket { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -134,7 +137,8 @@ class SSLClientSocketWin : public SSLClientSocket { int user_read_buf_len_; // User function to callback when a Write() completes. - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; scoped_refptr<IOBuffer> user_write_buf_; int user_write_buf_len_; diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 0785dd7d..c87076f 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -66,7 +66,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( transport_recv_busy_(false), user_handshake_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), nss_fd_(NULL), nss_bufs_(NULL), transport_socket_(transport_socket), @@ -199,7 +199,26 @@ int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} +int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -760,15 +779,23 @@ void SSLServerSocketNSS::DoReadCallback(int rv) { void SSLServerSocketNSS::DoWriteCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(old_user_write_callback_ || !user_write_callback_.is_null()); // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } // static diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 39283f6..4a4c3f5 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -45,6 +45,8 @@ class SSLServerSocketNSS : public SSLServerSocket { const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -113,7 +115,8 @@ class SSLServerSocketNSS : public SSLServerSocket { OldCompletionCallback* user_handshake_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index eb9dc7c..1ec6b4b 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -85,6 +85,14 @@ class FakeDataChannel { &FakeDataChannel::DoReadCallback)); return buf_len; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + data_.push(new net::DrainableIOBuffer(buf, buf_len)); + MessageLoop::current()->PostTask( + FROM_HERE, task_factory_.NewRunnableMethod( + &FakeDataChannel::DoReadCallback)); + return buf_len; + } private: void DoReadCallback() { @@ -160,6 +168,12 @@ class FakeSocket : public StreamSocket { buf_len = rand() % buf_len + 1; return outgoing_->Write(buf, buf_len, callback); } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + // Write random number of bytes. + buf_len = rand() % buf_len + 1; + return outgoing_->Write(buf, buf_len, callback); + } virtual bool SetReceiveBufferSize(int32 size) { return true; diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index 3c99ae5..3243177 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -574,6 +574,42 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, old_write_callback_ = callback; return ERR_IO_PENDING; } +int TCPClientSocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_); + DCHECK(!waiting_connect()); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_GT(buf_len, 0); + + int nwrite = InternalWrite(buf, buf_len); + if (nwrite >= 0) { + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(nwrite); + if (nwrite > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite, + buf->data()); + return nwrite; + } + if (errno != EAGAIN && errno != EWOULDBLOCK) + return MapSystemError(errno); + + if (!MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno; + return MapSystemError(errno); + } + + write_buf_ = buf; + write_buf_len_ = buf_len; + write_callback_ = callback; + return ERR_IO_PENDING; +} int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { int nwrite; diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index 47f19a0..448d7a7 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -70,6 +70,9 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index cb59a98..8101782 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -318,7 +318,7 @@ TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, waiting_read_(false), waiting_write_(false), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), @@ -792,7 +792,58 @@ int TCPClientSocketWin::Write(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_write_); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK_GT(buf_len, 0); + DCHECK(!core_->write_iobuffer_); + + base::StatsCounter writes("tcp.writes"); + writes.Increment(); + + core_->write_buffer_.len = buf_len; + core_->write_buffer_.buf = buf->data(); + core_->write_buffer_length_ = buf_len; + + // TODO(wtc): Remove the assertion after enough testing. + AssertEventNotSignaled(core_->write_overlapped_.hEvent); + DWORD num; + int rv = WSASend(socket_, &core_->write_buffer_, 1, &num, 0, + &core_->write_overlapped_, NULL); + if (rv == 0) { + if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) { + rv = static_cast<int>(num); + if (rv > buf_len || rv < 0) { + // It seems that some winsock interceptors report that more was written + // than was available. Treat this as an error. http://crbug.com/27870 + LOG(ERROR) << "Detected broken LSP: Asked to write " << buf_len + << " bytes, but " << rv << " bytes reported."; + return ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES; + } + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(rv); + if (rv > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, + core_->write_buffer_.buf); + return rv; + } + } else { + int os_error = WSAGetLastError(); + if (os_error != WSA_IO_PENDING) + return MapSystemError(os_error); + } + core_->WatchForWrite(); + waiting_write_ = true; + old_write_callback_ = callback; + core_->write_iobuffer_ = buf; + return ERR_IO_PENDING; +} +int TCPClientSocketWin::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_write_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); DCHECK_GT(buf_len, 0); DCHECK(!core_->write_iobuffer_); @@ -897,12 +948,19 @@ void TCPClientSocketWin::DoReadCallback(int rv) { void TCPClientSocketWin::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(write_callback_); + DCHECK(old_write_callback_ || !write_callback_.is_null()); - // since Run may result in Write being called, clear write_callback_ up front. - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + // Since Run may result in Write being called, clear old_write_callback_ up + // front. + if (old_write_callback_) { + OldCompletionCallback* c = old_write_callback_; + old_write_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); + } } void TCPClientSocketWin::DidCompleteConnect() { diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index 1e75933..2e58cad 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -63,7 +63,10 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -130,7 +133,8 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, CompletionCallback read_callback_; // External callback; called when write is complete. - OldCompletionCallback* write_callback_; + OldCompletionCallback* old_write_callback_; + CompletionCallback write_callback_; // The next state for the Connect() state machine. ConnectState next_connect_state_; diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 56b1fa9c..bfd330a 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -107,6 +107,10 @@ class MockClientSocket : public StreamSocket { OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } @@ -169,6 +173,10 @@ class MockFailingClientSocket : public StreamSocket { OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } @@ -260,6 +268,10 @@ class MockPendingClientSocket : public StreamSocket { OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } diff --git a/net/socket/web_socket_server_socket.cc b/net/socket/web_socket_server_socket.cc index d792689..b24f25c 100644 --- a/net/socket/web_socket_server_socket.cc +++ b/net/socket/web_socket_server_socket.cc @@ -397,6 +397,53 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { ConsiderTransportWrite(); return net::ERR_IO_PENDING; } + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE { + if (buf_len == 0) + return 0; + if (buf == NULL || buf_len < 0) { + NOTREACHED(); + return net::ERR_INVALID_ARGUMENT; + } + DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'), + buf->data() + buf_len); + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + break; + } + case PHASE_SHUT: { + return net::ERR_SOCKET_NOT_CONNECTED; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + + net::IOBuffer* frame_start = new net::IOBuffer(1); + frame_start->data()[0] = '\x00'; + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, + new net::DrainableIOBuffer(frame_start, 1), + NULL)); + + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, + new net::DrainableIOBuffer(buf, buf_len), + callback)); + + net::IOBuffer* frame_end = new net::IOBuffer(1); + frame_end->data()[0] = '\xff'; + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, + new net::DrainableIOBuffer(frame_end, 1), + NULL)); + + ConsiderTransportWrite(); + return net::ERR_IO_PENDING; + } virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return transport_socket_->SetReceiveBufferSize(size); diff --git a/net/socket/web_socket_server_socket_unittest.cc b/net/socket/web_socket_server_socket_unittest.cc index cabb4b9..bb4d022 100644 --- a/net/socket/web_socket_server_socket_unittest.cc +++ b/net/socket/web_socket_server_socket_unittest.cc @@ -157,6 +157,24 @@ class TestingTransportSocket : public net::Socket { } MessageLoop::current()->PostTask(FROM_HERE, method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoOldWriteCallback, callback, lot)); + return net::ERR_IO_PENDING; + } + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + CHECK_GT(buf_len, 0); + int remaining = answer_->BytesRemaining(); + CHECK_GE(remaining, buf_len); + int lot = std::min(remaining, buf_len); + if (GetRand(0, 1)) + lot = GetRand(1, lot); + std::copy(buf->data(), buf->data() + lot, answer_->data()); + answer_->DidConsume(lot); + if (GetRand(0, 1)) { + return lot; + } + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( &TestingTransportSocket::DoWriteCallback, callback, lot)); return net::ERR_IO_PENDING; } @@ -192,10 +210,14 @@ class TestingTransportSocket : public net::Socket { } } - void DoWriteCallback(net::OldCompletionCallback* callback, int result) { + void DoOldWriteCallback(net::OldCompletionCallback* callback, int result) { if (callback) callback->Run(result); } + void DoWriteCallback(const net::CompletionCallback& callback, int result) { + if (!callback.is_null()) + callback.Run(result); + } bool is_closed_; |