diff options
Diffstat (limited to 'net/socket')
-rw-r--r-- | net/socket/socket_test_util.cc | 21 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 11 |
2 files changed, 29 insertions, 3 deletions
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index b1af59f..4b6a3ad 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -290,17 +290,32 @@ void MockClientSocketFactory::ResetNextMockIndexes() { mock_ssl_sockets_.ResetNextIndex(); } +ClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(int index) const { + return tcp_client_sockets_[index]; +} + +SSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( + int index) const { + return ssl_client_sockets_[index]; +} + ClientSocket* MockClientSocketFactory::CreateTCPClientSocket( const AddressList& addresses) { - return new MockTCPClientSocket(addresses, mock_sockets_.GetNext()); + ClientSocket* socket = + new MockTCPClientSocket(addresses, mock_sockets_.GetNext()); + tcp_client_sockets_.push_back(socket); + return socket; } SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( ClientSocket* transport_socket, const std::string& hostname, const SSLConfig& ssl_config) { - return new MockSSLClientSocket(transport_socket, hostname, ssl_config, - mock_ssl_sockets_.GetNext()); + SSLClientSocket* socket = + new MockSSLClientSocket(transport_socket, hostname, ssl_config, + mock_ssl_sockets_.GetNext()); + ssl_client_sockets_.push_back(socket); + return socket; } int TestSocketRequest::WaitForResult() { diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 76d4df1f..75f3a37 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -200,6 +200,13 @@ class MockClientSocketFactory : public ClientSocketFactory { void AddMockSSLSocket(MockSSLSocket* socket); void ResetNextMockIndexes(); + // Return |index|-th ClientSocket (starting from 0) that the factory created. + ClientSocket* GetMockTCPClientSocket(int index) const; + + // Return |index|-th SSLClientSocket (starting from 0) that the factory + // created. + SSLClientSocket* GetMockSSLClientSocket(int index) const; + // ClientSocketFactory virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses); virtual SSLClientSocket* CreateSSLClientSocket( @@ -210,6 +217,10 @@ class MockClientSocketFactory : public ClientSocketFactory { private: MockSocketArray<MockSocket> mock_sockets_; MockSocketArray<MockSSLSocket> mock_ssl_sockets_; + + // Store pointers to handed out sockets in case the test wants to get them. + std::vector<ClientSocket*> tcp_client_sockets_; + std::vector<SSLClientSocket*> ssl_client_sockets_; }; class MockClientSocket : public net::SSLClientSocket { |