diff options
author | szym@chromium.org <szym@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-14 22:36:36 +0000 |
---|---|---|
committer | szym@chromium.org <szym@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-14 22:36:36 +0000 |
commit | 6c41190e39d8ecdc0dbcc50bd87125f39c39a101 (patch) | |
tree | 80be44493d34177ca8eb0e1c0f0fe00eaeac1845 /net | |
parent | 5d4b7adcca413109257be1a72a18d883480c3cc0 (diff) | |
download | chromium_src-6c41190e39d8ecdc0dbcc50bd87125f39c39a101.zip chromium_src-6c41190e39d8ecdc0dbcc50bd87125f39c39a101.tar.gz chromium_src-6c41190e39d8ecdc0dbcc50bd87125f39c39a101.tar.bz2 |
[net/dns] Resolve AF_UNSPEC on dual-stacked systems. Sort addresses according to RFC3484.
BUG=113993
TEST=./net_unittests --gtest_filter=AddressSorter*:HostResolverImplDnsTest.DnsTaskUnspec
Review URL: https://chromiumcodereview.appspot.com/10442098
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@151586 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/base/host_resolver_impl.cc | 162 | ||||
-rw-r--r-- | net/base/host_resolver_impl.h | 9 | ||||
-rw-r--r-- | net/base/host_resolver_impl_unittest.cc | 58 | ||||
-rw-r--r-- | net/base/net_error_list.h | 3 | ||||
-rw-r--r-- | net/base/net_util.cc | 26 | ||||
-rw-r--r-- | net/base/net_util.h | 8 | ||||
-rw-r--r-- | net/base/net_util_unittest.cc | 23 | ||||
-rw-r--r-- | net/dns/address_sorter.h | 46 | ||||
-rw-r--r-- | net/dns/address_sorter_posix.cc | 428 | ||||
-rw-r--r-- | net/dns/address_sorter_posix.h | 94 | ||||
-rw-r--r-- | net/dns/address_sorter_posix_unittest.cc | 325 | ||||
-rw-r--r-- | net/dns/address_sorter_unittest.cc | 49 | ||||
-rw-r--r-- | net/dns/address_sorter_win.cc | 197 | ||||
-rw-r--r-- | net/dns/dns_client.cc | 10 | ||||
-rw-r--r-- | net/dns/dns_client.h | 9 | ||||
-rw-r--r-- | net/dns/dns_response.cc | 4 | ||||
-rw-r--r-- | net/dns/dns_response.h | 4 | ||||
-rw-r--r-- | net/dns/dns_response_unittest.cc | 5 | ||||
-rw-r--r-- | net/dns/dns_test_util.cc | 157 | ||||
-rw-r--r-- | net/dns/dns_test_util.h | 22 | ||||
-rw-r--r-- | net/net.gyp | 6 | ||||
-rw-r--r-- | net/tools/dns_fuzz_stub/dns_fuzz_stub.cc | 2 |
22 files changed, 1536 insertions, 111 deletions
diff --git a/net/base/host_resolver_impl.cc b/net/base/host_resolver_impl.cc index 68b6bd5..8b42a22 100644 --- a/net/base/host_resolver_impl.cc +++ b/net/base/host_resolver_impl.cc @@ -38,6 +38,7 @@ #include "net/base/net_errors.h" #include "net/base/net_log.h" #include "net/base/net_util.h" +#include "net/dns/address_sorter.h" #include "net/dns/dns_client.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_protocol.h" @@ -609,7 +610,7 @@ class HostResolverImpl::ProcTask void Cancel() { DCHECK(origin_loop_->BelongsToCurrentThread()); - if (was_canceled()) + if (was_canceled() || was_completed()) return; callback_.Reset(); @@ -1042,32 +1043,33 @@ class HostResolverImpl::IPv6ProbeJob // Resolves the hostname using DnsTransaction. // TODO(szym): This could be moved to separate source file as well. -class HostResolverImpl::DnsTask { +class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { public: typedef base::Callback<void(int net_error, const AddressList& addr_list, base::TimeDelta ttl)> Callback; - DnsTask(DnsTransactionFactory* factory, + DnsTask(DnsClient* client, const Key& key, const Callback& callback, const BoundNetLog& job_net_log) - : callback_(callback), net_log_(job_net_log) { - DCHECK(factory); + : client_(client), + family_(key.address_family), + callback_(callback), + net_log_(job_net_log) { + DCHECK(client); DCHECK(!callback.is_null()); - // For now we treat ADDRESS_FAMILY_UNSPEC as if it was IPV4. - uint16 qtype = (key.address_family == ADDRESS_FAMILY_IPV6) - ? dns_protocol::kTypeAAAA - : dns_protocol::kTypeA; - // TODO(szym): Implement "happy eyeballs". - transaction_ = factory->CreateTransaction( + // If unspecified, do IPv4 first, because suffix search will be faster. + uint16 qtype = (family_ == ADDRESS_FAMILY_IPV6) ? + dns_protocol::kTypeAAAA : + dns_protocol::kTypeA; + transaction_ = client_->GetTransactionFactory()->CreateTransaction( key.hostname, qtype, base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), - base::TimeTicks::Now()), + true /* first_query */, base::TimeTicks::Now()), net_log_); - DCHECK(transaction_.get()); } int Start() { @@ -1075,47 +1077,138 @@ class HostResolverImpl::DnsTask { return transaction_->Start(); } - void OnTransactionComplete(const base::TimeTicks& start_time, + private: + void OnTransactionComplete(bool first_query, + const base::TimeTicks& start_time, DnsTransaction* transaction, int net_error, const DnsResponse* response) { DCHECK(transaction); // Run |callback_| last since the owning Job will then delete this DnsTask. - DnsResponse::Result result = DnsResponse::DNS_SUCCESS; - if (net_error == OK) { - CHECK(response); - DNS_HISTOGRAM("AsyncDNS.TransactionSuccess", + if (net_error != OK) { + DNS_HISTOGRAM("AsyncDNS.TransactionFailure", base::TimeTicks::Now() - start_time); - AddressList addr_list; - base::TimeDelta ttl; - result = response->ParseToAddressList(&addr_list, &ttl); - UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ParseToAddressList", - result, - DnsResponse::DNS_PARSE_RESULT_MAX); - if (result == DnsResponse::DNS_SUCCESS) { - net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, - addr_list.CreateNetLogCallback()); - callback_.Run(net_error, addr_list, ttl); + OnFailure(net_error, DnsResponse::DNS_PARSE_OK); + return; + } + + CHECK(response); + DNS_HISTOGRAM("AsyncDNS.TransactionSuccess", + base::TimeTicks::Now() - start_time); + AddressList addr_list; + base::TimeDelta ttl; + DnsResponse::Result result = response->ParseToAddressList(&addr_list, &ttl); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ParseToAddressList", + result, + DnsResponse::DNS_PARSE_RESULT_MAX); + if (result != DnsResponse::DNS_PARSE_OK) { + // Fail even if the other query succeeds. + OnFailure(ERR_DNS_MALFORMED_RESPONSE, result); + return; + } + + bool needs_sort = false; + if (first_query) { + DCHECK(client_->GetConfig()) << + "Transaction should have been aborted when config changed!"; + if (family_ == ADDRESS_FAMILY_IPV6) { + needs_sort = (addr_list.size() > 1); + } else if (family_ == ADDRESS_FAMILY_UNSPECIFIED) { + first_addr_list_ = addr_list; + first_ttl_ = ttl; + // Use fully-qualified domain name to avoid search. + transaction_ = client_->GetTransactionFactory()->CreateTransaction( + response->GetDottedName() + ".", + dns_protocol::kTypeAAAA, + base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), + false /* first_query */, base::TimeTicks::Now()), + net_log_); + net_error = transaction_->Start(); + if (net_error != ERR_IO_PENDING) + OnFailure(net_error, DnsResponse::DNS_PARSE_OK); return; } - net_error = ERR_DNS_MALFORMED_RESPONSE; } else { - DNS_HISTOGRAM("AsyncDNS.TransactionFailure", + DCHECK_EQ(ADDRESS_FAMILY_UNSPECIFIED, family_); + bool has_ipv6_addresses = !addr_list.empty(); + if (!first_addr_list_.empty()) { + ttl = std::min(ttl, first_ttl_); + // Place IPv4 addresses after IPv6. + addr_list.insert(addr_list.end(), first_addr_list_.begin(), + first_addr_list_.end()); + } + needs_sort = (has_ipv6_addresses && addr_list.size() > 1); + } + + if (addr_list.empty()) { + // TODO(szym): Don't fallback to ProcTask in this case. + OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK); + return; + } + + if (needs_sort) { + // Sort could complete synchronously. + client_->GetAddressSorter()->Sort( + addr_list, + base::Bind(&DnsTask::OnSortComplete, AsWeakPtr(), + base::TimeTicks::Now(), + ttl)); + } else { + OnSuccess(addr_list, ttl); + } + } + + void OnSortComplete(base::TimeTicks start_time, + base::TimeDelta ttl, + bool success, + const AddressList& addr_list) { + if (!success) { + DNS_HISTOGRAM("AsyncDNS.SortFailure", base::TimeTicks::Now() - start_time); + OnFailure(ERR_DNS_SORT_ERROR, DnsResponse::DNS_PARSE_OK); + return; + } + + DNS_HISTOGRAM("AsyncDNS.SortSuccess", + base::TimeTicks::Now() - start_time); + + // AddressSorter prunes unusable destinations. + if (addr_list.empty()) { + LOG(WARNING) << "Address list empty after RFC3484 sort"; + OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK); + return; } + + OnSuccess(addr_list, ttl); + } + + void OnFailure(int net_error, DnsResponse::Result result) { + DCHECK_NE(OK, net_error); net_log_.EndEvent( NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, base::Bind(&NetLogDnsTaskFailedCallback, net_error, result)); callback_.Run(net_error, AddressList(), base::TimeDelta()); } - private: + void OnSuccess(const AddressList& addr_list, base::TimeDelta ttl) { + net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, + addr_list.CreateNetLogCallback()); + callback_.Run(OK, addr_list, ttl); + } + + DnsClient* client_; + AddressFamily family_; // The listener to the results of this DnsTask. Callback callback_; - const BoundNetLog net_log_; scoped_ptr<DnsTransaction> transaction_; + + // Results from the first transaction. Used only if |family_| is unspecified. + AddressList first_addr_list_; + base::TimeDelta first_ttl_; + + DISALLOW_COPY_AND_ASSIGN(DnsTask); }; //----------------------------------------------------------------------------- @@ -1214,7 +1307,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { } // Marks |req| as cancelled. If it was the last active Request, also finishes - // this Job marking it either as aborted or cancelled, and deletes it. + // this Job, marking it as cancelled, and deletes it. void CancelRequest(Request* req) { DCHECK_EQ(key_.hostname, req->info().hostname()); DCHECK(!req->was_canceled()); @@ -1381,7 +1474,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { void StartDnsTask() { DCHECK(resolver_->HaveDnsConfig()); dns_task_.reset(new DnsTask( - resolver_->dns_client_->GetTransactionFactory(), + resolver_->dns_client_.get(), key_, base::Bind(&Job::OnDnsTaskComplete, base::Unretained(this)), net_log_)); @@ -1415,6 +1508,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { } UmaAsyncDnsResolveStatus(RESOLVE_STATUS_DNS_SUCCESS); + CompleteRequests(net_error, addr_list, ttl); } diff --git a/net/base/host_resolver_impl.h b/net/base/host_resolver_impl.h index 3c1b5d2..7044be3 100644 --- a/net/base/host_resolver_impl.h +++ b/net/base/host_resolver_impl.h @@ -19,14 +19,17 @@ #include "net/base/host_resolver.h" #include "net/base/host_resolver_proc.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" #include "net/base/network_change_notifier.h" #include "net/base/prioritized_dispatcher.h" -#include "net/dns/dns_client.h" -#include "net/dns/dns_config_service.h" namespace net { +class BoundNetLog; +class DnsClient; +struct DnsConfig; +class DnsConfigService; +class NetLog; + // For each hostname that is requested, HostResolver creates a // HostResolverImpl::Job. When this job gets dispatched it creates a ProcTask // which runs the given HostResolverProc on a WorkerPool thread. If requests for diff --git a/net/base/host_resolver_impl_unittest.cc b/net/base/host_resolver_impl_unittest.cc index 91236b3..6571c46 100644 --- a/net/base/host_resolver_impl_unittest.cc +++ b/net/base/host_resolver_impl_unittest.cc @@ -1236,21 +1236,46 @@ DnsConfig CreateValidDnsConfig() { class HostResolverImplDnsTest : public HostResolverImplTest { protected: virtual void SetUp() OVERRIDE { + AddDnsRule("er", dns_protocol::kTypeA, MockDnsClientRule::FAIL_SYNC); + AddDnsRule("er", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL_SYNC); + AddDnsRule("nx", dns_protocol::kTypeA, MockDnsClientRule::FAIL_ASYNC); + AddDnsRule("nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL_ASYNC); + AddDnsRule("ok", dns_protocol::kTypeA, MockDnsClientRule::OK); + AddDnsRule("ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK); + AddDnsRule("4ok", dns_protocol::kTypeA, MockDnsClientRule::OK); + AddDnsRule("4ok", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY); + AddDnsRule("6ok", dns_protocol::kTypeA, MockDnsClientRule::EMPTY); + AddDnsRule("6ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK); + AddDnsRule("4nx", dns_protocol::kTypeA, MockDnsClientRule::OK); + AddDnsRule("4nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL_ASYNC); + CreateResolver(); + } + + void CreateResolver() { config_service_ = new MockDnsConfigService(); resolver_.reset(new HostResolverImpl( HostCache::CreateDefaultCache(), DefaultLimits(), DefaultParams(proc_), scoped_ptr<DnsConfigService>(config_service_), - CreateMockDnsClient(DnsConfig()), + CreateMockDnsClient(DnsConfig(), dns_rules_), NULL)); } + // Adds a rule to |dns_rules_|. Must be followed by |CreateResolver| to apply. + void AddDnsRule(const std::string& prefix, + uint16 qtype, + MockDnsClientRule::Result result) { + MockDnsClientRule rule = { prefix, qtype, result }; + dns_rules_.push_back(rule); + } + void ChangeDnsConfig(const DnsConfig& config) { config_service_->ChangeConfig(config); config_service_->ChangeHosts(config.hosts); } + MockDnsClientRuleList dns_rules_; // Owned by |resolver_|. MockDnsConfigService* config_service_; }; @@ -1298,11 +1323,38 @@ TEST_F(HostResolverImplDnsTest, DnsTask) { EXPECT_TRUE(requests_[5]->HasOneAddress("192.168.1.102", 80)); } +TEST_F(HostResolverImplDnsTest, DnsTaskUnspec) { + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies("4nx", "192.168.1.101"); + // All other hostnames will fail in proc_. + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("6ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4nx", 80)->Resolve()); + + proc_->SignalMultiple(requests_.size()); + + for (size_t i = 0; i < requests_.size(); ++i) + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + + EXPECT_EQ(2u, requests_[0]->NumberOfAddresses()); + EXPECT_TRUE(requests_[0]->HasAddress("127.0.0.1", 80)); + EXPECT_TRUE(requests_[0]->HasAddress("::1", 80)); + EXPECT_EQ(1u, requests_[1]->NumberOfAddresses()); + EXPECT_TRUE(requests_[1]->HasAddress("127.0.0.1", 80)); + EXPECT_EQ(1u, requests_[2]->NumberOfAddresses()); + EXPECT_TRUE(requests_[2]->HasAddress("::1", 80)); + EXPECT_EQ(1u, requests_[3]->NumberOfAddresses()); + EXPECT_TRUE(requests_[3]->HasAddress("192.168.1.101", 80)); +} + TEST_F(HostResolverImplDnsTest, ServeFromHosts) { // Initially, use empty HOSTS file. ChangeDnsConfig(CreateValidDnsConfig()); - proc_->AddRuleForAllFamilies("", "0.0.0.0"); // Default to failures. + proc_->AddRuleForAllFamilies("", ""); // Default to failures. proc_->SignalMultiple(1u); // For the first request which misses. Request* req0 = CreateRequest("er_ipv4", 80); @@ -1353,7 +1405,7 @@ TEST_F(HostResolverImplDnsTest, ServeFromHosts) { TEST_F(HostResolverImplDnsTest, BypassDnsTask) { ChangeDnsConfig(CreateValidDnsConfig()); - proc_->AddRuleForAllFamilies("", "0.0.0.0"); // Default to failures. + proc_->AddRuleForAllFamilies("", ""); // Default to failures. EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok.local", 80)->Resolve()); EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok.local.", 80)->Resolve()); diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h index 82745fe..63ec73a 100644 --- a/net/base/net_error_list.h +++ b/net/base/net_error_list.h @@ -658,3 +658,6 @@ NET_ERROR(DNS_CACHE_MISS, -804) // Suffix search list rules prevent resolution of the given host name. NET_ERROR(DNS_SEARCH_EMPTY, -805) + +// Failed to sort addresses according to RFC3484. +NET_ERROR(DNS_SORT_ERROR, -806) diff --git a/net/base/net_util.cc b/net/base/net_util.cc index 56e779a..d39d7ab 100644 --- a/net/base/net_util.cc +++ b/net/base/net_util.cc @@ -2216,6 +2216,12 @@ bool ParseIPLiteralToNumber(const std::string& ip_literal, return family == url_canon::CanonHostInfo::IPV4; } +namespace { + +const unsigned char kIPv4MappedPrefix[] = + { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }; +} + IPAddressNumber ConvertIPv4NumberToIPv6Number( const IPAddressNumber& ipv4_number) { DCHECK(ipv4_number.size() == 4); @@ -2224,13 +2230,27 @@ IPAddressNumber ConvertIPv4NumberToIPv6Number( // <80 bits of zeros> + <16 bits of ones> + <32-bit IPv4 address>. IPAddressNumber ipv6_number; ipv6_number.reserve(16); - ipv6_number.insert(ipv6_number.end(), 10, 0); - ipv6_number.push_back(0xFF); - ipv6_number.push_back(0xFF); + ipv6_number.insert(ipv6_number.end(), + kIPv4MappedPrefix, + kIPv4MappedPrefix + arraysize(kIPv4MappedPrefix)); ipv6_number.insert(ipv6_number.end(), ipv4_number.begin(), ipv4_number.end()); return ipv6_number; } +bool IsIPv4Mapped(const IPAddressNumber& address) { + if (address.size() != kIPv6AddressSize) + return false; + return std::equal(address.begin(), + address.begin() + arraysize(kIPv4MappedPrefix), + kIPv4MappedPrefix); +} + +IPAddressNumber ConvertIPv4MappedToIPv4(const IPAddressNumber& address) { + DCHECK(IsIPv4Mapped(address)); + return IPAddressNumber(address.begin() + arraysize(kIPv4MappedPrefix), + address.end()); +} + bool ParseCIDRBlock(const std::string& cidr_literal, IPAddressNumber* ip_number, size_t* prefix_length_in_bits) { diff --git a/net/base/net_util.h b/net/base/net_util.h index 69c9a97..fef0a78 100644 --- a/net/base/net_util.h +++ b/net/base/net_util.h @@ -424,6 +424,14 @@ NET_EXPORT_PRIVATE bool ParseIPLiteralToNumber(const std::string& ip_literal, NET_EXPORT_PRIVATE IPAddressNumber ConvertIPv4NumberToIPv6Number( const IPAddressNumber& ipv4_number); +// Returns true iff |address| is an IPv4-mapped IPv6 address. +NET_EXPORT_PRIVATE bool IsIPv4Mapped(const IPAddressNumber& address); + +// Converts an IPv4-mapped IPv6 address to IPv4 address. Should only be called +// on IPv4-mapped IPv6 addresses. +NET_EXPORT_PRIVATE IPAddressNumber ConvertIPv4MappedToIPv4( + const IPAddressNumber& address); + // Parses an IP block specifier from CIDR notation to an // (IP address, prefix length) pair. Returns true on success and fills // |*ip_number| with the numeric value of the IP address and sets diff --git a/net/base/net_util_unittest.cc b/net/base/net_util_unittest.cc index 283ac6c..ffab57e 100644 --- a/net/base/net_util_unittest.cc +++ b/net/base/net_util_unittest.cc @@ -3049,6 +3049,29 @@ TEST(NetUtilTest, ConvertIPv4NumberToIPv6Number) { DumpIPNumber(ipv6_number)); } +TEST(NetUtilTest, IsIPv4Mapped) { + IPAddressNumber ipv4_number; + EXPECT_TRUE(ParseIPLiteralToNumber("192.168.0.1", &ipv4_number)); + EXPECT_FALSE(IsIPv4Mapped(ipv4_number)); + + IPAddressNumber ipv6_number; + EXPECT_TRUE(ParseIPLiteralToNumber("::1", &ipv4_number)); + EXPECT_FALSE(IsIPv4Mapped(ipv6_number)); + + IPAddressNumber ipv4mapped_number; + EXPECT_TRUE(ParseIPLiteralToNumber("::ffff:0101:1", &ipv4mapped_number)); + EXPECT_TRUE(IsIPv4Mapped(ipv4mapped_number)); +} + +TEST(NetUtilTest, ConvertIPv4MappedToIPv4) { + IPAddressNumber ipv4mapped_number; + EXPECT_TRUE(ParseIPLiteralToNumber("::ffff:0101:1", &ipv4mapped_number)); + IPAddressNumber expected; + EXPECT_TRUE(ParseIPLiteralToNumber("1.1.0.1", &expected)); + IPAddressNumber result = ConvertIPv4MappedToIPv4(ipv4mapped_number); + EXPECT_EQ(expected, result); +} + // Test parsing invalid CIDR notation literals. TEST(NetUtilTest, ParseCIDRBlock_Invalid) { const char* bad_literals[] = { diff --git a/net/dns/address_sorter.h b/net/dns/address_sorter.h new file mode 100644 index 0000000..6ac9430 --- /dev/null +++ b/net/dns/address_sorter.h @@ -0,0 +1,46 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_ADDRESS_SORTER_H_ +#define NET_DNS_ADDRESS_SORTER_H_ + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" + +namespace net { + +class AddressList; + +// Sorts AddressList according to RFC3484, by likelihood of successful +// connection. Depending on the platform, the sort could be performed +// asynchronously by the OS, or synchronously by local implementation. +// AddressSorter does not necessarily preserve port numbers on the sorted list. +class NET_EXPORT AddressSorter { + public: + typedef base::Callback<void(bool success, + const AddressList& list)> CallbackType; + + virtual ~AddressSorter() {} + + // Sorts |list|, which must include at least one IPv6 address. + // Calls |callback| upon completion. Could complete synchronously. Could + // complete after this AddressSorter is destroyed. + virtual void Sort(const AddressList& list, + const CallbackType& callback) const = 0; + + // Creates platform-dependent AddressSorter. + static scoped_ptr<AddressSorter> CreateAddressSorter(); + + protected: + AddressSorter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(AddressSorter); +}; + +} // namespace net + +#endif // NET_DNS_ADDRESS_SORTER_H_ diff --git a/net/dns/address_sorter_posix.cc b/net/dns/address_sorter_posix.cc new file mode 100644 index 0000000..807cbf8 --- /dev/null +++ b/net/dns/address_sorter_posix.cc @@ -0,0 +1,428 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter_posix.h" + +#include <netinet/in.h> + +#if defined(OS_MACOSX) || defined(OS_BSD) +#include <sys/socket.h> // Must be included before ifaddrs.h. +#include <ifaddrs.h> +#include <net/if.h> +#include <netinet/in_var.h> +#include <string.h> +#include <sys/ioctl.h> +#endif + +#include <algorithm> + +#include "base/eintr_wrapper.h" +#include "base/logging.h" +#include "base/memory/scoped_vector.h" +#include "net/socket/client_socket_factory.h" +#include "net/udp/datagram_client_socket.h" + +#if defined(OS_LINUX) +#include "net/base/address_tracker_linux.h" +#endif + +namespace net { + +namespace { + +// Address sorting is performed according to RFC3484 with revisions. +// http://tools.ietf.org/html/draft-ietf-6man-rfc3484bis-06 +// Precedence and label are separate to support override through /etc/gai.conf. + +// Returns true if |p1| should precede |p2| in the table. +// Sorts table by decreasing prefix size to allow longest prefix matching. +bool ComparePolicy(const AddressSorterPosix::PolicyEntry& p1, + const AddressSorterPosix::PolicyEntry& p2) { + return p1.prefix_length > p2.prefix_length; +} + +// Creates sorted PolicyTable from |table| with |size| entries. +AddressSorterPosix::PolicyTable LoadPolicy( + AddressSorterPosix::PolicyEntry* table, + size_t size) { + AddressSorterPosix::PolicyTable result(table, table + size); + std::sort(result.begin(), result.end(), ComparePolicy); + return result; +} + +// Search |table| for matching prefix of |address|. |table| must be sorted by +// descending prefix (prefix of another prefix must be later in table). +unsigned GetPolicyValue(const AddressSorterPosix::PolicyTable& table, + const IPAddressNumber& address) { + if (address.size() == kIPv4AddressSize) + return GetPolicyValue(table, ConvertIPv4NumberToIPv6Number(address)); + for (unsigned i = 0; i < table.size(); ++i) { + const AddressSorterPosix::PolicyEntry& entry = table[i]; + IPAddressNumber prefix(entry.prefix, entry.prefix + kIPv6AddressSize); + if (IPNumberMatchesPrefix(address, prefix, entry.prefix_length)) + return entry.value; + } + NOTREACHED(); + // The last entry is the least restrictive, so assume it's default. + return table.back().value; +} + +bool IsIPv6Multicast(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + return address[0] == 0xFF; +} + +AddressSorterPosix::AddressScope GetIPv6MulticastScope( + const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + return static_cast<AddressSorterPosix::AddressScope>(address[1] & 0x0F); +} + +bool IsIPv6Loopback(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_LOOPBACK + unsigned char kLoopback[kIPv6AddressSize] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + }; + return address == IPAddressNumber(kLoopback, kLoopback + kIPv6AddressSize); +} + +bool IsIPv6LinkLocal(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_LINKLOCAL + return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80); +} + +bool IsIPv6SiteLocal(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_SITELOCAL + return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0xC0); +} + +AddressSorterPosix::AddressScope GetScope( + const AddressSorterPosix::PolicyTable& ipv4_scope_table, + const IPAddressNumber& address) { + if (address.size() == kIPv6AddressSize) { + if (IsIPv6Multicast(address)) { + return GetIPv6MulticastScope(address); + } else if (IsIPv6Loopback(address) || IsIPv6LinkLocal(address)) { + return AddressSorterPosix::SCOPE_LINKLOCAL; + } else if (IsIPv6SiteLocal(address)) { + return AddressSorterPosix::SCOPE_SITELOCAL; + } else { + return AddressSorterPosix::SCOPE_GLOBAL; + } + } else if (address.size() == kIPv4AddressSize) { + return static_cast<AddressSorterPosix::AddressScope>( + GetPolicyValue(ipv4_scope_table, address)); + } else { + NOTREACHED(); + return AddressSorterPosix::SCOPE_NODELOCAL; + } +} + +// Default policy table. RFC 3484, Section 2.1. +AddressSorterPosix::PolicyEntry kDefaultPrecedenceTable[] = { + // ::1/128 -- loopback + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 50 }, + // ::/0 -- any + { { }, 0, 40 }, + // ::ffff:0:0/96 -- IPv4 mapped + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 35 }, + // 2002::/16 -- 6to4 + { { 0x20, 0x02, }, 16, 30 }, + // 2001::/32 -- Teredo + { { 0x20, 0x01, 0, 0 }, 32, 5 }, + // fc00::/7 -- unique local address + { { 0xFC }, 7, 3 }, + // ::/96 -- IPv4 compatible + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 1 }, + // fec0::/10 -- site-local expanded scope + { { 0xFE, 0xC0 }, 10, 1 }, + // 3ffe::/16 -- 6bone + { { 0x3F, 0xFE }, 16, 1 }, +}; + +AddressSorterPosix::PolicyEntry kDefaultLabelTable[] = { + // ::1/128 -- loopback + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 0 }, + // ::/0 -- any + { { }, 0, 1 }, + // ::ffff:0:0/96 -- IPv4 mapped + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 4 }, + // 2002::/16 -- 6to4 + { { 0x20, 0x02, }, 16, 2 }, + // 2001::/32 -- Teredo + { { 0x20, 0x01, 0, 0 }, 32, 5 }, + // fc00::/7 -- unique local address + { { 0xFC }, 7, 13 }, + // ::/96 -- IPv4 compatible + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 3 }, + // fec0::/10 -- site-local expanded scope + { { 0xFE, 0xC0 }, 10, 11 }, + // 3ffe::/16 -- 6bone + { { 0x3F, 0xFE }, 16, 12 }, +}; + +// Default mapping of IPv4 addresses to scope. +AddressSorterPosix::PolicyEntry kDefaultIPv4ScopeTable[] = { + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F }, 104, + AddressSorterPosix::SCOPE_LINKLOCAL }, + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xA9, 0xFE }, 112, + AddressSorterPosix::SCOPE_LINKLOCAL }, + { { }, 0, AddressSorterPosix::SCOPE_GLOBAL }, +}; + +// Returns number of matching initial bits between the addresses |a1| and |a2|. +unsigned CommonPrefixLength(const IPAddressNumber& a1, + const IPAddressNumber& a2) { + DCHECK_EQ(a1.size(), a2.size()); + for (size_t i = 0; i < a1.size(); ++i) { + unsigned diff = a1[i] ^ a2[i]; + if (!diff) + continue; + for (unsigned j = 0; j < CHAR_BIT; ++j) { + if (diff & (1 << (CHAR_BIT - 1))) + return i * CHAR_BIT + j; + diff <<= 1; + } + NOTREACHED(); + } + return a1.size() * CHAR_BIT; +} + +// Computes the number of leading 1-bits in |mask|. +unsigned MaskPrefixLength(const IPAddressNumber& mask) { + IPAddressNumber all_ones(mask.size(), 0xFF); + return CommonPrefixLength(mask, all_ones); +} + +struct DestinationInfo { + IPAddressNumber address; + AddressSorterPosix::AddressScope scope; + unsigned precedence; + unsigned label; + const AddressSorterPosix::SourceAddressInfo* src; + unsigned common_prefix_length; +}; + +// Returns true iff |dst_a| should precede |dst_b| in the address list. +// RFC 3484, section 6. +bool CompareDestinations(const DestinationInfo* dst_a, + const DestinationInfo* dst_b) { + // Rule 1: Avoid unusable destinations. + // Unusable destinations are already filtered out. + DCHECK(dst_a->src); + DCHECK(dst_b->src); + + // Rule 2: Prefer matching scope. + bool scope_match1 = (dst_a->src->scope == dst_a->scope); + bool scope_match2 = (dst_b->src->scope == dst_b->scope); + if (scope_match1 != scope_match2) + return scope_match1; + + // Rule 3: Avoid deprecated addresses. + if (dst_a->src->deprecated != dst_b->src->deprecated) + return !dst_a->src->deprecated; + + // Rule 4: Prefer home addresses. + if (dst_a->src->home != dst_b->src->home) + return dst_a->src->home; + + // Rule 5: Prefer matching label. + bool label_match1 = (dst_a->src->label == dst_a->label); + bool label_match2 = (dst_b->src->label == dst_b->label); + if (label_match1 != label_match2) + return label_match1; + + // Rule 6: Prefer higher precedence. + if (dst_a->precedence != dst_b->precedence) + return dst_a->precedence > dst_b->precedence; + + // Rule 7: Prefer native transport. + if (dst_a->src->native != dst_b->src->native) + return dst_a->src->native; + + // Rule 8: Prefer smaller scope. + if (dst_a->scope != dst_b->scope) + return dst_a->scope < dst_b->scope; + + // Rule 9: Use longest matching prefix. Only for matching address families. + if (dst_a->address.size() == dst_b->address.size()) { + if (dst_a->common_prefix_length != dst_b->common_prefix_length) + return dst_a->common_prefix_length > dst_b->common_prefix_length; + } + + // Rule 10: Leave the order unchanged. + // stable_sort takes care of that. + return false; +} + +} // namespace + +AddressSorterPosix::AddressSorterPosix(ClientSocketFactory* socket_factory) + : socket_factory_(socket_factory), + precedence_table_(LoadPolicy(kDefaultPrecedenceTable, + arraysize(kDefaultPrecedenceTable))), + label_table_(LoadPolicy(kDefaultLabelTable, + arraysize(kDefaultLabelTable))), + ipv4_scope_table_(LoadPolicy(kDefaultIPv4ScopeTable, + arraysize(kDefaultIPv4ScopeTable))) { + NetworkChangeNotifier::AddIPAddressObserver(this); + OnIPAddressChanged(); +} + +AddressSorterPosix::~AddressSorterPosix() { + NetworkChangeNotifier::RemoveIPAddressObserver(this); +} + +void AddressSorterPosix::Sort(const AddressList& list, + const CallbackType& callback) const { + DCHECK(CalledOnValidThread()); + ScopedVector<DestinationInfo> sort_list; + + for (size_t i = 0; i < list.size(); ++i) { + scoped_ptr<DestinationInfo> info(new DestinationInfo()); + info->address = list[i].address(); + info->scope = GetScope(ipv4_scope_table_, info->address); + info->precedence = GetPolicyValue(precedence_table_, info->address); + info->label = GetPolicyValue(label_table_, info->address); + + // Each socket can only be bound once. + scoped_ptr<DatagramClientSocket> socket( + socket_factory_->CreateDatagramClientSocket( + DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL /* NetLog */, + NetLog::Source())); + + // Even though no packets are sent, cannot use port 0 in Connect. + IPEndPoint dest(info->address, 80 /* port */); + int rv = socket->Connect(dest); + if (rv != OK) { + LOG(WARNING) << "Could not connect to " << dest.ToStringWithoutPort() + << " reason " << rv; + continue; + } + // Filter out unusable destinations. + IPEndPoint src; + rv = socket->GetLocalAddress(&src); + if (rv != OK) { + LOG(WARNING) << "Could not get local address for " + << src.ToStringWithoutPort() << " reason " << rv; + continue; + } + + SourceAddressInfo& src_info = source_map_[src.address()]; + if (src_info.scope == SCOPE_UNDEFINED) { + // If |source_info_| is out of date, |src| might be missing, but we still + // want to sort, even though the HostCache will be cleared soon. + FillPolicy(src.address(), &src_info); + } + info->src = &src_info; + + if (info->address.size() == src.address().size()) { + info->common_prefix_length = std::min( + CommonPrefixLength(info->address, src.address()), + info->src->prefix_length); + } + sort_list.push_back(info.release()); + } + + std::stable_sort(sort_list.begin(), sort_list.end(), CompareDestinations); + + AddressList result; + for (size_t i = 0; i < sort_list.size(); ++i) + result.push_back(IPEndPoint(sort_list[i]->address, 0 /* port */)); + + callback.Run(true, result); +} + +void AddressSorterPosix::OnIPAddressChanged() { + DCHECK(CalledOnValidThread()); + source_map_.clear(); +#if defined(OS_LINUX) + const internal::AddressTrackerLinux* tracker = + NetworkChangeNotifier::GetAddressTracker(); + if (!tracker) + return; + typedef internal::AddressTrackerLinux::AddressMap AddressMap; + AddressMap map = tracker->GetAddressMap(); + for (AddressMap::const_iterator it = map.begin(); it != map.end(); ++it) { + const IPAddressNumber& address = it->first; + const struct ifaddrmsg& msg = it->second; + SourceAddressInfo& info = source_map_[address]; + info.native = false; // TODO(szym): obtain this via netlink. + info.deprecated = msg.ifa_flags & IFA_F_DEPRECATED; + info.home = msg.ifa_flags & IFA_F_HOMEADDRESS; + info.prefix_length = msg.ifa_prefixlen; + FillPolicy(address, &info); + } +#elif defined(OS_MACOSX) || defined(OS_BSD) + // It's not clear we will receive notification when deprecated flag changes. + // Socket for ioctl. + int ioctl_socket = socket(AF_INET6, SOCK_DGRAM, 0); + if (ioctl_socket < 0) + return; + + struct ifaddrs* addrs; + int rv = getifaddrs(&addrs); + if (rv < 0) { + LOG(WARNING) << "getifaddrs failed " << rv; + close(ioctl_socket); + return; + } + + for (struct ifaddrs* ifa = addrs; ifa != NULL; ifa = ifa->ifa_next) { + IPEndPoint src; + if (!src.FromSockAddr(ifa->ifa_addr, ifa->ifa_addr->sa_len)) { + LOG(WARNING) << "FromSockAddr failed"; + continue; + } + SourceAddressInfo& info = source_map_[src.address()]; + // Note: no known way to fill in |native| and |home|. + info.native = info.home = info.deprecated = false; + if (ifa->ifa_addr->sa_family == AF_INET6) { + struct in6_ifreq ifr = {}; + strncpy(ifr.ifr_name, ifa->ifa_name, sizeof(ifr.ifr_name) - 1); + DCHECK_LE(ifa->ifa_addr->sa_len, sizeof(ifr.ifr_ifru.ifru_addr)); + memcpy(&ifr.ifr_ifru.ifru_addr, ifa->ifa_addr, ifa->ifa_addr->sa_len); + int rv = ioctl(ioctl_socket, SIOCGIFAFLAG_IN6, &ifr); + if (rv > 0) { + info.deprecated = ifr.ifr_ifru.ifru_flags & IN6_IFF_DEPRECATED; + } else { + LOG(WARNING) << "SIOCGIFAFLAG_IN6 failed " << rv; + } + } + if (ifa->ifa_netmask) { + IPEndPoint netmask; + if (netmask.FromSockAddr(ifa->ifa_netmask, ifa->ifa_addr->sa_len)) { + info.prefix_length = MaskPrefixLength(netmask.address()); + } else { + LOG(WARNING) << "FromSockAddr failed on netmask"; + } + } + FillPolicy(src.address(), &info); + } + freeifaddrs(addrs); + close(ioctl_socket); +#endif +} + +void AddressSorterPosix::FillPolicy(const IPAddressNumber& address, + SourceAddressInfo* info) const { + DCHECK(CalledOnValidThread()); + info->scope = GetScope(ipv4_scope_table_, address); + info->label = GetPolicyValue(label_table_, address); +} + +// static +scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() { + return scoped_ptr<AddressSorter>( + new AddressSorterPosix(ClientSocketFactory::GetDefaultFactory())); +} + +} // namespace net + diff --git a/net/dns/address_sorter_posix.h b/net/dns/address_sorter_posix.h new file mode 100644 index 0000000..1c88ad2 --- /dev/null +++ b/net/dns/address_sorter_posix.h @@ -0,0 +1,94 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_ADDRESS_SORTER_POSIX_H_ +#define NET_DNS_ADDRESS_SORTER_POSIX_H_ + +#include <map> +#include <vector> + +#include "base/threading/non_thread_safe.h" +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" +#include "net/base/network_change_notifier.h" +#include "net/dns/address_sorter.h" + +namespace net { + +class ClientSocketFactory; + +// This implementation uses explicit policy to perform the sorting. It is not +// thread-safe and always completes synchronously. +class NET_EXPORT_PRIVATE AddressSorterPosix + : public AddressSorter, + public base::NonThreadSafe, + public NetworkChangeNotifier::IPAddressObserver { + public: + // Generic policy entry. + struct PolicyEntry { + // IPv4 addresses must be mapped to IPv6. + unsigned char prefix[kIPv6AddressSize]; + unsigned prefix_length; + unsigned value; + }; + + typedef std::vector<PolicyEntry> PolicyTable; + + enum AddressScope { + SCOPE_UNDEFINED = 0, + SCOPE_NODELOCAL = 1, + SCOPE_LINKLOCAL = 2, + SCOPE_SITELOCAL = 5, + SCOPE_ORGLOCAL = 8, + SCOPE_GLOBAL = 14, + }; + + struct SourceAddressInfo { + // Values read from policy tables. + AddressScope scope; + unsigned label; + + // Values from the OS, matter only if more than one source address is used. + unsigned prefix_length; + bool deprecated; // vs. preferred RFC4862 + bool home; // vs. care-of RFC6275 + bool native; + }; + + typedef std::map<IPAddressNumber, SourceAddressInfo> SourceAddressMap; + + explicit AddressSorterPosix(ClientSocketFactory* socket_factory); + virtual ~AddressSorterPosix(); + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE; + + private: + friend class AddressSorterPosixTest; + + // NetworkChangeNotifier::IPAddressObserver: + virtual void OnIPAddressChanged() OVERRIDE; + + // Fills |info| with values for |address| from policy tables. + void FillPolicy(const IPAddressNumber& address, + SourceAddressInfo* info) const; + + // Mutable to allow using default values for source addresses which were not + // found in most recent OnIPAddressChanged. + mutable SourceAddressMap source_map_; + + ClientSocketFactory* socket_factory_; + PolicyTable precedence_table_; + PolicyTable label_table_; + PolicyTable ipv4_scope_table_; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterPosix); +}; + +} // namespace net + +#endif // NET_DNS_ADDRESS_SORTER_POSIX_H_ diff --git a/net/dns/address_sorter_posix_unittest.cc b/net/dns/address_sorter_posix_unittest.cc new file mode 100644 index 0000000..96cbfc6 --- /dev/null +++ b/net/dns/address_sorter_posix_unittest.cc @@ -0,0 +1,325 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter_posix.h" + +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/client_socket_factory.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +// Used to map destination address to source address. +typedef std::map<IPAddressNumber, IPAddressNumber> AddressMapping; + +IPAddressNumber ParseIP(const std::string& str) { + IPAddressNumber addr; + CHECK(ParseIPLiteralToNumber(str, &addr)); + return addr; +} + +// A mock socket which binds to source address according to AddressMapping. +class TestUDPClientSocket : public DatagramClientSocket { + public: + explicit TestUDPClientSocket(const AddressMapping* mapping) + : mapping_(mapping), connected_(false) {} + + virtual ~TestUDPClientSocket() {} + + virtual int Read(IOBuffer*, int, const CompletionCallback&) OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual int Write(IOBuffer*, int, const CompletionCallback&) OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual bool SetReceiveBufferSize(int32) OVERRIDE { + return true; + } + virtual bool SetSendBufferSize(int32) OVERRIDE { + return true; + } + + virtual void Close() OVERRIDE {} + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!connected_) + return ERR_UNEXPECTED; + *address = local_endpoint_; + return OK; + } + + virtual int Connect(const IPEndPoint& remote) OVERRIDE { + if (connected_) + return ERR_UNEXPECTED; + AddressMapping::const_iterator it = mapping_->find(remote.address()); + if (it == mapping_->end()) + return ERR_FAILED; + connected_ = true; + local_endpoint_ = IPEndPoint(it->second, 39874 /* arbitrary port */); + return OK; + } + + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + private: + BoundNetLog net_log_; + const AddressMapping* mapping_; + bool connected_; + IPEndPoint local_endpoint_; + + DISALLOW_COPY_AND_ASSIGN(TestUDPClientSocket); +}; + +// Creates TestUDPClientSockets and maintains an AddressMapping. +class TestSocketFactory : public ClientSocketFactory { + public: + TestSocketFactory() {} + virtual ~TestSocketFactory() {} + + virtual DatagramClientSocket* CreateDatagramClientSocket( + DatagramSocket::BindType, + const RandIntCallback&, + NetLog*, + const NetLog::Source&) OVERRIDE { + return new TestUDPClientSocket(&mapping_); + } + virtual StreamSocket* CreateTransportClientSocket( + const AddressList&, + NetLog*, + const NetLog::Source&) OVERRIDE { + NOTIMPLEMENTED(); + return NULL; + } + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocketHandle*, + const HostPortPair&, + const SSLConfig&, + const SSLClientSocketContext&) OVERRIDE { + NOTIMPLEMENTED(); + return NULL; + } + virtual void ClearSSLSessionCache() OVERRIDE { + NOTIMPLEMENTED(); + } + + void AddMapping(const IPAddressNumber& dst, const IPAddressNumber& src) { + mapping_[dst] = src; + } + + private: + AddressMapping mapping_; + + DISALLOW_COPY_AND_ASSIGN(TestSocketFactory); +}; + +void OnSortComplete(AddressList* result_buf, + const CompletionCallback& callback, + bool success, + const AddressList& result) { + EXPECT_TRUE(success); + if (success) + *result_buf = result; + callback.Run(OK); +} + +} // namespace + +class AddressSorterPosixTest : public testing::Test { + protected: + AddressSorterPosixTest() : sorter_(&socket_factory_) {} + + void AddMapping(const std::string& dst, const std::string& src) { + socket_factory_.AddMapping(ParseIP(dst), ParseIP(src)); + } + + AddressSorterPosix::SourceAddressInfo* GetSourceInfo( + const std::string& addr) { + IPAddressNumber address = ParseIP(addr); + AddressSorterPosix::SourceAddressInfo* info = &sorter_.source_map_[address]; + if (info->scope == AddressSorterPosix::SCOPE_UNDEFINED) + sorter_.FillPolicy(address, info); + return info; + } + + // Verify that NULL-terminated |addresses| matches (-1)-terminated |order| + // after sorting. + void Verify(const char* addresses[], const int order[]) { + AddressList list; + for (const char** addr = addresses; *addr != NULL; ++addr) + list.push_back(IPEndPoint(ParseIP(*addr), 80)); + for (size_t i = 0; order[i] >= 0; ++i) + CHECK_LT(order[i], static_cast<int>(list.size())); + + AddressList result; + TestCompletionCallback callback; + sorter_.Sort(list, base::Bind(&OnSortComplete, &result, + callback.callback())); + callback.WaitForResult(); + + for (size_t i = 0; (i < result.size()) || (order[i] >= 0); ++i) { + IPEndPoint expected = order[i] >= 0 ? list[order[i]] : IPEndPoint(); + IPEndPoint actual = i < result.size() ? result[i] : IPEndPoint(); + EXPECT_TRUE(expected.address() == actual.address()) << + "Address out of order at position " << i << "\n" << + " Actual: " << actual.ToStringWithoutPort() << "\n" << + "Expected: " << expected.ToStringWithoutPort(); + } + } + + TestSocketFactory socket_factory_; + AddressSorterPosix sorter_; +}; + +// Rule 1: Avoid unusable destinations. +TEST_F(AddressSorterPosixTest, Rule1) { + AddMapping("10.0.0.231", "10.0.0.1"); + const char* addresses[] = { "::1", "10.0.0.231", "127.0.0.1", NULL }; + const int order[] = { 1, -1 }; + Verify(addresses, order); +} + +// Rule 2: Prefer matching scope. +TEST_F(AddressSorterPosixTest, Rule2) { + AddMapping("3002::1", "4000::10"); // matching global + AddMapping("ff32::1", "fe81::10"); // matching link-local + AddMapping("fec1::1", "fec1::10"); // matching node-local + AddMapping("3002::2", "::1"); // global vs. link-local + AddMapping("fec1::2", "fe81::10"); // site-local vs. link-local + AddMapping("8.0.0.1", "169.254.0.10"); // global vs. link-local + // In all three cases, matching scope is preferred. + const int order[] = { 1, 0, -1 }; + const char* addresses1[] = { "3002::2", "3002::1", NULL }; + Verify(addresses1, order); + const char* addresses2[] = { "fec1::2", "ff32::1", NULL }; + Verify(addresses2, order); + const char* addresses3[] = { "8.0.0.1", "fec1::1", NULL }; + Verify(addresses3, order); +} + +// Rule 3: Avoid deprecated addresses. +TEST_F(AddressSorterPosixTest, Rule3) { + // Matching scope. + AddMapping("3002::1", "4000::10"); + GetSourceInfo("4000::10")->deprecated = true; + AddMapping("3002::2", "4000::20"); + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 4: Prefer home addresses. +TEST_F(AddressSorterPosixTest, Rule4) { + AddMapping("3002::1", "4000::10"); + AddMapping("3002::2", "4000::20"); + GetSourceInfo("4000::20")->home = true; + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 5: Prefer matching label. +TEST_F(AddressSorterPosixTest, Rule5) { + AddMapping("::1", "::1"); // matching loopback + AddMapping("::ffff:1234:1", "::ffff:1234:10"); // matching IPv4-mapped + AddMapping("2001::1", "::ffff:1234:10"); // Teredo vs. IPv4-mapped + AddMapping("2002::1", "2001::10"); // 6to4 vs. Teredo + const int order[] = { 1, 0, -1 }; + { + const char* addresses[] = { "2001::1", "::1", NULL }; + Verify(addresses, order); + } + { + const char* addresses[] = { "2002::1", "::ffff:1234:1", NULL }; + Verify(addresses, order); + } +} + +// Rule 6: Prefer higher precedence. +TEST_F(AddressSorterPosixTest, Rule6) { + AddMapping("::1", "::1"); // loopback + AddMapping("ff32::1", "fe81::10"); // multicast + AddMapping("::ffff:1234:1", "::ffff:1234:10"); // IPv4-mapped + AddMapping("2001::1", "2001::10"); // Teredo + const char* addresses[] = { "2001::1", "::ffff:1234:1", "ff32::1", "::1", + NULL }; + const int order[] = { 3, 2, 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 7: Prefer native transport. +TEST_F(AddressSorterPosixTest, Rule7) { + AddMapping("3002::1", "4000::10"); + AddMapping("3002::2", "4000::20"); + GetSourceInfo("4000::20")->native = true; + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 8: Prefer smaller scope. +TEST_F(AddressSorterPosixTest, Rule8) { + // Matching scope. Should precede the others by Rule 2. + AddMapping("fe81::1", "fe81::10"); // link-local + AddMapping("3000::1", "4000::10"); // global + // Mismatched scope. + AddMapping("ff32::1", "4000::10"); // link-local + AddMapping("ff35::1", "4000::10"); // site-local + AddMapping("ff38::1", "4000::10"); // org-local + const char* addresses[] = { "ff38::1", "3000::1", "ff35::1", "ff32::1", + "fe81::1", NULL }; + const int order[] = { 4, 1, 3, 2, 0, -1 }; + Verify(addresses, order); +} + +// Rule 9: Use longest matching prefix. +TEST_F(AddressSorterPosixTest, Rule9) { + AddMapping("3000::1", "3000:ffff::10"); // 16 bit match + GetSourceInfo("3000:ffff::10")->prefix_length = 16; + AddMapping("4000::1", "4000::10"); // 123 bit match, limited to 15 + GetSourceInfo("4000::10")->prefix_length = 15; + AddMapping("4002::1", "4000::10"); // 14 bit match + AddMapping("4080::1", "4000::10"); // 8 bit match + const char* addresses[] = { "4080::1", "4002::1", "4000::1", "3000::1", + NULL }; + const int order[] = { 3, 2, 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 10: Leave the order unchanged. +TEST_F(AddressSorterPosixTest, Rule10) { + AddMapping("4000::1", "4000::10"); + AddMapping("4000::2", "4000::10"); + AddMapping("4000::3", "4000::10"); + const char* addresses[] = { "4000::1", "4000::2", "4000::3", NULL }; + const int order[] = { 0, 1, 2, -1 }; + Verify(addresses, order); +} + +TEST_F(AddressSorterPosixTest, MultipleRules) { + AddMapping("::1", "::1"); // loopback + AddMapping("ff32::1", "fe81::10"); // link-local multicast + AddMapping("ff3e::1", "4000::10"); // global multicast + AddMapping("4000::1", "4000::10"); // global unicast + AddMapping("ff32::2", "fe81::20"); // deprecated link-local multicast + GetSourceInfo("fe81::20")->deprecated = true; + const char* addresses[] = { "ff3e::1", "ff32::2", "4000::1", "ff32::1", "::1", + "8.0.0.1", NULL }; + const int order[] = { 4, 3, 0, 2, 1, -1 }; + Verify(addresses, order); +} + +} // namespace net diff --git a/net/dns/address_sorter_unittest.cc b/net/dns/address_sorter_unittest.cc new file mode 100644 index 0000000..93c841f --- /dev/null +++ b/net/dns/address_sorter_unittest.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter.h" + +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/address_list.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +IPEndPoint MakeEndPoint(const std::string& str) { + IPAddressNumber addr; + CHECK(ParseIPLiteralToNumber(str, &addr)); + return IPEndPoint(addr, 0); +} + +void OnSortComplete(AddressList* result_buf, + const CompletionCallback& callback, + bool success, + const AddressList& result) { + EXPECT_TRUE(success); + if (success) + *result_buf = result; + callback.Run(OK); +} + +TEST(AddressSorterTest, Sort) { + scoped_ptr<AddressSorter> sorter(AddressSorter::CreateAddressSorter()); + AddressList list; + list.push_back(MakeEndPoint("10.0.0.1")); + list.push_back(MakeEndPoint("8.8.8.8")); + list.push_back(MakeEndPoint("::1")); + list.push_back(MakeEndPoint("2001:4860:4860::8888")); + + AddressList result; + TestCompletionCallback callback; + sorter->Sort(list, base::Bind(&OnSortComplete, &result, + callback.callback())); + callback.WaitForResult(); +} + +} // namespace +} // namespace net diff --git a/net/dns/address_sorter_win.cc b/net/dns/address_sorter_win.cc new file mode 100644 index 0000000..116834a --- /dev/null +++ b/net/dns/address_sorter_win.cc @@ -0,0 +1,197 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter.h" + +#include <winsock2.h> + +#include <algorithm> + +#include "base/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/threading/worker_pool.h" +#include "base/win/windows_version.h" +#include "net/base/address_list.h" +#include "net/base/ip_endpoint.h" +#include "net/base/winsock_init.h" + +namespace net { + +namespace { + +class AddressSorterWin : public AddressSorter { + public: + AddressSorterWin() { + EnsureWinsockInit(); + } + + virtual ~AddressSorterWin() {} + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + DCHECK(!list.empty()); + scoped_refptr<Job> job = new Job(list, callback); + } + + private: + // Executes the SIO_ADDRESS_LIST_SORT ioctl on the WorkerPool, and + // performs the necessary conversions to/from AddressList. + class Job : public base::RefCountedThreadSafe<Job> { + public: + Job(const AddressList& list, const CallbackType& callback) + : callback_(callback), + buffer_size_(sizeof(SOCKET_ADDRESS_LIST) + + list.size() * (sizeof(SOCKET_ADDRESS) + + sizeof(SOCKADDR_STORAGE))), + input_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>( + malloc(buffer_size_))), + output_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>( + malloc(buffer_size_))), + success_(false) { + input_buffer_->iAddressCount = list.size(); + SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>( + input_buffer_->Address + input_buffer_->iAddressCount); + + for (size_t i = 0; i < list.size(); ++i) { + IPEndPoint ipe = list[i]; + // Addresses must be sockaddr_in6. + if (ipe.GetFamily() == AF_INET) { + ipe = IPEndPoint(ConvertIPv4NumberToIPv6Number(ipe.address()), + ipe.port()); + } + + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i); + socklen_t addr_len = sizeof(SOCKADDR_STORAGE); + bool result = ipe.ToSockAddr(addr, &addr_len); + DCHECK(result); + input_buffer_->Address[i].lpSockaddr = addr; + input_buffer_->Address[i].iSockaddrLength = addr_len; + } + + if (!base::WorkerPool::PostTaskAndReply( + FROM_HERE, + base::Bind(&Job::Run, this), + base::Bind(&Job::OnComplete, this), + false /* task is slow */)) { + LOG(ERROR) << "WorkerPool::PostTaskAndReply failed"; + OnComplete(); + } + } + + private: + friend class base::RefCountedThreadSafe<Job>; + ~Job() {} + + // Executed on the WorkerPool. + void Run() { + SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + DCHECK_NE(INVALID_SOCKET, sock); + DWORD result_size = 0; + int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(), + buffer_size_, output_buffer_.get(), buffer_size_, + &result_size, NULL, NULL); + if (result == SOCKET_ERROR) { + LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError(); + } else { + success_ = true; + } + closesocket(sock); + } + + // Executed on the calling thread. + void OnComplete() { + AddressList list; + if (success_) { + list.reserve(output_buffer_->iAddressCount); + for (int i = 0; i < output_buffer_->iAddressCount; ++i) { + IPEndPoint ipe; + ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr, + output_buffer_->Address[i].iSockaddrLength); + // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works. + if (IsIPv4Mapped(ipe.address())) { + ipe = IPEndPoint(ConvertIPv4MappedToIPv4(ipe.address()), + ipe.port()); + } + list.push_back(ipe); + } + } + callback_.Run(success_, list); + } + + const CallbackType callback_; + const size_t buffer_size_; + scoped_ptr_malloc<SOCKET_ADDRESS_LIST> input_buffer_; + scoped_ptr_malloc<SOCKET_ADDRESS_LIST> output_buffer_; + bool success_; + + DISALLOW_COPY_AND_ASSIGN(Job); + }; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterWin); +}; + +// Merges |list_ipv4| and |list_ipv6| before passing it to |callback|, but +// only if |success| is true. +void MergeResults(const AddressSorter::CallbackType& callback, + const AddressList& list_ipv4, + bool success, + const AddressList& list_ipv6) { + if (!success) { + callback.Run(false, AddressList()); + return; + } + AddressList list; + list.insert(list.end(), list_ipv6.begin(), list_ipv6.end()); + list.insert(list.end(), list_ipv4.begin(), list_ipv4.end()); + callback.Run(true, list); +} + +// Wrapper for AddressSorterWin which does not sort IPv4 or IPv4-mapped +// addresses but always puts them at the end of the list. Needed because the +// SIO_ADDRESS_LIST_SORT does not support IPv4 addresses on Windows XP. +class AddressSorterWinXP : public AddressSorter { + public: + AddressSorterWinXP() {} + virtual ~AddressSorterWinXP() {} + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + AddressList list_ipv4; + AddressList list_ipv6; + for (size_t i = 0; i < list.size(); ++i) { + const IPEndPoint& ipe = list[i]; + if (ipe.GetFamily() == AF_INET) { + list_ipv4.push_back(ipe); + } else { + list_ipv6.push_back(ipe); + } + } + if (!list_ipv6.empty()) { + sorter_.Sort(list_ipv6, base::Bind(&MergeResults, callback, list_ipv4)); + } else { + NOTREACHED() << "Should not be called with IPv4-only addresses."; + callback.Run(true, list); + } + } + + private: + AddressSorterWin sorter_; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterWinXP); +}; + +} // namespace + +// static +scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() { + if (base::win::GetVersion() < base::win::VERSION_VISTA) + return scoped_ptr<AddressSorter>(new AddressSorterWinXP()); + return scoped_ptr<AddressSorter>(new AddressSorterWin()); +} + +} // namespace net + diff --git a/net/dns/dns_client.cc b/net/dns/dns_client.cc index 5381452..859fc27 100644 --- a/net/dns/dns_client.cc +++ b/net/dns/dns_client.cc @@ -7,6 +7,7 @@ #include "base/bind.h" #include "base/rand_util.h" #include "net/base/net_log.h" +#include "net/dns/address_sorter.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_session.h" #include "net/dns/dns_transaction.h" @@ -18,7 +19,9 @@ namespace { class DnsClientImpl : public DnsClient { public: - explicit DnsClientImpl(NetLog* net_log) : net_log_(net_log) {} + explicit DnsClientImpl(NetLog* net_log) + : address_sorter_(AddressSorter::CreateAddressSorter()), + net_log_(net_log) {} virtual void SetConfig(const DnsConfig& config) OVERRIDE { factory_.reset(); @@ -40,9 +43,14 @@ class DnsClientImpl : public DnsClient { return session_.get() ? factory_.get() : NULL; } + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return address_sorter_.get(); + } + private: scoped_refptr<DnsSession> session_; scoped_ptr<DnsTransactionFactory> factory_; + scoped_ptr<AddressSorter> address_sorter_; NetLog* net_log_; }; diff --git a/net/dns/dns_client.h b/net/dns/dns_client.h index 13aa0bf8..650c7d0 100644 --- a/net/dns/dns_client.h +++ b/net/dns/dns_client.h @@ -10,12 +10,14 @@ namespace net { +class AddressSorter; struct DnsConfig; class DnsTransactionFactory; class NetLog; -// Convenience wrapper allows easy injection of DnsTransaction into -// HostResolverImpl. +// Convenience wrapper which allows easy injection of DnsTransaction into +// HostResolverImpl. Pointers returned by the Get* methods are only guaranteed +// to remain valid until next time SetConfig is called. class NET_EXPORT DnsClient { public: virtual ~DnsClient() {} @@ -29,6 +31,9 @@ class NET_EXPORT DnsClient { // Returns NULL if the current config is not valid. virtual DnsTransactionFactory* GetTransactionFactory() = 0; + // Returns NULL if the current config is not valid. + virtual AddressSorter* GetAddressSorter() = 0; + // Creates default client. static scoped_ptr<DnsClient> CreateClient(NetLog* net_log); }; diff --git a/net/dns/dns_response.cc b/net/dns/dns_response.cc index 4ad6465..4ab7296 100644 --- a/net/dns/dns_response.cc +++ b/net/dns/dns_response.cc @@ -280,15 +280,13 @@ DnsResponse::Result DnsResponse::ParseToAddressList( } // TODO(szym): Extract TTL for NODATA results. http://crbug.com/115051 - if (ip_addresses.empty()) - return DNS_NO_ADDRESSES; // getcanonname in eglibc returns the first owner name of an A or AAAA RR. // If the response passed all the checks so far, then |expected_name| is it. *addr_list = AddressList::CreateFromIPAddressList(ip_addresses, expected_name); *ttl = base::TimeDelta::FromSeconds(std::min(cname_ttl_sec, addr_ttl_sec)); - return DNS_SUCCESS; + return DNS_PARSE_OK; } } // namespace net diff --git a/net/dns/dns_response.h b/net/dns/dns_response.h index 5ade458..fe80c30 100644 --- a/net/dns/dns_response.h +++ b/net/dns/dns_response.h @@ -81,7 +81,7 @@ class NET_EXPORT_PRIVATE DnsResponse { public: // Possible results from ParseToAddressList. enum Result { - DNS_SUCCESS = 0, + DNS_PARSE_OK = 0, DNS_MALFORMED_RESPONSE, // DnsRecordParser failed before the end of // packet. DNS_MALFORMED_CNAME, // Could not parse CNAME out of RRDATA. @@ -90,7 +90,7 @@ class NET_EXPORT_PRIVATE DnsResponse { DNS_SIZE_MISMATCH, // Got an address but size does not match. DNS_CNAME_AFTER_ADDRESS, // Found CNAME after an address record. DNS_ADDRESS_TTL_MISMATCH, // TTL of all address records are not identical. - DNS_NO_ADDRESSES, // No address records found. + DNS_NO_ADDRESSES, // OBSOLETE. No longer used. // Only add new values here. DNS_PARSE_RESULT_MAX, // Bounding value for histograms. }; diff --git a/net/dns/dns_response_unittest.cc b/net/dns/dns_response_unittest.cc index bb053bc..6ed6f06 100644 --- a/net/dns/dns_response_unittest.cc +++ b/net/dns/dns_response_unittest.cc @@ -299,7 +299,7 @@ TEST(DnsResponseTest, ParseToAddressList) { DnsResponse response(t.response_data, t.response_size, t.query_size); AddressList addr_list; base::TimeDelta ttl; - EXPECT_EQ(DnsResponse::DNS_SUCCESS, + EXPECT_EQ(DnsResponse::DNS_PARSE_OK, response.ParseToAddressList(&addr_list, &ttl)); std::vector<const char*> expected_addresses( t.expected_addresses, @@ -429,8 +429,9 @@ TEST(DnsResponseTest, ParseToAddressListFail) { DnsResponse::DNS_CNAME_AFTER_ADDRESS }, { kResponseTTLMismatch, arraysize(kResponseTTLMismatch), DnsResponse::DNS_ADDRESS_TTL_MISMATCH }, + // Not actually a failure, just an empty result. { kResponseNoAddresses, arraysize(kResponseNoAddresses), - DnsResponse::DNS_NO_ADDRESSES }, + DnsResponse::DNS_PARSE_OK }, }; const size_t kQuerySize = 12 + 7; diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc index 051f595..73d208c 100644 --- a/net/dns/dns_test_util.cc +++ b/net/dns/dns_test_util.cc @@ -14,6 +14,7 @@ #include "net/base/dns_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "net/dns/address_sorter.h" #include "net/dns/dns_client.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_protocol.h" @@ -25,19 +26,29 @@ namespace net { namespace { -// A DnsTransaction which responds with loopback to all queries starting with -// "ok", fails synchronously on all queries starting with "er", and NXDOMAIN to -// all others. +// A DnsTransaction which uses MockDnsClientRuleList to determine the response. class MockTransaction : public DnsTransaction, public base::SupportsWeakPtr<MockTransaction> { public: - MockTransaction(const std::string& hostname, + MockTransaction(const MockDnsClientRuleList& rules, + const std::string& hostname, uint16 qtype, const DnsTransactionFactory::CallbackType& callback) - : hostname_(hostname), + : result_(MockDnsClientRule::FAIL_SYNC), + hostname_(hostname), qtype_(qtype), callback_(callback), started_(false) { + // Find the relevant rule which matches |qtype| and prefix of |hostname|. + for (size_t i = 0; i < rules.size(); ++i) { + const std::string& prefix = rules[i].prefix; + if ((rules[i].qtype == qtype) && + (hostname.size() >= prefix.size()) && + (hostname.compare(0, prefix.size(), prefix) == 0)) { + result_ = rules[i].result; + break; + } + } } virtual const std::string& GetHostname() const OVERRIDE { @@ -51,7 +62,7 @@ class MockTransaction : public DnsTransaction, virtual int Start() OVERRIDE { EXPECT_FALSE(started_); started_ = true; - if (hostname_.substr(0, 2) == "er") + if (MockDnsClientRule::FAIL_SYNC == result_) return ERR_NAME_NOT_RESOLVED; // Using WeakPtr to cleanly cancel when transaction is destroyed. MessageLoop::current()->PostTask( @@ -62,54 +73,66 @@ class MockTransaction : public DnsTransaction, private: void Finish() { - if (hostname_.substr(0, 2) == "ok") { - std::string qname; - DNSDomainFromDot(hostname_, &qname); - DnsQuery query(0, qname, qtype_); - - DnsResponse response; - char* buffer = response.io_buffer()->data(); - int nbytes = query.io_buffer()->size(); - memcpy(buffer, query.io_buffer()->data(), nbytes); - - const uint16 kPointerToQueryName = - static_cast<uint16>(0xc000 | sizeof(net::dns_protocol::Header)); - - const uint32 kTTL = 86400; // One day. - - // Size of RDATA which is a IPv4 or IPv6 address. - size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? - net::kIPv4AddressSize : net::kIPv6AddressSize; - - // 12 is the sum of sizes of the compressed name reference, TYPE, - // CLASS, TTL and RDLENGTH. - size_t answer_size = 12 + rdata_size; - - // Write answer with loopback IP address. - reinterpret_cast<dns_protocol::Header*>(buffer)->ancount = - base::HostToNet16(1); - BigEndianWriter writer(buffer + nbytes, answer_size); - writer.WriteU16(kPointerToQueryName); - writer.WriteU16(qtype_); - writer.WriteU16(net::dns_protocol::kClassIN); - writer.WriteU32(kTTL); - writer.WriteU16(rdata_size); - if (qtype_ == net::dns_protocol::kTypeA) { - char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; - writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); - } else { - char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1 }; - writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); - } - - EXPECT_TRUE(response.InitParse(nbytes + answer_size, query)); - callback_.Run(this, OK, &response); - } else { - callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + switch (result_) { + case MockDnsClientRule::EMPTY: + case MockDnsClientRule::OK: { + std::string qname; + DNSDomainFromDot(hostname_, &qname); + DnsQuery query(0, qname, qtype_); + + DnsResponse response; + char* buffer = response.io_buffer()->data(); + int nbytes = query.io_buffer()->size(); + memcpy(buffer, query.io_buffer()->data(), nbytes); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(buffer); + header->flags |= dns_protocol::kFlagResponse; + + if (MockDnsClientRule::OK == result_) { + const uint16 kPointerToQueryName = + static_cast<uint16>(0xc000 | sizeof(*header)); + + const uint32 kTTL = 86400; // One day. + + // Size of RDATA which is a IPv4 or IPv6 address. + size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? + net::kIPv4AddressSize : net::kIPv6AddressSize; + + // 12 is the sum of sizes of the compressed name reference, TYPE, + // CLASS, TTL and RDLENGTH. + size_t answer_size = 12 + rdata_size; + + // Write answer with loopback IP address. + header->ancount = base::HostToNet16(1); + BigEndianWriter writer(buffer + nbytes, answer_size); + writer.WriteU16(kPointerToQueryName); + writer.WriteU16(qtype_); + writer.WriteU16(net::dns_protocol::kClassIN); + writer.WriteU32(kTTL); + writer.WriteU16(rdata_size); + if (qtype_ == net::dns_protocol::kTypeA) { + char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; + writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); + } else { + char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1 }; + writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); + } + nbytes += answer_size; + } + EXPECT_TRUE(response.InitParse(nbytes, query)); + callback_.Run(this, OK, &response); + } break; + case MockDnsClientRule::FAIL_ASYNC: + callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + break; + default: + NOTREACHED(); + break; } } + MockDnsClientRule::Result result_; const std::string hostname_; const uint16 qtype_; DnsTransactionFactory::CallbackType callback_; @@ -120,7 +143,8 @@ class MockTransaction : public DnsTransaction, // A DnsTransactionFactory which creates MockTransaction. class MockTransactionFactory : public DnsTransactionFactory { public: - MockTransactionFactory() {} + explicit MockTransactionFactory(const MockDnsClientRuleList& rules) + : rules_(rules) {} virtual ~MockTransactionFactory() {} virtual scoped_ptr<DnsTransaction> CreateTransaction( @@ -129,14 +153,29 @@ class MockTransactionFactory : public DnsTransactionFactory { const DnsTransactionFactory::CallbackType& callback, const BoundNetLog&) OVERRIDE { return scoped_ptr<DnsTransaction>( - new MockTransaction(hostname, qtype, callback)); + new MockTransaction(rules_, hostname, qtype, callback)); + } + + private: + MockDnsClientRuleList rules_; +}; + +class MockAddressSorter : public AddressSorter { + public: + virtual ~MockAddressSorter() {} + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + // Do nothing. + callback.Run(true, list); } }; // MockDnsClient provides MockTransactionFactory. class MockDnsClient : public DnsClient { public: - explicit MockDnsClient(const DnsConfig& config) : config_(config) {} + MockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) + : config_(config), factory_(rules) {} virtual ~MockDnsClient() {} virtual void SetConfig(const DnsConfig& config) OVERRIDE { @@ -151,16 +190,22 @@ class MockDnsClient : public DnsClient { return config_.IsValid() ? &factory_ : NULL; } + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return &address_sorter_; + } + private: DnsConfig config_; MockTransactionFactory factory_; + MockAddressSorter address_sorter_; }; } // namespace // static -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config) { - return scoped_ptr<DnsClient>(new MockDnsClient(config)); +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) { + return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); } MockDnsConfigService::~MockDnsConfigService() { diff --git a/net/dns/dns_test_util.h b/net/dns/dns_test_util.h index 53a0620..2d03d0e 100644 --- a/net/dns/dns_test_util.h +++ b/net/dns/dns_test_util.h @@ -5,6 +5,9 @@ #ifndef NET_DNS_DNS_TEST_UTIL_H_ #define NET_DNS_DNS_TEST_UTIL_H_ +#include <string> +#include <vector> + #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" #include "net/dns/dns_config_service.h" @@ -166,8 +169,25 @@ static const int kT3TTL = 0x00000015; static const unsigned kT3RecordCount = arraysize(kT3IpAddresses) + 2; class DnsClient; + +struct MockDnsClientRule { + enum Result { + FAIL_SYNC, // Fail synchronously with ERR_NAME_NOT_RESOLVED. + FAIL_ASYNC, // Fail asynchronously with ERR_NAME_NOT_RESOLVED. + EMPTY, // Return an empty response. + OK, // Return a response with loopback address. + }; + + std::string prefix; + uint16 qtype; + Result result; +}; + +typedef std::vector<MockDnsClientRule> MockDnsClientRuleList; + // Creates mock DnsClient for testing HostResolverImpl. -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config); +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules); class MockDnsConfigService : public DnsConfigService { public: diff --git a/net/net.gyp b/net/net.gyp index 6adaf04..a75369d 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -364,6 +364,10 @@ 'disk_cache/stress_support.h', 'disk_cache/trace.cc', 'disk_cache/trace.h', + 'dns/address_sorter.h', + 'dns/address_sorter_posix.cc', + 'dns/address_sorter_posix.h', + 'dns/address_sorter_win.cc', 'dns/dns_client.cc', 'dns/dns_client.h', 'dns/dns_config_service.cc', @@ -1195,6 +1199,8 @@ 'disk_cache/entry_unittest.cc', 'disk_cache/mapped_file_unittest.cc', 'disk_cache/storage_block_unittest.cc', + 'dns/address_sorter_posix_unittest.cc', + 'dns/address_sorter_unittest.cc', 'dns/dns_config_service_posix_unittest.cc', 'dns/dns_config_service_unittest.cc', 'dns/dns_config_service_win_unittest.cc', diff --git a/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc b/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc index e4573b1..9884ded 100644 --- a/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc +++ b/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc @@ -149,7 +149,7 @@ void RunTestCase(uint16 id, std::string& qname, uint16 qtype, base::TimeDelta ttl; net::DnsResponse::Result result = response.ParseToAddressList( &address_list, &ttl); - if (result != net::DnsResponse::DNS_SUCCESS) { + if (result != net::DnsResponse::DNS_PARSE_OK) { LOG(INFO) << "ParseToAddressList failed: " << result; return; } |