// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/dns/dns_socket_pool.h" #include "base/logging.h" #include "base/rand_util.h" #include "base/stl_util.h" #include "net/base/address_list.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/rand_callback.h" #include "net/socket/client_socket_factory.h" #include "net/socket/stream_socket.h" #include "net/udp/datagram_client_socket.h" namespace net { namespace { // When we initialize the SocketPool, we allocate kInitialPoolSize sockets. // When we allocate a socket, we ensure we have at least kAllocateMinSize // sockets to choose from. Freed sockets are not retained. // On Windows, we can't request specific (random) ports, since that will // trigger firewall prompts, so request default ones, but keep a pile of // them. Everywhere else, request fresh, random ports each time. #if defined(OS_WIN) const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND; const unsigned kInitialPoolSize = 256; const unsigned kAllocateMinSize = 256; #else const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND; const unsigned kInitialPoolSize = 0; const unsigned kAllocateMinSize = 1; #endif } // namespace DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory) : socket_factory_(socket_factory), net_log_(NULL), nameservers_(NULL), initialized_(false) { } void DnsSocketPool::InitializeInternal( const std::vector* nameservers, NetLog* net_log) { DCHECK(nameservers); DCHECK(!initialized_); net_log_ = net_log; nameservers_ = nameservers; initialized_ = true; } scoped_ptr DnsSocketPool::CreateTCPSocket( unsigned server_index, const NetLog::Source& source) { DCHECK_LT(server_index, nameservers_->size()); return scoped_ptr( socket_factory_->CreateTransportClientSocket( AddressList((*nameservers_)[server_index]), net_log_, source)); } scoped_ptr DnsSocketPool::CreateConnectedSocket( unsigned server_index) { DCHECK_LT(server_index, nameservers_->size()); scoped_ptr socket; NetLog::Source no_source; socket = socket_factory_->CreateDatagramClientSocket( kBindType, base::Bind(&base::RandInt), net_log_, no_source); if (socket.get()) { int rv = socket->Connect((*nameservers_)[server_index]); if (rv != OK) { VLOG(1) << "Failed to connect socket: " << rv; socket.reset(); } } else { LOG(WARNING) << "Failed to create socket."; } return socket.Pass(); } class NullDnsSocketPool : public DnsSocketPool { public: NullDnsSocketPool(ClientSocketFactory* factory) : DnsSocketPool(factory) { } virtual void Initialize( const std::vector* nameservers, NetLog* net_log) OVERRIDE { InitializeInternal(nameservers, net_log); } virtual scoped_ptr AllocateSocket( unsigned server_index) OVERRIDE { return CreateConnectedSocket(server_index); } virtual void FreeSocket( unsigned server_index, scoped_ptr socket) OVERRIDE { } private: DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool); }; // static scoped_ptr DnsSocketPool::CreateNull( ClientSocketFactory* factory) { return scoped_ptr(new NullDnsSocketPool(factory)); } class DefaultDnsSocketPool : public DnsSocketPool { public: DefaultDnsSocketPool(ClientSocketFactory* factory) : DnsSocketPool(factory) { }; virtual ~DefaultDnsSocketPool(); virtual void Initialize( const std::vector* nameservers, NetLog* net_log) OVERRIDE; virtual scoped_ptr AllocateSocket( unsigned server_index) OVERRIDE; virtual void FreeSocket( unsigned server_index, scoped_ptr socket) OVERRIDE; private: void FillPool(unsigned server_index, unsigned size); typedef std::vector SocketVector; std::vector pools_; DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool); }; // static scoped_ptr DnsSocketPool::CreateDefault( ClientSocketFactory* factory) { return scoped_ptr(new DefaultDnsSocketPool(factory)); } void DefaultDnsSocketPool::Initialize( const std::vector* nameservers, NetLog* net_log) { InitializeInternal(nameservers, net_log); DCHECK(pools_.empty()); const unsigned num_servers = nameservers->size(); pools_.resize(num_servers); for (unsigned server_index = 0; server_index < num_servers; ++server_index) FillPool(server_index, kInitialPoolSize); } DefaultDnsSocketPool::~DefaultDnsSocketPool() { unsigned num_servers = pools_.size(); for (unsigned server_index = 0; server_index < num_servers; ++server_index) { SocketVector& pool = pools_[server_index]; STLDeleteElements(&pool); } } scoped_ptr DefaultDnsSocketPool::AllocateSocket( unsigned server_index) { DCHECK_LT(server_index, pools_.size()); SocketVector& pool = pools_[server_index]; FillPool(server_index, kAllocateMinSize); if (pool.size() == 0) { LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!"; return scoped_ptr(); } if (pool.size() < kAllocateMinSize) { LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize << " sockets to choose from, but only have " << pool.size() << " in pool " << server_index << "."; } unsigned socket_index = base::RandInt(0, pool.size() - 1); DatagramClientSocket* socket = pool[socket_index]; pool[socket_index] = pool.back(); pool.pop_back(); return scoped_ptr(socket); } void DefaultDnsSocketPool::FreeSocket( unsigned server_index, scoped_ptr socket) { DCHECK_LT(server_index, pools_.size()); } void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) { SocketVector& pool = pools_[server_index]; for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) { DatagramClientSocket* socket = CreateConnectedSocket(server_index).release(); if (!socket) break; pool.push_back(socket); } } } // namespace net