summaryrefslogtreecommitdiffstats
path: root/net/socket
diff options
context:
space:
mode:
Diffstat (limited to 'net/socket')
-rw-r--r--net/socket/socks5_client_socket.cc6
-rw-r--r--net/socket/socks_client_socket.cc7
-rw-r--r--net/socket/socks_client_socket_unittest.cc96
3 files changed, 101 insertions, 8 deletions
diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc
index a709f59..d546231 100644
--- a/net/socket/socks5_client_socket.cc
+++ b/net/socket/socks5_client_socket.cc
@@ -94,6 +94,12 @@ int SOCKS5ClientSocket::Connect(CompletionCallback* callback,
void SOCKS5ClientSocket::Disconnect() {
completed_handshake_ = false;
transport_->Disconnect();
+
+ // Reset other states to make sure they aren't mistakenly used later.
+ // These are the states initialized by Connect().
+ next_state_ = STATE_NONE;
+ user_callback_ = NULL;
+ load_log_ = NULL;
}
bool SOCKS5ClientSocket::IsConnected() const {
diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc
index db70461..4509bfa 100644
--- a/net/socket/socks_client_socket.cc
+++ b/net/socket/socks_client_socket.cc
@@ -107,7 +107,14 @@ int SOCKSClientSocket::Connect(CompletionCallback* callback,
void SOCKSClientSocket::Disconnect() {
completed_handshake_ = false;
+ host_resolver_.Cancel();
transport_->Disconnect();
+
+ // Reset other states to make sure they aren't mistakenly used later.
+ // These are the states initialized by Connect().
+ next_state_ = STATE_NONE;
+ user_callback_ = NULL;
+ load_log_ = NULL;
}
bool SOCKSClientSocket::IsConnected() const {
diff --git a/net/socket/socks_client_socket_unittest.cc b/net/socket/socks_client_socket_unittest.cc
index 6fb7daf..6cf450e 100644
--- a/net/socket/socks_client_socket_unittest.cc
+++ b/net/socket/socks_client_socket_unittest.cc
@@ -30,6 +30,7 @@ class SOCKSClientSocketTest : public PlatformTest {
SOCKSClientSocketTest();
// Create a SOCKSClientSocket on top of a MockSocket.
SOCKSClientSocket* BuildMockSocket(MockRead reads[], MockWrite writes[],
+ HostResolver* host_resolver,
const std::string& hostname, int port);
virtual void SetUp();
@@ -57,6 +58,7 @@ void SOCKSClientSocketTest::SetUp() {
SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
MockRead reads[],
MockWrite writes[],
+ HostResolver* host_resolver,
const std::string& hostname,
int port) {
@@ -72,9 +74,48 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
return new SOCKSClientSocket(tcp_sock_,
HostResolver::RequestInfo(hostname, port),
- host_resolver_);
+ host_resolver);
}
+// Implementation of HostResolver that never completes its resolve request.
+// We use this in the test "DisconnectWhileHostResolveInProgress" to make
+// sure that the outstanding resolve request gets cancelled.
+class HangingHostResolver : public HostResolver {
+ public:
+ HangingHostResolver() : outstanding_request_(NULL) {}
+
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ CompletionCallback* callback,
+ RequestHandle* out_req,
+ LoadLog* load_log) {
+ EXPECT_FALSE(HasOutstandingRequest());
+ outstanding_request_ = reinterpret_cast<RequestHandle>(1);
+ *out_req = outstanding_request_;
+ return ERR_IO_PENDING;
+ }
+
+ virtual void CancelRequest(RequestHandle req) {
+ EXPECT_TRUE(HasOutstandingRequest());
+ EXPECT_EQ(outstanding_request_, req);
+ outstanding_request_ = NULL;
+ }
+
+ virtual void AddObserver(Observer* observer) {}
+ virtual void RemoveObserver(Observer* observer) {}
+ virtual HostCache* GetHostCache() { return NULL; }
+ virtual void Shutdown() {}
+
+ bool HasOutstandingRequest() {
+ return outstanding_request_ != NULL;
+ }
+
+ private:
+ RequestHandle outstanding_request_;
+
+ DISALLOW_COPY_AND_ASSIGN(HangingHostResolver);
+};
+
// Tests a complete handshake and the disconnection.
TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
const std::string payload_write = "random data";
@@ -87,7 +128,8 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
MockRead(true, payload_read.data(), payload_read.size()) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ "localhost", 80));
// At this state the TCP connection is completed but not the SOCKS handshake.
EXPECT_TRUE(tcp_sock_->IsConnected());
@@ -153,7 +195,8 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
MockRead data_reads[] = {
MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ "localhost", 80));
scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded));
int rv = user_sock_->Connect(&callback_, log);
@@ -181,7 +224,8 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) {
MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ "localhost", 80));
scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded));
int rv = user_sock_->Connect(&callback_, log);
@@ -211,7 +255,8 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
MockRead data_reads[] = {
MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ "localhost", 80));
scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded));
int rv = user_sock_->Connect(&callback_, log);
@@ -235,7 +280,8 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
// close connection unexpectedly
MockRead(false, 0) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, "localhost", 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ "localhost", 80));
scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded));
int rv = user_sock_->Connect(&callback_, log);
@@ -264,7 +310,8 @@ TEST_F(SOCKSClientSocketTest, SOCKS4AFailedDNS) {
MockRead data_reads[] = {
MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ hostname, 80));
scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded));
int rv = user_sock_->Connect(&callback_, log);
@@ -295,7 +342,8 @@ TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) {
MockRead data_reads[] = {
MockRead(false, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
- user_sock_.reset(BuildMockSocket(data_reads, data_writes, hostname, 80));
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, host_resolver_,
+ hostname, 80));
scoped_refptr<LoadLog> log(new LoadLog(LoadLog::kUnbounded));
int rv = user_sock_->Connect(&callback_, log);
@@ -310,4 +358,36 @@ TEST_F(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6) {
*log, -1, LoadLog::TYPE_SOCKS_CONNECT, LoadLog::PHASE_END));
}
+// Calls Disconnect() while a host resolve is in progress. The outstanding host
+// resolve should be cancelled.
+TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
+ scoped_refptr<HangingHostResolver> hanging_resolver =
+ new HangingHostResolver();
+
+ // Doesn't matter what the socket data is, we will never use it -- garbage.
+ MockWrite data_writes[] = { MockWrite(false, "", 0) };
+ MockRead data_reads[] = { MockRead(false, "", 0) };
+
+ user_sock_.reset(BuildMockSocket(data_reads, data_writes, hanging_resolver,
+ "foo", 80));
+
+ // Start connecting (will get stuck waiting for the host to resolve).
+ int rv = user_sock_->Connect(&callback_, NULL);
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_FALSE(user_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
+
+ // The host resolver should have received the resolve request.
+ EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
+
+ // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
+ user_sock_->Disconnect();
+
+ EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
+
+ EXPECT_FALSE(user_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
+}
+
} // namespace net