diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/base/big_endian.cc | 98 | ||||
-rw-r--r-- | net/base/big_endian.h | 107 | ||||
-rw-r--r-- | net/base/big_endian_unittest.cc | 100 | ||||
-rw-r--r-- | net/base/dns_util.cc | 17 | ||||
-rw-r--r-- | net/base/dns_util.h | 10 | ||||
-rw-r--r-- | net/dns/async_host_resolver.cc | 137 | ||||
-rw-r--r-- | net/dns/async_host_resolver.h | 50 | ||||
-rw-r--r-- | net/dns/async_host_resolver_unittest.cc | 171 | ||||
-rw-r--r-- | net/dns/dns_client.cc | 91 | ||||
-rw-r--r-- | net/dns/dns_client.h | 93 | ||||
-rw-r--r-- | net/dns/dns_client_unittest.cc | 311 | ||||
-rw-r--r-- | net/dns/dns_protocol.h | 122 | ||||
-rw-r--r-- | net/dns/dns_query.cc | 111 | ||||
-rw-r--r-- | net/dns/dns_query.h | 39 | ||||
-rw-r--r-- | net/dns/dns_query_unittest.cc | 103 | ||||
-rw-r--r-- | net/dns/dns_response.cc | 232 | ||||
-rw-r--r-- | net/dns/dns_response.h | 92 | ||||
-rw-r--r-- | net/dns/dns_response_unittest.cc | 295 | ||||
-rw-r--r-- | net/dns/dns_session.cc | 47 | ||||
-rw-r--r-- | net/dns/dns_session.h | 70 | ||||
-rw-r--r-- | net/dns/dns_test_util.cc | 14 | ||||
-rw-r--r-- | net/dns/dns_test_util.h | 32 | ||||
-rw-r--r-- | net/dns/dns_transaction.cc | 184 | ||||
-rw-r--r-- | net/dns/dns_transaction.h | 86 | ||||
-rw-r--r-- | net/dns/dns_transaction_unittest.cc | 286 | ||||
-rw-r--r-- | net/net.gyp | 8 |
26 files changed, 1979 insertions, 927 deletions
diff --git a/net/base/big_endian.cc b/net/base/big_endian.cc new file mode 100644 index 0000000..17acc59 --- /dev/null +++ b/net/base/big_endian.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/big_endian.h" + +#include "base/string_piece.h" + +namespace net { + +BigEndianReader::BigEndianReader(const void* buf, size_t len) + : ptr_(reinterpret_cast<const char*>(buf)), end_(ptr_ + len) {} + +bool BigEndianReader::Skip(size_t len) { + if (ptr_ + len > end_) + return false; + ptr_ += len; + return true; +} + +bool BigEndianReader::ReadBytes(void* out, size_t len) { + if (ptr_ + len > end_) + return false; + memcpy(out, ptr_, len); + ptr_ += len; + return true; +} + +bool BigEndianReader::ReadPiece(base::StringPiece* out, size_t len) { + if (ptr_ + len > end_) + return false; + *out = base::StringPiece(ptr_, len); + ptr_ += len; + return true; +} + +template<typename T> +bool BigEndianReader::Read(T* value) { + if (ptr_ + sizeof(T) > end_) + return false; + ReadBigEndian<T>(ptr_, value); + ptr_ += sizeof(T); + return true; +} + +bool BigEndianReader::ReadU8(uint8* value) { + return Read(value); +} + +bool BigEndianReader::ReadU16(uint16* value) { + return Read(value); +} + +bool BigEndianReader::ReadU32(uint32* value) { + return Read(value); +} + +BigEndianWriter::BigEndianWriter(void* buf, size_t len) + : ptr_(reinterpret_cast<char*>(buf)), end_(ptr_ + len) {} + +bool BigEndianWriter::Skip(size_t len) { + if (ptr_ + len > end_) + return false; + ptr_ += len; + return true; +} + +bool BigEndianWriter::WriteBytes(const void* buf, size_t len) { + if (ptr_ + len > end_) + return false; + memcpy(ptr_, buf, len); + ptr_ += len; + return true; +} + +template<typename T> +bool BigEndianWriter::Write(T value) { + if (ptr_ + sizeof(T) > end_) + return false; + WriteBigEndian<T>(ptr_, value); + ptr_ += sizeof(T); + return true; +} + +bool BigEndianWriter::WriteU8(uint8 value) { + return Write(value); +} + +bool BigEndianWriter::WriteU16(uint16 value) { + return Write(value); +} + +bool BigEndianWriter::WriteU32(uint32 value) { + return Write(value); +} + +} // namespace net + diff --git a/net/base/big_endian.h b/net/base/big_endian.h new file mode 100644 index 0000000..0786f62 --- /dev/null +++ b/net/base/big_endian.h @@ -0,0 +1,107 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_BIG_ENDIAN_H_ +#define NET_BASE_BIG_ENDIAN_H_ +#pragma once + +#include "base/basictypes.h" +#include "net/base/net_export.h" + +namespace base { +class StringPiece; +} + +namespace net { + +// Read an integer (signed or unsigned) from |buf| in Big Endian order. +// Note: this loop is unrolled with -O1 and above. +// NOTE(szym): glibc dns-canon.c and SpdyFrameBuilder use +// ntohs(*(uint16_t*)ptr) which is potentially unaligned. +// This would cause SIGBUS on ARMv5 or earlier and ARMv6-M. +template<typename T> +inline void ReadBigEndian(const char buf[], T* out) { + *out = buf[0]; + for (size_t i = 1; i < sizeof(T); ++i) { + *out <<= 8; + // Must cast to uint8 to avoid clobbering by sign extension. + *out |= static_cast<uint8>(buf[i]); + } +} + +// Write an integer (signed or unsigned) |val| to |buf| in Big Endian order. +// Note: this loop is unrolled with -O1 and above. +template<typename T> +inline void WriteBigEndian(char buf[], T val) { + for (size_t i = 0; i < sizeof(T); ++i) { + buf[sizeof(T)-i-1] = static_cast<char>(val & 0xFF); + val >>= 8; + } +} + +// Specializations to make clang happy about the (dead code) shifts above. +template<> +inline void ReadBigEndian<uint8>(const char buf[], uint8* out) { + *out = buf[0]; +} + +template<> +inline void WriteBigEndian<uint8>(char buf[], uint8 val) { + buf[0] = static_cast<char>(val); +} + +// Allows reading integers in network order (big endian) while iterating over +// an underlying buffer. All the reading functions advance the internal pointer. +class NET_EXPORT BigEndianReader { + public: + BigEndianReader(const void* buf, size_t len); + + const char* ptr() const { return ptr_; } + int remaining() const { return end_ - ptr_; } + + bool Skip(size_t len); + bool ReadBytes(void* out, size_t len); + // Creates a StringPiece in |out| that points to the underlying buffer. + bool ReadPiece(base::StringPiece* out, size_t len); + bool ReadU8(uint8* value); + bool ReadU16(uint16* value); + bool ReadU32(uint32* value); + + private: + // Hidden to promote type safety. + template<typename T> + bool Read(T* v); + + const char* ptr_; + const char* end_; +}; + +// Allows writing integers in network order (big endian) while iterating over +// an underlying buffer. All the writing functions advance the internal pointer. +class NET_EXPORT BigEndianWriter { + public: + BigEndianWriter(void* buf, size_t len); + + char* ptr() const { return ptr_; } + int remaining() const { return end_ - ptr_; } + + bool Skip(size_t len); + bool WriteBytes(const void* buf, size_t len); + bool WriteU8(uint8 value); + bool WriteU16(uint16 value); + bool WriteU32(uint32 value); + + private: + // Hidden to promote type safety. + template<typename T> + bool Write(T v); + + char* ptr_; + char* end_; +}; + +} // namespace net + +#endif // NET_BASE_BIG_ENDIAN_H_ + diff --git a/net/base/big_endian_unittest.cc b/net/base/big_endian_unittest.cc new file mode 100644 index 0000000..cb4a6d4 --- /dev/null +++ b/net/base/big_endian_unittest.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/string_piece.h" +#include "net/base/big_endian.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +TEST(BigEndianReaderTest, ReadsValues) { + char data[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xA, 0xB, 0xC }; + char buf[2]; + uint8 u8; + uint16 u16; + uint32 u32; + base::StringPiece piece; + BigEndianReader reader(data, sizeof(data)); + + EXPECT_TRUE(reader.Skip(2)); + EXPECT_EQ(data + 2, reader.ptr()); + EXPECT_EQ(reader.remaining(), static_cast<int>(sizeof(data)) - 2); + EXPECT_TRUE(reader.ReadBytes(buf, sizeof(buf))); + EXPECT_EQ(0x2, buf[0]); + EXPECT_EQ(0x3, buf[1]); + EXPECT_TRUE(reader.ReadU8(&u8)); + EXPECT_EQ(0x4, u8); + EXPECT_TRUE(reader.ReadU16(&u16)); + EXPECT_EQ(0x0506, u16); + EXPECT_TRUE(reader.ReadU32(&u32)); + EXPECT_EQ(0x0708090Au, u32); + base::StringPiece expected(reader.ptr(), 2); + EXPECT_TRUE(reader.ReadPiece(&piece, 2)); + EXPECT_EQ(2u, piece.size()); + EXPECT_EQ(expected.data(), piece.data()); +} + +TEST(BigEndianReaderTest, RespectsLength) { + char data[4]; + char buf[2]; + uint8 u8; + uint16 u16; + uint32 u32; + base::StringPiece piece; + BigEndianReader reader(data, sizeof(data)); + // 4 left + EXPECT_FALSE(reader.Skip(6)); + EXPECT_TRUE(reader.Skip(1)); + // 3 left + EXPECT_FALSE(reader.ReadU32(&u32)); + EXPECT_FALSE(reader.ReadPiece(&piece, 4)); + EXPECT_TRUE(reader.Skip(2)); + // 1 left + EXPECT_FALSE(reader.ReadU16(&u16)); + EXPECT_FALSE(reader.ReadBytes(buf, 2)); + EXPECT_TRUE(reader.Skip(1)); + // 0 left + EXPECT_FALSE(reader.ReadU8(&u8)); + EXPECT_EQ(0, reader.remaining()); +} + +TEST(BigEndianWriterTest, WritesValues) { + char expected[] = { 0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 0xA }; + char data[sizeof(expected)]; + char buf[] = { 0x2, 0x3 }; + memset(data, 0, sizeof(data)); + BigEndianWriter writer(data, sizeof(data)); + + EXPECT_TRUE(writer.Skip(2)); + EXPECT_TRUE(writer.WriteBytes(buf, sizeof(buf))); + EXPECT_TRUE(writer.WriteU8(0x4)); + EXPECT_TRUE(writer.WriteU16(0x0506)); + EXPECT_TRUE(writer.WriteU32(0x0708090A)); + EXPECT_EQ(0, memcmp(expected, data, sizeof(expected))); +} + +TEST(BigEndianWriterTest, RespectsLength) { + char data[4]; + char buf[2]; + uint8 u8 = 0; + uint16 u16 = 0; + uint32 u32 = 0; + BigEndianWriter writer(data, sizeof(data)); + // 4 left + EXPECT_FALSE(writer.Skip(6)); + EXPECT_TRUE(writer.Skip(1)); + // 3 left + EXPECT_FALSE(writer.WriteU32(u32)); + EXPECT_TRUE(writer.Skip(2)); + // 1 left + EXPECT_FALSE(writer.WriteU16(u16)); + EXPECT_FALSE(writer.WriteBytes(buf, 2)); + EXPECT_TRUE(writer.Skip(1)); + // 0 left + EXPECT_FALSE(writer.WriteU8(u8)); + EXPECT_EQ(0, writer.remaining()); +} + +} // namespace net + diff --git a/net/base/dns_util.cc b/net/base/dns_util.cc index 93d789e..a49ada8 100644 --- a/net/base/dns_util.cc +++ b/net/base/dns_util.cc @@ -9,7 +9,7 @@ namespace net { // Based on DJB's public domain code. -bool DNSDomainFromDot(const std::string& dotted, std::string* out) { +bool DNSDomainFromDot(const base::StringPiece& dotted, std::string* out) { const char* buf = dotted.data(); unsigned n = dotted.size(); char label[63]; @@ -56,7 +56,7 @@ bool DNSDomainFromDot(const std::string& dotted, std::string* out) { return true; } -std::string DNSDomainToString(const std::string& domain) { +std::string DNSDomainToString(const base::StringPiece& domain) { std::string ret; for (unsigned i = 0; i < domain.size() && domain[i]; i += domain[i] + 1) { @@ -73,7 +73,7 @@ std::string DNSDomainToString(const std::string& domain) { if (static_cast<unsigned>(domain[i]) + i + 1 > domain.size()) return ""; - ret += domain.substr(i + 1, domain[i]); + domain.substr(i + 1, domain[i]).AppendToString(&ret); } return ret; } @@ -92,12 +92,13 @@ bool IsSTD3ASCIIValidCharacter(char c) { return true; } -std::string TrimEndingDot(const std::string& host) { - std::string host_trimmed = host; +std::string TrimEndingDot(const base::StringPiece& host) { + base::StringPiece host_trimmed = host; size_t len = host_trimmed.length(); - if (len > 1 && host_trimmed[len - 1] == '.') - host_trimmed.erase(len - 1); - return host_trimmed; + if (len > 1 && host_trimmed[len - 1] == '.') { + host_trimmed.remove_suffix(1); + } + return host_trimmed.as_string(); } bool DnsResponseBuffer::U8(uint8* v) { diff --git a/net/base/dns_util.h b/net/base/dns_util.h index f09e906..c01b2a2 100644 --- a/net/base/dns_util.h +++ b/net/base/dns_util.h @@ -19,18 +19,22 @@ namespace net { // // dotted: a string in dotted form: "www.google.com" // out: a result in DNS form: "\x03www\x06google\x03com\x00" -NET_EXPORT_PRIVATE bool DNSDomainFromDot(const std::string& dotted, +NET_EXPORT_PRIVATE bool DNSDomainFromDot(const base::StringPiece& dotted, std::string* out); // DNSDomainToString coverts a domain in DNS format to a dotted string. -NET_EXPORT_PRIVATE std::string DNSDomainToString(const std::string& domain); +NET_EXPORT_PRIVATE std::string DNSDomainToString( + const base::StringPiece& domain); // Returns true iff the given character is in the set of valid DNS label // characters as given in RFC 3490, 4.1, 3(a) NET_EXPORT_PRIVATE bool IsSTD3ASCIIValidCharacter(char c); // Returns the hostname by trimming the ending dot, if one exists. -NET_EXPORT std::string TrimEndingDot(const std::string& host); +NET_EXPORT std::string TrimEndingDot(const base::StringPiece& host); + +// TODO(szym): remove all definitions below once DnsRRResolver migrates to +// DnsClient // DNS class types. static const uint16 kClassIN = 1; diff --git a/net/dns/async_host_resolver.cc b/net/dns/async_host_resolver.cc index 1666e04..ff93c3e 100644 --- a/net/dns/async_host_resolver.cc +++ b/net/dns/async_host_resolver.cc @@ -8,12 +8,16 @@ #include "base/bind.h" #include "base/logging.h" +#include "base/message_loop.h" #include "base/rand_util.h" #include "base/stl_util.h" #include "base/values.h" #include "net/base/address_list.h" #include "net/base/dns_util.h" #include "net/base/net_errors.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" #include "net/socket/client_socket_factory.h" namespace net { @@ -22,7 +26,7 @@ namespace { // TODO(agayev): fix this when IPv6 support is added. uint16 QueryTypeFromAddressFamily(AddressFamily address_family) { - return kDNS_A; + return dns_protocol::kTypeA; } class RequestParameters : public NetLog::EventParameters { @@ -56,17 +60,22 @@ class RequestParameters : public NetLog::EventParameters { HostResolver* CreateAsyncHostResolver(size_t max_concurrent_resolves, const IPAddressNumber& dns_ip, NetLog* net_log) { - size_t max_transactions = max_concurrent_resolves; - if (max_transactions == 0) - max_transactions = 20; - size_t max_pending_requests = max_transactions * 100; + size_t max_dns_requests = max_concurrent_resolves; + if (max_dns_requests == 0) + max_dns_requests = 20; + size_t max_pending_requests = max_dns_requests * 100; + DnsConfig config; + config.nameservers.push_back(IPEndPoint(dns_ip, 53)); + DnsSession* session = new DnsSession( + config, + ClientSocketFactory::GetDefaultFactory(), + base::Bind(&base::RandInt), + net_log); HostResolver* resolver = new AsyncHostResolver( - IPEndPoint(dns_ip, 53), - max_transactions, + max_dns_requests, max_pending_requests, - base::Bind(&base::RandInt), HostCache::CreateDefaultCache(), - NULL, + DnsClient::CreateClient(session), net_log); return resolver; } @@ -193,19 +202,15 @@ class AsyncHostResolver::Request { }; //----------------------------------------------------------------------------- -AsyncHostResolver::AsyncHostResolver(const IPEndPoint& dns_server, - size_t max_transactions, +AsyncHostResolver::AsyncHostResolver(size_t max_dns_requests, size_t max_pending_requests, - const RandIntCallback& rand_int_cb, HostCache* cache, - ClientSocketFactory* factory, + DnsClient* client, NetLog* net_log) - : max_transactions_(max_transactions), + : max_dns_requests_(max_dns_requests), max_pending_requests_(max_pending_requests), - dns_server_(dns_server), - rand_int_cb_(rand_int_cb), cache_(cache), - factory_(factory), + client_(client), net_log_(net_log) { } @@ -215,8 +220,8 @@ AsyncHostResolver::~AsyncHostResolver() { it != requestlist_map_.end(); ++it) STLDeleteElements(&it->second); - // Destroy transactions. - STLDeleteElements(&transactions_); + // Destroy DNS requests. + STLDeleteElements(&dns_requests_); // Destroy pending requests. for (size_t i = 0; i < arraysize(pending_requests_); ++i) @@ -240,8 +245,8 @@ int AsyncHostResolver::Resolve(const RequestInfo& info, rv = request->result(); else if (AttachToRequestList(request.get())) rv = ERR_IO_PENDING; - else if (transactions_.size() < max_transactions_) - rv = StartNewTransactionFor(request.get()); + else if (dns_requests_.size() < max_dns_requests_) + rv = StartNewDnsRequestFor(request.get()); else rv = Enqueue(request.get()); @@ -327,39 +332,56 @@ HostCache* AsyncHostResolver::GetHostCache() { return cache_.get(); } -void AsyncHostResolver::OnTransactionComplete( +void AsyncHostResolver::OnDnsRequestComplete( + DnsClient::Request* dns_req, int result, - const DnsTransaction* transaction, - const IPAddressList& ip_addresses) { - DCHECK(std::find(transactions_.begin(), transactions_.end(), transaction) - != transactions_.end()); - DCHECK(requestlist_map_.find(transaction->key()) != requestlist_map_.end()); + const DnsResponse* response) { + DCHECK(std::find(dns_requests_.begin(), dns_requests_.end(), dns_req) + != dns_requests_.end()); - // If by the time requests that caused |transaction| are cancelled, we do + // If by the time requests that caused |dns_req| are cancelled, we do // not have a port number to associate with the result, therefore, we // assume the most common port, otherwise we use the port number of the // first request. - RequestList& requests = requestlist_map_[transaction->key()]; + KeyRequestListMap::iterator rit = requestlist_map_.find( + std::make_pair(dns_req->qname(), dns_req->qtype())); + DCHECK(rit != requestlist_map_.end()); + RequestList& requests = rit->second; int port = requests.empty() ? 80 : requests.front()->info().port(); - // Run callback of every request that was depending on this transaction, + // Extract AddressList out of DnsResponse. + AddressList addr_list; + if (result == OK) { + IPAddressList ip_addresses; + DnsRecordParser parser = response->Parser(); + DnsResourceRecord record; + // TODO(szym): Add stricter checking of names, aliases and address lengths. + while (parser.ParseRecord(&record)) { + if (record.type == dns_req->qtype() && + (record.rdata.size() == kIPv4AddressSize || + record.rdata.size() == kIPv6AddressSize)) { + ip_addresses.push_back(IPAddressNumber(record.rdata.begin(), + record.rdata.end())); + } + } + if (!ip_addresses.empty()) + addr_list = AddressList::CreateFromIPAddressList(ip_addresses, port); + else + result = ERR_NAME_NOT_RESOLVED; + } + + // Run callback of every request that was depending on this DNS request, // also notify observers. - AddressList addrlist; - if (result == OK) - addrlist = AddressList::CreateFromIPAddressList(ip_addresses, port); - for (RequestList::iterator it = requests.begin(); it != requests.end(); - ++it) - (*it)->OnAsyncComplete(result, addrlist); - - // It is possible that the requests that caused |transaction| to be - // created are cancelled by the time |transaction| completes. In that + for (RequestList::iterator it = requests.begin(); it != requests.end(); ++it) + (*it)->OnAsyncComplete(result, addr_list); + + // It is possible that the requests that caused |dns_req| to be + // created are cancelled by the time |dns_req| completes. In that // case |requests| would be empty. We are knowingly throwing away the // result of a DNS resolution in that case, because (a) if there are no // requests, we do not have info to obtain a key from, (b) DnsTransaction // does not have info(), adding one into it just temporarily doesn't make - // sense, since HostCache will be replaced with RR cache soon, (c) - // recreating info from DnsTransaction::Key adds a lot of temporary - // code/functions (like converting back from qtype to AddressFamily.) + // sense, since HostCache will be replaced with RR cache soon. // Also, we only cache positive results. All of this will change when RR // cache is added. if (result == OK && cache_.get() && !requests.empty()) { @@ -367,16 +389,16 @@ void AsyncHostResolver::OnTransactionComplete( HostResolver::RequestInfo info = request->info(); HostCache::Key key( info.hostname(), info.address_family(), info.host_resolver_flags()); - cache_->Set(key, result, addrlist, base::TimeTicks::Now()); + cache_->Set(key, result, addr_list, base::TimeTicks::Now()); } // Cleanup requests. STLDeleteElements(&requests); - requestlist_map_.erase(transaction->key()); + requestlist_map_.erase(rit); - // Cleanup transaction and start a new one if there are pending requests. - delete transaction; - transactions_.remove(transaction); + // Cleanup |dns_req| and start a new one if there are pending requests. + delete dns_req; + dns_requests_.remove(dns_req); ProcessPending(); } @@ -399,25 +421,22 @@ bool AsyncHostResolver::AttachToRequestList(Request* request) { return true; } -int AsyncHostResolver::StartNewTransactionFor(Request* request) { +int AsyncHostResolver::StartNewDnsRequestFor(Request* request) { DCHECK(requestlist_map_.find(request->key()) == requestlist_map_.end()); - DCHECK(transactions_.size() < max_transactions_); + DCHECK(dns_requests_.size() < max_dns_requests_); request->request_net_log().AddEvent( NetLog::TYPE_ASYNC_HOST_RESOLVER_CREATE_DNS_TRANSACTION, NULL); requestlist_map_[request->key()].push_back(request); - DnsTransaction* transaction = new DnsTransaction( - dns_server_, + DnsClient::Request* dns_req = client_->CreateRequest( request->key().first, request->key().second, - rand_int_cb_, - factory_, - request->request_net_log(), - net_log_); - transaction->SetDelegate(this); - transactions_.push_back(transaction); - return transaction->Start(); + base::Bind(&AsyncHostResolver::OnDnsRequestComplete, + base::Unretained(this)), + request->request_net_log()); + dns_requests_.push_back(dns_req); + return dns_req->Start(); } int AsyncHostResolver::Enqueue(Request* request) { @@ -490,7 +509,7 @@ void AsyncHostResolver::ProcessPending() { } } } - StartNewTransactionFor(request); + StartNewDnsRequestFor(request); } } // namespace net diff --git a/net/dns/async_host_resolver.h b/net/dns/async_host_resolver.h index c501d65..e8aeb8b 100644 --- a/net/dns/async_host_resolver.h +++ b/net/dns/async_host_resolver.h @@ -8,6 +8,8 @@ #include <list> #include <map> +#include <string> +#include <utility> #include "base/threading/non_thread_safe.h" #include "net/base/address_family.h" @@ -15,24 +17,18 @@ #include "net/base/host_resolver.h" #include "net/base/ip_endpoint.h" #include "net/base/net_log.h" -#include "net/base/rand_callback.h" -#include "net/dns/dns_transaction.h" +#include "net/dns/dns_client.h" namespace net { -class ClientSocketFactory; - class NET_EXPORT AsyncHostResolver : public HostResolver, - public DnsTransaction::Delegate, NON_EXPORTED_BASE(public base::NonThreadSafe) { public: - AsyncHostResolver(const IPEndPoint& dns_server, - size_t max_transactions, - size_t max_pending_requests_, - const RandIntCallback& rand_int, + AsyncHostResolver(size_t max_dns_requests, + size_t max_pending_requests, HostCache* cache, - ClientSocketFactory* factory, + DnsClient* client, NetLog* net_log); virtual ~AsyncHostResolver(); @@ -50,11 +46,9 @@ class NET_EXPORT AsyncHostResolver virtual AddressFamily GetDefaultAddressFamily() const OVERRIDE; virtual HostCache* GetHostCache() OVERRIDE; - // DnsTransaction::Delegate interface - virtual void OnTransactionComplete( - int result, - const DnsTransaction* transaction, - const IPAddressList& ip_addresses) OVERRIDE; + void OnDnsRequestComplete(DnsClient::Request* request, + int result, + const DnsResponse* transaction); private: FRIEND_TEST_ALL_PREFIXES(AsyncHostResolverTest, QueuedLookup); @@ -68,9 +62,9 @@ class NET_EXPORT AsyncHostResolver class Request; - typedef DnsTransaction::Key Key; + typedef std::pair<std::string, uint16> Key; typedef std::list<Request*> RequestList; - typedef std::list<const DnsTransaction*> TransactionList; + typedef std::list<const DnsClient::Request*> DnsRequestList; typedef std::map<Key, RequestList> KeyRequestListMap; // Create a new request for the incoming Resolve() call. @@ -92,9 +86,9 @@ class NET_EXPORT AsyncHostResolver // attach |request| to the respective list. bool AttachToRequestList(Request* request); - // Will start a new transaction for |request|, will insert a new key in + // Will start a new DNS request for |request|, will insert a new key in // |requestlist_map_| and append |request| to the respective list. - int StartNewTransactionFor(Request* request); + int StartNewDnsRequestFor(Request* request); // Will enqueue |request| in |pending_requests_|. int Enqueue(Request* request); @@ -114,11 +108,11 @@ class NET_EXPORT AsyncHostResolver // there are pending requests. void ProcessPending(); - // Maximum number of concurrent transactions. - size_t max_transactions_; + // Maximum number of concurrent DNS requests. + size_t max_dns_requests_; - // List of current transactions. - TransactionList transactions_; + // List of current DNS requests. + DnsRequestList dns_requests_; // A map from Key to a list of requests waiting for the Key to resolve. KeyRequestListMap requestlist_map_; @@ -129,18 +123,10 @@ class NET_EXPORT AsyncHostResolver // Queues based on priority for putting pending requests. RequestList pending_requests_[NUM_PRIORITIES]; - // DNS server to which queries will be setn. - IPEndPoint dns_server_; - - // Callback to be passed to DnsTransaction for generating DNS query ids. - RandIntCallback rand_int_cb_; - // Cache of host resolution results. scoped_ptr<HostCache> cache_; - // Also passed to DnsTransaction; it's a dependency injection to aid - // testing, outside of unit tests, its value is always NULL. - ClientSocketFactory* factory_; + DnsClient* client_; NetLog* net_log_; diff --git a/net/dns/async_host_resolver_unittest.cc b/net/dns/async_host_resolver_unittest.cc index d92887e..21123bc 100644 --- a/net/dns/async_host_resolver_unittest.cc +++ b/net/dns/async_host_resolver_unittest.cc @@ -6,18 +6,27 @@ #include "base/bind.h" #include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" +#include "base/stl_util.h" #include "net/base/host_cache.h" +#include "net/base/net_errors.h" #include "net/base/net_log.h" -#include "net/base/rand_callback.h" #include "net/base/sys_addrinfo.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/dns_client.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" #include "net/dns/dns_test_util.h" -#include "net/socket/socket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { namespace { +const int kPortNum = 80; +const size_t kMaxTransactions = 2; +const size_t kMaxPendingRequests = 1; + void VerifyAddressList(const std::vector<const char*>& ip_addresses, int port, const AddressList& addrlist) { @@ -39,12 +48,64 @@ void VerifyAddressList(const std::vector<const char*>& ip_addresses, ASSERT_EQ(static_cast<addrinfo*>(NULL), ainfo); } +class MockDnsClient : public DnsClient, + public base::SupportsWeakPtr<MockDnsClient> { + public: + // Using WeakPtr to support cancellation. + // All MockRequests succeed unless canceled or MockDnsClient is destroyed. + class MockRequest : public DnsClient::Request, + public base::SupportsWeakPtr<MockRequest> { + public: + MockRequest(const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback, + const base::WeakPtr<MockDnsClient>& client) + : Request(qname, qtype, callback), started_(false), client_(client) { + } + + virtual int Start() OVERRIDE { + EXPECT_FALSE(started_); + started_ = true; + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MockRequest::Finish, AsWeakPtr())); + return ERR_IO_PENDING; + } + + private: + void Finish() { + if (!client_) { + DoCallback(ERR_DNS_SERVER_FAILED, NULL); + return; + } + DoCallback(OK, client_->responses[Key(qname(), qtype())]); + } + + bool started_; + base::WeakPtr<MockDnsClient> client_; + }; + + typedef std::pair<std::string, uint16> Key; + + MockDnsClient() : num_requests(0) {} + ~MockDnsClient() { + STLDeleteValues(&responses); + } + + Request* CreateRequest(const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback, + const BoundNetLog&) { + ++num_requests; + return new MockRequest(qname, qtype, callback, AsWeakPtr()); + } + + int num_requests; + std::map<Key, DnsResponse*> responses; +}; + } // namespace -static const int kPortNum = 80; -static const size_t kMaxTransactions = 2; -static const size_t kMaxPendingRequests = 1; -static int transaction_ids[] = {0, 1, 2, 3}; // The following fixture sets up an environment for four different lookups // with their data defined in dns_test_util.h. All tests make use of these @@ -69,84 +130,52 @@ class AsyncHostResolverTest : public testing::Test { ip_addresses2_(kT2IpAddresses, kT2IpAddresses + arraysize(kT2IpAddresses)), ip_addresses3_(kT3IpAddresses, - kT3IpAddresses + arraysize(kT3IpAddresses)), - test_prng_(std::deque<int>( - transaction_ids, transaction_ids + arraysize(transaction_ids))) { - rand_int_cb_ = base::Bind(&TestPrng::GetNext, - base::Unretained(&test_prng_)); + kT3IpAddresses + arraysize(kT3IpAddresses)) { // AF_INET only for now. info0_.set_address_family(ADDRESS_FAMILY_IPV4); info1_.set_address_family(ADDRESS_FAMILY_IPV4); info2_.set_address_family(ADDRESS_FAMILY_IPV4); info3_.set_address_family(ADDRESS_FAMILY_IPV4); - // Setup socket read/writes for transaction 0. - writes0_.push_back( - MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), - arraysize(kT0QueryDatagram))); - reads0_.push_back( - MockRead(true, reinterpret_cast<const char*>(kT0ResponseDatagram), - arraysize(kT0ResponseDatagram))); - data0_.reset(new StaticSocketDataProvider(&reads0_[0], reads0_.size(), - &writes0_[0], writes0_.size())); - - // Setup socket read/writes for transaction 1. - writes1_.push_back( - MockWrite(true, reinterpret_cast<const char*>(kT1QueryDatagram), - arraysize(kT1QueryDatagram))); - reads1_.push_back( - MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram), - arraysize(kT1ResponseDatagram))); - data1_.reset(new StaticSocketDataProvider(&reads1_[0], reads1_.size(), - &writes1_[0], writes1_.size())); - - // Setup socket read/writes for transaction 2. - writes2_.push_back( - MockWrite(true, reinterpret_cast<const char*>(kT2QueryDatagram), - arraysize(kT2QueryDatagram))); - reads2_.push_back( - MockRead(true, reinterpret_cast<const char*>(kT2ResponseDatagram), - arraysize(kT2ResponseDatagram))); - data2_.reset(new StaticSocketDataProvider(&reads2_[0], reads2_.size(), - &writes2_[0], writes2_.size())); - - // Setup socket read/writes for transaction 3. - writes3_.push_back( - MockWrite(true, reinterpret_cast<const char*>(kT3QueryDatagram), - arraysize(kT3QueryDatagram))); - reads3_.push_back( - MockRead(true, reinterpret_cast<const char*>(kT3ResponseDatagram), - arraysize(kT3ResponseDatagram))); - data3_.reset(new StaticSocketDataProvider(&reads3_[0], reads3_.size(), - &writes3_[0], writes3_.size())); - - factory_.AddSocketDataProvider(data0_.get()); - factory_.AddSocketDataProvider(data1_.get()); - factory_.AddSocketDataProvider(data2_.get()); - factory_.AddSocketDataProvider(data3_.get()); - - IPEndPoint dns_server; - bool rv0 = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); - DCHECK(rv0); + client_.reset(new MockDnsClient()); + + AddResponse(std::string(kT0DnsName, arraysize(kT0DnsName)), kT0Qtype, + new DnsResponse(reinterpret_cast<const char*>(kT0ResponseDatagram), + arraysize(kT0ResponseDatagram), + arraysize(kT0QueryDatagram))); + + AddResponse(std::string(kT1DnsName, arraysize(kT1DnsName)), kT1Qtype, + new DnsResponse(reinterpret_cast<const char*>(kT1ResponseDatagram), + arraysize(kT1ResponseDatagram), + arraysize(kT1QueryDatagram))); + + AddResponse(std::string(kT2DnsName, arraysize(kT2DnsName)), kT2Qtype, + new DnsResponse(reinterpret_cast<const char*>(kT2ResponseDatagram), + arraysize(kT2ResponseDatagram), + arraysize(kT2QueryDatagram))); + + AddResponse(std::string(kT3DnsName, arraysize(kT3DnsName)), kT3Qtype, + new DnsResponse(reinterpret_cast<const char*>(kT3ResponseDatagram), + arraysize(kT3ResponseDatagram), + arraysize(kT3QueryDatagram))); resolver_.reset( - new AsyncHostResolver( - dns_server, kMaxTransactions, kMaxPendingRequests, rand_int_cb_, - HostCache::CreateDefaultCache(), &factory_, NULL)); + new AsyncHostResolver(kMaxTransactions, kMaxPendingRequests, + HostCache::CreateDefaultCache(), + client_.get(), NULL)); + } + + void AddResponse(const std::string& name, uint8 type, DnsResponse* response) { + client_->responses[MockDnsClient::Key(name, type)] = response; } protected: AddressList addrlist0_, addrlist1_, addrlist2_, addrlist3_; HostResolver::RequestInfo info0_, info1_, info2_, info3_; - std::vector<MockWrite> writes0_, writes1_, writes2_, writes3_; - std::vector<MockRead> reads0_, reads1_, reads2_, reads3_; - scoped_ptr<StaticSocketDataProvider> data0_, data1_, data2_, data3_; std::vector<const char*> ip_addresses0_, ip_addresses1_, ip_addresses2_, ip_addresses3_; - MockClientSocketFactory factory_; - TestPrng test_prng_; - RandIntCallback rand_int_cb_; scoped_ptr<HostResolver> resolver_; + scoped_ptr<MockDnsClient> client_; TestCompletionCallback callback0_, callback1_, callback2_, callback3_; }; @@ -242,7 +271,7 @@ TEST_F(AsyncHostResolverTest, ConcurrentLookup) { EXPECT_EQ(OK, rv2); VerifyAddressList(ip_addresses2_, kPortNum, addrlist2_); - EXPECT_EQ(3u, factory_.udp_client_sockets().size()); + EXPECT_EQ(3, client_->num_requests); } TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) { @@ -270,7 +299,7 @@ TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) { VerifyAddressList(ip_addresses0_, kPortNum, addrlist2_); // Although we have three lookups, a single UDP socket was used. - EXPECT_EQ(1u, factory_.udp_client_sockets().size()); + EXPECT_EQ(1, client_->num_requests); } TEST_F(AsyncHostResolverTest, CancelLookup) { @@ -319,7 +348,7 @@ TEST_F(AsyncHostResolverTest, CancelSameHostLookup) { EXPECT_EQ(OK, rv1); VerifyAddressList(ip_addresses0_, kPortNum, addrlist1_); - EXPECT_EQ(1u, factory_.udp_client_sockets().size()); + EXPECT_EQ(1, client_->num_requests); } // Test that a queued lookup completes. diff --git a/net/dns/dns_client.cc b/net/dns/dns_client.cc new file mode 100644 index 0000000..de60cc3 --- /dev/null +++ b/net/dns/dns_client.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_client.h" + +#include "base/bind.h" +#include "base/string_piece.h" +#include "net/base/net_errors.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" +#include "net/dns/dns_transaction.h" +#include "net/socket/client_socket_factory.h" + +namespace net { + +DnsClient::Request::Request(const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback) + : qname_(qname.data(), qname.size()), + qtype_(qtype), + callback_(callback) { +} + +DnsClient::Request::~Request() {} + +// Implementation of DnsClient that uses DnsTransaction to serve requests. +class DnsClientImpl : public DnsClient { + public: + class RequestImpl : public Request { + public: + RequestImpl(const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback, + DnsSession* session, + const BoundNetLog& net_log) + : Request(qname, qtype, callback), + session_(session), + net_log_(net_log) { + } + + virtual int Start() OVERRIDE { + transaction_.reset(new DnsTransaction( + session_, + qname(), + qtype(), + base::Bind(&RequestImpl::OnComplete, base::Unretained(this)), + net_log_)); + return transaction_->Start(); + } + + void OnComplete(DnsTransaction* transaction, int rv) { + DCHECK_EQ(transaction_.get(), transaction); + // TODO(szym): + // - handle retransmissions here instead of DnsTransaction + // - handle rcode and flags here instead of DnsTransaction + // - update RTT in DnsSession + // - perform suffix search + // - handle DNS over TCP + DoCallback(rv, (rv == OK) ? transaction->response() : NULL); + } + + private: + scoped_refptr<DnsSession> session_; + BoundNetLog net_log_; + scoped_ptr<DnsTransaction> transaction_; + }; + + explicit DnsClientImpl(DnsSession* session) { + session_ = session; + } + + virtual Request* CreateRequest( + const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback, + const BoundNetLog& source_net_log) OVERRIDE { + return new RequestImpl(qname, qtype, callback, session_, source_net_log); + } + + private: + scoped_refptr<DnsSession> session_; +}; + +// static +DnsClient* DnsClient::CreateClient(DnsSession* session) { + return new DnsClientImpl(session); +} + +} // namespace net + diff --git a/net/dns/dns_client.h b/net/dns/dns_client.h new file mode 100644 index 0000000..af75cf6 --- /dev/null +++ b/net/dns/dns_client.h @@ -0,0 +1,93 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CLIENT_H_ +#define NET_DNS_DNS_CLIENT_H_ +#pragma once + +#include <string> + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/memory/weak_ptr.h" +#include "net/base/net_export.h" + +namespace base { +class StringPiece; +} + +namespace net { + +class BoundNetLog; +class ClientSocketFactory; +class DnsResponse; +class DnsSession; + +// DnsClient performs asynchronous DNS queries. DnsClient takes care of +// retransmissions, DNS server fallback (or round-robin), suffix search, and +// simple response validation ("does it match the query") to fight poisoning. +// It does NOT perform caching, aggregation or prioritization of requests. +// +// Destroying DnsClient does NOT affect any already created Requests. +// +// TODO(szym): consider adding flags to MakeRequest to indicate options: +// -- don't perform suffix search +// -- query both A and AAAA at once +// -- accept truncated response (and/or forbid TCP) +class NET_EXPORT_PRIVATE DnsClient { + public: + class Request; + // Callback for complete requests. Note that DnsResponse might be NULL if + // the DNS server(s) could not be reached. + typedef base::Callback<void(Request* req, + int result, + const DnsResponse* resp)> RequestCallback; + + // A data-holder for a request made to the DnsClient. + // Destroying the request cancels the underlying network effort. + class NET_EXPORT_PRIVATE Request { + public: + Request(const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback); + virtual ~Request(); + + const std::string& qname() const { return qname_; } + + uint16 qtype() const { return qtype_; } + + virtual int Start() = 0; + + void DoCallback(int result, const DnsResponse* response) { + callback_.Run(this, result, response); + } + + private: + std::string qname_; + uint16 qtype_; + RequestCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(Request); + }; + + virtual ~DnsClient() {} + + // Makes asynchronous DNS query for the given |qname| and |qtype| (assuming + // QCLASS == IN). The caller is responsible for destroying the returned + // request whether to cancel it or after its completion. + // (Destroying DnsClient does not abort the requests.) + virtual Request* CreateRequest( + const base::StringPiece& qname, + uint16 qtype, + const RequestCallback& callback, + const BoundNetLog& source_net_log) WARN_UNUSED_RESULT = 0; + + // Creates a socket-based DnsClient using the |session|. + static DnsClient* CreateClient(DnsSession* session) WARN_UNUSED_RESULT; +}; + +} // namespace net + +#endif // NET_DNS_DNS_CLIENT_H_ + diff --git a/net/dns/dns_client_unittest.cc b/net/dns/dns_client_unittest.cc new file mode 100644 index 0000000..fdbae8f --- /dev/null +++ b/net/dns/dns_client_unittest.cc @@ -0,0 +1,311 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_client.h" + +#include "base/bind.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/big_endian.h" +#include "net/base/net_log.h" +#include "net/base/sys_addrinfo.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" +#include "net/dns/dns_test_util.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +// TODO(szym): test DnsClient::Request::Start with synchronous failure +// TODO(szym): test suffix search and server fallback once implemented + +namespace net { + +namespace { + +class DnsClientTest : public testing::Test { + public: + class TestRequestHelper { + public: + // If |answer_count| < 0, it is the expected error code. + TestRequestHelper(const char* name, + uint16 type, + const MockWrite& write, + const MockRead& read, + int answer_count) { + // Must include the terminating \x00. + qname = std::string(name, strlen(name) + 1); + qtype = type; + expected_answer_count = answer_count; + completed = false; + writes.push_back(write); + reads.push_back(read); + ReadBigEndian<uint16>(write.data, &transaction_id); + data.reset(new StaticSocketDataProvider(&reads[0], reads.size(), + &writes[0], writes.size())); + } + + void MakeRequest(DnsClient* client) { + EXPECT_EQ(NULL, request.get()); + request.reset(client->CreateRequest( + qname, + qtype, + base::Bind(&TestRequestHelper::OnRequestComplete, + base::Unretained(this)), + BoundNetLog())); + EXPECT_EQ(qname, request->qname()); + EXPECT_EQ(qtype, request->qtype()); + EXPECT_EQ(ERR_IO_PENDING, request->Start()); + } + + void Cancel() { + ASSERT_TRUE(request.get() != NULL); + request.reset(NULL); + } + + void OnRequestComplete(DnsClient::Request* req, + int rv, + const DnsResponse* response) { + EXPECT_FALSE(completed); + EXPECT_EQ(request.get(), req); + + if (expected_answer_count >= 0) { + EXPECT_EQ(OK, rv); + EXPECT_EQ(expected_answer_count, response->answer_count()); + + DnsRecordParser parser = response->Parser(); + DnsResourceRecord record; + for (int i = 0; i < expected_answer_count; ++i) { + EXPECT_TRUE(parser.ParseRecord(&record)); + } + EXPECT_TRUE(parser.AtEnd()); + + } else { + EXPECT_EQ(expected_answer_count, rv); + EXPECT_EQ(NULL, response); + } + + completed = true; + } + + void CancelOnRequestComplete(DnsClient::Request* req, + int rv, + const DnsResponse* response) { + EXPECT_FALSE(completed); + Cancel(); + } + + std::string qname; + uint16 qtype; + std::vector<MockWrite> writes; + std::vector<MockRead> reads; + uint16 transaction_id; // Id from first write. + scoped_ptr<StaticSocketDataProvider> data; + scoped_ptr<DnsClient::Request> request; + int expected_answer_count; + + bool completed; + }; + + virtual void SetUp() OVERRIDE { + helpers_.push_back(new TestRequestHelper( + kT0DnsName, + kT0Qtype, + MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), + arraysize(kT0QueryDatagram)), + MockRead(true, reinterpret_cast<const char*>(kT0ResponseDatagram), + arraysize(kT0ResponseDatagram)), + arraysize(kT0IpAddresses) + 1)); // +1 for CNAME RR + + helpers_.push_back(new TestRequestHelper( + kT1DnsName, + kT1Qtype, + MockWrite(true, reinterpret_cast<const char*>(kT1QueryDatagram), + arraysize(kT1QueryDatagram)), + MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram), + arraysize(kT1ResponseDatagram)), + arraysize(kT1IpAddresses) + 1)); // +1 for CNAME RR + + helpers_.push_back(new TestRequestHelper( + kT2DnsName, + kT2Qtype, + MockWrite(true, reinterpret_cast<const char*>(kT2QueryDatagram), + arraysize(kT2QueryDatagram)), + MockRead(true, reinterpret_cast<const char*>(kT2ResponseDatagram), + arraysize(kT2ResponseDatagram)), + arraysize(kT2IpAddresses) + 1)); // +1 for CNAME RR + + helpers_.push_back(new TestRequestHelper( + kT3DnsName, + kT3Qtype, + MockWrite(true, reinterpret_cast<const char*>(kT3QueryDatagram), + arraysize(kT3QueryDatagram)), + MockRead(true, reinterpret_cast<const char*>(kT3ResponseDatagram), + arraysize(kT3ResponseDatagram)), + arraysize(kT3IpAddresses) + 2)); // +2 for CNAME RR + + CreateClient(); + } + + void CreateClient() { + MockClientSocketFactory* factory = new MockClientSocketFactory(); + + transaction_ids_.clear(); + for (unsigned i = 0; i < helpers_.size(); ++i) { + factory->AddSocketDataProvider(helpers_[i]->data.get()); + transaction_ids_.push_back(static_cast<int>(helpers_[i]->transaction_id)); + } + + DnsConfig config; + + IPEndPoint dns_server; + { + bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); + EXPECT_TRUE(rv); + } + config.nameservers.push_back(dns_server); + + DnsSession* session = new DnsSession( + config, + factory, + base::Bind(&DnsClientTest::GetNextId, base::Unretained(this)), + NULL /* NetLog */); + + client_.reset(DnsClient::CreateClient(session)); + } + + virtual void TearDown() OVERRIDE { + STLDeleteElements(&helpers_); + } + + int GetNextId(int min, int max) { + EXPECT_FALSE(transaction_ids_.empty()); + int id = transaction_ids_.front(); + transaction_ids_.pop_front(); + EXPECT_GE(id, min); + EXPECT_LE(id, max); + return id; + } + + protected: + std::vector<TestRequestHelper*> helpers_; + std::deque<int> transaction_ids_; + scoped_ptr<DnsClient> client_; +}; + +TEST_F(DnsClientTest, Lookup) { + helpers_[0]->MakeRequest(client_.get()); + + // Wait until result. + MessageLoop::current()->RunAllPending(); + + EXPECT_TRUE(helpers_[0]->completed); +} + +TEST_F(DnsClientTest, ConcurrentLookup) { + for (unsigned i = 0; i < helpers_.size(); ++i) { + helpers_[i]->MakeRequest(client_.get()); + } + + MessageLoop::current()->RunAllPending(); + + for (unsigned i = 0; i < helpers_.size(); ++i) { + EXPECT_TRUE(helpers_[i]->completed); + } +} + +TEST_F(DnsClientTest, CancelLookup) { + for (unsigned i = 0; i < helpers_.size(); ++i) { + helpers_[i]->MakeRequest(client_.get()); + } + + helpers_[0]->Cancel(); + helpers_[2]->Cancel(); + + MessageLoop::current()->RunAllPending(); + + EXPECT_FALSE(helpers_[0]->completed); + EXPECT_TRUE(helpers_[1]->completed); + EXPECT_FALSE(helpers_[2]->completed); + EXPECT_TRUE(helpers_[3]->completed); +} + +TEST_F(DnsClientTest, DestroyClient) { + for (unsigned i = 0; i < helpers_.size(); ++i) { + helpers_[i]->MakeRequest(client_.get()); + } + + // Destroying the client does not affect running requests. + client_.reset(NULL); + + MessageLoop::current()->RunAllPending(); + + for (unsigned i = 0; i < helpers_.size(); ++i) { + EXPECT_TRUE(helpers_[i]->completed); + } +} + +TEST_F(DnsClientTest, DestroyRequestFromCallback) { + // Custom callback to cancel the completing request. + helpers_[0]->request.reset(client_->CreateRequest( + helpers_[0]->qname, + helpers_[0]->qtype, + base::Bind(&TestRequestHelper::CancelOnRequestComplete, + base::Unretained(helpers_[0])), + BoundNetLog())); + helpers_[0]->request->Start(); + + for (unsigned i = 1; i < helpers_.size(); ++i) { + helpers_[i]->MakeRequest(client_.get()); + } + + MessageLoop::current()->RunAllPending(); + + EXPECT_FALSE(helpers_[0]->completed); + for (unsigned i = 1; i < helpers_.size(); ++i) { + EXPECT_TRUE(helpers_[i]->completed); + } +} + +TEST_F(DnsClientTest, HandleFailure) { + STLDeleteElements(&helpers_); + // Wrong question. + helpers_.push_back(new TestRequestHelper( + kT0DnsName, + kT0Qtype, + MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), + arraysize(kT0QueryDatagram)), + MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram), + arraysize(kT1ResponseDatagram)), + ERR_DNS_MALFORMED_RESPONSE)); + + // Response with NXDOMAIN. + uint8 nxdomain_response[arraysize(kT0QueryDatagram)]; + memcpy(nxdomain_response, kT0QueryDatagram, arraysize(nxdomain_response)); + nxdomain_response[2] &= 0x80; // Response bit. + nxdomain_response[3] &= 0x03; // NXDOMAIN bit. + helpers_.push_back(new TestRequestHelper( + kT0DnsName, + kT0Qtype, + MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), + arraysize(kT0QueryDatagram)), + MockRead(true, reinterpret_cast<const char*>(nxdomain_response), + arraysize(nxdomain_response)), + ERR_NAME_NOT_RESOLVED)); + + CreateClient(); + + for (unsigned i = 0; i < helpers_.size(); ++i) { + helpers_[i]->MakeRequest(client_.get()); + } + + MessageLoop::current()->RunAllPending(); + + for (unsigned i = 0; i < helpers_.size(); ++i) { + EXPECT_TRUE(helpers_[i]->completed); + } +} + +} // namespace + +} // namespace net + diff --git a/net/dns/dns_protocol.h b/net/dns/dns_protocol.h new file mode 100644 index 0000000..494429b --- /dev/null +++ b/net/dns/dns_protocol.h @@ -0,0 +1,122 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_PROTOCOL_H_ +#define NET_DNS_DNS_PROTOCOL_H_ +#pragma once + +#include "base/basictypes.h" +#include "net/base/net_export.h" + +namespace net { + +namespace dns_protocol { + +// DNS packet consists of a header followed by questions and/or answers. +// For the meaning of specific fields, please see RFC 1035 and 2535 + +// Header format. +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ID | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// |QR| Opcode |AA|TC|RD|RA| Z|AD|CD| RCODE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QDCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ANCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | NSCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ARCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + +// Question format. +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / QNAME / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QTYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QCLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + +// Answer format. +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / / +// / NAME / +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | CLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TTL | +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | RDLENGTH | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| +// / RDATA / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + +#pragma pack(push) +#pragma pack(1) + +// On-the-wire header. All uint16 are in network order. +// Used internally in DnsQuery and DnsResponseParser. +struct NET_EXPORT_PRIVATE Header { + uint16 id; + uint8 flags[2]; + uint16 qdcount; + uint16 ancount; + uint16 nscount; + uint16 arcount; +}; + +#pragma pack(pop) + +static const uint8 kLabelMask = 0xc0; +static const uint8 kLabelPointer = 0xc0; +static const uint8 kLabelDirect = 0x0; +static const uint16 kOffsetMask = 0x3fff; + +static const int kMaxNameLength = 255; + +// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512 +// bytes (not counting the IP nor UDP headers). +static const int kMaxUDPSize = 512; + +// DNS class types. +static const uint16 kClassIN = 1; + +// DNS resource record types. See +// http://www.iana.org/assignments/dns-parameters +static const uint16 kTypeA = 1; +static const uint16 kTypeCNAME = 5; +static const uint16 kTypeTXT = 16; +static const uint16 kTypeAAAA = 28; + +// DNS rcode values. +static const uint8 kRcodeMask = 0xf; +static const uint8 kRcodeNOERROR = 0; +static const uint8 kRcodeFORMERR = 1; +static const uint8 kRcodeSERVFAIL = 2; +static const uint8 kRcodeNXDOMAIN = 3; +static const uint8 kRcodeNOTIMP = 4; +static const uint8 kRcodeREFUSED = 5; + +} // namespace dns_protocol + +} // namespace net + +#endif // NET_DNS_DNS_PROTOCOL_H_ + diff --git a/net/dns/dns_query.cc b/net/dns/dns_query.cc index 788d653..3cfb5cd 100644 --- a/net/dns/dns_query.cc +++ b/net/dns/dns_query.cc @@ -6,91 +6,78 @@ #include <limits> +#include "net/base/big_endian.h" #include "net/base/dns_util.h" #include "net/base/io_buffer.h" +#include "net/base/sys_byteorder.h" +#include "net/dns/dns_protocol.h" namespace net { -namespace { - -void PackUint16BE(char buf[2], uint16 v) { - buf[0] = v >> 8; - buf[1] = v & 0xff; -} - -uint16 UnpackUint16BE(char buf[2]) { - return static_cast<uint8>(buf[0]) << 8 | static_cast<uint8>(buf[1]); -} - -} // namespace - // DNS query consists of a 12-byte header followed by a question section. // For details, see RFC 1035 section 4.1.1. This header template sets RD // bit, which directs the name server to pursue query recursively, and sets -// the QDCOUNT to 1, meaning the question section has a single entry. The -// first two bytes of the header form a 16-bit random query ID to be copied -// in the corresponding reply by the name server -- randomized during -// DnsQuery construction. -static const char kHeader[] = {0x00, 0x00, 0x01, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; -static const size_t kHeaderSize = arraysize(kHeader); - -DnsQuery::DnsQuery(const std::string& qname, - uint16 qtype, - const RandIntCallback& rand_int_cb) - : qname_size_(qname.size()), - rand_int_cb_(rand_int_cb) { - DCHECK(DnsResponseBuffer(reinterpret_cast<const uint8*>(qname.c_str()), - qname.size()).DNSName(NULL)); - DCHECK(qtype == kDNS_A || qtype == kDNS_AAAA); - - io_buffer_ = new IOBufferWithSize(kHeaderSize + question_size()); - - int byte_offset = 0; - char* buffer_head = io_buffer_->data(); - memcpy(&buffer_head[byte_offset], kHeader, kHeaderSize); - byte_offset += kHeaderSize; - memcpy(&buffer_head[byte_offset], &qname[0], qname_size_); - byte_offset += qname_size_; - PackUint16BE(&buffer_head[byte_offset], qtype); - byte_offset += sizeof(qtype); - PackUint16BE(&buffer_head[byte_offset], kClassIN); - RandomizeId(); +// the QDCOUNT to 1, meaning the question section has a single entry. +DnsQuery::DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype) + : qname_size_(qname.size()) { + DCHECK(!DNSDomainToString(qname).empty()); + // QNAME + QTYPE + QCLASS + size_t question_size = qname_size_ + sizeof(uint16) + sizeof(uint16); + io_buffer_ = new IOBufferWithSize(sizeof(dns_protocol::Header) + + question_size); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + memset(header, 0, sizeof(dns_protocol::Header)); + header->id = htons(id); + header->flags[0] = 0x1; // RD bit + header->qdcount = htons(1); + + // Write question section after the header. + BigEndianWriter writer(reinterpret_cast<char*>(header + 1), question_size); + writer.WriteBytes(qname.data(), qname.size()); + writer.WriteU16(qtype); + writer.WriteU16(dns_protocol::kClassIN); } DnsQuery::~DnsQuery() { } -uint16 DnsQuery::id() const { - return UnpackUint16BE(&io_buffer_->data()[0]); +DnsQuery* DnsQuery::CloneWithNewId(uint16 id) const { + return new DnsQuery(*this, id); } -uint16 DnsQuery::qtype() const { - return UnpackUint16BE(&io_buffer_->data()[kHeaderSize + qname_size_]); -} - -DnsQuery* DnsQuery::CloneWithNewId() const { - return new DnsQuery(qname(), qtype(), rand_int_cb_); +uint16 DnsQuery::id() const { + const dns_protocol::Header* header = + reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data()); + return ntohs(header->id); } -size_t DnsQuery::question_size() const { - return qname_size_ // QNAME - + sizeof(uint16) // QTYPE - + sizeof(uint16); // QCLASS +base::StringPiece DnsQuery::qname() const { + return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header), + qname_size_); } -const char* DnsQuery::question_data() const { - return &io_buffer_->data()[kHeaderSize]; +uint16 DnsQuery::qtype() const { + uint16 type; + ReadBigEndian<uint16>(io_buffer_->data() + + sizeof(dns_protocol::Header) + + qname_size_, &type); + return type; } -const std::string DnsQuery::qname() const { - return std::string(question_data(), qname_size_); +base::StringPiece DnsQuery::question() const { + return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header), + qname_size_ + sizeof(uint16) + sizeof(uint16)); } -void DnsQuery::RandomizeId() { - PackUint16BE(&io_buffer_->data()[0], rand_int_cb_.Run( - std::numeric_limits<uint16>::min(), - std::numeric_limits<uint16>::max())); +DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) { + qname_size_ = orig.qname_size_; + io_buffer_ = new IOBufferWithSize(orig.io_buffer()->size()); + memcpy(io_buffer_.get()->data(), orig.io_buffer()->data(), + io_buffer_.get()->size()); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + header->id = htons(id); } } // namespace net diff --git a/net/dns/dns_query.h b/net/dns/dns_query.h index c6bd3bc..e851113 100644 --- a/net/dns/dns_query.h +++ b/net/dns/dns_query.h @@ -6,53 +6,41 @@ #define NET_DNS_DNS_QUERY_H_ #pragma once -#include <string> - +#include "base/basictypes.h" #include "base/memory/ref_counted.h" +#include "base/string_piece.h" #include "net/base/net_export.h" -#include "net/base/rand_callback.h" namespace net { class IOBufferWithSize; // Represents on-the-wire DNS query message as an object. +// TODO(szym): add support for the OPT pseudo-RR (EDNS0/DNSSEC). class NET_EXPORT_PRIVATE DnsQuery { public: // Constructs a query message from |qname| which *MUST* be in a valid - // DNS name format, and |qtype| which must be either kDNS_A or kDNS_AAAA. - - // Every generated object has a random ID, hence two objects generated - // with the same set of constructor arguments are generally not equal; - // there is a 1/2^16 chance of them being equal due to size of |id_|. - DnsQuery(const std::string& qname, - uint16 qtype, - const RandIntCallback& rand_int_cb); + // DNS name format, and |qtype|. The qclass is set to IN. + DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype); ~DnsQuery(); - // Clones |this| verbatim, with ID field of the header regenerated. - DnsQuery* CloneWithNewId() const; + // Clones |this| verbatim, with ID field of the header set to |id|. + DnsQuery* CloneWithNewId(uint16 id) const; // DnsQuery field accessors. uint16 id() const; + base::StringPiece qname() const; uint16 qtype() const; - // Returns the size of the Question section of the query. Used when - // matching the response. - size_t question_size() const; - - // Returns pointer to the Question section of the query. Used when - // matching the response. - const char* question_data() const; + // Returns the Question section of the query. Used when matching the + // response. + base::StringPiece question() const; // IOBuffer accessor to be used for writing out the query. IOBufferWithSize* io_buffer() const { return io_buffer_; } private: - const std::string qname() const; - - // Randomizes ID field of the query message. - void RandomizeId(); + DnsQuery(const DnsQuery& orig, uint16 id); // Size of the DNS name (*NOT* hostname) we are trying to resolve; used // to calculate offsets. @@ -61,9 +49,6 @@ class NET_EXPORT_PRIVATE DnsQuery { // Contains query bytes to be consumed by higher level Write() call. scoped_refptr<IOBufferWithSize> io_buffer_; - // PRNG function for generating IDs. - RandIntCallback rand_int_cb_; - DISALLOW_COPY_AND_ASSIGN(DnsQuery); }; diff --git a/net/dns/dns_query_unittest.cc b/net/dns/dns_query_unittest.cc index d43cf8be..ffe02e7 100644 --- a/net/dns/dns_query_unittest.cc +++ b/net/dns/dns_query_unittest.cc @@ -5,58 +5,21 @@ #include "net/dns/dns_query.h" #include "base/bind.h" -#include "base/rand_util.h" #include "net/base/dns_util.h" #include "net/base/io_buffer.h" +#include "net/dns/dns_protocol.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { -// DNS query consists of a header followed by a question. Header format -// and question format are described below. For the meaning of specific -// fields, please see RFC 1035. +namespace { -// Header format. -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ID | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// |QR| Opcode |AA|TC|RD|RA| Z | RCODE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QDCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ANCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | NSCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ARCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - -// Question format. -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | | -// / QNAME / -// / / -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QTYPE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QCLASS | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - -TEST(DnsQueryTest, ConstructorTest) { - std::string kQname("\003www\006google\003com", 16); - DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt)); - EXPECT_EQ(kDNS_A, q1.qtype()); - - uint8 id_hi = q1.id() >> 8, id_lo = q1.id() & 0xff; - - // See the top of the file for the description of a DNS query. +TEST(DnsQueryTest, Constructor) { + // This includes \0 at the end. + const char qname_data[] = "\x03""www""\x07""example""\x03""com"; const uint8 query_data[] = { // Header - id_hi, id_lo, + 0xbe, 0xef, 0x01, 0x00, // Flags -- set RD (recursion desired) bit. 0x00, 0x01, // Set QDCOUNT (question count) to 1, all the // rest are 0 for a query. @@ -65,46 +28,42 @@ TEST(DnsQueryTest, ConstructorTest) { 0x00, 0x00, // Question - 0x03, 0x77, 0x77, 0x77, // QNAME: www.google.com in DNS format. - 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x03, 0x63, 0x6f, 0x6d, 0x00, + 0x03, 'w', 'w', 'w', // QNAME: www.example.com in DNS format. + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, 0x00, 0x01, // QTYPE: A query. 0x00, 0x01, // QCLASS: IN class. }; - int expected_size = arraysize(query_data); - EXPECT_EQ(expected_size, q1.io_buffer()->size()); - EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, expected_size)); + base::StringPiece qname(qname_data, sizeof(qname_data)); + DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA); + EXPECT_EQ(dns_protocol::kTypeA, q1.qtype()); + + ASSERT_EQ(static_cast<int>(sizeof(query_data)), q1.io_buffer()->size()); + EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, sizeof(query_data))); + EXPECT_EQ(qname, q1.qname()); + + base::StringPiece question(reinterpret_cast<const char*>(query_data) + 12, + 21); + EXPECT_EQ(question, q1.question()); } -TEST(DnsQueryTest, CloneTest) { - std::string kQname("\003www\006google\003com", 16); - DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt)); +TEST(DnsQueryTest, Clone) { + // This includes \0 at the end. + const char qname_data[] = "\x03""www""\x07""example""\x03""com"; + base::StringPiece qname(qname_data, sizeof(qname_data)); - scoped_ptr<DnsQuery> q2(q1.CloneWithNewId()); + DnsQuery q1(0, qname, dns_protocol::kTypeA); + EXPECT_EQ(0, q1.id()); + scoped_ptr<DnsQuery> q2(q1.CloneWithNewId(42)); + EXPECT_EQ(42, q2->id()); EXPECT_EQ(q1.io_buffer()->size(), q2->io_buffer()->size()); EXPECT_EQ(q1.qtype(), q2->qtype()); - EXPECT_EQ(q1.question_size(), q2->question_size()); - EXPECT_EQ(0, memcmp(q1.question_data(), q2->question_data(), - q1.question_size())); + EXPECT_EQ(q1.question(), q2->question()); } -TEST(DnsQueryTest, RandomIdTest) { - std::string kQname("\003www\006google\003com", 16); - - // Since id fields are 16-bit values, we iterate to reduce the - // probability of collision, to avoid a flaky test. - bool ids_are_random = false; - for (int i = 0; i < 1000; ++i) { - DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt)); - DnsQuery q2(kQname, kDNS_A, base::Bind(&base::RandInt)); - scoped_ptr<DnsQuery> q3(q1.CloneWithNewId()); - ids_are_random = q1.id () != q2.id() && q1.id() != q3->id(); - if (ids_are_random) - break; - } - EXPECT_TRUE(ids_are_random); -} +} // namespace } // namespace net diff --git a/net/dns/dns_response.cc b/net/dns/dns_response.cc index 3c5d605..805d6f3 100644 --- a/net/dns/dns_response.cc +++ b/net/dns/dns_response.cc @@ -4,97 +4,185 @@ #include "net/dns/dns_response.h" -#include "net/base/dns_util.h" +#include "net/base/big_endian.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "net/base/sys_byteorder.h" +#include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" namespace net { -// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512 -// bytes (not counting the IP nor UDP headers). -static const int kMaxResponseSize = 512; +DnsRecordParser::DnsRecordParser() : packet_(NULL), length_(0), cur_(0) { +} + +DnsRecordParser::DnsRecordParser(const void* packet, + size_t length, + size_t offset) + : packet_(reinterpret_cast<const char*>(packet)), + length_(length), + cur_(packet_ + offset) { + DCHECK_LE(offset, length); +} -DnsResponse::DnsResponse(DnsQuery* query) - : query_(query), - io_buffer_(new IOBufferWithSize(kMaxResponseSize + 1)) { - DCHECK(query_); +int DnsRecordParser::ParseName(const void* const vpos, std::string* out) const { + const char* const pos = reinterpret_cast<const char*>(vpos); + DCHECK(packet_); + DCHECK_LE(packet_, pos); + DCHECK_LE(pos, packet_ + length_); + + const char* p = pos; + const char* end = packet_ + length_; + // Count number of seen bytes to detect loops. + size_t seen = 0; + // Remember how many bytes were consumed before first jump. + size_t consumed = 0; + + if (pos >= end) + return 0; + + if (out) { + out->clear(); + out->reserve(dns_protocol::kMaxNameLength); + } + + for (;;) { + // The two couple of bits of the length give the type of the length. It's + // either a direct length or a pointer to the remainder of the name. + switch (*p & dns_protocol::kLabelMask) { + case dns_protocol::kLabelPointer: { + if (p + sizeof(uint16) > end) + return 0; + if (consumed == 0) { + consumed = p - pos + sizeof(uint16); + if (!out) + return consumed; // If name is not stored, that's all we need. + } + seen += sizeof(uint16); + // If seen the whole packet, then we must be in a loop. + if (seen > length_) + return 0; + uint16 offset; + ReadBigEndian<uint16>(p, &offset); + offset &= dns_protocol::kOffsetMask; + p = packet_ + offset; + if (p >= end) + return 0; + break; + } + case dns_protocol::kLabelDirect: { + uint8 label_len = *p; + ++p; + // Note: root domain (".") is NOT included. + if (label_len == 0) { + if (consumed == 0) { + consumed = p - pos; + } // else we set |consumed| before first jump + return consumed; + } + if (p + label_len >= end) + return 0; // Truncated or missing label. + if (out) { + if (!out->empty()) + out->append("."); + out->append(p, label_len); + } + p += label_len; + seen += 1 + label_len; + break; + } + default: + // unhandled label type + return 0; + } + } +} + +bool DnsRecordParser::ParseRecord(DnsResourceRecord* out) { + DCHECK(packet_); + size_t consumed = ParseName(cur_, &out->name); + if (!consumed) + return false; + BigEndianReader reader(cur_ + consumed, + packet_ + length_ - (cur_ + consumed)); + uint16 rdlen; + if (reader.ReadU16(&out->type) && + reader.ReadU16(&out->klass) && + reader.ReadU32(&out->ttl) && + reader.ReadU16(&rdlen) && + reader.ReadPiece(&out->rdata, rdlen)) { + cur_ = reader.ptr(); + return true; + } + return false; +} + +DnsResponse::DnsResponse() + : io_buffer_(new IOBufferWithSize(dns_protocol::kMaxUDPSize + 1)) { +} + +DnsResponse::DnsResponse(const void* data, + size_t length, + size_t answer_offset) + : io_buffer_(new IOBufferWithSize(length)), + parser_(io_buffer_->data(), length, answer_offset) { + memcpy(io_buffer_->data(), data, length); } DnsResponse::~DnsResponse() { } -int DnsResponse::Parse(int nbytes, IPAddressList* ip_addresses) { +bool DnsResponse::InitParse(int nbytes, const DnsQuery& query) { // Response includes query, it should be at least that size. - if (nbytes < query_->io_buffer()->size() || nbytes > kMaxResponseSize) - return ERR_DNS_MALFORMED_RESPONSE; - - DnsResponseBuffer response(reinterpret_cast<uint8*>(io_buffer_->data()), - io_buffer_->size()); - uint16 id; - if (!response.U16(&id) || id != query_->id()) // Make sure IDs match. - return ERR_DNS_MALFORMED_RESPONSE; - - uint8 flags, rcode; - if (!response.U8(&flags) || !response.U8(&rcode)) - return ERR_DNS_MALFORMED_RESPONSE; - - if (flags & 2) // TC is set -- server wants TCP, we don't support it (yet?). - return ERR_DNS_SERVER_REQUIRES_TCP; - - rcode &= 0x0f; // 3 means NXDOMAIN, the rest means server failed. - if (rcode && (rcode != 3)) - return ERR_DNS_SERVER_FAILED; - - uint16 query_count, answer_count, authority_count, additional_count; - if (!response.U16(&query_count) || - !response.U16(&answer_count) || - !response.U16(&authority_count) || - !response.U16(&additional_count)) { - return ERR_DNS_MALFORMED_RESPONSE; + if (nbytes < query.io_buffer()->size() || nbytes > dns_protocol::kMaxUDPSize) + return false; + + // Match the query id. + if (ntohs(header()->id) != query.id()) + return false; + + // Match question count. + if (ntohs(header()->qdcount) != 1) + return false; + + // Match the question section. + const size_t hdr_size = sizeof(dns_protocol::Header); + const base::StringPiece question = query.question(); + if (question != base::StringPiece(io_buffer_->data() + hdr_size, + question.size())) { + return false; } - if (query_count != 1) // Sent a single question, shouldn't have changed. - return ERR_DNS_MALFORMED_RESPONSE; + // Construct the parser. + parser_ = DnsRecordParser(io_buffer_->data(), + nbytes, + hdr_size + question.size()); + return true; +} - base::StringPiece question; // Make sure question section is echoed back. - if (!response.Block(&question, query_->question_size()) || - memcmp(question.data(), query_->question_data(), - query_->question_size())) { - return ERR_DNS_MALFORMED_RESPONSE; - } +uint8 DnsResponse::flags0() const { + return header()->flags[0]; +} - if (answer_count < 1) - return ERR_NAME_NOT_RESOLVED; - - IPAddressList rdatas; - while (answer_count--) { - uint32 ttl; - uint16 rdlength, qtype, qclass; - if (!response.DNSName(NULL) || - !response.U16(&qtype) || - !response.U16(&qclass) || - !response.U32(&ttl) || - !response.U16(&rdlength)) { - return ERR_DNS_MALFORMED_RESPONSE; - } - if (qtype == query_->qtype() && - qclass == kClassIN && - (rdlength == kIPv4AddressSize || rdlength == kIPv6AddressSize)) { - base::StringPiece rdata; - if (!response.Block(&rdata, rdlength)) - return ERR_DNS_MALFORMED_RESPONSE; - rdatas.push_back(IPAddressNumber(rdata.begin(), rdata.end())); - } else if (!response.Skip(rdlength)) - return ERR_DNS_MALFORMED_RESPONSE; - } +uint8 DnsResponse::flags1() const { + return header()->flags[1] & ~(dns_protocol::kRcodeMask); +} + +uint8 DnsResponse::rcode() const { + return header()->flags[1] & dns_protocol::kRcodeMask; +} - if (rdatas.empty()) - return ERR_NAME_NOT_RESOLVED; +int DnsResponse::answer_count() const { + return ntohs(header()->ancount); +} + +DnsRecordParser DnsResponse::Parser() const { + DCHECK(parser_.IsValid()); + return parser_; +} - if (ip_addresses) - ip_addresses->swap(rdatas); - return OK; +const dns_protocol::Header* DnsResponse::header() const { + return reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data()); } } // namespace net diff --git a/net/dns/dns_response.h b/net/dns/dns_response.h index cc7c3f7..0fa3df2 100644 --- a/net/dns/dns_response.h +++ b/net/dns/dns_response.h @@ -6,43 +6,109 @@ #define NET_DNS_DNS_RESPONSE_H_ #pragma once +#include <string> + +#include "base/basictypes.h" #include "base/memory/ref_counted.h" +#include "base/string_piece.h" #include "net/base/net_export.h" #include "net/base/net_util.h" -namespace net{ +namespace net { class DnsQuery; class IOBufferWithSize; -// Represents on-the-wire DNS response as an object; allows extracting -// records. +namespace dns_protocol { +struct Header; +} + +// Parsed resource record. +struct NET_EXPORT_PRIVATE DnsResourceRecord { + std::string name; // in dotted form + uint16 type; + uint16 klass; + uint32 ttl; + base::StringPiece rdata; // points to the original response buffer +}; + +// Iterator to walk over resource records of the DNS response packet. +class NET_EXPORT_PRIVATE DnsRecordParser { + public: + // Construct an uninitialized iterator. + DnsRecordParser(); + + // Construct an iterator to process the |packet| of given |length|. + // |offset| points to the beginning of the answer section. + DnsRecordParser(const void* packet, size_t length, size_t offset); + + // Returns |true| if initialized. + bool IsValid() const { return packet_ != NULL; } + + // Returns |true| if no more bytes remain in the packet. + bool AtEnd() const { return cur_ == packet_ + length_; } + + // Parses a (possibly compressed) DNS name from the packet starting at + // |pos|. Stores output (even partial) in |out| unless |out| is NULL. |out| + // is stored in the dotted form, e.g., "example.com". Returns number of bytes + // consumed or 0 on failure. + // This is exposed to allow parsing compressed names within RRDATA for TYPEs + // such as NS, CNAME, PTR, MX, SOA. + // See RFC 1035 section 4.1.4. + int ParseName(const void* pos, std::string* out) const; + + // Parses the next resource record. Returns true if succeeded. + bool ParseRecord(DnsResourceRecord* record); + + private: + const char* packet_; + size_t length_; + // Current offset within the packet. + const char* cur_; +}; + +// Buffer-holder for the DNS response allowing easy access to the header fields +// and resource records. After reading into |io_buffer| must call InitParse to +// position the RR parser. class NET_EXPORT_PRIVATE DnsResponse { public: // Constructs an object with an IOBuffer large enough to read // one byte more than largest possible response, to detect malformed - // responses; |query| is a pointer to the DnsQuery for which |this| - // is supposed to be a response. - explicit DnsResponse(DnsQuery* query); + // responses. + DnsResponse(); + // Constructs response from |data|. Used for testing purposes only! + DnsResponse(const void* data, size_t length, size_t answer_offset); ~DnsResponse(); // Internal buffer accessor into which actual bytes of response will be // read. IOBufferWithSize* io_buffer() { return io_buffer_.get(); } - // Parses response of size nbytes and puts address into |ip_addresses|, - // returns net_error code in case of failure. - int Parse(int nbytes, IPAddressList* ip_addresses); + // Returns false if the packet is shorter than the header or does not match + // |query| id or question. + bool InitParse(int nbytes, const DnsQuery& query); + + // Accessors for the header. + uint8 flags0() const; // first byte of flags + uint8 flags1() const; // second byte of flags excluding rcode + uint8 rcode() const; + int answer_count() const; + + // Returns an iterator to the resource records in the answer section. Must be + // called after InitParse. The iterator is valid only in the scope of the + // DnsResponse. + DnsRecordParser Parser() const; private: - // The matching query; |this| is the response for |query_|. We do not - // own it, lifetime of |this| should be within the limits of lifetime of - // |query_|. - const DnsQuery* const query_; + // Convenience for header access. + const dns_protocol::Header* header() const; // Buffer into which response bytes are read. scoped_refptr<IOBufferWithSize> io_buffer_; + // Iterator constructed after InitParse positioned at the answer section. + DnsRecordParser parser_; + DISALLOW_COPY_AND_ASSIGN(DnsResponse); }; diff --git a/net/dns/dns_response_unittest.cc b/net/dns/dns_response_unittest.cc index 775cfc6..51d7cff 100644 --- a/net/dns/dns_response_unittest.cc +++ b/net/dns/dns_response_unittest.cc @@ -4,106 +4,176 @@ #include "net/dns/dns_response.h" -#include "base/bind.h" -#include "base/rand_util.h" -#include "net/base/dns_util.h" -#include "net/base/net_errors.h" #include "net/base/io_buffer.h" +#include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { -// DNS response consists of a header followed by a question followed by -// answer. Header format, question format and response format are -// described below. For the meaning of specific fields, please see RFC -// 1035. - -// Header format. -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ID | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// |QR| Opcode |AA|TC|RD|RA| Z | RCODE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QDCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ANCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | NSCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ARCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - -// Question format. -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | | -// / QNAME / -// / / -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QTYPE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QCLASS | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - -// Answser format. -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | | -// / / -// / NAME / -// | | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | TYPE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | CLASS | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | TTL | -// | | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | RDLENGTH | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| -// / RDATA / -// / / -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - -// TODO(agayev): add more thorough tests. -TEST(DnsResponseTest, ResponseWithCnameA) { - const std::string kQname("\012codereview\010chromium\003org", 25); - DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt)); - - uint8 id_hi = q1.id() >> 8, id_lo = q1.id() & 0xff; - - uint8 ip[] = { // codereview.chromium.org resolves to - 0x4a, 0x7d, 0x5f, 0x79 // 74.125.95.121 +namespace { + +TEST(DnsRecordParserTest, Constructor) { + const char data[] = { 0 }; + + EXPECT_FALSE(DnsRecordParser().IsValid()); + EXPECT_TRUE(DnsRecordParser(data, 1, 0).IsValid()); + EXPECT_TRUE(DnsRecordParser(data, 1, 1).IsValid()); + + EXPECT_FALSE(DnsRecordParser(data, 1, 0).AtEnd()); + EXPECT_TRUE(DnsRecordParser(data, 1, 1).AtEnd()); +} + +TEST(DnsRecordParserTest, ParseName) { + const uint8 data[] = { + // all labels "foo.example.com" + 0x03, 'f', 'o', 'o', + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + // byte 0x10 + 0x00, + // byte 0x11 + // part label, part pointer, "bar.example.com" + 0x03, 'b', 'a', 'r', + 0xc0, 0x04, + // byte 0x17 + // all pointer to "bar.example.com", 2 jumps + 0xc0, 0x11, + // byte 0x1a + }; + + std::string out; + DnsRecordParser parser(data, sizeof(data), 0); + ASSERT_TRUE(parser.IsValid()); + + EXPECT_EQ(0x11, parser.ParseName(data + 0x00, &out)); + EXPECT_EQ("foo.example.com", out); + // Check that the last "." is never stored. + out.clear(); + EXPECT_EQ(0x1, parser.ParseName(data + 0x10, &out)); + EXPECT_EQ("", out); + out.clear(); + EXPECT_EQ(0x6, parser.ParseName(data + 0x11, &out)); + EXPECT_EQ("bar.example.com", out); + out.clear(); + EXPECT_EQ(0x2, parser.ParseName(data + 0x17, &out)); + EXPECT_EQ("bar.example.com", out); + + // Parse name without storing it. + EXPECT_EQ(0x11, parser.ParseName(data + 0x00, NULL)); + EXPECT_EQ(0x1, parser.ParseName(data + 0x10, NULL)); + EXPECT_EQ(0x6, parser.ParseName(data + 0x11, NULL)); + EXPECT_EQ(0x2, parser.ParseName(data + 0x17, NULL)); + + // Check that it works even if initial position is different. + parser = DnsRecordParser(data, sizeof(data), 0x12); + EXPECT_EQ(0x6, parser.ParseName(data + 0x11, NULL)); +} + +TEST(DnsRecordParserTest, ParseNameFail) { + const uint8 data[] = { + // label length beyond packet + 0x30, 'x', 'x', + 0x00, + // pointer offset beyond packet + 0xc0, 0x20, + // pointer loop + 0xc0, 0x08, + 0xc0, 0x06, + // incorrect label type (currently supports only direct and pointer) + 0x80, 0x00, + // truncated name (missing root label) + 0x02, 'x', 'x', + }; + + DnsRecordParser parser(data, sizeof(data), 0); + ASSERT_TRUE(parser.IsValid()); + + std::string out; + EXPECT_EQ(0, parser.ParseName(data + 0x00, &out)); + EXPECT_EQ(0, parser.ParseName(data + 0x04, &out)); + EXPECT_EQ(0, parser.ParseName(data + 0x08, &out)); + EXPECT_EQ(0, parser.ParseName(data + 0x0a, &out)); + EXPECT_EQ(0, parser.ParseName(data + 0x0c, &out)); + EXPECT_EQ(0, parser.ParseName(data + 0x0e, &out)); +} + +TEST(DnsRecordParserTest, ParseRecord) { + const uint8 data[] = { + // Type CNAME record. + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, + 0x00, 0x05, // TYPE is CNAME. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, 0x24, 0x74, // TTL is 0x00012474. + 0x00, 0x06, // RDLENGTH is 6 bytes. + 0x03, 'f', 'o', 'o', // compressed name in record + 0xc0, 0x00, + // Type A record. + 0x03, 'b', 'a', 'r', // compressed owner name + 0xc0, 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x20, 0x13, 0x55, // TTL is 0x00201355. + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x7f, 0x02, 0x04, 0x01, // IP is 127.2.4.1 }; - IPAddressList expected_ips; - expected_ips.push_back(IPAddressNumber(ip, ip + arraysize(ip))); + std::string out; + DnsRecordParser parser(data, sizeof(data), 0); + + DnsResourceRecord record; + EXPECT_TRUE(parser.ParseRecord(&record)); + EXPECT_EQ("example.com", record.name); + EXPECT_EQ(dns_protocol::kTypeCNAME, record.type); + EXPECT_EQ(dns_protocol::kClassIN, record.klass); + EXPECT_EQ(0x00012474u, record.ttl); + EXPECT_EQ(6u, record.rdata.length()); + EXPECT_EQ(6, parser.ParseName(record.rdata.data(), &out)); + EXPECT_EQ("foo.example.com", out); + EXPECT_FALSE(parser.AtEnd()); + + EXPECT_TRUE(parser.ParseRecord(&record)); + EXPECT_EQ("bar.example.com", record.name); + EXPECT_EQ(dns_protocol::kTypeA, record.type); + EXPECT_EQ(dns_protocol::kClassIN, record.klass); + EXPECT_EQ(0x00201355u, record.ttl); + EXPECT_EQ(4u, record.rdata.length()); + EXPECT_EQ(base::StringPiece("\x7f\x02\x04\x01"), record.rdata); + EXPECT_TRUE(parser.AtEnd()); + + // Test truncated record. + parser = DnsRecordParser(data, sizeof(data) - 2, 0); + EXPECT_TRUE(parser.ParseRecord(&record)); + EXPECT_FALSE(parser.AtEnd()); + EXPECT_FALSE(parser.ParseRecord(&record)); +} + +TEST(DnsResponseTest, InitParse) { + // This includes \0 at the end. + const char qname_data[] = "\x0A""codereview""\x08""chromium""\x03""org"; + const base::StringPiece qname(qname_data, sizeof(qname_data)); + // Compilers want to copy when binding temporary to const &, so must use heap. + scoped_ptr<DnsQuery> query(new DnsQuery(0xcafe, qname, dns_protocol::kTypeA)); - uint8 response_data[] = { + const uint8 response_data[] = { // Header - id_hi, id_lo, // ID - 0x81, 0x80, // Standard query response, no error + 0xca, 0xfe, // ID + 0x81, 0x80, // Standard query response, RA, no error 0x00, 0x01, // 1 question 0x00, 0x02, // 2 RRs (answers) 0x00, 0x00, // 0 authority RRs 0x00, 0x00, // 0 additional RRs // Question - 0x0a, 0x63, 0x6f, 0x64, // This part is echoed back from the - 0x65, 0x72, 0x65, 0x76, // respective query. - 0x69, 0x65, 0x77, 0x08, - 0x63, 0x68, 0x72, 0x6f, - 0x6d, 0x69, 0x75, 0x6d, - 0x03, 0x6f, 0x72, 0x67, + // This part is echoed back from the respective query. + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x03, 'o', 'r', 'g', 0x00, - 0x00, 0x01, - 0x00, 0x01, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. // Answer 1 0xc0, 0x0c, // NAME is a pointer to name in Question section. @@ -111,33 +181,56 @@ TEST(DnsResponseTest, ResponseWithCnameA) { 0x00, 0x01, // CLASS is IN. 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. 0x24, 0x74, - 0x00, 0x12, // RDLENGTH is 18 bytse. - 0x03, 0x67, 0x68, 0x73, // ghs.l.google.com in DNS format. - 0x01, 0x6c, 0x06, 0x67, - 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x03, 0x63, 0x6f, - 0x6d, 0x00, + 0x00, 0x12, // RDLENGTH is 18 bytes. + // ghs.l.google.com in DNS format. + 0x03, 'g', 'h', 's', + 0x01, 'l', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, // Answer 2 - 0xc0, 0x35, // NAME is a pointer to name in Question section. + 0xc0, 0x35, // NAME is a pointer to name in Answer 1. 0x00, 0x01, // TYPE is A. 0x00, 0x01, // CLASS is IN. 0x00, 0x00, // TTL (4 bytes) is 53 seconds. 0x00, 0x35, - 0x00, 0x04, // RDLENGTH is 4 bytes. - ip[0], ip[1], ip[2], ip[3], // RDATA is the IP. + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121 + 0x5f, 0x79, }; - // Create a response object and simulate reading into it. - DnsResponse r1(&q1); - memcpy(r1.io_buffer()->data(), &response_data[0], - r1.io_buffer()->size()); + DnsResponse resp; + memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data)); - // Verify resolved IPs. - int response_size = arraysize(response_data); - IPAddressList actual_ips; - EXPECT_EQ(OK, r1.Parse(response_size, &actual_ips)); - EXPECT_EQ(expected_ips, actual_ips); + // Reject too short. + EXPECT_FALSE(resp.InitParse(query->io_buffer()->size() - 1, *query)); + + // Reject wrong id. + scoped_ptr<DnsQuery> other_query(query->CloneWithNewId(0xbeef)); + EXPECT_FALSE(resp.InitParse(sizeof(response_data), *other_query)); + + // Reject wrong question. + scoped_ptr<DnsQuery> wrong_query( + new DnsQuery(0xcafe, qname, dns_protocol::kTypeCNAME)); + EXPECT_FALSE(resp.InitParse(sizeof(response_data), *wrong_query)); + + // Accept matching question. + EXPECT_TRUE(resp.InitParse(sizeof(response_data), *query)); + + // Check header access. + EXPECT_EQ(0x81, resp.flags0()); + EXPECT_EQ(0x80, resp.flags1()); + EXPECT_EQ(0x0, resp.rcode()); + EXPECT_EQ(2, resp.answer_count()); + + DnsResourceRecord record; + DnsRecordParser parser = resp.Parser(); + EXPECT_TRUE(parser.ParseRecord(&record)); + EXPECT_TRUE(parser.ParseRecord(&record)); + EXPECT_FALSE(parser.ParseRecord(&record)); } +} // namespace + } // namespace net diff --git a/net/dns/dns_session.cc b/net/dns/dns_session.cc new file mode 100644 index 0000000..15cbaa2 --- /dev/null +++ b/net/dns/dns_session.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_session.h" + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/time.h" +#include "net/base/ip_endpoint.h" +#include "net/dns/dns_config_service.h" +#include "net/socket/client_socket_factory.h" + +namespace net { + +DnsSession::DnsSession(const DnsConfig& config, + ClientSocketFactory* factory, + const RandIntCallback& rand_int_callback, + NetLog* net_log) + : config_(config), + socket_factory_(factory), + rand_callback_(base::Bind(rand_int_callback, 0, kuint16max)), + net_log_(net_log), + server_index_(0) { +} + +int DnsSession::NextId() const { + return rand_callback_.Run(); +} + +const IPEndPoint& DnsSession::NextServer() { + // TODO(szym): Rotate servers on failures. + const IPEndPoint& ipe = config_.nameservers[server_index_]; + if (config_.rotate) + server_index_ = (server_index_ + 1) % config_.nameservers.size(); + return ipe; +} + +base::TimeDelta DnsSession::NextTimeout(int attempt) { + // TODO(szym): Adapt timeout to observed RTT. + return config_.timeout * (attempt + 1); +} + +DnsSession::~DnsSession() {} + +} // namespace net + diff --git a/net/dns/dns_session.h b/net/dns/dns_session.h new file mode 100644 index 0000000..df986ea --- /dev/null +++ b/net/dns/dns_session.h @@ -0,0 +1,70 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_SESSION_H_ +#define NET_DNS_DNS_SESSION_H_ +#pragma once + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/time.h" +#include "net/base/net_export.h" +#include "net/base/rand_callback.h" +#include "net/dns/dns_config_service.h" + +namespace net { + +class ClientSocketFactory; +class NetLog; + +// Session parameters and state shared between DNS transactions. +// Ref-counted so that DnsClient::Request can keep working in absence of +// DnsClient. A DnsSession must be recreated when DnsConfig changes. +class NET_EXPORT_PRIVATE DnsSession + : NON_EXPORTED_BASE(public base::RefCounted<DnsSession>) { + public: + typedef base::Callback<int()> RandCallback; + + DnsSession(const DnsConfig& config, + ClientSocketFactory* factory, + const RandIntCallback& rand_int_callback, + NetLog* net_log); + + ClientSocketFactory* socket_factory() const { return socket_factory_.get(); } + const DnsConfig& config() const { return config_; } + NetLog* net_log() const { return net_log_; } + + // Return the next random query ID. + int NextId() const; + + // Return the next server address. + const IPEndPoint& NextServer(); + + // Return the timeout for the next transaction. + base::TimeDelta NextTimeout(int attempt); + + private: + friend class base::RefCounted<DnsSession>; + ~DnsSession(); + + const DnsConfig config_; + scoped_ptr<ClientSocketFactory> socket_factory_; + RandCallback rand_callback_; + NetLog* net_log_; + + // Current index into |config_.nameservers|. + int server_index_; + + // TODO(szym): add current RTT estimate + // TODO(szym): add flag to indicate DNSSEC is supported + // TODO(szym): add TCP connection pool to support DNS over TCP + // TODO(szym): add UDP socket pool ? + + DISALLOW_COPY_AND_ASSIGN(DnsSession); +}; + +} // namespace net + +#endif // NET_DNS_DNS_SESSION_H_ + diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc index ac63131..a396252 100644 --- a/net/dns/dns_test_util.cc +++ b/net/dns/dns_test_util.cc @@ -8,20 +8,6 @@ namespace net { -TestPrng::TestPrng(const std::deque<int>& numbers) : numbers_(numbers) { -} - -TestPrng::~TestPrng() { -} - -int TestPrng::GetNext(int min, int max) { - DCHECK(!numbers_.empty()); - int rv = numbers_.front(); - numbers_.pop_front(); - DCHECK(rv >= min && rv <= max); - return rv; -} - bool ConvertStringsToIPAddressList( const char* const ip_strings[], size_t size, IPAddressList* address_list) { DCHECK(address_list); diff --git a/net/dns/dns_test_util.h b/net/dns/dns_test_util.h index 651cfcb..c59a58f 100644 --- a/net/dns/dns_test_util.h +++ b/net/dns/dns_test_util.h @@ -14,26 +14,10 @@ #include "net/base/host_resolver.h" #include "net/base/ip_endpoint.h" #include "net/base/net_util.h" +#include "net/dns/dns_protocol.h" namespace net { -// DNS related classes make use of PRNG for various tasks. This class is -// used as a PRNG for unit testing those tasks. It takes a deque of -// integers |numbers| which should be returned by calls to GetNext. -class TestPrng { - public: - explicit TestPrng(const std::deque<int>& numbers); - ~TestPrng(); - - // Pops and returns the next number from |numbers_| deque. - int GetNext(int min, int max); - - private: - std::deque<int> numbers_; - - DISALLOW_COPY_AND_ASSIGN(TestPrng); -}; - // A utility function for tests that given an array of IP literals, // converts it to an IPAddressList. bool ConvertStringsToIPAddressList( @@ -49,7 +33,7 @@ static const uint16 kDnsPort = 53; //----------------------------------------------------------------------------- // Query/response set for www.google.com, ID is fixed to 0. static const char kT0HostName[] = "www.google.com"; -static const uint16 kT0Qtype = kDNS_A; +static const uint16 kT0Qtype = dns_protocol::kTypeA; static const char kT0DnsName[] = { 0x03, 'w', 'w', 'w', 0x06, 'g', 'o', 'o', 'g', 'l', 'e', @@ -92,10 +76,10 @@ static const char* const kT0IpAddresses[] = { //----------------------------------------------------------------------------- // Query/response set for codereview.chromium.org, ID is fixed to 1. static const char kT1HostName[] = "codereview.chromium.org"; -static const uint16 kT1Qtype = kDNS_A; +static const uint16 kT1Qtype = dns_protocol::kTypeA; static const char kT1DnsName[] = { - 0x12, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', - 0x10, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', 0x03, 'o', 'r', 'g', 0x00 }; @@ -130,10 +114,10 @@ static const char* const kT1IpAddresses[] = { //----------------------------------------------------------------------------- // Query/response set for www.ccs.neu.edu, ID is fixed to 2. static const char kT2HostName[] = "www.ccs.neu.edu"; -static const uint16 kT2Qtype = kDNS_A; +static const uint16 kT2Qtype = dns_protocol::kTypeA; static const char kT2DnsName[] = { 0x03, 'w', 'w', 'w', - 0x03, 'c', 'c', 'c', + 0x03, 'c', 'c', 's', 0x03, 'n', 'e', 'u', 0x03, 'e', 'd', 'u', 0x00 @@ -166,7 +150,7 @@ static const char* const kT2IpAddresses[] = { //----------------------------------------------------------------------------- // Query/response set for www.google.az, ID is fixed to 3. static const char kT3HostName[] = "www.google.az"; -static const uint16 kT3Qtype = kDNS_A; +static const uint16 kT3Qtype = dns_protocol::kTypeA; static const char kT3DnsName[] = { 0x03, 'w', 'w', 'w', 0x06, 'g', 'o', 'o', 'g', 'l', 'e', diff --git a/net/dns/dns_transaction.cc b/net/dns/dns_transaction.cc index a7fa922..5cbe39a 100644 --- a/net/dns/dns_transaction.cc +++ b/net/dns/dns_transaction.cc @@ -7,11 +7,12 @@ #include "base/bind.h" #include "base/rand_util.h" #include "base/values.h" -#include "net/base/dns_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" #include "net/socket/client_socket_factory.h" #include "net/udp/datagram_client_socket.h" @@ -19,104 +20,51 @@ namespace net { namespace { -// Retry timeouts. -const int kTimeoutsMs[] = {3000, 5000, 11000}; -const int kMaxAttempts = arraysize(kTimeoutsMs); - -// Returns the string representation of an IPAddressNumber. -std::string IPAddressToString(const IPAddressNumber& ip_address) { - IPEndPoint ip_endpoint(ip_address, 0); - struct sockaddr_storage addr; - size_t addr_len = sizeof(addr); - struct sockaddr* sockaddr = reinterpret_cast<struct sockaddr*>(&addr); - if (!ip_endpoint.ToSockAddr(sockaddr, &addr_len)) - return ""; - return NetAddressToString(sockaddr, addr_len); -} - -} - -DnsTransaction::Delegate::Delegate() { -} - -DnsTransaction::Delegate::~Delegate() { - while (!registered_transactions_.empty()) { - DnsTransaction* transaction = *registered_transactions_.begin(); - transaction->SetDelegate(NULL); - } - DCHECK(registered_transactions_.empty()); -} - -void DnsTransaction::Delegate::OnTransactionComplete( - int result, - const DnsTransaction* transaction, - const IPAddressList& ip_addresses) { -} - -void DnsTransaction::Delegate::Attach(DnsTransaction* transaction) { - DCHECK(registered_transactions_.find(transaction) == - registered_transactions_.end()); - registered_transactions_.insert(transaction); -} - -void DnsTransaction::Delegate::Detach(DnsTransaction* transaction) { - DCHECK(registered_transactions_.find(transaction) != - registered_transactions_.end()); - registered_transactions_.erase(transaction); -} - -namespace { - class DnsTransactionStartParameters : public NetLog::EventParameters { public: DnsTransactionStartParameters(const IPEndPoint& dns_server, - const DnsTransaction::Key& key, + const base::StringPiece& qname, + uint16 qtype, const NetLog::Source& source) - : dns_server_(dns_server), key_(key), source_(source) {} + : dns_server_(dns_server), + qname_(qname.data(), qname.length()), + qtype_(qtype), + source_(source) {} virtual Value* ToValue() const { - std::string hostname; - DnsResponseBuffer( - reinterpret_cast<const uint8*>(key_.first.c_str()), key_.first.size()). - DNSName(&hostname); - DictionaryValue* dict = new DictionaryValue(); dict->SetString("dns_server", dns_server_.ToString()); - dict->SetString("hostname", hostname); - dict->SetInteger("query_type", key_.second); + dict->SetString("hostname", qname_); + dict->SetInteger("query_type", qtype_); if (source_.is_valid()) dict->Set("source_dependency", source_.ToValue()); return dict; } private: - const IPEndPoint dns_server_; - const DnsTransaction::Key key_; + IPEndPoint dns_server_; + std::string qname_; + uint16 qtype_; const NetLog::Source source_; }; class DnsTransactionFinishParameters : public NetLog::EventParameters { public: - DnsTransactionFinishParameters(int net_error, - const IPAddressList& ip_address_list) - : net_error_(net_error), ip_address_list_(ip_address_list) {} + // TODO(szym): add rcode ? + DnsTransactionFinishParameters(int net_error, int answer_count) + : net_error_(net_error), answer_count_(answer_count) {} virtual Value* ToValue() const { - ListValue* list = new ListValue(); - for (IPAddressList::const_iterator it = ip_address_list_.begin(); - it != ip_address_list_.end(); ++it) - list->Append(Value::CreateStringValue(IPAddressToString(*it))); - DictionaryValue* dict = new DictionaryValue(); if (net_error_) dict->SetInteger("net_error", net_error_); - dict->Set("address_list", list); + dict->SetInteger("answer_count", answer_count_); return dict; } private: const int net_error_; - const IPAddressList ip_address_list_; + const int answer_count_; }; class DnsTransactionRetryParameters : public NetLog::EventParameters { @@ -139,47 +87,32 @@ class DnsTransactionRetryParameters : public NetLog::EventParameters { } // namespace -DnsTransaction::DnsTransaction(const IPEndPoint& dns_server, - const std::string& dns_name, - uint16 query_type, - const RandIntCallback& rand_int, - ClientSocketFactory* socket_factory, - const BoundNetLog& source_net_log, - NetLog* net_log) - : dns_server_(dns_server), - key_(dns_name, query_type), - delegate_(NULL), - query_(new DnsQuery(dns_name, query_type, rand_int)), + +DnsTransaction::DnsTransaction(DnsSession* session, + const base::StringPiece& qname, + uint16 qtype, + const ResultCallback& callback, + const BoundNetLog& source_net_log) + : session_(session), + dns_server_(session->NextServer()), + query_(new DnsQuery(session->NextId(), qname, qtype)), + callback_(callback), attempts_(0), next_state_(STATE_NONE), - socket_factory_(socket_factory ? socket_factory : - ClientSocketFactory::GetDefaultFactory()), ALLOW_THIS_IN_INITIALIZER_LIST( io_callback_(this, &DnsTransaction::OnIOComplete)), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_DNS_TRANSACTION)) { - DCHECK(!rand_int.is_null()); - for (size_t i = 0; i < arraysize(kTimeoutsMs); ++i) - timeouts_ms_.push_back(base::TimeDelta::FromMilliseconds(kTimeoutsMs[i])); + net_log_(BoundNetLog::Make(session->net_log(), + NetLog::SOURCE_DNS_TRANSACTION)) { net_log_.BeginEvent( NetLog::TYPE_DNS_TRANSACTION, make_scoped_refptr( - new DnsTransactionStartParameters(dns_server_, key_, + new DnsTransactionStartParameters(dns_server_, + qname, + qtype, source_net_log.source()))); } -DnsTransaction::~DnsTransaction() { - SetDelegate(NULL); -} - -void DnsTransaction::SetDelegate(Delegate* delegate) { - if (delegate == delegate_) - return; - if (delegate_) - delegate_->Detach(this); - delegate_ = delegate; - if (delegate_) - delegate_->Attach(this); -} +DnsTransaction::~DnsTransaction() {} int DnsTransaction::Start() { DCHECK_EQ(STATE_NONE, next_state_); @@ -223,12 +156,12 @@ int DnsTransaction::DoLoop(int result) { void DnsTransaction::DoCallback(int result) { DCHECK_NE(result, ERR_IO_PENDING); + int answer_count = (result == OK) ? response()->answer_count() : 0; net_log_.EndEvent( NetLog::TYPE_DNS_TRANSACTION, make_scoped_refptr( - new DnsTransactionFinishParameters(result, ip_addresses_))); - if (delegate_) - delegate_->OnTransactionComplete(result, this, ip_addresses_); + new DnsTransactionFinishParameters(result, answer_count))); + callback_.Run(this, result); } void DnsTransaction::OnIOComplete(int result) { @@ -240,13 +173,14 @@ void DnsTransaction::OnIOComplete(int result) { int DnsTransaction::DoConnect() { next_state_ = STATE_CONNECT_COMPLETE; - DCHECK_LT(attempts_, timeouts_ms_.size()); - StartTimer(timeouts_ms_[attempts_]); - attempts_++; + StartTimer(session_->NextTimeout(attempts_)); + ++attempts_; - // TODO(agayev): keep all sockets around in case the server responds + // TODO(szym): keep all sockets around in case the server responds // after its timeout; state machine will need to change to handle that. - socket_.reset(socket_factory_->CreateDatagramClientSocket( + // The current plan is to move socket management out to DnsSession. + // Hence also move retransmissions to DnsClient::Request. + socket_.reset(session_->socket_factory()->CreateDatagramClientSocket( DatagramSocket::RANDOM_BIND, base::Bind(&base::RandInt), net_log_.net_log(), @@ -281,7 +215,7 @@ int DnsTransaction::DoSendQueryComplete(int rv) { // Writing to UDP should not result in a partial datagram. if (rv != query_->io_buffer()->size()) - return ERR_NAME_NOT_RESOLVED; + return ERR_MSG_TOO_BIG; next_state_ = STATE_READ_RESPONSE; return OK; @@ -289,7 +223,7 @@ int DnsTransaction::DoSendQueryComplete(int rv) { int DnsTransaction::DoReadResponse() { next_state_ = STATE_READ_RESPONSE_COMPLETE; - response_.reset(new DnsResponse(query_.get())); + response_.reset(new DnsResponse()); return socket_->Read(response_->io_buffer(), response_->io_buffer()->size(), &io_callback_); @@ -302,9 +236,21 @@ int DnsTransaction::DoReadResponseComplete(int rv) { return rv; DCHECK(rv); - // TODO(agayev): when supporting EDNS0 we may need to do multiple reads - // to read the whole response. - return response_->Parse(rv, &ip_addresses_); + if (!response_->InitParse(rv, *query_)) + return ERR_DNS_MALFORMED_RESPONSE; + // TODO(szym): define this flag value in dns_protocol + if (response_->flags1() & 2) + return ERR_DNS_SERVER_REQUIRES_TCP; + // TODO(szym): move this handling out of DnsTransaction? + if (response_->rcode() != dns_protocol::kRcodeNOERROR && + response_->rcode() != dns_protocol::kRcodeNXDOMAIN) { + return ERR_DNS_SERVER_FAILED; + } + // TODO(szym): add ERR_DNS_RR_NOT_FOUND? + if (response_->answer_count() == 0) + return ERR_NAME_NOT_RESOLVED; + + return OK; } void DnsTransaction::StartTimer(base::TimeDelta delay) { @@ -318,21 +264,15 @@ void DnsTransaction::RevokeTimer() { void DnsTransaction::OnTimeout() { DCHECK(next_state_ == STATE_SEND_QUERY_COMPLETE || next_state_ == STATE_READ_RESPONSE_COMPLETE); - if (attempts_ == timeouts_ms_.size()) { + if (attempts_ == session_->config().attempts) { DoCallback(ERR_DNS_TIMED_OUT); return; } next_state_ = STATE_CONNECT; - query_.reset(query_->CloneWithNewId()); + query_.reset(query_->CloneWithNewId(session_->NextId())); int rv = DoLoop(OK); if (rv != ERR_IO_PENDING) DoCallback(rv); } -void DnsTransaction::set_timeouts_ms( - const std::vector<base::TimeDelta>& timeouts_ms) { - DCHECK_EQ(0u, attempts_); - timeouts_ms_ = timeouts_ms; -} - } // namespace net diff --git a/net/dns/dns_transaction.h b/net/dns/dns_transaction.h index c126193..d4078f0 100644 --- a/net/dns/dns_transaction.h +++ b/net/dns/dns_transaction.h @@ -6,12 +6,10 @@ #define NET_DNS_DNS_TRANSACTION_H_ #pragma once -#include <set> #include <string> -#include <utility> #include <vector> -#include "base/gtest_prod_util.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/timer.h" #include "base/threading/non_thread_safe.h" @@ -23,70 +21,39 @@ namespace net { -class ClientSocketFactory; class DatagramClientSocket; class DnsQuery; class DnsResponse; +class DnsSession; -// Performs (with fixed retries) a single asynchronous DNS transaction, +// Performs a single asynchronous DNS transaction over UDP, // which consists of sending out a DNS query, waiting for a response, and -// parsing and returning the IP addresses that it matches. +// returning the response that it matches. class NET_EXPORT_PRIVATE DnsTransaction : NON_EXPORTED_BASE(public base::NonThreadSafe) { public: - typedef std::pair<std::string, uint16> Key; - - // Interface that should be implemented by DnsTransaction consumers and - // passed to the |Start| method to be notified when the transaction has - // completed. - class NET_EXPORT_PRIVATE Delegate { - public: - Delegate(); - virtual ~Delegate(); - - // A consumer of DnsTransaction should override |OnTransactionComplete| - // and call |set_delegate(this)|. The method will be called once the - // resolution has completed, results passed in as arguments. - virtual void OnTransactionComplete( - int result, - const DnsTransaction* transaction, - const IPAddressList& ip_addresses); - - private: - friend class DnsTransaction; - - void Attach(DnsTransaction* transaction); - void Detach(DnsTransaction* transaction); - - std::set<DnsTransaction*> registered_transactions_; + typedef base::Callback<void(DnsTransaction*, int)> ResultCallback; + + // Create new transaction using the parameters and state in |session|. + // Issues query for name |qname| (in DNS format) type |qtype| and class IN. + // Calls |callback| on completion or timeout. + // TODO(szym): change dependency to (IPEndPoint, Socket, DnsQuery, callback) + DnsTransaction(DnsSession* session, + const base::StringPiece& qname, + uint16 qtype, + const ResultCallback& callback, + const BoundNetLog& source_net_log); + ~DnsTransaction(); - DISALLOW_COPY_AND_ASSIGN(Delegate); - }; + const DnsQuery* query() const { return query_.get(); } - // |dns_server| is the address of the DNS server, |dns_name| is the - // hostname (in DNS format) to be resolved, |query_type| is the type of - // the query, either kDNS_A or kDNS_AAAA, |rand_int| is the PRNG used for - // generating DNS query. - DnsTransaction(const IPEndPoint& dns_server, - const std::string& dns_name, - uint16 query_type, - const RandIntCallback& rand_int, - ClientSocketFactory* socket_factory, - const BoundNetLog& source_net_log, - NetLog* net_log); - ~DnsTransaction(); - void SetDelegate(Delegate* delegate); - const Key& key() const { return key_; } + const DnsResponse* response() const { return response_.get(); } // Starts the resolution process. Will return ERR_IO_PENDING and will // notify the caller via |delegate|. Should only be called once. int Start(); private: - FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, FirstTimeoutTest); - FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, SecondTimeoutTest); - FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, ThirdTimeoutTest); - enum State { STATE_CONNECT, STATE_CONNECT_COMPLETE, @@ -114,26 +81,17 @@ class NET_EXPORT_PRIVATE DnsTransaction : void RevokeTimer(); void OnTimeout(); - // This is to be used by unit tests only. - void set_timeouts_ms(const std::vector<base::TimeDelta>& timeouts_ms); - - const IPEndPoint dns_server_; - Key key_; - IPAddressList ip_addresses_; - Delegate* delegate_; - + scoped_refptr<DnsSession> session_; + IPEndPoint dns_server_; scoped_ptr<DnsQuery> query_; + ResultCallback callback_; scoped_ptr<DnsResponse> response_; scoped_ptr<DatagramClientSocket> socket_; // Number of retry attempts so far. - size_t attempts_; - - // Timeouts in milliseconds. - std::vector<base::TimeDelta> timeouts_ms_; + int attempts_; State next_state_; - ClientSocketFactory* socket_factory_; base::OneShotTimer<DnsTransaction> timer_; OldCompletionCallbackImpl<DnsTransaction> io_callback_; diff --git a/net/dns/dns_transaction_unittest.cc b/net/dns/dns_transaction_unittest.cc index a99f55b3..3e3c2fe 100644 --- a/net/dns/dns_transaction_unittest.cc +++ b/net/dns/dns_transaction_unittest.cc @@ -8,6 +8,12 @@ #include <vector> #include "base/bind.h" +#include "base/test/test_timeouts.h" +#include "base/time.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" #include "net/dns/dns_test_util.h" #include "net/socket/socket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" @@ -16,43 +22,73 @@ namespace net { namespace { -static const base::TimeDelta kTimeoutsMs[] = { - base::TimeDelta::FromMilliseconds(20), - base::TimeDelta::FromMilliseconds(20), - base::TimeDelta::FromMilliseconds(20), -}; +// A mock for RandIntCallback that always returns 0. +int ReturnZero(int min, int max) { + return 0; +} -} // namespace +class DnsTransactionTest : public testing::Test { + protected: + virtual void SetUp() OVERRIDE { + callback_ = base::Bind(&DnsTransactionTest::OnTransactionComplete, + base::Unretained(this)); + qname_ = std::string(kT0DnsName, arraysize(kT0DnsName)); + // Use long timeout to prevent timing out on slow bots. + ConfigureSession(base::TimeDelta::FromMilliseconds( + TestTimeouts::action_timeout_ms())); + } + + void ConfigureSession(const base::TimeDelta& timeout) { + IPEndPoint dns_server; + bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); + ASSERT_TRUE(rv); + + DnsConfig config; + config.nameservers.push_back(dns_server); + config.attempts = 3; + config.timeout = timeout; + + session_ = new DnsSession(config, + new MockClientSocketFactory(), + base::Bind(&ReturnZero), + NULL /* NetLog */); + } + + void StartTransaction() { + transaction_.reset(new DnsTransaction(session_.get(), + qname_, + kT0Qtype, + callback_, + BoundNetLog())); + + int rv0 = transaction_->Start(); + EXPECT_EQ(ERR_IO_PENDING, rv0); + } -class TestDelegate : public DnsTransaction::Delegate { - public: - TestDelegate() : result_(ERR_UNEXPECTED), transaction_(NULL) {} - virtual ~TestDelegate() {} - virtual void OnTransactionComplete( - int result, - const DnsTransaction* transaction, - const IPAddressList& ip_addresses) { - result_ = result; - transaction_ = transaction; - ip_addresses_ = ip_addresses; + void OnTransactionComplete(DnsTransaction* transaction, int rv) { + EXPECT_EQ(transaction_.get(), transaction); + EXPECT_EQ(qname_, transaction->query()->qname().as_string()); + EXPECT_EQ(kT0Qtype, transaction->query()->qtype()); + rv_ = rv; MessageLoop::current()->Quit(); } - int result() const { return result_; } - const DnsTransaction* transaction() const { return transaction_; } - const IPAddressList& ip_addresses() const { - return ip_addresses_; + + MockClientSocketFactory& factory() { + return *static_cast<MockClientSocketFactory*>(session_->socket_factory()); } + int rv() const { return rv_; } + private: - int result_; - const DnsTransaction* transaction_; - IPAddressList ip_addresses_; + DnsTransaction::ResultCallback callback_; + std::string qname_; + scoped_refptr<DnsSession> session_; + scoped_ptr<DnsTransaction> transaction_; - DISALLOW_COPY_AND_ASSIGN(TestDelegate); + int rv_; }; - -TEST(DnsTransactionTest, NormalQueryResponseTest) { +TEST_F(DnsTransactionTest, NormalQueryResponseTest) { MockWrite writes0[] = { MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), arraysize(kT0QueryDatagram)) @@ -65,45 +101,19 @@ TEST(DnsTransactionTest, NormalQueryResponseTest) { StaticSocketDataProvider data(reads0, arraysize(reads0), writes0, arraysize(writes0)); - MockClientSocketFactory factory; - factory.AddSocketDataProvider(&data); - - TestPrng test_prng(std::deque<int>(1, 0)); - RandIntCallback rand_int_cb = - base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng)); - std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName)); - - IPEndPoint dns_server; - bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); - ASSERT_TRUE(rv); - - DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory, - BoundNetLog(), NULL); - - TestDelegate delegate; - t.SetDelegate(&delegate); - - IPAddressList expected_ip_addresses; - rv = ConvertStringsToIPAddressList(kT0IpAddresses, - arraysize(kT0IpAddresses), - &expected_ip_addresses); - ASSERT_TRUE(rv); - - int rv0 = t.Start(); - EXPECT_EQ(ERR_IO_PENDING, rv0); + factory().AddSocketDataProvider(&data); + StartTransaction(); MessageLoop::current()->Run(); - EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key()); - EXPECT_EQ(OK, delegate.result()); - EXPECT_EQ(&t, delegate.transaction()); - EXPECT_TRUE(expected_ip_addresses == delegate.ip_addresses()); + EXPECT_EQ(OK, rv()); + // TODO(szym): test fields of |transaction_->response()| EXPECT_TRUE(data.at_read_eof()); EXPECT_TRUE(data.at_write_eof()); } -TEST(DnsTransactionTest, MismatchedQueryResponseTest) { +TEST_F(DnsTransactionTest, MismatchedQueryResponseTest) { MockWrite writes0[] = { MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), arraysize(kT0QueryDatagram)) @@ -116,40 +126,20 @@ TEST(DnsTransactionTest, MismatchedQueryResponseTest) { StaticSocketDataProvider data(reads1, arraysize(reads1), writes0, arraysize(writes0)); - MockClientSocketFactory factory; - factory.AddSocketDataProvider(&data); - - TestPrng test_prng(std::deque<int>(1, 0)); - RandIntCallback rand_int_cb = - base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng)); - std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName)); - - IPEndPoint dns_server; - bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); - ASSERT_TRUE(rv); - - DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory, - BoundNetLog(), NULL); - - TestDelegate delegate; - t.SetDelegate(&delegate); - - int rv0 = t.Start(); - EXPECT_EQ(ERR_IO_PENDING, rv0); + factory().AddSocketDataProvider(&data); + StartTransaction(); MessageLoop::current()->Run(); - EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key()); - EXPECT_EQ(ERR_DNS_MALFORMED_RESPONSE, delegate.result()); - EXPECT_EQ(0u, delegate.ip_addresses().size()); - EXPECT_EQ(&t, delegate.transaction()); + EXPECT_EQ(ERR_DNS_MALFORMED_RESPONSE, rv()); + EXPECT_TRUE(data.at_read_eof()); EXPECT_TRUE(data.at_write_eof()); } // Test that after the first timeout we do a fresh connection and if we get // a response on the new connection, we return it. -TEST(DnsTransactionTest, FirstTimeoutTest) { +TEST_F(DnsTransactionTest, FirstTimeoutTest) { MockWrite writes0[] = { MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), arraysize(kT0QueryDatagram)) @@ -165,57 +155,30 @@ TEST(DnsTransactionTest, FirstTimeoutTest) { scoped_refptr<DelayedSocketData> socket1_data( new DelayedSocketData(0, reads0, arraysize(reads0), writes0, arraysize(writes0))); - MockClientSocketFactory factory; - factory.AddSocketDataProvider(socket0_data.get()); - factory.AddSocketDataProvider(socket1_data.get()); - - TestPrng test_prng(std::deque<int>(2, 0)); - RandIntCallback rand_int_cb = - base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng)); - std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName)); - - IPEndPoint dns_server; - bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); - ASSERT_TRUE(rv); - DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory, - BoundNetLog(), NULL); - - TestDelegate delegate; - t.SetDelegate(&delegate); - - t.set_timeouts_ms( - std::vector<base::TimeDelta>(kTimeoutsMs, - kTimeoutsMs + arraysize(kTimeoutsMs))); - - IPAddressList expected_ip_addresses; - rv = ConvertStringsToIPAddressList(kT0IpAddresses, - arraysize(kT0IpAddresses), - &expected_ip_addresses); - ASSERT_TRUE(rv); - - int rv0 = t.Start(); - EXPECT_EQ(ERR_IO_PENDING, rv0); + // Use short timeout to speed up the test. + ConfigureSession(base::TimeDelta::FromMilliseconds( + TestTimeouts::tiny_timeout_ms())); + factory().AddSocketDataProvider(socket0_data.get()); + factory().AddSocketDataProvider(socket1_data.get()); + StartTransaction(); MessageLoop::current()->Run(); - EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key()); - EXPECT_EQ(OK, delegate.result()); - EXPECT_EQ(&t, delegate.transaction()); - EXPECT_TRUE(expected_ip_addresses == delegate.ip_addresses()); + EXPECT_EQ(OK, rv()); EXPECT_TRUE(socket0_data->at_read_eof()); EXPECT_TRUE(socket0_data->at_write_eof()); EXPECT_TRUE(socket1_data->at_read_eof()); EXPECT_TRUE(socket1_data->at_write_eof()); - EXPECT_EQ(2u, factory.udp_client_sockets().size()); + EXPECT_EQ(2u, factory().udp_client_sockets().size()); } // Test that after the first timeout we do a fresh connection, and after // the second timeout we do another fresh connection, and if we get a // response on the second connection, we return it. -TEST(DnsTransactionTest, SecondTimeoutTest) { +TEST_F(DnsTransactionTest, SecondTimeoutTest) { MockWrite writes0[] = { MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), arraysize(kT0QueryDatagram)) @@ -233,45 +196,19 @@ TEST(DnsTransactionTest, SecondTimeoutTest) { scoped_refptr<DelayedSocketData> socket2_data( new DelayedSocketData(0, reads0, arraysize(reads0), writes0, arraysize(writes0))); - MockClientSocketFactory factory; - factory.AddSocketDataProvider(socket0_data.get()); - factory.AddSocketDataProvider(socket1_data.get()); - factory.AddSocketDataProvider(socket2_data.get()); - - TestPrng test_prng(std::deque<int>(3, 0)); - RandIntCallback rand_int_cb = - base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng)); - std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName)); - IPEndPoint dns_server; - bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); - ASSERT_TRUE(rv); + // Use short timeout to speed up the test. + ConfigureSession(base::TimeDelta::FromMilliseconds( + TestTimeouts::tiny_timeout_ms())); + factory().AddSocketDataProvider(socket0_data.get()); + factory().AddSocketDataProvider(socket1_data.get()); + factory().AddSocketDataProvider(socket2_data.get()); - DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory, - BoundNetLog(), NULL); - - TestDelegate delegate; - t.SetDelegate(&delegate); - - t.set_timeouts_ms( - std::vector<base::TimeDelta>(kTimeoutsMs, - kTimeoutsMs + arraysize(kTimeoutsMs))); - - IPAddressList expected_ip_addresses; - rv = ConvertStringsToIPAddressList(kT0IpAddresses, - arraysize(kT0IpAddresses), - &expected_ip_addresses); - ASSERT_TRUE(rv); - - int rv0 = t.Start(); - EXPECT_EQ(ERR_IO_PENDING, rv0); + StartTransaction(); MessageLoop::current()->Run(); - EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT1Qtype) == t.key()); - EXPECT_EQ(OK, delegate.result()); - EXPECT_EQ(&t, delegate.transaction()); - EXPECT_TRUE(expected_ip_addresses == delegate.ip_addresses()); + EXPECT_EQ(OK, rv()); EXPECT_TRUE(socket0_data->at_read_eof()); EXPECT_TRUE(socket0_data->at_write_eof()); @@ -279,13 +216,13 @@ TEST(DnsTransactionTest, SecondTimeoutTest) { EXPECT_TRUE(socket1_data->at_write_eof()); EXPECT_TRUE(socket2_data->at_read_eof()); EXPECT_TRUE(socket2_data->at_write_eof()); - EXPECT_EQ(3u, factory.udp_client_sockets().size()); + EXPECT_EQ(3u, factory().udp_client_sockets().size()); } // Test that after the first timeout we do a fresh connection, and after // the second timeout we do another fresh connection and after the third // timeout we give up and return a timeout error. -TEST(DnsTransactionTest, ThirdTimeoutTest) { +TEST_F(DnsTransactionTest, ThirdTimeoutTest) { MockWrite writes0[] = { MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram), arraysize(kT0QueryDatagram)) @@ -297,38 +234,19 @@ TEST(DnsTransactionTest, ThirdTimeoutTest) { new DelayedSocketData(2, NULL, 0, writes0, arraysize(writes0))); scoped_refptr<DelayedSocketData> socket2_data( new DelayedSocketData(2, NULL, 0, writes0, arraysize(writes0))); - MockClientSocketFactory factory; - factory.AddSocketDataProvider(socket0_data.get()); - factory.AddSocketDataProvider(socket1_data.get()); - factory.AddSocketDataProvider(socket2_data.get()); - - TestPrng test_prng(std::deque<int>(3, 0)); - RandIntCallback rand_int_cb = - base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng)); - std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName)); - IPEndPoint dns_server; - bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server); - ASSERT_TRUE(rv); + // Use short timeout to speed up the test. + ConfigureSession(base::TimeDelta::FromMilliseconds( + TestTimeouts::tiny_timeout_ms())); + factory().AddSocketDataProvider(socket0_data.get()); + factory().AddSocketDataProvider(socket1_data.get()); + factory().AddSocketDataProvider(socket2_data.get()); - DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory, - BoundNetLog(), NULL); - - TestDelegate delegate; - t.SetDelegate(&delegate); - - t.set_timeouts_ms( - std::vector<base::TimeDelta>(kTimeoutsMs, - kTimeoutsMs + arraysize(kTimeoutsMs))); - - int rv0 = t.Start(); - EXPECT_EQ(ERR_IO_PENDING, rv0); + StartTransaction(); MessageLoop::current()->Run(); - EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key()); - EXPECT_EQ(ERR_DNS_TIMED_OUT, delegate.result()); - EXPECT_EQ(&t, delegate.transaction()); + EXPECT_EQ(ERR_DNS_TIMED_OUT, rv()); EXPECT_TRUE(socket0_data->at_read_eof()); EXPECT_TRUE(socket0_data->at_write_eof()); @@ -336,7 +254,9 @@ TEST(DnsTransactionTest, ThirdTimeoutTest) { EXPECT_TRUE(socket1_data->at_write_eof()); EXPECT_TRUE(socket2_data->at_read_eof()); EXPECT_TRUE(socket2_data->at_write_eof()); - EXPECT_EQ(3u, factory.udp_client_sockets().size()); + EXPECT_EQ(3u, factory().udp_client_sockets().size()); } +} // namespace + } // namespace net diff --git a/net/net.gyp b/net/net.gyp index 8eb4948..446c628 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -55,6 +55,8 @@ 'base/backoff_entry.h', 'base/bandwidth_metrics.cc', 'base/bandwidth_metrics.h', + 'base/big_endian.cc', + 'base/big_endian.h', 'base/cache_type.h', 'base/capturing_net_log.cc', 'base/capturing_net_log.h', @@ -323,6 +325,8 @@ 'disk_cache/trace.h', 'dns/async_host_resolver.cc', 'dns/async_host_resolver.h', + 'dns/dns_client.cc', + 'dns/dns_client.h', 'dns/dns_config_service.cc', 'dns/dns_config_service.h', 'dns/dns_config_service_posix.cc', @@ -335,6 +339,8 @@ 'dns/dns_query.h', 'dns/dns_response.cc', 'dns/dns_response.h', + 'dns/dns_session.cc', + 'dns/dns_session.h', 'dns/dns_transaction.cc', 'dns/dns_transaction.h', 'dns/serial_worker.cc', @@ -992,6 +998,7 @@ 'sources': [ 'base/address_list_unittest.cc', 'base/backoff_entry_unittest.cc', + 'base/big_endian_unittest.cc', 'base/cert_database_nss_unittest.cc', 'base/cert_verifier_unittest.cc', 'base/cookie_monster_unittest.cc', @@ -1052,6 +1059,7 @@ 'disk_cache/mapped_file_unittest.cc', 'disk_cache/storage_block_unittest.cc', 'dns/async_host_resolver_unittest.cc', + 'dns/dns_client_unittest.cc', 'dns/dns_config_service_posix_unittest.cc', 'dns/dns_config_service_unittest.cc', 'dns/dns_config_service_win_unittest.cc', |