diff options
author | thakis@chromium.org <thakis@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-27 18:20:07 +0000 |
---|---|---|
committer | thakis@chromium.org <thakis@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-27 18:20:07 +0000 |
commit | 37a955a80fe1ebb08570dda4eba2e30861e9256d (patch) | |
tree | 8c63d89aca3ef5c7a53a7897a6677a4126b134ed /net | |
parent | b77656d8e282ad9b2055c912c4631a84c542497b (diff) | |
download | chromium_src-37a955a80fe1ebb08570dda4eba2e30861e9256d.zip chromium_src-37a955a80fe1ebb08570dda4eba2e30861e9256d.tar.gz chromium_src-37a955a80fe1ebb08570dda4eba2e30861e9256d.tar.bz2 |
Reverting 19466.
Review URL: http://codereview.chromium.org/150003
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@19468 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/net.gyp | 1 | ||||
-rw-r--r-- | net/socket/socket_test_util.cc | 90 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 85 | ||||
-rw-r--r-- | net/socket/socks_client_socket.cc | 96 | ||||
-rw-r--r-- | net/socket/socks_client_socket.h | 20 | ||||
-rw-r--r-- | net/socket/socks_client_socket_unittest.cc | 291 |
6 files changed, 147 insertions, 436 deletions
diff --git a/net/net.gyp b/net/net.gyp index fe05eaa..de904b6 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -482,7 +482,6 @@ 'proxy/proxy_script_fetcher_unittest.cc', 'proxy/proxy_server_unittest.cc', 'proxy/proxy_service_unittest.cc', - 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', 'socket/tcp_client_socket_unittest.cc', diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index fc5ddf3..8e0cb46 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -7,11 +7,95 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" #include "base/message_loop.h" +#include "net/base/io_buffer.h" #include "net/base/ssl_info.h" #include "net/socket/socket.h" +#include "net/socket/ssl_client_socket.h" #include "testing/gtest/include/gtest/gtest.h" -namespace net { +namespace { + +class MockClientSocket : public net::SSLClientSocket { + public: + MockClientSocket(); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(net::SSLInfo* ssl_info); + virtual void GetSSLCertRequestInfo( + net::SSLCertRequestInfo* cert_request_info); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); +#endif + + protected: + void RunCallbackAsync(net::CompletionCallback* callback, int result); + void RunCallback(int result); + + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + net::CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + MockTCPClientSocket(const net::AddressList& addresses, + net::MockSocket* socket); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + private: + net::MockSocket* data_; + int read_offset_; + net::MockRead read_data_; + bool need_read_data_; +}; + +class MockSSLClientSocket : public MockClientSocket { + public: + MockSSLClientSocket( + net::ClientSocket* transport_socket, + const std::string& hostname, + const net::SSLConfig& ssl_config, + net::MockSSLSocket* socket); + ~MockSSLClientSocket(); + + virtual void GetSSLInfo(net::SSLInfo* ssl_info); + + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + private: + class ConnectCallback; + + scoped_ptr<ClientSocket> transport_; + net::MockSSLSocket* data_; +}; MockClientSocket::MockClientSocket() : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), @@ -217,6 +301,10 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, return transport_->Write(buf, buf_len, callback); } +} // namespace + +namespace net { + MockRead StaticMockSocket::GetNextRead() { return reads_[read_index_++]; } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index f710820..bd7c437 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -10,13 +10,10 @@ #include "base/basictypes.h" #include "base/logging.h" -#include "base/scoped_ptr.h" #include "net/base/address_list.h" -#include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/ssl_config_service.h" #include "net/socket/client_socket_factory.h" -#include "net/socket/ssl_client_socket.h" namespace net { @@ -203,88 +200,6 @@ class MockClientSocketFactory : public ClientSocketFactory { MockSocketArray<MockSSLSocket> mock_ssl_sockets_; }; -class MockClientSocket : public net::SSLClientSocket { - public: - MockClientSocket(); - - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) = 0; - - // SSLClientSocket methods: - virtual void GetSSLInfo(net::SSLInfo* ssl_info); - virtual void GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - -#if defined(OS_LINUX) - virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); -#endif - - protected: - void RunCallbackAsync(net::CompletionCallback* callback, int result); - void RunCallback(int result); - - ScopedRunnableMethodFactory<MockClientSocket> method_factory_; - net::CompletionCallback* callback_; - bool connected_; -}; - -class MockTCPClientSocket : public MockClientSocket { - public: - MockTCPClientSocket(const net::AddressList& addresses, - net::MockSocket* socket); - - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback); - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - - private: - net::MockSocket* data_; - int read_offset_; - net::MockRead read_data_; - bool need_read_data_; -}; - -class MockSSLClientSocket : public MockClientSocket { - public: - MockSSLClientSocket( - net::ClientSocket* transport_socket, - const std::string& hostname, - const net::SSLConfig& ssl_config, - net::MockSSLSocket* socket); - ~MockSSLClientSocket(); - - virtual void GetSSLInfo(net::SSLInfo* ssl_info); - - virtual int Connect(net::CompletionCallback* callback); - virtual void Disconnect(); - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - - private: - class ConnectCallback; - - scoped_ptr<ClientSocket> transport_; - net::MockSSLSocket* data_; -}; - } // namespace net #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 80411b7..1bcf80c 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -28,10 +28,10 @@ static const uint8 kInvalidIp[] = { 0, 0, 0, 127 }; // For SOCKS4, the client sends 8 bytes plus the size of the user-id. // For SOCKS4A, this increases to accomodate the unresolved hostname. -static const unsigned int kWriteHeaderSize = 8; +static const int kWriteHeaderSize = 8; // For SOCKS4 and SOCKS4a, the server sends 8 bytes for acknowledgement. -static const unsigned int kReadHeaderSize = 8; +static const int kReadHeaderSize = 8; // Server Response codes for SOCKS. static const uint8 kServerResponseOk = 0x5A; @@ -72,6 +72,9 @@ SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket, next_state_(STATE_NONE), socks_version_(kSOCKS4Unresolved), user_callback_(NULL), + handshake_buf_len_(0), + buffer_(NULL), + buffer_len_(0), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -230,60 +233,65 @@ int SOCKSClientSocket::DoResolveHostComplete(int result) { // Builds the buffer that is to be sent to the server. // We check whether the SOCKS proxy is 4 or 4A. // In case it is 4A, the record size increases by size of the hostname. -const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { +void SOCKSClientSocket::BuildHandshakeWriteBuffer() { DCHECK_NE(kSOCKS4Unresolved, socks_version_); - SOCKS4ServerRequest request; - request.version = kSOCKSVersion4; - request.command = kSOCKSStreamRequest; - request.nw_port = htons(host_request_info_.port()); + int record_size = kWriteHeaderSize + arraysize(kEmptyUserId); + if (socks_version_ == kSOCKS4a) { + record_size += host_request_info_.hostname().size() + 1; + } + + buffer_len_ = record_size; + buffer_.reset(new char[buffer_len_]); + + SOCKS4ServerRequest* request = + reinterpret_cast<SOCKS4ServerRequest*>(buffer_.get()); + + request->version = kSOCKSVersion4; + request->command = kSOCKSStreamRequest; + request->nw_port = htons(host_request_info_.port()); if (socks_version_ == kSOCKS4) { const struct addrinfo* ai = addresses_.head(); DCHECK(ai); // If the sockaddr is IPv6, we have already marked the version to socks4a // and so this step does not get hit. - struct sockaddr_in* ipv4_host = + struct sockaddr_in *ipv4_host = reinterpret_cast<struct sockaddr_in*>(ai->ai_addr); - memcpy(&request.ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr)); + memcpy(&request->ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr)); DLOG(INFO) << "Resolved Host is : " << NetAddressToString(ai); } else if (socks_version_ == kSOCKS4a) { // invalid IP of the form 0.0.0.127 - memcpy(&request.ip, kInvalidIp, arraysize(kInvalidIp)); + memcpy(&request->ip, kInvalidIp, arraysize(kInvalidIp)); } else { NOTREACHED(); } - std::string handshake_data(reinterpret_cast<char*>(&request), - sizeof(request)); - handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); + memcpy(&buffer_[kWriteHeaderSize], kEmptyUserId, arraysize(kEmptyUserId)); - // In case we are passing the domain also, pass the hostname - // terminated with a null character. if (socks_version_ == kSOCKS4a) { - handshake_data.append(host_request_info_.hostname()); - handshake_data.push_back('\0'); + memcpy(&buffer_[kWriteHeaderSize + arraysize(kEmptyUserId)], + host_request_info_.hostname().c_str(), + host_request_info_.hostname().size() + 1); } - - return handshake_data; } // Writes the SOCKS handshake data to the underlying socket connection. int SOCKSClientSocket::DoHandshakeWrite() { next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; - if (buffer_.empty()) { - buffer_ = BuildHandshakeWriteBuffer(); + if (!buffer_.get()) { + BuildHandshakeWriteBuffer(); bytes_sent_ = 0; } - int handshake_buf_len = buffer_.size() - bytes_sent_; - DCHECK_GT(handshake_buf_len, 0); - handshake_buf_ = new IOBuffer(handshake_buf_len); - memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], - handshake_buf_len); - return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_); + handshake_buf_len_ = buffer_len_ - bytes_sent_; + DCHECK_GT(handshake_buf_len_, 0); + handshake_buf_ = new IOBuffer(handshake_buf_len_); + memcpy(handshake_buf_.get()->data(), &buffer_[bytes_sent_], + handshake_buf_len_); + return transport_->Write(handshake_buf_, handshake_buf_len_, &io_callback_); } int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { @@ -292,14 +300,11 @@ int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { if (result < 0) return result; - // We ignore the case when result is 0, since the underlying Write - // may return spurious writes while waiting on the socket. - bytes_sent_ += result; - if (bytes_sent_ == buffer_.size()) { + if (bytes_sent_ == buffer_len_) { next_state_ = STATE_HANDSHAKE_READ; - buffer_.clear(); - } else if (bytes_sent_ < buffer_.size()) { + buffer_.reset(NULL); + } else if (bytes_sent_ < buffer_len_) { next_state_ = STATE_HANDSHAKE_WRITE; } else { return ERR_UNEXPECTED; @@ -313,13 +318,15 @@ int SOCKSClientSocket::DoHandshakeRead() { next_state_ = STATE_HANDSHAKE_READ_COMPLETE; - if (buffer_.empty()) { + if (!buffer_.get()) { + buffer_.reset(new char[kReadHeaderSize]); + buffer_len_ = kReadHeaderSize; bytes_received_ = 0; } - int handshake_buf_len = kReadHeaderSize - bytes_received_; - handshake_buf_ = new IOBuffer(handshake_buf_len); - return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_); + handshake_buf_len_ = buffer_len_ - bytes_received_; + handshake_buf_ = new IOBuffer(handshake_buf_len_); + return transport_->Read(handshake_buf_, handshake_buf_len_, &io_callback_); } int SOCKSClientSocket::DoHandshakeReadComplete(int result) { @@ -327,23 +334,18 @@ int SOCKSClientSocket::DoHandshakeReadComplete(int result) { if (result < 0) return result; - - // The underlying socket closed unexpectedly. - if (result == 0) - return ERR_CONNECTION_CLOSED; - - if (bytes_received_ + result > kReadHeaderSize) + if (bytes_received_ + result > buffer_len_) return ERR_INVALID_RESPONSE; - buffer_.append(handshake_buf_->data(), result); + memcpy(buffer_.get() + bytes_received_, handshake_buf_->data(), result); bytes_received_ += result; - if (bytes_received_ < kReadHeaderSize) { + if (bytes_received_ < buffer_len_) { next_state_ = STATE_HANDSHAKE_READ; return OK; } - const SOCKS4ServerResponse* response = - reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); + SOCKS4ServerResponse* response = + reinterpret_cast<SOCKS4ServerResponse*>(buffer_.get()); if (response->reserved_null != 0x00) { LOG(ERROR) << "Unknown response from SOCKS server."; diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 295000a..03925ba 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ -#define NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ +#ifndef NET_BASE_SOCKS_CLIENT_SOCKET_H_ +#define NET_BASE_SOCKS_CLIENT_SOCKET_H_ #include <string> @@ -54,10 +54,6 @@ class SOCKSClientSocket : public ClientSocket { #endif private: - FRIEND_TEST(SOCKSClientSocketTest, CompleteHandshake); - FRIEND_TEST(SOCKSClientSocketTest, SOCKS4AFailedDNS); - FRIEND_TEST(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6); - enum State { STATE_RESOLVE_HOST, STATE_RESOLVE_HOST_COMPLETE, @@ -89,7 +85,7 @@ class SOCKSClientSocket : public ClientSocket { int DoHandshakeWrite(); int DoHandshakeWriteComplete(int result); - const std::string BuildHandshakeWriteBuffer() const; + void BuildHandshakeWriteBuffer(); CompletionCallbackImpl<SOCKSClientSocket> io_callback_; @@ -106,18 +102,20 @@ class SOCKSClientSocket : public ClientSocket { // SOCKS handshake data. The length contains the expected size to // read or write. scoped_refptr<IOBuffer> handshake_buf_; + int handshake_buf_len_; // While writing, this buffer stores the complete write handshake data. // While reading, it stores the handshake information received so far. - std::string buffer_; + scoped_array<char> buffer_; + int buffer_len_; // This becomes true when the SOCKS handshake has completed and the // overlying connection is free to communicate. bool completed_handshake_; // These contain the bytes sent / received by the SOCKS handshake. - size_t bytes_sent_; - size_t bytes_received_; + int bytes_sent_; + int bytes_received_; // Used to resolve the hostname to which the SOCKS proxy will connect. SingleRequestHostResolver resolver_; @@ -129,5 +127,5 @@ class SOCKSClientSocket : public ClientSocket { } // namespace net -#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ +#endif // NET_BASE_SOCKS_CLIENT_SOCKET_H_ diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc deleted file mode 100644 index 4d6a624..0000000 --- a/net/socket/socks_client_socket_unittest.cc +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright (c) 2009 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/socks_client_socket.h" - -#include "net/base/address_list.h" -#include "net/base/host_resolver_unittest.h" -#include "net/base/listen_socket.h" -#include "net/base/test_completion_callback.h" -#include "net/base/winsock_init.h" -#include "net/socket/client_socket_factory.h" -#include "net/socket/tcp_client_socket.h" -#include "net/socket/socket_test_util.h" -#include "testing/gtest/include/gtest/gtest.h" -#include "testing/platform_test.h" - -//----------------------------------------------------------------------------- - -namespace net { - -const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; -const char kSOCKS4aInitialRequest[] = - { 0x04, 0x01, 0x00, 0x50, 0, 0, 0, 127, 0 }; -const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; - -class SOCKSClientSocketTest : public PlatformTest { - public: - SOCKSClientSocketTest(); - // Create a SOCKSClientSocket on top of a MockSocket. - SOCKSClientSocket* BuildMockSocket(MockRead reads[], MockWrite writes[], - const std::string& hostname, int port); - virtual void SetUp(); - - protected: - scoped_ptr<SOCKSClientSocket> user_sock_; - AddressList address_list_; - ClientSocket* tcp_sock_; - ScopedHostMapper host_mapper_; - TestCompletionCallback callback_; - scoped_refptr<RuleBasedHostMapper> mapper_; - HostResolver host_resolver_; - scoped_ptr<MockSocket> mock_socket_; - - private: - DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketTest); -}; - -SOCKSClientSocketTest::SOCKSClientSocketTest() - : host_resolver_(0, 0) { -} - -// Set up platform before every test case -void SOCKSClientSocketTest::SetUp() { - PlatformTest::SetUp(); - - // Resolve the "localhost" AddressList used by the tcp_connection to connect. - HostResolver resolver; - HostResolver::RequestInfo info("localhost", 1080); - int rv = resolver.Resolve(info, &address_list_, NULL, NULL); - ASSERT_EQ(OK, rv); - - // Create a new host mapping for the duration of this test case only. - mapper_ = new RuleBasedHostMapper(); - host_mapper_.Init(mapper_); - mapper_->AddRule("www.google.com", "127.0.0.1"); -} - -SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( - MockRead reads[], - MockWrite writes[], - const std::string& hostname, - int port) { - - TestCompletionCallback callback; - mock_socket_.reset(new StaticMockSocket(reads, writes)); - tcp_sock_ = new MockTCPClientSocket(address_list_, mock_socket_.get()); - - int rv = tcp_sock_->Connect(&callback); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(tcp_sock_->IsConnected()); - - return new SOCKSClientSocket(tcp_sock_, - HostResolver::RequestInfo(hostname, port), - &host_resolver_); -} - -// Tests a complete handshake and the disconnection. -TEST_F(SOCKSClientSocketTest, CompleteHandshake) { - const std::string payload_write = "random data"; - const std::string payload_read = "moar random data"; - - MockWrite data_writes[] = { - MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), - MockWrite(true, payload_write.data(), payload_write.size()) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)), - MockRead(true, payload_read.data(), payload_read.size()) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - // At this state the TCP connection is completed but not the SOCKS handshake. - EXPECT_TRUE(tcp_sock_->IsConnected()); - EXPECT_FALSE(user_sock_->IsConnected()); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_FALSE(user_sock_->IsConnected()); - rv = callback_.WaitForResult(); - - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); - EXPECT_EQ(SOCKSClientSocket::kSOCKS4, user_sock_->socks_version_); - - scoped_refptr<IOBuffer> buffer = new IOBuffer(payload_write.size()); - memcpy(buffer->data(), payload_write.data(), payload_write.size()); - rv = user_sock_->Write(buffer, payload_write.size(), &callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(static_cast<int>(payload_write.size()), rv); - - buffer = new IOBuffer(payload_read.size()); - rv = user_sock_->Read(buffer, payload_read.size(), &callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(static_cast<int>(payload_read.size()), rv); - EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); - - user_sock_->Disconnect(); - EXPECT_FALSE(tcp_sock_->IsConnected()); - EXPECT_FALSE(user_sock_->IsConnected()); -} - -// List of responses from the socks server and the errors they should -// throw up are tested here. -TEST_F(SOCKSClientSocketTest, HandshakeFailures) { - const struct { - const char fail_reply[8]; - Error fail_code; - } tests[] = { - // Failure of the server response code - { - { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, - ERR_INVALID_RESPONSE, - }, - // Failure of the null byte - { - { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, - ERR_FAILED, - }, - }; - - //--------------------------------------- - - for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { - MockWrite data_writes[] = { - MockWrite(false, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; - MockRead data_reads[] = { - MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(tests[i].fail_code, rv); - EXPECT_FALSE(user_sock_->IsConnected()); - EXPECT_TRUE(tcp_sock_->IsConnected()); - } -} - -// Tests scenario when the server sends the handshake response in -// more than one packet. -TEST_F(SOCKSClientSocketTest, PartialServerReads) { - const char kSOCKSPartialReply1[] = { 0x00 }; - const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; - - MockWrite data_writes[] = { - MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), - MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); -} - -// Tests scenario when the client sends the handshake request in -// more than one packet. -TEST_F(SOCKSClientSocketTest, PartialClientWrites) { - const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; - const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; - - MockWrite data_writes[] = { - MockWrite(true, arraysize(kSOCKSPartialRequest1)), - // simulate some empty writes - MockWrite(true, 0), - MockWrite(true, 0), - MockWrite(true, kSOCKSPartialRequest2, - arraysize(kSOCKSPartialRequest2)) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); -} - -// Tests the case when the server sends a smaller sized handshake data -// and closes the connection. -TEST_F(SOCKSClientSocketTest, FailedSocketRead) { - MockWrite data_writes[] = { - MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; - MockRead data_reads[] = { - MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), - // close connection unexpectedly - MockRead(false, 0) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); - EXPECT_FALSE(user_sock_->IsConnected()); -} - -// Tries to connect to an unknown DNS and on failure should revert to SOCKS4A. -TEST_F(SOCKSClientSocketTest, SOCKS4AFailedDNS) { - const char hostname[] = "unresolved.ipv4.address"; - - mapper_->AddSimulatedFailure(hostname); - - std::string request(kSOCKS4aInitialRequest, - arraysize(kSOCKS4aInitialRequest)); - request.append(hostname, arraysize(hostname)); - - MockWrite data_writes[] = { - MockWrite(false, request.data(), request.size()) }; - MockRead data_reads[] = { - MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); - EXPECT_EQ(SOCKSClientSocket::kSOCKS4a, user_sock_->socks_version_); -} - -// Tries to connect to a domain that resolves to IPv6. -// Should revert to SOCKS4a. -TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) { - const char hostname[] = "an.ipv6.address"; - - mapper_->AddRule(hostname, "2001:db8:8714:3a90::12"); - - std::string request(kSOCKS4aInitialRequest, - arraysize(kSOCKS4aInitialRequest)); - request.append(hostname, arraysize(hostname)); - - MockWrite data_writes[] = { - MockWrite(false, request.data(), request.size()) }; - MockRead data_reads[] = { - MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - - user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); - - int rv = user_sock_->Connect(&callback_); - EXPECT_EQ(ERR_IO_PENDING, rv); - rv = callback_.WaitForResult(); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(user_sock_->IsConnected()); - EXPECT_EQ(SOCKSClientSocket::kSOCKS4a, user_sock_->socks_version_); -} - -} // namespace net - |