diff options
author | erikchen@google.com <erikchen@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-08-18 18:33:27 +0000 |
---|---|---|
committer | erikchen@google.com <erikchen@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-08-18 18:33:27 +0000 |
commit | 3b78284325c04007a09ba8b47ac2980997e39e1d (patch) | |
tree | 10c1d14f07aa4c43f133820d7cf561071838e66e /net/socket | |
parent | cce6f1b3b3d54334bf116084fadfe5b828c87227 (diff) | |
download | chromium_src-3b78284325c04007a09ba8b47ac2980997e39e1d.zip chromium_src-3b78284325c04007a09ba8b47ac2980997e39e1d.tar.gz chromium_src-3b78284325c04007a09ba8b47ac2980997e39e1d.tar.bz2 |
Added a new MockSocket that enforces ordering of reads/writes.
Re-enabled two broken SPDY tests, using new MockSocket. All SPDY tests should eventually convert to using this new object.
TEST=none
BUG=none
Review URL: http://codereview.chromium.org/3179016
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@56560 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/socket')
-rw-r--r-- | net/socket/socket_test_util.cc | 263 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 163 |
2 files changed, 424 insertions, 2 deletions
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 1cce947..069c293 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -312,6 +312,120 @@ int MockTCPClientSocket::CompleteRead() { return result; } +DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( + net::NetLog* net_log, net::DeterministicSocketData* data) + : MockClientSocket(net_log), + write_pending_(false), + write_callback_(NULL), + write_result_(0), + read_data_(), + read_buf_(NULL), + read_buf_len_(0), + read_pending_(false), + read_callback_(NULL), + data_(data) {} + +void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} + +// TODO(erikchen): Support connect sequencing. +int DeterministicMockTCPClientSocket::Connect( + net::CompletionCallback* callback) { + if (connected_) + return net::OK; + connected_ = true; + if (data_->connect_data().async) { + RunCallbackAsync(callback, data_->connect_data().result); + return net::ERR_IO_PENDING; + } + return data_->connect_data().result; +} + +void DeterministicMockTCPClientSocket::Disconnect() { + MockClientSocket::Disconnect(); +} + +bool DeterministicMockTCPClientSocket::IsConnected() const { + return connected_; +} + +int DeterministicMockTCPClientSocket::Write( + net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(write_callback_ != NULL); + write_pending_ = true; + return net::ERR_IO_PENDING; + } + write_pending_ = false; + return write_result.result; +} + +int DeterministicMockTCPClientSocket::Read( + net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + if (!connected_) + return net::ERR_UNEXPECTED; + + read_data_ = data_->GetNextRead(); + // The buffer should always be big enough to contain all the MockRead data. To + // use small buffers, split the data into multiple MockReads. + DCHECK_LE(read_data_.data_len, buf_len); + + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + + if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) { + read_pending_ = true; + DCHECK(read_callback_); + return ERR_IO_PENDING; + } + + return CompleteRead(); +} + +void DeterministicMockTCPClientSocket::CompleteWrite(){ + write_pending_ = false; + write_callback_->Run(write_result_); +} + +int DeterministicMockTCPClientSocket::CompleteRead() { + DCHECK_GT(read_buf_len_, 0); + DCHECK_LE(read_data_.data_len, read_buf_len_); + DCHECK(read_buf_); + + if (read_data_.result == ERR_IO_PENDING) + read_data_ = data_->GetNextRead(); + DCHECK_NE(ERR_IO_PENDING, read_data_.result); + // If read_data_.async is true, we do not need to wait, since this is already + // the callback. Therefore we don't even bother to check it. + int result = read_data_.result; + + if (read_data_.data_len > 0) { + DCHECK(read_data_.data); + result = std::min(read_buf_len_, read_data_.data_len); + memcpy(read_buf_->data(), read_data_.data, result); + } else { + result = 0; + } + + if (read_pending_) { + read_pending_ = false; + read_callback_->Run(result); + } + + return result; +} + class MockSSLClientSocket::ConnectCallback : public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { public: @@ -650,6 +764,113 @@ void OrderedSocketData::CompleteRead() { } } +DeterministicSocketData::DeterministicSocketData(MockRead* reads, + size_t reads_count, MockWrite* writes, size_t writes_count) + : StaticSocketDataProvider(reads, reads_count, writes, writes_count), + sequence_number_(0), + current_read_(), + current_write_(), + next_read_seq_(0), + stopping_sequence_number_(1<<31), + stopped_(false), + print_debug_(false) {} + +MockRead DeterministicSocketData::GetNextRead() { + const MockRead& next_read = StaticSocketDataProvider::PeekRead(); + EXPECT_LE(sequence_number_, next_read.sequence_number); + current_read_ = next_read; + next_read_seq_ = current_read_.sequence_number; + if (sequence_number_ >= stopping_sequence_number_) { + SetStopped(true); + NET_TRACE(INFO, " *** ") << "Force Stop. I/O Pending on read. Stage " + << sequence_number_; + MockRead result = MockRead(false, ERR_IO_PENDING); + if (print_debug_) + DumpMockRead(result); + return result; + } + if (sequence_number_ < next_read.sequence_number) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": I/O Pending"; + MockRead result = MockRead(false, ERR_IO_PENDING); + if (print_debug_) + DumpMockRead(result); + return result; + } + + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Read " << read_index(); + if (print_debug_) + DumpMockRead(next_read); + sequence_number_++; + StaticSocketDataProvider::GetNextRead(); + if (current_read_.result == ERR_IO_PENDING) + current_read_ = StaticSocketDataProvider::GetNextRead(); + + if (!at_read_eof()) + next_read_seq_ = PeekRead().sequence_number; + + return next_read; +} + +MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Write " << write_index(); + const MockWrite& next_write = StaticSocketDataProvider::PeekWrite(); + DCHECK_LE(next_write.sequence_number, sequence_number_); + if (print_debug_) + DumpMockRead(next_write); + ++sequence_number_; + current_write_ = next_write; + return StaticSocketDataProvider::OnWrite(data); +} + +void DeterministicSocketData::Reset(){ + NET_TRACE(INFO, " *** ") << "Stage " + << sequence_number_ << ": Reset()"; + sequence_number_ = 0; + StaticSocketDataProvider::Reset(); + NOTREACHED(); +} + +void DeterministicSocketData::Run(){ + int counter = 0; + // Continue to consume data until all data has run out, or the stopped_ flag + // has been set. Consuming data requires two separate operations -- running + // the tasks in the message loop, and explicitly invoking the read/write + // callbacks (simulating network I/O). We check our conditions between each, + // since they can change in either. + while ((!at_write_eof() || !at_read_eof()) && + !stopped()) { + if (counter % 2 == 0) + MessageLoop::current()->RunAllPending(); + if (counter % 2 == 1) + InvokeCallbacks(); + counter++; + } + // We're done consuming new data, but it is possible there are still some + // pending callbacks which we expect to complete before returning. + while (socket_ && (socket_->write_pending() || socket_->read_pending()) && + !stopped()) { + InvokeCallbacks(); + MessageLoop::current()->RunAllPending(); + } + SetStopped(false); +} + +void DeterministicSocketData::InvokeCallbacks(){ + if (socket_ && socket_->write_pending() && + (current_write().sequence_number == sequence_number())) { + socket_->CompleteWrite(); + return; + } + if (socket_ && socket_->read_pending() && + (next_read_seq() == sequence_number())) { + socket_->CompleteRead(); + return; + } +} + void MockClientSocketFactory::AddSocketDataProvider( SocketDataProvider* data) { mock_data_.Add(data); @@ -698,6 +919,48 @@ SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( return socket; } +void DeterministicMockClientSocketFactory::AddSocketDataProvider( + DeterministicSocketData* data) { + mock_data_.Add(data); +} + +void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider( + SSLSocketDataProvider* data) { + mock_ssl_data_.Add(data); +} + +void DeterministicMockClientSocketFactory::ResetNextMockIndexes() { + mock_data_.ResetNextIndex(); + mock_ssl_data_.ResetNextIndex(); +} + +MockSSLClientSocket* DeterministicMockClientSocketFactory:: + GetMockSSLClientSocket(size_t index) const { + DCHECK_LT(index, ssl_client_sockets_.size()); + return ssl_client_sockets_[index]; +} + +ClientSocket* DeterministicMockClientSocketFactory::CreateTCPClientSocket( + const AddressList& addresses, net::NetLog* net_log) { + DeterministicSocketData* data_provider = mock_data().GetNext(); + DeterministicMockTCPClientSocket* socket = + new DeterministicMockTCPClientSocket(net_log, data_provider); + data_provider->set_socket(socket->AsWeakPtr()); + tcp_client_sockets().push_back(socket); + return socket; +} + +SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( + ClientSocketHandle* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) { + MockSSLClientSocket* socket = + new MockSSLClientSocket(transport_socket, hostname, ssl_config, + mock_ssl_data_.GetNext()); + ssl_client_sockets_.push_back(socket); + return socket; +} + int TestSocketRequest::WaitForResult() { return callback_.WaitForResult(); } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 9367349..0904dbe 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -17,6 +17,7 @@ #include "base/scoped_ptr.h" #include "base/scoped_vector.h" #include "base/string16.h" +#include "base/weak_ptr.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -168,6 +169,7 @@ class StaticSocketDataProvider : public SocketDataProvider { write_index_(0), write_count_(writes_count) { } + virtual ~StaticSocketDataProvider() {} // SocketDataProvider methods: virtual MockRead GetNextRead(); @@ -343,6 +345,89 @@ class OrderedSocketData : public StaticSocketDataProvider, ScopedRunnableMethodFactory<OrderedSocketData> factory_; }; +class DeterministicMockTCPClientSocket; + +// This class gives the user full control over the mock socket reads and writes, +// including the timing of the callbacks. By default, synchronous reads and +// writes will force the callback for that read or write to complete before +// allowing another read or write to finish. +// +// Sequence numbers are preserved across both reads and writes. There should be +// no gaps in sequence numbers, and no repeated sequence numbers. i.e. +// MockWrite writes[] = { +// MockWrite(true, "first write", length, 0), +// MockWrite(false, "second write", length, 3), +// }; +// +// MockRead reads[] = { +// MockRead(false, "first read", length, 1) +// MockRead(false, "second read", length, 2) +// }; +// Example control flow: +// The first write completes. A call to read() returns ERR_IO_PENDING, since the +// first write's callback has not happened yet. The first write's callback is +// called. Now the first read's callback will be called. A call to write() will +// succeed, because the write() API requires this, but the callback will not be +// called until the second read has completed and its callback called. +class DeterministicSocketData : public StaticSocketDataProvider, + public base::RefCounted<DeterministicSocketData> { + public: + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + DeterministicSocketData(MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + // |connect| the result for the connect phase. + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + DeterministicSocketData(const MockConnect& connect, + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + // When the socket calls Read(), that calls GetNextRead(), and expects either + // ERR_IO_PENDING or data. + virtual MockRead GetNextRead(); + + // When the socket calls Write(), it always completes synchronously. OnWrite() + // checks to make sure the written data matches the expected data. The + // callback will not be invoked until its sequence number is reached. + virtual MockWriteResult OnWrite(const std::string& data); + + virtual void Reset(); + + // Consume all the data up to the give stop point (via SetStop()). + void Run(); + + // Stop when Read() is about to consume a MockRead with sequence_number >= + // seq. Instead feed ERR_IO_PENDING to Read(). + virtual void SetStop(int seq) { stopping_sequence_number_ = seq; } + + void CompleteRead(); + bool stopped() const { return stopped_; } + void SetStopped(bool val) { stopped_ = val; } + MockRead& current_read() { return current_read_; } + MockRead& current_write() { return current_write_; } + int next_read_seq() const { return next_read_seq_; } + int sequence_number() const { return sequence_number_; } + void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) { + socket_ = socket; + } + + private: + // Invoke the read and write callbacks, if the timing is appropriate. + void InvokeCallbacks(); + + int sequence_number_; + MockRead current_read_; + MockWrite current_write_; + int next_read_seq_; + int stopping_sequence_number_; + bool stopped_; + base::WeakPtr<DeterministicMockTCPClientSocket> socket_; + bool print_debug_; +}; + + // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}ClientSocket // objects get instantiated, they take their data from the i'th element of this // array. @@ -404,6 +489,12 @@ class MockClientSocketFactory : public ClientSocketFactory { ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); + SocketDataProviderArray<SocketDataProvider>& mock_data() { + return mock_data_; + } + std::vector<MockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } private: SocketDataProviderArray<SocketDataProvider> mock_data_; @@ -417,7 +508,6 @@ class MockClientSocketFactory : public ClientSocketFactory { class MockClientSocket : public net::SSLClientSocket { public: explicit MockClientSocket(net::NetLog* net_log); - // ClientSocket methods: virtual int Connect(net::CompletionCallback* callback) = 0; virtual void Disconnect(); @@ -448,6 +538,7 @@ class MockClientSocket : public net::SSLClientSocket { virtual void OnReadComplete(const MockRead& data) = 0; protected: + virtual ~MockClientSocket() {} void RunCallbackAsync(net::CompletionCallback* callback, int result); void RunCallback(net::CompletionCallback*, int result); @@ -501,6 +592,40 @@ class MockTCPClientSocket : public MockClientSocket { net::CompletionCallback* pending_callback_; }; +class DeterministicMockTCPClientSocket : public MockClientSocket, + public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> { + public: + DeterministicMockTCPClientSocket(net::NetLog* net_log, + net::DeterministicSocketData* data); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual void CompleteWrite(); + virtual int CompleteRead(); + virtual void OnReadComplete(const MockRead& data); + + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const { return IsConnected(); } + bool write_pending() { return write_pending_; } + bool read_pending() { return read_pending_; } + + private: + bool write_pending_; + net::CompletionCallback* write_callback_; + int write_result_; + + net::MockRead read_data_; + + net::IOBuffer* read_buf_; + int read_buf_len_; + bool read_pending_; + net::CompletionCallback* read_callback_; + net::DeterministicSocketData* data_; +}; + class MockSSLClientSocket : public MockClientSocket { public: MockSSLClientSocket( @@ -659,13 +784,47 @@ class MockTCPClientSocketPool : public TCPClientSocketPool { private: ClientSocketFactory* client_socket_factory_; + ScopedVector<MockConnectJob> job_list_; int release_count_; int cancel_count_; - ScopedVector<MockConnectJob> job_list_; DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool); }; +class DeterministicMockClientSocketFactory : public ClientSocketFactory { + public: + void AddSocketDataProvider(DeterministicSocketData* socket); + void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); + void ResetNextMockIndexes(); + + // Return |index|-th MockSSLClientSocket (starting from 0) that the factory + // created. + MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; + + // ClientSocketFactory + virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses, + NetLog* net_log); + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocketHandle* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config); + + SocketDataProviderArray<DeterministicSocketData>& mock_data() { + return mock_data_; + } + std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } + + private: + SocketDataProviderArray<DeterministicSocketData> mock_data_; + SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; + + // Store pointers to handed out sockets in case the test wants to get them. + std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_; + std::vector<MockSSLClientSocket*> ssl_client_sockets_; +}; + class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { public: MockSOCKSClientSocketPool( |