diff options
| author | vitalybuka@chromium.org <vitalybuka@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-04-28 08:14:37 +0000 |
|---|---|---|
| committer | vitalybuka@chromium.org <vitalybuka@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-04-28 08:14:37 +0000 |
| commit | 040e90ddcb6c491db29fc917424e177bba40508c (patch) | |
| tree | 84a76dfd29aa19bc42fc3b1f7cd8db847f27635e /chrome/common/local_discovery | |
| parent | 2ea5fc1d74afbd7bf3f4d88a18b437f05e20d58b (diff) | |
| download | chromium_src-040e90ddcb6c491db29fc917424e177bba40508c.zip chromium_src-040e90ddcb6c491db29fc917424e177bba40508c.tar.gz chromium_src-040e90ddcb6c491db29fc917424e177bba40508c.tar.bz2 | |
Move service_discovery_client_impl.* into chrome/common.
It's need to be used by utility and browser processes.
BUG=349645
Review URL: https://codereview.chromium.org/256923003
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@266487 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'chrome/common/local_discovery')
4 files changed, 1558 insertions, 0 deletions
diff --git a/chrome/common/local_discovery/local_domain_resolver_unittest.cc b/chrome/common/local_discovery/local_domain_resolver_unittest.cc new file mode 100644 index 0000000..cc65ab3 --- /dev/null +++ b/chrome/common/local_discovery/local_domain_resolver_unittest.cc @@ -0,0 +1,183 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome/common/local_discovery/service_discovery_client_impl.h" +#include "net/dns/mdns_client_impl.h" +#include "net/dns/mock_mdns_socket_factory.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using ::testing::_; + +namespace local_discovery { + +namespace { + +const uint8 kSamplePacketA[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 16 seconds. + 0x00, 0x10, + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x01, 0x02, + 0x03, 0x04, +}; + +const uint8 kSamplePacketAAAA[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x1C, // TYPE is AAAA. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 16 seconds. + 0x00, 0x10, + 0x00, 0x10, // RDLENGTH is 4 bytes. + 0x00, 0x0A, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x02, + 0x00, 0x03, 0x00, 0x04, +}; + +class LocalDomainResolverTest : public testing::Test { + public: + virtual void SetUp() OVERRIDE { + mdns_client_.StartListening(&socket_factory_); + } + + std::string IPAddressToStringWithEmpty(const net::IPAddressNumber& address) { + if (address.empty()) return ""; + return net::IPAddressToString(address); + } + + void AddressCallback(bool resolved, + const net::IPAddressNumber& address_ipv4, + const net::IPAddressNumber& address_ipv6) { + AddressCallbackInternal(resolved, + IPAddressToStringWithEmpty(address_ipv4), + IPAddressToStringWithEmpty(address_ipv6)); + } + + void RunFor(base::TimeDelta time_period) { + base::CancelableCallback<void()> callback(base::Bind( + &base::MessageLoop::Quit, + base::Unretained(base::MessageLoop::current()))); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, callback.callback(), time_period); + + base::MessageLoop::current()->Run(); + callback.Cancel(); + } + + MOCK_METHOD3(AddressCallbackInternal, + void(bool resolved, + std::string address_ipv4, + std::string address_ipv6)); + + net::MockMDnsSocketFactory socket_factory_; + net::MDnsClientImpl mdns_client_; + base::MessageLoop message_loop_; +}; + +TEST_F(LocalDomainResolverTest, ResolveDomainA) { + LocalDomainResolverImpl resolver( + "myhello.local", net::ADDRESS_FAMILY_IPV4, + base::Bind(&LocalDomainResolverTest::AddressCallback, + base::Unretained(this)), &mdns_client_); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); // Twice per query + + resolver.Start(); + + EXPECT_CALL(*this, AddressCallbackInternal(true, "1.2.3.4", "")); + + socket_factory_.SimulateReceive(kSamplePacketA, sizeof(kSamplePacketA)); +} + +TEST_F(LocalDomainResolverTest, ResolveDomainAAAA) { + LocalDomainResolverImpl resolver( + "myhello.local", net::ADDRESS_FAMILY_IPV6, + base::Bind(&LocalDomainResolverTest::AddressCallback, + base::Unretained(this)), &mdns_client_); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); // Twice per query + + resolver.Start(); + + EXPECT_CALL(*this, AddressCallbackInternal(true, "", "a::1:2:3:4")); + + socket_factory_.SimulateReceive(kSamplePacketAAAA, sizeof(kSamplePacketAAAA)); +} + +TEST_F(LocalDomainResolverTest, ResolveDomainAnyOneAvailable) { + LocalDomainResolverImpl resolver( + "myhello.local", net::ADDRESS_FAMILY_UNSPECIFIED, + base::Bind(&LocalDomainResolverTest::AddressCallback, + base::Unretained(this)), &mdns_client_); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); // Twice per query + + resolver.Start(); + + socket_factory_.SimulateReceive(kSamplePacketAAAA, sizeof(kSamplePacketAAAA)); + + EXPECT_CALL(*this, AddressCallbackInternal(true, "", "a::1:2:3:4")); + + RunFor(base::TimeDelta::FromMilliseconds(150)); +} + + +TEST_F(LocalDomainResolverTest, ResolveDomainAnyBothAvailable) { + LocalDomainResolverImpl resolver( + "myhello.local", net::ADDRESS_FAMILY_UNSPECIFIED, + base::Bind(&LocalDomainResolverTest::AddressCallback, + base::Unretained(this)), &mdns_client_); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); // Twice per query + + resolver.Start(); + + EXPECT_CALL(*this, AddressCallbackInternal(true, "1.2.3.4", "a::1:2:3:4")); + + socket_factory_.SimulateReceive(kSamplePacketAAAA, sizeof(kSamplePacketAAAA)); + + socket_factory_.SimulateReceive(kSamplePacketA, sizeof(kSamplePacketA)); +} + +TEST_F(LocalDomainResolverTest, ResolveDomainNone) { + LocalDomainResolverImpl resolver( + "myhello.local", net::ADDRESS_FAMILY_UNSPECIFIED, + base::Bind(&LocalDomainResolverTest::AddressCallback, + base::Unretained(this)), &mdns_client_); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); // Twice per query + + resolver.Start(); + + EXPECT_CALL(*this, AddressCallbackInternal(false, "", "")); + + RunFor(base::TimeDelta::FromSeconds(4)); +} + +} // namespace + +} // namespace local_discovery diff --git a/chrome/common/local_discovery/service_discovery_client_impl.cc b/chrome/common/local_discovery/service_discovery_client_impl.cc new file mode 100644 index 0000000..c86f726 --- /dev/null +++ b/chrome/common/local_discovery/service_discovery_client_impl.cc @@ -0,0 +1,584 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <utility> + +#include "base/logging.h" +#include "base/memory/singleton.h" +#include "base/message_loop/message_loop_proxy.h" +#include "base/stl_util.h" +#include "chrome/common/local_discovery/service_discovery_client_impl.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/record_rdata.h" + +namespace local_discovery { + +namespace { +// TODO(noamsml): Make this configurable through the LocalDomainResolver +// interface. +const int kLocalDomainSecondAddressTimeoutMs = 100; + +const int kInitialRequeryTimeSeconds = 1; +const int kMaxRequeryTimeSeconds = 2; // Time for last requery +} + +ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl( + net::MDnsClient* mdns_client) : mdns_client_(mdns_client) { +} + +ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() { +} + +scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher( + const std::string& service_type, + const ServiceWatcher::UpdatedCallback& callback) { + return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl( + service_type, callback, mdns_client_)); +} + +scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver( + const std::string& service_name, + const ServiceResolver::ResolveCompleteCallback& callback) { + return scoped_ptr<ServiceResolver>(new ServiceResolverImpl( + service_name, callback, mdns_client_)); +} + +scoped_ptr<LocalDomainResolver> +ServiceDiscoveryClientImpl::CreateLocalDomainResolver( + const std::string& domain, + net::AddressFamily address_family, + const LocalDomainResolver::IPAddressCallback& callback) { + return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl( + domain, address_family, callback, mdns_client_)); +} + +ServiceWatcherImpl::ServiceWatcherImpl( + const std::string& service_type, + const ServiceWatcher::UpdatedCallback& callback, + net::MDnsClient* mdns_client) + : service_type_(service_type), callback_(callback), started_(false), + actively_refresh_services_(false), mdns_client_(mdns_client) { +} + +void ServiceWatcherImpl::Start() { + DCHECK(!started_); + listener_ = mdns_client_->CreateListener( + net::dns_protocol::kTypePTR, service_type_, this); + started_ = listener_->Start(); + if (started_) + ReadCachedServices(); +} + +ServiceWatcherImpl::~ServiceWatcherImpl() { +} + +void ServiceWatcherImpl::DiscoverNewServices(bool force_update) { + DCHECK(started_); + if (force_update) + services_.clear(); + SendQuery(kInitialRequeryTimeSeconds, force_update); +} + +void ServiceWatcherImpl::SetActivelyRefreshServices( + bool actively_refresh_services) { + DCHECK(started_); + actively_refresh_services_ = actively_refresh_services; + + for (ServiceListenersMap::iterator i = services_.begin(); + i != services_.end(); i++) { + i->second->SetActiveRefresh(actively_refresh_services); + } +} + +void ServiceWatcherImpl::ReadCachedServices() { + DCHECK(started_); + CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/, + &transaction_cache_); +} + +bool ServiceWatcherImpl::CreateTransaction( + bool network, bool cache, bool force_refresh, + scoped_ptr<net::MDnsTransaction>* transaction) { + int transaction_flags = 0; + if (network) + transaction_flags |= net::MDnsTransaction::QUERY_NETWORK; + + if (cache) + transaction_flags |= net::MDnsTransaction::QUERY_CACHE; + + // TODO(noamsml): Add flag for force_refresh when supported. + + if (transaction_flags) { + *transaction = mdns_client_->CreateTransaction( + net::dns_protocol::kTypePTR, service_type_, transaction_flags, + base::Bind(&ServiceWatcherImpl::OnTransactionResponse, + base::Unretained(this), transaction)); + return (*transaction)->Start(); + } + + return true; +} + +std::string ServiceWatcherImpl::GetServiceType() const { + return listener_->GetName(); +} + +void ServiceWatcherImpl::OnRecordUpdate( + net::MDnsListener::UpdateType update, + const net::RecordParsed* record) { + DCHECK(started_); + if (record->type() == net::dns_protocol::kTypePTR) { + DCHECK(record->name() == GetServiceType()); + const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); + + switch (update) { + case net::MDnsListener::RECORD_ADDED: + AddService(rdata->ptrdomain()); + break; + case net::MDnsListener::RECORD_CHANGED: + NOTREACHED(); + break; + case net::MDnsListener::RECORD_REMOVED: + RemovePTR(rdata->ptrdomain()); + break; + } + } else { + DCHECK(record->type() == net::dns_protocol::kTypeSRV || + record->type() == net::dns_protocol::kTypeTXT); + DCHECK(services_.find(record->name()) != services_.end()); + + if (record->type() == net::dns_protocol::kTypeSRV) { + if (update == net::MDnsListener::RECORD_REMOVED) { + RemoveSRV(record->name()); + } else if (update == net::MDnsListener::RECORD_ADDED) { + AddSRV(record->name()); + } + } + + // If this is the first time we see an SRV record, do not send + // an UPDATE_CHANGED. + if (record->type() != net::dns_protocol::kTypeSRV || + update != net::MDnsListener::RECORD_ADDED) { + DeferUpdate(UPDATE_CHANGED, record->name()); + } + } +} + +void ServiceWatcherImpl::OnCachePurged() { + // Not yet implemented. +} + +void ServiceWatcherImpl::OnTransactionResponse( + scoped_ptr<net::MDnsTransaction>* transaction, + net::MDnsTransaction::Result result, + const net::RecordParsed* record) { + DCHECK(started_); + if (result == net::MDnsTransaction::RESULT_RECORD) { + const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); + DCHECK(rdata); + AddService(rdata->ptrdomain()); + } else if (result == net::MDnsTransaction::RESULT_DONE) { + transaction->reset(); + } + + // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC + // record for PTR records on any name. +} + +ServiceWatcherImpl::ServiceListeners::ServiceListeners( + const std::string& service_name, + ServiceWatcherImpl* watcher, + net::MDnsClient* mdns_client) + : service_name_(service_name), mdns_client_(mdns_client), + update_pending_(false), has_ptr_(true), has_srv_(false) { + srv_listener_ = mdns_client->CreateListener( + net::dns_protocol::kTypeSRV, service_name, watcher); + txt_listener_ = mdns_client->CreateListener( + net::dns_protocol::kTypeTXT, service_name, watcher); +} + +ServiceWatcherImpl::ServiceListeners::~ServiceListeners() { +} + +bool ServiceWatcherImpl::ServiceListeners::Start() { + if (!srv_listener_->Start()) + return false; + return txt_listener_->Start(); +} + +void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh( + bool active_refresh) { + srv_listener_->SetActiveRefresh(active_refresh); + + if (active_refresh && !has_srv_) { + DCHECK(has_ptr_); + srv_transaction_ = mdns_client_->CreateTransaction( + net::dns_protocol::kTypeSRV, service_name_, + net::MDnsTransaction::SINGLE_RESULT | + net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK, + base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord, + base::Unretained(this))); + srv_transaction_->Start(); + } else if (!active_refresh) { + srv_transaction_.reset(); + } +} + +void ServiceWatcherImpl::ServiceListeners::OnSRVRecord( + net::MDnsTransaction::Result result, + const net::RecordParsed* record) { + set_has_srv(record != NULL); +} + +void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) { + has_srv_ = has_srv; + + srv_transaction_.reset(); +} + +void ServiceWatcherImpl::AddService(const std::string& service) { + DCHECK(started_); + std::pair<ServiceListenersMap::iterator, bool> found = services_.insert( + make_pair(service, linked_ptr<ServiceListeners>(NULL))); + + if (found.second) { // Newly inserted. + found.first->second = linked_ptr<ServiceListeners>( + new ServiceListeners(service, this, mdns_client_)); + bool success = found.first->second->Start(); + found.first->second->SetActiveRefresh(actively_refresh_services_); + DeferUpdate(UPDATE_ADDED, service); + + DCHECK(success); + } + + found.first->second->set_has_ptr(true); +} + +void ServiceWatcherImpl::AddSRV(const std::string& service) { + DCHECK(started_); + + ServiceListenersMap::iterator found = services_.find(service); + if (found != services_.end()) { + found->second->set_has_srv(true); + } +} + +void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type, + const std::string& service_name) { + ServiceListenersMap::iterator found = services_.find(service_name); + + if (found != services_.end() && !found->second->update_pending()) { + found->second->set_update_pending(true); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(), + update_type, service_name)); + } +} + +void ServiceWatcherImpl::DeliverDeferredUpdate( + ServiceWatcher::UpdateType update_type, const std::string& service_name) { + ServiceListenersMap::iterator found = services_.find(service_name); + + if (found != services_.end()) { + found->second->set_update_pending(false); + if (!callback_.is_null()) + callback_.Run(update_type, service_name); + } +} + +void ServiceWatcherImpl::RemovePTR(const std::string& service) { + DCHECK(started_); + + ServiceListenersMap::iterator found = services_.find(service); + if (found != services_.end()) { + found->second->set_has_ptr(false); + + if (!found->second->has_ptr_or_srv()) { + services_.erase(found); + if (!callback_.is_null()) + callback_.Run(UPDATE_REMOVED, service); + } + } +} + +void ServiceWatcherImpl::RemoveSRV(const std::string& service) { + DCHECK(started_); + + ServiceListenersMap::iterator found = services_.find(service); + if (found != services_.end()) { + found->second->set_has_srv(false); + + if (!found->second->has_ptr_or_srv()) { + services_.erase(found); + if (!callback_.is_null()) + callback_.Run(UPDATE_REMOVED, service); + } + } +} + +void ServiceWatcherImpl::OnNsecRecord(const std::string& name, + unsigned rrtype) { + // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR + // on any name. +} + +void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) { + if (timeout_seconds <= kMaxRequeryTimeSeconds) { + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&ServiceWatcherImpl::SendQuery, + AsWeakPtr(), + timeout_seconds * 2 /*next_timeout_seconds*/, + false /*force_update*/), + base::TimeDelta::FromSeconds(timeout_seconds)); + } +} + +void ServiceWatcherImpl::SendQuery(int next_timeout_seconds, + bool force_update) { + CreateTransaction(true /*network*/, false /*cache*/, force_update, + &transaction_network_); + ScheduleQuery(next_timeout_seconds); +} + +ServiceResolverImpl::ServiceResolverImpl( + const std::string& service_name, + const ResolveCompleteCallback& callback, + net::MDnsClient* mdns_client) + : service_name_(service_name), callback_(callback), + metadata_resolved_(false), address_resolved_(false), + mdns_client_(mdns_client) { +} + +void ServiceResolverImpl::StartResolving() { + address_resolved_ = false; + metadata_resolved_ = false; + service_staging_ = ServiceDescription(); + service_staging_.service_name = service_name_; + + if (!CreateTxtTransaction() || !CreateSrvTransaction()) { + ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT); + } +} + +ServiceResolverImpl::~ServiceResolverImpl() { +} + +bool ServiceResolverImpl::CreateTxtTransaction() { + txt_transaction_ = mdns_client_->CreateTransaction( + net::dns_protocol::kTypeTXT, service_name_, + net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | + net::MDnsTransaction::QUERY_NETWORK, + base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse, + AsWeakPtr())); + return txt_transaction_->Start(); +} + +// TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in +void ServiceResolverImpl::CreateATransaction() { + a_transaction_ = mdns_client_->CreateTransaction( + net::dns_protocol::kTypeA, + service_staging_.address.host(), + net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE, + base::Bind(&ServiceResolverImpl::ARecordTransactionResponse, + AsWeakPtr())); + a_transaction_->Start(); +} + +bool ServiceResolverImpl::CreateSrvTransaction() { + srv_transaction_ = mdns_client_->CreateTransaction( + net::dns_protocol::kTypeSRV, service_name_, + net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | + net::MDnsTransaction::QUERY_NETWORK, + base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse, + AsWeakPtr())); + return srv_transaction_->Start(); +} + +std::string ServiceResolverImpl::GetName() const { + return service_name_; +} + +void ServiceResolverImpl::SrvRecordTransactionResponse( + net::MDnsTransaction::Result status, const net::RecordParsed* record) { + srv_transaction_.reset(); + if (status == net::MDnsTransaction::RESULT_RECORD) { + DCHECK(record); + service_staging_.address = RecordToAddress(record); + service_staging_.last_seen = record->time_created(); + CreateATransaction(); + } else { + ServiceNotFound(MDnsStatusToRequestStatus(status)); + } +} + +void ServiceResolverImpl::TxtRecordTransactionResponse( + net::MDnsTransaction::Result status, const net::RecordParsed* record) { + txt_transaction_.reset(); + if (status == net::MDnsTransaction::RESULT_RECORD) { + DCHECK(record); + service_staging_.metadata = RecordToMetadata(record); + } else { + service_staging_.metadata = std::vector<std::string>(); + } + + metadata_resolved_ = true; + AlertCallbackIfReady(); +} + +void ServiceResolverImpl::ARecordTransactionResponse( + net::MDnsTransaction::Result status, const net::RecordParsed* record) { + a_transaction_.reset(); + + if (status == net::MDnsTransaction::RESULT_RECORD) { + DCHECK(record); + service_staging_.ip_address = RecordToIPAddress(record); + } else { + service_staging_.ip_address = net::IPAddressNumber(); + } + + address_resolved_ = true; + AlertCallbackIfReady(); +} + +void ServiceResolverImpl::AlertCallbackIfReady() { + if (metadata_resolved_ && address_resolved_) { + txt_transaction_.reset(); + srv_transaction_.reset(); + a_transaction_.reset(); + if (!callback_.is_null()) + callback_.Run(STATUS_SUCCESS, service_staging_); + } +} + +void ServiceResolverImpl::ServiceNotFound( + ServiceResolver::RequestStatus status) { + txt_transaction_.reset(); + srv_transaction_.reset(); + a_transaction_.reset(); + if (!callback_.is_null()) + callback_.Run(status, ServiceDescription()); +} + +ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus( + net::MDnsTransaction::Result status) const { + switch (status) { + case net::MDnsTransaction::RESULT_RECORD: + return ServiceResolver::STATUS_SUCCESS; + case net::MDnsTransaction::RESULT_NO_RESULTS: + return ServiceResolver::STATUS_REQUEST_TIMEOUT; + case net::MDnsTransaction::RESULT_NSEC: + return ServiceResolver::STATUS_KNOWN_NONEXISTENT; + case net::MDnsTransaction::RESULT_DONE: // Pass through. + default: + NOTREACHED(); + return ServiceResolver::STATUS_REQUEST_TIMEOUT; + } +} + +const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata( + const net::RecordParsed* record) const { + DCHECK(record->type() == net::dns_protocol::kTypeTXT); + const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>(); + DCHECK(txt_rdata); + return txt_rdata->texts(); +} + +net::HostPortPair ServiceResolverImpl::RecordToAddress( + const net::RecordParsed* record) const { + DCHECK(record->type() == net::dns_protocol::kTypeSRV); + const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>(); + DCHECK(srv_rdata); + return net::HostPortPair(srv_rdata->target(), srv_rdata->port()); +} + +const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress( + const net::RecordParsed* record) const { + DCHECK(record->type() == net::dns_protocol::kTypeA); + const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>(); + DCHECK(a_rdata); + return a_rdata->address(); +} + +LocalDomainResolverImpl::LocalDomainResolverImpl( + const std::string& domain, + net::AddressFamily address_family, + const IPAddressCallback& callback, + net::MDnsClient* mdns_client) + : domain_(domain), address_family_(address_family), callback_(callback), + transactions_finished_(0), mdns_client_(mdns_client) { +} + +LocalDomainResolverImpl::~LocalDomainResolverImpl() { + timeout_callback_.Cancel(); +} + +void LocalDomainResolverImpl::Start() { + if (address_family_ == net::ADDRESS_FAMILY_IPV4 || + address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { + transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA); + transaction_a_->Start(); + } + + if (address_family_ == net::ADDRESS_FAMILY_IPV6 || + address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { + transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA); + transaction_aaaa_->Start(); + } +} + +scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction( + uint16 type) { + return mdns_client_->CreateTransaction( + type, domain_, net::MDnsTransaction::SINGLE_RESULT | + net::MDnsTransaction::QUERY_CACHE | + net::MDnsTransaction::QUERY_NETWORK, + base::Bind(&LocalDomainResolverImpl::OnTransactionComplete, + base::Unretained(this))); +} + +void LocalDomainResolverImpl::OnTransactionComplete( + net::MDnsTransaction::Result result, const net::RecordParsed* record) { + transactions_finished_++; + + if (result == net::MDnsTransaction::RESULT_RECORD) { + if (record->type() == net::dns_protocol::kTypeA) { + const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>(); + address_ipv4_ = rdata->address(); + } else { + DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type()); + const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>(); + address_ipv6_ = rdata->address(); + } + } + + if (transactions_finished_ == 1 && + address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { + timeout_callback_.Reset(base::Bind( + &LocalDomainResolverImpl::SendResolvedAddresses, + base::Unretained(this))); + + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + timeout_callback_.callback(), + base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs)); + } else if (transactions_finished_ == 2 + || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) { + SendResolvedAddresses(); + } +} + +bool LocalDomainResolverImpl::IsSuccess() { + return !address_ipv4_.empty() || !address_ipv6_.empty(); +} + +void LocalDomainResolverImpl::SendResolvedAddresses() { + transaction_a_.reset(); + transaction_aaaa_.reset(); + timeout_callback_.Cancel(); + callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_); +} + +} // namespace local_discovery diff --git a/chrome/common/local_discovery/service_discovery_client_impl.h b/chrome/common/local_discovery/service_discovery_client_impl.h new file mode 100644 index 0000000..29e42c7 --- /dev/null +++ b/chrome/common/local_discovery/service_discovery_client_impl.h @@ -0,0 +1,273 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_COMMON_LOCAL_DISCOVERY_SERVICE_DISCOVERY_CLIENT_IMPL_H_ +#define CHROME_COMMON_LOCAL_DISCOVERY_SERVICE_DISCOVERY_CLIENT_IMPL_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/callback.h" +#include "base/cancelable_callback.h" +#include "base/memory/linked_ptr.h" +#include "base/memory/weak_ptr.h" +#include "base/message_loop/message_loop.h" +#include "chrome/common/local_discovery/service_discovery_client.h" +#include "net/dns/mdns_client.h" + +namespace local_discovery { + +class ServiceDiscoveryClientImpl : public ServiceDiscoveryClient { + public: + // |mdns_client| must outlive the Service Discovery Client. + explicit ServiceDiscoveryClientImpl(net::MDnsClient* mdns_client); + virtual ~ServiceDiscoveryClientImpl(); + + // ServiceDiscoveryClient implementation: + virtual scoped_ptr<ServiceWatcher> CreateServiceWatcher( + const std::string& service_type, + const ServiceWatcher::UpdatedCallback& callback) OVERRIDE; + + virtual scoped_ptr<ServiceResolver> CreateServiceResolver( + const std::string& service_name, + const ServiceResolver::ResolveCompleteCallback& callback) OVERRIDE; + + virtual scoped_ptr<LocalDomainResolver> CreateLocalDomainResolver( + const std::string& domain, + net::AddressFamily address_family, + const LocalDomainResolver::IPAddressCallback& callback) OVERRIDE; + + private: + net::MDnsClient* mdns_client_; + + DISALLOW_COPY_AND_ASSIGN(ServiceDiscoveryClientImpl); +}; + +class ServiceWatcherImpl : public ServiceWatcher, + public net::MDnsListener::Delegate, + public base::SupportsWeakPtr<ServiceWatcherImpl> { + public: + ServiceWatcherImpl(const std::string& service_type, + const ServiceWatcher::UpdatedCallback& callback, + net::MDnsClient* mdns_client); + // Listening will automatically stop when the destructor is called. + virtual ~ServiceWatcherImpl(); + + // ServiceWatcher implementation: + virtual void Start() OVERRIDE; + + virtual void DiscoverNewServices(bool force_update) OVERRIDE; + + virtual void SetActivelyRefreshServices( + bool actively_refresh_services) OVERRIDE; + + virtual std::string GetServiceType() const OVERRIDE; + + virtual void OnRecordUpdate(net::MDnsListener::UpdateType update, + const net::RecordParsed* record) OVERRIDE; + + virtual void OnNsecRecord(const std::string& name, unsigned rrtype) OVERRIDE; + + virtual void OnCachePurged() OVERRIDE; + + virtual void OnTransactionResponse( + scoped_ptr<net::MDnsTransaction>* transaction, + net::MDnsTransaction::Result result, + const net::RecordParsed* record); + + private: + struct ServiceListeners { + ServiceListeners(const std::string& service_name, + ServiceWatcherImpl* watcher, + net::MDnsClient* mdns_client); + ~ServiceListeners(); + bool Start(); + void SetActiveRefresh(bool auto_update); + + void set_update_pending(bool update_pending) { + update_pending_ = update_pending; + } + + bool update_pending() { return update_pending_; } + + void set_has_ptr(bool has_ptr) { + has_ptr_ = has_ptr; + } + + void set_has_srv(bool has_srv); + + bool has_ptr_or_srv() { return has_ptr_ || has_srv_; } + + private: + void OnSRVRecord(net::MDnsTransaction::Result result, + const net::RecordParsed* record); + + void DoQuerySRV(); + + scoped_ptr<net::MDnsListener> srv_listener_; + scoped_ptr<net::MDnsListener> txt_listener_; + scoped_ptr<net::MDnsTransaction> srv_transaction_; + + std::string service_name_; + net::MDnsClient* mdns_client_; + bool update_pending_; + + bool has_ptr_; + bool has_srv_; + }; + + typedef std::map<std::string, linked_ptr<ServiceListeners> > + ServiceListenersMap; + + void ReadCachedServices(); + void AddService(const std::string& service); + void RemovePTR(const std::string& service); + void RemoveSRV(const std::string& service); + void AddSRV(const std::string& service); + bool CreateTransaction(bool active, bool alert_existing_services, + bool force_refresh, + scoped_ptr<net::MDnsTransaction>* transaction); + + void DeferUpdate(ServiceWatcher::UpdateType update_type, + const std::string& service_name); + void DeliverDeferredUpdate(ServiceWatcher::UpdateType update_type, + const std::string& service_name); + + void ScheduleQuery(int timeout_seconds); + + void SendQuery(int next_timeout_seconds, bool force_update); + + std::string service_type_; + ServiceListenersMap services_; + scoped_ptr<net::MDnsTransaction> transaction_network_; + scoped_ptr<net::MDnsTransaction> transaction_cache_; + scoped_ptr<net::MDnsListener> listener_; + + ServiceWatcher::UpdatedCallback callback_; + bool started_; + bool actively_refresh_services_; + + net::MDnsClient* mdns_client_; + + DISALLOW_COPY_AND_ASSIGN(ServiceWatcherImpl); +}; + +class ServiceResolverImpl + : public ServiceResolver, + public base::SupportsWeakPtr<ServiceResolverImpl> { + public: + ServiceResolverImpl(const std::string& service_name, + const ServiceResolver::ResolveCompleteCallback& callback, + net::MDnsClient* mdns_client); + + virtual ~ServiceResolverImpl(); + + // ServiceResolver implementation: + virtual void StartResolving() OVERRIDE; + + virtual std::string GetName() const OVERRIDE; + + private: + // Respond to transaction finishing for SRV records. + void SrvRecordTransactionResponse(net::MDnsTransaction::Result status, + const net::RecordParsed* record); + + // Respond to transaction finishing for TXT records. + void TxtRecordTransactionResponse(net::MDnsTransaction::Result status, + const net::RecordParsed* record); + + // Respond to transaction finishing for A records. + void ARecordTransactionResponse(net::MDnsTransaction::Result status, + const net::RecordParsed* record); + + void AlertCallbackIfReady(); + + void ServiceNotFound(RequestStatus status); + + // Convert a TXT record to a vector of strings (metadata). + const std::vector<std::string>& RecordToMetadata( + const net::RecordParsed* record) const; + + // Convert an SRV record to a host and port pair. + net::HostPortPair RecordToAddress( + const net::RecordParsed* record) const; + + // Convert an A record to an IP address. + const net::IPAddressNumber& RecordToIPAddress( + const net::RecordParsed* record) const; + + // Convert an MDns status to a service discovery status. + RequestStatus MDnsStatusToRequestStatus( + net::MDnsTransaction::Result status) const; + + bool CreateTxtTransaction(); + bool CreateSrvTransaction(); + void CreateATransaction(); + + std::string service_name_; + ResolveCompleteCallback callback_; + + bool has_resolved_; + + bool metadata_resolved_; + bool address_resolved_; + + scoped_ptr<net::MDnsTransaction> txt_transaction_; + scoped_ptr<net::MDnsTransaction> srv_transaction_; + scoped_ptr<net::MDnsTransaction> a_transaction_; + + ServiceDescription service_staging_; + + net::MDnsClient* mdns_client_; + + DISALLOW_COPY_AND_ASSIGN(ServiceResolverImpl); +}; + +class LocalDomainResolverImpl : public LocalDomainResolver { + public: + LocalDomainResolverImpl(const std::string& domain, + net::AddressFamily address_family, + const IPAddressCallback& callback, + net::MDnsClient* mdns_client); + virtual ~LocalDomainResolverImpl(); + + virtual void Start() OVERRIDE; + + const std::string& domain() { return domain_; } + + private: + void OnTransactionComplete( + net::MDnsTransaction::Result result, + const net::RecordParsed* record); + + scoped_ptr<net::MDnsTransaction> CreateTransaction(uint16 type); + + bool IsSuccess(); + + void SendResolvedAddresses(); + + std::string domain_; + net::AddressFamily address_family_; + IPAddressCallback callback_; + + scoped_ptr<net::MDnsTransaction> transaction_a_; + scoped_ptr<net::MDnsTransaction> transaction_aaaa_; + + int transactions_finished_; + + net::MDnsClient* mdns_client_; + + net::IPAddressNumber address_ipv4_; + net::IPAddressNumber address_ipv6_; + + base::CancelableCallback<void()> timeout_callback_; + + DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverImpl); +}; + + +} // namespace local_discovery + +#endif // CHROME_COMMON_LOCAL_DISCOVERY_SERVICE_DISCOVERY_CLIENT_IMPL_H_ diff --git a/chrome/common/local_discovery/service_discovery_client_unittest.cc b/chrome/common/local_discovery/service_discovery_client_unittest.cc new file mode 100644 index 0000000..5033fbdd --- /dev/null +++ b/chrome/common/local_discovery/service_discovery_client_unittest.cc @@ -0,0 +1,518 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/memory/weak_ptr.h" +#include "base/run_loop.h" +#include "chrome/common/local_discovery/service_discovery_client_impl.h" +#include "net/base/net_errors.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/mdns_client_impl.h" +#include "net/dns/mock_mdns_socket_factory.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using ::testing::_; +using ::testing::Invoke; +using ::testing::StrictMock; +using ::testing::NiceMock; +using ::testing::Mock; +using ::testing::SaveArg; +using ::testing::SetArgPointee; +using ::testing::Return; +using ::testing::Exactly; + +namespace local_discovery { + +namespace { + +const uint8 kSamplePacketPTR[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 1 second. + 0x00, 0x01, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c +}; + +const uint8 kSamplePacketSRV[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x05, 'h', 'e', 'l', 'l', 'o', + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x21, // TYPE is SRV. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 1 second. + 0x00, 0x01, + 0x00, 0x15, // RDLENGTH is 21 bytes. + 0x00, 0x00, + 0x00, 0x00, + 0x22, 0xb8, // port 8888 + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, +}; + +const uint8 kSamplePacketTXT[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x05, 'h', 'e', 'l', 'l', 'o', + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x10, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x00, 0x01, + 0x00, 0x06, // RDLENGTH is 21 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o' +}; + +const uint8 kSamplePacketSRVA[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x05, 'h', 'e', 'l', 'l', 'o', + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x21, // TYPE is SRV. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 16 seconds. + 0x00, 0x10, + 0x00, 0x15, // RDLENGTH is 21 bytes. + 0x00, 0x00, + 0x00, 0x00, + 0x22, 0xb8, // port 8888 + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + + 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 16 seconds. + 0x00, 0x10, + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x01, 0x02, + 0x03, 0x04, +}; + +const uint8 kSamplePacketPTR2[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x02, 0x00, // TTL (4 bytes) is 1 second. + 0x00, 0x01, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'g', 'd', 'b', 'y', 'e', + 0xc0, 0x0c, + + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x02, 0x00, // TTL (4 bytes) is 1 second. + 0x00, 0x01, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c +}; + +const uint8 kSamplePacketQuerySRV[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x00, 0x00, // No flags. + 0x00, 0x01, // One question. + 0x00, 0x00, // 0 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question + 0x05, 'h', 'e', 'l', 'l', 'o', + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x21, // TYPE is SRV. + 0x00, 0x01, // CLASS is IN. +}; + + +class MockServiceWatcherClient { + public: + MOCK_METHOD2(OnServiceUpdated, + void(ServiceWatcher::UpdateType, const std::string&)); + + ServiceWatcher::UpdatedCallback GetCallback() { + return base::Bind(&MockServiceWatcherClient::OnServiceUpdated, + base::Unretained(this)); + } +}; + +class ServiceDiscoveryTest : public ::testing::Test { + public: + ServiceDiscoveryTest() + : service_discovery_client_(&mdns_client_) { + mdns_client_.StartListening(&socket_factory_); + } + + virtual ~ServiceDiscoveryTest() { + } + + protected: + void RunFor(base::TimeDelta time_period) { + base::CancelableCallback<void()> callback(base::Bind( + &ServiceDiscoveryTest::Stop, base::Unretained(this))); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, callback.callback(), time_period); + + base::MessageLoop::current()->Run(); + callback.Cancel(); + } + + void Stop() { + base::MessageLoop::current()->Quit(); + } + + net::MockMDnsSocketFactory socket_factory_; + net::MDnsClientImpl mdns_client_; + ServiceDiscoveryClientImpl service_discovery_client_; + base::MessageLoop loop_; +}; + +TEST_F(ServiceDiscoveryTest, AddRemoveService) { + StrictMock<MockServiceWatcherClient> delegate; + + scoped_ptr<ServiceWatcher> watcher( + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback())); + + watcher->Start(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + RunFor(base::TimeDelta::FromSeconds(2)); +}; + +TEST_F(ServiceDiscoveryTest, DiscoverNewServices) { + StrictMock<MockServiceWatcherClient> delegate; + + scoped_ptr<ServiceWatcher> watcher( + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback())); + + watcher->Start(); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); + + watcher->DiscoverNewServices(false); + + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); + + RunFor(base::TimeDelta::FromSeconds(2)); +}; + +TEST_F(ServiceDiscoveryTest, ReadCachedServices) { + socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); + + StrictMock<MockServiceWatcherClient> delegate; + + scoped_ptr<ServiceWatcher> watcher( + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback())); + + watcher->Start(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + base::MessageLoop::current()->RunUntilIdle(); +}; + + +TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) { + socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2)); + + StrictMock<MockServiceWatcherClient> delegate; + scoped_ptr<ServiceWatcher> watcher = + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback()); + + watcher->Start(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "gdbye._privet._tcp.local")) + .Times(Exactly(1)); + + base::MessageLoop::current()->RunUntilIdle(); +}; + + +TEST_F(ServiceDiscoveryTest, OnServiceChanged) { + StrictMock<MockServiceWatcherClient> delegate; + scoped_ptr<ServiceWatcher> watcher( + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback())); + + watcher->Start(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); + + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); + + socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); + + base::MessageLoop::current()->RunUntilIdle(); +}; + +TEST_F(ServiceDiscoveryTest, SinglePacket) { + StrictMock<MockServiceWatcherClient> delegate; + scoped_ptr<ServiceWatcher> watcher( + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback())); + + watcher->Start(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); + + // Reset the "already updated" flag. + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); + + socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); + + base::MessageLoop::current()->RunUntilIdle(); +}; + +TEST_F(ServiceDiscoveryTest, ActivelyRefreshServices) { + StrictMock<MockServiceWatcherClient> delegate; + scoped_ptr<ServiceWatcher> watcher( + service_discovery_client_.CreateServiceWatcher( + "_privet._tcp.local", delegate.GetCallback())); + + watcher->Start(); + watcher->SetActivelyRefreshServices(true); + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + std::string query_packet = std::string((const char*)(kSamplePacketQuerySRV), + sizeof(kSamplePacketQuerySRV)); + + EXPECT_CALL(socket_factory_, OnSendTo(query_packet)) + .Times(2); + + socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); + + base::MessageLoop::current()->RunUntilIdle(); + + socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); + + EXPECT_CALL(socket_factory_, OnSendTo(query_packet)) + .Times(4); // IPv4 and IPv6 at 85% and 95% + + EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED, + "hello._privet._tcp.local")) + .Times(Exactly(1)); + + RunFor(base::TimeDelta::FromSeconds(2)); + + base::MessageLoop::current()->RunUntilIdle(); +}; + + +class ServiceResolverTest : public ServiceDiscoveryTest { + public: + ServiceResolverTest() { + metadata_expected_.push_back("hello"); + address_expected_ = net::HostPortPair("myhello.local", 8888); + ip_address_expected_.push_back(1); + ip_address_expected_.push_back(2); + ip_address_expected_.push_back(3); + ip_address_expected_.push_back(4); + } + + ~ServiceResolverTest() { + } + + void SetUp() { + resolver_ = service_discovery_client_.CreateServiceResolver( + "hello._privet._tcp.local", + base::Bind(&ServiceResolverTest::OnFinishedResolving, + base::Unretained(this))); + } + + void OnFinishedResolving(ServiceResolver::RequestStatus request_status, + const ServiceDescription& service_description) { + OnFinishedResolvingInternal(request_status, + service_description.address.ToString(), + service_description.metadata, + service_description.ip_address); + } + + MOCK_METHOD4(OnFinishedResolvingInternal, + void(ServiceResolver::RequestStatus, + const std::string&, + const std::vector<std::string>&, + const net::IPAddressNumber&)); + + protected: + scoped_ptr<ServiceResolver> resolver_; + net::IPAddressNumber ip_address_; + net::HostPortPair address_expected_; + std::vector<std::string> metadata_expected_; + net::IPAddressNumber ip_address_expected_; +}; + +TEST_F(ServiceResolverTest, TxtAndSrvButNoA) { + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); + + resolver_->StartResolving(); + + socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); + + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_CALL(*this, + OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, + address_expected_.ToString(), + metadata_expected_, + net::IPAddressNumber())); + + socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); +}; + +TEST_F(ServiceResolverTest, TxtSrvAndA) { + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); + + resolver_->StartResolving(); + + EXPECT_CALL(*this, + OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, + address_expected_.ToString(), + metadata_expected_, + ip_address_expected_)); + + socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); + + socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); +}; + +TEST_F(ServiceResolverTest, JustSrv) { + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); + + resolver_->StartResolving(); + + EXPECT_CALL(*this, + OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, + address_expected_.ToString(), + std::vector<std::string>(), + ip_address_expected_)); + + socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); + + // TODO(noamsml): When NSEC record support is added, change this to use an + // NSEC record. + RunFor(base::TimeDelta::FromSeconds(4)); +}; + +TEST_F(ServiceResolverTest, WithNothing) { + EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); + + resolver_->StartResolving(); + + EXPECT_CALL(*this, OnFinishedResolvingInternal( + ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _)); + + // TODO(noamsml): When NSEC record support is added, change this to use an + // NSEC record. + RunFor(base::TimeDelta::FromSeconds(4)); +}; + +} // namespace + +} // namespace local_discovery |
