diff options
author | ajwong@chromium.org <ajwong@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-12-09 18:43:55 +0000 |
---|---|---|
committer | ajwong@chromium.org <ajwong@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-12-09 18:43:55 +0000 |
commit | 83039bbf2f2ec0e918f7000b5212d104f60f2bb7 (patch) | |
tree | b22dbd0051b57a437a588772a874271f0d02ffdb /net/socket | |
parent | e7456a206fe5b50aeb322ebabd6c26adc869a5fd (diff) | |
download | chromium_src-83039bbf2f2ec0e918f7000b5212d104f60f2bb7.zip chromium_src-83039bbf2f2ec0e918f7000b5212d104f60f2bb7.tar.gz chromium_src-83039bbf2f2ec0e918f7000b5212d104f60f2bb7.tar.bz2 |
Migrate net/socket/socket.h, net/socket/stream_socket.h to base::Bind().
This changes Socket::Read(), Socket::Write, and StreamSocket::Connect() to use CompletionCallback and fixes all users.
BUG=none
TEST=existing.
Review URL: http://codereview.chromium.org/8824006
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@113825 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/socket')
40 files changed, 733 insertions, 1806 deletions
diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 779b8dee..1aac330 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -63,18 +63,15 @@ class MockClientSocket : public StreamSocket { // Socket implementation. virtual int Read( - IOBuffer* /* buf */, int len, OldCompletionCallback* /* callback */) { - num_bytes_read_ += len; - return len; - } - virtual int Read( - IOBuffer* /* buf */, int len, const CompletionCallback& /* callback */) { + IOBuffer* /* buf */, int len, + const CompletionCallback& /* callback */) OVERRIDE { num_bytes_read_ += len; return len; } virtual int Write( - IOBuffer* /* buf */, int len, OldCompletionCallback* /* callback */) { + IOBuffer* /* buf */, int len, + const CompletionCallback& /* callback */) OVERRIDE { was_used_to_convey_data_ = true; return len; } @@ -82,11 +79,7 @@ class MockClientSocket : public StreamSocket { virtual bool SetSendBufferSize(int32 size) { return true; } // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) { - connected_ = true; - return OK; - } - virtual int Connect(const net::CompletionCallback& callback) { + virtual int Connect(const CompletionCallback& callback) OVERRIDE { connected_ = true; return OK; } @@ -328,7 +321,7 @@ class TestConnectJob : public ConnectJob { int DoConnect(bool succeed, bool was_async, bool recoverable) { int result = OK; if (succeed) { - socket()->Connect(NULL); + socket()->Connect(CompletionCallback()); } else if (recoverable) { result = ERR_PROXY_AUTH_REQUESTED; } else { @@ -384,6 +377,7 @@ class TestConnectJobFactory } // ConnectJobFactory implementation. + virtual ConnectJob* NewConnectJob( const std::string& group_name, const TestClientSocketPoolBase::Request& request, @@ -677,7 +671,7 @@ TEST_F(ClientSocketPoolBaseTest, AssignIdleSocketToGroup_WarmestSocket) { MockClientSocket* sock = static_cast<MockClientSocket*>(s); CHECK(sock); sockets_[i] = sock; - sock->Read(NULL, 1024 - i, NULL); + sock->Read(NULL, 1024 - i, CompletionCallback()); } ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); @@ -713,7 +707,7 @@ TEST_F(ClientSocketPoolBaseTest, AssignIdleSocketToGroup_LastAccessedSocket) { MockClientSocket* sock = static_cast<MockClientSocket*>(s); CHECK(sock); sockets_[i] = sock; - sock->Read(NULL, 1024 - i, NULL); + sock->Read(NULL, 1024 - i, CompletionCallback()); } ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); @@ -2037,7 +2031,7 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimer) { handle.Reset(); EXPECT_EQ(OK, callback2.WaitForResult()); // Use the socket. - EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, NULL)); + EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, CompletionCallback())); handle2.Reset(); // The idle socket timeout value was set to 10 milliseconds. Wait 100 @@ -2111,7 +2105,7 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { handle.Reset(); EXPECT_EQ(OK, callback2.WaitForResult()); // Use the socket. - EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, NULL)); + EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, CompletionCallback())); handle2.Reset(); // We post all of our delayed tasks with a 2ms delay. I.e. they don't @@ -2872,8 +2866,8 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { EXPECT_EQ(OK, callback3.WaitForResult()); // Use the socket. - EXPECT_EQ(1, handle1.socket()->Write(NULL, 1, NULL)); - EXPECT_EQ(1, handle3.socket()->Write(NULL, 1, NULL)); + EXPECT_EQ(1, handle1.socket()->Write(NULL, 1, CompletionCallback())); + EXPECT_EQ(1, handle3.socket()->Write(NULL, 1, CompletionCallback())); handle1.Reset(); handle2.Reset(); diff --git a/net/socket/deterministic_socket_data_unittest.cc b/net/socket/deterministic_socket_data_unittest.cc index cb6ca32..8740c8d 100644 --- a/net/socket/deterministic_socket_data_unittest.cc +++ b/net/socket/deterministic_socket_data_unittest.cc @@ -41,8 +41,8 @@ class DeterministicSocketDataTest : public PlatformTest { void AssertAsyncWriteEquals(const char* data, int len); void AssertWriteReturns(const char* data, int len, int rv); - TestOldCompletionCallback read_callback_; - TestOldCompletionCallback write_callback_; + TestCompletionCallback read_callback_; + TestCompletionCallback write_callback_; StreamSocket* sock_; scoped_refptr<DeterministicSocketData> data_; @@ -61,9 +61,7 @@ class DeterministicSocketDataTest : public PlatformTest { }; DeterministicSocketDataTest::DeterministicSocketDataTest() - : read_callback_(), - write_callback_(), - sock_(NULL), + : sock_(NULL), data_(NULL), read_buf_(NULL), connect_data_(false, OK), @@ -125,7 +123,7 @@ void DeterministicSocketDataTest::AssertAsyncReadEquals(const char* data, void DeterministicSocketDataTest::AssertReadReturns(const char* data, int len, int rv) { read_buf_ = new IOBuffer(len); - ASSERT_EQ(rv, sock_->Read(read_buf_, len, &read_callback_)); + ASSERT_EQ(rv, sock_->Read(read_buf_, len, read_callback_.callback())); } void DeterministicSocketDataTest::AssertReadBufferEquals(const char* data, @@ -139,7 +137,7 @@ void DeterministicSocketDataTest::AssertSyncWriteEquals(const char* data, memcpy(buf->data(), data, len); // Issue the write, which will complete immediately - ASSERT_EQ(len, sock_->Write(buf, len, &write_callback_)); + ASSERT_EQ(len, sock_->Write(buf, len, write_callback_.callback())); } void DeterministicSocketDataTest::AssertAsyncWriteEquals(const char* data, @@ -160,7 +158,7 @@ void DeterministicSocketDataTest::AssertWriteReturns(const char* data, memcpy(buf->data(), data, len); // Issue the read, which will complete asynchronously - ASSERT_EQ(rv, sock_->Write(buf, len, &write_callback_)); + ASSERT_EQ(rv, sock_->Write(buf, len, write_callback_.callback())); } // ----------- Read diff --git a/net/socket/socket.h b/net/socket/socket.h index c185c44..2f1fe50 100644 --- a/net/socket/socket.h +++ b/net/socket/socket.h @@ -30,8 +30,6 @@ class NET_EXPORT Socket { // closed. If the socket is Disconnected before the read completes, the // callback will not be invoked. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) = 0; - virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) = 0; // Writes data, up to |buf_len| bytes, to the socket. Note: data may be @@ -47,7 +45,7 @@ class NET_EXPORT Socket { // of the actual buffer that is written to the socket. If the socket is // Disconnected before the write completes, the callback will not be invoked. virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) = 0; + const CompletionCallback& callback) = 0; // Set the receive buffer size (in bytes) for the socket. // Note: changing this value can affect the TCP window size on some platforms. diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 376d542..6b4e00f 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -176,7 +176,7 @@ MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { // Check that what we are writing matches the expectation. // Then give the mocked return value. - net::MockWrite* w = &writes_[write_index_++]; + MockWrite* w = &writes_[write_index_++]; w->time_stamp = base::Time::Now(); int result = w->result; if (w->data) { @@ -191,8 +191,8 @@ MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { std::string actual_data(data.substr(0, w->data_len)); EXPECT_EQ(expected_data, actual_data); if (expected_data != actual_data) - return MockWriteResult(false, net::ERR_UNEXPECTED); - if (result == net::OK) + return MockWriteResult(false, ERR_UNEXPECTED); + if (result == OK) result = w->data_len; } return MockWriteResult(w->async, result); @@ -252,7 +252,7 @@ DelayedSocketData::DelayedSocketData( MockWrite* writes, size_t writes_count) : StaticSocketDataProvider(reads, reads_count, writes, writes_count), write_delay_(write_delay), - ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) { + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { DCHECK_GE(write_delay_, 0); } @@ -261,7 +261,7 @@ DelayedSocketData::DelayedSocketData( size_t reads_count, MockWrite* writes, size_t writes_count) : StaticSocketDataProvider(reads, reads_count, writes, writes_count), write_delay_(write_delay), - ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) { + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { DCHECK_GE(write_delay_, 0); set_connect_data(connect); } @@ -284,14 +284,17 @@ MockWriteResult DelayedSocketData::OnWrite(const std::string& data) { MockWriteResult rv = StaticSocketDataProvider::OnWrite(data); // Now that our write has completed, we can allow reads to continue. if (!--write_delay_) - MessageLoop::current()->PostDelayedTask(FROM_HERE, - factory_.NewRunnableMethod(&DelayedSocketData::CompleteRead), 100); + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&DelayedSocketData::CompleteRead, + weak_factory_.GetWeakPtr()), + 100); return rv; } void DelayedSocketData::Reset() { set_socket(NULL); - factory_.RevokeAll(); + weak_factory_.InvalidateWeakPtrs(); StaticSocketDataProvider::Reset(); } @@ -303,8 +306,8 @@ void DelayedSocketData::CompleteRead() { OrderedSocketData::OrderedSocketData( MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) : StaticSocketDataProvider(reads, reads_count, writes, writes_count), - sequence_number_(0), loop_stop_stage_(0), callback_(NULL), - blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) { + sequence_number_(0), loop_stop_stage_(0), + blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { } OrderedSocketData::OrderedSocketData( @@ -312,8 +315,8 @@ OrderedSocketData::OrderedSocketData( MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) : StaticSocketDataProvider(reads, reads_count, writes, writes_count), - sequence_number_(0), loop_stop_stage_(0), callback_(NULL), - blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) { + sequence_number_(0), loop_stop_stage_(0), + blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { set_connect_data(connect); } @@ -336,12 +339,12 @@ void OrderedSocketData::EndLoop() { NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": Posting Quit at read " << read_index(); loop_stop_stage_ = sequence_number_; - if (callback_) - callback_->RunWithParams(Tuple1<int>(ERR_IO_PENDING)); + if (!callback_.is_null()) + callback_.Run(ERR_IO_PENDING); } MockRead OrderedSocketData::GetNextRead() { - factory_.RevokeAll(); + weak_factory_.InvalidateWeakPtrs(); blocked_ = false; const MockRead& next_read = StaticSocketDataProvider::PeekRead(); if (next_read.sequence_number & MockRead::STOPLOOP) @@ -374,7 +377,9 @@ MockWriteResult OrderedSocketData::OnWrite(const std::string& data) { // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED(). MessageLoop::current()->PostDelayedTask( FROM_HERE, - factory_.NewRunnableMethod(&OrderedSocketData::CompleteRead), 100); + base::Bind(&OrderedSocketData::CompleteRead, + weak_factory_.GetWeakPtr()), + 100); } return StaticSocketDataProvider::OnWrite(data); } @@ -385,7 +390,7 @@ void OrderedSocketData::Reset() { sequence_number_ = 0; loop_stop_stage_ = 0; set_socket(NULL); - factory_.RevokeAll(); + weak_factory_.InvalidateWeakPtrs(); StaticSocketDataProvider::Reset(); } @@ -593,8 +598,8 @@ MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, - NetLog* net_log, - const NetLog::Source& source) { + net::NetLog* net_log, + const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); data_provider->set_socket(socket); @@ -605,7 +610,7 @@ DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( StreamSocket* MockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, - const NetLog::Source& source) { + const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); MockTCPClientSocket* socket = new MockTCPClientSocket(addresses, net_log, data_provider); @@ -633,7 +638,7 @@ void MockClientSocketFactory::ClearSSLSessionCache() { MockClientSocket::MockClientSocket(net::NetLog* net_log) : ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)), connected_(false), - net_log_(NetLog::Source(), net_log) { + net_log_(net::NetLog::Source(), net_log) { } bool MockClientSocket::SetReceiveBufferSize(int32 size) { @@ -676,12 +681,12 @@ const BoundNetLog& MockClientSocket::NetLog() const { return net_log_; } -void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { +void MockClientSocket::GetSSLInfo(SSLInfo* ssl_info) { NOTREACHED(); } void MockClientSocket::GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info) { + SSLCertRequestInfo* cert_request_info) { } int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label, @@ -701,44 +706,32 @@ MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) { MockClientSocket::~MockClientSocket() {} -void MockClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, - int result) { - MessageLoop::current()->PostTask(FROM_HERE, - base::Bind(&MockClientSocket::RunOldCallback, weak_factory_.GetWeakPtr(), - callback, result)); -} -void MockClientSocket::RunCallbackAsync(const net::CompletionCallback& callback, +void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback, int result) { MessageLoop::current()->PostTask(FROM_HERE, base::Bind(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(), callback, result)); } -void MockClientSocket::RunOldCallback(net::OldCompletionCallback* callback, - int result) { - if (callback) - callback->Run(result); -} void MockClientSocket::RunCallback(const net::CompletionCallback& callback, int result) { if (!callback.is_null()) callback.Run(result); } -MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, +MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, - net::SocketDataProvider* data) + SocketDataProvider* data) : MockClientSocket(net_log), addresses_(addresses), data_(data), read_offset_(0), num_bytes_read_(0), - read_data_(false, net::ERR_UNEXPECTED), + read_data_(false, ERR_UNEXPECTED), need_read_data_(true), peer_closed_connection_(false), pending_buf_(NULL), pending_buf_len_(0), - old_pending_callback_(NULL), was_used_to_convey_data_(false) { DCHECK(data_); data_->Reset(); @@ -746,42 +739,10 @@ MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, MockTCPClientSocket::~MockTCPClientSocket() {} -int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { +int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { if (!connected_) - return net::ERR_UNEXPECTED; - - // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_ == NULL); - - // Store our async IO data. - pending_buf_ = buf; - pending_buf_len_ = buf_len; - old_pending_callback_ = callback; - - if (need_read_data_) { - read_data_ = data_->GetNextRead(); - if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { - // This MockRead is just a marker to instruct us to set - // peer_closed_connection_. Skip it and get the next one. - read_data_ = data_->GetNextRead(); - peer_closed_connection_ = true; - } - // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility - // to complete the async IO manually later (via OnReadComplete). - if (read_data_.result == ERR_IO_PENDING) { - DCHECK(callback); // We need to be using async IO in this case. - return ERR_IO_PENDING; - } - need_read_data_ = false; - } - - return CompleteRead(); -} -int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) { - if (!connected_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; // If the buffer is already in use, a read is already in progress! DCHECK(pending_buf_ == NULL); @@ -812,55 +773,41 @@ int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, return CompleteRead(); } -int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { +int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { DCHECK(buf); DCHECK_GT(buf_len, 0); if (!connected_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); + MockWriteResult write_result = data_->OnWrite(data); was_used_to_convey_data_ = true; if (write_result.async) { RunCallbackAsync(callback, write_result.result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return write_result.result; } -int MockTCPClientSocket::Connect(net::OldCompletionCallback* callback) { - if (connected_) - return net::OK; - connected_ = true; - peer_closed_connection_ = false; - if (data_->connect_data().async) { - RunCallbackAsync(callback, data_->connect_data().result); - return net::ERR_IO_PENDING; - } - return data_->connect_data().result; -} -int MockTCPClientSocket::Connect(const net::CompletionCallback& callback) { +int MockTCPClientSocket::Connect(const CompletionCallback& callback) { if (connected_) - return net::OK; - + return OK; connected_ = true; peer_closed_connection_ = false; if (data_->connect_data().async) { RunCallbackAsync(callback, data_->connect_data().result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } - return data_->connect_data().result; } void MockTCPClientSocket::Disconnect() { MockClientSocket::Disconnect(); - old_pending_callback_ = NULL; pending_callback_.Reset(); } @@ -912,15 +859,9 @@ void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // let CompleteRead() schedule a callback. read_data_.async = false; - if (old_pending_callback_) { - net::OldCompletionCallback* callback = old_pending_callback_; - int rv = CompleteRead(); - RunOldCallback(callback, rv); - } else { - net::CompletionCallback callback = pending_callback_; - int rv = CompleteRead(); - RunCallback(callback, rv); - } + CompletionCallback callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); } int MockTCPClientSocket::CompleteRead() { @@ -930,13 +871,11 @@ int MockTCPClientSocket::CompleteRead() { was_used_to_convey_data_ = true; // Save the pending async IO data and reset our |pending_| state. - net::IOBuffer* buf = pending_buf_; + IOBuffer* buf = pending_buf_; int buf_len = pending_buf_len_; - net::OldCompletionCallback* old_callback = old_pending_callback_; - net::CompletionCallback callback = pending_callback_; + CompletionCallback callback = pending_callback_; pending_buf_ = NULL; pending_buf_len_ = 0; - old_pending_callback_ = NULL; pending_callback_.Reset(); int result = read_data_.result; @@ -958,27 +897,22 @@ int MockTCPClientSocket::CompleteRead() { } if (read_data_.async) { - DCHECK(old_callback || !callback.is_null()); - if (old_callback) - RunCallbackAsync(old_callback, result); - else - RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + DCHECK(!callback.is_null()); + RunCallbackAsync(callback, result); + return ERR_IO_PENDING; } return result; } DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( - net::NetLog* net_log, net::DeterministicSocketData* data) + net::NetLog* net_log, DeterministicSocketData* data) : MockClientSocket(net_log), write_pending_(false), - write_callback_(NULL), write_result_(0), read_data_(), read_buf_(NULL), read_buf_len_(0), read_pending_(false), - old_read_callback_(NULL), data_(data), was_used_to_convey_data_(false) {} @@ -987,7 +921,7 @@ DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} void DeterministicMockTCPClientSocket::CompleteWrite() { was_used_to_convey_data_ = true; write_pending_ = false; - write_callback_->Run(write_result_); + write_callback_.Run(write_result_); } int DeterministicMockTCPClientSocket::CompleteRead() { @@ -1012,32 +946,29 @@ int DeterministicMockTCPClientSocket::CompleteRead() { if (read_pending_) { read_pending_ = false; - if (old_read_callback_) - old_read_callback_->Run(result); - else - read_callback_.Run(result); + read_callback_.Run(result); } return result; } int DeterministicMockTCPClientSocket::Write( - net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { + IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(buf); DCHECK_GT(buf_len, 0); if (!connected_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); + MockWriteResult write_result = data_->OnWrite(data); if (write_result.async) { write_callback_ = callback; write_result_ = write_result.result; - DCHECK(write_callback_ != NULL); + DCHECK(!write_callback_.is_null()); write_pending_ = true; - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } was_used_to_convey_data_ = true; @@ -1046,32 +977,9 @@ int DeterministicMockTCPClientSocket::Write( } int DeterministicMockTCPClientSocket::Read( - net::IOBuffer* buf, int buf_len, net::OldCompletionCallback* callback) { - if (!connected_) - return net::ERR_UNEXPECTED; - - read_data_ = data_->GetNextRead(); - // The buffer should always be big enough to contain all the MockRead data. To - // use small buffers, split the data into multiple MockReads. - DCHECK_LE(read_data_.data_len, buf_len); - - read_buf_ = buf; - read_buf_len_ = buf_len; - old_read_callback_ = callback; - - if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) { - read_pending_ = true; - DCHECK(old_read_callback_); - return ERR_IO_PENDING; - } - - was_used_to_convey_data_ = true; - return CompleteRead(); -} -int DeterministicMockTCPClientSocket::Read( - net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { + IOBuffer* buf, int buf_len, const CompletionCallback& callback) { if (!connected_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; read_data_ = data_->GetNextRead(); // The buffer should always be big enough to contain all the MockRead data. To @@ -1094,27 +1002,14 @@ int DeterministicMockTCPClientSocket::Read( // TODO(erikchen): Support connect sequencing. int DeterministicMockTCPClientSocket::Connect( - net::OldCompletionCallback* callback) { - if (connected_) - return net::OK; - connected_ = true; - if (data_->connect_data().async) { - RunCallbackAsync(callback, data_->connect_data().result); - return net::ERR_IO_PENDING; - } - return data_->connect_data().result; -} -int DeterministicMockTCPClientSocket::Connect( - const net::CompletionCallback& callback) { + const CompletionCallback& callback) { if (connected_) - return net::OK; - + return OK; connected_ = true; if (data_->connect_data().async) { RunCallbackAsync(callback, data_->connect_data().result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } - return data_->connect_data().result; } @@ -1148,67 +1043,22 @@ base::TimeDelta DeterministicMockTCPClientSocket::GetConnectTimeMicros() const { void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} -class MockSSLClientSocket::OldConnectCallback - : public net::OldCompletionCallbackImpl< - MockSSLClientSocket::OldConnectCallback> { - public: - OldConnectCallback(MockSSLClientSocket *ssl_client_socket, - net::OldCompletionCallback* user_callback, - int rv) - : ALLOW_THIS_IN_INITIALIZER_LIST( - net::OldCompletionCallbackImpl< - MockSSLClientSocket::OldConnectCallback>( - this, &OldConnectCallback::Wrapper)), - ssl_client_socket_(ssl_client_socket), - user_callback_(user_callback), - rv_(rv) { - } - - private: - void Wrapper(int rv) { - if (rv_ == net::OK) - ssl_client_socket_->connected_ = true; - user_callback_->Run(rv_); - delete this; - } - - MockSSLClientSocket* ssl_client_socket_; - net::OldCompletionCallback* user_callback_; - int rv_; -}; -class MockSSLClientSocket::ConnectCallback { - public: - ConnectCallback(MockSSLClientSocket *ssl_client_socket, - const CompletionCallback& user_callback, - int rv) - : ALLOW_THIS_IN_INITIALIZER_LIST(callback_( - base::Bind(&ConnectCallback::Wrapper, base::Unretained(this)))), - ssl_client_socket_(ssl_client_socket), - user_callback_(user_callback), - rv_(rv) { - } - - const CompletionCallback& callback() const { return callback_; } - - private: - void Wrapper(int rv) { - if (rv_ == net::OK) - ssl_client_socket_->connected_ = true; - user_callback_.Run(rv_); - } - - CompletionCallback callback_; - MockSSLClientSocket* ssl_client_socket_; - CompletionCallback user_callback_; - int rv_; -}; +// static +void MockSSLClientSocket::ConnectCallback( + MockSSLClientSocket *ssl_client_socket, + const CompletionCallback& callback, + int rv) { + if (rv == OK) + ssl_client_socket->connected_ = true; + callback.Run(rv); +} MockSSLClientSocket::MockSSLClientSocket( - net::ClientSocketHandle* transport_socket, + ClientSocketHandle* transport_socket, const HostPortPair& host_port_pair, - const net::SSLConfig& ssl_config, + const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - net::SSLSocketDataProvider* data) + SSLSocketDataProvider* data) : MockClientSocket(transport_socket->socket()->NetLog().net_log()), transport_(transport_socket), data_(data), @@ -1222,45 +1072,25 @@ MockSSLClientSocket::~MockSSLClientSocket() { Disconnect(); } -int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { - return transport_->socket()->Read(buf, buf_len, callback); -} -int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) { +int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { return transport_->socket()->Read(buf, buf_len, callback); } -int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { +int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { return transport_->socket()->Write(buf, buf_len, callback); } -int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { - OldConnectCallback* connect_callback = new OldConnectCallback( - this, callback, data_->connect.result); - int rv = transport_->socket()->Connect(connect_callback); - if (rv == net::OK) { - delete connect_callback; - if (data_->connect.result == net::OK) - connected_ = true; - if (data_->connect.async) { - RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; - } - return data_->connect.result; - } - return rv; -} -int MockSSLClientSocket::Connect(const net::CompletionCallback& callback) { - ConnectCallback connect_callback(this, callback, data_->connect.result); - int rv = transport_->socket()->Connect(connect_callback.callback()); - if (rv == net::OK) { - if (data_->connect.result == net::OK) +int MockSSLClientSocket::Connect(const CompletionCallback& callback) { + int rv = transport_->socket()->Connect( + base::Bind(&ConnectCallback, base::Unretained(this), callback)); + if (rv == OK) { + if (data_->connect.result == OK) connected_ = true; if (data_->connect.async) { RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return data_->connect.result; } @@ -1293,14 +1123,14 @@ base::TimeDelta MockSSLClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } -void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { +void MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); ssl_info->cert = data_->cert; ssl_info->client_cert_sent = data_->client_cert_sent; } void MockSSLClientSocket::GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info) { + SSLCertRequestInfo* cert_request_info) { DCHECK(cert_request_info); if (data_->cert_request_info) { cert_request_info->host_and_port = @@ -1338,12 +1168,11 @@ MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, : connected_(false), data_(data), read_offset_(0), - read_data_(false, net::ERR_UNEXPECTED), + read_data_(false, ERR_UNEXPECTED), need_read_data_(true), pending_buf_(NULL), pending_buf_len_(0), - old_pending_callback_(NULL), - net_log_(NetLog::Source(), net_log), + net_log_(net::NetLog::Source(), net_log), ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { DCHECK(data_); data_->Reset(); @@ -1351,36 +1180,10 @@ MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, MockUDPClientSocket::~MockUDPClientSocket() {} -int MockUDPClientSocket::Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { +int MockUDPClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { if (!connected_) - return net::ERR_UNEXPECTED; - - // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_ == NULL); - - // Store our async IO data. - pending_buf_ = buf; - pending_buf_len_ = buf_len; - old_pending_callback_ = callback; - - if (need_read_data_) { - read_data_ = data_->GetNextRead(); - // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility - // to complete the async IO manually later (via OnReadComplete). - if (read_data_.result == ERR_IO_PENDING) { - DCHECK(callback); // We need to be using async IO in this case. - return ERR_IO_PENDING; - } - need_read_data_ = false; - } - - return CompleteRead(); -} -int MockUDPClientSocket::Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) { - if (!connected_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; // If the buffer is already in use, a read is already in progress! DCHECK(pending_buf_ == NULL); @@ -1405,8 +1208,8 @@ int MockUDPClientSocket::Read(net::IOBuffer* buf, int buf_len, return CompleteRead(); } -int MockUDPClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { +int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { DCHECK(buf); DCHECK_GT(buf_len, 0); @@ -1469,15 +1272,9 @@ void MockUDPClientSocket::OnReadComplete(const MockRead& data) { // let CompleteRead() schedule a callback. read_data_.async = false; - if (old_pending_callback_) { - net::OldCompletionCallback* callback = old_pending_callback_; - int rv = CompleteRead(); - RunOldCallback(callback, rv); - } else { - net::CompletionCallback callback = pending_callback_; - int rv = CompleteRead(); - RunCallback(callback, rv); - } + net::CompletionCallback callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); } int MockUDPClientSocket::CompleteRead() { @@ -1485,14 +1282,11 @@ int MockUDPClientSocket::CompleteRead() { DCHECK(pending_buf_len_ > 0); // Save the pending async IO data and reset our |pending_| state. - net::IOBuffer* buf = pending_buf_; + IOBuffer* buf = pending_buf_; int buf_len = pending_buf_len_; - net::OldCompletionCallback* old_callback = old_pending_callback_; - net::CompletionCallback callback = pending_callback_; + CompletionCallback callback = pending_callback_; pending_buf_ = NULL; pending_buf_len_ = 0; - old_pending_callback_ = NULL; - pending_callback_.Reset(); pending_callback_.Reset(); int result = read_data_.result; @@ -1513,35 +1307,22 @@ int MockUDPClientSocket::CompleteRead() { } if (read_data_.async) { - DCHECK(old_callback || !callback.is_null()); - if (old_callback) - RunCallbackAsync(old_callback, result); - else - RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + DCHECK(!callback.is_null()); + RunCallbackAsync(callback, result); + return ERR_IO_PENDING; } return result; } -void MockUDPClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, +void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback, int result) { - MessageLoop::current()->PostTask(FROM_HERE, - base::Bind(&MockUDPClientSocket::RunOldCallback, - weak_factory_.GetWeakPtr(), callback, result)); -} -void MockUDPClientSocket::RunCallbackAsync( - const net::CompletionCallback& callback, int result) { - MessageLoop::current()->PostTask(FROM_HERE, + MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&MockUDPClientSocket::RunCallback, weak_factory_.GetWeakPtr(), callback, result)); } -void MockUDPClientSocket::RunOldCallback(net::OldCompletionCallback* callback, - int result) { - if (callback) - callback->Run(result); -} -void MockUDPClientSocket::RunCallback(const net::CompletionCallback& callback, +void MockUDPClientSocket::RunCallback(const CompletionCallback& callback, int result) { if (!callback.is_null()) callback.Run(result); @@ -1614,20 +1395,19 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { MockTransportClientSocketPool::MockConnectJob::MockConnectJob( StreamSocket* socket, ClientSocketHandle* handle, - OldCompletionCallback* callback) + const CompletionCallback& callback) : socket_(socket), handle_(handle), - user_callback_(callback), - ALLOW_THIS_IN_INITIALIZER_LIST( - connect_callback_(this, &MockConnectJob::OnConnect)) { + user_callback_(callback) { } MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {} int MockTransportClientSocketPool::MockConnectJob::Connect() { - int rv = socket_->Connect(&connect_callback_); + int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect, + base::Unretained(this))); if (rv == OK) { - user_callback_ = NULL; + user_callback_.Reset(); OnConnect(OK); } return rv; @@ -1639,7 +1419,7 @@ bool MockTransportClientSocketPool::MockConnectJob::CancelHandle( return false; socket_.reset(); handle_ = NULL; - user_callback_ = NULL; + user_callback_.Reset(); return true; } @@ -1654,10 +1434,10 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { handle_ = NULL; - if (user_callback_) { - OldCompletionCallback* callback = user_callback_; - user_callback_ = NULL; - callback->Run(rv); + if (!user_callback_.is_null()) { + CompletionCallback callback = user_callback_; + user_callback_.Reset(); + callback.Run(rv); } } @@ -1675,15 +1455,21 @@ MockTransportClientSocketPool::MockTransportClientSocketPool( MockTransportClientSocketPool::~MockTransportClientSocketPool() {} -int MockTransportClientSocketPool::RequestSocket(const std::string& group_name, - const void* socket_params, - RequestPriority priority, - ClientSocketHandle* handle, - OldCompletionCallback* callback, - const BoundNetLog& net_log) { +int MockTransportClientSocketPool::RequestSocket( + const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + OldCompletionCallback* callback, + const BoundNetLog& net_log) { StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket( AddressList(), net_log.net_log(), net::NetLog::Source()); - MockConnectJob* job = new MockConnectJob(socket, handle, callback); + CompletionCallback cb; + if (callback) { + cb = base::Bind(&OldCompletionCallback::Run<int>, + base::Unretained(callback)); + } + MockConnectJob* job = new MockConnectJob(socket, handle, cb); job_list_.push_back(job); handle->set_pool_id(1); return job->Connect(); @@ -1736,7 +1522,7 @@ DatagramClientSocket* DeterministicMockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, - NetLog* net_log, + net::NetLog* net_log, const NetLog::Source& source) { NOTREACHED(); return NULL; diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 73ffc3d..84c9c05 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -19,7 +19,6 @@ #include "base/memory/weak_ptr.h" #include "base/string16.h" #include "net/base/address_list.h" -#include "net/base/completion_callback.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/net_log.h" @@ -267,7 +266,7 @@ struct SSLSocketDataProvider { std::string server_protos; bool was_npn_negotiated; bool client_cert_sent; - net::SSLCertRequestInfo* cert_request_info; + SSLCertRequestInfo* cert_request_info; scoped_refptr<X509Certificate> cert; }; @@ -310,7 +309,7 @@ class DelayedSocketData : public StaticSocketDataProvider, private: int write_delay_; - ScopedRunnableMethodFactory<DelayedSocketData> factory_; + base::WeakPtrFactory<DelayedSocketData> weak_factory_; }; // A DataProvider where the reads are ordered. @@ -345,7 +344,7 @@ class OrderedSocketData : public StaticSocketDataProvider, MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count); - void SetOldCompletionCallback(OldCompletionCallback* callback) { + void SetCompletionCallback(const CompletionCallback& callback) { callback_ = callback; } @@ -364,9 +363,9 @@ class OrderedSocketData : public StaticSocketDataProvider, int sequence_number_; int loop_stop_stage_; - OldCompletionCallback* callback_; + CompletionCallback callback_; bool blocked_; - ScopedRunnableMethodFactory<OrderedSocketData> factory_; + base::WeakPtrFactory<OrderedSocketData> weak_factory_; }; class DeterministicMockTCPClientSocket; @@ -581,23 +580,21 @@ class MockClientSocketFactory : public ClientSocketFactory { std::vector<MockSSLClientSocket*> ssl_client_sockets_; }; -class MockClientSocket : public net::SSLClientSocket { +class MockClientSocket : public SSLClientSocket { public: + // TODO(ajwong): Why do we need net::NetLog? explicit MockClientSocket(net::NetLog* net_log); // Socket implementation. - virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) = 0; - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) = 0; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) = 0; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; // StreamSocket implementation. - virtual int Connect(net::OldCompletionCallback* callback) = 0; - virtual int Connect(const net::CompletionCallback& callback) = 0; + virtual int Connect(const CompletionCallback& callback) = 0; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -608,9 +605,9 @@ class MockClientSocket : public net::SSLClientSocket { virtual void SetOmniboxSpeculation() OVERRIDE {} // SSLClientSocket implementation. - virtual void GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; + virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info) OVERRIDE; + SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual int ExportKeyingMaterial(const base::StringPiece& label, const base::StringPiece& context, unsigned char *out, @@ -620,38 +617,33 @@ class MockClientSocket : public net::SSLClientSocket { protected: virtual ~MockClientSocket(); - void RunCallbackAsync(net::OldCompletionCallback* callback, int result); - void RunCallbackAsync(const net::CompletionCallback& callback, int result); - void RunOldCallback(net::OldCompletionCallback*, int result); - void RunCallback(const net::CompletionCallback&, int result); + void RunCallbackAsync(const CompletionCallback& callback, int result); + void RunCallback(const CompletionCallback& callback, int result); base::WeakPtrFactory<MockClientSocket> weak_factory_; // True if Connect completed successfully and Disconnect hasn't been called. bool connected_; - net::BoundNetLog net_log_; + BoundNetLog net_log_; }; class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { public: - MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log, - net::SocketDataProvider* socket); + MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, + SocketDataProvider* socket); virtual ~MockTCPClientSocket(); - net::AddressList addresses() const { return addresses_; } + AddressList addresses() const { return addresses_; } // Socket implementation. - virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. - virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; - virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -667,12 +659,12 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { private: int CompleteRead(); - net::AddressList addresses_; + AddressList addresses_; - net::SocketDataProvider* data_; + SocketDataProvider* data_; int read_offset_; int num_bytes_read_; - net::MockRead read_data_; + MockRead read_data_; bool need_read_data_; // True if the peer has closed the connection. This allows us to simulate @@ -681,10 +673,9 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { bool peer_closed_connection_; // While an asynchronous IO is pending, we save our user-buffer state. - net::IOBuffer* pending_buf_; + IOBuffer* pending_buf_; int pending_buf_len_; - net::OldCompletionCallback* old_pending_callback_; - net::CompletionCallback pending_callback_; + CompletionCallback pending_callback_; bool was_used_to_convey_data_; }; @@ -693,7 +684,7 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> { public: DeterministicMockTCPClientSocket(net::NetLog* net_log, - net::DeterministicSocketData* data); + DeterministicSocketData* data); virtual ~DeterministicMockTCPClientSocket(); bool write_pending() const { return write_pending_; } @@ -703,16 +694,14 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, int CompleteRead(); // Socket implementation. - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; - virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; + // Socket: + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. - virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; - virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -726,41 +715,37 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, private: bool write_pending_; - net::OldCompletionCallback* write_callback_; + CompletionCallback write_callback_; int write_result_; - net::MockRead read_data_; + MockRead read_data_; - net::IOBuffer* read_buf_; + IOBuffer* read_buf_; int read_buf_len_; bool read_pending_; - net::OldCompletionCallback* old_read_callback_; - net::CompletionCallback read_callback_; - net::DeterministicSocketData* data_; + CompletionCallback read_callback_; + DeterministicSocketData* data_; bool was_used_to_convey_data_; }; class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { public: MockSSLClientSocket( - net::ClientSocketHandle* transport_socket, + ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const net::SSLConfig& ssl_config, + const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - net::SSLSocketDataProvider* socket); + SSLSocketDataProvider* socket); virtual ~MockSSLClientSocket(); // Socket implementation. - virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; // StreamSocket implementation. - virtual int Connect(net::OldCompletionCallback* callback) OVERRIDE; - virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool WasEverUsed() const OVERRIDE; @@ -769,9 +754,9 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; // SSLClientSocket implementation. - virtual void GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; + virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info) OVERRIDE; + SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos) OVERRIDE; virtual bool was_npn_negotiated() const OVERRIDE; @@ -781,11 +766,12 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual void OnReadComplete(const MockRead& data) OVERRIDE; private: - class OldConnectCallback; - class ConnectCallback; + static void ConnectCallback(MockSSLClientSocket *ssl_client_socket, + const CompletionCallback& callback, + int rv); scoped_ptr<ClientSocketHandle> transport_; - net::SSLSocketDataProvider* data_; + SSLSocketDataProvider* data_; bool is_npn_state_set_; bool new_npn_value_; bool was_used_to_convey_data_; @@ -798,12 +784,10 @@ class MockUDPClientSocket : public DatagramClientSocket, virtual ~MockUDPClientSocket(); // Socket implementation. - virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -822,22 +806,19 @@ class MockUDPClientSocket : public DatagramClientSocket, private: int CompleteRead(); - void RunCallbackAsync(net::OldCompletionCallback* callback, int result); - void RunCallbackAsync(const net::CompletionCallback& callback, int result); - void RunOldCallback(net::OldCompletionCallback* callback, int result); - void RunCallback(const net::CompletionCallback& callback, int result); + void RunCallbackAsync(const CompletionCallback& callback, int result); + void RunCallback(const CompletionCallback& callback, int result); bool connected_; SocketDataProvider* data_; int read_offset_; - net::MockRead read_data_; + MockRead read_data_; bool need_read_data_; // While an asynchronous IO is pending, we save our user-buffer state. - net::IOBuffer* pending_buf_; + IOBuffer* pending_buf_; int pending_buf_len_; - net::OldCompletionCallback* old_pending_callback_; - net::CompletionCallback pending_callback_; + CompletionCallback pending_callback_; BoundNetLog net_log_; @@ -926,7 +907,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { class MockConnectJob { public: MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle, - OldCompletionCallback* callback); + const CompletionCallback& callback); ~MockConnectJob(); int Connect(); @@ -937,8 +918,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { scoped_ptr<StreamSocket> socket_; ClientSocketHandle* handle_; - OldCompletionCallback* user_callback_; - OldCompletionCallbackImpl<MockConnectJob> connect_callback_; + CompletionCallback user_callback_; DISALLOW_COPY_AND_ASSIGN(MockConnectJob); }; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index ea5fc7a..b8b3439 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -31,10 +31,10 @@ SOCKS5ClientSocket::SOCKS5ClientSocket( ClientSocketHandle* transport_socket, const HostResolver::RequestInfo& req_info) : ALLOW_THIS_IN_INITIALIZER_LIST( - io_callback_(this, &SOCKS5ClientSocket::OnIOComplete)), + io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, + base::Unretained(this)))), transport_(transport_socket), next_state_(STATE_NONE), - old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -47,10 +47,10 @@ SOCKS5ClientSocket::SOCKS5ClientSocket( StreamSocket* transport_socket, const HostResolver::RequestInfo& req_info) : ALLOW_THIS_IN_INITIALIZER_LIST( - io_callback_(this, &SOCKS5ClientSocket::OnIOComplete)), + io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, + base::Unretained(this)))), transport_(new ClientSocketHandle()), next_state_(STATE_NONE), - old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -64,11 +64,11 @@ SOCKS5ClientSocket::~SOCKS5ClientSocket() { Disconnect(); } -int SOCKS5ClientSocket::Connect(OldCompletionCallback* callback) { +int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) { DCHECK(transport_.get()); DCHECK(transport_->socket()); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); + DCHECK(user_callback_.is_null()); // If already connected, then just return OK. if (completed_handshake_) @@ -81,35 +81,12 @@ int SOCKS5ClientSocket::Connect(OldCompletionCallback* callback) { int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) { - old_user_callback_ = callback; + user_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv); } return rv; } -int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) { - DCHECK(transport_.get()); - DCHECK(transport_->socket()); - DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); - - // If already connected, then just return OK. - if (completed_handshake_) - return OK; - - net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT, NULL); - - next_state_ = STATE_GREET_WRITE; - buffer_.clear(); - - int rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) - user_callback_ = callback; - else - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv); - - return rv; -} void SOCKS5ClientSocket::Disconnect() { completed_handshake_ = false; @@ -118,7 +95,6 @@ void SOCKS5ClientSocket::Disconnect() { // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). next_state_ = STATE_NONE; - old_user_callback_ = NULL; user_callback_.Reset(); } @@ -185,18 +161,10 @@ base::TimeDelta SOCKS5ClientSocket::GetConnectTimeMicros() const { // Read is called by the transport layer above to read. This can only be done // if the SOCKS handshake is complete. int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(completed_handshake_); - DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); - - return transport_->socket()->Read(buf, buf_len, callback); -} -int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); + DCHECK(user_callback_.is_null()); return transport_->socket()->Read(buf, buf_len, callback); } @@ -204,10 +172,10 @@ int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, // Write is called by the transport layer. This can only be done if the // SOCKS handshake is complete. int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } @@ -222,19 +190,13 @@ bool SOCKS5ClientSocket::SetSendBufferSize(int32 size) { void SOCKS5ClientSocket::DoCallback(int result) { DCHECK_NE(ERR_IO_PENDING, result); - DCHECK(old_user_callback_ || !user_callback_.is_null()); + DCHECK(!user_callback_.is_null()); // Since Run() may result in Read being called, // clear user_callback_ up front. - if (old_user_callback_) { - OldCompletionCallback* c = old_user_callback_; - old_user_callback_ = NULL; - c->Run(result); - } else { - CompletionCallback c = user_callback_; - user_callback_.Reset(); - c.Run(result); - } + CompletionCallback c = user_callback_; + user_callback_.Reset(); + c.Run(result); } void SOCKS5ClientSocket::OnIOComplete(int result) { @@ -323,7 +285,7 @@ int SOCKS5ClientSocket::DoGreetWrite() { memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], handshake_buf_len); return transport_->socket()->Write(handshake_buf_, handshake_buf_len, - &io_callback_); + io_callback_); } int SOCKS5ClientSocket::DoGreetWriteComplete(int result) { @@ -346,7 +308,7 @@ int SOCKS5ClientSocket::DoGreetRead() { size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_; handshake_buf_ = new IOBuffer(handshake_buf_len); return transport_->socket()->Read(handshake_buf_, handshake_buf_len, - &io_callback_); + io_callback_); } int SOCKS5ClientSocket::DoGreetReadComplete(int result) { @@ -424,7 +386,7 @@ int SOCKS5ClientSocket::DoHandshakeWrite() { memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); return transport_->socket()->Write(handshake_buf_, handshake_buf_len, - &io_callback_); + io_callback_); } int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) { @@ -458,7 +420,7 @@ int SOCKS5ClientSocket::DoHandshakeRead() { int handshake_buf_len = read_header_size - bytes_received_; handshake_buf_ = new IOBuffer(handshake_buf_len); return transport_->socket()->Read(handshake_buf_, handshake_buf_len, - &io_callback_); + io_callback_); } int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) { diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index b83a347..84bb325 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -51,7 +51,6 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { // StreamSocket implementation. // Does the SOCKS handshake and completes the protocol. - virtual int Connect(OldCompletionCallback* callback) OVERRIDE; virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; @@ -67,13 +66,10 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; - virtual int Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -125,7 +121,7 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { // and return OK on success. int BuildHandshakeWriteBuffer(std::string* handshake) const; - OldCompletionCallbackImpl<SOCKS5ClientSocket> io_callback_; + CompletionCallback io_callback_; // Stores the underlying socket. scoped_ptr<ClientSocketHandle> transport_; @@ -133,7 +129,6 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { State next_state_; // Stores the callback to the layer above, called on completing Connect(). - OldCompletionCallback* old_user_callback_; CompletionCallback user_callback_; // This IOBuffer is used by the class to read and write diff --git a/net/socket/socks5_client_socket_unittest.cc b/net/socket/socks5_client_socket_unittest.cc index 85b07b8..72f015a 100644 --- a/net/socket/socks5_client_socket_unittest.cc +++ b/net/socket/socks5_client_socket_unittest.cc @@ -47,7 +47,7 @@ class SOCKS5ClientSocketTest : public PlatformTest { scoped_ptr<SOCKS5ClientSocket> user_sock_; AddressList address_list_; StreamSocket* tcp_sock_; - TestOldCompletionCallback callback_; + TestCompletionCallback callback_; scoped_ptr<MockHostResolver> host_resolver_; scoped_ptr<SocketDataProvider> data_; @@ -83,12 +83,12 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( const std::string& hostname, int port, NetLog* net_log) { - TestOldCompletionCallback callback; + TestCompletionCallback callback; data_.reset(new StaticSocketDataProvider(reads, reads_count, writes, writes_count)); tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); - int rv = tcp_sock_->Connect(&callback); + int rv = tcp_sock_->Connect(callback.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -131,7 +131,7 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { EXPECT_TRUE(tcp_sock_->IsConnected()); EXPECT_FALSE(user_sock_->IsConnected()); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(user_sock_->IsConnected()); @@ -151,13 +151,13 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); memcpy(buffer->data(), payload_write.data(), payload_write.size()); - rv = user_sock_->Write(buffer, payload_write.size(), &callback_); + rv = user_sock_->Write(buffer, payload_write.size(), callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback_.WaitForResult(); EXPECT_EQ(static_cast<int>(payload_write.size()), rv); buffer = new IOBuffer(payload_read.size()); - rv = user_sock_->Read(buffer, payload_read.size(), &callback_); + rv = user_sock_->Read(buffer, payload_read.size(), callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback_.WaitForResult(); EXPECT_EQ(static_cast<int>(payload_read.size()), rv); @@ -197,7 +197,7 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { data_writes, arraysize(data_writes), hostname, 80, NULL)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(OK, rv); EXPECT_TRUE(user_sock_->IsConnected()); @@ -221,8 +221,8 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { // Try to connect -- should fail (without having read/written anything to // the transport socket first) because the hostname is too long. - TestOldCompletionCallback callback; - int rv = user_sock_->Connect(&callback); + TestCompletionCallback callback; + int rv = user_sock_->Connect(callback.callback()); EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); } @@ -254,7 +254,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80, &net_log_)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList net_log_entries; @@ -285,7 +285,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80, &net_log_)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList net_log_entries; @@ -315,7 +315,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80, &net_log_)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList net_log_entries; net_log_.GetEntries(&net_log_entries); @@ -346,7 +346,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), hostname, 80, &net_log_)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList net_log_entries; net_log_.GetEntries(&net_log_entries); diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 623f202..4c368c1 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -58,11 +58,8 @@ COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, const HostResolver::RequestInfo& req_info, HostResolver* host_resolver) - : ALLOW_THIS_IN_INITIALIZER_LIST( - io_callback_(this, &SOCKSClientSocket::OnIOComplete)), - transport_(transport_socket), + : transport_(transport_socket), next_state_(STATE_NONE), - old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -74,11 +71,8 @@ SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket, const HostResolver::RequestInfo& req_info, HostResolver* host_resolver) - : ALLOW_THIS_IN_INITIALIZER_LIST( - io_callback_(this, &SOCKSClientSocket::OnIOComplete)), - transport_(new ClientSocketHandle()), + : transport_(new ClientSocketHandle()), next_state_(STATE_NONE), - old_user_callback_(NULL), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -92,11 +86,11 @@ SOCKSClientSocket::~SOCKSClientSocket() { Disconnect(); } -int SOCKSClientSocket::Connect(OldCompletionCallback* callback) { +int SOCKSClientSocket::Connect(const CompletionCallback& callback) { DCHECK(transport_.get()); DCHECK(transport_->socket()); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); + DCHECK(user_callback_.is_null()); // If already connected, then just return OK. if (completed_handshake_) @@ -108,34 +102,12 @@ int SOCKSClientSocket::Connect(OldCompletionCallback* callback) { int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) { - old_user_callback_ = callback; + user_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); } return rv; } -int SOCKSClientSocket::Connect(const net::CompletionCallback& callback) { - DCHECK(transport_.get()); - DCHECK(transport_->socket()); - DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); - - // If already connected, then just return OK. - if (completed_handshake_) - return OK; - - next_state_ = STATE_RESOLVE_HOST; - - net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT, NULL); - - int rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) - user_callback_ = callback; - else - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); - - return rv; -} void SOCKSClientSocket::Disconnect() { completed_handshake_ = false; @@ -145,7 +117,6 @@ void SOCKSClientSocket::Disconnect() { // Reset other states to make sure they aren't mistakenly used later. // These are the states initialized by Connect(). next_state_ = STATE_NONE; - old_user_callback_ = NULL; user_callback_.Reset(); } @@ -213,18 +184,10 @@ base::TimeDelta SOCKSClientSocket::GetConnectTimeMicros() const { // Read is called by the transport layer above to read. This can only be done // if the SOCKS handshake is complete. int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(completed_handshake_); - DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); - - return transport_->socket()->Read(buf, buf_len, callback); -} -int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_ && user_callback_.is_null()); + DCHECK(user_callback_.is_null()); return transport_->socket()->Read(buf, buf_len, callback); } @@ -232,10 +195,10 @@ int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, // Write is called by the transport layer. This can only be done if the // SOCKS handshake is complete. int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { DCHECK(completed_handshake_); DCHECK_EQ(STATE_NONE, next_state_); - DCHECK(!old_user_callback_); + DCHECK(user_callback_.is_null()); return transport_->socket()->Write(buf, buf_len, callback); } @@ -250,21 +213,14 @@ bool SOCKSClientSocket::SetSendBufferSize(int32 size) { void SOCKSClientSocket::DoCallback(int result) { DCHECK_NE(ERR_IO_PENDING, result); - DCHECK(old_user_callback_ || !user_callback_.is_null()); + DCHECK(!user_callback_.is_null()); // Since Run() may result in Read being called, // clear user_callback_ up front. - if (old_user_callback_) { - OldCompletionCallback* c = old_user_callback_; - old_user_callback_ = NULL; - DVLOG(1) << "Finished setting up SOCKS handshake"; - c->Run(result); - } else { - CompletionCallback c = user_callback_; - user_callback_.Reset(); - DVLOG(1) << "Finished setting up SOCKS handshake"; - c.Run(result); - } + CompletionCallback c = user_callback_; + user_callback_.Reset(); + DVLOG(1) << "Finished setting up SOCKS handshake"; + c.Run(result); } void SOCKSClientSocket::OnIOComplete(int result) { @@ -379,8 +335,9 @@ int SOCKSClientSocket::DoHandshakeWrite() { handshake_buf_ = new IOBuffer(handshake_buf_len); memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); - return transport_->socket()->Write(handshake_buf_, handshake_buf_len, - &io_callback_); + return transport_->socket()->Write( + handshake_buf_, handshake_buf_len, + base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); } int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { @@ -413,7 +370,8 @@ int SOCKSClientSocket::DoHandshakeRead() { int handshake_buf_len = kReadHeaderSize - bytes_received_; handshake_buf_ = new IOBuffer(handshake_buf_len); return transport_->socket()->Read(handshake_buf_, handshake_buf_len, - &io_callback_); + base::Bind(&SOCKSClientSocket::OnIOComplete, + base::Unretained(this))); } int SOCKSClientSocket::DoHandshakeReadComplete(int result) { diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 1a4a75c..fb88cd2 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -48,8 +48,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { // StreamSocket implementation. // Does the SOCKS handshake and completes the protocol. - virtual int Connect(OldCompletionCallback* callback) OVERRIDE; - virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; virtual bool IsConnectedAndIdle() const OVERRIDE; @@ -64,13 +63,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; - virtual int Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -106,15 +102,12 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { const std::string BuildHandshakeWriteBuffer() const; - OldCompletionCallbackImpl<SOCKSClientSocket> io_callback_; - // Stores the underlying socket. scoped_ptr<ClientSocketHandle> transport_; State next_state_; // Stores the callback to the layer above, called on completing Connect(). - OldCompletionCallback* old_user_callback_; CompletionCallback user_callback_; // This IOBuffer is used by the class to read and write diff --git a/net/socket/socks_client_socket_pool.cc b/net/socket/socks_client_socket_pool.cc index 5fa52bf..2602a3e 100644 --- a/net/socket/socks_client_socket_pool.cc +++ b/net/socket/socks_client_socket_pool.cc @@ -51,7 +51,7 @@ SOCKSConnectJob::SOCKSConnectJob( transport_pool_(transport_pool), resolver_(host_resolver), ALLOW_THIS_IN_INITIALIZER_LIST( - callback_(this, &SOCKSConnectJob::OnIOComplete)) { + callback_old_(this, &SOCKSConnectJob::OnIOComplete)) { } SOCKSConnectJob::~SOCKSConnectJob() { @@ -118,7 +118,7 @@ int SOCKSConnectJob::DoTransportConnect() { return transport_socket_handle_->Init(group_name(), socks_params_->transport_params(), socks_params_->destination().priority(), - &callback_, + &callback_old_, transport_pool_, net_log()); } @@ -147,7 +147,8 @@ int SOCKSConnectJob::DoSOCKSConnect() { socks_params_->destination(), resolver_)); } - return socket_->Connect(&callback_); + return socket_->Connect( + base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))); } int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h index 501f3bf..422b772 100644 --- a/net/socket/socks_client_socket_pool.h +++ b/net/socket/socks_client_socket_pool.h @@ -98,7 +98,7 @@ class SOCKSConnectJob : public ConnectJob { HostResolver* const resolver_; State next_state_; - OldCompletionCallbackImpl<SOCKSConnectJob> callback_; + OldCompletionCallbackImpl<SOCKSConnectJob> callback_old_; scoped_ptr<ClientSocketHandle> transport_socket_handle_; scoped_ptr<StreamSocket> socket_; diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc index 6762e1bb..1a6db28 100644 --- a/net/socket/socks_client_socket_unittest.cc +++ b/net/socket/socks_client_socket_unittest.cc @@ -38,7 +38,7 @@ class SOCKSClientSocketTest : public PlatformTest { scoped_ptr<SOCKSClientSocket> user_sock_; AddressList address_list_; StreamSocket* tcp_sock_; - TestOldCompletionCallback callback_; + TestCompletionCallback callback_; scoped_ptr<MockHostResolver> host_resolver_; scoped_ptr<SocketDataProvider> data_; }; @@ -62,12 +62,12 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( int port, NetLog* net_log) { - TestOldCompletionCallback callback; + TestCompletionCallback callback; data_.reset(new StaticSocketDataProvider(reads, reads_count, writes, writes_count)); tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); - int rv = tcp_sock_->Connect(&callback); + int rv = tcp_sock_->Connect(callback.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -144,7 +144,7 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { EXPECT_TRUE(tcp_sock_->IsConnected()); EXPECT_FALSE(user_sock_->IsConnected()); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList entries; @@ -162,13 +162,13 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); memcpy(buffer->data(), payload_write.data(), payload_write.size()); - rv = user_sock_->Write(buffer, payload_write.size(), &callback_); + rv = user_sock_->Write(buffer, payload_write.size(), callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback_.WaitForResult(); EXPECT_EQ(static_cast<int>(payload_write.size()), rv); buffer = new IOBuffer(payload_read.size()); - rv = user_sock_->Read(buffer, payload_read.size(), &callback_); + rv = user_sock_->Read(buffer, payload_read.size(), callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback_.WaitForResult(); EXPECT_EQ(static_cast<int>(payload_read.size()), rv); @@ -213,7 +213,7 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { "localhost", 80, &log)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList entries; @@ -250,7 +250,7 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { "localhost", 80, &log)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -288,7 +288,7 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { "localhost", 80, &log)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -320,7 +320,7 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { "localhost", 80, &log)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -350,7 +350,7 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) { hostname, 80, &log)); - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -382,7 +382,7 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { NULL)); // Start connecting (will get stuck waiting for the host to resolve). - int rv = user_sock_->Connect(&callback_); + int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(user_sock_->IsConnected()); diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index f58d340..b6e03f7 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -524,16 +524,9 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) - : transport_read_callback_(this, - &SSLClientSocketMac::OnTransportReadComplete), - transport_write_callback_(this, - &SSLClientSocketMac::OnTransportWriteComplete), - transport_(transport_socket), + : transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - old_user_connect_callback_(NULL), - old_user_read_callback_(NULL), - user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), next_handshake_state_(STATE_NONE), @@ -555,32 +548,10 @@ SSLClientSocketMac::~SSLClientSocketMac() { Disconnect(); } -int SSLClientSocketMac::Connect(OldCompletionCallback* callback) { - DCHECK(transport_.get()); - DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); - - net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); - - int rv = InitializeSSLContext(); - if (rv != OK) { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - return rv; - } - - next_handshake_state_ = STATE_HANDSHAKE; - rv = DoHandshakeLoop(OK); - if (rv == ERR_IO_PENDING) { - old_user_connect_callback_ = callback; - } else { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - } - return rv; -} int SSLClientSocketMac::Connect(const CompletionCallback& callback) { DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); @@ -698,27 +669,9 @@ base::TimeDelta SSLClientSocketMac::GetConnectTimeMicros() const { } int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(completed_handshake()); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_read_buf_); - - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - int rv = DoPayloadRead(); - if (rv == ERR_IO_PENDING) { - old_user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } - return rv; -} -int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(completed_handshake()); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(user_read_callback_.is_null()); DCHECK(!user_read_buf_); user_read_buf_ = buf; @@ -735,9 +688,9 @@ int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, } int SSLClientSocketMac::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { DCHECK(completed_handshake()); - DCHECK(!user_write_callback_); + DCHECK(user_write_callback_.is_null()); DCHECK(!user_write_buf_); user_write_buf_ = buf; @@ -936,51 +889,37 @@ int SSLClientSocketMac::InitializeSSLContext() { void SSLClientSocketMac::DoConnectCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); + DCHECK(!user_connect_callback_.is_null()); - if (old_user_connect_callback_) { - OldCompletionCallback* c = old_user_connect_callback_; - old_user_connect_callback_ = NULL; - c->Run(rv > OK ? OK : rv); - } else { - CompletionCallback c = user_connect_callback_; - user_connect_callback_.Reset(); - c.Run(rv > OK ? OK : rv); - } + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); } void SSLClientSocketMac::DoReadCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); + DCHECK(!user_read_callback_.is_null()); // Since Run may result in Read being called, clear user_read_callback_ up // front. - if (old_user_read_callback_) { - OldCompletionCallback* c = old_user_read_callback_; - old_user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); - } else { - CompletionCallback c = user_read_callback_; - user_read_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c.Run(rv); - } + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); } void SSLClientSocketMac::DoWriteCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(!user_write_callback_.is_null()); // Since Run may result in Write being called, clear user_write_callback_ up // front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); user_write_buf_ = NULL; user_write_buf_len_ = 0; - c->Run(rv); + c.Run(rv); } void SSLClientSocketMac::OnHandshakeIOComplete(int result) { @@ -990,7 +929,7 @@ void SSLClientSocketMac::OnHandshakeIOComplete(int result) { // renegotiating (which occurs because we are in the middle of a Read // when the renegotiation process starts). So we complete the Read // here. - if (!old_user_connect_callback_ && user_connect_callback_.is_null()) { + if (user_connect_callback_.is_null()) { DoReadCallback(rv); return; } @@ -1328,7 +1267,7 @@ int SSLClientSocketMac::DoCompletedRenegotiation(int result) { } void SSLClientSocketMac::DidCompleteRenegotiation() { - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); renegotiating_ = false; next_handshake_state_ = STATE_COMPLETED_RENEGOTIATION; } @@ -1408,9 +1347,11 @@ OSStatus SSLClientSocketMac::SSLReadCallback(SSLConnectionRef connection, int rv = 1; // any old value to spin the loop below while (rv > 0 && total_read < *data_length) { us->read_io_buf_ = new IOBuffer(*data_length - total_read); - rv = us->transport_->socket()->Read(us->read_io_buf_, - *data_length - total_read, - &us->transport_read_callback_); + rv = us->transport_->socket()->Read( + us->read_io_buf_, + *data_length - total_read, + base::Bind(&SSLClientSocketMac::OnTransportReadComplete, + base::Unretained(us))); if (rv >= 0) { us->recv_buffer_.insert(us->recv_buffer_.end(), @@ -1470,9 +1411,11 @@ OSStatus SSLClientSocketMac::SSLWriteCallback(SSLConnectionRef connection, us->write_io_buf_ = new IOBuffer(us->send_buffer_.size()); memcpy(us->write_io_buf_->data(), &us->send_buffer_[0], us->send_buffer_.size()); - rv = us->transport_->socket()->Write(us->write_io_buf_, - us->send_buffer_.size(), - &us->transport_write_callback_); + rv = us->transport_->socket()->Write( + us->write_io_buf_, + us->send_buffer_.size(), + base::Bind(&SSLClientSocketMac::OnTransportWriteComplete, + base::Unretained(us))); if (rv > 0) { us->send_buffer_.erase(us->send_buffer_.begin(), us->send_buffer_.begin() + rv); diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index 7792cb3..3a10ae3 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -52,7 +52,6 @@ class SSLClientSocketMac : public SSLClientSocket { std::string* server_protos) OVERRIDE; // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) OVERRIDE; virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; @@ -70,13 +69,10 @@ class SSLClientSocketMac : public SSLClientSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; - virtual int Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -115,18 +111,13 @@ class SSLClientSocketMac : public SSLClientSocket { const void* data, size_t* data_length); - OldCompletionCallbackImpl<SSLClientSocketMac> transport_read_callback_; - OldCompletionCallbackImpl<SSLClientSocketMac> transport_write_callback_; - scoped_ptr<ClientSocketHandle> transport_; HostPortPair host_and_port_; SSLConfig ssl_config_; - OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; - OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 5279cf9..eb8c662 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -432,22 +432,12 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, const SSLClientSocketContext& context) - : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( - this, &SSLClientSocketNSS::BufferSendComplete)), - ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( - this, &SSLClientSocketNSS::BufferRecvComplete)), - transport_send_busy_(false), + : transport_send_busy_(false), transport_recv_busy_(false), corked_(false), - ALLOW_THIS_IN_INITIALIZER_LIST(handshake_io_callback_( - base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, - base::Unretained(this)))), transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - old_user_connect_callback_(NULL), - old_user_read_callback_(NULL), - user_write_callback_(NULL), user_read_buf_len_(0), user_write_buf_len_(0), server_cert_nss_(NULL), @@ -570,61 +560,13 @@ SSLClientSocketNSS::GetNextProto(std::string* proto, return next_proto_status_; } -int SSLClientSocketNSS::Connect(OldCompletionCallback* callback) { - EnterFunction(""); - DCHECK(transport_.get()); - DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_write_callback_); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); - DCHECK(!user_read_buf_); - DCHECK(!user_write_buf_); - - EnsureThreadIdAssigned(); - - net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); - - int rv = Init(); - if (rv != OK) { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - return rv; - } - - rv = InitializeSSLOptions(); - if (rv != OK) { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - return rv; - } - - rv = InitializeSSLPeerName(); - if (rv != OK) { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - return rv; - } - - if (ssl_config_.cached_info_enabled && ssl_host_info_.get()) { - GotoState(STATE_LOAD_SSL_HOST_INFO); - } else { - GotoState(STATE_HANDSHAKE); - } - - rv = DoHandshakeLoop(OK); - if (rv == ERR_IO_PENDING) { - old_user_connect_callback_ = callback; - } else { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - } - - LeaveFunction(""); - return rv > OK ? OK : rv; -} int SSLClientSocketNSS::Connect(const CompletionCallback& callback) { EnterFunction(""); DCHECK(transport_.get()); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_write_callback_); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(user_read_callback_.is_null()); + DCHECK(user_write_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(!user_write_buf_); @@ -672,8 +614,7 @@ void SSLClientSocketNSS::Disconnect() { CHECK(CalledOnValidThread()); - // Shut down anything that may call us back (through buffer_send_callback_, - // buffer_recv_callback, or handshake_io_callback_). + // Shut down anything that may call us back. verifier_.reset(); transport_->socket()->Disconnect(); @@ -688,14 +629,12 @@ void SSLClientSocketNSS::Disconnect() { nss_fd_ = NULL; } - // Reset object state - transport_send_busy_ = false; - transport_recv_busy_ = false; - old_user_connect_callback_ = NULL; + // Reset object state. user_connect_callback_.Reset(); - old_user_read_callback_ = NULL; user_read_callback_.Reset(); - user_write_callback_ = NULL; + user_write_callback_.Reset(); + transport_send_busy_ = false; + transport_recv_busy_ = false; user_read_buf_ = NULL; user_read_buf_len_ = 0; user_write_buf_ = NULL; @@ -810,36 +749,12 @@ base::TimeDelta SSLClientSocketNSS::GetConnectTimeMicros() const { } int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - EnterFunction(buf_len); - DCHECK(completed_handshake_); - DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); - DCHECK(!user_read_buf_); - DCHECK(nss_bufs_); - - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - int rv = DoReadLoop(OK); - - if (rv == ERR_IO_PENDING) { - old_user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } - LeaveFunction(rv); - return rv; -} -int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { EnterFunction(buf_len); DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(user_read_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); DCHECK(!user_read_buf_); DCHECK(nss_bufs_); @@ -859,12 +774,12 @@ int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, } int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { EnterFunction(buf_len); DCHECK(completed_handshake_); DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_write_callback_); - DCHECK(!old_user_connect_callback_); + DCHECK(user_write_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -1223,38 +1138,30 @@ void SSLClientSocketNSS::UpdateConnectionStatus() { void SSLClientSocketNSS::DoReadCallback(int rv) { EnterFunction(rv); DCHECK(rv != ERR_IO_PENDING); - DCHECK(old_user_read_callback_ || user_read_callback_.is_null()); + DCHECK(!user_read_callback_.is_null()); - // Since Run may result in Read being called, clear |old_user_read_callback_| + // Since Run may result in Read being called, clear |user_read_callback_| // up front. - if (old_user_read_callback_) { - OldCompletionCallback* c = old_user_read_callback_; - old_user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); - } else { - CompletionCallback c = user_read_callback_; - user_read_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c.Run(rv); - } + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); LeaveFunction(""); } void SSLClientSocketNSS::DoWriteCallback(int rv) { EnterFunction(rv); DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(!user_write_callback_.is_null()); // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); user_write_buf_ = NULL; user_write_buf_len_ = 0; - c->Run(rv); + c.Run(rv); LeaveFunction(""); } @@ -1268,17 +1175,11 @@ void SSLClientSocketNSS::DoWriteCallback(int rv) { void SSLClientSocketNSS::DoConnectCallback(int rv) { EnterFunction(rv); DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(old_user_connect_callback_ || !user_connect_callback_.is_null()); + DCHECK(!user_connect_callback_.is_null()); - if (old_user_connect_callback_) { - OldCompletionCallback* c = old_user_connect_callback_; - old_user_connect_callback_ = NULL; - c->Run(rv > OK ? OK : rv); - } else { - CompletionCallback c = user_connect_callback_; - user_connect_callback_.Reset(); - c.Run(rv > OK ? OK : rv); - } + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); LeaveFunction(""); } @@ -1483,7 +1384,9 @@ bool SSLClientSocketNSS::LoadSSLHostInfo() { int SSLClientSocketNSS::DoLoadSSLHostInfo() { EnterFunction(""); - int rv = ssl_host_info_->WaitForDataReady(handshake_io_callback_); + int rv = ssl_host_info_->WaitForDataReady( + base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, + base::Unretained(this))); GotoState(STATE_HANDSHAKE); if (rv == OK) { @@ -1786,7 +1689,9 @@ int SSLClientSocketNSS::DoVerifyCert(int result) { UMA_HISTOGRAM_TIMES("Net.SSLVerificationMergedMsSaved", end_time - ssl_host_info_->verification_start_time()); server_cert_verify_result_ = &ssl_host_info_->cert_verify_result(); - return ssl_host_info_->WaitForCertVerification(handshake_io_callback_); + return ssl_host_info_->WaitForCertVerification( + base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, + base::Unretained(this))); } else { UMA_HISTOGRAM_ENUMERATION("Net.SSLVerificationMerged", 0 /* false */, 2); } @@ -1889,7 +1794,7 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { completed_handshake_ = true; - if (old_user_read_callback_ || !user_read_callback_.is_null()) { + if (!user_read_callback_.is_null()) { int rv = DoReadLoop(OK); if (rv != ERR_IO_PENDING) DoReadCallback(rv); @@ -2106,8 +2011,10 @@ int SSLClientSocketNSS::BufferSend(void) { scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); memcpy(send_buffer->data(), buf1, len1); memcpy(send_buffer->data() + len1, buf2, len2); - rv = transport_->socket()->Write(send_buffer, len, - &buffer_send_callback_); + rv = transport_->socket()->Write( + send_buffer, len, + base::Bind(&SSLClientSocketNSS::BufferSendComplete, + base::Unretained(this))); if (rv == ERR_IO_PENDING) { transport_send_busy_ = true; } else { @@ -2139,7 +2046,10 @@ int SSLClientSocketNSS::BufferRecv(void) { rv = ERR_IO_PENDING; } else { recv_buffer_ = new IOBuffer(nb); - rv = transport_->socket()->Read(recv_buffer_, nb, &buffer_recv_callback_); + rv = transport_->socket()->Read( + recv_buffer_, nb, + base::Bind(&SSLClientSocketNSS::BufferRecvComplete, + base::Unretained(this))); if (rv == ERR_IO_PENDING) { transport_recv_busy_ = true; } else { diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 5c566e9..366aa7f 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -71,7 +71,6 @@ class SSLClientSocketNSS : public SSLClientSocket { std::string* server_protos) OVERRIDE; // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) OVERRIDE; virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; @@ -89,13 +88,10 @@ class SSLClientSocketNSS : public SSLClientSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; - virtual int Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -212,8 +208,6 @@ class SSLClientSocketNSS : public SSLClientSocket { void EnsureThreadIdAssigned() const; bool CalledOnValidThread() const; - OldCompletionCallbackImpl<SSLClientSocketNSS> buffer_send_callback_; - OldCompletionCallbackImpl<SSLClientSocketNSS> buffer_recv_callback_; bool transport_send_busy_; bool transport_recv_busy_; // corked_ is true if we are currently suspending writes to the network. This @@ -224,16 +218,13 @@ class SSLClientSocketNSS : public SSLClientSocket { base::OneShotTimer<SSLClientSocketNSS> uncork_timer_; scoped_refptr<IOBuffer> recv_buffer_; - CompletionCallback handshake_io_callback_; scoped_ptr<ClientSocketHandle> transport_; HostPortPair host_and_port_; SSLConfig ssl_config_; - OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; - OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index 1237348..f933031 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -384,15 +384,8 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) - : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( - this, &SSLClientSocketOpenSSL::BufferSendComplete)), - ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( - this, &SSLClientSocketOpenSSL::BufferRecvComplete)), - transport_send_busy_(false), + : transport_send_busy_(false), transport_recv_busy_(false), - old_user_connect_callback_(NULL), - old_user_read_callback_(NULL), - user_write_callback_(NULL), completed_handshake_(false), client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), @@ -614,56 +607,24 @@ SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( void SSLClientSocketOpenSSL::DoReadCallback(int rv) { // Since Run may result in Read being called, clear |user_read_callback_| // up front. - if (old_user_read_callback_) { - OldCompletionCallback* c = old_user_read_callback_; - old_user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); - } else { - CompletionCallback c = user_read_callback_; - user_read_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c.Run(rv); - } + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); } void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); user_write_buf_ = NULL; user_write_buf_len_ = 0; - c->Run(rv); + c.Run(rv); } -// StreamSocket methods - -int SSLClientSocketOpenSSL::Connect(OldCompletionCallback* callback) { - net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); - - // Set up new ssl object. - if (!Init()) { - int result = ERR_UNEXPECTED; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, result); - return result; - } - - // Set SSL to client mode. Handshake happens in the loop below. - SSL_set_connect_state(ssl_); - - GotoState(STATE_HANDSHAKE); - int rv = DoHandshakeLoop(net::OK); - if (rv == ERR_IO_PENDING) { - old_user_connect_callback_ = callback; - } else { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - } - - return rv > OK ? OK : rv; -} +// StreamSocket implementation. int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); @@ -708,11 +669,9 @@ void SSLClientSocketOpenSSL::Disconnect() { transport_recv_busy_ = false; recv_buffer_ = NULL; - old_user_connect_callback_ = NULL; user_connect_callback_.Reset(); - old_user_read_callback_ = NULL; user_read_callback_.Reset(); - user_write_callback_ = NULL; + user_write_callback_.Reset(); user_read_buf_ = NULL; user_read_buf_len_ = 0; user_write_buf_ = NULL; @@ -971,9 +930,11 @@ int SSLClientSocketOpenSSL::BufferSend(void) { int rv = 0; while (send_buffer_) { - rv = transport_->socket()->Write(send_buffer_, - send_buffer_->BytesRemaining(), - &buffer_send_callback_); + rv = transport_->socket()->Write( + send_buffer_, + send_buffer_->BytesRemaining(), + base::Bind(&SSLClientSocketOpenSSL::BufferSendComplete, + base::Unretained(this))); if (rv == ERR_IO_PENDING) { transport_send_busy_ = true; return rv; @@ -1017,8 +978,10 @@ int SSLClientSocketOpenSSL::BufferRecv(void) { return ERR_IO_PENDING; recv_buffer_ = new IOBuffer(max_write); - int rv = transport_->socket()->Read(recv_buffer_, max_write, - &buffer_recv_callback_); + int rv = transport_->socket()->Read( + recv_buffer_, max_write, + base::Bind(&SSLClientSocketOpenSSL::BufferRecvComplete, + base::Unretained(this))); if (rv == ERR_IO_PENDING) { transport_recv_busy_ = true; } else { @@ -1052,11 +1015,7 @@ void SSLClientSocketOpenSSL::TransportReadComplete(int result) { } void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { - if (old_user_connect_callback_) { - OldCompletionCallback* c = old_user_connect_callback_; - old_user_connect_callback_ = NULL; - c->Run(rv > OK ? OK : rv); - } else { + if (!user_connect_callback_.is_null()) { CompletionCallback c = user_connect_callback_; user_connect_callback_.Reset(); c.Run(rv > OK ? OK : rv); @@ -1190,23 +1149,6 @@ base::TimeDelta SSLClientSocketOpenSSL::GetConnectTimeMicros() const { int SSLClientSocketOpenSSL::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - int rv = DoReadLoop(OK); - - if (rv == ERR_IO_PENDING) { - old_user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } - - return rv; -} -int SSLClientSocketOpenSSL::Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) { user_read_buf_ = buf; user_read_buf_len_ = buf_len; @@ -1239,7 +1181,7 @@ int SSLClientSocketOpenSSL::DoReadLoop(int result) { int SSLClientSocketOpenSSL::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { user_write_buf_ = buf; user_write_buf_len_ = buf_len; diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index a15c0e3..e1e1778 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -63,7 +63,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { std::string* server_protos); // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback); virtual int Connect(const CompletionCallback& callback); virtual void Disconnect(); virtual bool IsConnected() const; @@ -79,10 +78,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual base::TimeDelta GetConnectTimeMicros() const; // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -115,18 +114,14 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { void TransportWriteComplete(int result); void TransportReadComplete(int result); - OldCompletionCallbackImpl<SSLClientSocketOpenSSL> buffer_send_callback_; - OldCompletionCallbackImpl<SSLClientSocketOpenSSL> buffer_recv_callback_; bool transport_send_busy_; scoped_refptr<DrainableIOBuffer> send_buffer_; bool transport_recv_busy_; scoped_refptr<IOBuffer> recv_buffer_; - OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; - OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc index dffe962..2cc1cde 100644 --- a/net/socket/ssl_client_socket_pool.cc +++ b/net/socket/ssl_client_socket_pool.cc @@ -94,7 +94,10 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, host_resolver_(host_resolver), context_(context), ALLOW_THIS_IN_INITIALIZER_LIST( - callback_(this, &SSLConnectJob::OnIOComplete)) {} + callback_(base::Bind(&SSLConnectJob::OnIOComplete, + base::Unretained(this)))), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_old_(this, &SSLConnectJob::OnIOComplete)) {} SSLConnectJob::~SSLConnectJob() {} @@ -119,7 +122,7 @@ LoadState SSLConnectJob::GetLoadState() const { } } -void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { +void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle* handle) { // Headers in |error_response_info_| indicate a proxy tunnel setup // problem. See DoTunnelConnectComplete. if (error_response_info_.headers) { @@ -207,7 +210,7 @@ int SSLConnectJob::DoTransportConnect() { group_name(), transport_params, transport_params->destination().priority(), - &callback_, transport_pool_, net_log()); + &callback_old_, transport_pool_, net_log()); } int SSLConnectJob::DoTransportConnectComplete(int result) { @@ -224,7 +227,7 @@ int SSLConnectJob::DoSOCKSConnect() { scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params(); return transport_socket_handle_->Init(group_name(), socks_params, socks_params->destination().priority(), - &callback_, socks_pool_, net_log()); + &callback_old_, socks_pool_, net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { @@ -243,7 +246,7 @@ int SSLConnectJob::DoTunnelConnect() { params_->http_proxy_params(); return transport_socket_handle_->Init( group_name(), http_proxy_params, - http_proxy_params->destination().priority(), &callback_, + http_proxy_params->destination().priority(), &callback_old_, http_proxy_pool_, net_log()); } @@ -276,7 +279,7 @@ int SSLConnectJob::DoSSLConnect() { ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( transport_socket_handle_.release(), params_->host_and_port(), params_->ssl_config(), ssl_host_info_.release(), context_)); - return ssl_socket_->Connect(&callback_); + return ssl_socket_->Connect(callback_); } int SSLConnectJob::DoSSLConnectComplete(int result) { diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h index 2ca42b5..cec2e25 100644 --- a/net/socket/ssl_client_socket_pool.h +++ b/net/socket/ssl_client_socket_pool.h @@ -152,7 +152,8 @@ class SSLConnectJob : public ConnectJob { const SSLClientSocketContext context_; State next_state_; - OldCompletionCallbackImpl<SSLConnectJob> callback_; + CompletionCallback callback_; + OldCompletionCallbackImpl<SSLConnectJob> callback_old_; scoped_ptr<ClientSocketHandle> transport_socket_handle_; scoped_ptr<SSLClientSocket> ssl_socket_; scoped_ptr<SSLHostInfo> ssl_host_info_; diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index 88af817..aadffe1 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -73,11 +73,11 @@ TEST_F(SSLClientSocketTest, Connect) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -91,7 +91,7 @@ TEST_F(SSLClientSocketTest, Connect) { EXPECT_FALSE(sock->IsConnected()); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -117,11 +117,11 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -132,7 +132,7 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { EXPECT_FALSE(sock->IsConnected()); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -160,11 +160,11 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -175,7 +175,7 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { EXPECT_FALSE(sock->IsConnected()); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -205,11 +205,11 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -220,7 +220,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { EXPECT_FALSE(sock->IsConnected()); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -265,11 +265,11 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -286,7 +286,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // Our test server accepts certificate-less connections. // TODO(davidben): Add a test which requires them and verify the error. - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); @@ -323,10 +323,10 @@ TEST_F(SSLClientSocketTest, Read) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::StreamSocket* transport = new net::TCPClientSocket( addr, NULL, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -335,7 +335,7 @@ TEST_F(SSLClientSocketTest, Read) { CreateSSLClientSocket(transport, test_server.host_port_pair(), kDefaultSSLConfig)); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -346,7 +346,8 @@ TEST_F(SSLClientSocketTest, Read) { new net::IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); - rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_buffer, arraysize(request_text) - 1, + callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -355,7 +356,7 @@ TEST_F(SSLClientSocketTest, Read) { scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); for (;;) { - rv = sock->Read(buf, 4096, &callback); + rv = sock->Read(buf, 4096, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -376,12 +377,11 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; // Used for everything except Write. - TestOldCompletionCallback callback2; // Used for Write only. + net::TestCompletionCallback callback; // Used for everything except Write. net::StreamSocket* transport = new net::TCPClientSocket( addr, NULL, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -393,7 +393,7 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL, context)); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -401,7 +401,7 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { // Issue a "hanging" Read first. scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); - rv = sock->Read(buf, 4096, &callback); + rv = sock->Read(buf, 4096, callback.callback()); // We haven't written the request, so there should be no response yet. ASSERT_EQ(net::ERR_IO_PENDING, rv); @@ -416,7 +416,8 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { scoped_refptr<net::IOBuffer> request_buffer( new net::StringIOBuffer(request_text)); - rv = sock->Write(request_buffer, request_text.size(), &callback2); + net::TestCompletionCallback callback2; // Used for Write only. + rv = sock->Write(request_buffer, request_text.size(), callback2.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -435,10 +436,10 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::StreamSocket* transport = new net::TCPClientSocket( addr, NULL, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -447,7 +448,7 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { CreateSSLClientSocket(transport, test_server.host_port_pair(), kDefaultSSLConfig)); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -457,7 +458,8 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { new net::IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); - rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_buffer, arraysize(request_text) - 1, + callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -466,7 +468,7 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(1)); for (;;) { - rv = sock->Read(buf, 1, &callback); + rv = sock->Read(buf, 1, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -485,10 +487,10 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::StreamSocket* transport = new net::TCPClientSocket( addr, NULL, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -497,7 +499,7 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { CreateSSLClientSocket(transport, test_server.host_port_pair(), kDefaultSSLConfig)); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -507,7 +509,8 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { new net::IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); - rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_buffer, arraysize(request_text) - 1, + callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -516,7 +519,7 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { // Do a partial read and then exit. This test should not crash! scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(512)); - rv = sock->Read(buf, 512, &callback); + rv = sock->Read(buf, 512, callback.callback()); EXPECT_TRUE(rv > 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -532,12 +535,12 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); log.SetLogLevel(net::NetLog::LOG_ALL); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -546,7 +549,7 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { CreateSSLClientSocket(transport, test_server.host_port_pair(), kDefaultSSLConfig)); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -557,7 +560,8 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { new net::IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); - rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_buffer, arraysize(request_text) - 1, + callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -572,7 +576,7 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); for (;;) { - rv = sock->Read(buf, 4096, &callback); + rv = sock->Read(buf, 4096, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -595,7 +599,7 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { ASSERT_TRUE(test_server.Start()); net::AddressList addr; - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; static const unsigned char application_data[] = { 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, @@ -622,7 +626,7 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { net::StreamSocket* transport = new net::MockTCPClientSocket(addr, NULL, &data); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -631,7 +635,7 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { CreateSSLClientSocket(transport, test_server.host_port_pair(), kDefaultSSLConfig)); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); } @@ -656,11 +660,11 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::CapturingNetLog log(net::CapturingNetLog::kUnbounded); net::StreamSocket* transport = new net::TCPClientSocket( addr, &log, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -675,7 +679,7 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { EXPECT_FALSE(sock->IsConnected()); - rv = sock->Connect(&callback); + rv = sock->Connect(callback.callback()); net::CapturingNetLog::EntryList entries; log.GetEntries(&entries); EXPECT_TRUE(net::LogContainsBeginEvent( @@ -725,10 +729,10 @@ TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { net::AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - TestOldCompletionCallback callback; + net::TestCompletionCallback callback; net::StreamSocket* transport = new net::TCPClientSocket( addr, NULL, net::NetLog::Source()); - int rv = transport->Connect(&callback); + int rv = transport->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); @@ -744,7 +748,7 @@ TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { NULL, context)); EXPECT_FALSE(ssl_socket->IsConnected()); - rv = ssl_socket->Connect(&callback); + rv = ssl_socket->Connect(callback.callback()); if (rv == net::ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 30f599d..e2509f2 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -387,20 +387,10 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) - : ALLOW_THIS_IN_INITIALIZER_LIST( - handshake_io_callback_(this, - &SSLClientSocketWin::OnHandshakeIOComplete)), - ALLOW_THIS_IN_INITIALIZER_LIST( - read_callback_(this, &SSLClientSocketWin::OnReadComplete)), - ALLOW_THIS_IN_INITIALIZER_LIST( - write_callback_(this, &SSLClientSocketWin::OnWriteComplete)), - transport_(transport_socket), + : transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - old_user_connect_callback_(NULL), - old_user_read_callback_(NULL), user_read_buf_len_(0), - user_write_callback_(NULL), user_write_buf_len_(0), next_state_(STATE_NONE), cert_verifier_(context.cert_verifier), @@ -562,33 +552,10 @@ SSLClientSocketWin::GetNextProto(std::string* proto, return kNextProtoUnsupported; } -int SSLClientSocketWin::Connect(OldCompletionCallback* callback) { - DCHECK(transport_.get()); - DCHECK(next_state_ == STATE_NONE); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); - - net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); - - int rv = InitializeSSLContext(); - if (rv != OK) { - net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); - return rv; - } - - writing_first_token_ = true; - next_state_ = STATE_HANDSHAKE_WRITE; - rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) { - old_user_connect_callback_ = callback; - } else { - net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); - } - return rv; -} int SSLClientSocketWin::Connect(const CompletionCallback& callback) { DCHECK(transport_.get()); DCHECK(next_state_ == STATE_NONE); - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT, NULL); @@ -784,50 +751,9 @@ base::TimeDelta SSLClientSocketWin::GetConnectTimeMicros() const { } int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(completed_handshake()); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - - // If we have surplus decrypted plaintext, satisfy the Read with it without - // reading more ciphertext from the transport socket. - if (bytes_decrypted_ != 0) { - int len = std::min(buf_len, bytes_decrypted_); - net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, len, - decrypted_ptr_); - memcpy(buf->data(), decrypted_ptr_, len); - decrypted_ptr_ += len; - bytes_decrypted_ -= len; - if (bytes_decrypted_ == 0) { - decrypted_ptr_ = NULL; - if (bytes_received_ != 0) { - memmove(recv_buffer_.get(), received_ptr_, bytes_received_); - received_ptr_ = recv_buffer_.get(); - } - } - return len; - } - - DCHECK(!user_read_buf_); - // http://crbug.com/16371: We're seeing |buf->data()| return NULL. See if the - // user is passing in an IOBuffer with a NULL |data_|. - CHECK(buf); - CHECK(buf->data()); - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - int rv = DoPayloadRead(); - if (rv == ERR_IO_PENDING) { - old_user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } - return rv; -} -int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(completed_handshake()); - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(user_read_callback_.is_null()); // If we have surplus decrypted plaintext, satisfy the Read with it without // reading more ciphertext from the transport socket. @@ -867,9 +793,9 @@ int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, } int SSLClientSocketWin::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { DCHECK(completed_handshake()); - DCHECK(!user_write_callback_); + DCHECK(user_write_callback_.is_null()); DCHECK(!user_write_buf_); user_write_buf_ = buf; @@ -906,32 +832,18 @@ void SSLClientSocketWin::OnHandshakeIOComplete(int result) { // If there is no connect callback available to call, we are renegotiating // (which occurs because we are in the middle of a Read when the // renegotiation process starts). So we complete the Read here. - if (!old_user_connect_callback_ && user_connect_callback_.is_null()) { - if (old_user_read_callback_) { - OldCompletionCallback* c = old_user_read_callback_; - old_user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); - } else { - CompletionCallback c = user_read_callback_; - user_read_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c.Run(rv); - } + if (user_connect_callback_.is_null()) { + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); return; } net_log_.EndEvent(NetLog::TYPE_SSL_CONNECT, NULL); - if (old_user_connect_callback_) { - OldCompletionCallback* c = old_user_connect_callback_; - old_user_connect_callback_ = NULL; - c->Run(rv); - } else { - CompletionCallback c = user_connect_callback_; - user_connect_callback_.Reset(); - c.Run(rv); - } + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv); } } @@ -942,20 +854,12 @@ void SSLClientSocketWin::OnReadComplete(int result) { if (result > 0) result = DoPayloadDecrypt(); if (result != ERR_IO_PENDING) { - DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); - if (old_user_read_callback_) { - OldCompletionCallback* c = old_user_read_callback_; - old_user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(result); - } else { - CompletionCallback c = user_read_callback_; - user_read_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c.Run(result); - } + DCHECK(!user_read_callback_.is_null()); + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(result); } } @@ -964,12 +868,12 @@ void SSLClientSocketWin::OnWriteComplete(int result) { int rv = DoPayloadWriteComplete(result); if (rv != ERR_IO_PENDING) { - DCHECK(user_write_callback_); - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; + DCHECK(!user_write_callback_.is_null()); + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); user_write_buf_ = NULL; user_write_buf_len_ = 0; - c->Run(rv); + c.Run(rv); } } @@ -1031,8 +935,10 @@ int SSLClientSocketWin::DoHandshakeRead() { DCHECK(!transport_read_buf_); transport_read_buf_ = new IOBuffer(buf_len); - return transport_->socket()->Read(transport_read_buf_, buf_len, - &handshake_io_callback_); + return transport_->socket()->Read( + transport_read_buf_, buf_len, + base::Bind(&SSLClientSocketWin::OnHandshakeIOComplete, + base::Unretained(this))); } int SSLClientSocketWin::DoHandshakeReadComplete(int result) { @@ -1209,8 +1115,10 @@ int SSLClientSocketWin::DoHandshakeWrite() { transport_write_buf_ = new IOBuffer(buf_len); memcpy(transport_write_buf_->data(), buf, buf_len); - return transport_->socket()->Write(transport_write_buf_, buf_len, - &handshake_io_callback_); + return transport_->socket()->Write( + transport_write_buf_, buf_len, + base::Bind(&SSLClientSocketWin::OnHandshakeIOComplete, + base::Unretained(this))); } int SSLClientSocketWin::DoHandshakeWriteComplete(int result) { @@ -1308,8 +1216,10 @@ int SSLClientSocketWin::DoPayloadRead() { DCHECK(!transport_read_buf_); transport_read_buf_ = new IOBuffer(buf_len); - rv = transport_->socket()->Read(transport_read_buf_, buf_len, - &read_callback_); + rv = transport_->socket()->Read( + transport_read_buf_, buf_len, + base::Bind(&SSLClientSocketWin::OnReadComplete, + base::Unretained(this))); if (rv != ERR_IO_PENDING) rv = DoPayloadReadComplete(rv); if (rv <= 0) @@ -1551,8 +1461,10 @@ int SSLClientSocketWin::DoPayloadWrite() { transport_write_buf_ = new IOBuffer(buf_len); memcpy(transport_write_buf_->data(), buf, buf_len); - int rv = transport_->socket()->Write(transport_write_buf_, buf_len, - &write_callback_); + int rv = transport_->socket()->Write( + transport_write_buf_, buf_len, + base::Bind(&SSLClientSocketWin::OnWriteComplete, + base::Unretained(this))); if (rv != ERR_IO_PENDING) rv = DoPayloadWriteComplete(rv); return rv; @@ -1635,8 +1547,8 @@ int SSLClientSocketWin::DidCompleteHandshake() { // Called when a renegotiation is completed. |result| is the verification // result of the server certificate received during renegotiation. void SSLClientSocketWin::DidCompleteRenegotiation() { - DCHECK(!old_user_connect_callback_ && user_connect_callback_.is_null()); - DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); + DCHECK(!user_read_callback_.is_null()); renegotiating_ = false; next_state_ = STATE_COMPLETED_RENEGOTIATION; } diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 27ce300..c59decf 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -56,29 +56,28 @@ class SSLClientSocketWin : public SSLClientSocket { std::string* server_protos); // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback); - virtual int Connect(const CompletionCallback& callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - virtual int GetPeerAddress(AddressList* address) const; - virtual int GetLocalAddress(IPEndPoint* address) const; - virtual const BoundNetLog& NetLog() const { return net_log_; } - virtual void SetSubresourceSpeculation(); - virtual void SetOmniboxSpeculation(); - virtual bool WasEverUsed() const; - virtual bool UsingTCPFastOpen() const; - virtual int64 NumBytesRead() const; - virtual base::TimeDelta GetConnectTimeMicros() const; + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(AddressList* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE{ return net_log_; } + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual int64 NumBytesRead() const OVERRIDE; + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; - virtual bool SetReceiveBufferSize(int32 size); - virtual bool SetSendBufferSize(int32 size); + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; private: bool completed_handshake() const { @@ -114,27 +113,20 @@ class SSLClientSocketWin : public SSLClientSocket { void LogConnectionTypeMetrics() const; void FreeSendBuffer(); - // Internal callbacks as async operations complete. - OldCompletionCallbackImpl<SSLClientSocketWin> handshake_io_callback_; - OldCompletionCallbackImpl<SSLClientSocketWin> read_callback_; - OldCompletionCallbackImpl<SSLClientSocketWin> write_callback_; - scoped_ptr<ClientSocketHandle> transport_; HostPortPair host_and_port_; SSLConfig ssl_config_; // User function to callback when the Connect() completes. - OldCompletionCallback* old_user_connect_callback_; CompletionCallback user_connect_callback_; // User function to callback when a Read() completes. - OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; scoped_refptr<IOBuffer> user_read_buf_; int user_read_buf_len_; // User function to callback when a Write() completes. - OldCompletionCallback* user_write_callback_; + CompletionCallback user_write_callback_; scoped_refptr<IOBuffer> user_write_buf_; int user_write_buf_len_; diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 0785dd7d..5b57492 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -58,15 +58,8 @@ SSLServerSocketNSS::SSLServerSocketNSS( scoped_refptr<X509Certificate> cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) - : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( - this, &SSLServerSocketNSS::BufferSendComplete)), - ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( - this, &SSLServerSocketNSS::BufferRecvComplete)), - transport_send_busy_(false), + : transport_send_busy_(false), transport_recv_busy_(false), - user_handshake_callback_(NULL), - old_user_read_callback_(NULL), - user_write_callback_(NULL), nss_fd_(NULL), nss_bufs_(NULL), transport_socket_(transport_socket), @@ -143,40 +136,14 @@ int SSLServerSocketNSS::ExportKeyingMaterial(const base::StringPiece& label, return OK; } -int SSLServerSocketNSS::Connect(OldCompletionCallback* callback) { - NOTIMPLEMENTED(); - return ERR_NOT_IMPLEMENTED; -} int SSLServerSocketNSS::Connect(const CompletionCallback& callback) { NOTIMPLEMENTED(); return ERR_NOT_IMPLEMENTED; } int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); - DCHECK(!user_handshake_callback_); - DCHECK(!user_read_buf_); - DCHECK(nss_bufs_); - - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - DCHECK(completed_handshake_); - - int rv = DoReadLoop(OK); - - if (rv == ERR_IO_PENDING) { - old_user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } - return rv; -} -int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { - DCHECK(!old_user_read_callback_ && user_read_callback_.is_null()); + DCHECK(user_read_callback_.is_null()); DCHECK(!user_handshake_callback_); DCHECK(!user_read_buf_); DCHECK(nss_bufs_); @@ -198,8 +165,8 @@ int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, } int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(!user_write_callback_); + const CompletionCallback& callback) { + DCHECK(user_write_callback_.is_null()); DCHECK(!user_write_buf_); DCHECK(nss_bufs_); @@ -521,8 +488,10 @@ int SSLServerSocketNSS::BufferSend(void) { scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); memcpy(send_buffer->data(), buf1, len1); memcpy(send_buffer->data() + len1, buf2, len2); - rv = transport_socket_->Write(send_buffer, len, - &buffer_send_callback_); + rv = transport_socket_->Write( + send_buffer, len, + base::Bind(&SSLServerSocketNSS::BufferSendComplete, + base::Unretained(this))); if (rv == ERR_IO_PENDING) { transport_send_busy_ = true; } else { @@ -550,7 +519,10 @@ int SSLServerSocketNSS::BufferRecv(void) { rv = ERR_IO_PENDING; } else { recv_buffer_ = new IOBuffer(nb); - rv = transport_socket_->Read(recv_buffer_, nb, &buffer_recv_callback_); + rv = transport_socket_->Read( + recv_buffer_, nb, + base::Bind(&SSLServerSocketNSS::BufferRecvComplete, + base::Unretained(this))); if (rv == ERR_IO_PENDING) { transport_recv_busy_ = true; } else { @@ -739,36 +711,28 @@ void SSLServerSocketNSS::DoHandshakeCallback(int rv) { void SSLServerSocketNSS::DoReadCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(old_user_read_callback_ || !user_read_callback_.is_null()); + DCHECK(!user_read_callback_.is_null()); // Since Run may result in Read being called, clear |user_read_callback_| // up front. - if (old_user_read_callback_) { - OldCompletionCallback* c = old_user_read_callback_; - old_user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); - } else { - CompletionCallback c = user_read_callback_; - user_read_callback_.Reset(); - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c.Run(rv); - } + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); } void SSLServerSocketNSS::DoWriteCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); + DCHECK(!user_write_callback_.is_null()); // Since Run may result in Write being called, clear |user_write_callback_| // up front. - OldCompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); user_write_buf_ = NULL; user_write_buf_len_ = 0; - c->Run(rv); + c.Run(rv); } // static diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 39283f6..1d716a1 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -40,16 +40,13 @@ class SSLServerSocketNSS : public SSLServerSocket { // Socket interface (via StreamSocket). virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; - virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) OVERRIDE; virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; @@ -101,8 +98,6 @@ class SSLServerSocketNSS : public SSLServerSocket { virtual int Init(); // Members used to send and receive buffer. - OldCompletionCallbackImpl<SSLServerSocketNSS> buffer_send_callback_; - OldCompletionCallbackImpl<SSLServerSocketNSS> buffer_recv_callback_; bool transport_send_busy_; bool transport_recv_busy_; @@ -111,9 +106,8 @@ class SSLServerSocketNSS : public SSLServerSocket { BoundNetLog net_log_; OldCompletionCallback* user_handshake_callback_; - OldCompletionCallback* old_user_read_callback_; CompletionCallback user_read_callback_; - OldCompletionCallback* user_write_callback_; + CompletionCallback user_write_callback_; // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index eb9dc7c..f5034c2 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -51,22 +51,11 @@ namespace { class FakeDataChannel { public: FakeDataChannel() - : old_read_callback_(NULL), - read_buf_len_(0), - ALLOW_THIS_IN_INITIALIZER_LIST(task_factory_(this)) { + : read_buf_len_(0), + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { } virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - if (data_.empty()) { - old_read_callback_ = callback; - read_buf_ = buf; - read_buf_len_ = buf_len; - return net::ERR_IO_PENDING; - } - return PropogateData(buf, buf_len); - } - virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { if (data_.empty()) { read_callback_ = callback; @@ -78,33 +67,25 @@ class FakeDataChannel { } virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { data_.push(new net::DrainableIOBuffer(buf, buf_len)); MessageLoop::current()->PostTask( - FROM_HERE, task_factory_.NewRunnableMethod( - &FakeDataChannel::DoReadCallback)); + FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, + weak_factory_.GetWeakPtr())); return buf_len; } private: void DoReadCallback() { - if ((!old_read_callback_ && read_callback_.is_null()) || data_.empty()) + if (read_callback_.is_null() || data_.empty()) return; int copied = PropogateData(read_buf_, read_buf_len_); - if (old_read_callback_) { - net::OldCompletionCallback* callback = old_read_callback_; - old_read_callback_ = NULL; - read_buf_ = NULL; - read_buf_len_ = 0; - callback->Run(copied); - } else { - net::CompletionCallback callback = read_callback_; - read_callback_.Reset(); - read_buf_ = NULL; - read_buf_len_ = 0; - callback.Run(copied); - } + CompletionCallback callback = read_callback_; + read_callback_.Reset(); + read_buf_ = NULL; + read_buf_len_ = 0; + callback.Run(copied); } int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { @@ -118,14 +99,13 @@ class FakeDataChannel { return copied; } - net::OldCompletionCallback* old_read_callback_; - net::CompletionCallback read_callback_; + CompletionCallback read_callback_; scoped_refptr<net::IOBuffer> read_buf_; int read_buf_len_; std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; - ScopedRunnableMethodFactory<FakeDataChannel> task_factory_; + base::WeakPtrFactory<FakeDataChannel> weak_factory_; DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); }; @@ -142,82 +122,73 @@ class FakeSocket : public StreamSocket { } virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - // Read random number of bytes. - buf_len = rand() % buf_len + 1; - return incoming_->Read(buf, buf_len, callback); - } - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) { + const CompletionCallback& callback) OVERRIDE { // Read random number of bytes. buf_len = rand() % buf_len + 1; return incoming_->Read(buf, buf_len, callback); } virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) OVERRIDE { // Write random number of bytes. buf_len = rand() % buf_len + 1; return outgoing_->Write(buf, buf_len, callback); } - virtual bool SetReceiveBufferSize(int32 size) { + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; } - virtual bool SetSendBufferSize(int32 size) { + virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; } - virtual int Connect(OldCompletionCallback* callback) { - return net::OK; - } - virtual int Connect(const CompletionCallback& callback) { + virtual int Connect(const CompletionCallback& callback) OVERRIDE { return net::OK; } - virtual void Disconnect() {} + virtual void Disconnect() OVERRIDE {} - virtual bool IsConnected() const { + virtual bool IsConnected() const OVERRIDE { return true; } - virtual bool IsConnectedAndIdle() const { + virtual bool IsConnectedAndIdle() const OVERRIDE { return true; } - virtual int GetPeerAddress(AddressList* address) const { + virtual int GetPeerAddress(AddressList* address) const OVERRIDE { net::IPAddressNumber ip_address(4); *address = net::AddressList::CreateFromIPAddress(ip_address, 0 /*port*/); return net::OK; } - virtual int GetLocalAddress(IPEndPoint* address) const { + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { net::IPAddressNumber ip_address(4); *address = net::IPEndPoint(ip_address, 0); return net::OK; } - virtual const BoundNetLog& NetLog() const { + virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } - virtual void SetSubresourceSpeculation() {} - virtual void SetOmniboxSpeculation() {} + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const { + virtual bool WasEverUsed() const OVERRIDE { return true; } - virtual bool UsingTCPFastOpen() const { + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual int64 NumBytesRead() const { + virtual int64 NumBytesRead() const OVERRIDE { return -1; } - virtual base::TimeDelta GetConnectTimeMicros() const { + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { return base::TimeDelta::FromMicroseconds(-1); } @@ -246,21 +217,21 @@ TEST(FakeSocketTest, DataTransfer) { scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); // Write then read. - int written = server.Write(write_buf, kTestDataSize, NULL); + int written = server.Write(write_buf, kTestDataSize, CompletionCallback()); EXPECT_GT(written, 0); EXPECT_LE(written, kTestDataSize); - int read = client.Read(read_buf, kReadBufSize, NULL); + int read = client.Read(read_buf, kReadBufSize, CompletionCallback()); EXPECT_GT(read, 0); EXPECT_LE(read, written); EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); // Read then write. - TestOldCompletionCallback callback; + TestCompletionCallback callback; EXPECT_EQ(net::ERR_IO_PENDING, - server.Read(read_buf, kReadBufSize, &callback)); + server.Read(read_buf, kReadBufSize, callback.callback())); - written = client.Write(write_buf, kTestDataSize, NULL); + written = client.Write(write_buf, kTestDataSize, CompletionCallback()); EXPECT_GT(written, 0); EXPECT_LE(written, kTestDataSize); @@ -354,13 +325,13 @@ TEST_F(SSLServerSocketTest, Initialize) { TEST_F(SSLServerSocketTest, Handshake) { Initialize(); - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TestOldCompletionCallback handshake_callback; int server_ret = server_socket_->Handshake(&handshake_callback); EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); - int client_ret = client_socket_->Connect(&connect_callback); + int client_ret = client_socket_->Connect(connect_callback.callback()); EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); if (client_ret == net::ERR_IO_PENDING) { @@ -379,11 +350,11 @@ TEST_F(SSLServerSocketTest, Handshake) { TEST_F(SSLServerSocketTest, DataTransfer) { Initialize(); - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TestOldCompletionCallback handshake_callback; // Establish connection. - int client_ret = client_socket_->Connect(&connect_callback); + int client_ret = client_socket_->Connect(connect_callback.callback()); ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); int server_ret = server_socket_->Handshake(&handshake_callback); @@ -402,13 +373,13 @@ TEST_F(SSLServerSocketTest, DataTransfer) { kReadBufSize); // Write then read. - TestOldCompletionCallback write_callback; - TestOldCompletionCallback read_callback; + TestCompletionCallback write_callback; + TestCompletionCallback read_callback; server_ret = server_socket_->Write(write_buf, write_buf->size(), - &write_callback); + write_callback.callback()); EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); client_ret = client_socket_->Read(read_buf, read_buf->BytesRemaining(), - &read_callback); + read_callback.callback()); EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); server_ret = write_callback.GetResult(server_ret); @@ -419,7 +390,7 @@ TEST_F(SSLServerSocketTest, DataTransfer) { read_buf->DidConsume(client_ret); while (read_buf->BytesConsumed() < write_buf->size()) { client_ret = client_socket_->Read(read_buf, read_buf->BytesRemaining(), - &read_callback); + read_callback.callback()); EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); client_ret = read_callback.GetResult(client_ret); ASSERT_GT(client_ret, 0); @@ -432,10 +403,10 @@ TEST_F(SSLServerSocketTest, DataTransfer) { // Read then write. write_buf = new net::StringIOBuffer("hello123"); server_ret = server_socket_->Read(read_buf, read_buf->BytesRemaining(), - &read_callback); + read_callback.callback()); EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); client_ret = client_socket_->Write(write_buf, write_buf->size(), - &write_callback); + write_callback.callback()); EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); server_ret = read_callback.GetResult(server_ret); @@ -446,7 +417,7 @@ TEST_F(SSLServerSocketTest, DataTransfer) { read_buf->DidConsume(server_ret); while (read_buf->BytesConsumed() < write_buf->size()) { server_ret = server_socket_->Read(read_buf, read_buf->BytesRemaining(), - &read_callback); + read_callback.callback()); EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); server_ret = read_callback.GetResult(server_ret); ASSERT_GT(server_ret, 0); @@ -463,10 +434,10 @@ TEST_F(SSLServerSocketTest, DataTransfer) { TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { Initialize(); - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TestOldCompletionCallback handshake_callback; - int client_ret = client_socket_->Connect(&connect_callback); + int client_ret = client_socket_->Connect(connect_callback.callback()); ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); int server_ret = server_socket_->Handshake(&handshake_callback); diff --git a/net/socket/stream_socket.h b/net/socket/stream_socket.h index 3ba5b42..a0e6f86 100644 --- a/net/socket/stream_socket.h +++ b/net/socket/stream_socket.h @@ -33,7 +33,6 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { // // Connect may also be called again after a call to the Disconnect method. // - virtual int Connect(OldCompletionCallback* callback) = 0; virtual int Connect(const CompletionCallback& callback) = 0; // Called to disconnect a socket. Does nothing if the socket is already diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index 3c99ae5..9b47249 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -130,8 +130,6 @@ TCPClientSocketLibevent::TCPClientSocketLibevent( current_ai_(NULL), read_watcher_(this), write_watcher_(this), - old_read_callback_(NULL), - old_write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), @@ -202,38 +200,6 @@ int TCPClientSocketLibevent::Bind(const IPEndPoint& address) { return 0; } -int TCPClientSocketLibevent::Connect(OldCompletionCallback* callback) { - DCHECK(CalledOnValidThread()); - - // If already connected, then just return OK. - if (socket_ != kInvalidSocket) - return OK; - - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - - DCHECK(!waiting_connect()); - - net_log_.BeginEvent( - NetLog::TYPE_TCP_CONNECT, - make_scoped_refptr(new AddressListNetLogParam(addresses_))); - - // We will try to connect to each address in addresses_. Start with the - // first one in the list. - next_connect_state_ = CONNECT_STATE_CONNECT; - current_ai_ = addresses_.head(); - - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(callback); - old_write_callback_ = callback; - } else { - LogConnectCompletion(rv); - } - - return rv; -} int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); @@ -461,50 +427,11 @@ bool TCPClientSocketLibevent::IsConnectedAndIdle() const { int TCPClientSocketLibevent::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(CalledOnValidThread()); - DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect()); - DCHECK(!old_read_callback_ && read_callback_.is_null()); - // Synchronous operation not supported - DCHECK(callback); - DCHECK_GT(buf_len, 0); - - int nread = HANDLE_EINTR(read(socket_, buf->data(), buf_len)); - if (nread >= 0) { - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(nread); - num_bytes_read_ += static_cast<int64>(nread); - if (nread > 0) - use_history_.set_was_used_to_convey_data(); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread, - buf->data()); - return nread; - } - if (errno != EAGAIN && errno != EWOULDBLOCK) { - DVLOG(1) << "read failed, errno " << errno; - return MapSystemError(errno); - } - - if (!MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, MessageLoopForIO::WATCH_READ, - &read_socket_watcher_, &read_watcher_)) { - DVLOG(1) << "WatchFileDescriptor failed on read, errno " << errno; - return MapSystemError(errno); - } - - read_buf_ = buf; - read_buf_len_ = buf_len; - old_read_callback_ = callback; - return ERR_IO_PENDING; -} -int TCPClientSocketLibevent::Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); DCHECK(!waiting_connect()); - DCHECK(!old_read_callback_ && read_callback_.is_null()); + DCHECK(read_callback_.is_null()); // Synchronous operation not supported DCHECK(!callback.is_null()); DCHECK_GT(buf_len, 0); @@ -540,13 +467,13 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, int TCPClientSocketLibevent::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); DCHECK(!waiting_connect()); - DCHECK(!old_write_callback_ && write_callback_.is_null()); + DCHECK(write_callback_.is_null()); // Synchronous operation not supported - DCHECK(callback); + DCHECK(!callback.is_null()); DCHECK_GT(buf_len, 0); int nwrite = InternalWrite(buf, buf_len); @@ -571,7 +498,7 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, write_buf_ = buf; write_buf_len_ = buf_len; - old_write_callback_ = callback; + write_callback_ = callback; return ERR_IO_PENDING; } @@ -657,34 +584,22 @@ void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { void TCPClientSocketLibevent::DoReadCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(old_read_callback_ || !read_callback_.is_null()); + DCHECK(!read_callback_.is_null()); // since Run may result in Read being called, clear read_callback_ up front. - if (old_read_callback_) { - OldCompletionCallback* c = old_read_callback_; - old_read_callback_ = NULL; - c->Run(rv); - } else { - CompletionCallback c = read_callback_; - read_callback_.Reset(); - c.Run(rv); - } + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); } void TCPClientSocketLibevent::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(old_write_callback_ || !write_callback_.is_null()); + DCHECK(!write_callback_.is_null()); // since Run may result in Write being called, clear write_callback_ up front. - if (old_write_callback_) { - OldCompletionCallback* c = old_write_callback_; - old_write_callback_ = NULL; - c->Run(rv); - } else { - CompletionCallback c = write_callback_; - write_callback_.Reset(); - c.Run(rv); - } + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); } void TCPClientSocketLibevent::DidCompleteConnect() { diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index 47f19a0..a6aa241 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -43,7 +43,6 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, int Bind(const IPEndPoint& address); // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) OVERRIDE; virtual int Connect(const CompletionCallback& callback) OVERRIDE; virtual void Disconnect() OVERRIDE; virtual bool IsConnected() const OVERRIDE; @@ -63,13 +62,10 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, // Full duplex mode (reading and writing at the same time) is supported virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; - virtual int Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) OVERRIDE; + const CompletionCallback& callback) OVERRIDE; virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; @@ -88,7 +84,7 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, // MessageLoopForIO::Watcher methods virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE { - if (socket_->old_read_callback_) + if (!socket_->read_callback_.is_null()) socket_->DidCompleteRead(); } @@ -109,8 +105,7 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE { if (socket_->waiting_connect()) { socket_->DidCompleteConnect(); - } else if (socket_->old_write_callback_ || - !socket_->write_callback_.is_null()) { + } else if (!socket_->write_callback_.is_null()) { socket_->DidCompleteWrite(); } } @@ -179,11 +174,9 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, int write_buf_len_; // External callback; called when read is complete. - OldCompletionCallback* old_read_callback_; CompletionCallback read_callback_; // External callback; called when write is complete. - OldCompletionCallback* old_write_callback_; CompletionCallback write_callback_; // The next state for the Connect() state machine. diff --git a/net/socket/tcp_client_socket_unittest.cc b/net/socket/tcp_client_socket_unittest.cc index 50a319b..991b645 100644 --- a/net/socket/tcp_client_socket_unittest.cc +++ b/net/socket/tcp_client_socket_unittest.cc @@ -38,8 +38,8 @@ TEST(TCPClientSocketTest, BindLoopbackToLoopback) { EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0))); - TestOldCompletionCallback connect_callback; - EXPECT_EQ(ERR_IO_PENDING, socket.Connect(&connect_callback)); + TestCompletionCallback connect_callback; + EXPECT_EQ(ERR_IO_PENDING, socket.Connect(connect_callback.callback())); TestCompletionCallback accept_callback; scoped_ptr<StreamSocket> accepted_socket; @@ -63,8 +63,8 @@ TEST(TCPClientSocketTest, BindLoopbackToExternal) { ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address)); EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0))); - TestOldCompletionCallback connect_callback; - int result = socket.Connect(&connect_callback); + TestCompletionCallback connect_callback; + int result = socket.Connect(connect_callback.callback()); if (result == ERR_IO_PENDING) result = connect_callback.WaitForResult(); @@ -97,8 +97,8 @@ TEST(TCPClientSocketTest, BindLoopbackToIPv6) { ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &ipv4_lo_ip)); EXPECT_EQ(OK, socket.Bind(IPEndPoint(ipv4_lo_ip, 0))); - TestOldCompletionCallback connect_callback; - int result = socket.Connect(&connect_callback); + TestCompletionCallback connect_callback; + int result = socket.Connect(connect_callback.callback()); if (result == ERR_IO_PENDING) result = connect_callback.WaitForResult(); diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index cb59a98..b5921b9 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -317,8 +317,6 @@ TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, current_ai_(NULL), waiting_read_(false), waiting_write_(false), - old_read_callback_(NULL), - write_callback_(NULL), next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), @@ -382,35 +380,6 @@ int TCPClientSocketWin::Bind(const IPEndPoint& address) { } -int TCPClientSocketWin::Connect(OldCompletionCallback* callback) { - DCHECK(CalledOnValidThread()); - - // If already connected, then just return OK. - if (socket_ != INVALID_SOCKET) - return OK; - - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, - new AddressListNetLogParam(addresses_)); - - // We will try to connect to each address in addresses_. Start with the - // first one in the list. - next_connect_state_ = CONNECT_STATE_CONNECT; - current_ai_ = addresses_.head(); - - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(callback); - old_read_callback_ = callback; - } else { - LogConnectCompletion(rv); - } - - return rv; -} int TCPClientSocketWin::Connect(const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); @@ -433,6 +402,7 @@ int TCPClientSocketWin::Connect(const CompletionCallback& callback) { if (rv == ERR_IO_PENDING) { // Synchronous operation not supported. DCHECK(!callback.is_null()); + // TODO(ajwong): Is setting read_callback_ the right thing to do here?? read_callback_ = callback; } else { LogConnectCompletion(rv); @@ -705,52 +675,11 @@ base::TimeDelta TCPClientSocketWin::GetConnectTimeMicros() const { int TCPClientSocketWin::Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - DCHECK(CalledOnValidThread()); - DCHECK_NE(socket_, INVALID_SOCKET); - DCHECK(!waiting_read_); - DCHECK(!old_read_callback_ && read_callback_.is_null()); - DCHECK(!core_->read_iobuffer_); - - buf_len = core_->ThrottleReadSize(buf_len); - - core_->read_buffer_.len = buf_len; - core_->read_buffer_.buf = buf->data(); - - // TODO(wtc): Remove the assertion after enough testing. - AssertEventNotSignaled(core_->read_overlapped_.hEvent); - DWORD num, flags = 0; - int rv = WSARecv(socket_, &core_->read_buffer_, 1, &num, &flags, - &core_->read_overlapped_, NULL); - if (rv == 0) { - if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) { - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(num); - num_bytes_read_ += num; - if (num > 0) - use_history_.set_was_used_to_convey_data(); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num, - core_->read_buffer_.buf); - return static_cast<int>(num); - } - } else { - int os_error = WSAGetLastError(); - if (os_error != WSA_IO_PENDING) - return MapSystemError(os_error); - } - core_->WatchForRead(); - waiting_read_ = true; - old_read_callback_ = callback; - core_->read_iobuffer_ = buf; - return ERR_IO_PENDING; -} -int TCPClientSocketWin::Read(IOBuffer* buf, - int buf_len, const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_read_); - DCHECK(!old_read_callback_ && read_callback_.is_null()); + DCHECK(read_callback_.is_null()); DCHECK(!core_->read_iobuffer_); buf_len = core_->ThrottleReadSize(buf_len); @@ -788,11 +717,11 @@ int TCPClientSocketWin::Read(IOBuffer* buf, int TCPClientSocketWin::Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_write_); - DCHECK(!write_callback_); + DCHECK(write_callback_.is_null()); DCHECK_GT(buf_len, 0); DCHECK(!core_->write_iobuffer_); @@ -881,28 +810,22 @@ void TCPClientSocketWin::LogConnectCompletion(int net_error) { void TCPClientSocketWin::DoReadCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(old_read_callback_ || !read_callback_.is_null()); + DCHECK(!read_callback_.is_null()); - // since Run may result in Read being called, clear read_callback_ up front. - if (old_read_callback_) { - OldCompletionCallback* c = old_read_callback_; - old_read_callback_ = NULL; - c->Run(rv); - } else { - CompletionCallback c = read_callback_; - read_callback_.Reset(); - c.Run(rv); - } + // Since Run may result in Read being called, clear read_callback_ up front. + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); } void TCPClientSocketWin::DoWriteCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(write_callback_); + DCHECK(!write_callback_.is_null()); // since Run may result in Write being called, clear write_callback_ up front. - OldCompletionCallback* c = write_callback_; - write_callback_ = NULL; - c->Run(rv); + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); } void TCPClientSocketWin::DidCompleteConnect() { diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index 1e75933..7f681fa 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -42,7 +42,6 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, int Bind(const IPEndPoint& address); // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback); virtual int Connect(const CompletionCallback& callback); virtual void Disconnect(); virtual bool IsConnected() const; @@ -60,10 +59,10 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, // Socket implementation. // Multiple outstanding requests are not supported. // Full duplex mode (reading and writing at the same time) is supported - virtual int Read(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, OldCompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); @@ -126,11 +125,10 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, scoped_refptr<Core> core_; // External callback; called when connect or read is complete. - OldCompletionCallback* old_read_callback_; CompletionCallback read_callback_; // External callback; called when write is complete. - OldCompletionCallback* write_callback_; + CompletionCallback write_callback_; // The next state for the Connect() state machine. ConnectState next_connect_state_; diff --git a/net/socket/tcp_server_socket_unittest.cc b/net/socket/tcp_server_socket_unittest.cc index 573ff4c..2ccc70e 100644 --- a/net/socket/tcp_server_socket_unittest.cc +++ b/net/socket/tcp_server_socket_unittest.cc @@ -77,10 +77,10 @@ class TCPServerSocketTest : public PlatformTest { TEST_F(TCPServerSocketTest, Accept) { ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TCPClientSocket connecting_socket(local_address_list(), NULL, NetLog::Source()); - connecting_socket.Connect(&connect_callback); + connecting_socket.Connect(connect_callback.callback()); TestCompletionCallback accept_callback; scoped_ptr<StreamSocket> accepted_socket; @@ -108,10 +108,10 @@ TEST_F(TCPServerSocketTest, AcceptAsync) { ASSERT_EQ(ERR_IO_PENDING, socket_.Accept(&accepted_socket, accept_callback.callback())); - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TCPClientSocket connecting_socket(local_address_list(), NULL, NetLog::Source()); - connecting_socket.Connect(&connect_callback); + connecting_socket.Connect(connect_callback.callback()); EXPECT_EQ(OK, connect_callback.WaitForResult()); EXPECT_EQ(OK, accept_callback.WaitForResult()); @@ -133,15 +133,15 @@ TEST_F(TCPServerSocketTest, Accept2Connections) { ASSERT_EQ(ERR_IO_PENDING, socket_.Accept(&accepted_socket, accept_callback.callback())); - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TCPClientSocket connecting_socket(local_address_list(), NULL, NetLog::Source()); - connecting_socket.Connect(&connect_callback); + connecting_socket.Connect(connect_callback.callback()); - TestOldCompletionCallback connect_callback2; + TestCompletionCallback connect_callback2; TCPClientSocket connecting_socket2(local_address_list(), NULL, NetLog::Source()); - connecting_socket2.Connect(&connect_callback2); + connecting_socket2.Connect(connect_callback2.callback()); EXPECT_EQ(OK, accept_callback.WaitForResult()); @@ -170,10 +170,10 @@ TEST_F(TCPServerSocketTest, AcceptIPv6) { if (!initialized) return; - TestOldCompletionCallback connect_callback; + TestCompletionCallback connect_callback; TCPClientSocket connecting_socket(local_address_list(), NULL, NetLog::Source()); - connecting_socket.Connect(&connect_callback); + connecting_socket.Connect(connect_callback.callback()); TestCompletionCallback accept_callback; scoped_ptr<StreamSocket> accepted_socket; diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc index a423ebd..faf12df 100644 --- a/net/socket/transport_client_socket_pool.cc +++ b/net/socket/transport_client_socket_pool.cc @@ -99,14 +99,7 @@ TransportConnectJob::TransportConnectJob( BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), client_socket_factory_(client_socket_factory), - ALLOW_THIS_IN_INITIALIZER_LIST( - callback_(this, - &TransportConnectJob::OnIOComplete)), - resolver_(host_resolver), - ALLOW_THIS_IN_INITIALIZER_LIST( - fallback_callback_( - this, - &TransportConnectJob::DoIPv6FallbackTransportConnectComplete)) {} + resolver_(host_resolver) {} TransportConnectJob::~TransportConnectJob() { // We don't worry about cancelling the host resolution and TCP connect, since @@ -216,7 +209,8 @@ int TransportConnectJob::DoTransportConnect() { transport_socket_.reset(client_socket_factory_->CreateTransportClientSocket( addresses_, net_log().net_log(), net_log().source())); connect_start_time_ = base::TimeTicks::Now(); - int rv = transport_socket_->Connect(&callback_); + int rv = transport_socket_->Connect( + base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING && AddressListStartsWithIPv6AndHasAnIPv4Addr(addresses_)) { fallback_timer_.Start(FROM_HERE, @@ -296,7 +290,10 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { client_socket_factory_->CreateTransportClientSocket( *fallback_addresses_, net_log().net_log(), net_log().source())); fallback_connect_start_time_ = base::TimeTicks::Now(); - int rv = fallback_transport_socket_->Connect(&fallback_callback_); + int rv = fallback_transport_socket_->Connect( + base::Bind( + &TransportConnectJob::DoIPv6FallbackTransportConnectComplete, + base::Unretained(this))); if (rv != ERR_IO_PENDING) DoIPv6FallbackTransportConnectComplete(rv); } diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h index e3e6f02..44309d4 100644 --- a/net/socket/transport_client_socket_pool.h +++ b/net/socket/transport_client_socket_pool.h @@ -108,7 +108,6 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { scoped_refptr<TransportSocketParams> params_; ClientSocketFactory* const client_socket_factory_; - OldCompletionCallbackImpl<TransportConnectJob> callback_; SingleRequestHostResolver resolver_; AddressList addresses_; State next_state_; @@ -123,7 +122,6 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { scoped_ptr<StreamSocket> fallback_transport_socket_; scoped_ptr<AddressList> fallback_addresses_; - OldCompletionCallbackImpl<TransportConnectJob> fallback_callback_; base::TimeTicks fallback_connect_start_time_; base::OneShotTimer<TransportConnectJob> fallback_timer_; diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 56b1fa9c..c5d1080 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -52,10 +52,6 @@ class MockClientSocket : public StreamSocket { addrlist_(addrlist) {} // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) { - connected_ = true; - return OK; - } virtual int Connect(const CompletionCallback& callback) { connected_ = true; return OK; @@ -96,15 +92,11 @@ class MockClientSocket : public StreamSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - return ERR_FAILED; - } - virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { return ERR_FAILED; } virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { return ERR_FAILED; } virtual bool SetReceiveBufferSize(int32 size) { return true; } @@ -121,10 +113,7 @@ class MockFailingClientSocket : public StreamSocket { MockFailingClientSocket(const AddressList& addrlist) : addrlist_(addrlist) {} // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) { - return ERR_CONNECTION_FAILED; - } - virtual int Connect(const net::CompletionCallback& callback) { + virtual int Connect(const CompletionCallback& callback) { return ERR_CONNECTION_FAILED; } @@ -157,16 +146,12 @@ class MockFailingClientSocket : public StreamSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - return ERR_FAILED; - } - virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { return ERR_FAILED; } virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { return ERR_FAILED; } virtual bool SetReceiveBufferSize(int32 size) { return true; } @@ -196,14 +181,6 @@ class MockPendingClientSocket : public StreamSocket { addrlist_(addrlist) {} // StreamSocket implementation. - virtual int Connect(OldCompletionCallback* callback) { - MessageLoop::current()->PostDelayedTask( - FROM_HERE, - base::Bind(&MockPendingClientSocket::DoOldCallback, - weak_factory_.GetWeakPtr(), callback), - delay_ms_); - return ERR_IO_PENDING; - } virtual int Connect(const CompletionCallback& callback) { MessageLoop::current()->PostDelayedTask( FROM_HERE, @@ -248,34 +225,18 @@ class MockPendingClientSocket : public StreamSocket { // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { - return ERR_FAILED; - } - virtual int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { return ERR_FAILED; } virtual int Write(IOBuffer* buf, int buf_len, - OldCompletionCallback* callback) { + const CompletionCallback& callback) { return ERR_FAILED; } virtual bool SetReceiveBufferSize(int32 size) { return true; } virtual bool SetSendBufferSize(int32 size) { return true; } private: - void DoOldCallback(OldCompletionCallback* callback) { - if (should_stall_) - return; - - if (should_connect_) { - is_connected_ = true; - callback->Run(OK); - } else { - is_connected_ = false; - callback->Run(ERR_CONNECTION_FAILED); - } - } void DoCallback(const CompletionCallback& callback) { if (should_stall_) return; diff --git a/net/socket/transport_client_socket_unittest.cc b/net/socket/transport_client_socket_unittest.cc index d661e02..4ba4c000 100644 --- a/net/socket/transport_client_socket_unittest.cc +++ b/net/socket/transport_client_socket_unittest.cc @@ -75,7 +75,7 @@ class TransportClientSocketTest int DrainClientSocket(IOBuffer* buf, uint32 buf_len, uint32 bytes_to_read, - TestOldCompletionCallback* callback); + TestCompletionCallback* callback); void SendClientRequest(); @@ -136,12 +136,12 @@ void TransportClientSocketTest::SetUp() { int TransportClientSocketTest::DrainClientSocket( IOBuffer* buf, uint32 buf_len, - uint32 bytes_to_read, TestOldCompletionCallback* callback) { + uint32 bytes_to_read, TestCompletionCallback* callback) { int rv = OK; uint32 bytes_read = 0; while (bytes_read < bytes_to_read) { - rv = sock_->Read(buf, buf_len, callback); + rv = sock_->Read(buf, buf_len, callback->callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) @@ -158,11 +158,12 @@ void TransportClientSocketTest::SendClientRequest() { const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; scoped_refptr<IOBuffer> request_buffer( new IOBuffer(arraysize(request_text) - 1)); - TestOldCompletionCallback callback; + TestCompletionCallback callback; int rv; memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); - rv = sock_->Write(request_buffer, arraysize(request_text) - 1, &callback); + rv = sock_->Write(request_buffer, arraysize(request_text) - 1, + callback.callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) @@ -176,10 +177,10 @@ INSTANTIATE_TEST_CASE_P(StreamSocket, ::testing::Values(TCP)); TEST_P(TransportClientSocketTest, Connect) { - TestOldCompletionCallback callback; + TestCompletionCallback callback; EXPECT_FALSE(sock_->IsConnected()); - int rv = sock_->Connect(&callback); + int rv = sock_->Connect(callback.callback()); net::CapturingNetLog::EntryList net_log_entries; net_log_.GetEntries(&net_log_entries); @@ -204,12 +205,12 @@ TEST_P(TransportClientSocketTest, Connect) { TEST_P(TransportClientSocketTest, IsConnected) { scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); - TestOldCompletionCallback callback; + TestCompletionCallback callback; uint32 bytes_read; EXPECT_FALSE(sock_->IsConnected()); EXPECT_FALSE(sock_->IsConnectedAndIdle()); - int rv = sock_->Connect(&callback); + int rv = sock_->Connect(callback.callback()); if (rv != OK) { ASSERT_EQ(rv, ERR_IO_PENDING); rv = callback.WaitForResult(); @@ -261,8 +262,8 @@ TEST_P(TransportClientSocketTest, IsConnected) { } TEST_P(TransportClientSocketTest, Read) { - TestOldCompletionCallback callback; - int rv = sock_->Connect(&callback); + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); if (rv != OK) { ASSERT_EQ(rv, ERR_IO_PENDING); @@ -279,7 +280,7 @@ TEST_P(TransportClientSocketTest, Read) { // All data has been read now. Read once more to force an ERR_IO_PENDING, and // then close the server socket, and note the close. - rv = sock_->Read(buf, 4096, &callback); + rv = sock_->Read(buf, 4096, callback.callback()); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(static_cast<int64>(std::string(kServerReply).size()), sock_->NumBytesRead()); @@ -288,8 +289,8 @@ TEST_P(TransportClientSocketTest, Read) { } TEST_P(TransportClientSocketTest, Read_SmallChunks) { - TestOldCompletionCallback callback; - int rv = sock_->Connect(&callback); + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); if (rv != OK) { ASSERT_EQ(rv, ERR_IO_PENDING); @@ -301,7 +302,7 @@ TEST_P(TransportClientSocketTest, Read_SmallChunks) { scoped_refptr<IOBuffer> buf(new IOBuffer(1)); uint32 bytes_read = 0; while (bytes_read < arraysize(kServerReply) - 1) { - rv = sock_->Read(buf, 1, &callback); + rv = sock_->Read(buf, 1, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) @@ -314,7 +315,7 @@ TEST_P(TransportClientSocketTest, Read_SmallChunks) { // All data has been read now. Read once more to force an ERR_IO_PENDING, and // then close the server socket, and note the close. - rv = sock_->Read(buf, 1, &callback); + rv = sock_->Read(buf, 1, callback.callback()); EXPECT_EQ(static_cast<int64>(std::string(kServerReply).size()), sock_->NumBytesRead()); ASSERT_EQ(ERR_IO_PENDING, rv); @@ -323,8 +324,8 @@ TEST_P(TransportClientSocketTest, Read_SmallChunks) { } TEST_P(TransportClientSocketTest, Read_Interrupted) { - TestOldCompletionCallback callback; - int rv = sock_->Connect(&callback); + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); if (rv != OK) { ASSERT_EQ(ERR_IO_PENDING, rv); @@ -335,7 +336,7 @@ TEST_P(TransportClientSocketTest, Read_Interrupted) { // Do a partial read and then exit. This test should not crash! scoped_refptr<IOBuffer> buf(new IOBuffer(16)); - rv = sock_->Read(buf, 16, &callback); + rv = sock_->Read(buf, 16, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); EXPECT_EQ(0, sock_->NumBytesRead()); @@ -348,8 +349,8 @@ TEST_P(TransportClientSocketTest, Read_Interrupted) { } TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { - TestOldCompletionCallback callback; - int rv = sock_->Connect(&callback); + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); if (rv != OK) { ASSERT_EQ(rv, ERR_IO_PENDING); @@ -360,7 +361,7 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { // Read first. There's no data, so it should return ERR_IO_PENDING. const int kBufLen = 4096; scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen)); - rv = sock_->Read(buf, kBufLen, &callback); + rv = sock_->Read(buf, kBufLen, callback.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); PauseServerReads(); @@ -368,10 +369,10 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen)); char* request_data = request_buffer->data(); memset(request_data, 'A', kWriteBufLen); - TestOldCompletionCallback write_callback; + TestCompletionCallback write_callback; while (true) { - rv = sock_->Write(request_buffer, kWriteBufLen, &write_callback); + rv = sock_->Write(request_buffer, kWriteBufLen, write_callback.callback()); ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) { @@ -390,8 +391,8 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { } TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { - TestOldCompletionCallback callback; - int rv = sock_->Connect(&callback); + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); if (rv != OK) { ASSERT_EQ(ERR_IO_PENDING, rv); @@ -404,10 +405,10 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen)); char* request_data = request_buffer->data(); memset(request_data, 'A', kWriteBufLen); - TestOldCompletionCallback write_callback; + TestCompletionCallback write_callback; while (true) { - rv = sock_->Write(request_buffer, kWriteBufLen, &write_callback); + rv = sock_->Write(request_buffer, kWriteBufLen, write_callback.callback()); ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) @@ -420,7 +421,7 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { const int kBufLen = 4096; scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen)); while (true) { - rv = sock_->Read(buf, kBufLen, &callback); + rv = sock_->Read(buf, kBufLen, callback.callback()); ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) break; diff --git a/net/socket/web_socket_server_socket.cc b/net/socket/web_socket_server_socket.cc index d792689..08b7788 100644 --- a/net/socket/web_socket_server_socket.cc +++ b/net/socket/web_socket_server_socket.cc @@ -118,10 +118,6 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { handshake_buf_, kHandshakeLimitBytes)), process_handshake_buf_(new net::DrainableIOBuffer( handshake_buf_, kHandshakeLimitBytes)), - transport_read_callback_(NewCallback( - this, &WebSocketServerSocketImpl::OnRead)), - transport_write_callback_(NewCallback( - this, &WebSocketServerSocketImpl::OnWrite)), is_transport_read_pending_(false), is_transport_write_pending_(false), method_factory_(this) { @@ -135,11 +131,8 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { it->type == PendingReq::TYPE_READ && it->io_buf != NULL && it->io_buf->data() != NULL && - (it->old_callback || !it->callback.is_null())) { - if (it->old_callback) - it->old_callback->Run(0); // Report EOF. - else - it->callback.Run(0); + !it->callback.is_null()) { + it->callback.Run(0); // Report EOF. } } @@ -175,29 +168,9 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { }; PendingReq(Type type, net::DrainableIOBuffer* io_buf, - net::OldCompletionCallback* callback) - : type(type), - io_buf(io_buf), - old_callback(callback) { - switch (type) { - case PendingReq::TYPE_READ: - case PendingReq::TYPE_WRITE: - case PendingReq::TYPE_READ_METADATA: - case PendingReq::TYPE_WRITE_METADATA: { - DCHECK(io_buf); - break; - } - default: { - NOTREACHED(); - break; - } - } - } - PendingReq(Type type, net::DrainableIOBuffer* io_buf, const net::CompletionCallback& callback) : type(type), io_buf(io_buf), - old_callback(NULL), callback(callback) { switch (type) { case PendingReq::TYPE_READ: @@ -216,76 +189,11 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { Type type; scoped_refptr<net::DrainableIOBuffer> io_buf; - net::OldCompletionCallback* old_callback; net::CompletionCallback callback; }; // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE { - if (buf_len == 0) - return 0; - if (buf == NULL || buf_len < 0) { - NOTREACHED(); - return net::ERR_INVALID_ARGUMENT; - } - while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() - - process_handshake_buf_->BytesConsumed()) { - DCHECK(!is_transport_read_pending_); - DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end()); - switch (phase_) { - case PHASE_FRAME_OUTSIDE: - case PHASE_FRAME_INSIDE: - case PHASE_FRAME_LENGTH: - case PHASE_FRAME_SKIP: { - int n = std::min(bytes_remaining, buf_len); - int rv = ProcessDataFrames( - process_handshake_buf_->data(), n, buf->data(), buf_len); - process_handshake_buf_->DidConsume(n); - if (rv == 0) { - // ProcessDataFrames may return zero for non-empty buffer if it - // contains only frame delimiters without real data. In this case: - // try again and do not just return zero (zero stands for EOF). - continue; - } - return rv; - } - case PHASE_SHUT: { - return 0; - } - case PHASE_NYMPH: - case PHASE_HANDSHAKE: - default: { - NOTREACHED(); - return net::ERR_UNEXPECTED; - } - } - } - switch (phase_) { - case PHASE_FRAME_OUTSIDE: - case PHASE_FRAME_INSIDE: - case PHASE_FRAME_LENGTH: - case PHASE_FRAME_SKIP: { - pending_reqs_.push_back(PendingReq( - PendingReq::TYPE_READ, - new net::DrainableIOBuffer(buf, buf_len), - callback)); - ConsiderTransportRead(); - break; - } - case PHASE_SHUT: { - return 0; - } - case PHASE_NYMPH: - case PHASE_HANDSHAKE: - default: { - NOTREACHED(); - return net::ERR_UNEXPECTED; - } - } - return net::ERR_IO_PENDING; - } - virtual int Read(net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) OVERRIDE { if (buf_len == 0) return 0; @@ -351,7 +259,7 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { } virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) OVERRIDE { + const net::CompletionCallback& callback) OVERRIDE { if (buf_len == 0) return 0; if (buf == NULL || buf_len < 0) { @@ -382,7 +290,7 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { frame_start->data()[0] = '\x00'; pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, new net::DrainableIOBuffer(frame_start, 1), - NULL)); + net::CompletionCallback())); pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, new net::DrainableIOBuffer(buf, buf_len), @@ -392,7 +300,7 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { frame_end->data()[0] = '\xff'; pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, new net::DrainableIOBuffer(frame_end, 1), - NULL)); + net::CompletionCallback())); ConsiderTransportWrite(); return net::ERR_IO_PENDING; @@ -411,8 +319,14 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { if (phase_ != PHASE_NYMPH) return net::ERR_UNEXPECTED; phase_ = PHASE_HANDSHAKE; + net::CompletionCallback cb; + if (callback) { + cb = base::Bind(&net::OldCompletionCallback::Run<int>, + base::Unretained(callback)); + } pending_reqs_.push_front(PendingReq( - PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback)); + PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), + cb)); ConsiderTransportRead(); return net::ERR_IO_PENDING; } @@ -441,7 +355,8 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { is_transport_read_pending_ = true; int rv = transport_socket_->Read( it->io_buf.get(), it->io_buf->BytesRemaining(), - transport_read_callback_.get()); + base::Bind(&WebSocketServerSocketImpl::OnRead, + base::Unretained(this))); if (rv != net::ERR_IO_PENDING) { // PostTask rather than direct call in order to: // (1) guarantee calling callback after returning from Read(); @@ -468,7 +383,8 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { is_transport_write_pending_ = true; int rv = transport_socket_->Write( it->io_buf.get(), it->io_buf->BytesRemaining(), - transport_write_callback_.get()); + base::Bind(&WebSocketServerSocketImpl::OnWrite, + base::Unretained(this))); if (rv != net::ERR_IO_PENDING) { // PostTask rather than direct call in order to: // (1) guarantee calling callback after returning from Read(); @@ -485,9 +401,7 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { if (result != 0) { while (!pending_reqs_.empty()) { PendingReq& req = pending_reqs_.front(); - if (req.old_callback) - req.old_callback->Run(result); - else if (!req.callback.is_null()) + if (!req.callback.is_null()) req.callback.Run(result); pending_reqs_.pop_front(); } @@ -537,14 +451,10 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { if (rv > 0) { process_handshake_buf_->DidConsume(rv); phase_ = PHASE_FRAME_OUTSIDE; - net::OldCompletionCallback* old_cb = - pending_reqs_.front().old_callback; net::CompletionCallback cb = pending_reqs_.front().callback; pending_reqs_.pop_front(); ConsiderTransportWrite(); // Schedule answer handshake. - if (old_cb) - old_cb->Run(0); - else if (!cb.is_null()) + if (!cb.is_null()) cb.Run(0); } else if (rv == net::ERR_IO_PENDING) { if (fill_handshake_buf_->BytesRemaining() < 1) @@ -568,12 +478,9 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { return; } if (rv > 0 || phase_ == PHASE_SHUT) { - net::OldCompletionCallback* old_cb = it->old_callback; net::CompletionCallback cb = it->callback; pending_reqs_.erase(it); - if (old_cb) - old_cb->Run(rv); - else if (!cb.is_null()) + if (!cb.is_null()) cb.Run(rv); } break; @@ -612,14 +519,11 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { DCHECK_LE(result, it->io_buf->BytesRemaining()); it->io_buf->DidConsume(result); if (it->io_buf->BytesRemaining() == 0) { - net::OldCompletionCallback* old_cb = it->old_callback; net::CompletionCallback cb = it->callback; int bytes_written = it->io_buf->BytesConsumed(); DCHECK_GT(bytes_written, 0); pending_reqs_.erase(it); - if (old_cb) - old_cb->Run(bytes_written); - else if (!cb.is_null()) + if (!cb.is_null()) cb.Run(bytes_written); } ConsiderTransportWrite(); @@ -885,7 +789,7 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { return net::ERR_LIMIT_VIOLATION; pending_reqs_.push_back(PendingReq( - PendingReq::TYPE_WRITE_METADATA, buffer, NULL)); + PendingReq::TYPE_WRITE_METADATA, buffer, net::CompletionCallback())); DCHECK_GT(term_pos - buf, 0); return term_pos - buf; } @@ -977,10 +881,6 @@ class WebSocketServerSocketImpl : public net::WebSocketServerSocket { // Pending io requests we need to complete. std::deque<PendingReq> pending_reqs_; - // Callbacks from transport to us. - scoped_ptr<net::OldCompletionCallback> transport_read_callback_; - scoped_ptr<net::OldCompletionCallback> transport_write_callback_; - // Whether transport requests are pending. bool is_transport_read_pending_; bool is_transport_write_pending_; diff --git a/net/socket/web_socket_server_socket_unittest.cc b/net/socket/web_socket_server_socket_unittest.cc index cabb4b9..d64cf19 100644 --- a/net/socket/web_socket_server_socket_unittest.cc +++ b/net/socket/web_socket_server_socket_unittest.cc @@ -79,53 +79,25 @@ class TestingTransportSocket : public net::Socket { net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) : sample_(sample), answer_(answer), - old_final_read_callback_(NULL), - method_factory_(this) { + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { } ~TestingTransportSocket() { - if (old_final_read_callback_) { + if (!final_read_callback_.is_null()) { MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoOldReadCallback, - old_final_read_callback_, 0)); - } else if (!final_read_callback_.is_null()) { - MessageLoop::current()->PostTask( - FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoReadCallback, - final_read_callback_, 0)); + base::Bind(&TestingTransportSocket::DoReadCallback, + weak_factory_.GetWeakPtr(), + final_read_callback_, 0)); } } // Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { - CHECK_GT(buf_len, 0); - int remaining = sample_->BytesRemaining(); - if (remaining < 1) { - if (old_final_read_callback_ || !final_read_callback_.is_null()) - return 0; - old_final_read_callback_ = callback; - return net::ERR_IO_PENDING; - } - int lot = GetRand(1, std::min(remaining, buf_len)); - std::copy(sample_->data(), sample_->data() + lot, buf->data()); - sample_->DidConsume(lot); - if (GetRand(0, 1)) { - return lot; - } - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoOldReadCallback, callback, lot)); - return net::ERR_IO_PENDING; - } - virtual int Read(net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { CHECK_GT(buf_len, 0); int remaining = sample_->BytesRemaining(); if (remaining < 1) { - if (old_final_read_callback_ || !final_read_callback_.is_null()) + if (!final_read_callback_.is_null()) return 0; final_read_callback_ = callback; return net::ERR_IO_PENDING; @@ -136,14 +108,15 @@ class TestingTransportSocket : public net::Socket { if (GetRand(0, 1)) { return lot; } - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoReadCallback, callback, lot)); + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&TestingTransportSocket::DoReadCallback, + weak_factory_.GetWeakPtr(), callback, lot)); return net::ERR_IO_PENDING; } virtual int Write(net::IOBuffer* buf, int buf_len, - net::OldCompletionCallback* callback) { + const net::CompletionCallback& callback) { CHECK_GT(buf_len, 0); int remaining = answer_->BytesRemaining(); CHECK_GE(remaining, buf_len); @@ -155,9 +128,10 @@ class TestingTransportSocket : public net::Socket { if (GetRand(0, 1)) { return lot; } - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoWriteCallback, callback, lot)); + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&TestingTransportSocket::DoWriteCallback, + weak_factory_.GetWeakPtr(), callback, lot)); return net::ERR_IO_PENDING; } @@ -171,30 +145,22 @@ class TestingTransportSocket : public net::Socket { net::DrainableIOBuffer* answer() { return answer_.get(); } - void DoOldReadCallback(net::OldCompletionCallback* callback, int result) { - if (result == 0 && !is_closed_) { - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoOldReadCallback, callback, 0)); - } else { - if (callback) - callback->Run(result); - } - } void DoReadCallback(const net::CompletionCallback& callback, int result) { if (result == 0 && !is_closed_) { - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &TestingTransportSocket::DoReadCallback, callback, 0)); + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind( + &TestingTransportSocket::DoReadCallback, + weak_factory_.GetWeakPtr(), callback, 0)); } else { if (!callback.is_null()) callback.Run(result); } } - void DoWriteCallback(net::OldCompletionCallback* callback, int result) { - if (callback) - callback->Run(result); + void DoWriteCallback(const net::CompletionCallback& callback, int result) { + if (!callback.is_null()) + callback.Run(result); } bool is_closed_; @@ -206,10 +172,9 @@ class TestingTransportSocket : public net::Socket { scoped_refptr<net::DrainableIOBuffer> answer_; // Final read callback to report zero (zero stands for EOF). - net::OldCompletionCallback* old_final_read_callback_; net::CompletionCallback final_read_callback_; - ScopedRunnableMethodFactory<TestingTransportSocket> method_factory_; + base::WeakPtrFactory<TestingTransportSocket> weak_factory_; }; class Validator : public net::WebSocketServerSocket::Delegate { @@ -256,9 +221,8 @@ class ReadWriteTracker { net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write) : ws_(ws), buf_size_(1 << 14), - accept_callback_(NewCallback(this, &ReadWriteTracker::OnAccept)), - read_callback_(NewCallback(this, &ReadWriteTracker::OnRead)), - write_callback_(NewCallback(this, &ReadWriteTracker::OnWrite)), + ALLOW_THIS_IN_INITIALIZER_LIST( + accept_callback_(this, &ReadWriteTracker::OnAccept)), read_buf_(new net::IOBuffer(buf_size_)), write_buf_(new net::IOBuffer(buf_size_)), bytes_remaining_to_read_(bytes_to_read), @@ -266,7 +230,7 @@ class ReadWriteTracker { read_initiated_(false), write_initiated_(false), got_final_zero_(false) { - int rv = ws_->Accept(accept_callback_.get()); + int rv = ws_->Accept(&accept_callback_); if (rv != net::ERR_IO_PENDING) OnAccept(rv); } @@ -295,7 +259,8 @@ class ReadWriteTracker { for (int i = 0; i < lot; ++i) write_buf_->data()[i] = ReferenceSeq( bytes_remaining_to_write_ - i - 1, kWriteSalt); - int rv = ws_->Write(write_buf_, lot, write_callback_.get()); + int rv = ws_->Write(write_buf_, lot, base::Bind(&ReadWriteTracker::OnWrite, + base::Unretained(this))); if (rv != net::ERR_IO_PENDING) OnWrite(rv); } @@ -309,7 +274,8 @@ class ReadWriteTracker { lot = GetRand(1, bytes_remaining_to_read_); lot = std::min(lot, buf_size_); } - int rv = ws_->Read(read_buf_, lot, read_callback_.get()); + int rv = ws_->Read(read_buf_, lot, base::Bind(&ReadWriteTracker::OnRead, + base::Unretained(this))); if (rv != net::ERR_IO_PENDING) OnRead(rv); } @@ -340,9 +306,7 @@ class ReadWriteTracker { private: net::WebSocketServerSocket* const ws_; int const buf_size_; - scoped_ptr<net::OldCompletionCallback> accept_callback_; - scoped_ptr<net::OldCompletionCallback> read_callback_; - scoped_ptr<net::OldCompletionCallback> write_callback_; + net::OldCompletionCallbackImpl<ReadWriteTracker> accept_callback_; scoped_refptr<net::IOBuffer> read_buf_; scoped_refptr<net::IOBuffer> write_buf_; int bytes_remaining_to_read_; |