diff options
-rw-r--r-- | net/socket/socket_test_util.cc | 100 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 30 |
2 files changed, 108 insertions, 22 deletions
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 8b7fcbf..239a81d 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -67,7 +67,10 @@ MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, data_(data), read_offset_(0), read_data_(true, net::ERR_UNEXPECTED), - need_read_data_(true) { + need_read_data_(true), + pending_buf_(NULL), + pending_buf_len_(0), + pending_callback_(NULL) { DCHECK(data_); data_->Reset(); } @@ -89,11 +92,77 @@ int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, if (!IsConnected()) return net::ERR_UNEXPECTED; + // If the buffer is already in use, a read is already in progress! + DCHECK(pending_buf_ == NULL); + + // Store our async IO data. + pending_buf_ = buf; + pending_buf_len_ = buf_len; + pending_callback_ = callback; + if (need_read_data_) { read_data_ = data_->GetNextRead(); + // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility + // to complete the async IO manually later (via OnReadComplete). + if (read_data_.result == ERR_IO_PENDING) { + DCHECK(callback); // We need to be using async IO in this case. + return ERR_IO_PENDING; + } need_read_data_ = false; } + + return CompleteRead(); +} + +int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!IsConnected()) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + RunCallbackAsync(callback, write_result.result); + return net::ERR_IO_PENDING; + } + return write_result.result; +} + +void MockTCPClientSocket::OnReadComplete(const MockRead& data) { + // There must be a read pending. + DCHECK(pending_buf_); + // You can't complete a read with another ERR_IO_PENDING status code. + DCHECK_NE(ERR_IO_PENDING, data.result); + // Since we've been waiting for data, need_read_data_ should be true. + DCHECK(need_read_data_); + // In order to fire the callback, this IO needs to be marked as async. + DCHECK(data.async); + + read_data_ = data; + need_read_data_ = false; + + CompleteRead(); +} + +int MockTCPClientSocket::CompleteRead() { + DCHECK(pending_buf_); + DCHECK(pending_buf_len_ > 0); + + // Save the pending async IO data and reset our |pending_| state. + net::IOBuffer* buf = pending_buf_; + int buf_len = pending_buf_len_; + net::CompletionCallback* callback = pending_callback_; + pending_buf_ = NULL; + pending_buf_len_ = 0; + pending_callback_ = NULL; + int result = read_data_.result; + DCHECK(result != ERR_IO_PENDING); + if (read_data_.data) { if (read_data_.data_len - read_offset_ > 0) { result = std::min(buf_len, read_data_.data_len - read_offset_); @@ -107,31 +176,15 @@ int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, result = 0; // EOF } } + if (read_data_.async) { + DCHECK(callback); RunCallbackAsync(callback, result); return net::ERR_IO_PENDING; } return result; } -int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - DCHECK(buf); - DCHECK_GT(buf_len, 0); - - if (!IsConnected()) - return net::ERR_UNEXPECTED; - - std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); - - if (write_result.async) { - RunCallbackAsync(callback, write_result.result); - return net::ERR_IO_PENDING; - } - return write_result.result; -} - class MockSSLClientSocket::ConnectCallback : public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { public: @@ -212,7 +265,10 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, } MockRead StaticSocketDataProvider::GetNextRead() { - return reads_[read_index_++]; + MockRead rv = reads_[read_index_]; + if (reads_[read_index_].data_len != 0) + read_index_++; // Don't advance past an EOF. + return rv; } MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { @@ -298,8 +354,10 @@ MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( ClientSocket* MockClientSocketFactory::CreateTCPClientSocket( const AddressList& addresses) { + SocketDataProvider* data_provider = mock_data_.GetNext(); MockTCPClientSocket* socket = - new MockTCPClientSocket(addresses, mock_data_.GetNext()); + new MockTCPClientSocket(addresses, data_provider); + data_provider->set_socket(socket); tcp_client_sockets_.push_back(socket); return socket; } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 0e1380c..c369435 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -27,6 +27,7 @@ namespace net { class ClientSocket; class LoadLog; +class MockClientSocket; class SSLClientSocket; struct MockConnect { @@ -78,17 +79,27 @@ struct MockWriteResult { // for getting data about individual reads and writes on the socket. class SocketDataProvider { public: - SocketDataProvider() {} + SocketDataProvider() : socket_(NULL) {} virtual ~SocketDataProvider() {} + + // Returns the buffer and result code for the next simulated read. + // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller + // that it will be called via the MockClientSocket::OnReadComplete() + // function at a later time. virtual MockRead GetNextRead() = 0; virtual MockWriteResult OnWrite(const std::string& data) = 0; virtual void Reset() = 0; + // Accessor for the socket which is using the SocketDataProvider. + MockClientSocket* socket() { return socket_; } + void set_socket(MockClientSocket* socket) { socket_ = socket; } + MockConnect connect_data() const { return connect_; } private: MockConnect connect_; + MockClientSocket* socket_; DISALLOW_COPY_AND_ASSIGN(SocketDataProvider); }; @@ -264,6 +275,11 @@ class MockClientSocket : public net::SSLClientSocket { virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); #endif + // If an async IO is pending because the SocketDataProvider returned + // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete + // is called to complete the asynchronous read operation. + virtual void OnReadComplete(const MockRead& data) = 0; + protected: void RunCallbackAsync(net::CompletionCallback* callback, int result); void RunCallback(net::CompletionCallback*, int result); @@ -287,15 +303,24 @@ class MockTCPClientSocket : public MockClientSocket { virtual int Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback); + virtual void OnReadComplete(const MockRead& data); + net::AddressList addresses() const { return addresses_; } private: + int CompleteRead(); + net::AddressList addresses_; net::SocketDataProvider* data_; int read_offset_; net::MockRead read_data_; bool need_read_data_; + + // While an asynchronous IO is pending, we save our user-buffer state. + net::IOBuffer* pending_buf_; + int pending_buf_len_; + net::CompletionCallback* pending_callback_; }; class MockSSLClientSocket : public MockClientSocket { @@ -318,6 +343,9 @@ class MockSSLClientSocket : public MockClientSocket { virtual int Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback); + // This MockSocket does not implement the manual async IO feature. + virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } + private: class ConnectCallback; |