summaryrefslogtreecommitdiffstats
path: root/net/dns
diff options
context:
space:
mode:
authorszym@chromium.org <szym@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-12-08 19:29:15 +0000
committerszym@chromium.org <szym@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-12-08 19:29:15 +0000
commit7556ea2db8b28d15a37d5680febadc5cf1a5cc5b (patch)
treef9bb844166850fca5a4c64adf61361e6b20a2032 /net/dns
parent334e0cd1ff440365b1f3b28630362ede5548f400 (diff)
downloadchromium_src-7556ea2db8b28d15a37d5680febadc5cf1a5cc5b.zip
chromium_src-7556ea2db8b28d15a37d5680febadc5cf1a5cc5b.tar.gz
chromium_src-7556ea2db8b28d15a37d5680febadc5cf1a5cc5b.tar.bz2
Isolates generic DnsClient from AsyncHostResolver.
DnsClient provides a generic DNS client that allows fetching resource records. DnsClient is very lightweight and does not support aggregation, queuing or prioritization of requests. This is the first CL in a series to merge AsyncHostResolver into HostResolverImpl. Also introduces general-purpose BigEndianReader/Writer. Removes DnsTransactionTest-related suppressions. BUG=90881,80225,106688 TEST=./net_unittests Review URL: http://codereview.chromium.org/8852009 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@113640 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/dns')
-rw-r--r--net/dns/async_host_resolver.cc137
-rw-r--r--net/dns/async_host_resolver.h50
-rw-r--r--net/dns/async_host_resolver_unittest.cc171
-rw-r--r--net/dns/dns_client.cc91
-rw-r--r--net/dns/dns_client.h93
-rw-r--r--net/dns/dns_client_unittest.cc311
-rw-r--r--net/dns/dns_protocol.h122
-rw-r--r--net/dns/dns_query.cc111
-rw-r--r--net/dns/dns_query.h39
-rw-r--r--net/dns/dns_query_unittest.cc103
-rw-r--r--net/dns/dns_response.cc232
-rw-r--r--net/dns/dns_response.h92
-rw-r--r--net/dns/dns_response_unittest.cc295
-rw-r--r--net/dns/dns_session.cc47
-rw-r--r--net/dns/dns_session.h70
-rw-r--r--net/dns/dns_test_util.cc14
-rw-r--r--net/dns/dns_test_util.h32
-rw-r--r--net/dns/dns_transaction.cc184
-rw-r--r--net/dns/dns_transaction.h86
-rw-r--r--net/dns/dns_transaction_unittest.cc286
20 files changed, 1650 insertions, 916 deletions
diff --git a/net/dns/async_host_resolver.cc b/net/dns/async_host_resolver.cc
index 1666e04..ff93c3e 100644
--- a/net/dns/async_host_resolver.cc
+++ b/net/dns/async_host_resolver.cc
@@ -8,12 +8,16 @@
#include "base/bind.h"
#include "base/logging.h"
+#include "base/message_loop.h"
#include "base/rand_util.h"
#include "base/stl_util.h"
#include "base/values.h"
#include "net/base/address_list.h"
#include "net/base/dns_util.h"
#include "net/base/net_errors.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
#include "net/socket/client_socket_factory.h"
namespace net {
@@ -22,7 +26,7 @@ namespace {
// TODO(agayev): fix this when IPv6 support is added.
uint16 QueryTypeFromAddressFamily(AddressFamily address_family) {
- return kDNS_A;
+ return dns_protocol::kTypeA;
}
class RequestParameters : public NetLog::EventParameters {
@@ -56,17 +60,22 @@ class RequestParameters : public NetLog::EventParameters {
HostResolver* CreateAsyncHostResolver(size_t max_concurrent_resolves,
const IPAddressNumber& dns_ip,
NetLog* net_log) {
- size_t max_transactions = max_concurrent_resolves;
- if (max_transactions == 0)
- max_transactions = 20;
- size_t max_pending_requests = max_transactions * 100;
+ size_t max_dns_requests = max_concurrent_resolves;
+ if (max_dns_requests == 0)
+ max_dns_requests = 20;
+ size_t max_pending_requests = max_dns_requests * 100;
+ DnsConfig config;
+ config.nameservers.push_back(IPEndPoint(dns_ip, 53));
+ DnsSession* session = new DnsSession(
+ config,
+ ClientSocketFactory::GetDefaultFactory(),
+ base::Bind(&base::RandInt),
+ net_log);
HostResolver* resolver = new AsyncHostResolver(
- IPEndPoint(dns_ip, 53),
- max_transactions,
+ max_dns_requests,
max_pending_requests,
- base::Bind(&base::RandInt),
HostCache::CreateDefaultCache(),
- NULL,
+ DnsClient::CreateClient(session),
net_log);
return resolver;
}
@@ -193,19 +202,15 @@ class AsyncHostResolver::Request {
};
//-----------------------------------------------------------------------------
-AsyncHostResolver::AsyncHostResolver(const IPEndPoint& dns_server,
- size_t max_transactions,
+AsyncHostResolver::AsyncHostResolver(size_t max_dns_requests,
size_t max_pending_requests,
- const RandIntCallback& rand_int_cb,
HostCache* cache,
- ClientSocketFactory* factory,
+ DnsClient* client,
NetLog* net_log)
- : max_transactions_(max_transactions),
+ : max_dns_requests_(max_dns_requests),
max_pending_requests_(max_pending_requests),
- dns_server_(dns_server),
- rand_int_cb_(rand_int_cb),
cache_(cache),
- factory_(factory),
+ client_(client),
net_log_(net_log) {
}
@@ -215,8 +220,8 @@ AsyncHostResolver::~AsyncHostResolver() {
it != requestlist_map_.end(); ++it)
STLDeleteElements(&it->second);
- // Destroy transactions.
- STLDeleteElements(&transactions_);
+ // Destroy DNS requests.
+ STLDeleteElements(&dns_requests_);
// Destroy pending requests.
for (size_t i = 0; i < arraysize(pending_requests_); ++i)
@@ -240,8 +245,8 @@ int AsyncHostResolver::Resolve(const RequestInfo& info,
rv = request->result();
else if (AttachToRequestList(request.get()))
rv = ERR_IO_PENDING;
- else if (transactions_.size() < max_transactions_)
- rv = StartNewTransactionFor(request.get());
+ else if (dns_requests_.size() < max_dns_requests_)
+ rv = StartNewDnsRequestFor(request.get());
else
rv = Enqueue(request.get());
@@ -327,39 +332,56 @@ HostCache* AsyncHostResolver::GetHostCache() {
return cache_.get();
}
-void AsyncHostResolver::OnTransactionComplete(
+void AsyncHostResolver::OnDnsRequestComplete(
+ DnsClient::Request* dns_req,
int result,
- const DnsTransaction* transaction,
- const IPAddressList& ip_addresses) {
- DCHECK(std::find(transactions_.begin(), transactions_.end(), transaction)
- != transactions_.end());
- DCHECK(requestlist_map_.find(transaction->key()) != requestlist_map_.end());
+ const DnsResponse* response) {
+ DCHECK(std::find(dns_requests_.begin(), dns_requests_.end(), dns_req)
+ != dns_requests_.end());
- // If by the time requests that caused |transaction| are cancelled, we do
+ // If by the time requests that caused |dns_req| are cancelled, we do
// not have a port number to associate with the result, therefore, we
// assume the most common port, otherwise we use the port number of the
// first request.
- RequestList& requests = requestlist_map_[transaction->key()];
+ KeyRequestListMap::iterator rit = requestlist_map_.find(
+ std::make_pair(dns_req->qname(), dns_req->qtype()));
+ DCHECK(rit != requestlist_map_.end());
+ RequestList& requests = rit->second;
int port = requests.empty() ? 80 : requests.front()->info().port();
- // Run callback of every request that was depending on this transaction,
+ // Extract AddressList out of DnsResponse.
+ AddressList addr_list;
+ if (result == OK) {
+ IPAddressList ip_addresses;
+ DnsRecordParser parser = response->Parser();
+ DnsResourceRecord record;
+ // TODO(szym): Add stricter checking of names, aliases and address lengths.
+ while (parser.ParseRecord(&record)) {
+ if (record.type == dns_req->qtype() &&
+ (record.rdata.size() == kIPv4AddressSize ||
+ record.rdata.size() == kIPv6AddressSize)) {
+ ip_addresses.push_back(IPAddressNumber(record.rdata.begin(),
+ record.rdata.end()));
+ }
+ }
+ if (!ip_addresses.empty())
+ addr_list = AddressList::CreateFromIPAddressList(ip_addresses, port);
+ else
+ result = ERR_NAME_NOT_RESOLVED;
+ }
+
+ // Run callback of every request that was depending on this DNS request,
// also notify observers.
- AddressList addrlist;
- if (result == OK)
- addrlist = AddressList::CreateFromIPAddressList(ip_addresses, port);
- for (RequestList::iterator it = requests.begin(); it != requests.end();
- ++it)
- (*it)->OnAsyncComplete(result, addrlist);
-
- // It is possible that the requests that caused |transaction| to be
- // created are cancelled by the time |transaction| completes. In that
+ for (RequestList::iterator it = requests.begin(); it != requests.end(); ++it)
+ (*it)->OnAsyncComplete(result, addr_list);
+
+ // It is possible that the requests that caused |dns_req| to be
+ // created are cancelled by the time |dns_req| completes. In that
// case |requests| would be empty. We are knowingly throwing away the
// result of a DNS resolution in that case, because (a) if there are no
// requests, we do not have info to obtain a key from, (b) DnsTransaction
// does not have info(), adding one into it just temporarily doesn't make
- // sense, since HostCache will be replaced with RR cache soon, (c)
- // recreating info from DnsTransaction::Key adds a lot of temporary
- // code/functions (like converting back from qtype to AddressFamily.)
+ // sense, since HostCache will be replaced with RR cache soon.
// Also, we only cache positive results. All of this will change when RR
// cache is added.
if (result == OK && cache_.get() && !requests.empty()) {
@@ -367,16 +389,16 @@ void AsyncHostResolver::OnTransactionComplete(
HostResolver::RequestInfo info = request->info();
HostCache::Key key(
info.hostname(), info.address_family(), info.host_resolver_flags());
- cache_->Set(key, result, addrlist, base::TimeTicks::Now());
+ cache_->Set(key, result, addr_list, base::TimeTicks::Now());
}
// Cleanup requests.
STLDeleteElements(&requests);
- requestlist_map_.erase(transaction->key());
+ requestlist_map_.erase(rit);
- // Cleanup transaction and start a new one if there are pending requests.
- delete transaction;
- transactions_.remove(transaction);
+ // Cleanup |dns_req| and start a new one if there are pending requests.
+ delete dns_req;
+ dns_requests_.remove(dns_req);
ProcessPending();
}
@@ -399,25 +421,22 @@ bool AsyncHostResolver::AttachToRequestList(Request* request) {
return true;
}
-int AsyncHostResolver::StartNewTransactionFor(Request* request) {
+int AsyncHostResolver::StartNewDnsRequestFor(Request* request) {
DCHECK(requestlist_map_.find(request->key()) == requestlist_map_.end());
- DCHECK(transactions_.size() < max_transactions_);
+ DCHECK(dns_requests_.size() < max_dns_requests_);
request->request_net_log().AddEvent(
NetLog::TYPE_ASYNC_HOST_RESOLVER_CREATE_DNS_TRANSACTION, NULL);
requestlist_map_[request->key()].push_back(request);
- DnsTransaction* transaction = new DnsTransaction(
- dns_server_,
+ DnsClient::Request* dns_req = client_->CreateRequest(
request->key().first,
request->key().second,
- rand_int_cb_,
- factory_,
- request->request_net_log(),
- net_log_);
- transaction->SetDelegate(this);
- transactions_.push_back(transaction);
- return transaction->Start();
+ base::Bind(&AsyncHostResolver::OnDnsRequestComplete,
+ base::Unretained(this)),
+ request->request_net_log());
+ dns_requests_.push_back(dns_req);
+ return dns_req->Start();
}
int AsyncHostResolver::Enqueue(Request* request) {
@@ -490,7 +509,7 @@ void AsyncHostResolver::ProcessPending() {
}
}
}
- StartNewTransactionFor(request);
+ StartNewDnsRequestFor(request);
}
} // namespace net
diff --git a/net/dns/async_host_resolver.h b/net/dns/async_host_resolver.h
index c501d65..e8aeb8b 100644
--- a/net/dns/async_host_resolver.h
+++ b/net/dns/async_host_resolver.h
@@ -8,6 +8,8 @@
#include <list>
#include <map>
+#include <string>
+#include <utility>
#include "base/threading/non_thread_safe.h"
#include "net/base/address_family.h"
@@ -15,24 +17,18 @@
#include "net/base/host_resolver.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_log.h"
-#include "net/base/rand_callback.h"
-#include "net/dns/dns_transaction.h"
+#include "net/dns/dns_client.h"
namespace net {
-class ClientSocketFactory;
-
class NET_EXPORT AsyncHostResolver
: public HostResolver,
- public DnsTransaction::Delegate,
NON_EXPORTED_BASE(public base::NonThreadSafe) {
public:
- AsyncHostResolver(const IPEndPoint& dns_server,
- size_t max_transactions,
- size_t max_pending_requests_,
- const RandIntCallback& rand_int,
+ AsyncHostResolver(size_t max_dns_requests,
+ size_t max_pending_requests,
HostCache* cache,
- ClientSocketFactory* factory,
+ DnsClient* client,
NetLog* net_log);
virtual ~AsyncHostResolver();
@@ -50,11 +46,9 @@ class NET_EXPORT AsyncHostResolver
virtual AddressFamily GetDefaultAddressFamily() const OVERRIDE;
virtual HostCache* GetHostCache() OVERRIDE;
- // DnsTransaction::Delegate interface
- virtual void OnTransactionComplete(
- int result,
- const DnsTransaction* transaction,
- const IPAddressList& ip_addresses) OVERRIDE;
+ void OnDnsRequestComplete(DnsClient::Request* request,
+ int result,
+ const DnsResponse* transaction);
private:
FRIEND_TEST_ALL_PREFIXES(AsyncHostResolverTest, QueuedLookup);
@@ -68,9 +62,9 @@ class NET_EXPORT AsyncHostResolver
class Request;
- typedef DnsTransaction::Key Key;
+ typedef std::pair<std::string, uint16> Key;
typedef std::list<Request*> RequestList;
- typedef std::list<const DnsTransaction*> TransactionList;
+ typedef std::list<const DnsClient::Request*> DnsRequestList;
typedef std::map<Key, RequestList> KeyRequestListMap;
// Create a new request for the incoming Resolve() call.
@@ -92,9 +86,9 @@ class NET_EXPORT AsyncHostResolver
// attach |request| to the respective list.
bool AttachToRequestList(Request* request);
- // Will start a new transaction for |request|, will insert a new key in
+ // Will start a new DNS request for |request|, will insert a new key in
// |requestlist_map_| and append |request| to the respective list.
- int StartNewTransactionFor(Request* request);
+ int StartNewDnsRequestFor(Request* request);
// Will enqueue |request| in |pending_requests_|.
int Enqueue(Request* request);
@@ -114,11 +108,11 @@ class NET_EXPORT AsyncHostResolver
// there are pending requests.
void ProcessPending();
- // Maximum number of concurrent transactions.
- size_t max_transactions_;
+ // Maximum number of concurrent DNS requests.
+ size_t max_dns_requests_;
- // List of current transactions.
- TransactionList transactions_;
+ // List of current DNS requests.
+ DnsRequestList dns_requests_;
// A map from Key to a list of requests waiting for the Key to resolve.
KeyRequestListMap requestlist_map_;
@@ -129,18 +123,10 @@ class NET_EXPORT AsyncHostResolver
// Queues based on priority for putting pending requests.
RequestList pending_requests_[NUM_PRIORITIES];
- // DNS server to which queries will be setn.
- IPEndPoint dns_server_;
-
- // Callback to be passed to DnsTransaction for generating DNS query ids.
- RandIntCallback rand_int_cb_;
-
// Cache of host resolution results.
scoped_ptr<HostCache> cache_;
- // Also passed to DnsTransaction; it's a dependency injection to aid
- // testing, outside of unit tests, its value is always NULL.
- ClientSocketFactory* factory_;
+ DnsClient* client_;
NetLog* net_log_;
diff --git a/net/dns/async_host_resolver_unittest.cc b/net/dns/async_host_resolver_unittest.cc
index d92887e..21123bc 100644
--- a/net/dns/async_host_resolver_unittest.cc
+++ b/net/dns/async_host_resolver_unittest.cc
@@ -6,18 +6,27 @@
#include "base/bind.h"
#include "base/memory/scoped_ptr.h"
+#include "base/message_loop.h"
+#include "base/stl_util.h"
#include "net/base/host_cache.h"
+#include "net/base/net_errors.h"
#include "net/base/net_log.h"
-#include "net/base/rand_callback.h"
#include "net/base/sys_addrinfo.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
#include "net/dns/dns_test_util.h"
-#include "net/socket/socket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
namespace {
+const int kPortNum = 80;
+const size_t kMaxTransactions = 2;
+const size_t kMaxPendingRequests = 1;
+
void VerifyAddressList(const std::vector<const char*>& ip_addresses,
int port,
const AddressList& addrlist) {
@@ -39,12 +48,64 @@ void VerifyAddressList(const std::vector<const char*>& ip_addresses,
ASSERT_EQ(static_cast<addrinfo*>(NULL), ainfo);
}
+class MockDnsClient : public DnsClient,
+ public base::SupportsWeakPtr<MockDnsClient> {
+ public:
+ // Using WeakPtr to support cancellation.
+ // All MockRequests succeed unless canceled or MockDnsClient is destroyed.
+ class MockRequest : public DnsClient::Request,
+ public base::SupportsWeakPtr<MockRequest> {
+ public:
+ MockRequest(const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback,
+ const base::WeakPtr<MockDnsClient>& client)
+ : Request(qname, qtype, callback), started_(false), client_(client) {
+ }
+
+ virtual int Start() OVERRIDE {
+ EXPECT_FALSE(started_);
+ started_ = true;
+ MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&MockRequest::Finish, AsWeakPtr()));
+ return ERR_IO_PENDING;
+ }
+
+ private:
+ void Finish() {
+ if (!client_) {
+ DoCallback(ERR_DNS_SERVER_FAILED, NULL);
+ return;
+ }
+ DoCallback(OK, client_->responses[Key(qname(), qtype())]);
+ }
+
+ bool started_;
+ base::WeakPtr<MockDnsClient> client_;
+ };
+
+ typedef std::pair<std::string, uint16> Key;
+
+ MockDnsClient() : num_requests(0) {}
+ ~MockDnsClient() {
+ STLDeleteValues(&responses);
+ }
+
+ Request* CreateRequest(const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback,
+ const BoundNetLog&) {
+ ++num_requests;
+ return new MockRequest(qname, qtype, callback, AsWeakPtr());
+ }
+
+ int num_requests;
+ std::map<Key, DnsResponse*> responses;
+};
+
} // namespace
-static const int kPortNum = 80;
-static const size_t kMaxTransactions = 2;
-static const size_t kMaxPendingRequests = 1;
-static int transaction_ids[] = {0, 1, 2, 3};
// The following fixture sets up an environment for four different lookups
// with their data defined in dns_test_util.h. All tests make use of these
@@ -69,84 +130,52 @@ class AsyncHostResolverTest : public testing::Test {
ip_addresses2_(kT2IpAddresses,
kT2IpAddresses + arraysize(kT2IpAddresses)),
ip_addresses3_(kT3IpAddresses,
- kT3IpAddresses + arraysize(kT3IpAddresses)),
- test_prng_(std::deque<int>(
- transaction_ids, transaction_ids + arraysize(transaction_ids))) {
- rand_int_cb_ = base::Bind(&TestPrng::GetNext,
- base::Unretained(&test_prng_));
+ kT3IpAddresses + arraysize(kT3IpAddresses)) {
// AF_INET only for now.
info0_.set_address_family(ADDRESS_FAMILY_IPV4);
info1_.set_address_family(ADDRESS_FAMILY_IPV4);
info2_.set_address_family(ADDRESS_FAMILY_IPV4);
info3_.set_address_family(ADDRESS_FAMILY_IPV4);
- // Setup socket read/writes for transaction 0.
- writes0_.push_back(
- MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
- arraysize(kT0QueryDatagram)));
- reads0_.push_back(
- MockRead(true, reinterpret_cast<const char*>(kT0ResponseDatagram),
- arraysize(kT0ResponseDatagram)));
- data0_.reset(new StaticSocketDataProvider(&reads0_[0], reads0_.size(),
- &writes0_[0], writes0_.size()));
-
- // Setup socket read/writes for transaction 1.
- writes1_.push_back(
- MockWrite(true, reinterpret_cast<const char*>(kT1QueryDatagram),
- arraysize(kT1QueryDatagram)));
- reads1_.push_back(
- MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram),
- arraysize(kT1ResponseDatagram)));
- data1_.reset(new StaticSocketDataProvider(&reads1_[0], reads1_.size(),
- &writes1_[0], writes1_.size()));
-
- // Setup socket read/writes for transaction 2.
- writes2_.push_back(
- MockWrite(true, reinterpret_cast<const char*>(kT2QueryDatagram),
- arraysize(kT2QueryDatagram)));
- reads2_.push_back(
- MockRead(true, reinterpret_cast<const char*>(kT2ResponseDatagram),
- arraysize(kT2ResponseDatagram)));
- data2_.reset(new StaticSocketDataProvider(&reads2_[0], reads2_.size(),
- &writes2_[0], writes2_.size()));
-
- // Setup socket read/writes for transaction 3.
- writes3_.push_back(
- MockWrite(true, reinterpret_cast<const char*>(kT3QueryDatagram),
- arraysize(kT3QueryDatagram)));
- reads3_.push_back(
- MockRead(true, reinterpret_cast<const char*>(kT3ResponseDatagram),
- arraysize(kT3ResponseDatagram)));
- data3_.reset(new StaticSocketDataProvider(&reads3_[0], reads3_.size(),
- &writes3_[0], writes3_.size()));
-
- factory_.AddSocketDataProvider(data0_.get());
- factory_.AddSocketDataProvider(data1_.get());
- factory_.AddSocketDataProvider(data2_.get());
- factory_.AddSocketDataProvider(data3_.get());
-
- IPEndPoint dns_server;
- bool rv0 = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
- DCHECK(rv0);
+ client_.reset(new MockDnsClient());
+
+ AddResponse(std::string(kT0DnsName, arraysize(kT0DnsName)), kT0Qtype,
+ new DnsResponse(reinterpret_cast<const char*>(kT0ResponseDatagram),
+ arraysize(kT0ResponseDatagram),
+ arraysize(kT0QueryDatagram)));
+
+ AddResponse(std::string(kT1DnsName, arraysize(kT1DnsName)), kT1Qtype,
+ new DnsResponse(reinterpret_cast<const char*>(kT1ResponseDatagram),
+ arraysize(kT1ResponseDatagram),
+ arraysize(kT1QueryDatagram)));
+
+ AddResponse(std::string(kT2DnsName, arraysize(kT2DnsName)), kT2Qtype,
+ new DnsResponse(reinterpret_cast<const char*>(kT2ResponseDatagram),
+ arraysize(kT2ResponseDatagram),
+ arraysize(kT2QueryDatagram)));
+
+ AddResponse(std::string(kT3DnsName, arraysize(kT3DnsName)), kT3Qtype,
+ new DnsResponse(reinterpret_cast<const char*>(kT3ResponseDatagram),
+ arraysize(kT3ResponseDatagram),
+ arraysize(kT3QueryDatagram)));
resolver_.reset(
- new AsyncHostResolver(
- dns_server, kMaxTransactions, kMaxPendingRequests, rand_int_cb_,
- HostCache::CreateDefaultCache(), &factory_, NULL));
+ new AsyncHostResolver(kMaxTransactions, kMaxPendingRequests,
+ HostCache::CreateDefaultCache(),
+ client_.get(), NULL));
+ }
+
+ void AddResponse(const std::string& name, uint8 type, DnsResponse* response) {
+ client_->responses[MockDnsClient::Key(name, type)] = response;
}
protected:
AddressList addrlist0_, addrlist1_, addrlist2_, addrlist3_;
HostResolver::RequestInfo info0_, info1_, info2_, info3_;
- std::vector<MockWrite> writes0_, writes1_, writes2_, writes3_;
- std::vector<MockRead> reads0_, reads1_, reads2_, reads3_;
- scoped_ptr<StaticSocketDataProvider> data0_, data1_, data2_, data3_;
std::vector<const char*> ip_addresses0_, ip_addresses1_,
ip_addresses2_, ip_addresses3_;
- MockClientSocketFactory factory_;
- TestPrng test_prng_;
- RandIntCallback rand_int_cb_;
scoped_ptr<HostResolver> resolver_;
+ scoped_ptr<MockDnsClient> client_;
TestCompletionCallback callback0_, callback1_, callback2_, callback3_;
};
@@ -242,7 +271,7 @@ TEST_F(AsyncHostResolverTest, ConcurrentLookup) {
EXPECT_EQ(OK, rv2);
VerifyAddressList(ip_addresses2_, kPortNum, addrlist2_);
- EXPECT_EQ(3u, factory_.udp_client_sockets().size());
+ EXPECT_EQ(3, client_->num_requests);
}
TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) {
@@ -270,7 +299,7 @@ TEST_F(AsyncHostResolverTest, SameHostLookupsConsumeSingleTransaction) {
VerifyAddressList(ip_addresses0_, kPortNum, addrlist2_);
// Although we have three lookups, a single UDP socket was used.
- EXPECT_EQ(1u, factory_.udp_client_sockets().size());
+ EXPECT_EQ(1, client_->num_requests);
}
TEST_F(AsyncHostResolverTest, CancelLookup) {
@@ -319,7 +348,7 @@ TEST_F(AsyncHostResolverTest, CancelSameHostLookup) {
EXPECT_EQ(OK, rv1);
VerifyAddressList(ip_addresses0_, kPortNum, addrlist1_);
- EXPECT_EQ(1u, factory_.udp_client_sockets().size());
+ EXPECT_EQ(1, client_->num_requests);
}
// Test that a queued lookup completes.
diff --git a/net/dns/dns_client.cc b/net/dns/dns_client.cc
new file mode 100644
index 0000000..de60cc3
--- /dev/null
+++ b/net/dns/dns_client.cc
@@ -0,0 +1,91 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_client.h"
+
+#include "base/bind.h"
+#include "base/string_piece.h"
+#include "net/base/net_errors.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
+#include "net/dns/dns_transaction.h"
+#include "net/socket/client_socket_factory.h"
+
+namespace net {
+
+DnsClient::Request::Request(const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback)
+ : qname_(qname.data(), qname.size()),
+ qtype_(qtype),
+ callback_(callback) {
+}
+
+DnsClient::Request::~Request() {}
+
+// Implementation of DnsClient that uses DnsTransaction to serve requests.
+class DnsClientImpl : public DnsClient {
+ public:
+ class RequestImpl : public Request {
+ public:
+ RequestImpl(const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback,
+ DnsSession* session,
+ const BoundNetLog& net_log)
+ : Request(qname, qtype, callback),
+ session_(session),
+ net_log_(net_log) {
+ }
+
+ virtual int Start() OVERRIDE {
+ transaction_.reset(new DnsTransaction(
+ session_,
+ qname(),
+ qtype(),
+ base::Bind(&RequestImpl::OnComplete, base::Unretained(this)),
+ net_log_));
+ return transaction_->Start();
+ }
+
+ void OnComplete(DnsTransaction* transaction, int rv) {
+ DCHECK_EQ(transaction_.get(), transaction);
+ // TODO(szym):
+ // - handle retransmissions here instead of DnsTransaction
+ // - handle rcode and flags here instead of DnsTransaction
+ // - update RTT in DnsSession
+ // - perform suffix search
+ // - handle DNS over TCP
+ DoCallback(rv, (rv == OK) ? transaction->response() : NULL);
+ }
+
+ private:
+ scoped_refptr<DnsSession> session_;
+ BoundNetLog net_log_;
+ scoped_ptr<DnsTransaction> transaction_;
+ };
+
+ explicit DnsClientImpl(DnsSession* session) {
+ session_ = session;
+ }
+
+ virtual Request* CreateRequest(
+ const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback,
+ const BoundNetLog& source_net_log) OVERRIDE {
+ return new RequestImpl(qname, qtype, callback, session_, source_net_log);
+ }
+
+ private:
+ scoped_refptr<DnsSession> session_;
+};
+
+// static
+DnsClient* DnsClient::CreateClient(DnsSession* session) {
+ return new DnsClientImpl(session);
+}
+
+} // namespace net
+
diff --git a/net/dns/dns_client.h b/net/dns/dns_client.h
new file mode 100644
index 0000000..af75cf6
--- /dev/null
+++ b/net/dns/dns_client.h
@@ -0,0 +1,93 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_CLIENT_H_
+#define NET_DNS_DNS_CLIENT_H_
+#pragma once
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/callback.h"
+#include "base/memory/weak_ptr.h"
+#include "net/base/net_export.h"
+
+namespace base {
+class StringPiece;
+}
+
+namespace net {
+
+class BoundNetLog;
+class ClientSocketFactory;
+class DnsResponse;
+class DnsSession;
+
+// DnsClient performs asynchronous DNS queries. DnsClient takes care of
+// retransmissions, DNS server fallback (or round-robin), suffix search, and
+// simple response validation ("does it match the query") to fight poisoning.
+// It does NOT perform caching, aggregation or prioritization of requests.
+//
+// Destroying DnsClient does NOT affect any already created Requests.
+//
+// TODO(szym): consider adding flags to MakeRequest to indicate options:
+// -- don't perform suffix search
+// -- query both A and AAAA at once
+// -- accept truncated response (and/or forbid TCP)
+class NET_EXPORT_PRIVATE DnsClient {
+ public:
+ class Request;
+ // Callback for complete requests. Note that DnsResponse might be NULL if
+ // the DNS server(s) could not be reached.
+ typedef base::Callback<void(Request* req,
+ int result,
+ const DnsResponse* resp)> RequestCallback;
+
+ // A data-holder for a request made to the DnsClient.
+ // Destroying the request cancels the underlying network effort.
+ class NET_EXPORT_PRIVATE Request {
+ public:
+ Request(const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback);
+ virtual ~Request();
+
+ const std::string& qname() const { return qname_; }
+
+ uint16 qtype() const { return qtype_; }
+
+ virtual int Start() = 0;
+
+ void DoCallback(int result, const DnsResponse* response) {
+ callback_.Run(this, result, response);
+ }
+
+ private:
+ std::string qname_;
+ uint16 qtype_;
+ RequestCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(Request);
+ };
+
+ virtual ~DnsClient() {}
+
+ // Makes asynchronous DNS query for the given |qname| and |qtype| (assuming
+ // QCLASS == IN). The caller is responsible for destroying the returned
+ // request whether to cancel it or after its completion.
+ // (Destroying DnsClient does not abort the requests.)
+ virtual Request* CreateRequest(
+ const base::StringPiece& qname,
+ uint16 qtype,
+ const RequestCallback& callback,
+ const BoundNetLog& source_net_log) WARN_UNUSED_RESULT = 0;
+
+ // Creates a socket-based DnsClient using the |session|.
+ static DnsClient* CreateClient(DnsSession* session) WARN_UNUSED_RESULT;
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_CLIENT_H_
+
diff --git a/net/dns/dns_client_unittest.cc b/net/dns/dns_client_unittest.cc
new file mode 100644
index 0000000..fdbae8f
--- /dev/null
+++ b/net/dns/dns_client_unittest.cc
@@ -0,0 +1,311 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_client.h"
+
+#include "base/bind.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/big_endian.h"
+#include "net/base/net_log.h"
+#include "net/base/sys_addrinfo.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
+#include "net/dns/dns_test_util.h"
+#include "net/socket/socket_test_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+// TODO(szym): test DnsClient::Request::Start with synchronous failure
+// TODO(szym): test suffix search and server fallback once implemented
+
+namespace net {
+
+namespace {
+
+class DnsClientTest : public testing::Test {
+ public:
+ class TestRequestHelper {
+ public:
+ // If |answer_count| < 0, it is the expected error code.
+ TestRequestHelper(const char* name,
+ uint16 type,
+ const MockWrite& write,
+ const MockRead& read,
+ int answer_count) {
+ // Must include the terminating \x00.
+ qname = std::string(name, strlen(name) + 1);
+ qtype = type;
+ expected_answer_count = answer_count;
+ completed = false;
+ writes.push_back(write);
+ reads.push_back(read);
+ ReadBigEndian<uint16>(write.data, &transaction_id);
+ data.reset(new StaticSocketDataProvider(&reads[0], reads.size(),
+ &writes[0], writes.size()));
+ }
+
+ void MakeRequest(DnsClient* client) {
+ EXPECT_EQ(NULL, request.get());
+ request.reset(client->CreateRequest(
+ qname,
+ qtype,
+ base::Bind(&TestRequestHelper::OnRequestComplete,
+ base::Unretained(this)),
+ BoundNetLog()));
+ EXPECT_EQ(qname, request->qname());
+ EXPECT_EQ(qtype, request->qtype());
+ EXPECT_EQ(ERR_IO_PENDING, request->Start());
+ }
+
+ void Cancel() {
+ ASSERT_TRUE(request.get() != NULL);
+ request.reset(NULL);
+ }
+
+ void OnRequestComplete(DnsClient::Request* req,
+ int rv,
+ const DnsResponse* response) {
+ EXPECT_FALSE(completed);
+ EXPECT_EQ(request.get(), req);
+
+ if (expected_answer_count >= 0) {
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ(expected_answer_count, response->answer_count());
+
+ DnsRecordParser parser = response->Parser();
+ DnsResourceRecord record;
+ for (int i = 0; i < expected_answer_count; ++i) {
+ EXPECT_TRUE(parser.ParseRecord(&record));
+ }
+ EXPECT_TRUE(parser.AtEnd());
+
+ } else {
+ EXPECT_EQ(expected_answer_count, rv);
+ EXPECT_EQ(NULL, response);
+ }
+
+ completed = true;
+ }
+
+ void CancelOnRequestComplete(DnsClient::Request* req,
+ int rv,
+ const DnsResponse* response) {
+ EXPECT_FALSE(completed);
+ Cancel();
+ }
+
+ std::string qname;
+ uint16 qtype;
+ std::vector<MockWrite> writes;
+ std::vector<MockRead> reads;
+ uint16 transaction_id; // Id from first write.
+ scoped_ptr<StaticSocketDataProvider> data;
+ scoped_ptr<DnsClient::Request> request;
+ int expected_answer_count;
+
+ bool completed;
+ };
+
+ virtual void SetUp() OVERRIDE {
+ helpers_.push_back(new TestRequestHelper(
+ kT0DnsName,
+ kT0Qtype,
+ MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
+ arraysize(kT0QueryDatagram)),
+ MockRead(true, reinterpret_cast<const char*>(kT0ResponseDatagram),
+ arraysize(kT0ResponseDatagram)),
+ arraysize(kT0IpAddresses) + 1)); // +1 for CNAME RR
+
+ helpers_.push_back(new TestRequestHelper(
+ kT1DnsName,
+ kT1Qtype,
+ MockWrite(true, reinterpret_cast<const char*>(kT1QueryDatagram),
+ arraysize(kT1QueryDatagram)),
+ MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram),
+ arraysize(kT1ResponseDatagram)),
+ arraysize(kT1IpAddresses) + 1)); // +1 for CNAME RR
+
+ helpers_.push_back(new TestRequestHelper(
+ kT2DnsName,
+ kT2Qtype,
+ MockWrite(true, reinterpret_cast<const char*>(kT2QueryDatagram),
+ arraysize(kT2QueryDatagram)),
+ MockRead(true, reinterpret_cast<const char*>(kT2ResponseDatagram),
+ arraysize(kT2ResponseDatagram)),
+ arraysize(kT2IpAddresses) + 1)); // +1 for CNAME RR
+
+ helpers_.push_back(new TestRequestHelper(
+ kT3DnsName,
+ kT3Qtype,
+ MockWrite(true, reinterpret_cast<const char*>(kT3QueryDatagram),
+ arraysize(kT3QueryDatagram)),
+ MockRead(true, reinterpret_cast<const char*>(kT3ResponseDatagram),
+ arraysize(kT3ResponseDatagram)),
+ arraysize(kT3IpAddresses) + 2)); // +2 for CNAME RR
+
+ CreateClient();
+ }
+
+ void CreateClient() {
+ MockClientSocketFactory* factory = new MockClientSocketFactory();
+
+ transaction_ids_.clear();
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ factory->AddSocketDataProvider(helpers_[i]->data.get());
+ transaction_ids_.push_back(static_cast<int>(helpers_[i]->transaction_id));
+ }
+
+ DnsConfig config;
+
+ IPEndPoint dns_server;
+ {
+ bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
+ EXPECT_TRUE(rv);
+ }
+ config.nameservers.push_back(dns_server);
+
+ DnsSession* session = new DnsSession(
+ config,
+ factory,
+ base::Bind(&DnsClientTest::GetNextId, base::Unretained(this)),
+ NULL /* NetLog */);
+
+ client_.reset(DnsClient::CreateClient(session));
+ }
+
+ virtual void TearDown() OVERRIDE {
+ STLDeleteElements(&helpers_);
+ }
+
+ int GetNextId(int min, int max) {
+ EXPECT_FALSE(transaction_ids_.empty());
+ int id = transaction_ids_.front();
+ transaction_ids_.pop_front();
+ EXPECT_GE(id, min);
+ EXPECT_LE(id, max);
+ return id;
+ }
+
+ protected:
+ std::vector<TestRequestHelper*> helpers_;
+ std::deque<int> transaction_ids_;
+ scoped_ptr<DnsClient> client_;
+};
+
+TEST_F(DnsClientTest, Lookup) {
+ helpers_[0]->MakeRequest(client_.get());
+
+ // Wait until result.
+ MessageLoop::current()->RunAllPending();
+
+ EXPECT_TRUE(helpers_[0]->completed);
+}
+
+TEST_F(DnsClientTest, ConcurrentLookup) {
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ helpers_[i]->MakeRequest(client_.get());
+ }
+
+ MessageLoop::current()->RunAllPending();
+
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ EXPECT_TRUE(helpers_[i]->completed);
+ }
+}
+
+TEST_F(DnsClientTest, CancelLookup) {
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ helpers_[i]->MakeRequest(client_.get());
+ }
+
+ helpers_[0]->Cancel();
+ helpers_[2]->Cancel();
+
+ MessageLoop::current()->RunAllPending();
+
+ EXPECT_FALSE(helpers_[0]->completed);
+ EXPECT_TRUE(helpers_[1]->completed);
+ EXPECT_FALSE(helpers_[2]->completed);
+ EXPECT_TRUE(helpers_[3]->completed);
+}
+
+TEST_F(DnsClientTest, DestroyClient) {
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ helpers_[i]->MakeRequest(client_.get());
+ }
+
+ // Destroying the client does not affect running requests.
+ client_.reset(NULL);
+
+ MessageLoop::current()->RunAllPending();
+
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ EXPECT_TRUE(helpers_[i]->completed);
+ }
+}
+
+TEST_F(DnsClientTest, DestroyRequestFromCallback) {
+ // Custom callback to cancel the completing request.
+ helpers_[0]->request.reset(client_->CreateRequest(
+ helpers_[0]->qname,
+ helpers_[0]->qtype,
+ base::Bind(&TestRequestHelper::CancelOnRequestComplete,
+ base::Unretained(helpers_[0])),
+ BoundNetLog()));
+ helpers_[0]->request->Start();
+
+ for (unsigned i = 1; i < helpers_.size(); ++i) {
+ helpers_[i]->MakeRequest(client_.get());
+ }
+
+ MessageLoop::current()->RunAllPending();
+
+ EXPECT_FALSE(helpers_[0]->completed);
+ for (unsigned i = 1; i < helpers_.size(); ++i) {
+ EXPECT_TRUE(helpers_[i]->completed);
+ }
+}
+
+TEST_F(DnsClientTest, HandleFailure) {
+ STLDeleteElements(&helpers_);
+ // Wrong question.
+ helpers_.push_back(new TestRequestHelper(
+ kT0DnsName,
+ kT0Qtype,
+ MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
+ arraysize(kT0QueryDatagram)),
+ MockRead(true, reinterpret_cast<const char*>(kT1ResponseDatagram),
+ arraysize(kT1ResponseDatagram)),
+ ERR_DNS_MALFORMED_RESPONSE));
+
+ // Response with NXDOMAIN.
+ uint8 nxdomain_response[arraysize(kT0QueryDatagram)];
+ memcpy(nxdomain_response, kT0QueryDatagram, arraysize(nxdomain_response));
+ nxdomain_response[2] &= 0x80; // Response bit.
+ nxdomain_response[3] &= 0x03; // NXDOMAIN bit.
+ helpers_.push_back(new TestRequestHelper(
+ kT0DnsName,
+ kT0Qtype,
+ MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
+ arraysize(kT0QueryDatagram)),
+ MockRead(true, reinterpret_cast<const char*>(nxdomain_response),
+ arraysize(nxdomain_response)),
+ ERR_NAME_NOT_RESOLVED));
+
+ CreateClient();
+
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ helpers_[i]->MakeRequest(client_.get());
+ }
+
+ MessageLoop::current()->RunAllPending();
+
+ for (unsigned i = 0; i < helpers_.size(); ++i) {
+ EXPECT_TRUE(helpers_[i]->completed);
+ }
+}
+
+} // namespace
+
+} // namespace net
+
diff --git a/net/dns/dns_protocol.h b/net/dns/dns_protocol.h
new file mode 100644
index 0000000..494429b
--- /dev/null
+++ b/net/dns/dns_protocol.h
@@ -0,0 +1,122 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_PROTOCOL_H_
+#define NET_DNS_DNS_PROTOCOL_H_
+#pragma once
+
+#include "base/basictypes.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+namespace dns_protocol {
+
+// DNS packet consists of a header followed by questions and/or answers.
+// For the meaning of specific fields, please see RFC 1035 and 2535
+
+// Header format.
+// 1 1 1 1 1 1
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | ID |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// |QR| Opcode |AA|TC|RD|RA| Z|AD|CD| RCODE |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | QDCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | ANCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | NSCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | ARCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+
+// Question format.
+// 1 1 1 1 1 1
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | |
+// / QNAME /
+// / /
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | QTYPE |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | QCLASS |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+
+// Answer format.
+// 1 1 1 1 1 1
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | |
+// / /
+// / NAME /
+// | |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | TYPE |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | CLASS |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | TTL |
+// | |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | RDLENGTH |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
+// / RDATA /
+// / /
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+
+#pragma pack(push)
+#pragma pack(1)
+
+// On-the-wire header. All uint16 are in network order.
+// Used internally in DnsQuery and DnsResponseParser.
+struct NET_EXPORT_PRIVATE Header {
+ uint16 id;
+ uint8 flags[2];
+ uint16 qdcount;
+ uint16 ancount;
+ uint16 nscount;
+ uint16 arcount;
+};
+
+#pragma pack(pop)
+
+static const uint8 kLabelMask = 0xc0;
+static const uint8 kLabelPointer = 0xc0;
+static const uint8 kLabelDirect = 0x0;
+static const uint16 kOffsetMask = 0x3fff;
+
+static const int kMaxNameLength = 255;
+
+// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512
+// bytes (not counting the IP nor UDP headers).
+static const int kMaxUDPSize = 512;
+
+// DNS class types.
+static const uint16 kClassIN = 1;
+
+// DNS resource record types. See
+// http://www.iana.org/assignments/dns-parameters
+static const uint16 kTypeA = 1;
+static const uint16 kTypeCNAME = 5;
+static const uint16 kTypeTXT = 16;
+static const uint16 kTypeAAAA = 28;
+
+// DNS rcode values.
+static const uint8 kRcodeMask = 0xf;
+static const uint8 kRcodeNOERROR = 0;
+static const uint8 kRcodeFORMERR = 1;
+static const uint8 kRcodeSERVFAIL = 2;
+static const uint8 kRcodeNXDOMAIN = 3;
+static const uint8 kRcodeNOTIMP = 4;
+static const uint8 kRcodeREFUSED = 5;
+
+} // namespace dns_protocol
+
+} // namespace net
+
+#endif // NET_DNS_DNS_PROTOCOL_H_
+
diff --git a/net/dns/dns_query.cc b/net/dns/dns_query.cc
index 788d653..3cfb5cd 100644
--- a/net/dns/dns_query.cc
+++ b/net/dns/dns_query.cc
@@ -6,91 +6,78 @@
#include <limits>
+#include "net/base/big_endian.h"
#include "net/base/dns_util.h"
#include "net/base/io_buffer.h"
+#include "net/base/sys_byteorder.h"
+#include "net/dns/dns_protocol.h"
namespace net {
-namespace {
-
-void PackUint16BE(char buf[2], uint16 v) {
- buf[0] = v >> 8;
- buf[1] = v & 0xff;
-}
-
-uint16 UnpackUint16BE(char buf[2]) {
- return static_cast<uint8>(buf[0]) << 8 | static_cast<uint8>(buf[1]);
-}
-
-} // namespace
-
// DNS query consists of a 12-byte header followed by a question section.
// For details, see RFC 1035 section 4.1.1. This header template sets RD
// bit, which directs the name server to pursue query recursively, and sets
-// the QDCOUNT to 1, meaning the question section has a single entry. The
-// first two bytes of the header form a 16-bit random query ID to be copied
-// in the corresponding reply by the name server -- randomized during
-// DnsQuery construction.
-static const char kHeader[] = {0x00, 0x00, 0x01, 0x00, 0x00, 0x01,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
-static const size_t kHeaderSize = arraysize(kHeader);
-
-DnsQuery::DnsQuery(const std::string& qname,
- uint16 qtype,
- const RandIntCallback& rand_int_cb)
- : qname_size_(qname.size()),
- rand_int_cb_(rand_int_cb) {
- DCHECK(DnsResponseBuffer(reinterpret_cast<const uint8*>(qname.c_str()),
- qname.size()).DNSName(NULL));
- DCHECK(qtype == kDNS_A || qtype == kDNS_AAAA);
-
- io_buffer_ = new IOBufferWithSize(kHeaderSize + question_size());
-
- int byte_offset = 0;
- char* buffer_head = io_buffer_->data();
- memcpy(&buffer_head[byte_offset], kHeader, kHeaderSize);
- byte_offset += kHeaderSize;
- memcpy(&buffer_head[byte_offset], &qname[0], qname_size_);
- byte_offset += qname_size_;
- PackUint16BE(&buffer_head[byte_offset], qtype);
- byte_offset += sizeof(qtype);
- PackUint16BE(&buffer_head[byte_offset], kClassIN);
- RandomizeId();
+// the QDCOUNT to 1, meaning the question section has a single entry.
+DnsQuery::DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype)
+ : qname_size_(qname.size()) {
+ DCHECK(!DNSDomainToString(qname).empty());
+ // QNAME + QTYPE + QCLASS
+ size_t question_size = qname_size_ + sizeof(uint16) + sizeof(uint16);
+ io_buffer_ = new IOBufferWithSize(sizeof(dns_protocol::Header) +
+ question_size);
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
+ memset(header, 0, sizeof(dns_protocol::Header));
+ header->id = htons(id);
+ header->flags[0] = 0x1; // RD bit
+ header->qdcount = htons(1);
+
+ // Write question section after the header.
+ BigEndianWriter writer(reinterpret_cast<char*>(header + 1), question_size);
+ writer.WriteBytes(qname.data(), qname.size());
+ writer.WriteU16(qtype);
+ writer.WriteU16(dns_protocol::kClassIN);
}
DnsQuery::~DnsQuery() {
}
-uint16 DnsQuery::id() const {
- return UnpackUint16BE(&io_buffer_->data()[0]);
+DnsQuery* DnsQuery::CloneWithNewId(uint16 id) const {
+ return new DnsQuery(*this, id);
}
-uint16 DnsQuery::qtype() const {
- return UnpackUint16BE(&io_buffer_->data()[kHeaderSize + qname_size_]);
-}
-
-DnsQuery* DnsQuery::CloneWithNewId() const {
- return new DnsQuery(qname(), qtype(), rand_int_cb_);
+uint16 DnsQuery::id() const {
+ const dns_protocol::Header* header =
+ reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
+ return ntohs(header->id);
}
-size_t DnsQuery::question_size() const {
- return qname_size_ // QNAME
- + sizeof(uint16) // QTYPE
- + sizeof(uint16); // QCLASS
+base::StringPiece DnsQuery::qname() const {
+ return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header),
+ qname_size_);
}
-const char* DnsQuery::question_data() const {
- return &io_buffer_->data()[kHeaderSize];
+uint16 DnsQuery::qtype() const {
+ uint16 type;
+ ReadBigEndian<uint16>(io_buffer_->data() +
+ sizeof(dns_protocol::Header) +
+ qname_size_, &type);
+ return type;
}
-const std::string DnsQuery::qname() const {
- return std::string(question_data(), qname_size_);
+base::StringPiece DnsQuery::question() const {
+ return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header),
+ qname_size_ + sizeof(uint16) + sizeof(uint16));
}
-void DnsQuery::RandomizeId() {
- PackUint16BE(&io_buffer_->data()[0], rand_int_cb_.Run(
- std::numeric_limits<uint16>::min(),
- std::numeric_limits<uint16>::max()));
+DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) {
+ qname_size_ = orig.qname_size_;
+ io_buffer_ = new IOBufferWithSize(orig.io_buffer()->size());
+ memcpy(io_buffer_.get()->data(), orig.io_buffer()->data(),
+ io_buffer_.get()->size());
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
+ header->id = htons(id);
}
} // namespace net
diff --git a/net/dns/dns_query.h b/net/dns/dns_query.h
index c6bd3bc..e851113 100644
--- a/net/dns/dns_query.h
+++ b/net/dns/dns_query.h
@@ -6,53 +6,41 @@
#define NET_DNS_DNS_QUERY_H_
#pragma once
-#include <string>
-
+#include "base/basictypes.h"
#include "base/memory/ref_counted.h"
+#include "base/string_piece.h"
#include "net/base/net_export.h"
-#include "net/base/rand_callback.h"
namespace net {
class IOBufferWithSize;
// Represents on-the-wire DNS query message as an object.
+// TODO(szym): add support for the OPT pseudo-RR (EDNS0/DNSSEC).
class NET_EXPORT_PRIVATE DnsQuery {
public:
// Constructs a query message from |qname| which *MUST* be in a valid
- // DNS name format, and |qtype| which must be either kDNS_A or kDNS_AAAA.
-
- // Every generated object has a random ID, hence two objects generated
- // with the same set of constructor arguments are generally not equal;
- // there is a 1/2^16 chance of them being equal due to size of |id_|.
- DnsQuery(const std::string& qname,
- uint16 qtype,
- const RandIntCallback& rand_int_cb);
+ // DNS name format, and |qtype|. The qclass is set to IN.
+ DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype);
~DnsQuery();
- // Clones |this| verbatim, with ID field of the header regenerated.
- DnsQuery* CloneWithNewId() const;
+ // Clones |this| verbatim, with ID field of the header set to |id|.
+ DnsQuery* CloneWithNewId(uint16 id) const;
// DnsQuery field accessors.
uint16 id() const;
+ base::StringPiece qname() const;
uint16 qtype() const;
- // Returns the size of the Question section of the query. Used when
- // matching the response.
- size_t question_size() const;
-
- // Returns pointer to the Question section of the query. Used when
- // matching the response.
- const char* question_data() const;
+ // Returns the Question section of the query. Used when matching the
+ // response.
+ base::StringPiece question() const;
// IOBuffer accessor to be used for writing out the query.
IOBufferWithSize* io_buffer() const { return io_buffer_; }
private:
- const std::string qname() const;
-
- // Randomizes ID field of the query message.
- void RandomizeId();
+ DnsQuery(const DnsQuery& orig, uint16 id);
// Size of the DNS name (*NOT* hostname) we are trying to resolve; used
// to calculate offsets.
@@ -61,9 +49,6 @@ class NET_EXPORT_PRIVATE DnsQuery {
// Contains query bytes to be consumed by higher level Write() call.
scoped_refptr<IOBufferWithSize> io_buffer_;
- // PRNG function for generating IDs.
- RandIntCallback rand_int_cb_;
-
DISALLOW_COPY_AND_ASSIGN(DnsQuery);
};
diff --git a/net/dns/dns_query_unittest.cc b/net/dns/dns_query_unittest.cc
index d43cf8be..ffe02e7 100644
--- a/net/dns/dns_query_unittest.cc
+++ b/net/dns/dns_query_unittest.cc
@@ -5,58 +5,21 @@
#include "net/dns/dns_query.h"
#include "base/bind.h"
-#include "base/rand_util.h"
#include "net/base/dns_util.h"
#include "net/base/io_buffer.h"
+#include "net/dns/dns_protocol.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
-// DNS query consists of a header followed by a question. Header format
-// and question format are described below. For the meaning of specific
-// fields, please see RFC 1035.
+namespace {
-// Header format.
-// 1 1 1 1 1 1
-// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | ID |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | QDCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | ANCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | NSCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | ARCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-
-// Question format.
-// 1 1 1 1 1 1
-// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | |
-// / QNAME /
-// / /
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | QTYPE |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | QCLASS |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-
-TEST(DnsQueryTest, ConstructorTest) {
- std::string kQname("\003www\006google\003com", 16);
- DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt));
- EXPECT_EQ(kDNS_A, q1.qtype());
-
- uint8 id_hi = q1.id() >> 8, id_lo = q1.id() & 0xff;
-
- // See the top of the file for the description of a DNS query.
+TEST(DnsQueryTest, Constructor) {
+ // This includes \0 at the end.
+ const char qname_data[] = "\x03""www""\x07""example""\x03""com";
const uint8 query_data[] = {
// Header
- id_hi, id_lo,
+ 0xbe, 0xef,
0x01, 0x00, // Flags -- set RD (recursion desired) bit.
0x00, 0x01, // Set QDCOUNT (question count) to 1, all the
// rest are 0 for a query.
@@ -65,46 +28,42 @@ TEST(DnsQueryTest, ConstructorTest) {
0x00, 0x00,
// Question
- 0x03, 0x77, 0x77, 0x77, // QNAME: www.google.com in DNS format.
- 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
- 0x03, 0x63, 0x6f, 0x6d, 0x00,
+ 0x03, 'w', 'w', 'w', // QNAME: www.example.com in DNS format.
+ 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
0x00, 0x01, // QTYPE: A query.
0x00, 0x01, // QCLASS: IN class.
};
- int expected_size = arraysize(query_data);
- EXPECT_EQ(expected_size, q1.io_buffer()->size());
- EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, expected_size));
+ base::StringPiece qname(qname_data, sizeof(qname_data));
+ DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA);
+ EXPECT_EQ(dns_protocol::kTypeA, q1.qtype());
+
+ ASSERT_EQ(static_cast<int>(sizeof(query_data)), q1.io_buffer()->size());
+ EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, sizeof(query_data)));
+ EXPECT_EQ(qname, q1.qname());
+
+ base::StringPiece question(reinterpret_cast<const char*>(query_data) + 12,
+ 21);
+ EXPECT_EQ(question, q1.question());
}
-TEST(DnsQueryTest, CloneTest) {
- std::string kQname("\003www\006google\003com", 16);
- DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt));
+TEST(DnsQueryTest, Clone) {
+ // This includes \0 at the end.
+ const char qname_data[] = "\x03""www""\x07""example""\x03""com";
+ base::StringPiece qname(qname_data, sizeof(qname_data));
- scoped_ptr<DnsQuery> q2(q1.CloneWithNewId());
+ DnsQuery q1(0, qname, dns_protocol::kTypeA);
+ EXPECT_EQ(0, q1.id());
+ scoped_ptr<DnsQuery> q2(q1.CloneWithNewId(42));
+ EXPECT_EQ(42, q2->id());
EXPECT_EQ(q1.io_buffer()->size(), q2->io_buffer()->size());
EXPECT_EQ(q1.qtype(), q2->qtype());
- EXPECT_EQ(q1.question_size(), q2->question_size());
- EXPECT_EQ(0, memcmp(q1.question_data(), q2->question_data(),
- q1.question_size()));
+ EXPECT_EQ(q1.question(), q2->question());
}
-TEST(DnsQueryTest, RandomIdTest) {
- std::string kQname("\003www\006google\003com", 16);
-
- // Since id fields are 16-bit values, we iterate to reduce the
- // probability of collision, to avoid a flaky test.
- bool ids_are_random = false;
- for (int i = 0; i < 1000; ++i) {
- DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt));
- DnsQuery q2(kQname, kDNS_A, base::Bind(&base::RandInt));
- scoped_ptr<DnsQuery> q3(q1.CloneWithNewId());
- ids_are_random = q1.id () != q2.id() && q1.id() != q3->id();
- if (ids_are_random)
- break;
- }
- EXPECT_TRUE(ids_are_random);
-}
+} // namespace
} // namespace net
diff --git a/net/dns/dns_response.cc b/net/dns/dns_response.cc
index 3c5d605..805d6f3 100644
--- a/net/dns/dns_response.cc
+++ b/net/dns/dns_response.cc
@@ -4,97 +4,185 @@
#include "net/dns/dns_response.h"
-#include "net/base/dns_util.h"
+#include "net/base/big_endian.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
+#include "net/base/sys_byteorder.h"
+#include "net/dns/dns_protocol.h"
#include "net/dns/dns_query.h"
namespace net {
-// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512
-// bytes (not counting the IP nor UDP headers).
-static const int kMaxResponseSize = 512;
+DnsRecordParser::DnsRecordParser() : packet_(NULL), length_(0), cur_(0) {
+}
+
+DnsRecordParser::DnsRecordParser(const void* packet,
+ size_t length,
+ size_t offset)
+ : packet_(reinterpret_cast<const char*>(packet)),
+ length_(length),
+ cur_(packet_ + offset) {
+ DCHECK_LE(offset, length);
+}
-DnsResponse::DnsResponse(DnsQuery* query)
- : query_(query),
- io_buffer_(new IOBufferWithSize(kMaxResponseSize + 1)) {
- DCHECK(query_);
+int DnsRecordParser::ParseName(const void* const vpos, std::string* out) const {
+ const char* const pos = reinterpret_cast<const char*>(vpos);
+ DCHECK(packet_);
+ DCHECK_LE(packet_, pos);
+ DCHECK_LE(pos, packet_ + length_);
+
+ const char* p = pos;
+ const char* end = packet_ + length_;
+ // Count number of seen bytes to detect loops.
+ size_t seen = 0;
+ // Remember how many bytes were consumed before first jump.
+ size_t consumed = 0;
+
+ if (pos >= end)
+ return 0;
+
+ if (out) {
+ out->clear();
+ out->reserve(dns_protocol::kMaxNameLength);
+ }
+
+ for (;;) {
+ // The two couple of bits of the length give the type of the length. It's
+ // either a direct length or a pointer to the remainder of the name.
+ switch (*p & dns_protocol::kLabelMask) {
+ case dns_protocol::kLabelPointer: {
+ if (p + sizeof(uint16) > end)
+ return 0;
+ if (consumed == 0) {
+ consumed = p - pos + sizeof(uint16);
+ if (!out)
+ return consumed; // If name is not stored, that's all we need.
+ }
+ seen += sizeof(uint16);
+ // If seen the whole packet, then we must be in a loop.
+ if (seen > length_)
+ return 0;
+ uint16 offset;
+ ReadBigEndian<uint16>(p, &offset);
+ offset &= dns_protocol::kOffsetMask;
+ p = packet_ + offset;
+ if (p >= end)
+ return 0;
+ break;
+ }
+ case dns_protocol::kLabelDirect: {
+ uint8 label_len = *p;
+ ++p;
+ // Note: root domain (".") is NOT included.
+ if (label_len == 0) {
+ if (consumed == 0) {
+ consumed = p - pos;
+ } // else we set |consumed| before first jump
+ return consumed;
+ }
+ if (p + label_len >= end)
+ return 0; // Truncated or missing label.
+ if (out) {
+ if (!out->empty())
+ out->append(".");
+ out->append(p, label_len);
+ }
+ p += label_len;
+ seen += 1 + label_len;
+ break;
+ }
+ default:
+ // unhandled label type
+ return 0;
+ }
+ }
+}
+
+bool DnsRecordParser::ParseRecord(DnsResourceRecord* out) {
+ DCHECK(packet_);
+ size_t consumed = ParseName(cur_, &out->name);
+ if (!consumed)
+ return false;
+ BigEndianReader reader(cur_ + consumed,
+ packet_ + length_ - (cur_ + consumed));
+ uint16 rdlen;
+ if (reader.ReadU16(&out->type) &&
+ reader.ReadU16(&out->klass) &&
+ reader.ReadU32(&out->ttl) &&
+ reader.ReadU16(&rdlen) &&
+ reader.ReadPiece(&out->rdata, rdlen)) {
+ cur_ = reader.ptr();
+ return true;
+ }
+ return false;
+}
+
+DnsResponse::DnsResponse()
+ : io_buffer_(new IOBufferWithSize(dns_protocol::kMaxUDPSize + 1)) {
+}
+
+DnsResponse::DnsResponse(const void* data,
+ size_t length,
+ size_t answer_offset)
+ : io_buffer_(new IOBufferWithSize(length)),
+ parser_(io_buffer_->data(), length, answer_offset) {
+ memcpy(io_buffer_->data(), data, length);
}
DnsResponse::~DnsResponse() {
}
-int DnsResponse::Parse(int nbytes, IPAddressList* ip_addresses) {
+bool DnsResponse::InitParse(int nbytes, const DnsQuery& query) {
// Response includes query, it should be at least that size.
- if (nbytes < query_->io_buffer()->size() || nbytes > kMaxResponseSize)
- return ERR_DNS_MALFORMED_RESPONSE;
-
- DnsResponseBuffer response(reinterpret_cast<uint8*>(io_buffer_->data()),
- io_buffer_->size());
- uint16 id;
- if (!response.U16(&id) || id != query_->id()) // Make sure IDs match.
- return ERR_DNS_MALFORMED_RESPONSE;
-
- uint8 flags, rcode;
- if (!response.U8(&flags) || !response.U8(&rcode))
- return ERR_DNS_MALFORMED_RESPONSE;
-
- if (flags & 2) // TC is set -- server wants TCP, we don't support it (yet?).
- return ERR_DNS_SERVER_REQUIRES_TCP;
-
- rcode &= 0x0f; // 3 means NXDOMAIN, the rest means server failed.
- if (rcode && (rcode != 3))
- return ERR_DNS_SERVER_FAILED;
-
- uint16 query_count, answer_count, authority_count, additional_count;
- if (!response.U16(&query_count) ||
- !response.U16(&answer_count) ||
- !response.U16(&authority_count) ||
- !response.U16(&additional_count)) {
- return ERR_DNS_MALFORMED_RESPONSE;
+ if (nbytes < query.io_buffer()->size() || nbytes > dns_protocol::kMaxUDPSize)
+ return false;
+
+ // Match the query id.
+ if (ntohs(header()->id) != query.id())
+ return false;
+
+ // Match question count.
+ if (ntohs(header()->qdcount) != 1)
+ return false;
+
+ // Match the question section.
+ const size_t hdr_size = sizeof(dns_protocol::Header);
+ const base::StringPiece question = query.question();
+ if (question != base::StringPiece(io_buffer_->data() + hdr_size,
+ question.size())) {
+ return false;
}
- if (query_count != 1) // Sent a single question, shouldn't have changed.
- return ERR_DNS_MALFORMED_RESPONSE;
+ // Construct the parser.
+ parser_ = DnsRecordParser(io_buffer_->data(),
+ nbytes,
+ hdr_size + question.size());
+ return true;
+}
- base::StringPiece question; // Make sure question section is echoed back.
- if (!response.Block(&question, query_->question_size()) ||
- memcmp(question.data(), query_->question_data(),
- query_->question_size())) {
- return ERR_DNS_MALFORMED_RESPONSE;
- }
+uint8 DnsResponse::flags0() const {
+ return header()->flags[0];
+}
- if (answer_count < 1)
- return ERR_NAME_NOT_RESOLVED;
-
- IPAddressList rdatas;
- while (answer_count--) {
- uint32 ttl;
- uint16 rdlength, qtype, qclass;
- if (!response.DNSName(NULL) ||
- !response.U16(&qtype) ||
- !response.U16(&qclass) ||
- !response.U32(&ttl) ||
- !response.U16(&rdlength)) {
- return ERR_DNS_MALFORMED_RESPONSE;
- }
- if (qtype == query_->qtype() &&
- qclass == kClassIN &&
- (rdlength == kIPv4AddressSize || rdlength == kIPv6AddressSize)) {
- base::StringPiece rdata;
- if (!response.Block(&rdata, rdlength))
- return ERR_DNS_MALFORMED_RESPONSE;
- rdatas.push_back(IPAddressNumber(rdata.begin(), rdata.end()));
- } else if (!response.Skip(rdlength))
- return ERR_DNS_MALFORMED_RESPONSE;
- }
+uint8 DnsResponse::flags1() const {
+ return header()->flags[1] & ~(dns_protocol::kRcodeMask);
+}
+
+uint8 DnsResponse::rcode() const {
+ return header()->flags[1] & dns_protocol::kRcodeMask;
+}
- if (rdatas.empty())
- return ERR_NAME_NOT_RESOLVED;
+int DnsResponse::answer_count() const {
+ return ntohs(header()->ancount);
+}
+
+DnsRecordParser DnsResponse::Parser() const {
+ DCHECK(parser_.IsValid());
+ return parser_;
+}
- if (ip_addresses)
- ip_addresses->swap(rdatas);
- return OK;
+const dns_protocol::Header* DnsResponse::header() const {
+ return reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
}
} // namespace net
diff --git a/net/dns/dns_response.h b/net/dns/dns_response.h
index cc7c3f7..0fa3df2 100644
--- a/net/dns/dns_response.h
+++ b/net/dns/dns_response.h
@@ -6,43 +6,109 @@
#define NET_DNS_DNS_RESPONSE_H_
#pragma once
+#include <string>
+
+#include "base/basictypes.h"
#include "base/memory/ref_counted.h"
+#include "base/string_piece.h"
#include "net/base/net_export.h"
#include "net/base/net_util.h"
-namespace net{
+namespace net {
class DnsQuery;
class IOBufferWithSize;
-// Represents on-the-wire DNS response as an object; allows extracting
-// records.
+namespace dns_protocol {
+struct Header;
+}
+
+// Parsed resource record.
+struct NET_EXPORT_PRIVATE DnsResourceRecord {
+ std::string name; // in dotted form
+ uint16 type;
+ uint16 klass;
+ uint32 ttl;
+ base::StringPiece rdata; // points to the original response buffer
+};
+
+// Iterator to walk over resource records of the DNS response packet.
+class NET_EXPORT_PRIVATE DnsRecordParser {
+ public:
+ // Construct an uninitialized iterator.
+ DnsRecordParser();
+
+ // Construct an iterator to process the |packet| of given |length|.
+ // |offset| points to the beginning of the answer section.
+ DnsRecordParser(const void* packet, size_t length, size_t offset);
+
+ // Returns |true| if initialized.
+ bool IsValid() const { return packet_ != NULL; }
+
+ // Returns |true| if no more bytes remain in the packet.
+ bool AtEnd() const { return cur_ == packet_ + length_; }
+
+ // Parses a (possibly compressed) DNS name from the packet starting at
+ // |pos|. Stores output (even partial) in |out| unless |out| is NULL. |out|
+ // is stored in the dotted form, e.g., "example.com". Returns number of bytes
+ // consumed or 0 on failure.
+ // This is exposed to allow parsing compressed names within RRDATA for TYPEs
+ // such as NS, CNAME, PTR, MX, SOA.
+ // See RFC 1035 section 4.1.4.
+ int ParseName(const void* pos, std::string* out) const;
+
+ // Parses the next resource record. Returns true if succeeded.
+ bool ParseRecord(DnsResourceRecord* record);
+
+ private:
+ const char* packet_;
+ size_t length_;
+ // Current offset within the packet.
+ const char* cur_;
+};
+
+// Buffer-holder for the DNS response allowing easy access to the header fields
+// and resource records. After reading into |io_buffer| must call InitParse to
+// position the RR parser.
class NET_EXPORT_PRIVATE DnsResponse {
public:
// Constructs an object with an IOBuffer large enough to read
// one byte more than largest possible response, to detect malformed
- // responses; |query| is a pointer to the DnsQuery for which |this|
- // is supposed to be a response.
- explicit DnsResponse(DnsQuery* query);
+ // responses.
+ DnsResponse();
+ // Constructs response from |data|. Used for testing purposes only!
+ DnsResponse(const void* data, size_t length, size_t answer_offset);
~DnsResponse();
// Internal buffer accessor into which actual bytes of response will be
// read.
IOBufferWithSize* io_buffer() { return io_buffer_.get(); }
- // Parses response of size nbytes and puts address into |ip_addresses|,
- // returns net_error code in case of failure.
- int Parse(int nbytes, IPAddressList* ip_addresses);
+ // Returns false if the packet is shorter than the header or does not match
+ // |query| id or question.
+ bool InitParse(int nbytes, const DnsQuery& query);
+
+ // Accessors for the header.
+ uint8 flags0() const; // first byte of flags
+ uint8 flags1() const; // second byte of flags excluding rcode
+ uint8 rcode() const;
+ int answer_count() const;
+
+ // Returns an iterator to the resource records in the answer section. Must be
+ // called after InitParse. The iterator is valid only in the scope of the
+ // DnsResponse.
+ DnsRecordParser Parser() const;
private:
- // The matching query; |this| is the response for |query_|. We do not
- // own it, lifetime of |this| should be within the limits of lifetime of
- // |query_|.
- const DnsQuery* const query_;
+ // Convenience for header access.
+ const dns_protocol::Header* header() const;
// Buffer into which response bytes are read.
scoped_refptr<IOBufferWithSize> io_buffer_;
+ // Iterator constructed after InitParse positioned at the answer section.
+ DnsRecordParser parser_;
+
DISALLOW_COPY_AND_ASSIGN(DnsResponse);
};
diff --git a/net/dns/dns_response_unittest.cc b/net/dns/dns_response_unittest.cc
index 775cfc6..51d7cff 100644
--- a/net/dns/dns_response_unittest.cc
+++ b/net/dns/dns_response_unittest.cc
@@ -4,106 +4,176 @@
#include "net/dns/dns_response.h"
-#include "base/bind.h"
-#include "base/rand_util.h"
-#include "net/base/dns_util.h"
-#include "net/base/net_errors.h"
#include "net/base/io_buffer.h"
+#include "net/dns/dns_protocol.h"
#include "net/dns/dns_query.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
-// DNS response consists of a header followed by a question followed by
-// answer. Header format, question format and response format are
-// described below. For the meaning of specific fields, please see RFC
-// 1035.
-
-// Header format.
-// 1 1 1 1 1 1
-// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | ID |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | QDCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | ANCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | NSCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | ARCOUNT |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-
-// Question format.
-// 1 1 1 1 1 1
-// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | |
-// / QNAME /
-// / /
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | QTYPE |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | QCLASS |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-
-// Answser format.
-// 1 1 1 1 1 1
-// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | |
-// / /
-// / NAME /
-// | |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | TYPE |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | CLASS |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | TTL |
-// | |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-// | RDLENGTH |
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
-// / RDATA /
-// / /
-// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-
-// TODO(agayev): add more thorough tests.
-TEST(DnsResponseTest, ResponseWithCnameA) {
- const std::string kQname("\012codereview\010chromium\003org", 25);
- DnsQuery q1(kQname, kDNS_A, base::Bind(&base::RandInt));
-
- uint8 id_hi = q1.id() >> 8, id_lo = q1.id() & 0xff;
-
- uint8 ip[] = { // codereview.chromium.org resolves to
- 0x4a, 0x7d, 0x5f, 0x79 // 74.125.95.121
+namespace {
+
+TEST(DnsRecordParserTest, Constructor) {
+ const char data[] = { 0 };
+
+ EXPECT_FALSE(DnsRecordParser().IsValid());
+ EXPECT_TRUE(DnsRecordParser(data, 1, 0).IsValid());
+ EXPECT_TRUE(DnsRecordParser(data, 1, 1).IsValid());
+
+ EXPECT_FALSE(DnsRecordParser(data, 1, 0).AtEnd());
+ EXPECT_TRUE(DnsRecordParser(data, 1, 1).AtEnd());
+}
+
+TEST(DnsRecordParserTest, ParseName) {
+ const uint8 data[] = {
+ // all labels "foo.example.com"
+ 0x03, 'f', 'o', 'o',
+ 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ // byte 0x10
+ 0x00,
+ // byte 0x11
+ // part label, part pointer, "bar.example.com"
+ 0x03, 'b', 'a', 'r',
+ 0xc0, 0x04,
+ // byte 0x17
+ // all pointer to "bar.example.com", 2 jumps
+ 0xc0, 0x11,
+ // byte 0x1a
+ };
+
+ std::string out;
+ DnsRecordParser parser(data, sizeof(data), 0);
+ ASSERT_TRUE(parser.IsValid());
+
+ EXPECT_EQ(0x11, parser.ParseName(data + 0x00, &out));
+ EXPECT_EQ("foo.example.com", out);
+ // Check that the last "." is never stored.
+ out.clear();
+ EXPECT_EQ(0x1, parser.ParseName(data + 0x10, &out));
+ EXPECT_EQ("", out);
+ out.clear();
+ EXPECT_EQ(0x6, parser.ParseName(data + 0x11, &out));
+ EXPECT_EQ("bar.example.com", out);
+ out.clear();
+ EXPECT_EQ(0x2, parser.ParseName(data + 0x17, &out));
+ EXPECT_EQ("bar.example.com", out);
+
+ // Parse name without storing it.
+ EXPECT_EQ(0x11, parser.ParseName(data + 0x00, NULL));
+ EXPECT_EQ(0x1, parser.ParseName(data + 0x10, NULL));
+ EXPECT_EQ(0x6, parser.ParseName(data + 0x11, NULL));
+ EXPECT_EQ(0x2, parser.ParseName(data + 0x17, NULL));
+
+ // Check that it works even if initial position is different.
+ parser = DnsRecordParser(data, sizeof(data), 0x12);
+ EXPECT_EQ(0x6, parser.ParseName(data + 0x11, NULL));
+}
+
+TEST(DnsRecordParserTest, ParseNameFail) {
+ const uint8 data[] = {
+ // label length beyond packet
+ 0x30, 'x', 'x',
+ 0x00,
+ // pointer offset beyond packet
+ 0xc0, 0x20,
+ // pointer loop
+ 0xc0, 0x08,
+ 0xc0, 0x06,
+ // incorrect label type (currently supports only direct and pointer)
+ 0x80, 0x00,
+ // truncated name (missing root label)
+ 0x02, 'x', 'x',
+ };
+
+ DnsRecordParser parser(data, sizeof(data), 0);
+ ASSERT_TRUE(parser.IsValid());
+
+ std::string out;
+ EXPECT_EQ(0, parser.ParseName(data + 0x00, &out));
+ EXPECT_EQ(0, parser.ParseName(data + 0x04, &out));
+ EXPECT_EQ(0, parser.ParseName(data + 0x08, &out));
+ EXPECT_EQ(0, parser.ParseName(data + 0x0a, &out));
+ EXPECT_EQ(0, parser.ParseName(data + 0x0c, &out));
+ EXPECT_EQ(0, parser.ParseName(data + 0x0e, &out));
+}
+
+TEST(DnsRecordParserTest, ParseRecord) {
+ const uint8 data[] = {
+ // Type CNAME record.
+ 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x05, // TYPE is CNAME.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, 0x24, 0x74, // TTL is 0x00012474.
+ 0x00, 0x06, // RDLENGTH is 6 bytes.
+ 0x03, 'f', 'o', 'o', // compressed name in record
+ 0xc0, 0x00,
+ // Type A record.
+ 0x03, 'b', 'a', 'r', // compressed owner name
+ 0xc0, 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x20, 0x13, 0x55, // TTL is 0x00201355.
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x7f, 0x02, 0x04, 0x01, // IP is 127.2.4.1
};
- IPAddressList expected_ips;
- expected_ips.push_back(IPAddressNumber(ip, ip + arraysize(ip)));
+ std::string out;
+ DnsRecordParser parser(data, sizeof(data), 0);
+
+ DnsResourceRecord record;
+ EXPECT_TRUE(parser.ParseRecord(&record));
+ EXPECT_EQ("example.com", record.name);
+ EXPECT_EQ(dns_protocol::kTypeCNAME, record.type);
+ EXPECT_EQ(dns_protocol::kClassIN, record.klass);
+ EXPECT_EQ(0x00012474u, record.ttl);
+ EXPECT_EQ(6u, record.rdata.length());
+ EXPECT_EQ(6, parser.ParseName(record.rdata.data(), &out));
+ EXPECT_EQ("foo.example.com", out);
+ EXPECT_FALSE(parser.AtEnd());
+
+ EXPECT_TRUE(parser.ParseRecord(&record));
+ EXPECT_EQ("bar.example.com", record.name);
+ EXPECT_EQ(dns_protocol::kTypeA, record.type);
+ EXPECT_EQ(dns_protocol::kClassIN, record.klass);
+ EXPECT_EQ(0x00201355u, record.ttl);
+ EXPECT_EQ(4u, record.rdata.length());
+ EXPECT_EQ(base::StringPiece("\x7f\x02\x04\x01"), record.rdata);
+ EXPECT_TRUE(parser.AtEnd());
+
+ // Test truncated record.
+ parser = DnsRecordParser(data, sizeof(data) - 2, 0);
+ EXPECT_TRUE(parser.ParseRecord(&record));
+ EXPECT_FALSE(parser.AtEnd());
+ EXPECT_FALSE(parser.ParseRecord(&record));
+}
+
+TEST(DnsResponseTest, InitParse) {
+ // This includes \0 at the end.
+ const char qname_data[] = "\x0A""codereview""\x08""chromium""\x03""org";
+ const base::StringPiece qname(qname_data, sizeof(qname_data));
+ // Compilers want to copy when binding temporary to const &, so must use heap.
+ scoped_ptr<DnsQuery> query(new DnsQuery(0xcafe, qname, dns_protocol::kTypeA));
- uint8 response_data[] = {
+ const uint8 response_data[] = {
// Header
- id_hi, id_lo, // ID
- 0x81, 0x80, // Standard query response, no error
+ 0xca, 0xfe, // ID
+ 0x81, 0x80, // Standard query response, RA, no error
0x00, 0x01, // 1 question
0x00, 0x02, // 2 RRs (answers)
0x00, 0x00, // 0 authority RRs
0x00, 0x00, // 0 additional RRs
// Question
- 0x0a, 0x63, 0x6f, 0x64, // This part is echoed back from the
- 0x65, 0x72, 0x65, 0x76, // respective query.
- 0x69, 0x65, 0x77, 0x08,
- 0x63, 0x68, 0x72, 0x6f,
- 0x6d, 0x69, 0x75, 0x6d,
- 0x03, 0x6f, 0x72, 0x67,
+ // This part is echoed back from the respective query.
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x03, 'o', 'r', 'g',
0x00,
- 0x00, 0x01,
- 0x00, 0x01,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
// Answer 1
0xc0, 0x0c, // NAME is a pointer to name in Question section.
@@ -111,33 +181,56 @@ TEST(DnsResponseTest, ResponseWithCnameA) {
0x00, 0x01, // CLASS is IN.
0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
0x24, 0x74,
- 0x00, 0x12, // RDLENGTH is 18 bytse.
- 0x03, 0x67, 0x68, 0x73, // ghs.l.google.com in DNS format.
- 0x01, 0x6c, 0x06, 0x67,
- 0x6f, 0x6f, 0x67, 0x6c,
- 0x65, 0x03, 0x63, 0x6f,
- 0x6d, 0x00,
+ 0x00, 0x12, // RDLENGTH is 18 bytes.
+ // ghs.l.google.com in DNS format.
+ 0x03, 'g', 'h', 's',
+ 0x01, 'l',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
// Answer 2
- 0xc0, 0x35, // NAME is a pointer to name in Question section.
+ 0xc0, 0x35, // NAME is a pointer to name in Answer 1.
0x00, 0x01, // TYPE is A.
0x00, 0x01, // CLASS is IN.
0x00, 0x00, // TTL (4 bytes) is 53 seconds.
0x00, 0x35,
- 0x00, 0x04, // RDLENGTH is 4 bytes.
- ip[0], ip[1], ip[2], ip[3], // RDATA is the IP.
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121
+ 0x5f, 0x79,
};
- // Create a response object and simulate reading into it.
- DnsResponse r1(&q1);
- memcpy(r1.io_buffer()->data(), &response_data[0],
- r1.io_buffer()->size());
+ DnsResponse resp;
+ memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data));
- // Verify resolved IPs.
- int response_size = arraysize(response_data);
- IPAddressList actual_ips;
- EXPECT_EQ(OK, r1.Parse(response_size, &actual_ips));
- EXPECT_EQ(expected_ips, actual_ips);
+ // Reject too short.
+ EXPECT_FALSE(resp.InitParse(query->io_buffer()->size() - 1, *query));
+
+ // Reject wrong id.
+ scoped_ptr<DnsQuery> other_query(query->CloneWithNewId(0xbeef));
+ EXPECT_FALSE(resp.InitParse(sizeof(response_data), *other_query));
+
+ // Reject wrong question.
+ scoped_ptr<DnsQuery> wrong_query(
+ new DnsQuery(0xcafe, qname, dns_protocol::kTypeCNAME));
+ EXPECT_FALSE(resp.InitParse(sizeof(response_data), *wrong_query));
+
+ // Accept matching question.
+ EXPECT_TRUE(resp.InitParse(sizeof(response_data), *query));
+
+ // Check header access.
+ EXPECT_EQ(0x81, resp.flags0());
+ EXPECT_EQ(0x80, resp.flags1());
+ EXPECT_EQ(0x0, resp.rcode());
+ EXPECT_EQ(2, resp.answer_count());
+
+ DnsResourceRecord record;
+ DnsRecordParser parser = resp.Parser();
+ EXPECT_TRUE(parser.ParseRecord(&record));
+ EXPECT_TRUE(parser.ParseRecord(&record));
+ EXPECT_FALSE(parser.ParseRecord(&record));
}
+} // namespace
+
} // namespace net
diff --git a/net/dns/dns_session.cc b/net/dns/dns_session.cc
new file mode 100644
index 0000000..15cbaa2
--- /dev/null
+++ b/net/dns/dns_session.cc
@@ -0,0 +1,47 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_session.h"
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/time.h"
+#include "net/base/ip_endpoint.h"
+#include "net/dns/dns_config_service.h"
+#include "net/socket/client_socket_factory.h"
+
+namespace net {
+
+DnsSession::DnsSession(const DnsConfig& config,
+ ClientSocketFactory* factory,
+ const RandIntCallback& rand_int_callback,
+ NetLog* net_log)
+ : config_(config),
+ socket_factory_(factory),
+ rand_callback_(base::Bind(rand_int_callback, 0, kuint16max)),
+ net_log_(net_log),
+ server_index_(0) {
+}
+
+int DnsSession::NextId() const {
+ return rand_callback_.Run();
+}
+
+const IPEndPoint& DnsSession::NextServer() {
+ // TODO(szym): Rotate servers on failures.
+ const IPEndPoint& ipe = config_.nameservers[server_index_];
+ if (config_.rotate)
+ server_index_ = (server_index_ + 1) % config_.nameservers.size();
+ return ipe;
+}
+
+base::TimeDelta DnsSession::NextTimeout(int attempt) {
+ // TODO(szym): Adapt timeout to observed RTT.
+ return config_.timeout * (attempt + 1);
+}
+
+DnsSession::~DnsSession() {}
+
+} // namespace net
+
diff --git a/net/dns/dns_session.h b/net/dns/dns_session.h
new file mode 100644
index 0000000..df986ea
--- /dev/null
+++ b/net/dns/dns_session.h
@@ -0,0 +1,70 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_SESSION_H_
+#define NET_DNS_DNS_SESSION_H_
+#pragma once
+
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time.h"
+#include "net/base/net_export.h"
+#include "net/base/rand_callback.h"
+#include "net/dns/dns_config_service.h"
+
+namespace net {
+
+class ClientSocketFactory;
+class NetLog;
+
+// Session parameters and state shared between DNS transactions.
+// Ref-counted so that DnsClient::Request can keep working in absence of
+// DnsClient. A DnsSession must be recreated when DnsConfig changes.
+class NET_EXPORT_PRIVATE DnsSession
+ : NON_EXPORTED_BASE(public base::RefCounted<DnsSession>) {
+ public:
+ typedef base::Callback<int()> RandCallback;
+
+ DnsSession(const DnsConfig& config,
+ ClientSocketFactory* factory,
+ const RandIntCallback& rand_int_callback,
+ NetLog* net_log);
+
+ ClientSocketFactory* socket_factory() const { return socket_factory_.get(); }
+ const DnsConfig& config() const { return config_; }
+ NetLog* net_log() const { return net_log_; }
+
+ // Return the next random query ID.
+ int NextId() const;
+
+ // Return the next server address.
+ const IPEndPoint& NextServer();
+
+ // Return the timeout for the next transaction.
+ base::TimeDelta NextTimeout(int attempt);
+
+ private:
+ friend class base::RefCounted<DnsSession>;
+ ~DnsSession();
+
+ const DnsConfig config_;
+ scoped_ptr<ClientSocketFactory> socket_factory_;
+ RandCallback rand_callback_;
+ NetLog* net_log_;
+
+ // Current index into |config_.nameservers|.
+ int server_index_;
+
+ // TODO(szym): add current RTT estimate
+ // TODO(szym): add flag to indicate DNSSEC is supported
+ // TODO(szym): add TCP connection pool to support DNS over TCP
+ // TODO(szym): add UDP socket pool ?
+
+ DISALLOW_COPY_AND_ASSIGN(DnsSession);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_SESSION_H_
+
diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc
index ac63131..a396252 100644
--- a/net/dns/dns_test_util.cc
+++ b/net/dns/dns_test_util.cc
@@ -8,20 +8,6 @@
namespace net {
-TestPrng::TestPrng(const std::deque<int>& numbers) : numbers_(numbers) {
-}
-
-TestPrng::~TestPrng() {
-}
-
-int TestPrng::GetNext(int min, int max) {
- DCHECK(!numbers_.empty());
- int rv = numbers_.front();
- numbers_.pop_front();
- DCHECK(rv >= min && rv <= max);
- return rv;
-}
-
bool ConvertStringsToIPAddressList(
const char* const ip_strings[], size_t size, IPAddressList* address_list) {
DCHECK(address_list);
diff --git a/net/dns/dns_test_util.h b/net/dns/dns_test_util.h
index 651cfcb..c59a58f 100644
--- a/net/dns/dns_test_util.h
+++ b/net/dns/dns_test_util.h
@@ -14,26 +14,10 @@
#include "net/base/host_resolver.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_util.h"
+#include "net/dns/dns_protocol.h"
namespace net {
-// DNS related classes make use of PRNG for various tasks. This class is
-// used as a PRNG for unit testing those tasks. It takes a deque of
-// integers |numbers| which should be returned by calls to GetNext.
-class TestPrng {
- public:
- explicit TestPrng(const std::deque<int>& numbers);
- ~TestPrng();
-
- // Pops and returns the next number from |numbers_| deque.
- int GetNext(int min, int max);
-
- private:
- std::deque<int> numbers_;
-
- DISALLOW_COPY_AND_ASSIGN(TestPrng);
-};
-
// A utility function for tests that given an array of IP literals,
// converts it to an IPAddressList.
bool ConvertStringsToIPAddressList(
@@ -49,7 +33,7 @@ static const uint16 kDnsPort = 53;
//-----------------------------------------------------------------------------
// Query/response set for www.google.com, ID is fixed to 0.
static const char kT0HostName[] = "www.google.com";
-static const uint16 kT0Qtype = kDNS_A;
+static const uint16 kT0Qtype = dns_protocol::kTypeA;
static const char kT0DnsName[] = {
0x03, 'w', 'w', 'w',
0x06, 'g', 'o', 'o', 'g', 'l', 'e',
@@ -92,10 +76,10 @@ static const char* const kT0IpAddresses[] = {
//-----------------------------------------------------------------------------
// Query/response set for codereview.chromium.org, ID is fixed to 1.
static const char kT1HostName[] = "codereview.chromium.org";
-static const uint16 kT1Qtype = kDNS_A;
+static const uint16 kT1Qtype = dns_protocol::kTypeA;
static const char kT1DnsName[] = {
- 0x12, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
- 0x10, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
0x03, 'o', 'r', 'g',
0x00
};
@@ -130,10 +114,10 @@ static const char* const kT1IpAddresses[] = {
//-----------------------------------------------------------------------------
// Query/response set for www.ccs.neu.edu, ID is fixed to 2.
static const char kT2HostName[] = "www.ccs.neu.edu";
-static const uint16 kT2Qtype = kDNS_A;
+static const uint16 kT2Qtype = dns_protocol::kTypeA;
static const char kT2DnsName[] = {
0x03, 'w', 'w', 'w',
- 0x03, 'c', 'c', 'c',
+ 0x03, 'c', 'c', 's',
0x03, 'n', 'e', 'u',
0x03, 'e', 'd', 'u',
0x00
@@ -166,7 +150,7 @@ static const char* const kT2IpAddresses[] = {
//-----------------------------------------------------------------------------
// Query/response set for www.google.az, ID is fixed to 3.
static const char kT3HostName[] = "www.google.az";
-static const uint16 kT3Qtype = kDNS_A;
+static const uint16 kT3Qtype = dns_protocol::kTypeA;
static const char kT3DnsName[] = {
0x03, 'w', 'w', 'w',
0x06, 'g', 'o', 'o', 'g', 'l', 'e',
diff --git a/net/dns/dns_transaction.cc b/net/dns/dns_transaction.cc
index a7fa922..5cbe39a 100644
--- a/net/dns/dns_transaction.cc
+++ b/net/dns/dns_transaction.cc
@@ -7,11 +7,12 @@
#include "base/bind.h"
#include "base/rand_util.h"
#include "base/values.h"
-#include "net/base/dns_util.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
+#include "net/dns/dns_protocol.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
#include "net/socket/client_socket_factory.h"
#include "net/udp/datagram_client_socket.h"
@@ -19,104 +20,51 @@ namespace net {
namespace {
-// Retry timeouts.
-const int kTimeoutsMs[] = {3000, 5000, 11000};
-const int kMaxAttempts = arraysize(kTimeoutsMs);
-
-// Returns the string representation of an IPAddressNumber.
-std::string IPAddressToString(const IPAddressNumber& ip_address) {
- IPEndPoint ip_endpoint(ip_address, 0);
- struct sockaddr_storage addr;
- size_t addr_len = sizeof(addr);
- struct sockaddr* sockaddr = reinterpret_cast<struct sockaddr*>(&addr);
- if (!ip_endpoint.ToSockAddr(sockaddr, &addr_len))
- return "";
- return NetAddressToString(sockaddr, addr_len);
-}
-
-}
-
-DnsTransaction::Delegate::Delegate() {
-}
-
-DnsTransaction::Delegate::~Delegate() {
- while (!registered_transactions_.empty()) {
- DnsTransaction* transaction = *registered_transactions_.begin();
- transaction->SetDelegate(NULL);
- }
- DCHECK(registered_transactions_.empty());
-}
-
-void DnsTransaction::Delegate::OnTransactionComplete(
- int result,
- const DnsTransaction* transaction,
- const IPAddressList& ip_addresses) {
-}
-
-void DnsTransaction::Delegate::Attach(DnsTransaction* transaction) {
- DCHECK(registered_transactions_.find(transaction) ==
- registered_transactions_.end());
- registered_transactions_.insert(transaction);
-}
-
-void DnsTransaction::Delegate::Detach(DnsTransaction* transaction) {
- DCHECK(registered_transactions_.find(transaction) !=
- registered_transactions_.end());
- registered_transactions_.erase(transaction);
-}
-
-namespace {
-
class DnsTransactionStartParameters : public NetLog::EventParameters {
public:
DnsTransactionStartParameters(const IPEndPoint& dns_server,
- const DnsTransaction::Key& key,
+ const base::StringPiece& qname,
+ uint16 qtype,
const NetLog::Source& source)
- : dns_server_(dns_server), key_(key), source_(source) {}
+ : dns_server_(dns_server),
+ qname_(qname.data(), qname.length()),
+ qtype_(qtype),
+ source_(source) {}
virtual Value* ToValue() const {
- std::string hostname;
- DnsResponseBuffer(
- reinterpret_cast<const uint8*>(key_.first.c_str()), key_.first.size()).
- DNSName(&hostname);
-
DictionaryValue* dict = new DictionaryValue();
dict->SetString("dns_server", dns_server_.ToString());
- dict->SetString("hostname", hostname);
- dict->SetInteger("query_type", key_.second);
+ dict->SetString("hostname", qname_);
+ dict->SetInteger("query_type", qtype_);
if (source_.is_valid())
dict->Set("source_dependency", source_.ToValue());
return dict;
}
private:
- const IPEndPoint dns_server_;
- const DnsTransaction::Key key_;
+ IPEndPoint dns_server_;
+ std::string qname_;
+ uint16 qtype_;
const NetLog::Source source_;
};
class DnsTransactionFinishParameters : public NetLog::EventParameters {
public:
- DnsTransactionFinishParameters(int net_error,
- const IPAddressList& ip_address_list)
- : net_error_(net_error), ip_address_list_(ip_address_list) {}
+ // TODO(szym): add rcode ?
+ DnsTransactionFinishParameters(int net_error, int answer_count)
+ : net_error_(net_error), answer_count_(answer_count) {}
virtual Value* ToValue() const {
- ListValue* list = new ListValue();
- for (IPAddressList::const_iterator it = ip_address_list_.begin();
- it != ip_address_list_.end(); ++it)
- list->Append(Value::CreateStringValue(IPAddressToString(*it)));
-
DictionaryValue* dict = new DictionaryValue();
if (net_error_)
dict->SetInteger("net_error", net_error_);
- dict->Set("address_list", list);
+ dict->SetInteger("answer_count", answer_count_);
return dict;
}
private:
const int net_error_;
- const IPAddressList ip_address_list_;
+ const int answer_count_;
};
class DnsTransactionRetryParameters : public NetLog::EventParameters {
@@ -139,47 +87,32 @@ class DnsTransactionRetryParameters : public NetLog::EventParameters {
} // namespace
-DnsTransaction::DnsTransaction(const IPEndPoint& dns_server,
- const std::string& dns_name,
- uint16 query_type,
- const RandIntCallback& rand_int,
- ClientSocketFactory* socket_factory,
- const BoundNetLog& source_net_log,
- NetLog* net_log)
- : dns_server_(dns_server),
- key_(dns_name, query_type),
- delegate_(NULL),
- query_(new DnsQuery(dns_name, query_type, rand_int)),
+
+DnsTransaction::DnsTransaction(DnsSession* session,
+ const base::StringPiece& qname,
+ uint16 qtype,
+ const ResultCallback& callback,
+ const BoundNetLog& source_net_log)
+ : session_(session),
+ dns_server_(session->NextServer()),
+ query_(new DnsQuery(session->NextId(), qname, qtype)),
+ callback_(callback),
attempts_(0),
next_state_(STATE_NONE),
- socket_factory_(socket_factory ? socket_factory :
- ClientSocketFactory::GetDefaultFactory()),
ALLOW_THIS_IN_INITIALIZER_LIST(
io_callback_(this, &DnsTransaction::OnIOComplete)),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_DNS_TRANSACTION)) {
- DCHECK(!rand_int.is_null());
- for (size_t i = 0; i < arraysize(kTimeoutsMs); ++i)
- timeouts_ms_.push_back(base::TimeDelta::FromMilliseconds(kTimeoutsMs[i]));
+ net_log_(BoundNetLog::Make(session->net_log(),
+ NetLog::SOURCE_DNS_TRANSACTION)) {
net_log_.BeginEvent(
NetLog::TYPE_DNS_TRANSACTION,
make_scoped_refptr(
- new DnsTransactionStartParameters(dns_server_, key_,
+ new DnsTransactionStartParameters(dns_server_,
+ qname,
+ qtype,
source_net_log.source())));
}
-DnsTransaction::~DnsTransaction() {
- SetDelegate(NULL);
-}
-
-void DnsTransaction::SetDelegate(Delegate* delegate) {
- if (delegate == delegate_)
- return;
- if (delegate_)
- delegate_->Detach(this);
- delegate_ = delegate;
- if (delegate_)
- delegate_->Attach(this);
-}
+DnsTransaction::~DnsTransaction() {}
int DnsTransaction::Start() {
DCHECK_EQ(STATE_NONE, next_state_);
@@ -223,12 +156,12 @@ int DnsTransaction::DoLoop(int result) {
void DnsTransaction::DoCallback(int result) {
DCHECK_NE(result, ERR_IO_PENDING);
+ int answer_count = (result == OK) ? response()->answer_count() : 0;
net_log_.EndEvent(
NetLog::TYPE_DNS_TRANSACTION,
make_scoped_refptr(
- new DnsTransactionFinishParameters(result, ip_addresses_)));
- if (delegate_)
- delegate_->OnTransactionComplete(result, this, ip_addresses_);
+ new DnsTransactionFinishParameters(result, answer_count)));
+ callback_.Run(this, result);
}
void DnsTransaction::OnIOComplete(int result) {
@@ -240,13 +173,14 @@ void DnsTransaction::OnIOComplete(int result) {
int DnsTransaction::DoConnect() {
next_state_ = STATE_CONNECT_COMPLETE;
- DCHECK_LT(attempts_, timeouts_ms_.size());
- StartTimer(timeouts_ms_[attempts_]);
- attempts_++;
+ StartTimer(session_->NextTimeout(attempts_));
+ ++attempts_;
- // TODO(agayev): keep all sockets around in case the server responds
+ // TODO(szym): keep all sockets around in case the server responds
// after its timeout; state machine will need to change to handle that.
- socket_.reset(socket_factory_->CreateDatagramClientSocket(
+ // The current plan is to move socket management out to DnsSession.
+ // Hence also move retransmissions to DnsClient::Request.
+ socket_.reset(session_->socket_factory()->CreateDatagramClientSocket(
DatagramSocket::RANDOM_BIND,
base::Bind(&base::RandInt),
net_log_.net_log(),
@@ -281,7 +215,7 @@ int DnsTransaction::DoSendQueryComplete(int rv) {
// Writing to UDP should not result in a partial datagram.
if (rv != query_->io_buffer()->size())
- return ERR_NAME_NOT_RESOLVED;
+ return ERR_MSG_TOO_BIG;
next_state_ = STATE_READ_RESPONSE;
return OK;
@@ -289,7 +223,7 @@ int DnsTransaction::DoSendQueryComplete(int rv) {
int DnsTransaction::DoReadResponse() {
next_state_ = STATE_READ_RESPONSE_COMPLETE;
- response_.reset(new DnsResponse(query_.get()));
+ response_.reset(new DnsResponse());
return socket_->Read(response_->io_buffer(),
response_->io_buffer()->size(),
&io_callback_);
@@ -302,9 +236,21 @@ int DnsTransaction::DoReadResponseComplete(int rv) {
return rv;
DCHECK(rv);
- // TODO(agayev): when supporting EDNS0 we may need to do multiple reads
- // to read the whole response.
- return response_->Parse(rv, &ip_addresses_);
+ if (!response_->InitParse(rv, *query_))
+ return ERR_DNS_MALFORMED_RESPONSE;
+ // TODO(szym): define this flag value in dns_protocol
+ if (response_->flags1() & 2)
+ return ERR_DNS_SERVER_REQUIRES_TCP;
+ // TODO(szym): move this handling out of DnsTransaction?
+ if (response_->rcode() != dns_protocol::kRcodeNOERROR &&
+ response_->rcode() != dns_protocol::kRcodeNXDOMAIN) {
+ return ERR_DNS_SERVER_FAILED;
+ }
+ // TODO(szym): add ERR_DNS_RR_NOT_FOUND?
+ if (response_->answer_count() == 0)
+ return ERR_NAME_NOT_RESOLVED;
+
+ return OK;
}
void DnsTransaction::StartTimer(base::TimeDelta delay) {
@@ -318,21 +264,15 @@ void DnsTransaction::RevokeTimer() {
void DnsTransaction::OnTimeout() {
DCHECK(next_state_ == STATE_SEND_QUERY_COMPLETE ||
next_state_ == STATE_READ_RESPONSE_COMPLETE);
- if (attempts_ == timeouts_ms_.size()) {
+ if (attempts_ == session_->config().attempts) {
DoCallback(ERR_DNS_TIMED_OUT);
return;
}
next_state_ = STATE_CONNECT;
- query_.reset(query_->CloneWithNewId());
+ query_.reset(query_->CloneWithNewId(session_->NextId()));
int rv = DoLoop(OK);
if (rv != ERR_IO_PENDING)
DoCallback(rv);
}
-void DnsTransaction::set_timeouts_ms(
- const std::vector<base::TimeDelta>& timeouts_ms) {
- DCHECK_EQ(0u, attempts_);
- timeouts_ms_ = timeouts_ms;
-}
-
} // namespace net
diff --git a/net/dns/dns_transaction.h b/net/dns/dns_transaction.h
index c126193..d4078f0 100644
--- a/net/dns/dns_transaction.h
+++ b/net/dns/dns_transaction.h
@@ -6,12 +6,10 @@
#define NET_DNS_DNS_TRANSACTION_H_
#pragma once
-#include <set>
#include <string>
-#include <utility>
#include <vector>
-#include "base/gtest_prod_util.h"
+#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/timer.h"
#include "base/threading/non_thread_safe.h"
@@ -23,70 +21,39 @@
namespace net {
-class ClientSocketFactory;
class DatagramClientSocket;
class DnsQuery;
class DnsResponse;
+class DnsSession;
-// Performs (with fixed retries) a single asynchronous DNS transaction,
+// Performs a single asynchronous DNS transaction over UDP,
// which consists of sending out a DNS query, waiting for a response, and
-// parsing and returning the IP addresses that it matches.
+// returning the response that it matches.
class NET_EXPORT_PRIVATE DnsTransaction :
NON_EXPORTED_BASE(public base::NonThreadSafe) {
public:
- typedef std::pair<std::string, uint16> Key;
-
- // Interface that should be implemented by DnsTransaction consumers and
- // passed to the |Start| method to be notified when the transaction has
- // completed.
- class NET_EXPORT_PRIVATE Delegate {
- public:
- Delegate();
- virtual ~Delegate();
-
- // A consumer of DnsTransaction should override |OnTransactionComplete|
- // and call |set_delegate(this)|. The method will be called once the
- // resolution has completed, results passed in as arguments.
- virtual void OnTransactionComplete(
- int result,
- const DnsTransaction* transaction,
- const IPAddressList& ip_addresses);
-
- private:
- friend class DnsTransaction;
-
- void Attach(DnsTransaction* transaction);
- void Detach(DnsTransaction* transaction);
-
- std::set<DnsTransaction*> registered_transactions_;
+ typedef base::Callback<void(DnsTransaction*, int)> ResultCallback;
+
+ // Create new transaction using the parameters and state in |session|.
+ // Issues query for name |qname| (in DNS format) type |qtype| and class IN.
+ // Calls |callback| on completion or timeout.
+ // TODO(szym): change dependency to (IPEndPoint, Socket, DnsQuery, callback)
+ DnsTransaction(DnsSession* session,
+ const base::StringPiece& qname,
+ uint16 qtype,
+ const ResultCallback& callback,
+ const BoundNetLog& source_net_log);
+ ~DnsTransaction();
- DISALLOW_COPY_AND_ASSIGN(Delegate);
- };
+ const DnsQuery* query() const { return query_.get(); }
- // |dns_server| is the address of the DNS server, |dns_name| is the
- // hostname (in DNS format) to be resolved, |query_type| is the type of
- // the query, either kDNS_A or kDNS_AAAA, |rand_int| is the PRNG used for
- // generating DNS query.
- DnsTransaction(const IPEndPoint& dns_server,
- const std::string& dns_name,
- uint16 query_type,
- const RandIntCallback& rand_int,
- ClientSocketFactory* socket_factory,
- const BoundNetLog& source_net_log,
- NetLog* net_log);
- ~DnsTransaction();
- void SetDelegate(Delegate* delegate);
- const Key& key() const { return key_; }
+ const DnsResponse* response() const { return response_.get(); }
// Starts the resolution process. Will return ERR_IO_PENDING and will
// notify the caller via |delegate|. Should only be called once.
int Start();
private:
- FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, FirstTimeoutTest);
- FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, SecondTimeoutTest);
- FRIEND_TEST_ALL_PREFIXES(DnsTransactionTest, ThirdTimeoutTest);
-
enum State {
STATE_CONNECT,
STATE_CONNECT_COMPLETE,
@@ -114,26 +81,17 @@ class NET_EXPORT_PRIVATE DnsTransaction :
void RevokeTimer();
void OnTimeout();
- // This is to be used by unit tests only.
- void set_timeouts_ms(const std::vector<base::TimeDelta>& timeouts_ms);
-
- const IPEndPoint dns_server_;
- Key key_;
- IPAddressList ip_addresses_;
- Delegate* delegate_;
-
+ scoped_refptr<DnsSession> session_;
+ IPEndPoint dns_server_;
scoped_ptr<DnsQuery> query_;
+ ResultCallback callback_;
scoped_ptr<DnsResponse> response_;
scoped_ptr<DatagramClientSocket> socket_;
// Number of retry attempts so far.
- size_t attempts_;
-
- // Timeouts in milliseconds.
- std::vector<base::TimeDelta> timeouts_ms_;
+ int attempts_;
State next_state_;
- ClientSocketFactory* socket_factory_;
base::OneShotTimer<DnsTransaction> timer_;
OldCompletionCallbackImpl<DnsTransaction> io_callback_;
diff --git a/net/dns/dns_transaction_unittest.cc b/net/dns/dns_transaction_unittest.cc
index a99f55b3..3e3c2fe 100644
--- a/net/dns/dns_transaction_unittest.cc
+++ b/net/dns/dns_transaction_unittest.cc
@@ -8,6 +8,12 @@
#include <vector>
#include "base/bind.h"
+#include "base/test/test_timeouts.h"
+#include "base/time.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
#include "net/dns/dns_test_util.h"
#include "net/socket/socket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -16,43 +22,73 @@ namespace net {
namespace {
-static const base::TimeDelta kTimeoutsMs[] = {
- base::TimeDelta::FromMilliseconds(20),
- base::TimeDelta::FromMilliseconds(20),
- base::TimeDelta::FromMilliseconds(20),
-};
+// A mock for RandIntCallback that always returns 0.
+int ReturnZero(int min, int max) {
+ return 0;
+}
-} // namespace
+class DnsTransactionTest : public testing::Test {
+ protected:
+ virtual void SetUp() OVERRIDE {
+ callback_ = base::Bind(&DnsTransactionTest::OnTransactionComplete,
+ base::Unretained(this));
+ qname_ = std::string(kT0DnsName, arraysize(kT0DnsName));
+ // Use long timeout to prevent timing out on slow bots.
+ ConfigureSession(base::TimeDelta::FromMilliseconds(
+ TestTimeouts::action_timeout_ms()));
+ }
+
+ void ConfigureSession(const base::TimeDelta& timeout) {
+ IPEndPoint dns_server;
+ bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
+ ASSERT_TRUE(rv);
+
+ DnsConfig config;
+ config.nameservers.push_back(dns_server);
+ config.attempts = 3;
+ config.timeout = timeout;
+
+ session_ = new DnsSession(config,
+ new MockClientSocketFactory(),
+ base::Bind(&ReturnZero),
+ NULL /* NetLog */);
+ }
+
+ void StartTransaction() {
+ transaction_.reset(new DnsTransaction(session_.get(),
+ qname_,
+ kT0Qtype,
+ callback_,
+ BoundNetLog()));
+
+ int rv0 = transaction_->Start();
+ EXPECT_EQ(ERR_IO_PENDING, rv0);
+ }
-class TestDelegate : public DnsTransaction::Delegate {
- public:
- TestDelegate() : result_(ERR_UNEXPECTED), transaction_(NULL) {}
- virtual ~TestDelegate() {}
- virtual void OnTransactionComplete(
- int result,
- const DnsTransaction* transaction,
- const IPAddressList& ip_addresses) {
- result_ = result;
- transaction_ = transaction;
- ip_addresses_ = ip_addresses;
+ void OnTransactionComplete(DnsTransaction* transaction, int rv) {
+ EXPECT_EQ(transaction_.get(), transaction);
+ EXPECT_EQ(qname_, transaction->query()->qname().as_string());
+ EXPECT_EQ(kT0Qtype, transaction->query()->qtype());
+ rv_ = rv;
MessageLoop::current()->Quit();
}
- int result() const { return result_; }
- const DnsTransaction* transaction() const { return transaction_; }
- const IPAddressList& ip_addresses() const {
- return ip_addresses_;
+
+ MockClientSocketFactory& factory() {
+ return *static_cast<MockClientSocketFactory*>(session_->socket_factory());
}
+ int rv() const { return rv_; }
+
private:
- int result_;
- const DnsTransaction* transaction_;
- IPAddressList ip_addresses_;
+ DnsTransaction::ResultCallback callback_;
+ std::string qname_;
+ scoped_refptr<DnsSession> session_;
+ scoped_ptr<DnsTransaction> transaction_;
- DISALLOW_COPY_AND_ASSIGN(TestDelegate);
+ int rv_;
};
-
-TEST(DnsTransactionTest, NormalQueryResponseTest) {
+TEST_F(DnsTransactionTest, NormalQueryResponseTest) {
MockWrite writes0[] = {
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram))
@@ -65,45 +101,19 @@ TEST(DnsTransactionTest, NormalQueryResponseTest) {
StaticSocketDataProvider data(reads0, arraysize(reads0),
writes0, arraysize(writes0));
- MockClientSocketFactory factory;
- factory.AddSocketDataProvider(&data);
-
- TestPrng test_prng(std::deque<int>(1, 0));
- RandIntCallback rand_int_cb =
- base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng));
- std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName));
-
- IPEndPoint dns_server;
- bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
- ASSERT_TRUE(rv);
-
- DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory,
- BoundNetLog(), NULL);
-
- TestDelegate delegate;
- t.SetDelegate(&delegate);
-
- IPAddressList expected_ip_addresses;
- rv = ConvertStringsToIPAddressList(kT0IpAddresses,
- arraysize(kT0IpAddresses),
- &expected_ip_addresses);
- ASSERT_TRUE(rv);
-
- int rv0 = t.Start();
- EXPECT_EQ(ERR_IO_PENDING, rv0);
+ factory().AddSocketDataProvider(&data);
+ StartTransaction();
MessageLoop::current()->Run();
- EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key());
- EXPECT_EQ(OK, delegate.result());
- EXPECT_EQ(&t, delegate.transaction());
- EXPECT_TRUE(expected_ip_addresses == delegate.ip_addresses());
+ EXPECT_EQ(OK, rv());
+ // TODO(szym): test fields of |transaction_->response()|
EXPECT_TRUE(data.at_read_eof());
EXPECT_TRUE(data.at_write_eof());
}
-TEST(DnsTransactionTest, MismatchedQueryResponseTest) {
+TEST_F(DnsTransactionTest, MismatchedQueryResponseTest) {
MockWrite writes0[] = {
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram))
@@ -116,40 +126,20 @@ TEST(DnsTransactionTest, MismatchedQueryResponseTest) {
StaticSocketDataProvider data(reads1, arraysize(reads1),
writes0, arraysize(writes0));
- MockClientSocketFactory factory;
- factory.AddSocketDataProvider(&data);
-
- TestPrng test_prng(std::deque<int>(1, 0));
- RandIntCallback rand_int_cb =
- base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng));
- std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName));
-
- IPEndPoint dns_server;
- bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
- ASSERT_TRUE(rv);
-
- DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory,
- BoundNetLog(), NULL);
-
- TestDelegate delegate;
- t.SetDelegate(&delegate);
-
- int rv0 = t.Start();
- EXPECT_EQ(ERR_IO_PENDING, rv0);
+ factory().AddSocketDataProvider(&data);
+ StartTransaction();
MessageLoop::current()->Run();
- EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key());
- EXPECT_EQ(ERR_DNS_MALFORMED_RESPONSE, delegate.result());
- EXPECT_EQ(0u, delegate.ip_addresses().size());
- EXPECT_EQ(&t, delegate.transaction());
+ EXPECT_EQ(ERR_DNS_MALFORMED_RESPONSE, rv());
+
EXPECT_TRUE(data.at_read_eof());
EXPECT_TRUE(data.at_write_eof());
}
// Test that after the first timeout we do a fresh connection and if we get
// a response on the new connection, we return it.
-TEST(DnsTransactionTest, FirstTimeoutTest) {
+TEST_F(DnsTransactionTest, FirstTimeoutTest) {
MockWrite writes0[] = {
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram))
@@ -165,57 +155,30 @@ TEST(DnsTransactionTest, FirstTimeoutTest) {
scoped_refptr<DelayedSocketData> socket1_data(
new DelayedSocketData(0, reads0, arraysize(reads0),
writes0, arraysize(writes0)));
- MockClientSocketFactory factory;
- factory.AddSocketDataProvider(socket0_data.get());
- factory.AddSocketDataProvider(socket1_data.get());
-
- TestPrng test_prng(std::deque<int>(2, 0));
- RandIntCallback rand_int_cb =
- base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng));
- std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName));
-
- IPEndPoint dns_server;
- bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
- ASSERT_TRUE(rv);
- DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory,
- BoundNetLog(), NULL);
-
- TestDelegate delegate;
- t.SetDelegate(&delegate);
-
- t.set_timeouts_ms(
- std::vector<base::TimeDelta>(kTimeoutsMs,
- kTimeoutsMs + arraysize(kTimeoutsMs)));
-
- IPAddressList expected_ip_addresses;
- rv = ConvertStringsToIPAddressList(kT0IpAddresses,
- arraysize(kT0IpAddresses),
- &expected_ip_addresses);
- ASSERT_TRUE(rv);
-
- int rv0 = t.Start();
- EXPECT_EQ(ERR_IO_PENDING, rv0);
+ // Use short timeout to speed up the test.
+ ConfigureSession(base::TimeDelta::FromMilliseconds(
+ TestTimeouts::tiny_timeout_ms()));
+ factory().AddSocketDataProvider(socket0_data.get());
+ factory().AddSocketDataProvider(socket1_data.get());
+ StartTransaction();
MessageLoop::current()->Run();
- EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key());
- EXPECT_EQ(OK, delegate.result());
- EXPECT_EQ(&t, delegate.transaction());
- EXPECT_TRUE(expected_ip_addresses == delegate.ip_addresses());
+ EXPECT_EQ(OK, rv());
EXPECT_TRUE(socket0_data->at_read_eof());
EXPECT_TRUE(socket0_data->at_write_eof());
EXPECT_TRUE(socket1_data->at_read_eof());
EXPECT_TRUE(socket1_data->at_write_eof());
- EXPECT_EQ(2u, factory.udp_client_sockets().size());
+ EXPECT_EQ(2u, factory().udp_client_sockets().size());
}
// Test that after the first timeout we do a fresh connection, and after
// the second timeout we do another fresh connection, and if we get a
// response on the second connection, we return it.
-TEST(DnsTransactionTest, SecondTimeoutTest) {
+TEST_F(DnsTransactionTest, SecondTimeoutTest) {
MockWrite writes0[] = {
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram))
@@ -233,45 +196,19 @@ TEST(DnsTransactionTest, SecondTimeoutTest) {
scoped_refptr<DelayedSocketData> socket2_data(
new DelayedSocketData(0, reads0, arraysize(reads0),
writes0, arraysize(writes0)));
- MockClientSocketFactory factory;
- factory.AddSocketDataProvider(socket0_data.get());
- factory.AddSocketDataProvider(socket1_data.get());
- factory.AddSocketDataProvider(socket2_data.get());
-
- TestPrng test_prng(std::deque<int>(3, 0));
- RandIntCallback rand_int_cb =
- base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng));
- std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName));
- IPEndPoint dns_server;
- bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
- ASSERT_TRUE(rv);
+ // Use short timeout to speed up the test.
+ ConfigureSession(base::TimeDelta::FromMilliseconds(
+ TestTimeouts::tiny_timeout_ms()));
+ factory().AddSocketDataProvider(socket0_data.get());
+ factory().AddSocketDataProvider(socket1_data.get());
+ factory().AddSocketDataProvider(socket2_data.get());
- DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory,
- BoundNetLog(), NULL);
-
- TestDelegate delegate;
- t.SetDelegate(&delegate);
-
- t.set_timeouts_ms(
- std::vector<base::TimeDelta>(kTimeoutsMs,
- kTimeoutsMs + arraysize(kTimeoutsMs)));
-
- IPAddressList expected_ip_addresses;
- rv = ConvertStringsToIPAddressList(kT0IpAddresses,
- arraysize(kT0IpAddresses),
- &expected_ip_addresses);
- ASSERT_TRUE(rv);
-
- int rv0 = t.Start();
- EXPECT_EQ(ERR_IO_PENDING, rv0);
+ StartTransaction();
MessageLoop::current()->Run();
- EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT1Qtype) == t.key());
- EXPECT_EQ(OK, delegate.result());
- EXPECT_EQ(&t, delegate.transaction());
- EXPECT_TRUE(expected_ip_addresses == delegate.ip_addresses());
+ EXPECT_EQ(OK, rv());
EXPECT_TRUE(socket0_data->at_read_eof());
EXPECT_TRUE(socket0_data->at_write_eof());
@@ -279,13 +216,13 @@ TEST(DnsTransactionTest, SecondTimeoutTest) {
EXPECT_TRUE(socket1_data->at_write_eof());
EXPECT_TRUE(socket2_data->at_read_eof());
EXPECT_TRUE(socket2_data->at_write_eof());
- EXPECT_EQ(3u, factory.udp_client_sockets().size());
+ EXPECT_EQ(3u, factory().udp_client_sockets().size());
}
// Test that after the first timeout we do a fresh connection, and after
// the second timeout we do another fresh connection and after the third
// timeout we give up and return a timeout error.
-TEST(DnsTransactionTest, ThirdTimeoutTest) {
+TEST_F(DnsTransactionTest, ThirdTimeoutTest) {
MockWrite writes0[] = {
MockWrite(true, reinterpret_cast<const char*>(kT0QueryDatagram),
arraysize(kT0QueryDatagram))
@@ -297,38 +234,19 @@ TEST(DnsTransactionTest, ThirdTimeoutTest) {
new DelayedSocketData(2, NULL, 0, writes0, arraysize(writes0)));
scoped_refptr<DelayedSocketData> socket2_data(
new DelayedSocketData(2, NULL, 0, writes0, arraysize(writes0)));
- MockClientSocketFactory factory;
- factory.AddSocketDataProvider(socket0_data.get());
- factory.AddSocketDataProvider(socket1_data.get());
- factory.AddSocketDataProvider(socket2_data.get());
-
- TestPrng test_prng(std::deque<int>(3, 0));
- RandIntCallback rand_int_cb =
- base::Bind(&TestPrng::GetNext, base::Unretained(&test_prng));
- std::string t0_dns_name(kT0DnsName, arraysize(kT0DnsName));
- IPEndPoint dns_server;
- bool rv = CreateDnsAddress(kDnsIp, kDnsPort, &dns_server);
- ASSERT_TRUE(rv);
+ // Use short timeout to speed up the test.
+ ConfigureSession(base::TimeDelta::FromMilliseconds(
+ TestTimeouts::tiny_timeout_ms()));
+ factory().AddSocketDataProvider(socket0_data.get());
+ factory().AddSocketDataProvider(socket1_data.get());
+ factory().AddSocketDataProvider(socket2_data.get());
- DnsTransaction t(dns_server, t0_dns_name, kT1Qtype, rand_int_cb, &factory,
- BoundNetLog(), NULL);
-
- TestDelegate delegate;
- t.SetDelegate(&delegate);
-
- t.set_timeouts_ms(
- std::vector<base::TimeDelta>(kTimeoutsMs,
- kTimeoutsMs + arraysize(kTimeoutsMs)));
-
- int rv0 = t.Start();
- EXPECT_EQ(ERR_IO_PENDING, rv0);
+ StartTransaction();
MessageLoop::current()->Run();
- EXPECT_TRUE(DnsTransaction::Key(t0_dns_name, kT0Qtype) == t.key());
- EXPECT_EQ(ERR_DNS_TIMED_OUT, delegate.result());
- EXPECT_EQ(&t, delegate.transaction());
+ EXPECT_EQ(ERR_DNS_TIMED_OUT, rv());
EXPECT_TRUE(socket0_data->at_read_eof());
EXPECT_TRUE(socket0_data->at_write_eof());
@@ -336,7 +254,9 @@ TEST(DnsTransactionTest, ThirdTimeoutTest) {
EXPECT_TRUE(socket1_data->at_write_eof());
EXPECT_TRUE(socket2_data->at_read_eof());
EXPECT_TRUE(socket2_data->at_write_eof());
- EXPECT_EQ(3u, factory.udp_client_sockets().size());
+ EXPECT_EQ(3u, factory().udp_client_sockets().size());
}
+} // namespace
+
} // namespace net