diff options
author | noamsml@chromium.org <noamsml@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-06-13 22:31:42 +0000 |
---|---|---|
committer | noamsml@chromium.org <noamsml@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-06-13 22:31:42 +0000 |
commit | 245b164eab1be2a0e993a438c62c5733ec2a32e5 (patch) | |
tree | fa4175a36f31d8133f4ba601d1e10987aeb59c9e /net | |
parent | d23741d1102ab02aca4123dcb55ae125e6f3c463 (diff) | |
download | chromium_src-245b164eab1be2a0e993a438c62c5733ec2a32e5.zip chromium_src-245b164eab1be2a0e993a438c62c5733ec2a32e5.tar.gz chromium_src-245b164eab1be2a0e993a438c62c5733ec2a32e5.tar.bz2 |
Multicast DNS implementation (initial)
An implementation of multicast DNS in net/. Currently, the multicast DNS
implementation supports the following features:
- Listeners, which can be notified of any changes to records of a certain type
and name, or records of a certain type.
- Transactions, which allow users to query multicast DNS for a specific unique
record (e.g. an address)
BUG=233821
TEST=MDnsTest.*,MDnsConnectionTest.*
Review URL: https://chromiumcodereview.appspot.com/15733008
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@206170 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/dns/dns_protocol.h | 5 | ||||
-rw-r--r-- | net/dns/dns_query.cc | 6 | ||||
-rw-r--r-- | net/dns/dns_query.h | 2 | ||||
-rw-r--r-- | net/dns/dns_response.cc | 5 | ||||
-rw-r--r-- | net/dns/dns_response.h | 2 | ||||
-rw-r--r-- | net/dns/mdns_client.cc | 26 | ||||
-rw-r--r-- | net/dns/mdns_client.h | 156 | ||||
-rw-r--r-- | net/dns/mdns_client_impl.cc | 599 | ||||
-rw-r--r-- | net/dns/mdns_client_impl.h | 288 | ||||
-rw-r--r-- | net/dns/mdns_client_unittest.cc | 1070 | ||||
-rw-r--r-- | net/net.gyp | 11 |
11 files changed, 2170 insertions, 0 deletions
diff --git a/net/dns/dns_protocol.h b/net/dns/dns_protocol.h index b65f56d..4516b29 100644 --- a/net/dns/dns_protocol.h +++ b/net/dns/dns_protocol.h @@ -13,6 +13,7 @@ namespace net { namespace dns_protocol { static const uint16 kDefaultPort = 53; +static const uint16 kDefaultPortMulticast = 5353; // DNS packet consists of a header followed by questions and/or answers. // For the meaning of specific fields, please see RFC 1035 and 2535 @@ -101,6 +102,10 @@ static const int kMaxNameLength = 255; // bytes (not counting the IP nor UDP headers). static const int kMaxUDPSize = 512; +// RFC 6762, section 17: Messages over the local link are restricted by the +// medium's MTU, and must be under 9000 bytes +static const int kMaxMulticastSize = 9000; + // DNS class types. static const uint16 kClassIN = 1; diff --git a/net/dns/dns_query.cc b/net/dns/dns_query.cc index 72e97cf..270757e 100644 --- a/net/dns/dns_query.cc +++ b/net/dns/dns_query.cc @@ -80,4 +80,10 @@ DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) { header->id = base::HostToNet16(id); } +void DnsQuery::set_flags(uint16 flags) { + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + header->flags = flags; +} + } // namespace net diff --git a/net/dns/dns_query.h b/net/dns/dns_query.h index a2ed868..e1469bd 100644 --- a/net/dns/dns_query.h +++ b/net/dns/dns_query.h @@ -38,6 +38,8 @@ class NET_EXPORT_PRIVATE DnsQuery { // IOBuffer accessor to be used for writing out the query. IOBufferWithSize* io_buffer() const { return io_buffer_.get(); } + void set_flags(uint16 flags); + private: DnsQuery(const DnsQuery& orig, uint16 id); diff --git a/net/dns/dns_response.cc b/net/dns/dns_response.cc index 7daf5ff..d29d3c4 100644 --- a/net/dns/dns_response.cc +++ b/net/dns/dns_response.cc @@ -231,6 +231,11 @@ unsigned DnsResponse::answer_count() const { return base::NetToHost16(header()->ancount); } +unsigned DnsResponse::additional_answer_count() const { + DCHECK(parser_.IsValid()); + return base::NetToHost16(header()->arcount); +} + base::StringPiece DnsResponse::qname() const { DCHECK(parser_.IsValid()); // The response is HEADER QNAME QTYPE QCLASS ANSWER. diff --git a/net/dns/dns_response.h b/net/dns/dns_response.h index 4a445d9..76c2215 100644 --- a/net/dns/dns_response.h +++ b/net/dns/dns_response.h @@ -130,7 +130,9 @@ class NET_EXPORT_PRIVATE DnsResponse { // Accessors for the header. uint16 flags() const; // excluding rcode uint8 rcode() const; + unsigned answer_count() const; + unsigned additional_answer_count() const; // Accessors to the question. The qname is unparsed. base::StringPiece qname() const; diff --git a/net/dns/mdns_client.cc b/net/dns/mdns_client.cc new file mode 100644 index 0000000..d4cf470 --- /dev/null +++ b/net/dns/mdns_client.cc @@ -0,0 +1,26 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mdns_client.h" + +#include "net/dns/mdns_client_impl.h" + +namespace net { + +static MDnsClient* g_instance = NULL; + +MDnsClient* MDnsClient::GetInstance() { + if (!g_instance) { + g_instance = + new MDnsClientImpl(MDnsConnection::SocketFactory::CreateDefault()); + } + + return g_instance; +} + +void MDnsClient::SetInstance(MDnsClient* instance) { + g_instance = instance; +} + +} // namespace net diff --git a/net/dns/mdns_client.h b/net/dns/mdns_client.h new file mode 100644 index 0000000..4e4255d --- /dev/null +++ b/net/dns/mdns_client.h @@ -0,0 +1,156 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MDNS_CLIENT_H_ +#define NET_DNS_MDNS_CLIENT_H_ + +#include <string> +#include <vector> + +#include "base/callback.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/record_parsed.h" + +namespace net { + +class RecordParsed; + +// Represents a one-time record lookup. A transaction takes one +// associated callback (see |MDnsClient::CreateTransaction|) and calls it +// whenever a matching record has been found, either from the cache or +// by querying the network (it may choose to query either or both based on its +// creation flags, see MDnsTransactionFlags). Network-based transactions will +// time out after a reasonable number of seconds. +class MDnsTransaction { + public: + // Used to signify what type of result the transaction has recieved. + enum Result { + // Passed whenever a record is found. + RESULT_RECORD, + // The transaction is done. Applies to non-single-valued transactions. Is + // called when the transaction has finished (this is the last call to the + // callback). + RESULT_DONE, + // No results have been found. Applies to single-valued transactions. Is + // called when the transaction has finished without finding any results. + // For transactions that use the network, this happens when a timeout + // occurs, for transactions that are cache-only, this happens when no + // results are in the cache. + RESULT_NO_RESULTS, + // Called when an NSec record is read for this transaction's + // query. This means there cannot possibly be a record of the type + // and name for this transaction. + RESULT_NSEC + }; + + // Used when creating an MDnsTransaction. + enum Flags { + // Transaction should return only one result, and stop listening after it. + // Note that single result transactions will signal when their timeout is + // reached, whereas multi-result transactions will not. + SINGLE_RESULT = 1 << 0, + // Query the cache or the network. May both be used. One must be present. + QUERY_CACHE = 1 << 1, + QUERY_NETWORK = 1 << 2, + // TODO(noamsml): Add flag for flushing cache when feature is implemented + // Mask of all possible flags on MDnsTransaction. + FLAG_MASK = (1 << 3) - 1, + }; + + typedef base::Callback<void(Result, const RecordParsed*)> + ResultCallback; + + // Destroying the transaction cancels it. + virtual ~MDnsTransaction() {} + + // Start the transaction. Return true on success. Cache-based transactions + // will execute the callback synchronously. + virtual bool Start() = 0; + + // Get the host or service name for the transaction. + virtual const std::string& GetName() const = 0; + + // Get the type for this transaction (SRV, TXT, A, AAA, etc) + virtual uint16 GetType() const = 0; +}; + +// A listener listens for updates regarding a specific record or set of records. +// Created by the MDnsClient (see |MDnsClient::CreateListener|) and used to keep +// track of listeners. +class MDnsListener { + public: + // Used in the MDnsListener delegate to signify what type of change has been + // made to a record. + enum UpdateType { + RECORD_ADDED, + RECORD_CHANGED, + RECORD_REMOVED + }; + + class Delegate { + public: + virtual ~Delegate() {} + + // Called when a record is added, removed or updated. + virtual void OnRecordUpdate(UpdateType update, + const RecordParsed* record) = 0; + + // Called when a record is marked nonexistent by an NSEC record. + virtual void OnNsecRecord(const std::string& name, unsigned type) = 0; + + // Called when the cache is purged (due, for example, ot the network + // disconnecting). + virtual void OnCachePurged() = 0; + }; + + // Destroying the listener stops listening. + virtual ~MDnsListener() {} + + // Start the listener. Return true on success. + virtual bool Start() = 0; + + // Get the host or service name for this query. + // Return an empty string for no name. + virtual const std::string& GetName() const = 0; + + // Get the type for this query (SRV, TXT, A, AAA, etc) + virtual uint16 GetType() const = 0; +}; + +// Listens for Multicast DNS on the local network. You can access information +// regarding multicast DNS either by creating an |MDnsListener| to be notified +// of new records, or by creating an |MDnsTransaction| to look up the value of a +// specific records. When all listeners and active transactions are destroyed, +// the client stops listening on the network and destroys the cache. +class MDnsClient { + public: + virtual ~MDnsClient() {} + + // Create listener object for RRType |rrtype| and name |name|. If |name| is + // an empty string, listen to all notification of type |rrtype|. + virtual scoped_ptr<MDnsListener> CreateListener( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate) = 0; + + // Create a transaction that can be used to query either the MDns cache, the + // network, or both for records of type |rrtype| and name |name|. |flags| is + // defined by MDnsTransactionFlags. + virtual scoped_ptr<MDnsTransaction> CreateTransaction( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback) = 0; + + // Lazily create and return static instance for MDnsClient. + static MDnsClient* GetInstance(); + + // Set the global instance (for testing). MUST be called before the first call + // to GetInstance. + static void SetInstance(MDnsClient* instance); +}; + +} // namespace net +#endif // NET_DNS_MDNS_CLIENT_H_ diff --git a/net/dns/mdns_client_impl.cc b/net/dns/mdns_client_impl.cc new file mode 100644 index 0000000..16852e9 --- /dev/null +++ b/net/dns/mdns_client_impl.cc @@ -0,0 +1,599 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mdns_client_impl.h" + +#include "base/bind.h" +#include "base/message_loop_proxy.h" +#include "base/stl_util.h" +#include "base/time/default_clock.h" +#include "net/base/dns_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/rand_callback.h" +#include "net/dns/dns_protocol.h" +#include "net/udp/datagram_socket.h" + +namespace net { + +namespace { +const char kMDnsMulticastGroupIPv4[] = "224.0.0.251"; +const char kMDnsMulticastGroupIPv6[] = "FF02::FB"; +const unsigned MDnsTransactionTimeoutSeconds = 3; +} + +MDnsConnection::SocketHandler::SocketHandler( + MDnsConnection* connection, const IPEndPoint& multicast_addr, + MDnsConnection::SocketFactory* socket_factory) + : socket_(socket_factory->CreateSocket()), connection_(connection), + response_(new DnsResponse(dns_protocol::kMaxMulticastSize)), + multicast_addr_(multicast_addr) { +} + +MDnsConnection::SocketHandler::~SocketHandler() { +} + +int MDnsConnection::SocketHandler::Start() { + int rv = BindSocket(); + if (rv != OK) { + return rv; + } + + return DoLoop(0); +} + +int MDnsConnection::SocketHandler::DoLoop(int rv) { + do { + if (rv > 0) + connection_->OnDatagramReceived(response_.get(), recv_addr_, rv); + + rv = socket_->RecvFrom( + response_->io_buffer(), + response_->io_buffer()->size(), + &recv_addr_, + base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived, + base::Unretained(this))); + } while (rv > 0); + + if (rv != ERR_IO_PENDING) + return rv; + + return OK; +} + +void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) { + if (rv >= OK) + rv = DoLoop(rv); + + if (rv != OK) + connection_->OnError(this, rv); +} + +int MDnsConnection::SocketHandler::Send(IOBuffer* buffer, unsigned size) { + return socket_->SendTo( + buffer, size, multicast_addr_, + base::Bind(&MDnsConnection::SocketHandler::SendDone, + base::Unretained(this) )); +} + +void MDnsConnection::SocketHandler::SendDone(int rv) { + // TODO(noamsml): Retry logic. +} + +int MDnsConnection::SocketHandler::BindSocket() { + IPAddressNumber address_any(multicast_addr_.address().size()); + + IPEndPoint bind_endpoint(address_any, multicast_addr_.port()); + + socket_->AllowAddressReuse(); + int rv = socket_->Listen(bind_endpoint); + + if (rv < OK) return rv; + + socket_->SetMulticastLoopbackMode(false); + + return socket_->JoinGroup(multicast_addr_.address()); +} + +MDnsConnection::MDnsConnection(MDnsConnection::SocketFactory* socket_factory, + MDnsConnection::Delegate* delegate) + : socket_handler_ipv4_(this, + GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4), + socket_factory), + socket_handler_ipv6_(this, + GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6), + socket_factory), + delegate_(delegate) { +} + +MDnsConnection::~MDnsConnection() { +} + +int MDnsConnection::Init() { + int rv; + + rv = socket_handler_ipv4_.Start(); + if (rv != OK) return rv; + rv = socket_handler_ipv6_.Start(); + if (rv != OK) return rv; + + return OK; +} + +int MDnsConnection::Send(IOBuffer* buffer, unsigned size) { + int rv; + + rv = socket_handler_ipv4_.Send(buffer, size); + if (rv < OK && rv != ERR_IO_PENDING) return rv; + + rv = socket_handler_ipv6_.Send(buffer, size); + if (rv < OK && rv != ERR_IO_PENDING) return rv; + + return OK; +} + +void MDnsConnection::OnError(SocketHandler* loop, + int error) { + // TODO(noamsml): Specific handling of intermittent errors that can be handled + // in the connection. + delegate_->OnConnectionError(error); +} + +IPEndPoint MDnsConnection::GetMDnsIPEndPoint(const char* address) { + IPAddressNumber multicast_group_number; + bool success = ParseIPLiteralToNumber(address, + &multicast_group_number); + DCHECK(success); + return IPEndPoint(multicast_group_number, + dns_protocol::kDefaultPortMulticast); +} + +void MDnsConnection::OnDatagramReceived( + DnsResponse* response, + const IPEndPoint& recv_addr, + int bytes_read) { + // TODO(noamsml): More sophisticated error handling. + DCHECK_GT(bytes_read, 0); + delegate_->HandlePacket(response, bytes_read); +} + +class MDnsConnectionSocketFactoryImpl + : public MDnsConnection::SocketFactory { + public: + MDnsConnectionSocketFactoryImpl(); + virtual ~MDnsConnectionSocketFactoryImpl(); + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE; +}; + +MDnsConnectionSocketFactoryImpl::MDnsConnectionSocketFactoryImpl() { +} + +MDnsConnectionSocketFactoryImpl::~MDnsConnectionSocketFactoryImpl() { +} + +scoped_ptr<DatagramServerSocket> +MDnsConnectionSocketFactoryImpl::CreateSocket() { + return scoped_ptr<DatagramServerSocket>(new UDPServerSocket( + NULL, NetLog::Source())); +} + +// static +scoped_ptr<MDnsConnection::SocketFactory> +MDnsConnection::SocketFactory::CreateDefault() { + return scoped_ptr<MDnsConnection::SocketFactory>( + new MDnsConnectionSocketFactoryImpl); +} + +MDnsClientImpl::Core::Core(MDnsClientImpl* client, + MDnsConnection::SocketFactory* socket_factory) + : client_(client), connection_(new MDnsConnection(socket_factory, this)) { +} + +MDnsClientImpl::Core::~Core() { + STLDeleteValues(&listeners_); +} + +bool MDnsClientImpl::Core::Init() { + return connection_->Init() == OK; +} + +bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { + std::string name_dns; + if (!DNSDomainFromDot(name, &name_dns)) + return false; + + DnsQuery query(0, name_dns, rrtype); + query.set_flags(0); // Remove the RD flag from the query. It is unneeded. + + return connection_->Send(query.io_buffer(), query.io_buffer()->size()) == OK; +} + +void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, + int bytes_read) { + unsigned offset; + + if (!response->InitParseWithoutQuery(bytes_read)) { + LOG(WARNING) << "Could not understand an mDNS packet."; + return; // Message is unreadable. + } + + // TODO(noamsml): duplicate query suppression. + if (!(response->flags() & dns_protocol::kFlagResponse)) + return; // Message is a query. ignore it. + + DnsRecordParser parser = response->Parser(); + unsigned answer_count = response->answer_count() + + response->additional_answer_count(); + + for (unsigned i = 0; i < answer_count; i++) { + offset = parser.GetOffset(); + scoped_ptr<const RecordParsed> scoped_record = RecordParsed::CreateFrom( + &parser, base::Time::Now()); + + if (!scoped_record) { + LOG(WARNING) << "Could not understand an mDNS record."; + + if (offset == parser.GetOffset()) { + LOG(WARNING) << "Abandoned parsing the rest of the packet."; + return; // The parser did not advance, abort reading the packet. + } else { + continue; // We may be able to extract other records from the packet. + } + } + + if ((scoped_record->klass() & dns_protocol::kMDnsClassMask) != + dns_protocol::kClassIN) { + LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring."; + continue; // Ignore all records not in the IN class. + } + + // We want to retain a copy of the record pointer for updating listeners + // but we are passing ownership to the cache. + const RecordParsed* record = scoped_record.get(); + MDnsCache::UpdateType update = cache_.UpdateDnsRecord(scoped_record.Pass()); + + // Cleanup time may have changed. + ScheduleCleanup(cache_.next_expiration()); + + if (update != MDnsCache::NoChange) { + MDnsListener::UpdateType update_external; + + switch (update) { + case MDnsCache::RecordAdded: + update_external = MDnsListener::RECORD_ADDED; + break; + case MDnsCache::RecordChanged: + update_external = MDnsListener::RECORD_CHANGED; + break; + case MDnsCache::NoChange: + default: + NOTREACHED(); + // Dummy assignment to suppress compiler warning. + update_external = MDnsListener::RECORD_CHANGED; + break; + } + + AlertListeners(update_external, + ListenerKey(record->type(), record->name()), record); + // Alert listeners listening only for rrtype and not for name. + AlertListeners(update_external, ListenerKey(record->type(), ""), record); + } + } +} + +void MDnsClientImpl::Core::OnConnectionError(int error) { + // TODO(noamsml): On connection error, recreate connection and flush cache. +} + +void MDnsClientImpl::Core::AlertListeners( + MDnsListener::UpdateType update_type, + const ListenerKey& key, + const RecordParsed* record) { + ListenerMap::iterator listener_map_iterator = listeners_.find(key); + if (listener_map_iterator == listeners_.end()) return; + + FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second, + AlertDelegate(update_type, record)); +} + +void MDnsClientImpl::Core::AddListener( + MDnsListenerImpl* listener) { + ListenerKey key(listener->GetType(), listener->GetName()); + std::pair<ListenerMap::iterator, bool> observer_insert_result = + listeners_.insert( + make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL))); + + // If an equivalent key does not exist, actually create the observer list. + if (observer_insert_result.second) + observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>(); + + ObserverList<MDnsListenerImpl>* observer_list = + observer_insert_result.first->second; + + observer_list->AddObserver(listener); +} + +void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { + ListenerKey key(listener->GetType(), listener->GetName()); + ListenerMap::iterator observer_list_iterator = listeners_.find(key); + + DCHECK(observer_list_iterator != listeners_.end()); + DCHECK(observer_list_iterator->second->HasObserver(listener)); + + observer_list_iterator->second->RemoveObserver(listener); + + // Remove the observer list from the map if it is empty + if (observer_list_iterator->second->size() == 0) { + delete observer_list_iterator->second; + listeners_.erase(observer_list_iterator); + } +} + +void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) { + // Cleanup is already scheduled, no need to do anything. + if (cleanup == scheduled_cleanup_) return; + scheduled_cleanup_ = cleanup; + + // This cancels the previously scheduled cleanup. + cleanup_callback_.Reset(base::Bind( + &MDnsClientImpl::Core::DoCleanup, base::Unretained(this))); + + // If |cleanup| is empty, then no cleanup necessary. + if (cleanup != base::Time()) { + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + cleanup_callback_.callback(), + cleanup - base::Time::Now()); + } +} + +void MDnsClientImpl::Core::DoCleanup() { + cache_.CleanupRecords(base::Time::Now(), base::Bind( + &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this))); + + ScheduleCleanup(cache_.next_expiration()); +} + +void MDnsClientImpl::Core::OnRecordRemoved( + const RecordParsed* record) { + AlertListeners(MDnsListener::RECORD_REMOVED, + ListenerKey(record->type(), record->name()), record); + // Alert listeners listening only for rrtype and not for name. + AlertListeners(MDnsListener::RECORD_REMOVED, ListenerKey(record->type(), ""), + record); +} + +void MDnsClientImpl::Core::QueryCache( + uint16 rrtype, const std::string& name, + std::vector<const RecordParsed*>* records) const { + cache_.FindDnsRecords(rrtype, name, records, base::Time::Now()); +} + +MDnsClientImpl::MDnsClientImpl( + scoped_ptr<MDnsConnection::SocketFactory> socket_factory) + : listen_refs_(0), socket_factory_(socket_factory.Pass()) { +} + +MDnsClientImpl::~MDnsClientImpl() { +} + +bool MDnsClientImpl::AddListenRef() { + if (!core_.get()) { + core_.reset(new Core(this, socket_factory_.get())); + if (!core_->Init()) { + core_.reset(); + return false; + } + } + listen_refs_++; + return true; +} + +void MDnsClientImpl::SubtractListenRef() { + listen_refs_--; + if (listen_refs_ == 0) { + base::MessageLoop::current()->PostTask(FROM_HERE, base::Bind( + &MDnsClientImpl::Shutdown, base::Unretained(this))); + } +} + +void MDnsClientImpl::Shutdown() { + // We need to check that new listeners haven't been created. + if (listen_refs_ == 0) { + core_.reset(); + } +} + +bool MDnsClientImpl::IsListeningForTests() { + return core_.get() != NULL; +} + +scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate) { + return scoped_ptr<net::MDnsListener>( + new MDnsListenerImpl(rrtype, name, delegate, this)); +} + +scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback) { + return scoped_ptr<MDnsTransaction>( + new MDnsTransactionImpl(rrtype, name, flags, callback, this)); +} + +MDnsListenerImpl::MDnsListenerImpl( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate, + MDnsClientImpl* client) + : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate), + started_(false) { +} + +bool MDnsListenerImpl::Start() { + DCHECK(!started_); + + if (!client_->AddListenRef()) return false; + started_ = true; + + DCHECK(client_->core()); + client_->core()->AddListener(this); + + return true; +} + +MDnsListenerImpl::~MDnsListenerImpl() { + if (started_) { + DCHECK(client_->core()); + client_->core()->RemoveListener(this); + client_->SubtractListenRef(); + } +} + +const std::string& MDnsListenerImpl::GetName() const { + return name_; +} + +uint16 MDnsListenerImpl::GetType() const { + return rrtype_; +} + +void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type, + const RecordParsed* record) { + DCHECK(started_); + delegate_->OnRecordUpdate(update_type, record); +} + +MDnsTransactionImpl::MDnsTransactionImpl( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback, + MDnsClientImpl* client) + : rrtype_(rrtype), name_(name), callback_(callback), client_(client), + started_(false), flags_(flags) { + DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_); + DCHECK(flags_ & MDnsTransaction::QUERY_CACHE || + flags_ & MDnsTransaction::QUERY_NETWORK); +} + +MDnsTransactionImpl::~MDnsTransactionImpl() { + timeout_.Cancel(); +} + +bool MDnsTransactionImpl::Start() { + DCHECK(!started_); + started_ = true; + std::vector<const RecordParsed*> records; + base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); + + if (flags_ & MDnsTransaction::QUERY_CACHE) { + if (client_->core()) { + client_->core()->QueryCache(rrtype_, name_, &records); + for (std::vector<const RecordParsed*>::iterator i = records.begin(); + i != records.end() && weak_this; ++i) { + weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, + records.front()); + } + } + } + + if (!weak_this) return true; + + if (is_active() && (flags_ & MDnsTransaction::QUERY_NETWORK)) { + listener_ = client_->CreateListener(rrtype_, name_, this); + if (!listener_->Start()) return false; + + DCHECK(client_->core()); + if (!client_->core()->SendQuery(rrtype_, name_)) + return false; + + timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver, + weak_this)); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + timeout_.callback(), + base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds)); + + return listener_.get() != NULL; + } else { + // If this is a cache only query, signal that the transaction is over + // immediately. + SignalTransactionOver(); + } + + return true; +} + +const std::string& MDnsTransactionImpl::GetName() const { + return name_; +} + +uint16 MDnsTransactionImpl::GetType() const { + return rrtype_; +} + +void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) { + DCHECK(started_); + OnRecordUpdate(MDnsListener::RECORD_ADDED, record); +} + +void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result, + const RecordParsed* record) { + DCHECK(started_); + if (!is_active()) return; + + // Ensure callback is run after touching all class state, so that + // the callback can delete the transaction. + MDnsTransaction::ResultCallback callback = callback_; + + if (flags_ & MDnsTransaction::SINGLE_RESULT) + Reset(); + + callback.Run(result, record); +} + +void MDnsTransactionImpl::Reset() { + callback_.Reset(); + listener_.reset(); + timeout_.Cancel(); +} + +void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update, + const RecordParsed* record) { + DCHECK(started_); + if (update == MDnsListener::RECORD_ADDED || + update == MDnsListener::RECORD_CHANGED) + TriggerCallback(MDnsTransaction::RESULT_RECORD, record); +} + +void MDnsTransactionImpl::SignalTransactionOver() { + DCHECK(started_); + base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); + + if (flags_ & MDnsTransaction::SINGLE_RESULT) { + TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL); + } else { + TriggerCallback(MDnsTransaction::RESULT_DONE, NULL); + } + + if (weak_this) { + weak_this->Reset(); + } +} + +void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) { + // TODO(noamsml): NSEC records not yet implemented +} + +void MDnsTransactionImpl::OnCachePurged() { + // TODO(noamsml): Cache purge situations not yet implemented +} + +} // namespace net diff --git a/net/dns/mdns_client_impl.h b/net/dns/mdns_client_impl.h new file mode 100644 index 0000000..5a22894 --- /dev/null +++ b/net/dns/mdns_client_impl.h @@ -0,0 +1,288 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MDNS_CLIENT_IMPL_H_ +#define NET_DNS_MDNS_CLIENT_IMPL_H_ + +#include <map> +#include <string> +#include <utility> +#include <vector> + +#include "base/cancelable_callback.h" +#include "base/observer_list.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/dns/mdns_cache.h" +#include "net/dns/mdns_client.h" +#include "net/udp/datagram_server_socket.h" +#include "net/udp/udp_server_socket.h" +#include "net/udp/udp_socket.h" + +namespace net { + +// A connection to the network for multicast DNS clients. It reads data into +// DnsResponse objects and alerts the delegate that a packet has been received. +class MDnsConnection { + public: + class SocketFactory { + public: + virtual ~SocketFactory() {} + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() = 0; + + static scoped_ptr<SocketFactory> CreateDefault(); + }; + + class Delegate { + public: + // Handle an mDNS packet buffered in |response| with a size of |bytes_read|. + virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0; + virtual void OnConnectionError(int error) = 0; + virtual ~Delegate() {} + }; + + explicit MDnsConnection(SocketFactory* socket_factory, + MDnsConnection::Delegate* delegate); + + virtual ~MDnsConnection(); + + int Init(); + int Send(IOBuffer* buffer, unsigned size); + + private: + class SocketHandler { + public: + SocketHandler(MDnsConnection* connection, + const IPEndPoint& multicast_addr, + SocketFactory* socket_factory); + ~SocketHandler(); + int DoLoop(int rv); + int Start(); + + int Send(IOBuffer* buffer, unsigned size); + + private: + int BindSocket(); + void OnDatagramReceived(int rv); + + // Callback for when sending a query has finished. + void SendDone(int rv); + + scoped_ptr<DatagramServerSocket> socket_; + + MDnsConnection* connection_; + IPEndPoint recv_addr_; + scoped_ptr<DnsResponse> response_; + IPEndPoint multicast_addr_; + }; + + // Callback for handling a datagram being received on either ipv4 or ipv6. + void OnDatagramReceived(DnsResponse* response, + const IPEndPoint& recv_addr, + int bytes_read); + + void OnError(SocketHandler* loop, int error); + + IPEndPoint GetMDnsIPEndPoint(const char* address); + + SocketHandler socket_handler_ipv4_; + SocketHandler socket_handler_ipv6_; + + Delegate* delegate_; + + DISALLOW_COPY_AND_ASSIGN(MDnsConnection); +}; + +class MDnsListenerImpl; + +class MDnsClientImpl : public MDnsClient { + public: + // The core object exists while the MDnsClient is listening, and is + // deleted whenever the number of listeners reaches zero. + class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate { + public: + Core(MDnsClientImpl* client, + MDnsConnection::SocketFactory* socket_factory); + virtual ~Core(); + + // Initialize the core. Returns true on success. + bool Init(); + + // Send a query with a specific rrtype and name. Returns true on success. + bool SendQuery(uint16 rrtype, std::string name); + + // Add/remove a listener to the list of listener. May cause network traffic + // if listener is active. + void AddListener(MDnsListenerImpl* listener); + void RemoveListener(MDnsListenerImpl* listener); + + // Query the cache for records of a specific type and name. + void QueryCache(uint16 rrtype, const std::string& name, + std::vector<const RecordParsed*>* records) const; + + // Parse the response and alert relevant listeners. + virtual void HandlePacket(DnsResponse* response, int bytes_read) OVERRIDE; + + virtual void OnConnectionError(int error) OVERRIDE; + + private: + typedef std::pair<uint16, std::string> ListenerKey; + typedef std::map<ListenerKey, ObserverList<MDnsListenerImpl>* > + ListenerMap; + + // Alert listeners of an update to the cache. + void AlertListeners(MDnsListener::UpdateType update_type, + const ListenerKey& key, const RecordParsed* record); + + // Schedule a cache cleanup to a specific time, cancelling other cleanups. + void ScheduleCleanup(base::Time cleanup); + + // Clean up the cache and schedule a new cleanup. + void DoCleanup(); + + // Callback for when a record is removed from the cache. + void OnRecordRemoved(const RecordParsed* record); + + ListenerMap listeners_; + + MDnsClientImpl* client_; + MDnsCache cache_; + + base::CancelableCallback<void()> cleanup_callback_; + base::Time scheduled_cleanup_; + + scoped_ptr<MDnsConnection> connection_; + + DISALLOW_COPY_AND_ASSIGN(Core); + }; + + explicit MDnsClientImpl( + scoped_ptr<MDnsConnection::SocketFactory> socket_factory_); + virtual ~MDnsClientImpl(); + + // MDnsClient implementation: + virtual scoped_ptr<MDnsListener> CreateListener( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate) OVERRIDE; + + virtual scoped_ptr<MDnsTransaction> CreateTransaction( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback) OVERRIDE; + + // Returns true when the client is listening for network packets. + bool IsListeningForTests(); + + bool AddListenRef(); + void SubtractListenRef(); + + Core* core() { return core_.get(); } + + private: + // This method causes the client to stop listening for packets. The + // call for it is deferred through the message loop after the last + // listener is removed. If another listener is added after a + // shutdown is scheduled but before it actually runs, the shutdown + // will be canceled. + void Shutdown(); + + scoped_ptr<Core> core_; + int listen_refs_; + + scoped_ptr<MDnsConnection::SocketFactory> socket_factory_; + + DISALLOW_COPY_AND_ASSIGN(MDnsClientImpl); +}; + +class MDnsListenerImpl : public MDnsListener, + public base::SupportsWeakPtr<MDnsListenerImpl> { + public: + MDnsListenerImpl(uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate, + MDnsClientImpl* client); + + virtual ~MDnsListenerImpl(); + + // MDnsListener implementation: + virtual bool Start() OVERRIDE; + + virtual const std::string& GetName() const OVERRIDE; + + virtual uint16 GetType() const OVERRIDE; + + MDnsListener::Delegate* delegate() { return delegate_; } + + // Alert the delegate of a record update. + void AlertDelegate(MDnsListener::UpdateType update_type, + const RecordParsed* record_parsed); + private: + uint16 rrtype_; + std::string name_; + MDnsClientImpl* client_; + MDnsListener::Delegate* delegate_; + + bool started_; + DISALLOW_COPY_AND_ASSIGN(MDnsListenerImpl); +}; + +class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>, + public MDnsTransaction, + public MDnsListener::Delegate { + public: + MDnsTransactionImpl(uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback, + MDnsClientImpl* client); + virtual ~MDnsTransactionImpl(); + + // MDnsTransaction implementation: + virtual bool Start() OVERRIDE; + + virtual const std::string& GetName() const OVERRIDE; + virtual uint16 GetType() const OVERRIDE; + + // MDnsListener::Delegate implementation: + virtual void OnRecordUpdate(MDnsListener::UpdateType update, + const RecordParsed* record) OVERRIDE; + virtual void OnNsecRecord(const std::string& name, unsigned type) OVERRIDE; + + virtual void OnCachePurged() OVERRIDE; + + private: + bool is_active() { return !callback_.is_null(); } + + void Reset(); + + // Trigger the callback and reset all related variables. + void TriggerCallback(MDnsTransaction::Result result, + const RecordParsed* record); + + // Internal callback for when a cache record is found. + void CacheRecordFound(const RecordParsed* record); + + // Signal the transactionis over and release all related resources. + void SignalTransactionOver(); + + uint16 rrtype_; + std::string name_; + MDnsTransaction::ResultCallback callback_; + + scoped_ptr<MDnsListener> listener_; + base::CancelableCallback<void()> timeout_; + + MDnsClientImpl* client_; + + bool started_; + int flags_; + + DISALLOW_COPY_AND_ASSIGN(MDnsTransactionImpl); +}; + +} // namespace net +#endif // NET_DNS_MDNS_CLIENT_IMPL_H_ diff --git a/net/dns/mdns_client_unittest.cc b/net/dns/mdns_client_unittest.cc new file mode 100644 index 0000000..8f0b45a --- /dev/null +++ b/net/dns/mdns_client_unittest.cc @@ -0,0 +1,1070 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <queue> + +#include "base/memory/ref_counted.h" +#include "base/message_loop.h" +#include "net/base/rand_callback.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mdns_client_impl.h" +#include "net/dns/record_rdata.h" +#include "net/udp/udp_client_socket.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using ::testing::Invoke; +using ::testing::InvokeWithoutArgs; +using ::testing::StrictMock; +using ::testing::NiceMock; +using ::testing::Exactly; +using ::testing::Return; +using ::testing::SaveArg; +using ::testing::_; + +namespace net { + +namespace { + +const char kSamplePacket1[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 1 second; + 0x00, 0x01, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', + 0xc0, 0x14, // Pointer to "._tcp.local" + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds. + 0x24, 0x75, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x32 +}; + +const char kCorruptedPacketBadQuestion[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x01, // One question + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question is corrupted and cannot be read. + 0x99, 'h', 'e', 'l', 'l', 'o', + 0x00, + 0x00, 0x00, + 0x00, 0x00, + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x99, // RDLENGTH is impossible + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', // Useless trailing data. +}; + +const char kCorruptedPacketUnsalvagable[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x99, // RDLENGTH is impossible + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', // Useless trailing data. +}; + +const char kCorruptedPacketSalvagable[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x99, 'h', 'e', 'l', 'l', 'o', // Bad RDATA format. + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', + 0xc0, 0x14, // Pointer to "._tcp.local" + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds. + 0x24, 0x75, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x32 +}; + +const char kSamplePacket2[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'z', 'z', 'z', 'z', 'z', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', + 0xc0, 0x14, // Pointer to "._tcp.local" + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'z', 'z', 'z', 'z', 'z', + 0xc0, 0x32 +}; + +const char kQueryPacketPrivet[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x00, 0x00, // No flags. + 0x00, 0x01, // One question. + 0x00, 0x00, // 0 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question + // This part is echoed back from the respective query. + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. +}; + +const char kSamplePacketAdditionalOnly[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x00, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x01, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, +}; + +class MockDatagramServerSocket : public DatagramServerSocket { + public: + // DatagramServerSocket implementation: + int Listen(const IPEndPoint& address) { + return ListenInternal(address.ToString()); + } + + MOCK_METHOD1(ListenInternal, int(const std::string& address)); + + MOCK_METHOD4(RecvFrom, int(IOBuffer* buffer, int size, + IPEndPoint* address, + const CompletionCallback& callback)); + + int SendTo(IOBuffer* buf, int buf_len, const IPEndPoint& address, + const CompletionCallback& callback) { + return SendToInternal(std::string(buf->data(), buf_len), address.ToString(), + callback); + } + + MOCK_METHOD3(SendToInternal, int(const std::string& packet, + const std::string address, + const CompletionCallback& callback)); + + MOCK_METHOD1(SetReceiveBufferSize, bool(int32 size)); + MOCK_METHOD1(SetSendBufferSize, bool(int32 size)); + + MOCK_METHOD0(Close, void()); + + MOCK_CONST_METHOD1(GetPeerAddress, int(IPEndPoint* address)); + MOCK_CONST_METHOD1(GetLocalAddress, int(IPEndPoint* address)); + MOCK_CONST_METHOD0(NetLog, const BoundNetLog&()); + + MOCK_METHOD0(AllowAddressReuse, void()); + MOCK_METHOD0(AllowBroadcast, void()); + + int JoinGroup(const IPAddressNumber& group_address) const { + return JoinGroupInternal(IPAddressToString(group_address)); + } + + MOCK_CONST_METHOD1(JoinGroupInternal, int(const std::string& group)); + + int LeaveGroup(const IPAddressNumber& group_address) const { + return LeaveGroupInternal(IPAddressToString(group_address)); + } + + MOCK_CONST_METHOD1(LeaveGroupInternal, int(const std::string& group)); + + MOCK_METHOD1(SetMulticastTimeToLive, int(int ttl)); + + MOCK_METHOD1(SetMulticastLoopbackMode, int(bool loopback)); + + void SetResponsePacket(std::string response_packet) { + response_packet_ = response_packet; + } + + int HandleRecvNow(IOBuffer* buffer, int size, IPEndPoint* address, + const CompletionCallback& callback) { + int size_returned = + std::min(response_packet_.size(), static_cast<size_t>(size)); + memcpy(buffer->data(), response_packet_.data(), size_returned); + return size_returned; + } + + int HandleRecvLater(IOBuffer* buffer, int size, IPEndPoint* address, + const CompletionCallback& callback) { + int rv = HandleRecvNow(buffer, size, address, callback); + base::MessageLoop::current()->PostTask(FROM_HERE, base::Bind(callback, rv)); + return ERR_IO_PENDING; + } + + private: + std::string response_packet_; +}; + +class MockSocketFactory + : public MDnsConnection::SocketFactory { + public: + MockSocketFactory() { + } + + virtual ~MockSocketFactory() { + } + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE { + scoped_ptr<MockDatagramServerSocket> new_socket( + new NiceMock<MockDatagramServerSocket>); + + ON_CALL(*new_socket, SendToInternal(_, _, _)) + .WillByDefault(Invoke( + this, + &MockSocketFactory::SendToInternal)); + + ON_CALL(*new_socket, RecvFrom(_, _, _, _)) + .WillByDefault(Invoke( + this, + &MockSocketFactory::RecvFromInternal)); + + return new_socket.PassAs<DatagramServerSocket>(); + } + + void SimulateReceive(const char* packet, int size) { + DCHECK(recv_buffer_size_ >= size); + DCHECK(recv_buffer_.get()); + DCHECK(!recv_callback_.is_null()); + + memcpy(recv_buffer_->data(), packet, size); + CompletionCallback recv_callback = recv_callback_; + recv_callback_.Reset(); + recv_callback.Run(size); + } + + MOCK_METHOD1(OnSendTo, void(const std::string&)); + + private: + int SendToInternal(const std::string& packet, const std::string& address, + const CompletionCallback& callback) { + OnSendTo(packet); + return packet.size(); + } + + // The latest receive callback is always saved, since the MDnsConnection + // does not care which socket a packet is received on. + int RecvFromInternal(IOBuffer* buffer, int size, + IPEndPoint* address, + const CompletionCallback& callback) { + recv_buffer_ = buffer; + recv_buffer_size_ = size; + recv_callback_ = callback; + return ERR_IO_PENDING; + } + + scoped_refptr<IOBuffer> recv_buffer_; + int recv_buffer_size_; + CompletionCallback recv_callback_; +}; + +class PtrRecordCopyContainer { + public: + PtrRecordCopyContainer() {} + ~PtrRecordCopyContainer() {} + + bool is_set() const { return set_; } + + void SaveWithDummyArg(int unused, const RecordParsed* value) { + Save(value); + } + + void Save(const RecordParsed* value) { + set_ = true; + name_ = value->name(); + ptrdomain_ = value->rdata<PtrRecordRdata>()->ptrdomain(); + ttl_ = value->ttl(); + } + + bool IsRecordWith(std::string name, std::string ptrdomain) { + return set_ && name_ == name && ptrdomain_ == ptrdomain; + } + + const std::string& name() { return name_; } + const std::string& ptrdomain() { return ptrdomain_; } + int ttl() { return ttl_; } + + private: + bool set_; + std::string name_; + std::string ptrdomain_; + int ttl_; +}; + +class MDnsTest : public ::testing::Test { + public: + MDnsTest(); + virtual ~MDnsTest(); + void DeleteTransaction(); + void DeleteBothListeners(); + void RunUntilIdle(); + void RunFor(base::TimeDelta time_period); + void Stop(); + + MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result, + const RecordParsed* record)); + + protected: + void ExpectPacket(const char* packet, unsigned size); + void SimulatePacketReceive(const char* packet, unsigned size); + + base::MessageLoop* message_loop_; + + scoped_ptr<MDnsClientImpl> test_client_; + IPEndPoint mdns_ipv4_endpoint_; + StrictMock<MockSocketFactory>* socket_factory_; + + // Transactions and listeners that can be deleted by class methods for + // reentrancy tests. + scoped_ptr<MDnsTransaction> transaction_; + scoped_ptr<MDnsListener> listener1_; + scoped_ptr<MDnsListener> listener2_; +}; + +class MockListenerDelegate : public MDnsListener::Delegate { + public: + MOCK_METHOD2(OnRecordUpdate, + void(MDnsListener::UpdateType update, + const RecordParsed* records)); + MOCK_METHOD2(OnNsecRecord, void(const std::string&, unsigned)); + MOCK_METHOD0(OnCachePurged, void()); +}; + +MDnsTest::MDnsTest() + : message_loop_(base::MessageLoop::current()) { + socket_factory_ = new StrictMock<MockSocketFactory>(); + test_client_.reset(new MDnsClientImpl( + scoped_ptr<MDnsConnection::SocketFactory>(socket_factory_))); +} + +MDnsTest::~MDnsTest() { +} + +void MDnsTest::SimulatePacketReceive(const char* packet, unsigned size) { + socket_factory_->SimulateReceive(packet, size); +} + +void MDnsTest::ExpectPacket( + const char* packet, + unsigned size) { + EXPECT_CALL(*socket_factory_, OnSendTo(std::string(packet, size))) + .Times(2); +} + +void MDnsTest::DeleteTransaction() { + transaction_.reset(); +} + +void MDnsTest::DeleteBothListeners() { + listener1_.reset(); + listener2_.reset(); +} + +void MDnsTest::RunUntilIdle() { + base::MessageLoop::current()->RunUntilIdle(); +} + +void MDnsTest::RunFor(base::TimeDelta time_period) { + base::CancelableCallback<void()> callback(base::Bind(&MDnsTest::Stop, + base::Unretained(this))); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, callback.callback(), time_period); + + base::MessageLoop::current()->Run(); + callback.Cancel(); +} + +void MDnsTest::Stop() { + base::MessageLoop::current()->Quit(); +} + +TEST_F(MDnsTest, PassiveListeners) { + StrictMock<MockListenerDelegate> delegate_privet; + StrictMock<MockListenerDelegate> delegate_printer; + StrictMock<MockListenerDelegate> delegate_ptr; + + PtrRecordCopyContainer record_privet; + PtrRecordCopyContainer record_printer; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet); + scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener( + dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer); + scoped_ptr<MDnsListener> listener_ptr = test_client_->CreateListener( + dns_protocol::kTypePTR, "", &delegate_ptr); + + ASSERT_TRUE(listener_privet->Start()); + ASSERT_TRUE(listener_printer->Start()); + ASSERT_TRUE(listener_ptr->Start()); + + ASSERT_TRUE(test_client_->IsListeningForTests()); + + // Send the same packet twice to ensure no records are double-counted. + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_printer, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + EXPECT_CALL(delegate_ptr, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(2)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local", + "hello._printer._tcp.local")); + + listener_privet.reset(); + listener_printer.reset(); + + ASSERT_TRUE(test_client_->IsListeningForTests()); + + EXPECT_CALL(delegate_ptr, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(2)); + + SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2)); + + // Test to make sure mdns listener is not active with no listeners present. + listener_ptr.reset(); + + RunUntilIdle(); + + ASSERT_FALSE(test_client_->IsListeningForTests()); +} + +TEST_F(MDnsTest, PassiveListenersCacheCleanup) { + StrictMock<MockListenerDelegate> delegate_privet; + + PtrRecordCopyContainer record_privet; + PtrRecordCopyContainer record_privet2; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + ASSERT_TRUE(test_client_->IsListeningForTests()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + RunUntilIdle(); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + // Expect record is removed when its TTL expires. + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _)) + .Times(Exactly(1)) + .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::Stop), + Invoke(&record_privet2, + &PtrRecordCopyContainer::SaveWithDummyArg))); + + RunFor(base::TimeDelta::FromSeconds(record_privet.ttl() + 1)); + + RunUntilIdle(); + + EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, MalformedPacket) { + StrictMock<MockListenerDelegate> delegate_printer; + + PtrRecordCopyContainer record_printer; + + scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener( + dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer); + + ASSERT_TRUE(listener_printer->Start()); + + ASSERT_TRUE(test_client_->IsListeningForTests()); + + EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_printer, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + // First, send unsalvagable packet to ensure we can deal with it. + SimulatePacketReceive(kCorruptedPacketUnsalvagable, + sizeof(kCorruptedPacketUnsalvagable)); + + // Regression test: send a packet where the question cannot be read. + SimulatePacketReceive(kCorruptedPacketBadQuestion, + sizeof(kCorruptedPacketBadQuestion)); + + // Then send salvagable packet to ensure we can extract useful records. + SimulatePacketReceive(kCorruptedPacketSalvagable, + sizeof(kCorruptedPacketSalvagable)); + + RunUntilIdle(); + + EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local", + "hello._printer._tcp.local")); +} + +TEST_F(MDnsTest, TransactionWithEmptyCache) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + PtrRecordCopyContainer record_privet; + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .Times(Exactly(1)) + .WillOnce(Invoke(&record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + RunUntilIdle(); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + EXPECT_FALSE(test_client_->IsListeningForTests()); +} + +TEST_F(MDnsTest, TransactionCacheOnlyNoResult) { + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + EXPECT_CALL(*this, + MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, _)) + .Times(Exactly(1)); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_FALSE(test_client_->IsListeningForTests()); + + RunUntilIdle(); +} + +TEST_F(MDnsTest, TransactionWithCache) { + // Listener to force the client to listen + StrictMock<MockListenerDelegate> delegate_irrelevant; + scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener( + dns_protocol::kTypeA, "codereview.chromium.local", + &delegate_irrelevant); + + ASSERT_TRUE(listener_irrelevant->Start()); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + RunUntilIdle(); + + PtrRecordCopyContainer record_privet; + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .WillOnce(Invoke(&record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + RunUntilIdle(); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, AdditionalRecords) { + StrictMock<MockListenerDelegate> delegate_privet; + + PtrRecordCopyContainer record_privet; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + ASSERT_TRUE(test_client_->IsListeningForTests()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacketAdditionalOnly, sizeof(kSamplePacket1)); + + RunUntilIdle(); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, TransactionTimeout) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + EXPECT_CALL(*this, + MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL)) + .Times(Exactly(1)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop)); + + RunFor(base::TimeDelta::FromSeconds(4)); + + EXPECT_FALSE(test_client_->IsListeningForTests()); +} + +TEST_F(MDnsTest, TransactionMultipleRecords) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE , + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + PtrRecordCopyContainer record_privet; + PtrRecordCopyContainer record_privet2; + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .Times(Exactly(2)) + .WillOnce(Invoke(&record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)) + .WillOnce(Invoke(&record_privet2, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2)); + + RunUntilIdle(); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local", + "zzzzz._privet._tcp.local")); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_DONE, NULL)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop)); + + RunFor(base::TimeDelta::FromSeconds(4)); + + EXPECT_FALSE(test_client_->IsListeningForTests()); +} + +TEST_F(MDnsTest, TransactionReentrantDelete) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + transaction_ = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_->Start()); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, + NULL)) + .Times(Exactly(1)) + .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction), + InvokeWithoutArgs(this, &MDnsTest::Stop))); + + RunFor(base::TimeDelta::FromSeconds(4)); + + EXPECT_EQ(NULL, transaction_.get()); + + EXPECT_FALSE(test_client_->IsListeningForTests()); +} + +TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) { + StrictMock<MockListenerDelegate> delegate_irrelevant; + scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener( + dns_protocol::kTypeA, "codereview.chromium.local", + &delegate_irrelevant); + ASSERT_TRUE(listener_irrelevant->Start()); + + ASSERT_TRUE(test_client_->IsListeningForTests()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + transaction_ = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .Times(Exactly(1)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction)); + + ASSERT_TRUE(transaction_->Start()); + + RunUntilIdle(); + + EXPECT_EQ(NULL, transaction_.get()); +} + +// In order to reliably test reentrant listener deletes, we create two listeners +// and have each of them delete both, so we're guaranteed to try and deliver a +// callback to at least one deleted listener. + +TEST_F(MDnsTest, ListenerReentrantDelete) { + StrictMock<MockListenerDelegate> delegate_privet; + + listener1_ = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_privet); + + listener2_ = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_privet); + + ASSERT_TRUE(listener1_->Start()); + + ASSERT_TRUE(listener2_->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteBothListeners)); + + EXPECT_TRUE(test_client_->IsListeningForTests()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + RunUntilIdle(); + + EXPECT_EQ(NULL, listener1_.get()); + EXPECT_EQ(NULL, listener2_.get()); + + EXPECT_FALSE(test_client_->IsListeningForTests()); +} + +// Note: These tests assume that the ipv4 socket will always be created first. +// This is a simplifying assumption based on the way the code works now. + +class SimpleMockSocketFactory + : public MDnsConnection::SocketFactory { + public: + SimpleMockSocketFactory() { + } + virtual ~SimpleMockSocketFactory() { + } + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE { + scoped_ptr<MockDatagramServerSocket> socket( + new StrictMock<MockDatagramServerSocket>); + sockets_.push(socket.get()); + return socket.PassAs<DatagramServerSocket>(); + } + + MockDatagramServerSocket* PopFirstSocket() { + MockDatagramServerSocket* socket = sockets_.front(); + sockets_.pop(); + return socket; + } + + size_t num_sockets() { + return sockets_.size(); + } + + private: + std::queue<MockDatagramServerSocket*> sockets_; +}; + +class MockMDnsConnectionDelegate : public MDnsConnection::Delegate { + public: + virtual void HandlePacket(DnsResponse* response, int size) { + HandlePacketInternal(std::string(response->io_buffer()->data(), size)); + } + + MOCK_METHOD1(HandlePacketInternal, void(std::string packet)); + + MOCK_METHOD1(OnConnectionError, void(int error)); +}; + +class MDnsConnectionTest : public ::testing::Test { + public: + MDnsConnectionTest() : connection_(&factory_, &delegate_) { + } + + protected: + // Follow successful connection initialization. + virtual void SetUp() OVERRIDE { + ASSERT_EQ(2u, factory_.num_sockets()); + + socket_ipv4_ = factory_.PopFirstSocket(); + socket_ipv6_ = factory_.PopFirstSocket(); + } + + bool InitConnection() { + EXPECT_CALL(*socket_ipv4_, AllowAddressReuse()); + EXPECT_CALL(*socket_ipv6_, AllowAddressReuse()); + + EXPECT_CALL(*socket_ipv4_, SetMulticastLoopbackMode(false)); + EXPECT_CALL(*socket_ipv6_, SetMulticastLoopbackMode(false)); + + EXPECT_CALL(*socket_ipv4_, ListenInternal("0.0.0.0:5353")) + .WillOnce(Return(OK)); + EXPECT_CALL(*socket_ipv6_, ListenInternal("[::]:5353")) + .WillOnce(Return(OK)); + + EXPECT_CALL(*socket_ipv4_, JoinGroupInternal("224.0.0.251")) + .WillOnce(Return(OK)); + EXPECT_CALL(*socket_ipv6_, JoinGroupInternal("ff02::fb")) + .WillOnce(Return(OK)); + + return connection_.Init() == OK; + } + + StrictMock<MockMDnsConnectionDelegate> delegate_; + + MockDatagramServerSocket* socket_ipv4_; + MockDatagramServerSocket* socket_ipv6_; + SimpleMockSocketFactory factory_; + MDnsConnection connection_; + TestCompletionCallback callback_; +}; + +TEST_F(MDnsConnectionTest, ReceiveSynchronous) { + std::string sample_packet = + std::string(kSamplePacket1, sizeof(kSamplePacket1)); + + socket_ipv6_->SetResponsePacket(sample_packet); + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce( + Invoke(socket_ipv6_, &MockDatagramServerSocket::HandleRecvNow)) + .WillOnce(Return(ERR_IO_PENDING)); + + EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet)); + + ASSERT_TRUE(InitConnection()); +} + +TEST_F(MDnsConnectionTest, ReceiveAsynchronous) { + std::string sample_packet = + std::string(kSamplePacket1, sizeof(kSamplePacket1)); + socket_ipv6_->SetResponsePacket(sample_packet); + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce( + Invoke(socket_ipv6_, &MockDatagramServerSocket::HandleRecvLater)) + .WillOnce(Return(ERR_IO_PENDING)); + + ASSERT_TRUE(InitConnection()); + + EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet)); + + base::MessageLoop::current()->RunUntilIdle(); +} + +TEST_F(MDnsConnectionTest, Send) { + std::string sample_packet = + std::string(kSamplePacket1, sizeof(kSamplePacket1)); + + scoped_refptr<IOBufferWithSize> buf( + new IOBufferWithSize(sizeof kSamplePacket1)); + memcpy(buf->data(), kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + + ASSERT_TRUE(InitConnection()); + + EXPECT_CALL(*socket_ipv4_, + SendToInternal(sample_packet, "224.0.0.251:5353", _)); + EXPECT_CALL(*socket_ipv6_, + SendToInternal(sample_packet, "[ff02::fb]:5353", _)); + + connection_.Send(buf, buf->size()); +} + +TEST_F(MDnsConnectionTest, Error) { + CompletionCallback callback; + + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce(DoAll(SaveArg<3>(&callback), Return(ERR_IO_PENDING))); + + ASSERT_TRUE(InitConnection()); + + EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED)); + callback.Run(ERR_SOCKET_NOT_CONNECTED); +} + +} // namespace + +} // namespace net diff --git a/net/net.gyp b/net/net.gyp index e8f36ce..68d0114 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -447,6 +447,10 @@ 'dns/mapped_host_resolver.h', 'dns/mdns_cache.cc', 'dns/mdns_cache.h', + 'dns/mdns_client.cc', + 'dns/mdns_client.h', + 'dns/mdns_client_impl.cc', + 'dns/mdns_client_impl.h', 'dns/notify_watcher_mac.cc', 'dns/notify_watcher_mac.h', 'dns/record_parsed.cc', @@ -1314,6 +1318,10 @@ 'sources!' : [ 'dns/mdns_cache.cc', 'dns/mdns_cache.h', + 'dns/mdns_client.cc', + 'dns/mdns_client.h', + 'dns/mdns_client_impl.cc', + 'dns/mdns_client_impl.h', 'dns/record_parsed.cc', 'dns/record_parsed.h', 'dns/record_rdata.cc', @@ -1544,6 +1552,7 @@ 'dns/host_resolver_impl_unittest.cc', 'dns/mapped_host_resolver_unittest.cc', 'dns/mdns_cache_unittest.cc', + 'dns/mdns_client_unittest.cc', 'dns/serial_worker_unittest.cc', 'dns/record_parsed_unittest.cc', 'dns/record_rdata_unittest.cc', @@ -1949,6 +1958,8 @@ [ 'enable_mdns != 1', { 'sources!' : [ 'dns/mdns_cache_unittest.cc', + 'dns/mdns_client_unittest.cc', + 'dns/mdns_query_unittest.cc', 'dns/record_parsed_unittest.cc', 'dns/record_rdata_unittest.cc', ], |