diff options
author | noamsml@chromium.org <noamsml@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-06-19 03:43:31 +0000 |
---|---|---|
committer | noamsml@chromium.org <noamsml@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-06-19 03:43:31 +0000 |
commit | 5e6f42734932258c9d22ecb3065ff787fd3b7b58 (patch) | |
tree | 305b6b5b4e859e4d32f276e98192b936512c6159 /net | |
parent | e07f7b7b919c7c54bdc4c7b2433c2b79cefdb986 (diff) | |
download | chromium_src-5e6f42734932258c9d22ecb3065ff787fd3b7b58.zip chromium_src-5e6f42734932258c9d22ecb3065ff787fd3b7b58.tar.gz chromium_src-5e6f42734932258c9d22ecb3065ff787fd3b7b58.tar.bz2 |
MDnsClient: Process all records before alerting listeners
Change the notification structure in the MDnsListener to notify listeners after
the whole packet is processed. This should allow listeners to issue cache
queries for other records in the same packet in a synchronous manner.
In order to not issue invalid notifications in case the packet contains multiple
records that cancel each other out (and thus delete each other), store cache
keys in the list of notifications to send, rather than storing record pointers
directly.
BUG=
TEST=MDnsCacheTest.*,MDnsTest.*
Review URL: https://chromiumcodereview.appspot.com/17379009
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@207161 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/dns/mdns_cache.cc | 27 | ||||
-rw-r--r-- | net/dns/mdns_cache.h | 61 | ||||
-rw-r--r-- | net/dns/mdns_client_impl.cc | 31 | ||||
-rw-r--r-- | net/dns/mdns_client_unittest.cc | 100 |
4 files changed, 176 insertions, 43 deletions
diff --git a/net/dns/mdns_cache.cc b/net/dns/mdns_cache.cc index cb7dc117..f210d0ab 100644 --- a/net/dns/mdns_cache.cc +++ b/net/dns/mdns_cache.cc @@ -59,6 +59,14 @@ bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const { return type_ == key.type_ && name_ == key.name_ && optional_ == key.optional_; } +// static +MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) { + return Key(record->type(), + record->name(), + GetOptionalFieldForRecord(record)); +} + + MDnsCache::MDnsCache() { } @@ -71,14 +79,19 @@ void MDnsCache::Clear() { STLDeleteValues(&mdns_cache_); } +const RecordParsed* MDnsCache::LookupKey(const Key& key) { + RecordMap::iterator found = mdns_cache_.find(key); + if (found != mdns_cache_.end()) { + return found->second; + } + return NULL; +} + MDnsCache::UpdateType MDnsCache::UpdateDnsRecord( scoped_ptr<const RecordParsed> record) { UpdateType type = NoChange; - MDnsCache::Key cache_key = MDnsCache::Key( - record->type(), - record->name(), - GetOptionalFieldForRecord(record.get())); + Key cache_key = Key::CreateFor(record.get()); base::Time expiration = GetEffectiveExpiration(record.get()); if (next_expiration_ == base::Time() || expiration < next_expiration_) { @@ -155,8 +168,9 @@ void MDnsCache::FindDnsRecords(unsigned type, } } +// static std::string MDnsCache::GetOptionalFieldForRecord( - const RecordParsed* record) const { + const RecordParsed* record) { switch (record->type()) { case PtrRecordRdata::kType: { const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>(); @@ -167,7 +181,8 @@ std::string MDnsCache::GetOptionalFieldForRecord( } } -base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) const { +// static +base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) { base::TimeDelta ttl; if (record->ttl()) { diff --git a/net/dns/mdns_cache.h b/net/dns/mdns_cache.h index 4505b85..373917e 100644 --- a/net/dns/mdns_cache.h +++ b/net/dns/mdns_cache.h @@ -24,7 +24,32 @@ class RecordParsed; // guaranteed not to return expired records. It also has facilities for timely // record expiration. class NET_EXPORT_PRIVATE MDnsCache { - public: +public: + // Key type for the record map. It is a 3-tuple of type, name and optional + // value ordered by type, then name, then optional value. This allows us to + // query for all records of a certain type and name, while also allowing us + // to set records of a certain type, name and optionally value as unique. + class Key { + public: + Key(unsigned type, const std::string& name, const std::string& optional); + Key(const Key&); + Key& operator=(const Key&); + ~Key(); + bool operator<(const Key& key) const; + bool operator==(const Key& key) const; + + unsigned type() const { return type_; } + const std::string& name() const { return name_; } + const std::string& optional() const { return optional_; } + + // Create the cache key corresponding to |record|. + static Key CreateFor(const RecordParsed* record); + private: + unsigned type_; + std::string name_; + std::string optional_; + }; + typedef base::Callback<void(const RecordParsed*)> RecordRemovedCallback; enum UpdateType { @@ -41,6 +66,10 @@ class NET_EXPORT_PRIVATE MDnsCache { // previously with same value). UpdateType UpdateDnsRecord(scoped_ptr<const RecordParsed> record); + // Check cache for record with key |key|. Return the record if it exists, or + // NULL if it doesn't. + const RecordParsed* LookupKey(const Key& key); + // Return records with type |type| and name |name|. Expired records will not // be returned. If |name| is empty, return all records with type |type|. void FindDnsRecords(unsigned type, @@ -60,41 +89,19 @@ class NET_EXPORT_PRIVATE MDnsCache { void Clear(); - private: - // Key type for the record map. It is a 3-tuple of type, name and optional - // value ordered by type, then name, then optional value. This allows us to - // query for all records of a certain type and name, while also allowing us - // to set records of a certain type, name and optionally value as unique. - class Key { - public: - Key(unsigned type, const std::string& name, const std::string& optional); - Key(const Key&); - Key& operator=(const Key&); - ~Key(); - bool operator<(const Key& key) const; - bool operator==(const Key& key) const; - - unsigned type() const { return type_; } - const std::string& name() const { return name_; } - const std::string& optional() const { return optional_; } - - private: - unsigned type_; - std::string name_; - std::string optional_; - }; - +private: typedef std::map<Key, const RecordParsed*> RecordMap; // Get the effective expiration of a cache entry, based on its creation time // and TTL. Does adjustments so entries with a TTL of zero will have a // nonzero TTL, as explained in RFC 6762 Section 10.1. - base::Time GetEffectiveExpiration(const RecordParsed* entry) const; + static base::Time GetEffectiveExpiration(const RecordParsed* entry); // Get optional part of the DNS key for shared records. For example, in PTR // records this is the pointed domain, since multiple PTR records may exist // for the same name. - std::string GetOptionalFieldForRecord(const RecordParsed* record) const; + static std::string GetOptionalFieldForRecord( + const RecordParsed* record); RecordMap mdns_cache_; diff --git a/net/dns/mdns_client_impl.cc b/net/dns/mdns_client_impl.cc index 16852e9..4aac454 100644 --- a/net/dns/mdns_client_impl.cc +++ b/net/dns/mdns_client_impl.cc @@ -213,6 +213,10 @@ bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, int bytes_read) { unsigned offset; + // Note: We store cache keys rather than record pointers to avoid + // erroneous behavior in case a packet contains multiple exclusive + // records with the same type and name. + std::map<MDnsCache::Key, MDnsListener::UpdateType> update_keys; if (!response->InitParseWithoutQuery(bytes_read)) { LOG(WARNING) << "Could not understand an mDNS packet."; @@ -229,10 +233,10 @@ void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, for (unsigned i = 0; i < answer_count; i++) { offset = parser.GetOffset(); - scoped_ptr<const RecordParsed> scoped_record = RecordParsed::CreateFrom( + scoped_ptr<const RecordParsed> record = RecordParsed::CreateFrom( &parser, base::Time::Now()); - if (!scoped_record) { + if (!record) { LOG(WARNING) << "Could not understand an mDNS record."; if (offset == parser.GetOffset()) { @@ -243,16 +247,14 @@ void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, } } - if ((scoped_record->klass() & dns_protocol::kMDnsClassMask) != + if ((record->klass() & dns_protocol::kMDnsClassMask) != dns_protocol::kClassIN) { LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring."; continue; // Ignore all records not in the IN class. } - // We want to retain a copy of the record pointer for updating listeners - // but we are passing ownership to the cache. - const RecordParsed* record = scoped_record.get(); - MDnsCache::UpdateType update = cache_.UpdateDnsRecord(scoped_record.Pass()); + MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get()); + MDnsCache::UpdateType update = cache_.UpdateDnsRecord(record.Pass()); // Cleanup time may have changed. ScheduleCleanup(cache_.next_expiration()); @@ -275,10 +277,19 @@ void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, break; } - AlertListeners(update_external, - ListenerKey(record->type(), record->name()), record); + update_keys.insert(std::make_pair(update_key, update_external)); + } + } + + for (std::map<MDnsCache::Key, MDnsListener::UpdateType>::iterator i = + update_keys.begin(); i != update_keys.end(); i++) { + const RecordParsed* record = cache_.LookupKey(i->first); + if (record) { + AlertListeners(i->second, ListenerKey(record->type(), record->name()), + record); // Alert listeners listening only for rrtype and not for name. - AlertListeners(update_external, ListenerKey(record->type(), ""), record); + AlertListeners(i->second, ListenerKey(record->type(), ""), + record); } } } diff --git a/net/dns/mdns_client_unittest.cc b/net/dns/mdns_client_unittest.cc index aa1ea90..5b53d43 100644 --- a/net/dns/mdns_client_unittest.cc +++ b/net/dns/mdns_client_unittest.cc @@ -119,6 +119,40 @@ const char kCorruptedPacketUnsalvagable[] = { 0x08, '_', 'p', 'r', // Useless trailing data. }; +const char kCorruptedPacketDoubleRecord[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x06, 'p', 'r', 'i', 'v', 'e', 't', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x04, // RDLENGTH is 4 + 0x05, 0x03, + 0xc0, 0x0c, + + // Answer 2 -- Same key + 0x06, 'p', 'r', 'i', 'v', 'e', 't', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x04, // RDLENGTH is 4 + 0x02, 0x03, + 0x04, 0x05, +}; + const char kCorruptedPacketSalvagable[] = { // Header 0x00, 0x00, // ID is zeroed out @@ -411,6 +445,10 @@ class MDnsTest : public ::testing::Test { MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result, const RecordParsed* record)); + MOCK_METHOD2(MockableRecordCallback2, void(MDnsTransaction::Result result, + const RecordParsed* record)); + + protected: void ExpectPacket(const char* packet, unsigned size); void SimulatePacketReceive(const char* packet, unsigned size); @@ -839,6 +877,41 @@ TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) { EXPECT_EQ(NULL, transaction_.get()); } +TEST_F(MDnsTest, TransactionReentrantCacheLookupStart) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction1 = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + scoped_ptr<MDnsTransaction> transaction2 = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_printer._tcp.local", + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback2, + base::Unretained(this))); + + EXPECT_CALL(*this, MockableRecordCallback2(MDnsTransaction::RESULT_RECORD, + _)) + .Times(Exactly(1)); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, + _)) + .Times(Exactly(1)) + .WillOnce(IgnoreResult(InvokeWithoutArgs(transaction2.get(), + &MDnsTransaction::Start))); + + ASSERT_TRUE(transaction1->Start()); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); +} + // In order to reliably test reentrant listener deletes, we create two listeners // and have each of them delete both, so we're guaranteed to try and deliver a // callback to at least one deleted listener. @@ -870,6 +943,33 @@ TEST_F(MDnsTest, ListenerReentrantDelete) { EXPECT_EQ(NULL, listener2_.get()); } +ACTION_P(SaveIPAddress, ip_container) { + ::testing::StaticAssertTypeEq<const RecordParsed*, arg1_type>(); + ::testing::StaticAssertTypeEq<IPAddressNumber*, ip_container_type>(); + + *ip_container = arg1->template rdata<ARecordRdata>()->address(); +} + +TEST_F(MDnsTest, DoubleRecordDisagreeing) { + IPAddressNumber address; + StrictMock<MockListenerDelegate> delegate_privet; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypeA, "privet.local", &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(SaveIPAddress(&address)); + + SimulatePacketReceive(kCorruptedPacketDoubleRecord, + sizeof(kCorruptedPacketDoubleRecord)); + + EXPECT_EQ("2.3.4.5", IPAddressToString(address)); +} + + // Note: These tests assume that the ipv4 socket will always be created first. // This is a simplifying assumption based on the way the code works now. |