diff options
Diffstat (limited to 'net/socket/socket_test_util.h')
-rw-r--r-- | net/socket/socket_test_util.h | 332 |
1 files changed, 266 insertions, 66 deletions
diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 2bd0d23..7f83a70 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -4,6 +4,7 @@ #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ #define NET_SOCKET_SOCKET_TEST_UTIL_H_ +#pragma once #include <cstring> #include <deque> @@ -15,6 +16,8 @@ #include "base/logging.h" #include "base/scoped_ptr.h" #include "base/scoped_vector.h" +#include "base/string16.h" +#include "base/weak_ptr.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -27,6 +30,7 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket.h" +#include "net/socket/ssl_client_socket_pool.h" #include "net/socket/tcp_client_socket_pool.h" #include "testing/gtest/include/gtest/gtest.h" @@ -42,8 +46,6 @@ enum { }; class ClientSocket; -class HttpRequestHeaders; -class HttpResponseHeaders; class MockClientSocket; class SSLClientSocket; @@ -166,6 +168,7 @@ class StaticSocketDataProvider : public SocketDataProvider { write_index_(0), write_count_(writes_count) { } + virtual ~StaticSocketDataProvider() {} // SocketDataProvider methods: virtual MockRead GetNextRead(); @@ -277,11 +280,13 @@ class DelayedSocketData : public StaticSocketDataProvider, DelayedSocketData(const MockConnect& connect, int write_delay, MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count); + ~DelayedSocketData(); virtual MockRead GetNextRead(); virtual MockWriteResult OnWrite(const std::string& data); virtual void Reset(); void CompleteRead(); + void ForceNextRead(); private: int write_delay_; @@ -340,6 +345,89 @@ class OrderedSocketData : public StaticSocketDataProvider, ScopedRunnableMethodFactory<OrderedSocketData> factory_; }; +class DeterministicMockTCPClientSocket; + +// This class gives the user full control over the mock socket reads and writes, +// including the timing of the callbacks. By default, synchronous reads and +// writes will force the callback for that read or write to complete before +// allowing another read or write to finish. +// +// Sequence numbers are preserved across both reads and writes. There should be +// no gaps in sequence numbers, and no repeated sequence numbers. i.e. +// MockWrite writes[] = { +// MockWrite(true, "first write", length, 0), +// MockWrite(false, "second write", length, 3), +// }; +// +// MockRead reads[] = { +// MockRead(false, "first read", length, 1) +// MockRead(false, "second read", length, 2) +// }; +// Example control flow: +// The first write completes. A call to read() returns ERR_IO_PENDING, since the +// first write's callback has not happened yet. The first write's callback is +// called. Now the first read's callback will be called. A call to write() will +// succeed, because the write() API requires this, but the callback will not be +// called until the second read has completed and its callback called. +class DeterministicSocketData : public StaticSocketDataProvider, + public base::RefCounted<DeterministicSocketData> { + public: + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + DeterministicSocketData(MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + // |connect| the result for the connect phase. + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + DeterministicSocketData(const MockConnect& connect, + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + // When the socket calls Read(), that calls GetNextRead(), and expects either + // ERR_IO_PENDING or data. + virtual MockRead GetNextRead(); + + // When the socket calls Write(), it always completes synchronously. OnWrite() + // checks to make sure the written data matches the expected data. The + // callback will not be invoked until its sequence number is reached. + virtual MockWriteResult OnWrite(const std::string& data); + + virtual void Reset(); + + // Consume all the data up to the give stop point (via SetStop()). + void Run(); + + // Stop when Read() is about to consume a MockRead with sequence_number >= + // seq. Instead feed ERR_IO_PENDING to Read(). + virtual void SetStop(int seq) { stopping_sequence_number_ = seq; } + + void CompleteRead(); + bool stopped() const { return stopped_; } + void SetStopped(bool val) { stopped_ = val; } + MockRead& current_read() { return current_read_; } + MockRead& current_write() { return current_write_; } + int next_read_seq() const { return next_read_seq_; } + int sequence_number() const { return sequence_number_; } + void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) { + socket_ = socket; + } + + private: + // Invoke the read and write callbacks, if the timing is appropriate. + void InvokeCallbacks(); + + int sequence_number_; + MockRead current_read_; + MockWrite current_write_; + int next_read_seq_; + int stopping_sequence_number_; + bool stopped_; + base::WeakPtr<DeterministicMockTCPClientSocket> socket_; + bool print_debug_; +}; + + // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}ClientSocket // objects get instantiated, they take their data from the i'th element of this // array. @@ -395,12 +483,20 @@ class MockClientSocketFactory : public ClientSocketFactory { MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; // ClientSocketFactory - virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses, - NetLog* net_log); + virtual ClientSocket* CreateTCPClientSocket( + const AddressList& addresses, + NetLog* net_log, + const NetLog::Source& source); virtual SSLClientSocket* CreateSSLClientSocket( ClientSocketHandle* transport_socket, const std::string& hostname, const SSLConfig& ssl_config); + SocketDataProviderArray<SocketDataProvider>& mock_data() { + return mock_data_; + } + std::vector<MockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } private: SocketDataProviderArray<SocketDataProvider> mock_data_; @@ -414,7 +510,6 @@ class MockClientSocketFactory : public ClientSocketFactory { class MockClientSocket : public net::SSLClientSocket { public: explicit MockClientSocket(net::NetLog* net_log); - // ClientSocket methods: virtual int Connect(net::CompletionCallback* callback) = 0; virtual void Disconnect(); @@ -422,6 +517,8 @@ class MockClientSocket : public net::SSLClientSocket { virtual bool IsConnectedAndIdle() const; virtual int GetPeerAddress(AddressList* address) const; virtual const BoundNetLog& NetLog() const { return net_log_;} + virtual void SetSubresourceSpeculation() {} + virtual void SetOmniboxSpeculation() {} // SSLClientSocket methods: virtual void GetSSLInfo(net::SSLInfo* ssl_info); @@ -445,6 +542,7 @@ class MockClientSocket : public net::SSLClientSocket { virtual void OnReadComplete(const MockRead& data) = 0; protected: + virtual ~MockClientSocket() {} void RunCallbackAsync(net::CompletionCallback* callback, int result); void RunCallback(net::CompletionCallback*, int result); @@ -466,6 +564,7 @@ class MockTCPClientSocket : public MockClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const { return IsConnected(); } + virtual bool WasEverUsed() const { return was_used_to_convey_data_; } // Socket methods: virtual int Read(net::IOBuffer* buf, int buf_len, @@ -496,6 +595,48 @@ class MockTCPClientSocket : public MockClientSocket { net::IOBuffer* pending_buf_; int pending_buf_len_; net::CompletionCallback* pending_callback_; + bool was_used_to_convey_data_; +}; + +class DeterministicMockTCPClientSocket : public MockClientSocket, + public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> { + public: + DeterministicMockTCPClientSocket(net::NetLog* net_log, + net::DeterministicSocketData* data); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const { return IsConnected(); } + virtual bool WasEverUsed() const { return was_used_to_convey_data_; } + + // Socket methods: + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + bool write_pending() const { return write_pending_; } + bool read_pending() const { return read_pending_; } + + void CompleteWrite(); + int CompleteRead(); + void OnReadComplete(const MockRead& data); + + private: + bool write_pending_; + net::CompletionCallback* write_callback_; + int write_result_; + + net::MockRead read_data_; + + net::IOBuffer* read_buf_; + int read_buf_len_; + bool read_pending_; + net::CompletionCallback* read_callback_; + net::DeterministicSocketData* data_; + bool was_used_to_convey_data_; }; class MockSSLClientSocket : public MockClientSocket { @@ -510,6 +651,8 @@ class MockSSLClientSocket : public MockClientSocket { // ClientSocket methods: virtual int Connect(net::CompletionCallback* callback); virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool WasEverUsed() const; // Socket methods: virtual int Read(net::IOBuffer* buf, int buf_len, @@ -520,8 +663,8 @@ class MockSSLClientSocket : public MockClientSocket { // SSLClientSocket methods: virtual void GetSSLInfo(net::SSLInfo* ssl_info); virtual NextProtoStatus GetNextProto(std::string* proto); - virtual bool wasNpnNegotiated() const; - virtual bool setWasNpnNegotiated(bool negotiated); + virtual bool was_npn_negotiated() const; + virtual bool set_was_npn_negotiated(bool negotiated); // This MockSocket does not implement the manual async IO feature. virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } @@ -533,6 +676,7 @@ class MockSSLClientSocket : public MockClientSocket { net::SSLSocketDataProvider* data_; bool is_npn_state_set_; bool new_npn_value_; + bool was_used_to_convey_data_; }; class TestSocketRequest : public CallbackRunner< Tuple1<int> > { @@ -558,8 +702,8 @@ class TestSocketRequest : public CallbackRunner< Tuple1<int> > { TestCompletionCallback callback_; }; -class ClientSocketPoolTest : public testing::Test { - protected: +class ClientSocketPoolTest { + public: enum KeepAlive { KEEP_ALIVE, @@ -570,15 +714,15 @@ class ClientSocketPoolTest : public testing::Test { static const int kIndexOutOfBounds; static const int kRequestNotFound; - virtual void SetUp(); - virtual void TearDown(); + ClientSocketPoolTest(); + ~ClientSocketPoolTest(); template <typename PoolType, typename SocketParams> - int StartRequestUsingPool(const scoped_refptr<PoolType>& socket_pool, + int StartRequestUsingPool(PoolType* socket_pool, const std::string& group_name, RequestPriority priority, const scoped_refptr<SocketParams>& socket_params) { - DCHECK(socket_pool.get()); + DCHECK(socket_pool); TestSocketRequest* request = new TestSocketRequest(&request_order_, &completion_count_); requests_.push_back(request); @@ -594,7 +738,7 @@ class ClientSocketPoolTest : public testing::Test { // and returns order in which that request completed, in range 1..n, // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound // if that request did not complete (for example was canceled). - int GetOrderOfRequest(size_t index); + int GetOrderOfRequest(size_t index) const; // Resets first initialized socket handle from |requests_|. If found such // a handle, returns true. @@ -603,6 +747,12 @@ class ClientSocketPoolTest : public testing::Test { // Releases connections until there is nothing to release. void ReleaseAllConnections(KeepAlive keep_alive); + TestSocketRequest* request(int i) { return requests_[i]; } + size_t requests_size() const { return requests_.size(); } + ScopedVector<TestSocketRequest>* requests() { return &requests_; } + size_t completion_count() const { return completion_count_; } + + private: ScopedVector<TestSocketRequest> requests_; std::vector<TestSocketRequest*> request_order_; size_t completion_count_; @@ -632,11 +782,13 @@ class MockTCPClientSocketPool : public TCPClientSocketPool { MockTCPClientSocketPool( int max_sockets, int max_sockets_per_group, - const scoped_refptr<ClientSocketPoolHistograms>& histograms, + ClientSocketPoolHistograms* histograms, ClientSocketFactory* socket_factory); - int release_count() const { return release_count_; }; - int cancel_count() const { return cancel_count_; }; + virtual ~MockTCPClientSocketPool(); + + int release_count() const { return release_count_; } + int cancel_count() const { return cancel_count_; } // TCPClientSocketPool methods. virtual int RequestSocket(const std::string& group_name, @@ -651,25 +803,59 @@ class MockTCPClientSocketPool : public TCPClientSocketPool { virtual void ReleaseSocket(const std::string& group_name, ClientSocket* socket, int id); - protected: - virtual ~MockTCPClientSocketPool(); - private: ClientSocketFactory* client_socket_factory_; + ScopedVector<MockConnectJob> job_list_; int release_count_; int cancel_count_; - ScopedVector<MockConnectJob> job_list_; DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool); }; +class DeterministicMockClientSocketFactory : public ClientSocketFactory { + public: + void AddSocketDataProvider(DeterministicSocketData* socket); + void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); + void ResetNextMockIndexes(); + + // Return |index|-th MockSSLClientSocket (starting from 0) that the factory + // created. + MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; + + // ClientSocketFactory + virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses, + NetLog* net_log, + const NetLog::Source& source); + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocketHandle* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config); + + SocketDataProviderArray<DeterministicSocketData>& mock_data() { + return mock_data_; + } + std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } + + private: + SocketDataProviderArray<DeterministicSocketData> mock_data_; + SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; + + // Store pointers to handed out sockets in case the test wants to get them. + std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_; + std::vector<MockSSLClientSocket*> ssl_client_sockets_; +}; + class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { public: MockSOCKSClientSocketPool( int max_sockets, int max_sockets_per_group, - const scoped_refptr<ClientSocketPoolHistograms>& histograms, - const scoped_refptr<TCPClientSocketPool>& tcp_pool); + ClientSocketPoolHistograms* histograms, + TCPClientSocketPool* tcp_pool); + + virtual ~MockSOCKSClientSocketPool(); // SOCKSClientSocketPool methods. virtual int RequestSocket(const std::string& group_name, @@ -684,54 +870,12 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { virtual void ReleaseSocket(const std::string& group_name, ClientSocket* socket, int id); - protected: - virtual ~MockSOCKSClientSocketPool(); - private: - const scoped_refptr<TCPClientSocketPool> tcp_pool_; + TCPClientSocketPool* const tcp_pool_; DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool); }; -struct MockHttpAuthControllerData { - MockHttpAuthControllerData(std::string header) : auth_header(header) {} - - std::string auth_header; -}; - -class MockHttpAuthController : public HttpAuthController { - public: - MockHttpAuthController(); - void SetMockAuthControllerData(struct MockHttpAuthControllerData* data, - size_t data_length); - - // HttpAuthController methods. - virtual int MaybeGenerateAuthToken(const HttpRequestInfo* request, - CompletionCallback* callback, - const BoundNetLog& net_log); - virtual void AddAuthorizationHeader( - HttpRequestHeaders* authorization_headers); - virtual int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers, - bool do_not_send_server_auth, - bool establishing_tunnel, - const BoundNetLog& net_log); - virtual void ResetAuth(const std::wstring& username, - const std::wstring& password); - virtual bool HaveAuthHandler() const; - virtual bool HaveAuth() const; - - private: - virtual ~MockHttpAuthController() {} - const struct MockHttpAuthControllerData& CurrentData() const { - DCHECK(data_index_ < data_count_); - return data_[data_index_]; - } - - MockHttpAuthControllerData* data_; - size_t data_index_; - size_t data_count_; -}; - // Constants for a successful SOCKS v5 handshake. extern const char kSOCKS5GreetRequest[]; extern const int kSOCKS5GreetRequestLength; @@ -745,6 +889,62 @@ extern const int kSOCKS5OkRequestLength; extern const char kSOCKS5OkResponse[]; extern const int kSOCKS5OkResponseLength; +class MockSSLClientSocketPool : public SSLClientSocketPool { + public: + class MockConnectJob { + public: + MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle, + CompletionCallback* callback); + + int Connect(); + bool CancelHandle(const ClientSocketHandle* handle); + + private: + void OnConnect(int rv); + + scoped_ptr<ClientSocket> socket_; + ClientSocketHandle* handle_; + CompletionCallback* user_callback_; + CompletionCallbackImpl<MockConnectJob> connect_callback_; + + DISALLOW_COPY_AND_ASSIGN(MockConnectJob); + }; + + MockSSLClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + ClientSocketFactory* socket_factory, + TCPClientSocketPool* tcp_pool); + + virtual ~MockSSLClientSocketPool(); + + int release_count() const { return release_count_; } + int cancel_count() const { return cancel_count_; } + + // SSLClientSocketPool methods. + virtual int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + CompletionCallback* callback, + const BoundNetLog& net_log); + + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle); + virtual void ReleaseSocket(const std::string& group_name, + ClientSocket* socket, int id); + + private: + ClientSocketFactory* client_socket_factory_; + int release_count_; + int cancel_count_; + ScopedVector<MockConnectJob> job_list_; + + DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocketPool); +}; + + } // namespace net #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ |