diff options
author | jhawkins@chromium.org <jhawkins@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-12-07 02:03:33 +0000 |
---|---|---|
committer | jhawkins@chromium.org <jhawkins@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-12-07 02:03:33 +0000 |
commit | 3f55aa10587b3eaa629d7e95de87998b399fe3e2 (patch) | |
tree | 0f407a7e7fc837bed337a9a5af787edbd9473ef6 /net/socket | |
parent | b456003a580041d83e7f5a998c15f62ce380560f (diff) | |
download | chromium_src-3f55aa10587b3eaa629d7e95de87998b399fe3e2.zip chromium_src-3f55aa10587b3eaa629d7e95de87998b399fe3e2.tar.gz chromium_src-3f55aa10587b3eaa629d7e95de87998b399fe3e2.tar.bz2 |
base::Bind: Convert Socket::Read.
BUG=none
TEST=none
R=csilv
Review URL: http://codereview.chromium.org/8801005
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@113326 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/socket')
26 files changed, 766 insertions, 138 deletions
diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index d318dd6..fdbc63f 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -49,12 +49,17 @@ class MockClientSocket : public StreamSocket { MockClientSocket() : connected_(false), was_used_to_convey_data_(false), num_bytes_read_(0) {} - // Socket methods: + // Socket implementation. virtual int Read( IOBuffer* /* buf */, int len, OldCompletionCallback* /* callback */) { num_bytes_read_ += len; return len; } + virtual int Read( + IOBuffer* /* buf */, int len, const CompletionCallback& /* callback */) { + num_bytes_read_ += len; + return len; + } virtual int Write( IOBuffer* /* buf */, int len, OldCompletionCallback* /* callback */) { @@ -211,7 +216,7 @@ class TestConnectJob : public ConnectJob { } private: - // ConnectJob methods: + // ConnectJob implementation. virtual int ConnectInternal() { AddressList ignored; @@ -366,8 +371,7 @@ class TestConnectJobFactory timeout_duration_ = timeout_duration; } - // ConnectJobFactory methods: - + // ConnectJobFactory implementation. virtual ConnectJob* NewConnectJob( const std::string& group_name, const TestClientSocketPoolBase::Request& request, diff --git a/net/socket/socket.h b/net/socket/socket.h index 90185a2..c185c44 100644 --- a/net/socket/socket.h +++ b/net/socket/socket.h @@ -31,6 +31,8 @@ class NET_EXPORT Socket { // callback will not be invoked. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) = 0; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; // Writes data, up to |buf_len| bytes, to the socket. Note: data may be // written partially. The number of bytes written is returned, or an error diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index eb1599f..376d542 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -738,12 +738,14 @@ MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, peer_closed_connection_(false), pending_buf_(NULL), pending_buf_len_(0), - pending_callback_(NULL), + old_pending_callback_(NULL), was_used_to_convey_data_(false) { DCHECK(data_); data_->Reset(); } +MockTCPClientSocket::~MockTCPClientSocket() {} + int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { if (!connected_) @@ -755,7 +757,7 @@ int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, // Store our async IO data. pending_buf_ = buf; pending_buf_len_ = buf_len; - pending_callback_ = callback; + old_pending_callback_ = callback; if (need_read_data_) { read_data_ = data_->GetNextRead(); @@ -776,6 +778,39 @@ int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, return CompleteRead(); } +int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + if (!connected_) + return net::ERR_UNEXPECTED; + + // If the buffer is already in use, a read is already in progress! + DCHECK(pending_buf_ == NULL); + + // Store our async IO data. + pending_buf_ = buf; + pending_buf_len_ = buf_len; + pending_callback_ = callback; + + if (need_read_data_) { + read_data_ = data_->GetNextRead(); + if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { + // This MockRead is just a marker to instruct us to set + // peer_closed_connection_. Skip it and get the next one. + read_data_ = data_->GetNextRead(); + peer_closed_connection_ = true; + } + // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility + // to complete the async IO manually later (via OnReadComplete). + if (read_data_.result == ERR_IO_PENDING) { + // We need to be using async IO in this case. + DCHECK(!callback.is_null()); + return ERR_IO_PENDING; + } + need_read_data_ = false; + } + + return CompleteRead(); +} int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { @@ -825,7 +860,8 @@ int MockTCPClientSocket::Connect(const net::CompletionCallback& callback) { void MockTCPClientSocket::Disconnect() { MockClientSocket::Disconnect(); - pending_callback_ = NULL; + old_pending_callback_ = NULL; + pending_callback_.Reset(); } bool MockTCPClientSocket::IsConnected() const { @@ -876,9 +912,15 @@ void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // let CompleteRead() schedule a callback. read_data_.async = false; - net::OldCompletionCallback* callback = pending_callback_; - int rv = CompleteRead(); - RunOldCallback(callback, rv); + if (old_pending_callback_) { + net::OldCompletionCallback* callback = old_pending_callback_; + int rv = CompleteRead(); + RunOldCallback(callback, rv); + } else { + net::CompletionCallback callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); + } } int MockTCPClientSocket::CompleteRead() { @@ -890,10 +932,12 @@ int MockTCPClientSocket::CompleteRead() { // Save the pending async IO data and reset our |pending_| state. net::IOBuffer* buf = pending_buf_; int buf_len = pending_buf_len_; - net::OldCompletionCallback* callback = pending_callback_; + net::OldCompletionCallback* old_callback = old_pending_callback_; + net::CompletionCallback callback = pending_callback_; pending_buf_ = NULL; pending_buf_len_ = 0; - pending_callback_ = NULL; + old_pending_callback_ = NULL; + pending_callback_.Reset(); int result = read_data_.result; DCHECK(result != ERR_IO_PENDING); @@ -914,8 +958,11 @@ int MockTCPClientSocket::CompleteRead() { } if (read_data_.async) { - DCHECK(callback); - RunCallbackAsync(callback, result); + DCHECK(old_callback || !callback.is_null()); + if (old_callback) + RunCallbackAsync(old_callback, result); + else + RunCallbackAsync(callback, result); return net::ERR_IO_PENDING; } return result; @@ -931,7 +978,7 @@ DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( read_buf_(NULL), read_buf_len_(0), read_pending_(false), - read_callback_(NULL), + old_read_callback_(NULL), data_(data), was_used_to_convey_data_(false) {} @@ -965,7 +1012,10 @@ int DeterministicMockTCPClientSocket::CompleteRead() { if (read_pending_) { read_pending_ = false; - read_callback_->Run(result); + if (old_read_callback_) + old_read_callback_->Run(result); + else + read_callback_.Run(result); } return result; @@ -1007,11 +1057,34 @@ int DeterministicMockTCPClientSocket::Read( read_buf_ = buf; read_buf_len_ = buf_len; + old_read_callback_ = callback; + + if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) { + read_pending_ = true; + DCHECK(old_read_callback_); + return ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + return CompleteRead(); +} +int DeterministicMockTCPClientSocket::Read( + net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { + if (!connected_) + return net::ERR_UNEXPECTED; + + read_data_ = data_->GetNextRead(); + // The buffer should always be big enough to contain all the MockRead data. To + // use small buffers, split the data into multiple MockReads. + DCHECK_LE(read_data_.data_len, buf_len); + + read_buf_ = buf; + read_buf_len_ = buf_len; read_callback_ = callback; if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) { read_pending_ = true; - DCHECK(read_callback_); + DCHECK(!read_callback_.is_null()); return ERR_IO_PENDING; } @@ -1153,6 +1226,10 @@ int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { return transport_->socket()->Read(buf, buf_len, callback); } +int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + return transport_->socket()->Read(buf, buf_len, callback); +} int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { @@ -1265,7 +1342,7 @@ MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, need_read_data_(true), pending_buf_(NULL), pending_buf_len_(0), - pending_callback_(NULL), + old_pending_callback_(NULL), net_log_(NetLog::Source(), net_log), ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { DCHECK(data_); @@ -1285,7 +1362,7 @@ int MockUDPClientSocket::Read(net::IOBuffer* buf, int buf_len, // Store our async IO data. pending_buf_ = buf; pending_buf_len_ = buf_len; - pending_callback_ = callback; + old_pending_callback_ = callback; if (need_read_data_) { read_data_ = data_->GetNextRead(); @@ -1300,6 +1377,33 @@ int MockUDPClientSocket::Read(net::IOBuffer* buf, int buf_len, return CompleteRead(); } +int MockUDPClientSocket::Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + if (!connected_) + return net::ERR_UNEXPECTED; + + // If the buffer is already in use, a read is already in progress! + DCHECK(pending_buf_ == NULL); + + // Store our async IO data. + pending_buf_ = buf; + pending_buf_len_ = buf_len; + pending_callback_ = callback; + + if (need_read_data_) { + read_data_ = data_->GetNextRead(); + // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility + // to complete the async IO manually later (via OnReadComplete). + if (read_data_.result == ERR_IO_PENDING) { + // We need to be using async IO in this case. + DCHECK(!callback.is_null()); + return ERR_IO_PENDING; + } + need_read_data_ = false; + } + + return CompleteRead(); +} int MockUDPClientSocket::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { @@ -1365,9 +1469,15 @@ void MockUDPClientSocket::OnReadComplete(const MockRead& data) { // let CompleteRead() schedule a callback. read_data_.async = false; - net::OldCompletionCallback* callback = pending_callback_; - int rv = CompleteRead(); - RunCallback(callback, rv); + if (old_pending_callback_) { + net::OldCompletionCallback* callback = old_pending_callback_; + int rv = CompleteRead(); + RunOldCallback(callback, rv); + } else { + net::CompletionCallback callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); + } } int MockUDPClientSocket::CompleteRead() { @@ -1377,10 +1487,13 @@ int MockUDPClientSocket::CompleteRead() { // Save the pending async IO data and reset our |pending_| state. net::IOBuffer* buf = pending_buf_; int buf_len = pending_buf_len_; - net::OldCompletionCallback* callback = pending_callback_; + net::OldCompletionCallback* old_callback = old_pending_callback_; + net::CompletionCallback callback = pending_callback_; pending_buf_ = NULL; pending_buf_len_ = 0; - pending_callback_ = NULL; + old_pending_callback_ = NULL; + pending_callback_.Reset(); + pending_callback_.Reset(); int result = read_data_.result; DCHECK(result != ERR_IO_PENDING); @@ -1400,8 +1513,11 @@ int MockUDPClientSocket::CompleteRead() { } if (read_data_.async) { - DCHECK(callback); - RunCallbackAsync(callback, result); + DCHECK(old_callback || !callback.is_null()); + if (old_callback) + RunCallbackAsync(old_callback, result); + else + RunCallbackAsync(callback, result); return net::ERR_IO_PENDING; } return result; @@ -1410,15 +1526,26 @@ int MockUDPClientSocket::CompleteRead() { void MockUDPClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, int result) { MessageLoop::current()->PostTask(FROM_HERE, + base::Bind(&MockUDPClientSocket::RunOldCallback, + weak_factory_.GetWeakPtr(), callback, result)); +} +void MockUDPClientSocket::RunCallbackAsync( + const net::CompletionCallback& callback, int result) { + MessageLoop::current()->PostTask(FROM_HERE, base::Bind(&MockUDPClientSocket::RunCallback, weak_factory_.GetWeakPtr(), callback, result)); } -void MockUDPClientSocket::RunCallback(net::OldCompletionCallback* callback, +void MockUDPClientSocket::RunOldCallback(net::OldCompletionCallback* callback, int result) { if (callback) callback->Run(result); } +void MockUDPClientSocket::RunCallback(const net::CompletionCallback& callback, + int result) { + if (!callback.is_null()) + callback.Run(result); +} TestSocketRequest::TestSocketRequest( std::vector<TestSocketRequest*>* request_order, diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index fa47d26..73ffc3d 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -19,6 +19,7 @@ #include "base/memory/weak_ptr.h" #include "base/string16.h" #include "net/base/address_list.h" +#include "net/base/completion_callback.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/net_log.h" @@ -587,6 +588,8 @@ class MockClientSocket : public net::SSLClientSocket { // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) = 0; + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) = 0; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) = 0; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; @@ -634,12 +637,15 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { public: MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log, net::SocketDataProvider* socket); + virtual ~MockTCPClientSocket(); net::AddressList addresses() const { return addresses_; } // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; @@ -677,7 +683,8 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { // While an asynchronous IO is pending, we save our user-buffer state. net::IOBuffer* pending_buf_; int pending_buf_len_; - net::OldCompletionCallback* pending_callback_; + net::OldCompletionCallback* old_pending_callback_; + net::CompletionCallback pending_callback_; bool was_used_to_convey_data_; }; @@ -695,11 +702,13 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, void CompleteWrite(); int CompleteRead(); - // Socket: + // Socket implementation. virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -725,7 +734,8 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, net::IOBuffer* read_buf_; int read_buf_len_; bool read_pending_; - net::OldCompletionCallback* read_callback_; + net::OldCompletionCallback* old_read_callback_; + net::CompletionCallback read_callback_; net::DeterministicSocketData* data_; bool was_used_to_convey_data_; }; @@ -743,6 +753,8 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; @@ -785,31 +797,35 @@ class MockUDPClientSocket : public DatagramClientSocket, MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log); virtual ~MockUDPClientSocket(); - // Socket interface + // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; - // DatagramSocket interface + // DatagramSocket implementation. virtual void Close() OVERRIDE; virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; virtual const BoundNetLog& NetLog() const OVERRIDE; - // DatagramClientSocket interface + // DatagramClientSocket implementation. virtual int Connect(const IPEndPoint& address) OVERRIDE; - // AsyncSocket interface + // AsyncSocket implementation. virtual void OnReadComplete(const MockRead& data) OVERRIDE; private: int CompleteRead(); void RunCallbackAsync(net::OldCompletionCallback* callback, int result); - void RunCallback(net::OldCompletionCallback* callback, int result); + void RunCallbackAsync(const net::CompletionCallback& callback, int result); + void RunOldCallback(net::OldCompletionCallback* callback, int result); + void RunCallback(const net::CompletionCallback& callback, int result); bool connected_; SocketDataProvider* data_; @@ -820,7 +836,8 @@ class MockUDPClientSocket : public DatagramClientSocket, // While an asynchronous IO is pending, we save our user-buffer state. net::IOBuffer* pending_buf_; int pending_buf_len_; - net::OldCompletionCallback* pending_callback_; + net::OldCompletionCallback* old_pending_callback_; + net::CompletionCallback pending_callback_; BoundNetLog net_log_; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index c9d2825..ea5fc7a 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -188,7 +188,15 @@ int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Read(buf, buf_len, callback); +} +int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Read(buf, buf_len, callback); } diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index 748b55a..b83a347 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -48,7 +48,7 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { // On destruction Disconnect() is called. virtual ~SOCKS5ClientSocket(); - // StreamSocket methods: + // StreamSocket implementation. // Does the SOCKS handshake and completes the protocol. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; @@ -64,10 +64,13 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index a4c4b47..623f202 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -216,7 +216,15 @@ int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Read(buf, buf_len, callback); +} +int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Read(buf, buf_len, callback); } diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index c7089af..1a4a75c 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -45,7 +45,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { // On destruction Disconnect() is called. virtual ~SOCKSClientSocket(); - // StreamSocket methods: + // StreamSocket implementation. // Does the SOCKS handshake and completes the protocol. virtual int Connect(OldCompletionCallback* callback) OVERRIDE; @@ -61,10 +61,13 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index 817b3be..f58d340 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -532,7 +532,7 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, host_and_port_(host_and_port), ssl_config_(ssl_config), old_user_connect_callback_(NULL), - user_read_callback_(NULL), + old_user_read_callback_(NULL), user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), @@ -700,7 +700,25 @@ base::TimeDelta SSLClientSocketMac::GetConnectTimeMicros() const { int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake()); - DCHECK(!user_read_callback_); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(!user_read_buf_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoPayloadRead(); + if (rv == ERR_IO_PENDING) { + old_user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + return rv; +} +int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake()); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); DCHECK(!user_read_buf_); user_read_buf_ = buf; @@ -933,15 +951,23 @@ void SSLClientSocketMac::DoConnectCallback(int rv) { void SSLClientSocketMac::DoReadCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_read_callback_); + DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); // Since Run may result in Read being called, clear user_read_callback_ up // front. - OldCompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); + if (old_user_read_callback_) { + OldCompletionCallback* c = old_user_read_callback_; + old_user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); + } } void SSLClientSocketMac::DoWriteCallback(int rv) { diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index b9dccc0..7792cb3 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -71,6 +71,9 @@ class SSLClientSocketMac : public SSLClientSocket { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; @@ -121,7 +124,8 @@ class SSLClientSocketMac : public SSLClientSocket { OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; - OldCompletionCallback* user_read_callback_; + OldCompletionCallback* old_user_read_callback_; + CompletionCallback user_read_callback_; OldCompletionCallback* user_write_callback_; // Used by Read function. diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index e1ac396..6e62edb 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -447,7 +447,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, host_and_port_(host_and_port), ssl_config_(ssl_config), old_user_connect_callback_(NULL), - user_read_callback_(NULL), + old_user_read_callback_(NULL), user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), @@ -576,7 +576,7 @@ int SSLClientSocketNSS::Connect(OldCompletionCallback* callback) { EnterFunction(""); DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_read_callback_); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); DCHECK(!user_write_callback_); DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); @@ -624,7 +624,7 @@ int SSLClientSocketNSS::Connect(const CompletionCallback& callback) { EnterFunction(""); DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_read_callback_); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); DCHECK(!user_write_callback_); DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); @@ -695,7 +695,8 @@ void SSLClientSocketNSS::Disconnect() { transport_recv_busy_ = false; old_user_connect_callback_ = NULL; user_connect_callback_.Reset(); - user_read_callback_ = NULL; + old_user_read_callback_ = NULL; + user_read_callback_.Reset(); user_write_callback_ = NULL; user_read_buf_ = NULL; user_read_buf_len_ = 0; @@ -815,8 +816,32 @@ int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, EnterFunction(buf_len); DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_read_callback_); - DCHECK(!old_user_connect_callback_); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(!user_read_buf_); + DCHECK(nss_bufs_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + old_user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + LeaveFunction(rv); + return rv; +} +int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + EnterFunction(buf_len); + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(nss_bufs_); @@ -1200,15 +1225,23 @@ void SSLClientSocketNSS::UpdateConnectionStatus() { void SSLClientSocketNSS::DoReadCallback(int rv) { EnterFunction(rv); DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_read_callback_); + DCHECK(old_user_read_callback_ || user_read_callback_.is_null()); - // Since Run may result in Read being called, clear |user_read_callback_| + // Since Run may result in Read being called, clear |old_user_read_callback_| // up front. - OldCompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); + if (old_user_read_callback_) { + OldCompletionCallback* c = old_user_read_callback_; + old_user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); + } LeaveFunction(""); } @@ -1864,7 +1897,7 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { completed_handshake_ = true; - if (user_read_callback_) { + if (old_user_read_callback_ || !user_read_callback_.is_null()) { int rv = DoReadLoop(OK); if (rv != ERR_IO_PENDING) DoReadCallback(rv); diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 78e222b..7e75cd5 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -90,6 +90,9 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; @@ -228,7 +231,8 @@ class SSLClientSocketNSS : public SSLClientSocket { OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; - OldCompletionCallback* user_read_callback_; + OldCompletionCallback* old_user_read_callback_; + CompletionCallback user_read_callback_; OldCompletionCallback* user_write_callback_; // Used by Read function. diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index e6e9aac..1237348 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -391,7 +391,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( transport_send_busy_(false), transport_recv_busy_(false), old_user_connect_callback_(NULL), - user_read_callback_(NULL), + old_user_read_callback_(NULL), user_write_callback_(NULL), completed_handshake_(false), client_auth_cert_needed_(false), @@ -614,11 +614,19 @@ SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( void SSLClientSocketOpenSSL::DoReadCallback(int rv) { // Since Run may result in Read being called, clear |user_read_callback_| // up front. - OldCompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); + if (old_user_read_callback_) { + OldCompletionCallback* c = old_user_read_callback_; + old_user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); + } } void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { @@ -702,7 +710,8 @@ void SSLClientSocketOpenSSL::Disconnect() { old_user_connect_callback_ = NULL; user_connect_callback_.Reset(); - user_read_callback_ = NULL; + old_user_read_callback_ = NULL; + user_read_callback_.Reset(); user_write_callback_ = NULL; user_read_buf_ = NULL; user_read_buf_len_ = 0; @@ -1188,6 +1197,23 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf, int rv = DoReadLoop(OK); if (rv == ERR_IO_PENDING) { + old_user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + + return rv; +} +int SSLClientSocketOpenSSL::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { user_read_callback_ = callback; } else { user_read_buf_ = NULL; diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 010930a..a15c0e3 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -80,6 +80,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -122,7 +124,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; - OldCompletionCallback* user_read_callback_; + OldCompletionCallback* old_user_read_callback_; + CompletionCallback user_read_callback_; OldCompletionCallback* user_write_callback_; // Used by Read function. diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index d1ed130..30f599d 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -398,7 +398,7 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, host_and_port_(host_and_port), ssl_config_(ssl_config), old_user_connect_callback_(NULL), - user_read_callback_(NULL), + old_user_read_callback_(NULL), user_read_buf_len_(0), user_write_callback_(NULL), user_write_buf_len_(0), @@ -786,7 +786,48 @@ base::TimeDelta SSLClientSocketWin::GetConnectTimeMicros() const { int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake()); - DCHECK(!user_read_callback_); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + + // If we have surplus decrypted plaintext, satisfy the Read with it without + // reading more ciphertext from the transport socket. + if (bytes_decrypted_ != 0) { + int len = std::min(buf_len, bytes_decrypted_); + net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, len, + decrypted_ptr_); + memcpy(buf->data(), decrypted_ptr_, len); + decrypted_ptr_ += len; + bytes_decrypted_ -= len; + if (bytes_decrypted_ == 0) { + decrypted_ptr_ = NULL; + if (bytes_received_ != 0) { + memmove(recv_buffer_.get(), received_ptr_, bytes_received_); + received_ptr_ = recv_buffer_.get(); + } + } + return len; + } + + DCHECK(!user_read_buf_); + // http://crbug.com/16371: We're seeing |buf->data()| return NULL. See if the + // user is passing in an IOBuffer with a NULL |data_|. + CHECK(buf); + CHECK(buf->data()); + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoPayloadRead(); + if (rv == ERR_IO_PENDING) { + old_user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + return rv; +} +int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake()); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); // If we have surplus decrypted plaintext, satisfy the Read with it without // reading more ciphertext from the transport socket. @@ -866,11 +907,19 @@ void SSLClientSocketWin::OnHandshakeIOComplete(int result) { // (which occurs because we are in the middle of a Read when the // renegotiation process starts). So we complete the Read here. if (!old_user_connect_callback_ && user_connect_callback_.is_null()) { - OldCompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); + if (old_user_read_callback_) { + OldCompletionCallback* c = old_user_read_callback_; + old_user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); + } return; } net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); @@ -893,12 +942,20 @@ void SSLClientSocketWin::OnReadComplete(int result) { if (result > 0) result = DoPayloadDecrypt(); if (result != ERR_IO_PENDING) { - DCHECK(user_read_callback_); - OldCompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(result); + DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); + if (old_user_read_callback_) { + OldCompletionCallback* c = old_user_read_callback_; + old_user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(result); + } else { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(result); + } } } @@ -1579,7 +1636,7 @@ int SSLClientSocketWin::DidCompleteHandshake() { // result of the server certificate received during renegotiation. void SSLClientSocketWin::DidCompleteRenegotiation() { DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); - DCHECK(user_read_callback_); + DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); renegotiating_ = false; next_state_ = STATE_COMPLETED_RENEGOTIATION; } diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 01a5509b..27ce300 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -73,6 +73,8 @@ class SSLClientSocketWin : public SSLClientSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual bool SetReceiveBufferSize(int32 size); @@ -126,7 +128,8 @@ class SSLClientSocketWin : public SSLClientSocket { CompletionCallback user_connect_callback_; // User function to callback when a Read() completes. - OldCompletionCallback* user_read_callback_; + OldCompletionCallback* old_user_read_callback_; + CompletionCallback user_read_callback_; scoped_refptr<IOBuffer> user_read_buf_; int user_read_buf_len_; diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 8ead679..0785dd7d 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -65,7 +65,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( transport_send_busy_(false), transport_recv_busy_(false), user_handshake_callback_(NULL), - user_read_callback_(NULL), + old_user_read_callback_(NULL), user_write_callback_(NULL), nss_fd_(NULL), nss_bufs_(NULL), @@ -154,7 +154,29 @@ int SSLServerSocketNSS::Connect(const CompletionCallback& callback) { int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { - DCHECK(!user_read_callback_); + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(!user_handshake_callback_); + DCHECK(!user_read_buf_); + DCHECK(nss_bufs_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + DCHECK(completed_handshake_); + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + old_user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + return rv; +} +int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); DCHECK(!user_handshake_callback_); DCHECK(!user_read_buf_); DCHECK(nss_bufs_); @@ -717,15 +739,23 @@ void SSLServerSocketNSS::DoHandshakeCallback(int rv) { void SSLServerSocketNSS::DoReadCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_read_callback_); + DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); // Since Run may result in Read being called, clear |user_read_callback_| // up front. - OldCompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); + if (old_user_read_callback_) { + OldCompletionCallback* c = old_user_read_callback_; + old_user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); + } } void SSLServerSocketNSS::DoWriteCallback(int rv) { diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 7967ffa..39283f6 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -41,6 +41,8 @@ class SSLServerSocketNSS : public SSLServerSocket { // Socket interface (via StreamSocket). virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; @@ -109,7 +111,8 @@ class SSLServerSocketNSS : public SSLServerSocket { BoundNetLog net_log_; OldCompletionCallback* user_handshake_callback_; - OldCompletionCallback* user_read_callback_; + OldCompletionCallback* old_user_read_callback_; + CompletionCallback user_read_callback_; OldCompletionCallback* user_write_callback_; // Used by Read function. diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index 5af50f8..eb9dc7c 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -51,7 +51,7 @@ namespace { class FakeDataChannel { public: FakeDataChannel() - : read_callback_(NULL), + : old_read_callback_(NULL), read_buf_len_(0), ALLOW_THIS_IN_INITIALIZER_LIST(task_factory_(this)) { } @@ -59,6 +59,16 @@ class FakeDataChannel { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { if (data_.empty()) { + old_read_callback_ = callback; + read_buf_ = buf; + read_buf_len_ = buf_len; + return net::ERR_IO_PENDING; + } + return PropogateData(buf, buf_len); + } + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (data_.empty()) { read_callback_ = callback; read_buf_ = buf; read_buf_len_ = buf_len; @@ -78,15 +88,23 @@ class FakeDataChannel { private: void DoReadCallback() { - if (!read_callback_ || data_.empty()) + if ((!old_read_callback_ && read_callback_.is_null()) || data_.empty()) return; int copied = PropogateData(read_buf_, read_buf_len_); - net::OldCompletionCallback* callback = read_callback_; - read_callback_ = NULL; - read_buf_ = NULL; - read_buf_len_ = 0; - callback->Run(copied); + if (old_read_callback_) { + net::OldCompletionCallback* callback = old_read_callback_; + old_read_callback_ = NULL; + read_buf_ = NULL; + read_buf_len_ = 0; + callback->Run(copied); + } else { + net::CompletionCallback callback = read_callback_; + read_callback_.Reset(); + read_buf_ = NULL; + read_buf_len_ = 0; + callback.Run(copied); + } } int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { @@ -100,7 +118,8 @@ class FakeDataChannel { return copied; } - net::OldCompletionCallback* read_callback_; + net::OldCompletionCallback* old_read_callback_; + net::CompletionCallback read_callback_; scoped_refptr<net::IOBuffer> read_buf_; int read_buf_len_; @@ -128,6 +147,12 @@ class FakeSocket : public StreamSocket { buf_len = rand() % buf_len + 1; return incoming_->Read(buf, buf_len, callback); } + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + // Read random number of bytes. + buf_len = rand() % buf_len + 1; + return incoming_->Read(buf, buf_len, callback); + } virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index 7c8af80..3c99ae5 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -130,7 +130,7 @@ TCPClientSocketLibevent::TCPClientSocketLibevent( current_ai_(NULL), read_watcher_(this), write_watcher_(this), - read_callback_(NULL), + old_read_callback_(NULL), old_write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), @@ -465,7 +465,7 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); DCHECK(!waiting_connect()); - DCHECK(!read_callback_); + DCHECK(!old_read_callback_ && read_callback_.is_null()); // Synchronous operation not supported DCHECK(callback); DCHECK_GT(buf_len, 0); @@ -495,6 +495,45 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, read_buf_ = buf; read_buf_len_ = buf_len; + old_read_callback_ = callback; + return ERR_IO_PENDING; +} +int TCPClientSocketLibevent::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_); + DCHECK(!waiting_connect()); + DCHECK(!old_read_callback_ && read_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_GT(buf_len, 0); + + int nread = HANDLE_EINTR(read(socket_, buf->data(), buf_len)); + if (nread >= 0) { + base::StatsCounter read_bytes("tcp.read_bytes"); + read_bytes.Add(nread); + num_bytes_read_ += static_cast<int64>(nread); + if (nread > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread, + buf->data()); + return nread; + } + if (errno != EAGAIN && errno != EWOULDBLOCK) { + DVLOG(1) << "read failed, errno " << errno; + return MapSystemError(errno); + } + + if (!MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, MessageLoopForIO::WATCH_READ, + &read_socket_watcher_, &read_watcher_)) { + DVLOG(1) << "WatchFileDescriptor failed on read, errno " << errno; + return MapSystemError(errno); + } + + read_buf_ = buf; + read_buf_len_ = buf_len; read_callback_ = callback; return ERR_IO_PENDING; } @@ -618,12 +657,18 @@ void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { void TCPClientSocketLibevent::DoReadCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(read_callback_); + DCHECK(old_read_callback_ || !read_callback_.is_null()); // since Run may result in Read being called, clear read_callback_ up front. - OldCompletionCallback* c = read_callback_; - read_callback_ = NULL; - c->Run(rv); + if (old_read_callback_) { + OldCompletionCallback* c = old_read_callback_; + old_read_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); + } } void TCPClientSocketLibevent::DoWriteCallback(int rv) { diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index ac73f2c..47f19a0 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -64,6 +64,9 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; @@ -85,7 +88,7 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, // MessageLoopForIO::Watcher methods virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE { - if (socket_->read_callback_) + if (socket_->old_read_callback_) socket_->DidCompleteRead(); } @@ -176,7 +179,8 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, int write_buf_len_; // External callback; called when read is complete. - OldCompletionCallback* read_callback_; + OldCompletionCallback* old_read_callback_; + CompletionCallback read_callback_; // External callback; called when write is complete. OldCompletionCallback* old_write_callback_; diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index 1267d06..cb59a98 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -744,6 +744,47 @@ int TCPClientSocketWin::Read(IOBuffer* buf, core_->read_iobuffer_ = buf; return ERR_IO_PENDING; } +int TCPClientSocketWin::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_read_); + DCHECK(!old_read_callback_ && read_callback_.is_null()); + DCHECK(!core_->read_iobuffer_); + + buf_len = core_->ThrottleReadSize(buf_len); + + core_->read_buffer_.len = buf_len; + core_->read_buffer_.buf = buf->data(); + + // TODO(wtc): Remove the assertion after enough testing. + AssertEventNotSignaled(core_->read_overlapped_.hEvent); + DWORD num, flags = 0; + int rv = WSARecv(socket_, &core_->read_buffer_, 1, &num, &flags, + &core_->read_overlapped_, NULL); + if (rv == 0) { + if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) { + base::StatsCounter read_bytes("tcp.read_bytes"); + read_bytes.Add(num); + num_bytes_read_ += num; + if (num > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num, + core_->read_buffer_.buf); + return static_cast<int>(num); + } + } else { + int os_error = WSAGetLastError(); + if (os_error != WSA_IO_PENDING) + return MapSystemError(os_error); + } + core_->WatchForRead(); + waiting_read_ = true; + read_callback_ = callback; + core_->read_iobuffer_ = buf; + return ERR_IO_PENDING; +} int TCPClientSocketWin::Write(IOBuffer* buf, int buf_len, diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index bda2585..1e75933 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -61,6 +61,8 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, // Multiple outstanding requests are not supported. // Full duplex mode (reading and writing at the same time) is supported virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual bool SetReceiveBufferSize(int32 size); diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 6604727..56b1fa9c 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -94,11 +94,15 @@ class MockClientSocket : public StreamSocket { return base::TimeDelta::FromMicroseconds(-1); } - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { return ERR_FAILED; @@ -151,11 +155,15 @@ class MockFailingClientSocket : public StreamSocket { return base::TimeDelta::FromMicroseconds(-1); } - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { @@ -238,11 +246,15 @@ class MockPendingClientSocket : public StreamSocket { return base::TimeDelta::FromMicroseconds(-1); } - // Socket methods: + // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { diff --git a/net/socket/web_socket_server_socket.cc b/net/socket/web_socket_server_socket.cc index 4bb7b9f..d792689 100644 --- a/net/socket/web_socket_server_socket.cc +++ b/net/socket/web_socket_server_socket.cc @@ -135,8 +135,11 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { it->type == PendingReq::TYPE_READ && it->io_buf != NULL && it->io_buf->data() != NULL && - it->callback != 0) { - it->callback->Run(0); // Report EOF. + (it->old_callback || !it->callback.is_null())) { + if (it->old_callback) + it->old_callback->Run(0); // Report EOF. + else + it->callback.Run(0); } } @@ -175,6 +178,26 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { net::OldCompletionCallback* callback) : type(type), io_buf(io_buf), + old_callback(callback) { + switch (type) { + case PendingReq::TYPE_READ: + case PendingReq::TYPE_WRITE: + case PendingReq::TYPE_READ_METADATA: + case PendingReq::TYPE_WRITE_METADATA: { + DCHECK(io_buf); + break; + } + default: { + NOTREACHED(); + break; + } + } + } + PendingReq(Type type, net::DrainableIOBuffer* io_buf, + const net::CompletionCallback& callback) + : type(type), + io_buf(io_buf), + old_callback(NULL), callback(callback) { switch (type) { case PendingReq::TYPE_READ: @@ -193,7 +216,8 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { Type type; scoped_refptr<net::DrainableIOBuffer> io_buf; - net::OldCompletionCallback* callback; + net::OldCompletionCallback* old_callback; + net::CompletionCallback callback; }; // Socket implementation. @@ -261,6 +285,70 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { } return net::ERR_IO_PENDING; } + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE { + if (buf_len == 0) + return 0; + if (buf == NULL || buf_len < 0) { + NOTREACHED(); + return net::ERR_INVALID_ARGUMENT; + } + while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() - + process_handshake_buf_->BytesConsumed()) { + DCHECK(!is_transport_read_pending_); + DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end()); + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + int n = std::min(bytes_remaining, buf_len); + int rv = ProcessDataFrames( + process_handshake_buf_->data(), n, buf->data(), buf_len); + process_handshake_buf_->DidConsume(n); + if (rv == 0) { + // ProcessDataFrames may return zero for non-empty buffer if it + // contains only frame delimiters without real data. In this case: + // try again and do not just return zero (zero stands for EOF). + continue; + } + return rv; + } + case PHASE_SHUT: { + return 0; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + } + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + pending_reqs_.push_back(PendingReq( + PendingReq::TYPE_READ, + new net::DrainableIOBuffer(buf, buf_len), + callback)); + ConsiderTransportRead(); + break; + } + case PHASE_SHUT: { + return 0; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + return net::ERR_IO_PENDING; + } virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE { @@ -397,8 +485,10 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { if (result != 0) { while (!pending_reqs_.empty()) { PendingReq& req = pending_reqs_.front(); - if (req.callback) - req.callback->Run(result); + if (req.old_callback) + req.old_callback->Run(result); + else if (!req.callback.is_null()) + req.callback.Run(result); pending_reqs_.pop_front(); } transport_socket_.reset(); // terminate underlying connection. @@ -447,11 +537,15 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { if (rv > 0) { process_handshake_buf_->DidConsume(rv); phase_ = PHASE_FRAME_OUTSIDE; - net::OldCompletionCallback* cb = pending_reqs_.front().callback; + net::OldCompletionCallback* old_cb = + pending_reqs_.front().old_callback; + net::CompletionCallback cb = pending_reqs_.front().callback; pending_reqs_.pop_front(); ConsiderTransportWrite(); // Schedule answer handshake. - if (cb) - cb->Run(0); + if (old_cb) + old_cb->Run(0); + else if (!cb.is_null()) + cb.Run(0); } else if (rv == net::ERR_IO_PENDING) { if (fill_handshake_buf_->BytesRemaining() < 1) Shut(net::ERR_LIMIT_VIOLATION); @@ -474,10 +568,13 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { return; } if (rv > 0 || phase_ == PHASE_SHUT) { - net::OldCompletionCallback* cb = it->callback; + net::OldCompletionCallback* old_cb = it->old_callback; + net::CompletionCallback cb = it->callback; pending_reqs_.erase(it); - if (cb) - cb->Run(rv); + if (old_cb) + old_cb->Run(rv); + else if (!cb.is_null()) + cb.Run(rv); } break; } @@ -515,12 +612,15 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { DCHECK_LE(result, it->io_buf->BytesRemaining()); it->io_buf->DidConsume(result); if (it->io_buf->BytesRemaining() == 0) { - net::OldCompletionCallback* cb = it->callback; + net::OldCompletionCallback* old_cb = it->old_callback; + net::CompletionCallback cb = it->callback; int bytes_written = it->io_buf->BytesConsumed(); DCHECK_GT(bytes_written, 0); pending_reqs_.erase(it); - if (cb) - cb->Run(bytes_written); + if (old_cb) + old_cb->Run(bytes_written); + else if (!cb.is_null()) + cb.Run(bytes_written); } ConsiderTransportWrite(); } diff --git a/net/socket/web_socket_server_socket_unittest.cc b/net/socket/web_socket_server_socket_unittest.cc index 476fc33..cabb4b9 100644 --- a/net/socket/web_socket_server_socket_unittest.cc +++ b/net/socket/web_socket_server_socket_unittest.cc @@ -79,16 +79,22 @@ class TestingTransportSocket : public net::Socket { net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) : sample_(sample), answer_(answer), - final_read_callback_(NULL), + old_final_read_callback_(NULL), method_factory_(this) { } ~TestingTransportSocket() { - if (final_read_callback_) { + if (old_final_read_callback_) { MessageLoop::current()->PostTask(FROM_HERE, method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoOldReadCallback, + old_final_read_callback_, 0)); + } else if (!final_read_callback_.is_null()) { + MessageLoop::current()->PostTask( + FROM_HERE, + method_factory_.NewRunnableMethod( &TestingTransportSocket::DoReadCallback, - final_read_callback_, 0)); + final_read_callback_, 0)); } } @@ -98,7 +104,28 @@ class TestingTransportSocket : public net::Socket { CHECK_GT(buf_len, 0); int remaining = sample_->BytesRemaining(); if (remaining < 1) { - if (final_read_callback_) + if (old_final_read_callback_ || !final_read_callback_.is_null()) + return 0; + old_final_read_callback_ = callback; + return net::ERR_IO_PENDING; + } + int lot = GetRand(1, std::min(remaining, buf_len)); + std::copy(sample_->data(), sample_->data() + lot, buf->data()); + sample_->DidConsume(lot); + if (GetRand(0, 1)) { + return lot; + } + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoOldReadCallback, callback, lot)); + return net::ERR_IO_PENDING; + } + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + CHECK_GT(buf_len, 0); + int remaining = sample_->BytesRemaining(); + if (remaining < 1) { + if (old_final_read_callback_ || !final_read_callback_.is_null()) return 0; final_read_callback_ = callback; return net::ERR_IO_PENDING; @@ -144,16 +171,26 @@ class TestingTransportSocket : public net::Socket { net::DrainableIOBuffer* answer() { return answer_.get(); } - void DoReadCallback(net::OldCompletionCallback* callback, int result) { + void DoOldReadCallback(net::OldCompletionCallback* callback, int result) { if (result == 0 && !is_closed_) { MessageLoop::current()->PostTask(FROM_HERE, method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoReadCallback, callback, 0)); + &TestingTransportSocket::DoOldReadCallback, callback, 0)); } else { if (callback) callback->Run(result); } } + void DoReadCallback(const net::CompletionCallback& callback, int result) { + if (result == 0 && !is_closed_) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoReadCallback, callback, 0)); + } else { + if (!callback.is_null()) + callback.Run(result); + } + } void DoWriteCallback(net::OldCompletionCallback* callback, int result) { if (callback) @@ -169,7 +206,8 @@ class TestingTransportSocket : public net::Socket { scoped_refptr<net::DrainableIOBuffer> answer_; // Final read callback to report zero (zero stands for EOF). - net::OldCompletionCallback* final_read_callback_; + net::OldCompletionCallback* old_final_read_callback_; + net::CompletionCallback final_read_callback_; ScopedRunnableMethodFactory<TestingTransportSocket> method_factory_; }; |