summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authornoamsml@chromium.org <noamsml@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-06-13 22:31:42 +0000
committernoamsml@chromium.org <noamsml@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-06-13 22:31:42 +0000
commit245b164eab1be2a0e993a438c62c5733ec2a32e5 (patch)
treefa4175a36f31d8133f4ba601d1e10987aeb59c9e /net
parentd23741d1102ab02aca4123dcb55ae125e6f3c463 (diff)
downloadchromium_src-245b164eab1be2a0e993a438c62c5733ec2a32e5.zip
chromium_src-245b164eab1be2a0e993a438c62c5733ec2a32e5.tar.gz
chromium_src-245b164eab1be2a0e993a438c62c5733ec2a32e5.tar.bz2
Multicast DNS implementation (initial)
An implementation of multicast DNS in net/. Currently, the multicast DNS implementation supports the following features: - Listeners, which can be notified of any changes to records of a certain type and name, or records of a certain type. - Transactions, which allow users to query multicast DNS for a specific unique record (e.g. an address) BUG=233821 TEST=MDnsTest.*,MDnsConnectionTest.* Review URL: https://chromiumcodereview.appspot.com/15733008 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@206170 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r--net/dns/dns_protocol.h5
-rw-r--r--net/dns/dns_query.cc6
-rw-r--r--net/dns/dns_query.h2
-rw-r--r--net/dns/dns_response.cc5
-rw-r--r--net/dns/dns_response.h2
-rw-r--r--net/dns/mdns_client.cc26
-rw-r--r--net/dns/mdns_client.h156
-rw-r--r--net/dns/mdns_client_impl.cc599
-rw-r--r--net/dns/mdns_client_impl.h288
-rw-r--r--net/dns/mdns_client_unittest.cc1070
-rw-r--r--net/net.gyp11
11 files changed, 2170 insertions, 0 deletions
diff --git a/net/dns/dns_protocol.h b/net/dns/dns_protocol.h
index b65f56d..4516b29 100644
--- a/net/dns/dns_protocol.h
+++ b/net/dns/dns_protocol.h
@@ -13,6 +13,7 @@ namespace net {
namespace dns_protocol {
static const uint16 kDefaultPort = 53;
+static const uint16 kDefaultPortMulticast = 5353;
// DNS packet consists of a header followed by questions and/or answers.
// For the meaning of specific fields, please see RFC 1035 and 2535
@@ -101,6 +102,10 @@ static const int kMaxNameLength = 255;
// bytes (not counting the IP nor UDP headers).
static const int kMaxUDPSize = 512;
+// RFC 6762, section 17: Messages over the local link are restricted by the
+// medium's MTU, and must be under 9000 bytes
+static const int kMaxMulticastSize = 9000;
+
// DNS class types.
static const uint16 kClassIN = 1;
diff --git a/net/dns/dns_query.cc b/net/dns/dns_query.cc
index 72e97cf..270757e 100644
--- a/net/dns/dns_query.cc
+++ b/net/dns/dns_query.cc
@@ -80,4 +80,10 @@ DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) {
header->id = base::HostToNet16(id);
}
+void DnsQuery::set_flags(uint16 flags) {
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
+ header->flags = flags;
+}
+
} // namespace net
diff --git a/net/dns/dns_query.h b/net/dns/dns_query.h
index a2ed868..e1469bd 100644
--- a/net/dns/dns_query.h
+++ b/net/dns/dns_query.h
@@ -38,6 +38,8 @@ class NET_EXPORT_PRIVATE DnsQuery {
// IOBuffer accessor to be used for writing out the query.
IOBufferWithSize* io_buffer() const { return io_buffer_.get(); }
+ void set_flags(uint16 flags);
+
private:
DnsQuery(const DnsQuery& orig, uint16 id);
diff --git a/net/dns/dns_response.cc b/net/dns/dns_response.cc
index 7daf5ff..d29d3c4 100644
--- a/net/dns/dns_response.cc
+++ b/net/dns/dns_response.cc
@@ -231,6 +231,11 @@ unsigned DnsResponse::answer_count() const {
return base::NetToHost16(header()->ancount);
}
+unsigned DnsResponse::additional_answer_count() const {
+ DCHECK(parser_.IsValid());
+ return base::NetToHost16(header()->arcount);
+}
+
base::StringPiece DnsResponse::qname() const {
DCHECK(parser_.IsValid());
// The response is HEADER QNAME QTYPE QCLASS ANSWER.
diff --git a/net/dns/dns_response.h b/net/dns/dns_response.h
index 4a445d9..76c2215 100644
--- a/net/dns/dns_response.h
+++ b/net/dns/dns_response.h
@@ -130,7 +130,9 @@ class NET_EXPORT_PRIVATE DnsResponse {
// Accessors for the header.
uint16 flags() const; // excluding rcode
uint8 rcode() const;
+
unsigned answer_count() const;
+ unsigned additional_answer_count() const;
// Accessors to the question. The qname is unparsed.
base::StringPiece qname() const;
diff --git a/net/dns/mdns_client.cc b/net/dns/mdns_client.cc
new file mode 100644
index 0000000..d4cf470
--- /dev/null
+++ b/net/dns/mdns_client.cc
@@ -0,0 +1,26 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_client.h"
+
+#include "net/dns/mdns_client_impl.h"
+
+namespace net {
+
+static MDnsClient* g_instance = NULL;
+
+MDnsClient* MDnsClient::GetInstance() {
+ if (!g_instance) {
+ g_instance =
+ new MDnsClientImpl(MDnsConnection::SocketFactory::CreateDefault());
+ }
+
+ return g_instance;
+}
+
+void MDnsClient::SetInstance(MDnsClient* instance) {
+ g_instance = instance;
+}
+
+} // namespace net
diff --git a/net/dns/mdns_client.h b/net/dns/mdns_client.h
new file mode 100644
index 0000000..4e4255d
--- /dev/null
+++ b/net/dns/mdns_client.h
@@ -0,0 +1,156 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MDNS_CLIENT_H_
+#define NET_DNS_MDNS_CLIENT_H_
+
+#include <string>
+#include <vector>
+
+#include "base/callback.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/record_parsed.h"
+
+namespace net {
+
+class RecordParsed;
+
+// Represents a one-time record lookup. A transaction takes one
+// associated callback (see |MDnsClient::CreateTransaction|) and calls it
+// whenever a matching record has been found, either from the cache or
+// by querying the network (it may choose to query either or both based on its
+// creation flags, see MDnsTransactionFlags). Network-based transactions will
+// time out after a reasonable number of seconds.
+class MDnsTransaction {
+ public:
+ // Used to signify what type of result the transaction has recieved.
+ enum Result {
+ // Passed whenever a record is found.
+ RESULT_RECORD,
+ // The transaction is done. Applies to non-single-valued transactions. Is
+ // called when the transaction has finished (this is the last call to the
+ // callback).
+ RESULT_DONE,
+ // No results have been found. Applies to single-valued transactions. Is
+ // called when the transaction has finished without finding any results.
+ // For transactions that use the network, this happens when a timeout
+ // occurs, for transactions that are cache-only, this happens when no
+ // results are in the cache.
+ RESULT_NO_RESULTS,
+ // Called when an NSec record is read for this transaction's
+ // query. This means there cannot possibly be a record of the type
+ // and name for this transaction.
+ RESULT_NSEC
+ };
+
+ // Used when creating an MDnsTransaction.
+ enum Flags {
+ // Transaction should return only one result, and stop listening after it.
+ // Note that single result transactions will signal when their timeout is
+ // reached, whereas multi-result transactions will not.
+ SINGLE_RESULT = 1 << 0,
+ // Query the cache or the network. May both be used. One must be present.
+ QUERY_CACHE = 1 << 1,
+ QUERY_NETWORK = 1 << 2,
+ // TODO(noamsml): Add flag for flushing cache when feature is implemented
+ // Mask of all possible flags on MDnsTransaction.
+ FLAG_MASK = (1 << 3) - 1,
+ };
+
+ typedef base::Callback<void(Result, const RecordParsed*)>
+ ResultCallback;
+
+ // Destroying the transaction cancels it.
+ virtual ~MDnsTransaction() {}
+
+ // Start the transaction. Return true on success. Cache-based transactions
+ // will execute the callback synchronously.
+ virtual bool Start() = 0;
+
+ // Get the host or service name for the transaction.
+ virtual const std::string& GetName() const = 0;
+
+ // Get the type for this transaction (SRV, TXT, A, AAA, etc)
+ virtual uint16 GetType() const = 0;
+};
+
+// A listener listens for updates regarding a specific record or set of records.
+// Created by the MDnsClient (see |MDnsClient::CreateListener|) and used to keep
+// track of listeners.
+class MDnsListener {
+ public:
+ // Used in the MDnsListener delegate to signify what type of change has been
+ // made to a record.
+ enum UpdateType {
+ RECORD_ADDED,
+ RECORD_CHANGED,
+ RECORD_REMOVED
+ };
+
+ class Delegate {
+ public:
+ virtual ~Delegate() {}
+
+ // Called when a record is added, removed or updated.
+ virtual void OnRecordUpdate(UpdateType update,
+ const RecordParsed* record) = 0;
+
+ // Called when a record is marked nonexistent by an NSEC record.
+ virtual void OnNsecRecord(const std::string& name, unsigned type) = 0;
+
+ // Called when the cache is purged (due, for example, ot the network
+ // disconnecting).
+ virtual void OnCachePurged() = 0;
+ };
+
+ // Destroying the listener stops listening.
+ virtual ~MDnsListener() {}
+
+ // Start the listener. Return true on success.
+ virtual bool Start() = 0;
+
+ // Get the host or service name for this query.
+ // Return an empty string for no name.
+ virtual const std::string& GetName() const = 0;
+
+ // Get the type for this query (SRV, TXT, A, AAA, etc)
+ virtual uint16 GetType() const = 0;
+};
+
+// Listens for Multicast DNS on the local network. You can access information
+// regarding multicast DNS either by creating an |MDnsListener| to be notified
+// of new records, or by creating an |MDnsTransaction| to look up the value of a
+// specific records. When all listeners and active transactions are destroyed,
+// the client stops listening on the network and destroys the cache.
+class MDnsClient {
+ public:
+ virtual ~MDnsClient() {}
+
+ // Create listener object for RRType |rrtype| and name |name|. If |name| is
+ // an empty string, listen to all notification of type |rrtype|.
+ virtual scoped_ptr<MDnsListener> CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate) = 0;
+
+ // Create a transaction that can be used to query either the MDns cache, the
+ // network, or both for records of type |rrtype| and name |name|. |flags| is
+ // defined by MDnsTransactionFlags.
+ virtual scoped_ptr<MDnsTransaction> CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback) = 0;
+
+ // Lazily create and return static instance for MDnsClient.
+ static MDnsClient* GetInstance();
+
+ // Set the global instance (for testing). MUST be called before the first call
+ // to GetInstance.
+ static void SetInstance(MDnsClient* instance);
+};
+
+} // namespace net
+#endif // NET_DNS_MDNS_CLIENT_H_
diff --git a/net/dns/mdns_client_impl.cc b/net/dns/mdns_client_impl.cc
new file mode 100644
index 0000000..16852e9
--- /dev/null
+++ b/net/dns/mdns_client_impl.cc
@@ -0,0 +1,599 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_client_impl.h"
+
+#include "base/bind.h"
+#include "base/message_loop_proxy.h"
+#include "base/stl_util.h"
+#include "base/time/default_clock.h"
+#include "net/base/dns_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
+#include "net/dns/dns_protocol.h"
+#include "net/udp/datagram_socket.h"
+
+namespace net {
+
+namespace {
+const char kMDnsMulticastGroupIPv4[] = "224.0.0.251";
+const char kMDnsMulticastGroupIPv6[] = "FF02::FB";
+const unsigned MDnsTransactionTimeoutSeconds = 3;
+}
+
+MDnsConnection::SocketHandler::SocketHandler(
+ MDnsConnection* connection, const IPEndPoint& multicast_addr,
+ MDnsConnection::SocketFactory* socket_factory)
+ : socket_(socket_factory->CreateSocket()), connection_(connection),
+ response_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
+ multicast_addr_(multicast_addr) {
+}
+
+MDnsConnection::SocketHandler::~SocketHandler() {
+}
+
+int MDnsConnection::SocketHandler::Start() {
+ int rv = BindSocket();
+ if (rv != OK) {
+ return rv;
+ }
+
+ return DoLoop(0);
+}
+
+int MDnsConnection::SocketHandler::DoLoop(int rv) {
+ do {
+ if (rv > 0)
+ connection_->OnDatagramReceived(response_.get(), recv_addr_, rv);
+
+ rv = socket_->RecvFrom(
+ response_->io_buffer(),
+ response_->io_buffer()->size(),
+ &recv_addr_,
+ base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived,
+ base::Unretained(this)));
+ } while (rv > 0);
+
+ if (rv != ERR_IO_PENDING)
+ return rv;
+
+ return OK;
+}
+
+void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
+ if (rv >= OK)
+ rv = DoLoop(rv);
+
+ if (rv != OK)
+ connection_->OnError(this, rv);
+}
+
+int MDnsConnection::SocketHandler::Send(IOBuffer* buffer, unsigned size) {
+ return socket_->SendTo(
+ buffer, size, multicast_addr_,
+ base::Bind(&MDnsConnection::SocketHandler::SendDone,
+ base::Unretained(this) ));
+}
+
+void MDnsConnection::SocketHandler::SendDone(int rv) {
+ // TODO(noamsml): Retry logic.
+}
+
+int MDnsConnection::SocketHandler::BindSocket() {
+ IPAddressNumber address_any(multicast_addr_.address().size());
+
+ IPEndPoint bind_endpoint(address_any, multicast_addr_.port());
+
+ socket_->AllowAddressReuse();
+ int rv = socket_->Listen(bind_endpoint);
+
+ if (rv < OK) return rv;
+
+ socket_->SetMulticastLoopbackMode(false);
+
+ return socket_->JoinGroup(multicast_addr_.address());
+}
+
+MDnsConnection::MDnsConnection(MDnsConnection::SocketFactory* socket_factory,
+ MDnsConnection::Delegate* delegate)
+ : socket_handler_ipv4_(this,
+ GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4),
+ socket_factory),
+ socket_handler_ipv6_(this,
+ GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6),
+ socket_factory),
+ delegate_(delegate) {
+}
+
+MDnsConnection::~MDnsConnection() {
+}
+
+int MDnsConnection::Init() {
+ int rv;
+
+ rv = socket_handler_ipv4_.Start();
+ if (rv != OK) return rv;
+ rv = socket_handler_ipv6_.Start();
+ if (rv != OK) return rv;
+
+ return OK;
+}
+
+int MDnsConnection::Send(IOBuffer* buffer, unsigned size) {
+ int rv;
+
+ rv = socket_handler_ipv4_.Send(buffer, size);
+ if (rv < OK && rv != ERR_IO_PENDING) return rv;
+
+ rv = socket_handler_ipv6_.Send(buffer, size);
+ if (rv < OK && rv != ERR_IO_PENDING) return rv;
+
+ return OK;
+}
+
+void MDnsConnection::OnError(SocketHandler* loop,
+ int error) {
+ // TODO(noamsml): Specific handling of intermittent errors that can be handled
+ // in the connection.
+ delegate_->OnConnectionError(error);
+}
+
+IPEndPoint MDnsConnection::GetMDnsIPEndPoint(const char* address) {
+ IPAddressNumber multicast_group_number;
+ bool success = ParseIPLiteralToNumber(address,
+ &multicast_group_number);
+ DCHECK(success);
+ return IPEndPoint(multicast_group_number,
+ dns_protocol::kDefaultPortMulticast);
+}
+
+void MDnsConnection::OnDatagramReceived(
+ DnsResponse* response,
+ const IPEndPoint& recv_addr,
+ int bytes_read) {
+ // TODO(noamsml): More sophisticated error handling.
+ DCHECK_GT(bytes_read, 0);
+ delegate_->HandlePacket(response, bytes_read);
+}
+
+class MDnsConnectionSocketFactoryImpl
+ : public MDnsConnection::SocketFactory {
+ public:
+ MDnsConnectionSocketFactoryImpl();
+ virtual ~MDnsConnectionSocketFactoryImpl();
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE;
+};
+
+MDnsConnectionSocketFactoryImpl::MDnsConnectionSocketFactoryImpl() {
+}
+
+MDnsConnectionSocketFactoryImpl::~MDnsConnectionSocketFactoryImpl() {
+}
+
+scoped_ptr<DatagramServerSocket>
+MDnsConnectionSocketFactoryImpl::CreateSocket() {
+ return scoped_ptr<DatagramServerSocket>(new UDPServerSocket(
+ NULL, NetLog::Source()));
+}
+
+// static
+scoped_ptr<MDnsConnection::SocketFactory>
+MDnsConnection::SocketFactory::CreateDefault() {
+ return scoped_ptr<MDnsConnection::SocketFactory>(
+ new MDnsConnectionSocketFactoryImpl);
+}
+
+MDnsClientImpl::Core::Core(MDnsClientImpl* client,
+ MDnsConnection::SocketFactory* socket_factory)
+ : client_(client), connection_(new MDnsConnection(socket_factory, this)) {
+}
+
+MDnsClientImpl::Core::~Core() {
+ STLDeleteValues(&listeners_);
+}
+
+bool MDnsClientImpl::Core::Init() {
+ return connection_->Init() == OK;
+}
+
+bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) {
+ std::string name_dns;
+ if (!DNSDomainFromDot(name, &name_dns))
+ return false;
+
+ DnsQuery query(0, name_dns, rrtype);
+ query.set_flags(0); // Remove the RD flag from the query. It is unneeded.
+
+ return connection_->Send(query.io_buffer(), query.io_buffer()->size()) == OK;
+}
+
+void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
+ int bytes_read) {
+ unsigned offset;
+
+ if (!response->InitParseWithoutQuery(bytes_read)) {
+ LOG(WARNING) << "Could not understand an mDNS packet.";
+ return; // Message is unreadable.
+ }
+
+ // TODO(noamsml): duplicate query suppression.
+ if (!(response->flags() & dns_protocol::kFlagResponse))
+ return; // Message is a query. ignore it.
+
+ DnsRecordParser parser = response->Parser();
+ unsigned answer_count = response->answer_count() +
+ response->additional_answer_count();
+
+ for (unsigned i = 0; i < answer_count; i++) {
+ offset = parser.GetOffset();
+ scoped_ptr<const RecordParsed> scoped_record = RecordParsed::CreateFrom(
+ &parser, base::Time::Now());
+
+ if (!scoped_record) {
+ LOG(WARNING) << "Could not understand an mDNS record.";
+
+ if (offset == parser.GetOffset()) {
+ LOG(WARNING) << "Abandoned parsing the rest of the packet.";
+ return; // The parser did not advance, abort reading the packet.
+ } else {
+ continue; // We may be able to extract other records from the packet.
+ }
+ }
+
+ if ((scoped_record->klass() & dns_protocol::kMDnsClassMask) !=
+ dns_protocol::kClassIN) {
+ LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring.";
+ continue; // Ignore all records not in the IN class.
+ }
+
+ // We want to retain a copy of the record pointer for updating listeners
+ // but we are passing ownership to the cache.
+ const RecordParsed* record = scoped_record.get();
+ MDnsCache::UpdateType update = cache_.UpdateDnsRecord(scoped_record.Pass());
+
+ // Cleanup time may have changed.
+ ScheduleCleanup(cache_.next_expiration());
+
+ if (update != MDnsCache::NoChange) {
+ MDnsListener::UpdateType update_external;
+
+ switch (update) {
+ case MDnsCache::RecordAdded:
+ update_external = MDnsListener::RECORD_ADDED;
+ break;
+ case MDnsCache::RecordChanged:
+ update_external = MDnsListener::RECORD_CHANGED;
+ break;
+ case MDnsCache::NoChange:
+ default:
+ NOTREACHED();
+ // Dummy assignment to suppress compiler warning.
+ update_external = MDnsListener::RECORD_CHANGED;
+ break;
+ }
+
+ AlertListeners(update_external,
+ ListenerKey(record->type(), record->name()), record);
+ // Alert listeners listening only for rrtype and not for name.
+ AlertListeners(update_external, ListenerKey(record->type(), ""), record);
+ }
+ }
+}
+
+void MDnsClientImpl::Core::OnConnectionError(int error) {
+ // TODO(noamsml): On connection error, recreate connection and flush cache.
+}
+
+void MDnsClientImpl::Core::AlertListeners(
+ MDnsListener::UpdateType update_type,
+ const ListenerKey& key,
+ const RecordParsed* record) {
+ ListenerMap::iterator listener_map_iterator = listeners_.find(key);
+ if (listener_map_iterator == listeners_.end()) return;
+
+ FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second,
+ AlertDelegate(update_type, record));
+}
+
+void MDnsClientImpl::Core::AddListener(
+ MDnsListenerImpl* listener) {
+ ListenerKey key(listener->GetType(), listener->GetName());
+ std::pair<ListenerMap::iterator, bool> observer_insert_result =
+ listeners_.insert(
+ make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL)));
+
+ // If an equivalent key does not exist, actually create the observer list.
+ if (observer_insert_result.second)
+ observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>();
+
+ ObserverList<MDnsListenerImpl>* observer_list =
+ observer_insert_result.first->second;
+
+ observer_list->AddObserver(listener);
+}
+
+void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
+ ListenerKey key(listener->GetType(), listener->GetName());
+ ListenerMap::iterator observer_list_iterator = listeners_.find(key);
+
+ DCHECK(observer_list_iterator != listeners_.end());
+ DCHECK(observer_list_iterator->second->HasObserver(listener));
+
+ observer_list_iterator->second->RemoveObserver(listener);
+
+ // Remove the observer list from the map if it is empty
+ if (observer_list_iterator->second->size() == 0) {
+ delete observer_list_iterator->second;
+ listeners_.erase(observer_list_iterator);
+ }
+}
+
+void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
+ // Cleanup is already scheduled, no need to do anything.
+ if (cleanup == scheduled_cleanup_) return;
+ scheduled_cleanup_ = cleanup;
+
+ // This cancels the previously scheduled cleanup.
+ cleanup_callback_.Reset(base::Bind(
+ &MDnsClientImpl::Core::DoCleanup, base::Unretained(this)));
+
+ // If |cleanup| is empty, then no cleanup necessary.
+ if (cleanup != base::Time()) {
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ cleanup_callback_.callback(),
+ cleanup - base::Time::Now());
+ }
+}
+
+void MDnsClientImpl::Core::DoCleanup() {
+ cache_.CleanupRecords(base::Time::Now(), base::Bind(
+ &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this)));
+
+ ScheduleCleanup(cache_.next_expiration());
+}
+
+void MDnsClientImpl::Core::OnRecordRemoved(
+ const RecordParsed* record) {
+ AlertListeners(MDnsListener::RECORD_REMOVED,
+ ListenerKey(record->type(), record->name()), record);
+ // Alert listeners listening only for rrtype and not for name.
+ AlertListeners(MDnsListener::RECORD_REMOVED, ListenerKey(record->type(), ""),
+ record);
+}
+
+void MDnsClientImpl::Core::QueryCache(
+ uint16 rrtype, const std::string& name,
+ std::vector<const RecordParsed*>* records) const {
+ cache_.FindDnsRecords(rrtype, name, records, base::Time::Now());
+}
+
+MDnsClientImpl::MDnsClientImpl(
+ scoped_ptr<MDnsConnection::SocketFactory> socket_factory)
+ : listen_refs_(0), socket_factory_(socket_factory.Pass()) {
+}
+
+MDnsClientImpl::~MDnsClientImpl() {
+}
+
+bool MDnsClientImpl::AddListenRef() {
+ if (!core_.get()) {
+ core_.reset(new Core(this, socket_factory_.get()));
+ if (!core_->Init()) {
+ core_.reset();
+ return false;
+ }
+ }
+ listen_refs_++;
+ return true;
+}
+
+void MDnsClientImpl::SubtractListenRef() {
+ listen_refs_--;
+ if (listen_refs_ == 0) {
+ base::MessageLoop::current()->PostTask(FROM_HERE, base::Bind(
+ &MDnsClientImpl::Shutdown, base::Unretained(this)));
+ }
+}
+
+void MDnsClientImpl::Shutdown() {
+ // We need to check that new listeners haven't been created.
+ if (listen_refs_ == 0) {
+ core_.reset();
+ }
+}
+
+bool MDnsClientImpl::IsListeningForTests() {
+ return core_.get() != NULL;
+}
+
+scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate) {
+ return scoped_ptr<net::MDnsListener>(
+ new MDnsListenerImpl(rrtype, name, delegate, this));
+}
+
+scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback) {
+ return scoped_ptr<MDnsTransaction>(
+ new MDnsTransactionImpl(rrtype, name, flags, callback, this));
+}
+
+MDnsListenerImpl::MDnsListenerImpl(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate,
+ MDnsClientImpl* client)
+ : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate),
+ started_(false) {
+}
+
+bool MDnsListenerImpl::Start() {
+ DCHECK(!started_);
+
+ if (!client_->AddListenRef()) return false;
+ started_ = true;
+
+ DCHECK(client_->core());
+ client_->core()->AddListener(this);
+
+ return true;
+}
+
+MDnsListenerImpl::~MDnsListenerImpl() {
+ if (started_) {
+ DCHECK(client_->core());
+ client_->core()->RemoveListener(this);
+ client_->SubtractListenRef();
+ }
+}
+
+const std::string& MDnsListenerImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsListenerImpl::GetType() const {
+ return rrtype_;
+}
+
+void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type,
+ const RecordParsed* record) {
+ DCHECK(started_);
+ delegate_->OnRecordUpdate(update_type, record);
+}
+
+MDnsTransactionImpl::MDnsTransactionImpl(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback,
+ MDnsClientImpl* client)
+ : rrtype_(rrtype), name_(name), callback_(callback), client_(client),
+ started_(false), flags_(flags) {
+ DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
+ DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
+ flags_ & MDnsTransaction::QUERY_NETWORK);
+}
+
+MDnsTransactionImpl::~MDnsTransactionImpl() {
+ timeout_.Cancel();
+}
+
+bool MDnsTransactionImpl::Start() {
+ DCHECK(!started_);
+ started_ = true;
+ std::vector<const RecordParsed*> records;
+ base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
+
+ if (flags_ & MDnsTransaction::QUERY_CACHE) {
+ if (client_->core()) {
+ client_->core()->QueryCache(rrtype_, name_, &records);
+ for (std::vector<const RecordParsed*>::iterator i = records.begin();
+ i != records.end() && weak_this; ++i) {
+ weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD,
+ records.front());
+ }
+ }
+ }
+
+ if (!weak_this) return true;
+
+ if (is_active() && (flags_ & MDnsTransaction::QUERY_NETWORK)) {
+ listener_ = client_->CreateListener(rrtype_, name_, this);
+ if (!listener_->Start()) return false;
+
+ DCHECK(client_->core());
+ if (!client_->core()->SendQuery(rrtype_, name_))
+ return false;
+
+ timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver,
+ weak_this));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ timeout_.callback(),
+ base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds));
+
+ return listener_.get() != NULL;
+ } else {
+ // If this is a cache only query, signal that the transaction is over
+ // immediately.
+ SignalTransactionOver();
+ }
+
+ return true;
+}
+
+const std::string& MDnsTransactionImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsTransactionImpl::GetType() const {
+ return rrtype_;
+}
+
+void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
+ DCHECK(started_);
+ OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
+}
+
+void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
+ const RecordParsed* record) {
+ DCHECK(started_);
+ if (!is_active()) return;
+
+ // Ensure callback is run after touching all class state, so that
+ // the callback can delete the transaction.
+ MDnsTransaction::ResultCallback callback = callback_;
+
+ if (flags_ & MDnsTransaction::SINGLE_RESULT)
+ Reset();
+
+ callback.Run(result, record);
+}
+
+void MDnsTransactionImpl::Reset() {
+ callback_.Reset();
+ listener_.reset();
+ timeout_.Cancel();
+}
+
+void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
+ const RecordParsed* record) {
+ DCHECK(started_);
+ if (update == MDnsListener::RECORD_ADDED ||
+ update == MDnsListener::RECORD_CHANGED)
+ TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
+}
+
+void MDnsTransactionImpl::SignalTransactionOver() {
+ DCHECK(started_);
+ base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
+
+ if (flags_ & MDnsTransaction::SINGLE_RESULT) {
+ TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL);
+ } else {
+ TriggerCallback(MDnsTransaction::RESULT_DONE, NULL);
+ }
+
+ if (weak_this) {
+ weak_this->Reset();
+ }
+}
+
+void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
+ // TODO(noamsml): NSEC records not yet implemented
+}
+
+void MDnsTransactionImpl::OnCachePurged() {
+ // TODO(noamsml): Cache purge situations not yet implemented
+}
+
+} // namespace net
diff --git a/net/dns/mdns_client_impl.h b/net/dns/mdns_client_impl.h
new file mode 100644
index 0000000..5a22894
--- /dev/null
+++ b/net/dns/mdns_client_impl.h
@@ -0,0 +1,288 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MDNS_CLIENT_IMPL_H_
+#define NET_DNS_MDNS_CLIENT_IMPL_H_
+
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "base/cancelable_callback.h"
+#include "base/observer_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/dns/mdns_cache.h"
+#include "net/dns/mdns_client.h"
+#include "net/udp/datagram_server_socket.h"
+#include "net/udp/udp_server_socket.h"
+#include "net/udp/udp_socket.h"
+
+namespace net {
+
+// A connection to the network for multicast DNS clients. It reads data into
+// DnsResponse objects and alerts the delegate that a packet has been received.
+class MDnsConnection {
+ public:
+ class SocketFactory {
+ public:
+ virtual ~SocketFactory() {}
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() = 0;
+
+ static scoped_ptr<SocketFactory> CreateDefault();
+ };
+
+ class Delegate {
+ public:
+ // Handle an mDNS packet buffered in |response| with a size of |bytes_read|.
+ virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0;
+ virtual void OnConnectionError(int error) = 0;
+ virtual ~Delegate() {}
+ };
+
+ explicit MDnsConnection(SocketFactory* socket_factory,
+ MDnsConnection::Delegate* delegate);
+
+ virtual ~MDnsConnection();
+
+ int Init();
+ int Send(IOBuffer* buffer, unsigned size);
+
+ private:
+ class SocketHandler {
+ public:
+ SocketHandler(MDnsConnection* connection,
+ const IPEndPoint& multicast_addr,
+ SocketFactory* socket_factory);
+ ~SocketHandler();
+ int DoLoop(int rv);
+ int Start();
+
+ int Send(IOBuffer* buffer, unsigned size);
+
+ private:
+ int BindSocket();
+ void OnDatagramReceived(int rv);
+
+ // Callback for when sending a query has finished.
+ void SendDone(int rv);
+
+ scoped_ptr<DatagramServerSocket> socket_;
+
+ MDnsConnection* connection_;
+ IPEndPoint recv_addr_;
+ scoped_ptr<DnsResponse> response_;
+ IPEndPoint multicast_addr_;
+ };
+
+ // Callback for handling a datagram being received on either ipv4 or ipv6.
+ void OnDatagramReceived(DnsResponse* response,
+ const IPEndPoint& recv_addr,
+ int bytes_read);
+
+ void OnError(SocketHandler* loop, int error);
+
+ IPEndPoint GetMDnsIPEndPoint(const char* address);
+
+ SocketHandler socket_handler_ipv4_;
+ SocketHandler socket_handler_ipv6_;
+
+ Delegate* delegate_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsConnection);
+};
+
+class MDnsListenerImpl;
+
+class MDnsClientImpl : public MDnsClient {
+ public:
+ // The core object exists while the MDnsClient is listening, and is
+ // deleted whenever the number of listeners reaches zero.
+ class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate {
+ public:
+ Core(MDnsClientImpl* client,
+ MDnsConnection::SocketFactory* socket_factory);
+ virtual ~Core();
+
+ // Initialize the core. Returns true on success.
+ bool Init();
+
+ // Send a query with a specific rrtype and name. Returns true on success.
+ bool SendQuery(uint16 rrtype, std::string name);
+
+ // Add/remove a listener to the list of listener. May cause network traffic
+ // if listener is active.
+ void AddListener(MDnsListenerImpl* listener);
+ void RemoveListener(MDnsListenerImpl* listener);
+
+ // Query the cache for records of a specific type and name.
+ void QueryCache(uint16 rrtype, const std::string& name,
+ std::vector<const RecordParsed*>* records) const;
+
+ // Parse the response and alert relevant listeners.
+ virtual void HandlePacket(DnsResponse* response, int bytes_read) OVERRIDE;
+
+ virtual void OnConnectionError(int error) OVERRIDE;
+
+ private:
+ typedef std::pair<uint16, std::string> ListenerKey;
+ typedef std::map<ListenerKey, ObserverList<MDnsListenerImpl>* >
+ ListenerMap;
+
+ // Alert listeners of an update to the cache.
+ void AlertListeners(MDnsListener::UpdateType update_type,
+ const ListenerKey& key, const RecordParsed* record);
+
+ // Schedule a cache cleanup to a specific time, cancelling other cleanups.
+ void ScheduleCleanup(base::Time cleanup);
+
+ // Clean up the cache and schedule a new cleanup.
+ void DoCleanup();
+
+ // Callback for when a record is removed from the cache.
+ void OnRecordRemoved(const RecordParsed* record);
+
+ ListenerMap listeners_;
+
+ MDnsClientImpl* client_;
+ MDnsCache cache_;
+
+ base::CancelableCallback<void()> cleanup_callback_;
+ base::Time scheduled_cleanup_;
+
+ scoped_ptr<MDnsConnection> connection_;
+
+ DISALLOW_COPY_AND_ASSIGN(Core);
+ };
+
+ explicit MDnsClientImpl(
+ scoped_ptr<MDnsConnection::SocketFactory> socket_factory_);
+ virtual ~MDnsClientImpl();
+
+ // MDnsClient implementation:
+ virtual scoped_ptr<MDnsListener> CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate) OVERRIDE;
+
+ virtual scoped_ptr<MDnsTransaction> CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback) OVERRIDE;
+
+ // Returns true when the client is listening for network packets.
+ bool IsListeningForTests();
+
+ bool AddListenRef();
+ void SubtractListenRef();
+
+ Core* core() { return core_.get(); }
+
+ private:
+ // This method causes the client to stop listening for packets. The
+ // call for it is deferred through the message loop after the last
+ // listener is removed. If another listener is added after a
+ // shutdown is scheduled but before it actually runs, the shutdown
+ // will be canceled.
+ void Shutdown();
+
+ scoped_ptr<Core> core_;
+ int listen_refs_;
+
+ scoped_ptr<MDnsConnection::SocketFactory> socket_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsClientImpl);
+};
+
+class MDnsListenerImpl : public MDnsListener,
+ public base::SupportsWeakPtr<MDnsListenerImpl> {
+ public:
+ MDnsListenerImpl(uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate,
+ MDnsClientImpl* client);
+
+ virtual ~MDnsListenerImpl();
+
+ // MDnsListener implementation:
+ virtual bool Start() OVERRIDE;
+
+ virtual const std::string& GetName() const OVERRIDE;
+
+ virtual uint16 GetType() const OVERRIDE;
+
+ MDnsListener::Delegate* delegate() { return delegate_; }
+
+ // Alert the delegate of a record update.
+ void AlertDelegate(MDnsListener::UpdateType update_type,
+ const RecordParsed* record_parsed);
+ private:
+ uint16 rrtype_;
+ std::string name_;
+ MDnsClientImpl* client_;
+ MDnsListener::Delegate* delegate_;
+
+ bool started_;
+ DISALLOW_COPY_AND_ASSIGN(MDnsListenerImpl);
+};
+
+class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>,
+ public MDnsTransaction,
+ public MDnsListener::Delegate {
+ public:
+ MDnsTransactionImpl(uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback,
+ MDnsClientImpl* client);
+ virtual ~MDnsTransactionImpl();
+
+ // MDnsTransaction implementation:
+ virtual bool Start() OVERRIDE;
+
+ virtual const std::string& GetName() const OVERRIDE;
+ virtual uint16 GetType() const OVERRIDE;
+
+ // MDnsListener::Delegate implementation:
+ virtual void OnRecordUpdate(MDnsListener::UpdateType update,
+ const RecordParsed* record) OVERRIDE;
+ virtual void OnNsecRecord(const std::string& name, unsigned type) OVERRIDE;
+
+ virtual void OnCachePurged() OVERRIDE;
+
+ private:
+ bool is_active() { return !callback_.is_null(); }
+
+ void Reset();
+
+ // Trigger the callback and reset all related variables.
+ void TriggerCallback(MDnsTransaction::Result result,
+ const RecordParsed* record);
+
+ // Internal callback for when a cache record is found.
+ void CacheRecordFound(const RecordParsed* record);
+
+ // Signal the transactionis over and release all related resources.
+ void SignalTransactionOver();
+
+ uint16 rrtype_;
+ std::string name_;
+ MDnsTransaction::ResultCallback callback_;
+
+ scoped_ptr<MDnsListener> listener_;
+ base::CancelableCallback<void()> timeout_;
+
+ MDnsClientImpl* client_;
+
+ bool started_;
+ int flags_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsTransactionImpl);
+};
+
+} // namespace net
+#endif // NET_DNS_MDNS_CLIENT_IMPL_H_
diff --git a/net/dns/mdns_client_unittest.cc b/net/dns/mdns_client_unittest.cc
new file mode 100644
index 0000000..8f0b45a
--- /dev/null
+++ b/net/dns/mdns_client_unittest.cc
@@ -0,0 +1,1070 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <queue>
+
+#include "base/memory/ref_counted.h"
+#include "base/message_loop.h"
+#include "net/base/rand_callback.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mdns_client_impl.h"
+#include "net/dns/record_rdata.h"
+#include "net/udp/udp_client_socket.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::StrictMock;
+using ::testing::NiceMock;
+using ::testing::Exactly;
+using ::testing::Return;
+using ::testing::SaveArg;
+using ::testing::_;
+
+namespace net {
+
+namespace {
+
+const char kSamplePacket1[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 1 second;
+ 0x00, 0x01,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',
+ 0xc0, 0x14, // Pointer to "._tcp.local"
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
+ 0x24, 0x75,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x32
+};
+
+const char kCorruptedPacketBadQuestion[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x01, // One question
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question is corrupted and cannot be read.
+ 0x99, 'h', 'e', 'l', 'l', 'o',
+ 0x00,
+ 0x00, 0x00,
+ 0x00, 0x00,
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x99, // RDLENGTH is impossible
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', // Useless trailing data.
+};
+
+const char kCorruptedPacketUnsalvagable[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x99, // RDLENGTH is impossible
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', // Useless trailing data.
+};
+
+const char kCorruptedPacketSalvagable[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x99, 'h', 'e', 'l', 'l', 'o', // Bad RDATA format.
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',
+ 0xc0, 0x14, // Pointer to "._tcp.local"
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
+ 0x24, 0x75,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x32
+};
+
+const char kSamplePacket2[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'z', 'z', 'z', 'z', 'z',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',
+ 0xc0, 0x14, // Pointer to "._tcp.local"
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'z', 'z', 'z', 'z', 'z',
+ 0xc0, 0x32
+};
+
+const char kQueryPacketPrivet[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x00, 0x00, // No flags.
+ 0x00, 0x01, // One question.
+ 0x00, 0x00, // 0 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question
+ // This part is echoed back from the respective query.
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+};
+
+const char kSamplePacketAdditionalOnly[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x00, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x01, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+};
+
+class MockDatagramServerSocket : public DatagramServerSocket {
+ public:
+ // DatagramServerSocket implementation:
+ int Listen(const IPEndPoint& address) {
+ return ListenInternal(address.ToString());
+ }
+
+ MOCK_METHOD1(ListenInternal, int(const std::string& address));
+
+ MOCK_METHOD4(RecvFrom, int(IOBuffer* buffer, int size,
+ IPEndPoint* address,
+ const CompletionCallback& callback));
+
+ int SendTo(IOBuffer* buf, int buf_len, const IPEndPoint& address,
+ const CompletionCallback& callback) {
+ return SendToInternal(std::string(buf->data(), buf_len), address.ToString(),
+ callback);
+ }
+
+ MOCK_METHOD3(SendToInternal, int(const std::string& packet,
+ const std::string address,
+ const CompletionCallback& callback));
+
+ MOCK_METHOD1(SetReceiveBufferSize, bool(int32 size));
+ MOCK_METHOD1(SetSendBufferSize, bool(int32 size));
+
+ MOCK_METHOD0(Close, void());
+
+ MOCK_CONST_METHOD1(GetPeerAddress, int(IPEndPoint* address));
+ MOCK_CONST_METHOD1(GetLocalAddress, int(IPEndPoint* address));
+ MOCK_CONST_METHOD0(NetLog, const BoundNetLog&());
+
+ MOCK_METHOD0(AllowAddressReuse, void());
+ MOCK_METHOD0(AllowBroadcast, void());
+
+ int JoinGroup(const IPAddressNumber& group_address) const {
+ return JoinGroupInternal(IPAddressToString(group_address));
+ }
+
+ MOCK_CONST_METHOD1(JoinGroupInternal, int(const std::string& group));
+
+ int LeaveGroup(const IPAddressNumber& group_address) const {
+ return LeaveGroupInternal(IPAddressToString(group_address));
+ }
+
+ MOCK_CONST_METHOD1(LeaveGroupInternal, int(const std::string& group));
+
+ MOCK_METHOD1(SetMulticastTimeToLive, int(int ttl));
+
+ MOCK_METHOD1(SetMulticastLoopbackMode, int(bool loopback));
+
+ void SetResponsePacket(std::string response_packet) {
+ response_packet_ = response_packet;
+ }
+
+ int HandleRecvNow(IOBuffer* buffer, int size, IPEndPoint* address,
+ const CompletionCallback& callback) {
+ int size_returned =
+ std::min(response_packet_.size(), static_cast<size_t>(size));
+ memcpy(buffer->data(), response_packet_.data(), size_returned);
+ return size_returned;
+ }
+
+ int HandleRecvLater(IOBuffer* buffer, int size, IPEndPoint* address,
+ const CompletionCallback& callback) {
+ int rv = HandleRecvNow(buffer, size, address, callback);
+ base::MessageLoop::current()->PostTask(FROM_HERE, base::Bind(callback, rv));
+ return ERR_IO_PENDING;
+ }
+
+ private:
+ std::string response_packet_;
+};
+
+class MockSocketFactory
+ : public MDnsConnection::SocketFactory {
+ public:
+ MockSocketFactory() {
+ }
+
+ virtual ~MockSocketFactory() {
+ }
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE {
+ scoped_ptr<MockDatagramServerSocket> new_socket(
+ new NiceMock<MockDatagramServerSocket>);
+
+ ON_CALL(*new_socket, SendToInternal(_, _, _))
+ .WillByDefault(Invoke(
+ this,
+ &MockSocketFactory::SendToInternal));
+
+ ON_CALL(*new_socket, RecvFrom(_, _, _, _))
+ .WillByDefault(Invoke(
+ this,
+ &MockSocketFactory::RecvFromInternal));
+
+ return new_socket.PassAs<DatagramServerSocket>();
+ }
+
+ void SimulateReceive(const char* packet, int size) {
+ DCHECK(recv_buffer_size_ >= size);
+ DCHECK(recv_buffer_.get());
+ DCHECK(!recv_callback_.is_null());
+
+ memcpy(recv_buffer_->data(), packet, size);
+ CompletionCallback recv_callback = recv_callback_;
+ recv_callback_.Reset();
+ recv_callback.Run(size);
+ }
+
+ MOCK_METHOD1(OnSendTo, void(const std::string&));
+
+ private:
+ int SendToInternal(const std::string& packet, const std::string& address,
+ const CompletionCallback& callback) {
+ OnSendTo(packet);
+ return packet.size();
+ }
+
+ // The latest receive callback is always saved, since the MDnsConnection
+ // does not care which socket a packet is received on.
+ int RecvFromInternal(IOBuffer* buffer, int size,
+ IPEndPoint* address,
+ const CompletionCallback& callback) {
+ recv_buffer_ = buffer;
+ recv_buffer_size_ = size;
+ recv_callback_ = callback;
+ return ERR_IO_PENDING;
+ }
+
+ scoped_refptr<IOBuffer> recv_buffer_;
+ int recv_buffer_size_;
+ CompletionCallback recv_callback_;
+};
+
+class PtrRecordCopyContainer {
+ public:
+ PtrRecordCopyContainer() {}
+ ~PtrRecordCopyContainer() {}
+
+ bool is_set() const { return set_; }
+
+ void SaveWithDummyArg(int unused, const RecordParsed* value) {
+ Save(value);
+ }
+
+ void Save(const RecordParsed* value) {
+ set_ = true;
+ name_ = value->name();
+ ptrdomain_ = value->rdata<PtrRecordRdata>()->ptrdomain();
+ ttl_ = value->ttl();
+ }
+
+ bool IsRecordWith(std::string name, std::string ptrdomain) {
+ return set_ && name_ == name && ptrdomain_ == ptrdomain;
+ }
+
+ const std::string& name() { return name_; }
+ const std::string& ptrdomain() { return ptrdomain_; }
+ int ttl() { return ttl_; }
+
+ private:
+ bool set_;
+ std::string name_;
+ std::string ptrdomain_;
+ int ttl_;
+};
+
+class MDnsTest : public ::testing::Test {
+ public:
+ MDnsTest();
+ virtual ~MDnsTest();
+ void DeleteTransaction();
+ void DeleteBothListeners();
+ void RunUntilIdle();
+ void RunFor(base::TimeDelta time_period);
+ void Stop();
+
+ MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result,
+ const RecordParsed* record));
+
+ protected:
+ void ExpectPacket(const char* packet, unsigned size);
+ void SimulatePacketReceive(const char* packet, unsigned size);
+
+ base::MessageLoop* message_loop_;
+
+ scoped_ptr<MDnsClientImpl> test_client_;
+ IPEndPoint mdns_ipv4_endpoint_;
+ StrictMock<MockSocketFactory>* socket_factory_;
+
+ // Transactions and listeners that can be deleted by class methods for
+ // reentrancy tests.
+ scoped_ptr<MDnsTransaction> transaction_;
+ scoped_ptr<MDnsListener> listener1_;
+ scoped_ptr<MDnsListener> listener2_;
+};
+
+class MockListenerDelegate : public MDnsListener::Delegate {
+ public:
+ MOCK_METHOD2(OnRecordUpdate,
+ void(MDnsListener::UpdateType update,
+ const RecordParsed* records));
+ MOCK_METHOD2(OnNsecRecord, void(const std::string&, unsigned));
+ MOCK_METHOD0(OnCachePurged, void());
+};
+
+MDnsTest::MDnsTest()
+ : message_loop_(base::MessageLoop::current()) {
+ socket_factory_ = new StrictMock<MockSocketFactory>();
+ test_client_.reset(new MDnsClientImpl(
+ scoped_ptr<MDnsConnection::SocketFactory>(socket_factory_)));
+}
+
+MDnsTest::~MDnsTest() {
+}
+
+void MDnsTest::SimulatePacketReceive(const char* packet, unsigned size) {
+ socket_factory_->SimulateReceive(packet, size);
+}
+
+void MDnsTest::ExpectPacket(
+ const char* packet,
+ unsigned size) {
+ EXPECT_CALL(*socket_factory_, OnSendTo(std::string(packet, size)))
+ .Times(2);
+}
+
+void MDnsTest::DeleteTransaction() {
+ transaction_.reset();
+}
+
+void MDnsTest::DeleteBothListeners() {
+ listener1_.reset();
+ listener2_.reset();
+}
+
+void MDnsTest::RunUntilIdle() {
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+void MDnsTest::RunFor(base::TimeDelta time_period) {
+ base::CancelableCallback<void()> callback(base::Bind(&MDnsTest::Stop,
+ base::Unretained(this)));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, callback.callback(), time_period);
+
+ base::MessageLoop::current()->Run();
+ callback.Cancel();
+}
+
+void MDnsTest::Stop() {
+ base::MessageLoop::current()->Quit();
+}
+
+TEST_F(MDnsTest, PassiveListeners) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+ StrictMock<MockListenerDelegate> delegate_printer;
+ StrictMock<MockListenerDelegate> delegate_ptr;
+
+ PtrRecordCopyContainer record_privet;
+ PtrRecordCopyContainer record_printer;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
+ scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
+ scoped_ptr<MDnsListener> listener_ptr = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "", &delegate_ptr);
+
+ ASSERT_TRUE(listener_privet->Start());
+ ASSERT_TRUE(listener_printer->Start());
+ ASSERT_TRUE(listener_ptr->Start());
+
+ ASSERT_TRUE(test_client_->IsListeningForTests());
+
+ // Send the same packet twice to ensure no records are double-counted.
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_printer,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ EXPECT_CALL(delegate_ptr, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(2));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
+ "hello._printer._tcp.local"));
+
+ listener_privet.reset();
+ listener_printer.reset();
+
+ ASSERT_TRUE(test_client_->IsListeningForTests());
+
+ EXPECT_CALL(delegate_ptr, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(2));
+
+ SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
+
+ // Test to make sure mdns listener is not active with no listeners present.
+ listener_ptr.reset();
+
+ RunUntilIdle();
+
+ ASSERT_FALSE(test_client_->IsListeningForTests());
+}
+
+TEST_F(MDnsTest, PassiveListenersCacheCleanup) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ PtrRecordCopyContainer record_privet;
+ PtrRecordCopyContainer record_privet2;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ ASSERT_TRUE(test_client_->IsListeningForTests());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ // Expect record is removed when its TTL expires.
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
+ .Times(Exactly(1))
+ .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::Stop),
+ Invoke(&record_privet2,
+ &PtrRecordCopyContainer::SaveWithDummyArg)));
+
+ RunFor(base::TimeDelta::FromSeconds(record_privet.ttl() + 1));
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, MalformedPacket) {
+ StrictMock<MockListenerDelegate> delegate_printer;
+
+ PtrRecordCopyContainer record_printer;
+
+ scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
+
+ ASSERT_TRUE(listener_printer->Start());
+
+ ASSERT_TRUE(test_client_->IsListeningForTests());
+
+ EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_printer,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ // First, send unsalvagable packet to ensure we can deal with it.
+ SimulatePacketReceive(kCorruptedPacketUnsalvagable,
+ sizeof(kCorruptedPacketUnsalvagable));
+
+ // Regression test: send a packet where the question cannot be read.
+ SimulatePacketReceive(kCorruptedPacketBadQuestion,
+ sizeof(kCorruptedPacketBadQuestion));
+
+ // Then send salvagable packet to ensure we can extract useful records.
+ SimulatePacketReceive(kCorruptedPacketSalvagable,
+ sizeof(kCorruptedPacketSalvagable));
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
+ "hello._printer._tcp.local"));
+}
+
+TEST_F(MDnsTest, TransactionWithEmptyCache) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_TRUE(test_client_->IsListeningForTests());
+
+ PtrRecordCopyContainer record_privet;
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(&record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ EXPECT_FALSE(test_client_->IsListeningForTests());
+}
+
+TEST_F(MDnsTest, TransactionCacheOnlyNoResult) {
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*this,
+ MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, _))
+ .Times(Exactly(1));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_FALSE(test_client_->IsListeningForTests());
+
+ RunUntilIdle();
+}
+
+TEST_F(MDnsTest, TransactionWithCache) {
+ // Listener to force the client to listen
+ StrictMock<MockListenerDelegate> delegate_irrelevant;
+ scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener(
+ dns_protocol::kTypeA, "codereview.chromium.local",
+ &delegate_irrelevant);
+
+ ASSERT_TRUE(listener_irrelevant->Start());
+
+ EXPECT_TRUE(test_client_->IsListeningForTests());
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ RunUntilIdle();
+
+ PtrRecordCopyContainer record_privet;
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .WillOnce(Invoke(&record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, AdditionalRecords) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ PtrRecordCopyContainer record_privet;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_privet);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ ASSERT_TRUE(test_client_->IsListeningForTests());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacketAdditionalOnly, sizeof(kSamplePacket1));
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, TransactionTimeout) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_TRUE(test_client_->IsListeningForTests());
+
+ EXPECT_CALL(*this,
+ MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+
+ EXPECT_FALSE(test_client_->IsListeningForTests());
+}
+
+TEST_F(MDnsTest, TransactionMultipleRecords) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE ,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_TRUE(test_client_->IsListeningForTests());
+
+ PtrRecordCopyContainer record_privet;
+ PtrRecordCopyContainer record_privet2;
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .Times(Exactly(2))
+ .WillOnce(Invoke(&record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg))
+ .WillOnce(Invoke(&record_privet2,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+ SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
+
+ RunUntilIdle();
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
+ "zzzzz._privet._tcp.local"));
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_DONE, NULL))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+
+ EXPECT_FALSE(test_client_->IsListeningForTests());
+}
+
+TEST_F(MDnsTest, TransactionReentrantDelete) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ transaction_ = test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_->Start());
+
+ EXPECT_TRUE(test_client_->IsListeningForTests());
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS,
+ NULL))
+ .Times(Exactly(1))
+ .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction),
+ InvokeWithoutArgs(this, &MDnsTest::Stop)));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+
+ EXPECT_EQ(NULL, transaction_.get());
+
+ EXPECT_FALSE(test_client_->IsListeningForTests());
+}
+
+TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) {
+ StrictMock<MockListenerDelegate> delegate_irrelevant;
+ scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener(
+ dns_protocol::kTypeA, "codereview.chromium.local",
+ &delegate_irrelevant);
+ ASSERT_TRUE(listener_irrelevant->Start());
+
+ ASSERT_TRUE(test_client_->IsListeningForTests());
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ transaction_ = test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction));
+
+ ASSERT_TRUE(transaction_->Start());
+
+ RunUntilIdle();
+
+ EXPECT_EQ(NULL, transaction_.get());
+}
+
+// In order to reliably test reentrant listener deletes, we create two listeners
+// and have each of them delete both, so we're guaranteed to try and deliver a
+// callback to at least one deleted listener.
+
+TEST_F(MDnsTest, ListenerReentrantDelete) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ listener1_ = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_privet);
+
+ listener2_ = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_privet);
+
+ ASSERT_TRUE(listener1_->Start());
+
+ ASSERT_TRUE(listener2_->Start());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteBothListeners));
+
+ EXPECT_TRUE(test_client_->IsListeningForTests());
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ RunUntilIdle();
+
+ EXPECT_EQ(NULL, listener1_.get());
+ EXPECT_EQ(NULL, listener2_.get());
+
+ EXPECT_FALSE(test_client_->IsListeningForTests());
+}
+
+// Note: These tests assume that the ipv4 socket will always be created first.
+// This is a simplifying assumption based on the way the code works now.
+
+class SimpleMockSocketFactory
+ : public MDnsConnection::SocketFactory {
+ public:
+ SimpleMockSocketFactory() {
+ }
+ virtual ~SimpleMockSocketFactory() {
+ }
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE {
+ scoped_ptr<MockDatagramServerSocket> socket(
+ new StrictMock<MockDatagramServerSocket>);
+ sockets_.push(socket.get());
+ return socket.PassAs<DatagramServerSocket>();
+ }
+
+ MockDatagramServerSocket* PopFirstSocket() {
+ MockDatagramServerSocket* socket = sockets_.front();
+ sockets_.pop();
+ return socket;
+ }
+
+ size_t num_sockets() {
+ return sockets_.size();
+ }
+
+ private:
+ std::queue<MockDatagramServerSocket*> sockets_;
+};
+
+class MockMDnsConnectionDelegate : public MDnsConnection::Delegate {
+ public:
+ virtual void HandlePacket(DnsResponse* response, int size) {
+ HandlePacketInternal(std::string(response->io_buffer()->data(), size));
+ }
+
+ MOCK_METHOD1(HandlePacketInternal, void(std::string packet));
+
+ MOCK_METHOD1(OnConnectionError, void(int error));
+};
+
+class MDnsConnectionTest : public ::testing::Test {
+ public:
+ MDnsConnectionTest() : connection_(&factory_, &delegate_) {
+ }
+
+ protected:
+ // Follow successful connection initialization.
+ virtual void SetUp() OVERRIDE {
+ ASSERT_EQ(2u, factory_.num_sockets());
+
+ socket_ipv4_ = factory_.PopFirstSocket();
+ socket_ipv6_ = factory_.PopFirstSocket();
+ }
+
+ bool InitConnection() {
+ EXPECT_CALL(*socket_ipv4_, AllowAddressReuse());
+ EXPECT_CALL(*socket_ipv6_, AllowAddressReuse());
+
+ EXPECT_CALL(*socket_ipv4_, SetMulticastLoopbackMode(false));
+ EXPECT_CALL(*socket_ipv6_, SetMulticastLoopbackMode(false));
+
+ EXPECT_CALL(*socket_ipv4_, ListenInternal("0.0.0.0:5353"))
+ .WillOnce(Return(OK));
+ EXPECT_CALL(*socket_ipv6_, ListenInternal("[::]:5353"))
+ .WillOnce(Return(OK));
+
+ EXPECT_CALL(*socket_ipv4_, JoinGroupInternal("224.0.0.251"))
+ .WillOnce(Return(OK));
+ EXPECT_CALL(*socket_ipv6_, JoinGroupInternal("ff02::fb"))
+ .WillOnce(Return(OK));
+
+ return connection_.Init() == OK;
+ }
+
+ StrictMock<MockMDnsConnectionDelegate> delegate_;
+
+ MockDatagramServerSocket* socket_ipv4_;
+ MockDatagramServerSocket* socket_ipv6_;
+ SimpleMockSocketFactory factory_;
+ MDnsConnection connection_;
+ TestCompletionCallback callback_;
+};
+
+TEST_F(MDnsConnectionTest, ReceiveSynchronous) {
+ std::string sample_packet =
+ std::string(kSamplePacket1, sizeof(kSamplePacket1));
+
+ socket_ipv6_->SetResponsePacket(sample_packet);
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(
+ Invoke(socket_ipv6_, &MockDatagramServerSocket::HandleRecvNow))
+ .WillOnce(Return(ERR_IO_PENDING));
+
+ EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet));
+
+ ASSERT_TRUE(InitConnection());
+}
+
+TEST_F(MDnsConnectionTest, ReceiveAsynchronous) {
+ std::string sample_packet =
+ std::string(kSamplePacket1, sizeof(kSamplePacket1));
+ socket_ipv6_->SetResponsePacket(sample_packet);
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(
+ Invoke(socket_ipv6_, &MockDatagramServerSocket::HandleRecvLater))
+ .WillOnce(Return(ERR_IO_PENDING));
+
+ ASSERT_TRUE(InitConnection());
+
+ EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet));
+
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+TEST_F(MDnsConnectionTest, Send) {
+ std::string sample_packet =
+ std::string(kSamplePacket1, sizeof(kSamplePacket1));
+
+ scoped_refptr<IOBufferWithSize> buf(
+ new IOBufferWithSize(sizeof kSamplePacket1));
+ memcpy(buf->data(), kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+
+ ASSERT_TRUE(InitConnection());
+
+ EXPECT_CALL(*socket_ipv4_,
+ SendToInternal(sample_packet, "224.0.0.251:5353", _));
+ EXPECT_CALL(*socket_ipv6_,
+ SendToInternal(sample_packet, "[ff02::fb]:5353", _));
+
+ connection_.Send(buf, buf->size());
+}
+
+TEST_F(MDnsConnectionTest, Error) {
+ CompletionCallback callback;
+
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(DoAll(SaveArg<3>(&callback), Return(ERR_IO_PENDING)));
+
+ ASSERT_TRUE(InitConnection());
+
+ EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED));
+ callback.Run(ERR_SOCKET_NOT_CONNECTED);
+}
+
+} // namespace
+
+} // namespace net
diff --git a/net/net.gyp b/net/net.gyp
index e8f36ce..68d0114 100644
--- a/net/net.gyp
+++ b/net/net.gyp
@@ -447,6 +447,10 @@
'dns/mapped_host_resolver.h',
'dns/mdns_cache.cc',
'dns/mdns_cache.h',
+ 'dns/mdns_client.cc',
+ 'dns/mdns_client.h',
+ 'dns/mdns_client_impl.cc',
+ 'dns/mdns_client_impl.h',
'dns/notify_watcher_mac.cc',
'dns/notify_watcher_mac.h',
'dns/record_parsed.cc',
@@ -1314,6 +1318,10 @@
'sources!' : [
'dns/mdns_cache.cc',
'dns/mdns_cache.h',
+ 'dns/mdns_client.cc',
+ 'dns/mdns_client.h',
+ 'dns/mdns_client_impl.cc',
+ 'dns/mdns_client_impl.h',
'dns/record_parsed.cc',
'dns/record_parsed.h',
'dns/record_rdata.cc',
@@ -1544,6 +1552,7 @@
'dns/host_resolver_impl_unittest.cc',
'dns/mapped_host_resolver_unittest.cc',
'dns/mdns_cache_unittest.cc',
+ 'dns/mdns_client_unittest.cc',
'dns/serial_worker_unittest.cc',
'dns/record_parsed_unittest.cc',
'dns/record_rdata_unittest.cc',
@@ -1949,6 +1958,8 @@
[ 'enable_mdns != 1', {
'sources!' : [
'dns/mdns_cache_unittest.cc',
+ 'dns/mdns_client_unittest.cc',
+ 'dns/mdns_query_unittest.cc',
'dns/record_parsed_unittest.cc',
'dns/record_rdata_unittest.cc',
],