diff options
author | szym@chromium.org <szym@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-15 21:30:46 +0000 |
---|---|---|
committer | szym@chromium.org <szym@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-15 21:30:46 +0000 |
commit | 0adcb2b7566fc42a11409b79f81a8bbf4f1eed32 (patch) | |
tree | 650466bc57a4bf5082d4a0291dc030ada658f6ca /net/dns | |
parent | 69cc9fd8cf3a8b7a67109d13a687dc3f5d80626b (diff) | |
download | chromium_src-0adcb2b7566fc42a11409b79f81a8bbf4f1eed32.zip chromium_src-0adcb2b7566fc42a11409b79f81a8bbf4f1eed32.tar.gz chromium_src-0adcb2b7566fc42a11409b79f81a8bbf4f1eed32.tar.bz2 |
[net/dns] Resolve AF_UNSPEC on dual-stacked systems. Sort addresses according to RFC3484.
Original review: http://codereview.chromium.org/10442098/
BUG=113993
TEST=./net_unittests --gtest_filter=AddressSorter*:HostResolverImplDnsTest.DnsTaskUnspec
Review URL: https://chromiumcodereview.appspot.com/10855179
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@151750 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/dns')
-rw-r--r-- | net/dns/address_sorter.h | 46 | ||||
-rw-r--r-- | net/dns/address_sorter_posix.cc | 428 | ||||
-rw-r--r-- | net/dns/address_sorter_posix.h | 94 | ||||
-rw-r--r-- | net/dns/address_sorter_posix_unittest.cc | 325 | ||||
-rw-r--r-- | net/dns/address_sorter_unittest.cc | 66 | ||||
-rw-r--r-- | net/dns/address_sorter_win.cc | 198 | ||||
-rw-r--r-- | net/dns/dns_client.cc | 10 | ||||
-rw-r--r-- | net/dns/dns_client.h | 9 | ||||
-rw-r--r-- | net/dns/dns_response.cc | 4 | ||||
-rw-r--r-- | net/dns/dns_response.h | 4 | ||||
-rw-r--r-- | net/dns/dns_response_unittest.cc | 5 | ||||
-rw-r--r-- | net/dns/dns_test_util.cc | 157 | ||||
-rw-r--r-- | net/dns/dns_test_util.h | 22 |
13 files changed, 1301 insertions, 67 deletions
diff --git a/net/dns/address_sorter.h b/net/dns/address_sorter.h new file mode 100644 index 0000000..6ac9430 --- /dev/null +++ b/net/dns/address_sorter.h @@ -0,0 +1,46 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_ADDRESS_SORTER_H_ +#define NET_DNS_ADDRESS_SORTER_H_ + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" + +namespace net { + +class AddressList; + +// Sorts AddressList according to RFC3484, by likelihood of successful +// connection. Depending on the platform, the sort could be performed +// asynchronously by the OS, or synchronously by local implementation. +// AddressSorter does not necessarily preserve port numbers on the sorted list. +class NET_EXPORT AddressSorter { + public: + typedef base::Callback<void(bool success, + const AddressList& list)> CallbackType; + + virtual ~AddressSorter() {} + + // Sorts |list|, which must include at least one IPv6 address. + // Calls |callback| upon completion. Could complete synchronously. Could + // complete after this AddressSorter is destroyed. + virtual void Sort(const AddressList& list, + const CallbackType& callback) const = 0; + + // Creates platform-dependent AddressSorter. + static scoped_ptr<AddressSorter> CreateAddressSorter(); + + protected: + AddressSorter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(AddressSorter); +}; + +} // namespace net + +#endif // NET_DNS_ADDRESS_SORTER_H_ diff --git a/net/dns/address_sorter_posix.cc b/net/dns/address_sorter_posix.cc new file mode 100644 index 0000000..807cbf8 --- /dev/null +++ b/net/dns/address_sorter_posix.cc @@ -0,0 +1,428 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter_posix.h" + +#include <netinet/in.h> + +#if defined(OS_MACOSX) || defined(OS_BSD) +#include <sys/socket.h> // Must be included before ifaddrs.h. +#include <ifaddrs.h> +#include <net/if.h> +#include <netinet/in_var.h> +#include <string.h> +#include <sys/ioctl.h> +#endif + +#include <algorithm> + +#include "base/eintr_wrapper.h" +#include "base/logging.h" +#include "base/memory/scoped_vector.h" +#include "net/socket/client_socket_factory.h" +#include "net/udp/datagram_client_socket.h" + +#if defined(OS_LINUX) +#include "net/base/address_tracker_linux.h" +#endif + +namespace net { + +namespace { + +// Address sorting is performed according to RFC3484 with revisions. +// http://tools.ietf.org/html/draft-ietf-6man-rfc3484bis-06 +// Precedence and label are separate to support override through /etc/gai.conf. + +// Returns true if |p1| should precede |p2| in the table. +// Sorts table by decreasing prefix size to allow longest prefix matching. +bool ComparePolicy(const AddressSorterPosix::PolicyEntry& p1, + const AddressSorterPosix::PolicyEntry& p2) { + return p1.prefix_length > p2.prefix_length; +} + +// Creates sorted PolicyTable from |table| with |size| entries. +AddressSorterPosix::PolicyTable LoadPolicy( + AddressSorterPosix::PolicyEntry* table, + size_t size) { + AddressSorterPosix::PolicyTable result(table, table + size); + std::sort(result.begin(), result.end(), ComparePolicy); + return result; +} + +// Search |table| for matching prefix of |address|. |table| must be sorted by +// descending prefix (prefix of another prefix must be later in table). +unsigned GetPolicyValue(const AddressSorterPosix::PolicyTable& table, + const IPAddressNumber& address) { + if (address.size() == kIPv4AddressSize) + return GetPolicyValue(table, ConvertIPv4NumberToIPv6Number(address)); + for (unsigned i = 0; i < table.size(); ++i) { + const AddressSorterPosix::PolicyEntry& entry = table[i]; + IPAddressNumber prefix(entry.prefix, entry.prefix + kIPv6AddressSize); + if (IPNumberMatchesPrefix(address, prefix, entry.prefix_length)) + return entry.value; + } + NOTREACHED(); + // The last entry is the least restrictive, so assume it's default. + return table.back().value; +} + +bool IsIPv6Multicast(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + return address[0] == 0xFF; +} + +AddressSorterPosix::AddressScope GetIPv6MulticastScope( + const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + return static_cast<AddressSorterPosix::AddressScope>(address[1] & 0x0F); +} + +bool IsIPv6Loopback(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_LOOPBACK + unsigned char kLoopback[kIPv6AddressSize] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + }; + return address == IPAddressNumber(kLoopback, kLoopback + kIPv6AddressSize); +} + +bool IsIPv6LinkLocal(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_LINKLOCAL + return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80); +} + +bool IsIPv6SiteLocal(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_SITELOCAL + return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0xC0); +} + +AddressSorterPosix::AddressScope GetScope( + const AddressSorterPosix::PolicyTable& ipv4_scope_table, + const IPAddressNumber& address) { + if (address.size() == kIPv6AddressSize) { + if (IsIPv6Multicast(address)) { + return GetIPv6MulticastScope(address); + } else if (IsIPv6Loopback(address) || IsIPv6LinkLocal(address)) { + return AddressSorterPosix::SCOPE_LINKLOCAL; + } else if (IsIPv6SiteLocal(address)) { + return AddressSorterPosix::SCOPE_SITELOCAL; + } else { + return AddressSorterPosix::SCOPE_GLOBAL; + } + } else if (address.size() == kIPv4AddressSize) { + return static_cast<AddressSorterPosix::AddressScope>( + GetPolicyValue(ipv4_scope_table, address)); + } else { + NOTREACHED(); + return AddressSorterPosix::SCOPE_NODELOCAL; + } +} + +// Default policy table. RFC 3484, Section 2.1. +AddressSorterPosix::PolicyEntry kDefaultPrecedenceTable[] = { + // ::1/128 -- loopback + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 50 }, + // ::/0 -- any + { { }, 0, 40 }, + // ::ffff:0:0/96 -- IPv4 mapped + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 35 }, + // 2002::/16 -- 6to4 + { { 0x20, 0x02, }, 16, 30 }, + // 2001::/32 -- Teredo + { { 0x20, 0x01, 0, 0 }, 32, 5 }, + // fc00::/7 -- unique local address + { { 0xFC }, 7, 3 }, + // ::/96 -- IPv4 compatible + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 1 }, + // fec0::/10 -- site-local expanded scope + { { 0xFE, 0xC0 }, 10, 1 }, + // 3ffe::/16 -- 6bone + { { 0x3F, 0xFE }, 16, 1 }, +}; + +AddressSorterPosix::PolicyEntry kDefaultLabelTable[] = { + // ::1/128 -- loopback + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 0 }, + // ::/0 -- any + { { }, 0, 1 }, + // ::ffff:0:0/96 -- IPv4 mapped + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 4 }, + // 2002::/16 -- 6to4 + { { 0x20, 0x02, }, 16, 2 }, + // 2001::/32 -- Teredo + { { 0x20, 0x01, 0, 0 }, 32, 5 }, + // fc00::/7 -- unique local address + { { 0xFC }, 7, 13 }, + // ::/96 -- IPv4 compatible + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 3 }, + // fec0::/10 -- site-local expanded scope + { { 0xFE, 0xC0 }, 10, 11 }, + // 3ffe::/16 -- 6bone + { { 0x3F, 0xFE }, 16, 12 }, +}; + +// Default mapping of IPv4 addresses to scope. +AddressSorterPosix::PolicyEntry kDefaultIPv4ScopeTable[] = { + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F }, 104, + AddressSorterPosix::SCOPE_LINKLOCAL }, + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xA9, 0xFE }, 112, + AddressSorterPosix::SCOPE_LINKLOCAL }, + { { }, 0, AddressSorterPosix::SCOPE_GLOBAL }, +}; + +// Returns number of matching initial bits between the addresses |a1| and |a2|. +unsigned CommonPrefixLength(const IPAddressNumber& a1, + const IPAddressNumber& a2) { + DCHECK_EQ(a1.size(), a2.size()); + for (size_t i = 0; i < a1.size(); ++i) { + unsigned diff = a1[i] ^ a2[i]; + if (!diff) + continue; + for (unsigned j = 0; j < CHAR_BIT; ++j) { + if (diff & (1 << (CHAR_BIT - 1))) + return i * CHAR_BIT + j; + diff <<= 1; + } + NOTREACHED(); + } + return a1.size() * CHAR_BIT; +} + +// Computes the number of leading 1-bits in |mask|. +unsigned MaskPrefixLength(const IPAddressNumber& mask) { + IPAddressNumber all_ones(mask.size(), 0xFF); + return CommonPrefixLength(mask, all_ones); +} + +struct DestinationInfo { + IPAddressNumber address; + AddressSorterPosix::AddressScope scope; + unsigned precedence; + unsigned label; + const AddressSorterPosix::SourceAddressInfo* src; + unsigned common_prefix_length; +}; + +// Returns true iff |dst_a| should precede |dst_b| in the address list. +// RFC 3484, section 6. +bool CompareDestinations(const DestinationInfo* dst_a, + const DestinationInfo* dst_b) { + // Rule 1: Avoid unusable destinations. + // Unusable destinations are already filtered out. + DCHECK(dst_a->src); + DCHECK(dst_b->src); + + // Rule 2: Prefer matching scope. + bool scope_match1 = (dst_a->src->scope == dst_a->scope); + bool scope_match2 = (dst_b->src->scope == dst_b->scope); + if (scope_match1 != scope_match2) + return scope_match1; + + // Rule 3: Avoid deprecated addresses. + if (dst_a->src->deprecated != dst_b->src->deprecated) + return !dst_a->src->deprecated; + + // Rule 4: Prefer home addresses. + if (dst_a->src->home != dst_b->src->home) + return dst_a->src->home; + + // Rule 5: Prefer matching label. + bool label_match1 = (dst_a->src->label == dst_a->label); + bool label_match2 = (dst_b->src->label == dst_b->label); + if (label_match1 != label_match2) + return label_match1; + + // Rule 6: Prefer higher precedence. + if (dst_a->precedence != dst_b->precedence) + return dst_a->precedence > dst_b->precedence; + + // Rule 7: Prefer native transport. + if (dst_a->src->native != dst_b->src->native) + return dst_a->src->native; + + // Rule 8: Prefer smaller scope. + if (dst_a->scope != dst_b->scope) + return dst_a->scope < dst_b->scope; + + // Rule 9: Use longest matching prefix. Only for matching address families. + if (dst_a->address.size() == dst_b->address.size()) { + if (dst_a->common_prefix_length != dst_b->common_prefix_length) + return dst_a->common_prefix_length > dst_b->common_prefix_length; + } + + // Rule 10: Leave the order unchanged. + // stable_sort takes care of that. + return false; +} + +} // namespace + +AddressSorterPosix::AddressSorterPosix(ClientSocketFactory* socket_factory) + : socket_factory_(socket_factory), + precedence_table_(LoadPolicy(kDefaultPrecedenceTable, + arraysize(kDefaultPrecedenceTable))), + label_table_(LoadPolicy(kDefaultLabelTable, + arraysize(kDefaultLabelTable))), + ipv4_scope_table_(LoadPolicy(kDefaultIPv4ScopeTable, + arraysize(kDefaultIPv4ScopeTable))) { + NetworkChangeNotifier::AddIPAddressObserver(this); + OnIPAddressChanged(); +} + +AddressSorterPosix::~AddressSorterPosix() { + NetworkChangeNotifier::RemoveIPAddressObserver(this); +} + +void AddressSorterPosix::Sort(const AddressList& list, + const CallbackType& callback) const { + DCHECK(CalledOnValidThread()); + ScopedVector<DestinationInfo> sort_list; + + for (size_t i = 0; i < list.size(); ++i) { + scoped_ptr<DestinationInfo> info(new DestinationInfo()); + info->address = list[i].address(); + info->scope = GetScope(ipv4_scope_table_, info->address); + info->precedence = GetPolicyValue(precedence_table_, info->address); + info->label = GetPolicyValue(label_table_, info->address); + + // Each socket can only be bound once. + scoped_ptr<DatagramClientSocket> socket( + socket_factory_->CreateDatagramClientSocket( + DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL /* NetLog */, + NetLog::Source())); + + // Even though no packets are sent, cannot use port 0 in Connect. + IPEndPoint dest(info->address, 80 /* port */); + int rv = socket->Connect(dest); + if (rv != OK) { + LOG(WARNING) << "Could not connect to " << dest.ToStringWithoutPort() + << " reason " << rv; + continue; + } + // Filter out unusable destinations. + IPEndPoint src; + rv = socket->GetLocalAddress(&src); + if (rv != OK) { + LOG(WARNING) << "Could not get local address for " + << src.ToStringWithoutPort() << " reason " << rv; + continue; + } + + SourceAddressInfo& src_info = source_map_[src.address()]; + if (src_info.scope == SCOPE_UNDEFINED) { + // If |source_info_| is out of date, |src| might be missing, but we still + // want to sort, even though the HostCache will be cleared soon. + FillPolicy(src.address(), &src_info); + } + info->src = &src_info; + + if (info->address.size() == src.address().size()) { + info->common_prefix_length = std::min( + CommonPrefixLength(info->address, src.address()), + info->src->prefix_length); + } + sort_list.push_back(info.release()); + } + + std::stable_sort(sort_list.begin(), sort_list.end(), CompareDestinations); + + AddressList result; + for (size_t i = 0; i < sort_list.size(); ++i) + result.push_back(IPEndPoint(sort_list[i]->address, 0 /* port */)); + + callback.Run(true, result); +} + +void AddressSorterPosix::OnIPAddressChanged() { + DCHECK(CalledOnValidThread()); + source_map_.clear(); +#if defined(OS_LINUX) + const internal::AddressTrackerLinux* tracker = + NetworkChangeNotifier::GetAddressTracker(); + if (!tracker) + return; + typedef internal::AddressTrackerLinux::AddressMap AddressMap; + AddressMap map = tracker->GetAddressMap(); + for (AddressMap::const_iterator it = map.begin(); it != map.end(); ++it) { + const IPAddressNumber& address = it->first; + const struct ifaddrmsg& msg = it->second; + SourceAddressInfo& info = source_map_[address]; + info.native = false; // TODO(szym): obtain this via netlink. + info.deprecated = msg.ifa_flags & IFA_F_DEPRECATED; + info.home = msg.ifa_flags & IFA_F_HOMEADDRESS; + info.prefix_length = msg.ifa_prefixlen; + FillPolicy(address, &info); + } +#elif defined(OS_MACOSX) || defined(OS_BSD) + // It's not clear we will receive notification when deprecated flag changes. + // Socket for ioctl. + int ioctl_socket = socket(AF_INET6, SOCK_DGRAM, 0); + if (ioctl_socket < 0) + return; + + struct ifaddrs* addrs; + int rv = getifaddrs(&addrs); + if (rv < 0) { + LOG(WARNING) << "getifaddrs failed " << rv; + close(ioctl_socket); + return; + } + + for (struct ifaddrs* ifa = addrs; ifa != NULL; ifa = ifa->ifa_next) { + IPEndPoint src; + if (!src.FromSockAddr(ifa->ifa_addr, ifa->ifa_addr->sa_len)) { + LOG(WARNING) << "FromSockAddr failed"; + continue; + } + SourceAddressInfo& info = source_map_[src.address()]; + // Note: no known way to fill in |native| and |home|. + info.native = info.home = info.deprecated = false; + if (ifa->ifa_addr->sa_family == AF_INET6) { + struct in6_ifreq ifr = {}; + strncpy(ifr.ifr_name, ifa->ifa_name, sizeof(ifr.ifr_name) - 1); + DCHECK_LE(ifa->ifa_addr->sa_len, sizeof(ifr.ifr_ifru.ifru_addr)); + memcpy(&ifr.ifr_ifru.ifru_addr, ifa->ifa_addr, ifa->ifa_addr->sa_len); + int rv = ioctl(ioctl_socket, SIOCGIFAFLAG_IN6, &ifr); + if (rv > 0) { + info.deprecated = ifr.ifr_ifru.ifru_flags & IN6_IFF_DEPRECATED; + } else { + LOG(WARNING) << "SIOCGIFAFLAG_IN6 failed " << rv; + } + } + if (ifa->ifa_netmask) { + IPEndPoint netmask; + if (netmask.FromSockAddr(ifa->ifa_netmask, ifa->ifa_addr->sa_len)) { + info.prefix_length = MaskPrefixLength(netmask.address()); + } else { + LOG(WARNING) << "FromSockAddr failed on netmask"; + } + } + FillPolicy(src.address(), &info); + } + freeifaddrs(addrs); + close(ioctl_socket); +#endif +} + +void AddressSorterPosix::FillPolicy(const IPAddressNumber& address, + SourceAddressInfo* info) const { + DCHECK(CalledOnValidThread()); + info->scope = GetScope(ipv4_scope_table_, address); + info->label = GetPolicyValue(label_table_, address); +} + +// static +scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() { + return scoped_ptr<AddressSorter>( + new AddressSorterPosix(ClientSocketFactory::GetDefaultFactory())); +} + +} // namespace net + diff --git a/net/dns/address_sorter_posix.h b/net/dns/address_sorter_posix.h new file mode 100644 index 0000000..1c88ad2 --- /dev/null +++ b/net/dns/address_sorter_posix.h @@ -0,0 +1,94 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_ADDRESS_SORTER_POSIX_H_ +#define NET_DNS_ADDRESS_SORTER_POSIX_H_ + +#include <map> +#include <vector> + +#include "base/threading/non_thread_safe.h" +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" +#include "net/base/network_change_notifier.h" +#include "net/dns/address_sorter.h" + +namespace net { + +class ClientSocketFactory; + +// This implementation uses explicit policy to perform the sorting. It is not +// thread-safe and always completes synchronously. +class NET_EXPORT_PRIVATE AddressSorterPosix + : public AddressSorter, + public base::NonThreadSafe, + public NetworkChangeNotifier::IPAddressObserver { + public: + // Generic policy entry. + struct PolicyEntry { + // IPv4 addresses must be mapped to IPv6. + unsigned char prefix[kIPv6AddressSize]; + unsigned prefix_length; + unsigned value; + }; + + typedef std::vector<PolicyEntry> PolicyTable; + + enum AddressScope { + SCOPE_UNDEFINED = 0, + SCOPE_NODELOCAL = 1, + SCOPE_LINKLOCAL = 2, + SCOPE_SITELOCAL = 5, + SCOPE_ORGLOCAL = 8, + SCOPE_GLOBAL = 14, + }; + + struct SourceAddressInfo { + // Values read from policy tables. + AddressScope scope; + unsigned label; + + // Values from the OS, matter only if more than one source address is used. + unsigned prefix_length; + bool deprecated; // vs. preferred RFC4862 + bool home; // vs. care-of RFC6275 + bool native; + }; + + typedef std::map<IPAddressNumber, SourceAddressInfo> SourceAddressMap; + + explicit AddressSorterPosix(ClientSocketFactory* socket_factory); + virtual ~AddressSorterPosix(); + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE; + + private: + friend class AddressSorterPosixTest; + + // NetworkChangeNotifier::IPAddressObserver: + virtual void OnIPAddressChanged() OVERRIDE; + + // Fills |info| with values for |address| from policy tables. + void FillPolicy(const IPAddressNumber& address, + SourceAddressInfo* info) const; + + // Mutable to allow using default values for source addresses which were not + // found in most recent OnIPAddressChanged. + mutable SourceAddressMap source_map_; + + ClientSocketFactory* socket_factory_; + PolicyTable precedence_table_; + PolicyTable label_table_; + PolicyTable ipv4_scope_table_; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterPosix); +}; + +} // namespace net + +#endif // NET_DNS_ADDRESS_SORTER_POSIX_H_ diff --git a/net/dns/address_sorter_posix_unittest.cc b/net/dns/address_sorter_posix_unittest.cc new file mode 100644 index 0000000..96cbfc6 --- /dev/null +++ b/net/dns/address_sorter_posix_unittest.cc @@ -0,0 +1,325 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter_posix.h" + +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/client_socket_factory.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +// Used to map destination address to source address. +typedef std::map<IPAddressNumber, IPAddressNumber> AddressMapping; + +IPAddressNumber ParseIP(const std::string& str) { + IPAddressNumber addr; + CHECK(ParseIPLiteralToNumber(str, &addr)); + return addr; +} + +// A mock socket which binds to source address according to AddressMapping. +class TestUDPClientSocket : public DatagramClientSocket { + public: + explicit TestUDPClientSocket(const AddressMapping* mapping) + : mapping_(mapping), connected_(false) {} + + virtual ~TestUDPClientSocket() {} + + virtual int Read(IOBuffer*, int, const CompletionCallback&) OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual int Write(IOBuffer*, int, const CompletionCallback&) OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual bool SetReceiveBufferSize(int32) OVERRIDE { + return true; + } + virtual bool SetSendBufferSize(int32) OVERRIDE { + return true; + } + + virtual void Close() OVERRIDE {} + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!connected_) + return ERR_UNEXPECTED; + *address = local_endpoint_; + return OK; + } + + virtual int Connect(const IPEndPoint& remote) OVERRIDE { + if (connected_) + return ERR_UNEXPECTED; + AddressMapping::const_iterator it = mapping_->find(remote.address()); + if (it == mapping_->end()) + return ERR_FAILED; + connected_ = true; + local_endpoint_ = IPEndPoint(it->second, 39874 /* arbitrary port */); + return OK; + } + + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + private: + BoundNetLog net_log_; + const AddressMapping* mapping_; + bool connected_; + IPEndPoint local_endpoint_; + + DISALLOW_COPY_AND_ASSIGN(TestUDPClientSocket); +}; + +// Creates TestUDPClientSockets and maintains an AddressMapping. +class TestSocketFactory : public ClientSocketFactory { + public: + TestSocketFactory() {} + virtual ~TestSocketFactory() {} + + virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType, + const RandIntCallback&, + NetLog*, + const NetLog::Source&) OVERRIDE { + return new TestUDPClientSocket(&mapping_); + } + virtual StreamSocket* CreateTransportClientSocket( + const AddressList&, + NetLog*, + const NetLog::Source&) OVERRIDE { + NOTIMPLEMENTED(); + return NULL; + } + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocketHandle*, + const HostPortPair&, + const SSLConfig&, + const SSLClientSocketContext&) OVERRIDE { + NOTIMPLEMENTED(); + return NULL; + } + virtual void ClearSSLSessionCache() OVERRIDE { + NOTIMPLEMENTED(); + } + + void AddMapping(const IPAddressNumber& dst, const IPAddressNumber& src) { + mapping_[dst] = src; + } + + private: + AddressMapping mapping_; + + DISALLOW_COPY_AND_ASSIGN(TestSocketFactory); +}; + +void OnSortComplete(AddressList* result_buf, + const CompletionCallback& callback, + bool success, + const AddressList& result) { + EXPECT_TRUE(success); + if (success) + *result_buf = result; + callback.Run(OK); +} + +} // namespace + +class AddressSorterPosixTest : public testing::Test { + protected: + AddressSorterPosixTest() : sorter_(&socket_factory_) {} + + void AddMapping(const std::string& dst, const std::string& src) { + socket_factory_.AddMapping(ParseIP(dst), ParseIP(src)); + } + + AddressSorterPosix::SourceAddressInfo* GetSourceInfo( + const std::string& addr) { + IPAddressNumber address = ParseIP(addr); + AddressSorterPosix::SourceAddressInfo* info = &sorter_.source_map_[address]; + if (info->scope == AddressSorterPosix::SCOPE_UNDEFINED) + sorter_.FillPolicy(address, info); + return info; + } + + // Verify that NULL-terminated |addresses| matches (-1)-terminated |order| + // after sorting. + void Verify(const char* addresses[], const int order[]) { + AddressList list; + for (const char** addr = addresses; *addr != NULL; ++addr) + list.push_back(IPEndPoint(ParseIP(*addr), 80)); + for (size_t i = 0; order[i] >= 0; ++i) + CHECK_LT(order[i], static_cast<int>(list.size())); + + AddressList result; + TestCompletionCallback callback; + sorter_.Sort(list, base::Bind(&OnSortComplete, &result, + callback.callback())); + callback.WaitForResult(); + + for (size_t i = 0; (i < result.size()) || (order[i] >= 0); ++i) { + IPEndPoint expected = order[i] >= 0 ? list[order[i]] : IPEndPoint(); + IPEndPoint actual = i < result.size() ? result[i] : IPEndPoint(); + EXPECT_TRUE(expected.address() == actual.address()) << + "Address out of order at position " << i << "\n" << + " Actual: " << actual.ToStringWithoutPort() << "\n" << + "Expected: " << expected.ToStringWithoutPort(); + } + } + + TestSocketFactory socket_factory_; + AddressSorterPosix sorter_; +}; + +// Rule 1: Avoid unusable destinations. +TEST_F(AddressSorterPosixTest, Rule1) { + AddMapping("10.0.0.231", "10.0.0.1"); + const char* addresses[] = { "::1", "10.0.0.231", "127.0.0.1", NULL }; + const int order[] = { 1, -1 }; + Verify(addresses, order); +} + +// Rule 2: Prefer matching scope. +TEST_F(AddressSorterPosixTest, Rule2) { + AddMapping("3002::1", "4000::10"); // matching global + AddMapping("ff32::1", "fe81::10"); // matching link-local + AddMapping("fec1::1", "fec1::10"); // matching node-local + AddMapping("3002::2", "::1"); // global vs. link-local + AddMapping("fec1::2", "fe81::10"); // site-local vs. link-local + AddMapping("8.0.0.1", "169.254.0.10"); // global vs. link-local + // In all three cases, matching scope is preferred. + const int order[] = { 1, 0, -1 }; + const char* addresses1[] = { "3002::2", "3002::1", NULL }; + Verify(addresses1, order); + const char* addresses2[] = { "fec1::2", "ff32::1", NULL }; + Verify(addresses2, order); + const char* addresses3[] = { "8.0.0.1", "fec1::1", NULL }; + Verify(addresses3, order); +} + +// Rule 3: Avoid deprecated addresses. +TEST_F(AddressSorterPosixTest, Rule3) { + // Matching scope. + AddMapping("3002::1", "4000::10"); + GetSourceInfo("4000::10")->deprecated = true; + AddMapping("3002::2", "4000::20"); + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 4: Prefer home addresses. +TEST_F(AddressSorterPosixTest, Rule4) { + AddMapping("3002::1", "4000::10"); + AddMapping("3002::2", "4000::20"); + GetSourceInfo("4000::20")->home = true; + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 5: Prefer matching label. +TEST_F(AddressSorterPosixTest, Rule5) { + AddMapping("::1", "::1"); // matching loopback + AddMapping("::ffff:1234:1", "::ffff:1234:10"); // matching IPv4-mapped + AddMapping("2001::1", "::ffff:1234:10"); // Teredo vs. IPv4-mapped + AddMapping("2002::1", "2001::10"); // 6to4 vs. Teredo + const int order[] = { 1, 0, -1 }; + { + const char* addresses[] = { "2001::1", "::1", NULL }; + Verify(addresses, order); + } + { + const char* addresses[] = { "2002::1", "::ffff:1234:1", NULL }; + Verify(addresses, order); + } +} + +// Rule 6: Prefer higher precedence. +TEST_F(AddressSorterPosixTest, Rule6) { + AddMapping("::1", "::1"); // loopback + AddMapping("ff32::1", "fe81::10"); // multicast + AddMapping("::ffff:1234:1", "::ffff:1234:10"); // IPv4-mapped + AddMapping("2001::1", "2001::10"); // Teredo + const char* addresses[] = { "2001::1", "::ffff:1234:1", "ff32::1", "::1", + NULL }; + const int order[] = { 3, 2, 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 7: Prefer native transport. +TEST_F(AddressSorterPosixTest, Rule7) { + AddMapping("3002::1", "4000::10"); + AddMapping("3002::2", "4000::20"); + GetSourceInfo("4000::20")->native = true; + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 8: Prefer smaller scope. +TEST_F(AddressSorterPosixTest, Rule8) { + // Matching scope. Should precede the others by Rule 2. + AddMapping("fe81::1", "fe81::10"); // link-local + AddMapping("3000::1", "4000::10"); // global + // Mismatched scope. + AddMapping("ff32::1", "4000::10"); // link-local + AddMapping("ff35::1", "4000::10"); // site-local + AddMapping("ff38::1", "4000::10"); // org-local + const char* addresses[] = { "ff38::1", "3000::1", "ff35::1", "ff32::1", + "fe81::1", NULL }; + const int order[] = { 4, 1, 3, 2, 0, -1 }; + Verify(addresses, order); +} + +// Rule 9: Use longest matching prefix. +TEST_F(AddressSorterPosixTest, Rule9) { + AddMapping("3000::1", "3000:ffff::10"); // 16 bit match + GetSourceInfo("3000:ffff::10")->prefix_length = 16; + AddMapping("4000::1", "4000::10"); // 123 bit match, limited to 15 + GetSourceInfo("4000::10")->prefix_length = 15; + AddMapping("4002::1", "4000::10"); // 14 bit match + AddMapping("4080::1", "4000::10"); // 8 bit match + const char* addresses[] = { "4080::1", "4002::1", "4000::1", "3000::1", + NULL }; + const int order[] = { 3, 2, 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 10: Leave the order unchanged. +TEST_F(AddressSorterPosixTest, Rule10) { + AddMapping("4000::1", "4000::10"); + AddMapping("4000::2", "4000::10"); + AddMapping("4000::3", "4000::10"); + const char* addresses[] = { "4000::1", "4000::2", "4000::3", NULL }; + const int order[] = { 0, 1, 2, -1 }; + Verify(addresses, order); +} + +TEST_F(AddressSorterPosixTest, MultipleRules) { + AddMapping("::1", "::1"); // loopback + AddMapping("ff32::1", "fe81::10"); // link-local multicast + AddMapping("ff3e::1", "4000::10"); // global multicast + AddMapping("4000::1", "4000::10"); // global unicast + AddMapping("ff32::2", "fe81::20"); // deprecated link-local multicast + GetSourceInfo("fe81::20")->deprecated = true; + const char* addresses[] = { "ff3e::1", "ff32::2", "4000::1", "ff32::1", "::1", + "8.0.0.1", NULL }; + const int order[] = { 4, 3, 0, 2, 1, -1 }; + Verify(addresses, order); +} + +} // namespace net diff --git a/net/dns/address_sorter_unittest.cc b/net/dns/address_sorter_unittest.cc new file mode 100644 index 0000000..0c2be88 --- /dev/null +++ b/net/dns/address_sorter_unittest.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif + +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/address_list.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "testing/gtest/include/gtest/gtest.h" + +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif + +namespace net { +namespace { + +IPEndPoint MakeEndPoint(const std::string& str) { + IPAddressNumber addr; + CHECK(ParseIPLiteralToNumber(str, &addr)); + return IPEndPoint(addr, 0); +} + +void OnSortComplete(AddressList* result_buf, + const CompletionCallback& callback, + bool success, + const AddressList& result) { + if (success) + *result_buf = result; + callback.Run(success ? OK : ERR_FAILED); +} + +TEST(AddressSorterTest, Sort) { + int expected_result = OK; +#if defined(OS_WIN) + EnsureWinsockInit(); + SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) { + expected_result = ERR_FAILED; + } else { + closesocket(sock); + } +#endif + scoped_ptr<AddressSorter> sorter(AddressSorter::CreateAddressSorter()); + AddressList list; + list.push_back(MakeEndPoint("10.0.0.1")); + list.push_back(MakeEndPoint("8.8.8.8")); + list.push_back(MakeEndPoint("::1")); + list.push_back(MakeEndPoint("2001:4860:4860::8888")); + + AddressList result; + TestCompletionCallback callback; + sorter->Sort(list, base::Bind(&OnSortComplete, &result, + callback.callback())); + EXPECT_EQ(expected_result, callback.WaitForResult()); +} + +} // namespace +} // namespace net diff --git a/net/dns/address_sorter_win.cc b/net/dns/address_sorter_win.cc new file mode 100644 index 0000000..0f1c504 --- /dev/null +++ b/net/dns/address_sorter_win.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter.h" + +#include <winsock2.h> + +#include <algorithm> + +#include "base/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/threading/worker_pool.h" +#include "base/win/windows_version.h" +#include "net/base/address_list.h" +#include "net/base/ip_endpoint.h" +#include "net/base/winsock_init.h" + +namespace net { + +namespace { + +class AddressSorterWin : public AddressSorter { + public: + AddressSorterWin() { + EnsureWinsockInit(); + } + + virtual ~AddressSorterWin() {} + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + DCHECK(!list.empty()); + scoped_refptr<Job> job = new Job(list, callback); + } + + private: + // Executes the SIO_ADDRESS_LIST_SORT ioctl on the WorkerPool, and + // performs the necessary conversions to/from AddressList. + class Job : public base::RefCountedThreadSafe<Job> { + public: + Job(const AddressList& list, const CallbackType& callback) + : callback_(callback), + buffer_size_(sizeof(SOCKET_ADDRESS_LIST) + + list.size() * (sizeof(SOCKET_ADDRESS) + + sizeof(SOCKADDR_STORAGE))), + input_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>( + malloc(buffer_size_))), + output_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>( + malloc(buffer_size_))), + success_(false) { + input_buffer_->iAddressCount = list.size(); + SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>( + input_buffer_->Address + input_buffer_->iAddressCount); + + for (size_t i = 0; i < list.size(); ++i) { + IPEndPoint ipe = list[i]; + // Addresses must be sockaddr_in6. + if (ipe.GetFamily() == AF_INET) { + ipe = IPEndPoint(ConvertIPv4NumberToIPv6Number(ipe.address()), + ipe.port()); + } + + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i); + socklen_t addr_len = sizeof(SOCKADDR_STORAGE); + bool result = ipe.ToSockAddr(addr, &addr_len); + DCHECK(result); + input_buffer_->Address[i].lpSockaddr = addr; + input_buffer_->Address[i].iSockaddrLength = addr_len; + } + + if (!base::WorkerPool::PostTaskAndReply( + FROM_HERE, + base::Bind(&Job::Run, this), + base::Bind(&Job::OnComplete, this), + false /* task is slow */)) { + LOG(ERROR) << "WorkerPool::PostTaskAndReply failed"; + OnComplete(); + } + } + + private: + friend class base::RefCountedThreadSafe<Job>; + ~Job() {} + + // Executed on the WorkerPool. + void Run() { + SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) + return; + DWORD result_size = 0; + int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(), + buffer_size_, output_buffer_.get(), buffer_size_, + &result_size, NULL, NULL); + if (result == SOCKET_ERROR) { + LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError(); + } else { + success_ = true; + } + closesocket(sock); + } + + // Executed on the calling thread. + void OnComplete() { + AddressList list; + if (success_) { + list.reserve(output_buffer_->iAddressCount); + for (int i = 0; i < output_buffer_->iAddressCount; ++i) { + IPEndPoint ipe; + ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr, + output_buffer_->Address[i].iSockaddrLength); + // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works. + if (IsIPv4Mapped(ipe.address())) { + ipe = IPEndPoint(ConvertIPv4MappedToIPv4(ipe.address()), + ipe.port()); + } + list.push_back(ipe); + } + } + callback_.Run(success_, list); + } + + const CallbackType callback_; + const size_t buffer_size_; + scoped_ptr_malloc<SOCKET_ADDRESS_LIST> input_buffer_; + scoped_ptr_malloc<SOCKET_ADDRESS_LIST> output_buffer_; + bool success_; + + DISALLOW_COPY_AND_ASSIGN(Job); + }; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterWin); +}; + +// Merges |list_ipv4| and |list_ipv6| before passing it to |callback|, but +// only if |success| is true. +void MergeResults(const AddressSorter::CallbackType& callback, + const AddressList& list_ipv4, + bool success, + const AddressList& list_ipv6) { + if (!success) { + callback.Run(false, AddressList()); + return; + } + AddressList list; + list.insert(list.end(), list_ipv6.begin(), list_ipv6.end()); + list.insert(list.end(), list_ipv4.begin(), list_ipv4.end()); + callback.Run(true, list); +} + +// Wrapper for AddressSorterWin which does not sort IPv4 or IPv4-mapped +// addresses but always puts them at the end of the list. Needed because the +// SIO_ADDRESS_LIST_SORT does not support IPv4 addresses on Windows XP. +class AddressSorterWinXP : public AddressSorter { + public: + AddressSorterWinXP() {} + virtual ~AddressSorterWinXP() {} + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + AddressList list_ipv4; + AddressList list_ipv6; + for (size_t i = 0; i < list.size(); ++i) { + const IPEndPoint& ipe = list[i]; + if (ipe.GetFamily() == AF_INET) { + list_ipv4.push_back(ipe); + } else { + list_ipv6.push_back(ipe); + } + } + if (!list_ipv6.empty()) { + sorter_.Sort(list_ipv6, base::Bind(&MergeResults, callback, list_ipv4)); + } else { + NOTREACHED() << "Should not be called with IPv4-only addresses."; + callback.Run(true, list); + } + } + + private: + AddressSorterWin sorter_; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterWinXP); +}; + +} // namespace + +// static +scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() { + if (base::win::GetVersion() < base::win::VERSION_VISTA) + return scoped_ptr<AddressSorter>(new AddressSorterWinXP()); + return scoped_ptr<AddressSorter>(new AddressSorterWin()); +} + +} // namespace net + diff --git a/net/dns/dns_client.cc b/net/dns/dns_client.cc index 5381452..859fc27 100644 --- a/net/dns/dns_client.cc +++ b/net/dns/dns_client.cc @@ -7,6 +7,7 @@ #include "base/bind.h" #include "base/rand_util.h" #include "net/base/net_log.h" +#include "net/dns/address_sorter.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_session.h" #include "net/dns/dns_transaction.h" @@ -18,7 +19,9 @@ namespace { class DnsClientImpl : public DnsClient { public: - explicit DnsClientImpl(NetLog* net_log) : net_log_(net_log) {} + explicit DnsClientImpl(NetLog* net_log) + : address_sorter_(AddressSorter::CreateAddressSorter()), + net_log_(net_log) {} virtual void SetConfig(const DnsConfig& config) OVERRIDE { factory_.reset(); @@ -40,9 +43,14 @@ class DnsClientImpl : public DnsClient { return session_.get() ? factory_.get() : NULL; } + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return address_sorter_.get(); + } + private: scoped_refptr<DnsSession> session_; scoped_ptr<DnsTransactionFactory> factory_; + scoped_ptr<AddressSorter> address_sorter_; NetLog* net_log_; }; diff --git a/net/dns/dns_client.h b/net/dns/dns_client.h index 13aa0bf8..650c7d0 100644 --- a/net/dns/dns_client.h +++ b/net/dns/dns_client.h @@ -10,12 +10,14 @@ namespace net { +class AddressSorter; struct DnsConfig; class DnsTransactionFactory; class NetLog; -// Convenience wrapper allows easy injection of DnsTransaction into -// HostResolverImpl. +// Convenience wrapper which allows easy injection of DnsTransaction into +// HostResolverImpl. Pointers returned by the Get* methods are only guaranteed +// to remain valid until next time SetConfig is called. class NET_EXPORT DnsClient { public: virtual ~DnsClient() {} @@ -29,6 +31,9 @@ class NET_EXPORT DnsClient { // Returns NULL if the current config is not valid. virtual DnsTransactionFactory* GetTransactionFactory() = 0; + // Returns NULL if the current config is not valid. + virtual AddressSorter* GetAddressSorter() = 0; + // Creates default client. static scoped_ptr<DnsClient> CreateClient(NetLog* net_log); }; diff --git a/net/dns/dns_response.cc b/net/dns/dns_response.cc index 4ad6465..4ab7296 100644 --- a/net/dns/dns_response.cc +++ b/net/dns/dns_response.cc @@ -280,15 +280,13 @@ DnsResponse::Result DnsResponse::ParseToAddressList( } // TODO(szym): Extract TTL for NODATA results. http://crbug.com/115051 - if (ip_addresses.empty()) - return DNS_NO_ADDRESSES; // getcanonname in eglibc returns the first owner name of an A or AAAA RR. // If the response passed all the checks so far, then |expected_name| is it. *addr_list = AddressList::CreateFromIPAddressList(ip_addresses, expected_name); *ttl = base::TimeDelta::FromSeconds(std::min(cname_ttl_sec, addr_ttl_sec)); - return DNS_SUCCESS; + return DNS_PARSE_OK; } } // namespace net diff --git a/net/dns/dns_response.h b/net/dns/dns_response.h index 5ade458..fe80c30 100644 --- a/net/dns/dns_response.h +++ b/net/dns/dns_response.h @@ -81,7 +81,7 @@ class NET_EXPORT_PRIVATE DnsResponse { public: // Possible results from ParseToAddressList. enum Result { - DNS_SUCCESS = 0, + DNS_PARSE_OK = 0, DNS_MALFORMED_RESPONSE, // DnsRecordParser failed before the end of // packet. DNS_MALFORMED_CNAME, // Could not parse CNAME out of RRDATA. @@ -90,7 +90,7 @@ class NET_EXPORT_PRIVATE DnsResponse { DNS_SIZE_MISMATCH, // Got an address but size does not match. DNS_CNAME_AFTER_ADDRESS, // Found CNAME after an address record. DNS_ADDRESS_TTL_MISMATCH, // TTL of all address records are not identical. - DNS_NO_ADDRESSES, // No address records found. + DNS_NO_ADDRESSES, // OBSOLETE. No longer used. // Only add new values here. DNS_PARSE_RESULT_MAX, // Bounding value for histograms. }; diff --git a/net/dns/dns_response_unittest.cc b/net/dns/dns_response_unittest.cc index bb053bc..6ed6f06 100644 --- a/net/dns/dns_response_unittest.cc +++ b/net/dns/dns_response_unittest.cc @@ -299,7 +299,7 @@ TEST(DnsResponseTest, ParseToAddressList) { DnsResponse response(t.response_data, t.response_size, t.query_size); AddressList addr_list; base::TimeDelta ttl; - EXPECT_EQ(DnsResponse::DNS_SUCCESS, + EXPECT_EQ(DnsResponse::DNS_PARSE_OK, response.ParseToAddressList(&addr_list, &ttl)); std::vector<const char*> expected_addresses( t.expected_addresses, @@ -429,8 +429,9 @@ TEST(DnsResponseTest, ParseToAddressListFail) { DnsResponse::DNS_CNAME_AFTER_ADDRESS }, { kResponseTTLMismatch, arraysize(kResponseTTLMismatch), DnsResponse::DNS_ADDRESS_TTL_MISMATCH }, + // Not actually a failure, just an empty result. { kResponseNoAddresses, arraysize(kResponseNoAddresses), - DnsResponse::DNS_NO_ADDRESSES }, + DnsResponse::DNS_PARSE_OK }, }; const size_t kQuerySize = 12 + 7; diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc index 051f595..73d208c 100644 --- a/net/dns/dns_test_util.cc +++ b/net/dns/dns_test_util.cc @@ -14,6 +14,7 @@ #include "net/base/dns_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "net/dns/address_sorter.h" #include "net/dns/dns_client.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_protocol.h" @@ -25,19 +26,29 @@ namespace net { namespace { -// A DnsTransaction which responds with loopback to all queries starting with -// "ok", fails synchronously on all queries starting with "er", and NXDOMAIN to -// all others. +// A DnsTransaction which uses MockDnsClientRuleList to determine the response. class MockTransaction : public DnsTransaction, public base::SupportsWeakPtr<MockTransaction> { public: - MockTransaction(const std::string& hostname, + MockTransaction(const MockDnsClientRuleList& rules, + const std::string& hostname, uint16 qtype, const DnsTransactionFactory::CallbackType& callback) - : hostname_(hostname), + : result_(MockDnsClientRule::FAIL_SYNC), + hostname_(hostname), qtype_(qtype), callback_(callback), started_(false) { + // Find the relevant rule which matches |qtype| and prefix of |hostname|. + for (size_t i = 0; i < rules.size(); ++i) { + const std::string& prefix = rules[i].prefix; + if ((rules[i].qtype == qtype) && + (hostname.size() >= prefix.size()) && + (hostname.compare(0, prefix.size(), prefix) == 0)) { + result_ = rules[i].result; + break; + } + } } virtual const std::string& GetHostname() const OVERRIDE { @@ -51,7 +62,7 @@ class MockTransaction : public DnsTransaction, virtual int Start() OVERRIDE { EXPECT_FALSE(started_); started_ = true; - if (hostname_.substr(0, 2) == "er") + if (MockDnsClientRule::FAIL_SYNC == result_) return ERR_NAME_NOT_RESOLVED; // Using WeakPtr to cleanly cancel when transaction is destroyed. MessageLoop::current()->PostTask( @@ -62,54 +73,66 @@ class MockTransaction : public DnsTransaction, private: void Finish() { - if (hostname_.substr(0, 2) == "ok") { - std::string qname; - DNSDomainFromDot(hostname_, &qname); - DnsQuery query(0, qname, qtype_); - - DnsResponse response; - char* buffer = response.io_buffer()->data(); - int nbytes = query.io_buffer()->size(); - memcpy(buffer, query.io_buffer()->data(), nbytes); - - const uint16 kPointerToQueryName = - static_cast<uint16>(0xc000 | sizeof(net::dns_protocol::Header)); - - const uint32 kTTL = 86400; // One day. - - // Size of RDATA which is a IPv4 or IPv6 address. - size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? - net::kIPv4AddressSize : net::kIPv6AddressSize; - - // 12 is the sum of sizes of the compressed name reference, TYPE, - // CLASS, TTL and RDLENGTH. - size_t answer_size = 12 + rdata_size; - - // Write answer with loopback IP address. - reinterpret_cast<dns_protocol::Header*>(buffer)->ancount = - base::HostToNet16(1); - BigEndianWriter writer(buffer + nbytes, answer_size); - writer.WriteU16(kPointerToQueryName); - writer.WriteU16(qtype_); - writer.WriteU16(net::dns_protocol::kClassIN); - writer.WriteU32(kTTL); - writer.WriteU16(rdata_size); - if (qtype_ == net::dns_protocol::kTypeA) { - char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; - writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); - } else { - char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1 }; - writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); - } - - EXPECT_TRUE(response.InitParse(nbytes + answer_size, query)); - callback_.Run(this, OK, &response); - } else { - callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + switch (result_) { + case MockDnsClientRule::EMPTY: + case MockDnsClientRule::OK: { + std::string qname; + DNSDomainFromDot(hostname_, &qname); + DnsQuery query(0, qname, qtype_); + + DnsResponse response; + char* buffer = response.io_buffer()->data(); + int nbytes = query.io_buffer()->size(); + memcpy(buffer, query.io_buffer()->data(), nbytes); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(buffer); + header->flags |= dns_protocol::kFlagResponse; + + if (MockDnsClientRule::OK == result_) { + const uint16 kPointerToQueryName = + static_cast<uint16>(0xc000 | sizeof(*header)); + + const uint32 kTTL = 86400; // One day. + + // Size of RDATA which is a IPv4 or IPv6 address. + size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? + net::kIPv4AddressSize : net::kIPv6AddressSize; + + // 12 is the sum of sizes of the compressed name reference, TYPE, + // CLASS, TTL and RDLENGTH. + size_t answer_size = 12 + rdata_size; + + // Write answer with loopback IP address. + header->ancount = base::HostToNet16(1); + BigEndianWriter writer(buffer + nbytes, answer_size); + writer.WriteU16(kPointerToQueryName); + writer.WriteU16(qtype_); + writer.WriteU16(net::dns_protocol::kClassIN); + writer.WriteU32(kTTL); + writer.WriteU16(rdata_size); + if (qtype_ == net::dns_protocol::kTypeA) { + char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; + writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); + } else { + char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1 }; + writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); + } + nbytes += answer_size; + } + EXPECT_TRUE(response.InitParse(nbytes, query)); + callback_.Run(this, OK, &response); + } break; + case MockDnsClientRule::FAIL_ASYNC: + callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + break; + default: + NOTREACHED(); + break; } } + MockDnsClientRule::Result result_; const std::string hostname_; const uint16 qtype_; DnsTransactionFactory::CallbackType callback_; @@ -120,7 +143,8 @@ class MockTransaction : public DnsTransaction, // A DnsTransactionFactory which creates MockTransaction. class MockTransactionFactory : public DnsTransactionFactory { public: - MockTransactionFactory() {} + explicit MockTransactionFactory(const MockDnsClientRuleList& rules) + : rules_(rules) {} virtual ~MockTransactionFactory() {} virtual scoped_ptr<DnsTransaction> CreateTransaction( @@ -129,14 +153,29 @@ class MockTransactionFactory : public DnsTransactionFactory { const DnsTransactionFactory::CallbackType& callback, const BoundNetLog&) OVERRIDE { return scoped_ptr<DnsTransaction>( - new MockTransaction(hostname, qtype, callback)); + new MockTransaction(rules_, hostname, qtype, callback)); + } + + private: + MockDnsClientRuleList rules_; +}; + +class MockAddressSorter : public AddressSorter { + public: + virtual ~MockAddressSorter() {} + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + // Do nothing. + callback.Run(true, list); } }; // MockDnsClient provides MockTransactionFactory. class MockDnsClient : public DnsClient { public: - explicit MockDnsClient(const DnsConfig& config) : config_(config) {} + MockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) + : config_(config), factory_(rules) {} virtual ~MockDnsClient() {} virtual void SetConfig(const DnsConfig& config) OVERRIDE { @@ -151,16 +190,22 @@ class MockDnsClient : public DnsClient { return config_.IsValid() ? &factory_ : NULL; } + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return &address_sorter_; + } + private: DnsConfig config_; MockTransactionFactory factory_; + MockAddressSorter address_sorter_; }; } // namespace // static -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config) { - return scoped_ptr<DnsClient>(new MockDnsClient(config)); +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) { + return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); } MockDnsConfigService::~MockDnsConfigService() { diff --git a/net/dns/dns_test_util.h b/net/dns/dns_test_util.h index 53a0620..2d03d0e 100644 --- a/net/dns/dns_test_util.h +++ b/net/dns/dns_test_util.h @@ -5,6 +5,9 @@ #ifndef NET_DNS_DNS_TEST_UTIL_H_ #define NET_DNS_DNS_TEST_UTIL_H_ +#include <string> +#include <vector> + #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" #include "net/dns/dns_config_service.h" @@ -166,8 +169,25 @@ static const int kT3TTL = 0x00000015; static const unsigned kT3RecordCount = arraysize(kT3IpAddresses) + 2; class DnsClient; + +struct MockDnsClientRule { + enum Result { + FAIL_SYNC, // Fail synchronously with ERR_NAME_NOT_RESOLVED. + FAIL_ASYNC, // Fail asynchronously with ERR_NAME_NOT_RESOLVED. + EMPTY, // Return an empty response. + OK, // Return a response with loopback address. + }; + + std::string prefix; + uint16 qtype; + Result result; +}; + +typedef std::vector<MockDnsClientRule> MockDnsClientRuleList; + // Creates mock DnsClient for testing HostResolverImpl. -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config); +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules); class MockDnsConfigService : public DnsConfigService { public: |