summaryrefslogtreecommitdiffstats
path: root/net/socket/socket_test_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'net/socket/socket_test_util.h')
-rw-r--r--net/socket/socket_test_util.h332
1 files changed, 266 insertions, 66 deletions
diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h
index 2bd0d23..7f83a70 100644
--- a/net/socket/socket_test_util.h
+++ b/net/socket/socket_test_util.h
@@ -4,6 +4,7 @@
#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
#define NET_SOCKET_SOCKET_TEST_UTIL_H_
+#pragma once
#include <cstring>
#include <deque>
@@ -15,6 +16,8 @@
#include "base/logging.h"
#include "base/scoped_ptr.h"
#include "base/scoped_vector.h"
+#include "base/string16.h"
+#include "base/weak_ptr.h"
#include "net/base/address_list.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
@@ -27,6 +30,7 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/socks_client_socket_pool.h"
#include "net/socket/ssl_client_socket.h"
+#include "net/socket/ssl_client_socket_pool.h"
#include "net/socket/tcp_client_socket_pool.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -42,8 +46,6 @@ enum {
};
class ClientSocket;
-class HttpRequestHeaders;
-class HttpResponseHeaders;
class MockClientSocket;
class SSLClientSocket;
@@ -166,6 +168,7 @@ class StaticSocketDataProvider : public SocketDataProvider {
write_index_(0),
write_count_(writes_count) {
}
+ virtual ~StaticSocketDataProvider() {}
// SocketDataProvider methods:
virtual MockRead GetNextRead();
@@ -277,11 +280,13 @@ class DelayedSocketData : public StaticSocketDataProvider,
DelayedSocketData(const MockConnect& connect, int write_delay,
MockRead* reads, size_t reads_count,
MockWrite* writes, size_t writes_count);
+ ~DelayedSocketData();
virtual MockRead GetNextRead();
virtual MockWriteResult OnWrite(const std::string& data);
virtual void Reset();
void CompleteRead();
+ void ForceNextRead();
private:
int write_delay_;
@@ -340,6 +345,89 @@ class OrderedSocketData : public StaticSocketDataProvider,
ScopedRunnableMethodFactory<OrderedSocketData> factory_;
};
+class DeterministicMockTCPClientSocket;
+
+// This class gives the user full control over the mock socket reads and writes,
+// including the timing of the callbacks. By default, synchronous reads and
+// writes will force the callback for that read or write to complete before
+// allowing another read or write to finish.
+//
+// Sequence numbers are preserved across both reads and writes. There should be
+// no gaps in sequence numbers, and no repeated sequence numbers. i.e.
+// MockWrite writes[] = {
+// MockWrite(true, "first write", length, 0),
+// MockWrite(false, "second write", length, 3),
+// };
+//
+// MockRead reads[] = {
+// MockRead(false, "first read", length, 1)
+// MockRead(false, "second read", length, 2)
+// };
+// Example control flow:
+// The first write completes. A call to read() returns ERR_IO_PENDING, since the
+// first write's callback has not happened yet. The first write's callback is
+// called. Now the first read's callback will be called. A call to write() will
+// succeed, because the write() API requires this, but the callback will not be
+// called until the second read has completed and its callback called.
+class DeterministicSocketData : public StaticSocketDataProvider,
+ public base::RefCounted<DeterministicSocketData> {
+ public:
+ // |reads| the list of MockRead completions.
+ // |writes| the list of MockWrite completions.
+ DeterministicSocketData(MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+
+ // |connect| the result for the connect phase.
+ // |reads| the list of MockRead completions.
+ // |writes| the list of MockWrite completions.
+ DeterministicSocketData(const MockConnect& connect,
+ MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+
+ // When the socket calls Read(), that calls GetNextRead(), and expects either
+ // ERR_IO_PENDING or data.
+ virtual MockRead GetNextRead();
+
+ // When the socket calls Write(), it always completes synchronously. OnWrite()
+ // checks to make sure the written data matches the expected data. The
+ // callback will not be invoked until its sequence number is reached.
+ virtual MockWriteResult OnWrite(const std::string& data);
+
+ virtual void Reset();
+
+ // Consume all the data up to the give stop point (via SetStop()).
+ void Run();
+
+ // Stop when Read() is about to consume a MockRead with sequence_number >=
+ // seq. Instead feed ERR_IO_PENDING to Read().
+ virtual void SetStop(int seq) { stopping_sequence_number_ = seq; }
+
+ void CompleteRead();
+ bool stopped() const { return stopped_; }
+ void SetStopped(bool val) { stopped_ = val; }
+ MockRead& current_read() { return current_read_; }
+ MockRead& current_write() { return current_write_; }
+ int next_read_seq() const { return next_read_seq_; }
+ int sequence_number() const { return sequence_number_; }
+ void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) {
+ socket_ = socket;
+ }
+
+ private:
+ // Invoke the read and write callbacks, if the timing is appropriate.
+ void InvokeCallbacks();
+
+ int sequence_number_;
+ MockRead current_read_;
+ MockWrite current_write_;
+ int next_read_seq_;
+ int stopping_sequence_number_;
+ bool stopped_;
+ base::WeakPtr<DeterministicMockTCPClientSocket> socket_;
+ bool print_debug_;
+};
+
+
// Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}ClientSocket
// objects get instantiated, they take their data from the i'th element of this
// array.
@@ -395,12 +483,20 @@ class MockClientSocketFactory : public ClientSocketFactory {
MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
// ClientSocketFactory
- virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses,
- NetLog* net_log);
+ virtual ClientSocket* CreateTCPClientSocket(
+ const AddressList& addresses,
+ NetLog* net_log,
+ const NetLog::Source& source);
virtual SSLClientSocket* CreateSSLClientSocket(
ClientSocketHandle* transport_socket,
const std::string& hostname,
const SSLConfig& ssl_config);
+ SocketDataProviderArray<SocketDataProvider>& mock_data() {
+ return mock_data_;
+ }
+ std::vector<MockTCPClientSocket*>& tcp_client_sockets() {
+ return tcp_client_sockets_;
+ }
private:
SocketDataProviderArray<SocketDataProvider> mock_data_;
@@ -414,7 +510,6 @@ class MockClientSocketFactory : public ClientSocketFactory {
class MockClientSocket : public net::SSLClientSocket {
public:
explicit MockClientSocket(net::NetLog* net_log);
-
// ClientSocket methods:
virtual int Connect(net::CompletionCallback* callback) = 0;
virtual void Disconnect();
@@ -422,6 +517,8 @@ class MockClientSocket : public net::SSLClientSocket {
virtual bool IsConnectedAndIdle() const;
virtual int GetPeerAddress(AddressList* address) const;
virtual const BoundNetLog& NetLog() const { return net_log_;}
+ virtual void SetSubresourceSpeculation() {}
+ virtual void SetOmniboxSpeculation() {}
// SSLClientSocket methods:
virtual void GetSSLInfo(net::SSLInfo* ssl_info);
@@ -445,6 +542,7 @@ class MockClientSocket : public net::SSLClientSocket {
virtual void OnReadComplete(const MockRead& data) = 0;
protected:
+ virtual ~MockClientSocket() {}
void RunCallbackAsync(net::CompletionCallback* callback, int result);
void RunCallback(net::CompletionCallback*, int result);
@@ -466,6 +564,7 @@ class MockTCPClientSocket : public MockClientSocket {
virtual void Disconnect();
virtual bool IsConnected() const;
virtual bool IsConnectedAndIdle() const { return IsConnected(); }
+ virtual bool WasEverUsed() const { return was_used_to_convey_data_; }
// Socket methods:
virtual int Read(net::IOBuffer* buf, int buf_len,
@@ -496,6 +595,48 @@ class MockTCPClientSocket : public MockClientSocket {
net::IOBuffer* pending_buf_;
int pending_buf_len_;
net::CompletionCallback* pending_callback_;
+ bool was_used_to_convey_data_;
+};
+
+class DeterministicMockTCPClientSocket : public MockClientSocket,
+ public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
+ public:
+ DeterministicMockTCPClientSocket(net::NetLog* net_log,
+ net::DeterministicSocketData* data);
+
+ // ClientSocket methods:
+ virtual int Connect(net::CompletionCallback* callback);
+ virtual void Disconnect();
+ virtual bool IsConnected() const;
+ virtual bool IsConnectedAndIdle() const { return IsConnected(); }
+ virtual bool WasEverUsed() const { return was_used_to_convey_data_; }
+
+ // Socket methods:
+ virtual int Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback);
+ virtual int Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback);
+
+ bool write_pending() const { return write_pending_; }
+ bool read_pending() const { return read_pending_; }
+
+ void CompleteWrite();
+ int CompleteRead();
+ void OnReadComplete(const MockRead& data);
+
+ private:
+ bool write_pending_;
+ net::CompletionCallback* write_callback_;
+ int write_result_;
+
+ net::MockRead read_data_;
+
+ net::IOBuffer* read_buf_;
+ int read_buf_len_;
+ bool read_pending_;
+ net::CompletionCallback* read_callback_;
+ net::DeterministicSocketData* data_;
+ bool was_used_to_convey_data_;
};
class MockSSLClientSocket : public MockClientSocket {
@@ -510,6 +651,8 @@ class MockSSLClientSocket : public MockClientSocket {
// ClientSocket methods:
virtual int Connect(net::CompletionCallback* callback);
virtual void Disconnect();
+ virtual bool IsConnected() const;
+ virtual bool WasEverUsed() const;
// Socket methods:
virtual int Read(net::IOBuffer* buf, int buf_len,
@@ -520,8 +663,8 @@ class MockSSLClientSocket : public MockClientSocket {
// SSLClientSocket methods:
virtual void GetSSLInfo(net::SSLInfo* ssl_info);
virtual NextProtoStatus GetNextProto(std::string* proto);
- virtual bool wasNpnNegotiated() const;
- virtual bool setWasNpnNegotiated(bool negotiated);
+ virtual bool was_npn_negotiated() const;
+ virtual bool set_was_npn_negotiated(bool negotiated);
// This MockSocket does not implement the manual async IO feature.
virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); }
@@ -533,6 +676,7 @@ class MockSSLClientSocket : public MockClientSocket {
net::SSLSocketDataProvider* data_;
bool is_npn_state_set_;
bool new_npn_value_;
+ bool was_used_to_convey_data_;
};
class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
@@ -558,8 +702,8 @@ class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
TestCompletionCallback callback_;
};
-class ClientSocketPoolTest : public testing::Test {
- protected:
+class ClientSocketPoolTest {
+ public:
enum KeepAlive {
KEEP_ALIVE,
@@ -570,15 +714,15 @@ class ClientSocketPoolTest : public testing::Test {
static const int kIndexOutOfBounds;
static const int kRequestNotFound;
- virtual void SetUp();
- virtual void TearDown();
+ ClientSocketPoolTest();
+ ~ClientSocketPoolTest();
template <typename PoolType, typename SocketParams>
- int StartRequestUsingPool(const scoped_refptr<PoolType>& socket_pool,
+ int StartRequestUsingPool(PoolType* socket_pool,
const std::string& group_name,
RequestPriority priority,
const scoped_refptr<SocketParams>& socket_params) {
- DCHECK(socket_pool.get());
+ DCHECK(socket_pool);
TestSocketRequest* request = new TestSocketRequest(&request_order_,
&completion_count_);
requests_.push_back(request);
@@ -594,7 +738,7 @@ class ClientSocketPoolTest : public testing::Test {
// and returns order in which that request completed, in range 1..n,
// or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
// if that request did not complete (for example was canceled).
- int GetOrderOfRequest(size_t index);
+ int GetOrderOfRequest(size_t index) const;
// Resets first initialized socket handle from |requests_|. If found such
// a handle, returns true.
@@ -603,6 +747,12 @@ class ClientSocketPoolTest : public testing::Test {
// Releases connections until there is nothing to release.
void ReleaseAllConnections(KeepAlive keep_alive);
+ TestSocketRequest* request(int i) { return requests_[i]; }
+ size_t requests_size() const { return requests_.size(); }
+ ScopedVector<TestSocketRequest>* requests() { return &requests_; }
+ size_t completion_count() const { return completion_count_; }
+
+ private:
ScopedVector<TestSocketRequest> requests_;
std::vector<TestSocketRequest*> request_order_;
size_t completion_count_;
@@ -632,11 +782,13 @@ class MockTCPClientSocketPool : public TCPClientSocketPool {
MockTCPClientSocketPool(
int max_sockets,
int max_sockets_per_group,
- const scoped_refptr<ClientSocketPoolHistograms>& histograms,
+ ClientSocketPoolHistograms* histograms,
ClientSocketFactory* socket_factory);
- int release_count() const { return release_count_; };
- int cancel_count() const { return cancel_count_; };
+ virtual ~MockTCPClientSocketPool();
+
+ int release_count() const { return release_count_; }
+ int cancel_count() const { return cancel_count_; }
// TCPClientSocketPool methods.
virtual int RequestSocket(const std::string& group_name,
@@ -651,25 +803,59 @@ class MockTCPClientSocketPool : public TCPClientSocketPool {
virtual void ReleaseSocket(const std::string& group_name,
ClientSocket* socket, int id);
- protected:
- virtual ~MockTCPClientSocketPool();
-
private:
ClientSocketFactory* client_socket_factory_;
+ ScopedVector<MockConnectJob> job_list_;
int release_count_;
int cancel_count_;
- ScopedVector<MockConnectJob> job_list_;
DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool);
};
+class DeterministicMockClientSocketFactory : public ClientSocketFactory {
+ public:
+ void AddSocketDataProvider(DeterministicSocketData* socket);
+ void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
+ void ResetNextMockIndexes();
+
+ // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
+ // created.
+ MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
+
+ // ClientSocketFactory
+ virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses,
+ NetLog* net_log,
+ const NetLog::Source& source);
+ virtual SSLClientSocket* CreateSSLClientSocket(
+ ClientSocketHandle* transport_socket,
+ const std::string& hostname,
+ const SSLConfig& ssl_config);
+
+ SocketDataProviderArray<DeterministicSocketData>& mock_data() {
+ return mock_data_;
+ }
+ std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
+ return tcp_client_sockets_;
+ }
+
+ private:
+ SocketDataProviderArray<DeterministicSocketData> mock_data_;
+ SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
+
+ // Store pointers to handed out sockets in case the test wants to get them.
+ std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
+ std::vector<MockSSLClientSocket*> ssl_client_sockets_;
+};
+
class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
public:
MockSOCKSClientSocketPool(
int max_sockets,
int max_sockets_per_group,
- const scoped_refptr<ClientSocketPoolHistograms>& histograms,
- const scoped_refptr<TCPClientSocketPool>& tcp_pool);
+ ClientSocketPoolHistograms* histograms,
+ TCPClientSocketPool* tcp_pool);
+
+ virtual ~MockSOCKSClientSocketPool();
// SOCKSClientSocketPool methods.
virtual int RequestSocket(const std::string& group_name,
@@ -684,54 +870,12 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
virtual void ReleaseSocket(const std::string& group_name,
ClientSocket* socket, int id);
- protected:
- virtual ~MockSOCKSClientSocketPool();
-
private:
- const scoped_refptr<TCPClientSocketPool> tcp_pool_;
+ TCPClientSocketPool* const tcp_pool_;
DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
};
-struct MockHttpAuthControllerData {
- MockHttpAuthControllerData(std::string header) : auth_header(header) {}
-
- std::string auth_header;
-};
-
-class MockHttpAuthController : public HttpAuthController {
- public:
- MockHttpAuthController();
- void SetMockAuthControllerData(struct MockHttpAuthControllerData* data,
- size_t data_length);
-
- // HttpAuthController methods.
- virtual int MaybeGenerateAuthToken(const HttpRequestInfo* request,
- CompletionCallback* callback,
- const BoundNetLog& net_log);
- virtual void AddAuthorizationHeader(
- HttpRequestHeaders* authorization_headers);
- virtual int HandleAuthChallenge(scoped_refptr<HttpResponseHeaders> headers,
- bool do_not_send_server_auth,
- bool establishing_tunnel,
- const BoundNetLog& net_log);
- virtual void ResetAuth(const std::wstring& username,
- const std::wstring& password);
- virtual bool HaveAuthHandler() const;
- virtual bool HaveAuth() const;
-
- private:
- virtual ~MockHttpAuthController() {}
- const struct MockHttpAuthControllerData& CurrentData() const {
- DCHECK(data_index_ < data_count_);
- return data_[data_index_];
- }
-
- MockHttpAuthControllerData* data_;
- size_t data_index_;
- size_t data_count_;
-};
-
// Constants for a successful SOCKS v5 handshake.
extern const char kSOCKS5GreetRequest[];
extern const int kSOCKS5GreetRequestLength;
@@ -745,6 +889,62 @@ extern const int kSOCKS5OkRequestLength;
extern const char kSOCKS5OkResponse[];
extern const int kSOCKS5OkResponseLength;
+class MockSSLClientSocketPool : public SSLClientSocketPool {
+ public:
+ class MockConnectJob {
+ public:
+ MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle,
+ CompletionCallback* callback);
+
+ int Connect();
+ bool CancelHandle(const ClientSocketHandle* handle);
+
+ private:
+ void OnConnect(int rv);
+
+ scoped_ptr<ClientSocket> socket_;
+ ClientSocketHandle* handle_;
+ CompletionCallback* user_callback_;
+ CompletionCallbackImpl<MockConnectJob> connect_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
+ };
+
+ MockSSLClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ ClientSocketFactory* socket_factory,
+ TCPClientSocketPool* tcp_pool);
+
+ virtual ~MockSSLClientSocketPool();
+
+ int release_count() const { return release_count_; }
+ int cancel_count() const { return cancel_count_; }
+
+ // SSLClientSocketPool methods.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* socket_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ CompletionCallback* callback,
+ const BoundNetLog& net_log);
+
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle);
+ virtual void ReleaseSocket(const std::string& group_name,
+ ClientSocket* socket, int id);
+
+ private:
+ ClientSocketFactory* client_socket_factory_;
+ int release_count_;
+ int cancel_count_;
+ ScopedVector<MockConnectJob> job_list_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocketPool);
+};
+
+
} // namespace net
#endif // NET_SOCKET_SOCKET_TEST_UTIL_H_