summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authoragayev@chromium.org <agayev@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-06-29 03:47:04 +0000
committeragayev@chromium.org <agayev@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-06-29 03:47:04 +0000
commit5370c013eb6372dbffe91de3fde793da6b74e4e1 (patch)
treef4a14d0380e1c673c1e0d54f75a192fc1b504b9e /net
parent034bda715a6756a9b07de1fe9db9ceb6caf73123 (diff)
downloadchromium_src-5370c013eb6372dbffe91de3fde793da6b74e4e1.zip
chromium_src-5370c013eb6372dbffe91de3fde793da6b74e4e1.tar.gz
chromium_src-5370c013eb6372dbffe91de3fde793da6b74e4e1.tar.bz2
Add support for random UDP source port selection to avoid birthday attacks in DNS implementation.
BUG=60149 TEST=net_unittests Review URL: http://codereview.chromium.org/7202011 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@90925 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r--net/base/net_error_list.h3
-rw-r--r--net/base/net_errors_posix.cc2
-rw-r--r--net/base/net_errors_win.cc2
-rw-r--r--net/curvecp/client_packetizer.cc5
-rw-r--r--net/socket/client_socket_factory.cc4
-rw-r--r--net/socket/client_socket_factory.h4
-rw-r--r--net/socket/client_socket_pool_base_unittest.cc2
-rw-r--r--net/socket/socket_test_util.cc4
-rw-r--r--net/socket/socket_test_util.h4
-rw-r--r--net/socket/transport_client_socket_pool_unittest.cc2
-rw-r--r--net/udp/datagram_socket.h6
-rw-r--r--net/udp/udp_client_socket.cc9
-rw-r--r--net/udp/udp_client_socket.h5
-rw-r--r--net/udp/udp_server_socket.cc8
-rw-r--r--net/udp/udp_socket_libevent.cc82
-rw-r--r--net/udp/udp_socket_libevent.h18
-rw-r--r--net/udp/udp_socket_unittest.cc97
-rw-r--r--net/udp/udp_socket_win.cc61
-rw-r--r--net/udp/udp_socket_win.h16
19 files changed, 285 insertions, 49 deletions
diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h
index 6cf561b..b6aaa5b 100644
--- a/net/base/net_error_list.h
+++ b/net/base/net_error_list.h
@@ -252,6 +252,9 @@ NET_ERROR(DNS_SERVER_FAILED, -145)
// WebSocket abort SocketStream connection when alternate protocol is found.
NET_ERROR(PROTOCOL_SWITCHED, -146)
+// Returned when attempting to bind an address that is already in use.
+NET_ERROR(ADDRESS_IN_USE, -147)
+
// Certificate error codes
//
// The values of certificate error codes must be consecutive.
diff --git a/net/base/net_errors_posix.cc b/net/base/net_errors_posix.cc
index 6958801..31450f9 100644
--- a/net/base/net_errors_posix.cc
+++ b/net/base/net_errors_posix.cc
@@ -46,6 +46,8 @@ Error MapSystemError(int os_error) {
return ERR_SOCKET_NOT_CONNECTED;
case EINVAL:
return ERR_INVALID_ARGUMENT;
+ case EADDRINUSE:
+ return ERR_ADDRESS_IN_USE;
case 0:
return OK;
default:
diff --git a/net/base/net_errors_win.cc b/net/base/net_errors_win.cc
index 0fff5d6..c290020 100644
--- a/net/base/net_errors_win.cc
+++ b/net/base/net_errors_win.cc
@@ -46,6 +46,8 @@ Error MapSystemError(int os_error) {
return ERR_ADDRESS_UNREACHABLE;
case WSAEINVAL:
return ERR_INVALID_ARGUMENT;
+ case WSAEADDRINUSE:
+ return ERR_ADDRESS_IN_USE;
case ERROR_SUCCESS:
return OK;
default:
diff --git a/net/curvecp/client_packetizer.cc b/net/curvecp/client_packetizer.cc
index 54efddd..c609e96 100644
--- a/net/curvecp/client_packetizer.cc
+++ b/net/curvecp/client_packetizer.cc
@@ -290,7 +290,10 @@ int ClientPacketizer::ConnectNextAddress() {
DCHECK(addresses_.head());
- socket_.reset(new UDPClientSocket(NULL, NetLog::Source()));
+ socket_.reset(new UDPClientSocket(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL,
+ NetLog::Source()));
// Rotate to next address in the list.
if (current_address_)
diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc
index b6efe12..1104d25 100644
--- a/net/socket/client_socket_factory.cc
+++ b/net/socket/client_socket_factory.cc
@@ -54,9 +54,11 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
}
virtual DatagramClientSocket* CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) {
- return new UDPClientSocket(net_log, source);
+ return new UDPClientSocket(bind_type, rand_int_cb, net_log, source);
}
virtual StreamSocket* CreateTransportClientSocket(
diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h
index e5a6956..d1fe2f7 100644
--- a/net/socket/client_socket_factory.h
+++ b/net/socket/client_socket_factory.h
@@ -11,6 +11,8 @@
#include "base/basictypes.h"
#include "net/base/net_api.h"
#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
+#include "net/udp/datagram_socket.h"
namespace net {
@@ -34,6 +36,8 @@ class NET_API ClientSocketFactory {
// |source| is the NetLog::Source for the entity trying to create the socket,
// if it has one.
virtual DatagramClientSocket* CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) = 0;
diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc
index 6cdbdd3..cb59623 100644
--- a/net/socket/client_socket_pool_base_unittest.cc
+++ b/net/socket/client_socket_pool_base_unittest.cc
@@ -115,6 +115,8 @@ class MockClientSocketFactory : public ClientSocketFactory {
MockClientSocketFactory() : allocation_count_(0) {}
virtual DatagramClientSocket* CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) {
NOTREACHED();
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc
index 0573f73..474a739 100644
--- a/net/socket/socket_test_util.cc
+++ b/net/socket/socket_test_util.cc
@@ -589,6 +589,8 @@ MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
}
DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) {
SocketDataProvider* data_provider = mock_data_.GetNext();
@@ -1440,6 +1442,8 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory::
DatagramClientSocket*
DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) {
NOTREACHED();
diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h
index d1f4816..52bfdcd 100644
--- a/net/socket/socket_test_util.h
+++ b/net/socket/socket_test_util.h
@@ -536,6 +536,8 @@ class MockClientSocketFactory : public ClientSocketFactory {
// ClientSocketFactory
virtual DatagramClientSocket* CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source);
virtual StreamSocket* CreateTransportClientSocket(
@@ -937,6 +939,8 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory {
// ClientSocketFactory
virtual DatagramClientSocket* CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source);
virtual StreamSocket* CreateTransportClientSocket(
diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc
index d12cd9d..5c4d967 100644
--- a/net/socket/transport_client_socket_pool_unittest.cc
+++ b/net/socket/transport_client_socket_pool_unittest.cc
@@ -278,6 +278,8 @@ class MockClientSocketFactory : public ClientSocketFactory {
delay_ms_(ClientSocketPool::kMaxConnectRetryIntervalMs) {}
virtual DatagramClientSocket* CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
NetLog* net_log,
const NetLog::Source& source) {
NOTREACHED();
diff --git a/net/udp/datagram_socket.h b/net/udp/datagram_socket.h
index b60e079..65bc5d6 100644
--- a/net/udp/datagram_socket.h
+++ b/net/udp/datagram_socket.h
@@ -16,6 +16,12 @@ class IPEndPoint;
// datagrams, like UDP.
class NET_TEST DatagramSocket {
public:
+ // Type of source port binding to use.
+ enum BindType {
+ RANDOM_BIND,
+ DEFAULT_BIND,
+ };
+
virtual ~DatagramSocket() {}
// Close the socket.
diff --git a/net/udp/udp_client_socket.cc b/net/udp/udp_client_socket.cc
index 912394b..8df1fca 100644
--- a/net/udp/udp_client_socket.cc
+++ b/net/udp/udp_client_socket.cc
@@ -8,10 +8,11 @@
namespace net {
-UDPClientSocket::UDPClientSocket(
- net::NetLog* net_log,
- const net::NetLog::Source& source)
- : socket_(net_log, source) {
+UDPClientSocket::UDPClientSocket(DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(bind_type, rand_int_cb, net_log, source) {
}
UDPClientSocket::~UDPClientSocket() {
diff --git a/net/udp/udp_client_socket.h b/net/udp/udp_client_socket.h
index 3dc8778..b26393c 100644
--- a/net/udp/udp_client_socket.h
+++ b/net/udp/udp_client_socket.h
@@ -7,6 +7,7 @@
#pragma once
#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
#include "net/udp/datagram_client_socket.h"
#include "net/udp/udp_socket.h"
@@ -17,7 +18,9 @@ class BoundNetLog;
// A client socket that uses UDP as the transport layer.
class NET_TEST UDPClientSocket : public DatagramClientSocket {
public:
- UDPClientSocket(net::NetLog* net_log,
+ UDPClientSocket(DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
const net::NetLog::Source& source);
virtual ~UDPClientSocket();
diff --git a/net/udp/udp_server_socket.cc b/net/udp/udp_server_socket.cc
index acab7e3..31df603 100644
--- a/net/udp/udp_server_socket.cc
+++ b/net/udp/udp_server_socket.cc
@@ -4,11 +4,16 @@
#include "net/udp/udp_server_socket.h"
+#include "net/base/rand_callback.h"
+
namespace net {
UDPServerSocket::UDPServerSocket(net::NetLog* net_log,
const net::NetLog::Source& source)
- : socket_(net_log, source) {
+ : socket_(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ net_log,
+ source) {
}
UDPServerSocket::~UDPServerSocket() {
@@ -44,5 +49,4 @@ int UDPServerSocket::GetLocalAddress(IPEndPoint* address) const {
return socket_.GetLocalAddress(address);
}
-
} // namespace net
diff --git a/net/udp/udp_socket_libevent.cc b/net/udp/udp_socket_libevent.cc
index c451696..201f208 100644
--- a/net/udp/udp_socket_libevent.cc
+++ b/net/udp/udp_socket_libevent.cc
@@ -13,6 +13,7 @@
#include "base/logging.h"
#include "base/message_loop.h"
#include "base/metrics/stats_counters.h"
+#include "base/rand_util.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
@@ -22,23 +23,38 @@
#include <netinet/in.h>
#endif
+namespace {
+
+static const int kBindRetries = 10;
+static const int kPortStart = 1024;
+static const int kPortEnd = 65535;
+
+} // namespace net
+
namespace net {
-UDPSocketLibevent::UDPSocketLibevent(net::NetLog* net_log,
- const net::NetLog::Source& source)
- : socket_(kInvalidSocket),
- read_watcher_(this),
- write_watcher_(this),
- read_buf_len_(0),
- recv_from_address_(NULL),
- write_buf_len_(0),
- read_callback_(NULL),
- write_callback_(NULL),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
+UDPSocketLibevent::UDPSocketLibevent(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(kInvalidSocket),
+ bind_type_(bind_type),
+ rand_int_cb_(rand_int_cb),
+ read_watcher_(this),
+ write_watcher_(this),
+ read_buf_len_(0),
+ recv_from_address_(NULL),
+ write_buf_len_(0),
+ read_callback_(NULL),
+ write_callback_(NULL),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
scoped_refptr<NetLog::EventParameters> params;
if (source.is_valid())
params = new NetLogSourceParameter("source_dependency", source);
net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, params);
+ if (bind_type == DatagramSocket::RANDOM_BIND)
+ DCHECK(!rand_int_cb.is_null());
}
UDPSocketLibevent::~UDPSocketLibevent() {
@@ -208,6 +224,13 @@ int UDPSocketLibevent::Connect(const IPEndPoint& address) {
if (rv < 0)
return rv;
+ if (bind_type_ == DatagramSocket::RANDOM_BIND)
+ rv = RandomBind(address);
+ // else connect() does the DatagramSocket::DEFAULT_BIND
+
+ if (rv < 0)
+ return rv;
+
struct sockaddr_storage addr_storage;
size_t addr_len = sizeof(addr_storage);
struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
@@ -224,21 +247,12 @@ int UDPSocketLibevent::Connect(const IPEndPoint& address) {
int UDPSocketLibevent::Bind(const IPEndPoint& address) {
DCHECK(!is_connected());
- DCHECK(!local_address_.get());
int rv = CreateSocket(address);
if (rv < 0)
return rv;
-
- struct sockaddr_storage addr_storage;
- size_t addr_len = sizeof(addr_storage);
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
- if (!address.ToSockAddr(addr, &addr_len))
- return ERR_FAILED;
-
- rv = bind(socket_, addr, addr_len);
+ rv = DoBind(address);
if (rv < 0)
- return MapSystemError(errno);
-
+ return rv;
local_address_.reset();
return rv;
}
@@ -359,4 +373,28 @@ int UDPSocketLibevent::InternalSendTo(IOBuffer* buf, int buf_len,
addr_len));
}
+int UDPSocketLibevent::DoBind(const IPEndPoint& address) {
+ struct sockaddr_storage addr_storage;
+ size_t addr_len = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ if (!address.ToSockAddr(addr, &addr_len))
+ return ERR_UNEXPECTED;
+ int rv = bind(socket_, addr, addr_len);
+ return rv < 0 ? MapSystemError(errno) : rv;
+}
+
+int UDPSocketLibevent::RandomBind(const IPEndPoint& address) {
+ DCHECK(bind_type_ == DatagramSocket::RANDOM_BIND && !rand_int_cb_.is_null());
+
+ // Construct IPAddressNumber of appropriate size (IPv4 or IPv6) of 0s.
+ IPAddressNumber ip(address.address().size());
+
+ for (int i = 0; i < kBindRetries; ++i) {
+ int rv = DoBind(IPEndPoint(ip, rand_int_cb_.Run(kPortStart, kPortEnd)));
+ if (rv == OK || rv != ERR_ADDRESS_IN_USE)
+ return rv;
+ }
+ return DoBind(IPEndPoint(ip, 0));
+}
+
} // namespace net
diff --git a/net/udp/udp_socket_libevent.h b/net/udp/udp_socket_libevent.h
index 55d4cca..6efaf49 100644
--- a/net/udp/udp_socket_libevent.h
+++ b/net/udp/udp_socket_libevent.h
@@ -11,9 +11,11 @@
#include "base/message_loop.h"
#include "base/threading/non_thread_safe.h"
#include "net/base/completion_callback.h"
+#include "net/base/rand_callback.h"
+#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_log.h"
-#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_socket.h"
namespace net {
@@ -21,7 +23,9 @@ class BoundNetLog;
class UDPSocketLibevent : public base::NonThreadSafe {
public:
- UDPSocketLibevent(net::NetLog* net_log,
+ UDPSocketLibevent(DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
const net::NetLog::Source& source);
virtual ~UDPSocketLibevent();
@@ -153,8 +157,18 @@ class UDPSocketLibevent : public base::NonThreadSafe {
int InternalRecvFrom(IOBuffer* buf, int buf_len, IPEndPoint* address);
int InternalSendTo(IOBuffer* buf, int buf_len, const IPEndPoint* address);
+ int DoBind(const IPEndPoint& address);
+ int RandomBind(const IPEndPoint& address);
+
int socket_;
+ // How to do source port binding, used only when UDPSocket is part of
+ // UDPClientSocket, since UDPServerSocket provides Bind.
+ DatagramSocket::BindType bind_type_;
+
+ // PRNG function for generating port numbers.
+ RandIntCallback rand_int_cb_;
+
// These are mutable since they're just cached copies to make
// GetPeerAddress/GetLocalAddress smarter.
mutable scoped_ptr<IPEndPoint> local_address_;
diff --git a/net/udp/udp_socket_unittest.cc b/net/udp/udp_socket_unittest.cc
index 459b35d..9fe870f 100644
--- a/net/udp/udp_socket_unittest.cc
+++ b/net/udp/udp_socket_unittest.cc
@@ -6,7 +6,10 @@
#include "net/udp/udp_server_socket.h"
#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/callback.h"
#include "base/metrics/histogram.h"
+#include "base/stl_util-inl.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
@@ -136,7 +139,10 @@ TEST_F(UDPSocketTest, Connect) {
// Setup the client.
IPEndPoint server_address;
CreateUDPAddress("127.0.0.1", kPort, &server_address);
- UDPClientSocket client(NULL, NetLog::Source());
+ UDPClientSocket client(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL,
+ NetLog::Source());
rv = client.Connect(server_address);
EXPECT_EQ(OK, rv);
@@ -157,6 +163,85 @@ TEST_F(UDPSocketTest, Connect) {
DCHECK(simple_message == str);
}
+// In this test, we verify that random binding logic works, which attempts
+// to bind to a random port and returns if succeeds, otherwise retries for
+// |kBindRetries| number of times.
+
+// To generate the scenario, we first create |kBindRetries| number of
+// UDPClientSockets with default binding policy and connect to the same
+// peer and save the used port numbers. Then we get rid of the last
+// socket, making sure that the local port it was bound to is available.
+// Finally, we create a socket with random binding policy, passing it a
+// test PRNG that would serve used port numbers in the array, one after
+// another. At the end, we make sure that the test socket was bound to the
+// port that became available after deleting the last socket with default
+// binding policy.
+
+// We do not test the randomness of bound ports, but that we are using
+// passed in PRNG correctly, thus, it's the duty of PRNG to produce strong
+// random numbers.
+static const int kBindRetries = 10;
+
+class TestPrng {
+ public:
+ explicit TestPrng(const std::deque<int>& numbers) : numbers_(numbers) {}
+ int GetNext(int /* min */, int /* max */) {
+ DCHECK(!numbers_.empty());
+ int rv = numbers_.front();
+ numbers_.pop_front();
+ return rv;
+ }
+ private:
+ std::deque<int> numbers_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestPrng);
+};
+
+TEST_F(UDPSocketTest, ConnectRandomBind) {
+ std::vector<UDPClientSocket*> sockets;
+ IPEndPoint peer_address;
+ CreateUDPAddress("192.168.1.13", 53, &peer_address);
+
+ // Create and connect sockets and save port numbers.
+ std::deque<int> used_ports;
+ for (int i = 0; i < kBindRetries; ++i) {
+ UDPClientSocket* socket =
+ new UDPClientSocket(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL,
+ NetLog::Source());
+ sockets.push_back(socket);
+ EXPECT_EQ(OK, socket->Connect(peer_address));
+
+ IPEndPoint client_address;
+ EXPECT_EQ(OK, socket->GetLocalAddress(&client_address));
+ used_ports.push_back(client_address.port());
+ }
+
+ // Free the last socket, its local port is still in |used_ports|.
+ delete sockets.back();
+ sockets.pop_back();
+
+ TestPrng test_prng(used_ports);
+ RandIntCallback rand_int_cb =
+ base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng));
+
+ // Create a socket with random binding policy and connect.
+ scoped_ptr<UDPClientSocket> test_socket(
+ new UDPClientSocket(DatagramSocket::RANDOM_BIND,
+ rand_int_cb,
+ NULL,
+ NetLog::Source()));
+ EXPECT_EQ(OK, test_socket->Connect(peer_address));
+
+ // Make sure that the last port number in the |used_ports| was used.
+ IPEndPoint client_address;
+ EXPECT_EQ(OK, test_socket->GetLocalAddress(&client_address));
+ EXPECT_EQ(used_ports.back(), client_address.port());
+
+ STLDeleteElements(&sockets);
+}
+
// In this test, we verify that connect() on a socket will have the effect
// of filtering reads on this socket only to data read from the destination
// we connected to.
@@ -187,7 +272,10 @@ TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
// Setup the client, connected to server 1.
IPEndPoint server_address;
CreateUDPAddress("127.0.0.1", kPort1, &server_address);
- UDPClientSocket client(NULL, NetLog::Source());
+ UDPClientSocket client(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL,
+ NetLog::Source());
rv = client.Connect(server_address);
EXPECT_EQ(OK, rv);
@@ -240,7 +328,10 @@ TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number);
net::IPEndPoint local_address(ip_number, 80);
- UDPClientSocket client(NULL, NetLog::Source());
+ UDPClientSocket client(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL,
+ NetLog::Source());
int rv = client.Connect(remote_address);
if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
// Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
diff --git a/net/udp/udp_socket_win.cc b/net/udp/udp_socket_win.cc
index bf375dc..5840f77 100644
--- a/net/udp/udp_socket_win.cc
+++ b/net/udp/udp_socket_win.cc
@@ -11,6 +11,7 @@
#include "base/memory/memory_debug.h"
#include "base/message_loop.h"
#include "base/metrics/stats_counters.h"
+#include "base/rand_util.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
@@ -19,6 +20,14 @@
#include "net/base/winsock_init.h"
#include "net/base/winsock_util.h"
+namespace {
+
+static const int kBindRetries = 10;
+static const int kPortStart = 1024;
+static const int kPortEnd = 65535;
+
+} // namespace net
+
namespace net {
void UDPSocketWin::ReadDelegate::OnObjectSignaled(HANDLE object) {
@@ -31,9 +40,13 @@ void UDPSocketWin::WriteDelegate::OnObjectSignaled(HANDLE object) {
socket_->DidCompleteWrite();
}
-UDPSocketWin::UDPSocketWin(net::NetLog* net_log,
+UDPSocketWin::UDPSocketWin(DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
const net::NetLog::Source& source)
: socket_(INVALID_SOCKET),
+ bind_type_(bind_type),
+ rand_int_cb_(rand_int_cb),
ALLOW_THIS_IN_INITIALIZER_LIST(read_delegate_(this)),
ALLOW_THIS_IN_INITIALIZER_LIST(write_delegate_(this)),
recv_from_address_(NULL),
@@ -49,6 +62,8 @@ UDPSocketWin::UDPSocketWin(net::NetLog* net_log,
read_overlapped_.hEvent = WSACreateEvent();
memset(&write_overlapped_, 0, sizeof(write_overlapped_));
write_overlapped_.hEvent = WSACreateEvent();
+ if (bind_type == DatagramSocket::RANDOM_BIND)
+ DCHECK(!rand_int_cb.is_null());
}
UDPSocketWin::~UDPSocketWin() {
@@ -184,6 +199,13 @@ int UDPSocketWin::Connect(const IPEndPoint& address) {
if (rv < 0)
return rv;
+ if (bind_type_ == DatagramSocket::RANDOM_BIND)
+ rv = RandomBind(address);
+ // else connect() does the DatagramSocket::DEFAULT_BIND
+
+ if (rv < 0)
+ return rv;
+
struct sockaddr_storage addr_storage;
size_t addr_len = sizeof(addr_storage);
struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
@@ -200,21 +222,12 @@ int UDPSocketWin::Connect(const IPEndPoint& address) {
int UDPSocketWin::Bind(const IPEndPoint& address) {
DCHECK(!is_connected());
- DCHECK(!local_address_.get());
int rv = CreateSocket(address);
if (rv < 0)
return rv;
-
- struct sockaddr_storage addr_storage;
- size_t addr_len = sizeof(addr_storage);
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
- if (!address.ToSockAddr(addr, &addr_len))
- return ERR_FAILED;
-
- rv = bind(socket_, addr, addr_len);
+ rv = DoBind(address);
if (rv < 0)
- return MapSystemError(WSAGetLastError());
-
+ return rv;
local_address_.reset();
return rv;
}
@@ -370,4 +383,28 @@ int UDPSocketWin::InternalSendTo(IOBuffer* buf, int buf_len,
return ERR_IO_PENDING;
}
+int UDPSocketWin::DoBind(const IPEndPoint& address) {
+ struct sockaddr_storage addr_storage;
+ size_t addr_len = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ if (!address.ToSockAddr(addr, &addr_len))
+ return ERR_UNEXPECTED;
+ int rv = bind(socket_, addr, addr_len);
+ return rv < 0 ? MapSystemError(WSAGetLastError()) : rv;
+}
+
+int UDPSocketWin::RandomBind(const IPEndPoint& address) {
+ DCHECK(bind_type_ == DatagramSocket::RANDOM_BIND && !rand_int_cb_.is_null());
+
+ // Construct IPAddressNumber of appropriate size (IPv4 or IPv6) of 0s.
+ IPAddressNumber ip(address.address().size());
+
+ for (int i = 0; i < kBindRetries; ++i) {
+ int rv = DoBind(IPEndPoint(ip, rand_int_cb_.Run(kPortStart, kPortEnd)));
+ if (rv == OK || rv != ERR_ADDRESS_IN_USE)
+ return rv;
+ }
+ return DoBind(IPEndPoint(ip, 0));
+}
+
} // namespace net
diff --git a/net/udp/udp_socket_win.h b/net/udp/udp_socket_win.h
index 96dc55f..01d5855 100644
--- a/net/udp/udp_socket_win.h
+++ b/net/udp/udp_socket_win.h
@@ -13,9 +13,11 @@
#include "base/threading/non_thread_safe.h"
#include "base/win/object_watcher.h"
#include "net/base/completion_callback.h"
+#include "net/base/rand_callback.h"
#include "net/base/ip_endpoint.h"
#include "net/base/io_buffer.h"
#include "net/base/net_log.h"
+#include "net/udp/datagram_socket.h"
namespace net {
@@ -23,7 +25,9 @@ class BoundNetLog;
class UDPSocketWin : public base::NonThreadSafe {
public:
- UDPSocketWin(net::NetLog* net_log,
+ UDPSocketWin(DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
const net::NetLog::Source& source);
virtual ~UDPSocketWin();
@@ -141,8 +145,18 @@ class UDPSocketWin : public base::NonThreadSafe {
int InternalRecvFrom(IOBuffer* buf, int buf_len, IPEndPoint* address);
int InternalSendTo(IOBuffer* buf, int buf_len, const IPEndPoint* address);
+ int DoBind(const IPEndPoint& address);
+ int RandomBind(const IPEndPoint& address);
+
SOCKET socket_;
+ // How to do source port binding, used only when UDPSocket is part of
+ // UDPClientSocket, since UDPServerSocket provides Bind.
+ DatagramSocket::BindType bind_type_;
+
+ // PRNG function for generating port numbers.
+ RandIntCallback rand_int_cb_;
+
// These are mutable since they're just cached copies to make
// GetPeerAddress/GetLocalAddress smarter.
mutable scoped_ptr<IPEndPoint> local_address_;