summaryrefslogtreecommitdiffstats
path: root/chrome/common/local_discovery
diff options
context:
space:
mode:
authorvitalybuka@chromium.org <vitalybuka@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2014-04-28 08:14:37 +0000
committervitalybuka@chromium.org <vitalybuka@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2014-04-28 08:14:37 +0000
commit040e90ddcb6c491db29fc917424e177bba40508c (patch)
tree84a76dfd29aa19bc42fc3b1f7cd8db847f27635e /chrome/common/local_discovery
parent2ea5fc1d74afbd7bf3f4d88a18b437f05e20d58b (diff)
downloadchromium_src-040e90ddcb6c491db29fc917424e177bba40508c.zip
chromium_src-040e90ddcb6c491db29fc917424e177bba40508c.tar.gz
chromium_src-040e90ddcb6c491db29fc917424e177bba40508c.tar.bz2
Move service_discovery_client_impl.* into chrome/common.
It's need to be used by utility and browser processes. BUG=349645 Review URL: https://codereview.chromium.org/256923003 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@266487 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'chrome/common/local_discovery')
-rw-r--r--chrome/common/local_discovery/local_domain_resolver_unittest.cc183
-rw-r--r--chrome/common/local_discovery/service_discovery_client_impl.cc584
-rw-r--r--chrome/common/local_discovery/service_discovery_client_impl.h273
-rw-r--r--chrome/common/local_discovery/service_discovery_client_unittest.cc518
4 files changed, 1558 insertions, 0 deletions
diff --git a/chrome/common/local_discovery/local_domain_resolver_unittest.cc b/chrome/common/local_discovery/local_domain_resolver_unittest.cc
new file mode 100644
index 0000000..cc65ab3
--- /dev/null
+++ b/chrome/common/local_discovery/local_domain_resolver_unittest.cc
@@ -0,0 +1,183 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "chrome/common/local_discovery/service_discovery_client_impl.h"
+#include "net/dns/mdns_client_impl.h"
+#include "net/dns/mock_mdns_socket_factory.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using ::testing::_;
+
+namespace local_discovery {
+
+namespace {
+
+const uint8 kSamplePacketA[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
+ 0x00, 0x10,
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x01, 0x02,
+ 0x03, 0x04,
+};
+
+const uint8 kSamplePacketAAAA[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x1C, // TYPE is AAAA.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
+ 0x00, 0x10,
+ 0x00, 0x10, // RDLENGTH is 4 bytes.
+ 0x00, 0x0A, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x01, 0x00, 0x02,
+ 0x00, 0x03, 0x00, 0x04,
+};
+
+class LocalDomainResolverTest : public testing::Test {
+ public:
+ virtual void SetUp() OVERRIDE {
+ mdns_client_.StartListening(&socket_factory_);
+ }
+
+ std::string IPAddressToStringWithEmpty(const net::IPAddressNumber& address) {
+ if (address.empty()) return "";
+ return net::IPAddressToString(address);
+ }
+
+ void AddressCallback(bool resolved,
+ const net::IPAddressNumber& address_ipv4,
+ const net::IPAddressNumber& address_ipv6) {
+ AddressCallbackInternal(resolved,
+ IPAddressToStringWithEmpty(address_ipv4),
+ IPAddressToStringWithEmpty(address_ipv6));
+ }
+
+ void RunFor(base::TimeDelta time_period) {
+ base::CancelableCallback<void()> callback(base::Bind(
+ &base::MessageLoop::Quit,
+ base::Unretained(base::MessageLoop::current())));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, callback.callback(), time_period);
+
+ base::MessageLoop::current()->Run();
+ callback.Cancel();
+ }
+
+ MOCK_METHOD3(AddressCallbackInternal,
+ void(bool resolved,
+ std::string address_ipv4,
+ std::string address_ipv6));
+
+ net::MockMDnsSocketFactory socket_factory_;
+ net::MDnsClientImpl mdns_client_;
+ base::MessageLoop message_loop_;
+};
+
+TEST_F(LocalDomainResolverTest, ResolveDomainA) {
+ LocalDomainResolverImpl resolver(
+ "myhello.local", net::ADDRESS_FAMILY_IPV4,
+ base::Bind(&LocalDomainResolverTest::AddressCallback,
+ base::Unretained(this)), &mdns_client_);
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); // Twice per query
+
+ resolver.Start();
+
+ EXPECT_CALL(*this, AddressCallbackInternal(true, "1.2.3.4", ""));
+
+ socket_factory_.SimulateReceive(kSamplePacketA, sizeof(kSamplePacketA));
+}
+
+TEST_F(LocalDomainResolverTest, ResolveDomainAAAA) {
+ LocalDomainResolverImpl resolver(
+ "myhello.local", net::ADDRESS_FAMILY_IPV6,
+ base::Bind(&LocalDomainResolverTest::AddressCallback,
+ base::Unretained(this)), &mdns_client_);
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); // Twice per query
+
+ resolver.Start();
+
+ EXPECT_CALL(*this, AddressCallbackInternal(true, "", "a::1:2:3:4"));
+
+ socket_factory_.SimulateReceive(kSamplePacketAAAA, sizeof(kSamplePacketAAAA));
+}
+
+TEST_F(LocalDomainResolverTest, ResolveDomainAnyOneAvailable) {
+ LocalDomainResolverImpl resolver(
+ "myhello.local", net::ADDRESS_FAMILY_UNSPECIFIED,
+ base::Bind(&LocalDomainResolverTest::AddressCallback,
+ base::Unretained(this)), &mdns_client_);
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); // Twice per query
+
+ resolver.Start();
+
+ socket_factory_.SimulateReceive(kSamplePacketAAAA, sizeof(kSamplePacketAAAA));
+
+ EXPECT_CALL(*this, AddressCallbackInternal(true, "", "a::1:2:3:4"));
+
+ RunFor(base::TimeDelta::FromMilliseconds(150));
+}
+
+
+TEST_F(LocalDomainResolverTest, ResolveDomainAnyBothAvailable) {
+ LocalDomainResolverImpl resolver(
+ "myhello.local", net::ADDRESS_FAMILY_UNSPECIFIED,
+ base::Bind(&LocalDomainResolverTest::AddressCallback,
+ base::Unretained(this)), &mdns_client_);
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); // Twice per query
+
+ resolver.Start();
+
+ EXPECT_CALL(*this, AddressCallbackInternal(true, "1.2.3.4", "a::1:2:3:4"));
+
+ socket_factory_.SimulateReceive(kSamplePacketAAAA, sizeof(kSamplePacketAAAA));
+
+ socket_factory_.SimulateReceive(kSamplePacketA, sizeof(kSamplePacketA));
+}
+
+TEST_F(LocalDomainResolverTest, ResolveDomainNone) {
+ LocalDomainResolverImpl resolver(
+ "myhello.local", net::ADDRESS_FAMILY_UNSPECIFIED,
+ base::Bind(&LocalDomainResolverTest::AddressCallback,
+ base::Unretained(this)), &mdns_client_);
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); // Twice per query
+
+ resolver.Start();
+
+ EXPECT_CALL(*this, AddressCallbackInternal(false, "", ""));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+}
+
+} // namespace
+
+} // namespace local_discovery
diff --git a/chrome/common/local_discovery/service_discovery_client_impl.cc b/chrome/common/local_discovery/service_discovery_client_impl.cc
new file mode 100644
index 0000000..c86f726
--- /dev/null
+++ b/chrome/common/local_discovery/service_discovery_client_impl.cc
@@ -0,0 +1,584 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <utility>
+
+#include "base/logging.h"
+#include "base/memory/singleton.h"
+#include "base/message_loop/message_loop_proxy.h"
+#include "base/stl_util.h"
+#include "chrome/common/local_discovery/service_discovery_client_impl.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/record_rdata.h"
+
+namespace local_discovery {
+
+namespace {
+// TODO(noamsml): Make this configurable through the LocalDomainResolver
+// interface.
+const int kLocalDomainSecondAddressTimeoutMs = 100;
+
+const int kInitialRequeryTimeSeconds = 1;
+const int kMaxRequeryTimeSeconds = 2; // Time for last requery
+}
+
+ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
+ net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
+}
+
+ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
+}
+
+scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
+ const std::string& service_type,
+ const ServiceWatcher::UpdatedCallback& callback) {
+ return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
+ service_type, callback, mdns_client_));
+}
+
+scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
+ const std::string& service_name,
+ const ServiceResolver::ResolveCompleteCallback& callback) {
+ return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
+ service_name, callback, mdns_client_));
+}
+
+scoped_ptr<LocalDomainResolver>
+ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
+ const std::string& domain,
+ net::AddressFamily address_family,
+ const LocalDomainResolver::IPAddressCallback& callback) {
+ return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
+ domain, address_family, callback, mdns_client_));
+}
+
+ServiceWatcherImpl::ServiceWatcherImpl(
+ const std::string& service_type,
+ const ServiceWatcher::UpdatedCallback& callback,
+ net::MDnsClient* mdns_client)
+ : service_type_(service_type), callback_(callback), started_(false),
+ actively_refresh_services_(false), mdns_client_(mdns_client) {
+}
+
+void ServiceWatcherImpl::Start() {
+ DCHECK(!started_);
+ listener_ = mdns_client_->CreateListener(
+ net::dns_protocol::kTypePTR, service_type_, this);
+ started_ = listener_->Start();
+ if (started_)
+ ReadCachedServices();
+}
+
+ServiceWatcherImpl::~ServiceWatcherImpl() {
+}
+
+void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
+ DCHECK(started_);
+ if (force_update)
+ services_.clear();
+ SendQuery(kInitialRequeryTimeSeconds, force_update);
+}
+
+void ServiceWatcherImpl::SetActivelyRefreshServices(
+ bool actively_refresh_services) {
+ DCHECK(started_);
+ actively_refresh_services_ = actively_refresh_services;
+
+ for (ServiceListenersMap::iterator i = services_.begin();
+ i != services_.end(); i++) {
+ i->second->SetActiveRefresh(actively_refresh_services);
+ }
+}
+
+void ServiceWatcherImpl::ReadCachedServices() {
+ DCHECK(started_);
+ CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
+ &transaction_cache_);
+}
+
+bool ServiceWatcherImpl::CreateTransaction(
+ bool network, bool cache, bool force_refresh,
+ scoped_ptr<net::MDnsTransaction>* transaction) {
+ int transaction_flags = 0;
+ if (network)
+ transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
+
+ if (cache)
+ transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
+
+ // TODO(noamsml): Add flag for force_refresh when supported.
+
+ if (transaction_flags) {
+ *transaction = mdns_client_->CreateTransaction(
+ net::dns_protocol::kTypePTR, service_type_, transaction_flags,
+ base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
+ base::Unretained(this), transaction));
+ return (*transaction)->Start();
+ }
+
+ return true;
+}
+
+std::string ServiceWatcherImpl::GetServiceType() const {
+ return listener_->GetName();
+}
+
+void ServiceWatcherImpl::OnRecordUpdate(
+ net::MDnsListener::UpdateType update,
+ const net::RecordParsed* record) {
+ DCHECK(started_);
+ if (record->type() == net::dns_protocol::kTypePTR) {
+ DCHECK(record->name() == GetServiceType());
+ const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
+
+ switch (update) {
+ case net::MDnsListener::RECORD_ADDED:
+ AddService(rdata->ptrdomain());
+ break;
+ case net::MDnsListener::RECORD_CHANGED:
+ NOTREACHED();
+ break;
+ case net::MDnsListener::RECORD_REMOVED:
+ RemovePTR(rdata->ptrdomain());
+ break;
+ }
+ } else {
+ DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
+ record->type() == net::dns_protocol::kTypeTXT);
+ DCHECK(services_.find(record->name()) != services_.end());
+
+ if (record->type() == net::dns_protocol::kTypeSRV) {
+ if (update == net::MDnsListener::RECORD_REMOVED) {
+ RemoveSRV(record->name());
+ } else if (update == net::MDnsListener::RECORD_ADDED) {
+ AddSRV(record->name());
+ }
+ }
+
+ // If this is the first time we see an SRV record, do not send
+ // an UPDATE_CHANGED.
+ if (record->type() != net::dns_protocol::kTypeSRV ||
+ update != net::MDnsListener::RECORD_ADDED) {
+ DeferUpdate(UPDATE_CHANGED, record->name());
+ }
+ }
+}
+
+void ServiceWatcherImpl::OnCachePurged() {
+ // Not yet implemented.
+}
+
+void ServiceWatcherImpl::OnTransactionResponse(
+ scoped_ptr<net::MDnsTransaction>* transaction,
+ net::MDnsTransaction::Result result,
+ const net::RecordParsed* record) {
+ DCHECK(started_);
+ if (result == net::MDnsTransaction::RESULT_RECORD) {
+ const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
+ DCHECK(rdata);
+ AddService(rdata->ptrdomain());
+ } else if (result == net::MDnsTransaction::RESULT_DONE) {
+ transaction->reset();
+ }
+
+ // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
+ // record for PTR records on any name.
+}
+
+ServiceWatcherImpl::ServiceListeners::ServiceListeners(
+ const std::string& service_name,
+ ServiceWatcherImpl* watcher,
+ net::MDnsClient* mdns_client)
+ : service_name_(service_name), mdns_client_(mdns_client),
+ update_pending_(false), has_ptr_(true), has_srv_(false) {
+ srv_listener_ = mdns_client->CreateListener(
+ net::dns_protocol::kTypeSRV, service_name, watcher);
+ txt_listener_ = mdns_client->CreateListener(
+ net::dns_protocol::kTypeTXT, service_name, watcher);
+}
+
+ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
+}
+
+bool ServiceWatcherImpl::ServiceListeners::Start() {
+ if (!srv_listener_->Start())
+ return false;
+ return txt_listener_->Start();
+}
+
+void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
+ bool active_refresh) {
+ srv_listener_->SetActiveRefresh(active_refresh);
+
+ if (active_refresh && !has_srv_) {
+ DCHECK(has_ptr_);
+ srv_transaction_ = mdns_client_->CreateTransaction(
+ net::dns_protocol::kTypeSRV, service_name_,
+ net::MDnsTransaction::SINGLE_RESULT |
+ net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK,
+ base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord,
+ base::Unretained(this)));
+ srv_transaction_->Start();
+ } else if (!active_refresh) {
+ srv_transaction_.reset();
+ }
+}
+
+void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
+ net::MDnsTransaction::Result result,
+ const net::RecordParsed* record) {
+ set_has_srv(record != NULL);
+}
+
+void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) {
+ has_srv_ = has_srv;
+
+ srv_transaction_.reset();
+}
+
+void ServiceWatcherImpl::AddService(const std::string& service) {
+ DCHECK(started_);
+ std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
+ make_pair(service, linked_ptr<ServiceListeners>(NULL)));
+
+ if (found.second) { // Newly inserted.
+ found.first->second = linked_ptr<ServiceListeners>(
+ new ServiceListeners(service, this, mdns_client_));
+ bool success = found.first->second->Start();
+ found.first->second->SetActiveRefresh(actively_refresh_services_);
+ DeferUpdate(UPDATE_ADDED, service);
+
+ DCHECK(success);
+ }
+
+ found.first->second->set_has_ptr(true);
+}
+
+void ServiceWatcherImpl::AddSRV(const std::string& service) {
+ DCHECK(started_);
+
+ ServiceListenersMap::iterator found = services_.find(service);
+ if (found != services_.end()) {
+ found->second->set_has_srv(true);
+ }
+}
+
+void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
+ const std::string& service_name) {
+ ServiceListenersMap::iterator found = services_.find(service_name);
+
+ if (found != services_.end() && !found->second->update_pending()) {
+ found->second->set_update_pending(true);
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
+ update_type, service_name));
+ }
+}
+
+void ServiceWatcherImpl::DeliverDeferredUpdate(
+ ServiceWatcher::UpdateType update_type, const std::string& service_name) {
+ ServiceListenersMap::iterator found = services_.find(service_name);
+
+ if (found != services_.end()) {
+ found->second->set_update_pending(false);
+ if (!callback_.is_null())
+ callback_.Run(update_type, service_name);
+ }
+}
+
+void ServiceWatcherImpl::RemovePTR(const std::string& service) {
+ DCHECK(started_);
+
+ ServiceListenersMap::iterator found = services_.find(service);
+ if (found != services_.end()) {
+ found->second->set_has_ptr(false);
+
+ if (!found->second->has_ptr_or_srv()) {
+ services_.erase(found);
+ if (!callback_.is_null())
+ callback_.Run(UPDATE_REMOVED, service);
+ }
+ }
+}
+
+void ServiceWatcherImpl::RemoveSRV(const std::string& service) {
+ DCHECK(started_);
+
+ ServiceListenersMap::iterator found = services_.find(service);
+ if (found != services_.end()) {
+ found->second->set_has_srv(false);
+
+ if (!found->second->has_ptr_or_srv()) {
+ services_.erase(found);
+ if (!callback_.is_null())
+ callback_.Run(UPDATE_REMOVED, service);
+ }
+ }
+}
+
+void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
+ unsigned rrtype) {
+ // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
+ // on any name.
+}
+
+void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
+ if (timeout_seconds <= kMaxRequeryTimeSeconds) {
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&ServiceWatcherImpl::SendQuery,
+ AsWeakPtr(),
+ timeout_seconds * 2 /*next_timeout_seconds*/,
+ false /*force_update*/),
+ base::TimeDelta::FromSeconds(timeout_seconds));
+ }
+}
+
+void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
+ bool force_update) {
+ CreateTransaction(true /*network*/, false /*cache*/, force_update,
+ &transaction_network_);
+ ScheduleQuery(next_timeout_seconds);
+}
+
+ServiceResolverImpl::ServiceResolverImpl(
+ const std::string& service_name,
+ const ResolveCompleteCallback& callback,
+ net::MDnsClient* mdns_client)
+ : service_name_(service_name), callback_(callback),
+ metadata_resolved_(false), address_resolved_(false),
+ mdns_client_(mdns_client) {
+}
+
+void ServiceResolverImpl::StartResolving() {
+ address_resolved_ = false;
+ metadata_resolved_ = false;
+ service_staging_ = ServiceDescription();
+ service_staging_.service_name = service_name_;
+
+ if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
+ ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
+ }
+}
+
+ServiceResolverImpl::~ServiceResolverImpl() {
+}
+
+bool ServiceResolverImpl::CreateTxtTransaction() {
+ txt_transaction_ = mdns_client_->CreateTransaction(
+ net::dns_protocol::kTypeTXT, service_name_,
+ net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
+ net::MDnsTransaction::QUERY_NETWORK,
+ base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
+ AsWeakPtr()));
+ return txt_transaction_->Start();
+}
+
+// TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
+void ServiceResolverImpl::CreateATransaction() {
+ a_transaction_ = mdns_client_->CreateTransaction(
+ net::dns_protocol::kTypeA,
+ service_staging_.address.host(),
+ net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
+ base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
+ AsWeakPtr()));
+ a_transaction_->Start();
+}
+
+bool ServiceResolverImpl::CreateSrvTransaction() {
+ srv_transaction_ = mdns_client_->CreateTransaction(
+ net::dns_protocol::kTypeSRV, service_name_,
+ net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
+ net::MDnsTransaction::QUERY_NETWORK,
+ base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
+ AsWeakPtr()));
+ return srv_transaction_->Start();
+}
+
+std::string ServiceResolverImpl::GetName() const {
+ return service_name_;
+}
+
+void ServiceResolverImpl::SrvRecordTransactionResponse(
+ net::MDnsTransaction::Result status, const net::RecordParsed* record) {
+ srv_transaction_.reset();
+ if (status == net::MDnsTransaction::RESULT_RECORD) {
+ DCHECK(record);
+ service_staging_.address = RecordToAddress(record);
+ service_staging_.last_seen = record->time_created();
+ CreateATransaction();
+ } else {
+ ServiceNotFound(MDnsStatusToRequestStatus(status));
+ }
+}
+
+void ServiceResolverImpl::TxtRecordTransactionResponse(
+ net::MDnsTransaction::Result status, const net::RecordParsed* record) {
+ txt_transaction_.reset();
+ if (status == net::MDnsTransaction::RESULT_RECORD) {
+ DCHECK(record);
+ service_staging_.metadata = RecordToMetadata(record);
+ } else {
+ service_staging_.metadata = std::vector<std::string>();
+ }
+
+ metadata_resolved_ = true;
+ AlertCallbackIfReady();
+}
+
+void ServiceResolverImpl::ARecordTransactionResponse(
+ net::MDnsTransaction::Result status, const net::RecordParsed* record) {
+ a_transaction_.reset();
+
+ if (status == net::MDnsTransaction::RESULT_RECORD) {
+ DCHECK(record);
+ service_staging_.ip_address = RecordToIPAddress(record);
+ } else {
+ service_staging_.ip_address = net::IPAddressNumber();
+ }
+
+ address_resolved_ = true;
+ AlertCallbackIfReady();
+}
+
+void ServiceResolverImpl::AlertCallbackIfReady() {
+ if (metadata_resolved_ && address_resolved_) {
+ txt_transaction_.reset();
+ srv_transaction_.reset();
+ a_transaction_.reset();
+ if (!callback_.is_null())
+ callback_.Run(STATUS_SUCCESS, service_staging_);
+ }
+}
+
+void ServiceResolverImpl::ServiceNotFound(
+ ServiceResolver::RequestStatus status) {
+ txt_transaction_.reset();
+ srv_transaction_.reset();
+ a_transaction_.reset();
+ if (!callback_.is_null())
+ callback_.Run(status, ServiceDescription());
+}
+
+ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
+ net::MDnsTransaction::Result status) const {
+ switch (status) {
+ case net::MDnsTransaction::RESULT_RECORD:
+ return ServiceResolver::STATUS_SUCCESS;
+ case net::MDnsTransaction::RESULT_NO_RESULTS:
+ return ServiceResolver::STATUS_REQUEST_TIMEOUT;
+ case net::MDnsTransaction::RESULT_NSEC:
+ return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
+ case net::MDnsTransaction::RESULT_DONE: // Pass through.
+ default:
+ NOTREACHED();
+ return ServiceResolver::STATUS_REQUEST_TIMEOUT;
+ }
+}
+
+const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
+ const net::RecordParsed* record) const {
+ DCHECK(record->type() == net::dns_protocol::kTypeTXT);
+ const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
+ DCHECK(txt_rdata);
+ return txt_rdata->texts();
+}
+
+net::HostPortPair ServiceResolverImpl::RecordToAddress(
+ const net::RecordParsed* record) const {
+ DCHECK(record->type() == net::dns_protocol::kTypeSRV);
+ const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
+ DCHECK(srv_rdata);
+ return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
+}
+
+const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
+ const net::RecordParsed* record) const {
+ DCHECK(record->type() == net::dns_protocol::kTypeA);
+ const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
+ DCHECK(a_rdata);
+ return a_rdata->address();
+}
+
+LocalDomainResolverImpl::LocalDomainResolverImpl(
+ const std::string& domain,
+ net::AddressFamily address_family,
+ const IPAddressCallback& callback,
+ net::MDnsClient* mdns_client)
+ : domain_(domain), address_family_(address_family), callback_(callback),
+ transactions_finished_(0), mdns_client_(mdns_client) {
+}
+
+LocalDomainResolverImpl::~LocalDomainResolverImpl() {
+ timeout_callback_.Cancel();
+}
+
+void LocalDomainResolverImpl::Start() {
+ if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
+ address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
+ transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
+ transaction_a_->Start();
+ }
+
+ if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
+ address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
+ transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
+ transaction_aaaa_->Start();
+ }
+}
+
+scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
+ uint16 type) {
+ return mdns_client_->CreateTransaction(
+ type, domain_, net::MDnsTransaction::SINGLE_RESULT |
+ net::MDnsTransaction::QUERY_CACHE |
+ net::MDnsTransaction::QUERY_NETWORK,
+ base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
+ base::Unretained(this)));
+}
+
+void LocalDomainResolverImpl::OnTransactionComplete(
+ net::MDnsTransaction::Result result, const net::RecordParsed* record) {
+ transactions_finished_++;
+
+ if (result == net::MDnsTransaction::RESULT_RECORD) {
+ if (record->type() == net::dns_protocol::kTypeA) {
+ const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
+ address_ipv4_ = rdata->address();
+ } else {
+ DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
+ const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
+ address_ipv6_ = rdata->address();
+ }
+ }
+
+ if (transactions_finished_ == 1 &&
+ address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
+ timeout_callback_.Reset(base::Bind(
+ &LocalDomainResolverImpl::SendResolvedAddresses,
+ base::Unretained(this)));
+
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ timeout_callback_.callback(),
+ base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
+ } else if (transactions_finished_ == 2
+ || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
+ SendResolvedAddresses();
+ }
+}
+
+bool LocalDomainResolverImpl::IsSuccess() {
+ return !address_ipv4_.empty() || !address_ipv6_.empty();
+}
+
+void LocalDomainResolverImpl::SendResolvedAddresses() {
+ transaction_a_.reset();
+ transaction_aaaa_.reset();
+ timeout_callback_.Cancel();
+ callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
+}
+
+} // namespace local_discovery
diff --git a/chrome/common/local_discovery/service_discovery_client_impl.h b/chrome/common/local_discovery/service_discovery_client_impl.h
new file mode 100644
index 0000000..29e42c7
--- /dev/null
+++ b/chrome/common/local_discovery/service_discovery_client_impl.h
@@ -0,0 +1,273 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CHROME_COMMON_LOCAL_DISCOVERY_SERVICE_DISCOVERY_CLIENT_IMPL_H_
+#define CHROME_COMMON_LOCAL_DISCOVERY_SERVICE_DISCOVERY_CLIENT_IMPL_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "base/callback.h"
+#include "base/cancelable_callback.h"
+#include "base/memory/linked_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "chrome/common/local_discovery/service_discovery_client.h"
+#include "net/dns/mdns_client.h"
+
+namespace local_discovery {
+
+class ServiceDiscoveryClientImpl : public ServiceDiscoveryClient {
+ public:
+ // |mdns_client| must outlive the Service Discovery Client.
+ explicit ServiceDiscoveryClientImpl(net::MDnsClient* mdns_client);
+ virtual ~ServiceDiscoveryClientImpl();
+
+ // ServiceDiscoveryClient implementation:
+ virtual scoped_ptr<ServiceWatcher> CreateServiceWatcher(
+ const std::string& service_type,
+ const ServiceWatcher::UpdatedCallback& callback) OVERRIDE;
+
+ virtual scoped_ptr<ServiceResolver> CreateServiceResolver(
+ const std::string& service_name,
+ const ServiceResolver::ResolveCompleteCallback& callback) OVERRIDE;
+
+ virtual scoped_ptr<LocalDomainResolver> CreateLocalDomainResolver(
+ const std::string& domain,
+ net::AddressFamily address_family,
+ const LocalDomainResolver::IPAddressCallback& callback) OVERRIDE;
+
+ private:
+ net::MDnsClient* mdns_client_;
+
+ DISALLOW_COPY_AND_ASSIGN(ServiceDiscoveryClientImpl);
+};
+
+class ServiceWatcherImpl : public ServiceWatcher,
+ public net::MDnsListener::Delegate,
+ public base::SupportsWeakPtr<ServiceWatcherImpl> {
+ public:
+ ServiceWatcherImpl(const std::string& service_type,
+ const ServiceWatcher::UpdatedCallback& callback,
+ net::MDnsClient* mdns_client);
+ // Listening will automatically stop when the destructor is called.
+ virtual ~ServiceWatcherImpl();
+
+ // ServiceWatcher implementation:
+ virtual void Start() OVERRIDE;
+
+ virtual void DiscoverNewServices(bool force_update) OVERRIDE;
+
+ virtual void SetActivelyRefreshServices(
+ bool actively_refresh_services) OVERRIDE;
+
+ virtual std::string GetServiceType() const OVERRIDE;
+
+ virtual void OnRecordUpdate(net::MDnsListener::UpdateType update,
+ const net::RecordParsed* record) OVERRIDE;
+
+ virtual void OnNsecRecord(const std::string& name, unsigned rrtype) OVERRIDE;
+
+ virtual void OnCachePurged() OVERRIDE;
+
+ virtual void OnTransactionResponse(
+ scoped_ptr<net::MDnsTransaction>* transaction,
+ net::MDnsTransaction::Result result,
+ const net::RecordParsed* record);
+
+ private:
+ struct ServiceListeners {
+ ServiceListeners(const std::string& service_name,
+ ServiceWatcherImpl* watcher,
+ net::MDnsClient* mdns_client);
+ ~ServiceListeners();
+ bool Start();
+ void SetActiveRefresh(bool auto_update);
+
+ void set_update_pending(bool update_pending) {
+ update_pending_ = update_pending;
+ }
+
+ bool update_pending() { return update_pending_; }
+
+ void set_has_ptr(bool has_ptr) {
+ has_ptr_ = has_ptr;
+ }
+
+ void set_has_srv(bool has_srv);
+
+ bool has_ptr_or_srv() { return has_ptr_ || has_srv_; }
+
+ private:
+ void OnSRVRecord(net::MDnsTransaction::Result result,
+ const net::RecordParsed* record);
+
+ void DoQuerySRV();
+
+ scoped_ptr<net::MDnsListener> srv_listener_;
+ scoped_ptr<net::MDnsListener> txt_listener_;
+ scoped_ptr<net::MDnsTransaction> srv_transaction_;
+
+ std::string service_name_;
+ net::MDnsClient* mdns_client_;
+ bool update_pending_;
+
+ bool has_ptr_;
+ bool has_srv_;
+ };
+
+ typedef std::map<std::string, linked_ptr<ServiceListeners> >
+ ServiceListenersMap;
+
+ void ReadCachedServices();
+ void AddService(const std::string& service);
+ void RemovePTR(const std::string& service);
+ void RemoveSRV(const std::string& service);
+ void AddSRV(const std::string& service);
+ bool CreateTransaction(bool active, bool alert_existing_services,
+ bool force_refresh,
+ scoped_ptr<net::MDnsTransaction>* transaction);
+
+ void DeferUpdate(ServiceWatcher::UpdateType update_type,
+ const std::string& service_name);
+ void DeliverDeferredUpdate(ServiceWatcher::UpdateType update_type,
+ const std::string& service_name);
+
+ void ScheduleQuery(int timeout_seconds);
+
+ void SendQuery(int next_timeout_seconds, bool force_update);
+
+ std::string service_type_;
+ ServiceListenersMap services_;
+ scoped_ptr<net::MDnsTransaction> transaction_network_;
+ scoped_ptr<net::MDnsTransaction> transaction_cache_;
+ scoped_ptr<net::MDnsListener> listener_;
+
+ ServiceWatcher::UpdatedCallback callback_;
+ bool started_;
+ bool actively_refresh_services_;
+
+ net::MDnsClient* mdns_client_;
+
+ DISALLOW_COPY_AND_ASSIGN(ServiceWatcherImpl);
+};
+
+class ServiceResolverImpl
+ : public ServiceResolver,
+ public base::SupportsWeakPtr<ServiceResolverImpl> {
+ public:
+ ServiceResolverImpl(const std::string& service_name,
+ const ServiceResolver::ResolveCompleteCallback& callback,
+ net::MDnsClient* mdns_client);
+
+ virtual ~ServiceResolverImpl();
+
+ // ServiceResolver implementation:
+ virtual void StartResolving() OVERRIDE;
+
+ virtual std::string GetName() const OVERRIDE;
+
+ private:
+ // Respond to transaction finishing for SRV records.
+ void SrvRecordTransactionResponse(net::MDnsTransaction::Result status,
+ const net::RecordParsed* record);
+
+ // Respond to transaction finishing for TXT records.
+ void TxtRecordTransactionResponse(net::MDnsTransaction::Result status,
+ const net::RecordParsed* record);
+
+ // Respond to transaction finishing for A records.
+ void ARecordTransactionResponse(net::MDnsTransaction::Result status,
+ const net::RecordParsed* record);
+
+ void AlertCallbackIfReady();
+
+ void ServiceNotFound(RequestStatus status);
+
+ // Convert a TXT record to a vector of strings (metadata).
+ const std::vector<std::string>& RecordToMetadata(
+ const net::RecordParsed* record) const;
+
+ // Convert an SRV record to a host and port pair.
+ net::HostPortPair RecordToAddress(
+ const net::RecordParsed* record) const;
+
+ // Convert an A record to an IP address.
+ const net::IPAddressNumber& RecordToIPAddress(
+ const net::RecordParsed* record) const;
+
+ // Convert an MDns status to a service discovery status.
+ RequestStatus MDnsStatusToRequestStatus(
+ net::MDnsTransaction::Result status) const;
+
+ bool CreateTxtTransaction();
+ bool CreateSrvTransaction();
+ void CreateATransaction();
+
+ std::string service_name_;
+ ResolveCompleteCallback callback_;
+
+ bool has_resolved_;
+
+ bool metadata_resolved_;
+ bool address_resolved_;
+
+ scoped_ptr<net::MDnsTransaction> txt_transaction_;
+ scoped_ptr<net::MDnsTransaction> srv_transaction_;
+ scoped_ptr<net::MDnsTransaction> a_transaction_;
+
+ ServiceDescription service_staging_;
+
+ net::MDnsClient* mdns_client_;
+
+ DISALLOW_COPY_AND_ASSIGN(ServiceResolverImpl);
+};
+
+class LocalDomainResolverImpl : public LocalDomainResolver {
+ public:
+ LocalDomainResolverImpl(const std::string& domain,
+ net::AddressFamily address_family,
+ const IPAddressCallback& callback,
+ net::MDnsClient* mdns_client);
+ virtual ~LocalDomainResolverImpl();
+
+ virtual void Start() OVERRIDE;
+
+ const std::string& domain() { return domain_; }
+
+ private:
+ void OnTransactionComplete(
+ net::MDnsTransaction::Result result,
+ const net::RecordParsed* record);
+
+ scoped_ptr<net::MDnsTransaction> CreateTransaction(uint16 type);
+
+ bool IsSuccess();
+
+ void SendResolvedAddresses();
+
+ std::string domain_;
+ net::AddressFamily address_family_;
+ IPAddressCallback callback_;
+
+ scoped_ptr<net::MDnsTransaction> transaction_a_;
+ scoped_ptr<net::MDnsTransaction> transaction_aaaa_;
+
+ int transactions_finished_;
+
+ net::MDnsClient* mdns_client_;
+
+ net::IPAddressNumber address_ipv4_;
+ net::IPAddressNumber address_ipv6_;
+
+ base::CancelableCallback<void()> timeout_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverImpl);
+};
+
+
+} // namespace local_discovery
+
+#endif // CHROME_COMMON_LOCAL_DISCOVERY_SERVICE_DISCOVERY_CLIENT_IMPL_H_
diff --git a/chrome/common/local_discovery/service_discovery_client_unittest.cc b/chrome/common/local_discovery/service_discovery_client_unittest.cc
new file mode 100644
index 0000000..5033fbdd
--- /dev/null
+++ b/chrome/common/local_discovery/service_discovery_client_unittest.cc
@@ -0,0 +1,518 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/memory/weak_ptr.h"
+#include "base/run_loop.h"
+#include "chrome/common/local_discovery/service_discovery_client_impl.h"
+#include "net/base/net_errors.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/mdns_client_impl.h"
+#include "net/dns/mock_mdns_socket_factory.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::StrictMock;
+using ::testing::NiceMock;
+using ::testing::Mock;
+using ::testing::SaveArg;
+using ::testing::SetArgPointee;
+using ::testing::Return;
+using ::testing::Exactly;
+
+namespace local_discovery {
+
+namespace {
+
+const uint8 kSamplePacketPTR[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 1 second.
+ 0x00, 0x01,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c
+};
+
+const uint8 kSamplePacketSRV[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x21, // TYPE is SRV.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 1 second.
+ 0x00, 0x01,
+ 0x00, 0x15, // RDLENGTH is 21 bytes.
+ 0x00, 0x00,
+ 0x00, 0x00,
+ 0x22, 0xb8, // port 8888
+ 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+};
+
+const uint8 kSamplePacketTXT[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x10, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x00, 0x01,
+ 0x00, 0x06, // RDLENGTH is 21 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o'
+};
+
+const uint8 kSamplePacketSRVA[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x21, // TYPE is SRV.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
+ 0x00, 0x10,
+ 0x00, 0x15, // RDLENGTH is 21 bytes.
+ 0x00, 0x00,
+ 0x00, 0x00,
+ 0x22, 0xb8, // port 8888
+ 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+
+ 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
+ 0x00, 0x10,
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x01, 0x02,
+ 0x03, 0x04,
+};
+
+const uint8 kSamplePacketPTR2[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x02, 0x00, // TTL (4 bytes) is 1 second.
+ 0x00, 0x01,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'g', 'd', 'b', 'y', 'e',
+ 0xc0, 0x0c,
+
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x02, 0x00, // TTL (4 bytes) is 1 second.
+ 0x00, 0x01,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c
+};
+
+const uint8 kSamplePacketQuerySRV[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x00, 0x00, // No flags.
+ 0x00, 0x01, // One question.
+ 0x00, 0x00, // 0 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x21, // TYPE is SRV.
+ 0x00, 0x01, // CLASS is IN.
+};
+
+
+class MockServiceWatcherClient {
+ public:
+ MOCK_METHOD2(OnServiceUpdated,
+ void(ServiceWatcher::UpdateType, const std::string&));
+
+ ServiceWatcher::UpdatedCallback GetCallback() {
+ return base::Bind(&MockServiceWatcherClient::OnServiceUpdated,
+ base::Unretained(this));
+ }
+};
+
+class ServiceDiscoveryTest : public ::testing::Test {
+ public:
+ ServiceDiscoveryTest()
+ : service_discovery_client_(&mdns_client_) {
+ mdns_client_.StartListening(&socket_factory_);
+ }
+
+ virtual ~ServiceDiscoveryTest() {
+ }
+
+ protected:
+ void RunFor(base::TimeDelta time_period) {
+ base::CancelableCallback<void()> callback(base::Bind(
+ &ServiceDiscoveryTest::Stop, base::Unretained(this)));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, callback.callback(), time_period);
+
+ base::MessageLoop::current()->Run();
+ callback.Cancel();
+ }
+
+ void Stop() {
+ base::MessageLoop::current()->Quit();
+ }
+
+ net::MockMDnsSocketFactory socket_factory_;
+ net::MDnsClientImpl mdns_client_;
+ ServiceDiscoveryClientImpl service_discovery_client_;
+ base::MessageLoop loop_;
+};
+
+TEST_F(ServiceDiscoveryTest, AddRemoveService) {
+ StrictMock<MockServiceWatcherClient> delegate;
+
+ scoped_ptr<ServiceWatcher> watcher(
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback()));
+
+ watcher->Start();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ RunFor(base::TimeDelta::FromSeconds(2));
+};
+
+TEST_F(ServiceDiscoveryTest, DiscoverNewServices) {
+ StrictMock<MockServiceWatcherClient> delegate;
+
+ scoped_ptr<ServiceWatcher> watcher(
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback()));
+
+ watcher->Start();
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
+
+ watcher->DiscoverNewServices(false);
+
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
+
+ RunFor(base::TimeDelta::FromSeconds(2));
+};
+
+TEST_F(ServiceDiscoveryTest, ReadCachedServices) {
+ socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
+
+ StrictMock<MockServiceWatcherClient> delegate;
+
+ scoped_ptr<ServiceWatcher> watcher(
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback()));
+
+ watcher->Start();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ base::MessageLoop::current()->RunUntilIdle();
+};
+
+
+TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) {
+ socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2));
+
+ StrictMock<MockServiceWatcherClient> delegate;
+ scoped_ptr<ServiceWatcher> watcher =
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback());
+
+ watcher->Start();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "gdbye._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ base::MessageLoop::current()->RunUntilIdle();
+};
+
+
+TEST_F(ServiceDiscoveryTest, OnServiceChanged) {
+ StrictMock<MockServiceWatcherClient> delegate;
+ scoped_ptr<ServiceWatcher> watcher(
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback()));
+
+ watcher->Start();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
+
+ socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
+
+ base::MessageLoop::current()->RunUntilIdle();
+};
+
+TEST_F(ServiceDiscoveryTest, SinglePacket) {
+ StrictMock<MockServiceWatcherClient> delegate;
+ scoped_ptr<ServiceWatcher> watcher(
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback()));
+
+ watcher->Start();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
+
+ // Reset the "already updated" flag.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
+
+ socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
+
+ base::MessageLoop::current()->RunUntilIdle();
+};
+
+TEST_F(ServiceDiscoveryTest, ActivelyRefreshServices) {
+ StrictMock<MockServiceWatcherClient> delegate;
+ scoped_ptr<ServiceWatcher> watcher(
+ service_discovery_client_.CreateServiceWatcher(
+ "_privet._tcp.local", delegate.GetCallback()));
+
+ watcher->Start();
+ watcher->SetActivelyRefreshServices(true);
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ std::string query_packet = std::string((const char*)(kSamplePacketQuerySRV),
+ sizeof(kSamplePacketQuerySRV));
+
+ EXPECT_CALL(socket_factory_, OnSendTo(query_packet))
+ .Times(2);
+
+ socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
+
+ EXPECT_CALL(socket_factory_, OnSendTo(query_packet))
+ .Times(4); // IPv4 and IPv6 at 85% and 95%
+
+ EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
+ "hello._privet._tcp.local"))
+ .Times(Exactly(1));
+
+ RunFor(base::TimeDelta::FromSeconds(2));
+
+ base::MessageLoop::current()->RunUntilIdle();
+};
+
+
+class ServiceResolverTest : public ServiceDiscoveryTest {
+ public:
+ ServiceResolverTest() {
+ metadata_expected_.push_back("hello");
+ address_expected_ = net::HostPortPair("myhello.local", 8888);
+ ip_address_expected_.push_back(1);
+ ip_address_expected_.push_back(2);
+ ip_address_expected_.push_back(3);
+ ip_address_expected_.push_back(4);
+ }
+
+ ~ServiceResolverTest() {
+ }
+
+ void SetUp() {
+ resolver_ = service_discovery_client_.CreateServiceResolver(
+ "hello._privet._tcp.local",
+ base::Bind(&ServiceResolverTest::OnFinishedResolving,
+ base::Unretained(this)));
+ }
+
+ void OnFinishedResolving(ServiceResolver::RequestStatus request_status,
+ const ServiceDescription& service_description) {
+ OnFinishedResolvingInternal(request_status,
+ service_description.address.ToString(),
+ service_description.metadata,
+ service_description.ip_address);
+ }
+
+ MOCK_METHOD4(OnFinishedResolvingInternal,
+ void(ServiceResolver::RequestStatus,
+ const std::string&,
+ const std::vector<std::string>&,
+ const net::IPAddressNumber&));
+
+ protected:
+ scoped_ptr<ServiceResolver> resolver_;
+ net::IPAddressNumber ip_address_;
+ net::HostPortPair address_expected_;
+ std::vector<std::string> metadata_expected_;
+ net::IPAddressNumber ip_address_expected_;
+};
+
+TEST_F(ServiceResolverTest, TxtAndSrvButNoA) {
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
+
+ resolver_->StartResolving();
+
+ socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_CALL(*this,
+ OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
+ address_expected_.ToString(),
+ metadata_expected_,
+ net::IPAddressNumber()));
+
+ socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
+};
+
+TEST_F(ServiceResolverTest, TxtSrvAndA) {
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
+
+ resolver_->StartResolving();
+
+ EXPECT_CALL(*this,
+ OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
+ address_expected_.ToString(),
+ metadata_expected_,
+ ip_address_expected_));
+
+ socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
+
+ socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
+};
+
+TEST_F(ServiceResolverTest, JustSrv) {
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
+
+ resolver_->StartResolving();
+
+ EXPECT_CALL(*this,
+ OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
+ address_expected_.ToString(),
+ std::vector<std::string>(),
+ ip_address_expected_));
+
+ socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
+
+ // TODO(noamsml): When NSEC record support is added, change this to use an
+ // NSEC record.
+ RunFor(base::TimeDelta::FromSeconds(4));
+};
+
+TEST_F(ServiceResolverTest, WithNothing) {
+ EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
+
+ resolver_->StartResolving();
+
+ EXPECT_CALL(*this, OnFinishedResolvingInternal(
+ ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _));
+
+ // TODO(noamsml): When NSEC record support is added, change this to use an
+ // NSEC record.
+ RunFor(base::TimeDelta::FromSeconds(4));
+};
+
+} // namespace
+
+} // namespace local_discovery