diff options
Diffstat (limited to 'net/base')
-rw-r--r-- | net/base/address_list.cc | 118 | ||||
-rw-r--r-- | net/base/address_list.h | 29 | ||||
-rw-r--r-- | net/base/address_list_unittest.cc | 86 | ||||
-rw-r--r-- | net/base/client_socket_pool.h | 4 | ||||
-rw-r--r-- | net/base/host_cache.cc | 115 | ||||
-rw-r--r-- | net/base/host_cache.h | 93 | ||||
-rw-r--r-- | net/base/host_cache_unittest.cc | 218 | ||||
-rw-r--r-- | net/base/host_resolver.cc | 374 | ||||
-rw-r--r-- | net/base/host_resolver.h | 138 | ||||
-rw-r--r-- | net/base/host_resolver_unittest.cc | 436 | ||||
-rw-r--r-- | net/base/ssl_client_socket_unittest.cc | 12 | ||||
-rw-r--r-- | net/base/ssl_test_util.cc | 2 | ||||
-rw-r--r-- | net/base/tcp_client_socket_pool.cc | 5 | ||||
-rw-r--r-- | net/base/tcp_client_socket_pool.h | 11 | ||||
-rw-r--r-- | net/base/tcp_client_socket_pool_unittest.cc | 8 | ||||
-rw-r--r-- | net/base/tcp_client_socket_unittest.cc | 2 | ||||
-rw-r--r-- | net/base/tcp_pinger_unittest.cc | 4 |
17 files changed, 1550 insertions, 105 deletions
diff --git a/net/base/address_list.cc b/net/base/address_list.cc index 1cd2f62..d92d280 100644 --- a/net/base/address_list.cc +++ b/net/base/address_list.cc @@ -11,14 +11,128 @@ #include <netdb.h> #endif +#include "base/logging.h" + namespace net { +namespace { + +// Make a deep copy of |info|. This copy should be deleted using +// DeleteCopyOfAddrinfo(), and NOT freeaddrinfo(). +struct addrinfo* CreateCopyOfAddrinfo(const struct addrinfo* info) { + struct addrinfo* copy = new struct addrinfo; + + // Copy all the fields (some of these are pointers, we will fix that next). + memcpy(copy, info, sizeof(addrinfo)); + + // ai_canonname is a NULL-terminated string. + if (info->ai_canonname) { +#ifdef OS_WIN + copy->ai_canonname = _strdup(info->ai_canonname); +#else + copy->ai_canonname = strdup(info->ai_canonname); +#endif + } + + // ai_addr is a buffer of length ai_addrlen. + if (info->ai_addr) { + copy->ai_addr = reinterpret_cast<sockaddr *>(new char[info->ai_addrlen]); + memcpy(copy->ai_addr, info->ai_addr, info->ai_addrlen); + } + + // Recursive copy. + if (info->ai_next) + copy->ai_next = CreateCopyOfAddrinfo(info->ai_next); + + return copy; +} + +// Free an addrinfo that was created by CreateCopyOfAddrinfo(). +void FreeMyAddrinfo(struct addrinfo* info) { + if (info->ai_canonname) + free(info->ai_canonname); // Allocated by strdup. + + if (info->ai_addr) + delete [] reinterpret_cast<char*>(info->ai_addr); + + struct addrinfo* next = info->ai_next; + + delete info; + + // Recursive free. + if (next) + FreeMyAddrinfo(next); +} + +// Returns the address to port field in |info|. +uint16* GetPortField(const struct addrinfo* info) { + if (info->ai_family == AF_INET) { + DCHECK_EQ(sizeof(sockaddr_in), info->ai_addrlen); + struct sockaddr_in* sockaddr = + reinterpret_cast<struct sockaddr_in*>(info->ai_addr); + return &sockaddr->sin_port; + } else if (info->ai_family == AF_INET6) { + DCHECK_EQ(sizeof(sockaddr_in6), info->ai_addrlen); + struct sockaddr_in6* sockaddr = + reinterpret_cast<struct sockaddr_in6*>(info->ai_addr); + return &sockaddr->sin6_port; + } else { + NOTREACHED(); + return NULL; + } +} + +// Assign the port for all addresses in the list. +void SetPortRecursive(struct addrinfo* info, int port) { + uint16* port_field = GetPortField(info); + *port_field = htons(port); + + // Assign recursively. + if (info->ai_next) + SetPortRecursive(info->ai_next, port); +} + +} // namespace + void AddressList::Adopt(struct addrinfo* head) { - data_ = new Data(head); + data_ = new Data(head, true /*is_system_created*/); +} + +void AddressList::Copy(const struct addrinfo* head) { + data_ = new Data(CreateCopyOfAddrinfo(head), false /*is_system_created*/); +} + +void AddressList::SetPort(int port) { + SetPortRecursive(data_->head, port); +} + +int AddressList::GetPort() const { + uint16* port_field = GetPortField(data_->head); + return ntohs(*port_field); +} + +void AddressList::SetFrom(const AddressList& src, int port) { + if (src.GetPort() == port) { + // We can reference the data from |src| directly. + *this = src; + } else { + // Otherwise we need to make a copy in order to change the port number. + Copy(src.head()); + SetPort(port); + } +} + +void AddressList::Reset() { + data_ = NULL; } AddressList::Data::~Data() { - freeaddrinfo(head); + // Call either freeaddrinfo(head), or FreeMyAddrinfo(head), depending who + // created the data. + if (is_system_created) + freeaddrinfo(head); + else + FreeMyAddrinfo(head); } } // namespace net diff --git a/net/base/address_list.h b/net/base/address_list.h index 5bffd4c..506350b 100644 --- a/net/base/address_list.h +++ b/net/base/address_list.h @@ -20,16 +20,39 @@ class AddressList { // object. void Adopt(struct addrinfo* head); + // Copies the given addrinfo rather than adopting it. + void Copy(const struct addrinfo* head); + + // Sets the port of all addresses in the list to |port| (that is the + // sin[6]_port field for the sockaddrs). + void SetPort(int port); + + // Retrieves the port number of the first sockaddr in the list. (If SetPort() + // was previously used on this list, then all the addresses will have this + // same port number.) + int GetPort() const; + + // Sets the address to match |src|, and have each sockaddr's port be |port|. + // If |src| already has the desired port this operation is cheap (just adds + // a reference to |src|'s data.) Otherwise we will make a copy. + void SetFrom(const AddressList& src, int port); + + // Clears all data from this address list. This leaves the list in the same + // empty state as when first constructed. + void Reset(); + // Get access to the head of the addrinfo list. const struct addrinfo* head() const { return data_->head; } private: struct Data : public base::RefCountedThreadSafe<Data> { - explicit Data(struct addrinfo* ai) : head(ai) {} + Data(struct addrinfo* ai, bool is_system_created) + : head(ai), is_system_created(is_system_created) {} ~Data(); struct addrinfo* head; - private: - Data(); + + // Indicates which free function to use for |head|. + bool is_system_created; }; scoped_refptr<Data> data_; }; diff --git a/net/base/address_list_unittest.cc b/net/base/address_list_unittest.cc new file mode 100644 index 0000000..d594c85 --- /dev/null +++ b/net/base/address_list_unittest.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/address_list.h" + +#if defined(OS_WIN) +#include <ws2tcpip.h> +#include <wspiapi.h> // Needed for Win2k compat. +#elif defined(OS_POSIX) +#include <netdb.h> +#include <sys/socket.h> +#endif + +#include "base/string_util.h" +#include "net/base/net_util.h" +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif +#include "testing/gtest/include/gtest/gtest.h" + +namespace { + +// Use getaddrinfo() to allocate an addrinfo structure. +void CreateAddressList(net::AddressList* addrlist, int port) { +#if defined(OS_WIN) + net::EnsureWinsockInit(); +#endif + std::string portstr = IntToString(port); + + struct addrinfo* result = NULL; + struct addrinfo hints = {0}; + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + hints.ai_socktype = SOCK_STREAM; + + int err = getaddrinfo("192.168.1.1", portstr.c_str(), &hints, &result); + EXPECT_EQ(0, err); + addrlist->Adopt(result); +} + +TEST(AddressListTest, GetPort) { + net::AddressList addrlist; + CreateAddressList(&addrlist, 81); + EXPECT_EQ(81, addrlist.GetPort()); + + addrlist.SetPort(83); + EXPECT_EQ(83, addrlist.GetPort()); +} + +TEST(AddressListTest, Assignment) { + net::AddressList addrlist1; + CreateAddressList(&addrlist1, 85); + EXPECT_EQ(85, addrlist1.GetPort()); + + // Should reference the same data as addrlist1 -- so when we change addrlist1 + // both are changed. + net::AddressList addrlist2 = addrlist1; + EXPECT_EQ(85, addrlist2.GetPort()); + + addrlist1.SetPort(80); + EXPECT_EQ(80, addrlist1.GetPort()); + EXPECT_EQ(80, addrlist2.GetPort()); +} + +TEST(AddressListTest, Copy) { + net::AddressList addrlist1; + CreateAddressList(&addrlist1, 85); + EXPECT_EQ(85, addrlist1.GetPort()); + + net::AddressList addrlist2; + addrlist2.Copy(addrlist1.head()); + + // addrlist1 is the same as addrlist2 at this point. + EXPECT_EQ(85, addrlist1.GetPort()); + EXPECT_EQ(85, addrlist2.GetPort()); + + // Changes to addrlist1 are not reflected in addrlist2. + addrlist1.SetPort(70); + addrlist2.SetPort(90); + + EXPECT_EQ(70, addrlist1.GetPort()); + EXPECT_EQ(90, addrlist2.GetPort()); +} + +} // namespace diff --git a/net/base/client_socket_pool.h b/net/base/client_socket_pool.h index 599b7a9..6d200c0 100644 --- a/net/base/client_socket_pool.h +++ b/net/base/client_socket_pool.h @@ -17,6 +17,7 @@ namespace net { class ClientSocket; class ClientSocketHandle; +class HostResolver; // A ClientSocketPool is used to restrict the number of sockets open at a time. // It also maintains a list of idle persistent sockets. @@ -69,6 +70,9 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { // Called to close any idle connections held by the connection manager. virtual void CloseIdleSockets() = 0; + // Returns the HostResolver that will be used for host lookups. + virtual HostResolver* GetHostResolver() const = 0; + // The total number of idle sockets in the pool. virtual int idle_socket_count() const = 0; diff --git a/net/base/host_cache.cc b/net/base/host_cache.cc new file mode 100644 index 0000000..53af5b4 --- /dev/null +++ b/net/base/host_cache.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/host_cache.h" + +#include "base/logging.h" +#include "net/base/net_errors.h" + +namespace net { + +//----------------------------------------------------------------------------- + +HostCache::Entry::Entry(int error, + const AddressList& addrlist, + base::TimeTicks expiration) + : error(error), addrlist(addrlist), expiration(expiration) { +} + +HostCache::Entry::~Entry() { +} + +//----------------------------------------------------------------------------- + +HostCache::HostCache(size_t max_entries, size_t cache_duration_ms) + : max_entries_(max_entries), cache_duration_ms_(cache_duration_ms) { +} + +HostCache::~HostCache() { +} + +const HostCache::Entry* HostCache::Lookup(const std::string& hostname, + base::TimeTicks now) const { + if (caching_is_disabled()) + return NULL; + + EntryMap::const_iterator it = entries_.find(hostname); + if (it == entries_.end()) + return NULL; // Not found. + + Entry* entry = it->second.get(); + if (CanUseEntry(entry, now)) + return entry; + + return NULL; +} + +HostCache::Entry* HostCache::Set(const std::string& hostname, + int error, + const AddressList addrlist, + base::TimeTicks now) { + if (caching_is_disabled()) + return NULL; + + base::TimeTicks expiration = now + + base::TimeDelta::FromMilliseconds(cache_duration_ms_); + + scoped_refptr<Entry>& entry = entries_[hostname]; + if (!entry) { + // Entry didn't exist, creating one now. + Entry* ptr = new Entry(error, addrlist, expiration); + entry = ptr; + + // Compact the cache if we grew it beyond limit -- exclude |entry| from + // being pruned though! + if (entries_.size() > max_entries_) + Compact(now, ptr); + return ptr; + } else { + // Update an existing cache entry. + entry->error = error; + entry->addrlist = addrlist; + entry->expiration = expiration; + return entry.get(); + } +} + +// static +bool HostCache::CanUseEntry(const Entry* entry, const base::TimeTicks now) { + return entry->error == OK && entry->expiration > now; +} + +void HostCache::Compact(base::TimeTicks now, const Entry* pinned_entry) { + // Clear out expired entries. + for (EntryMap::iterator it = entries_.begin(); it != entries_.end(); ) { + Entry* entry = (it->second).get(); + if (entry != pinned_entry && !CanUseEntry(entry, now)) { + entries_.erase(it++); + } else { + ++it; + } + } + + if (entries_.size() <= max_entries_) + return; + + // If we still have too many entries, start removing unexpired entries + // at random. + // TODO(eroman): this eviction policy could be better (access count FIFO + // or whatever). + for (EntryMap::iterator it = entries_.begin(); + it != entries_.end() && entries_.size() > max_entries_; ) { + Entry* entry = (it->second).get(); + if (entry != pinned_entry) { + entries_.erase(it++); + } else { + ++it; + } + } + + if (entries_.size() > max_entries_) + DLOG(WARNING) << "Still above max entries limit"; +} + +} // namespace net diff --git a/net/base/host_cache.h b/net/base/host_cache.h new file mode 100644 index 0000000..2085b7c --- /dev/null +++ b/net/base/host_cache.h @@ -0,0 +1,93 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_HOST_CACHE_H_ +#define NET_BASE_HOST_CACHE_H_ + +#include <string> + +#include "base/hash_tables.h" +#include "base/ref_counted.h" +#include "base/time.h" +#include "net/base/address_list.h" +#include "testing/gtest/include/gtest/gtest_prod.h" + +namespace net { + +// Cache used by HostResolver to map hostnames to their resolved result. +// If the resolve is still in progress, the entry will reference the job +// responsible for populating it. +class HostCache { + public: + // Stores the latest address list that was looked up for a hostname. + struct Entry : public base::RefCounted<Entry> { + Entry(int error, const AddressList& addrlist, base::TimeTicks expiration); + ~Entry(); + + // The resolve results for this entry. + int error; + AddressList addrlist; + + // The time when this entry expires. + base::TimeTicks expiration; + }; + + // Constructs a HostCache whose entries are valid for |cache_duration_ms| + // milliseconds. The cache will store up to |max_entries|. + HostCache(size_t max_entries, size_t cache_duration_ms); + + ~HostCache(); + + // Returns a pointer to the entry for |hostname|, which is valid at time + // |now|. If there is no such entry, returns NULL. + const Entry* Lookup(const std::string& hostname, base::TimeTicks now) const; + + // Overwrites or creates an entry for |hostname|. Returns the pointer to the + // entry, or NULL on failure (fails if caching is disabled). + // (|error|, |addrlist|) is the value to set, and |now| is the current + // timestamp. + Entry* Set(const std::string& hostname, + int error, + const AddressList addrlist, + base::TimeTicks now); + + // Returns true if this HostCache can contain no entries. + bool caching_is_disabled() const { + return max_entries_ == 0; + } + + // Returns the number of entries in the cache. + size_t size() const { + return entries_.size(); + } + + private: + FRIEND_TEST(HostCacheTest, Compact); + FRIEND_TEST(HostCacheTest, NoCache); + + typedef base::hash_map<std::string, scoped_refptr<Entry> > EntryMap; + + // Returns true if this cache entry's result is valid at time |now|. + static bool CanUseEntry(const Entry* entry, const base::TimeTicks now); + + // Prunes entries from the cache to bring it below max entry bound. Entries + // matching |pinned_entry| will NOT be pruned. + void Compact(base::TimeTicks now, const Entry* pinned_entry); + + // Bound on total size of the cache. + size_t max_entries_; + + // Time to live for cache entries in milliseconds. + size_t cache_duration_ms_; + + // Map from hostname (presumably in lowercase canonicalized format) to + // a resolved result entry. + EntryMap entries_; + + DISALLOW_COPY_AND_ASSIGN(HostCache); +}; + +} // namespace net + +#endif // NET_BASE_HOST_CACHE_H_ diff --git a/net/base/host_cache_unittest.cc b/net/base/host_cache_unittest.cc new file mode 100644 index 0000000..fe01e20 --- /dev/null +++ b/net/base/host_cache_unittest.cc @@ -0,0 +1,218 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/host_cache.h" + +#include "base/stl_util-inl.h" +#include "base/string_util.h" +#include "net/base/net_errors.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { +static const int kMaxCacheEntries = 10; +static const int kCacheDurationMs = 10000; // 10 seconds. +} + +TEST(HostCacheTest, Basic) { + HostCache cache(kMaxCacheEntries, kCacheDurationMs); + + // Start at t=0. + base::TimeTicks now; + + const HostCache::Entry* entry1 = NULL; // Entry for foobar.com. + const HostCache::Entry* entry2 = NULL; // Entry for foobar2.com. + + EXPECT_EQ(0U, cache.size()); + + // Add an entry for "foobar.com" at t=0. + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + cache.Set("foobar.com", OK, AddressList(), now); + entry1 = cache.Lookup("foobar.com", base::TimeTicks()); + EXPECT_FALSE(NULL == entry1); + EXPECT_EQ(1U, cache.size()); + + // Advance to t=5. + now += base::TimeDelta::FromSeconds(5); + + // Add an entry for "foobar2.com" at t=5. + EXPECT_EQ(NULL, cache.Lookup("foobar2.com", base::TimeTicks())); + cache.Set("foobar2.com", OK, AddressList(), now); + entry2 = cache.Lookup("foobar2.com", base::TimeTicks()); + EXPECT_FALSE(NULL == entry1); + EXPECT_EQ(2U, cache.size()); + + // Advance to t=9 + now += base::TimeDelta::FromSeconds(4); + + // Verify that the entries we added are still retrievable, and usable. + EXPECT_EQ(entry1, cache.Lookup("foobar.com", now)); + EXPECT_EQ(entry2, cache.Lookup("foobar2.com", now)); + + // Advance to t=10; entry1 is now expired. + now += base::TimeDelta::FromSeconds(1); + + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); + EXPECT_EQ(entry2, cache.Lookup("foobar2.com", now)); + + // Update entry1, so it is no longer expired. + cache.Set("foobar.com", OK, AddressList(), now); + // Re-uses existing entry storage. + EXPECT_EQ(entry1, cache.Lookup("foobar.com", now)); + EXPECT_EQ(2U, cache.size()); + + // Both entries should still be retrievable and usable. + EXPECT_EQ(entry1, cache.Lookup("foobar.com", now)); + EXPECT_EQ(entry2, cache.Lookup("foobar2.com", now)); + + // Advance to t=20; both entries are now expired. + now += base::TimeDelta::FromSeconds(10); + + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); + EXPECT_EQ(NULL, cache.Lookup("foobar2.com", now)); +} + +// Try caching entries for a failed resolve attempt. +TEST(HostCacheTest, NegativeEntry) { + HostCache cache(kMaxCacheEntries, kCacheDurationMs); + + // Set t=0. + base::TimeTicks now; + + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + cache.Set("foobar.com", ERR_NAME_NOT_RESOLVED, AddressList(), now); + EXPECT_EQ(1U, cache.size()); + + // We disallow use of negative entries. + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); + + // Now overwrite with a valid entry, and then overwrite with negative entry + // again -- the valid entry should be kicked out. + cache.Set("foobar.com", OK, AddressList(), now); + EXPECT_FALSE(NULL == cache.Lookup("foobar.com", now)); + cache.Set("foobar.com", ERR_NAME_NOT_RESOLVED, AddressList(), now); + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); +} + +TEST(HostCacheTest, Compact) { + // Initial entries limit is big enough to accomadate everything we add. + net::HostCache cache(kMaxCacheEntries, kCacheDurationMs); + + EXPECT_EQ(0U, cache.size()); + + // t=10 + base::TimeTicks now = base::TimeTicks() + base::TimeDelta::FromSeconds(10); + + // Add five valid entries at t=10. + for (int i = 0; i < 5; ++i) { + std::string hostname = StringPrintf("valid%d", i); + cache.Set(hostname, OK, AddressList(), now); + } + EXPECT_EQ(5U, cache.size()); + + // Add 3 expired entries at t=0. + for (int i = 0; i < 3; ++i) { + std::string hostname = StringPrintf("expired%d", i); + base::TimeTicks t = now - base::TimeDelta::FromSeconds(10); + cache.Set(hostname, OK, AddressList(), t); + } + EXPECT_EQ(8U, cache.size()); + + // Add 2 negative entries at t=10 + for (int i = 0; i < 2; ++i) { + std::string hostname = StringPrintf("negative%d", i); + cache.Set(hostname, ERR_NAME_NOT_RESOLVED, AddressList(), now); + } + EXPECT_EQ(10U, cache.size()); + + EXPECT_TRUE(ContainsKey(cache.entries_, "valid0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid1")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid2")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid3")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid4")); + EXPECT_TRUE(ContainsKey(cache.entries_, "expired0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "expired1")); + EXPECT_TRUE(ContainsKey(cache.entries_, "expired2")); + EXPECT_TRUE(ContainsKey(cache.entries_, "negative0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "negative1")); + + // Shrink the max constraints bound and compact. We expect the "negative" + // and "expired" entries to have been dropped. + cache.max_entries_ = 5; + cache.Compact(now, NULL); + EXPECT_EQ(5U, cache.entries_.size()); + + EXPECT_TRUE(ContainsKey(cache.entries_, "valid0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid1")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid2")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid3")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid4")); + EXPECT_FALSE(ContainsKey(cache.entries_, "expired0")); + EXPECT_FALSE(ContainsKey(cache.entries_, "expired1")); + EXPECT_FALSE(ContainsKey(cache.entries_, "expired2")); + EXPECT_FALSE(ContainsKey(cache.entries_, "negative0")); + EXPECT_FALSE(ContainsKey(cache.entries_, "negative1")); + + // Shrink further -- this time the compact will start dropping valid entries + // to make space. + cache.max_entries_ = 3; + cache.Compact(now, NULL); + EXPECT_EQ(3U, cache.size()); +} + +// Add entries while the cache is at capacity, causing evictions. +TEST(HostCacheTest, SetWithCompact) { + net::HostCache cache(3, kCacheDurationMs); + + // t=10 + base::TimeTicks now = + base::TimeTicks() + base::TimeDelta::FromMilliseconds(kCacheDurationMs); + + cache.Set("host1", OK, AddressList(), now); + cache.Set("host2", OK, AddressList(), now); + cache.Set("expired", OK, AddressList(), + now - base::TimeDelta::FromMilliseconds(kCacheDurationMs)); + + EXPECT_EQ(3U, cache.size()); + + // Should all be retrievable except "expired". + EXPECT_FALSE(NULL == cache.Lookup("host1", now)); + EXPECT_FALSE(NULL == cache.Lookup("host2", now)); + EXPECT_TRUE(NULL == cache.Lookup("expired", now)); + + // Adding the fourth entry will cause "expired" to be evicted. + cache.Set("host3", OK, AddressList(), now); + EXPECT_EQ(3U, cache.size()); + EXPECT_EQ(NULL, cache.Lookup("expired", now)); + EXPECT_FALSE(NULL == cache.Lookup("host1", now)); + EXPECT_FALSE(NULL == cache.Lookup("host2", now)); + EXPECT_FALSE(NULL == cache.Lookup("host3", now)); + + // Add two more entries. Something should be evicted, however "host5" + // should definitely be in there (since it was last inserted). + cache.Set("host4", OK, AddressList(), now); + EXPECT_EQ(3U, cache.size()); + cache.Set("host5", OK, AddressList(), now); + EXPECT_EQ(3U, cache.size()); + EXPECT_FALSE(NULL == cache.Lookup("host5", now)); +} + +TEST(HostCacheTest, NoCache) { + // Disable caching. + HostCache cache(0, kCacheDurationMs); + EXPECT_TRUE(cache.caching_is_disabled()); + + // Set t=0. + base::TimeTicks now; + + // Lookup and Set should have no effect. + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + cache.Set("foobar.com", OK, AddressList(), now); + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + + EXPECT_EQ(0U, cache.size()); +} + +} // namespace net diff --git a/net/base/host_resolver.cc b/net/base/host_resolver.cc index ec693ee..50329ff 100644 --- a/net/base/host_resolver.cc +++ b/net/base/host_resolver.cc @@ -15,8 +15,11 @@ #include <resolv.h> #endif +#include "base/compiler_specific.h" #include "base/message_loop.h" +#include "base/stl_util-inl.h" #include "base/string_util.h" +#include "base/time.h" #include "base/worker_pool.h" #include "net/base/address_list.h" #include "net/base/net_errors.h" @@ -24,7 +27,6 @@ #if defined(OS_LINUX) #include "base/singleton.h" #include "base/thread_local_storage.h" -#include "base/time.h" #endif #if defined(OS_WIN) @@ -110,8 +112,7 @@ ThreadLocalStorage::Slot DnsReloadTimer::tls_index_(base::LINKER_INITIALIZED); #endif // defined(OS_LINUX) -static int HostResolverProc( - const std::string& host, const std::string& port, struct addrinfo** out) { +static int HostResolverProc(const std::string& host, struct addrinfo** out) { struct addrinfo hints = {0}; hints.ai_family = AF_UNSPEC; @@ -144,7 +145,7 @@ static int HostResolverProc( // Restrict result set to only this socket type to avoid duplicates. hints.ai_socktype = SOCK_STREAM; - int err = getaddrinfo(host.c_str(), port.c_str(), &hints, out); + int err = getaddrinfo(host.c_str(), NULL, &hints, out); #if defined(OS_LINUX) net::DnsReloadTimer* dns_timer = Singleton<net::DnsReloadTimer>::get(); // If we fail, re-initialise the resolver just in case there have been any @@ -152,7 +153,7 @@ static int HostResolverProc( if (err && dns_timer->Expired()) { res_nclose(&_res); if (!res_ninit(&_res)) - err = getaddrinfo(host.c_str(), port.c_str(), &hints, out); + err = getaddrinfo(host.c_str(), NULL, &hints, out); } #endif @@ -160,50 +161,148 @@ static int HostResolverProc( } static int ResolveAddrInfo(HostMapper* mapper, const std::string& host, - const std::string& port, struct addrinfo** out) { + struct addrinfo** out) { if (mapper) { std::string mapped_host = mapper->Map(host); if (mapped_host.empty()) return ERR_NAME_NOT_RESOLVED; - return HostResolverProc(mapped_host, port, out); + return HostResolverProc(mapped_host, out); } else { - return HostResolverProc(host, port, out); + return HostResolverProc(host, out); } } //----------------------------------------------------------------------------- -class HostResolver::Request : - public base::RefCountedThreadSafe<HostResolver::Request> { +class HostResolver::Request { public: - Request(HostResolver* resolver, - const std::string& host, - const std::string& port, - AddressList* addresses, - CompletionCallback* callback) + Request(CompletionCallback* callback, AddressList* addresses, int port) + : job_(NULL), callback_(callback), addresses_(addresses), port_(port) {} + + // Mark the request as cancelled. + void Cancel() { + job_ = NULL; + callback_ = NULL; + addresses_ = NULL; + } + + bool was_cancelled() const { + return callback_ == NULL; + } + + void set_job(Job* job) { + DCHECK(job != NULL); + // Identify which job the request is waiting on. + job_ = job; + } + + void OnComplete(int error, const AddressList& addrlist) { + if (error == OK) + addresses_->SetFrom(addrlist, port_); + callback_->Run(error); + } + + int port() const { + return port_; + } + + Job* job() const { + return job_; + } + + private: + // The resolve job (running in worker pool) that this request is dependent on. + Job* job_; + + // The user's callback to invoke when the request completes. + CompletionCallback* callback_; + + // The address list to save result into. + AddressList* addresses_; + + // The desired port number for the socket addresses. + int port_; + + DISALLOW_COPY_AND_ASSIGN(Request); +}; + +//----------------------------------------------------------------------------- + +// This class represents a request to the worker pool for a "getaddrinfo()" +// call. +class HostResolver::Job : public base::RefCountedThreadSafe<HostResolver::Job> { + public: + Job(HostResolver* resolver, const std::string& host) : host_(host), - port_(port), resolver_(resolver), - addresses_(addresses), - callback_(callback), origin_loop_(MessageLoop::current()), host_mapper_(host_mapper), error_(OK), results_(NULL) { } - ~Request() { + ~Job() { if (results_) freeaddrinfo(results_); + + // Free the requests attached to this job. + STLDeleteElements(&requests_); } + // Attaches a request to this job. The job takes ownership of |req| and will + // take care to delete it. + void AddRequest(HostResolver::Request* req) { + req->set_job(this); + requests_.push_back(req); + } + + // Called from origin loop. + void Start() { + // Dispatch the job to a worker thread. + if (!WorkerPool::PostTask(FROM_HERE, + NewRunnableMethod(this, &Job::DoLookup), true)) { + NOTREACHED(); + + // Since we could be running within Resolve() right now, we can't just + // call OnLookupComplete(). Instead we must wait until Resolve() has + // returned (IO_PENDING). + error_ = ERR_UNEXPECTED; + MessageLoop::current()->PostTask( + FROM_HERE, NewRunnableMethod(this, &Job::OnLookupComplete)); + } + } + + // Cancels the current job. Callable from origin thread. + void Cancel() { + resolver_ = NULL; + + AutoLock locked(origin_loop_lock_); + origin_loop_ = NULL; + } + + // Called from origin thread. + bool was_cancelled() const { + return resolver_ == NULL; + } + + // Called from origin thread. + const std::string& host() const { + return host_; + } + + // Called from origin thread. + const RequestsList& requests() const { + return requests_; + } + + private: void DoLookup() { // Running on the worker thread - error_ = ResolveAddrInfo(host_mapper_, host_, port_, &results_); + error_ = ResolveAddrInfo(host_mapper_, host_, &results_); - Task* reply = NewRunnableMethod(this, &Request::DoCallback); + Task* reply = NewRunnableMethod(this, &Job::OnLookupComplete); // The origin loop could go away while we are trying to post to it, so we // need to call its PostTask method inside a lock. See ~HostResolver. @@ -219,43 +318,33 @@ class HostResolver::Request : delete reply; } - void DoCallback() { - // Running on the origin thread. + // Callback for when DoLookup() completes (runs on origin thread). + void OnLookupComplete() { + DCHECK_EQ(origin_loop_, MessageLoop::current()); DCHECK(error_ || results_); - // We may have been cancelled! - if (!resolver_) + if (was_cancelled()) return; - if (!error_) { - addresses_->Adopt(results_); + DCHECK(!requests_.empty()); + + // Adopt the address list using the port number of the first request. + AddressList addrlist; + if (error_ == OK) { + addrlist.Adopt(results_); + addrlist.SetPort(requests_[0]->port()); results_ = NULL; } - // Drop the resolver's reference to us. Do this before running the - // callback since the callback might result in the resolver being - // destroyed. - resolver_->request_ = NULL; - - callback_->Run(error_); + resolver_->OnJobComplete(this, error_, addrlist); } - void Cancel() { - resolver_ = NULL; - - AutoLock locked(origin_loop_lock_); - origin_loop_ = NULL; - } - - private: // Set on the origin thread, read on the worker thread. std::string host_; - std::string port_; // Only used on the origin thread (where Resolve was called). HostResolver* resolver_; - AddressList* addresses_; - CompletionCallback* callback_; + RequestsList requests_; // The requests waiting on this job. // Used to post ourselves onto the origin thread. Lock origin_loop_lock_; @@ -270,48 +359,205 @@ class HostResolver::Request : // Assigned on the worker thread, read on the origin thread. int error_; struct addrinfo* results_; + + DISALLOW_COPY_AND_ASSIGN(Job); }; //----------------------------------------------------------------------------- -HostResolver::HostResolver() { +HostResolver::HostResolver(int max_cache_entries, int cache_duration_ms) + : cache_(max_cache_entries, cache_duration_ms) { #if defined(OS_WIN) EnsureWinsockInit(); #endif } HostResolver::~HostResolver() { - if (request_) - request_->Cancel(); + // Cancel the outstanding jobs. Those jobs may contain several attached + // requests, which will now never be completed. + for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) + it->second->Cancel(); + + // In case we are being deleted during the processing of a callback. + if (cur_completing_job_) + cur_completing_job_->Cancel(); } +// TODO(eroman): Don't create cache entries for hostnames which are simply IP +// address literals. int HostResolver::Resolve(const std::string& hostname, int port, AddressList* addresses, - CompletionCallback* callback) { - DCHECK(!request_) << "resolver already in use"; - - const std::string& port_str = IntToString(port); + CompletionCallback* callback, + Request** out_req) { + // If we have an unexpired cache entry, use it. + const HostCache::Entry* cache_entry = cache_.Lookup( + hostname, base::TimeTicks::Now()); + if (cache_entry) { + addresses->SetFrom(cache_entry->addrlist, port); + return OK; + } - // Do a synchronous resolution. + // If no callback was specified, do a synchronous resolution. if (!callback) { struct addrinfo* results; - int rv = ResolveAddrInfo(host_mapper, hostname, port_str, &results); - if (rv == OK) - addresses->Adopt(results); - return rv; - } + int error = ResolveAddrInfo(host_mapper, hostname, &results); + + // Adopt the address list. + AddressList addrlist; + if (error == OK) { + addrlist.Adopt(results); + addrlist.SetPort(port); + *addresses = addrlist; + } + + // Write to cache. + cache_.Set(hostname, error, addrlist, base::TimeTicks::Now()); - request_ = new Request(this, hostname, port_str, addresses, callback); + return error; + } - // Dispatch to worker thread... - if (!WorkerPool::PostTask(FROM_HERE, - NewRunnableMethod(request_.get(), &Request::DoLookup), true)) { - NOTREACHED(); - request_ = NULL; - return ERR_FAILED; + // Create a handle for this request, and pass it back to the user if they + // asked for it (out_req != NULL). + Request* req = new Request(callback, addresses, port); + if (out_req) + *out_req = req; + + // Next we need to attach our request to a "job". This job is responsible for + // calling "getaddrinfo(hostname)" on a worker thread. + scoped_refptr<Job> job; + + // If there is already an outstanding job to resolve |hostname|, use it. + // This prevents starting concurrent resolves for the same hostname. + job = FindOutstandingJob(hostname); + if (job) { + job->AddRequest(req); + } else { + // Create a new job for this request. + job = new Job(this, hostname); + job->AddRequest(req); + AddOutstandingJob(job); + // TODO(eroman): Bound the total number of concurrent jobs. + // http://crbug.com/9598 + job->Start(); } + // Completion happens during OnJobComplete(Job*). return ERR_IO_PENDING; } +// See OnJobComplete(Job*) for why it is important not to clean out +// cancelled requests from Job::requests_. +void HostResolver::CancelRequest(Request* req) { + DCHECK(req); + DCHECK(req->job()); + // NULL out the fields of req, to mark it as cancelled. + req->Cancel(); +} + +void HostResolver::AddOutstandingJob(Job* job) { + scoped_refptr<Job>& found_job = jobs_[job->host()]; + DCHECK(!found_job); + found_job = job; +} + +HostResolver::Job* HostResolver::FindOutstandingJob( + const std::string& hostname) { + JobMap::iterator it = jobs_.find(hostname); + if (it != jobs_.end()) + return it->second; + return NULL; +} + +void HostResolver::RemoveOutstandingJob(Job* job) { + JobMap::iterator it = jobs_.find(job->host()); + DCHECK(it != jobs_.end()); + DCHECK_EQ(it->second.get(), job); + jobs_.erase(it); +} + +void HostResolver::OnJobComplete(Job* job, + int error, + const AddressList& addrlist) { + RemoveOutstandingJob(job); + + // Write result to the cache. + cache_.Set(job->host(), error, addrlist, base::TimeTicks::Now()); + + // Make a note that we are executing within OnJobComplete() in case the + // HostResolver is deleted by a callback invocation. + DCHECK(!cur_completing_job_); + cur_completing_job_ = job; + + // Complete all of the requests that were attached to the job. + for (RequestsList::const_iterator it = job->requests().begin(); + it != job->requests().end(); ++it) { + Request* req = *it; + if (!req->was_cancelled()) { + DCHECK_EQ(job, req->job()); + req->OnComplete(error, addrlist); + + // Check if the job was cancelled as a result of running the callback. + // (Meaning that |this| was deleted). + if (job->was_cancelled()) + return; + } + } + + cur_completing_job_ = NULL; +} + +//----------------------------------------------------------------------------- + +SingleRequestHostResolver::SingleRequestHostResolver(HostResolver* resolver) + : resolver_(resolver), + cur_request_(NULL), + cur_request_callback_(NULL), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &SingleRequestHostResolver::OnResolveCompletion)) { + DCHECK(resolver_ != NULL); +} + +SingleRequestHostResolver::~SingleRequestHostResolver() { + if (cur_request_) { + resolver_->CancelRequest(cur_request_); + } +} + +int SingleRequestHostResolver::Resolve( + const std::string& hostname, int port, + AddressList* addresses, + CompletionCallback* callback) { + DCHECK(!cur_request_ && !cur_request_callback_) << "resolver already in use"; + + HostResolver::Request* request = NULL; + + // We need to be notified of completion before |callback| is called, so that + // we can clear out |cur_request_*|. + CompletionCallback* transient_callback = callback ? &callback_ : NULL; + + int rv = resolver_->Resolve( + hostname, port, addresses, transient_callback, &request); + + if (rv == ERR_IO_PENDING) { + // Cleared in OnResolveCompletion(). + cur_request_ = request; + cur_request_callback_ = callback; + } + + return rv; +} + +void SingleRequestHostResolver::OnResolveCompletion(int result) { + DCHECK(cur_request_ && cur_request_callback_); + + CompletionCallback* callback = cur_request_callback_; + + // Clear the outstanding request information. + cur_request_ = NULL; + cur_request_callback_ = NULL; + + // Call the user's original callback. + callback->Run(result); +} + } // namespace net diff --git a/net/base/host_resolver.h b/net/base/host_resolver.h index c730f76..a67419a 100644 --- a/net/base/host_resolver.h +++ b/net/base/host_resolver.h @@ -6,34 +6,73 @@ #define NET_BASE_HOST_RESOLVER_H_ #include <string> +#include <vector> #include "base/basictypes.h" +#include "base/lock.h" #include "base/ref_counted.h" #include "net/base/completion_callback.h" +#include "net/base/host_cache.h" + +class MessageLoop; namespace net { class AddressList; +class HostMapper; -// This class represents the task of resolving a hostname (or IP address -// literal) to an AddressList object. It can only resolve a single hostname at -// a time, so if you need to resolve multiple hostnames at the same time, you -// will need to allocate a HostResolver object for each hostname. +// This class represents the task of resolving hostnames (or IP address +// literal) to an AddressList object. +// +// HostResolver handles multiple requests at a time, so when cancelling a +// request the Request* handle that was returned by Resolve() needs to be +// given. A simpler alternative for consumers that only have 1 outstanding +// request at a time is to create a SingleRequestHostResolver wrapper around +// HostResolver (which will automatically cancel the single request when it +// goes out of scope). +// +// For each hostname that is requested, HostResolver creates a +// HostResolver::Job. This job gets dispatched to a thread in the global +// WorkerPool, where it runs "getaddrinfo(hostname)". If requests for that same +// host are made while the job is already outstanding, then they are attached +// to the existing job rather than creating a new one. This avoids doing +// parallel resolves for the same host. +// +// The way these classes fit together is illustrated by: +// // -// No attempt is made at this level to cache or pin resolution results. For -// each request, this API talks directly to the underlying name resolver of -// the local system, which may or may not result in a DNS query. The exact -// behavior depends on the system configuration. +// +------------- HostResolver ---------------+ +// | | | +// Job Job Job +// (for host1) (for host2) (for hostX) +// / | | / | | / | | +// Request ... Request Request ... Request Request ... Request +// (port1) (port2) (port3) (port4) (port5) (portX) +// +// +// When a HostResolver::Job finishes its work in the threadpool, the callbacks +// of each waiting request are run on the origin thread. +// +// Thread safety: This class is not threadsafe, and must only be called +// from one thread! // class HostResolver { public: - HostResolver(); + // Creates a HostResolver that caches up to |max_cache_entries| for + // |cache_duration_ms| milliseconds. + // + // TODO(eroman): Get rid of the default parameters as it violate google + // style. This is temporary to help with refactoring. + HostResolver(int max_cache_entries = 100, int cache_duration_ms = 60000); - // If a completion callback is pending when the resolver is destroyed, the - // host resolution is cancelled, and the completion callback will not be - // called. + // If any completion callbacks are pending when the resolver is destroyed, + // the host resolutions are cancelled, and the completion callbacks will not + // be called. ~HostResolver(); + // Opaque type used to cancel a request. + class Request; + // Resolves the given hostname (or IP address literal), filling out the // |addresses| object upon success. The |port| parameter will be set as the // sin(6)_port field of the sockaddr_in{6} struct. Returns OK if successful @@ -43,17 +82,82 @@ class HostResolver { // // When callback is non-null, the operation will be performed asynchronously. // ERR_IO_PENDING is returned if it has been scheduled successfully. Real - // result code will be passed to the completion callback. + // result code will be passed to the completion callback. If |req| is + // non-NULL, then |*req| will be filled with a handle to the async request. + // This handle is not valid after the request has completed. int Resolve(const std::string& hostname, int port, - AddressList* addresses, CompletionCallback* callback); + AddressList* addresses, CompletionCallback* callback, + Request** req); + + // Cancels the specified request. |req| is the handle returned by Resolve(). + // After a request is cancelled, its completion callback will not be called. + void CancelRequest(Request* req); private: - class Request; - friend class Request; - scoped_refptr<Request> request_; + class Job; + typedef std::vector<Request*> RequestsList; + typedef base::hash_map<std::string, scoped_refptr<Job> > JobMap; + + // Adds a job to outstanding jobs list. + void AddOutstandingJob(Job* job); + + // Returns the outstanding job for |hostname|, or NULL if there is none. + Job* FindOutstandingJob(const std::string& hostname); + + // Removes |job| from the outstanding jobs list. + void RemoveOutstandingJob(Job* job); + + // Callback for when |job| has completed with |error| and |addrlist|. + void OnJobComplete(Job* job, int error, const AddressList& addrlist); + + // Cache of host resolution results. + HostCache cache_; + + // Map from hostname to outstanding job. + JobMap jobs_; + + // The job that OnJobComplete() is currently processing (needed in case + // HostResolver gets deleted from within the callback). + scoped_refptr<Job> cur_completing_job_; + DISALLOW_COPY_AND_ASSIGN(HostResolver); }; +// This class represents the task of resolving a hostname (or IP address +// literal) to an AddressList object. It wraps HostResolver to resolve only a +// single hostname at a time and cancels this request when going out of scope. +class SingleRequestHostResolver { + public: + explicit SingleRequestHostResolver(HostResolver* resolver); + + // If a completion callback is pending when the resolver is destroyed, the + // host resolution is cancelled, and the completion callback will not be + // called. + ~SingleRequestHostResolver(); + + // Resolves the given hostname (or IP address literal), filling out the + // |addresses| object upon success. See HostResolver::Resolve() for details. + int Resolve(const std::string& hostname, int port, + AddressList* addresses, CompletionCallback* callback); + + private: + // Callback for when the request to |resolver_| completes, so we dispatch + // to the user's callback. + void OnResolveCompletion(int result); + + // The actual host resolver that will handle the request. + HostResolver* resolver_; + + // The current request (if any). + HostResolver::Request* cur_request_; + CompletionCallback* cur_request_callback_; + + // Completion callback for when request to |resolver_| completes. + net::CompletionCallbackImpl<SingleRequestHostResolver> callback_; + + DISALLOW_COPY_AND_ASSIGN(SingleRequestHostResolver); +}; + // A helper class used in unit tests to alter hostname mappings. See // SetHostMapper for details. class HostMapper : public base::RefCountedThreadSafe<HostMapper> { diff --git a/net/base/host_resolver_unittest.cc b/net/base/host_resolver_unittest.cc index d24d28e..090fd5c 100644 --- a/net/base/host_resolver_unittest.cc +++ b/net/base/host_resolver_unittest.cc @@ -26,8 +26,118 @@ using net::RuleBasedHostMapper; using net::ScopedHostMapper; using net::WaitingHostMapper; +// TODO(eroman): +// - Test mixing async with sync (in particular how does sync update the +// cache while an async is already pending). + namespace { +// A variant of WaitingHostMapper that pushes each host mapped into a list. +// (and uses a manual-reset event rather than auto-reset). +class CapturingHostMapper : public net::HostMapper { + public: + CapturingHostMapper() : event_(true, false) { + } + + void Signal() { + event_.Signal(); + } + + virtual std::string Map(const std::string& host) { + event_.Wait(); + { + AutoLock l(lock_); + capture_list_.push_back(host); + } + return MapUsingPrevious(host); + } + + std::vector<std::string> GetCaptureList() const { + std::vector<std::string> copy; + { + AutoLock l(lock_); + copy = capture_list_; + } + return copy; + } + + private: + std::vector<std::string> capture_list_; + mutable Lock lock_; + base::WaitableEvent event_; +}; + +// Helper that represents a single Resolve() result, used to inspect all the +// resolve results by forwarding them to Delegate. +class ResolveRequest { + public: + // Delegate interface, for notification when the ResolveRequest completes. + class Delegate { + public: + virtual ~Delegate() {} + virtual void OnCompleted(ResolveRequest* resolve) = 0; + }; + + ResolveRequest(net::HostResolver* resolver, + const std::string& hostname, + int port, + Delegate* delegate) + : hostname_(hostname), port_(port), resolver_(resolver), + delegate_(delegate), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &ResolveRequest::OnLookupFinished)) { + // Start the request. + int err = resolver->Resolve(hostname, port, &addrlist_, &callback_, &req_); + EXPECT_EQ(net::ERR_IO_PENDING, err); + } + + void Cancel() { + resolver_->CancelRequest(req_); + } + + const std::string& hostname() const { + return hostname_; + } + + int port() const { + return port_; + } + + int result() const { + return result_; + } + + const net::AddressList& addrlist() const { + return addrlist_; + } + + net::HostResolver* resolver() const { + return resolver_; + } + + private: + void OnLookupFinished(int result) { + result_ = result; + delegate_->OnCompleted(this); + } + + // The request details. + std::string hostname_; + int port_; + net::HostResolver::Request* req_; + + // The result of the resolve. + int result_; + net::AddressList addrlist_; + + net::HostResolver* resolver_; + + Delegate* delegate_; + net::CompletionCallbackImpl<ResolveRequest> callback_; + + DISALLOW_COPY_AND_ASSIGN(ResolveRequest); +}; + class HostResolverTest : public testing::Test { public: HostResolverTest() @@ -58,7 +168,8 @@ TEST_F(HostResolverTest, SynchronousLookup) { mapper->AddRule("just.testing", "192.168.1.42"); ScopedHostMapper scoped_mapper(mapper.get()); - int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, NULL, + NULL); EXPECT_EQ(net::OK, err); const struct addrinfo* ainfo = adrlist.head(); @@ -81,7 +192,7 @@ TEST_F(HostResolverTest, AsynchronousLookup) { ScopedHostMapper scoped_mapper(mapper.get()); int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, - &callback_); + &callback_, NULL); EXPECT_EQ(net::ERR_IO_PENDING, err); MessageLoop::current()->Run(); @@ -109,7 +220,7 @@ TEST_F(HostResolverTest, CanceledAsynchronousLookup) { const int kPortnum = 80; int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, - &callback_); + &callback_, NULL); EXPECT_EQ(net::ERR_IO_PENDING, err); // Make sure we will exit the queue even when callback is not called. @@ -134,7 +245,7 @@ TEST_F(HostResolverTest, NumericIPv4Address) { net::HostResolver host_resolver; net::AddressList adrlist; const int kPortnum = 5555; - int err = host_resolver.Resolve("127.1.2.3", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("127.1.2.3", kPortnum, &adrlist, NULL, NULL); EXPECT_EQ(net::OK, err); const struct addrinfo* ainfo = adrlist.head(); @@ -157,7 +268,8 @@ TEST_F(HostResolverTest, NumericIPv6Address) { net::HostResolver host_resolver; net::AddressList adrlist; const int kPortnum = 5555; - int err = host_resolver.Resolve("2001:db8::1", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("2001:db8::1", kPortnum, &adrlist, NULL, + NULL); // On computers without IPv6 support, getaddrinfo cannot convert IPv6 // address literals to addresses (getaddrinfo returns EAI_NONAME). So this // test has to allow host_resolver.Resolve to fail. @@ -190,8 +302,320 @@ TEST_F(HostResolverTest, EmptyHost) { net::HostResolver host_resolver; net::AddressList adrlist; const int kPortnum = 5555; - int err = host_resolver.Resolve("", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("", kPortnum, &adrlist, NULL, NULL); EXPECT_EQ(net::ERR_NAME_NOT_RESOLVED, err); } +// Helper class used by HostResolverTest.DeDupeRequests. It receives request +// completion notifications for all the resolves, so it can tally up and +// determine when we are done. +class DeDupeRequestsVerifier : public ResolveRequest::Delegate { + public: + explicit DeDupeRequestsVerifier(CapturingHostMapper* mapper) + : count_a_(0), count_b_(0), mapper_(mapper) {} + + // The test does 5 resolves (which can complete in any order). + virtual void OnCompleted(ResolveRequest* resolve) { + // Tally up how many requests we have seen. + if (resolve->hostname() == "a") { + count_a_++; + } else if (resolve->hostname() == "b") { + count_b_++; + } else { + FAIL() << "Unexpected hostname: " << resolve->hostname(); + } + + // Check that the port was set correctly. + EXPECT_EQ(resolve->port(), resolve->addrlist().GetPort()); + + // Check whether all the requests have finished yet. + int total_completions = count_a_ + count_b_; + if (total_completions == 5) { + EXPECT_EQ(2, count_a_); + EXPECT_EQ(3, count_b_); + + // The mapper should have been called only twice -- once with "a", once + // with "b". + std::vector<std::string> capture_list = mapper_->GetCaptureList(); + EXPECT_EQ(2U, capture_list.size()); + + // End this test, we are done. + MessageLoop::current()->Quit(); + } + } + + private: + int count_a_; + int count_b_; + CapturingHostMapper* mapper_; + + DISALLOW_COPY_AND_ASSIGN(DeDupeRequestsVerifier); +}; + +TEST_F(HostResolverTest, DeDupeRequests) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + net::HostResolver host_resolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + DeDupeRequestsVerifier verifier(mapper.get()); + + // Start 5 requests, duplicating hosts "a" and "b". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "b", 80, &verifier); + ResolveRequest req3(&host_resolver, "b", 81, &verifier); + ResolveRequest req4(&host_resolver, "a", 82, &verifier); + ResolveRequest req5(&host_resolver, "b", 83, &verifier); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.CancelMultipleRequests. +class CancelMultipleRequestsVerifier : public ResolveRequest::Delegate { + public: + CancelMultipleRequestsVerifier() {} + + // The cancels kill all but one request. + virtual void OnCompleted(ResolveRequest* resolve) { + EXPECT_EQ("a", resolve->hostname()); + EXPECT_EQ(82, resolve->port()); + + // Check that the port was set correctly. + EXPECT_EQ(resolve->port(), resolve->addrlist().GetPort()); + + // End this test, we are done. + MessageLoop::current()->Quit(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(CancelMultipleRequestsVerifier); +}; + +TEST_F(HostResolverTest, CancelMultipleRequests) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + net::HostResolver host_resolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + CancelMultipleRequestsVerifier verifier; + + // Start 5 requests, duplicating hosts "a" and "b". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "b", 80, &verifier); + ResolveRequest req3(&host_resolver, "b", 81, &verifier); + ResolveRequest req4(&host_resolver, "a", 82, &verifier); + ResolveRequest req5(&host_resolver, "b", 83, &verifier); + + // Cancel everything except request 4. + req1.Cancel(); + req2.Cancel(); + req3.Cancel(); + req5.Cancel(); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.CancelWithinCallback. +class CancelWithinCallbackVerifier : public ResolveRequest::Delegate { + public: + CancelWithinCallbackVerifier() + : req_to_cancel1_(NULL), req_to_cancel2_(NULL), num_completions_(0) { + } + + virtual void OnCompleted(ResolveRequest* resolve) { + num_completions_++; + + // Port 80 is the first request that the callback will be invoked for. + // While we are executing within that callback, cancel the other requests + // in the job and start another request. + if (80 == resolve->port()) { + EXPECT_EQ("a", resolve->hostname()); + + req_to_cancel1_->Cancel(); + req_to_cancel2_->Cancel(); + + // Start a request (so we can make sure the canceled requests don't + // complete before "finalrequest" finishes. + final_request_.reset(new ResolveRequest( + resolve->resolver(), "finalrequest", 70, this)); + + } else if (83 == resolve->port()) { + EXPECT_EQ("a", resolve->hostname()); + } else if (resolve->hostname() == "finalrequest") { + EXPECT_EQ(70, resolve->addrlist().GetPort()); + + // End this test, we are done. + MessageLoop::current()->Quit(); + } else { + FAIL() << "Unexpected completion: " << resolve->hostname() << ", " + << resolve->port(); + } + } + + void SetRequestsToCancel(ResolveRequest* req_to_cancel1, + ResolveRequest* req_to_cancel2) { + req_to_cancel1_ = req_to_cancel1; + req_to_cancel2_ = req_to_cancel2; + } + + private: + scoped_ptr<ResolveRequest> final_request_; + ResolveRequest* req_to_cancel1_; + ResolveRequest* req_to_cancel2_; + int num_completions_; + DISALLOW_COPY_AND_ASSIGN(CancelWithinCallbackVerifier); +}; + +TEST_F(HostResolverTest, CancelWithinCallback) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + net::HostResolver host_resolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + CancelWithinCallbackVerifier verifier; + + // Start 4 requests, duplicating hosts "a". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "a", 81, &verifier); + ResolveRequest req3(&host_resolver, "a", 82, &verifier); + ResolveRequest req4(&host_resolver, "a", 83, &verifier); + + // Once "a:80" completes, it will cancel "a:81" and "a:82". + verifier.SetRequestsToCancel(&req2, &req3); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.DeleteWithinCallback. +class DeleteWithinCallbackVerifier : public ResolveRequest::Delegate { + public: + DeleteWithinCallbackVerifier() {} + + virtual void OnCompleted(ResolveRequest* resolve) { + EXPECT_EQ("a", resolve->hostname()); + EXPECT_EQ(80, resolve->port()); + delete resolve->resolver(); + + // Quit after returning from OnCompleted (to give it a chance at + // incorrectly running the cancelled tasks). + MessageLoop::current()->PostTask(FROM_HERE, new MessageLoop::QuitTask()); + } + + private: + DISALLOW_COPY_AND_ASSIGN(DeleteWithinCallbackVerifier); +}; + +TEST_F(HostResolverTest, DeleteWithinCallback) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + // This should be deleted by DeleteWithinCallbackVerifier -- if it leaks + // then the test has failed. + net::HostResolver* host_resolver = new net::HostResolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + DeleteWithinCallbackVerifier verifier; + + // Start 4 requests, duplicating hosts "a". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(host_resolver, "a", 80, &verifier); + ResolveRequest req2(host_resolver, "a", 81, &verifier); + ResolveRequest req3(host_resolver, "a", 82, &verifier); + ResolveRequest req4(host_resolver, "a", 83, &verifier); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.StartWithinCallback. +class StartWithinCallbackVerifier : public ResolveRequest::Delegate { + public: + StartWithinCallbackVerifier() : num_requests_(0) {} + + virtual void OnCompleted(ResolveRequest* resolve) { + EXPECT_EQ("a", resolve->hostname()); + + if (80 == resolve->port()) { + // On completing the first request, start another request for "a". + // Since caching is disabled, this will result in another async request. + final_request_.reset(new ResolveRequest( + resolve->resolver(), "a", 70, this)); + } + if (++num_requests_ == 5) { + // Test is done. + MessageLoop::current()->Quit(); + } + } + + private: + int num_requests_; + scoped_ptr<ResolveRequest> final_request_; + DISALLOW_COPY_AND_ASSIGN(StartWithinCallbackVerifier); +}; + +TEST_F(HostResolverTest, StartWithinCallback) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + // Turn off caching for this host resolver. + net::HostResolver host_resolver(0, 0); + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + StartWithinCallbackVerifier verifier; + + // Start 4 requests, duplicating hosts "a". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "a", 81, &verifier); + ResolveRequest req3(&host_resolver, "a", 82, &verifier); + ResolveRequest req4(&host_resolver, "a", 83, &verifier); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + } // namespace diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc index d372e95..ab29cc4 100644 --- a/net/base/ssl_client_socket_unittest.cc +++ b/net/base/ssl_client_socket_unittest.cc @@ -77,7 +77,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Connect) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -115,7 +115,7 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectExpired) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kBadHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -152,7 +152,7 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectMismatched) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kMismatchedHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -194,7 +194,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Read) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, &callback); + &addr, &callback, NULL); EXPECT_EQ(net::ERR_IO_PENDING, rv); rv = callback.WaitForResult(); @@ -255,7 +255,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Read_SmallChunks) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -311,7 +311,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Read_Interrupted) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); diff --git a/net/base/ssl_test_util.cc b/net/base/ssl_test_util.cc index a3fe3b9..cac5d04 100644 --- a/net/base/ssl_test_util.cc +++ b/net/base/ssl_test_util.cc @@ -247,7 +247,7 @@ bool TestServerLauncher::WaitToStart(const std::string& host_name, int port) { // Otherwise tests can fail if they run faster than Python can start. net::AddressList addr; net::HostResolver resolver; - int rv = resolver.Resolve(host_name, port, &addr, NULL); + int rv = resolver.Resolve(host_name, port, &addr, NULL, NULL); if (rv != net::OK) return false; diff --git a/net/base/tcp_client_socket_pool.cc b/net/base/tcp_client_socket_pool.cc index 75875f2..7f7f13f 100644 --- a/net/base/tcp_client_socket_pool.cc +++ b/net/base/tcp_client_socket_pool.cc @@ -46,6 +46,7 @@ TCPClientSocketPool::ConnectingSocket::ConnectingSocket( callback_(this, &TCPClientSocketPool::ConnectingSocket::OnIOComplete)), pool_(pool), + resolver_(pool->GetHostResolver()), canceled_(false) { DCHECK(!ContainsKey(pool_->connecting_socket_map_, handle)); pool_->connecting_socket_map_[handle] = this; @@ -158,10 +159,12 @@ void TCPClientSocketPool::ConnectingSocket::Cancel() { TCPClientSocketPool::TCPClientSocketPool( int max_sockets_per_group, + HostResolver* host_resolver, ClientSocketFactory* client_socket_factory) : client_socket_factory_(client_socket_factory), idle_socket_count_(0), - max_sockets_per_group_(max_sockets_per_group) { + max_sockets_per_group_(max_sockets_per_group), + host_resolver_(host_resolver) { } TCPClientSocketPool::~TCPClientSocketPool() { diff --git a/net/base/tcp_client_socket_pool.h b/net/base/tcp_client_socket_pool.h index 8d4a3bd..82255c5 100644 --- a/net/base/tcp_client_socket_pool.h +++ b/net/base/tcp_client_socket_pool.h @@ -25,6 +25,7 @@ class ClientSocketFactory; class TCPClientSocketPool : public ClientSocketPool { public: TCPClientSocketPool(int max_sockets_per_group, + HostResolver* host_resolver, ClientSocketFactory* client_socket_factory); // ClientSocketPool methods: @@ -44,6 +45,10 @@ class TCPClientSocketPool : public ClientSocketPool { virtual void CloseIdleSockets(); + virtual HostResolver* GetHostResolver() const { + return host_resolver_; + } + virtual int idle_socket_count() const { return idle_socket_count_; } @@ -137,7 +142,7 @@ class TCPClientSocketPool : public ClientSocketPool { CompletionCallbackImpl<ConnectingSocket> callback_; scoped_ptr<ClientSocket> socket_; scoped_refptr<TCPClientSocketPool> pool_; - HostResolver resolver_; + SingleRequestHostResolver resolver_; AddressList addresses_; bool canceled_; @@ -185,6 +190,10 @@ class TCPClientSocketPool : public ClientSocketPool { // The maximum number of sockets kept per group. const int max_sockets_per_group_; + // The host resolver that will be used to do DNS lookups for connecting + // sockets. + HostResolver* host_resolver_; + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketPool); }; diff --git a/net/base/tcp_client_socket_pool_unittest.cc b/net/base/tcp_client_socket_pool_unittest.cc index ce85871..2073497 100644 --- a/net/base/tcp_client_socket_pool_unittest.cc +++ b/net/base/tcp_client_socket_pool_unittest.cc @@ -198,13 +198,19 @@ int TestSocketRequest::completion_count = 0; class TCPClientSocketPoolTest : public testing::Test { protected: TCPClientSocketPoolTest() - : pool_(new TCPClientSocketPool(kMaxSocketsPerGroup, + // We disable caching here since these unit tests don't expect + // host resolving to be able to complete synchronously. + // TODO(eroman): enable caching. + : host_resolver_(0, 0), + pool_(new TCPClientSocketPool(kMaxSocketsPerGroup, + &host_resolver_, &client_socket_factory_)) {} virtual void SetUp() { TestSocketRequest::completion_count = 0; } + HostResolver host_resolver_; MockClientSocketFactory client_socket_factory_; scoped_refptr<ClientSocketPool> pool_; std::vector<TestSocketRequest*> request_order_; diff --git a/net/base/tcp_client_socket_unittest.cc b/net/base/tcp_client_socket_unittest.cc index 9dad6e2..fe37c71 100644 --- a/net/base/tcp_client_socket_unittest.cc +++ b/net/base/tcp_client_socket_unittest.cc @@ -87,7 +87,7 @@ void TCPClientSocketTest::SetUp() { AddressList addr; HostResolver resolver; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); + int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL, NULL); CHECK(rv == OK); sock_.reset(new TCPClientSocket(addr)); } diff --git a/net/base/tcp_pinger_unittest.cc b/net/base/tcp_pinger_unittest.cc index 2ae6b1f..edf29a5 100644 --- a/net/base/tcp_pinger_unittest.cc +++ b/net/base/tcp_pinger_unittest.cc @@ -67,7 +67,7 @@ TEST_F(TCPPingerTest, Ping) { net::AddressList addr; net::HostResolver resolver; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); + int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL, NULL); EXPECT_EQ(rv, net::OK); net::TCPPinger pinger(addr); @@ -82,7 +82,7 @@ TEST_F(TCPPingerTest, PingFail) { // "Kill" "server" listen_sock_ = NULL; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); + int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL, NULL); EXPECT_EQ(rv, net::OK); net::TCPPinger pinger(addr); |