// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "chrome/utility/local_discovery/service_discovery_message_handler.h" #include #include "base/lazy_instance.h" #include "chrome/common/local_discovery/local_discovery_messages.h" #include "chrome/common/local_discovery/service_discovery_client_impl.h" #include "content/public/utility/utility_thread.h" #include "net/socket/socket_descriptor.h" #include "net/udp/datagram_server_socket.h" namespace local_discovery { namespace { void ClosePlatformSocket(net::SocketDescriptor socket); // Sets socket factory used by |net::CreatePlatformSocket|. Implemetation // keeps single socket that will be returned to the first call to // |net::CreatePlatformSocket| during object lifetime. class ScopedSocketFactory : public net::PlatformSocketFactory { public: explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) { net::PlatformSocketFactory::SetInstance(this); } virtual ~ScopedSocketFactory() { net::PlatformSocketFactory::SetInstance(NULL); ClosePlatformSocket(socket_); socket_ = net::kInvalidSocket; } virtual net::SocketDescriptor CreateSocket(int family, int type, int protocol) OVERRIDE { DCHECK_EQ(type, SOCK_DGRAM); DCHECK(family == AF_INET || family == AF_INET6); net::SocketDescriptor result = net::kInvalidSocket; std::swap(result, socket_); return result; } private: net::SocketDescriptor socket_; DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory); }; struct SocketInfo { SocketInfo(net::SocketDescriptor socket, net::AddressFamily address_family, uint32 interface_index) : socket(socket), address_family(address_family), interface_index(interface_index) { } net::SocketDescriptor socket; net::AddressFamily address_family; uint32 interface_index; }; // Returns list of sockets preallocated before. class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory { public: PreCreatedMDnsSocketFactory() {} virtual ~PreCreatedMDnsSocketFactory() { // Not empty if process exits too fast, before starting mDns code. If // happened, destructors may crash accessing destroyed global objects. sockets_.weak_clear(); } // net::MDnsSocketFactory implementation: virtual void CreateSockets( ScopedVector* sockets) OVERRIDE { sockets->swap(sockets_); Reset(); } void AddSocket(const SocketInfo& socket_info) { // Takes ownership of socket_info.socket; ScopedSocketFactory platform_factory(socket_info.socket); scoped_ptr socket( net::CreateAndBindMDnsSocket(socket_info.address_family, socket_info.interface_index)); if (socket) { socket->DetachFromThread(); sockets_.push_back(socket.release()); } } void Reset() { sockets_.clear(); } private: ScopedVector sockets_; DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory); }; base::LazyInstance g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER; #if defined(OS_WIN) void ClosePlatformSocket(net::SocketDescriptor socket) { ::closesocket(socket); } void StaticInitializeSocketFactory() { net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind()); for (size_t i = 0; i < interfaces.size(); ++i) { DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 || interfaces[i].second == net::ADDRESS_FAMILY_IPV6); net::SocketDescriptor descriptor = net::CreatePlatformSocket( net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM, IPPROTO_UDP); g_local_discovery_socket_factory.Get().AddSocket( SocketInfo(descriptor, interfaces[i].second, interfaces[i].first)); } } #else // OS_WIN void ClosePlatformSocket(net::SocketDescriptor socket) { ::close(socket); } void StaticInitializeSocketFactory() { } #endif // OS_WIN void SendHostMessageOnUtilityThread(IPC::Message* msg) { content::UtilityThread::Get()->Send(msg); } std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) { switch (update) { case ServiceWatcher::UPDATE_ADDED: return "UPDATE_ADDED"; case ServiceWatcher::UPDATE_CHANGED: return "UPDATE_CHANGED"; case ServiceWatcher::UPDATE_REMOVED: return "UPDATE_REMOVED"; case ServiceWatcher::UPDATE_INVALIDATED: return "UPDATE_INVALIDATED"; } return "Unknown Update"; } std::string ResolverStatusToString(ServiceResolver::RequestStatus status) { switch (status) { case ServiceResolver::STATUS_SUCCESS: return "STATUS_SUCESS"; case ServiceResolver::STATUS_REQUEST_TIMEOUT: return "STATUS_REQUEST_TIMEOUT"; case ServiceResolver::STATUS_KNOWN_NONEXISTENT: return "STATUS_KNOWN_NONEXISTENT"; } return "Unknown Status"; } } // namespace ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() { } ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() { DCHECK(!discovery_thread_); } void ServiceDiscoveryMessageHandler::PreSandboxStartup() { StaticInitializeSocketFactory(); } void ServiceDiscoveryMessageHandler::InitializeMdns() { if (service_discovery_client_ || mdns_client_) return; mdns_client_ = net::MDnsClient::CreateDefault(); bool result = mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer()); // Close unused sockets. g_local_discovery_socket_factory.Get().Reset(); if (!result) { VLOG(1) << "Failed to start MDnsClient"; Send(new LocalDiscoveryHostMsg_Error()); return; } service_discovery_client_.reset( new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get())); } bool ServiceDiscoveryMessageHandler::InitializeThread() { if (discovery_task_runner_) return true; if (discovery_thread_) return false; utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy(); discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread")); base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0); if (discovery_thread_->StartWithOptions(thread_options)) { discovery_task_runner_ = discovery_thread_->message_loop_proxy(); discovery_task_runner_->PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns, base::Unretained(this))); } return discovery_task_runner_ != NULL; } bool ServiceDiscoveryMessageHandler::OnMessageReceived( const IPC::Message& message) { bool handled = true; IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message) #if defined(OS_POSIX) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets) #endif // OS_POSIX IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices, OnSetActivelyRefreshServices) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain, OnResolveLocalDomain) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver, OnDestroyLocalDomainResolver) IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery, ShutdownLocalDiscovery) IPC_MESSAGE_UNHANDLED(handled = false) IPC_END_MESSAGE_MAP() return handled; } void ServiceDiscoveryMessageHandler::PostTask( const tracked_objects::Location& from_here, const base::Closure& task) { if (!InitializeThread()) return; discovery_task_runner_->PostTask(from_here, task); } #if defined(OS_POSIX) void ServiceDiscoveryMessageHandler::OnSetSockets( const std::vector& sockets) { for (size_t i = 0; i < sockets.size(); ++i) { g_local_discovery_socket_factory.Get().AddSocket( SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family, sockets[i].interface_index)); } } #endif // OS_POSIX void ServiceDiscoveryMessageHandler::OnStartWatcher( uint64 id, const std::string& service_type) { PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher, base::Unretained(this), id, service_type)); } void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id, bool force_update) { PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices, base::Unretained(this), id, force_update)); } void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices( uint64 id, bool actively_refresh_services) { PostTask(FROM_HERE, base::Bind( &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices, base::Unretained(this), id, actively_refresh_services)); } void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) { PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher, base::Unretained(this), id)); } void ServiceDiscoveryMessageHandler::OnResolveService( uint64 id, const std::string& service_name) { PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::ResolveService, base::Unretained(this), id, service_name)); } void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) { PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver, base::Unretained(this), id)); } void ServiceDiscoveryMessageHandler::OnResolveLocalDomain( uint64 id, const std::string& domain, net::AddressFamily address_family) { PostTask(FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain, base::Unretained(this), id, domain, address_family)); } void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) { PostTask(FROM_HERE, base::Bind( &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver, base::Unretained(this), id)); } void ServiceDiscoveryMessageHandler::StartWatcher( uint64 id, const std::string& service_type) { VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type; if (!service_discovery_client_) return; DCHECK(!ContainsKey(service_watchers_, id)); scoped_ptr watcher( service_discovery_client_->CreateServiceWatcher( service_type, base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated, base::Unretained(this), id))); watcher->Start(); service_watchers_[id].reset(watcher.release()); } void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id, bool force_update) { VLOG(1) << "DiscoverServices, id=" << id; if (!service_discovery_client_) return; DCHECK(ContainsKey(service_watchers_, id)); service_watchers_[id]->DiscoverNewServices(force_update); } void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices( uint64 id, bool actively_refresh_services) { VLOG(1) << "ActivelyRefreshServices, id=" << id; if (!service_discovery_client_) return; DCHECK(ContainsKey(service_watchers_, id)); service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services); } void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) { VLOG(1) << "DestoryWatcher, id=" << id; if (!service_discovery_client_) return; service_watchers_.erase(id); } void ServiceDiscoveryMessageHandler::ResolveService( uint64 id, const std::string& service_name) { VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name; if (!service_discovery_client_) return; DCHECK(!ContainsKey(service_resolvers_, id)); scoped_ptr resolver( service_discovery_client_->CreateServiceResolver( service_name, base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved, base::Unretained(this), id))); resolver->StartResolving(); service_resolvers_[id].reset(resolver.release()); } void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) { VLOG(1) << "DestroyResolver, id=" << id; if (!service_discovery_client_) return; service_resolvers_.erase(id); } void ServiceDiscoveryMessageHandler::ResolveLocalDomain( uint64 id, const std::string& domain, net::AddressFamily address_family) { VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain; if (!service_discovery_client_) return; DCHECK(!ContainsKey(local_domain_resolvers_, id)); scoped_ptr resolver( service_discovery_client_->CreateLocalDomainResolver( domain, address_family, base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved, base::Unretained(this), id))); resolver->Start(); local_domain_resolvers_[id].reset(resolver.release()); } void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) { VLOG(1) << "DestroyLocalDomainResolver, id=" << id; if (!service_discovery_client_) return; local_domain_resolvers_.erase(id); } void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() { if (!discovery_task_runner_) return; discovery_task_runner_->PostTask( FROM_HERE, base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread, base::Unretained(this))); // This will wait for message loop to drain, so ShutdownOnIOThread will // definitely be called. discovery_thread_.reset(); } void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() { VLOG(1) << "ShutdownLocalDiscovery"; service_watchers_.clear(); service_resolvers_.clear(); local_domain_resolvers_.clear(); service_discovery_client_.reset(); mdns_client_.reset(); } void ServiceDiscoveryMessageHandler::OnServiceUpdated( uint64 id, ServiceWatcher::UpdateType update, const std::string& name) { VLOG(1) << "OnServiceUpdated, id=" << id << ", status=" << WatcherUpdateToString(update) << ", name=" << name; DCHECK(service_discovery_client_); Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name)); } void ServiceDiscoveryMessageHandler::OnServiceResolved( uint64 id, ServiceResolver::RequestStatus status, const ServiceDescription& description) { VLOG(1) << "OnServiceResolved, id=" << id << ", status=" << ResolverStatusToString(status) << ", name=" << description.service_name; DCHECK(service_discovery_client_); Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description)); } void ServiceDiscoveryMessageHandler::OnLocalDomainResolved( uint64 id, bool success, const net::IPAddressNumber& address_ipv4, const net::IPAddressNumber& address_ipv6) { VLOG(1) << "OnLocalDomainResolved, id=" << id << ", IPv4=" << (address_ipv4.empty() ? "" : net::IPAddressToString(address_ipv4)) << ", IPv6=" << (address_ipv6.empty() ? "" : net::IPAddressToString(address_ipv6)); DCHECK(service_discovery_client_); Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback( id, success, address_ipv4, address_ipv6)); } void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) { utility_task_runner_->PostTask(FROM_HERE, base::Bind(&SendHostMessageOnUtilityThread, msg)); } } // namespace local_discovery