diff options
author | ericroman@google.com <ericroman@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-12 00:49:38 +0000 |
---|---|---|
committer | ericroman@google.com <ericroman@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-06-12 00:49:38 +0000 |
commit | 8a00f00ab5d68ffcc998fd04d2ca343af7cdf190 (patch) | |
tree | fd464ba49db4271c76c1cf8f769a22120ad631af /net | |
parent | 77ae132c1bfdd986228b6f1c0d8c63baa441afdf (diff) | |
download | chromium_src-8a00f00ab5d68ffcc998fd04d2ca343af7cdf190.zip chromium_src-8a00f00ab5d68ffcc998fd04d2ca343af7cdf190.tar.gz chromium_src-8a00f00ab5d68ffcc998fd04d2ca343af7cdf190.tar.bz2 |
* Avoid doing concurrent DNS resolves of the same hostname in HostResolver.
* Add a 1 minute cache for host resolves.
* Refactor HostResolver to handle multiple requests.
* Make HostResolver a dependency of URLRequestContext. operate the HostResolver
in async mode for proxy resolver (bridging to IO thread).
TEST=unittests
BUG=13163
Review URL: http://codereview.chromium.org/118100
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@18236 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
42 files changed, 1777 insertions, 179 deletions
diff --git a/net/base/address_list.cc b/net/base/address_list.cc index 1cd2f62..d92d280 100644 --- a/net/base/address_list.cc +++ b/net/base/address_list.cc @@ -11,14 +11,128 @@ #include <netdb.h> #endif +#include "base/logging.h" + namespace net { +namespace { + +// Make a deep copy of |info|. This copy should be deleted using +// DeleteCopyOfAddrinfo(), and NOT freeaddrinfo(). +struct addrinfo* CreateCopyOfAddrinfo(const struct addrinfo* info) { + struct addrinfo* copy = new struct addrinfo; + + // Copy all the fields (some of these are pointers, we will fix that next). + memcpy(copy, info, sizeof(addrinfo)); + + // ai_canonname is a NULL-terminated string. + if (info->ai_canonname) { +#ifdef OS_WIN + copy->ai_canonname = _strdup(info->ai_canonname); +#else + copy->ai_canonname = strdup(info->ai_canonname); +#endif + } + + // ai_addr is a buffer of length ai_addrlen. + if (info->ai_addr) { + copy->ai_addr = reinterpret_cast<sockaddr *>(new char[info->ai_addrlen]); + memcpy(copy->ai_addr, info->ai_addr, info->ai_addrlen); + } + + // Recursive copy. + if (info->ai_next) + copy->ai_next = CreateCopyOfAddrinfo(info->ai_next); + + return copy; +} + +// Free an addrinfo that was created by CreateCopyOfAddrinfo(). +void FreeMyAddrinfo(struct addrinfo* info) { + if (info->ai_canonname) + free(info->ai_canonname); // Allocated by strdup. + + if (info->ai_addr) + delete [] reinterpret_cast<char*>(info->ai_addr); + + struct addrinfo* next = info->ai_next; + + delete info; + + // Recursive free. + if (next) + FreeMyAddrinfo(next); +} + +// Returns the address to port field in |info|. +uint16* GetPortField(const struct addrinfo* info) { + if (info->ai_family == AF_INET) { + DCHECK_EQ(sizeof(sockaddr_in), info->ai_addrlen); + struct sockaddr_in* sockaddr = + reinterpret_cast<struct sockaddr_in*>(info->ai_addr); + return &sockaddr->sin_port; + } else if (info->ai_family == AF_INET6) { + DCHECK_EQ(sizeof(sockaddr_in6), info->ai_addrlen); + struct sockaddr_in6* sockaddr = + reinterpret_cast<struct sockaddr_in6*>(info->ai_addr); + return &sockaddr->sin6_port; + } else { + NOTREACHED(); + return NULL; + } +} + +// Assign the port for all addresses in the list. +void SetPortRecursive(struct addrinfo* info, int port) { + uint16* port_field = GetPortField(info); + *port_field = htons(port); + + // Assign recursively. + if (info->ai_next) + SetPortRecursive(info->ai_next, port); +} + +} // namespace + void AddressList::Adopt(struct addrinfo* head) { - data_ = new Data(head); + data_ = new Data(head, true /*is_system_created*/); +} + +void AddressList::Copy(const struct addrinfo* head) { + data_ = new Data(CreateCopyOfAddrinfo(head), false /*is_system_created*/); +} + +void AddressList::SetPort(int port) { + SetPortRecursive(data_->head, port); +} + +int AddressList::GetPort() const { + uint16* port_field = GetPortField(data_->head); + return ntohs(*port_field); +} + +void AddressList::SetFrom(const AddressList& src, int port) { + if (src.GetPort() == port) { + // We can reference the data from |src| directly. + *this = src; + } else { + // Otherwise we need to make a copy in order to change the port number. + Copy(src.head()); + SetPort(port); + } +} + +void AddressList::Reset() { + data_ = NULL; } AddressList::Data::~Data() { - freeaddrinfo(head); + // Call either freeaddrinfo(head), or FreeMyAddrinfo(head), depending who + // created the data. + if (is_system_created) + freeaddrinfo(head); + else + FreeMyAddrinfo(head); } } // namespace net diff --git a/net/base/address_list.h b/net/base/address_list.h index 5bffd4c..506350b 100644 --- a/net/base/address_list.h +++ b/net/base/address_list.h @@ -20,16 +20,39 @@ class AddressList { // object. void Adopt(struct addrinfo* head); + // Copies the given addrinfo rather than adopting it. + void Copy(const struct addrinfo* head); + + // Sets the port of all addresses in the list to |port| (that is the + // sin[6]_port field for the sockaddrs). + void SetPort(int port); + + // Retrieves the port number of the first sockaddr in the list. (If SetPort() + // was previously used on this list, then all the addresses will have this + // same port number.) + int GetPort() const; + + // Sets the address to match |src|, and have each sockaddr's port be |port|. + // If |src| already has the desired port this operation is cheap (just adds + // a reference to |src|'s data.) Otherwise we will make a copy. + void SetFrom(const AddressList& src, int port); + + // Clears all data from this address list. This leaves the list in the same + // empty state as when first constructed. + void Reset(); + // Get access to the head of the addrinfo list. const struct addrinfo* head() const { return data_->head; } private: struct Data : public base::RefCountedThreadSafe<Data> { - explicit Data(struct addrinfo* ai) : head(ai) {} + Data(struct addrinfo* ai, bool is_system_created) + : head(ai), is_system_created(is_system_created) {} ~Data(); struct addrinfo* head; - private: - Data(); + + // Indicates which free function to use for |head|. + bool is_system_created; }; scoped_refptr<Data> data_; }; diff --git a/net/base/address_list_unittest.cc b/net/base/address_list_unittest.cc new file mode 100644 index 0000000..d594c85 --- /dev/null +++ b/net/base/address_list_unittest.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/address_list.h" + +#if defined(OS_WIN) +#include <ws2tcpip.h> +#include <wspiapi.h> // Needed for Win2k compat. +#elif defined(OS_POSIX) +#include <netdb.h> +#include <sys/socket.h> +#endif + +#include "base/string_util.h" +#include "net/base/net_util.h" +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif +#include "testing/gtest/include/gtest/gtest.h" + +namespace { + +// Use getaddrinfo() to allocate an addrinfo structure. +void CreateAddressList(net::AddressList* addrlist, int port) { +#if defined(OS_WIN) + net::EnsureWinsockInit(); +#endif + std::string portstr = IntToString(port); + + struct addrinfo* result = NULL; + struct addrinfo hints = {0}; + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + hints.ai_socktype = SOCK_STREAM; + + int err = getaddrinfo("192.168.1.1", portstr.c_str(), &hints, &result); + EXPECT_EQ(0, err); + addrlist->Adopt(result); +} + +TEST(AddressListTest, GetPort) { + net::AddressList addrlist; + CreateAddressList(&addrlist, 81); + EXPECT_EQ(81, addrlist.GetPort()); + + addrlist.SetPort(83); + EXPECT_EQ(83, addrlist.GetPort()); +} + +TEST(AddressListTest, Assignment) { + net::AddressList addrlist1; + CreateAddressList(&addrlist1, 85); + EXPECT_EQ(85, addrlist1.GetPort()); + + // Should reference the same data as addrlist1 -- so when we change addrlist1 + // both are changed. + net::AddressList addrlist2 = addrlist1; + EXPECT_EQ(85, addrlist2.GetPort()); + + addrlist1.SetPort(80); + EXPECT_EQ(80, addrlist1.GetPort()); + EXPECT_EQ(80, addrlist2.GetPort()); +} + +TEST(AddressListTest, Copy) { + net::AddressList addrlist1; + CreateAddressList(&addrlist1, 85); + EXPECT_EQ(85, addrlist1.GetPort()); + + net::AddressList addrlist2; + addrlist2.Copy(addrlist1.head()); + + // addrlist1 is the same as addrlist2 at this point. + EXPECT_EQ(85, addrlist1.GetPort()); + EXPECT_EQ(85, addrlist2.GetPort()); + + // Changes to addrlist1 are not reflected in addrlist2. + addrlist1.SetPort(70); + addrlist2.SetPort(90); + + EXPECT_EQ(70, addrlist1.GetPort()); + EXPECT_EQ(90, addrlist2.GetPort()); +} + +} // namespace diff --git a/net/base/client_socket_pool.h b/net/base/client_socket_pool.h index 599b7a9..6d200c0 100644 --- a/net/base/client_socket_pool.h +++ b/net/base/client_socket_pool.h @@ -17,6 +17,7 @@ namespace net { class ClientSocket; class ClientSocketHandle; +class HostResolver; // A ClientSocketPool is used to restrict the number of sockets open at a time. // It also maintains a list of idle persistent sockets. @@ -69,6 +70,9 @@ class ClientSocketPool : public base::RefCounted<ClientSocketPool> { // Called to close any idle connections held by the connection manager. virtual void CloseIdleSockets() = 0; + // Returns the HostResolver that will be used for host lookups. + virtual HostResolver* GetHostResolver() const = 0; + // The total number of idle sockets in the pool. virtual int idle_socket_count() const = 0; diff --git a/net/base/host_cache.cc b/net/base/host_cache.cc new file mode 100644 index 0000000..53af5b4 --- /dev/null +++ b/net/base/host_cache.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/host_cache.h" + +#include "base/logging.h" +#include "net/base/net_errors.h" + +namespace net { + +//----------------------------------------------------------------------------- + +HostCache::Entry::Entry(int error, + const AddressList& addrlist, + base::TimeTicks expiration) + : error(error), addrlist(addrlist), expiration(expiration) { +} + +HostCache::Entry::~Entry() { +} + +//----------------------------------------------------------------------------- + +HostCache::HostCache(size_t max_entries, size_t cache_duration_ms) + : max_entries_(max_entries), cache_duration_ms_(cache_duration_ms) { +} + +HostCache::~HostCache() { +} + +const HostCache::Entry* HostCache::Lookup(const std::string& hostname, + base::TimeTicks now) const { + if (caching_is_disabled()) + return NULL; + + EntryMap::const_iterator it = entries_.find(hostname); + if (it == entries_.end()) + return NULL; // Not found. + + Entry* entry = it->second.get(); + if (CanUseEntry(entry, now)) + return entry; + + return NULL; +} + +HostCache::Entry* HostCache::Set(const std::string& hostname, + int error, + const AddressList addrlist, + base::TimeTicks now) { + if (caching_is_disabled()) + return NULL; + + base::TimeTicks expiration = now + + base::TimeDelta::FromMilliseconds(cache_duration_ms_); + + scoped_refptr<Entry>& entry = entries_[hostname]; + if (!entry) { + // Entry didn't exist, creating one now. + Entry* ptr = new Entry(error, addrlist, expiration); + entry = ptr; + + // Compact the cache if we grew it beyond limit -- exclude |entry| from + // being pruned though! + if (entries_.size() > max_entries_) + Compact(now, ptr); + return ptr; + } else { + // Update an existing cache entry. + entry->error = error; + entry->addrlist = addrlist; + entry->expiration = expiration; + return entry.get(); + } +} + +// static +bool HostCache::CanUseEntry(const Entry* entry, const base::TimeTicks now) { + return entry->error == OK && entry->expiration > now; +} + +void HostCache::Compact(base::TimeTicks now, const Entry* pinned_entry) { + // Clear out expired entries. + for (EntryMap::iterator it = entries_.begin(); it != entries_.end(); ) { + Entry* entry = (it->second).get(); + if (entry != pinned_entry && !CanUseEntry(entry, now)) { + entries_.erase(it++); + } else { + ++it; + } + } + + if (entries_.size() <= max_entries_) + return; + + // If we still have too many entries, start removing unexpired entries + // at random. + // TODO(eroman): this eviction policy could be better (access count FIFO + // or whatever). + for (EntryMap::iterator it = entries_.begin(); + it != entries_.end() && entries_.size() > max_entries_; ) { + Entry* entry = (it->second).get(); + if (entry != pinned_entry) { + entries_.erase(it++); + } else { + ++it; + } + } + + if (entries_.size() > max_entries_) + DLOG(WARNING) << "Still above max entries limit"; +} + +} // namespace net diff --git a/net/base/host_cache.h b/net/base/host_cache.h new file mode 100644 index 0000000..2085b7c --- /dev/null +++ b/net/base/host_cache.h @@ -0,0 +1,93 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_HOST_CACHE_H_ +#define NET_BASE_HOST_CACHE_H_ + +#include <string> + +#include "base/hash_tables.h" +#include "base/ref_counted.h" +#include "base/time.h" +#include "net/base/address_list.h" +#include "testing/gtest/include/gtest/gtest_prod.h" + +namespace net { + +// Cache used by HostResolver to map hostnames to their resolved result. +// If the resolve is still in progress, the entry will reference the job +// responsible for populating it. +class HostCache { + public: + // Stores the latest address list that was looked up for a hostname. + struct Entry : public base::RefCounted<Entry> { + Entry(int error, const AddressList& addrlist, base::TimeTicks expiration); + ~Entry(); + + // The resolve results for this entry. + int error; + AddressList addrlist; + + // The time when this entry expires. + base::TimeTicks expiration; + }; + + // Constructs a HostCache whose entries are valid for |cache_duration_ms| + // milliseconds. The cache will store up to |max_entries|. + HostCache(size_t max_entries, size_t cache_duration_ms); + + ~HostCache(); + + // Returns a pointer to the entry for |hostname|, which is valid at time + // |now|. If there is no such entry, returns NULL. + const Entry* Lookup(const std::string& hostname, base::TimeTicks now) const; + + // Overwrites or creates an entry for |hostname|. Returns the pointer to the + // entry, or NULL on failure (fails if caching is disabled). + // (|error|, |addrlist|) is the value to set, and |now| is the current + // timestamp. + Entry* Set(const std::string& hostname, + int error, + const AddressList addrlist, + base::TimeTicks now); + + // Returns true if this HostCache can contain no entries. + bool caching_is_disabled() const { + return max_entries_ == 0; + } + + // Returns the number of entries in the cache. + size_t size() const { + return entries_.size(); + } + + private: + FRIEND_TEST(HostCacheTest, Compact); + FRIEND_TEST(HostCacheTest, NoCache); + + typedef base::hash_map<std::string, scoped_refptr<Entry> > EntryMap; + + // Returns true if this cache entry's result is valid at time |now|. + static bool CanUseEntry(const Entry* entry, const base::TimeTicks now); + + // Prunes entries from the cache to bring it below max entry bound. Entries + // matching |pinned_entry| will NOT be pruned. + void Compact(base::TimeTicks now, const Entry* pinned_entry); + + // Bound on total size of the cache. + size_t max_entries_; + + // Time to live for cache entries in milliseconds. + size_t cache_duration_ms_; + + // Map from hostname (presumably in lowercase canonicalized format) to + // a resolved result entry. + EntryMap entries_; + + DISALLOW_COPY_AND_ASSIGN(HostCache); +}; + +} // namespace net + +#endif // NET_BASE_HOST_CACHE_H_ diff --git a/net/base/host_cache_unittest.cc b/net/base/host_cache_unittest.cc new file mode 100644 index 0000000..fe01e20 --- /dev/null +++ b/net/base/host_cache_unittest.cc @@ -0,0 +1,218 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/host_cache.h" + +#include "base/stl_util-inl.h" +#include "base/string_util.h" +#include "net/base/net_errors.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { +static const int kMaxCacheEntries = 10; +static const int kCacheDurationMs = 10000; // 10 seconds. +} + +TEST(HostCacheTest, Basic) { + HostCache cache(kMaxCacheEntries, kCacheDurationMs); + + // Start at t=0. + base::TimeTicks now; + + const HostCache::Entry* entry1 = NULL; // Entry for foobar.com. + const HostCache::Entry* entry2 = NULL; // Entry for foobar2.com. + + EXPECT_EQ(0U, cache.size()); + + // Add an entry for "foobar.com" at t=0. + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + cache.Set("foobar.com", OK, AddressList(), now); + entry1 = cache.Lookup("foobar.com", base::TimeTicks()); + EXPECT_FALSE(NULL == entry1); + EXPECT_EQ(1U, cache.size()); + + // Advance to t=5. + now += base::TimeDelta::FromSeconds(5); + + // Add an entry for "foobar2.com" at t=5. + EXPECT_EQ(NULL, cache.Lookup("foobar2.com", base::TimeTicks())); + cache.Set("foobar2.com", OK, AddressList(), now); + entry2 = cache.Lookup("foobar2.com", base::TimeTicks()); + EXPECT_FALSE(NULL == entry1); + EXPECT_EQ(2U, cache.size()); + + // Advance to t=9 + now += base::TimeDelta::FromSeconds(4); + + // Verify that the entries we added are still retrievable, and usable. + EXPECT_EQ(entry1, cache.Lookup("foobar.com", now)); + EXPECT_EQ(entry2, cache.Lookup("foobar2.com", now)); + + // Advance to t=10; entry1 is now expired. + now += base::TimeDelta::FromSeconds(1); + + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); + EXPECT_EQ(entry2, cache.Lookup("foobar2.com", now)); + + // Update entry1, so it is no longer expired. + cache.Set("foobar.com", OK, AddressList(), now); + // Re-uses existing entry storage. + EXPECT_EQ(entry1, cache.Lookup("foobar.com", now)); + EXPECT_EQ(2U, cache.size()); + + // Both entries should still be retrievable and usable. + EXPECT_EQ(entry1, cache.Lookup("foobar.com", now)); + EXPECT_EQ(entry2, cache.Lookup("foobar2.com", now)); + + // Advance to t=20; both entries are now expired. + now += base::TimeDelta::FromSeconds(10); + + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); + EXPECT_EQ(NULL, cache.Lookup("foobar2.com", now)); +} + +// Try caching entries for a failed resolve attempt. +TEST(HostCacheTest, NegativeEntry) { + HostCache cache(kMaxCacheEntries, kCacheDurationMs); + + // Set t=0. + base::TimeTicks now; + + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + cache.Set("foobar.com", ERR_NAME_NOT_RESOLVED, AddressList(), now); + EXPECT_EQ(1U, cache.size()); + + // We disallow use of negative entries. + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); + + // Now overwrite with a valid entry, and then overwrite with negative entry + // again -- the valid entry should be kicked out. + cache.Set("foobar.com", OK, AddressList(), now); + EXPECT_FALSE(NULL == cache.Lookup("foobar.com", now)); + cache.Set("foobar.com", ERR_NAME_NOT_RESOLVED, AddressList(), now); + EXPECT_EQ(NULL, cache.Lookup("foobar.com", now)); +} + +TEST(HostCacheTest, Compact) { + // Initial entries limit is big enough to accomadate everything we add. + net::HostCache cache(kMaxCacheEntries, kCacheDurationMs); + + EXPECT_EQ(0U, cache.size()); + + // t=10 + base::TimeTicks now = base::TimeTicks() + base::TimeDelta::FromSeconds(10); + + // Add five valid entries at t=10. + for (int i = 0; i < 5; ++i) { + std::string hostname = StringPrintf("valid%d", i); + cache.Set(hostname, OK, AddressList(), now); + } + EXPECT_EQ(5U, cache.size()); + + // Add 3 expired entries at t=0. + for (int i = 0; i < 3; ++i) { + std::string hostname = StringPrintf("expired%d", i); + base::TimeTicks t = now - base::TimeDelta::FromSeconds(10); + cache.Set(hostname, OK, AddressList(), t); + } + EXPECT_EQ(8U, cache.size()); + + // Add 2 negative entries at t=10 + for (int i = 0; i < 2; ++i) { + std::string hostname = StringPrintf("negative%d", i); + cache.Set(hostname, ERR_NAME_NOT_RESOLVED, AddressList(), now); + } + EXPECT_EQ(10U, cache.size()); + + EXPECT_TRUE(ContainsKey(cache.entries_, "valid0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid1")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid2")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid3")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid4")); + EXPECT_TRUE(ContainsKey(cache.entries_, "expired0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "expired1")); + EXPECT_TRUE(ContainsKey(cache.entries_, "expired2")); + EXPECT_TRUE(ContainsKey(cache.entries_, "negative0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "negative1")); + + // Shrink the max constraints bound and compact. We expect the "negative" + // and "expired" entries to have been dropped. + cache.max_entries_ = 5; + cache.Compact(now, NULL); + EXPECT_EQ(5U, cache.entries_.size()); + + EXPECT_TRUE(ContainsKey(cache.entries_, "valid0")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid1")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid2")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid3")); + EXPECT_TRUE(ContainsKey(cache.entries_, "valid4")); + EXPECT_FALSE(ContainsKey(cache.entries_, "expired0")); + EXPECT_FALSE(ContainsKey(cache.entries_, "expired1")); + EXPECT_FALSE(ContainsKey(cache.entries_, "expired2")); + EXPECT_FALSE(ContainsKey(cache.entries_, "negative0")); + EXPECT_FALSE(ContainsKey(cache.entries_, "negative1")); + + // Shrink further -- this time the compact will start dropping valid entries + // to make space. + cache.max_entries_ = 3; + cache.Compact(now, NULL); + EXPECT_EQ(3U, cache.size()); +} + +// Add entries while the cache is at capacity, causing evictions. +TEST(HostCacheTest, SetWithCompact) { + net::HostCache cache(3, kCacheDurationMs); + + // t=10 + base::TimeTicks now = + base::TimeTicks() + base::TimeDelta::FromMilliseconds(kCacheDurationMs); + + cache.Set("host1", OK, AddressList(), now); + cache.Set("host2", OK, AddressList(), now); + cache.Set("expired", OK, AddressList(), + now - base::TimeDelta::FromMilliseconds(kCacheDurationMs)); + + EXPECT_EQ(3U, cache.size()); + + // Should all be retrievable except "expired". + EXPECT_FALSE(NULL == cache.Lookup("host1", now)); + EXPECT_FALSE(NULL == cache.Lookup("host2", now)); + EXPECT_TRUE(NULL == cache.Lookup("expired", now)); + + // Adding the fourth entry will cause "expired" to be evicted. + cache.Set("host3", OK, AddressList(), now); + EXPECT_EQ(3U, cache.size()); + EXPECT_EQ(NULL, cache.Lookup("expired", now)); + EXPECT_FALSE(NULL == cache.Lookup("host1", now)); + EXPECT_FALSE(NULL == cache.Lookup("host2", now)); + EXPECT_FALSE(NULL == cache.Lookup("host3", now)); + + // Add two more entries. Something should be evicted, however "host5" + // should definitely be in there (since it was last inserted). + cache.Set("host4", OK, AddressList(), now); + EXPECT_EQ(3U, cache.size()); + cache.Set("host5", OK, AddressList(), now); + EXPECT_EQ(3U, cache.size()); + EXPECT_FALSE(NULL == cache.Lookup("host5", now)); +} + +TEST(HostCacheTest, NoCache) { + // Disable caching. + HostCache cache(0, kCacheDurationMs); + EXPECT_TRUE(cache.caching_is_disabled()); + + // Set t=0. + base::TimeTicks now; + + // Lookup and Set should have no effect. + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + cache.Set("foobar.com", OK, AddressList(), now); + EXPECT_EQ(NULL, cache.Lookup("foobar.com", base::TimeTicks())); + + EXPECT_EQ(0U, cache.size()); +} + +} // namespace net diff --git a/net/base/host_resolver.cc b/net/base/host_resolver.cc index ec693ee..50329ff 100644 --- a/net/base/host_resolver.cc +++ b/net/base/host_resolver.cc @@ -15,8 +15,11 @@ #include <resolv.h> #endif +#include "base/compiler_specific.h" #include "base/message_loop.h" +#include "base/stl_util-inl.h" #include "base/string_util.h" +#include "base/time.h" #include "base/worker_pool.h" #include "net/base/address_list.h" #include "net/base/net_errors.h" @@ -24,7 +27,6 @@ #if defined(OS_LINUX) #include "base/singleton.h" #include "base/thread_local_storage.h" -#include "base/time.h" #endif #if defined(OS_WIN) @@ -110,8 +112,7 @@ ThreadLocalStorage::Slot DnsReloadTimer::tls_index_(base::LINKER_INITIALIZED); #endif // defined(OS_LINUX) -static int HostResolverProc( - const std::string& host, const std::string& port, struct addrinfo** out) { +static int HostResolverProc(const std::string& host, struct addrinfo** out) { struct addrinfo hints = {0}; hints.ai_family = AF_UNSPEC; @@ -144,7 +145,7 @@ static int HostResolverProc( // Restrict result set to only this socket type to avoid duplicates. hints.ai_socktype = SOCK_STREAM; - int err = getaddrinfo(host.c_str(), port.c_str(), &hints, out); + int err = getaddrinfo(host.c_str(), NULL, &hints, out); #if defined(OS_LINUX) net::DnsReloadTimer* dns_timer = Singleton<net::DnsReloadTimer>::get(); // If we fail, re-initialise the resolver just in case there have been any @@ -152,7 +153,7 @@ static int HostResolverProc( if (err && dns_timer->Expired()) { res_nclose(&_res); if (!res_ninit(&_res)) - err = getaddrinfo(host.c_str(), port.c_str(), &hints, out); + err = getaddrinfo(host.c_str(), NULL, &hints, out); } #endif @@ -160,50 +161,148 @@ static int HostResolverProc( } static int ResolveAddrInfo(HostMapper* mapper, const std::string& host, - const std::string& port, struct addrinfo** out) { + struct addrinfo** out) { if (mapper) { std::string mapped_host = mapper->Map(host); if (mapped_host.empty()) return ERR_NAME_NOT_RESOLVED; - return HostResolverProc(mapped_host, port, out); + return HostResolverProc(mapped_host, out); } else { - return HostResolverProc(host, port, out); + return HostResolverProc(host, out); } } //----------------------------------------------------------------------------- -class HostResolver::Request : - public base::RefCountedThreadSafe<HostResolver::Request> { +class HostResolver::Request { public: - Request(HostResolver* resolver, - const std::string& host, - const std::string& port, - AddressList* addresses, - CompletionCallback* callback) + Request(CompletionCallback* callback, AddressList* addresses, int port) + : job_(NULL), callback_(callback), addresses_(addresses), port_(port) {} + + // Mark the request as cancelled. + void Cancel() { + job_ = NULL; + callback_ = NULL; + addresses_ = NULL; + } + + bool was_cancelled() const { + return callback_ == NULL; + } + + void set_job(Job* job) { + DCHECK(job != NULL); + // Identify which job the request is waiting on. + job_ = job; + } + + void OnComplete(int error, const AddressList& addrlist) { + if (error == OK) + addresses_->SetFrom(addrlist, port_); + callback_->Run(error); + } + + int port() const { + return port_; + } + + Job* job() const { + return job_; + } + + private: + // The resolve job (running in worker pool) that this request is dependent on. + Job* job_; + + // The user's callback to invoke when the request completes. + CompletionCallback* callback_; + + // The address list to save result into. + AddressList* addresses_; + + // The desired port number for the socket addresses. + int port_; + + DISALLOW_COPY_AND_ASSIGN(Request); +}; + +//----------------------------------------------------------------------------- + +// This class represents a request to the worker pool for a "getaddrinfo()" +// call. +class HostResolver::Job : public base::RefCountedThreadSafe<HostResolver::Job> { + public: + Job(HostResolver* resolver, const std::string& host) : host_(host), - port_(port), resolver_(resolver), - addresses_(addresses), - callback_(callback), origin_loop_(MessageLoop::current()), host_mapper_(host_mapper), error_(OK), results_(NULL) { } - ~Request() { + ~Job() { if (results_) freeaddrinfo(results_); + + // Free the requests attached to this job. + STLDeleteElements(&requests_); } + // Attaches a request to this job. The job takes ownership of |req| and will + // take care to delete it. + void AddRequest(HostResolver::Request* req) { + req->set_job(this); + requests_.push_back(req); + } + + // Called from origin loop. + void Start() { + // Dispatch the job to a worker thread. + if (!WorkerPool::PostTask(FROM_HERE, + NewRunnableMethod(this, &Job::DoLookup), true)) { + NOTREACHED(); + + // Since we could be running within Resolve() right now, we can't just + // call OnLookupComplete(). Instead we must wait until Resolve() has + // returned (IO_PENDING). + error_ = ERR_UNEXPECTED; + MessageLoop::current()->PostTask( + FROM_HERE, NewRunnableMethod(this, &Job::OnLookupComplete)); + } + } + + // Cancels the current job. Callable from origin thread. + void Cancel() { + resolver_ = NULL; + + AutoLock locked(origin_loop_lock_); + origin_loop_ = NULL; + } + + // Called from origin thread. + bool was_cancelled() const { + return resolver_ == NULL; + } + + // Called from origin thread. + const std::string& host() const { + return host_; + } + + // Called from origin thread. + const RequestsList& requests() const { + return requests_; + } + + private: void DoLookup() { // Running on the worker thread - error_ = ResolveAddrInfo(host_mapper_, host_, port_, &results_); + error_ = ResolveAddrInfo(host_mapper_, host_, &results_); - Task* reply = NewRunnableMethod(this, &Request::DoCallback); + Task* reply = NewRunnableMethod(this, &Job::OnLookupComplete); // The origin loop could go away while we are trying to post to it, so we // need to call its PostTask method inside a lock. See ~HostResolver. @@ -219,43 +318,33 @@ class HostResolver::Request : delete reply; } - void DoCallback() { - // Running on the origin thread. + // Callback for when DoLookup() completes (runs on origin thread). + void OnLookupComplete() { + DCHECK_EQ(origin_loop_, MessageLoop::current()); DCHECK(error_ || results_); - // We may have been cancelled! - if (!resolver_) + if (was_cancelled()) return; - if (!error_) { - addresses_->Adopt(results_); + DCHECK(!requests_.empty()); + + // Adopt the address list using the port number of the first request. + AddressList addrlist; + if (error_ == OK) { + addrlist.Adopt(results_); + addrlist.SetPort(requests_[0]->port()); results_ = NULL; } - // Drop the resolver's reference to us. Do this before running the - // callback since the callback might result in the resolver being - // destroyed. - resolver_->request_ = NULL; - - callback_->Run(error_); + resolver_->OnJobComplete(this, error_, addrlist); } - void Cancel() { - resolver_ = NULL; - - AutoLock locked(origin_loop_lock_); - origin_loop_ = NULL; - } - - private: // Set on the origin thread, read on the worker thread. std::string host_; - std::string port_; // Only used on the origin thread (where Resolve was called). HostResolver* resolver_; - AddressList* addresses_; - CompletionCallback* callback_; + RequestsList requests_; // The requests waiting on this job. // Used to post ourselves onto the origin thread. Lock origin_loop_lock_; @@ -270,48 +359,205 @@ class HostResolver::Request : // Assigned on the worker thread, read on the origin thread. int error_; struct addrinfo* results_; + + DISALLOW_COPY_AND_ASSIGN(Job); }; //----------------------------------------------------------------------------- -HostResolver::HostResolver() { +HostResolver::HostResolver(int max_cache_entries, int cache_duration_ms) + : cache_(max_cache_entries, cache_duration_ms) { #if defined(OS_WIN) EnsureWinsockInit(); #endif } HostResolver::~HostResolver() { - if (request_) - request_->Cancel(); + // Cancel the outstanding jobs. Those jobs may contain several attached + // requests, which will now never be completed. + for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) + it->second->Cancel(); + + // In case we are being deleted during the processing of a callback. + if (cur_completing_job_) + cur_completing_job_->Cancel(); } +// TODO(eroman): Don't create cache entries for hostnames which are simply IP +// address literals. int HostResolver::Resolve(const std::string& hostname, int port, AddressList* addresses, - CompletionCallback* callback) { - DCHECK(!request_) << "resolver already in use"; - - const std::string& port_str = IntToString(port); + CompletionCallback* callback, + Request** out_req) { + // If we have an unexpired cache entry, use it. + const HostCache::Entry* cache_entry = cache_.Lookup( + hostname, base::TimeTicks::Now()); + if (cache_entry) { + addresses->SetFrom(cache_entry->addrlist, port); + return OK; + } - // Do a synchronous resolution. + // If no callback was specified, do a synchronous resolution. if (!callback) { struct addrinfo* results; - int rv = ResolveAddrInfo(host_mapper, hostname, port_str, &results); - if (rv == OK) - addresses->Adopt(results); - return rv; - } + int error = ResolveAddrInfo(host_mapper, hostname, &results); + + // Adopt the address list. + AddressList addrlist; + if (error == OK) { + addrlist.Adopt(results); + addrlist.SetPort(port); + *addresses = addrlist; + } + + // Write to cache. + cache_.Set(hostname, error, addrlist, base::TimeTicks::Now()); - request_ = new Request(this, hostname, port_str, addresses, callback); + return error; + } - // Dispatch to worker thread... - if (!WorkerPool::PostTask(FROM_HERE, - NewRunnableMethod(request_.get(), &Request::DoLookup), true)) { - NOTREACHED(); - request_ = NULL; - return ERR_FAILED; + // Create a handle for this request, and pass it back to the user if they + // asked for it (out_req != NULL). + Request* req = new Request(callback, addresses, port); + if (out_req) + *out_req = req; + + // Next we need to attach our request to a "job". This job is responsible for + // calling "getaddrinfo(hostname)" on a worker thread. + scoped_refptr<Job> job; + + // If there is already an outstanding job to resolve |hostname|, use it. + // This prevents starting concurrent resolves for the same hostname. + job = FindOutstandingJob(hostname); + if (job) { + job->AddRequest(req); + } else { + // Create a new job for this request. + job = new Job(this, hostname); + job->AddRequest(req); + AddOutstandingJob(job); + // TODO(eroman): Bound the total number of concurrent jobs. + // http://crbug.com/9598 + job->Start(); } + // Completion happens during OnJobComplete(Job*). return ERR_IO_PENDING; } +// See OnJobComplete(Job*) for why it is important not to clean out +// cancelled requests from Job::requests_. +void HostResolver::CancelRequest(Request* req) { + DCHECK(req); + DCHECK(req->job()); + // NULL out the fields of req, to mark it as cancelled. + req->Cancel(); +} + +void HostResolver::AddOutstandingJob(Job* job) { + scoped_refptr<Job>& found_job = jobs_[job->host()]; + DCHECK(!found_job); + found_job = job; +} + +HostResolver::Job* HostResolver::FindOutstandingJob( + const std::string& hostname) { + JobMap::iterator it = jobs_.find(hostname); + if (it != jobs_.end()) + return it->second; + return NULL; +} + +void HostResolver::RemoveOutstandingJob(Job* job) { + JobMap::iterator it = jobs_.find(job->host()); + DCHECK(it != jobs_.end()); + DCHECK_EQ(it->second.get(), job); + jobs_.erase(it); +} + +void HostResolver::OnJobComplete(Job* job, + int error, + const AddressList& addrlist) { + RemoveOutstandingJob(job); + + // Write result to the cache. + cache_.Set(job->host(), error, addrlist, base::TimeTicks::Now()); + + // Make a note that we are executing within OnJobComplete() in case the + // HostResolver is deleted by a callback invocation. + DCHECK(!cur_completing_job_); + cur_completing_job_ = job; + + // Complete all of the requests that were attached to the job. + for (RequestsList::const_iterator it = job->requests().begin(); + it != job->requests().end(); ++it) { + Request* req = *it; + if (!req->was_cancelled()) { + DCHECK_EQ(job, req->job()); + req->OnComplete(error, addrlist); + + // Check if the job was cancelled as a result of running the callback. + // (Meaning that |this| was deleted). + if (job->was_cancelled()) + return; + } + } + + cur_completing_job_ = NULL; +} + +//----------------------------------------------------------------------------- + +SingleRequestHostResolver::SingleRequestHostResolver(HostResolver* resolver) + : resolver_(resolver), + cur_request_(NULL), + cur_request_callback_(NULL), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &SingleRequestHostResolver::OnResolveCompletion)) { + DCHECK(resolver_ != NULL); +} + +SingleRequestHostResolver::~SingleRequestHostResolver() { + if (cur_request_) { + resolver_->CancelRequest(cur_request_); + } +} + +int SingleRequestHostResolver::Resolve( + const std::string& hostname, int port, + AddressList* addresses, + CompletionCallback* callback) { + DCHECK(!cur_request_ && !cur_request_callback_) << "resolver already in use"; + + HostResolver::Request* request = NULL; + + // We need to be notified of completion before |callback| is called, so that + // we can clear out |cur_request_*|. + CompletionCallback* transient_callback = callback ? &callback_ : NULL; + + int rv = resolver_->Resolve( + hostname, port, addresses, transient_callback, &request); + + if (rv == ERR_IO_PENDING) { + // Cleared in OnResolveCompletion(). + cur_request_ = request; + cur_request_callback_ = callback; + } + + return rv; +} + +void SingleRequestHostResolver::OnResolveCompletion(int result) { + DCHECK(cur_request_ && cur_request_callback_); + + CompletionCallback* callback = cur_request_callback_; + + // Clear the outstanding request information. + cur_request_ = NULL; + cur_request_callback_ = NULL; + + // Call the user's original callback. + callback->Run(result); +} + } // namespace net diff --git a/net/base/host_resolver.h b/net/base/host_resolver.h index c730f76..a67419a 100644 --- a/net/base/host_resolver.h +++ b/net/base/host_resolver.h @@ -6,34 +6,73 @@ #define NET_BASE_HOST_RESOLVER_H_ #include <string> +#include <vector> #include "base/basictypes.h" +#include "base/lock.h" #include "base/ref_counted.h" #include "net/base/completion_callback.h" +#include "net/base/host_cache.h" + +class MessageLoop; namespace net { class AddressList; +class HostMapper; -// This class represents the task of resolving a hostname (or IP address -// literal) to an AddressList object. It can only resolve a single hostname at -// a time, so if you need to resolve multiple hostnames at the same time, you -// will need to allocate a HostResolver object for each hostname. +// This class represents the task of resolving hostnames (or IP address +// literal) to an AddressList object. +// +// HostResolver handles multiple requests at a time, so when cancelling a +// request the Request* handle that was returned by Resolve() needs to be +// given. A simpler alternative for consumers that only have 1 outstanding +// request at a time is to create a SingleRequestHostResolver wrapper around +// HostResolver (which will automatically cancel the single request when it +// goes out of scope). +// +// For each hostname that is requested, HostResolver creates a +// HostResolver::Job. This job gets dispatched to a thread in the global +// WorkerPool, where it runs "getaddrinfo(hostname)". If requests for that same +// host are made while the job is already outstanding, then they are attached +// to the existing job rather than creating a new one. This avoids doing +// parallel resolves for the same host. +// +// The way these classes fit together is illustrated by: +// // -// No attempt is made at this level to cache or pin resolution results. For -// each request, this API talks directly to the underlying name resolver of -// the local system, which may or may not result in a DNS query. The exact -// behavior depends on the system configuration. +// +------------- HostResolver ---------------+ +// | | | +// Job Job Job +// (for host1) (for host2) (for hostX) +// / | | / | | / | | +// Request ... Request Request ... Request Request ... Request +// (port1) (port2) (port3) (port4) (port5) (portX) +// +// +// When a HostResolver::Job finishes its work in the threadpool, the callbacks +// of each waiting request are run on the origin thread. +// +// Thread safety: This class is not threadsafe, and must only be called +// from one thread! // class HostResolver { public: - HostResolver(); + // Creates a HostResolver that caches up to |max_cache_entries| for + // |cache_duration_ms| milliseconds. + // + // TODO(eroman): Get rid of the default parameters as it violate google + // style. This is temporary to help with refactoring. + HostResolver(int max_cache_entries = 100, int cache_duration_ms = 60000); - // If a completion callback is pending when the resolver is destroyed, the - // host resolution is cancelled, and the completion callback will not be - // called. + // If any completion callbacks are pending when the resolver is destroyed, + // the host resolutions are cancelled, and the completion callbacks will not + // be called. ~HostResolver(); + // Opaque type used to cancel a request. + class Request; + // Resolves the given hostname (or IP address literal), filling out the // |addresses| object upon success. The |port| parameter will be set as the // sin(6)_port field of the sockaddr_in{6} struct. Returns OK if successful @@ -43,17 +82,82 @@ class HostResolver { // // When callback is non-null, the operation will be performed asynchronously. // ERR_IO_PENDING is returned if it has been scheduled successfully. Real - // result code will be passed to the completion callback. + // result code will be passed to the completion callback. If |req| is + // non-NULL, then |*req| will be filled with a handle to the async request. + // This handle is not valid after the request has completed. int Resolve(const std::string& hostname, int port, - AddressList* addresses, CompletionCallback* callback); + AddressList* addresses, CompletionCallback* callback, + Request** req); + + // Cancels the specified request. |req| is the handle returned by Resolve(). + // After a request is cancelled, its completion callback will not be called. + void CancelRequest(Request* req); private: - class Request; - friend class Request; - scoped_refptr<Request> request_; + class Job; + typedef std::vector<Request*> RequestsList; + typedef base::hash_map<std::string, scoped_refptr<Job> > JobMap; + + // Adds a job to outstanding jobs list. + void AddOutstandingJob(Job* job); + + // Returns the outstanding job for |hostname|, or NULL if there is none. + Job* FindOutstandingJob(const std::string& hostname); + + // Removes |job| from the outstanding jobs list. + void RemoveOutstandingJob(Job* job); + + // Callback for when |job| has completed with |error| and |addrlist|. + void OnJobComplete(Job* job, int error, const AddressList& addrlist); + + // Cache of host resolution results. + HostCache cache_; + + // Map from hostname to outstanding job. + JobMap jobs_; + + // The job that OnJobComplete() is currently processing (needed in case + // HostResolver gets deleted from within the callback). + scoped_refptr<Job> cur_completing_job_; + DISALLOW_COPY_AND_ASSIGN(HostResolver); }; +// This class represents the task of resolving a hostname (or IP address +// literal) to an AddressList object. It wraps HostResolver to resolve only a +// single hostname at a time and cancels this request when going out of scope. +class SingleRequestHostResolver { + public: + explicit SingleRequestHostResolver(HostResolver* resolver); + + // If a completion callback is pending when the resolver is destroyed, the + // host resolution is cancelled, and the completion callback will not be + // called. + ~SingleRequestHostResolver(); + + // Resolves the given hostname (or IP address literal), filling out the + // |addresses| object upon success. See HostResolver::Resolve() for details. + int Resolve(const std::string& hostname, int port, + AddressList* addresses, CompletionCallback* callback); + + private: + // Callback for when the request to |resolver_| completes, so we dispatch + // to the user's callback. + void OnResolveCompletion(int result); + + // The actual host resolver that will handle the request. + HostResolver* resolver_; + + // The current request (if any). + HostResolver::Request* cur_request_; + CompletionCallback* cur_request_callback_; + + // Completion callback for when request to |resolver_| completes. + net::CompletionCallbackImpl<SingleRequestHostResolver> callback_; + + DISALLOW_COPY_AND_ASSIGN(SingleRequestHostResolver); +}; + // A helper class used in unit tests to alter hostname mappings. See // SetHostMapper for details. class HostMapper : public base::RefCountedThreadSafe<HostMapper> { diff --git a/net/base/host_resolver_unittest.cc b/net/base/host_resolver_unittest.cc index d24d28e..090fd5c 100644 --- a/net/base/host_resolver_unittest.cc +++ b/net/base/host_resolver_unittest.cc @@ -26,8 +26,118 @@ using net::RuleBasedHostMapper; using net::ScopedHostMapper; using net::WaitingHostMapper; +// TODO(eroman): +// - Test mixing async with sync (in particular how does sync update the +// cache while an async is already pending). + namespace { +// A variant of WaitingHostMapper that pushes each host mapped into a list. +// (and uses a manual-reset event rather than auto-reset). +class CapturingHostMapper : public net::HostMapper { + public: + CapturingHostMapper() : event_(true, false) { + } + + void Signal() { + event_.Signal(); + } + + virtual std::string Map(const std::string& host) { + event_.Wait(); + { + AutoLock l(lock_); + capture_list_.push_back(host); + } + return MapUsingPrevious(host); + } + + std::vector<std::string> GetCaptureList() const { + std::vector<std::string> copy; + { + AutoLock l(lock_); + copy = capture_list_; + } + return copy; + } + + private: + std::vector<std::string> capture_list_; + mutable Lock lock_; + base::WaitableEvent event_; +}; + +// Helper that represents a single Resolve() result, used to inspect all the +// resolve results by forwarding them to Delegate. +class ResolveRequest { + public: + // Delegate interface, for notification when the ResolveRequest completes. + class Delegate { + public: + virtual ~Delegate() {} + virtual void OnCompleted(ResolveRequest* resolve) = 0; + }; + + ResolveRequest(net::HostResolver* resolver, + const std::string& hostname, + int port, + Delegate* delegate) + : hostname_(hostname), port_(port), resolver_(resolver), + delegate_(delegate), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &ResolveRequest::OnLookupFinished)) { + // Start the request. + int err = resolver->Resolve(hostname, port, &addrlist_, &callback_, &req_); + EXPECT_EQ(net::ERR_IO_PENDING, err); + } + + void Cancel() { + resolver_->CancelRequest(req_); + } + + const std::string& hostname() const { + return hostname_; + } + + int port() const { + return port_; + } + + int result() const { + return result_; + } + + const net::AddressList& addrlist() const { + return addrlist_; + } + + net::HostResolver* resolver() const { + return resolver_; + } + + private: + void OnLookupFinished(int result) { + result_ = result; + delegate_->OnCompleted(this); + } + + // The request details. + std::string hostname_; + int port_; + net::HostResolver::Request* req_; + + // The result of the resolve. + int result_; + net::AddressList addrlist_; + + net::HostResolver* resolver_; + + Delegate* delegate_; + net::CompletionCallbackImpl<ResolveRequest> callback_; + + DISALLOW_COPY_AND_ASSIGN(ResolveRequest); +}; + class HostResolverTest : public testing::Test { public: HostResolverTest() @@ -58,7 +168,8 @@ TEST_F(HostResolverTest, SynchronousLookup) { mapper->AddRule("just.testing", "192.168.1.42"); ScopedHostMapper scoped_mapper(mapper.get()); - int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, NULL, + NULL); EXPECT_EQ(net::OK, err); const struct addrinfo* ainfo = adrlist.head(); @@ -81,7 +192,7 @@ TEST_F(HostResolverTest, AsynchronousLookup) { ScopedHostMapper scoped_mapper(mapper.get()); int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, - &callback_); + &callback_, NULL); EXPECT_EQ(net::ERR_IO_PENDING, err); MessageLoop::current()->Run(); @@ -109,7 +220,7 @@ TEST_F(HostResolverTest, CanceledAsynchronousLookup) { const int kPortnum = 80; int err = host_resolver.Resolve("just.testing", kPortnum, &adrlist, - &callback_); + &callback_, NULL); EXPECT_EQ(net::ERR_IO_PENDING, err); // Make sure we will exit the queue even when callback is not called. @@ -134,7 +245,7 @@ TEST_F(HostResolverTest, NumericIPv4Address) { net::HostResolver host_resolver; net::AddressList adrlist; const int kPortnum = 5555; - int err = host_resolver.Resolve("127.1.2.3", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("127.1.2.3", kPortnum, &adrlist, NULL, NULL); EXPECT_EQ(net::OK, err); const struct addrinfo* ainfo = adrlist.head(); @@ -157,7 +268,8 @@ TEST_F(HostResolverTest, NumericIPv6Address) { net::HostResolver host_resolver; net::AddressList adrlist; const int kPortnum = 5555; - int err = host_resolver.Resolve("2001:db8::1", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("2001:db8::1", kPortnum, &adrlist, NULL, + NULL); // On computers without IPv6 support, getaddrinfo cannot convert IPv6 // address literals to addresses (getaddrinfo returns EAI_NONAME). So this // test has to allow host_resolver.Resolve to fail. @@ -190,8 +302,320 @@ TEST_F(HostResolverTest, EmptyHost) { net::HostResolver host_resolver; net::AddressList adrlist; const int kPortnum = 5555; - int err = host_resolver.Resolve("", kPortnum, &adrlist, NULL); + int err = host_resolver.Resolve("", kPortnum, &adrlist, NULL, NULL); EXPECT_EQ(net::ERR_NAME_NOT_RESOLVED, err); } +// Helper class used by HostResolverTest.DeDupeRequests. It receives request +// completion notifications for all the resolves, so it can tally up and +// determine when we are done. +class DeDupeRequestsVerifier : public ResolveRequest::Delegate { + public: + explicit DeDupeRequestsVerifier(CapturingHostMapper* mapper) + : count_a_(0), count_b_(0), mapper_(mapper) {} + + // The test does 5 resolves (which can complete in any order). + virtual void OnCompleted(ResolveRequest* resolve) { + // Tally up how many requests we have seen. + if (resolve->hostname() == "a") { + count_a_++; + } else if (resolve->hostname() == "b") { + count_b_++; + } else { + FAIL() << "Unexpected hostname: " << resolve->hostname(); + } + + // Check that the port was set correctly. + EXPECT_EQ(resolve->port(), resolve->addrlist().GetPort()); + + // Check whether all the requests have finished yet. + int total_completions = count_a_ + count_b_; + if (total_completions == 5) { + EXPECT_EQ(2, count_a_); + EXPECT_EQ(3, count_b_); + + // The mapper should have been called only twice -- once with "a", once + // with "b". + std::vector<std::string> capture_list = mapper_->GetCaptureList(); + EXPECT_EQ(2U, capture_list.size()); + + // End this test, we are done. + MessageLoop::current()->Quit(); + } + } + + private: + int count_a_; + int count_b_; + CapturingHostMapper* mapper_; + + DISALLOW_COPY_AND_ASSIGN(DeDupeRequestsVerifier); +}; + +TEST_F(HostResolverTest, DeDupeRequests) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + net::HostResolver host_resolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + DeDupeRequestsVerifier verifier(mapper.get()); + + // Start 5 requests, duplicating hosts "a" and "b". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "b", 80, &verifier); + ResolveRequest req3(&host_resolver, "b", 81, &verifier); + ResolveRequest req4(&host_resolver, "a", 82, &verifier); + ResolveRequest req5(&host_resolver, "b", 83, &verifier); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.CancelMultipleRequests. +class CancelMultipleRequestsVerifier : public ResolveRequest::Delegate { + public: + CancelMultipleRequestsVerifier() {} + + // The cancels kill all but one request. + virtual void OnCompleted(ResolveRequest* resolve) { + EXPECT_EQ("a", resolve->hostname()); + EXPECT_EQ(82, resolve->port()); + + // Check that the port was set correctly. + EXPECT_EQ(resolve->port(), resolve->addrlist().GetPort()); + + // End this test, we are done. + MessageLoop::current()->Quit(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(CancelMultipleRequestsVerifier); +}; + +TEST_F(HostResolverTest, CancelMultipleRequests) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + net::HostResolver host_resolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + CancelMultipleRequestsVerifier verifier; + + // Start 5 requests, duplicating hosts "a" and "b". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "b", 80, &verifier); + ResolveRequest req3(&host_resolver, "b", 81, &verifier); + ResolveRequest req4(&host_resolver, "a", 82, &verifier); + ResolveRequest req5(&host_resolver, "b", 83, &verifier); + + // Cancel everything except request 4. + req1.Cancel(); + req2.Cancel(); + req3.Cancel(); + req5.Cancel(); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.CancelWithinCallback. +class CancelWithinCallbackVerifier : public ResolveRequest::Delegate { + public: + CancelWithinCallbackVerifier() + : req_to_cancel1_(NULL), req_to_cancel2_(NULL), num_completions_(0) { + } + + virtual void OnCompleted(ResolveRequest* resolve) { + num_completions_++; + + // Port 80 is the first request that the callback will be invoked for. + // While we are executing within that callback, cancel the other requests + // in the job and start another request. + if (80 == resolve->port()) { + EXPECT_EQ("a", resolve->hostname()); + + req_to_cancel1_->Cancel(); + req_to_cancel2_->Cancel(); + + // Start a request (so we can make sure the canceled requests don't + // complete before "finalrequest" finishes. + final_request_.reset(new ResolveRequest( + resolve->resolver(), "finalrequest", 70, this)); + + } else if (83 == resolve->port()) { + EXPECT_EQ("a", resolve->hostname()); + } else if (resolve->hostname() == "finalrequest") { + EXPECT_EQ(70, resolve->addrlist().GetPort()); + + // End this test, we are done. + MessageLoop::current()->Quit(); + } else { + FAIL() << "Unexpected completion: " << resolve->hostname() << ", " + << resolve->port(); + } + } + + void SetRequestsToCancel(ResolveRequest* req_to_cancel1, + ResolveRequest* req_to_cancel2) { + req_to_cancel1_ = req_to_cancel1; + req_to_cancel2_ = req_to_cancel2; + } + + private: + scoped_ptr<ResolveRequest> final_request_; + ResolveRequest* req_to_cancel1_; + ResolveRequest* req_to_cancel2_; + int num_completions_; + DISALLOW_COPY_AND_ASSIGN(CancelWithinCallbackVerifier); +}; + +TEST_F(HostResolverTest, CancelWithinCallback) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + net::HostResolver host_resolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + CancelWithinCallbackVerifier verifier; + + // Start 4 requests, duplicating hosts "a". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "a", 81, &verifier); + ResolveRequest req3(&host_resolver, "a", 82, &verifier); + ResolveRequest req4(&host_resolver, "a", 83, &verifier); + + // Once "a:80" completes, it will cancel "a:81" and "a:82". + verifier.SetRequestsToCancel(&req2, &req3); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.DeleteWithinCallback. +class DeleteWithinCallbackVerifier : public ResolveRequest::Delegate { + public: + DeleteWithinCallbackVerifier() {} + + virtual void OnCompleted(ResolveRequest* resolve) { + EXPECT_EQ("a", resolve->hostname()); + EXPECT_EQ(80, resolve->port()); + delete resolve->resolver(); + + // Quit after returning from OnCompleted (to give it a chance at + // incorrectly running the cancelled tasks). + MessageLoop::current()->PostTask(FROM_HERE, new MessageLoop::QuitTask()); + } + + private: + DISALLOW_COPY_AND_ASSIGN(DeleteWithinCallbackVerifier); +}; + +TEST_F(HostResolverTest, DeleteWithinCallback) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + // This should be deleted by DeleteWithinCallbackVerifier -- if it leaks + // then the test has failed. + net::HostResolver* host_resolver = new net::HostResolver; + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + DeleteWithinCallbackVerifier verifier; + + // Start 4 requests, duplicating hosts "a". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(host_resolver, "a", 80, &verifier); + ResolveRequest req2(host_resolver, "a", 81, &verifier); + ResolveRequest req3(host_resolver, "a", 82, &verifier); + ResolveRequest req4(host_resolver, "a", 83, &verifier); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + +// Helper class used by HostResolverTest.StartWithinCallback. +class StartWithinCallbackVerifier : public ResolveRequest::Delegate { + public: + StartWithinCallbackVerifier() : num_requests_(0) {} + + virtual void OnCompleted(ResolveRequest* resolve) { + EXPECT_EQ("a", resolve->hostname()); + + if (80 == resolve->port()) { + // On completing the first request, start another request for "a". + // Since caching is disabled, this will result in another async request. + final_request_.reset(new ResolveRequest( + resolve->resolver(), "a", 70, this)); + } + if (++num_requests_ == 5) { + // Test is done. + MessageLoop::current()->Quit(); + } + } + + private: + int num_requests_; + scoped_ptr<ResolveRequest> final_request_; + DISALLOW_COPY_AND_ASSIGN(StartWithinCallbackVerifier); +}; + +TEST_F(HostResolverTest, StartWithinCallback) { + // Use a capturing mapper, since the verifier needs to know what calls + // reached Map(). Also, the capturing mapper is initially blocked. + scoped_refptr<CapturingHostMapper> mapper = new CapturingHostMapper(); + ScopedHostMapper scoped_mapper(mapper.get()); + + // Turn off caching for this host resolver. + net::HostResolver host_resolver(0, 0); + + // The class will receive callbacks for when each resolve completes. It + // checks that the right things happened. + StartWithinCallbackVerifier verifier; + + // Start 4 requests, duplicating hosts "a". Since the mapper is + // blocked, these should all pile up until we signal it. + + ResolveRequest req1(&host_resolver, "a", 80, &verifier); + ResolveRequest req2(&host_resolver, "a", 81, &verifier); + ResolveRequest req3(&host_resolver, "a", 82, &verifier); + ResolveRequest req4(&host_resolver, "a", 83, &verifier); + + // Ready, Set, GO!!! + mapper->Signal(); + + // |verifier| will send quit message once all the requests have finished. + MessageLoop::current()->Run(); +} + } // namespace diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc index d372e95..ab29cc4 100644 --- a/net/base/ssl_client_socket_unittest.cc +++ b/net/base/ssl_client_socket_unittest.cc @@ -77,7 +77,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Connect) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -115,7 +115,7 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectExpired) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kBadHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -152,7 +152,7 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectMismatched) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kMismatchedHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -194,7 +194,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Read) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, &callback); + &addr, &callback, NULL); EXPECT_EQ(net::ERR_IO_PENDING, rv); rv = callback.WaitForResult(); @@ -255,7 +255,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Read_SmallChunks) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); @@ -311,7 +311,7 @@ TEST_F(SSLClientSocketTest, MAYBE_Read_Interrupted) { TestCompletionCallback callback; int rv = resolver.Resolve(server_.kHostName, server_.kOKHTTPSPort, - &addr, NULL); + &addr, NULL, NULL); EXPECT_EQ(net::OK, rv); net::ClientSocket *transport = new net::TCPClientSocket(addr); diff --git a/net/base/ssl_test_util.cc b/net/base/ssl_test_util.cc index a3fe3b9..cac5d04 100644 --- a/net/base/ssl_test_util.cc +++ b/net/base/ssl_test_util.cc @@ -247,7 +247,7 @@ bool TestServerLauncher::WaitToStart(const std::string& host_name, int port) { // Otherwise tests can fail if they run faster than Python can start. net::AddressList addr; net::HostResolver resolver; - int rv = resolver.Resolve(host_name, port, &addr, NULL); + int rv = resolver.Resolve(host_name, port, &addr, NULL, NULL); if (rv != net::OK) return false; diff --git a/net/base/tcp_client_socket_pool.cc b/net/base/tcp_client_socket_pool.cc index 75875f2..7f7f13f 100644 --- a/net/base/tcp_client_socket_pool.cc +++ b/net/base/tcp_client_socket_pool.cc @@ -46,6 +46,7 @@ TCPClientSocketPool::ConnectingSocket::ConnectingSocket( callback_(this, &TCPClientSocketPool::ConnectingSocket::OnIOComplete)), pool_(pool), + resolver_(pool->GetHostResolver()), canceled_(false) { DCHECK(!ContainsKey(pool_->connecting_socket_map_, handle)); pool_->connecting_socket_map_[handle] = this; @@ -158,10 +159,12 @@ void TCPClientSocketPool::ConnectingSocket::Cancel() { TCPClientSocketPool::TCPClientSocketPool( int max_sockets_per_group, + HostResolver* host_resolver, ClientSocketFactory* client_socket_factory) : client_socket_factory_(client_socket_factory), idle_socket_count_(0), - max_sockets_per_group_(max_sockets_per_group) { + max_sockets_per_group_(max_sockets_per_group), + host_resolver_(host_resolver) { } TCPClientSocketPool::~TCPClientSocketPool() { diff --git a/net/base/tcp_client_socket_pool.h b/net/base/tcp_client_socket_pool.h index 8d4a3bd..82255c5 100644 --- a/net/base/tcp_client_socket_pool.h +++ b/net/base/tcp_client_socket_pool.h @@ -25,6 +25,7 @@ class ClientSocketFactory; class TCPClientSocketPool : public ClientSocketPool { public: TCPClientSocketPool(int max_sockets_per_group, + HostResolver* host_resolver, ClientSocketFactory* client_socket_factory); // ClientSocketPool methods: @@ -44,6 +45,10 @@ class TCPClientSocketPool : public ClientSocketPool { virtual void CloseIdleSockets(); + virtual HostResolver* GetHostResolver() const { + return host_resolver_; + } + virtual int idle_socket_count() const { return idle_socket_count_; } @@ -137,7 +142,7 @@ class TCPClientSocketPool : public ClientSocketPool { CompletionCallbackImpl<ConnectingSocket> callback_; scoped_ptr<ClientSocket> socket_; scoped_refptr<TCPClientSocketPool> pool_; - HostResolver resolver_; + SingleRequestHostResolver resolver_; AddressList addresses_; bool canceled_; @@ -185,6 +190,10 @@ class TCPClientSocketPool : public ClientSocketPool { // The maximum number of sockets kept per group. const int max_sockets_per_group_; + // The host resolver that will be used to do DNS lookups for connecting + // sockets. + HostResolver* host_resolver_; + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketPool); }; diff --git a/net/base/tcp_client_socket_pool_unittest.cc b/net/base/tcp_client_socket_pool_unittest.cc index ce85871..2073497 100644 --- a/net/base/tcp_client_socket_pool_unittest.cc +++ b/net/base/tcp_client_socket_pool_unittest.cc @@ -198,13 +198,19 @@ int TestSocketRequest::completion_count = 0; class TCPClientSocketPoolTest : public testing::Test { protected: TCPClientSocketPoolTest() - : pool_(new TCPClientSocketPool(kMaxSocketsPerGroup, + // We disable caching here since these unit tests don't expect + // host resolving to be able to complete synchronously. + // TODO(eroman): enable caching. + : host_resolver_(0, 0), + pool_(new TCPClientSocketPool(kMaxSocketsPerGroup, + &host_resolver_, &client_socket_factory_)) {} virtual void SetUp() { TestSocketRequest::completion_count = 0; } + HostResolver host_resolver_; MockClientSocketFactory client_socket_factory_; scoped_refptr<ClientSocketPool> pool_; std::vector<TestSocketRequest*> request_order_; diff --git a/net/base/tcp_client_socket_unittest.cc b/net/base/tcp_client_socket_unittest.cc index 9dad6e2..fe37c71 100644 --- a/net/base/tcp_client_socket_unittest.cc +++ b/net/base/tcp_client_socket_unittest.cc @@ -87,7 +87,7 @@ void TCPClientSocketTest::SetUp() { AddressList addr; HostResolver resolver; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); + int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL, NULL); CHECK(rv == OK); sock_.reset(new TCPClientSocket(addr)); } diff --git a/net/base/tcp_pinger_unittest.cc b/net/base/tcp_pinger_unittest.cc index 2ae6b1f..edf29a5 100644 --- a/net/base/tcp_pinger_unittest.cc +++ b/net/base/tcp_pinger_unittest.cc @@ -67,7 +67,7 @@ TEST_F(TCPPingerTest, Ping) { net::AddressList addr; net::HostResolver resolver; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); + int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL, NULL); EXPECT_EQ(rv, net::OK); net::TCPPinger pinger(addr); @@ -82,7 +82,7 @@ TEST_F(TCPPingerTest, PingFail) { // "Kill" "server" listen_sock_ = NULL; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); + int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL, NULL); EXPECT_EQ(rv, net::OK); net::TCPPinger pinger(addr); diff --git a/net/ftp/ftp_network_layer.cc b/net/ftp/ftp_network_layer.cc index e05b67f..fa0e87c 100644 --- a/net/ftp/ftp_network_layer.cc +++ b/net/ftp/ftp_network_layer.cc @@ -10,17 +10,18 @@ namespace net { -FtpNetworkLayer::FtpNetworkLayer() - : suspended_(false) { - session_ = new FtpNetworkSession(); +FtpNetworkLayer::FtpNetworkLayer(HostResolver* host_resolver) + : session_(new FtpNetworkSession(host_resolver)), + suspended_(false) { } FtpNetworkLayer::~FtpNetworkLayer() { } // static -FtpTransactionFactory* FtpNetworkLayer::CreateFactory() { - return new FtpNetworkLayer(); +FtpTransactionFactory* FtpNetworkLayer::CreateFactory( + HostResolver* host_resolver) { + return new FtpNetworkLayer(host_resolver); } FtpTransaction* FtpNetworkLayer::CreateTransaction() { diff --git a/net/ftp/ftp_network_layer.h b/net/ftp/ftp_network_layer.h index 71bd3b9..2d37eae 100644 --- a/net/ftp/ftp_network_layer.h +++ b/net/ftp/ftp_network_layer.h @@ -12,13 +12,14 @@ namespace net { class FtpNetworkSession; class FtpAuthCache; +class HostResolver; class FtpNetworkLayer : public FtpTransactionFactory { public: - FtpNetworkLayer(); + explicit FtpNetworkLayer(HostResolver* host_resolver); ~FtpNetworkLayer(); - static FtpTransactionFactory* CreateFactory(); + static FtpTransactionFactory* CreateFactory(HostResolver* host_resolver); // FtpTransactionFactory methods: virtual FtpTransaction* CreateTransaction(); diff --git a/net/ftp/ftp_network_session.h b/net/ftp/ftp_network_session.h index 13ab216..29c34d2 100644 --- a/net/ftp/ftp_network_session.h +++ b/net/ftp/ftp_network_session.h @@ -10,14 +10,19 @@ namespace net { +class HostResolver; + // This class holds session objects used by FtpNetworkTransaction objects. class FtpNetworkSession : public base::RefCounted<FtpNetworkSession> { public: - FtpNetworkSession() {} + explicit FtpNetworkSession(HostResolver* host_resolver) + : host_resolver_(host_resolver) {} + HostResolver* host_resolver() { return host_resolver_; } FtpAuthCache* auth_cache() { return &auth_cache_; } private: + HostResolver* host_resolver_; FtpAuthCache auth_cache_; }; diff --git a/net/ftp/ftp_network_transaction.cc b/net/ftp/ftp_network_transaction.cc index ae6e4fc..87f9217 100644 --- a/net/ftp/ftp_network_transaction.cc +++ b/net/ftp/ftp_network_transaction.cc @@ -34,6 +34,7 @@ FtpNetworkTransaction::FtpNetworkTransaction( user_callback_(NULL), session_(session), request_(NULL), + resolver_(session->host_resolver()), read_ctrl_buf_size_(kCtrlBufLen), response_message_buf_len_(0), read_data_buf_len_(0), diff --git a/net/ftp/ftp_network_transaction.h b/net/ftp/ftp_network_transaction.h index 0bb67e5..53f0b29 100644 --- a/net/ftp/ftp_network_transaction.h +++ b/net/ftp/ftp_network_transaction.h @@ -142,7 +142,8 @@ class FtpNetworkTransaction : public FtpTransaction { const FtpRequestInfo* request_; FtpResponseInfo response_; - HostResolver resolver_; + // Cancels the outstanding request on destruction. + SingleRequestHostResolver resolver_; AddressList addresses_; // User buffer and length passed to the Read method. diff --git a/net/http/http_cache.cc b/net/http/http_cache.cc index 497e141..7baa3fc 100644 --- a/net/http/http_cache.cc +++ b/net/http/http_cache.cc @@ -960,13 +960,15 @@ void HttpCache::Transaction::OnCacheReadCompleted(int result) { //----------------------------------------------------------------------------- -HttpCache::HttpCache(ProxyService* proxy_service, +HttpCache::HttpCache(HostResolver* host_resolver, + ProxyService* proxy_service, const std::wstring& cache_dir, int cache_size) : disk_cache_dir_(cache_dir), mode_(NORMAL), type_(DISK_CACHE), - network_layer_(HttpNetworkLayer::CreateFactory(proxy_service)), + network_layer_(HttpNetworkLayer::CreateFactory( + host_resolver, proxy_service)), ALLOW_THIS_IN_INITIALIZER_LIST(task_factory_(this)), in_memory_cache_(false), cache_size_(cache_size) { @@ -984,10 +986,13 @@ HttpCache::HttpCache(HttpNetworkSession* session, cache_size_(cache_size) { } -HttpCache::HttpCache(ProxyService* proxy_service, int cache_size) +HttpCache::HttpCache(HostResolver* host_resolver, + ProxyService* proxy_service, + int cache_size) : mode_(NORMAL), type_(MEMORY_CACHE), - network_layer_(HttpNetworkLayer::CreateFactory(proxy_service)), + network_layer_(HttpNetworkLayer::CreateFactory( + host_resolver, proxy_service)), ALLOW_THIS_IN_INITIALIZER_LIST(task_factory_(this)), in_memory_cache_(true), cache_size_(cache_size) { diff --git a/net/http/http_cache.h b/net/http/http_cache.h index 70aaf0d..9fca9ae 100644 --- a/net/http/http_cache.h +++ b/net/http/http_cache.h @@ -31,6 +31,7 @@ class Entry; namespace net { +class HostResolver; class HttpNetworkSession; class HttpRequestInfo; class HttpResponseInfo; @@ -57,7 +58,8 @@ class HttpCache : public HttpTransactionFactory { // Initialize the cache from the directory where its data is stored. The // disk cache is initialized lazily (by CreateTransaction) in this case. If // |cache_size| is zero, a default value will be calculated automatically. - HttpCache(ProxyService* proxy_service, + HttpCache(HostResolver* host_resolver, + ProxyService* proxy_service, const std::wstring& cache_dir, int cache_size); @@ -73,7 +75,9 @@ class HttpCache : public HttpTransactionFactory { // Initialize using an in-memory cache. The cache is initialized lazily // (by CreateTransaction) in this case. If |cache_size| is zero, a default // value will be calculated automatically. - HttpCache(ProxyService* proxy_service, int cache_size); + HttpCache(HostResolver* host_resolver, + ProxyService* proxy_service, + int cache_size); // Initialize the cache from its component parts, which is useful for // testing. The lifetime of the network_layer and disk_cache are managed by diff --git a/net/http/http_network_layer.cc b/net/http/http_network_layer.cc index b46012f..3a8122f 100644 --- a/net/http/http_network_layer.cc +++ b/net/http/http_network_layer.cc @@ -15,10 +15,11 @@ namespace net { // static HttpTransactionFactory* HttpNetworkLayer::CreateFactory( + HostResolver* host_resolver, ProxyService* proxy_service) { DCHECK(proxy_service); - return new HttpNetworkLayer(proxy_service); + return new HttpNetworkLayer(host_resolver, proxy_service); } // static @@ -31,8 +32,12 @@ HttpTransactionFactory* HttpNetworkLayer::CreateFactory( //----------------------------------------------------------------------------- -HttpNetworkLayer::HttpNetworkLayer(ProxyService* proxy_service) - : proxy_service_(proxy_service), session_(NULL), suspended_(false) { +HttpNetworkLayer::HttpNetworkLayer(HostResolver* host_resolver, + ProxyService* proxy_service) + : host_resolver_(host_resolver), + proxy_service_(proxy_service), + session_(NULL), + suspended_(false) { DCHECK(proxy_service_); } @@ -66,7 +71,7 @@ void HttpNetworkLayer::Suspend(bool suspend) { HttpNetworkSession* HttpNetworkLayer::GetSession() { if (!session_) { DCHECK(proxy_service_); - session_ = new HttpNetworkSession(proxy_service_, + session_ = new HttpNetworkSession(host_resolver_, proxy_service_, ClientSocketFactory::GetDefaultFactory()); } return session_; diff --git a/net/http/http_network_layer.h b/net/http/http_network_layer.h index 2011a6c..acee57e 100644 --- a/net/http/http_network_layer.h +++ b/net/http/http_network_layer.h @@ -11,14 +11,16 @@ namespace net { +class HostResolver; class HttpNetworkSession; class ProxyInfo; class ProxyService; class HttpNetworkLayer : public HttpTransactionFactory { public: - // |proxy_service| must remain valid for the lifetime of HttpNetworkLayer. - explicit HttpNetworkLayer(ProxyService* proxy_service); + // |proxy_service| and |host_resolver| must remain valid for the lifetime of + // HttpNetworkLayer. + HttpNetworkLayer(HostResolver* host_resolver, ProxyService* proxy_service); // Construct a HttpNetworkLayer with an existing HttpNetworkSession which // contains a valid ProxyService. explicit HttpNetworkLayer(HttpNetworkSession* session); @@ -26,7 +28,8 @@ class HttpNetworkLayer : public HttpTransactionFactory { // This function hides the details of how a network layer gets instantiated // and allows other implementations to be substituted. - static HttpTransactionFactory* CreateFactory(ProxyService* proxy_service); + static HttpTransactionFactory* CreateFactory(HostResolver* host_resolver, + ProxyService* proxy_service); // Create a transaction factory that instantiate a network layer over an // existing network session. Network session contains some valuable // information (e.g. authentication data) that we want to share across @@ -43,6 +46,9 @@ class HttpNetworkLayer : public HttpTransactionFactory { HttpNetworkSession* GetSession(); private: + // The host resolver being used for the session. + HostResolver* host_resolver_; + // The proxy service being used for the session. ProxyService* proxy_service_; diff --git a/net/http/http_network_layer_unittest.cc b/net/http/http_network_layer_unittest.cc index 33891b4..335cb93 100644 --- a/net/http/http_network_layer_unittest.cc +++ b/net/http/http_network_layer_unittest.cc @@ -26,15 +26,17 @@ class HttpNetworkLayerTest : public PlatformTest { }; TEST_F(HttpNetworkLayerTest, CreateAndDestroy) { + net::HostResolver host_resolver; scoped_ptr<net::ProxyService> proxy_service(net::ProxyService::CreateNull()); - net::HttpNetworkLayer factory(proxy_service.get()); + net::HttpNetworkLayer factory(&host_resolver, proxy_service.get()); scoped_ptr<net::HttpTransaction> trans(factory.CreateTransaction()); } TEST_F(HttpNetworkLayerTest, Suspend) { + net::HostResolver host_resolver; scoped_ptr<net::ProxyService> proxy_service(net::ProxyService::CreateNull()); - net::HttpNetworkLayer factory(proxy_service.get()); + net::HttpNetworkLayer factory(&host_resolver, proxy_service.get()); scoped_ptr<net::HttpTransaction> trans(factory.CreateTransaction()); trans.reset(); @@ -50,8 +52,9 @@ TEST_F(HttpNetworkLayerTest, Suspend) { } TEST_F(HttpNetworkLayerTest, GoogleGET) { + net::HostResolver host_resolver; scoped_ptr<net::ProxyService> proxy_service(net::ProxyService::CreateNull()); - net::HttpNetworkLayer factory(proxy_service.get()); + net::HttpNetworkLayer factory(&host_resolver, proxy_service.get()); TestCompletionCallback callback; diff --git a/net/http/http_network_session.h b/net/http/http_network_session.h index cb023a0..c7ebb3c 100644 --- a/net/http/http_network_session.h +++ b/net/http/http_network_session.h @@ -13,21 +13,24 @@ namespace net { class ClientSocketFactory; +class HostResolver; class ProxyService; // This class holds session objects used by HttpNetworkTransaction objects. class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { public: - HttpNetworkSession(ProxyService* proxy_service, + HttpNetworkSession(HostResolver* host_resolver, ProxyService* proxy_service, ClientSocketFactory* client_socket_factory) : connection_pool_(new TCPClientSocketPool( - max_sockets_per_group_, client_socket_factory)), + max_sockets_per_group_, host_resolver, client_socket_factory)), + host_resolver_(host_resolver), proxy_service_(proxy_service) { DCHECK(proxy_service); } HttpAuthCache* auth_cache() { return &auth_cache_; } ClientSocketPool* connection_pool() { return connection_pool_; } + HostResolver* host_resolver() { return host_resolver_; } ProxyService* proxy_service() { return proxy_service_; } #if defined(OS_WIN) SSLConfigService* ssl_config_service() { return &ssl_config_service_; } @@ -43,6 +46,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession> { HttpAuthCache auth_cache_; scoped_refptr<ClientSocketPool> connection_pool_; + HostResolver* host_resolver_; ProxyService* proxy_service_; #if defined(OS_WIN) // TODO(port): Port the SSLConfigService class to Linux and Mac OS X. diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index 8bfe386..b4185df 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -13,7 +13,6 @@ #include "net/base/client_socket_factory.h" #include "net/base/connection_type_histograms.h" #include "net/base/dns_resolution_observer.h" -#include "net/base/host_resolver.h" #include "net/base/io_buffer.h" #include "net/base/load_flags.h" #include "net/base/net_errors.h" diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index af3a5f4..fb21384 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -292,9 +292,6 @@ class HttpNetworkTransaction : public HttpTransaction { ProxyService::PacRequest* pac_request_; ProxyInfo proxy_info_; - HostResolver resolver_; - AddressList addresses_; - ClientSocketFactory* socket_factory_; ClientSocketHandle connection_; scoped_ptr<HttpStream> http_stream_; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 852612f..e5c4ccc 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -41,6 +41,7 @@ class SessionDependencies { explicit SessionDependencies(ProxyService* proxy_service) : proxy_service(proxy_service) {} + HostResolver host_resolver; scoped_ptr<ProxyService> proxy_service; MockClientSocketFactory socket_factory; }; @@ -53,7 +54,8 @@ ProxyService* CreateFixedProxyService(const std::string& proxy) { HttpNetworkSession* CreateSession(SessionDependencies* session_deps) { - return new HttpNetworkSession(session_deps->proxy_service.get(), + return new HttpNetworkSession(&session_deps->host_resolver, + session_deps->proxy_service.get(), &session_deps->socket_factory); } diff --git a/net/net.gyp b/net/net.gyp index 37bc3a7..fbac3e0 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -80,6 +80,8 @@ 'base/gzip_filter.h', 'base/gzip_header.cc', 'base/gzip_header.h', + 'base/host_cache.cc', + 'base/host_cache.h', 'base/host_resolver.cc', 'base/host_resolver.h', 'base/io_buffer.cc', @@ -394,6 +396,7 @@ ], 'msvs_guid': 'E99DA267-BE90-4F45-88A1-6919DB2C7567', 'sources': [ + 'base/address_list_unittest.cc', 'base/base64_unittest.cc', 'base/bzip2_filter_unittest.cc', 'base/cookie_monster_unittest.cc', @@ -406,6 +409,7 @@ 'base/filter_unittest.h', 'base/force_tls_state_unittest.cc', 'base/gzip_filter_unittest.cc', + 'base/host_cache_unittest.cc', 'base/host_resolver_unittest.cc', 'base/listen_socket_unittest.cc', 'base/listen_socket_unittest.h', diff --git a/net/proxy/proxy_resolver_perftest.cc b/net/proxy/proxy_resolver_perftest.cc index 3ad7d2a0..85ad554 100644 --- a/net/proxy/proxy_resolver_perftest.cc +++ b/net/proxy/proxy_resolver_perftest.cc @@ -185,7 +185,12 @@ TEST(ProxyResolverPerfTest, ProxyResolverMac) { #endif TEST(ProxyResolverPerfTest, ProxyResolverV8) { - net::ProxyResolverV8 resolver; + net::HostResolver host_resolver; + + net::ProxyResolverV8::JSBindings* js_bindings = + net::ProxyResolverV8::CreateDefaultBindings(&host_resolver, NULL); + + net::ProxyResolverV8 resolver(js_bindings); PacPerfSuiteRunner runner(&resolver, "ProxyResolverV8"); runner.RunAllTests(); } diff --git a/net/proxy/proxy_resolver_v8.cc b/net/proxy/proxy_resolver_v8.cc index a5d1c76..38e0a83 100644 --- a/net/proxy/proxy_resolver_v8.cc +++ b/net/proxy/proxy_resolver_v8.cc @@ -4,7 +4,10 @@ #include "net/proxy/proxy_resolver_v8.h" +#include "base/compiler_specific.h" #include "base/logging.h" +#include "base/message_loop.h" +#include "base/waitable_event.h" #include "base/string_util.h" #include "googleurl/src/gurl.h" #include "net/base/address_list.h" @@ -50,9 +53,78 @@ bool V8ObjectToString(v8::Handle<v8::Value> object, std::string* result) { return true; } +// Wrapper around HostResolver to give a sync API while running the resolve +// in async mode on |host_resolver_loop|. If |host_resolver_loop| is NULL, +// runs sync on the current thread (this mode is just used by testing). +class SyncHostResolverBridge + : public base::RefCountedThreadSafe<SyncHostResolverBridge> { + public: + SyncHostResolverBridge(HostResolver* host_resolver, + MessageLoop* host_resolver_loop) + : host_resolver_(host_resolver), + host_resolver_loop_(host_resolver_loop), + event_(false, false), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &SyncHostResolverBridge::OnResolveCompletion)) { + } + + // Run the resolve on host_resolver_loop, and wait for result. + int Resolve(const std::string& hostname, net::AddressList* addresses) { + int kPort = 80; // Doesn't matter. + + // Hack for tests -- run synchronously on current thread. + if (!host_resolver_loop_) + return host_resolver_->Resolve(hostname, kPort, addresses, NULL, NULL); + + // Otherwise start an async resolve on the resolver's thread. + host_resolver_loop_->PostTask(FROM_HERE, NewRunnableMethod(this, + &SyncHostResolverBridge::StartResolve, hostname, kPort, addresses)); + + // Wait for the resolve to complete in the resolver's thread. + event_.Wait(); + return err_; + } + + private: + // Called on host_resolver_loop_. + void StartResolve(const std::string& hostname, + int port, + net::AddressList* addresses) { + DCHECK_EQ(host_resolver_loop_, MessageLoop::current()); + int error = host_resolver_->Resolve( + hostname, port, addresses, &callback_, NULL); + if (error != ERR_IO_PENDING) + OnResolveCompletion(error); // Completed synchronously. + } + + // Called on host_resolver_loop_. + void OnResolveCompletion(int result) { + DCHECK_EQ(host_resolver_loop_, MessageLoop::current()); + err_ = result; + event_.Signal(); + } + + HostResolver* host_resolver_; + MessageLoop* host_resolver_loop_; + + // Event to notify completion of resolve request. + base::WaitableEvent event_; + + // Callback for when the resolve completes on host_resolver_loop_. + net::CompletionCallbackImpl<SyncHostResolverBridge> callback_; + + // The result from the result request (set by in host_resolver_loop_). + int err_; +}; + // JSBIndings implementation. class DefaultJSBindings : public ProxyResolverV8::JSBindings { public: + DefaultJSBindings(HostResolver* host_resolver, + MessageLoop* host_resolver_loop) + : host_resolver_(new SyncHostResolverBridge( + host_resolver, host_resolver_loop)) {} + // Handler for "alert(message)". virtual void Alert(const std::string& message) { LOG(INFO) << "PAC-alert: " << message; @@ -71,10 +143,9 @@ class DefaultJSBindings : public ProxyResolverV8::JSBindings { if (host.empty()) return std::string(); - // Try to resolve synchronously. + // Do a sync resolve of the hostname. net::AddressList address_list; - const int kPort = 80; // Doesn't matter what this is. - int result = host_resolver_.Resolve(host, kPort, &address_list, NULL); + int result = host_resolver_->Resolve(host, &address_list); if (result != OK) return std::string(); // Failed. @@ -96,7 +167,7 @@ class DefaultJSBindings : public ProxyResolverV8::JSBindings { } private: - HostResolver host_resolver_; + scoped_refptr<SyncHostResolverBridge> host_resolver_; }; } // namespace @@ -106,7 +177,7 @@ class DefaultJSBindings : public ProxyResolverV8::JSBindings { class ProxyResolverV8::Context { public: Context(JSBindings* js_bindings, const std::string& pac_data) - : js_bindings_(js_bindings) { + : js_bindings_(js_bindings) { DCHECK(js_bindings != NULL); InitV8(pac_data); } @@ -266,20 +337,12 @@ class ProxyResolverV8::Context { } JSBindings* js_bindings_; - HostResolver host_resolver_; v8::Persistent<v8::External> v8_this_; v8::Persistent<v8::Context> v8_context_; }; // ProxyResolverV8 ------------------------------------------------------------ -// the |false| argument to ProxyResolver means the ProxyService will handle -// downloading of the PAC script, and notify changes through SetPacScript(). -ProxyResolverV8::ProxyResolverV8() - : ProxyResolver(false /*does_fetch*/), - js_bindings_(new DefaultJSBindings()) { -} - ProxyResolverV8::ProxyResolverV8( ProxyResolverV8::JSBindings* custom_js_bindings) : ProxyResolver(false), js_bindings_(custom_js_bindings) { @@ -305,4 +368,10 @@ void ProxyResolverV8::SetPacScript(const std::string& data) { context_.reset(new Context(js_bindings_.get(), data)); } +// static +ProxyResolverV8::JSBindings* ProxyResolverV8::CreateDefaultBindings( + HostResolver* host_resolver, MessageLoop* host_resolver_loop) { + return new DefaultJSBindings(host_resolver, host_resolver_loop); +} + } // namespace net diff --git a/net/proxy/proxy_resolver_v8.h b/net/proxy/proxy_resolver_v8.h index 219560c..3bf2bec 100644 --- a/net/proxy/proxy_resolver_v8.h +++ b/net/proxy/proxy_resolver_v8.h @@ -10,8 +10,12 @@ #include "base/scoped_ptr.h" #include "net/proxy/proxy_resolver.h" +class MessageLoop; + namespace net { +class HostResolver; + // Implementation of ProxyResolver that uses V8 to evaluate PAC scripts. // // ---------------------------------------------------------------------------- @@ -32,19 +36,6 @@ namespace net { // and does not use locking since it expects to be alone. class ProxyResolverV8 : public ProxyResolver { public: - // Constructs a ProxyResolverV8 with default javascript bindings. - // - // The default javascript bindings will: - // - Send script error messages to LOG(INFO) - // - Send script alert()s to LOG(INFO) - // - Use the default host mapper to service dnsResolve(), synchronously - // on the V8 thread. - // - // For clients that need more control (for example, sending the script output - // to a UI widget), use the ProxyResolverV8(JSBindings*) and specify your - // own bindings. - ProxyResolverV8(); - class JSBindings; // Constructs a ProxyResolverV8 with custom bindings. ProxyResolverV8 takes @@ -62,6 +53,21 @@ class ProxyResolverV8 : public ProxyResolver { JSBindings* js_bindings() const { return js_bindings_.get(); } + // Creates a default javascript bindings implementation that will: + // - Send script error messages to LOG(INFO) + // - Send script alert()s to LOG(INFO) + // - Use the provided host mapper to service dnsResolve(). + // + // For clients that need more control (for example, sending the script output + // to a UI widget), use the ProxyResolverV8(JSBindings*) and specify your + // own bindings. + // + // |host_resolver| will be used in async mode on |host_resolver_loop|. If + // |host_resolver_loop| is NULL, then |host_resolver| will be used in sync + // mode on the PAC thread. + static JSBindings* CreateDefaultBindings(HostResolver* host_resolver, + MessageLoop* host_resolver_loop); + private: // Context holds the Javascript state for the most recently loaded PAC // script. It corresponds with the data from the last call to diff --git a/net/proxy/proxy_resolver_v8_unittest.cc b/net/proxy/proxy_resolver_v8_unittest.cc index dff8843..0e8ba47 100644 --- a/net/proxy/proxy_resolver_v8_unittest.cc +++ b/net/proxy/proxy_resolver_v8_unittest.cc @@ -6,6 +6,7 @@ #include "base/string_util.h" #include "base/path_service.h" #include "googleurl/src/gurl.h" +#include "net/base/host_resolver.h" #include "net/base/net_errors.h" #include "net/proxy/proxy_resolver_v8.h" #include "net/proxy/proxy_info.h" @@ -377,8 +378,9 @@ TEST(ProxyResolverV8Test, V8Bindings) { TEST(ProxyResolverV8DefaultBindingsTest, DnsResolve) { // Get a hold of a DefaultJSBindings* (it is a hidden impl class). - net::ProxyResolverV8 resolver; - net::ProxyResolverV8::JSBindings* bindings = resolver.js_bindings(); + net::HostResolver host_resolver; + scoped_ptr<net::ProxyResolverV8::JSBindings> bindings( + net::ProxyResolverV8::CreateDefaultBindings(&host_resolver, NULL)); // Considered an error. EXPECT_EQ("", bindings->DnsResolve("")); @@ -428,8 +430,9 @@ TEST(ProxyResolverV8DefaultBindingsTest, DnsResolve) { TEST(ProxyResolverV8DefaultBindingsTest, MyIpAddress) { // Get a hold of a DefaultJSBindings* (it is a hidden impl class). - net::ProxyResolverV8 resolver; - net::ProxyResolverV8::JSBindings* bindings = resolver.js_bindings(); + net::HostResolver host_resolver; + scoped_ptr<net::ProxyResolverV8::JSBindings> bindings( + net::ProxyResolverV8::CreateDefaultBindings(&host_resolver, NULL)); // Our ip address is always going to be 127.0.0.1, since we are using a // mock host mapper when running in unit-test mode. diff --git a/net/proxy/proxy_script_fetcher_unittest.cc b/net/proxy/proxy_script_fetcher_unittest.cc index 3614cfe..5b16738b 100644 --- a/net/proxy/proxy_script_fetcher_unittest.cc +++ b/net/proxy/proxy_script_fetcher_unittest.cc @@ -30,15 +30,18 @@ class RequestContext : public URLRequestContext { public: RequestContext() { net::ProxyConfig no_proxy; + host_resolver_ = new net::HostResolver; proxy_service_ = net::ProxyService::CreateFixed(no_proxy); http_transaction_factory_ = - new net::HttpCache(net::HttpNetworkLayer::CreateFactory(proxy_service_), - disk_cache::CreateInMemoryCacheBackend(0)); + new net::HttpCache(net::HttpNetworkLayer::CreateFactory( + host_resolver_, proxy_service_), + disk_cache::CreateInMemoryCacheBackend(0)); } ~RequestContext() { delete http_transaction_factory_; delete proxy_service_; + delete host_resolver_; } }; diff --git a/net/proxy/proxy_service.cc b/net/proxy/proxy_service.cc index 65db636..3a0cd7f 100644 --- a/net/proxy/proxy_service.cc +++ b/net/proxy/proxy_service.cc @@ -23,6 +23,7 @@ #endif #include "net/proxy/proxy_resolver.h" #include "net/proxy/proxy_resolver_v8.h" +#include "net/url_request/url_request_context.h" using base::TimeDelta; using base::TimeTicks; @@ -217,8 +218,18 @@ ProxyService* ProxyService::Create( new ProxyConfigServiceFixed(*pc) : CreateSystemProxyConfigService(io_loop); - ProxyResolver* proxy_resolver = use_v8_resolver ? - new ProxyResolverV8() : CreateNonV8ProxyResolver(); + ProxyResolver* proxy_resolver; + + if (use_v8_resolver) { + // Send javascript errors and alerts to LOG(INFO). + HostResolver* host_resolver = url_request_context->host_resolver(); + ProxyResolverV8::JSBindings* js_bindings = + ProxyResolverV8::CreateDefaultBindings(host_resolver, io_loop); + + proxy_resolver = new ProxyResolverV8(js_bindings); + } else { + proxy_resolver = CreateNonV8ProxyResolver(); + } ProxyService* proxy_service = new ProxyService( proxy_config_service, proxy_resolver); diff --git a/net/tools/fetch/fetch_client.cc b/net/tools/fetch/fetch_client.cc index d49b3f1..349eed9 100644 --- a/net/tools/fetch/fetch_client.cc +++ b/net/tools/fetch/fetch_client.cc @@ -9,6 +9,7 @@ #include "base/stats_counters.h" #include "base/string_util.h" #include "net/base/completion_callback.h" +#include "net/base/host_resolver.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/http/http_cache.h" @@ -125,12 +126,13 @@ int main(int argc, char**argv) { // Do work here. MessageLoop loop; + net::HostResolver host_resolver; scoped_ptr<net::ProxyService> proxy_service(net::ProxyService::CreateNull()); net::HttpTransactionFactory* factory = NULL; if (use_cache) - factory = new net::HttpCache(proxy_service.get(), 0); + factory = new net::HttpCache(&host_resolver, proxy_service.get(), 0); else - factory = new net::HttpNetworkLayer(proxy_service.get()); + factory = new net::HttpNetworkLayer(&host_resolver, proxy_service.get()); { StatsCounterTimer driver_time("FetchClient.total_time"); diff --git a/net/url_request/url_request_context.h b/net/url_request/url_request_context.h index 8e32c97..2c9f6fa 100644 --- a/net/url_request/url_request_context.h +++ b/net/url_request/url_request_context.h @@ -19,6 +19,7 @@ namespace net { class CookieMonster; class ForceTLSState; class FtpTransactionFactory; +class HostResolver; class HttpTransactionFactory; class ProxyService; } @@ -28,13 +29,18 @@ class URLRequestContext : public base::RefCountedThreadSafe<URLRequestContext> { public: URLRequestContext() - : proxy_service_(NULL), + : host_resolver_(NULL), + proxy_service_(NULL), http_transaction_factory_(NULL), ftp_transaction_factory_(NULL), cookie_store_(NULL), force_tls_state_(NULL) { } + net::HostResolver* host_resolver() const { + return host_resolver_; + } + // Get the proxy service for this context. net::ProxyService* proxy_service() const { return proxy_service_; @@ -88,6 +94,7 @@ class URLRequestContext : // The following members are expected to be initialized and owned by // subclasses. + net::HostResolver* host_resolver_; net::ProxyService* proxy_service_; net::HttpTransactionFactory* http_transaction_factory_; net::FtpTransactionFactory* ftp_transaction_factory_; diff --git a/net/url_request/url_request_unittest.cc b/net/url_request/url_request_unittest.cc index 6551df1..18ff111 100644 --- a/net/url_request/url_request_unittest.cc +++ b/net/url_request/url_request_unittest.cc @@ -45,10 +45,12 @@ namespace { class URLRequestHttpCacheContext : public URLRequestContext { public: URLRequestHttpCacheContext() { + host_resolver_ = new net::HostResolver; proxy_service_ = net::ProxyService::CreateNull(); http_transaction_factory_ = - new net::HttpCache(net::HttpNetworkLayer::CreateFactory(proxy_service_), - disk_cache::CreateInMemoryCacheBackend(0)); + new net::HttpCache( + net::HttpNetworkLayer::CreateFactory(host_resolver_, proxy_service_), + disk_cache::CreateInMemoryCacheBackend(0)); // In-memory cookie store. cookie_store_ = new net::CookieMonster(); } @@ -57,6 +59,7 @@ class URLRequestHttpCacheContext : public URLRequestContext { delete cookie_store_; delete http_transaction_factory_; delete proxy_service_; + delete host_resolver_; } }; diff --git a/net/url_request/url_request_unittest.h b/net/url_request/url_request_unittest.h index d4e2b82..1c880fd 100644 --- a/net/url_request/url_request_unittest.h +++ b/net/url_request/url_request_unittest.h @@ -21,6 +21,7 @@ #include "base/thread.h" #include "base/time.h" #include "base/waitable_event.h" +#include "net/base/host_resolver.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/ssl_test_util.h" @@ -42,22 +43,27 @@ using base::TimeDelta; class TestURLRequestContext : public URLRequestContext { public: TestURLRequestContext() { + host_resolver_ = new net::HostResolver; proxy_service_ = net::ProxyService::CreateNull(); http_transaction_factory_ = - net::HttpNetworkLayer::CreateFactory(proxy_service_); + net::HttpNetworkLayer::CreateFactory(host_resolver_, + proxy_service_); } explicit TestURLRequestContext(const std::string& proxy) { + host_resolver_ = new net::HostResolver; net::ProxyConfig proxy_config; proxy_config.proxy_rules.ParseFromString(proxy); proxy_service_ = net::ProxyService::CreateFixed(proxy_config); http_transaction_factory_ = - net::HttpNetworkLayer::CreateFactory(proxy_service_); + net::HttpNetworkLayer::CreateFactory(host_resolver_, + proxy_service_); } virtual ~TestURLRequestContext() { delete http_transaction_factory_; delete proxy_service_; + delete host_resolver_; } }; |