diff options
-rw-r--r-- | net/socket/stream_socket.h | 2 | ||||
-rw-r--r-- | net/socket/tcp_client_socket_libevent.cc | 10 | ||||
-rw-r--r-- | net/socket/tcp_client_socket_unittest.cc | 10 | ||||
-rw-r--r-- | net/socket/tcp_client_socket_win.cc | 12 |
4 files changed, 29 insertions, 5 deletions
diff --git a/net/socket/stream_socket.h b/net/socket/stream_socket.h index 9a7d153..60a057d 100644 --- a/net/socket/stream_socket.h +++ b/net/socket/stream_socket.h @@ -60,7 +60,7 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { virtual int GetPeerAddress(IPEndPoint* address) const = 0; // Copies the local address to |address| and returns a network error code. - // ERR_SOCKET_NOT_CONNECTED will be returned if the socket is not connected. + // ERR_SOCKET_NOT_CONNECTED will be returned if the socket is not bound. virtual int GetLocalAddress(IPEndPoint* address) const = 0; // Gets the NetLog for this socket. diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index e54eb1e..a3727ff 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -180,7 +180,7 @@ int TCPClientSocketLibevent::Bind(const IPEndPoint& address) { if (!address.ToSockAddr(storage.addr, &storage.addr_len)) return ERR_INVALID_ARGUMENT; - // Create |bound_socket_| and try to bound it to |address|. + // Create |bound_socket_| and try to bind it to |address|. int error = CreateSocket(address.GetFamily(), &bound_socket_); if (error) return MapSystemError(error); @@ -363,6 +363,7 @@ void TCPClientSocketLibevent::Disconnect() { DoDisconnect(); current_address_index_ = -1; + bind_address_.reset(); } void TCPClientSocketLibevent::DoDisconnect() { @@ -718,8 +719,13 @@ int TCPClientSocketLibevent::GetPeerAddress(IPEndPoint* address) const { int TCPClientSocketLibevent::GetLocalAddress(IPEndPoint* address) const { DCHECK(CalledOnValidThread()); DCHECK(address); - if (!IsConnected()) + if (socket_ == kInvalidSocket) { + if (bind_address_.get()) { + *address = *bind_address_; + return OK; + } return ERR_SOCKET_NOT_CONNECTED; + } SockaddrStorage storage; if (getsockname(socket_, storage.addr, &storage.addr_len)) diff --git a/net/socket/tcp_client_socket_unittest.cc b/net/socket/tcp_client_socket_unittest.cc index c2589b8..ce0c535 100644 --- a/net/socket/tcp_client_socket_unittest.cc +++ b/net/socket/tcp_client_socket_unittest.cc @@ -34,6 +34,10 @@ TEST(TCPClientSocketTest, BindLoopbackToLoopback) { EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0))); + IPEndPoint local_address_result; + EXPECT_EQ(OK, socket.GetLocalAddress(&local_address_result)); + EXPECT_EQ(lo_address, local_address_result.address()); + TestCompletionCallback connect_callback; EXPECT_EQ(ERR_IO_PENDING, socket.Connect(connect_callback.callback())); @@ -45,6 +49,12 @@ TEST(TCPClientSocketTest, BindLoopbackToLoopback) { ASSERT_EQ(OK, result); EXPECT_EQ(OK, connect_callback.WaitForResult()); + + EXPECT_TRUE(socket.IsConnected()); + socket.Disconnect(); + EXPECT_FALSE(socket.IsConnected()); + EXPECT_EQ(ERR_SOCKET_NOT_CONNECTED, + socket.GetLocalAddress(&local_address_result)); } // Try to bind socket to the loopback interface and connect to an diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index 74fa99e..3fda1bc 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -373,7 +373,7 @@ int TCPClientSocketWin::Bind(const IPEndPoint& address) { if (!address.ToSockAddr(storage.addr, &storage.addr_len)) return ERR_INVALID_ARGUMENT; - // Create |bound_socket_| and try to bound it to |address|. + // Create |bound_socket_| and try to bind it to |address|. int error = CreateSocket(address.GetFamily(), &bound_socket_); if (error) return MapSystemError(error); @@ -553,8 +553,11 @@ int TCPClientSocketWin::DoConnectComplete(int result) { } void TCPClientSocketWin::Disconnect() { + DCHECK(CalledOnValidThread()); + DoDisconnect(); current_address_index_ = -1; + bind_address_.reset(); } void TCPClientSocketWin::DoDisconnect() { @@ -646,8 +649,13 @@ int TCPClientSocketWin::GetPeerAddress(IPEndPoint* address) const { int TCPClientSocketWin::GetLocalAddress(IPEndPoint* address) const { DCHECK(CalledOnValidThread()); DCHECK(address); - if (!IsConnected()) + if (socket_ == INVALID_SOCKET) { + if (bind_address_.get()) { + *address = *bind_address_; + return OK; + } return ERR_SOCKET_NOT_CONNECTED; + } struct sockaddr_storage addr_storage; socklen_t addr_len = sizeof(addr_storage); |