diff options
author | agayev@chromium.org <agayev@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-06-29 03:47:04 +0000 |
---|---|---|
committer | agayev@chromium.org <agayev@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-06-29 03:47:04 +0000 |
commit | 5370c013eb6372dbffe91de3fde793da6b74e4e1 (patch) | |
tree | f4a14d0380e1c673c1e0d54f75a192fc1b504b9e /net | |
parent | 034bda715a6756a9b07de1fe9db9ceb6caf73123 (diff) | |
download | chromium_src-5370c013eb6372dbffe91de3fde793da6b74e4e1.zip chromium_src-5370c013eb6372dbffe91de3fde793da6b74e4e1.tar.gz chromium_src-5370c013eb6372dbffe91de3fde793da6b74e4e1.tar.bz2 |
Add support for random UDP source port selection to avoid birthday attacks in DNS implementation.
BUG=60149
TEST=net_unittests
Review URL: http://codereview.chromium.org/7202011
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@90925 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/base/net_error_list.h | 3 | ||||
-rw-r--r-- | net/base/net_errors_posix.cc | 2 | ||||
-rw-r--r-- | net/base/net_errors_win.cc | 2 | ||||
-rw-r--r-- | net/curvecp/client_packetizer.cc | 5 | ||||
-rw-r--r-- | net/socket/client_socket_factory.cc | 4 | ||||
-rw-r--r-- | net/socket/client_socket_factory.h | 4 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base_unittest.cc | 2 | ||||
-rw-r--r-- | net/socket/socket_test_util.cc | 4 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 4 | ||||
-rw-r--r-- | net/socket/transport_client_socket_pool_unittest.cc | 2 | ||||
-rw-r--r-- | net/udp/datagram_socket.h | 6 | ||||
-rw-r--r-- | net/udp/udp_client_socket.cc | 9 | ||||
-rw-r--r-- | net/udp/udp_client_socket.h | 5 | ||||
-rw-r--r-- | net/udp/udp_server_socket.cc | 8 | ||||
-rw-r--r-- | net/udp/udp_socket_libevent.cc | 82 | ||||
-rw-r--r-- | net/udp/udp_socket_libevent.h | 18 | ||||
-rw-r--r-- | net/udp/udp_socket_unittest.cc | 97 | ||||
-rw-r--r-- | net/udp/udp_socket_win.cc | 61 | ||||
-rw-r--r-- | net/udp/udp_socket_win.h | 16 |
19 files changed, 285 insertions, 49 deletions
diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h index 6cf561b..b6aaa5b 100644 --- a/net/base/net_error_list.h +++ b/net/base/net_error_list.h @@ -252,6 +252,9 @@ NET_ERROR(DNS_SERVER_FAILED, -145) // WebSocket abort SocketStream connection when alternate protocol is found. NET_ERROR(PROTOCOL_SWITCHED, -146) +// Returned when attempting to bind an address that is already in use. +NET_ERROR(ADDRESS_IN_USE, -147) + // Certificate error codes // // The values of certificate error codes must be consecutive. diff --git a/net/base/net_errors_posix.cc b/net/base/net_errors_posix.cc index 6958801..31450f9 100644 --- a/net/base/net_errors_posix.cc +++ b/net/base/net_errors_posix.cc @@ -46,6 +46,8 @@ Error MapSystemError(int os_error) { return ERR_SOCKET_NOT_CONNECTED; case EINVAL: return ERR_INVALID_ARGUMENT; + case EADDRINUSE: + return ERR_ADDRESS_IN_USE; case 0: return OK; default: diff --git a/net/base/net_errors_win.cc b/net/base/net_errors_win.cc index 0fff5d6..c290020 100644 --- a/net/base/net_errors_win.cc +++ b/net/base/net_errors_win.cc @@ -46,6 +46,8 @@ Error MapSystemError(int os_error) { return ERR_ADDRESS_UNREACHABLE; case WSAEINVAL: return ERR_INVALID_ARGUMENT; + case WSAEADDRINUSE: + return ERR_ADDRESS_IN_USE; case ERROR_SUCCESS: return OK; default: diff --git a/net/curvecp/client_packetizer.cc b/net/curvecp/client_packetizer.cc index 54efddd..c609e96 100644 --- a/net/curvecp/client_packetizer.cc +++ b/net/curvecp/client_packetizer.cc @@ -290,7 +290,10 @@ int ClientPacketizer::ConnectNextAddress() { DCHECK(addresses_.head()); - socket_.reset(new UDPClientSocket(NULL, NetLog::Source())); + socket_.reset(new UDPClientSocket(DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL, + NetLog::Source())); // Rotate to next address in the list. if (current_address_) diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc index b6efe12..1104d25 100644 --- a/net/socket/client_socket_factory.cc +++ b/net/socket/client_socket_factory.cc @@ -54,9 +54,11 @@ class DefaultClientSocketFactory : public ClientSocketFactory, } virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) { - return new UDPClientSocket(net_log, source); + return new UDPClientSocket(bind_type, rand_int_cb, net_log, source); } virtual StreamSocket* CreateTransportClientSocket( diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h index e5a6956..d1fe2f7 100644 --- a/net/socket/client_socket_factory.h +++ b/net/socket/client_socket_factory.h @@ -11,6 +11,8 @@ #include "base/basictypes.h" #include "net/base/net_api.h" #include "net/base/net_log.h" +#include "net/base/rand_callback.h" +#include "net/udp/datagram_socket.h" namespace net { @@ -34,6 +36,8 @@ class NET_API ClientSocketFactory { // |source| is the NetLog::Source for the entity trying to create the socket, // if it has one. virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) = 0; diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 6cdbdd3..cb59623 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -115,6 +115,8 @@ class MockClientSocketFactory : public ClientSocketFactory { MockClientSocketFactory() : allocation_count_(0) {} virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) { NOTREACHED(); diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 0573f73..474a739 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -589,6 +589,8 @@ MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( } DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); @@ -1440,6 +1442,8 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory:: DatagramClientSocket* DeterministicMockClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) { NOTREACHED(); diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index d1f4816..52bfdcd 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -536,6 +536,8 @@ class MockClientSocketFactory : public ClientSocketFactory { // ClientSocketFactory virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source); virtual StreamSocket* CreateTransportClientSocket( @@ -937,6 +939,8 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { // ClientSocketFactory virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source); virtual StreamSocket* CreateTransportClientSocket( diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index d12cd9d..5c4d967 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -278,6 +278,8 @@ class MockClientSocketFactory : public ClientSocketFactory { delay_ms_(ClientSocketPool::kMaxConnectRetryIntervalMs) {} virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) { NOTREACHED(); diff --git a/net/udp/datagram_socket.h b/net/udp/datagram_socket.h index b60e079..65bc5d6 100644 --- a/net/udp/datagram_socket.h +++ b/net/udp/datagram_socket.h @@ -16,6 +16,12 @@ class IPEndPoint; // datagrams, like UDP. class NET_TEST DatagramSocket { public: + // Type of source port binding to use. + enum BindType { + RANDOM_BIND, + DEFAULT_BIND, + }; + virtual ~DatagramSocket() {} // Close the socket. diff --git a/net/udp/udp_client_socket.cc b/net/udp/udp_client_socket.cc index 912394b..8df1fca 100644 --- a/net/udp/udp_client_socket.cc +++ b/net/udp/udp_client_socket.cc @@ -8,10 +8,11 @@ namespace net { -UDPClientSocket::UDPClientSocket( - net::NetLog* net_log, - const net::NetLog::Source& source) - : socket_(net_log, source) { +UDPClientSocket::UDPClientSocket(DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(bind_type, rand_int_cb, net_log, source) { } UDPClientSocket::~UDPClientSocket() { diff --git a/net/udp/udp_client_socket.h b/net/udp/udp_client_socket.h index 3dc8778..b26393c 100644 --- a/net/udp/udp_client_socket.h +++ b/net/udp/udp_client_socket.h @@ -7,6 +7,7 @@ #pragma once #include "net/base/net_log.h" +#include "net/base/rand_callback.h" #include "net/udp/datagram_client_socket.h" #include "net/udp/udp_socket.h" @@ -17,7 +18,9 @@ class BoundNetLog; // A client socket that uses UDP as the transport layer. class NET_TEST UDPClientSocket : public DatagramClientSocket { public: - UDPClientSocket(net::NetLog* net_log, + UDPClientSocket(DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, const net::NetLog::Source& source); virtual ~UDPClientSocket(); diff --git a/net/udp/udp_server_socket.cc b/net/udp/udp_server_socket.cc index acab7e3..31df603 100644 --- a/net/udp/udp_server_socket.cc +++ b/net/udp/udp_server_socket.cc @@ -4,11 +4,16 @@ #include "net/udp/udp_server_socket.h" +#include "net/base/rand_callback.h" + namespace net { UDPServerSocket::UDPServerSocket(net::NetLog* net_log, const net::NetLog::Source& source) - : socket_(net_log, source) { + : socket_(DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + net_log, + source) { } UDPServerSocket::~UDPServerSocket() { @@ -44,5 +49,4 @@ int UDPServerSocket::GetLocalAddress(IPEndPoint* address) const { return socket_.GetLocalAddress(address); } - } // namespace net diff --git a/net/udp/udp_socket_libevent.cc b/net/udp/udp_socket_libevent.cc index c451696..201f208 100644 --- a/net/udp/udp_socket_libevent.cc +++ b/net/udp/udp_socket_libevent.cc @@ -13,6 +13,7 @@ #include "base/logging.h" #include "base/message_loop.h" #include "base/metrics/stats_counters.h" +#include "base/rand_util.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -22,23 +23,38 @@ #include <netinet/in.h> #endif +namespace { + +static const int kBindRetries = 10; +static const int kPortStart = 1024; +static const int kPortEnd = 65535; + +} // namespace net + namespace net { -UDPSocketLibevent::UDPSocketLibevent(net::NetLog* net_log, - const net::NetLog::Source& source) - : socket_(kInvalidSocket), - read_watcher_(this), - write_watcher_(this), - read_buf_len_(0), - recv_from_address_(NULL), - write_buf_len_(0), - read_callback_(NULL), - write_callback_(NULL), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { +UDPSocketLibevent::UDPSocketLibevent( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(kInvalidSocket), + bind_type_(bind_type), + rand_int_cb_(rand_int_cb), + read_watcher_(this), + write_watcher_(this), + read_buf_len_(0), + recv_from_address_(NULL), + write_buf_len_(0), + read_callback_(NULL), + write_callback_(NULL), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { scoped_refptr<NetLog::EventParameters> params; if (source.is_valid()) params = new NetLogSourceParameter("source_dependency", source); net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, params); + if (bind_type == DatagramSocket::RANDOM_BIND) + DCHECK(!rand_int_cb.is_null()); } UDPSocketLibevent::~UDPSocketLibevent() { @@ -208,6 +224,13 @@ int UDPSocketLibevent::Connect(const IPEndPoint& address) { if (rv < 0) return rv; + if (bind_type_ == DatagramSocket::RANDOM_BIND) + rv = RandomBind(address); + // else connect() does the DatagramSocket::DEFAULT_BIND + + if (rv < 0) + return rv; + struct sockaddr_storage addr_storage; size_t addr_len = sizeof(addr_storage); struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); @@ -224,21 +247,12 @@ int UDPSocketLibevent::Connect(const IPEndPoint& address) { int UDPSocketLibevent::Bind(const IPEndPoint& address) { DCHECK(!is_connected()); - DCHECK(!local_address_.get()); int rv = CreateSocket(address); if (rv < 0) return rv; - - struct sockaddr_storage addr_storage; - size_t addr_len = sizeof(addr_storage); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - if (!address.ToSockAddr(addr, &addr_len)) - return ERR_FAILED; - - rv = bind(socket_, addr, addr_len); + rv = DoBind(address); if (rv < 0) - return MapSystemError(errno); - + return rv; local_address_.reset(); return rv; } @@ -359,4 +373,28 @@ int UDPSocketLibevent::InternalSendTo(IOBuffer* buf, int buf_len, addr_len)); } +int UDPSocketLibevent::DoBind(const IPEndPoint& address) { + struct sockaddr_storage addr_storage; + size_t addr_len = sizeof(addr_storage); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + if (!address.ToSockAddr(addr, &addr_len)) + return ERR_UNEXPECTED; + int rv = bind(socket_, addr, addr_len); + return rv < 0 ? MapSystemError(errno) : rv; +} + +int UDPSocketLibevent::RandomBind(const IPEndPoint& address) { + DCHECK(bind_type_ == DatagramSocket::RANDOM_BIND && !rand_int_cb_.is_null()); + + // Construct IPAddressNumber of appropriate size (IPv4 or IPv6) of 0s. + IPAddressNumber ip(address.address().size()); + + for (int i = 0; i < kBindRetries; ++i) { + int rv = DoBind(IPEndPoint(ip, rand_int_cb_.Run(kPortStart, kPortEnd))); + if (rv == OK || rv != ERR_ADDRESS_IN_USE) + return rv; + } + return DoBind(IPEndPoint(ip, 0)); +} + } // namespace net diff --git a/net/udp/udp_socket_libevent.h b/net/udp/udp_socket_libevent.h index 55d4cca..6efaf49 100644 --- a/net/udp/udp_socket_libevent.h +++ b/net/udp/udp_socket_libevent.h @@ -11,9 +11,11 @@ #include "base/message_loop.h" #include "base/threading/non_thread_safe.h" #include "net/base/completion_callback.h" +#include "net/base/rand_callback.h" +#include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_log.h" -#include "net/socket/stream_socket.h" +#include "net/udp/datagram_socket.h" namespace net { @@ -21,7 +23,9 @@ class BoundNetLog; class UDPSocketLibevent : public base::NonThreadSafe { public: - UDPSocketLibevent(net::NetLog* net_log, + UDPSocketLibevent(DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, const net::NetLog::Source& source); virtual ~UDPSocketLibevent(); @@ -153,8 +157,18 @@ class UDPSocketLibevent : public base::NonThreadSafe { int InternalRecvFrom(IOBuffer* buf, int buf_len, IPEndPoint* address); int InternalSendTo(IOBuffer* buf, int buf_len, const IPEndPoint* address); + int DoBind(const IPEndPoint& address); + int RandomBind(const IPEndPoint& address); + int socket_; + // How to do source port binding, used only when UDPSocket is part of + // UDPClientSocket, since UDPServerSocket provides Bind. + DatagramSocket::BindType bind_type_; + + // PRNG function for generating port numbers. + RandIntCallback rand_int_cb_; + // These are mutable since they're just cached copies to make // GetPeerAddress/GetLocalAddress smarter. mutable scoped_ptr<IPEndPoint> local_address_; diff --git a/net/udp/udp_socket_unittest.cc b/net/udp/udp_socket_unittest.cc index 459b35d..9fe870f 100644 --- a/net/udp/udp_socket_unittest.cc +++ b/net/udp/udp_socket_unittest.cc @@ -6,7 +6,10 @@ #include "net/udp/udp_server_socket.h" #include "base/basictypes.h" +#include "base/bind.h" +#include "base/callback.h" #include "base/metrics/histogram.h" +#include "base/stl_util-inl.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -136,7 +139,10 @@ TEST_F(UDPSocketTest, Connect) { // Setup the client. IPEndPoint server_address; CreateUDPAddress("127.0.0.1", kPort, &server_address); - UDPClientSocket client(NULL, NetLog::Source()); + UDPClientSocket client(DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL, + NetLog::Source()); rv = client.Connect(server_address); EXPECT_EQ(OK, rv); @@ -157,6 +163,85 @@ TEST_F(UDPSocketTest, Connect) { DCHECK(simple_message == str); } +// In this test, we verify that random binding logic works, which attempts +// to bind to a random port and returns if succeeds, otherwise retries for +// |kBindRetries| number of times. + +// To generate the scenario, we first create |kBindRetries| number of +// UDPClientSockets with default binding policy and connect to the same +// peer and save the used port numbers. Then we get rid of the last +// socket, making sure that the local port it was bound to is available. +// Finally, we create a socket with random binding policy, passing it a +// test PRNG that would serve used port numbers in the array, one after +// another. At the end, we make sure that the test socket was bound to the +// port that became available after deleting the last socket with default +// binding policy. + +// We do not test the randomness of bound ports, but that we are using +// passed in PRNG correctly, thus, it's the duty of PRNG to produce strong +// random numbers. +static const int kBindRetries = 10; + +class TestPrng { + public: + explicit TestPrng(const std::deque<int>& numbers) : numbers_(numbers) {} + int GetNext(int /* min */, int /* max */) { + DCHECK(!numbers_.empty()); + int rv = numbers_.front(); + numbers_.pop_front(); + return rv; + } + private: + std::deque<int> numbers_; + + DISALLOW_COPY_AND_ASSIGN(TestPrng); +}; + +TEST_F(UDPSocketTest, ConnectRandomBind) { + std::vector<UDPClientSocket*> sockets; + IPEndPoint peer_address; + CreateUDPAddress("192.168.1.13", 53, &peer_address); + + // Create and connect sockets and save port numbers. + std::deque<int> used_ports; + for (int i = 0; i < kBindRetries; ++i) { + UDPClientSocket* socket = + new UDPClientSocket(DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL, + NetLog::Source()); + sockets.push_back(socket); + EXPECT_EQ(OK, socket->Connect(peer_address)); + + IPEndPoint client_address; + EXPECT_EQ(OK, socket->GetLocalAddress(&client_address)); + used_ports.push_back(client_address.port()); + } + + // Free the last socket, its local port is still in |used_ports|. + delete sockets.back(); + sockets.pop_back(); + + TestPrng test_prng(used_ports); + RandIntCallback rand_int_cb = + base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng)); + + // Create a socket with random binding policy and connect. + scoped_ptr<UDPClientSocket> test_socket( + new UDPClientSocket(DatagramSocket::RANDOM_BIND, + rand_int_cb, + NULL, + NetLog::Source())); + EXPECT_EQ(OK, test_socket->Connect(peer_address)); + + // Make sure that the last port number in the |used_ports| was used. + IPEndPoint client_address; + EXPECT_EQ(OK, test_socket->GetLocalAddress(&client_address)); + EXPECT_EQ(used_ports.back(), client_address.port()); + + STLDeleteElements(&sockets); +} + // In this test, we verify that connect() on a socket will have the effect // of filtering reads on this socket only to data read from the destination // we connected to. @@ -187,7 +272,10 @@ TEST_F(UDPSocketTest, VerifyConnectBindsAddr) { // Setup the client, connected to server 1. IPEndPoint server_address; CreateUDPAddress("127.0.0.1", kPort1, &server_address); - UDPClientSocket client(NULL, NetLog::Source()); + UDPClientSocket client(DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL, + NetLog::Source()); rv = client.Connect(server_address); EXPECT_EQ(OK, rv); @@ -240,7 +328,10 @@ TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) { net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number); net::IPEndPoint local_address(ip_number, 80); - UDPClientSocket client(NULL, NetLog::Source()); + UDPClientSocket client(DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL, + NetLog::Source()); int rv = client.Connect(remote_address); if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) { // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6 diff --git a/net/udp/udp_socket_win.cc b/net/udp/udp_socket_win.cc index bf375dc..5840f77 100644 --- a/net/udp/udp_socket_win.cc +++ b/net/udp/udp_socket_win.cc @@ -11,6 +11,7 @@ #include "base/memory/memory_debug.h" #include "base/message_loop.h" #include "base/metrics/stats_counters.h" +#include "base/rand_util.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -19,6 +20,14 @@ #include "net/base/winsock_init.h" #include "net/base/winsock_util.h" +namespace { + +static const int kBindRetries = 10; +static const int kPortStart = 1024; +static const int kPortEnd = 65535; + +} // namespace net + namespace net { void UDPSocketWin::ReadDelegate::OnObjectSignaled(HANDLE object) { @@ -31,9 +40,13 @@ void UDPSocketWin::WriteDelegate::OnObjectSignaled(HANDLE object) { socket_->DidCompleteWrite(); } -UDPSocketWin::UDPSocketWin(net::NetLog* net_log, +UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, const net::NetLog::Source& source) : socket_(INVALID_SOCKET), + bind_type_(bind_type), + rand_int_cb_(rand_int_cb), ALLOW_THIS_IN_INITIALIZER_LIST(read_delegate_(this)), ALLOW_THIS_IN_INITIALIZER_LIST(write_delegate_(this)), recv_from_address_(NULL), @@ -49,6 +62,8 @@ UDPSocketWin::UDPSocketWin(net::NetLog* net_log, read_overlapped_.hEvent = WSACreateEvent(); memset(&write_overlapped_, 0, sizeof(write_overlapped_)); write_overlapped_.hEvent = WSACreateEvent(); + if (bind_type == DatagramSocket::RANDOM_BIND) + DCHECK(!rand_int_cb.is_null()); } UDPSocketWin::~UDPSocketWin() { @@ -184,6 +199,13 @@ int UDPSocketWin::Connect(const IPEndPoint& address) { if (rv < 0) return rv; + if (bind_type_ == DatagramSocket::RANDOM_BIND) + rv = RandomBind(address); + // else connect() does the DatagramSocket::DEFAULT_BIND + + if (rv < 0) + return rv; + struct sockaddr_storage addr_storage; size_t addr_len = sizeof(addr_storage); struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); @@ -200,21 +222,12 @@ int UDPSocketWin::Connect(const IPEndPoint& address) { int UDPSocketWin::Bind(const IPEndPoint& address) { DCHECK(!is_connected()); - DCHECK(!local_address_.get()); int rv = CreateSocket(address); if (rv < 0) return rv; - - struct sockaddr_storage addr_storage; - size_t addr_len = sizeof(addr_storage); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - if (!address.ToSockAddr(addr, &addr_len)) - return ERR_FAILED; - - rv = bind(socket_, addr, addr_len); + rv = DoBind(address); if (rv < 0) - return MapSystemError(WSAGetLastError()); - + return rv; local_address_.reset(); return rv; } @@ -370,4 +383,28 @@ int UDPSocketWin::InternalSendTo(IOBuffer* buf, int buf_len, return ERR_IO_PENDING; } +int UDPSocketWin::DoBind(const IPEndPoint& address) { + struct sockaddr_storage addr_storage; + size_t addr_len = sizeof(addr_storage); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + if (!address.ToSockAddr(addr, &addr_len)) + return ERR_UNEXPECTED; + int rv = bind(socket_, addr, addr_len); + return rv < 0 ? MapSystemError(WSAGetLastError()) : rv; +} + +int UDPSocketWin::RandomBind(const IPEndPoint& address) { + DCHECK(bind_type_ == DatagramSocket::RANDOM_BIND && !rand_int_cb_.is_null()); + + // Construct IPAddressNumber of appropriate size (IPv4 or IPv6) of 0s. + IPAddressNumber ip(address.address().size()); + + for (int i = 0; i < kBindRetries; ++i) { + int rv = DoBind(IPEndPoint(ip, rand_int_cb_.Run(kPortStart, kPortEnd))); + if (rv == OK || rv != ERR_ADDRESS_IN_USE) + return rv; + } + return DoBind(IPEndPoint(ip, 0)); +} + } // namespace net diff --git a/net/udp/udp_socket_win.h b/net/udp/udp_socket_win.h index 96dc55f..01d5855 100644 --- a/net/udp/udp_socket_win.h +++ b/net/udp/udp_socket_win.h @@ -13,9 +13,11 @@ #include "base/threading/non_thread_safe.h" #include "base/win/object_watcher.h" #include "net/base/completion_callback.h" +#include "net/base/rand_callback.h" #include "net/base/ip_endpoint.h" #include "net/base/io_buffer.h" #include "net/base/net_log.h" +#include "net/udp/datagram_socket.h" namespace net { @@ -23,7 +25,9 @@ class BoundNetLog; class UDPSocketWin : public base::NonThreadSafe { public: - UDPSocketWin(net::NetLog* net_log, + UDPSocketWin(DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, const net::NetLog::Source& source); virtual ~UDPSocketWin(); @@ -141,8 +145,18 @@ class UDPSocketWin : public base::NonThreadSafe { int InternalRecvFrom(IOBuffer* buf, int buf_len, IPEndPoint* address); int InternalSendTo(IOBuffer* buf, int buf_len, const IPEndPoint* address); + int DoBind(const IPEndPoint& address); + int RandomBind(const IPEndPoint& address); + SOCKET socket_; + // How to do source port binding, used only when UDPSocket is part of + // UDPClientSocket, since UDPServerSocket provides Bind. + DatagramSocket::BindType bind_type_; + + // PRNG function for generating port numbers. + RandIntCallback rand_int_cb_; + // These are mutable since they're just cached copies to make // GetPeerAddress/GetLocalAddress smarter. mutable scoped_ptr<IPEndPoint> local_address_; |