diff options
author | phajdan.jr@chromium.org <phajdan.jr@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-02-20 18:50:38 +0000 |
---|---|---|
committer | phajdan.jr@chromium.org <phajdan.jr@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-02-20 18:50:38 +0000 |
commit | ac9eec64de86f3d3a290a1a8b9321260cff7ed23 (patch) | |
tree | aac041c6ddaec400b6e2b6d3d982935aa7f69a9c /net | |
parent | 8c1ae5ec4d47638315096f54819793484383c91f (diff) | |
download | chromium_src-ac9eec64de86f3d3a290a1a8b9321260cff7ed23.zip chromium_src-ac9eec64de86f3d3a290a1a8b9321260cff7ed23.tar.gz chromium_src-ac9eec64de86f3d3a290a1a8b9321260cff7ed23.tar.bz2 |
Really connect to the same server in FTP network transaction.
Also create necessary infrastructure to know the address
a client socket is connected to.
TEST=Covered by net_unittests.
BUG=35670
Review URL: http://codereview.chromium.org/598071
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@39559 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
28 files changed, 218 insertions, 122 deletions
diff --git a/net/base/address_list.cc b/net/base/address_list.cc index 93ec009..d1624be 100644 --- a/net/base/address_list.cc +++ b/net/base/address_list.cc @@ -13,9 +13,12 @@ namespace net { namespace { -// Make a deep copy of |info|. This copy should be deleted using +// Make a copy of |info| (the dynamically-allocated parts are copied as well). +// If |recursive| is true, chained entries via ai_next are copied too. +// Copy returned by this function should be deleted using // DeleteCopyOfAddrinfo(), and NOT freeaddrinfo(). -struct addrinfo* CreateCopyOfAddrinfo(const struct addrinfo* info) { +struct addrinfo* CreateCopyOfAddrinfo(const struct addrinfo* info, + bool recursive) { struct addrinfo* copy = new addrinfo; // Copy all the fields (some of these are pointers, we will fix that next). @@ -37,8 +40,10 @@ struct addrinfo* CreateCopyOfAddrinfo(const struct addrinfo* info) { } // Recursive copy. - if (info->ai_next) - copy->ai_next = CreateCopyOfAddrinfo(info->ai_next); + if (recursive && info->ai_next) + copy->ai_next = CreateCopyOfAddrinfo(info->ai_next, recursive); + else + copy->ai_next = NULL; return copy; } @@ -81,7 +86,8 @@ uint16* GetPortField(const struct addrinfo* info) { // Assign the port for all addresses in the list. void SetPortRecursive(struct addrinfo* info, int port) { uint16* port_field = GetPortField(info); - *port_field = htons(port); + if (port_field) + *port_field = htons(port); // Assign recursively. if (info->ai_next) @@ -94,8 +100,25 @@ void AddressList::Adopt(struct addrinfo* head) { data_ = new Data(head, true /*is_system_created*/); } -void AddressList::Copy(const struct addrinfo* head) { - data_ = new Data(CreateCopyOfAddrinfo(head), false /*is_system_created*/); +void AddressList::Copy(const struct addrinfo* head, bool recursive) { + data_ = new Data(CreateCopyOfAddrinfo(head, recursive), + false /*is_system_created*/); +} + +void AddressList::Append(const struct addrinfo* head) { + struct addrinfo* new_head; + if (data_->is_system_created) { + new_head = CreateCopyOfAddrinfo(data_->head, true); + data_ = new Data(new_head, false /*is_system_created*/); + } else { + new_head = data_->head; + } + + // Find the end of current linked list and append new data there. + struct addrinfo* copy_ptr = new_head; + while (copy_ptr->ai_next) + copy_ptr = copy_ptr->ai_next; + copy_ptr->ai_next = CreateCopyOfAddrinfo(head, true); } void AddressList::SetPort(int port) { @@ -104,6 +127,9 @@ void AddressList::SetPort(int port) { int AddressList::GetPort() const { uint16* port_field = GetPortField(data_->head); + if (!port_field) + return -1; + return ntohs(*port_field); } @@ -113,7 +139,7 @@ void AddressList::SetFrom(const AddressList& src, int port) { *this = src; } else { // Otherwise we need to make a copy in order to change the port number. - Copy(src.head()); + Copy(src.head(), true); SetPort(port); } } diff --git a/net/base/address_list.h b/net/base/address_list.h index 3087472..b477987 100644 --- a/net/base/address_list.h +++ b/net/base/address_list.h @@ -23,8 +23,14 @@ class AddressList { // object. void Adopt(struct addrinfo* head); - // Copies the given addrinfo rather than adopting it. - void Copy(const struct addrinfo* head); + // Copies the given addrinfo rather than adopting it. If |recursive| is true, + // all linked struct addrinfos will be copied as well. Otherwise only the head + // will be copied, and the rest of linked entries will be ignored. + void Copy(const struct addrinfo* head, bool recursive); + + // Appends a copy of |head| and all its linked addrinfos to the stored + // addrinfo. + void Append(const struct addrinfo* head); // Sets the port of all addresses in the list to |port| (that is the // sin[6]_port field for the sockaddrs). diff --git a/net/base/address_list_unittest.cc b/net/base/address_list_unittest.cc index d440c15..dbd0f24 100644 --- a/net/base/address_list_unittest.cc +++ b/net/base/address_list_unittest.cc @@ -7,6 +7,7 @@ #include "base/string_util.h" #include "net/base/host_resolver_proc.h" #include "net/base/net_util.h" +#include "net/base/sys_addrinfo.h" #if defined(OS_WIN) #include "net/base/winsock_init.h" #endif @@ -15,20 +16,28 @@ namespace { // Use getaddrinfo() to allocate an addrinfo structure. -void CreateAddressList(net::AddressList* addrlist, int port) { +void CreateAddressList(const std::string& hostname, + net::AddressList* addrlist, int port) { #if defined(OS_WIN) net::EnsureWinsockInit(); #endif - int rv = SystemHostResolverProc("192.168.1.1", + int rv = SystemHostResolverProc(hostname, net::ADDRESS_FAMILY_UNSPECIFIED, addrlist); EXPECT_EQ(0, rv); addrlist->SetPort(port); } +void CreateLongAddressList(net::AddressList* addrlist, int port) { + CreateAddressList("192.168.1.1", addrlist, port); + net::AddressList second_list; + CreateAddressList("192.168.1.2", &second_list, port); + addrlist->Append(second_list.head()); +} + TEST(AddressListTest, GetPort) { net::AddressList addrlist; - CreateAddressList(&addrlist, 81); + CreateAddressList("192.168.1.1", &addrlist, 81); EXPECT_EQ(81, addrlist.GetPort()); addrlist.SetPort(83); @@ -37,7 +46,7 @@ TEST(AddressListTest, GetPort) { TEST(AddressListTest, Assignment) { net::AddressList addrlist1; - CreateAddressList(&addrlist1, 85); + CreateAddressList("192.168.1.1", &addrlist1, 85); EXPECT_EQ(85, addrlist1.GetPort()); // Should reference the same data as addrlist1 -- so when we change addrlist1 @@ -50,13 +59,15 @@ TEST(AddressListTest, Assignment) { EXPECT_EQ(80, addrlist2.GetPort()); } -TEST(AddressListTest, Copy) { +TEST(AddressListTest, CopyRecursive) { net::AddressList addrlist1; - CreateAddressList(&addrlist1, 85); + CreateLongAddressList(&addrlist1, 85); EXPECT_EQ(85, addrlist1.GetPort()); net::AddressList addrlist2; - addrlist2.Copy(addrlist1.head()); + addrlist2.Copy(addrlist1.head(), true); + + ASSERT_TRUE(addrlist2.head()->ai_next != NULL); // addrlist1 is the same as addrlist2 at this point. EXPECT_EQ(85, addrlist1.GetPort()); @@ -70,4 +81,43 @@ TEST(AddressListTest, Copy) { EXPECT_EQ(90, addrlist2.GetPort()); } +TEST(AddressListTest, CopyNonRecursive) { + net::AddressList addrlist1; + CreateLongAddressList(&addrlist1, 85); + EXPECT_EQ(85, addrlist1.GetPort()); + + net::AddressList addrlist2; + addrlist2.Copy(addrlist1.head(), false); + + ASSERT_TRUE(addrlist2.head()->ai_next == NULL); + + // addrlist1 is the same as addrlist2 at this point. + EXPECT_EQ(85, addrlist1.GetPort()); + EXPECT_EQ(85, addrlist2.GetPort()); + + // Changes to addrlist1 are not reflected in addrlist2. + addrlist1.SetPort(70); + addrlist2.SetPort(90); + + EXPECT_EQ(70, addrlist1.GetPort()); + EXPECT_EQ(90, addrlist2.GetPort()); +} + +TEST(AddressListTest, Append) { + net::AddressList addrlist1; + CreateAddressList("192.168.1.1", &addrlist1, 11); + EXPECT_EQ(11, addrlist1.GetPort()); + net::AddressList addrlist2; + CreateAddressList("192.168.1.2", &addrlist2, 12); + EXPECT_EQ(12, addrlist2.GetPort()); + + ASSERT_TRUE(addrlist1.head()->ai_next == NULL); + addrlist1.Append(addrlist2.head()); + ASSERT_TRUE(addrlist1.head()->ai_next != NULL); + + net::AddressList addrlist3; + addrlist3.Copy(addrlist1.head()->ai_next, false); + EXPECT_EQ(12, addrlist3.GetPort()); +} + } // namespace diff --git a/net/base/nss_memio.c b/net/base/nss_memio.c index 341cfee..6796882 100644 --- a/net/base/nss_memio.c +++ b/net/base/nss_memio.c @@ -359,11 +359,17 @@ PRFileDesc *memio_CreateIOLayer(int bufsize) return fd; } -void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername) +void memio_SetPeerName(PRFileDesc *fd, const struct sockaddr *peername, + size_t peername_len) { PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity); struct PRFilePrivate *secret = memiofd->secret; - secret->peername = *peername; + size_t len; + + memset(&secret->peername, 0, sizeof(secret->peername)); + PR_ASSERT(peername_len <= sizeof(secret->peername)); + len = PR_MIN(peername_len, sizeof(secret->peername)); + memcpy(&secret->peername, peername, len); } memio_Private *memio_GetSecret(PRFileDesc *fd) diff --git a/net/base/nss_memio.h b/net/base/nss_memio.h index 0bee53e..a9e6e22 100644 --- a/net/base/nss_memio.h +++ b/net/base/nss_memio.h @@ -6,12 +6,16 @@ #ifndef __MEMIO_H #define __MEMIO_H +#include <stddef.h> + #ifdef __cplusplus extern "C" { #endif #include "prio.h" +struct sockaddr; + /* Opaque structure. Really just a more typesafe alias for PRFilePrivate. */ struct memio_Private; typedef struct memio_Private memio_Private; @@ -38,7 +42,8 @@ typedef struct memio_Private memio_Private; PRFileDesc *memio_CreateIOLayer(int bufsize); /* Must call before trying to make an ssl connection */ -void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername); +void memio_SetPeerName(PRFileDesc *fd, const struct sockaddr *peername, + size_t peername_len); /* Return a private pointer needed by the following * four functions. (We could have passed a PRFileDesc to diff --git a/net/ftp/ftp_network_transaction.cc b/net/ftp/ftp_network_transaction.cc index 2bfbda0..5448c34 100644 --- a/net/ftp/ftp_network_transaction.cc +++ b/net/ftp/ftp_network_transaction.cc @@ -1104,14 +1104,14 @@ int FtpNetworkTransaction::ProcessResponseQUIT( int FtpNetworkTransaction::DoDataConnect() { next_state_ = STATE_DATA_CONNECT_COMPLETE; - AddressList data_addresses; - // TODO(phajdan.jr): Use exactly same IP address as the control socket. - // If the DNS name resolves to several different IPs, and they are different - // physical servers, this will break. However, that configuration is very rare - // in practice. - data_addresses.Copy(addresses_.head()); - data_addresses.SetPort(data_connection_port_); - data_socket_.reset(socket_factory_->CreateTCPClientSocket(data_addresses)); + AddressList data_address; + // Connect to the same host as the control socket to prevent PASV port + // scanning attacks. + int rv = ctrl_socket_->GetPeerAddress(&data_address); + if (rv != OK) + return Stop(rv); + data_address.SetPort(data_connection_port_); + data_socket_.reset(socket_factory_->CreateTCPClientSocket(data_address)); return data_socket_->Connect(&io_callback_, load_log_); } diff --git a/net/ftp/ftp_network_transaction_unittest.cc b/net/ftp/ftp_network_transaction_unittest.cc index d9be57c..adf074e 100644 --- a/net/ftp/ftp_network_transaction_unittest.cc +++ b/net/ftp/ftp_network_transaction_unittest.cc @@ -610,14 +610,14 @@ class FtpNetworkTransactionTest : public PlatformTest { ASSERT_EQ(ERR_IO_PENDING, transaction_.Start(&request_info, &callback_, NULL)); EXPECT_NE(LOAD_STATE_IDLE, transaction_.GetLoadState()); - EXPECT_EQ(expected_result, callback_.WaitForResult()); + ASSERT_EQ(expected_result, callback_.WaitForResult()); EXPECT_EQ(FtpSocketDataProvider::QUIT, ctrl_socket->state()); if (expected_result == OK) { scoped_refptr<IOBuffer> io_buffer(new IOBuffer(kBufferSize)); memset(io_buffer->data(), 0, kBufferSize); ASSERT_EQ(ERR_IO_PENDING, transaction_.Read(io_buffer.get(), kBufferSize, &callback_)); - EXPECT_EQ(static_cast<int>(mock_data.length()), + ASSERT_EQ(static_cast<int>(mock_data.length()), callback_.WaitForResult()); EXPECT_EQ(mock_data, std::string(io_buffer->data(), mock_data.length())); if (transaction_.GetResponseInfo()->is_directory_listing) { @@ -653,7 +653,7 @@ TEST_F(FtpNetworkTransactionTest, FailedLookup) { EXPECT_EQ(LOAD_STATE_IDLE, transaction_.GetLoadState()); ASSERT_EQ(ERR_IO_PENDING, transaction_.Start(&request_info, &callback_, NULL)); - EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback_.WaitForResult()); + ASSERT_EQ(ERR_NAME_NOT_RESOLVED, callback_.WaitForResult()); EXPECT_EQ(LOAD_STATE_IDLE, transaction_.GetLoadState()); } @@ -748,7 +748,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionAcceptedDataConnection) { // Start the transaction. ASSERT_EQ(ERR_IO_PENDING, transaction_.Start(&request_info, &callback_, NULL)); - EXPECT_EQ(OK, callback_.WaitForResult()); + ASSERT_EQ(OK, callback_.WaitForResult()); // The transaction fires the callback when we can start reading data. EXPECT_EQ(FtpSocketDataProvider::PRE_QUIT, ctrl_socket.state()); @@ -758,7 +758,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionAcceptedDataConnection) { ASSERT_EQ(ERR_IO_PENDING, transaction_.Read(io_buffer.get(), kBufferSize, &callback_)); EXPECT_EQ(LOAD_STATE_READING_RESPONSE, transaction_.GetLoadState()); - EXPECT_EQ(static_cast<int>(mock_data.length()), + ASSERT_EQ(static_cast<int>(mock_data.length()), callback_.WaitForResult()); EXPECT_EQ(LOAD_STATE_READING_RESPONSE, transaction_.GetLoadState()); EXPECT_EQ(mock_data, std::string(io_buffer->data(), mock_data.length())); @@ -775,7 +775,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionAcceptedDataConnection) { // Make sure the transaction finishes cleanly. EXPECT_EQ(LOAD_STATE_IDLE, transaction_.GetLoadState()); - EXPECT_EQ(OK, callback_.WaitForResult()); + ASSERT_EQ(OK, callback_.WaitForResult()); EXPECT_EQ(FtpSocketDataProvider::QUIT, ctrl_socket.state()); EXPECT_EQ(LOAD_STATE_IDLE, transaction_.GetLoadState()); } @@ -833,7 +833,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionEvilPasvUnsafeHost) { // Start the transaction. ASSERT_EQ(ERR_IO_PENDING, transaction_.Start(&request_info, &callback_, NULL)); - EXPECT_EQ(OK, callback_.WaitForResult()); + ASSERT_EQ(OK, callback_.WaitForResult()); // The transaction fires the callback when we can start reading data. That // means that the data socket should be open. @@ -843,11 +843,13 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionEvilPasvUnsafeHost) { ASSERT_TRUE(data_socket->IsConnected()); // Even if the PASV response specified some other address, we connect - // to the address we used for control connection. - EXPECT_EQ("127.0.0.1", NetAddressToString(data_socket->addresses().head())); - - // Make sure we have only one host entry in the AddressList. - EXPECT_FALSE(data_socket->addresses().head()->ai_next); + // to the address we used for control connection (which could be 127.0.0.1 + // or ::1 depending on whether we use IPv6). + const struct addrinfo* addrinfo = data_socket->addresses().head(); + while (addrinfo) { + EXPECT_NE("1.2.3.4", NetAddressToString(addrinfo)); + addrinfo = addrinfo->ai_next; + } } TEST_F(FtpNetworkTransactionTest, DownloadTransactionEvilLoginBadUsername) { @@ -881,7 +883,7 @@ TEST_F(FtpNetworkTransactionTest, EvilRestartUser) { ASSERT_EQ(ERR_IO_PENDING, transaction_.Start(&request_info, &callback_, NULL)); - EXPECT_EQ(ERR_FAILED, callback_.WaitForResult()); + ASSERT_EQ(ERR_FAILED, callback_.WaitForResult()); MockRead ctrl_reads[] = { MockRead("220 host TestFTPd\r\n"), @@ -911,7 +913,7 @@ TEST_F(FtpNetworkTransactionTest, EvilRestartPassword) { ASSERT_EQ(ERR_IO_PENDING, transaction_.Start(&request_info, &callback_, NULL)); - EXPECT_EQ(ERR_FAILED, callback_.WaitForResult()); + ASSERT_EQ(ERR_FAILED, callback_.WaitForResult()); MockRead ctrl_reads[] = { MockRead("220 host TestFTPd\r\n"), diff --git a/net/proxy/proxy_resolver_js_bindings_unittest.cc b/net/proxy/proxy_resolver_js_bindings_unittest.cc index 5035f3e..9231f93 100644 --- a/net/proxy/proxy_resolver_js_bindings_unittest.cc +++ b/net/proxy/proxy_resolver_js_bindings_unittest.cc @@ -67,7 +67,7 @@ class MockHostResolverWithMultipleResults : public HostResolver { // Make a copy of the concatenated list. AddressList concatenated; - concatenated.Copy(result.head()); + concatenated.Copy(result.head(), true); // Restore |result| (so it is freed properly). result_head->ai_next = NULL; diff --git a/net/socket/client_socket.h b/net/socket/client_socket.h index 28c7b4d..e696554 100644 --- a/net/socket/client_socket.h +++ b/net/socket/client_socket.h @@ -5,20 +5,11 @@ #ifndef NET_SOCKET_CLIENT_SOCKET_H_ #define NET_SOCKET_CLIENT_SOCKET_H_ -#include "build/build_config.h" - -// For struct sockaddr and socklen_t. -#if defined(OS_POSIX) -#include <sys/types.h> -#include <sys/socket.h> -#elif defined(OS_WIN) -#include <ws2tcpip.h> -#endif - #include "net/socket/socket.h" namespace net { +class AddressList; class LoadLog; class ClientSocket : public Socket { @@ -57,9 +48,8 @@ class ClientSocket : public Socket { // have been received. virtual bool IsConnectedAndIdle() const = 0; - // Identical to BSD socket call getpeername(). - // Needed by ssl_client_socket_nss and ssl_client_socket_mac. - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen) = 0; + // Copies the peer address to |address| and returns a network error code. + virtual int GetPeerAddress(AddressList* address) const = 0; }; } // namespace net diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 87e0250..dcc2c09 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -59,8 +59,7 @@ class MockClientSocket : public ClientSocket { virtual bool IsConnected() const { return connected_; } virtual bool IsConnectedAndIdle() const { return connected_; } - virtual int GetPeerName(struct sockaddr* /* name */, - socklen_t* /* namelen */) { + virtual int GetPeerAddress(AddressList* /* address */) const { return ERR_UNEXPECTED; } diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 1ea1ec9..1634106 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -9,6 +9,8 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" #include "base/message_loop.h" +#include "net/base/address_family.h" +#include "net/base/host_resolver_proc.h" #include "net/base/ssl_info.h" #include "net/socket/socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -47,9 +49,9 @@ bool MockClientSocket::IsConnectedAndIdle() const { return connected_; } -int MockClientSocket::GetPeerName(struct sockaddr* name, socklen_t* namelen) { - memset(reinterpret_cast<char *>(name), 0, *namelen); - return net::OK; +int MockClientSocket::GetPeerAddress(AddressList* address) const { + return net::SystemHostResolverProc("localhost", ADDRESS_FAMILY_UNSPECIFIED, + address); } void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback, @@ -370,12 +372,14 @@ void MockClientSocketFactory::ResetNextMockIndexes() { } MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket( - int index) const { + size_t index) const { + DCHECK_LT(index, tcp_client_sockets_.size()); return tcp_client_sockets_[index]; } MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( - int index) const { + size_t index) const { + DCHECK_LT(index, ssl_client_sockets_.size()); return ssl_client_sockets_[index]; } diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index ff5384a..e2f3504 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -245,11 +245,11 @@ class MockClientSocketFactory : public ClientSocketFactory { // Return |index|-th MockTCPClientSocket (starting from 0) that the factory // created. - MockTCPClientSocket* GetMockTCPClientSocket(int index) const; + MockTCPClientSocket* GetMockTCPClientSocket(size_t index) const; // Return |index|-th MockSSLClientSocket (starting from 0) that the factory // created. - MockSSLClientSocket* GetMockSSLClientSocket(int index) const; + MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; // ClientSocketFactory virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses); @@ -276,7 +276,7 @@ class MockClientSocket : public net::SSLClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; // SSLClientSocket methods: virtual void GetSSLInfo(net::SSLInfo* ssl_info); diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index cbb2dc7..fbaf77c 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -449,9 +449,8 @@ int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) { return OK; } -int SOCKS5ClientSocket::GetPeerName(struct sockaddr* name, - socklen_t* namelen) { - return transport_->GetPeerName(name, namelen); +int SOCKS5ClientSocket::GetPeerAddress(AddressList* address) const { + return transport_->GetPeerAddress(address); } } // namespace net diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index dec1cc4..629180a 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -56,7 +56,7 @@ class SOCKS5ClientSocket : public ClientSocket { virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; private: enum State { diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 7d8aaf9..a9135d2 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -398,9 +398,8 @@ int SOCKSClientSocket::DoHandshakeReadComplete(int result) { // Note: we ignore the last 6 bytes as specified by the SOCKS protocol } -int SOCKSClientSocket::GetPeerName(struct sockaddr* name, - socklen_t* namelen) { - return transport_->GetPeerName(name, namelen); +int SOCKSClientSocket::GetPeerAddress(AddressList* address) const { + return transport_->GetPeerAddress(address); } } // namespace net diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index dc0b287..7dc998a 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -52,7 +52,7 @@ class SOCKSClientSocket : public ClientSocket { virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; private: FRIEND_TEST(SOCKSClientSocketTest, CompleteHandshake); diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index fd374b7..af0a11a 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -5,10 +5,14 @@ #include "net/socket/ssl_client_socket_mac.h" #include <CoreServices/CoreServices.h> +#include <netdb.h> +#include <sys/socket.h> +#include <sys/types.h> #include "base/scoped_cftyperef.h" #include "base/singleton.h" #include "base/string_util.h" +#include "net/base/address_list.h" #include "net/base/cert_verifier.h" #include "net/base/io_buffer.h" #include "net/base/load_log.h" @@ -579,8 +583,8 @@ bool SSLClientSocketMac::IsConnectedAndIdle() const { return completed_handshake_ && transport_->IsConnectedAndIdle(); } -int SSLClientSocketMac::GetPeerName(struct sockaddr* name, socklen_t* namelen) { - return transport_->GetPeerName(name, namelen); +int SSLClientSocketMac::GetPeerAddress(AddressList* address) const { + return transport_->GetPeerAddress(address); } int SSLClientSocketMac::Read(IOBuffer* buf, int buf_len, @@ -745,22 +749,20 @@ int SSLClientSocketMac::InitializeSSLContext() { // using the same hostname (i.e., localhost and 127.0.0.1 are considered // different peers, which puts us through certificate validation again // and catches hostname/certificate name mismatches. - struct sockaddr_storage addr; - socklen_t addr_length = sizeof(struct sockaddr_storage); - memset(&addr, 0, sizeof(addr)); - if (!transport_->GetPeerName(reinterpret_cast<struct sockaddr*>(&addr), - &addr_length)) { - // Assemble the socket hostname and address into a single buffer. - std::vector<char> peer_id(hostname_.begin(), hostname_.end()); - peer_id.insert(peer_id.end(), reinterpret_cast<char*>(&addr), - reinterpret_cast<char*>(&addr) + addr_length); - - // SSLSetPeerID() treats peer_id as a binary blob, and makes its - // own copy. - status = SSLSetPeerID(ssl_context_, &peer_id[0], peer_id.size()); - if (status) - return NetErrorFromOSStatus(status); - } + AddressList address; + int rv = transport_->GetPeerAddress(&address); + if (rv != OK) + return rv; + const struct addrinfo* ai = address.head(); + std::string peer_id(hostname_); + peer_id += std::string(reinterpret_cast<char*>(ai->ai_addr), + ai->ai_addrlen); + + // SSLSetPeerID() treats peer_id as a binary blob, and makes its + // own copy. + status = SSLSetPeerID(ssl_context_, peer_id.data(), peer_id.length()); + if (status) + return NetErrorFromOSStatus(status); } else { // If I can't break on cert-requested, then set the cert up-front: status = SetClientCert(); diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index 6b2eb48..3e6f97e 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -43,7 +43,7 @@ class SSLClientSocketMac : public SSLClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; // Socket methods: virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index d6c321f..2c703d6 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -64,12 +64,14 @@ #include "base/nss_util.h" #include "base/singleton.h" #include "base/string_util.h" +#include "net/base/address_list.h" #include "net/base/cert_verifier.h" #include "net/base/io_buffer.h" #include "net/base/load_log.h" #include "net/base/net_errors.h" #include "net/base/ssl_cert_request_info.h" #include "net/base/ssl_info.h" +#include "net/base/sys_addrinfo.h" #include "net/ocsp/nss_ocsp.h" static const int kRecvBufferSize = 4096; @@ -313,15 +315,12 @@ int SSLClientSocketNSS::InitializeSSLOptions() { } // Tell NSS who we're connected to - PRNetAddr peername; - socklen_t len = sizeof(PRNetAddr); - int err = transport_->GetPeerName((struct sockaddr *)&peername, &len); - if (err) { - DLOG(ERROR) << "GetPeerName failed"; - // TODO(wtc): Change GetPeerName to return a network error code. - return ERR_UNEXPECTED; - } - memio_SetPeerName(nss_fd_, &peername); + AddressList peer_address; + int err = transport_->GetPeerAddress(&peer_address); + if (err != OK) + return err; + const struct addrinfo* ai = peer_address.head(); + memio_SetPeerName(nss_fd_, ai->ai_addr, ai->ai_addrlen); // Grab pointer to buffers nss_bufs_ = memio_GetSecret(nss_fd_); @@ -429,9 +428,10 @@ int SSLClientSocketNSS::InitializeSSLOptions() { // Set the peer ID for session reuse. This is necessary when we create an // SSL tunnel through a proxy -- GetPeerName returns the proxy's address // rather than the destination server's address in that case. - // TODO(wtc): port in peername is not the server's port when a proxy is used. + // TODO(wtc): port in |peer_address| is not the server's port when a proxy is + // used. std::string peer_id = StringPrintf("%s:%d", hostname_.c_str(), - PR_ntohs(PR_NetAddrInetPort(&peername))); + peer_address.GetPort()); rv = SSL_SetSockPeerID(nss_fd_, const_cast<char*>(peer_id.c_str())); if (rv != SECSuccess) LOG(INFO) << "SSL_SetSockPeerID failed: peer_id=" << peer_id; @@ -515,8 +515,8 @@ bool SSLClientSocketNSS::IsConnectedAndIdle() const { return ret; } -int SSLClientSocketNSS::GetPeerName(struct sockaddr* name, socklen_t* namelen) { - return transport_->GetPeerName(name, namelen); +int SSLClientSocketNSS::GetPeerAddress(AddressList* address) const { + return transport_->GetPeerAddress(address); } int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 7e59ea8..a33b703 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -48,7 +48,7 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; // Socket methods: virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 6e8d86d..5acfa0f 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -566,8 +566,8 @@ bool SSLClientSocketWin::IsConnectedAndIdle() const { return completed_handshake() && transport_->IsConnectedAndIdle(); } -int SSLClientSocketWin::GetPeerName(struct sockaddr* name, socklen_t* namelen) { - return transport_->GetPeerName(name, namelen); +int SSLClientSocketWin::GetPeerAddress(AddressList* address) const { + return transport_->GetPeerAddress(address); } int SSLClientSocketWin::Read(IOBuffer* buf, int buf_len, diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index c5d6cf7..1321c00 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -46,7 +46,7 @@ class SSLClientSocketWin : public SSLClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; // Socket methods: virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index 2c1c73d..3743495 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -468,9 +468,12 @@ void TCPClientSocketLibevent::DidCompleteWrite() { } } -int TCPClientSocketLibevent::GetPeerName(struct sockaddr* name, - socklen_t* namelen) { - return ::getpeername(socket_, name, namelen); +int TCPClientSocketLibevent::GetPeerAddress(AddressList* address) const { + DCHECK(address); + if (!current_ai_) + return ERR_UNEXPECTED; + address->Copy(current_ai_, false); + return OK; } } // namespace net diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index b054805..55c2fc4 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -33,7 +33,7 @@ class TCPClientSocketLibevent : public ClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; // Socket methods: // Multiple outstanding requests are not supported. diff --git a/net/socket/tcp_client_socket_pool_unittest.cc b/net/socket/tcp_client_socket_pool_unittest.cc index 20a42cb..2678848 100644 --- a/net/socket/tcp_client_socket_pool_unittest.cc +++ b/net/socket/tcp_client_socket_pool_unittest.cc @@ -43,7 +43,7 @@ class MockClientSocket : public ClientSocket { virtual bool IsConnectedAndIdle() const { return connected_; } - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen) { + virtual int GetPeerAddress(AddressList* address) const { return ERR_UNEXPECTED; } @@ -80,7 +80,7 @@ class MockFailingClientSocket : public ClientSocket { virtual bool IsConnectedAndIdle() const { return false; } - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen) { + virtual int GetPeerAddress(AddressList* address) const { return ERR_UNEXPECTED; } @@ -122,7 +122,7 @@ class MockPendingClientSocket : public ClientSocket { virtual bool IsConnectedAndIdle() const { return is_connected_; } - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen) { + virtual int GetPeerAddress(AddressList* address) const{ return ERR_UNEXPECTED; } diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index 32c7725..ac5823da 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -435,9 +435,12 @@ bool TCPClientSocketWin::IsConnectedAndIdle() const { return true; } -int TCPClientSocketWin::GetPeerName(struct sockaddr* name, - socklen_t* namelen) { - return getpeername(socket_, name, namelen); +int TCPClientSocketWin::GetPeerAddress(AddressList* address) const { + DCHECK(address); + if (!current_ai_) + return ERR_FAILED; + address->Copy(current_ai_, false); + return OK; } int TCPClientSocketWin::Read(IOBuffer* buf, diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index 9ad1632..6acfa8e 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -5,6 +5,8 @@ #ifndef NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ #define NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ +#include <winsock2.h> + #include "base/object_watcher.h" #include "net/base/address_list.h" #include "net/base/completion_callback.h" @@ -28,7 +30,7 @@ class TCPClientSocketWin : public ClientSocket { virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; - virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); + virtual int GetPeerAddress(AddressList* address) const; // Socket methods: // Multiple outstanding requests are not supported. diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index 4d80df2..162e7f3 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -252,7 +252,7 @@ void SocketStream::SetClientSocketFactory( } void SocketStream::CopyAddrInfo(struct addrinfo* head) { - addresses_.Copy(head); + addresses_.Copy(head, true); } int SocketStream::DidEstablishConnection() { |