diff options
59 files changed, 1060 insertions, 170 deletions
diff --git a/content/browser/renderer_host/p2p/socket_host_test_utils.h b/content/browser/renderer_host/p2p/socket_host_test_utils.h index 0f1aeb4..79b4794 100644 --- a/content/browser/renderer_host/p2p/socket_host_test_utils.h +++ b/content/browser/renderer_host/p2p/socket_host_test_utils.h @@ -62,6 +62,8 @@ class FakeSocket : public net::StreamSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -179,6 +181,15 @@ int FakeSocket::Write(net::IOBuffer* buf, int buf_len, } return buf_len; } +int FakeSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + if (written_data_) { + written_data_->insert(written_data_->end(), + buf->data(), buf->data() + buf_len); + } + return buf_len; +} bool FakeSocket::SetReceiveBufferSize(int32 size) { diff --git a/jingle/glue/channel_socket_adapter.cc b/jingle/glue/channel_socket_adapter.cc index 3e22a35..39b3eb45c 100644 --- a/jingle/glue/channel_socket_adapter.cc +++ b/jingle/glue/channel_socket_adapter.cc @@ -19,7 +19,7 @@ TransportChannelSocketAdapter::TransportChannelSocketAdapter( : message_loop_(MessageLoop::current()), channel_(channel), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), closed_error_code_(net::OK) { DCHECK(channel_); @@ -73,11 +73,47 @@ int TransportChannelSocketAdapter::Read( } int TransportChannelSocketAdapter::Write( - net::IOBuffer* buffer, int buffer_size, net::OldCompletionCallback* callback) { + net::IOBuffer* buffer, int buffer_size, + net::OldCompletionCallback* callback) { DCHECK_EQ(MessageLoop::current(), message_loop_); DCHECK(buffer); DCHECK(callback); - CHECK(!write_callback_); + CHECK(!old_write_callback_ && write_callback_.is_null()); + + if (!channel_) { + DCHECK(closed_error_code_ != net::OK); + return closed_error_code_; + } + + int result; + if (channel_->writable()) { + result = channel_->SendPacket(buffer->data(), buffer_size); + if (result < 0) { + result = net::MapSystemError(channel_->GetError()); + + // If the underlying socket returns IO pending where it shouldn't we + // pretend the packet is dropped and return as succeeded because no + // writeable callback will happen. + if (result == net::ERR_IO_PENDING) + result = net::OK; + } + } else { + // Channel is not writable yet. + result = net::ERR_IO_PENDING; + old_write_callback_ = callback; + write_buffer_ = buffer; + write_buffer_size_ = buffer_size; + } + + return result; +} +int TransportChannelSocketAdapter::Write( + net::IOBuffer* buffer, int buffer_size, + const net::CompletionCallback& callback) { + DCHECK_EQ(MessageLoop::current(), message_loop_); + DCHECK(buffer); + DCHECK(!callback.is_null()); + CHECK(!old_write_callback_ && write_callback_.is_null()); if (!channel_) { DCHECK(closed_error_code_ != net::OK); @@ -141,11 +177,16 @@ void TransportChannelSocketAdapter::Close(int error_code) { callback.Run(error_code); } - if (write_callback_) { - net::OldCompletionCallback* callback = write_callback_; - write_callback_ = NULL; + if (old_write_callback_) { + net::OldCompletionCallback* callback = old_write_callback_; + old_write_callback_ = NULL; write_buffer_ = NULL; callback->Run(error_code); + } else if (!write_callback_.is_null()) { + net::CompletionCallback callback = write_callback_; + write_callback_.Reset(); + write_buffer_ = NULL; + callback.Run(error_code); } } @@ -186,17 +227,24 @@ void TransportChannelSocketAdapter::OnWritableState( cricket::TransportChannel* channel) { DCHECK_EQ(MessageLoop::current(), message_loop_); // Try to send the packet if there is a pending write. - if (write_callback_) { + if (old_write_callback_ || !write_callback_.is_null()) { int result = channel_->SendPacket(write_buffer_->data(), write_buffer_size_); if (result < 0) result = net::MapSystemError(channel_->GetError()); if (result != net::ERR_IO_PENDING) { - net::OldCompletionCallback* callback = write_callback_; - write_callback_ = NULL; - write_buffer_ = NULL; - callback->Run(result); + if (old_write_callback_) { + net::OldCompletionCallback* callback = old_write_callback_; + old_write_callback_ = NULL; + write_buffer_ = NULL; + callback->Run(result); + } else { + net::CompletionCallback callback = write_callback_; + write_callback_.Reset(); + write_buffer_ = NULL; + callback.Run(result); + } } } } diff --git a/jingle/glue/channel_socket_adapter.h b/jingle/glue/channel_socket_adapter.h index 1f367e8..75af022 100644 --- a/jingle/glue/channel_socket_adapter.h +++ b/jingle/glue/channel_socket_adapter.h @@ -41,6 +41,8 @@ class TransportChannelSocketAdapter : public net::Socket, const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -60,7 +62,8 @@ class TransportChannelSocketAdapter : public net::Socket, scoped_refptr<net::IOBuffer> read_buffer_; int read_buffer_size_; - net::OldCompletionCallback* write_callback_; // Not owned. + net::OldCompletionCallback* old_write_callback_; // Not owned. + net::CompletionCallback write_callback_; scoped_refptr<net::IOBuffer> write_buffer_; int write_buffer_size_; diff --git a/jingle/glue/pseudotcp_adapter.cc b/jingle/glue/pseudotcp_adapter.cc index 7807eed..84b0dcc 100644 --- a/jingle/glue/pseudotcp_adapter.cc +++ b/jingle/glue/pseudotcp_adapter.cc @@ -35,6 +35,8 @@ class PseudoTcpAdapter::Core : public cricket::IPseudoTcpNotify, const net::CompletionCallback& callback); int Write(net::IOBuffer* buffer, int buffer_size, net::OldCompletionCallback* callback); + int Write(net::IOBuffer* buffer, int buffer_size, + const net::CompletionCallback& callback); int Connect(net::OldCompletionCallback* callback); int Connect(const net::CompletionCallback& callback); void Disconnect(); @@ -75,7 +77,8 @@ class PseudoTcpAdapter::Core : public cricket::IPseudoTcpNotify, net::CompletionCallback connect_callback_; net::OldCompletionCallback* old_read_callback_; net::CompletionCallback read_callback_; - net::OldCompletionCallback* write_callback_; + net::OldCompletionCallback* old_write_callback_; + net::CompletionCallback write_callback_; cricket::PseudoTcp pseudo_tcp_; scoped_ptr<net::Socket> socket_; @@ -100,7 +103,7 @@ class PseudoTcpAdapter::Core : public cricket::IPseudoTcpNotify, PseudoTcpAdapter::Core::Core(net::Socket* socket) : old_connect_callback_(NULL), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), ALLOW_THIS_IN_INITIALIZER_LIST(pseudo_tcp_(this, 0)), socket_(socket), socket_write_pending_(false), @@ -164,7 +167,30 @@ int PseudoTcpAdapter::Core::Read(net::IOBuffer* buffer, int buffer_size, int PseudoTcpAdapter::Core::Write(net::IOBuffer* buffer, int buffer_size, net::OldCompletionCallback* callback) { - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + + // Reference the Core in case a callback deletes the adapter. + scoped_refptr<Core> core(this); + + int result = pseudo_tcp_.Send(buffer->data(), buffer_size); + if (result < 0) { + result = net::MapSystemError(pseudo_tcp_.GetError()); + DCHECK(result < 0); + } + + if (result == net::ERR_IO_PENDING) { + write_buffer_ = buffer; + write_buffer_size_ = buffer_size; + old_write_callback_ = callback; + } + + AdjustClock(); + + return result; +} +int PseudoTcpAdapter::Core::Write(net::IOBuffer* buffer, int buffer_size, + const net::CompletionCallback& callback) { + DCHECK(!old_write_callback_ && write_callback_.is_null()); // Reference the Core in case a callback deletes the adapter. scoped_refptr<Core> core(this); @@ -231,7 +257,8 @@ void PseudoTcpAdapter::Core::Disconnect() { old_read_callback_ = NULL; read_callback_.Reset(); read_buffer_ = NULL; - write_callback_ = NULL; + old_write_callback_ = NULL; + write_callback_.Reset(); write_buffer_ = NULL; old_connect_callback_ = NULL; connect_callback_.Reset(); @@ -297,7 +324,7 @@ void PseudoTcpAdapter::Core::OnTcpReadable(PseudoTcp* tcp) { void PseudoTcpAdapter::Core::OnTcpWriteable(PseudoTcp* tcp) { DCHECK_EQ(tcp, &pseudo_tcp_); - if (!write_callback_) + if (!old_write_callback_ && write_callback_.is_null()) return; int result = pseudo_tcp_.Send(write_buffer_->data(), write_buffer_size_); @@ -310,10 +337,17 @@ void PseudoTcpAdapter::Core::OnTcpWriteable(PseudoTcp* tcp) { AdjustClock(); - net::OldCompletionCallback* callback = write_callback_; - write_callback_ = NULL; - write_buffer_ = NULL; - callback->Run(result); + if (old_write_callback_) { + net::OldCompletionCallback* callback = old_write_callback_; + old_write_callback_ = NULL; + write_buffer_ = NULL; + callback->Run(result); + } else { + net::CompletionCallback callback = write_callback_; + write_callback_.Reset(); + write_buffer_ = NULL; + callback.Run(result); + } } void PseudoTcpAdapter::Core::OnTcpClosed(PseudoTcp* tcp, uint32 error) { @@ -339,10 +373,14 @@ void PseudoTcpAdapter::Core::OnTcpClosed(PseudoTcp* tcp, uint32 error) { callback.Run(net::MapSystemError(error)); } - if (write_callback_) { - net::OldCompletionCallback* callback = write_callback_; - write_callback_ = NULL; + if (old_write_callback_) { + net::OldCompletionCallback* callback = old_write_callback_; + old_write_callback_ = NULL; callback->Run(net::MapSystemError(error)); + } else if (!write_callback_.is_null()) { + net::CompletionCallback callback = write_callback_; + write_callback_.Reset(); + callback.Run(net::MapSystemError(error)); } } @@ -480,6 +518,11 @@ int PseudoTcpAdapter::Write(net::IOBuffer* buffer, int buffer_size, DCHECK(CalledOnValidThread()); return core_->Write(buffer, buffer_size, callback); } +int PseudoTcpAdapter::Write(net::IOBuffer* buffer, int buffer_size, + const net::CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + return core_->Write(buffer, buffer_size, callback); +} bool PseudoTcpAdapter::SetReceiveBufferSize(int32 size) { DCHECK(CalledOnValidThread()); diff --git a/jingle/glue/pseudotcp_adapter.h b/jingle/glue/pseudotcp_adapter.h index f0d27ee..6d3cf8c 100644 --- a/jingle/glue/pseudotcp_adapter.h +++ b/jingle/glue/pseudotcp_adapter.h @@ -36,6 +36,8 @@ class PseudoTcpAdapter : public net::StreamSocket, base::NonThreadSafe { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buffer, int buffer_size, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buffer, int buffer_size, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/jingle/glue/pseudotcp_adapter_unittest.cc b/jingle/glue/pseudotcp_adapter_unittest.cc index 6d803a5..94b403f 100644 --- a/jingle/glue/pseudotcp_adapter_unittest.cc +++ b/jingle/glue/pseudotcp_adapter_unittest.cc @@ -174,6 +174,20 @@ class FakeSocket : public net::Socket { return buf_len; } + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE { + DCHECK(buf); + if (peer_socket_) { + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&FakeSocket::AppendInputPacket, + base::Unretained(peer_socket_), + std::vector<char>(buf->data(), buf->data() + buf_len)), + latency_ms_); + } + + return buf_len; + } virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { NOTIMPLEMENTED(); diff --git a/jingle/notifier/base/fake_ssl_client_socket.cc b/jingle/notifier/base/fake_ssl_client_socket.cc index bdad879..9c13f52 100644 --- a/jingle/notifier/base/fake_ssl_client_socket.cc +++ b/jingle/notifier/base/fake_ssl_client_socket.cc @@ -112,7 +112,13 @@ int FakeSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, } int FakeSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { + net::OldCompletionCallback* callback) { + DCHECK_EQ(next_handshake_state_, STATE_NONE); + DCHECK(handshake_completed_); + return transport_socket_->Write(buf, buf_len, callback); +} +int FakeSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { DCHECK_EQ(next_handshake_state_, STATE_NONE); DCHECK(handshake_completed_); return transport_socket_->Write(buf, buf_len, callback); diff --git a/jingle/notifier/base/fake_ssl_client_socket.h b/jingle/notifier/base/fake_ssl_client_socket.h index 623e21b..047923d 100644 --- a/jingle/notifier/base/fake_ssl_client_socket.h +++ b/jingle/notifier/base/fake_ssl_client_socket.h @@ -52,6 +52,8 @@ class FakeSSLClientSocket : public net::StreamSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; diff --git a/jingle/notifier/base/fake_ssl_client_socket_unittest.cc b/jingle/notifier/base/fake_ssl_client_socket_unittest.cc index eb3ba5a..28ee5db 100644 --- a/jingle/notifier/base/fake_ssl_client_socket_unittest.cc +++ b/jingle/notifier/base/fake_ssl_client_socket_unittest.cc @@ -51,6 +51,8 @@ class MockClientSocket : public net::StreamSocket { MOCK_METHOD3(Read, int(net::IOBuffer*, int, const net::CompletionCallback&)); MOCK_METHOD3(Write, int(net::IOBuffer*, int, net::OldCompletionCallback*)); + MOCK_METHOD3(Write, int(net::IOBuffer*, int, + const net::CompletionCallback&)); MOCK_METHOD1(SetReceiveBufferSize, bool(int32)); MOCK_METHOD1(SetSendBufferSize, bool(int32)); MOCK_METHOD1(Connect, int(net::OldCompletionCallback*)); diff --git a/jingle/notifier/base/proxy_resolving_client_socket.cc b/jingle/notifier/base/proxy_resolving_client_socket.cc index 3d46ed6b..c7d54c2 100644 --- a/jingle/notifier/base/proxy_resolving_client_socket.cc +++ b/jingle/notifier/base/proxy_resolving_client_socket.cc @@ -89,6 +89,13 @@ int ProxyResolvingClientSocket::Write(net::IOBuffer* buf, int buf_len, NOTREACHED(); return net::ERR_SOCKET_NOT_CONNECTED; } +int ProxyResolvingClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + if (transport_.get() && transport_->socket()) + return transport_->socket()->Write(buf, buf_len, callback); + NOTREACHED(); + return net::ERR_SOCKET_NOT_CONNECTED; +} bool ProxyResolvingClientSocket::SetReceiveBufferSize(int32 size) { if (transport_.get() && transport_->socket()) diff --git a/jingle/notifier/base/proxy_resolving_client_socket.h b/jingle/notifier/base/proxy_resolving_client_socket.h index 4463364..6eae9f4 100644 --- a/jingle/notifier/base/proxy_resolving_client_socket.h +++ b/jingle/notifier/base/proxy_resolving_client_socket.h @@ -52,6 +52,8 @@ class ProxyResolvingClientSocket : public net::StreamSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; diff --git a/net/curvecp/curvecp_client_socket.cc b/net/curvecp/curvecp_client_socket.cc index 6772e38..d6889cf 100644 --- a/net/curvecp/curvecp_client_socket.cc +++ b/net/curvecp/curvecp_client_socket.cc @@ -116,6 +116,11 @@ int CurveCPClientSocket::Write(IOBuffer* buf, OldCompletionCallback* callback) { return messenger_.Write(buf, buf_len, callback); } +int CurveCPClientSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + return messenger_.Write(buf, buf_len, callback); +} bool CurveCPClientSocket::SetReceiveBufferSize(int32 size) { return true; diff --git a/net/curvecp/curvecp_client_socket.h b/net/curvecp/curvecp_client_socket.h index 8062c87..90e2253 100644 --- a/net/curvecp/curvecp_client_socket.h +++ b/net/curvecp/curvecp_client_socket.h @@ -51,6 +51,9 @@ class CurveCPClientSocket : public StreamSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/curvecp/curvecp_server_socket.cc b/net/curvecp/curvecp_server_socket.cc index 6e7934c..b48f82f 100644 --- a/net/curvecp/curvecp_server_socket.cc +++ b/net/curvecp/curvecp_server_socket.cc @@ -63,6 +63,11 @@ int CurveCPServerSocket::Write(IOBuffer* buf, OldCompletionCallback* callback) { return messenger_.Write(buf, buf_len, callback); } +int CurveCPServerSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + return messenger_.Write(buf, buf_len, callback); +} bool CurveCPServerSocket::SetReceiveBufferSize(int32 size) { return true; diff --git a/net/curvecp/curvecp_server_socket.h b/net/curvecp/curvecp_server_socket.h index a2ea1d7..bd005f3 100644 --- a/net/curvecp/curvecp_server_socket.h +++ b/net/curvecp/curvecp_server_socket.h @@ -43,6 +43,9 @@ class CurveCPServerSocket : public Socket, virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/curvecp/messenger.cc b/net/curvecp/messenger.cc index a99abda..fc9adf5 100644 --- a/net/curvecp/messenger.cc +++ b/net/curvecp/messenger.cc @@ -58,7 +58,7 @@ static const size_t kReceiveBufferSize = (128 * 1024); Messenger::Messenger(Packetizer* packetizer) : packetizer_(packetizer), send_buffer_(kSendBufferSize), - send_complete_callback_(NULL), + old_send_complete_callback_(NULL), old_receive_complete_callback_(NULL), pending_receive_length_(0), send_message_in_progress_(false), @@ -104,10 +104,31 @@ int Messenger::Read(IOBuffer* buf, int buf_len, return bytes_read; } -int Messenger::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { +int Messenger::Write( + IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(CalledOnValidThread()); DCHECK(!pending_send_.get()); // Already a write pending! - DCHECK(!send_complete_callback_); + DCHECK(!old_send_complete_callback_ && send_complete_callback_.is_null()); + DCHECK_LT(0, buf_len); + + int len = send_buffer_.write(buf->data(), buf_len); + if (!send_timer_.IsRunning()) + send_timer_.Start(FROM_HERE, base::TimeDelta(), + this, &Messenger::OnSendTimer); + if (len) + return len; + + // We couldn't add data to the send buffer, so block the application. + pending_send_ = buf; + pending_send_length_ = buf_len; + old_send_complete_callback_ = callback; + return ERR_IO_PENDING; +} +int Messenger::Write( + IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(!pending_send_.get()); // Already a write pending! + DCHECK(!old_send_complete_callback_ && send_complete_callback_.is_null()); DCHECK_LT(0, buf_len); int len = send_buffer_.write(buf->data(), buf_len); @@ -168,15 +189,21 @@ IOBufferWithSize* Messenger::CreateBufferFromSendQueue() { DCHECK_EQ(bytes, length); // We consumed data, check to see if someone is waiting to write more data. - if (send_complete_callback_) { + if (old_send_complete_callback_ || !send_complete_callback_.is_null()) { DCHECK(pending_send_.get()); int len = send_buffer_.write(pending_send_->data(), pending_send_length_); if (len) { pending_send_ = NULL; - OldCompletionCallback* callback = send_complete_callback_; - send_complete_callback_ = NULL; - callback->Run(len); + if (old_send_complete_callback_) { + OldCompletionCallback* callback = old_send_complete_callback_; + old_send_complete_callback_ = NULL; + callback->Run(len); + } else { + CompletionCallback callback = send_complete_callback_; + send_complete_callback_.Reset(); + callback.Run(len); + } } } diff --git a/net/curvecp/messenger.h b/net/curvecp/messenger.h index bb67946..b71c684 100644 --- a/net/curvecp/messenger.h +++ b/net/curvecp/messenger.h @@ -36,6 +36,7 @@ class Messenger : public base::NonThreadSafe, int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); // Packetizer::Listener implementation. virtual void OnConnection(ConnectionKey key) OVERRIDE; @@ -71,7 +72,8 @@ class Messenger : public base::NonThreadSafe, // The send_buffer is a list of pending data to pack into messages and send // to the remote. CircularBuffer send_buffer_; - OldCompletionCallback* send_complete_callback_; + OldCompletionCallback* old_send_complete_callback_; + CompletionCallback send_complete_callback_; scoped_refptr<IOBuffer> pending_send_; int pending_send_length_; diff --git a/net/http/http_proxy_client_socket.cc b/net/http/http_proxy_client_socket.cc index ef9fb82..8e72f7d 100644 --- a/net/http/http_proxy_client_socket.cc +++ b/net/http/http_proxy_client_socket.cc @@ -254,7 +254,14 @@ int HttpProxyClientSocket::Read(IOBuffer* buf, int buf_len, int HttpProxyClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK_EQ(STATE_DONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} +int HttpProxyClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK_EQ(STATE_DONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } diff --git a/net/http/http_proxy_client_socket.h b/net/http/http_proxy_client_socket.h index 662c305..6161bd2 100644 --- a/net/http/http_proxy_client_socket.h +++ b/net/http/http_proxy_client_socket.h @@ -84,6 +84,9 @@ class HttpProxyClientSocket : public ProxyClientSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int GetPeerAddress(AddressList* address) const OVERRIDE; diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 5ef38a7..6405362 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -78,6 +78,11 @@ class MockClientSocket : public StreamSocket { was_used_to_convey_data_ = true; return len; } + virtual int Write( + IOBuffer* /* buf */, int len, const CompletionCallback& /* callback */) { + was_used_to_convey_data_ = true; + return len; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } diff --git a/net/socket/socket.h b/net/socket/socket.h index c185c44..b02de0a 100644 --- a/net/socket/socket.h +++ b/net/socket/socket.h @@ -48,6 +48,8 @@ class NET_EXPORT Socket { // Disconnected before the write completes, the callback will not be invoked. virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) = 0; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; // Set the receive buffer size (in bytes) for the socket. // Note: changing this value can affect the TCP window size on some platforms. diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 376d542..7730a5b 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -832,6 +832,26 @@ int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, return write_result.result; } +int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + was_used_to_convey_data_ = true; + + if (write_result.async) { + RunCallbackAsync(callback, write_result.result); + return net::ERR_IO_PENDING; + } + + return write_result.result; +} int MockTCPClientSocket::Connect(net::OldCompletionCallback* callback) { if (connected_) @@ -972,7 +992,7 @@ DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( net::NetLog* net_log, net::DeterministicSocketData* data) : MockClientSocket(net_log), write_pending_(false), - write_callback_(NULL), + old_write_callback_(NULL), write_result_(0), read_data_(), read_buf_(NULL), @@ -987,7 +1007,10 @@ DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} void DeterministicMockTCPClientSocket::CompleteWrite() { was_used_to_convey_data_ = true; write_pending_ = false; - write_callback_->Run(write_result_); + if (old_write_callback_) + old_write_callback_->Run(write_result_); + else + write_callback_.Run(write_result_); } int DeterministicMockTCPClientSocket::CompleteRead() { @@ -1021,30 +1044,6 @@ int DeterministicMockTCPClientSocket::CompleteRead() { return result; } -int DeterministicMockTCPClientSocket::Write( - net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { - DCHECK(buf); - DCHECK_GT(buf_len, 0); - - if (!connected_) - return net::ERR_UNEXPECTED; - - std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); - - if (write_result.async) { - write_callback_ = callback; - write_result_ = write_result.result; - DCHECK(write_callback_ != NULL); - write_pending_ = true; - return net::ERR_IO_PENDING; - } - - was_used_to_convey_data_ = true; - write_pending_ = false; - return write_result.result; -} - int DeterministicMockTCPClientSocket::Read( net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { if (!connected_) @@ -1092,6 +1091,53 @@ int DeterministicMockTCPClientSocket::Read( return CompleteRead(); } +int DeterministicMockTCPClientSocket::Write( + net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + old_write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(old_write_callback_ != NULL); + write_pending_ = true; + return net::ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + write_pending_ = false; + return write_result.result; +} +int DeterministicMockTCPClientSocket::Write( + net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(!write_callback_.is_null()); + write_pending_ = true; + return net::ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + write_pending_ = false; + return write_result.result; +} + // TODO(erikchen): Support connect sequencing. int DeterministicMockTCPClientSocket::Connect( net::OldCompletionCallback* callback) { @@ -1235,6 +1281,10 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { return transport_->socket()->Write(buf, buf_len, callback); } +int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + return transport_->socket()->Write(buf, buf_len, callback); +} int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { OldConnectCallback* connect_callback = new OldConnectCallback( @@ -1422,6 +1472,23 @@ int MockUDPClientSocket::Write(net::IOBuffer* buf, int buf_len, } return write_result.result; } +int MockUDPClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + RunCallbackAsync(callback, write_result.result); + return ERR_IO_PENDING; + } + return write_result.result; +} bool MockUDPClientSocket::SetReceiveBufferSize(int32 size) { return true; diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 73ffc3d..1840bfc 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -592,6 +592,8 @@ class MockClientSocket : public net::SSLClientSocket { const net::CompletionCallback& callback) = 0; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) = 0; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) = 0; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -648,6 +650,8 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -703,12 +707,14 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, int CompleteRead(); // Socket implementation. - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; virtual int Read(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; virtual int Read(net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -726,7 +732,8 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, private: bool write_pending_; - net::OldCompletionCallback* write_callback_; + net::OldCompletionCallback* old_write_callback_; + net::CompletionCallback write_callback_; int write_result_; net::MockRead read_data_; @@ -757,6 +764,8 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; @@ -804,6 +813,8 @@ class MockUDPClientSocket : public DatagramClientSocket, const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index ea5fc7a..3497e24 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -204,10 +204,18 @@ int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, // Write is called by the transport layer. This can only be done if the // SOCKS handshake is complete. int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} +int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index b83a347..3c0255d 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -74,6 +74,9 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 623f202..40f8453 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -235,7 +235,15 @@ int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} +int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!old_user_callback_ && user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 1a4a75c..0f28e10 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -71,6 +71,9 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index f58d340..32f0a08 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -533,7 +533,7 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, ssl_config_(ssl_config), old_user_connect_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), next_handshake_state_(STATE_NONE), @@ -737,7 +737,25 @@ int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, int SSLClientSocketMac::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!user_write_buf_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoPayloadWrite(); + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} +int SSLClientSocketMac::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake()); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!user_write_buf_); user_write_buf_ = buf; @@ -972,15 +990,23 @@ void SSLClientSocketMac::DoReadCallback(int rv) { void SSLClientSocketMac::DoWriteCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(old_user_write_callback_ && !user_write_callback_.is_null()); // Since Run may result in Write being called, clear user_write_callback_ up // front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } void SSLClientSocketMac::OnHandshakeIOComplete(int result) { diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index 7792cb3..661432b 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -77,6 +77,9 @@ class SSLClientSocketMac : public SSLClientSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -126,7 +129,8 @@ class SSLClientSocketMac : public SSLClientSocket { CompletionCallback user_connect_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 6c16c4a..cd40140 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -447,7 +447,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, ssl_config_(ssl_config), old_user_connect_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), server_cert_nss_(NULL), @@ -576,7 +576,7 @@ int SSLClientSocketNSS::Connect(OldCompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(!user_write_buf_); @@ -624,7 +624,7 @@ int SSLClientSocketNSS::Connect(const CompletionCallback& callback) { DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(!user_write_buf_); @@ -696,7 +696,8 @@ void SSLClientSocketNSS::Disconnect() { user_connect_callback_.Reset(); old_user_read_callback_ = NULL; user_read_callback_.Reset(); - user_write_callback_ = NULL; + old_user_write_callback_ = NULL; + user_write_callback_.Reset(); user_read_buf_ = NULL; user_read_buf_len_ = 0; user_write_buf_ = NULL; @@ -864,8 +865,36 @@ int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, EnterFunction(buf_len); DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_write_callback_); - DCHECK(!old_user_connect_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + if (corked_) { + corked_ = false; + uncork_timer_.Reset(); + } + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + LeaveFunction(rv); + return rv; +} +int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + EnterFunction(buf_len); + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -1247,15 +1276,23 @@ void SSLClientSocketNSS::DoReadCallback(int rv) { void SSLClientSocketNSS::DoWriteCallback(int rv) { EnterFunction(rv); DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(old_user_write_callback_ || !user_write_callback_.is_null()); // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } LeaveFunction(""); } diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 7e75cd5..22cb4a5 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -96,6 +96,9 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -233,7 +236,8 @@ class SSLClientSocketNSS : public SSLClientSocket { CompletionCallback user_connect_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index 1237348..ea3b629 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -392,7 +392,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( transport_recv_busy_(false), old_user_connect_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), completed_handshake_(false), client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), @@ -632,11 +632,19 @@ void SSLClientSocketOpenSSL::DoReadCallback(int rv) { void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } // StreamSocket methods @@ -712,7 +720,8 @@ void SSLClientSocketOpenSSL::Disconnect() { user_connect_callback_.Reset(); old_user_read_callback_ = NULL; user_read_callback_.Reset(); - user_write_callback_ = NULL; + old_user_write_callback_ = NULL; + user_write_callback_.Reset(); user_read_buf_ = NULL; user_read_buf_len_ = 0; user_write_buf_ = NULL; @@ -1246,6 +1255,23 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf, int rv = DoWriteLoop(OK); if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + + return rv; +} +int SSLClientSocketOpenSSL::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { user_write_callback_ = callback; } else { user_write_buf_ = NULL; diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index a15c0e3..74ee3bb 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -82,7 +82,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -126,7 +129,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { CompletionCallback user_connect_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 30f599d..3c855fb 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -400,7 +400,7 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, old_user_connect_callback_(NULL), old_user_read_callback_(NULL), user_read_buf_len_(0), - user_write_callback_(NULL), + old_user_write_callback_(NULL), user_write_buf_len_(0), next_state_(STATE_NONE), cert_verifier_(context.cert_verifier), @@ -869,7 +869,29 @@ int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, int SSLClientSocketWin::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { DCHECK(completed_handshake()); - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + + DCHECK(!user_write_buf_); + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoPayloadEncrypt(); + if (rv != OK) + return rv; + + rv = DoPayloadWrite(); + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} +int SSLClientSocketWin::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake()); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!user_write_buf_); user_write_buf_ = buf; @@ -964,12 +986,20 @@ void SSLClientSocketWin::OnWriteComplete(int result) { int rv = DoPayloadWriteComplete(result); if (rv != ERR_IO_PENDING) { - DCHECK(user_write_callback_); - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + DCHECK(old_user_write_callback_ || !user_write_callback_.is_null()); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } } diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 27ce300..8fc326d 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -75,7 +75,10 @@ class SSLClientSocketWin : public SSLClientSocket { virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -134,7 +137,8 @@ class SSLClientSocketWin : public SSLClientSocket { int user_read_buf_len_; // User function to callback when a Write() completes. - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; scoped_refptr<IOBuffer> user_write_buf_; int user_write_buf_len_; diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 0785dd7d..c87076f 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -66,7 +66,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( transport_recv_busy_(false), user_handshake_callback_(NULL), old_user_read_callback_(NULL), - user_write_callback_(NULL), + old_user_write_callback_(NULL), nss_fd_(NULL), nss_bufs_(NULL), transport_socket_(transport_socket), @@ -199,7 +199,26 @@ int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { - DCHECK(!user_write_callback_); + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + old_user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} +int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(!old_user_write_callback_ && user_write_callback_.is_null()); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -760,15 +779,23 @@ void SSLServerSocketNSS::DoReadCallback(int rv) { void SSLServerSocketNSS::DoWriteCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(old_user_write_callback_ || !user_write_callback_.is_null()); // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); + if (old_user_write_callback_) { + OldCompletionCallback* c = old_user_write_callback_; + old_user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); + } else { + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); + } } // static diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 39283f6..4a4c3f5 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -45,6 +45,8 @@ class SSLServerSocketNSS : public SSLServerSocket { const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -113,7 +115,8 @@ class SSLServerSocketNSS : public SSLServerSocket { OldCompletionCallback* user_handshake_callback_; OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + OldCompletionCallback* old_user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index eb9dc7c..1ec6b4b 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -85,6 +85,14 @@ class FakeDataChannel { &FakeDataChannel::DoReadCallback)); return buf_len; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + data_.push(new net::DrainableIOBuffer(buf, buf_len)); + MessageLoop::current()->PostTask( + FROM_HERE, task_factory_.NewRunnableMethod( + &FakeDataChannel::DoReadCallback)); + return buf_len; + } private: void DoReadCallback() { @@ -160,6 +168,12 @@ class FakeSocket : public StreamSocket { buf_len = rand() % buf_len + 1; return outgoing_->Write(buf, buf_len, callback); } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + // Write random number of bytes. + buf_len = rand() % buf_len + 1; + return outgoing_->Write(buf, buf_len, callback); + } virtual bool SetReceiveBufferSize(int32 size) { return true; diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index 3c99ae5..3243177 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -574,6 +574,42 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, old_write_callback_ = callback; return ERR_IO_PENDING; } +int TCPClientSocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_); + DCHECK(!waiting_connect()); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_GT(buf_len, 0); + + int nwrite = InternalWrite(buf, buf_len); + if (nwrite >= 0) { + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(nwrite); + if (nwrite > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite, + buf->data()); + return nwrite; + } + if (errno != EAGAIN && errno != EWOULDBLOCK) + return MapSystemError(errno); + + if (!MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno; + return MapSystemError(errno); + } + + write_buf_ = buf; + write_buf_len_ = buf_len; + write_callback_ = callback; + return ERR_IO_PENDING; +} int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { int nwrite; diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index 47f19a0..448d7a7 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -70,6 +70,9 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index cb59a98..8101782 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -318,7 +318,7 @@ TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, waiting_read_(false), waiting_write_(false), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), @@ -792,7 +792,58 @@ int TCPClientSocketWin::Write(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_write_); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK_GT(buf_len, 0); + DCHECK(!core_->write_iobuffer_); + + base::StatsCounter writes("tcp.writes"); + writes.Increment(); + + core_->write_buffer_.len = buf_len; + core_->write_buffer_.buf = buf->data(); + core_->write_buffer_length_ = buf_len; + + // TODO(wtc): Remove the assertion after enough testing. + AssertEventNotSignaled(core_->write_overlapped_.hEvent); + DWORD num; + int rv = WSASend(socket_, &core_->write_buffer_, 1, &num, 0, + &core_->write_overlapped_, NULL); + if (rv == 0) { + if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) { + rv = static_cast<int>(num); + if (rv > buf_len || rv < 0) { + // It seems that some winsock interceptors report that more was written + // than was available. Treat this as an error. http://crbug.com/27870 + LOG(ERROR) << "Detected broken LSP: Asked to write " << buf_len + << " bytes, but " << rv << " bytes reported."; + return ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES; + } + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(rv); + if (rv > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, + core_->write_buffer_.buf); + return rv; + } + } else { + int os_error = WSAGetLastError(); + if (os_error != WSA_IO_PENDING) + return MapSystemError(os_error); + } + core_->WatchForWrite(); + waiting_write_ = true; + old_write_callback_ = callback; + core_->write_iobuffer_ = buf; + return ERR_IO_PENDING; +} +int TCPClientSocketWin::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_write_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); DCHECK_GT(buf_len, 0); DCHECK(!core_->write_iobuffer_); @@ -897,12 +948,19 @@ void TCPClientSocketWin::DoReadCallback(int rv) { void TCPClientSocketWin::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(write_callback_); + DCHECK(old_write_callback_ || !write_callback_.is_null()); - // since Run may result in Write being called, clear write_callback_ up front. - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + // Since Run may result in Write being called, clear old_write_callback_ up + // front. + if (old_write_callback_) { + OldCompletionCallback* c = old_write_callback_; + old_write_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); + } } void TCPClientSocketWin::DidCompleteConnect() { diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index 1e75933..2e58cad 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -63,7 +63,10 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -130,7 +133,8 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, CompletionCallback read_callback_; // External callback; called when write is complete. - OldCompletionCallback* write_callback_; + OldCompletionCallback* old_write_callback_; + CompletionCallback write_callback_; // The next state for the Connect() state machine. ConnectState next_connect_state_; diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 56b1fa9c..bfd330a 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -107,6 +107,10 @@ class MockClientSocket : public StreamSocket { OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } @@ -169,6 +173,10 @@ class MockFailingClientSocket : public StreamSocket { OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } @@ -260,6 +268,10 @@ class MockPendingClientSocket : public StreamSocket { OldCompletionCallback* callback) { return ERR_FAILED; } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return ERR_FAILED; + } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } diff --git a/net/socket/web_socket_server_socket.cc b/net/socket/web_socket_server_socket.cc index d792689..b24f25c 100644 --- a/net/socket/web_socket_server_socket.cc +++ b/net/socket/web_socket_server_socket.cc @@ -397,6 +397,53 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { ConsiderTransportWrite(); return net::ERR_IO_PENDING; } + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE { + if (buf_len == 0) + return 0; + if (buf == NULL || buf_len < 0) { + NOTREACHED(); + return net::ERR_INVALID_ARGUMENT; + } + DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'), + buf->data() + buf_len); + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + break; + } + case PHASE_SHUT: { + return net::ERR_SOCKET_NOT_CONNECTED; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + + net::IOBuffer* frame_start = new net::IOBuffer(1); + frame_start->data()[0] = '\x00'; + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, + new net::DrainableIOBuffer(frame_start, 1), + NULL)); + + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, + new net::DrainableIOBuffer(buf, buf_len), + callback)); + + net::IOBuffer* frame_end = new net::IOBuffer(1); + frame_end->data()[0] = '\xff'; + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, + new net::DrainableIOBuffer(frame_end, 1), + NULL)); + + ConsiderTransportWrite(); + return net::ERR_IO_PENDING; + } virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return transport_socket_->SetReceiveBufferSize(size); diff --git a/net/socket/web_socket_server_socket_unittest.cc b/net/socket/web_socket_server_socket_unittest.cc index cabb4b9..bb4d022 100644 --- a/net/socket/web_socket_server_socket_unittest.cc +++ b/net/socket/web_socket_server_socket_unittest.cc @@ -157,6 +157,24 @@ class TestingTransportSocket : public net::Socket { } MessageLoop::current()->PostTask(FROM_HERE, method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoOldWriteCallback, callback, lot)); + return net::ERR_IO_PENDING; + } + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + CHECK_GT(buf_len, 0); + int remaining = answer_->BytesRemaining(); + CHECK_GE(remaining, buf_len); + int lot = std::min(remaining, buf_len); + if (GetRand(0, 1)) + lot = GetRand(1, lot); + std::copy(buf->data(), buf->data() + lot, answer_->data()); + answer_->DidConsume(lot); + if (GetRand(0, 1)) { + return lot; + } + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( &TestingTransportSocket::DoWriteCallback, callback, lot)); return net::ERR_IO_PENDING; } @@ -192,10 +210,14 @@ class TestingTransportSocket : public net::Socket { } } - void DoWriteCallback(net::OldCompletionCallback* callback, int result) { + void DoOldWriteCallback(net::OldCompletionCallback* callback, int result) { if (callback) callback->Run(result); } + void DoWriteCallback(const net::CompletionCallback& callback, int result) { + if (!callback.is_null()) + callback.Run(result); + } bool is_closed_; diff --git a/net/spdy/spdy_proxy_client_socket.cc b/net/spdy/spdy_proxy_client_socket.cc index 7b2c83b..9dd5145 100644 --- a/net/spdy/spdy_proxy_client_socket.cc +++ b/net/spdy/spdy_proxy_client_socket.cc @@ -34,7 +34,7 @@ SpdyProxyClientSocket::SpdyProxyClientSocket( next_state_(STATE_DISCONNECTED), spdy_stream_(spdy_stream), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), endpoint_(endpoint), auth_( new HttpAuthController(HttpAuth::AUTH_PROXY, @@ -126,7 +126,8 @@ void SpdyProxyClientSocket::Disconnect() { write_buffer_len_ = 0; write_bytes_outstanding_ = 0; - write_callback_ = NULL; + old_write_callback_ = NULL; + write_callback_.Reset(); next_state_ = STATE_DISCONNECTED; @@ -244,7 +245,45 @@ int SpdyProxyClientSocket::PopulateUserReadBuffer() { int SpdyProxyClientSocket::Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) { - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + if (next_state_ != STATE_OPEN) + return ERR_SOCKET_NOT_CONNECTED; + + DCHECK(spdy_stream_); + write_bytes_outstanding_= buf_len; + if (buf_len <= kMaxSpdyFrameChunkSize) { + int rv = spdy_stream_->WriteStreamData(buf, buf_len, spdy::DATA_FLAG_NONE); + if (rv == ERR_IO_PENDING) { + old_write_callback_ = callback; + write_buffer_len_ = buf_len; + } + return rv; + } + + // Since a SPDY Data frame can only include kMaxSpdyFrameChunkSize bytes + // we need to send multiple data frames + for (int i = 0; i < buf_len; i += kMaxSpdyFrameChunkSize) { + int len = std::min(kMaxSpdyFrameChunkSize, buf_len - i); + scoped_refptr<DrainableIOBuffer> iobuf(new DrainableIOBuffer(buf, i + len)); + iobuf->SetOffset(i); + int rv = spdy_stream_->WriteStreamData(iobuf, len, spdy::DATA_FLAG_NONE); + if (rv > 0) { + write_bytes_outstanding_ -= rv; + } else if (rv != ERR_IO_PENDING) { + return rv; + } + } + if (write_bytes_outstanding_ > 0) { + old_write_callback_ = callback; + write_buffer_len_ = buf_len; + return ERR_IO_PENDING; + } else { + return buf_len; + } +} +int SpdyProxyClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(!old_write_callback_ && write_callback_.is_null()); if (next_state_ != STATE_OPEN) return ERR_SOCKET_NOT_CONNECTED; @@ -532,7 +571,7 @@ void SpdyProxyClientSocket::OnDataReceived(const char* data, int length) { } void SpdyProxyClientSocket::OnDataSent(int length) { - DCHECK(write_callback_); + DCHECK(old_write_callback_ || !write_callback_.is_null()); write_bytes_outstanding_ -= length; @@ -542,9 +581,15 @@ void SpdyProxyClientSocket::OnDataSent(int length) { int rv = write_buffer_len_; write_buffer_len_ = 0; write_bytes_outstanding_ = 0; - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + if (old_write_callback_) { + OldCompletionCallback* c = old_write_callback_; + old_write_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); + } } } @@ -561,8 +606,10 @@ void SpdyProxyClientSocket::OnClose(int status) { next_state_ = STATE_DISCONNECTED; base::WeakPtr<SpdyProxyClientSocket> weak_ptr = weak_factory_.GetWeakPtr(); - OldCompletionCallback* write_callback = write_callback_; - write_callback_ = NULL; + OldCompletionCallback* old_write_callback = old_write_callback_; + CompletionCallback write_callback = write_callback_; + old_write_callback_ = NULL; + write_callback_.Reset(); write_buffer_len_ = 0; write_bytes_outstanding_ = 0; @@ -584,8 +631,12 @@ void SpdyProxyClientSocket::OnClose(int status) { OnDataReceived(NULL, 0); } // This may have been deleted by read_callback_, so check first. - if (weak_ptr && write_callback) - write_callback->Run(ERR_CONNECTION_CLOSED); + if (weak_ptr && (old_write_callback || !write_callback.is_null())) { + if (old_write_callback) + old_write_callback->Run(ERR_CONNECTION_CLOSED); + else + write_callback.Run(ERR_CONNECTION_CLOSED); + } } void SpdyProxyClientSocket::set_chunk_callback(ChunkCallback* /*callback*/) { diff --git a/net/spdy/spdy_proxy_client_socket.h b/net/spdy/spdy_proxy_client_socket.h index fb2d9a3..9fb97e8 100644 --- a/net/spdy/spdy_proxy_client_socket.h +++ b/net/spdy/spdy_proxy_client_socket.h @@ -84,6 +84,9 @@ class NET_EXPORT_PRIVATE SpdyProxyClientSocket : public ProxyClientSocket, virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual int GetPeerAddress(AddressList* address) const OVERRIDE; @@ -137,7 +140,8 @@ class NET_EXPORT_PRIVATE SpdyProxyClientSocket : public ProxyClientSocket, OldCompletionCallback* old_read_callback_; CompletionCallback read_callback_; // Stores the callback to the layer above, called on completing Write(). - OldCompletionCallback* write_callback_; + OldCompletionCallback* old_write_callback_; + CompletionCallback write_callback_; // CONNECT request and response. HttpRequestInfo request_; diff --git a/net/udp/udp_client_socket.cc b/net/udp/udp_client_socket.cc index 4bb3885..08b2cbe 100644 --- a/net/udp/udp_client_socket.cc +++ b/net/udp/udp_client_socket.cc @@ -38,6 +38,11 @@ int UDPClientSocket::Write(IOBuffer* buf, OldCompletionCallback* callback) { return socket_.Write(buf, buf_len, callback); } +int UDPClientSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + return socket_.Write(buf, buf_len, callback); +} void UDPClientSocket::Close() { socket_.Close(); diff --git a/net/udp/udp_client_socket.h b/net/udp/udp_client_socket.h index b25df63..2fee21a 100644 --- a/net/udp/udp_client_socket.h +++ b/net/udp/udp_client_socket.h @@ -32,6 +32,8 @@ class NET_EXPORT_PRIVATE UDPClientSocket : public DatagramClientSocket { const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual void Close() OVERRIDE; virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; diff --git a/net/udp/udp_socket_libevent.cc b/net/udp/udp_socket_libevent.cc index a669c37..9777626 100644 --- a/net/udp/udp_socket_libevent.cc +++ b/net/udp/udp_socket_libevent.cc @@ -48,7 +48,7 @@ UDPSocketLibevent::UDPSocketLibevent( recv_from_address_(NULL), write_buf_len_(0), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_UDP_SOCKET)) { scoped_refptr<NetLog::EventParameters> params; if (source.is_valid()) @@ -77,7 +77,8 @@ void UDPSocketLibevent::Close() { recv_from_address_ = NULL; write_buf_ = NULL; write_buf_len_ = 0; - write_callback_ = NULL; + old_write_callback_ = NULL; + write_callback_.Reset(); send_to_address_.reset(); bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); @@ -212,6 +213,11 @@ int UDPSocketLibevent::Write(IOBuffer* buf, OldCompletionCallback* callback) { return SendToOrWrite(buf, buf_len, NULL, callback); } +int UDPSocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + return SendToOrWrite(buf, buf_len, NULL, callback); +} int UDPSocketLibevent::SendTo(IOBuffer* buf, int buf_len, @@ -226,7 +232,7 @@ int UDPSocketLibevent::SendToOrWrite(IOBuffer* buf, OldCompletionCallback* callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); DCHECK(callback); // Synchronous operation not supported DCHECK_GT(buf_len, 0); @@ -249,6 +255,38 @@ int UDPSocketLibevent::SendToOrWrite(IOBuffer* buf, if (address) { send_to_address_.reset(new IPEndPoint(*address)); } + old_write_callback_ = callback; + return ERR_IO_PENDING; +} +int UDPSocketLibevent::SendToOrWrite(IOBuffer* buf, + int buf_len, + const IPEndPoint* address, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK(!callback.is_null()); // Synchronous operation not supported + DCHECK_GT(buf_len, 0); + + int result = InternalSendTo(buf, buf_len, address); + if (result != ERR_IO_PENDING) + return result; + + if (!MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno; + int result = MapSystemError(errno); + LogWrite(result, NULL, NULL); + return result; + } + + write_buf_ = buf; + write_buf_len_ = buf_len; + DCHECK(!send_to_address_.get()); + if (address) { + send_to_address_.reset(new IPEndPoint(*address)); + } write_callback_ = callback; return ERR_IO_PENDING; } @@ -339,12 +377,18 @@ void UDPSocketLibevent::DoReadCallback(int rv) { void UDPSocketLibevent::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(write_callback_); + DCHECK(old_write_callback_ || !write_callback_.is_null()); // since Run may result in Write being called, clear write_callback_ up front. - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + if (old_write_callback_) { + OldCompletionCallback* c = old_write_callback_; + old_write_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); + } } void UDPSocketLibevent::DidCompleteRead() { diff --git a/net/udp/udp_socket_libevent.h b/net/udp/udp_socket_libevent.h index 6cef1c0..45ec1bb 100644 --- a/net/udp/udp_socket_libevent.h +++ b/net/udp/udp_socket_libevent.h @@ -61,6 +61,7 @@ class UDPSocketLibevent : public base::NonThreadSafe { // Only usable from the client-side of a UDP socket, after the socket // has been connected. int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); // Read from a socket and receive sender address information. // |buf| is the buffer to read data into. @@ -140,7 +141,7 @@ class UDPSocketLibevent : public base::NonThreadSafe { virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {} virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE { - if (socket_->write_callback_) + if (socket_->old_write_callback_) socket_->DidCompleteWrite(); } @@ -173,6 +174,10 @@ class UDPSocketLibevent : public base::NonThreadSafe { int buf_len, const IPEndPoint* address, OldCompletionCallback* callback); + int SendToOrWrite(IOBuffer* buf, + int buf_len, + const IPEndPoint* address, + const CompletionCallback& callback); int InternalConnect(const IPEndPoint& address); int InternalRecvFrom(IOBuffer* buf, int buf_len, IPEndPoint* address); @@ -218,7 +223,8 @@ class UDPSocketLibevent : public base::NonThreadSafe { CompletionCallback read_callback_; // External callback; called when write is complete. - OldCompletionCallback* write_callback_; + OldCompletionCallback* old_write_callback_; + CompletionCallback write_callback_; BoundNetLog net_log_; diff --git a/net/udp/udp_socket_win.cc b/net/udp/udp_socket_win.cc index 88aa23b..6499444 100644 --- a/net/udp/udp_socket_win.cc +++ b/net/udp/udp_socket_win.cc @@ -52,7 +52,7 @@ UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type, ALLOW_THIS_IN_INITIALIZER_LIST(write_delegate_(this)), recv_from_address_(NULL), old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_UDP_SOCKET)) { EnsureWinsockInit(); scoped_refptr<NetLog::EventParameters> params; @@ -82,7 +82,8 @@ void UDPSocketWin::Close() { old_read_callback_ = NULL; read_callback_.Reset(); recv_from_address_ = NULL; - write_callback_ = NULL; + old_write_callback_ = NULL; + write_callback_.Reset(); read_watcher_.StopWatching(); write_watcher_.StopWatching(); @@ -192,6 +193,11 @@ int UDPSocketWin::Write(IOBuffer* buf, OldCompletionCallback* callback) { return SendToOrWrite(buf, buf_len, NULL, callback); } +int UDPSocketWin::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + return SendToOrWrite(buf, buf_len, NULL, callback); +} int UDPSocketWin::SendTo(IOBuffer* buf, int buf_len, @@ -206,7 +212,7 @@ int UDPSocketWin::SendToOrWrite(IOBuffer* buf, OldCompletionCallback* callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(INVALID_SOCKET, socket_); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); DCHECK(callback); // Synchronous operation not supported. DCHECK_GT(buf_len, 0); DCHECK(!send_to_address_.get()); @@ -218,6 +224,27 @@ int UDPSocketWin::SendToOrWrite(IOBuffer* buf, if (address) send_to_address_.reset(new IPEndPoint(*address)); write_iobuffer_ = buf; + old_write_callback_ = callback; + return ERR_IO_PENDING; +} +int UDPSocketWin::SendToOrWrite(IOBuffer* buf, + int buf_len, + const IPEndPoint* address, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(INVALID_SOCKET, socket_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK(!callback.is_null()); // Synchronous operation not supported. + DCHECK_GT(buf_len, 0); + DCHECK(!send_to_address_.get()); + + int nwrite = InternalSendTo(buf, buf_len, address); + if (nwrite != ERR_IO_PENDING) + return nwrite; + + if (address) + send_to_address_.reset(new IPEndPoint(*address)); + write_iobuffer_ = buf; write_callback_ = callback; return ERR_IO_PENDING; } @@ -314,12 +341,18 @@ void UDPSocketWin::DoReadCallback(int rv) { void UDPSocketWin::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(write_callback_); + DCHECK(old_write_callback_ && !write_callback_.is_null()); // since Run may result in Write being called, clear write_callback_ up front. - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + if (old_write_callback_) { + OldCompletionCallback* c = old_write_callback_; + old_write_callback_ = NULL; + c->Run(rv); + } else { + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); + } } void UDPSocketWin::DidCompleteRead() { diff --git a/net/udp/udp_socket_win.h b/net/udp/udp_socket_win.h index 6aace29..4436860 100644 --- a/net/udp/udp_socket_win.h +++ b/net/udp/udp_socket_win.h @@ -62,6 +62,7 @@ class UDPSocketWin : public base::NonThreadSafe { // Only usable from the client-side of a UDP socket, after the socket // has been connected. int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); // Read from a socket and receive sender address information. // |buf| is the buffer to read data into. @@ -156,6 +157,10 @@ class UDPSocketWin : public base::NonThreadSafe { int buf_len, const IPEndPoint* address, OldCompletionCallback* callback); + int SendToOrWrite(IOBuffer* buf, + int buf_len, + const IPEndPoint* address, + const CompletionCallback& callback); int InternalConnect(const IPEndPoint& address); int InternalRecvFrom(IOBuffer* buf, int buf_len, IPEndPoint* address); @@ -212,7 +217,8 @@ class UDPSocketWin : public base::NonThreadSafe { CompletionCallback read_callback_; // External callback; called when write is complete. - OldCompletionCallback* write_callback_; + OldCompletionCallback* old_write_callback_; + CompletionCallback write_callback_; BoundNetLog net_log_; diff --git a/remoting/jingle_glue/ssl_socket_adapter.cc b/remoting/jingle_glue/ssl_socket_adapter.cc index 7414c4b..3df508e 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.cc +++ b/remoting/jingle_glue/ssl_socket_adapter.cc @@ -188,7 +188,7 @@ void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, SSLSocketAdapter *ssl_adapter) : old_read_callback_(NULL), - write_callback_(NULL), + old_write_callback_(NULL), read_buffer_len_(0), write_buffer_len_(0), socket_(socket), @@ -327,7 +327,25 @@ int TransportSocket::Read(net::IOBuffer* buf, int buf_len, int TransportSocket::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { DCHECK(buf); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK(!write_buffer_.get()); + int result = socket_->Send(buf->data(), buf_len); + if (result < 0) { + result = net::MapSystemError(socket_->GetError()); + if (result == net::ERR_IO_PENDING) { + old_write_callback_ = callback; + write_buffer_ = buf; + write_buffer_len_ = buf_len; + } + } + if (result != net::ERR_IO_PENDING) + was_used_to_convey_data_ = true; + return result; +} +int TransportSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK(!old_write_callback_ && write_callback_.is_null()); DCHECK(!write_buffer_.get()); int result = socket_->Send(buf->data(), buf_len); if (result < 0) { @@ -386,13 +404,15 @@ void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { } void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { - if (write_callback_) { + if (old_write_callback_ || !write_callback_.is_null()) { DCHECK(write_buffer_.get()); - net::OldCompletionCallback* callback = write_callback_; + net::OldCompletionCallback* old_callback = old_write_callback_; + net::CompletionCallback callback = write_callback_; scoped_refptr<net::IOBuffer> buffer = write_buffer_; int buffer_len = write_buffer_len_; - write_callback_ = NULL; + old_write_callback_ = NULL; + write_callback_.Reset(); write_buffer_ = NULL; write_buffer_len_ = 0; @@ -400,6 +420,7 @@ void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { if (result < 0) { result = net::MapSystemError(socket_->GetError()); if (result == net::ERR_IO_PENDING) { + old_write_callback_ = old_callback; write_callback_ = callback; write_buffer_ = buffer; write_buffer_len_ = buffer_len; @@ -407,7 +428,10 @@ void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { } } was_used_to_convey_data_ = true; - callback->RunWithParams(Tuple1<int>(result)); + if (old_callback) + old_callback->RunWithParams(Tuple1<int>(result)); + else + callback.Run(result); } } diff --git a/remoting/jingle_glue/ssl_socket_adapter.h b/remoting/jingle_glue/ssl_socket_adapter.h index f929bf8..3638ae1 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.h +++ b/remoting/jingle_glue/ssl_socket_adapter.h @@ -62,6 +62,8 @@ class TransportSocket : public net::StreamSocket, public sigslot::has_slots<> { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -73,7 +75,8 @@ class TransportSocket : public net::StreamSocket, public sigslot::has_slots<> { net::OldCompletionCallback* old_read_callback_; net::CompletionCallback read_callback_; - net::OldCompletionCallback* write_callback_; + net::OldCompletionCallback* old_write_callback_; + net::CompletionCallback write_callback_; scoped_refptr<net::IOBuffer> read_buffer_; int read_buffer_len_; diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index f1a890a..18071fd 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -86,6 +86,13 @@ int FakeSocket::Write(net::IOBuffer* buf, int buf_len, buf->data(), buf->data() + buf_len); return buf_len; } +int FakeSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); + written_data_.insert(written_data_.end(), + buf->data(), buf->data() + buf_len); + return buf_len; +} bool FakeSocket::SetReceiveBufferSize(int32 size) { NOTIMPLEMENTED(); @@ -236,6 +243,13 @@ int FakeUdpSocket::Write(net::IOBuffer* buf, int buf_len, written_packets_.back().assign(buf->data(), buf->data() + buf_len); return buf_len; } +int FakeUdpSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + EXPECT_EQ(message_loop_, MessageLoop::current()); + written_packets_.push_back(std::string()); + written_packets_.back().assign(buf->data(), buf->data() + buf_len); + return buf_len; +} bool FakeUdpSocket::SetReceiveBufferSize(int32 size) { NOTIMPLEMENTED(); diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index a8239f5..c98d888 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -45,6 +45,8 @@ class FakeSocket : public net::StreamSocket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -105,6 +107,8 @@ class FakeUdpSocket : public net::Socket { const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; diff --git a/remoting/protocol/pepper_transport_socket_adapter.cc b/remoting/protocol/pepper_transport_socket_adapter.cc index 0385b9b..4876088 100644 --- a/remoting/protocol/pepper_transport_socket_adapter.cc +++ b/remoting/protocol/pepper_transport_socket_adapter.cc @@ -50,7 +50,7 @@ PepperTransportSocketAdapter::PepperTransportSocketAdapter( connected_(false), get_address_pending_(false), old_read_callback_(NULL), - write_callback_(NULL) { + old_write_callback_(NULL) { callback_factory_.Initialize(this); } @@ -112,7 +112,28 @@ int PepperTransportSocketAdapter::Read( int PepperTransportSocketAdapter::Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { DCHECK(CalledOnValidThread()); - DCHECK(!write_callback_); + DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK(!write_buffer_); + + if (!transport_.get()) + return net::ERR_SOCKET_NOT_CONNECTED; + + int result = PPErrorToNetError(transport_->Send( + buf->data(), buf_len, + callback_factory_.NewOptionalCallback( + &PepperTransportSocketAdapter::OnWrite))); + + if (result == net::ERR_IO_PENDING) { + old_write_callback_ = callback; + write_buffer_ = buf; + } + + return result; +} +int PepperTransportSocketAdapter::Write( + net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(!old_write_callback_ && write_callback_.is_null()); DCHECK(!write_buffer_); if (!transport_.get()) @@ -336,13 +357,20 @@ void PepperTransportSocketAdapter::OnRead(int32_t result) { void PepperTransportSocketAdapter::OnWrite(int32_t result) { DCHECK(CalledOnValidThread()); - DCHECK(write_callback_); + DCHECK(old_write_callback_ || !write_callback_.is_null()); DCHECK(write_buffer_); - net::OldCompletionCallback* callback = write_callback_; - write_callback_ = NULL; - write_buffer_ = NULL; - callback->Run(PPErrorToNetError(result)); + if (old_write_callback_) { + net::OldCompletionCallback* callback = old_write_callback_; + old_write_callback_ = NULL; + write_buffer_ = NULL; + callback->Run(PPErrorToNetError(result)); + } else { + net::CompletionCallback callback = write_callback_; + write_callback_.Reset(); + write_buffer_ = NULL; + callback.Run(PPErrorToNetError(result)); + } } } // namespace protocol diff --git a/remoting/protocol/pepper_transport_socket_adapter.h b/remoting/protocol/pepper_transport_socket_adapter.h index 40b75ce..8453fc4 100644 --- a/remoting/protocol/pepper_transport_socket_adapter.h +++ b/remoting/protocol/pepper_transport_socket_adapter.h @@ -55,6 +55,8 @@ class PepperTransportSocketAdapter : public base::NonThreadSafe, const net::CompletionCallback& callback) OVERRIDE; virtual int Write(net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -98,7 +100,8 @@ class PepperTransportSocketAdapter : public base::NonThreadSafe, net::CompletionCallback read_callback_; scoped_refptr<net::IOBuffer> read_buffer_; - net::OldCompletionCallback* write_callback_; + net::OldCompletionCallback* old_write_callback_; + net::CompletionCallback write_callback_; scoped_refptr<net::IOBuffer> write_buffer_; net::BoundNetLog net_log_; |