diff options
-rw-r--r-- | net/net.gyp | 1 | ||||
-rw-r--r-- | net/socket/socket_test_util.cc | 90 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 85 | ||||
-rw-r--r-- | net/socket/socks_client_socket.h | 16 | ||||
-rw-r--r-- | net/socket/socks_client_socket_unittest.cc | 291 |
5 files changed, 96 insertions, 387 deletions
diff --git a/net/net.gyp b/net/net.gyp index fe05eaa..de904b6 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -482,7 +482,6 @@ 'proxy/proxy_script_fetcher_unittest.cc', 'proxy/proxy_server_unittest.cc', 'proxy/proxy_service_unittest.cc', - 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', 'socket/tcp_client_socket_unittest.cc', diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index fc5ddf3..8e0cb46 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -7,11 +7,95 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" #include "base/message_loop.h" +#include "net/base/io_buffer.h" #include "net/base/ssl_info.h" #include "net/socket/socket.h" +#include "net/socket/ssl_client_socket.h" #include "testing/gtest/include/gtest/gtest.h" -namespace net { +namespace { + +class MockClientSocket : public net::SSLClientSocket { + public: + MockClientSocket(); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(net::SSLInfo* ssl_info); + virtual void GetSSLCertRequestInfo( + net::SSLCertRequestInfo* cert_request_info); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); +#endif + + protected: + void RunCallbackAsync(net::CompletionCallback* callback, int result); + void RunCallback(int result); + + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + net::CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + MockTCPClientSocket(const net::AddressList& addresses, + net::MockSocket* socket); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + private: + net::MockSocket* data_; + int read_offset_; + net::MockRead read_data_; + bool need_read_data_; +}; + +class MockSSLClientSocket : public MockClientSocket { + public: + MockSSLClientSocket( + net::ClientSocket* transport_socket, + const std::string& hostname, + const net::SSLConfig& ssl_config, + net::MockSSLSocket* socket); + ~MockSSLClientSocket(); + + virtual void GetSSLInfo(net::SSLInfo* ssl_info); + + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + private: + class ConnectCallback; + + scoped_ptr<ClientSocket> transport_; + net::MockSSLSocket* data_; +}; MockClientSocket::MockClientSocket() : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), @@ -217,6 +301,10 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, return transport_->Write(buf, buf_len, callback); } +} // namespace + +namespace net { + MockRead StaticMockSocket::GetNextRead() { return reads_[read_index_++]; } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index f710820..bd7c437 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -10,13 +10,10 @@ #include "base/basictypes.h" #include "base/logging.h" -#include "base/scoped_ptr.h" #include "net/base/address_list.h" -#include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/ssl_config_service.h" #include "net/socket/client_socket_factory.h" -#include "net/socket/ssl_client_socket.h" namespace net { @@ -203,88 +200,6 @@ class MockClientSocketFactory : public ClientSocketFactory { MockSocketArray<MockSSLSocket> mock_ssl_sockets_; }; -class MockClientSocket : public net::SSLClientSocket { - public: - MockClientSocket(); - - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) = 0; - - // SSLClientSocket methods: - virtual void GetSSLInfo(net::SSLInfo* ssl_info); - virtual void GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - -#if defined(OS_LINUX) - virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); -#endif - - protected: - void RunCallbackAsync(net::CompletionCallback* callback, int result); - void RunCallback(int result); - - ScopedRunnableMethodFactory<MockClientSocket> method_factory_; - net::CompletionCallback* callback_; - bool connected_; -}; - -class MockTCPClientSocket : public MockClientSocket { - public: - MockTCPClientSocket(const net::AddressList& addresses, - net::MockSocket* socket); - - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback); - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - - private: - net::MockSocket* data_; - int read_offset_; - net::MockRead read_data_; - bool need_read_data_; -}; - -class MockSSLClientSocket : public MockClientSocket { - public: - MockSSLClientSocket( - net::ClientSocket* transport_socket, - const std::string& hostname, - const net::SSLConfig& ssl_config, - net::MockSSLSocket* socket); - ~MockSSLClientSocket(); - - virtual void GetSSLInfo(net::SSLInfo* ssl_info); - - virtual int Connect(net::CompletionCallback* callback); - virtual void Disconnect(); - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - - private: - class ConnectCallback; - - scoped_ptr<ClientSocket> transport_; - net::MockSSLSocket* data_; -}; - } // namespace net #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 53df119..03925ba 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ -#define NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ +#ifndef NET_BASE_SOCKS_CLIENT_SOCKET_H_ +#define NET_BASE_SOCKS_CLIENT_SOCKET_H_ #include <string> @@ -54,10 +54,6 @@ class SOCKSClientSocket : public ClientSocket { #endif private: - FRIEND_TEST(SOCKSClientSocketTest, CompleteHandshake); - FRIEND_TEST(SOCKSClientSocketTest, SOCKS4AFailedDNS); - FRIEND_TEST(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6); - enum State { STATE_RESOLVE_HOST, STATE_RESOLVE_HOST_COMPLETE, @@ -89,7 +85,7 @@ class SOCKSClientSocket : public ClientSocket { int DoHandshakeWrite(); int DoHandshakeWriteComplete(int result); - const std::string BuildHandshakeWriteBuffer() const; + void BuildHandshakeWriteBuffer(); CompletionCallbackImpl<SOCKSClientSocket> io_callback_; @@ -106,10 +102,12 @@ class SOCKSClientSocket : public ClientSocket { // SOCKS handshake data. The length contains the expected size to // read or write. scoped_refptr<IOBuffer> handshake_buf_; + int handshake_buf_len_; // While writing, this buffer stores the complete write handshake data. // While reading, it stores the handshake information received so far. - std::string buffer_; + scoped_array<char> buffer_; + int buffer_len_; // This becomes true when the SOCKS handshake has completed and the // overlying connection is free to communicate. @@ -129,5 +127,5 @@ class SOCKSClientSocket : public ClientSocket { } // namespace net -#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ +#endif // NET_BASE_SOCKS_CLIENT_SOCKET_H_ diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc deleted file mode 100644 index 813a48e..0000000 --- a/net/socket/socks_client_socket_unittest.cc +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright (c) 2009 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/socks_client_socket.h" - -#include "net/base/address_list.h" -#include "net/base/host_resolver_unittest.h" -#include "net/base/listen_socket.h" -#include "net/base/test_completion_callback.h" -#include "net/base/winsock_init.h" -#include "net/socket/client_socket_factory.h" -#include "net/socket/tcp_client_socket.h" -#include "net/socket/socket_test_util.h" -#include "testing/gtest/include/gtest/gtest.h" -#include "testing/platform_test.h" - -//----------------------------------------------------------------------------- - -namespace net { - -const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; -const char kSOCKS4aInitialRequest[] = - { 0x04, 0x01, 0x00, 0x50, 0, 0, 0, 127, 0 }; -const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; - -class SOCKSClientSocketTest : public PlatformTest { - public: - SOCKSClientSocketTest(); - // Create a SOCKSClientSocket on top of a MockSocket. - SOCKSClientSocket* BuildMockSocket(MockRead reads[], MockWrite writes[], - const std::string& hostname, int port); - virtual void SetUp(); - - protected: - scoped_ptr<SOCKSClientSocket> user_sock_; - AddressList address_list_; - ClientSocket* tcp_sock_; - ScopedHostMapper host_mapper_; - TestCompletionCallback callback_; - scoped_refptr<RuleBasedHostMapper> mapper_; - HostResolver host_resolver_; - scoped_ptr<MockSocket> mock_socket_; - - private: - DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketTest); -}; - -SOCKSClientSocketTest::SOCKSClientSocketTest() - : host_resolver_(0, 0) { -} - -// Set up platform before every test case -void SOCKSClientSocketTest::SetUp() { - PlatformTest::SetUp(); - - // Resolve the "localhost" AddressList used by the tcp_connection to connect. - HostResolver resolver; - HostResolver::RequestInfo info("localhost", 1080); - int rv = resolver.Resolve(info, &address_list_, NULL, NULL); - ASSERT_EQ(OK, rv); - - // Create a new host mapping for the duration of this test case only. - mapper_ = new RuleBasedHostMapper(); - host_mapper_.Init(mapper_); - mapper_->AddRule("www.google.com", "127.0.0.1"); -} - -SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( - MockRead reads[], - MockWrite writes[], - const std::string& hostname, - int port) { - - TestCompletionCallback callback; - mock_socket_.reset(new StaticMockSocket(reads, writes)); - tcp_sock_ = new MockTCPClientSocket(address_list_, mock_socket_.get()); - - int rv = tcp_sock_->Connect(&callback); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(tcp_sock_->IsConnected()); - - return new SOCKSClientSocket(tcp_sock_, - HostResolver::RequestInfo(hostname, port), - &host_resolver_); -} - -// Tests a complete handshake and the disconnection. -TEST_F(SOCKSClientSocketTest, CompleteHandshake) { - const std::string payload_write = "random data"; - const std::string payload_read = "moar random data"; - - MockWrite data_writes[] = { - MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), - MockWrite(true, payload_write.data(), payload_write.size()) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)), - MockRead(true, payload_read.data(), payload_read.size()) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - // At this state the TCP connection is completed but not the SOCKS handshake. - EXPECT_TRUE(tcp_sock_->IsConnected()); - EXPECT_FALSE(user_sock_->IsConnected()); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_FALSE(user_sock_->IsConnected()); - rv = callback_.WaitForResult(); - - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); - EXPECT_EQ(SOCKSClientSocket::kSOCKS4, user_sock_->socks_version_); - - scoped_refptr<IOBuffer> buffer = new IOBuffer(payload_write.size()); - memcpy(buffer->data(), payload_write.data(), payload_write.size()); - rv = user_sock_->Write(buffer, payload_write.size(), &callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(payload_write.size(), rv); - - buffer = new IOBuffer(payload_read.size()); - rv = user_sock_->Read(buffer, payload_read.size(), &callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(payload_read.size(), rv); - EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); - - user_sock_->Disconnect(); - EXPECT_FALSE(tcp_sock_->IsConnected()); - EXPECT_FALSE(user_sock_->IsConnected()); -} - -// List of responses from the socks server and the errors they should -// throw up are tested here. -TEST_F(SOCKSClientSocketTest, HandshakeFailures) { - const struct { - const char fail_reply[8]; - Error fail_code; - } tests[] = { - // Failure of the server response code - { - { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, - ERR_INVALID_RESPONSE, - }, - // Failure of the null byte - { - { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, - ERR_FAILED, - }, - }; - - //--------------------------------------- - - for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { - MockWrite data_writes[] = { - MockWrite(false, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; - MockRead data_reads[] = { - MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(tests[i].fail_code, rv); - EXPECT_FALSE(user_sock_->IsConnected()); - EXPECT_TRUE(tcp_sock_->IsConnected()); - } -} - -// Tests scenario when the server sends the handshake response in -// more than one packet. -TEST_F(SOCKSClientSocketTest, PartialServerReads) { - const char kSOCKSPartialReply1[] = { 0x00 }; - const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; - - MockWrite data_writes[] = { - MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), - MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); -} - -// Tests scenario when the client sends the handshake request in -// more than one packet. -TEST_F(SOCKSClientSocketTest, PartialClientWrites) { - const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; - const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; - - MockWrite data_writes[] = { - MockWrite(true, arraysize(kSOCKSPartialRequest1)), - // simulate some empty writes - MockWrite(true, 0), - MockWrite(true, 0), - MockWrite(true, kSOCKSPartialRequest2, - arraysize(kSOCKSPartialRequest2)) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); -} - -// Tests the case when the server sends a smaller sized handshake data -// and closes the connection. -TEST_F(SOCKSClientSocketTest, FailedSocketRead) { - MockWrite data_writes[] = { - MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), - // close connection unexpectedly - MockRead(false, 0) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); - EXPECT_FALSE(user_sock_->IsConnected()); -} - -// Tries to connect to an unknown DNS and on failure should revert to SOCKS4A. -TEST_F(SOCKSClientSocketTest, SOCKS4AFailedDNS) { - const char hostname[] = "unresolved.ipv4.address"; - - mapper_->AddSimulatedFailure(hostname); - - std::string request(kSOCKS4aInitialRequest, - arraysize(kSOCKS4aInitialRequest)); - request.append(hostname, arraysize(hostname)); - - MockWrite data_writes[] = { - MockWrite(false, request.data(), request.size()) }; - MockRead data_reads[] = { - MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); - EXPECT_EQ(SOCKSClientSocket::kSOCKS4a, user_sock_->socks_version_); -} - -// Tries to connect to a domain that resolves to IPv6. -// Should revert to SOCKS4a. -TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) { - const char hostname[] = "an.ipv6.address"; - - mapper_->AddRule(hostname, "2001:db8:8714:3a90::12"); - - std::string request(kSOCKS4aInitialRequest, - arraysize(kSOCKS4aInitialRequest)); - request.append(hostname, arraysize(hostname)); - - MockWrite data_writes[] = { - MockWrite(false, request.data(), request.size()) }; - MockRead data_reads[] = { - MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); - EXPECT_EQ(SOCKSClientSocket::kSOCKS4a, user_sock_->socks_version_); -} - -} // namespace net - |