diff options
Diffstat (limited to 'net/socket')
-rw-r--r-- | net/socket/socks5_client_socket.cc | 6 | ||||
-rw-r--r-- | net/socket/socks_client_socket.cc | 7 | ||||
-rw-r--r-- | net/socket/socks_client_socket_unittest.cc | 96 |
3 files changed, 101 insertions, 8 deletions
diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index a709f59..d546231 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -94,6 +94,12 @@ int SOCKS5ClientSocket::Connect(CompletionCallback* callback, void SOCKS5ClientSocket::Disconnect() { completed_handshake_ = false; transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_ = NULL; + load_log_ = NULL; } bool SOCKS5ClientSocket::IsConnected() const { diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index db70461..4509bfa 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -107,7 +107,14 @@ int SOCKSClientSocket::Connect(CompletionCallback* callback, void SOCKSClientSocket::Disconnect() { completed_handshake_ = false; + host_resolver_.Cancel(); transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_ = NULL; + load_log_ = NULL; } bool SOCKSClientSocket::IsConnected() const { diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc index 6fb7daf..6cf450e 100644 --- a/net/socket/socks_client_socket_unittest.cc +++ b/net/socket/socks_client_socket_unittest.cc @@ -30,6 +30,7 @@ class SOCKSClientSocketTest : public PlatformTest { SOCKSClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. SOCKSClientSocket* BuildMockSocket(MockRead reads[], MockWrite writes[], + HostResolver* host_resolver, const std::string& hostname, int port); virtual void SetUp(); @@ -57,6 +58,7 @@ void SOCKSClientSocketTest::SetUp() { SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( MockRead reads[], MockWrite writes[], + HostResolver* host_resolver, const std::string& hostname, int port) { @@ -72,9 +74,48 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( return new SOCKSClientSocket(tcp_sock_, HostResolver::RequestInfo(hostname, port), - host_resolver_); + host_resolver); } +// Implementation of HostResolver that never completes its resolve request. +// We use this in the test "DisconnectWhileHostResolveInProgress" to make +// sure that the outstanding resolve request gets cancelled. +class HangingHostResolver : public HostResolver { + public: + HangingHostResolver() : outstanding_request_(NULL) {} + + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + CompletionCallback* callback, + RequestHandle* out_req, + LoadLog* load_log) { + EXPECT_FALSE(HasOutstandingRequest()); + outstanding_request_ = reinterpret_cast<RequestHandle>(1); + *out_req = outstanding_request_; + return ERR_IO_PENDING; + } + + virtual void CancelRequest(RequestHandle req) { + EXPECT_TRUE(HasOutstandingRequest()); + EXPECT_EQ(outstanding_request_, req); + outstanding_request_ = NULL; + } + + virtual void AddObserver(Observer* observer) {} + virtual void RemoveObserver(Observer* observer) {} + virtual HostCache* GetHostCache() { return NULL; } + virtual void Shutdown() {} + + bool HasOutstandingRequest() { + return outstanding_request_ != NULL; + } + + private: + RequestHandle outstanding_request_; + + DISALLOW_COPY_AND_ASSIGN(HangingHostResolver); +}; + // Tests a complete handshake and the disconnection. TEST_F(SOCKSClientSocketTest, CompleteHandshake) { const std::string payload_write = "random data"; @@ -87,7 +128,8 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)), MockRead(true, payload_read.data(), payload_read.size()) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + "localhost", 80)); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -153,7 +195,8 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { MockRead data_reads[] = { MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + "localhost", 80)); scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded)); int rv = user_sock_->Connect(&callback_, log); @@ -181,7 +224,8 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + "localhost", 80)); scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded)); int rv = user_sock_->Connect(&callback_, log); @@ -211,7 +255,8 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { MockRead data_reads[] = { MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + "localhost", 80)); scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded)); int rv = user_sock_->Connect(&callback_, log); @@ -235,7 +280,8 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { // close connection unexpectedly MockRead(false, 0) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + "localhost", 80)); scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded)); int rv = user_sock_->Connect(&callback_, log); @@ -264,7 +310,8 @@ TEST_F(SOCKSClientSocketTest, SOCKS4AFailedDNS) { MockRead data_reads[] = { MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + hostname, 80)); scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded)); int rv = user_sock_->Connect(&callback_, log); @@ -295,7 +342,8 @@ TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) { MockRead data_reads[] = { MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80)); + user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_, + hostname, 80)); scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded)); int rv = user_sock_->Connect(&callback_, log); @@ -310,4 +358,36 @@ TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) { *log, -1, LoadLog::TYPE_SOCKS_CONNECT, LoadLog::PHASE_END)); } +// Calls Disconnect() while a host resolve is in progress. The outstanding host +// resolve should be cancelled. +TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { + scoped_refptr<HangingHostResolver> hanging_resolver = + new HangingHostResolver(); + + // Doesn't matter what the socket data is, we will never use it -- garbage. + MockWrite data_writes[] = { MockWrite(false, "", 0) }; + MockRead data_reads[] = { MockRead(false, "", 0) }; + + user_sock_.reset(BuildMockSocket(data_reads, data_writes, hanging_resolver, + "foo", 80)); + + // Start connecting (will get stuck waiting for the host to resolve). + int rv = user_sock_->Connect(&callback_, NULL); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_FALSE(user_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); + + // The host resolver should have received the resolve request. + EXPECT_TRUE(hanging_resolver->HasOutstandingRequest()); + + // Disconnect the SOCKS socket -- this should cancel the outstanding resolve. + user_sock_->Disconnect(); + + EXPECT_FALSE(hanging_resolver->HasOutstandingRequest()); + + EXPECT_FALSE(user_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); +} + } // namespace net |