diff options
Diffstat (limited to 'net/dns/dns_test_util.cc')
-rw-r--r-- | net/dns/dns_test_util.cc | 157 |
1 files changed, 101 insertions, 56 deletions
diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc index 051f595..73d208c 100644 --- a/net/dns/dns_test_util.cc +++ b/net/dns/dns_test_util.cc @@ -14,6 +14,7 @@ #include "net/base/dns_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "net/dns/address_sorter.h" #include "net/dns/dns_client.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_protocol.h" @@ -25,19 +26,29 @@ namespace net { namespace { -// A DnsTransaction which responds with loopback to all queries starting with -// "ok", fails synchronously on all queries starting with "er", and NXDOMAIN to -// all others. +// A DnsTransaction which uses MockDnsClientRuleList to determine the response. class MockTransaction : public DnsTransaction, public base::SupportsWeakPtr<MockTransaction> { public: - MockTransaction(const std::string& hostname, + MockTransaction(const MockDnsClientRuleList& rules, + const std::string& hostname, uint16 qtype, const DnsTransactionFactory::CallbackType& callback) - : hostname_(hostname), + : result_(MockDnsClientRule::FAIL_SYNC), + hostname_(hostname), qtype_(qtype), callback_(callback), started_(false) { + // Find the relevant rule which matches |qtype| and prefix of |hostname|. + for (size_t i = 0; i < rules.size(); ++i) { + const std::string& prefix = rules[i].prefix; + if ((rules[i].qtype == qtype) && + (hostname.size() >= prefix.size()) && + (hostname.compare(0, prefix.size(), prefix) == 0)) { + result_ = rules[i].result; + break; + } + } } virtual const std::string& GetHostname() const OVERRIDE { @@ -51,7 +62,7 @@ class MockTransaction : public DnsTransaction, virtual int Start() OVERRIDE { EXPECT_FALSE(started_); started_ = true; - if (hostname_.substr(0, 2) == "er") + if (MockDnsClientRule::FAIL_SYNC == result_) return ERR_NAME_NOT_RESOLVED; // Using WeakPtr to cleanly cancel when transaction is destroyed. MessageLoop::current()->PostTask( @@ -62,54 +73,66 @@ class MockTransaction : public DnsTransaction, private: void Finish() { - if (hostname_.substr(0, 2) == "ok") { - std::string qname; - DNSDomainFromDot(hostname_, &qname); - DnsQuery query(0, qname, qtype_); - - DnsResponse response; - char* buffer = response.io_buffer()->data(); - int nbytes = query.io_buffer()->size(); - memcpy(buffer, query.io_buffer()->data(), nbytes); - - const uint16 kPointerToQueryName = - static_cast<uint16>(0xc000 | sizeof(net::dns_protocol::Header)); - - const uint32 kTTL = 86400; // One day. - - // Size of RDATA which is a IPv4 or IPv6 address. - size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? - net::kIPv4AddressSize : net::kIPv6AddressSize; - - // 12 is the sum of sizes of the compressed name reference, TYPE, - // CLASS, TTL and RDLENGTH. - size_t answer_size = 12 + rdata_size; - - // Write answer with loopback IP address. - reinterpret_cast<dns_protocol::Header*>(buffer)->ancount = - base::HostToNet16(1); - BigEndianWriter writer(buffer + nbytes, answer_size); - writer.WriteU16(kPointerToQueryName); - writer.WriteU16(qtype_); - writer.WriteU16(net::dns_protocol::kClassIN); - writer.WriteU32(kTTL); - writer.WriteU16(rdata_size); - if (qtype_ == net::dns_protocol::kTypeA) { - char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; - writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); - } else { - char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1 }; - writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); - } - - EXPECT_TRUE(response.InitParse(nbytes + answer_size, query)); - callback_.Run(this, OK, &response); - } else { - callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + switch (result_) { + case MockDnsClientRule::EMPTY: + case MockDnsClientRule::OK: { + std::string qname; + DNSDomainFromDot(hostname_, &qname); + DnsQuery query(0, qname, qtype_); + + DnsResponse response; + char* buffer = response.io_buffer()->data(); + int nbytes = query.io_buffer()->size(); + memcpy(buffer, query.io_buffer()->data(), nbytes); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(buffer); + header->flags |= dns_protocol::kFlagResponse; + + if (MockDnsClientRule::OK == result_) { + const uint16 kPointerToQueryName = + static_cast<uint16>(0xc000 | sizeof(*header)); + + const uint32 kTTL = 86400; // One day. + + // Size of RDATA which is a IPv4 or IPv6 address. + size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? + net::kIPv4AddressSize : net::kIPv6AddressSize; + + // 12 is the sum of sizes of the compressed name reference, TYPE, + // CLASS, TTL and RDLENGTH. + size_t answer_size = 12 + rdata_size; + + // Write answer with loopback IP address. + header->ancount = base::HostToNet16(1); + BigEndianWriter writer(buffer + nbytes, answer_size); + writer.WriteU16(kPointerToQueryName); + writer.WriteU16(qtype_); + writer.WriteU16(net::dns_protocol::kClassIN); + writer.WriteU32(kTTL); + writer.WriteU16(rdata_size); + if (qtype_ == net::dns_protocol::kTypeA) { + char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; + writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); + } else { + char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1 }; + writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); + } + nbytes += answer_size; + } + EXPECT_TRUE(response.InitParse(nbytes, query)); + callback_.Run(this, OK, &response); + } break; + case MockDnsClientRule::FAIL_ASYNC: + callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + break; + default: + NOTREACHED(); + break; } } + MockDnsClientRule::Result result_; const std::string hostname_; const uint16 qtype_; DnsTransactionFactory::CallbackType callback_; @@ -120,7 +143,8 @@ class MockTransaction : public DnsTransaction, // A DnsTransactionFactory which creates MockTransaction. class MockTransactionFactory : public DnsTransactionFactory { public: - MockTransactionFactory() {} + explicit MockTransactionFactory(const MockDnsClientRuleList& rules) + : rules_(rules) {} virtual ~MockTransactionFactory() {} virtual scoped_ptr<DnsTransaction> CreateTransaction( @@ -129,14 +153,29 @@ class MockTransactionFactory : public DnsTransactionFactory { const DnsTransactionFactory::CallbackType& callback, const BoundNetLog&) OVERRIDE { return scoped_ptr<DnsTransaction>( - new MockTransaction(hostname, qtype, callback)); + new MockTransaction(rules_, hostname, qtype, callback)); + } + + private: + MockDnsClientRuleList rules_; +}; + +class MockAddressSorter : public AddressSorter { + public: + virtual ~MockAddressSorter() {} + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + // Do nothing. + callback.Run(true, list); } }; // MockDnsClient provides MockTransactionFactory. class MockDnsClient : public DnsClient { public: - explicit MockDnsClient(const DnsConfig& config) : config_(config) {} + MockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) + : config_(config), factory_(rules) {} virtual ~MockDnsClient() {} virtual void SetConfig(const DnsConfig& config) OVERRIDE { @@ -151,16 +190,22 @@ class MockDnsClient : public DnsClient { return config_.IsValid() ? &factory_ : NULL; } + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return &address_sorter_; + } + private: DnsConfig config_; MockTransactionFactory factory_; + MockAddressSorter address_sorter_; }; } // namespace // static -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config) { - return scoped_ptr<DnsClient>(new MockDnsClient(config)); +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) { + return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); } MockDnsConfigService::~MockDnsConfigService() { |