diff options
author | arindam@chromium.org <arindam@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-26 13:30:50 +0000 |
---|---|---|
committer | arindam@chromium.org <arindam@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-26 13:30:50 +0000 |
commit | c08aac2a6078d61d4fe594e52377c4ac839a9f76 (patch) | |
tree | 9ed90f24824e47821b0ca99bc2120a1f81600f41 /net/socket | |
parent | 001b694d745a07e3eb42d6f48e30b8f9f8901138 (diff) | |
download | chromium_src-c08aac2a6078d61d4fe594e52377c4ac839a9f76.zip chromium_src-c08aac2a6078d61d4fe594e52377c4ac839a9f76.tar.gz chromium_src-c08aac2a6078d61d4fe594e52377c4ac839a9f76.tar.bz2 |
Tests for socks4/4a implementation.
Refactoring of void BuildHandshakeWriteBuffer() to const std::string BuildHandshakeWriteBuffer() const,
and removing private members handshake_buf_len_ and buffer_len_ (since buffer_ is now std::string, buffer_.size()) is more than sufficient.
TEST=unittests
BUG=none
Review URL: http://codereview.chromium.org/139009
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@19354 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/socket')
-rw-r--r-- | net/socket/socket_test_util.cc | 90 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 85 | ||||
-rw-r--r-- | net/socket/socks_client_socket.cc | 92 | ||||
-rw-r--r-- | net/socket/socks_client_socket.h | 20 | ||||
-rw-r--r-- | net/socket/socks_client_socket_unittest.cc | 291 |
5 files changed, 433 insertions, 145 deletions
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 8e0cb46..fc5ddf3 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -7,95 +7,11 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" #include "base/message_loop.h" -#include "net/base/io_buffer.h" #include "net/base/ssl_info.h" #include "net/socket/socket.h" -#include "net/socket/ssl_client_socket.h" #include "testing/gtest/include/gtest/gtest.h" -namespace { - -class MockClientSocket : public net::SSLClientSocket { - public: - MockClientSocket(); - - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) = 0; - - // SSLClientSocket methods: - virtual void GetSSLInfo(net::SSLInfo* ssl_info); - virtual void GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - -#if defined(OS_LINUX) - virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); -#endif - - protected: - void RunCallbackAsync(net::CompletionCallback* callback, int result); - void RunCallback(int result); - - ScopedRunnableMethodFactory<MockClientSocket> method_factory_; - net::CompletionCallback* callback_; - bool connected_; -}; - -class MockTCPClientSocket : public MockClientSocket { - public: - MockTCPClientSocket(const net::AddressList& addresses, - net::MockSocket* socket); - - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback); - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - - private: - net::MockSocket* data_; - int read_offset_; - net::MockRead read_data_; - bool need_read_data_; -}; - -class MockSSLClientSocket : public MockClientSocket { - public: - MockSSLClientSocket( - net::ClientSocket* transport_socket, - const std::string& hostname, - const net::SSLConfig& ssl_config, - net::MockSSLSocket* socket); - ~MockSSLClientSocket(); - - virtual void GetSSLInfo(net::SSLInfo* ssl_info); - - virtual int Connect(net::CompletionCallback* callback); - virtual void Disconnect(); - - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - - private: - class ConnectCallback; - - scoped_ptr<ClientSocket> transport_; - net::MockSSLSocket* data_; -}; +namespace net { MockClientSocket::MockClientSocket() : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), @@ -301,10 +217,6 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, return transport_->Write(buf, buf_len, callback); } -} // namespace - -namespace net { - MockRead StaticMockSocket::GetNextRead() { return reads_[read_index_++]; } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index bd7c437..f710820 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -10,10 +10,13 @@ #include "base/basictypes.h" #include "base/logging.h" +#include "base/scoped_ptr.h" #include "net/base/address_list.h" +#include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/ssl_config_service.h" #include "net/socket/client_socket_factory.h" +#include "net/socket/ssl_client_socket.h" namespace net { @@ -200,6 +203,88 @@ class MockClientSocketFactory : public ClientSocketFactory { MockSocketArray<MockSSLSocket> mock_ssl_sockets_; }; +class MockClientSocket : public net::SSLClientSocket { + public: + MockClientSocket(); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(net::SSLInfo* ssl_info); + virtual void GetSSLCertRequestInfo( + net::SSLCertRequestInfo* cert_request_info); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); +#endif + + protected: + void RunCallbackAsync(net::CompletionCallback* callback, int result); + void RunCallback(int result); + + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + net::CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + MockTCPClientSocket(const net::AddressList& addresses, + net::MockSocket* socket); + + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + private: + net::MockSocket* data_; + int read_offset_; + net::MockRead read_data_; + bool need_read_data_; +}; + +class MockSSLClientSocket : public MockClientSocket { + public: + MockSSLClientSocket( + net::ClientSocket* transport_socket, + const std::string& hostname, + const net::SSLConfig& ssl_config, + net::MockSSLSocket* socket); + ~MockSSLClientSocket(); + + virtual void GetSSLInfo(net::SSLInfo* ssl_info); + + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + + private: + class ConnectCallback; + + scoped_ptr<ClientSocket> transport_; + net::MockSSLSocket* data_; +}; + } // namespace net #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 1bcf80c..0ccaeb6 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -72,9 +72,6 @@ SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket, next_state_(STATE_NONE), socks_version_(kSOCKS4Unresolved), user_callback_(NULL), - handshake_buf_len_(0), - buffer_(NULL), - buffer_len_(0), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -233,65 +230,60 @@ int SOCKSClientSocket::DoResolveHostComplete(int result) { // Builds the buffer that is to be sent to the server. // We check whether the SOCKS proxy is 4 or 4A. // In case it is 4A, the record size increases by size of the hostname. -void SOCKSClientSocket::BuildHandshakeWriteBuffer() { +const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { DCHECK_NE(kSOCKS4Unresolved, socks_version_); - int record_size = kWriteHeaderSize + arraysize(kEmptyUserId); - if (socks_version_ == kSOCKS4a) { - record_size += host_request_info_.hostname().size() + 1; - } - - buffer_len_ = record_size; - buffer_.reset(new char[buffer_len_]); - - SOCKS4ServerRequest* request = - reinterpret_cast<SOCKS4ServerRequest*>(buffer_.get()); - - request->version = kSOCKSVersion4; - request->command = kSOCKSStreamRequest; - request->nw_port = htons(host_request_info_.port()); + SOCKS4ServerRequest request; + request.version = kSOCKSVersion4; + request.command = kSOCKSStreamRequest; + request.nw_port = htons(host_request_info_.port()); if (socks_version_ == kSOCKS4) { const struct addrinfo* ai = addresses_.head(); DCHECK(ai); // If the sockaddr is IPv6, we have already marked the version to socks4a // and so this step does not get hit. - struct sockaddr_in *ipv4_host = + struct sockaddr_in* ipv4_host = reinterpret_cast<struct sockaddr_in*>(ai->ai_addr); - memcpy(&request->ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr)); + memcpy(&request.ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr)); DLOG(INFO) << "Resolved Host is : " << NetAddressToString(ai); } else if (socks_version_ == kSOCKS4a) { // invalid IP of the form 0.0.0.127 - memcpy(&request->ip, kInvalidIp, arraysize(kInvalidIp)); + memcpy(&request.ip, kInvalidIp, arraysize(kInvalidIp)); } else { NOTREACHED(); } - memcpy(&buffer_[kWriteHeaderSize], kEmptyUserId, arraysize(kEmptyUserId)); + std::string handshake_data(reinterpret_cast<char*>(&request), + sizeof(request)); + handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); + // In case we are passing the domain also, pass the hostname + // terminated with a null character. if (socks_version_ == kSOCKS4a) { - memcpy(&buffer_[kWriteHeaderSize + arraysize(kEmptyUserId)], - host_request_info_.hostname().c_str(), - host_request_info_.hostname().size() + 1); + handshake_data.append(host_request_info_.hostname()); + handshake_data.push_back('\0'); } + + return handshake_data; } // Writes the SOCKS handshake data to the underlying socket connection. int SOCKSClientSocket::DoHandshakeWrite() { next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; - if (!buffer_.get()) { - BuildHandshakeWriteBuffer(); + if (buffer_.empty()) { + buffer_ = BuildHandshakeWriteBuffer(); bytes_sent_ = 0; } - handshake_buf_len_ = buffer_len_ - bytes_sent_; - DCHECK_GT(handshake_buf_len_, 0); - handshake_buf_ = new IOBuffer(handshake_buf_len_); - memcpy(handshake_buf_.get()->data(), &buffer_[bytes_sent_], - handshake_buf_len_); - return transport_->Write(handshake_buf_, handshake_buf_len_, &io_callback_); + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_GT(handshake_buf_len, 0); + handshake_buf_ = new IOBuffer(handshake_buf_len); + memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_); } int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { @@ -300,11 +292,14 @@ int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { if (result < 0) return result; + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + bytes_sent_ += result; - if (bytes_sent_ == buffer_len_) { + if (bytes_sent_ == buffer_.size()) { next_state_ = STATE_HANDSHAKE_READ; - buffer_.reset(NULL); - } else if (bytes_sent_ < buffer_len_) { + buffer_.clear(); + } else if (bytes_sent_ < buffer_.size()) { next_state_ = STATE_HANDSHAKE_WRITE; } else { return ERR_UNEXPECTED; @@ -318,15 +313,13 @@ int SOCKSClientSocket::DoHandshakeRead() { next_state_ = STATE_HANDSHAKE_READ_COMPLETE; - if (!buffer_.get()) { - buffer_.reset(new char[kReadHeaderSize]); - buffer_len_ = kReadHeaderSize; + if (buffer_.empty()) { bytes_received_ = 0; } - handshake_buf_len_ = buffer_len_ - bytes_received_; - handshake_buf_ = new IOBuffer(handshake_buf_len_); - return transport_->Read(handshake_buf_, handshake_buf_len_, &io_callback_); + int handshake_buf_len = kReadHeaderSize - bytes_received_; + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_); } int SOCKSClientSocket::DoHandshakeReadComplete(int result) { @@ -334,18 +327,23 @@ int SOCKSClientSocket::DoHandshakeReadComplete(int result) { if (result < 0) return result; - if (bytes_received_ + result > buffer_len_) + + // The underlying socket closed unexpectedly. + if (result == 0) + return ERR_CONNECTION_CLOSED; + + if (bytes_received_ + result > kReadHeaderSize) return ERR_INVALID_RESPONSE; - memcpy(buffer_.get() + bytes_received_, handshake_buf_->data(), result); + buffer_.append(handshake_buf_->data(), result); bytes_received_ += result; - if (bytes_received_ < buffer_len_) { + if (bytes_received_ < kReadHeaderSize) { next_state_ = STATE_HANDSHAKE_READ; return OK; } - SOCKS4ServerResponse* response = - reinterpret_cast<SOCKS4ServerResponse*>(buffer_.get()); + const SOCKS4ServerResponse* response = + reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); if (response->reserved_null != 0x00) { LOG(ERROR) << "Unknown response from SOCKS server."; diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 03925ba..295000a 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef NET_BASE_SOCKS_CLIENT_SOCKET_H_ -#define NET_BASE_SOCKS_CLIENT_SOCKET_H_ +#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ +#define NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ #include <string> @@ -54,6 +54,10 @@ class SOCKSClientSocket : public ClientSocket { #endif private: + FRIEND_TEST(SOCKSClientSocketTest, CompleteHandshake); + FRIEND_TEST(SOCKSClientSocketTest, SOCKS4AFailedDNS); + FRIEND_TEST(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6); + enum State { STATE_RESOLVE_HOST, STATE_RESOLVE_HOST_COMPLETE, @@ -85,7 +89,7 @@ class SOCKSClientSocket : public ClientSocket { int DoHandshakeWrite(); int DoHandshakeWriteComplete(int result); - void BuildHandshakeWriteBuffer(); + const std::string BuildHandshakeWriteBuffer() const; CompletionCallbackImpl<SOCKSClientSocket> io_callback_; @@ -102,20 +106,18 @@ class SOCKSClientSocket : public ClientSocket { // SOCKS handshake data. The length contains the expected size to // read or write. scoped_refptr<IOBuffer> handshake_buf_; - int handshake_buf_len_; // While writing, this buffer stores the complete write handshake data. // While reading, it stores the handshake information received so far. - scoped_array<char> buffer_; - int buffer_len_; + std::string buffer_; // This becomes true when the SOCKS handshake has completed and the // overlying connection is free to communicate. bool completed_handshake_; // These contain the bytes sent / received by the SOCKS handshake. - int bytes_sent_; - int bytes_received_; + size_t bytes_sent_; + size_t bytes_received_; // Used to resolve the hostname to which the SOCKS proxy will connect. SingleRequestHostResolver resolver_; @@ -127,5 +129,5 @@ class SOCKSClientSocket : public ClientSocket { } // namespace net -#endif // NET_BASE_SOCKS_CLIENT_SOCKET_H_ +#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc new file mode 100644 index 0000000..813a48e --- /dev/null +++ b/net/socket/socks_client_socket_unittest.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket.h" + +#include "net/base/address_list.h" +#include "net/base/host_resolver_unittest.h" +#include "net/base/listen_socket.h" +#include "net/base/test_completion_callback.h" +#include "net/base/winsock_init.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/tcp_client_socket.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +//----------------------------------------------------------------------------- + +namespace net { + +const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; +const char kSOCKS4aInitialRequest[] = + { 0x04, 0x01, 0x00, 0x50, 0, 0, 0, 127, 0 }; +const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; + +class SOCKSClientSocketTest : public PlatformTest { + public: + SOCKSClientSocketTest(); + // Create a SOCKSClientSocket on top of a MockSocket. + SOCKSClientSocket* BuildMockSocket(MockRead reads[], MockWrite writes[], + const std::string& hostname, int port); + virtual void SetUp(); + + protected: + scoped_ptr<SOCKSClientSocket> user_sock_; + AddressList address_list_; + ClientSocket* tcp_sock_; + ScopedHostMapper host_mapper_; + TestCompletionCallback callback_; + scoped_refptr<RuleBasedHostMapper> mapper_; + HostResolver host_resolver_; + scoped_ptr<MockSocket> mock_socket_; + + private: + DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketTest); +}; + +SOCKSClientSocketTest::SOCKSClientSocketTest() + : host_resolver_(0, 0) { +} + +// Set up platform before every test case +void SOCKSClientSocketTest::SetUp() { + PlatformTest::SetUp(); + + // Resolve the "localhost" AddressList used by the tcp_connection to connect. + HostResolver resolver; + HostResolver::RequestInfo info("localhost", 1080); + int rv = resolver.Resolve(info, &address_list_, NULL, NULL); + ASSERT_EQ(OK, rv); + + // Create a new host mapping for the duration of this test case only. + mapper_ = new RuleBasedHostMapper(); + host_mapper_.Init(mapper_); + mapper_->AddRule("www.google.com", "127.0.0.1"); +} + +SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( + MockRead reads[], + MockWrite writes[], + const std::string& hostname, + int port) { + + TestCompletionCallback callback; + mock_socket_.reset(new StaticMockSocket(reads, writes)); + tcp_sock_ = new MockTCPClientSocket(address_list_, mock_socket_.get()); + + int rv = tcp_sock_->Connect(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(tcp_sock_->IsConnected()); + + return new SOCKSClientSocket(tcp_sock_, + HostResolver::RequestInfo(hostname, port), + &host_resolver_); +} + +// Tests a complete handshake and the disconnection. +TEST_F(SOCKSClientSocketTest, CompleteHandshake) { + const std::string payload_write = "random data"; + const std::string payload_read = "moar random data"; + + MockWrite data_writes[] = { + MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), + MockWrite(true, payload_write.data(), payload_write.size()) }; + MockRead data_reads[] = { + MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)), + MockRead(true, payload_read.data(), payload_read.size()) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + + // At this state the TCP connection is completed but not the SOCKS handshake. + EXPECT_TRUE(tcp_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnected()); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(user_sock_->IsConnected()); + rv = callback_.WaitForResult(); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + EXPECT_EQ(SOCKSClientSocket::kSOCKS4, user_sock_->socks_version_); + + scoped_refptr<IOBuffer> buffer = new IOBuffer(payload_write.size()); + memcpy(buffer->data(), payload_write.data(), payload_write.size()); + rv = user_sock_->Write(buffer, payload_write.size(), &callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(payload_write.size(), rv); + + buffer = new IOBuffer(payload_read.size()); + rv = user_sock_->Read(buffer, payload_read.size(), &callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(payload_read.size(), rv); + EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); + + user_sock_->Disconnect(); + EXPECT_FALSE(tcp_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnected()); +} + +// List of responses from the socks server and the errors they should +// throw up are tested here. +TEST_F(SOCKSClientSocketTest, HandshakeFailures) { + const struct { + const char fail_reply[8]; + Error fail_code; + } tests[] = { + // Failure of the server response code + { + { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, + ERR_INVALID_RESPONSE, + }, + // Failure of the null byte + { + { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, + ERR_FAILED, + }, + }; + + //--------------------------------------- + + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { + MockWrite data_writes[] = { + MockWrite(false, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; + MockRead data_reads[] = { + MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(tests[i].fail_code, rv); + EXPECT_FALSE(user_sock_->IsConnected()); + EXPECT_TRUE(tcp_sock_->IsConnected()); + } +} + +// Tests scenario when the server sends the handshake response in +// more than one packet. +TEST_F(SOCKSClientSocketTest, PartialServerReads) { + const char kSOCKSPartialReply1[] = { 0x00 }; + const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; + + MockWrite data_writes[] = { + MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; + MockRead data_reads[] = { + MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), + MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); +} + +// Tests scenario when the client sends the handshake request in +// more than one packet. +TEST_F(SOCKSClientSocketTest, PartialClientWrites) { + const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; + const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; + + MockWrite data_writes[] = { + MockWrite(true, arraysize(kSOCKSPartialRequest1)), + // simulate some empty writes + MockWrite(true, 0), + MockWrite(true, 0), + MockWrite(true, kSOCKSPartialRequest2, + arraysize(kSOCKSPartialRequest2)) }; + MockRead data_reads[] = { + MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); +} + +// Tests the case when the server sends a smaller sized handshake data +// and closes the connection. +TEST_F(SOCKSClientSocketTest, FailedSocketRead) { + MockWrite data_writes[] = { + MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; + MockRead data_reads[] = { + MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), + // close connection unexpectedly + MockRead(false, 0) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); + EXPECT_FALSE(user_sock_->IsConnected()); +} + +// Tries to connect to an unknown DNS and on failure should revert to SOCKS4A. +TEST_F(SOCKSClientSocketTest, SOCKS4AFailedDNS) { + const char hostname[] = "unresolved.ipv4.address"; + + mapper_->AddSimulatedFailure(hostname); + + std::string request(kSOCKS4aInitialRequest, + arraysize(kSOCKS4aInitialRequest)); + request.append(hostname, arraysize(hostname)); + + MockWrite data_writes[] = { + MockWrite(false, request.data(), request.size()) }; + MockRead data_reads[] = { + MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + EXPECT_EQ(SOCKSClientSocket::kSOCKS4a, user_sock_->socks_version_); +} + +// Tries to connect to a domain that resolves to IPv6. +// Should revert to SOCKS4a. +TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) { + const char hostname[] = "an.ipv6.address"; + + mapper_->AddRule(hostname, "2001:db8:8714:3a90::12"); + + std::string request(kSOCKS4aInitialRequest, + arraysize(kSOCKS4aInitialRequest)); + request.append(hostname, arraysize(hostname)); + + MockWrite data_writes[] = { + MockWrite(false, request.data(), request.size()) }; + MockRead data_reads[] = { + MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); + + int rv = user_sock_->Connect(&callback_); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + EXPECT_EQ(SOCKSClientSocket::kSOCKS4a, user_sock_->socks_version_); +} + +} // namespace net + |