diff options
author | Ben Murdoch <benm@google.com> | 2011-01-07 14:18:56 +0000 |
---|---|---|
committer | Ben Murdoch <benm@google.com> | 2011-01-11 10:23:13 +0000 |
commit | 201ade2fbba22bfb27ae029f4d23fca6ded109a0 (patch) | |
tree | b793f4ed916f73cf18357ea467ff3deb5ffb5b52 /net | |
parent | d8c4c37a7d0961944bfdfaa117d5c68c8e129c97 (diff) | |
download | external_chromium-201ade2fbba22bfb27ae029f4d23fca6ded109a0.zip external_chromium-201ade2fbba22bfb27ae029f4d23fca6ded109a0.tar.gz external_chromium-201ade2fbba22bfb27ae029f4d23fca6ded109a0.tar.bz2 |
Merge chromium at 9.0.597.55: Initial merge by git.
Change-Id: Id686a88437441ec7e17abb3328a404c7b6c3c6ad
Diffstat (limited to 'net')
173 files changed, 8332 insertions, 2039 deletions
diff --git a/net/base/cert_database_nss_unittest.cc b/net/base/cert_database_nss_unittest.cc index c68b6fd..5056e5d 100644 --- a/net/base/cert_database_nss_unittest.cc +++ b/net/base/cert_database_nss_unittest.cc @@ -14,6 +14,7 @@ #include "base/nss_util_internal.h" #include "base/path_service.h" #include "base/scoped_temp_dir.h" +#include "base/singleton.h" #include "base/string_util.h" #include "base/utf_string_conversions.h" #include "net/base/cert_database.h" diff --git a/net/base/cert_test_util.cc b/net/base/cert_test_util.cc index d5c678e..df00b9d 100644 --- a/net/base/cert_test_util.cc +++ b/net/base/cert_test_util.cc @@ -32,7 +32,7 @@ X509Certificate* AddTemporaryRootCertToStore(X509* x509_cert) { unsigned long error_code = ERR_get_error(); if (ERR_GET_LIB(error_code) != ERR_LIB_X509 || ERR_GET_REASON(error_code) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { - base::ClearOpenSSLERRStack(); + base::ClearOpenSSLERRStack(FROM_HERE); return NULL; } } diff --git a/net/base/cert_verifier.cc b/net/base/cert_verifier.cc index 4e94133..ae910b4 100644 --- a/net/base/cert_verifier.cc +++ b/net/base/cert_verifier.cc @@ -8,7 +8,9 @@ #include <private/pprthred.h> // PR_DetatchThread #endif -#include "base/message_loop.h" +#include "base/lock.h" +#include "base/message_loop_proxy.h" +#include "base/scoped_ptr.h" #include "base/worker_pool.h" #include "net/base/cert_verify_result.h" #include "net/base/net_errors.h" @@ -31,7 +33,7 @@ class CertVerifier::Request : verifier_(verifier), verify_result_(verify_result), callback_(callback), - origin_loop_(MessageLoop::current()), + origin_loop_proxy_(base::MessageLoopProxy::CreateForCurrentThread()), error_(OK) { } @@ -49,20 +51,16 @@ class CertVerifier::Request : PR_DetachThread(); #endif - Task* reply = NewRunnableMethod(this, &Request::DoCallback); + scoped_ptr<Task> reply(NewRunnableMethod(this, &Request::DoCallback)); // The origin loop could go away while we are trying to post to it, so we // need to call its PostTask method inside a lock. See ~CertVerifier. - { - AutoLock locked(origin_loop_lock_); - if (origin_loop_) { - origin_loop_->PostTask(FROM_HERE, reply); - reply = NULL; - } + AutoLock locked(origin_loop_proxy_lock_); + if (origin_loop_proxy_) { + bool posted = origin_loop_proxy_->PostTask(FROM_HERE, reply.release()); + // TODO(willchan): Fix leaks and then change this to a DCHECK. + LOG_IF(ERROR, !posted) << "Leaked CertVerifier!"; } - - // Does nothing if it got posted. - delete reply; } void DoCallback() { @@ -85,8 +83,8 @@ class CertVerifier::Request : void Cancel() { verifier_ = NULL; - AutoLock locked(origin_loop_lock_); - origin_loop_ = NULL; + AutoLock locked(origin_loop_proxy_lock_); + origin_loop_proxy_ = NULL; } private: @@ -106,8 +104,13 @@ class CertVerifier::Request : CompletionCallback* callback_; // Used to post ourselves onto the origin thread. - Lock origin_loop_lock_; - MessageLoop* origin_loop_; + Lock origin_loop_proxy_lock_; + // Use a MessageLoopProxy in case the owner of the CertVerifier is leaked, so + // this code won't crash: http://crbug.com/42275. If this is leaked, then it + // doesn't get Cancel()'d, so |origin_loop_proxy_| doesn't get NULL'd out. If + // the MessageLoop goes away, then if we had used a MessageLoop, this would + // crash. + scoped_refptr<base::MessageLoopProxy> origin_loop_proxy_; // Assigned on the worker thread, read on the origin thread. int error_; diff --git a/net/base/dns_reload_timer.cc b/net/base/dns_reload_timer.cc index 5931c5b..1bfe535 100644 --- a/net/base/dns_reload_timer.cc +++ b/net/base/dns_reload_timer.cc @@ -5,11 +5,11 @@ #include "net/base/dns_reload_timer.h" #if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) -#include "base/singleton.h" +#include "base/lazy_instance.h" #include "base/thread_local_storage.h" #include "base/time.h" -namespace net { +namespace { // On Linux/BSD, changes to /etc/resolv.conf can go unnoticed thus resulting // in DNS queries failing either because nameservers are unknown on startup @@ -58,7 +58,7 @@ class DnsReloadTimer { } private: - friend struct DefaultSingletonTraits<DnsReloadTimer>; + friend struct base::DefaultLazyInstanceTraits<DnsReloadTimer>; DnsReloadTimer() { // During testing the DnsReloadTimer Singleton may be created and destroyed @@ -81,8 +81,16 @@ class DnsReloadTimer { // static ThreadLocalStorage::Slot DnsReloadTimer::tls_index_(base::LINKER_INITIALIZED); +base::LazyInstance<DnsReloadTimer, + base::LeakyLazyInstanceTraits<DnsReloadTimer> > + g_dns_reload_timer(base::LINKER_INITIALIZED); + +} // namespace + +namespace net { + bool DnsReloadTimerHasExpired() { - DnsReloadTimer* dns_timer = Singleton<DnsReloadTimer>::get(); + DnsReloadTimer* dns_timer = g_dns_reload_timer.Pointer(); return dns_timer->Expired(); } diff --git a/net/base/ev_root_ca_metadata.cc b/net/base/ev_root_ca_metadata.cc index 661b652..a721357 100644 --- a/net/base/ev_root_ca_metadata.cc +++ b/net/base/ev_root_ca_metadata.cc @@ -13,8 +13,8 @@ #include <stdlib.h> #endif +#include "base/lazy_instance.h" #include "base/logging.h" -#include "base/singleton.h" namespace net { @@ -283,9 +283,13 @@ const EVRootCAMetadata::PolicyOID EVRootCAMetadata::policy_oids_[] = { }; #endif +static base::LazyInstance<EVRootCAMetadata, + base::LeakyLazyInstanceTraits<EVRootCAMetadata> > + g_ev_root_ca_metadata(base::LINKER_INITIALIZED); + // static EVRootCAMetadata* EVRootCAMetadata::GetInstance() { - return Singleton<EVRootCAMetadata>::get(); + return g_ev_root_ca_metadata.Pointer(); } bool EVRootCAMetadata::GetPolicyOID( diff --git a/net/base/ev_root_ca_metadata.h b/net/base/ev_root_ca_metadata.h index e0961f3..832ebe2 100644 --- a/net/base/ev_root_ca_metadata.h +++ b/net/base/ev_root_ca_metadata.h @@ -17,8 +17,10 @@ #include "net/base/x509_certificate.h" +namespace base { template <typename T> -struct DefaultSingletonTraits; +struct DefaultLazyInstanceTraits; +} // namespace base namespace net { @@ -55,7 +57,7 @@ class EVRootCAMetadata { PolicyOID policy_oid) const; private: - friend struct DefaultSingletonTraits<EVRootCAMetadata>; + friend struct base::DefaultLazyInstanceTraits<EVRootCAMetadata>; typedef std::map<SHA1Fingerprint, PolicyOID, SHA1FingerprintLessThan> PolicyOidMap; diff --git a/net/base/keygen_handler_unittest.cc b/net/base/keygen_handler_unittest.cc index 62c5191..d3bf4f5 100644 --- a/net/base/keygen_handler_unittest.cc +++ b/net/base/keygen_handler_unittest.cc @@ -16,6 +16,7 @@ #include "base/logging.h" #include "base/nss_util.h" #include "base/task.h" +#include "base/thread_restrictions.h" #include "base/waitable_event.h" #include "base/worker_pool.h" #include "testing/gtest/include/gtest/gtest.h" @@ -90,6 +91,9 @@ class ConcurrencyTestTask : public Task { } virtual void Run() { + // We allow Singleton use on the worker thread here since we use a + // WaitableEvent to synchronize, so it's safe. + base::ThreadRestrictions::ScopedAllowSingleton scoped_allow_singleton; KeygenHandler handler(768, "some challenge", GURL("http://www.example.com")); handler.set_stores_key(false); // Don't leave the key-pair behind. diff --git a/net/base/load_flags_list.h b/net/base/load_flags_list.h index a7181f2..4f5460f 100644 --- a/net/base/load_flags_list.h +++ b/net/base/load_flags_list.h @@ -97,3 +97,8 @@ LOAD_FLAG(SUB_FRAME, 1 << 21) // respected if renderer has CanReadRawCookies capability in the security // policy. LOAD_FLAG(REPORT_RAW_HEADERS, 1 << 22) + +// Indicates that this load was motivated by the rel=prefetch feature, +// and is (in theory) not intended for the current frame. +LOAD_FLAG(PREFETCH, 1 << 23) + diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h index 96b19ad..bc7065d 100644 --- a/net/base/net_error_list.h +++ b/net/base/net_error_list.h @@ -191,7 +191,9 @@ NET_ERROR(SSL_SNAP_START_NPN_MISPREDICTION, -131) // give the user a helpful error message rather than have the connection hang. NET_ERROR(ESET_ANTI_VIRUS_SSL_INTERCEPTION, -132) -// Missing -133. Feel free to reuse in the future. +// We've hit the max socket limit for the socket pool while preconnecting. We +// don't bother trying to preconnect more sockets. +NET_ERROR(PRECONNECT_MAX_SOCKET_LIMIT, -133) // The permission to use the SSL client certificate's private key was denied. NET_ERROR(SSL_CLIENT_AUTH_PRIVATE_KEY_ACCESS_DENIED, -134) @@ -210,6 +212,9 @@ NET_ERROR(NAME_RESOLUTION_FAILED, -137) // errors. See also ERR_ACCESS_DENIED. NET_ERROR(NETWORK_ACCESS_DENIED, -138) +// The request throttler module cancelled this request to avoid DDOS. +NET_ERROR(TEMPORARILY_THROTTLED, -139) + // Certificate error codes // // The values of certificate error codes must be consecutive. @@ -421,6 +426,10 @@ NET_ERROR(RESPONSE_BODY_TOO_BIG_TO_DRAIN, -345) // The HTTP response was too big to drain. NET_ERROR(RESPONSE_HEADERS_MULTIPLE_CONTENT_LENGTH, -346) +// SPDY Headers have been received, but not all of them - status or version +// headers are missing, so we're expecting additional frames to complete them. +NET_ERROR(INCOMPLETE_SPDY_HEADERS, -347) + // The cache does not have the requested entry. NET_ERROR(CACHE_MISS, -400) diff --git a/net/base/net_log.h b/net/base/net_log.h index aa0b70e..ad775fa 100644 --- a/net/base/net_log.h +++ b/net/base/net_log.h @@ -81,7 +81,7 @@ class NetLog { // Base class for associating additional parameters with an event. Log // observers need to know what specific derivations of EventParameters a // particular EventType uses, in order to get at the individual components. - class EventParameters : public base::RefCounted<EventParameters> { + class EventParameters : public base::RefCountedThreadSafe<EventParameters> { public: EventParameters() {} virtual ~EventParameters() {} diff --git a/net/base/net_log_event_type_list.h b/net/base/net_log_event_type_list.h index 6183749..0021c0d 100644 --- a/net/base/net_log_event_type_list.h +++ b/net/base/net_log_event_type_list.h @@ -595,9 +595,19 @@ EVENT_TYPE(SPDY_SESSION_SYN_STREAM) // "flags": <The control frame flags> // "headers": <The list of header:value pairs> // "id": <The stream id> +// "associated_stream": <The stream id> // } EVENT_TYPE(SPDY_SESSION_PUSHED_SYN_STREAM) +// This event is sent for a SPDY HEADERS frame. +// The following parameters are attached: +// { +// "flags": <The control frame flags> +// "headers": <The list of header:value pairs> +// "id": <The stream id> +// } +EVENT_TYPE(SPDY_SESSION_HEADERS) + // This event is sent for a SPDY SYN_REPLY. // The following parameters are attached: // { diff --git a/net/base/network_change_notifier_linux.cc b/net/base/network_change_notifier_linux.cc index 95f230d..1db4bd1 100644 --- a/net/base/network_change_notifier_linux.cc +++ b/net/base/network_change_notifier_linux.cc @@ -102,11 +102,11 @@ void NetworkChangeNotifierLinux::Thread::ListenForNotifications() { if (HandleNetlinkMessage(buf, rv)) { VLOG(1) << "Detected IP address changes."; #if defined(OS_CHROMEOS) - // TODO(zelidrag): chromium-os:3996 - introduced artificial delay to - // work around the issue of proxy initialization before name resolving - // is functional in ChromeOS. This should be removed once this bug - // is properly fixed. - const int kObserverNotificationDelayMS = 500; + // TODO(oshima): chromium-os:8285 - introduced artificial delay to + // work around the issue of network load issue after connection + // restored. See the bug for more details. + // This should be removed once this bug is properly fixed. + const int kObserverNotificationDelayMS = 200; message_loop()->PostDelayedTask( FROM_HERE, method_factory_.NewRunnableMethod( diff --git a/net/base/ssl_client_auth_cache.cc b/net/base/ssl_client_auth_cache.cc index d2f47cc..355073f 100644 --- a/net/base/ssl_client_auth_cache.cc +++ b/net/base/ssl_client_auth_cache.cc @@ -4,15 +4,26 @@ #include "net/base/ssl_client_auth_cache.h" +#include "base/logging.h" +#include "net/base/x509_certificate.h" + namespace net { SSLClientAuthCache::SSLClientAuthCache() {} SSLClientAuthCache::~SSLClientAuthCache() {} -X509Certificate* SSLClientAuthCache::Lookup(const std::string& server) { +bool SSLClientAuthCache::Lookup( + const std::string& server, + scoped_refptr<X509Certificate>* certificate) { + DCHECK(certificate); + AuthCacheMap::iterator iter = cache_.find(server); - return (iter == cache_.end()) ? NULL : iter->second; + if (iter == cache_.end()) + return false; + + *certificate = iter->second; + return true; } void SSLClientAuthCache::Add(const std::string& server, diff --git a/net/base/ssl_client_auth_cache.h b/net/base/ssl_client_auth_cache.h index 023480b..2b276a2 100644 --- a/net/base/ssl_client_auth_cache.h +++ b/net/base/ssl_client_auth_cache.h @@ -10,10 +10,11 @@ #include <map> #include "base/ref_counted.h" -#include "net/base/x509_certificate.h" namespace net { +class X509Certificate; + // The SSLClientAuthCache class is a simple cache structure to store SSL // client certificates. Provides lookup, insertion, and deletion of entries. // The parameter for doing lookups, insertions, and deletions is the server's @@ -26,13 +27,18 @@ class SSLClientAuthCache { SSLClientAuthCache(); ~SSLClientAuthCache(); - // Check if we have a client certificate for SSL server at |server|. - // Returns the client certificate (if found) or NULL (if not found). - X509Certificate* Lookup(const std::string& server); + // Checks for a client certificate preference for SSL server at |server|. + // Returns true if a preference is found, and sets |*certificate| to the + // desired client certificate. The desired certificate may be NULL, which + // indicates a preference to not send any certificate to |server|. + // If a certificate preference is not found, returns false. + bool Lookup(const std::string& server, + scoped_refptr<X509Certificate>* certificate); // Add a client certificate for |server| to the cache. If there is already - // a client certificate for |server|, it will be overwritten. Both parameters - // are IN only. + // a client certificate for |server|, it will be overwritten. A NULL + // |client_cert| indicates a preference that no client certificate should + // be sent to |server|. void Add(const std::string& server, X509Certificate* client_cert); // Remove the client certificate for |server| from the cache, if one exists. diff --git a/net/base/ssl_client_auth_cache_unittest.cc b/net/base/ssl_client_auth_cache_unittest.cc index 85b3d5e..f528d58 100644 --- a/net/base/ssl_client_auth_cache_unittest.cc +++ b/net/base/ssl_client_auth_cache_unittest.cc @@ -5,6 +5,7 @@ #include "net/base/ssl_client_auth_cache.h" #include "base/time.h" +#include "net/base/x509_certificate.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -27,32 +28,50 @@ TEST(SSLClientAuthCacheTest, LookupAddRemove) { scoped_refptr<X509Certificate> cert3( new X509Certificate("foo3", "CA", start_date, expiration_date)); + scoped_refptr<X509Certificate> cached_cert; // Lookup non-existent client certificate. - EXPECT_TRUE(cache.Lookup(server1) == NULL); + cached_cert = NULL; + EXPECT_FALSE(cache.Lookup(server1, &cached_cert)); // Add client certificate for server1. - cache.Add(server1, cert1.get()); - EXPECT_EQ(cert1.get(), cache.Lookup(server1)); + cache.Add(server1, cert1); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(cert1, cached_cert); // Add client certificate for server2. - cache.Add(server2, cert2.get()); - EXPECT_EQ(cert1.get(), cache.Lookup(server1)); - EXPECT_EQ(cert2.get(), cache.Lookup(server2)); + cache.Add(server2, cert2); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(cert1, cached_cert.get()); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server2, &cached_cert)); + EXPECT_EQ(cert2, cached_cert); // Overwrite the client certificate for server1. - cache.Add(server1, cert3.get()); - EXPECT_EQ(cert3.get(), cache.Lookup(server1)); - EXPECT_EQ(cert2.get(), cache.Lookup(server2)); + cache.Add(server1, cert3); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(cert3, cached_cert); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server2, &cached_cert)); + EXPECT_EQ(cert2, cached_cert); // Remove client certificate of server1. cache.Remove(server1); - EXPECT_TRUE(cache.Lookup(server1) == NULL); - EXPECT_EQ(cert2.get(), cache.Lookup(server2)); + cached_cert = NULL; + EXPECT_FALSE(cache.Lookup(server1, &cached_cert)); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server2, &cached_cert)); + EXPECT_EQ(cert2, cached_cert); // Remove non-existent client certificate. cache.Remove(server1); - EXPECT_TRUE(cache.Lookup(server1) == NULL); - EXPECT_EQ(cert2.get(), cache.Lookup(server2)); + cached_cert = NULL; + EXPECT_FALSE(cache.Lookup(server1, &cached_cert)); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server2, &cached_cert)); + EXPECT_EQ(cert2, cached_cert); } // Check that if the server differs only by port number, it is considered @@ -74,8 +93,48 @@ TEST(SSLClientAuthCacheTest, LookupWithPort) { cache.Add(server1, cert1.get()); cache.Add(server2, cert2.get()); - EXPECT_EQ(cert1.get(), cache.Lookup(server1)); - EXPECT_EQ(cert2.get(), cache.Lookup(server2)); + scoped_refptr<X509Certificate> cached_cert; + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(cert1.get(), cached_cert); + EXPECT_TRUE(cache.Lookup(server2, &cached_cert)); + EXPECT_EQ(cert2.get(), cached_cert); +} + +// Check that the a NULL certificate, indicating the user has declined to send +// a certificate, is properly cached. +TEST(SSLClientAuthCacheTest, LookupNullPreference) { + SSLClientAuthCache cache; + base::Time start_date = base::Time::Now(); + base::Time expiration_date = start_date + base::TimeDelta::FromDays(1); + + std::string server1("foo:443"); + scoped_refptr<X509Certificate> cert1( + new X509Certificate("foo", "CA", start_date, expiration_date)); + + cache.Add(server1, NULL); + + scoped_refptr<X509Certificate> cached_cert(cert1); + // Make sure that |cached_cert| is updated to NULL, indicating the user + // declined to send a certificate to |server1|. + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(NULL, cached_cert.get()); + + // Remove the existing cached certificate. + cache.Remove(server1); + cached_cert = NULL; + EXPECT_FALSE(cache.Lookup(server1, &cached_cert)); + + // Add a new preference for a specific certificate. + cache.Add(server1, cert1); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(cert1, cached_cert); + + // Replace the specific preference with a NULL certificate. + cache.Add(server1, NULL); + cached_cert = NULL; + EXPECT_TRUE(cache.Lookup(server1, &cached_cert)); + EXPECT_EQ(NULL, cached_cert.get()); } } // namespace net diff --git a/net/base/ssl_false_start_blacklist.txt b/net/base/ssl_false_start_blacklist.txt index 26147de..e78b354 100644 --- a/net/base/ssl_false_start_blacklist.txt +++ b/net/base/ssl_false_start_blacklist.txt @@ -768,6 +768,7 @@ compus.de computerstore.be computerstore.nl coms.industrialcontrolrepair.com +comservicing.org concat.de concur.csmc.edu conduxio.com @@ -1451,6 +1452,7 @@ ggu.edu ggusd.us ggy.com gilmorehealth.com +giltcdn.com global2.mtsallstream.com glove.mizunoballpark.com glowinghealth.com.au diff --git a/net/base/transport_security_state.cc b/net/base/transport_security_state.cc index 3014e21..69c915d 100644 --- a/net/base/transport_security_state.cc +++ b/net/base/transport_security_state.cc @@ -411,6 +411,7 @@ bool TransportSecurityState::IsPreloadedSTS( {19, true, "\015sunshinepress\003org"}, {21, false, "\003www\013noisebridge\003net"}, {10, false, "\004neg9\003org"}, + {11, false, "\006factor\002cc"}, }; static const size_t kNumPreloadedSTS = ARRAYSIZE_UNSAFE(kPreloadedSTS); diff --git a/net/base/transport_security_state.h b/net/base/transport_security_state.h index b7db72c..49b44d7 100644 --- a/net/base/transport_security_state.h +++ b/net/base/transport_security_state.h @@ -11,7 +11,6 @@ #include "base/basictypes.h" #include "base/gtest_prod_util.h" -#include "base/lock.h" #include "base/ref_counted.h" #include "base/time.h" diff --git a/net/base/transport_security_state_unittest.cc b/net/base/transport_security_state_unittest.cc index 2a06501..47a3562 100644 --- a/net/base/transport_security_state_unittest.cc +++ b/net/base/transport_security_state_unittest.cc @@ -345,6 +345,9 @@ TEST_F(TransportSecurityStateTest, Preloaded) { EXPECT_TRUE(state->IsEnabledForHost(&domain_state, "neg9.org")); EXPECT_FALSE(state->IsEnabledForHost(&domain_state, "www.neg9.org")); + + EXPECT_TRUE(state->IsEnabledForHost(&domain_state, "factor.cc")); + EXPECT_FALSE(state->IsEnabledForHost(&domain_state, "www.factor.cc")); } TEST_F(TransportSecurityStateTest, LongNames) { diff --git a/net/base/x509_cert_types.cc b/net/base/x509_cert_types.cc index 5dfc57a..cdfbdaa 100644 --- a/net/base/x509_cert_types.cc +++ b/net/base/x509_cert_types.cc @@ -4,38 +4,11 @@ #include "net/base/x509_cert_types.h" -#include <ostream> - #include "net/base/x509_certificate.h" #include "base/logging.h" namespace net { -bool match(const std::string &str, const std::string &against) { - // TODO(snej): Use the full matching rules specified in RFC 5280 sec. 7.1 - // including trimming and case-folding: <http://www.ietf.org/rfc/rfc5280.txt>. - return against == str; -} - -bool match(const std::vector<std::string> &rdn1, - const std::vector<std::string> &rdn2) { - // "Two relative distinguished names RDN1 and RDN2 match if they have the - // same number of naming attributes and for each naming attribute in RDN1 - // there is a matching naming attribute in RDN2." --RFC 5280 sec. 7.1. - if (rdn1.size() != rdn2.size()) - return false; - for (unsigned i1 = 0; i1 < rdn1.size(); ++i1) { - unsigned i2; - for (i2 = 0; i2 < rdn2.size(); ++i2) { - if (match(rdn1[i1], rdn2[i2])) - break; - } - if (i2 == rdn2.size()) - return false; - } - return true; -} - CertPrincipal::CertPrincipal() { } @@ -44,18 +17,6 @@ CertPrincipal::CertPrincipal(const std::string& name) : common_name(name) {} CertPrincipal::~CertPrincipal() { } -bool CertPrincipal::Matches(const CertPrincipal& against) const { - return match(common_name, against.common_name) && - match(common_name, against.common_name) && - match(locality_name, against.locality_name) && - match(state_or_province_name, against.state_or_province_name) && - match(country_name, against.country_name) && - match(street_addresses, against.street_addresses) && - match(organization_names, against.organization_names) && - match(organization_unit_names, against.organization_unit_names) && - match(domain_components, against.domain_components); -} - std::string CertPrincipal::GetDisplayName() const { if (!common_name.empty()) return common_name; @@ -67,27 +28,6 @@ std::string CertPrincipal::GetDisplayName() const { return std::string(); } -std::ostream& operator<<(std::ostream& s, const CertPrincipal& p) { - s << "CertPrincipal["; - if (!p.common_name.empty()) - s << "cn=\"" << p.common_name << "\" "; - for (unsigned i = 0; i < p.street_addresses.size(); ++i) - s << "street=\"" << p.street_addresses[i] << "\" "; - if (!p.locality_name.empty()) - s << "l=\"" << p.locality_name << "\" "; - for (unsigned i = 0; i < p.organization_names.size(); ++i) - s << "o=\"" << p.organization_names[i] << "\" "; - for (unsigned i = 0; i < p.organization_unit_names.size(); ++i) - s << "ou=\"" << p.organization_unit_names[i] << "\" "; - if (!p.state_or_province_name.empty()) - s << "st=\"" << p.state_or_province_name << "\" "; - if (!p.country_name.empty()) - s << "c=\"" << p.country_name << "\" "; - for (unsigned i = 0; i < p.domain_components.size(); ++i) - s << "dc=\"" << p.domain_components[i] << "\" "; - return s << "]"; -} - CertPolicy::CertPolicy() { } diff --git a/net/base/x509_cert_types.h b/net/base/x509_cert_types.h index 7723c22..f762e56 100644 --- a/net/base/x509_cert_types.h +++ b/net/base/x509_cert_types.h @@ -8,25 +8,14 @@ #include <string.h> -#include <functional> -#include <iosfwd> #include <set> #include <string> #include <vector> -#include "base/ref_counted.h" -#include "base/singleton.h" -#include "base/time.h" -#include "testing/gtest/include/gtest/gtest_prod.h" +#include "build/build_config.h" -#if defined(OS_WIN) -#include <windows.h> -#include <wincrypt.h> -#elif defined(OS_MACOSX) +#if defined(OS_MACOSX) #include <Security/x509defs.h> -#elif defined(USE_NSS) -// Forward declaration; real one in <cert.h> -struct CERTCertificateStr; #endif namespace net { @@ -56,17 +45,19 @@ struct CertPrincipal { explicit CertPrincipal(const std::string& name); ~CertPrincipal(); +#if defined(OS_MACOSX) // Parses a BER-format DistinguishedName. bool ParseDistinguishedName(const void* ber_name_data, size_t length); -#if defined(OS_MACOSX) // Parses a CSSM_X509_NAME struct. void Parse(const CSSM_X509_NAME* name); -#endif - // Returns true if all attributes of the two objects match, - // where "match" is defined in RFC 5280 sec. 7.1. + // Compare this CertPrincipal with |against|, returning true if they're + // equal enough to be a possible match. This should NOT be used for any + // security relevant decisions. + // TODO(rsleevi): Remove once Mac client auth uses NSS for name comparison. bool Matches(const CertPrincipal& against) const; +#endif // Returns a name that can be used to represent the issuer. It tries in this // order: CN, O and OU and returns the first non-empty one found. @@ -86,9 +77,6 @@ struct CertPrincipal { std::vector<std::string> domain_components; }; -// Writes a human-readable description of a CertPrincipal, for debugging. -std::ostream& operator<<(std::ostream& s, const CertPrincipal& p); - // This class is useful for maintaining policies about which certificates are // permitted or forbidden for a particular purpose. class CertPolicy { diff --git a/net/base/x509_cert_types_mac.cc b/net/base/x509_cert_types_mac.cc index 14d5eee..c672863 100644 --- a/net/base/x509_cert_types_mac.cc +++ b/net/base/x509_cert_types_mac.cc @@ -14,7 +14,9 @@ namespace net { -static const CSSM_OID* kOIDs[] = { +namespace { + +const CSSM_OID* kOIDs[] = { &CSSMOID_CommonName, &CSSMOID_LocalityName, &CSSMOID_StateProvinceName, @@ -25,65 +27,6 @@ static const CSSM_OID* kOIDs[] = { &CSSMOID_DNQualifier // This should be "DC" but is undoubtedly wrong. }; // TODO(avi): Find the right OID. -// Converts raw CSSM_DATA to a std::string. (Char encoding is unaltered.) -static std::string DataToString(CSSM_DATA data); - -// Converts raw CSSM_DATA in ISO-8859-1 to a std::string in UTF-8. -static std::string Latin1DataToUTF8String(CSSM_DATA data); - -// Converts big-endian UTF-16 to UTF-8 in a std::string. -// Note: The byte-order flipping is done in place on the input buffer! -static bool UTF16BigEndianToUTF8(char16* chars, size_t length, - std::string* out_string); - -// Converts big-endian UTF-32 to UTF-8 in a std::string. -// Note: The byte-order flipping is done in place on the input buffer! -static bool UTF32BigEndianToUTF8(char32* chars, size_t length, - std::string* out_string); - -// Adds a type+value pair to the appropriate vector from a C array. -// The array is keyed by the matching OIDs from kOIDS[]. - static void AddTypeValuePair(const CSSM_OID type, - const std::string& value, - std::vector<std::string>* values[]); - -// Stores the first string of the vector, if any, to *single_value. -static void SetSingle(const std::vector<std::string> &values, - std::string* single_value); - - -void CertPrincipal::Parse(const CSSM_X509_NAME* name) { - std::vector<std::string> common_names, locality_names, state_names, - country_names; - - std::vector<std::string>* values[] = { - &common_names, &locality_names, - &state_names, &country_names, - &(this->street_addresses), - &(this->organization_names), - &(this->organization_unit_names), - &(this->domain_components) - }; - DCHECK(arraysize(kOIDs) == arraysize(values)); - - for (size_t rdn = 0; rdn < name->numberOfRDNs; ++rdn) { - CSSM_X509_RDN rdn_struct = name->RelativeDistinguishedName[rdn]; - for (size_t pair = 0; pair < rdn_struct.numberOfPairs; ++pair) { - CSSM_X509_TYPE_VALUE_PAIR pair_struct = - rdn_struct.AttributeTypeAndValue[pair]; - AddTypeValuePair(pair_struct.type, - DataToString(pair_struct.value), - values); - } - } - - SetSingle(common_names, &this->common_name); - SetSingle(locality_names, &this->locality_name); - SetSingle(state_names, &this->state_or_province_name); - SetSingle(country_names, &this->country_name); -} - - // The following structs and templates work with Apple's very arcane and under- // documented SecAsn1Parser API, which is apparently the same as NSS's ASN.1 // decoder: @@ -108,7 +51,7 @@ struct KeyValuePair { }; }; -static const SecAsn1Template kStringValueTemplate[] = { +const SecAsn1Template kStringValueTemplate[] = { { SEC_ASN1_CHOICE, offsetof(KeyValuePair, value_type), }, { SEC_ASN1_PRINTABLE_STRING, offsetof(KeyValuePair, value), 0, KeyValuePair::kTypePrintableString }, @@ -125,7 +68,7 @@ static const SecAsn1Template kStringValueTemplate[] = { { 0, } }; -static const SecAsn1Template kKeyValuePairTemplate[] = { +const SecAsn1Template kKeyValuePairTemplate[] = { { SEC_ASN1_SEQUENCE, 0, NULL, sizeof(KeyValuePair) }, { SEC_ASN1_OBJECT_ID, offsetof(KeyValuePair, key), }, { SEC_ASN1_INLINE, 0, &kStringValueTemplate, }, @@ -136,8 +79,8 @@ struct KeyValuePairs { KeyValuePair* pairs; }; -static const SecAsn1Template kKeyValuePairSetTemplate[] = { - { SEC_ASN1_SET_OF, offsetof(KeyValuePairs,pairs), +const SecAsn1Template kKeyValuePairSetTemplate[] = { + { SEC_ASN1_SET_OF, offsetof(KeyValuePairs, pairs), kKeyValuePairTemplate, sizeof(KeyValuePairs) } }; @@ -145,11 +88,99 @@ struct X509Name { KeyValuePairs** pairs_list; }; -static const SecAsn1Template kNameTemplate[] = { - { SEC_ASN1_SEQUENCE_OF, offsetof(X509Name,pairs_list), +const SecAsn1Template kNameTemplate[] = { + { SEC_ASN1_SEQUENCE_OF, offsetof(X509Name, pairs_list), kKeyValuePairSetTemplate, sizeof(X509Name) } }; +// Converts raw CSSM_DATA to a std::string. (Char encoding is unaltered.) +std::string DataToString(CSSM_DATA data) { + return std::string( + reinterpret_cast<std::string::value_type*>(data.Data), + data.Length); +} + +// Converts raw CSSM_DATA in ISO-8859-1 to a std::string in UTF-8. +std::string Latin1DataToUTF8String(CSSM_DATA data) { + string16 utf16; + if (!CodepageToUTF16(DataToString(data), base::kCodepageLatin1, + base::OnStringConversionError::FAIL, &utf16)) + return ""; + return UTF16ToUTF8(utf16); +} + +// Converts big-endian UTF-16 to UTF-8 in a std::string. +// Note: The byte-order flipping is done in place on the input buffer! +bool UTF16BigEndianToUTF8(char16* chars, size_t length, + std::string* out_string) { + for (size_t i = 0; i < length; i++) + chars[i] = EndianU16_BtoN(chars[i]); + return UTF16ToUTF8(chars, length, out_string); +} + +// Converts big-endian UTF-32 to UTF-8 in a std::string. +// Note: The byte-order flipping is done in place on the input buffer! +bool UTF32BigEndianToUTF8(char32* chars, size_t length, + std::string* out_string) { + for (size_t i = 0; i < length; ++i) + chars[i] = EndianS32_BtoN(chars[i]); +#if defined(WCHAR_T_IS_UTF32) + return WideToUTF8(reinterpret_cast<const wchar_t*>(chars), + length, out_string); +#else +#error This code doesn't handle 16-bit wchar_t. +#endif +} + +// Adds a type+value pair to the appropriate vector from a C array. +// The array is keyed by the matching OIDs from kOIDS[]. +void AddTypeValuePair(const CSSM_OID type, + const std::string& value, + std::vector<std::string>* values[]) { + for (size_t oid = 0; oid < arraysize(kOIDs); ++oid) { + if (CSSMOIDEqual(&type, kOIDs[oid])) { + values[oid]->push_back(value); + break; + } + } +} + +// Stores the first string of the vector, if any, to *single_value. +void SetSingle(const std::vector<std::string>& values, + std::string* single_value) { + // We don't expect to have more than one CN, L, S, and C. + LOG_IF(WARNING, values.size() > 1) << "Didn't expect multiple values"; + if (values.size() > 0) + *single_value = values[0]; +} + +bool match(const std::string& str, const std::string& against) { + // TODO(snej): Use the full matching rules specified in RFC 5280 sec. 7.1 + // including trimming and case-folding: <http://www.ietf.org/rfc/rfc5280.txt>. + return against == str; +} + +bool match(const std::vector<std::string>& rdn1, + const std::vector<std::string>& rdn2) { + // "Two relative distinguished names RDN1 and RDN2 match if they have the + // same number of naming attributes and for each naming attribute in RDN1 + // there is a matching naming attribute in RDN2." --RFC 5280 sec. 7.1. + if (rdn1.size() != rdn2.size()) + return false; + for (unsigned i1 = 0; i1 < rdn1.size(); ++i1) { + unsigned i2; + for (i2 = 0; i2 < rdn2.size(); ++i2) { + if (match(rdn1[i1], rdn2[i2])) + break; + } + if (i2 == rdn2.size()) + return false; + } + return true; +} + +} // namespace + bool CertPrincipal::ParseDistinguishedName(const void* ber_name_data, size_t length) { DCHECK(ber_name_data); @@ -182,7 +213,7 @@ bool CertPrincipal::ParseDistinguishedName(const void* ber_name_data, }; DCHECK(arraysize(kOIDs) == arraysize(values)); - for (int rdn=0; name[rdn].pairs_list; ++rdn) { + for (int rdn = 0; name[rdn].pairs_list; ++rdn) { KeyValuePair *pair; for (int pair_index = 0; NULL != (pair = name[rdn].pairs_list[0][pair_index].pairs); @@ -235,59 +266,46 @@ bool CertPrincipal::ParseDistinguishedName(const void* ber_name_data, return true; } +void CertPrincipal::Parse(const CSSM_X509_NAME* name) { + std::vector<std::string> common_names, locality_names, state_names, + country_names; -// SUBROUTINES: - -static std::string DataToString(CSSM_DATA data) { - return std::string( - reinterpret_cast<std::string::value_type*>(data.Data), - data.Length); -} - -static std::string Latin1DataToUTF8String(CSSM_DATA data) { - string16 utf16; - if (!CodepageToUTF16(DataToString(data), base::kCodepageLatin1, - base::OnStringConversionError::FAIL, &utf16)) - return ""; - return UTF16ToUTF8(utf16); -} - -bool UTF16BigEndianToUTF8(char16* chars, size_t length, - std::string* out_string) { - for (size_t i = 0; i < length; i++) - chars[i] = EndianU16_BtoN(chars[i]); - return UTF16ToUTF8(chars, length, out_string); -} - -bool UTF32BigEndianToUTF8(char32* chars, size_t length, - std::string* out_string) { - for (size_t i = 0; i < length; i++) - chars[i] = EndianS32_BtoN(chars[i]); -#if defined(WCHAR_T_IS_UTF32) - return WideToUTF8(reinterpret_cast<const wchar_t*>(chars), - length, out_string); -#else -#error This code doesn't handle 16-bit wchar_t. -#endif -} + std::vector<std::string>* values[] = { + &common_names, &locality_names, + &state_names, &country_names, + &(this->street_addresses), + &(this->organization_names), + &(this->organization_unit_names), + &(this->domain_components) + }; + DCHECK(arraysize(kOIDs) == arraysize(values)); - static void AddTypeValuePair(const CSSM_OID type, - const std::string& value, - std::vector<std::string>* values[]) { - for (size_t oid = 0; oid < arraysize(kOIDs); ++oid) { - if (CSSMOIDEqual(&type, kOIDs[oid])) { - values[oid]->push_back(value); - break; + for (size_t rdn = 0; rdn < name->numberOfRDNs; ++rdn) { + CSSM_X509_RDN rdn_struct = name->RelativeDistinguishedName[rdn]; + for (size_t pair = 0; pair < rdn_struct.numberOfPairs; ++pair) { + CSSM_X509_TYPE_VALUE_PAIR pair_struct = + rdn_struct.AttributeTypeAndValue[pair]; + AddTypeValuePair(pair_struct.type, + DataToString(pair_struct.value), + values); } } + + SetSingle(common_names, &this->common_name); + SetSingle(locality_names, &this->locality_name); + SetSingle(state_names, &this->state_or_province_name); + SetSingle(country_names, &this->country_name); } -static void SetSingle(const std::vector<std::string> &values, - std::string* single_value) { - // We don't expect to have more than one CN, L, S, and C. - LOG_IF(WARNING, values.size() > 1) << "Didn't expect multiple values"; - if (values.size() > 0) - *single_value = values[0]; +bool CertPrincipal::Matches(const CertPrincipal& against) const { + return match(common_name, against.common_name) && + match(locality_name, against.locality_name) && + match(state_or_province_name, against.state_or_province_name) && + match(country_name, against.country_name) && + match(street_addresses, against.street_addresses) && + match(organization_names, against.organization_names) && + match(organization_unit_names, against.organization_unit_names) && + match(domain_components, against.domain_components); } } // namespace net diff --git a/net/base/x509_cert_types_unittest.cc b/net/base/x509_cert_types_mac_unittest.cc index 50012b1..e4809b0 100644 --- a/net/base/x509_cert_types_unittest.cc +++ b/net/base/x509_cert_types_mac_unittest.cc @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "base/basictypes.h" #include "net/base/x509_cert_types.h" #include "testing/gtest/include/gtest/gtest.h" @@ -249,8 +250,6 @@ TEST(X509TypesTest, Matching) { EXPECT_FALSE(spamco.Matches(bogus)); } -#if defined(OS_MACOSX) // ParseDistinguishedName not implemented for Win/Linux - TEST(X509TypesTest, ParseDNVerisign) { CertPrincipal verisign; EXPECT_TRUE(verisign.ParseDistinguishedName(VerisignDN, sizeof(VerisignDN))); @@ -339,6 +338,4 @@ TEST(X509TypesTest, ParseDNEntrust) { entrust.organization_unit_names[1]); } -#endif - -} +} // namespace net diff --git a/net/base/x509_certificate.cc b/net/base/x509_certificate.cc index d93d270..7385743 100644 --- a/net/base/x509_certificate.cc +++ b/net/base/x509_certificate.cc @@ -6,6 +6,7 @@ #include <map> +#include "base/lazy_instance.h" #include "base/logging.h" #include "base/metrics/histogram.h" #include "base/singleton.h" @@ -39,17 +40,6 @@ const char kCertificateHeader[] = "CERTIFICATE"; // The PEM block header used for PKCS#7 data const char kPKCS7Header[] = "PKCS7"; -} // namespace - -bool X509Certificate::LessThan::operator()(X509Certificate* lhs, - X509Certificate* rhs) const { - if (lhs == rhs) - return false; - - SHA1FingerprintLessThan fingerprint_functor; - return fingerprint_functor(lhs->fingerprint_, rhs->fingerprint_); -} - // A thread-safe cache for X509Certificate objects. // // The cache does not hold a reference to the certificate objects. The objects @@ -57,9 +47,8 @@ bool X509Certificate::LessThan::operator()(X509Certificate* lhs, // will be holding dead pointers to the objects). // TODO(rsleevi): There exists a chance of a use-after-free, due to a race // between Find() and Remove(). See http://crbug.com/49377 -class X509Certificate::Cache { +class X509CertificateCache { public: - static Cache* GetInstance(); void Insert(X509Certificate* cert); void Remove(X509Certificate* cert); X509Certificate* Find(const SHA1Fingerprint& fingerprint); @@ -68,9 +57,10 @@ class X509Certificate::Cache { typedef std::map<SHA1Fingerprint, X509Certificate*, SHA1FingerprintLessThan> CertMap; - // Obtain an instance of X509Certificate::Cache via GetInstance(). - Cache() {} - friend struct DefaultSingletonTraits<Cache>; + // Obtain an instance of X509CertificateCache via a LazyInstance. + X509CertificateCache() {} + ~X509CertificateCache() {} + friend struct base::DefaultLazyInstanceTraits<X509CertificateCache>; // You must acquire this lock before using any private data of this object. // You must not block while holding this lock. @@ -79,18 +69,16 @@ class X509Certificate::Cache { // The certificate cache. You must acquire |lock_| before using |cache_|. CertMap cache_; - DISALLOW_COPY_AND_ASSIGN(Cache); + DISALLOW_COPY_AND_ASSIGN(X509CertificateCache); }; -// Get the singleton object for the cache. -// static -X509Certificate::Cache* X509Certificate::Cache::GetInstance() { - return Singleton<X509Certificate::Cache>::get(); -} +base::LazyInstance<X509CertificateCache, + base::LeakyLazyInstanceTraits<X509CertificateCache> > + g_x509_certificate_cache(base::LINKER_INITIALIZED); // Insert |cert| into the cache. The cache does NOT AddRef |cert|. // Any existing certificate with the same fingerprint will be replaced. -void X509Certificate::Cache::Insert(X509Certificate* cert) { +void X509CertificateCache::Insert(X509Certificate* cert) { AutoLock lock(lock_); DCHECK(!IsNullFingerprint(cert->fingerprint())) << @@ -100,7 +88,7 @@ void X509Certificate::Cache::Insert(X509Certificate* cert) { // Remove |cert| from the cache. The cache does not assume that |cert| is // already in the cache. -void X509Certificate::Cache::Remove(X509Certificate* cert) { +void X509CertificateCache::Remove(X509Certificate* cert) { AutoLock lock(lock_); CertMap::iterator pos(cache_.find(cert->fingerprint())); @@ -111,7 +99,7 @@ void X509Certificate::Cache::Remove(X509Certificate* cert) { // Find a certificate in the cache with the given fingerprint. If one does // not exist, this method returns NULL. -X509Certificate* X509Certificate::Cache::Find( +X509Certificate* X509CertificateCache::Find( const SHA1Fingerprint& fingerprint) { AutoLock lock(lock_); @@ -122,6 +110,17 @@ X509Certificate* X509Certificate::Cache::Find( return pos->second; }; +} // namespace + +bool X509Certificate::LessThan::operator()(X509Certificate* lhs, + X509Certificate* rhs) const { + if (lhs == rhs) + return false; + + SHA1FingerprintLessThan fingerprint_functor; + return fingerprint_functor(lhs->fingerprint_, rhs->fingerprint_); +} + // static X509Certificate* X509Certificate::CreateFromHandle( OSCertHandle cert_handle, @@ -131,7 +130,7 @@ X509Certificate* X509Certificate::CreateFromHandle( DCHECK(source != SOURCE_UNUSED); // Check if we already have this certificate in memory. - X509Certificate::Cache* cache = X509Certificate::Cache::GetInstance(); + X509CertificateCache* cache = g_x509_certificate_cache.Pointer(); X509Certificate* cached_cert = cache->Find(CalculateFingerprint(cert_handle)); if (cached_cert) { @@ -311,7 +310,7 @@ X509Certificate::X509Certificate(const std::string& subject, X509Certificate::~X509Certificate() { // We might not be in the cache, but it is safe to remove ourselves anyway. - X509Certificate::Cache::GetInstance()->Remove(this); + g_x509_certificate_cache.Get().Remove(this); if (cert_handle_) FreeOSCertHandle(cert_handle_); for (size_t i = 0; i < intermediate_ca_certs_.size(); ++i) diff --git a/net/base/x509_certificate.h b/net/base/x509_certificate.h index 7da6ccb..763bf9d 100644 --- a/net/base/x509_certificate.h +++ b/net/base/x509_certificate.h @@ -23,6 +23,8 @@ #elif defined(OS_MACOSX) #include <CoreFoundation/CFArray.h> #include <Security/SecBase.h> + +#include "base/lock.h" #elif defined(USE_OPENSSL) // Forward declaration; real one in <x509.h> struct x509_st; @@ -290,8 +292,6 @@ class X509Certificate : public base::RefCountedThreadSafe<X509Certificate> { FRIEND_TEST_ALL_PREFIXES(X509CertificateTest, Cache); FRIEND_TEST_ALL_PREFIXES(X509CertificateTest, IntermediateCertificates); - class Cache; - // Construct an X509Certificate from a handle to the certificate object // in the underlying crypto library. X509Certificate(OSCertHandle cert_handle, Source source, diff --git a/net/base/x509_certificate_mac.cc b/net/base/x509_certificate_mac.cc index a2a0eea..5a5d457 100644 --- a/net/base/x509_certificate_mac.cc +++ b/net/base/x509_certificate_mac.cc @@ -8,8 +8,10 @@ #include <Security/Security.h> #include <time.h> +#include "base/lazy_instance.h" #include "base/logging.h" #include "base/pickle.h" +#include "base/singleton.h" #include "base/mac/scoped_cftyperef.h" #include "base/sys_string_conversions.h" #include "net/base/cert_status_flags.h" @@ -21,6 +23,8 @@ using base::Time; namespace net { +namespace { + class MacTrustedCertificates { public: // Sets the trusted root certificate used by tests. Call with |cert| set @@ -57,7 +61,7 @@ class MacTrustedCertificates { return merged_array; } private: - friend struct DefaultSingletonTraits<MacTrustedCertificates>; + friend struct base::DefaultLazyInstanceTraits<MacTrustedCertificates>; // Obtain an instance of MacTrustedCertificates via the singleton // interface. @@ -73,11 +77,9 @@ class MacTrustedCertificates { DISALLOW_COPY_AND_ASSIGN(MacTrustedCertificates); }; -void SetMacTestCertificate(X509Certificate* cert) { - Singleton<MacTrustedCertificates>::get()->SetTestCertificate(cert); -} - -namespace { +base::LazyInstance<MacTrustedCertificates, + base::LeakyLazyInstanceTraits<MacTrustedCertificates> > + g_mac_trusted_certificates(base::LINKER_INITIALIZED); typedef OSStatus (*SecTrustCopyExtendedResultFuncPtr)(SecTrustRef, CFDictionaryRef*); @@ -443,6 +445,10 @@ void AddCertificatesFromBytes(const char* data, size_t length, } // namespace +void SetMacTestCertificate(X509Certificate* cert) { + g_mac_trusted_certificates.Get().SetTestCertificate(cert); +} + void X509Certificate::Initialize() { const CSSM_X509_NAME* name; OSStatus status = SecCertificateGetSubject(cert_handle_, &name); @@ -545,7 +551,7 @@ int X509Certificate::Verify(const std::string& hostname, int flags, // Set the trusted anchor certificates for the SecTrustRef by merging the // system trust anchors and the test root certificate. CFArrayRef anchor_array = - Singleton<MacTrustedCertificates>::get()->CopyTrustedCertificateArray(); + g_mac_trusted_certificates.Get().CopyTrustedCertificateArray(); ScopedCFTypeRef<CFArrayRef> scoped_anchor_array(anchor_array); if (anchor_array) { status = SecTrustSetAnchorCertificates(trust_ref, anchor_array); diff --git a/net/base/x509_certificate_openssl.cc b/net/base/x509_certificate_openssl.cc index 84a47ec..5b2d365 100644 --- a/net/base/x509_certificate_openssl.cc +++ b/net/base/x509_certificate_openssl.cc @@ -206,6 +206,13 @@ void DERCache_free(void* parent, void* ptr, CRYPTO_EX_DATA* ad, int idx, class X509InitSingleton { public: + static X509InitSingleton* Get() { + // We allow the X509 store to leak, because it is used from a non-joinable + // worker that is not stopped on shutdown, hence may still be using + // OpenSSL library after the AtExit runner has completed. + return Singleton<X509InitSingleton, + LeakySingletonTraits<X509InitSingleton> >::get(); + } int der_cache_ex_index() const { return der_cache_ex_index_; } X509_STORE* store() const { return store_.get(); } @@ -252,8 +259,7 @@ DERCache* SetDERCache(X509Certificate::OSCertHandle cert, // not free it). bool GetDERAndCacheIfNeeded(X509Certificate::OSCertHandle cert, DERCache* der_cache) { - int x509_der_cache_index = - Singleton<X509InitSingleton>::get()->der_cache_ex_index(); + int x509_der_cache_index = X509InitSingleton::Get()->der_cache_ex_index(); // Re-encoding the DER data via i2d_X509 is an expensive operation, but it's // necessary for comparing two certificates. We re-encode at most once per @@ -386,7 +392,7 @@ void X509Certificate::GetDNSNames(std::vector<std::string>* dns_names) const { // static X509_STORE* X509Certificate::cert_store() { - return Singleton<X509InitSingleton>::get()->store(); + return X509InitSingleton::Get()->store(); } #ifndef ANDROID diff --git a/net/base/x509_certificate_win.cc b/net/base/x509_certificate_win.cc index 9e018fd..75cdf40 100644 --- a/net/base/x509_certificate_win.cc +++ b/net/base/x509_certificate_win.cc @@ -6,6 +6,7 @@ #include "base/logging.h" #include "base/pickle.h" +#include "base/singleton.h" #include "base/string_tokenizer.h" #include "base/string_util.h" #include "base/utf_string_conversions.h" diff --git a/net/data/ftp/dir-listing-ls-20 b/net/data/ftp/dir-listing-ls-20 new file mode 100644 index 0000000..18d5bb2 --- /dev/null +++ b/net/data/ftp/dir-listing-ls-20 @@ -0,0 +1,18 @@ +drwxrwxr-x 17 ftp ftp 4096 Nov 01 16:27 .
+drwxr-xr-x 7 ftp ftp 4096 Apr 03 2010 ..
+drwxrwxrwx 5 ftp ftp 4096 May 26 15:51 2012_-_2012-(2009)-[1080p]_[BD]
+-rw-rw-rw- 1 ftp ftp 4931 Jun 08 15:24 _READ_ME.txt
+drwxrwxrwx 5 ftp ftp 4096 Nov 01 16:27 Áåç_ëèöà_-_Face_Off-(1997)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jan 22 2010 Ââåðõ_-_Up-(2009)-[1080p]_[BD]
+drwxrwxrwx 5 ftp ftp 4096 May 27 18:12 Âñïîìíèòü_âñå_-_Total_Recall-(1990)-[1080p]_[BD]
+drwxrwxrwx 3 ftp ftp 4096 May 21 14:28 Çàêîíîïîñëóøíûé_ãðàæäàíèí_[Ðàñøèðåííàÿ_âåðñèÿ]_-_Law_Abiding_Citizen_[Unrated_Edition]-(2009)-[1080p]_[BD_Remux]
+drwxrwxrwx 5 ftp ftp 4096 Jan 21 2010 Ìîíñòðî_-_Cloverfield-(2008)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 May 26 13:43 Ïàíäîðóì_-_Pandorum-(2009)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 May 27 14:18 Ïîñëåäíèé_ñàìóðàé_-_The_Last_Samurai-(2003)-[1080p]_[BD]
+drwxrwxrwx 6 ftp ftp 4096 Jan 27 2010 Ðàéîí_9_-_District_9-(2009)-[1080p]_[BD]
+drwxrwxrwx 7 ftp ftp 4096 Jan 01 2010 Ðîêêè_Àíòîëîãèÿ_-_Rocky_The_Undisputed_Collection-(1976_1979_1982_1985_1990)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 May 31 12:57 Ñóððîãàòû_-_Surrogates-(2009)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jan 28 2010 Òðîéíîé_Ôîðñàæ-Òîêèéñêèé_Äðèôò_-_The_Fast_and_the_Furious-Tokyo_Drift-(2006)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jan 08 2010 Ôîðñàæ_-_The_Fast_and_the_Furious-(2001)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jan 06 2010 Ôîðñàæ_2_-_2_Fast_2_Furious-(2003)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jun 08 15:26 Ôîðñàæ_4_-_Fast_&_Furious-(2009)-[1080p]_[BD]
diff --git a/net/data/ftp/dir-listing-ls-20.expected b/net/data/ftp/dir-listing-ls-20.expected new file mode 100644 index 0000000..9c636b9 --- /dev/null +++ b/net/data/ftp/dir-listing-ls-20.expected @@ -0,0 +1,161 @@ +d
+.
+-1
+1994
+11
+1
+16
+27
+
+d
+..
+-1
+2010
+4
+3
+0
+0
+
+d
+2012_-_2012-(2009)-[1080p]_[BD]
+-1
+1994
+5
+26
+15
+51
+
+-
+_READ_ME.txt
+4931
+1994
+6
+8
+15
+24
+
+d
+Ãåç_ëèöà _-_Face_Off-(1997)-[1080p]_[BD]
+-1
+1994
+11
+1
+16
+27
+
+d
+Ââåðõ_-_Up-(2009)-[1080p]_[BD]
+-1
+2010
+1
+22
+0
+0
+
+d
+ÂñïîìÃèòü_âñå_-_Total_Recall-(1990)-[1080p]_[BD]
+-1
+1994
+5
+27
+18
+12
+
+d
+Çà êîÃîïîñëóøÃûé_ãðà æäà ÃèÃ_[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_Law_Abiding_Citizen_[Unrated_Edition]-(2009)-[1080p]_[BD_Remux]
+-1
+1994
+5
+21
+14
+28
+
+d
+ÌîÃñòðî_-_Cloverfield-(2008)-[1080p]_[BD]
+-1
+2010
+1
+21
+0
+0
+
+d
+Ãà Ãäîðóì_-_Pandorum-(2009)-[1080p]_[BD]
+-1
+1994
+5
+26
+13
+43
+
+d
+ÃîñëåäÃèé_ñà ìóðà é_-_The_Last_Samurai-(2003)-[1080p]_[BD]
+-1
+1994
+5
+27
+14
+18
+
+d
+Ãà éîÃ_9_-_District_9-(2009)-[1080p]_[BD]
+-1
+2010
+1
+27
+0
+0
+
+d
+Ãîêêè_ÀÃòîëîãèÿ_-_Rocky_The_Undisputed_Collection-(1976_1979_1982_1985_1990)-[1080p]_[BD]
+-1
+2010
+1
+1
+0
+0
+
+d
+Ñóððîãà òû_-_Surrogates-(2009)-[1080p]_[BD]
+-1
+1994
+5
+31
+12
+57
+
+d
+ÒðîéÃîé_Ôîðñà æ-Òîêèéñêèé_Äðèôò_-_The_Fast_and_the_Furious-Tokyo_Drift-(2006)-[1080p]_[BD]
+-1
+2010
+1
+28
+0
+0
+
+d
+Ôîðñà æ_-_The_Fast_and_the_Furious-(2001)-[1080p]_[BD]
+-1
+2010
+1
+8
+0
+0
+
+d
+Ôîðñà æ_2_-_2_Fast_2_Furious-(2003)-[1080p]_[BD]
+-1
+2010
+1
+6
+0
+0
+
+d
+Ôîðñà æ_4_-_Fast_&_Furious-(2009)-[1080p]_[BD]
+-1
+1994
+6
+8
+15
+26
diff --git a/net/data/ftp/dir-listing-ls-21 b/net/data/ftp/dir-listing-ls-21 new file mode 100644 index 0000000..1246efd --- /dev/null +++ b/net/data/ftp/dir-listing-ls-21 @@ -0,0 +1,27 @@ +drwxrwxr-x 26 ftp ftp 4096 Jul 15 2009 .
+drwxr-xr-x 7 ftp ftp 4096 Apr 03 2010 ..
+-rw-rw-rw- 1 ftp ftp 4931 Jun 08 15:24 _READ_ME.txt
+drwxr-xr-x 5 ftp ftp 4096 Apr 27 2009 Àâàëîí_-_Avalon-(2001)-[1080p]_[BD_Remux]
+drwxrwxrwx 5 ftp ftp 4096 Jun 15 2009 Áðþñ_âñåìîãóùèé_-_Bruce_Almighty-(2003)-[1080p]_[BD]
+drwxr-xr-x 4 ftp ftp 4096 Apr 15 2009 ÂÀËË-È_-_WALL-E-(2008)-[1080p]_[BD]
+drwxr-xr-x 4 ftp ftp 4096 Apr 28 2009 Äæåéìñ_Áîíä_007-Êâàíò_ìèëîñåðäèÿ_-_James_Bond_007-Quantum_of_Solace-(2008)-[1080p]_[BD]
+drwxr-xr-x 4 ftp ftp 4096 Apr 15 2009 Êîñìîñ-Òåððèòîðèÿ_cìåðòè_-_Dead_Space-Downfall-(2008)-[1080p]_[BD]
+drwxrwxrwx 5 ftp ftp 4096 Jul 03 2009 Ìàäàãàñêàð_1_-_Madagascar_1-(2005)-[1080p]_[BD]
+drwxrwxrwx 5 ftp ftp 4096 Jul 03 2009 Ìàäàãàñêàð_2_-_Madagascar-Escape_2_Africa-(2008)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jun 13 2009 Ìàòðèöà-Ïåðåçàãðóçêà_-_The_Matrix-Reloaded-(2003)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jun 14 2009 Ìàòðèöà-Ðåâîëþöèÿ_-_The_Matrix-Revolutions-(2003)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jun 12 2009 Ìàòðèöà_-_The_Matrix-(1999)-[1080p]_[BD]
+drwxrwxrwx 4 ftp ftp 4096 Jul 02 2009 Îáèòåëü_çëà_3_-_Resident_Evil-Extinction-(2007)-[1080p]_[BD]
+drwxr-xr-x 3 ftp ftp 4096 May 01 2009 Îñòðîâ_-_The_Island-(2005)-[1080p]_[BD]
+drwxrwxrwx 5 ftp ftp 4096 Jul 03 2009 Ïåðåâîç÷èê_3_-_Transporter_3-(2008)-[1080p]_[BD]
+drwxrwxr-x 5 ftp ftp 4096 May 02 2009 Ïèðàòû_Êàðèáñêîãî_ìîðÿ-Íà_êðàþ_Ñâåòà_-_Pirates_of_the_Caribbean-At_World's_End-(2007)-[1080p]_[BD]
+drwxrwxr-x 5 ftp ftp 4096 May 03 2009 Ïèðàòû_Êàðèáñêîãî_ìîðÿ-Ïðîêëÿòèå_×åðíîé_Æåì÷óæèíû_-_Pirates_of_the_Caribbean-The_Curse_of_the_Black_Pearl-(2003)-[1080p]_[BD]
+drwxrwxr-x 5 ftp ftp 4096 May 02 2009 Ïèðàòû_Êàðèáñêîãî_ìîðÿ-Ñóíäóê_ìåðòâåöà_-_Pirates_of_the_Caribbean-Dead_Man's_Chest-(2006)-[1080p]_[BD]
+drwxrwxr-x 3 ftp ftp 4096 May 01 2009 Ïðèçðà÷íûé_ãîíùèê_-_Ghost_Rider-(2007)-[1080p]_[BD]
+drwxr-xr-x 5 ftp ftp 4096 Apr 29 2009 Ïðèíöåññà-íåâåñòà_-_The_Princess_Bride-(1987)-[1080p]_[BD]
+drwxrwxrwx 5 ftp ftp 4096 Jun 08 2009 Ñåêñ_è_101_ñìåðòü_-_Sex_and_Death_101-(2007)-[1080p]_[BD]
+drwxr-xr-x 4 ftp ftp 4096 May 01 2009 Òðàíñôîðìåðû-Áîíóñ_äèñê_-_Transformers-Bonus_Disk-(2007)-[1080p]_[BD]
+drwxr-xr-x 4 ftp ftp 4096 Apr 30 2009 Òðàíñôîðìåðû_-_Transformers-(2007)-[1080p]_[BD]
+drwxrwxrwx 6 ftp ftp 4096 Jun 07 2009 Òðèíàäöàòûé_ýòàæ_-_The_Thirteenth_Floor-(1999)-[1080p]_[BD]
+drwxrwxr-x 3 ftp ftp 4096 May 04 2009 Óëè÷íûé_áîåö_-_Street_Fighter-(1994)-[1080p]_[BD_Remux]
+drwxr-xr-x 5 ftp ftp 4096 Mar 15 2009 ×åãî_õîòÿò_æåíùèíû_-_What_Woman_Want-(2000)-[1080p]_[BD]
diff --git a/net/data/ftp/dir-listing-ls-21.expected b/net/data/ftp/dir-listing-ls-21.expected new file mode 100644 index 0000000..0111be1 --- /dev/null +++ b/net/data/ftp/dir-listing-ls-21.expected @@ -0,0 +1,242 @@ +d
+.
+-1
+2009
+7
+15
+0
+0
+
+d
+..
+-1
+2010
+4
+3
+0
+0
+
+-
+_READ_ME.txt
+4931
+1994
+6
+8
+15
+24
+
+d
+Àâà ëîÃ_-_Avalon-(2001)-[1080p]_[BD_Remux]
+-1
+2009
+4
+27
+0
+0
+
+d
+Ãðþñ_âñåìîãóùèé_-_Bruce_Almighty-(2003)-[1080p]_[BD]
+-1
+2009
+6
+15
+0
+0
+
+d
+ÂÀËË-È_-_WALL-E-(2008)-[1080p]_[BD]
+-1
+2009
+4
+15
+0
+0
+
+d
+Äæåéìñ_ÃîÃä_007-Êâà Ãò_ìèëîñåðäèÿ_-_James_Bond_007-Quantum_of_Solace-(2008)-[1080p]_[BD]
+-1
+2009
+4
+28
+0
+0
+
+d
+Êîñìîñ-Òåððèòîðèÿ_cìåðòè_-_Dead_Space-Downfall-(2008)-[1080p]_[BD]
+-1
+2009
+4
+15
+0
+0
+
+d
+Ìà äà ãà ñêà ð_1_-_Madagascar_1-(2005)-[1080p]_[BD]
+-1
+2009
+7
+3
+0
+0
+
+d
+Ìà äà ãà ñêà ð_2_-_Madagascar-Escape_2_Africa-(2008)-[1080p]_[BD]
+-1
+2009
+7
+3
+0
+0
+
+d
+Ìà òðèöà -Ãåðåçà ãðóçêà _-_The_Matrix-Reloaded-(2003)-[1080p]_[BD]
+-1
+2009
+6
+13
+0
+0
+
+d
+Ìà òðèöà -Ãåâîëþöèÿ_-_The_Matrix-Revolutions-(2003)-[1080p]_[BD]
+-1
+2009
+6
+14
+0
+0
+
+d
+Ìà òðèöà _-_The_Matrix-(1999)-[1080p]_[BD]
+-1
+2009
+6
+12
+0
+0
+
+d
+Îáèòåëü_çëà _3_-_Resident_Evil-Extinction-(2007)-[1080p]_[BD]
+-1
+2009
+7
+2
+0
+0
+
+d
+Îñòðîâ_-_The_Island-(2005)-[1080p]_[BD]
+-1
+2009
+5
+1
+0
+0
+
+d
+Ãåðåâîç÷èê_3_-_Transporter_3-(2008)-[1080p]_[BD]
+-1
+2009
+7
+3
+0
+0
+
+d
+Ãèðà òû_Êà ðèáñêîãî_ìîðÿ-Ãà _êðà þ_Ñâåòà _-_Pirates_of_the_Caribbean-At_World's_End-(2007)-[1080p]_[BD]
+-1
+2009
+5
+2
+0
+0
+
+d
+Ãèðà òû_Êà ðèáñêîãî_ìîðÿ-Ãðîêëÿòèå_×åðÃîé_Æåì÷óæèÃû_-_Pirates_of_the_Caribbean-The_Curse_of_the_Black_Pearl-(2003)-[1080p]_[BD]
+-1
+2009
+5
+3
+0
+0
+
+d
+Ãèðà òû_Êà ðèáñêîãî_ìîðÿ-ÑóÃäóê_ìåðòâåöà _-_Pirates_of_the_Caribbean-Dead_Man's_Chest-(2006)-[1080p]_[BD]
+-1
+2009
+5
+2
+0
+0
+
+d
+Ãðèçðà ÷Ãûé_ãîÃùèê_-_Ghost_Rider-(2007)-[1080p]_[BD]
+-1
+2009
+5
+1
+0
+0
+
+d
+ÃðèÃöåññà -Ãåâåñòà _-_The_Princess_Bride-(1987)-[1080p]_[BD]
+-1
+2009
+4
+29
+0
+0
+
+d
+Ñåêñ_è_101_ñìåðòü_-_Sex_and_Death_101-(2007)-[1080p]_[BD]
+-1
+2009
+6
+8
+0
+0
+
+d
+Òðà Ãñôîðìåðû-ÃîÃóñ_äèñê_-_Transformers-Bonus_Disk-(2007)-[1080p]_[BD]
+-1
+2009
+5
+1
+0
+0
+
+d
+Òðà Ãñôîðìåðû_-_Transformers-(2007)-[1080p]_[BD]
+-1
+2009
+4
+30
+0
+0
+
+d
+ÒðèÃà äöà òûé_ýòà æ_-_The_Thirteenth_Floor-(1999)-[1080p]_[BD]
+-1
+2009
+6
+7
+0
+0
+
+d
+Óëè÷Ãûé_áîåö_-_Street_Fighter-(1994)-[1080p]_[BD_Remux]
+-1
+2009
+5
+4
+0
+0
+
+d
+×åãî_õîòÿò_æåÃùèÃû_-_What_Woman_Want-(2000)-[1080p]_[BD]
+-1
+2009
+3
+15
+0
+0
diff --git a/net/data/ftp/dir-listing-ls-22 b/net/data/ftp/dir-listing-ls-22 new file mode 100644 index 0000000..df44141 --- /dev/null +++ b/net/data/ftp/dir-listing-ls-22 @@ -0,0 +1,32 @@ +drwxrwxr-x 5 ftp ftp 12288 Oct 20 17:04 .
+drwxr-xr-x 7 ftp ftp 4096 Apr 03 2010 ..
+-rw-rw-rw- 1 ftp ftp 4931 Jun 08 15:23 _READ_ME.txt
+drwxrwxrwx 4 ftp ftp 4096 May 31 16:26 Àâàòàð_-_Avatar-(2009)-[1080p]_[BD]
+-rwxrwxr-x 1 ftp ftp 17577705968 Mar 08 2009 Àìåðèêàíñêèé_ïèðîã_1_[Ðàñøèðåííàÿ_âåðñèÿ]_-_American_Pie_1_[Unrated_Edition]-(1999)-[1080p]_[BD_remux].ts
+-rwxrwxr-x 1 ftp ftp 15512934868 Mar 16 2009 Áîëüøîé_êóø_-_Snatch-(2000)-[1080i]_[HDTV].ts
+drwxrwxrwx 2 ftp ftp 4096 Jun 03 19:07 Áîëüøîé_êóø_-_Snatch-(2000)-[1080p]_[BD_Remux]
+-rwxrwxr-x 1 ftp ftp 8900589105 Mar 24 2009 Âîéíà_ìèðîâ_-_War_of_the_Worlds-(2005)-[720p]_[HDTV].mkv
+-rwxrwxr-x 1 ftp ftp 27728321654 Mar 09 2009 Ãàíãñòåð_-_American_Gangster-(2007)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 31731782861 Mar 09 2009 Ãàíãñòåð_[Ðàñøèðåííàÿ_âåðñèÿ]_-_American_Gangster_[Unrated_Edition]-(2007)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 5009104014 Mar 24 2009 Äîðîæíîå_ïðèêëþ÷åíèå_-_Road_Trip-(2000)-[720p]_[HDTV_Rip].mkv
+-rwxrwxr-x 1 ftp ftp 21410583980 Mar 11 2009 Çâ¸çäíûå_âîéíû-Ýïèçîä_2-Àòàêà_êëîíîâ_-_Star_Wars-Episode_2-Attack_of_the_Clones-(2002)-[1080i]_[HDTV].ts
+-rwxrwxr-x 1 ftp ftp 19858181688 Mar 11 2009 Çâ¸çäíûå_âîéíû-Ýïèçîä_3-Ìåñòü_Ñèòõîâ_-_Star_Wars-Episode_3-Revenge_of_the_Sith-(2005)-[1080i]_[HDTV].ts
+-rwxrwxr-x 1 ftp ftp 29026065728 Mar 16 2009 Çâ¸çäíûé_äåñàíò_-_Starship_Troopers-(1997)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 22169179449 Mar 16 2009 Çåðêàëà_[Ðàñøèðåííàÿ_âåðñèÿ]_-_Mirrors_[Unrated_Edition]-(2008)-[1080p]_[BD_remux].mkv
+drwxrwxrwx 4 ftp ftp 4096 Jun 15 14:56 Íèíäçÿ-óáèéöà_-_Ninja_Assassin-(2009)-[1080p]_[BD]
+-rwxrwxr-x 1 ftp ftp 19717173247 Mar 11 2009 Îáèòåëü_çëà_3_-_Resident_Evil-Extinction-(2007)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 18660904388 Mar 11 2009 Ïàòîëîãèÿ_-_Pathology-(2008)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 16476154520 Mar 05 2009 Ïèëà_1_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Saw_I_[Director's_Cut]-(2004)-[1080p]_[HDDVD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 19917510515 Mar 05 2009 Ïèëà_2_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Saw_II_[Director's_Cut]-(2005)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 18085592265 Mar 05 2009 Ïèëà_3_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Saw_III_[Director's_Cut]-(2006)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 3473582701 Mar 05 2009 Ïèëà_4_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Saw_IV_[Director's_Cut]-(2007)-[1080p]_[BD_remux].flac
+-rwxrwxr-x 1 ftp ftp 15263958421 Mar 05 2009 Ïèëà_4_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Saw_IV_[Director's_Cut]-(2007)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 19944605507 Mar 16 2009 Ïèëà_5_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Saw_V_[Director's_Cut]-(2008)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 3024333064 Mar 24 2009 Ïèíãâèíû_èç_Ìàäàãàñêàðà-Îïåðàöèÿ_Ñ_Íîâûì_Ãîäîì!_-_The_Madagascar_Penguins_in_A_Christmas_Caper-(2005)-[1080p]_[BD_remux].ts
+-rwxrwxr-x 1 ftp ftp 125961 Mar 05 2009 Ïëîõîé_Ñàíòà_[Ðàñøèðåííàÿ_âåðñèÿ]_-_Bad_Santa_[Unrated_Edition]-(2003)-[1080p]_[BD_remux].srt
+-rwxrwxr-x 1 ftp ftp 19908695408 Mar 05 2009 Ïëîõîé_Ñàíòà_[Ðàñøèðåííàÿ_âåðñèÿ]_-_Bad_Santa_[Unrated_Edition]-(2003)-[1080p]_[BD_remux].ts
+-rwxrwxr-x 1 ftp ftp 23185439267 Mar 11 2009 Ïîáåã_èç_Øîóøåíêà_-_The_Shawshank_Redemption-(1994)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 19567287274 Mar 16 2009 Òóïîé_è_åùå_òóïåå_[Ðàñøèðåííàÿ_âåðñèÿ]_-_Dumb_and_Dumber_[Unrated_Edition]-(1994)-[1080p]_[BD_remux].mkv
+-rwxrwxr-x 1 ftp ftp 14773061093 Mar 16 2009 Óðàãàí_-_The_Hurricane-(1999)-[1080p]_[HDDVD_Rip].mkv
+-rwxrwxr-x 1 ftp ftp 22411268500 Mar 11 2009 Õîñòåë_2_[Ðåæèññ¸ðñêàÿ_âåðñèÿ]_-_Hostel_2_[Director's_Cut]-(2007)-[1080p]_[BD_remux].ts
+-rwxrwxr-x 1 ftp ftp 23712519861 Mar 11 2009 ×óæîé_ïðîòèâ_Õèùíèêà_[Ðàñøèðåííàÿ_âåðñèÿ]_-_Alien_vs_Predator_[Unrated_Edition]-(2004)-[1080p]_[BD_remux].mkv
diff --git a/net/data/ftp/dir-listing-ls-22.expected b/net/data/ftp/dir-listing-ls-22.expected new file mode 100644 index 0000000..c5a02b7 --- /dev/null +++ b/net/data/ftp/dir-listing-ls-22.expected @@ -0,0 +1,287 @@ +d
+.
+-1
+1994
+10
+20
+17
+4
+
+d
+..
+-1
+2010
+4
+3
+0
+0
+
+-
+_READ_ME.txt
+4931
+1994
+6
+8
+15
+23
+
+d
+Àâà òà ð_-_Avatar-(2009)-[1080p]_[BD]
+-1
+1994
+5
+31
+16
+26
+
+-
+Àìåðèêà Ãñêèé_ïèðîã_1_[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_American_Pie_1_[Unrated_Edition]-(1999)-[1080p]_[BD_remux].ts
+17577705968
+2009
+3
+8
+0
+0
+
+-
+Ãîëüøîé_êóø_-_Snatch-(2000)-[1080i]_[HDTV].ts
+15512934868
+2009
+3
+16
+0
+0
+
+d
+Ãîëüøîé_êóø_-_Snatch-(2000)-[1080p]_[BD_Remux]
+-1
+1994
+6
+3
+19
+7
+
+-
+ÂîéÃà _ìèðîâ_-_War_of_the_Worlds-(2005)-[720p]_[HDTV].mkv
+8900589105
+2009
+3
+24
+0
+0
+
+-
+Ãà Ããñòåð_-_American_Gangster-(2007)-[1080p]_[BD_remux].mkv
+27728321654
+2009
+3
+9
+0
+0
+
+-
+Ãà Ããñòåð_[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_American_Gangster_[Unrated_Edition]-(2007)-[1080p]_[BD_remux].mkv
+31731782861
+2009
+3
+9
+0
+0
+
+-
+ÄîðîæÃîå_ïðèêëþ÷åÃèå_-_Road_Trip-(2000)-[720p]_[HDTV_Rip].mkv
+5009104014
+2009
+3
+24
+0
+0
+
+-
+Çâ¸çäÃûå_âîéÃû-Ãïèçîä_2-Àòà êà _êëîÃîâ_-_Star_Wars-Episode_2-Attack_of_the_Clones-(2002)-[1080i]_[HDTV].ts
+21410583980
+2009
+3
+11
+0
+0
+
+-
+Çâ¸çäÃûå_âîéÃû-Ãïèçîä_3-Ìåñòü_Ñèòõîâ_-_Star_Wars-Episode_3-Revenge_of_the_Sith-(2005)-[1080i]_[HDTV].ts
+19858181688
+2009
+3
+11
+0
+0
+
+-
+Çâ¸çäÃûé_äåñà Ãò_-_Starship_Troopers-(1997)-[1080p]_[BD_remux].mkv
+29026065728
+2009
+3
+16
+0
+0
+
+-
+Çåðêà ëà _[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_Mirrors_[Unrated_Edition]-(2008)-[1080p]_[BD_remux].mkv
+22169179449
+2009
+3
+16
+0
+0
+
+d
+ÃèÃäçÿ-óáèéöà _-_Ninja_Assassin-(2009)-[1080p]_[BD]
+-1
+1994
+6
+15
+14
+56
+
+-
+Îáèòåëü_çëà _3_-_Resident_Evil-Extinction-(2007)-[1080p]_[BD_remux].mkv
+19717173247
+2009
+3
+11
+0
+0
+
+-
+Ãà òîëîãèÿ_-_Pathology-(2008)-[1080p]_[BD_remux].mkv
+18660904388
+2009
+3
+11
+0
+0
+
+-
+Ãèëà _1_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Saw_I_[Director's_Cut]-(2004)-[1080p]_[HDDVD_remux].mkv
+16476154520
+2009
+3
+5
+0
+0
+
+-
+Ãèëà _2_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Saw_II_[Director's_Cut]-(2005)-[1080p]_[BD_remux].mkv
+19917510515
+2009
+3
+5
+0
+0
+
+-
+Ãèëà _3_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Saw_III_[Director's_Cut]-(2006)-[1080p]_[BD_remux].mkv
+18085592265
+2009
+3
+5
+0
+0
+
+-
+Ãèëà _4_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Saw_IV_[Director's_Cut]-(2007)-[1080p]_[BD_remux].flac
+3473582701
+2009
+3
+5
+0
+0
+
+-
+Ãèëà _4_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Saw_IV_[Director's_Cut]-(2007)-[1080p]_[BD_remux].mkv
+15263958421
+2009
+3
+5
+0
+0
+
+-
+Ãèëà _5_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Saw_V_[Director's_Cut]-(2008)-[1080p]_[BD_remux].mkv
+19944605507
+2009
+3
+16
+0
+0
+
+-
+ÃèÃãâèÃû_èç_Ìà äà ãà ñêà ðà -Îïåðà öèÿ_Ñ_Ãîâûì_Ãîäîì!_-_The_Madagascar_Penguins_in_A_Christmas_Caper-(2005)-[1080p]_[BD_remux].ts
+3024333064
+2009
+3
+24
+0
+0
+
+-
+Ãëîõîé_Ñà Ãòà _[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_Bad_Santa_[Unrated_Edition]-(2003)-[1080p]_[BD_remux].srt
+125961
+2009
+3
+5
+0
+0
+
+-
+Ãëîõîé_Ñà Ãòà _[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_Bad_Santa_[Unrated_Edition]-(2003)-[1080p]_[BD_remux].ts
+19908695408
+2009
+3
+5
+0
+0
+
+-
+Ãîáåã_èç_ØîóøåÃêà _-_The_Shawshank_Redemption-(1994)-[1080p]_[BD_remux].mkv
+23185439267
+2009
+3
+11
+0
+0
+
+-
+Òóïîé_è_åùå_òóïåå_[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_Dumb_and_Dumber_[Unrated_Edition]-(1994)-[1080p]_[BD_remux].mkv
+19567287274
+2009
+3
+16
+0
+0
+
+-
+Óðà ãà Ã_-_The_Hurricane-(1999)-[1080p]_[HDDVD_Rip].mkv
+14773061093
+2009
+3
+16
+0
+0
+
+-
+Õîñòåë_2_[Ãåæèññ¸ðñêà ÿ_âåðñèÿ]_-_Hostel_2_[Director's_Cut]-(2007)-[1080p]_[BD_remux].ts
+22411268500
+2009
+3
+11
+0
+0
+
+-
+×óæîé_ïðîòèâ_ÕèùÃèêà _[Ãà ñøèðåÃÃà ÿ_âåðñèÿ]_-_Alien_vs_Predator_[Unrated_Edition]-(2004)-[1080p]_[BD_remux].mkv
+23712519861
+2009
+3
+11
+0
+0
diff --git a/net/disk_cache/backend_impl.cc b/net/disk_cache/backend_impl.cc index 6162a77..78ebcc0 100644 --- a/net/disk_cache/backend_impl.cc +++ b/net/disk_cache/backend_impl.cc @@ -179,12 +179,7 @@ bool SetFieldTrialInfo(int size_group) { std::string group1 = base::StringPrintf("CacheSizeGroup_%d", size_group); trial1->AppendGroup(group1, base::FieldTrial::kAllRemainingProbability); - scoped_refptr<base::FieldTrial> trial2( - new base::FieldTrial("CacheThrottle", 100)); - int group2a = trial2->AppendGroup("CacheThrottle_On", 10); // 10 % in. - trial2->AppendGroup("CacheThrottle_Off", 10); // 10 % control. - - return trial2->group() == group2a; + return false; } // ------------------------------------------------------------------------ @@ -1219,10 +1214,6 @@ void BackendImpl::OnOperationCompleted(base::TimeDelta elapsed_time) { if (cache_type() != net::DISK_CACHE) return; - UMA_HISTOGRAM_TIMES(base::FieldTrial::MakeName("DiskCache.TotalIOTime", - "CacheThrottle").data(), - elapsed_time); - if (!throttle_requests_) return; diff --git a/net/ftp/ftp_directory_listing_buffer.cc b/net/ftp/ftp_directory_listing_buffer.cc index a173399..f6e8748 100644 --- a/net/ftp/ftp_directory_listing_buffer.cc +++ b/net/ftp/ftp_directory_listing_buffer.cc @@ -37,7 +37,7 @@ int FtpDirectoryListingBuffer::ConsumeData(const char* data, int data_length) { buffer_.append(data, data_length); if (!encoding_.empty() || buffer_.length() > 1024) { - int rv = ExtractFullLinesFromBuffer(); + int rv = ConsumeBuffer(); if (rv != OK) return rv; } @@ -46,11 +46,12 @@ int FtpDirectoryListingBuffer::ConsumeData(const char* data, int data_length) { } int FtpDirectoryListingBuffer::ProcessRemainingData() { - int rv = ExtractFullLinesFromBuffer(); + int rv = ConsumeBuffer(); if (rv != OK) return rv; - if (!buffer_.empty()) + DCHECK(buffer_.empty()); + if (!converted_buffer_.empty()) return ERR_INVALID_RESPONSE; rv = ParseLines(); @@ -77,38 +78,62 @@ FtpServerType FtpDirectoryListingBuffer::GetServerType() const { return (current_parser_ ? current_parser_->GetServerType() : SERVER_UNKNOWN); } -bool FtpDirectoryListingBuffer::ConvertToDetectedEncoding( - const std::string& from, string16* to) { - std::string encoding(encoding_.empty() ? "ascii" : encoding_); - return base::CodepageToUTF16(from, encoding.c_str(), - base::OnStringConversionError::FAIL, to); +int FtpDirectoryListingBuffer::DecodeBufferUsingEncoding( + const std::string& encoding) { + string16 converted; + if (!base::CodepageToUTF16(buffer_, + encoding.c_str(), + base::OnStringConversionError::FAIL, + &converted)) + return ERR_ENCODING_CONVERSION_FAILED; + + buffer_.clear(); + converted_buffer_ += converted; + return OK; } -int FtpDirectoryListingBuffer::ExtractFullLinesFromBuffer() { +int FtpDirectoryListingBuffer::ConvertBufferToUTF16() { if (encoding_.empty()) { - if (!base::DetectEncoding(buffer_, &encoding_)) + std::vector<std::string> encodings; + if (!base::DetectAllEncodings(buffer_, &encodings)) return ERR_ENCODING_DETECTION_FAILED; + + // Use first encoding that can be used to decode the buffer. + for (size_t i = 0; i < encodings.size(); i++) { + if (DecodeBufferUsingEncoding(encodings[i]) == OK) { + encoding_ = encodings[i]; + return OK; + } + } + + return ERR_ENCODING_DETECTION_FAILED; } + return DecodeBufferUsingEncoding(encoding_); +} + +void FtpDirectoryListingBuffer::ExtractFullLinesFromBuffer() { int cut_pos = 0; // TODO(phajdan.jr): This code accepts all endlines matching \r*\n. Should it // be more strict, or enforce consistent line endings? - for (size_t i = 0; i < buffer_.length(); ++i) { - if (buffer_[i] != '\n') + for (size_t i = 0; i < converted_buffer_.length(); ++i) { + if (converted_buffer_[i] != '\n') continue; int line_length = i - cut_pos; - if (i >= 1 && buffer_[i - 1] == '\r') + if (i >= 1 && converted_buffer_[i - 1] == '\r') line_length--; - std::string line(buffer_.substr(cut_pos, line_length)); + lines_.push_back(converted_buffer_.substr(cut_pos, line_length)); cut_pos = i + 1; - string16 line_converted; - if (!ConvertToDetectedEncoding(line, &line_converted)) { - buffer_.erase(0, cut_pos); - return ERR_ENCODING_CONVERSION_FAILED; - } - lines_.push_back(line_converted); } - buffer_.erase(0, cut_pos); + converted_buffer_.erase(0, cut_pos); +} + +int FtpDirectoryListingBuffer::ConsumeBuffer() { + int rv = ConvertBufferToUTF16(); + if (rv != OK) + return rv; + + ExtractFullLinesFromBuffer(); return OK; } diff --git a/net/ftp/ftp_directory_listing_buffer.h b/net/ftp/ftp_directory_listing_buffer.h index 0a25fff..ea68932 100644 --- a/net/ftp/ftp_directory_listing_buffer.h +++ b/net/ftp/ftp_directory_listing_buffer.h @@ -51,13 +51,20 @@ class FtpDirectoryListingBuffer { private: typedef std::set<FtpDirectoryListingParser*> ParserSet; - // Converts the string |from| to detected encoding and stores it in |to|. - // Returns true on success. - bool ConvertToDetectedEncoding(const std::string& from, string16* to); + // Decodes the raw buffer using specified |encoding|. On success + // clears the raw buffer and appends data to |converted_buffer_|. + // Returns network error code. + int DecodeBufferUsingEncoding(const std::string& encoding); - // Tries to extract full lines from the raw buffer, converting them to the - // detected encoding. Returns network error code. - int ExtractFullLinesFromBuffer(); + // Converts the raw buffer to UTF-16. Returns network error code. + int ConvertBufferToUTF16(); + + // Extracts lines from the converted buffer, and puts them in |lines_|. + void ExtractFullLinesFromBuffer(); + + // Consumes the raw buffer (i.e. does the character set conversion + // and line splitting). Returns network error code. + int ConsumeBuffer(); // Tries to parse full lines stored in |lines_|. Returns network error code. int ParseLines(); @@ -66,12 +73,15 @@ class FtpDirectoryListingBuffer { // parsers. Returns network error code. int OnEndOfInput(); - // Detected encoding of the response (empty if unknown or ASCII). + // Detected encoding of the response (empty if unknown). std::string encoding_; - // Buffer to keep not-yet-split data. + // Buffer to keep data before character set conversion. std::string buffer_; + // Buffer to keep data before line splitting. + string16 converted_buffer_; + // CRLF-delimited lines, without the CRLF, not yet consumed by parser. std::deque<string16> lines_; diff --git a/net/ftp/ftp_directory_listing_buffer_unittest.cc b/net/ftp/ftp_directory_listing_buffer_unittest.cc index 683e2f7..ceddfc4 100644 --- a/net/ftp/ftp_directory_listing_buffer_unittest.cc +++ b/net/ftp/ftp_directory_listing_buffer_unittest.cc @@ -42,6 +42,9 @@ TEST(FtpDirectoryListingBufferTest, Parse) { "dir-listing-ls-17", "dir-listing-ls-18", "dir-listing-ls-19", + "dir-listing-ls-20", // TODO(phajdan.jr): should use windows-1251 encoding. + "dir-listing-ls-21", // TODO(phajdan.jr): should use windows-1251 encoding. + "dir-listing-ls-22", // TODO(phajdan.jr): should use windows-1251 encoding. "dir-listing-mlsd-1", "dir-listing-mlsd-2", "dir-listing-netware-1", diff --git a/net/ftp/ftp_network_transaction.cc b/net/ftp/ftp_network_transaction.cc index bc1c2a9..bfda5bd 100644 --- a/net/ftp/ftp_network_transaction.cc +++ b/net/ftp/ftp_network_transaction.cc @@ -978,13 +978,10 @@ int FtpNetworkTransaction::ProcessResponseSIZE( if (size < 0) return Stop(ERR_INVALID_RESPONSE); - // Some FTP servers respond with success to the SIZE command - // for directories, and return 0 size. Make sure we don't set - // the resource type to file if that's the case. - if (size > 0) { - response_.expected_content_size = size; - resource_type_ = RESOURCE_TYPE_FILE; - } + // A successful response to SIZE does not mean the resource is a file. + // Some FTP servers (for example, the qnx one) send a SIZE even for + // directories. + response_.expected_content_size = size; break; case ERROR_CLASS_INFO_NEEDED: break; diff --git a/net/ftp/ftp_network_transaction_unittest.cc b/net/ftp/ftp_network_transaction_unittest.cc index 9c9f62e..1c22c5b 100644 --- a/net/ftp/ftp_network_transaction_unittest.cc +++ b/net/ftp/ftp_network_transaction_unittest.cc @@ -331,6 +331,30 @@ class FtpSocketDataProviderVMSDirectoryListingRootDirectory FtpSocketDataProviderVMSDirectoryListingRootDirectory); }; +class FtpSocketDataProviderFileDownloadWithFileTypecode + : public FtpSocketDataProvider { + public: + FtpSocketDataProviderFileDownloadWithFileTypecode() { + } + + virtual MockWriteResult OnWrite(const std::string& data) { + if (InjectFault()) + return MockWriteResult(true, data.length()); + switch (state()) { + case PRE_SIZE: + return Verify("SIZE /file\r\n", data, PRE_RETR, + "213 18\r\n"); + case PRE_RETR: + return Verify("RETR /file\r\n", data, PRE_QUIT, "200 OK\r\n"); + default: + return FtpSocketDataProvider::OnWrite(data); + } + } + + private: + DISALLOW_COPY_AND_ASSIGN(FtpSocketDataProviderFileDownloadWithFileTypecode); +}; + class FtpSocketDataProviderFileDownload : public FtpSocketDataProvider { public: FtpSocketDataProviderFileDownload() { @@ -341,8 +365,11 @@ class FtpSocketDataProviderFileDownload : public FtpSocketDataProvider { return MockWriteResult(true, data.length()); switch (state()) { case PRE_SIZE: - return Verify("SIZE /file\r\n", data, PRE_RETR, + return Verify("SIZE /file\r\n", data, PRE_CWD, "213 18\r\n"); + case PRE_CWD: + return Verify("CWD /file\r\n", data, PRE_RETR, + "550 Not a directory\r\n"); case PRE_RETR: return Verify("RETR /file\r\n", data, PRE_QUIT, "200 OK\r\n"); default: @@ -452,8 +479,11 @@ class FtpSocketDataProviderVMSFileDownload : public FtpSocketDataProvider { return Verify("PASV\r\n", data, PRE_SIZE, "227 Entering Passive Mode 127,0,0,1,123,456\r\n"); case PRE_SIZE: - return Verify("SIZE ANONYMOUS_ROOT:[000000]file\r\n", data, PRE_RETR, + return Verify("SIZE ANONYMOUS_ROOT:[000000]file\r\n", data, PRE_CWD, "213 18\r\n"); + case PRE_CWD: + return Verify("CWD ANONYMOUS_ROOT:[file]\r\n", data, PRE_RETR, + "550 Not a directory\r\n"); case PRE_RETR: return Verify("RETR ANONYMOUS_ROOT:[000000]file\r\n", data, PRE_QUIT, "200 OK\r\n"); @@ -476,8 +506,11 @@ class FtpSocketDataProviderEscaping : public FtpSocketDataProviderFileDownload { return MockWriteResult(true, data.length()); switch (state()) { case PRE_SIZE: - return Verify("SIZE / !\"#$%y\200\201\r\n", data, PRE_RETR, + return Verify("SIZE / !\"#$%y\200\201\r\n", data, PRE_CWD, "213 18\r\n"); + case PRE_CWD: + return Verify("CWD / !\"#$%y\200\201\r\n", data, PRE_RETR, + "550 Not a directory\r\n"); case PRE_RETR: return Verify("RETR / !\"#$%y\200\201\r\n", data, PRE_QUIT, "200 OK\r\n"); @@ -891,7 +924,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionWithPasvFallback) { } TEST_F(FtpNetworkTransactionTest, DownloadTransactionWithTypecodeA) { - FtpSocketDataProviderFileDownload ctrl_socket; + FtpSocketDataProviderFileDownloadWithFileTypecode ctrl_socket; ctrl_socket.set_data_type('A'); ExecuteTransaction(&ctrl_socket, "ftp://host/file;type=a", OK); @@ -900,7 +933,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionWithTypecodeA) { } TEST_F(FtpNetworkTransactionTest, DownloadTransactionWithTypecodeI) { - FtpSocketDataProviderFileDownload ctrl_socket; + FtpSocketDataProviderFileDownloadWithFileTypecode ctrl_socket; ExecuteTransaction(&ctrl_socket, "ftp://host/file;type=i", OK); // We pass an artificial value of 18 as a response to the SIZE command. @@ -1191,7 +1224,7 @@ TEST_F(FtpNetworkTransactionTest, DownloadTransactionBigSize) { // Pass a valid, but large file size. The transaction should not fail. FtpSocketDataProviderEvilSize ctrl_socket( "213 3204427776\r\n", - FtpSocketDataProvider::PRE_RETR); + FtpSocketDataProvider::PRE_CWD); ExecuteTransaction(&ctrl_socket, "ftp://host/file", OK); EXPECT_EQ(3204427776LL, transaction_.GetResponseInfo()->expected_content_size); diff --git a/net/http/http_auth_handler_digest_unittest.cc b/net/http/http_auth_handler_digest_unittest.cc index ff37e99..9338bca 100644 --- a/net/http/http_auth_handler_digest_unittest.cc +++ b/net/http/http_auth_handler_digest_unittest.cc @@ -8,7 +8,9 @@ #include "base/string_util.h" #include "base/utf_string_conversions.h" #include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" #include "net/http/http_auth_handler_digest.h" +#include "net/http/http_request_info.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -291,6 +293,10 @@ TEST(HttpAuthHandlerDigestTest, ParseChallenge) { EXPECT_EQ(tests[i].parsed_stale, digest->stale_); EXPECT_EQ(tests[i].parsed_algorithm, digest->algorithm_); EXPECT_EQ(tests[i].parsed_qop, digest->qop_); + EXPECT_TRUE(handler->encrypts_identity()); + EXPECT_FALSE(handler->is_connection_based()); + EXPECT_TRUE(handler->NeedsIdentity()); + EXPECT_FALSE(handler->AllowsDefaultCredentials()); } } @@ -482,4 +488,159 @@ TEST(HttpAuthHandlerDigest, HandleAnotherChallenge) { handler->HandleAnotherChallenge(&tok_stale_false)); } +namespace { + +const char* const kSimpleChallenge = + "Digest realm=\"Oblivion\", nonce=\"nonce-value\""; + +// RespondToChallenge creates an HttpAuthHandlerDigest for the specified +// |challenge|, and generates a response to the challenge which is returned in +// |token|. +// +// The return value is an error string - an empty string indicates no errors. +// +// If |target| is HttpAuth::AUTH_PROXY, then |proxy_name| specifies the source +// of the |challenge|. Otherwise, the scheme and host and port of |request_url| +// indicates the origin of the challenge. +std::string RespondToChallenge(HttpAuth::Target target, + const std::string& proxy_name, + const std::string& request_url, + const std::string& challenge, + std::string* token) { + // Input validation. + DCHECK(token); + DCHECK(target != HttpAuth::AUTH_PROXY || !proxy_name.empty()); + DCHECK(!request_url.empty()); + DCHECK(!challenge.empty()); + + token->clear(); + scoped_ptr<HttpAuthHandlerDigest::Factory> factory( + new HttpAuthHandlerDigest::Factory()); + HttpAuthHandlerDigest::NonceGenerator* nonce_generator = + new HttpAuthHandlerDigest::FixedNonceGenerator("client_nonce"); + factory->set_nonce_generator(nonce_generator); + scoped_ptr<HttpAuthHandler> handler; + + // Create a handler for a particular challenge. + GURL url_origin(target == HttpAuth::AUTH_SERVER ? request_url : proxy_name); + int rv_create = factory->CreateAuthHandlerFromString( + challenge, target, url_origin.GetOrigin(), BoundNetLog(), &handler); + if (rv_create != OK || handler.get() == NULL) + return "Unable to create auth handler."; + + // Create a token in response to the challenge. + // NOTE: HttpAuthHandlerDigest's implementation of GenerateAuthToken always + // completes synchronously. That's why this test can get away with a + // TestCompletionCallback without an IO thread. + TestCompletionCallback callback; + scoped_ptr<HttpRequestInfo> request(new HttpRequestInfo()); + request->url = GURL(request_url); + const string16 kFoo = ASCIIToUTF16("foo"); + const string16 kBar = ASCIIToUTF16("bar"); + int rv_generate = handler->GenerateAuthToken( + &kFoo, &kBar, request.get(), &callback, token); + if (rv_generate != OK) + return "Problems generating auth token"; + + // The token was correctly generated and is returned in |token|. + return std::string(); +} + +} // namespace + +TEST(HttpAuthHandlerDigest, RespondToServerChallenge) { + std::string auth_token; + std::string error_text = RespondToChallenge( + HttpAuth::AUTH_SERVER, + std::string(), + "http://www.example.com/path/to/resource", + kSimpleChallenge, + &auth_token); + EXPECT_EQ("", error_text); + EXPECT_EQ("Digest username=\"foo\", realm=\"Oblivion\", " + "nonce=\"nonce-value\", uri=\"/path/to/resource\", " + "response=\"6779f90bd0d658f937c1af967614fe84\"", + auth_token); +} + +TEST(HttpAuthHandlerDigest, RespondToHttpsServerChallenge) { + std::string auth_token; + std::string error_text = RespondToChallenge( + HttpAuth::AUTH_SERVER, + std::string(), + "https://www.example.com/path/to/resource", + kSimpleChallenge, + &auth_token); + EXPECT_EQ("", error_text); + EXPECT_EQ("Digest username=\"foo\", realm=\"Oblivion\", " + "nonce=\"nonce-value\", uri=\"/path/to/resource\", " + "response=\"6779f90bd0d658f937c1af967614fe84\"", + auth_token); +} + +TEST(HttpAuthHandlerDigest, RespondToProxyChallenge) { + std::string auth_token; + std::string error_text = RespondToChallenge( + HttpAuth::AUTH_PROXY, + "http://proxy.intranet.corp.com:3128", + "http://www.example.com/path/to/resource", + kSimpleChallenge, + &auth_token); + EXPECT_EQ("", error_text); + EXPECT_EQ("Digest username=\"foo\", realm=\"Oblivion\", " + "nonce=\"nonce-value\", uri=\"/path/to/resource\", " + "response=\"6779f90bd0d658f937c1af967614fe84\"", + auth_token); +} + +TEST(HttpAuthHandlerDigest, RespondToProxyChallengeHttps) { + std::string auth_token; + std::string error_text = RespondToChallenge( + HttpAuth::AUTH_PROXY, + "http://proxy.intranet.corp.com:3128", + "https://www.example.com/path/to/resource", + kSimpleChallenge, + &auth_token); + EXPECT_EQ("", error_text); + EXPECT_EQ("Digest username=\"foo\", realm=\"Oblivion\", " + "nonce=\"nonce-value\", uri=\"www.example.com:443\", " + "response=\"3270da8467afbe9ddf2334a48d46e9b9\"", + auth_token); +} + +TEST(HttpAuthHandlerDigest, RespondToChallengeAuthQop) { + std::string auth_token; + std::string error_text = RespondToChallenge( + HttpAuth::AUTH_SERVER, + std::string(), + "http://www.example.com/path/to/resource", + "Digest realm=\"Oblivion\", nonce=\"nonce-value\", qop=\"auth\"", + &auth_token); + EXPECT_EQ("", error_text); + EXPECT_EQ("Digest username=\"foo\", realm=\"Oblivion\", " + "nonce=\"nonce-value\", uri=\"/path/to/resource\", " + "response=\"5b1459beda5cee30d6ff9e970a69c0ea\", " + "qop=auth, nc=00000001, cnonce=\"client_nonce\"", + auth_token); +} + +TEST(HttpAuthHandlerDigest, RespondToChallengeOpaque) { + std::string auth_token; + std::string error_text = RespondToChallenge( + HttpAuth::AUTH_SERVER, + std::string(), + "http://www.example.com/path/to/resource", + "Digest realm=\"Oblivion\", nonce=\"nonce-value\", " + "qop=\"auth\", opaque=\"opaque text\"", + &auth_token); + EXPECT_EQ("", error_text); + EXPECT_EQ("Digest username=\"foo\", realm=\"Oblivion\", " + "nonce=\"nonce-value\", uri=\"/path/to/resource\", " + "response=\"5b1459beda5cee30d6ff9e970a69c0ea\", " + "opaque=\"opaque text\", " + "qop=auth, nc=00000001, cnonce=\"client_nonce\"", + auth_token); +} + + } // namespace net diff --git a/net/http/http_cache.cc b/net/http/http_cache.cc index 1342afa..896a6ac 100644 --- a/net/http/http_cache.cc +++ b/net/http/http_cache.cc @@ -280,6 +280,7 @@ class HttpCache::SSLHostInfoFactoryAdaptor : public SSLHostInfoFactory { HttpCache::HttpCache(HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker_, ProxyService* proxy_service, SSLConfigService* ssl_config_service, HttpAuthHandlerFactory* http_auth_handler_factory, @@ -292,7 +293,8 @@ HttpCache::HttpCache(HostResolver* host_resolver, ssl_host_info_factory_(new SSLHostInfoFactoryAdaptor( ALLOW_THIS_IN_INITIALIZER_LIST(this))), network_layer_(HttpNetworkLayer::CreateFactory(host_resolver, - dnsrr_resolver, ssl_host_info_factory_.get(), + dnsrr_resolver, dns_cert_checker_, + ssl_host_info_factory_.get(), proxy_service, ssl_config_service, http_auth_handler_factory, network_delegate, net_log)), ALLOW_THIS_IN_INITIALIZER_LIST(task_factory_(this)), diff --git a/net/http/http_cache.h b/net/http/http_cache.h index 0ce22e5..06c2ab9 100644 --- a/net/http/http_cache.h +++ b/net/http/http_cache.h @@ -41,6 +41,7 @@ class Entry; namespace net { +class DnsCertProvenanceChecker; class DnsRRResolver; class HostResolver; class HttpAuthHandlerFactory; @@ -117,6 +118,7 @@ class HttpCache : public HttpTransactionFactory, // The HttpCache takes ownership of the |backend_factory|. HttpCache(HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, ProxyService* proxy_service, SSLConfigService* ssl_config_service, HttpAuthHandlerFactory* http_auth_handler_factory, diff --git a/net/http/http_cache_transaction.cc b/net/http/http_cache_transaction.cc index 873ccf4..1720509 100644 --- a/net/http/http_cache_transaction.cc +++ b/net/http/http_cache_transaction.cc @@ -112,6 +112,7 @@ HttpCache::Transaction::Transaction(HttpCache* cache, bool enable_range_support) invalid_range_(false), enable_range_support_(enable_range_support), truncated_(false), + is_sparse_(false), server_responded_206_(false), cache_pending_(false), read_offset_(0), @@ -676,7 +677,7 @@ int HttpCache::Transaction::DoSuccessfulSendRequest() { return OK; } if (server_responded_206_ && mode_ == READ_WRITE && !truncated_ && - response_.headers->response_code() == 200) { + !is_sparse_) { // We have stored the full entry, but it changed and the server is // sending a range. We have to delete the old entry. DoneWritingToEntry(false); @@ -886,6 +887,7 @@ int HttpCache::Transaction::DoAddToEntryComplete(int result) { return OK; } +// We may end up here multiple times for a given request. int HttpCache::Transaction::DoStartPartialCacheValidation() { if (mode_ == NONE) return OK; @@ -1446,6 +1448,7 @@ int HttpCache::Transaction::BeginPartialCacheValidation() { return ValidateEntryHeadersAndContinue(false); } +// This should only be called once per request. int HttpCache::Transaction::ValidateEntryHeadersAndContinue( bool byte_range_requested) { DCHECK(mode_ == READ_WRITE); @@ -1461,6 +1464,9 @@ int HttpCache::Transaction::ValidateEntryHeadersAndContinue( return OK; } + if (response_.headers->response_code() == 206) + is_sparse_ = true; + if (!partial_->IsRequestedRangeOK()) { // The stored data is fine, but the request may be invalid. invalid_range_ = true; @@ -1842,6 +1848,7 @@ void HttpCache::Transaction::DoomPartialEntry(bool delete_object) { DCHECK_EQ(OK, rv); cache_->DoneWithEntry(entry_, this, false); entry_ = NULL; + is_sparse_ = false; if (delete_object) partial_.reset(NULL); } diff --git a/net/http/http_cache_transaction.h b/net/http/http_cache_transaction.h index 4d3673c..a842ade 100644 --- a/net/http/http_cache_transaction.h +++ b/net/http/http_cache_transaction.h @@ -341,6 +341,7 @@ class HttpCache::Transaction : public HttpTransaction { bool invalid_range_; // We may bypass the cache for this request. bool enable_range_support_; bool truncated_; // We don't have all the response data. + bool is_sparse_; // The data is stored in sparse byte ranges. bool server_responded_206_; bool cache_pending_; // We are waiting for the HttpCache. scoped_refptr<IOBuffer> read_buf_; diff --git a/net/http/http_cache_unittest.cc b/net/http/http_cache_unittest.cc index 6ca1da1..b4dde9b 100644 --- a/net/http/http_cache_unittest.cc +++ b/net/http/http_cache_unittest.cc @@ -3216,36 +3216,41 @@ TEST(HttpCache, GET_Previous206) { } // Tests that we can handle non-range requests when we have cached the first -// part of the object and server replies with 304 (Not Modified). +// part of the object and the server replies with 304 (Not Modified). TEST(HttpCache, GET_Previous206_NotModified) { MockHttpCache cache; cache.http_cache()->set_enable_range_support(true); MockTransaction transaction(kRangeGET_TransactionOK); - transaction.request_headers = "Range: bytes = 0-9\r\n" EXTRA_HEADER; - transaction.data = "rg: 00-09 "; AddMockTransaction(&transaction); std::string headers; // Write to the cache (0-9). + transaction.request_headers = "Range: bytes = 0-9\r\n" EXTRA_HEADER; + transaction.data = "rg: 00-09 "; RunTransactionTestWithResponse(cache.http_cache(), transaction, &headers); - Verify206Response(headers, 0, 9); - EXPECT_EQ(1, cache.network_layer()->transaction_count()); - EXPECT_EQ(0, cache.disk_cache()->open_count()); + + // Write to the cache (70-79). + transaction.request_headers = "Range: bytes = 70-79\r\n" EXTRA_HEADER; + transaction.data = "rg: 70-79 "; + RunTransactionTestWithResponse(cache.http_cache(), transaction, &headers); + Verify206Response(headers, 70, 79); + + EXPECT_EQ(2, cache.network_layer()->transaction_count()); + EXPECT_EQ(1, cache.disk_cache()->open_count()); EXPECT_EQ(1, cache.disk_cache()->create_count()); - // Read from the cache (0-9), write and read from cache (10 - 79), - MockTransaction transaction2(kRangeGET_TransactionOK); - transaction2.load_flags |= net::LOAD_VALIDATE_CACHE; - transaction2.request_headers = "Foo: bar\r\n" EXTRA_HEADER; - transaction2.data = "rg: 00-09 rg: 10-19 rg: 20-29 rg: 30-39 rg: 40-49 " + // Read from the cache (0-9), write and read from cache (10 - 79). + transaction.load_flags |= net::LOAD_VALIDATE_CACHE; + transaction.request_headers = "Foo: bar\r\n" EXTRA_HEADER; + transaction.data = "rg: 00-09 rg: 10-19 rg: 20-29 rg: 30-39 rg: 40-49 " "rg: 50-59 rg: 60-69 rg: 70-79 "; - RunTransactionTestWithResponse(cache.http_cache(), transaction2, &headers); + RunTransactionTestWithResponse(cache.http_cache(), transaction, &headers); EXPECT_EQ(0U, headers.find("HTTP/1.1 200 OK\n")); - EXPECT_EQ(3, cache.network_layer()->transaction_count()); - EXPECT_EQ(1, cache.disk_cache()->open_count()); + EXPECT_EQ(4, cache.network_layer()->transaction_count()); + EXPECT_EQ(2, cache.disk_cache()->open_count()); EXPECT_EQ(1, cache.disk_cache()->create_count()); RemoveMockTransaction(&transaction); diff --git a/net/http/http_network_layer.cc b/net/http/http_network_layer.cc index 5322e85..3da23c2 100644 --- a/net/http/http_network_layer.cc +++ b/net/http/http_network_layer.cc @@ -22,6 +22,7 @@ namespace net { HttpTransactionFactory* HttpNetworkLayer::CreateFactory( HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -32,6 +33,7 @@ HttpTransactionFactory* HttpNetworkLayer::CreateFactory( return new HttpNetworkLayer(ClientSocketFactory::GetDefaultFactory(), host_resolver, dnsrr_resolver, + dns_cert_checker, ssl_host_info_factory, proxy_service, ssl_config_service, http_auth_handler_factory, network_delegate, @@ -51,6 +53,7 @@ HttpNetworkLayer::HttpNetworkLayer( ClientSocketFactory* socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -60,6 +63,7 @@ HttpNetworkLayer::HttpNetworkLayer( : socket_factory_(socket_factory), host_resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), + dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), @@ -77,6 +81,7 @@ HttpNetworkLayer::HttpNetworkLayer( ClientSocketFactory* socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -87,6 +92,7 @@ HttpNetworkLayer::HttpNetworkLayer( : socket_factory_(socket_factory), host_resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), + dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), @@ -103,6 +109,7 @@ HttpNetworkLayer::HttpNetworkLayer( HttpNetworkLayer::HttpNetworkLayer(HttpNetworkSession* session) : socket_factory_(ClientSocketFactory::GetDefaultFactory()), dnsrr_resolver_(NULL), + dns_cert_checker_(NULL), ssl_host_info_factory_(NULL), ssl_config_service_(NULL), session_(session), @@ -144,6 +151,7 @@ HttpNetworkSession* HttpNetworkLayer::GetSession() { session_ = new HttpNetworkSession( host_resolver_, dnsrr_resolver_, + dns_cert_checker_, ssl_host_info_factory_, proxy_service_, socket_factory_, @@ -155,6 +163,7 @@ HttpNetworkSession* HttpNetworkLayer::GetSession() { // These were just temps for lazy-initializing HttpNetworkSession. host_resolver_ = NULL; dnsrr_resolver_ = NULL; + dns_cert_checker_ = NULL; ssl_host_info_factory_ = NULL; proxy_service_ = NULL; socket_factory_ = NULL; diff --git a/net/http/http_network_layer.h b/net/http/http_network_layer.h index 63ae3f2..7781efb 100644 --- a/net/http/http_network_layer.h +++ b/net/http/http_network_layer.h @@ -16,6 +16,7 @@ namespace net { class ClientSocketFactory; +class DnsCertProvenanceChecker; class DnsRRResolver; class HostResolver; class HttpAuthHandlerFactory; @@ -34,6 +35,7 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { HttpNetworkLayer(ClientSocketFactory* socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -46,6 +48,7 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { ClientSocketFactory* socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -62,6 +65,7 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { static HttpTransactionFactory* CreateFactory( HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service, @@ -100,6 +104,7 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { // creating |session_|. HostResolver* host_resolver_; DnsRRResolver* dnsrr_resolver_; + DnsCertProvenanceChecker* dns_cert_checker_; SSLHostInfoFactory* ssl_host_info_factory_; scoped_refptr<ProxyService> proxy_service_; diff --git a/net/http/http_network_layer_unittest.cc b/net/http/http_network_layer_unittest.cc index 2850404..3ed54bf 100644 --- a/net/http/http_network_layer_unittest.cc +++ b/net/http/http_network_layer_unittest.cc @@ -25,6 +25,7 @@ TEST_F(HttpNetworkLayerTest, CreateAndDestroy) { NULL, &host_resolver, NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, net::ProxyService::CreateDirect(), new net::SSLConfigServiceDefaults, @@ -44,6 +45,7 @@ TEST_F(HttpNetworkLayerTest, Suspend) { NULL, &host_resolver, NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, net::ProxyService::CreateDirect(), new net::SSLConfigServiceDefaults, @@ -92,6 +94,7 @@ TEST_F(HttpNetworkLayerTest, GET) { &mock_socket_factory, &host_resolver, NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, net::ProxyService::CreateDirect(), new net::SSLConfigServiceDefaults, diff --git a/net/http/http_network_session.cc b/net/http/http_network_session.cc index d96f901..1e77b49 100644 --- a/net/http/http_network_session.cc +++ b/net/http/http_network_session.cc @@ -21,6 +21,7 @@ namespace net { HttpNetworkSession::HttpNetworkSession( HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, ClientSocketFactory* client_socket_factory, @@ -32,12 +33,14 @@ HttpNetworkSession::HttpNetworkSession( : socket_factory_(client_socket_factory), host_resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), + dns_cert_checker_(dns_cert_checker), proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), socket_pool_manager_(net_log, client_socket_factory, host_resolver, dnsrr_resolver, + dns_cert_checker, ssl_host_info_factory, proxy_service, ssl_config_service), diff --git a/net/http/http_network_session.h b/net/http/http_network_session.h index 53ae36a..43424d2 100644 --- a/net/http/http_network_session.h +++ b/net/http/http_network_session.h @@ -29,6 +29,7 @@ class Value; namespace net { class ClientSocketFactory; +class DnsCertProvenanceChecker; class DnsRRResolver; class HttpAuthHandlerFactory; class HttpNetworkDelegate; @@ -48,6 +49,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, HttpNetworkSession( HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, ClientSocketFactory* client_socket_factory, @@ -108,6 +110,9 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, ClientSocketFactory* socket_factory() { return socket_factory_; } HostResolver* host_resolver() { return host_resolver_; } DnsRRResolver* dnsrr_resolver() { return dnsrr_resolver_; } + DnsCertProvenanceChecker* dns_cert_checker() { + return dns_cert_checker_; + } ProxyService* proxy_service() { return proxy_service_; } SSLConfigService* ssl_config_service() { return ssl_config_service_; } SpdySessionPool* spdy_session_pool() { return spdy_session_pool_.get(); } @@ -148,6 +153,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, HttpAlternateProtocols alternate_protocols_; HostResolver* const host_resolver_; DnsRRResolver* dnsrr_resolver_; + DnsCertProvenanceChecker* dns_cert_checker_; scoped_refptr<ProxyService> proxy_service_; scoped_refptr<SSLConfigService> ssl_config_service_; ClientSocketPoolManager socket_pool_manager_; diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index 3e1d8d1..c84dfa4 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -96,7 +96,6 @@ HttpNetworkTransaction::HttpNetworkTransaction(HttpNetworkSession* session) session->ssl_config_service()->GetSSLConfig(&ssl_config_); if (session->http_stream_factory()->next_protos()) ssl_config_.next_protos = *session->http_stream_factory()->next_protos(); - } HttpNetworkTransaction::~HttpNetworkTransaction() { @@ -171,10 +170,8 @@ int HttpNetworkTransaction::RestartWithCertificate( DCHECK_EQ(STATE_NONE, next_state_); ssl_config_.client_cert = client_cert; - if (client_cert) { - session_->ssl_client_auth_cache()->Add( - response_.cert_request_info->host_and_port, client_cert); - } + session_->ssl_client_auth_cache()->Add( + response_.cert_request_info->host_and_port, client_cert); ssl_config_.send_client_cert = true; // Reset the other member variables. // Note: this is necessary only with SSL renegotiation. @@ -971,29 +968,43 @@ int HttpNetworkTransaction::HandleCertificateRequest(int error) { // handshake. stream_request_.reset(); - // If the user selected one of the certificate in client_certs for this - // server before, use it automatically. - X509Certificate* client_cert = session_->ssl_client_auth_cache()->Lookup( - response_.cert_request_info->host_and_port); + // If the user selected one of the certificates in client_certs or declined + // to provide one for this server before, use the past decision + // automatically. + scoped_refptr<X509Certificate> client_cert; + bool found_cached_cert = session_->ssl_client_auth_cache()->Lookup( + response_.cert_request_info->host_and_port, &client_cert); + if (!found_cached_cert) + return error; + + // Check that the certificate selected is still a certificate the server + // is likely to accept, based on the criteria supplied in the + // CertificateRequest message. if (client_cert) { const std::vector<scoped_refptr<X509Certificate> >& client_certs = response_.cert_request_info->client_certs; + bool cert_still_valid = false; for (size_t i = 0; i < client_certs.size(); ++i) { - if (client_cert->fingerprint().Equals(client_certs[i]->fingerprint())) { - // TODO(davidben): Add a unit test which covers this path; we need to be - // able to send a legitimate certificate and also bypass/clear the - // SSL session cache. - ssl_config_.client_cert = client_cert; - ssl_config_.send_client_cert = true; - next_state_ = STATE_CREATE_STREAM; - // Reset the other member variables. - // Note: this is necessary only with SSL renegotiation. - ResetStateForRestart(); - return OK; + if (client_cert->Equals(client_certs[i])) { + cert_still_valid = true; + break; } } + + if (!cert_still_valid) + return error; } - return error; + + // TODO(davidben): Add a unit test which covers this path; we need to be + // able to send a legitimate certificate and also bypass/clear the + // SSL session cache. + ssl_config_.client_cert = client_cert; + ssl_config_.send_client_cert = true; + next_state_ = STATE_CREATE_STREAM; + // Reset the other member variables. + // Note: this is necessary only with SSL renegotiation. + ResetStateForRestart(); + return OK; } // This method determines whether it is safe to resend the request after an diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 6a389af..f765696 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -100,6 +100,7 @@ struct SessionDependencies { HttpNetworkSession* CreateSession(SessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, session_deps->proxy_service, &session_deps->socket_factory, @@ -307,7 +308,7 @@ template<> CaptureGroupNameSSLSocketPool::CaptureGroupNameSocketPool( HttpNetworkSession* session) : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), NULL, NULL, - NULL, NULL, NULL, NULL, NULL, NULL) {} + NULL, NULL, NULL, NULL, NULL, NULL, NULL) {} //----------------------------------------------------------------------------- diff --git a/net/http/http_proxy_client_socket_pool.cc b/net/http/http_proxy_client_socket_pool.cc index e43d02b..d0c177d 100644 --- a/net/http/http_proxy_client_socket_pool.cc +++ b/net/http/http_proxy_client_socket_pool.cc @@ -283,7 +283,8 @@ int HttpProxyConnectJob::DoSpdyProxyCreateStream() { next_state_ = STATE_SPDY_PROXY_CREATE_STREAM_COMPLETE; return spdy_session->CreateStream(params_->request_url(), params_->destination().priority(), - &spdy_stream_, net_log(), &callback_); + &spdy_stream_, spdy_session->net_log(), + &callback_); } int HttpProxyConnectJob::DoSpdyProxyCreateStreamComplete(int result) { diff --git a/net/http/http_proxy_client_socket_pool_unittest.cc b/net/http/http_proxy_client_socket_pool_unittest.cc index f5bc2e7..56fae19 100644 --- a/net/http/http_proxy_client_socket_pool_unittest.cc +++ b/net/http/http_proxy_client_socket_pool_unittest.cc @@ -66,6 +66,7 @@ class HttpProxyClientSocketPoolTest : public TestWithHttpParam { &ssl_histograms_, host_resolver_.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, &socket_factory_, &tcp_socket_pool_, @@ -77,6 +78,7 @@ class HttpProxyClientSocketPoolTest : public TestWithHttpParam { HttpAuthHandlerFactory::CreateDefault(host_resolver_.get())), session_(new HttpNetworkSession(host_resolver_.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, ProxyService::CreateDirect(), &socket_factory_, diff --git a/net/http/http_request_info.h b/net/http/http_request_info.h index f740d7c..7c73aa6 100644 --- a/net/http/http_request_info.h +++ b/net/http/http_request_info.h @@ -18,9 +18,12 @@ namespace net { struct HttpRequestInfo { enum RequestMotivation{ // TODO(mbelshe): move these into Client Socket. - PRECONNECT_MOTIVATED, // This request was motivated by a prefetch. - OMNIBOX_MOTIVATED, // This request was motivated by the omnibox. - NORMAL_MOTIVATION // No special motivation associated with the request. + PRECONNECT_MOTIVATED, // Request was motivated by a prefetch. + OMNIBOX_MOTIVATED, // Request was motivated by the omnibox. + NORMAL_MOTIVATION, // No special motivation associated with the request. + EARLY_LOAD_MOTIVATED, // When browser asks a tab to open an URL, this short + // circuits that path (of waiting for the renderer to + // do the URL request), and starts loading ASAP. }; HttpRequestInfo(); diff --git a/net/http/http_response_body_drainer_unittest.cc b/net/http/http_response_body_drainer_unittest.cc index d8c9bb7..75f099a 100644 --- a/net/http/http_response_body_drainer_unittest.cc +++ b/net/http/http_response_body_drainer_unittest.cc @@ -175,8 +175,9 @@ class HttpResponseBodyDrainerTest : public testing::Test { protected: HttpResponseBodyDrainerTest() : session_(new HttpNetworkSession( - NULL, - NULL, + NULL /* host_resolver */, + NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, ProxyService::CreateDirect(), NULL, diff --git a/net/http/http_stream_factory_unittest.cc b/net/http/http_stream_factory_unittest.cc index c295363..63fce33 100644 --- a/net/http/http_stream_factory_unittest.cc +++ b/net/http/http_stream_factory_unittest.cc @@ -44,6 +44,7 @@ struct SessionDependencies { HttpNetworkSession* CreateSession(SessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, session_deps->proxy_service, &session_deps->socket_factory, @@ -170,7 +171,7 @@ template<> CapturePreconnectsSSLSocketPool::CapturePreconnectsSocketPool( HttpNetworkSession* session) : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), NULL, NULL, - NULL, NULL, NULL, NULL, NULL, NULL) {} + NULL, NULL, NULL, NULL, NULL, NULL, NULL) {} TEST(HttpStreamFactoryTest, PreconnectDirect) { for (size_t i = 0; i < arraysize(kTests); ++i) { diff --git a/net/http/http_stream_request.cc b/net/http/http_stream_request.cc index eaaad16..6f0a39e 100644 --- a/net/http/http_stream_request.cc +++ b/net/http/http_stream_request.cc @@ -759,7 +759,12 @@ int HttpStreamRequest::DoCreateStream() { direct = false; } - if (!spdy_session.get()) { + if (spdy_session.get()) { + // We picked up an existing session, so we don't need our socket. + if (connection_->socket()) + connection_->socket()->Disconnect(); + connection_->Reset(); + } else { // SPDY can be negotiated using the TLS next protocol negotiation (NPN) // extension, or just directly using SSL. Either way, |connection_| must // contain an SSLClientSocket. diff --git a/net/http/stream_factory.h b/net/http/stream_factory.h index 480df06..94b3cc8 100644 --- a/net/http/stream_factory.h +++ b/net/http/stream_factory.h @@ -122,7 +122,8 @@ class StreamFactory { // Requests that enough connections for |num_streams| be opened. If // ERR_IO_PENDING is returned, |info|, |ssl_config|, and |proxy_info| must - // be kept alive until |callback| is invoked. + // be kept alive until |callback| is invoked. That callback will be given the + // final error code. virtual int PreconnectStreams(int num_streams, const HttpRequestInfo* info, SSLConfig* ssl_config, diff --git a/net/net.gyp b/net/net.gyp index b9e3776..a5134da 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -582,8 +582,8 @@ 'socket/client_socket_pool_histograms.h', 'socket/client_socket_pool_manager.cc', 'socket/client_socket_pool_manager.h', - 'socket/dns_cert_provenance_check.cc', - 'socket/dns_cert_provenance_check.h', + 'socket/dns_cert_provenance_checker.cc', + 'socket/dns_cert_provenance_checker.h', 'socket/socket.h', 'socket/socks5_client_socket.cc', 'socket/socks5_client_socket.h', @@ -685,6 +685,14 @@ 'url_request/url_request_status.h', 'url_request/url_request_test_job.cc', 'url_request/url_request_test_job.h', + 'url_request/url_request_throttler_entry.cc', + 'url_request/url_request_throttler_entry.h', + 'url_request/url_request_throttler_entry_interface.h', + 'url_request/url_request_throttler_header_adapter.h', + 'url_request/url_request_throttler_header_adapter.cc', + 'url_request/url_request_throttler_header_interface.h', + 'url_request/url_request_throttler_manager.cc', + 'url_request/url_request_throttler_manager.h', 'url_request/view_cache_helper.cc', 'url_request/view_cache_helper.h', 'websockets/websocket.cc', @@ -862,7 +870,7 @@ 'base/test_completion_callback_unittest.cc', 'base/upload_data_stream_unittest.cc', 'base/x509_certificate_unittest.cc', - 'base/x509_cert_types_unittest.cc', + 'base/x509_cert_types_mac_unittest.cc', 'base/x509_openssl_util_unittest.cc', 'disk_cache/addr_unittest.cc', 'disk_cache/backend_unittest.cc', @@ -960,6 +968,7 @@ 'tools/dump_cache/url_utilities.cc', 'tools/dump_cache/url_utilities_unittest.cc', 'url_request/url_request_job_tracker_unittest.cc', + 'url_request/url_request_throttler_unittest.cc', 'url_request/url_request_unittest.cc', 'url_request/url_request_unittest.h', 'url_request/view_cache_helper_unittest.cc', @@ -1334,61 +1343,57 @@ }, ], 'conditions': [ - # ['OS=="linux"', { - # 'targets': [ - # { - # 'target_name': 'flip_in_mem_edsm_server', - # 'type': 'executable', - # 'dependencies': [ - # '../base/base.gyp:base', - # 'net.gyp:net', - # ], - # 'link_settings': { - # 'ldflags': [ - # '-lssl' - # ], - # 'libraries': [ - # '-lssl' - # ], - # }, - # 'sources': [ - # 'tools/dump_cache/url_to_filename_encoder.cc', - # 'tools/dump_cache/url_to_filename_encoder.h', - # 'tools/dump_cache/url_utilities.h', - # 'tools/dump_cache/url_utilities.cc', + ['OS=="linux"', { + 'targets': [ + { + 'target_name': 'flip_in_mem_edsm_server', + 'type': 'executable', + 'cflags': [ + '-Wno-deprecated', + ], + 'dependencies': [ + '../base/base.gyp:base', + 'net.gyp:net', + '../third_party/openssl/openssl.gyp:openssl', + ], + 'sources': [ + 'tools/dump_cache/url_to_filename_encoder.cc', + 'tools/dump_cache/url_to_filename_encoder.h', + 'tools/dump_cache/url_utilities.h', + 'tools/dump_cache/url_utilities.cc', - # 'tools/flip_server/balsa_enums.h', - # 'tools/flip_server/balsa_frame.cc', - # 'tools/flip_server/balsa_frame.h', - # 'tools/flip_server/balsa_headers.cc', - # 'tools/flip_server/balsa_headers.h', - # 'tools/flip_server/balsa_headers_token_utils.cc', - # 'tools/flip_server/balsa_headers_token_utils.h', - # 'tools/flip_server/balsa_visitor_interface.h', - # 'tools/flip_server/buffer_interface.h', - # 'tools/flip_server/create_listener.cc', - # 'tools/flip_server/create_listener.h', - # 'tools/flip_server/epoll_server.cc', - # 'tools/flip_server/epoll_server.h', - # 'tools/flip_server/flip_in_mem_edsm_server.cc', - # 'tools/flip_server/http_message_constants.cc', - # 'tools/flip_server/http_message_constants.h', - # 'tools/flip_server/loadtime_measurement.h', - # 'tools/flip_server/porting.txt', - # 'tools/flip_server/ring_buffer.cc', - # 'tools/flip_server/ring_buffer.h', - # 'tools/flip_server/simple_buffer.cc', - # 'tools/flip_server/simple_buffer.h', - # 'tools/flip_server/split.h', - # 'tools/flip_server/split.cc', - # 'tools/flip_server/string_piece_utils.h', - # 'tools/flip_server/thread.h', - # 'tools/flip_server/url_to_filename_encoder.h', - # 'tools/flip_server/url_utilities.h', - # ], - # }, - # ] - # }], + 'tools/flip_server/balsa_enums.h', + 'tools/flip_server/balsa_frame.cc', + 'tools/flip_server/balsa_frame.h', + 'tools/flip_server/balsa_headers.cc', + 'tools/flip_server/balsa_headers.h', + 'tools/flip_server/balsa_headers_token_utils.cc', + 'tools/flip_server/balsa_headers_token_utils.h', + 'tools/flip_server/balsa_visitor_interface.h', + 'tools/flip_server/buffer_interface.h', + 'tools/flip_server/create_listener.cc', + 'tools/flip_server/create_listener.h', + 'tools/flip_server/epoll_server.cc', + 'tools/flip_server/epoll_server.h', + 'tools/flip_server/flip_in_mem_edsm_server.cc', + 'tools/flip_server/http_message_constants.cc', + 'tools/flip_server/http_message_constants.h', + 'tools/flip_server/loadtime_measurement.h', + 'tools/flip_server/porting.txt', + 'tools/flip_server/ring_buffer.cc', + 'tools/flip_server/ring_buffer.h', + 'tools/flip_server/simple_buffer.cc', + 'tools/flip_server/simple_buffer.h', + 'tools/flip_server/split.h', + 'tools/flip_server/split.cc', + 'tools/flip_server/string_piece_utils.h', + 'tools/flip_server/thread.h', + 'tools/flip_server/url_to_filename_encoder.h', + 'tools/flip_server/url_utilities.h', + ], + }, + ] + }], ['OS=="linux"', { 'targets': [ { diff --git a/net/ocsp/nss_ocsp.cc b/net/ocsp/nss_ocsp.cc index fafaa68..7618f9e 100644 --- a/net/ocsp/nss_ocsp.cc +++ b/net/ocsp/nss_ocsp.cc @@ -85,7 +85,8 @@ class OCSPIOLoop { DISALLOW_COPY_AND_ASSIGN(OCSPIOLoop); }; -base::LazyInstance<OCSPIOLoop> g_ocsp_io_loop(base::LINKER_INITIALIZED); +base::LazyInstance<OCSPIOLoop, base::LeakyLazyInstanceTraits<OCSPIOLoop> > + g_ocsp_io_loop(base::LINKER_INITIALIZED); const int kRecvBufferSize = 4096; @@ -560,7 +561,6 @@ OCSPNSSInitialization::~OCSPNSSInitialization() {} SECStatus OCSPCreateSession(const char* host, PRUint16 portnum, SEC_HTTP_SERVER_SESSION* pSession) { VLOG(1) << "OCSP create session: host=" << host << " port=" << portnum; - DCHECK(!MessageLoop::current()); pthread_mutex_lock(&g_request_context_lock); URLRequestContext* request_context = g_request_context; pthread_mutex_unlock(&g_request_context_lock); @@ -579,7 +579,6 @@ SECStatus OCSPCreateSession(const char* host, PRUint16 portnum, SECStatus OCSPKeepAliveSession(SEC_HTTP_SERVER_SESSION session, PRPollDesc **pPollDesc) { VLOG(1) << "OCSP keep alive"; - DCHECK(!MessageLoop::current()); if (pPollDesc) *pPollDesc = NULL; return SECSuccess; @@ -587,7 +586,6 @@ SECStatus OCSPKeepAliveSession(SEC_HTTP_SERVER_SESSION session, SECStatus OCSPFreeSession(SEC_HTTP_SERVER_SESSION session) { VLOG(1) << "OCSP free session"; - DCHECK(!MessageLoop::current()); delete reinterpret_cast<OCSPServerSession*>(session); return SECSuccess; } @@ -602,7 +600,6 @@ SECStatus OCSPCreate(SEC_HTTP_SERVER_SESSION session, << " path_and_query=" << path_and_query_string << " http_request_method=" << http_request_method << " timeout=" << timeout; - DCHECK(!MessageLoop::current()); OCSPServerSession* ocsp_session = reinterpret_cast<OCSPServerSession*>(session); @@ -624,7 +621,6 @@ SECStatus OCSPSetPostData(SEC_HTTP_REQUEST_SESSION request, const PRUint32 http_data_len, const char* http_content_type) { VLOG(1) << "OCSP set post data len=" << http_data_len; - DCHECK(!MessageLoop::current()); OCSPRequestSession* req = reinterpret_cast<OCSPRequestSession*>(request); req->SetPostData(http_data, http_data_len, http_content_type); @@ -636,7 +632,6 @@ SECStatus OCSPAddHeader(SEC_HTTP_REQUEST_SESSION request, const char* http_header_value) { VLOG(1) << "OCSP add header name=" << http_header_name << " value=" << http_header_value; - DCHECK(!MessageLoop::current()); OCSPRequestSession* req = reinterpret_cast<OCSPRequestSession*>(request); req->AddHeader(http_header_name, http_header_value); @@ -696,7 +691,6 @@ SECStatus OCSPTrySendAndReceive(SEC_HTTP_REQUEST_SESSION request, } VLOG(1) << "OCSP try send and receive"; - DCHECK(!MessageLoop::current()); OCSPRequestSession* req = reinterpret_cast<OCSPRequestSession*>(request); // We support blocking mode only. if (pPollDesc) @@ -774,7 +768,6 @@ SECStatus OCSPTrySendAndReceive(SEC_HTTP_REQUEST_SESSION request, SECStatus OCSPFree(SEC_HTTP_REQUEST_SESSION request) { VLOG(1) << "OCSP free"; - DCHECK(!MessageLoop::current()); OCSPRequestSession* req = reinterpret_cast<OCSPRequestSession*>(request); req->Cancel(); req->Release(); diff --git a/net/proxy/proxy_script_fetcher_impl.cc b/net/proxy/proxy_script_fetcher_impl.cc index 035622f..221e5c0 100644 --- a/net/proxy/proxy_script_fetcher_impl.cc +++ b/net/proxy/proxy_script_fetcher_impl.cc @@ -18,8 +18,6 @@ // TODO(eroman): // - Support auth-prompts. -// TODO(eroman): Use a state machine rather than recursion. Recursion could -// lead to lots of frames. namespace net { namespace { @@ -142,7 +140,7 @@ URLRequestContext* ProxyScriptFetcherImpl::GetRequestContext() { void ProxyScriptFetcherImpl::OnAuthRequired(URLRequest* request, AuthChallengeInfo* auth_info) { - DCHECK(request == cur_request_.get()); + DCHECK_EQ(request, cur_request_.get()); // TODO(eroman): LOG(WARNING) << "Auth required to fetch PAC script, aborting."; result_code_ = ERR_NOT_IMPLEMENTED; @@ -152,7 +150,7 @@ void ProxyScriptFetcherImpl::OnAuthRequired(URLRequest* request, void ProxyScriptFetcherImpl::OnSSLCertificateError(URLRequest* request, int cert_error, X509Certificate* cert) { - DCHECK(request == cur_request_.get()); + DCHECK_EQ(request, cur_request_.get()); LOG(WARNING) << "SSL certificate error when fetching PAC script, aborting."; // Certificate errors are in same space as net errors. result_code_ = cert_error; @@ -160,7 +158,7 @@ void ProxyScriptFetcherImpl::OnSSLCertificateError(URLRequest* request, } void ProxyScriptFetcherImpl::OnResponseStarted(URLRequest* request) { - DCHECK(request == cur_request_.get()); + DCHECK_EQ(request, cur_request_.get()); if (!request->status().is_success()) { OnResponseCompleted(request); @@ -195,24 +193,15 @@ void ProxyScriptFetcherImpl::OnResponseStarted(URLRequest* request) { void ProxyScriptFetcherImpl::OnReadCompleted(URLRequest* request, int num_bytes) { - DCHECK(request == cur_request_.get()); - if (num_bytes > 0) { - // Enforce maximum size bound. - if (num_bytes + bytes_read_so_far_.size() > - static_cast<size_t>(max_response_bytes_)) { - result_code_ = ERR_FILE_TOO_BIG; - request->Cancel(); - return; - } - bytes_read_so_far_.append(buf_->data(), num_bytes); + DCHECK_EQ(request, cur_request_.get()); + if (ConsumeBytesRead(request, num_bytes)) { + // Keep reading. ReadBody(request); - } else { // Error while reading, or EOF - OnResponseCompleted(request); } } void ProxyScriptFetcherImpl::OnResponseCompleted(URLRequest* request) { - DCHECK(request == cur_request_.get()); + DCHECK_EQ(request, cur_request_.get()); // Use |result_code_| as the request's error if we have already set it to // something specific. @@ -223,13 +212,38 @@ void ProxyScriptFetcherImpl::OnResponseCompleted(URLRequest* request) { } void ProxyScriptFetcherImpl::ReadBody(URLRequest* request) { - int num_bytes; - if (request->Read(buf_, kBufSize, &num_bytes)) { - OnReadCompleted(request, num_bytes); - } else if (!request->status().is_io_pending()) { - // Read failed synchronously. + // Read as many bytes as are available synchronously. + while (true) { + int num_bytes; + if (!request->Read(buf_, kBufSize, &num_bytes)) { + // Check whether the read failed synchronously. + if (!request->status().is_io_pending()) + OnResponseCompleted(request); + return; + } + if (!ConsumeBytesRead(request, num_bytes)) + return; + } +} + +bool ProxyScriptFetcherImpl::ConsumeBytesRead(URLRequest* request, + int num_bytes) { + if (num_bytes <= 0) { + // Error while reading, or EOF. OnResponseCompleted(request); + return false; } + + // Enforce maximum size bound. + if (num_bytes + bytes_read_so_far_.size() > + static_cast<size_t>(max_response_bytes_)) { + result_code_ = ERR_FILE_TOO_BIG; + request->Cancel(); + return false; + } + + bytes_read_so_far_.append(buf_->data(), num_bytes); + return true; } void ProxyScriptFetcherImpl::FetchCompleted() { @@ -246,6 +260,9 @@ void ProxyScriptFetcherImpl::FetchCompleted() { int result_code = result_code_; CompletionCallback* callback = callback_; + // Hold a reference to the URLRequestContext to prevent re-entrancy from + // ~URLRequestContext. + scoped_refptr<URLRequestContext> context(cur_request_->context()); ResetCurRequestState(); callback->Run(result_code); diff --git a/net/proxy/proxy_script_fetcher_impl.h b/net/proxy/proxy_script_fetcher_impl.h index 7d2e1c6..b671f6d 100644 --- a/net/proxy/proxy_script_fetcher_impl.h +++ b/net/proxy/proxy_script_fetcher_impl.h @@ -8,6 +8,7 @@ #include "base/basictypes.h" #include "base/ref_counted.h" +#include "base/scoped_ptr.h" #include "base/string16.h" #include "base/task.h" #include "base/time.h" @@ -60,6 +61,10 @@ class ProxyScriptFetcherImpl : public ProxyScriptFetcher, // Read more bytes from the response. void ReadBody(URLRequest* request); + // Handles a response from Read(). Returns true if we should continue trying + // to read. |num_bytes| is 0 for EOF, and < 0 on errors. + bool ConsumeBytesRead(URLRequest* request, int num_bytes); + // Called once the request has completed to notify the caller of // |response_code_| and |response_text_|. void FetchCompleted(); diff --git a/net/proxy/proxy_script_fetcher_impl_unittest.cc b/net/proxy/proxy_script_fetcher_impl_unittest.cc index 4734997..347bbe9 100644 --- a/net/proxy/proxy_script_fetcher_impl_unittest.cc +++ b/net/proxy/proxy_script_fetcher_impl_unittest.cc @@ -43,7 +43,7 @@ class RequestContext : public URLRequestContext { ssl_config_service_ = new net::SSLConfigServiceDefaults; http_transaction_factory_ = new net::HttpCache( - net::HttpNetworkLayer::CreateFactory(host_resolver_, NULL, NULL, + net::HttpNetworkLayer::CreateFactory(host_resolver_, NULL, NULL, NULL, proxy_service_, ssl_config_service_, NULL, NULL, NULL), net::HttpCache::DefaultBackend::InMemory(0)); } @@ -85,14 +85,13 @@ class ProxyScriptFetcherImplTest : public PlatformTest { TEST_F(ProxyScriptFetcherImplTest, FileUrl) { scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcher> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); { // Fetch a non-existent file. string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(GetTestFileUrl("does-not-exist"), - &text, &callback); + int result = pac_fetcher.Fetch(GetTestFileUrl("does-not-exist"), + &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(ERR_FILE_NOT_FOUND, callback.WaitForResult()); EXPECT_TRUE(text.empty()); @@ -100,8 +99,8 @@ TEST_F(ProxyScriptFetcherImplTest, FileUrl) { { // Fetch a file that exists. string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(GetTestFileUrl("pac.txt"), - &text, &callback); + int result = pac_fetcher.Fetch(GetTestFileUrl("pac.txt"), + &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-pac.txt-\n"), text); @@ -114,14 +113,13 @@ TEST_F(ProxyScriptFetcherImplTest, HttpMimeType) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcher> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); { // Fetch a PAC with mime type "text/plain" GURL url(test_server_.GetURL("files/pac.txt")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-pac.txt-\n"), text); @@ -130,7 +128,7 @@ TEST_F(ProxyScriptFetcherImplTest, HttpMimeType) { GURL url(test_server_.GetURL("files/pac.html")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-pac.html-\n"), text); @@ -139,7 +137,7 @@ TEST_F(ProxyScriptFetcherImplTest, HttpMimeType) { GURL url(test_server_.GetURL("files/pac.nsproxy")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-pac.nsproxy-\n"), text); @@ -150,14 +148,13 @@ TEST_F(ProxyScriptFetcherImplTest, HttpStatusCode) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcher> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); { // Fetch a PAC which gives a 500 -- FAIL GURL url(test_server_.GetURL("files/500.pac")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(ERR_PAC_STATUS_NOT_OK, callback.WaitForResult()); EXPECT_TRUE(text.empty()); @@ -166,7 +163,7 @@ TEST_F(ProxyScriptFetcherImplTest, HttpStatusCode) { GURL url(test_server_.GetURL("files/404.pac")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(ERR_PAC_STATUS_NOT_OK, callback.WaitForResult()); EXPECT_TRUE(text.empty()); @@ -177,15 +174,14 @@ TEST_F(ProxyScriptFetcherImplTest, ContentDisposition) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcher> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); // Fetch PAC scripts via HTTP with a Content-Disposition header -- should // have no effect. GURL url(test_server_.GetURL("files/downloadable.pac")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-downloadable.pac-\n"), text); @@ -195,15 +191,14 @@ TEST_F(ProxyScriptFetcherImplTest, NoCache) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcher> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); // Fetch a PAC script whose HTTP headers make it cacheable for 1 hour. GURL url(test_server_.GetURL("files/cacheable_1hr.pac")); { string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-cacheable_1hr.pac-\n"), text); @@ -218,7 +213,7 @@ TEST_F(ProxyScriptFetcherImplTest, NoCache) { { string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(ERR_CONNECTION_REFUSED, callback.WaitForResult()); } @@ -228,11 +223,10 @@ TEST_F(ProxyScriptFetcherImplTest, TooLarge) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcherImpl> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); // Set the maximum response size to 50 bytes. - int prev_size = pac_fetcher->SetSizeConstraint(50); + int prev_size = pac_fetcher.SetSizeConstraint(50); // These two URLs are the same file, but are http:// vs file:// GURL urls[] = { @@ -246,20 +240,20 @@ TEST_F(ProxyScriptFetcherImplTest, TooLarge) { const GURL& url = urls[i]; string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(ERR_FILE_TOO_BIG, callback.WaitForResult()); EXPECT_TRUE(text.empty()); } // Restore the original size bound. - pac_fetcher->SetSizeConstraint(prev_size); + pac_fetcher.SetSizeConstraint(prev_size); { // Make sure we can still fetch regular URLs. GURL url(test_server_.GetURL("files/pac.nsproxy")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-pac.nsproxy-\n"), text); @@ -270,11 +264,10 @@ TEST_F(ProxyScriptFetcherImplTest, Hang) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcherImpl> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); // Set the timeout period to 0.5 seconds. - base::TimeDelta prev_timeout = pac_fetcher->SetTimeoutConstraint( + base::TimeDelta prev_timeout = pac_fetcher.SetTimeoutConstraint( base::TimeDelta::FromMilliseconds(500)); // Try fetching a URL which takes 1.2 seconds. We should abort the request @@ -282,20 +275,20 @@ TEST_F(ProxyScriptFetcherImplTest, Hang) { { GURL url(test_server_.GetURL("slow/proxy.pac?1.2")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(ERR_TIMED_OUT, callback.WaitForResult()); EXPECT_TRUE(text.empty()); } // Restore the original timeout period. - pac_fetcher->SetTimeoutConstraint(prev_timeout); + pac_fetcher.SetTimeoutConstraint(prev_timeout); { // Make sure we can still fetch regular URLs. GURL url(test_server_.GetURL("files/pac.nsproxy")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("-pac.nsproxy-\n"), text); @@ -309,15 +302,14 @@ TEST_F(ProxyScriptFetcherImplTest, Encodings) { ASSERT_TRUE(test_server_.Start()); scoped_refptr<URLRequestContext> context(new RequestContext); - scoped_ptr<ProxyScriptFetcher> pac_fetcher( - new ProxyScriptFetcherImpl(context)); + ProxyScriptFetcherImpl pac_fetcher(context); // Test a response that is gzip-encoded -- should get inflated. { GURL url(test_server_.GetURL("files/gzipped_pac")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("This data was gzipped.\n"), text); @@ -329,7 +321,7 @@ TEST_F(ProxyScriptFetcherImplTest, Encodings) { GURL url(test_server_.GetURL("files/utf16be_pac")); string16 text; TestCompletionCallback callback; - int result = pac_fetcher->Fetch(url, &text, &callback); + int result = pac_fetcher.Fetch(url, &text, &callback); EXPECT_EQ(ERR_IO_PENDING, result); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ASCIIToUTF16("This was encoded as UTF-16BE.\n"), text); diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc index 72afd63..8965630 100644 --- a/net/socket/client_socket_factory.cc +++ b/net/socket/client_socket_factory.cc @@ -21,7 +21,7 @@ namespace net { -class DnsRRResolver; +class DnsCertProvenanceChecker; namespace { @@ -30,7 +30,7 @@ SSLClientSocket* DefaultSSLClientSocketFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { scoped_ptr<SSLHostInfo> shi(ssl_host_info); #if defined(OS_WIN) return new SSLClientSocketWin(transport_socket, host_and_port, ssl_config); @@ -39,10 +39,10 @@ SSLClientSocket* DefaultSSLClientSocketFactory( ssl_config); #elif defined(USE_NSS) return new SSLClientSocketNSS(transport_socket, host_and_port, ssl_config, - shi.release(), dnsrr_resolver); + shi.release(), dns_cert_checker); #elif defined(OS_MACOSX) return new SSLClientSocketNSS(transport_socket, host_and_port, ssl_config, - shi.release(), dnsrr_resolver); + shi.release(), dns_cert_checker); #else NOTIMPLEMENTED(); return NULL; @@ -65,9 +65,9 @@ class DefaultClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { return g_ssl_factory(transport_socket, host_and_port, ssl_config, - ssl_host_info, dnsrr_resolver); + ssl_host_info, dns_cert_checker); } }; @@ -93,7 +93,8 @@ SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( ClientSocketHandle* socket_handle = new ClientSocketHandle(); socket_handle->set_socket(transport_socket); return CreateSSLClientSocket(socket_handle, host_and_port, ssl_config, - ssl_host_info, NULL /* DnsRRResolver */); + ssl_host_info, + NULL /* DnsCertProvenanceChecker */); } } // namespace net diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h index 196b2ab..0ab370a 100644 --- a/net/socket/client_socket_factory.h +++ b/net/socket/client_socket_factory.h @@ -16,7 +16,7 @@ namespace net { class AddressList; class ClientSocket; class ClientSocketHandle; -class DnsRRResolver; +class DnsCertProvenanceChecker; class HostPortPair; class SSLClientSocket; struct SSLConfig; @@ -28,7 +28,7 @@ typedef SSLClientSocket* (*SSLClientSocketFactory)( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver); + DnsCertProvenanceChecker* dns_cert_checker); // An interface used to instantiate ClientSocket objects. Used to facilitate // testing code with mock socket implementations. @@ -48,7 +48,7 @@ class ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) = 0; + DnsCertProvenanceChecker* dns_cert_checker) = 0; // Deprecated function (http://crbug.com/37810) that takes a ClientSocket. virtual SSLClientSocket* CreateSSLClientSocket( diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index 2228729..86ba2dd 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -233,7 +233,6 @@ void ClientSocketPoolBaseHelper::RequestSockets( DCHECK(!request.handle()); if (num_sockets > max_sockets_per_group_) { - NOTREACHED(); num_sockets = max_sockets_per_group_; } @@ -244,17 +243,31 @@ void ClientSocketPoolBaseHelper::RequestSockets( Group* group = GetOrCreateGroup(group_name); + // RequestSocketsInternal() may delete the group. + bool deleted_group = false; + for (int num_iterations_left = num_sockets; group->NumActiveSocketSlots() < num_sockets && num_iterations_left > 0 ; num_iterations_left--) { int rv = RequestSocketInternal(group_name, &request); + // TODO(willchan): Possibly check for ERR_PRECONNECT_MAX_SOCKET_LIMIT so we + // can log it into the NetLog. if (rv < 0 && rv != ERR_IO_PENDING) { // We're encountering a synchronous error. Give up. + if (!ContainsKey(group_map_, group_name)) + deleted_group = true; + break; + } + if (!ContainsKey(group_map_, group_name)) { + // Unexpected. The group should only be getting deleted on synchronous + // error. + NOTREACHED(); + deleted_group = true; break; } } - if (group->IsEmpty()) + if (!deleted_group && group->IsEmpty()) RemoveGroup(group_name); request.net_log().EndEvent( @@ -287,7 +300,9 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( if (ReachedMaxSocketsLimit()) { if (idle_socket_count() > 0) { - CloseOneIdleSocket(); + bool closed = CloseOneIdleSocketExceptInGroup(group); + if (preconnecting && !closed) + return ERR_PRECONNECT_MAX_SOCKET_LIMIT; } else { // We could check if we really have a stalled group here, but it requires // a scan of all groups, so just flip a flag here, and do the check later. @@ -929,10 +944,17 @@ bool ClientSocketPoolBaseHelper::ReachedMaxSocketsLimit() const { } void ClientSocketPoolBaseHelper::CloseOneIdleSocket() { + CloseOneIdleSocketExceptInGroup(NULL); +} + +bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup( + const Group* exception_group) { CHECK_GT(idle_socket_count(), 0); for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end(); ++i) { Group* group = i->second; + if (exception_group == group) + continue; std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets(); if (!idle_sockets->empty()) { @@ -942,11 +964,14 @@ void ClientSocketPoolBaseHelper::CloseOneIdleSocket() { if (group->IsEmpty()) RemoveGroup(i); - return; + return true; } } - LOG(DFATAL) << "No idle socket found to close!."; + if (!exception_group) + LOG(DFATAL) << "No idle socket found to close!."; + + return false; } void ClientSocketPoolBaseHelper::InvokeUserCallbackLater( diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h index 2e8c618..8e6eb13 100644 --- a/net/socket/client_socket_pool_base.h +++ b/net/socket/client_socket_pool_base.h @@ -480,6 +480,11 @@ class ClientSocketPoolBaseHelper // I'm not sure if we hit this situation often. void CloseOneIdleSocket(); + // Same as CloseOneIdleSocket() except it won't close an idle socket in + // |group|. If |group| is NULL, it is ignored. Returns true if it closed a + // socket. + bool CloseOneIdleSocketExceptInGroup(const Group* group); + // Checks if there are stalled socket groups that should be notified // for possible wakeup. void CheckForStalledSocketGroups(); diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 5e7eb7f..5f5a636 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -110,7 +110,7 @@ class MockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { NOTIMPLEMENTED(); delete ssl_host_info; return NULL; @@ -2898,6 +2898,16 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsSynchronous) { EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->IdleSocketCountInGroup("b")); } +TEST_F(ClientSocketPoolBaseTest, RequestSocketsSynchronousError) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); + + pool_->RequestSockets("a", ¶ms_, kDefaultMaxSocketsPerGroup, + BoundNetLog()); + + ASSERT_FALSE(pool_->HasGroup("a")); +} + TEST_F(ClientSocketPoolBaseTest, RequestSocketsMultipleTimesDoesNothing) { CreatePool(4, 4); connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); @@ -2997,6 +3007,81 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectJobsTakenByNormalRequests) { EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); } +// http://crbug.com/64940 regression test. +TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) { + const int kMaxTotalSockets = 3; + const int kMaxSocketsPerGroup = 2; + CreatePool(kMaxTotalSockets, kMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + // Note that group name ordering matters here. "a" comes before "b", so + // CloseOneIdleSocket() will try to close "a"'s idle socket. + + // Set up one idle socket in "a". + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + &callback1, + pool_.get(), + BoundNetLog())); + + ASSERT_EQ(OK, callback1.WaitForResult()); + handle1.Reset(); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + + // Set up two active sockets in "b". + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("b", + params_, + kDefaultPriority, + &callback1, + pool_.get(), + BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("b", + params_, + kDefaultPriority, + &callback2, + pool_.get(), + BoundNetLog())); + + ASSERT_EQ(OK, callback1.WaitForResult()); + ASSERT_EQ(OK, callback2.WaitForResult()); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("b")); + + // Now we have 1 idle socket in "a" and 2 active sockets in "b". This means + // we've maxed out on sockets, since we set |kMaxTotalSockets| to 3. + // Requesting 2 preconnected sockets for "a" should fail to allocate any more + // sockets for "a", and "b" should still have 2 active sockets. + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("b")); + + // Now release the 2 active sockets for "b". This will give us 1 idle socket + // in "a" and 2 idle sockets in "b". Requesting 2 preconnected sockets for + // "a" should result in closing 1 for "b". + handle1.Reset(); + handle2.Reset(); + EXPECT_EQ(2, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("b")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("b")); +} + } // namespace } // namespace net diff --git a/net/socket/client_socket_pool_manager.cc b/net/socket/client_socket_pool_manager.cc index 512360b..6c73c36 100644 --- a/net/socket/client_socket_pool_manager.cc +++ b/net/socket/client_socket_pool_manager.cc @@ -56,6 +56,7 @@ ClientSocketPoolManager::ClientSocketPoolManager( ClientSocketFactory* socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service) @@ -63,6 +64,7 @@ ClientSocketPoolManager::ClientSocketPoolManager( socket_factory_(socket_factory), host_resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), + dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), @@ -79,6 +81,7 @@ ClientSocketPoolManager::ClientSocketPoolManager( &ssl_pool_histograms_, host_resolver, dnsrr_resolver, + dns_cert_checker, ssl_host_info_factory, socket_factory, tcp_socket_pool_.get(), @@ -228,6 +231,7 @@ HttpProxyClientSocketPool* ClientSocketPoolManager::GetSocketPoolForHTTPProxy( &ssl_for_https_proxy_pool_histograms_, host_resolver_, dnsrr_resolver_, + dns_cert_checker_, ssl_host_info_factory_, socket_factory_, tcp_https_ret.first->second /* https proxy */, @@ -263,6 +267,7 @@ SSLClientSocketPool* ClientSocketPoolManager::GetSocketPoolForSSLWithProxy( &ssl_pool_histograms_, host_resolver_, dnsrr_resolver_, + dns_cert_checker_, ssl_host_info_factory_, socket_factory_, NULL, /* no tcp pool, we always go through a proxy */ diff --git a/net/socket/client_socket_pool_manager.h b/net/socket/client_socket_pool_manager.h index c6d8f6f..823213e 100644 --- a/net/socket/client_socket_pool_manager.h +++ b/net/socket/client_socket_pool_manager.h @@ -25,6 +25,7 @@ namespace net { class ClientSocketFactory; class ClientSocketPoolHistograms; +class DnsCertProvenanceChecker; class DnsRRResolver; class HostPortPair; class HttpProxyClientSocketPool; @@ -61,6 +62,7 @@ class ClientSocketPoolManager : public NonThreadSafe { ClientSocketFactory* socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ProxyService* proxy_service, SSLConfigService* ssl_config_service); @@ -105,6 +107,7 @@ class ClientSocketPoolManager : public NonThreadSafe { ClientSocketFactory* const socket_factory_; HostResolver* const host_resolver_; DnsRRResolver* const dnsrr_resolver_; + DnsCertProvenanceChecker* const dns_cert_checker_; SSLHostInfoFactory* const ssl_host_info_factory_; const scoped_refptr<ProxyService> proxy_service_; const scoped_refptr<SSLConfigService> ssl_config_service_; diff --git a/net/socket/dns_cert_provenance_check.cc b/net/socket/dns_cert_provenance_check.cc deleted file mode 100644 index 61b9a04..0000000 --- a/net/socket/dns_cert_provenance_check.cc +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/dns_cert_provenance_check.h" - -#include <nspr.h> - -#include <hasht.h> -#include <keyhi.h> -#include <pk11pub.h> -#include <sechash.h> - -#include <string> - -#include "base/crypto/encryptor.h" -#include "base/crypto/symmetric_key.h" -#include "base/non_thread_safe.h" -#include "base/pickle.h" -#include "net/base/completion_callback.h" -#include "net/base/dns_util.h" -#include "net/base/dnsrr_resolver.h" -#include "net/base/net_errors.h" -#include "net/base/net_log.h" - -namespace net { - -namespace { - -// A DER encoded SubjectPublicKeyInfo structure containing the server's public -// key. -const uint8 kServerPublicKey[] = { - 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, - 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, - 0x04, 0xc7, 0xea, 0x88, 0x60, 0x52, 0xe3, 0xa3, 0x3e, 0x39, 0x92, 0x0f, 0xa4, - 0x3d, 0xba, 0xd8, 0x02, 0x2d, 0x06, 0x4d, 0x64, 0x98, 0x66, 0xb4, 0x82, 0xf0, - 0x23, 0xa6, 0xd8, 0x37, 0x55, 0x7c, 0x01, 0xbf, 0x18, 0xd8, 0x16, 0x9e, 0x66, - 0xdc, 0x49, 0xbf, 0x2e, 0x86, 0xe3, 0x99, 0xbd, 0xb3, 0x75, 0x25, 0x61, 0x04, - 0x6c, 0x2e, 0xfb, 0x32, 0x42, 0x27, 0xe4, 0x23, 0xea, 0xcd, 0x81, 0x62, 0xc1, -}; - -class DNSCertProvenanceChecker : public NonThreadSafe { - public: - DNSCertProvenanceChecker(const std::string hostname, - DnsRRResolver* dnsrr_resolver, - const std::vector<base::StringPiece>& der_certs) - : hostname_(hostname), - dnsrr_resolver_(dnsrr_resolver), - der_certs_(der_certs.size()), - handle_(DnsRRResolver::kInvalidHandle), - ALLOW_THIS_IN_INITIALIZER_LIST(callback_( - this, &DNSCertProvenanceChecker::ResolutionComplete)) { - for (size_t i = 0; i < der_certs.size(); i++) - der_certs_[i] = der_certs[i].as_string(); - } - - void Start() { - DCHECK(CalledOnValidThread()); - - if (der_certs_.empty()) - return; - - uint8 fingerprint[SHA1_LENGTH]; - SECStatus rv = HASH_HashBuf( - HASH_AlgSHA1, fingerprint, (uint8*) der_certs_[0].data(), - der_certs_[0].size()); - DCHECK_EQ(SECSuccess, rv); - char fingerprint_hex[SHA1_LENGTH * 2 + 1]; - for (unsigned i = 0; i < sizeof(fingerprint); i++) { - static const char hextable[] = "0123456789abcdef"; - fingerprint_hex[i*2] = hextable[fingerprint[i] >> 4]; - fingerprint_hex[i*2 + 1] = hextable[fingerprint[i] & 15]; - } - fingerprint_hex[SHA1_LENGTH * 2] = 0; - - static const char kBaseCertName[] = ".certs.links.org"; - domain_.assign(fingerprint_hex); - domain_.append(kBaseCertName); - - handle_ = dnsrr_resolver_->Resolve( - domain_, kDNS_TXT, 0 /* flags */, &callback_, &response_, - 0 /* priority */, BoundNetLog()); - if (handle_ == DnsRRResolver::kInvalidHandle) { - LOG(ERROR) << "Failed to resolve " << domain_ << " for " << hostname_; - delete this; - } - } - - private: - void ResolutionComplete(int status) { - DCHECK(CalledOnValidThread()); - - if (status == ERR_NAME_NOT_RESOLVED || - (status == OK && response_.rrdatas.empty())) { - LOG(ERROR) << "FAILED" - << " hostname:" << hostname_ - << " domain:" << domain_; - BuildRecord(); - } else if (status == OK) { - LOG(ERROR) << "GOOD" - << " hostname:" << hostname_ - << " resp:" << response_.rrdatas[0]; - } else { - LOG(ERROR) << "Unknown error " << status << " for " << domain_; - } - - delete this; - } - - // BuildRecord encrypts the certificate chain to a fixed public key and - // returns the encrypted blob. Since this code is reporting a possible HTTPS - // failure, it would seem silly to use HTTPS to protect the uploaded report. - std::string BuildRecord() { - static const int kVersion = 0; - static const unsigned kKeySizeInBytes = 16; // AES-128 - static const unsigned kIVSizeInBytes = 16; // AES's block size - static const unsigned kPadSize = 4096; // we pad up to 4KB, - // This is a DER encoded, ANSI X9.62 CurveParams object which simply - // specifies P256. - static const uint8 kANSIX962CurveParams[] = { - 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07 - }; - - DCHECK(CalledOnValidThread()); - - Pickle p; - p.WriteString(hostname_); - p.WriteInt(der_certs_.size()); - for (std::vector<std::string>::const_iterator - i = der_certs_.begin(); i != der_certs_.end(); i++) { - p.WriteString(*i); - } - // We pad to eliminate the possibility that someone could see the size of - // an upload and use that information to reduce the anonymity set of the - // certificate chain. - // The "2*sizeof(uint32)" here covers the padding length which we add next - // and Pickle's internal length which it includes at the beginning of the - // data. - unsigned pad_bytes = kPadSize - ((p.size() + 2*sizeof(uint32)) % kPadSize); - p.WriteUInt32(pad_bytes); - char* padding = new char[pad_bytes]; - memset(padding, 0, pad_bytes); - p.WriteData(padding, pad_bytes); - delete[] padding; - - // We generate a random public value and perform a DH key agreement with - // the server's fixed value. - SECKEYPublicKey* pub_key = NULL; - SECKEYPrivateKey* priv_key = NULL; - SECItem ec_der_params; - memset(&ec_der_params, 0, sizeof(ec_der_params)); - ec_der_params.data = const_cast<uint8*>(kANSIX962CurveParams); - ec_der_params.len = sizeof(kANSIX962CurveParams); - priv_key = SECKEY_CreateECPrivateKey(&ec_der_params, &pub_key, NULL); - SECKEYPublicKey* server_pub_key = GetServerPubKey(); - - // This extracts the big-endian, x value of the shared point. - // The values of the arguments match ssl3_SendECDHClientKeyExchange in NSS - // 3.12.8's lib/ssl/ssl3ecc.c - PK11SymKey* pms = PK11_PubDeriveWithKDF( - priv_key, server_pub_key, PR_FALSE /* is sender */, - NULL /* random a */, NULL /* random b */, CKM_ECDH1_DERIVE, - CKM_TLS_MASTER_KEY_DERIVE_DH, CKA_DERIVE, 0 /* key size */, - CKD_NULL /* KDF */, NULL /* shared data */, NULL /* wincx */); - SECKEY_DestroyPublicKey(server_pub_key); - SECStatus rv = PK11_ExtractKeyValue(pms); - DCHECK_EQ(SECSuccess, rv); - SECItem* x_data = PK11_GetKeyData(pms); - - // The key and IV are 128-bits and generated from a SHA256 hash of the x - // value. - char key_data[SHA256_LENGTH]; - HASH_HashBuf(HASH_AlgSHA256, reinterpret_cast<uint8*>(key_data), - x_data->data, x_data->len); - PK11_FreeSymKey(pms); - - DCHECK_GE(sizeof(key_data), kKeySizeInBytes + kIVSizeInBytes); - std::string raw_key(key_data, kKeySizeInBytes); - - scoped_ptr<base::SymmetricKey> symkey( - base::SymmetricKey::Import(base::SymmetricKey::AES, raw_key)); - std::string iv(key_data + kKeySizeInBytes, kIVSizeInBytes); - - base::Encryptor encryptor; - bool r = encryptor.Init(symkey.get(), base::Encryptor::CBC, iv); - CHECK(r); - - std::string plaintext(reinterpret_cast<const char*>(p.data()), p.size()); - std::string ciphertext; - encryptor.Encrypt(plaintext, &ciphertext); - - // We use another Pickle object to serialise the 'outer' wrapping of the - // plaintext. - Pickle outer; - outer.WriteInt(kVersion); - - SECItem* pub_key_serialized = SECKEY_EncodeDERSubjectPublicKeyInfo(pub_key); - outer.WriteString( - std::string(reinterpret_cast<char*>(pub_key_serialized->data), - pub_key_serialized->len)); - SECITEM_FreeItem(pub_key_serialized, PR_TRUE); - - outer.WriteString(ciphertext); - - SECKEY_DestroyPublicKey(pub_key); - SECKEY_DestroyPrivateKey(priv_key); - - return std::string(reinterpret_cast<const char*>(outer.data()), - outer.size()); - } - - SECKEYPublicKey* GetServerPubKey() { - DCHECK(CalledOnValidThread()); - - SECItem der; - memset(&der, 0, sizeof(der)); - der.data = const_cast<uint8*>(kServerPublicKey); - der.len = sizeof(kServerPublicKey); - - CERTSubjectPublicKeyInfo* spki = SECKEY_DecodeDERSubjectPublicKeyInfo(&der); - SECKEYPublicKey* public_key = SECKEY_ExtractPublicKey(spki); - SECKEY_DestroySubjectPublicKeyInfo(spki); - - return public_key; - } - - const std::string hostname_; - std::string domain_; - DnsRRResolver* const dnsrr_resolver_; - std::vector<std::string> der_certs_; - RRResponse response_; - DnsRRResolver::Handle handle_; - CompletionCallbackImpl<DNSCertProvenanceChecker> callback_; -}; - -} // anonymous namespace - -void DoAsyncDNSCertProvenanceVerification( - const std::string& hostname, - DnsRRResolver* dnsrr_resolver, - const std::vector<base::StringPiece>& der_certs) { - DNSCertProvenanceChecker* c(new DNSCertProvenanceChecker( - hostname, dnsrr_resolver, der_certs)); - c->Start(); -} - -} // namespace net diff --git a/net/socket/dns_cert_provenance_check.h b/net/socket/dns_cert_provenance_check.h deleted file mode 100644 index 289cccf..0000000 --- a/net/socket/dns_cert_provenance_check.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_DNS_CERT_PROVENANCE_CHECK_H -#define NET_SOCKET_DNS_CERT_PROVENANCE_CHECK_H - -#include <string> -#include <vector> - -#include "base/string_piece.h" - -namespace net { - -class DnsRRResolver; - -// DoAsyncDNSCertProvenanceVerification starts an asynchronous check for the -// given certificate chain. It must be run on the network thread. -void DoAsyncDNSCertProvenanceVerification( - const std::string& hostname, - DnsRRResolver* dnsrr_resolver, - const std::vector<base::StringPiece>& der_certs); - -} // namespace net - -#endif // NET_SOCKET_DNS_CERT_PROVENANCE_CHECK_H diff --git a/net/socket/dns_cert_provenance_checker.cc b/net/socket/dns_cert_provenance_checker.cc new file mode 100644 index 0000000..27c4982 --- /dev/null +++ b/net/socket/dns_cert_provenance_checker.cc @@ -0,0 +1,330 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/dns_cert_provenance_checker.h" + +#if !defined(USE_OPENSSL) + +#include <nspr.h> + +#include <hasht.h> +#include <keyhi.h> +#include <pk11pub.h> +#include <sechash.h> + +#include <set> +#include <string> + +#include "base/basictypes.h" +#include "base/crypto/encryptor.h" +#include "base/crypto/symmetric_key.h" +#include "base/non_thread_safe.h" +#include "base/pickle.h" +#include "base/scoped_ptr.h" +#include "base/singleton.h" +#include "net/base/completion_callback.h" +#include "net/base/dns_util.h" +#include "net/base/dnsrr_resolver.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" + +namespace net { + +namespace { + +// A DER encoded SubjectPublicKeyInfo structure containing the server's public +// key. +const uint8 kServerPublicKey[] = { + 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, + 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, + 0x04, 0xc7, 0xea, 0x88, 0x60, 0x52, 0xe3, 0xa3, 0x3e, 0x39, 0x92, 0x0f, 0xa4, + 0x3d, 0xba, 0xd8, 0x02, 0x2d, 0x06, 0x4d, 0x64, 0x98, 0x66, 0xb4, 0x82, 0xf0, + 0x23, 0xa6, 0xd8, 0x37, 0x55, 0x7c, 0x01, 0xbf, 0x18, 0xd8, 0x16, 0x9e, 0x66, + 0xdc, 0x49, 0xbf, 0x2e, 0x86, 0xe3, 0x99, 0xbd, 0xb3, 0x75, 0x25, 0x61, 0x04, + 0x6c, 0x2e, 0xfb, 0x32, 0x42, 0x27, 0xe4, 0x23, 0xea, 0xcd, 0x81, 0x62, 0xc1, +}; + +const unsigned kMaxUploadsPerSession = 10; + +// DnsCertLimits is a singleton class which keeps track of which hosts we have +// uploaded reports for in this session. Since some users will be behind MITM +// proxies, they would otherwise upload for every host and we don't wish to +// spam the upload server. +class DnsCertLimits { + public: + DnsCertLimits() { } + + // HaveReachedMaxUploads returns true iff we have uploaded the maximum number + // of DNS certificate reports for this session. + bool HaveReachedMaxUploads() { + return uploaded_hostnames_.size() >= kMaxUploadsPerSession; + } + + // HaveReachedMaxUploads returns true iff we have already uploaded a report + // about the given hostname in this session. + bool HaveUploadedForHostname(const std::string& hostname) { + return uploaded_hostnames_.count(hostname) > 0; + } + + void DidUpload(const std::string& hostname) { + uploaded_hostnames_.insert(hostname); + } + + private: + friend struct DefaultSingletonTraits<DnsCertLimits>; + + std::set<std::string> uploaded_hostnames_; + + DISALLOW_COPY_AND_ASSIGN(DnsCertLimits); +}; + +// DnsCertProvenanceCheck performs the DNS lookup of the certificate. This +// class is self-deleting. +class DnsCertProvenanceCheck : public NonThreadSafe { + public: + DnsCertProvenanceCheck( + const std::string& hostname, + DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker::Delegate* delegate, + const std::vector<base::StringPiece>& der_certs) + : hostname_(hostname), + dnsrr_resolver_(dnsrr_resolver), + delegate_(delegate), + der_certs_(der_certs.size()), + handle_(DnsRRResolver::kInvalidHandle), + ALLOW_THIS_IN_INITIALIZER_LIST(callback_( + this, &DnsCertProvenanceCheck::ResolutionComplete)) { + for (size_t i = 0; i < der_certs.size(); i++) + der_certs_[i] = der_certs[i].as_string(); + } + + void Start() { + DCHECK(CalledOnValidThread()); + + if (der_certs_.empty()) + return; + + DnsCertLimits* const limits = Singleton<DnsCertLimits>::get(); + if (limits->HaveReachedMaxUploads() || + limits->HaveUploadedForHostname(hostname_)) { + return; + } + + uint8 fingerprint[SHA1_LENGTH]; + SECStatus rv = HASH_HashBuf( + HASH_AlgSHA1, fingerprint, (uint8*) der_certs_[0].data(), + der_certs_[0].size()); + DCHECK_EQ(SECSuccess, rv); + char fingerprint_hex[SHA1_LENGTH * 2 + 1]; + for (unsigned i = 0; i < sizeof(fingerprint); i++) { + static const char hextable[] = "0123456789abcdef"; + fingerprint_hex[i*2] = hextable[fingerprint[i] >> 4]; + fingerprint_hex[i*2 + 1] = hextable[fingerprint[i] & 15]; + } + fingerprint_hex[SHA1_LENGTH * 2] = 0; + + static const char kBaseCertName[] = ".certs.links.org"; + domain_.assign(fingerprint_hex); + domain_.append(kBaseCertName); + + handle_ = dnsrr_resolver_->Resolve( + domain_, kDNS_TXT, 0 /* flags */, &callback_, &response_, + 0 /* priority */, BoundNetLog()); + if (handle_ == DnsRRResolver::kInvalidHandle) { + LOG(ERROR) << "Failed to resolve " << domain_ << " for " << hostname_; + delete this; + } + } + + private: + void ResolutionComplete(int status) { + DCHECK(CalledOnValidThread()); + + if (status == ERR_NAME_NOT_RESOLVED || + (status == OK && response_.rrdatas.empty())) { + LOG(ERROR) << "FAILED" + << " hostname:" << hostname_ + << " domain:" << domain_; + Singleton<DnsCertLimits>::get()->DidUpload(hostname_); + delegate_->OnDnsCertLookupFailed(hostname_, der_certs_); + } else if (status == OK) { + LOG(ERROR) << "GOOD" + << " hostname:" << hostname_ + << " resp:" << response_.rrdatas[0]; + } else { + LOG(ERROR) << "Unknown error " << status << " for " << domain_; + } + + delete this; + } + + const std::string hostname_; + std::string domain_; + DnsRRResolver* dnsrr_resolver_; + DnsCertProvenanceChecker::Delegate* const delegate_; + std::vector<std::string> der_certs_; + RRResponse response_; + DnsRRResolver::Handle handle_; + CompletionCallbackImpl<DnsCertProvenanceCheck> callback_; +}; + +SECKEYPublicKey* GetServerPubKey() { + SECItem der; + memset(&der, 0, sizeof(der)); + der.data = const_cast<uint8*>(kServerPublicKey); + der.len = sizeof(kServerPublicKey); + + CERTSubjectPublicKeyInfo* spki = SECKEY_DecodeDERSubjectPublicKeyInfo(&der); + SECKEYPublicKey* public_key = SECKEY_ExtractPublicKey(spki); + SECKEY_DestroySubjectPublicKeyInfo(spki); + + return public_key; +} + +} // namespace + +// static +std::string DnsCertProvenanceChecker::BuildEncryptedReport( + const std::string& hostname, + const std::vector<std::string>& der_certs) { + static const int kVersion = 0; + static const unsigned kKeySizeInBytes = 16; // AES-128 + static const unsigned kIVSizeInBytes = 16; // AES's block size + static const unsigned kPadSize = 4096; // we pad up to 4KB, + // This is a DER encoded, ANSI X9.62 CurveParams object which simply + // specifies P256. + static const uint8 kANSIX962CurveParams[] = { + 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07 + }; + + Pickle p; + p.WriteString(hostname); + p.WriteInt(der_certs.size()); + for (std::vector<std::string>::const_iterator + i = der_certs.begin(); i != der_certs.end(); i++) { + p.WriteString(*i); + } + // We pad to eliminate the possibility that someone could see the size of + // an upload and use that information to reduce the anonymity set of the + // certificate chain. + // The "2*sizeof(uint32)" here covers the padding length which we add next + // and Pickle's internal length which it includes at the beginning of the + // data. + unsigned pad_bytes = kPadSize - ((p.size() + 2*sizeof(uint32)) % kPadSize); + p.WriteUInt32(pad_bytes); + char* padding = new char[pad_bytes]; + memset(padding, 0, pad_bytes); + p.WriteData(padding, pad_bytes); + delete[] padding; + + // We generate a random public value and perform a DH key agreement with + // the server's fixed value. + SECKEYPublicKey* pub_key = NULL; + SECKEYPrivateKey* priv_key = NULL; + SECItem ec_der_params; + memset(&ec_der_params, 0, sizeof(ec_der_params)); + ec_der_params.data = const_cast<uint8*>(kANSIX962CurveParams); + ec_der_params.len = sizeof(kANSIX962CurveParams); + priv_key = SECKEY_CreateECPrivateKey(&ec_der_params, &pub_key, NULL); + SECKEYPublicKey* server_pub_key = GetServerPubKey(); + + // This extracts the big-endian, x value of the shared point. + // The values of the arguments match ssl3_SendECDHClientKeyExchange in NSS + // 3.12.8's lib/ssl/ssl3ecc.c + PK11SymKey* pms = PK11_PubDeriveWithKDF( + priv_key, server_pub_key, PR_FALSE /* is sender */, + NULL /* random a */, NULL /* random b */, CKM_ECDH1_DERIVE, + CKM_TLS_MASTER_KEY_DERIVE_DH, CKA_DERIVE, 0 /* key size */, + CKD_NULL /* KDF */, NULL /* shared data */, NULL /* wincx */); + SECKEY_DestroyPublicKey(server_pub_key); + SECStatus rv = PK11_ExtractKeyValue(pms); + DCHECK_EQ(SECSuccess, rv); + SECItem* x_data = PK11_GetKeyData(pms); + + // The key and IV are 128-bits and generated from a SHA256 hash of the x + // value. + char key_data[SHA256_LENGTH]; + HASH_HashBuf(HASH_AlgSHA256, reinterpret_cast<uint8*>(key_data), + x_data->data, x_data->len); + PK11_FreeSymKey(pms); + + DCHECK_GE(sizeof(key_data), kKeySizeInBytes + kIVSizeInBytes); + std::string raw_key(key_data, kKeySizeInBytes); + + scoped_ptr<base::SymmetricKey> symkey( + base::SymmetricKey::Import(base::SymmetricKey::AES, raw_key)); + std::string iv(key_data + kKeySizeInBytes, kIVSizeInBytes); + + base::Encryptor encryptor; + bool r = encryptor.Init(symkey.get(), base::Encryptor::CBC, iv); + CHECK(r); + + std::string plaintext(reinterpret_cast<const char*>(p.data()), p.size()); + std::string ciphertext; + encryptor.Encrypt(plaintext, &ciphertext); + + // We use another Pickle object to serialise the 'outer' wrapping of the + // plaintext. + Pickle outer; + outer.WriteInt(kVersion); + + SECItem* pub_key_serialized = SECKEY_EncodeDERSubjectPublicKeyInfo(pub_key); + outer.WriteString( + std::string(reinterpret_cast<char*>(pub_key_serialized->data), + pub_key_serialized->len)); + SECITEM_FreeItem(pub_key_serialized, PR_TRUE); + + outer.WriteString(ciphertext); + + SECKEY_DestroyPublicKey(pub_key); + SECKEY_DestroyPrivateKey(priv_key); + + return std::string(reinterpret_cast<const char*>(outer.data()), + outer.size()); +} + +void DnsCertProvenanceChecker::DoAsyncLookup( + const std::string& hostname, + const std::vector<base::StringPiece>& der_certs, + DnsRRResolver* dnsrr_resolver, + Delegate* delegate) { + DnsCertProvenanceCheck* check = new DnsCertProvenanceCheck( + hostname, dnsrr_resolver, delegate, der_certs); + check->Start(); +} + +DnsCertProvenanceChecker::Delegate::~Delegate() { +} + +DnsCertProvenanceChecker::~DnsCertProvenanceChecker() { +} + +} // namespace net + +#else // USE_OPENSSL + +namespace net { + +std::string DnsCertProvenanceChecker::BuildEncryptedReport( + const std::string& hostname, + const std::vector<std::string>& der_certs) { + return ""; +} + +void DnsCertProvenanceChecker::DoAsyncLookup( + const std::string& hostname, + const std::vector<base::StringPiece>& der_certs, + DnsRRResolver* dnsrr_resolver, + Delegate* delegate) { +} + +DnsCertProvenanceChecker::Delegate::~Delegate() { +} + +DnsCertProvenanceChecker::~DnsCertProvenanceChecker() { +} + +} // namespace net + +#endif // USE_OPENSSL diff --git a/net/socket/dns_cert_provenance_checker.h b/net/socket/dns_cert_provenance_checker.h new file mode 100644 index 0000000..810e272 --- /dev/null +++ b/net/socket/dns_cert_provenance_checker.h @@ -0,0 +1,62 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_DNS_CERT_PROVENANCE_CHECKER_H +#define NET_SOCKET_DNS_CERT_PROVENANCE_CHECKER_H + +#include <string> +#include <vector> + +#include "base/string_piece.h" + +namespace net { + +class DnsRRResolver; + +// DnsCertProvenanceChecker is an interface for asynchronously checking HTTPS +// certificates via a DNS side-channel. +class DnsCertProvenanceChecker { + public: + class Delegate { + public: + virtual ~Delegate(); + + virtual void OnDnsCertLookupFailed( + const std::string& hostname, + const std::vector<std::string>& der_certs) = 0; + }; + + virtual void Shutdown() = 0; + + virtual ~DnsCertProvenanceChecker(); + + // DoAsyncVerification starts an asynchronous check for the given certificate + // chain. It must be run on the network thread. + virtual void DoAsyncVerification( + const std::string& hostname, + const std::vector<base::StringPiece>& der_certs) = 0; + + + protected: + // DoAsyncLookup performs a DNS lookup for the given name and certificate + // chain. In the event that the lookup reports a failure, the Delegate is + // called back. + static void DoAsyncLookup( + const std::string& hostname, + const std::vector<base::StringPiece>& der_certs, + DnsRRResolver* dnsrr_resolver, + Delegate* delegate); + + // BuildEncryptedRecord encrypts the certificate chain to a fixed public key + // and returns the encrypted blob. Since this code is reporting a possible + // HTTPS failure, it would seem silly to use HTTPS to protect the uploaded + // report. + static std::string BuildEncryptedReport( + const std::string& hostname, + const std::vector<std::string>& der_certs); +}; + +} // namespace net + +#endif // NET_SOCKET_DNS_CERT_PROVENANCE_CHECK_H diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index 8378c1d..b2e738a 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -1016,7 +1016,7 @@ SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { MockSSLClientSocket* socket = new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, ssl_host_info, mock_ssl_data_.GetNext()); @@ -1066,7 +1066,7 @@ SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { MockSSLClientSocket* socket = new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, ssl_host_info, mock_ssl_data_.GetNext()); diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index ba0b94a..0a01df3 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -172,6 +172,7 @@ class StaticSocketDataProvider : public SocketDataProvider { virtual MockRead GetNextRead(); virtual MockWriteResult OnWrite(const std::string& data); virtual void Reset(); + virtual void CompleteRead() {} // These functions get access to the next available read and write data. const MockRead& PeekRead() const; @@ -284,7 +285,7 @@ class DelayedSocketData : public StaticSocketDataProvider, virtual MockRead GetNextRead(); virtual MockWriteResult OnWrite(const std::string& data); virtual void Reset(); - void CompleteRead(); + virtual void CompleteRead(); void ForceNextRead(); private: @@ -327,6 +328,8 @@ class OrderedSocketData : public StaticSocketDataProvider, virtual MockRead GetNextRead(); virtual MockWriteResult OnWrite(const std::string& data); virtual void Reset(); + virtual void CompleteRead(); + void SetCompletionCallback(CompletionCallback* callback) { callback_ = callback; } @@ -334,8 +337,6 @@ class OrderedSocketData : public StaticSocketDataProvider, // Posts a quit message to the current message loop, if one is running. void EndLoop(); - void CompleteRead(); - private: friend class base::RefCounted<OrderedSocketData>; virtual ~OrderedSocketData(); @@ -425,6 +426,8 @@ class DeterministicSocketData : public StaticSocketDataProvider, virtual void Reset(); + virtual void CompleteRead() {} + // Consume all the data up to the give stop point (via SetStop()). void Run(); @@ -442,7 +445,6 @@ class DeterministicSocketData : public StaticSocketDataProvider, virtual void StopAfter(int seq) { SetStop(sequence_number_ + seq); } - void CompleteRead(); bool stopped() const { return stopped_; } void SetStopped(bool val) { stopped_ = val; } MockRead& current_read() { return current_read_; } @@ -535,7 +537,7 @@ class MockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver); + DnsCertProvenanceChecker* dns_cert_checker); SocketDataProviderArray<SocketDataProvider>& mock_data() { return mock_data_; } @@ -880,7 +882,7 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver); + DnsCertProvenanceChecker* dns_cert_checker); SocketDataProviderArray<DeterministicSocketData>& mock_data() { return mock_data_; diff --git a/net/socket/ssl_client_socket_mac_factory.cc b/net/socket/ssl_client_socket_mac_factory.cc index a4ffb78..bf732e6 100644 --- a/net/socket/ssl_client_socket_mac_factory.cc +++ b/net/socket/ssl_client_socket_mac_factory.cc @@ -14,7 +14,7 @@ SSLClientSocket* SSLClientSocketMacFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { delete ssl_host_info; return new SSLClientSocketMac(transport_socket, host_and_port, ssl_config); } diff --git a/net/socket/ssl_client_socket_mac_factory.h b/net/socket/ssl_client_socket_mac_factory.h index c8f48ea..5539136 100644 --- a/net/socket/ssl_client_socket_mac_factory.h +++ b/net/socket/ssl_client_socket_mac_factory.h @@ -10,7 +10,7 @@ namespace net { -class DnsRRResolver; +class DnsCertProvenanceChecker; class SSLHostInfo; // Creates SSLClientSocketMac objects. @@ -19,7 +19,7 @@ SSLClientSocket* SSLClientSocketMacFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver); + DnsCertProvenanceChecker* dns_cert_checker); } // namespace net diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 3234320..b9c6dff 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -47,12 +47,6 @@ #include "net/socket/ssl_client_socket_nss.h" -#if defined(USE_SYSTEM_SSL) -#include <dlfcn.h> -#endif -#if defined(OS_MACOSX) -#include <Security/Security.h> -#endif #include <certdb.h> #include <hasht.h> #include <keyhi.h> @@ -93,10 +87,22 @@ #include "net/base/sys_addrinfo.h" #include "net/ocsp/nss_ocsp.h" #include "net/socket/client_socket_handle.h" -#include "net/socket/dns_cert_provenance_check.h" +#include "net/socket/dns_cert_provenance_checker.h" #include "net/socket/ssl_error_params.h" #include "net/socket/ssl_host_info.h" +#if defined(USE_SYSTEM_SSL) +#include <dlfcn.h> +#endif +#if defined(OS_WIN) +#include <windows.h> +#include <wincrypt.h> +#elif defined(OS_MACOSX) +#include <Security/SecBase.h> +#include <Security/SecCertificate.h> +#include <Security/SecIdentity.h> +#endif + static const int kRecvBufferSize = 4096; // kCorkTimeoutMs is the number of milliseconds for which we'll wait for a @@ -399,7 +405,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) + DnsCertProvenanceChecker* dns_ctx) : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( this, &SSLClientSocketNSS::BufferSendComplete)), ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( @@ -435,7 +441,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, predicted_npn_status_(kNextProtoUnsupported), predicted_npn_proto_used_(false), ssl_host_info_(ssl_host_info), - dnsrr_resolver_(dnsrr_resolver) { + dns_cert_checker_(dns_ctx) { EnterFunction(""); } @@ -2348,6 +2354,13 @@ static DNSValidationResult CheckDNSSECChain( } int SSLClientSocketNSS::DoVerifyDNSSEC(int result) { + if (ssl_config_.dns_cert_provenance_checking_enabled && + dns_cert_checker_) { + PeerCertificateChain certs(nss_fd_); + dns_cert_checker_->DoAsyncVerification( + host_and_port_.host(), certs.AsStringPieceVector()); + } + if (ssl_config_.dnssec_enabled) { DNSValidationResult r = CheckDNSSECChain(host_and_port_.host(), server_cert_nss_); diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index b2725f6..7743097 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -31,7 +31,7 @@ namespace net { class BoundNetLog; class CertVerifier; class ClientSocketHandle; -class DnsRRResolver; +class DnsCertProvenanceChecker; class SSLHostInfo; class X509Certificate; @@ -48,7 +48,7 @@ class SSLClientSocketNSS : public SSLClientSocket { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver); + DnsCertProvenanceChecker* dnsrr_resolver); ~SSLClientSocketNSS(); // SSLClientSocket methods: @@ -250,7 +250,7 @@ class SSLClientSocketNSS : public SSLClientSocket { bool predicted_npn_proto_used_; scoped_ptr<SSLHostInfo> ssl_host_info_; - DnsRRResolver* const dnsrr_resolver_; + DnsCertProvenanceChecker* const dns_cert_checker_; }; } // namespace net diff --git a/net/socket/ssl_client_socket_nss_factory.cc b/net/socket/ssl_client_socket_nss_factory.cc index f7fc435..e4c01f0 100644 --- a/net/socket/ssl_client_socket_nss_factory.cc +++ b/net/socket/ssl_client_socket_nss_factory.cc @@ -19,10 +19,10 @@ SSLClientSocket* SSLClientSocketNSSFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { scoped_ptr<SSLHostInfo> shi(ssl_host_info); return new SSLClientSocketNSS(transport_socket, host_and_port, ssl_config, - shi.release(), dnsrr_resolver); + shi.release(), dns_cert_checker); } } // namespace net diff --git a/net/socket/ssl_client_socket_nss_factory.h b/net/socket/ssl_client_socket_nss_factory.h index c51b5be..15b05b2 100644 --- a/net/socket/ssl_client_socket_nss_factory.h +++ b/net/socket/ssl_client_socket_nss_factory.h @@ -10,7 +10,7 @@ namespace net { -class DnsRRResolver; +class DnsCertProvenanceChecker; class SSLHostInfo; // Creates SSLClientSocketNSS objects. @@ -19,7 +19,7 @@ SSLClientSocket* SSLClientSocketNSSFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver); + DnsCertProvenanceChecker* dns_cert_checker); } // namespace net diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index 62f3dbb..9aaca41 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -10,8 +10,10 @@ #include <openssl/ssl.h> #include <openssl/err.h> +#include "base/lock.h" #include "base/metrics/histogram.h" #include "base/openssl_util.h" +#include "base/singleton.h" #include "net/base/cert_verifier.h" #include "net/base/net_errors.h" #include "net/base/ssl_connection_status_flags.h" @@ -31,15 +33,8 @@ namespace { #endif const size_t kMaxRecvBufferSize = 4096; - -void MaybeLogSSLError() { - int error_num; - while ((error_num = ERR_get_error()) != 0) { - char buf[128]; // this buffer must be at least 120 chars long. - ERR_error_string_n(error_num, buf, arraysize(buf)); - DVLOG(1) << "SSL error " << error_num << ": " << buf; - } -} +const int kSessionCacheTimeoutSeconds = 60 * 60; +const size_t kSessionCacheMaxEntires = 1024; int MapOpenSSLError(int err) { switch (err) { @@ -48,12 +43,10 @@ int MapOpenSSLError(int err) { return ERR_IO_PENDING; case SSL_ERROR_SYSCALL: DVLOG(1) << "OpenSSL SYSCALL error, errno " << errno; - MaybeLogSSLError(); return ERR_SSL_PROTOCOL_ERROR; default: // TODO(joth): Implement full mapping. LOG(WARNING) << "Unknown OpenSSL error " << err; - MaybeLogSSLError(); return ERR_SSL_PROTOCOL_ERROR; } } @@ -65,21 +58,149 @@ int NoOpVerifyCallback(X509_STORE_CTX*, void *) { return 1; } -struct SSLContextSingletonTraits : public DefaultSingletonTraits<SSL_CTX> { - static SSL_CTX* New() { - base::EnsureOpenSSLInit(); - SSL_CTX* self = SSL_CTX_new(SSLv23_client_method()); - SSL_CTX_set_cert_verify_callback(self, NoOpVerifyCallback, NULL); - return self; +// OpenSSL manages a cache of SSL_SESSION, this class provides the application +// side policy for that cache about session re-use: we retain one session per +// unique HostPortPair. +class SSLSessionCache { + public: + SSLSessionCache() {} + + void OnSessionAdded(const HostPortPair& host_and_port, SSL_SESSION* session) { + // Declare the session cleaner-upper before the lock, so any call into + // OpenSSL to free the session will happen after the lock is released. + base::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; + AutoLock lock(lock_); + + DCHECK_EQ(0U, session_map_.count(session)); + std::pair<HostPortMap::iterator, bool> res = + host_port_map_.insert(std::make_pair(host_and_port, session)); + if (!res.second) { // Already exists: replace old entry. + session_to_free.reset(res.first->second); + session_map_.erase(session_to_free.get()); + res.first->second = session; + } + DVLOG(2) << "Adding session " << session << " => " + << host_and_port.ToString() << ", new entry = " << res.second; + DCHECK(host_port_map_[host_and_port] == session); + session_map_[session] = res.first; + DCHECK_EQ(host_port_map_.size(), session_map_.size()); + DCHECK_LE(host_port_map_.size(), kSessionCacheMaxEntires); } - static void Delete(SSL_CTX* self) { - SSL_CTX_free(self); + + void OnSessionRemoved(SSL_SESSION* session) { + // Declare the session cleaner-upper before the lock, so any call into + // OpenSSL to free the session will happen after the lock is released. + base::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; + AutoLock lock(lock_); + + SessionMap::iterator it = session_map_.find(session); + if (it == session_map_.end()) + return; + DVLOG(2) << "Remove session " << session << " => " + << it->second->first.ToString(); + DCHECK(it->second->second == session); + host_port_map_.erase(it->second); + session_map_.erase(it); + session_to_free.reset(session); + DCHECK_EQ(host_port_map_.size(), session_map_.size()); } + + // Looks up the host:port in the cache, and if a session is found it is added + // to |ssl|, returning true on success. + bool SetSSLSession(SSL* ssl, const HostPortPair& host_and_port) { + AutoLock lock(lock_); + HostPortMap::iterator it = host_port_map_.find(host_and_port); + if (it == host_port_map_.end()) + return false; + DVLOG(2) << "Lookup session: " << it->second << " => " + << host_and_port.ToString(); + SSL_SESSION* session = it->second; + DCHECK(session); + DCHECK(session_map_[session] == it); + // Ideally we'd release |lock_| before calling into OpenSSL here, however + // that opens a small risk |session| will go out of scope before it is used. + // Alternatively we would take a temporary local refcount on |session|, + // except OpenSSL does not provide a public API for adding a ref (c.f. + // SSL_SESSION_free which decrements the ref). + return SSL_set_session(ssl, session) == 1; + } + + private: + // A pair of maps to allow bi-directional lookups between host:port and an + // associated seesion. + // TODO(joth): When client certificates are implemented we should key the + // cache on the client certificate used in addition to the host-port pair. + typedef std::map<HostPortPair, SSL_SESSION*> HostPortMap; + typedef std::map<SSL_SESSION*, HostPortMap::iterator> SessionMap; + HostPortMap host_port_map_; + SessionMap session_map_; + + // Protects access to both the above maps. + Lock lock_; + + DISALLOW_COPY_AND_ASSIGN(SSLSessionCache); }; -SSL_CTX* GetSSLContext() { - return Singleton<SSL_CTX, SSLContextSingletonTraits>::get(); -} +class SSLContext { + public: + static SSLContext* Get() { return Singleton<SSLContext>::get(); } + SSL_CTX* ssl_ctx() { return ssl_ctx_.get(); } + SSLSessionCache* session_cache() { return &session_cache_; } + + SSLClientSocketOpenSSL* GetClientSocketFromSSL(SSL* ssl) { + DCHECK(ssl); + SSLClientSocketOpenSSL* socket = static_cast<SSLClientSocketOpenSSL*>( + SSL_get_ex_data(ssl, ssl_socket_data_index_)); + DCHECK(socket); + return socket; + } + + bool SetClientSocketForSSL(SSL* ssl, SSLClientSocketOpenSSL* socket) { + return SSL_set_ex_data(ssl, ssl_socket_data_index_, socket) != 0; + } + + private: + friend struct DefaultSingletonTraits<SSLContext>; + + SSLContext() { + base::EnsureOpenSSLInit(); + ssl_socket_data_index_ = SSL_get_ex_new_index(0, 0, 0, 0, 0); + DCHECK_NE(ssl_socket_data_index_, -1); + ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method())); + SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), NoOpVerifyCallback, NULL); + SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_CLIENT); + SSL_CTX_sess_set_new_cb(ssl_ctx_.get(), NewSessionCallbackStatic); + SSL_CTX_sess_set_remove_cb(ssl_ctx_.get(), RemoveSessionCallbackStatic); + SSL_CTX_set_timeout(ssl_ctx_.get(), kSessionCacheTimeoutSeconds); + SSL_CTX_sess_set_cache_size(ssl_ctx_.get(), kSessionCacheMaxEntires); + } + + static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) { + return Get()->NewSessionCallback(ssl, session); + } + + int NewSessionCallback(SSL* ssl, SSL_SESSION* session) { + SSLClientSocketOpenSSL* socket = GetClientSocketFromSSL(ssl); + session_cache_.OnSessionAdded(socket->host_and_port(), session); + return 1; // 1 => We took ownership of |session|. + } + + static void RemoveSessionCallbackStatic(SSL_CTX* ctx, SSL_SESSION* session) { + return Get()->RemoveSessionCallback(ctx, session); + } + + void RemoveSessionCallback(SSL_CTX* ctx, SSL_SESSION* session) { + DCHECK(ctx == ssl_ctx()); + session_cache_.OnSessionRemoved(session); + } + + // This is the index used with SSL_get_ex_data to retrieve the owner + // SSLClientSocketOpenSSL object from an SSL instance. + int ssl_socket_data_index_; + + base::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_; + SSLSessionCache session_cache_; +}; } // namespace @@ -96,6 +217,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( user_connect_callback_(NULL), user_read_callback_(NULL), user_write_callback_(NULL), + completed_handshake_(false), client_auth_cert_needed_(false), ALLOW_THIS_IN_INITIALIZER_LIST(handshake_io_callback_( this, &SSLClientSocketOpenSSL::OnHandshakeIOComplete)), @@ -104,7 +226,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( transport_(transport_socket), host_and_port_(host_and_port), ssl_config_(ssl_config), - completed_handshake_(false), + trying_cached_session_(false), net_log_(transport_socket->socket()->NetLog()) { } @@ -116,23 +238,23 @@ bool SSLClientSocketOpenSSL::Init() { DCHECK(!ssl_); DCHECK(!transport_bio_); - ssl_ = SSL_new(GetSSLContext()); - if (!ssl_) { - MaybeLogSSLError(); + SSLContext* context = SSLContext::Get(); + base::OpenSSLErrStackTracer err_tracer(FROM_HERE); + + ssl_ = SSL_new(context->ssl_ctx()); + if (!ssl_ || !context->SetClientSocketForSSL(ssl_, this)) return false; - } - if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) { - MaybeLogSSLError(); + if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) return false; - } + + trying_cached_session_ = + context->session_cache()->SetSSLSession(ssl_, host_and_port_); BIO* ssl_bio = NULL; - // TODO(joth): Provide explicit write buffer sizes, rather than use defaults? - if (!BIO_new_bio_pair(&ssl_bio, 0, &transport_bio_, 0)) { - MaybeLogSSLError(); + // 0 => use default buffer sizes. + if (!BIO_new_bio_pair(&ssl_bio, 0, &transport_bio_, 0)) return false; - } DCHECK(ssl_bio); DCHECK(transport_bio_); @@ -286,11 +408,11 @@ void SSLClientSocketOpenSSL::Disconnect() { user_write_buf_ = NULL; user_write_buf_len_ = 0; - client_certs_.clear(); - client_auth_cert_needed_ = false; - server_cert_verify_result_.Reset(); completed_handshake_ = false; + + client_certs_.clear(); + client_auth_cert_needed_ = false; } int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { @@ -336,10 +458,16 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { } int SSLClientSocketOpenSSL::DoHandshake() { + base::OpenSSLErrStackTracer err_tracer(FROM_HERE); int net_error = net::OK; int rv = SSL_do_handshake(ssl_); if (rv == 1) { + if (trying_cached_session_ && logging::DEBUG_MODE) { + DVLOG(2) << "Result of session reuse for " << host_and_port_.ToString() + << " is: " << (SSL_session_reused(ssl_) ? "Success" : "Fail"); + } + // SSL handshake is completed. Let's verify the certificate. const bool got_cert = !!UpdateServerCert(); DCHECK(got_cert); @@ -355,7 +483,6 @@ int SSLClientSocketOpenSSL::DoHandshake() { LOG(ERROR) << "handshake failed; returned " << rv << ", SSL error code " << ssl_error << ", net_error " << net_error; - MaybeLogSSLError(); } } return net_error; @@ -400,28 +527,11 @@ int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { } completed_handshake_ = true; - // The NSS version has a comment that we may not need this call because it is - // now harmless to have a session with a bad cert. - // This may or may not apply here, but let's invalidate it anyway. - InvalidateSessionIfBadCertificate(); // Exit DoHandshakeLoop and return the result to the caller to Connect. DCHECK_EQ(STATE_NONE, next_handshake_state_); return result; } -void SSLClientSocketOpenSSL::InvalidateSessionIfBadCertificate() { - if (UpdateServerCert() != NULL && - ssl_config_.IsAllowedBadCert(server_cert_)) { - // Remove from session cache but don't clear this connection. - // TODO(joth): This should be a no-op until we enable session caching, - // see SSL_CTX_set_session_cache_mode(SSL_SESS_CACHE_CLIENT). - SSL_SESSION* session = SSL_get_session(ssl_); - LOG_IF(ERROR, session) << "Connection has a session?? " << session; - int rv = SSL_CTX_remove_session(GetSSLContext(), session); - LOG_IF(ERROR, rv) << "Session was cached?? " << rv; - } -} - X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() { if (server_cert_) return server_cert_; @@ -736,6 +846,7 @@ bool SSLClientSocketOpenSSL::SetSendBufferSize(int32 size) { } int SSLClientSocketOpenSSL::DoPayloadRead() { + base::OpenSSLErrStackTracer err_tracer(FROM_HERE); int rv = SSL_read(ssl_, user_read_buf_->data(), user_read_buf_len_); // We don't need to invalidate the non-client-authenticated SSL session // because the server will renegotiate anyway. @@ -750,6 +861,7 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { } int SSLClientSocketOpenSSL::DoPayloadWrite() { + base::OpenSSLErrStackTracer err_tracer(FROM_HERE); int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_); if (rv >= 0) diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 783132d..e7bfe3c 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -36,6 +36,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { const SSLConfig& ssl_config); ~SSLClientSocketOpenSSL(); + const HostPortPair& host_and_port() const { return host_and_port_; } + // SSLClientSocket methods: virtual void GetSSLInfo(SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info); @@ -111,6 +113,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // Set when handshake finishes. scoped_refptr<X509Certificate> server_cert_; CertVerifyResult server_cert_verify_result_; + bool completed_handshake_; // Stores client authentication information between ClientAuthHandler and // GetSSLCertRequestInfo calls. @@ -128,7 +131,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { const HostPortPair host_and_port_; SSLConfig ssl_config_; - bool completed_handshake_; + // Used for session cache diagnostics. + bool trying_cached_session_; enum State { STATE_NONE, diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc index 785faab..5b21005 100644 --- a/net/socket/ssl_client_socket_pool.cc +++ b/net/socket/ssl_client_socket_pool.cc @@ -78,6 +78,7 @@ SSLConnectJob::SSLConnectJob( ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, Delegate* delegate, NetLog* net_log) @@ -90,6 +91,7 @@ SSLConnectJob::SSLConnectJob( client_socket_factory_(client_socket_factory), resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), + dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), ALLOW_THIS_IN_INITIALIZER_LIST( callback_(this, &SSLConnectJob::OnIOComplete)) {} @@ -287,7 +289,7 @@ int SSLConnectJob::DoSSLConnect() { ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( transport_socket_handle_.release(), params_->host_and_port(), - params_->ssl_config(), ssl_host_info_.release(), dnsrr_resolver_)); + params_->ssl_config(), ssl_host_info_.release(), dns_cert_checker_)); return ssl_socket_->Connect(&callback_); } @@ -358,8 +360,8 @@ ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), tcp_pool_, socks_pool_, http_proxy_pool_, client_socket_factory_, host_resolver_, - dnsrr_resolver_, ssl_host_info_factory_, delegate, - net_log_); + dnsrr_resolver_, dns_cert_checker_, + ssl_host_info_factory_, delegate, net_log_); } SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( @@ -369,6 +371,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, NetLog* net_log) : tcp_pool_(tcp_pool), @@ -377,6 +380,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( client_socket_factory_(client_socket_factory), host_resolver_(host_resolver), dnsrr_resolver_(dnsrr_resolver), + dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), net_log_(net_log) { base::TimeDelta max_transport_timeout = base::TimeDelta(); @@ -403,6 +407,7 @@ SSLClientSocketPool::SSLClientSocketPool( ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ClientSocketFactory* client_socket_factory, TCPClientSocketPool* tcp_pool, @@ -419,7 +424,8 @@ SSLClientSocketPool::SSLClientSocketPool( base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), new SSLConnectJobFactory(tcp_pool, socks_pool, http_proxy_pool, client_socket_factory, host_resolver, - dnsrr_resolver, ssl_host_info_factory, + dnsrr_resolver, dns_cert_checker, + ssl_host_info_factory, net_log)), ssl_config_service_(ssl_config_service) { if (ssl_config_service_) diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h index 11cf250..5eb8594 100644 --- a/net/socket/ssl_client_socket_pool.h +++ b/net/socket/ssl_client_socket_pool.h @@ -24,6 +24,7 @@ namespace net { class ClientSocketFactory; class ConnectJobFactory; +class DnsCertProvenanceChecker; class DnsRRResolver; class HostPortPair; class HttpProxyClientSocketPool; @@ -95,6 +96,7 @@ class SSLConnectJob : public ConnectJob { ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, Delegate* delegate, NetLog* net_log); @@ -144,6 +146,7 @@ class SSLConnectJob : public ConnectJob { ClientSocketFactory* const client_socket_factory_; HostResolver* const resolver_; DnsRRResolver* const dnsrr_resolver_; + DnsCertProvenanceChecker* dns_cert_checker_; SSLHostInfoFactory* const ssl_host_info_factory_; State next_state_; @@ -171,6 +174,7 @@ class SSLClientSocketPool : public ClientSocketPool, ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, ClientSocketFactory* client_socket_factory, TCPClientSocketPool* tcp_pool, @@ -244,6 +248,7 @@ class SSLClientSocketPool : public ClientSocketPool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, NetLog* net_log); @@ -264,6 +269,7 @@ class SSLClientSocketPool : public ClientSocketPool, ClientSocketFactory* const client_socket_factory_; HostResolver* const host_resolver_; DnsRRResolver* const dnsrr_resolver_; + DnsCertProvenanceChecker* const dns_cert_checker_; SSLHostInfoFactory* const ssl_host_info_factory_; base::TimeDelta timeout_; NetLog* net_log_; diff --git a/net/socket/ssl_client_socket_pool_unittest.cc b/net/socket/ssl_client_socket_pool_unittest.cc index f58a762..247638b 100644 --- a/net/socket/ssl_client_socket_pool_unittest.cc +++ b/net/socket/ssl_client_socket_pool_unittest.cc @@ -40,6 +40,7 @@ class SSLClientSocketPoolTest : public testing::Test { host_resolver_.get())), session_(new HttpNetworkSession(host_resolver_.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, ProxyService::CreateDirect(), &socket_factory_, @@ -97,6 +98,7 @@ class SSLClientSocketPoolTest : public testing::Test { ssl_histograms_.get(), NULL, NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, &socket_factory_, tcp_pool ? &tcp_socket_pool_ : NULL, diff --git a/net/socket/tcp_client_socket_pool_unittest.cc b/net/socket/tcp_client_socket_pool_unittest.cc index 215b9ba..c44815c 100644 --- a/net/socket/tcp_client_socket_pool_unittest.cc +++ b/net/socket/tcp_client_socket_pool_unittest.cc @@ -251,7 +251,7 @@ class MockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - DnsRRResolver* dnsrr_resolver) { + DnsCertProvenanceChecker* dns_cert_checker) { NOTIMPLEMENTED(); delete ssl_host_info; return NULL; diff --git a/net/socket_stream/socket_stream_job.cc b/net/socket_stream/socket_stream_job.cc index 0913015..8d1da73 100644 --- a/net/socket_stream/socket_stream_job.cc +++ b/net/socket_stream/socket_stream_job.cc @@ -4,6 +4,7 @@ #include "net/socket_stream/socket_stream_job.h" +#include "base/singleton.h" #include "net/socket_stream/socket_stream_job_manager.h" namespace net { diff --git a/net/spdy/spdy_framer.cc b/net/spdy/spdy_framer.cc index ea58559..ed21610 100644 --- a/net/spdy/spdy_framer.cc +++ b/net/spdy/spdy_framer.cc @@ -2,6 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +// TODO(rtenhove) clean up frame buffer size calculations so that we aren't +// constantly adding and subtracting header sizes; this is ugly and error- +// prone. + #include "net/spdy/spdy_framer.h" #include "base/metrics/stats_counters.h" @@ -17,23 +21,26 @@ namespace spdy { -// The initial size of the control frame buffer; this is used internally -// as we parse through control frames. -static const size_t kControlFrameBufferInitialSize = 32 * 1024; -// The maximum size of the control frame buffer that we support. -// TODO(mbelshe): We should make this stream-based so there are no limits. -static const size_t kControlFrameBufferMaxSize = 64 * 1024; - // By default is compression on or off. bool SpdyFramer::compression_default_ = true; int SpdyFramer::spdy_version_ = kSpdyProtocolVersion; +// The initial size of the control frame buffer; this is used internally +// as we parse through control frames. (It is exposed here for unit test +// purposes.) +size_t SpdyFramer::kControlFrameBufferInitialSize = 8 * 1024; + +// The maximum size of the control frame buffer that we support. +// TODO(mbelshe): We should make this stream-based so there are no limits. +size_t SpdyFramer::kControlFrameBufferMaxSize = 16 * 1024; + #ifdef DEBUG_SPDY_STATE_CHANGES #define CHANGE_STATE(newstate) \ { \ do { \ - VLOG(1) << "Changing state from: " << StateToString(state_) \ - << " to " << StateToString(newstate); \ + LOG(INFO) << "Changing state from: " \ + << StateToString(state_) \ + << " to " << StateToString(newstate) << "\n"; \ state_ = newstate; \ } while (false); \ } @@ -72,7 +79,7 @@ void SpdyFramer::Reset() { current_frame_len_ = 0; if (current_frame_capacity_ != kControlFrameBufferInitialSize) { delete [] current_frame_buffer_; - current_frame_buffer_ = NULL; + current_frame_buffer_ = 0; current_frame_capacity_ = 0; ExpandControlFrameBuffer(kControlFrameBufferInitialSize); } @@ -289,6 +296,11 @@ void SpdyFramer::ProcessControlFrameHeader() { SpdyRstStreamControlFrame::size() - SpdyFrame::size()) set_error(SPDY_INVALID_CONTROL_FRAME); break; + case SETTINGS: + if (current_control_frame.length() < + SpdySettingsControlFrame::size() - SpdyControlFrame::size()) + set_error(SPDY_INVALID_CONTROL_FRAME); + break; case NOOP: // NOOP. Swallow it. CHANGE_STATE(SPDY_AUTO_RESET); @@ -298,14 +310,14 @@ void SpdyFramer::ProcessControlFrameHeader() { SpdyGoAwayControlFrame::size() - SpdyFrame::size()) set_error(SPDY_INVALID_CONTROL_FRAME); break; - case SETTINGS: + case HEADERS: if (current_control_frame.length() < - SpdySettingsControlFrame::size() - SpdyControlFrame::size()) + SpdyHeadersControlFrame::size() - SpdyControlFrame::size()) set_error(SPDY_INVALID_CONTROL_FRAME); break; case WINDOW_UPDATE: if (current_control_frame.length() != - SpdyWindowUpdateControlFrame::size() - SpdyFrame::size()) + SpdyWindowUpdateControlFrame::size() - SpdyControlFrame::size()) set_error(SPDY_INVALID_CONTROL_FRAME); break; default: @@ -435,7 +447,7 @@ bool SpdyFramer::ParseHeaderBlock(const SpdyFrame* frame, SpdyHeaderBlock* block) { SpdyControlFrame control_frame(frame->data(), false); uint32 type = control_frame.type(); - if (type != SYN_STREAM && type != SYN_REPLY) + if (type != SYN_STREAM && type != SYN_REPLY && type != HEADERS) return false; // Find the header data within the control frame. @@ -461,14 +473,21 @@ bool SpdyFramer::ParseHeaderBlock(const SpdyFrame* frame, header_length = syn_frame.header_block_len(); } break; + case HEADERS: + { + SpdyHeadersControlFrame header_frame(decompressed_frame->data(), false); + header_data = header_frame.header_block(); + header_length = header_frame.header_block_len(); + } + break; } SpdyFrameBuilder builder(header_data, header_length); void* iter = NULL; uint16 num_headers; if (builder.ReadUInt16(&iter, &num_headers)) { - int index = 0; - for ( ; index < num_headers; ++index) { + int index; + for (index = 0; index < num_headers; ++index) { std::string name; std::string value; if (!builder.ReadString(&iter, &name)) @@ -551,55 +570,60 @@ SpdySynStreamControlFrame* SpdyFramer::CreateSynStream( return reinterpret_cast<SpdySynStreamControlFrame*>(syn_frame.release()); } -/* static */ -SpdyRstStreamControlFrame* SpdyFramer::CreateRstStream(SpdyStreamId stream_id, - SpdyStatusCodes status) { +SpdySynReplyControlFrame* SpdyFramer::CreateSynReply(SpdyStreamId stream_id, + SpdyControlFlags flags, bool compressed, SpdyHeaderBlock* headers) { DCHECK_GT(stream_id, 0u); DCHECK_EQ(0u, stream_id & ~kStreamIdMask); - DCHECK_NE(status, INVALID); - DCHECK_LT(status, NUM_STATUS_CODES); SpdyFrameBuilder frame; + frame.WriteUInt16(kControlFlagMask | spdy_version_); - frame.WriteUInt16(RST_STREAM); - frame.WriteUInt32(8); + frame.WriteUInt16(SYN_REPLY); + frame.WriteUInt32(0); // Placeholder for the length and flags. frame.WriteUInt32(stream_id); - frame.WriteUInt32(status); - return reinterpret_cast<SpdyRstStreamControlFrame*>(frame.take()); -} + frame.WriteUInt16(0); // Unused -/* static */ -SpdyGoAwayControlFrame* SpdyFramer::CreateGoAway( - SpdyStreamId last_accepted_stream_id) { - DCHECK_EQ(0u, last_accepted_stream_id & ~kStreamIdMask); + frame.WriteUInt16(headers->size()); // Number of headers. + SpdyHeaderBlock::iterator it; + for (it = headers->begin(); it != headers->end(); ++it) { + bool wrote_header; + wrote_header = frame.WriteString(it->first); + wrote_header &= frame.WriteString(it->second); + DCHECK(wrote_header); + } - SpdyFrameBuilder frame; - frame.WriteUInt16(kControlFlagMask | spdy_version_); - frame.WriteUInt16(GOAWAY); - size_t go_away_size = SpdyGoAwayControlFrame::size() - SpdyFrame::size(); - frame.WriteUInt32(go_away_size); - frame.WriteUInt32(last_accepted_stream_id); - return reinterpret_cast<SpdyGoAwayControlFrame*>(frame.take()); + // Write the length and flags. + size_t length = frame.length() - SpdyFrame::size(); + DCHECK_EQ(0u, length & ~static_cast<size_t>(kLengthMask)); + FlagsAndLength flags_length; + flags_length.length_ = htonl(static_cast<uint32>(length)); + DCHECK_EQ(0, flags & ~kControlFlagsMask); + flags_length.flags_[0] = flags; + frame.WriteBytesToOffset(4, &flags_length, sizeof(flags_length)); + + scoped_ptr<SpdyFrame> reply_frame(frame.take()); + if (compressed) { + return reinterpret_cast<SpdySynReplyControlFrame*>( + CompressFrame(*reply_frame.get())); + } + return reinterpret_cast<SpdySynReplyControlFrame*>(reply_frame.release()); } /* static */ -SpdyWindowUpdateControlFrame* SpdyFramer::CreateWindowUpdate( - SpdyStreamId stream_id, - uint32 delta_window_size) { +SpdyRstStreamControlFrame* SpdyFramer::CreateRstStream(SpdyStreamId stream_id, + SpdyStatusCodes status) { DCHECK_GT(stream_id, 0u); DCHECK_EQ(0u, stream_id & ~kStreamIdMask); - DCHECK_GT(delta_window_size, 0u); - DCHECK_LT(delta_window_size, 0x80000000u); // 2^31 + DCHECK_NE(status, INVALID); + DCHECK_LT(status, NUM_STATUS_CODES); SpdyFrameBuilder frame; frame.WriteUInt16(kControlFlagMask | spdy_version_); - frame.WriteUInt16(WINDOW_UPDATE); - size_t window_update_size = SpdyWindowUpdateControlFrame::size() - - SpdyFrame::size(); - frame.WriteUInt32(window_update_size); + frame.WriteUInt16(RST_STREAM); + frame.WriteUInt32(8); frame.WriteUInt32(stream_id); - frame.WriteUInt32(delta_window_size); - return reinterpret_cast<SpdyWindowUpdateControlFrame*>(frame.take()); + frame.WriteUInt32(status); + return reinterpret_cast<SpdyRstStreamControlFrame*>(frame.take()); } /* static */ @@ -621,15 +645,38 @@ SpdySettingsControlFrame* SpdyFramer::CreateSettings( return reinterpret_cast<SpdySettingsControlFrame*>(frame.take()); } -SpdySynReplyControlFrame* SpdyFramer::CreateSynReply(SpdyStreamId stream_id, +/* static */ +SpdyControlFrame* SpdyFramer::CreateNopFrame() { + SpdyFrameBuilder frame; + frame.WriteUInt16(kControlFlagMask | spdy_version_); + frame.WriteUInt16(NOOP); + frame.WriteUInt32(0); + return reinterpret_cast<SpdyControlFrame*>(frame.take()); +} + +/* static */ +SpdyGoAwayControlFrame* SpdyFramer::CreateGoAway( + SpdyStreamId last_accepted_stream_id) { + DCHECK_EQ(0u, last_accepted_stream_id & ~kStreamIdMask); + + SpdyFrameBuilder frame; + frame.WriteUInt16(kControlFlagMask | spdy_version_); + frame.WriteUInt16(GOAWAY); + size_t go_away_size = SpdyGoAwayControlFrame::size() - SpdyFrame::size(); + frame.WriteUInt32(go_away_size); + frame.WriteUInt32(last_accepted_stream_id); + return reinterpret_cast<SpdyGoAwayControlFrame*>(frame.take()); +} + +SpdyHeadersControlFrame* SpdyFramer::CreateHeaders(SpdyStreamId stream_id, SpdyControlFlags flags, bool compressed, SpdyHeaderBlock* headers) { + // Basically the same as CreateSynReply(). DCHECK_GT(stream_id, 0u); DCHECK_EQ(0u, stream_id & ~kStreamIdMask); SpdyFrameBuilder frame; - - frame.WriteUInt16(kControlFlagMask | spdy_version_); - frame.WriteUInt16(SYN_REPLY); + frame.WriteUInt16(kControlFlagMask | kSpdyProtocolVersion); + frame.WriteUInt16(HEADERS); frame.WriteUInt32(0); // Placeholder for the length and flags. frame.WriteUInt32(stream_id); frame.WriteUInt16(0); // Unused @@ -652,12 +699,32 @@ SpdySynReplyControlFrame* SpdyFramer::CreateSynReply(SpdyStreamId stream_id, flags_length.flags_[0] = flags; frame.WriteBytesToOffset(4, &flags_length, sizeof(flags_length)); - scoped_ptr<SpdyFrame> reply_frame(frame.take()); + scoped_ptr<SpdyFrame> headers_frame(frame.take()); if (compressed) { - return reinterpret_cast<SpdySynReplyControlFrame*>( - CompressFrame(*reply_frame.get())); + return reinterpret_cast<SpdyHeadersControlFrame*>( + CompressFrame(*headers_frame.get())); } - return reinterpret_cast<SpdySynReplyControlFrame*>(reply_frame.release()); + return reinterpret_cast<SpdyHeadersControlFrame*>(headers_frame.release()); +} + +/* static */ +SpdyWindowUpdateControlFrame* SpdyFramer::CreateWindowUpdate( + SpdyStreamId stream_id, + uint32 delta_window_size) { + DCHECK_GT(stream_id, 0u); + DCHECK_EQ(0u, stream_id & ~kStreamIdMask); + DCHECK_GT(delta_window_size, 0u); + DCHECK_LE(delta_window_size, spdy::kSpdyStreamMaximumWindowSize); + + SpdyFrameBuilder frame; + frame.WriteUInt16(kControlFlagMask | spdy_version_); + frame.WriteUInt16(WINDOW_UPDATE); + size_t window_update_size = SpdyWindowUpdateControlFrame::size() - + SpdyFrame::size(); + frame.WriteUInt32(window_update_size); + frame.WriteUInt32(stream_id); + frame.WriteUInt32(delta_window_size); + return reinterpret_cast<SpdyWindowUpdateControlFrame*>(frame.take()); } SpdyDataFrame* SpdyFramer::CreateDataFrame(SpdyStreamId stream_id, @@ -692,15 +759,6 @@ SpdyDataFrame* SpdyFramer::CreateDataFrame(SpdyStreamId stream_id, return rv; } -/* static */ -SpdyControlFrame* SpdyFramer::CreateNopFrame() { - SpdyFrameBuilder frame; - frame.WriteUInt16(kControlFlagMask | spdy_version_); - frame.WriteUInt16(NOOP); - frame.WriteUInt32(0); - return reinterpret_cast<SpdyControlFrame*>(frame.take()); -} - // The following compression setting are based on Brian Olson's analysis. See // https://groups.google.com/group/spdy-dev/browse_thread/thread/dfaf498542fac792 // for more details. @@ -844,6 +902,16 @@ bool SpdyFramer::GetFrameBoundaries(const SpdyFrame& frame, *payload = frame.data() + *header_length; } break; + case HEADERS: + { + const SpdyHeadersControlFrame& headers_frame = + reinterpret_cast<const SpdyHeadersControlFrame&>(frame); + frame_size = SpdyHeadersControlFrame::size(); + *payload_length = headers_frame.header_block_len(); + *header_length = frame_size; + *payload = frame.data() + *header_length; + } + break; default: // TODO(mbelshe): set an error? return false; // We can't compress this frame! @@ -985,8 +1053,10 @@ SpdyFrame* SpdyFramer::DecompressFrameWithZStream(const SpdyFrame& frame, // Create an output frame. Assume it does not need to be longer than // the input data. - int decompressed_max_size = kControlFrameBufferInitialSize; + size_t decompressed_max_size = kControlFrameBufferInitialSize; int new_frame_size = header_length + decompressed_max_size; + if (frame.length() > decompressed_max_size) + return NULL; scoped_ptr<SpdyFrame> new_frame(new SpdyFrame(new_frame_size)); memcpy(new_frame->data(), frame.data(), frame.length() + SpdyFrame::size()); diff --git a/net/spdy/spdy_framer.h b/net/spdy/spdy_framer.h index 9b290cd..85805f3 100644 --- a/net/spdy/spdy_framer.h +++ b/net/spdy/spdy_framer.h @@ -157,9 +157,27 @@ class SpdyFramer { bool compressed, SpdyHeaderBlock* headers); + // Create a SpdySynReplyControlFrame. + // |stream_id| is the stream for this frame. + // |flags| is the flags to use with the data. + // To mark this frame as the last frame, enable CONTROL_FLAG_FIN. + // |compressed| specifies whether the frame should be compressed. + // |headers| is the header block to include in the frame. + SpdySynReplyControlFrame* CreateSynReply(SpdyStreamId stream_id, + SpdyControlFlags flags, + bool compressed, + SpdyHeaderBlock* headers); + static SpdyRstStreamControlFrame* CreateRstStream(SpdyStreamId stream_id, SpdyStatusCodes status); + // Creates an instance of SpdySettingsControlFrame. The SETTINGS frame is + // used to communicate name/value pairs relevant to the communication channel. + // TODO(mbelshe): add the name/value pairs!! + static SpdySettingsControlFrame* CreateSettings(const SpdySettings& values); + + static SpdyControlFrame* CreateNopFrame(); + // Creates an instance of SpdyGoAwayControlFrame. The GOAWAY frame is used // prior to the shutting down of the TCP connection, and includes the // stream_id of the last stream the sender of the frame is willing to process @@ -167,32 +185,25 @@ class SpdyFramer { static SpdyGoAwayControlFrame* CreateGoAway( SpdyStreamId last_accepted_stream_id); + // Creates an instance of SpdyHeadersControlFrame. The HEADERS frame is used + // for sending additional headers outside of a SYN_STREAM/SYN_REPLY. The + // arguments are the same as for CreateSynReply. + SpdyHeadersControlFrame* CreateHeaders(SpdyStreamId stream_id, + SpdyControlFlags flags, + bool compressed, + SpdyHeaderBlock* headers); + // Creates an instance of SpdyWindowUpdateControlFrame. The WINDOW_UPDATE // frame is used to implement per stream flow control in SPDY. static SpdyWindowUpdateControlFrame* CreateWindowUpdate( - SpdyStreamId stream_id, uint32 delta_window_size); - - // Creates an instance of SpdySettingsControlFrame. The SETTINGS frame is - // used to communicate name/value pairs relevant to the communication channel. - // TODO(mbelshe): add the name/value pairs!! - static SpdySettingsControlFrame* CreateSettings(const SpdySettings& values); + SpdyStreamId stream_id, + uint32 delta_window_size); // Given a SpdySettingsControlFrame, extract the settings. // Returns true on successful parse, false otherwise. static bool ParseSettings(const SpdySettingsControlFrame* frame, SpdySettings* settings); - // Create a SpdySynReplyControlFrame. - // |stream_id| is the stream for this frame. - // |flags| is the flags to use with the data. - // To mark this frame as the last frame, enable CONTROL_FLAG_FIN. - // |compressed| specifies whether the frame should be compressed. - // |headers| is the header block to include in the frame. - SpdySynReplyControlFrame* CreateSynReply(SpdyStreamId stream_id, - SpdyControlFlags flags, - bool compressed, - SpdyHeaderBlock* headers); - // Create a data frame. // |stream_id| is the stream for this frame // |data| is the data to be included in the frame. @@ -203,8 +214,6 @@ class SpdyFramer { SpdyDataFrame* CreateDataFrame(SpdyStreamId stream_id, const char* data, uint32 len, SpdyDataFlags flags); - static SpdyControlFrame* CreateNopFrame(); - // NOTES about frame compression. // We want spdy to compress headers across the entire session. As long as // the session is over TCP, frames are sent serially. The client & server @@ -249,7 +258,11 @@ class SpdyFramer { protected: FRIEND_TEST_ALL_PREFIXES(SpdyFramerTest, DataCompression); + FRIEND_TEST_ALL_PREFIXES(SpdyFramerTest, ExpandBuffer_HeapSmash); + FRIEND_TEST_ALL_PREFIXES(SpdyFramerTest, HugeHeaderBlock); FRIEND_TEST_ALL_PREFIXES(SpdyFramerTest, UnclosedStreamDataCompressors); + FRIEND_TEST_ALL_PREFIXES(SpdyFramerTest, + UncompressLargerThanFrameBufferInitialSize); friend class net::HttpNetworkLayer; // This is temporary for the server. friend class net::HttpNetworkTransactionTest; friend class net::HttpProxyClientSocketPoolTest; @@ -266,6 +279,16 @@ class SpdyFramer { void set_enable_compression(bool value); static void set_enable_compression_default(bool value); + + // The initial size of the control frame buffer; this is used internally + // as we parse through control frames. (It is exposed here for unit test + // purposes.) + static size_t kControlFrameBufferInitialSize; + + // The maximum size of the control frame buffer that we support. + // TODO(mbelshe): We should make this stream-based so there are no limits. + static size_t kControlFrameBufferMaxSize; + private: typedef std::map<SpdyStreamId, z_stream*> CompressorMap; diff --git a/net/spdy/spdy_framer_test.cc b/net/spdy/spdy_framer_test.cc index d238062..b4345c1 100644 --- a/net/spdy/spdy_framer_test.cc +++ b/net/spdy/spdy_framer_test.cc @@ -86,6 +86,7 @@ class TestSpdyVisitor : public SpdyFramerVisitorInterface { : error_count_(0), syn_frame_count_(0), syn_reply_frame_count_(0), + headers_frame_count_(0), data_bytes_(0), fin_frame_count_(0), fin_flag_count_(0), @@ -129,6 +130,11 @@ class TestSpdyVisitor : public SpdyFramerVisitorInterface { case RST_STREAM: fin_frame_count_++; break; + case HEADERS: + parsed_headers = framer_.ParseHeaderBlock(frame, &headers); + DCHECK(parsed_headers); + headers_frame_count_++; + break; default: DCHECK(false); // Error! } @@ -163,6 +169,7 @@ class TestSpdyVisitor : public SpdyFramerVisitorInterface { int error_count_; int syn_frame_count_; int syn_reply_frame_count_; + int headers_frame_count_; int data_bytes_; int fin_frame_count_; // The count of RST_STREAM type frames received. int fin_flag_count_; // The count of frames with the FIN flag set. @@ -264,7 +271,6 @@ TEST_F(SpdyFramerTest, WrongNumberOfHeaders) { frame1.WriteUInt16(SYN_STREAM); frame1.WriteUInt32(0); // Placeholder for the length. frame1.WriteUInt32(3); // stream_id - frame1.WriteUInt32(0); // associated stream id frame1.WriteUInt16(0); // Priority. frame1.WriteUInt16(1); // Wrong number of headers (underflow) @@ -280,7 +286,6 @@ TEST_F(SpdyFramerTest, WrongNumberOfHeaders) { frame2.WriteUInt16(SYN_STREAM); frame2.WriteUInt32(0); // Placeholder for the length. frame2.WriteUInt32(3); // stream_id - frame2.WriteUInt32(0); // associated stream id frame2.WriteUInt16(0); // Priority. frame2.WriteUInt16(100); // Wrong number of headers (overflow) @@ -446,6 +451,15 @@ TEST_F(SpdyFramerTest, Basic) { 0x00, 0x02, 'h', 'h', 0x00, 0x02, 'v', 'v', + 0x80, 0x02, 0x00, 0x08, // HEADERS on Stream #1 + 0x00, 0x00, 0x00, 0x18, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x02, 'h', '2', + 0x00, 0x02, 'v', '2', + 0x00, 0x02, 'h', '3', + 0x00, 0x02, 'v', '3', + 0x00, 0x00, 0x00, 0x01, // DATA on Stream #1 0x00, 0x00, 0x00, 0x0c, 0xde, 0xad, 0xbe, 0xef, @@ -487,6 +501,7 @@ TEST_F(SpdyFramerTest, Basic) { EXPECT_EQ(0, visitor.error_count_); EXPECT_EQ(2, visitor.syn_frame_count_); EXPECT_EQ(0, visitor.syn_reply_frame_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); EXPECT_EQ(24, visitor.data_bytes_); EXPECT_EQ(2, visitor.fin_frame_count_); EXPECT_EQ(0, visitor.fin_flag_count_); @@ -528,6 +543,7 @@ TEST_F(SpdyFramerTest, FinOnDataFrame) { EXPECT_EQ(0, visitor.error_count_); EXPECT_EQ(1, visitor.syn_frame_count_); EXPECT_EQ(1, visitor.syn_reply_frame_count_); + EXPECT_EQ(0, visitor.headers_frame_count_); EXPECT_EQ(16, visitor.data_bytes_); EXPECT_EQ(0, visitor.fin_frame_count_); EXPECT_EQ(0, visitor.fin_flag_count_); @@ -559,6 +575,7 @@ TEST_F(SpdyFramerTest, FinOnSynReplyFrame) { EXPECT_EQ(0, visitor.error_count_); EXPECT_EQ(1, visitor.syn_frame_count_); EXPECT_EQ(1, visitor.syn_reply_frame_count_); + EXPECT_EQ(0, visitor.headers_frame_count_); EXPECT_EQ(0, visitor.data_bytes_); EXPECT_EQ(0, visitor.fin_frame_count_); EXPECT_EQ(1, visitor.fin_flag_count_); @@ -683,7 +700,9 @@ TEST_F(SpdyFramerTest, UnclosedStreamDataCompressors) { const char bytes[] = "this is a test test test test test!"; scoped_ptr<SpdyFrame> send_frame( - send_framer.CreateDataFrame(1, bytes, arraysize(bytes), + send_framer.CreateDataFrame(1, + bytes, + arraysize(bytes), DATA_FLAG_FIN)); EXPECT_TRUE(send_frame.get() != NULL); @@ -698,6 +717,7 @@ TEST_F(SpdyFramerTest, UnclosedStreamDataCompressors) { EXPECT_EQ(0, visitor.error_count_); EXPECT_EQ(1, visitor.syn_frame_count_); EXPECT_EQ(0, visitor.syn_reply_frame_count_); + EXPECT_EQ(0, visitor.headers_frame_count_); EXPECT_EQ(arraysize(bytes), static_cast<unsigned>(visitor.data_bytes_)); EXPECT_EQ(0, visitor.fin_frame_count_); EXPECT_EQ(0, visitor.fin_flag_count_); @@ -1140,6 +1160,112 @@ TEST_F(SpdyFramerTest, CreateGoAway) { } } +TEST_F(SpdyFramerTest, CreateHeadersUncompressed) { + SpdyFramer framer; + FramerSetEnableCompressionHelper(&framer, false); + + { + const char kDescription[] = "HEADERS frame, no FIN"; + + SpdyHeaderBlock headers; + headers["bar"] = "foo"; + headers["foo"] = "bar"; + + const unsigned char kFrameData[] = { + 0x80, 0x02, 0x00, 0x08, + 0x00, 0x00, 0x00, 0x1C, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x03, 'b', 'a', + 'r', 0x00, 0x03, 'f', + 'o', 'o', 0x00, 0x03, + 'f', 'o', 'o', 0x00, + 0x03, 'b', 'a', 'r' + }; + scoped_ptr<SpdyFrame> frame(framer.CreateHeaders( + 1, CONTROL_FLAG_NONE, false, &headers)); + CompareFrame(kDescription, *frame, kFrameData, arraysize(kFrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header name, FIN, max stream ID"; + + SpdyHeaderBlock headers; + headers[""] = "foo"; + headers["foo"] = "bar"; + + const unsigned char kFrameData[] = { + 0x80, 0x02, 0x00, 0x08, + 0x01, 0x00, 0x00, 0x19, + 0x7f, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x03, + 'f', 'o', 'o', 0x00, + 0x03, 'f', 'o', 'o', + 0x00, 0x03, 'b', 'a', + 'r' + }; + scoped_ptr<SpdyFrame> frame(framer.CreateHeaders( + 0x7fffffff, CONTROL_FLAG_FIN, false, &headers)); + CompareFrame(kDescription, *frame, kFrameData, arraysize(kFrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header val, FIN, max stream ID"; + + SpdyHeaderBlock headers; + headers["bar"] = "foo"; + headers["foo"] = ""; + + const unsigned char kFrameData[] = { + 0x80, 0x02, 0x00, 0x08, + 0x01, 0x00, 0x00, 0x19, + 0x7f, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x03, 'b', 'a', + 'r', 0x00, 0x03, 'f', + 'o', 'o', 0x00, 0x03, + 'f', 'o', 'o', 0x00, + 0x00 + }; + scoped_ptr<SpdyFrame> frame(framer.CreateHeaders( + 0x7fffffff, CONTROL_FLAG_FIN, false, &headers)); + CompareFrame(kDescription, *frame, kFrameData, arraysize(kFrameData)); + } +} + +TEST_F(SpdyFramerTest, CreateHeadersCompressed) { + SpdyFramer framer; + FramerSetEnableCompressionHelper(&framer, true); + + { + const char kDescription[] = "HEADERS frame, no FIN"; + + SpdyHeaderBlock headers; + headers["bar"] = "foo"; + headers["foo"] = "bar"; + + const unsigned char kFrameData[] = { + 0x80, 0x02, 0x00, 0x08, + 0x00, 0x00, 0x00, 0x21, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x38, 0xea, + 0xdf, 0xa2, 0x51, 0xb2, + 0x62, 0x60, 0x62, 0x60, + 0x4e, 0x4a, 0x2c, 0x62, + 0x60, 0x4e, 0xcb, 0xcf, + 0x87, 0x12, 0x40, 0x2e, + 0x00, 0x00, 0x00, 0xff, + 0xff + }; + scoped_ptr<SpdyFrame> frame(framer.CreateHeaders( + 1, CONTROL_FLAG_NONE, true, &headers)); + CompareFrame(kDescription, *frame, kFrameData, arraysize(kFrameData)); + } +} + TEST_F(SpdyFramerTest, CreateWindowUpdate) { SpdyFramer framer; @@ -1180,4 +1306,80 @@ TEST_F(SpdyFramerTest, CreateWindowUpdate) { } } +// This test case reproduces conditions that caused ExpandControlFrameBuffer to +// fail to expand the buffer control frame buffer when it should have, allowing +// the framer to overrun the buffer, and smash other heap contents. This test +// relies on the debug version of the heap manager, which checks for buffer +// overrun errors during delete processing. Regression test for b/2974814. +TEST_F(SpdyFramerTest, ExpandBuffer_HeapSmash) { + // Sweep through the area of problematic values, to make sure we always cover + // the danger zone, even if it moves around at bit due to SPDY changes. + for (uint16 val2_len = SpdyFramer::kControlFrameBufferInitialSize - 50; + val2_len < SpdyFramer::kControlFrameBufferInitialSize; + val2_len++) { + std::string val2 = std::string(val2_len, 'a'); + SpdyHeaderBlock headers; + headers["bar"] = "foo"; + headers["foo"] = "baz"; + headers["grue"] = val2.c_str(); + SpdyFramer framer; + scoped_ptr<SpdySynStreamControlFrame> template_frame( + framer.CreateSynStream(1, // stream_id + 0, // associated_stream_id + 1, // priority + CONTROL_FLAG_NONE, + false, // compress + &headers)); + EXPECT_TRUE(template_frame.get() != NULL); + TestSpdyVisitor visitor; + visitor.SimulateInFramer( + reinterpret_cast<unsigned char*>(template_frame.get()->data()), + template_frame.get()->length() + SpdyControlFrame::size()); + EXPECT_EQ(1, visitor.syn_frame_count_); + } +} + +std::string RandomString(int length) { + std::string rv; + for (int index = 0; index < length; index++) + rv += static_cast<char>('a' + (rand() % 26)); + return rv; +} + +// Stress that we can handle a really large header block compression and +// decompression. +TEST_F(SpdyFramerTest, HugeHeaderBlock) { + // Loop targetting various sizes which will potentially jam up the + // frame compressor/decompressor. + SpdyFramer compress_framer; + SpdyFramer decompress_framer; + for (size_t target_size = 1024; + target_size < SpdyFramer::kControlFrameBufferInitialSize; + target_size += 1024) { + SpdyHeaderBlock headers; + for (size_t index = 0; index < target_size; ++index) { + std::string name = RandomString(4); + std::string value = RandomString(8); + headers[name] = value; + } + + // Encode the header block into a SynStream frame. + scoped_ptr<SpdySynStreamControlFrame> frame( + compress_framer.CreateSynStream(1, + 0, + 1, + CONTROL_FLAG_NONE, + true, + &headers)); + // The point of this test is to exercise the limits. So, it is ok if the + // frame was too large to encode, or if the decompress fails. We just want + // to make sure we don't crash. + if (frame.get() != NULL) { + // Now that same header block should decompress just fine. + SpdyHeaderBlock new_headers; + decompress_framer.ParseHeaderBlock(frame.get(), &new_headers); + } + } +} + } // namespace diff --git a/net/spdy/spdy_http_stream.cc b/net/spdy/spdy_http_stream.cc index 89002cb..1b72f7c 100644 --- a/net/spdy/spdy_http_stream.cc +++ b/net/spdy/spdy_http_stream.cc @@ -27,6 +27,7 @@ SpdyHttpStream::SpdyHttpStream(SpdySession* spdy_session, bool direct) spdy_session_(spdy_session), response_info_(NULL), download_finished_(false), + response_headers_received_(false), user_callback_(NULL), user_buffer_len_(0), buffered_read_callback_pending_(false), @@ -242,29 +243,41 @@ int SpdyHttpStream::OnResponseReceived(const spdy::SpdyHeaderBlock& response, response_info_ = push_response_info_.get(); } + // If the response is already received, these headers are too late. + if (response_headers_received_) { + LOG(WARNING) << "SpdyHttpStream headers received after response started."; + return OK; + } + // TODO(mbelshe): This is the time of all headers received, not just time // to first byte. - DCHECK(response_info_->response_time.is_null()); response_info_->response_time = base::Time::Now(); if (!SpdyHeadersToHttpResponse(response, response_info_)) { - status = ERR_INVALID_RESPONSE; - } else { - stream_->GetSSLInfo(&response_info_->ssl_info, - &response_info_->was_npn_negotiated); - response_info_->request_time = stream_->GetRequestTime(); - response_info_->vary_data.Init(*request_info_, *response_info_->headers); - // TODO(ahendrickson): This is recorded after the entire SYN_STREAM control - // frame has been received and processed. Move to framer? - response_info_->response_time = response_time; + // We might not have complete headers yet. + return ERR_INCOMPLETE_SPDY_HEADERS; } + response_headers_received_ = true; + stream_->GetSSLInfo(&response_info_->ssl_info, + &response_info_->was_npn_negotiated); + response_info_->request_time = stream_->GetRequestTime(); + response_info_->vary_data.Init(*request_info_, *response_info_->headers); + // TODO(ahendrickson): This is recorded after the entire SYN_STREAM control + // frame has been received and processed. Move to framer? + response_info_->response_time = response_time; + if (user_callback_) DoCallback(status); return status; } void SpdyHttpStream::OnDataReceived(const char* data, int length) { + // SpdyStream won't call us with data if the header block didn't contain a + // valid set of headers. So we don't expect to not have headers received + // here. + DCHECK(response_headers_received_); + // Note that data may be received for a SpdyStream prior to the user calling // ReadResponseBody(), therefore user_buffer_ may be NULL. This may often // happen for server initiated streams. diff --git a/net/spdy/spdy_http_stream.h b/net/spdy/spdy_http_stream.h index a878ff9..cd351cd 100644 --- a/net/spdy/spdy_http_stream.h +++ b/net/spdy/spdy_http_stream.h @@ -37,99 +37,50 @@ class SpdyHttpStream : public SpdyStream::Delegate, public HttpStream { SpdyStream* stream() { return stream_.get(); } - // =================================================== - // HttpStream methods: + // Cancels any callbacks from being invoked and deletes the stream. + void Cancel(); - // Initialize stream. Must be called before calling SendRequest(). + // HttpStream methods: virtual int InitializeStream(const HttpRequestInfo* request_info, const BoundNetLog& net_log, CompletionCallback* callback); - - // Sends the request. - // |callback| is used when this completes asynchronously. - // SpdyHttpStream takes ownership of |upload_data|. |upload_data| may be NULL. - // The actual SYN_STREAM packet will be sent if the stream is non-pushed. virtual int SendRequest(const HttpRequestHeaders& headers, UploadDataStream* request_body, HttpResponseInfo* response, CompletionCallback* callback); - - // Returns the number of bytes uploaded. virtual uint64 GetUploadProgress() const; - - // Reads the response headers. Returns a net error code. virtual int ReadResponseHeaders(CompletionCallback* callback); - virtual const HttpResponseInfo* GetResponseInfo() const; - - // Reads the response body. Returns a net error code or the number of bytes - // read. - virtual int ReadResponseBody( - IOBuffer* buf, int buf_len, CompletionCallback* callback); - - // Closes the stream. + virtual int ReadResponseBody(IOBuffer* buf, + int buf_len, + CompletionCallback* callback); virtual void Close(bool not_reusable); - virtual HttpStream* RenewStreamForAuth() { return NULL; } - - // Indicates if the response body has been completely read. virtual bool IsResponseBodyComplete() const { if (!stream_) return false; return stream_->closed(); } - - // With SPDY the end of response is always detectable. virtual bool CanFindEndOfResponse() const { return true; } - - // A SPDY stream never has more data after the FIN. virtual bool IsMoreDataBuffered() const { return false; } - virtual bool IsConnectionReused() const { return spdy_session_->IsReused(); } - virtual void SetConnectionReused() { // SPDY doesn't need an indicator here. } - virtual void GetSSLInfo(SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info); - // =================================================== - // SpdyStream::Delegate. - - // Cancels any callbacks from being invoked and deletes the stream. - void Cancel(); - + // SpdyStream::Delegate methods: virtual bool OnSendHeadersComplete(int status); virtual int OnSendBody(); virtual bool OnSendBodyComplete(int status); - - // Called by the SpdySession when a response (e.g. a SYN_REPLY) has been - // received for this stream. - // SpdyHttpSession calls back |callback| set by SendRequest or - // ReadResponseHeaders. virtual int OnResponseReceived(const spdy::SpdyHeaderBlock& response, base::Time response_time, int status); - - // Called by the SpdySession when response data has been received for this - // stream. This callback may be called multiple times as data arrives - // from the network, and will never be called prior to OnResponseReceived. - // SpdyHttpSession schedule to call back |callback| set by ReadResponseBody. virtual void OnDataReceived(const char* buffer, int bytes); - - // For HTTP streams, no data is sent from the client while in the OPEN state, - // so OnDataSent is never called. virtual void OnDataSent(int length); - - // Called by the SpdySession when the request is finished. This callback - // will always be called at the end of the request and signals to the - // stream that the stream has no more network events. No further callbacks - // to the stream will be made after this call. - // SpdyHttpSession call back |callback| set by SendRequest, - // ReadResponseHeaders or ReadResponseBody. virtual void OnClose(int status); private: @@ -161,6 +112,7 @@ class SpdyHttpStream : public SpdyStream::Delegate, public HttpStream { scoped_ptr<HttpResponseInfo> push_response_info_; bool download_finished_; + bool response_headers_received_; // Indicates waiting for more HEADERS. // We buffer the response body as it arrives asynchronously from the stream. // TODO(mbelshe): is this infinite buffering? diff --git a/net/spdy/spdy_http_utils.cc b/net/spdy/spdy_http_utils.cc index 09ac79c..e367c42 100644 --- a/net/spdy/spdy_http_utils.cc +++ b/net/spdy/spdy_http_utils.cc @@ -27,18 +27,14 @@ bool SpdyHeadersToHttpResponse(const spdy::SpdyHeaderBlock& headers, // The "status" and "version" headers are required. spdy::SpdyHeaderBlock::const_iterator it; it = headers.find("status"); - if (it == headers.end()) { - LOG(ERROR) << "SpdyHeaderBlock without status header."; + if (it == headers.end()) return false; - } status = it->second; // Grab the version. If not provided by the server, it = headers.find("version"); - if (it == headers.end()) { - LOG(ERROR) << "SpdyHeaderBlock without version header."; + if (it == headers.end()) return false; - } version = it->second; response->response_time = base::Time::Now(); diff --git a/net/spdy/spdy_network_transaction_unittest.cc b/net/spdy/spdy_network_transaction_unittest.cc index 1aaf657..075b872 100644 --- a/net/spdy/spdy_network_transaction_unittest.cc +++ b/net/spdy/spdy_network_transaction_unittest.cc @@ -177,7 +177,6 @@ class SpdyNetworkTransactionTest output_.status_line = response->headers->GetStatusLine(); output_.response_info = *response; // Make a copy so we can verify. output_.rv = ReadTransaction(trans_.get(), &output_.response_data); - EXPECT_EQ(OK, output_.rv); return; } @@ -352,7 +351,7 @@ class SpdyNetworkTransactionTest // to skip over data destined for other transactions while we consume // the data for |trans|. int ReadResult(HttpNetworkTransaction* trans, - OrderedSocketData* data, + StaticSocketDataProvider* data, std::string* result) { const int kSize = 3000; @@ -397,19 +396,14 @@ class SpdyNetworkTransactionTest EXPECT_EQ(0u, spdy_session->num_unclaimed_pushed_streams()); } - void RunServerPushTest(MockWrite writes[], int writes_length, - MockRead reads[], int reads_length, + void RunServerPushTest(OrderedSocketData* data, HttpResponseInfo* response, - HttpResponseInfo* response2, + HttpResponseInfo* push_response, std::string& expected) { - scoped_refptr<OrderedSocketData> data( - new OrderedSocketData(reads, reads_length, - writes, writes_length)); NormalSpdyTransactionHelper helper(CreateGetRequest(), BoundNetLog(), GetParam()); - helper.RunPreTestSetup(); - helper.AddData(data.get()); + helper.AddData(data); HttpNetworkTransaction* trans = helper.trans(); @@ -449,7 +443,7 @@ class SpdyNetworkTransactionTest // Verify the SYN_REPLY. // Copy the response info, because trans goes away. *response = *trans->GetResponseInfo(); - *response2 = *trans2->GetResponseInfo(); + *push_response = *trans2->GetResponseInfo(); VerifyStreamsClosed(helper); } @@ -1198,7 +1192,7 @@ TEST_P(SpdyNetworkTransactionTest, Put) { "content-length", "0" }; scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyPacket(kSynStartHeader, NULL, 0, - kPutHeaders, arraysize(kPutHeaders)/2)); + kPutHeaders, arraysize(kPutHeaders) / 2)); MockWrite writes[] = { CreateMockWrite(*req) }; @@ -1222,7 +1216,7 @@ TEST_P(SpdyNetworkTransactionTest, Put) { "content-length", "1234" }; scoped_ptr<spdy::SpdyFrame> resp(ConstructSpdyPacket(kSynReplyHeader, - NULL, 0, kStandardGetHeaders, arraysize(kStandardGetHeaders)/2)); + NULL, 0, kStandardGetHeaders, arraysize(kStandardGetHeaders) / 2)); MockRead reads[] = { CreateMockRead(*resp), CreateMockRead(*body), @@ -1269,7 +1263,7 @@ TEST_P(SpdyNetworkTransactionTest, Head) { "content-length", "0" }; scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyPacket(kSynStartHeader, NULL, 0, - kHeadHeaders, arraysize(kHeadHeaders)/2)); + kHeadHeaders, arraysize(kHeadHeaders) / 2)); MockWrite writes[] = { CreateMockWrite(*req) }; @@ -1293,7 +1287,7 @@ TEST_P(SpdyNetworkTransactionTest, Head) { "content-length", "1234" }; scoped_ptr<spdy::SpdyFrame> resp(ConstructSpdyPacket(kSynReplyHeader, - NULL, 0, kStandardGetHeaders, arraysize(kStandardGetHeaders)/2)); + NULL, 0, kStandardGetHeaders, arraysize(kStandardGetHeaders) / 2)); MockRead reads[] = { CreateMockRead(*resp), CreateMockRead(*body), @@ -1650,7 +1644,7 @@ TEST_P(SpdyNetworkTransactionTest, WindowUpdateReceived) { SpdyHttpStream* stream = static_cast<SpdyHttpStream*>(trans->stream_.get()); ASSERT_TRUE(stream != NULL); ASSERT_TRUE(stream->stream() != NULL); - EXPECT_EQ(spdy::kInitialWindowSize + + EXPECT_EQ(static_cast<int>(spdy::kSpdyStreamInitialWindowSize) + kDeltaWindowSize * kDeltaCount - kMaxSpdyFrameChunkSize * kFrameCount, stream->stream()->send_window_size()); @@ -1709,8 +1703,9 @@ TEST_P(SpdyNetworkTransactionTest, WindowUpdateSent) { ASSERT_TRUE(stream != NULL); ASSERT_TRUE(stream->stream() != NULL); - EXPECT_EQ(spdy::kInitialWindowSize - kUploadDataSize, - stream->stream()->recv_window_size()); + EXPECT_EQ( + static_cast<int>(spdy::kSpdyStreamInitialWindowSize) - kUploadDataSize, + stream->stream()->recv_window_size()); const HttpResponseInfo* response = trans->GetResponseInfo(); ASSERT_TRUE(response != NULL); @@ -1841,17 +1836,19 @@ TEST_P(SpdyNetworkTransactionTest, FlowControlStallResume) { // frames plus SYN_STREAM plus the last data frame; also we need another // data frame that we will send once the WINDOW_UPDATE is received, // therefore +3. - size_t nwrites = spdy::kInitialWindowSize / kMaxSpdyFrameChunkSize + 3; + size_t nwrites = + spdy::kSpdyStreamInitialWindowSize / kMaxSpdyFrameChunkSize + 3; // Calculate last frame's size; 0 size data frame is legal. - size_t last_frame_size = spdy::kInitialWindowSize % kMaxSpdyFrameChunkSize; + size_t last_frame_size = + spdy::kSpdyStreamInitialWindowSize % kMaxSpdyFrameChunkSize; // Construct content for a data frame of maximum size. scoped_ptr<std::string> content( new std::string(kMaxSpdyFrameChunkSize, 'a')); scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyPost( - spdy::kInitialWindowSize + kUploadDataSize, NULL, 0)); + spdy::kSpdyStreamInitialWindowSize + kUploadDataSize, NULL, 0)); // Full frames. scoped_ptr<spdy::SpdyFrame> body1( @@ -1898,7 +1895,7 @@ TEST_P(SpdyNetworkTransactionTest, FlowControlStallResume) { request.url = GURL("http://www.google.com/"); request.upload_data = new UploadData(); scoped_ptr<std::string> upload_data( - new std::string(spdy::kInitialWindowSize, 'a')); + new std::string(spdy::kSpdyStreamInitialWindowSize, 'a')); upload_data->append(kUploadData, kUploadDataSize); request.upload_data->AppendBytes(upload_data->c_str(), upload_data->size()); NormalSpdyTransactionHelper helper(request, @@ -2199,11 +2196,11 @@ TEST_P(SpdyNetworkTransactionTest, RedirectGetRequest) { // Setup writes/reads to www.google.com scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyPacket( - kSynStartHeader, kExtraHeaders, arraysize(kExtraHeaders)/2, - kStandardGetHeaders, arraysize(kStandardGetHeaders)/2)); + kSynStartHeader, kExtraHeaders, arraysize(kExtraHeaders) / 2, + kStandardGetHeaders, arraysize(kStandardGetHeaders) / 2)); scoped_ptr<spdy::SpdyFrame> req2(ConstructSpdyPacket( - kSynStartHeader, kExtraHeaders, arraysize(kExtraHeaders)/2, - kStandardGetHeaders2, arraysize(kStandardGetHeaders2)/2)); + kSynStartHeader, kExtraHeaders, arraysize(kExtraHeaders) / 2, + kStandardGetHeaders2, arraysize(kStandardGetHeaders2) / 2)); scoped_ptr<spdy::SpdyFrame> resp(ConstructSpdyGetSynReplyRedirect(1)); MockWrite writes[] = { CreateMockWrite(*req, 1), @@ -2308,16 +2305,27 @@ TEST_P(SpdyNetworkTransactionTest, RedirectServerPush) { }; // Setup writes/reads to www.google.com - scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyPacket( - kSynStartHeader, kExtraHeaders, arraysize(kExtraHeaders)/2, - kStandardGetHeaders, arraysize(kStandardGetHeaders)/2)); - scoped_ptr<spdy::SpdyFrame> req2(ConstructSpdyPacket( - kSynStartHeader, kExtraHeaders, arraysize(kExtraHeaders)/2, - kStandardGetHeaders2, arraysize(kStandardGetHeaders2)/2)); + scoped_ptr<spdy::SpdyFrame> req( + ConstructSpdyPacket(kSynStartHeader, + kExtraHeaders, + arraysize(kExtraHeaders) / 2, + kStandardGetHeaders, + arraysize(kStandardGetHeaders) / 2)); + scoped_ptr<spdy::SpdyFrame> req2( + ConstructSpdyPacket(kSynStartHeader, + kExtraHeaders, + arraysize(kExtraHeaders) / 2, + kStandardGetHeaders2, + arraysize(kStandardGetHeaders2) / 2)); scoped_ptr<spdy::SpdyFrame> resp(ConstructSpdyGetSynReply(NULL, 0, 1)); - scoped_ptr<spdy::SpdyFrame> rep(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat", - "301 Moved Permanently", "http://www.foo.com/index.php", - "http://www.foo.com/index.php")); + scoped_ptr<spdy::SpdyFrame> rep( + ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat", + "301 Moved Permanently", + "http://www.foo.com/index.php")); scoped_ptr<spdy::SpdyFrame> body(ConstructSpdyBodyFrame(1, true)); MockWrite writes[] = { CreateMockWrite(*req, 1), @@ -2411,7 +2419,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushSingleDataFrame) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat")); MockRead reads[] = { CreateMockRead(*stream1_reply, 2), CreateMockRead(*stream2_syn, 3), @@ -2424,8 +2436,15 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushSingleDataFrame) { HttpResponseInfo response; HttpResponseInfo response2; std::string expected_push_result("pushed"); - RunServerPushTest(writes, arraysize(writes), reads, arraysize(reads), - &response, &response2, expected_push_result); + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + RunServerPushTest(data.get(), + &response, + &response2, + expected_push_result); // Verify the SYN_REPLY. EXPECT_TRUE(response.headers != NULL); @@ -2451,7 +2470,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushSingleDataFrame2) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat")); scoped_ptr<spdy::SpdyFrame> stream1_body(ConstructSpdyBodyFrame(1, true)); MockRead reads[] = { @@ -2466,8 +2489,15 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushSingleDataFrame2) { HttpResponseInfo response; HttpResponseInfo response2; std::string expected_push_result("pushed"); - RunServerPushTest(writes, arraysize(writes), reads, arraysize(reads), - &response, &response2, expected_push_result); + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + RunServerPushTest(data.get(), + &response, + &response2, + expected_push_result); // Verify the SYN_REPLY. EXPECT_TRUE(response.headers != NULL); @@ -2490,7 +2520,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushServerAborted) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat")); scoped_ptr<spdy::SpdyFrame> stream2_rst(ConstructSpdyRstStream(2, spdy::PROTOCOL_ERROR)); MockRead reads[] = { @@ -2558,9 +2592,17 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushDuplicate) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat")); scoped_ptr<spdy::SpdyFrame> - stream3_syn(ConstructSpdyPush(NULL, 0, 4, 1, "/foo.dat")); + stream3_syn(ConstructSpdyPush(NULL, + 0, + 4, + 1, + "http://www.google.com/foo.dat")); MockRead reads[] = { CreateMockRead(*stream1_reply, 2), CreateMockRead(*stream2_syn, 3), @@ -2574,8 +2616,15 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushDuplicate) { HttpResponseInfo response; HttpResponseInfo response2; std::string expected_push_result("pushed"); - RunServerPushTest(writes, arraysize(writes), reads, arraysize(reads), - &response, &response2, expected_push_result); + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + RunServerPushTest(data.get(), + &response, + &response2, + expected_push_result); // Verify the SYN_REPLY. EXPECT_TRUE(response.headers != NULL); @@ -2607,7 +2656,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushMultipleDataFrame) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat")); MockRead reads[] = { CreateMockRead(*stream1_reply, 2), CreateMockRead(*stream2_syn, 3), @@ -2626,8 +2679,15 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushMultipleDataFrame) { HttpResponseInfo response; HttpResponseInfo response2; std::string expected_push_result("pushed my darling hello my baby"); - RunServerPushTest(writes, arraysize(writes), reads, arraysize(reads), - &response, &response2, expected_push_result); + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + RunServerPushTest(data.get(), + &response, + &response2, + expected_push_result); // Verify the SYN_REPLY. EXPECT_TRUE(response.headers != NULL); @@ -2659,7 +2719,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushMultipleDataFrameInterrupted) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 1, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + "http://www.google.com/foo.dat")); MockRead reads[] = { CreateMockRead(*stream1_reply, 2), CreateMockRead(*stream2_syn, 3), @@ -2679,8 +2743,15 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushMultipleDataFrameInterrupted) { HttpResponseInfo response; HttpResponseInfo response2; std::string expected_push_result("pushed my darling hello my baby"); - RunServerPushTest(writes, arraysize(writes), reads, arraysize(reads), - &response, &response2, expected_push_result); + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + RunServerPushTest(data.get(), + &response, + &response2, + expected_push_result); // Verify the SYN_REPLY. EXPECT_TRUE(response.headers != NULL); @@ -2706,7 +2777,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushInvalidAssociatedStreamID0) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 0, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 0, + "http://www.google.com/foo.dat")); MockRead reads[] = { CreateMockRead(*stream1_reply, 2), CreateMockRead(*stream2_syn, 3), @@ -2763,7 +2838,11 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushInvalidAssociatedStreamID9) { scoped_ptr<spdy::SpdyFrame> stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); scoped_ptr<spdy::SpdyFrame> - stream2_syn(ConstructSpdyPush(NULL, 0, 2, 9, "/foo.dat")); + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 9, + "http://www.google.com/foo.dat")); MockRead reads[] = { CreateMockRead(*stream1_reply, 2), CreateMockRead(*stream2_syn, 3), @@ -2879,7 +2958,6 @@ TEST_P(SpdyNetworkTransactionTest, SynReplyHeaders) { "cookie: val2\n" "hello: bye\n" "status: 200\n" - "url: /index.php\n" "version: HTTP/1.1\n" }, // This is the minimalist set of headers. @@ -2887,7 +2965,6 @@ TEST_P(SpdyNetworkTransactionTest, SynReplyHeaders) { { NULL }, "hello: bye\n" "status: 200\n" - "url: /index.php\n" "version: HTTP/1.1\n" }, // Headers with a comma separated list. @@ -2898,7 +2975,6 @@ TEST_P(SpdyNetworkTransactionTest, SynReplyHeaders) { "cookie: val1,val2\n" "hello: bye\n" "status: 200\n" - "url: /index.php\n" "version: HTTP/1.1\n" } }; @@ -3181,7 +3257,7 @@ TEST_P(SpdyNetworkTransactionTest, InvalidSynReply) { BoundNetLog(), GetParam()); helper.RunToCompletion(data.get()); TransactionHelperResult out = helper.output(); - EXPECT_EQ(ERR_INVALID_RESPONSE, out.rv); + EXPECT_EQ(ERR_INCOMPLETE_SPDY_HEADERS, out.rv); } } @@ -4599,7 +4675,8 @@ TEST_P(SpdyNetworkTransactionTest, SpdyBasicAuth) { }; scoped_ptr<spdy::SpdyFrame> req_get_authorization( ConstructSpdyGet( - kExtraAuthorizationHeaders, arraysize(kExtraAuthorizationHeaders)/2, + kExtraAuthorizationHeaders, + arraysize(kExtraAuthorizationHeaders) / 2, false, 3, LOWEST)); MockWrite spdy_writes[] = { CreateMockWrite(*req_get, 1), @@ -4616,7 +4693,8 @@ TEST_P(SpdyNetworkTransactionTest, SpdyBasicAuth) { scoped_ptr<spdy::SpdyFrame> resp_authentication( ConstructSpdySynReplyError( "401 Authentication Required", - kExtraAuthenticationHeaders, arraysize(kExtraAuthenticationHeaders)/2, + kExtraAuthenticationHeaders, + arraysize(kExtraAuthenticationHeaders) / 2, 1)); scoped_ptr<spdy::SpdyFrame> body_authentication( ConstructSpdyBodyFrame(1, true)); @@ -4674,4 +4752,661 @@ TEST_P(SpdyNetworkTransactionTest, SpdyBasicAuth) { EXPECT_TRUE(response_restart->auth_challenge.get() == NULL); } +TEST_P(SpdyNetworkTransactionTest, ServerPushWithHeaders) { + static const unsigned char kPushBodyFrame[] = { + 0x00, 0x00, 0x00, 0x02, // header, ID + 0x01, 0x00, 0x00, 0x06, // FIN, length + 'p', 'u', 's', 'h', 'e', 'd' // "pushed" + }; + scoped_ptr<spdy::SpdyFrame> + stream1_syn(ConstructSpdyGet(NULL, 0, false, 1, LOWEST)); + scoped_ptr<spdy::SpdyFrame> + stream1_body(ConstructSpdyBodyFrame(1, true)); + MockWrite writes[] = { + CreateMockWrite(*stream1_syn, 1), + }; + + static const char* const kInitialHeaders[] = { + "url", + "http://www.google.com/foo.dat", + }; + static const char* const kLateHeaders[] = { + "hello", + "bye", + "status", + "200", + "version", + "HTTP/1.1" + }; + scoped_ptr<spdy::SpdyFrame> + stream2_syn(ConstructSpdyControlFrame(kInitialHeaders, + arraysize(kInitialHeaders) / 2, + false, + 2, + LOWEST, + spdy::SYN_STREAM, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 1)); + scoped_ptr<spdy::SpdyFrame> + stream2_headers(ConstructSpdyControlFrame(kLateHeaders, + arraysize(kLateHeaders) / 2, + false, + 2, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); + MockRead reads[] = { + CreateMockRead(*stream1_reply, 2), + CreateMockRead(*stream2_syn, 3), + CreateMockRead(*stream2_headers, 4), + CreateMockRead(*stream1_body, 5, false), + MockRead(true, reinterpret_cast<const char*>(kPushBodyFrame), + arraysize(kPushBodyFrame), 6), + MockRead(true, ERR_IO_PENDING, 7), // Force a pause + }; + + HttpResponseInfo response; + HttpResponseInfo response2; + std::string expected_push_result("pushed"); + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + RunServerPushTest(data.get(), + &response, + &response2, + expected_push_result); + + // Verify the SYN_REPLY. + EXPECT_TRUE(response.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response.headers->GetStatusLine()); + + // Verify the pushed stream. + EXPECT_TRUE(response2.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response2.headers->GetStatusLine()); +} + +TEST_P(SpdyNetworkTransactionTest, ServerPushClaimBeforeHeaders) { + // We push a stream and attempt to claim it before the headers come down. + static const unsigned char kPushBodyFrame[] = { + 0x00, 0x00, 0x00, 0x02, // header, ID + 0x01, 0x00, 0x00, 0x06, // FIN, length + 'p', 'u', 's', 'h', 'e', 'd' // "pushed" + }; + scoped_ptr<spdy::SpdyFrame> + stream1_syn(ConstructSpdyGet(NULL, 0, false, 1, LOWEST)); + scoped_ptr<spdy::SpdyFrame> + stream1_body(ConstructSpdyBodyFrame(1, true)); + MockWrite writes[] = { + CreateMockWrite(*stream1_syn, 0, false), + }; + + static const char* const kInitialHeaders[] = { + "url", + "http://www.google.com/foo.dat", + }; + static const char* const kLateHeaders[] = { + "hello", + "bye", + "status", + "200", + "version", + "HTTP/1.1" + }; + scoped_ptr<spdy::SpdyFrame> + stream2_syn(ConstructSpdyControlFrame(kInitialHeaders, + arraysize(kInitialHeaders) / 2, + false, + 2, + LOWEST, + spdy::SYN_STREAM, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 1)); + scoped_ptr<spdy::SpdyFrame> + stream2_headers(ConstructSpdyControlFrame(kLateHeaders, + arraysize(kLateHeaders) / 2, + false, + 2, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); + MockRead reads[] = { + CreateMockRead(*stream1_reply, 1), + CreateMockRead(*stream2_syn, 2), + CreateMockRead(*stream1_body, 3), + CreateMockRead(*stream2_headers, 4), + MockRead(true, reinterpret_cast<const char*>(kPushBodyFrame), + arraysize(kPushBodyFrame), 5), + MockRead(true, 0, 5), // EOF + }; + + HttpResponseInfo response; + HttpResponseInfo response2; + std::string expected_push_result("pushed"); + scoped_refptr<DeterministicSocketData> data(new DeterministicSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + + NormalSpdyTransactionHelper helper(CreateGetRequest(), + BoundNetLog(), GetParam()); + helper.SetDeterministic(); + helper.AddDeterministicData(static_cast<DeterministicSocketData*>(data)); + helper.RunPreTestSetup(); + + HttpNetworkTransaction* trans = helper.trans(); + + // Run until we've received the primary SYN_STREAM, the pushed SYN_STREAM, + // and the body of the primary stream, but before we've received the HEADERS + // for the pushed stream. + data->SetStop(3); + + // Start the transaction. + TestCompletionCallback callback; + int rv = trans->Start(&CreateGetRequest(), &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + data->Run(); + rv = callback.WaitForResult(); + EXPECT_EQ(0, rv); + + // Request the pushed path. At this point, we've received the push, but the + // headers are not yet complete. + scoped_ptr<HttpNetworkTransaction> trans2( + new HttpNetworkTransaction(helper.session())); + rv = trans2->Start(&CreateGetPushRequest(), &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + data->RunFor(3); + MessageLoop::current()->RunAllPending(); + + // Read the server push body. + std::string result2; + ReadResult(trans2.get(), data.get(), &result2); + // Read the response body. + std::string result; + ReadResult(trans, data, &result); + + // Verify that we consumed all test data. + EXPECT_TRUE(data->at_read_eof()); + EXPECT_TRUE(data->at_write_eof()); + + // Verify that the received push data is same as the expected push data. + EXPECT_EQ(result2.compare(expected_push_result), 0) + << "Received data: " + << result2 + << "||||| Expected data: " + << expected_push_result; + + // Verify the SYN_REPLY. + // Copy the response info, because trans goes away. + response = *trans->GetResponseInfo(); + response2 = *trans2->GetResponseInfo(); + + VerifyStreamsClosed(helper); + + // Verify the SYN_REPLY. + EXPECT_TRUE(response.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response.headers->GetStatusLine()); + + // Verify the pushed stream. + EXPECT_TRUE(response2.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response2.headers->GetStatusLine()); +} + +TEST_P(SpdyNetworkTransactionTest, ServerPushWithTwoHeaderFrames) { + // We push a stream and attempt to claim it before the headers come down. + static const unsigned char kPushBodyFrame[] = { + 0x00, 0x00, 0x00, 0x02, // header, ID + 0x01, 0x00, 0x00, 0x06, // FIN, length + 'p', 'u', 's', 'h', 'e', 'd' // "pushed" + }; + scoped_ptr<spdy::SpdyFrame> + stream1_syn(ConstructSpdyGet(NULL, 0, false, 1, LOWEST)); + scoped_ptr<spdy::SpdyFrame> + stream1_body(ConstructSpdyBodyFrame(1, true)); + MockWrite writes[] = { + CreateMockWrite(*stream1_syn, 0, false), + }; + + static const char* const kInitialHeaders[] = { + "url", + "http://www.google.com/foo.dat", + }; + static const char* const kMiddleHeaders[] = { + "hello", + "bye", + }; + static const char* const kLateHeaders[] = { + "status", + "200", + "version", + "HTTP/1.1" + }; + scoped_ptr<spdy::SpdyFrame> + stream2_syn(ConstructSpdyControlFrame(kInitialHeaders, + arraysize(kInitialHeaders) / 2, + false, + 2, + LOWEST, + spdy::SYN_STREAM, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 1)); + scoped_ptr<spdy::SpdyFrame> + stream2_headers1(ConstructSpdyControlFrame(kMiddleHeaders, + arraysize(kMiddleHeaders) / 2, + false, + 2, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> + stream2_headers2(ConstructSpdyControlFrame(kLateHeaders, + arraysize(kLateHeaders) / 2, + false, + 2, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); + MockRead reads[] = { + CreateMockRead(*stream1_reply, 1), + CreateMockRead(*stream2_syn, 2), + CreateMockRead(*stream1_body, 3), + CreateMockRead(*stream2_headers1, 4), + CreateMockRead(*stream2_headers2, 5), + MockRead(true, reinterpret_cast<const char*>(kPushBodyFrame), + arraysize(kPushBodyFrame), 6), + MockRead(true, 0, 6), // EOF + }; + + HttpResponseInfo response; + HttpResponseInfo response2; + std::string expected_push_result("pushed"); + scoped_refptr<DeterministicSocketData> data(new DeterministicSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + + NormalSpdyTransactionHelper helper(CreateGetRequest(), + BoundNetLog(), GetParam()); + helper.SetDeterministic(); + helper.AddDeterministicData(static_cast<DeterministicSocketData*>(data)); + helper.RunPreTestSetup(); + + HttpNetworkTransaction* trans = helper.trans(); + + // Run until we've received the primary SYN_STREAM, the pushed SYN_STREAM, + // the first HEADERS frame, and the body of the primary stream, but before + // we've received the final HEADERS for the pushed stream. + data->SetStop(4); + + // Start the transaction. + TestCompletionCallback callback; + int rv = trans->Start(&CreateGetRequest(), &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + data->Run(); + rv = callback.WaitForResult(); + EXPECT_EQ(0, rv); + + // Request the pushed path. At this point, we've received the push, but the + // headers are not yet complete. + scoped_ptr<HttpNetworkTransaction> trans2( + new HttpNetworkTransaction(helper.session())); + rv = trans2->Start(&CreateGetPushRequest(), &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + data->RunFor(3); + MessageLoop::current()->RunAllPending(); + + // Read the server push body. + std::string result2; + ReadResult(trans2.get(), data, &result2); + // Read the response body. + std::string result; + ReadResult(trans, data, &result); + + // Verify that we consumed all test data. + EXPECT_TRUE(data->at_read_eof()); + EXPECT_TRUE(data->at_write_eof()); + + // Verify that the received push data is same as the expected push data. + EXPECT_EQ(result2.compare(expected_push_result), 0) + << "Received data: " + << result2 + << "||||| Expected data: " + << expected_push_result; + + // Verify the SYN_REPLY. + // Copy the response info, because trans goes away. + response = *trans->GetResponseInfo(); + response2 = *trans2->GetResponseInfo(); + + VerifyStreamsClosed(helper); + + // Verify the SYN_REPLY. + EXPECT_TRUE(response.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response.headers->GetStatusLine()); + + // Verify the pushed stream. + EXPECT_TRUE(response2.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response2.headers->GetStatusLine()); + + // Verify we got all the headers + EXPECT_TRUE(response2.headers->HasHeaderValue( + "url", + "http://www.google.com/foo.dat")); + EXPECT_TRUE(response2.headers->HasHeaderValue("hello", "bye")); + EXPECT_TRUE(response2.headers->HasHeaderValue("status", "200")); + EXPECT_TRUE(response2.headers->HasHeaderValue("version", "HTTP/1.1")); +} + +TEST_P(SpdyNetworkTransactionTest, SynReplyWithHeaders) { + scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyGet(NULL, 0, false, 1, LOWEST)); + MockWrite writes[] = { CreateMockWrite(*req) }; + + static const char* const kInitialHeaders[] = { + "status", + "200 OK", + "version", + "HTTP/1.1" + }; + static const char* const kLateHeaders[] = { + "hello", + "bye", + }; + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyControlFrame(kInitialHeaders, + arraysize(kInitialHeaders) / 2, + false, + 1, + LOWEST, + spdy::SYN_REPLY, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> + stream1_headers(ConstructSpdyControlFrame(kLateHeaders, + arraysize(kLateHeaders) / 2, + false, + 1, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> stream1_body(ConstructSpdyBodyFrame(1, true)); + MockRead reads[] = { + CreateMockRead(*stream1_reply), + CreateMockRead(*stream1_headers), + CreateMockRead(*stream1_body), + MockRead(true, 0, 0) // EOF + }; + + scoped_refptr<DelayedSocketData> data( + new DelayedSocketData(1, reads, arraysize(reads), + writes, arraysize(writes))); + NormalSpdyTransactionHelper helper(CreateGetRequest(), + BoundNetLog(), GetParam()); + helper.RunToCompletion(data.get()); + TransactionHelperResult out = helper.output(); + EXPECT_EQ(OK, out.rv); + EXPECT_EQ("HTTP/1.1 200 OK", out.status_line); + EXPECT_EQ("hello!", out.response_data); +} + +TEST_P(SpdyNetworkTransactionTest, SynReplyWithLateHeaders) { + scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyGet(NULL, 0, false, 1, LOWEST)); + MockWrite writes[] = { CreateMockWrite(*req) }; + + static const char* const kInitialHeaders[] = { + "status", + "200 OK", + "version", + "HTTP/1.1" + }; + static const char* const kLateHeaders[] = { + "hello", + "bye", + }; + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyControlFrame(kInitialHeaders, + arraysize(kInitialHeaders) / 2, + false, + 1, + LOWEST, + spdy::SYN_REPLY, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> + stream1_headers(ConstructSpdyControlFrame(kLateHeaders, + arraysize(kLateHeaders) / 2, + false, + 1, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> stream1_body(ConstructSpdyBodyFrame(1, false)); + scoped_ptr<spdy::SpdyFrame> stream1_body2(ConstructSpdyBodyFrame(1, true)); + MockRead reads[] = { + CreateMockRead(*stream1_reply), + CreateMockRead(*stream1_body), + CreateMockRead(*stream1_headers), + CreateMockRead(*stream1_body2), + MockRead(true, 0, 0) // EOF + }; + + scoped_refptr<DelayedSocketData> data( + new DelayedSocketData(1, reads, arraysize(reads), + writes, arraysize(writes))); + NormalSpdyTransactionHelper helper(CreateGetRequest(), + BoundNetLog(), GetParam()); + helper.RunToCompletion(data.get()); + TransactionHelperResult out = helper.output(); + EXPECT_EQ(OK, out.rv); + EXPECT_EQ("HTTP/1.1 200 OK", out.status_line); + EXPECT_EQ("hello!hello!", out.response_data); +} + +TEST_P(SpdyNetworkTransactionTest, SynReplyWithDuplicateLateHeaders) { + scoped_ptr<spdy::SpdyFrame> req(ConstructSpdyGet(NULL, 0, false, 1, LOWEST)); + MockWrite writes[] = { CreateMockWrite(*req) }; + + static const char* const kInitialHeaders[] = { + "status", + "200 OK", + "version", + "HTTP/1.1" + }; + static const char* const kLateHeaders[] = { + "status", + "500 Server Error", + }; + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyControlFrame(kInitialHeaders, + arraysize(kInitialHeaders) / 2, + false, + 1, + LOWEST, + spdy::SYN_REPLY, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> + stream1_headers(ConstructSpdyControlFrame(kLateHeaders, + arraysize(kLateHeaders) / 2, + false, + 1, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + NULL, + 0, + 0)); + scoped_ptr<spdy::SpdyFrame> stream1_body(ConstructSpdyBodyFrame(1, false)); + scoped_ptr<spdy::SpdyFrame> stream1_body2(ConstructSpdyBodyFrame(1, true)); + MockRead reads[] = { + CreateMockRead(*stream1_reply), + CreateMockRead(*stream1_body), + CreateMockRead(*stream1_headers), + CreateMockRead(*stream1_body2), + MockRead(true, 0, 0) // EOF + }; + + scoped_refptr<DelayedSocketData> data( + new DelayedSocketData(1, reads, arraysize(reads), + writes, arraysize(writes))); + NormalSpdyTransactionHelper helper(CreateGetRequest(), + BoundNetLog(), GetParam()); + helper.RunToCompletion(data.get()); + TransactionHelperResult out = helper.output(); + EXPECT_EQ(ERR_SPDY_PROTOCOL_ERROR, out.rv); +} + +TEST_P(SpdyNetworkTransactionTest, ServerPushCrossOriginCorrectness) { + // In this test we want to verify that we can't accidentally push content + // which can't be pushed by this content server. + // This test assumes that: + // - if we're requesting http://www.foo.com/barbaz + // - the browser has made a connection to "www.foo.com". + + // A list of the URL to fetch, followed by the URL being pushed. + static const char* const kTestCases[] = { + "http://www.google.com/foo.html", + "http://www.google.com:81/foo.js", // Bad port + + "http://www.google.com/foo.html", + "https://www.google.com/foo.js", // Bad protocol + + "http://www.google.com/foo.html", + "ftp://www.google.com/foo.js", // Invalid Protocol + + "http://www.google.com/foo.html", + "http://blat.www.google.com/foo.js", // Cross subdomain + + "http://www.google.com/foo.html", + "http://www.foo.com/foo.js", // Cross domain + }; + + + static const unsigned char kPushBodyFrame[] = { + 0x00, 0x00, 0x00, 0x02, // header, ID + 0x01, 0x00, 0x00, 0x06, // FIN, length + 'p', 'u', 's', 'h', 'e', 'd' // "pushed" + }; + + for (size_t index = 0; index < arraysize(kTestCases); index += 2) { + const char* url_to_fetch = kTestCases[index]; + const char* url_to_push = kTestCases[index + 1]; + + scoped_ptr<spdy::SpdyFrame> + stream1_syn(ConstructSpdyGet(url_to_fetch, false, 1, LOWEST)); + scoped_ptr<spdy::SpdyFrame> + stream1_body(ConstructSpdyBodyFrame(1, true)); + scoped_ptr<spdy::SpdyFrame> push_rst( + ConstructSpdyRstStream(2, spdy::REFUSED_STREAM)); + MockWrite writes[] = { + CreateMockWrite(*stream1_syn, 1), + CreateMockWrite(*push_rst, 4), + }; + + scoped_ptr<spdy::SpdyFrame> + stream1_reply(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<spdy::SpdyFrame> + stream2_syn(ConstructSpdyPush(NULL, + 0, + 2, + 1, + url_to_push)); + scoped_ptr<spdy::SpdyFrame> rst( + ConstructSpdyRstStream(2, spdy::CANCEL)); + + MockRead reads[] = { + CreateMockRead(*stream1_reply, 2), + CreateMockRead(*stream2_syn, 3), + CreateMockRead(*stream1_body, 5, false), + MockRead(true, reinterpret_cast<const char*>(kPushBodyFrame), + arraysize(kPushBodyFrame), 6), + MockRead(true, ERR_IO_PENDING, 7), // Force a pause + }; + + HttpResponseInfo response; + scoped_refptr<OrderedSocketData> data(new OrderedSocketData( + reads, + arraysize(reads), + writes, + arraysize(writes))); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL(url_to_fetch); + request.load_flags = 0; + NormalSpdyTransactionHelper helper(request, + BoundNetLog(), GetParam()); + helper.RunPreTestSetup(); + helper.AddData(data); + + HttpNetworkTransaction* trans = helper.trans(); + + // Start the transaction with basic parameters. + TestCompletionCallback callback; + + int rv = trans->Start(&request, &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + + // Read the response body. + std::string result; + ReadResult(trans, data, &result); + + // Verify that we consumed all test data. + EXPECT_TRUE(data->at_read_eof()); + EXPECT_TRUE(data->at_write_eof()); + + // Verify the SYN_REPLY. + // Copy the response info, because trans goes away. + response = *trans->GetResponseInfo(); + + VerifyStreamsClosed(helper); + + // Verify the SYN_REPLY. + EXPECT_TRUE(response.headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response.headers->GetStatusLine()); + } +} + } // namespace net diff --git a/net/spdy/spdy_protocol.h b/net/spdy/spdy_protocol.h index ce074c4..9834a11 100644 --- a/net/spdy/spdy_protocol.h +++ b/net/spdy/spdy_protocol.h @@ -14,6 +14,8 @@ #include <arpa/inet.h> #endif +#include <limits> + #include "base/basictypes.h" #include "base/logging.h" #include "net/spdy/spdy_bitmasks.h" @@ -40,7 +42,7 @@ // +----------------------------------+ // |1|000000000000001|0000000000000001| // +----------------------------------+ -// | flags (8) | Length (24 bits) | >= 8 +// | flags (8) | Length (24 bits) | >= 12 // +----------------------------------+ // |X| Stream-ID(31bits) | // +----------------------------------+ @@ -105,6 +107,17 @@ // |X| Last-accepted-stream-id | // +----------------------------------+ // +// Control Frame: HEADERS +// +----------------------------------+ +// |1|000000000000001|0000000000001000| +// +----------------------------------+ +// | flags (8) | Length (24 bits) | >= 8 +// +----------------------------------+ +// |X| Stream-ID (31 bits) | +// +----------------------------------+ +// | unused (16 bits)| Length (16bits)| +// +----------------------------------+ +// // Control Frame: WINDOW_UPDATE // +----------------------------------+ // |1|000000000000001|0000000000001001| @@ -115,15 +128,26 @@ // +----------------------------------+ // | Delta-Window-Size (32 bits) | // +----------------------------------+ - - namespace spdy { -// The SPDY version of this implementation. +// This implementation of Spdy is version 2; It's like version 1, with some +// minor tweaks. const int kSpdyProtocolVersion = 2; -// Default initial window size. -const int kInitialWindowSize = 64 * 1024; +// Initial window size for a Spdy stream +const size_t kSpdyStreamInitialWindowSize = 64 * 1024; // 64 KBytes + +// Maximum window size for a Spdy stream +const size_t kSpdyStreamMaximumWindowSize = std::numeric_limits<int32>::max(); + +// HTTP-over-SPDY header constants +const char kMethod[] = "method"; +const char kStatus[] = "status"; +const char kUrl[] = "url"; +const char kVersion[] = "version"; +// When we server push, we will add [path: fully/qualified/url] to the server +// push headers so that the client will know what url the data corresponds to. +const char kPath[] = "path"; // Note: all protocol data structures are on-the-wire format. That means that // data is stored in network-normalized order. Readers must use the @@ -176,7 +200,9 @@ enum SpdySettingsIds { SETTINGS_MAX_CONCURRENT_STREAMS = 0x4, SETTINGS_CURRENT_CWND = 0x5, // Downstream byte retransmission rate in percentage. - SETTINGS_DOWNLOAD_RETRANS_RATE = 0x6 + SETTINGS_DOWNLOAD_RETRANS_RATE = 0x6, + // Initial window size in bytes + SETTINGS_INITIAL_WINDOW_SIZE = 0x7 }; // Status codes, as used in control frames (primarily RST_STREAM). @@ -196,7 +222,7 @@ enum SpdyStatusCodes { // A SPDY stream id is a 31 bit entity. typedef uint32 SpdyStreamId; -// A SPDY priority is a number between 0 and 4. +// A SPDY priority is a number between 0 and 3 (inclusive). typedef uint8 SpdyPriority; // SPDY Priorities. (there are only 2 bits) @@ -251,20 +277,38 @@ struct SpdyRstStreamControlFrameBlock : SpdyFrameBlock { uint32 status_; }; +// A SETTINGS Control Frame structure. +struct SpdySettingsControlFrameBlock : SpdyFrameBlock { + uint32 num_entries_; + // Variable data here. +}; + // A GOAWAY Control Frame structure. struct SpdyGoAwayControlFrameBlock : SpdyFrameBlock { SpdyStreamId last_accepted_stream_id_; }; +// A HEADERS Control Frame structure. +struct SpdyHeadersControlFrameBlock : SpdyFrameBlock { + SpdyStreamId stream_id_; + uint16 unused_; +}; + +// A WINDOW_UPDATE Control Frame structure +struct SpdyWindowUpdateControlFrameBlock : SpdyFrameBlock { + SpdyStreamId stream_id_; + uint32 delta_window_size_; +}; + // A structure for the 8 bit flags and 24 bit ID fields. union SettingsFlagsAndId { uint8 flags_[4]; // 8 bits uint32 id_; // 24 bits - SettingsFlagsAndId(uint32 val) : id_(val) {}; + SettingsFlagsAndId(uint32 val) : id_(val) {} uint8 flags() const { return flags_[0]; } void set_flags(uint8 flags) { flags_[0] = flags; } - uint32 id() const { return (ntohl(id_) & kSettingsIdMask); }; + uint32 id() const { return (ntohl(id_) & kSettingsIdMask); } void set_id(uint32 id) { DCHECK_EQ(0u, (id & ~kSettingsIdMask)); id = htonl(id & kSettingsIdMask); @@ -272,18 +316,6 @@ union SettingsFlagsAndId { } }; -// A SETTINGS Control Frame structure. -struct SpdySettingsControlFrameBlock : SpdyFrameBlock { - uint32 num_entries_; - // Variable data here. -}; - -// A WINDOW_UPDATE Control Frame structure -struct SpdyWindowUpdateControlFrameBlock : SpdyFrameBlock { - SpdyStreamId stream_id_; - uint32 delta_window_size_; -}; - #pragma pack(pop) // ------------------------------------------------------------------------- @@ -561,6 +593,42 @@ class SpdyRstStreamControlFrame : public SpdyControlFrame { DISALLOW_COPY_AND_ASSIGN(SpdyRstStreamControlFrame); }; +class SpdySettingsControlFrame : public SpdyControlFrame { + public: + SpdySettingsControlFrame() : SpdyControlFrame(size()) {} + SpdySettingsControlFrame(char* data, bool owns_buffer) + : SpdyControlFrame(data, owns_buffer) {} + + uint32 num_entries() const { + return ntohl(block()->num_entries_); + } + + void set_num_entries(int val) { + mutable_block()->num_entries_ = htonl(val); + } + + int header_block_len() const { + return length() - (size() - SpdyFrame::size()); + } + + const char* header_block() const { + return reinterpret_cast<const char*>(block()) + size(); + } + + // Returns the size of the SpdySettingsControlFrameBlock structure. + // Note: this is not the size of the SpdySettingsControlFrameBlock class. + static size_t size() { return sizeof(SpdySettingsControlFrameBlock); } + + private: + const struct SpdySettingsControlFrameBlock* block() const { + return static_cast<SpdySettingsControlFrameBlock*>(frame_); + } + struct SpdySettingsControlFrameBlock* mutable_block() { + return static_cast<SpdySettingsControlFrameBlock*>(frame_); + } + DISALLOW_COPY_AND_ASSIGN(SpdySettingsControlFrame); +}; + class SpdyGoAwayControlFrame : public SpdyControlFrame { public: SpdyGoAwayControlFrame() : SpdyControlFrame(size()) {} @@ -587,20 +655,22 @@ class SpdyGoAwayControlFrame : public SpdyControlFrame { DISALLOW_COPY_AND_ASSIGN(SpdyGoAwayControlFrame); }; -class SpdySettingsControlFrame : public SpdyControlFrame { +// A HEADERS frame. +class SpdyHeadersControlFrame : public SpdyControlFrame { public: - SpdySettingsControlFrame() : SpdyControlFrame(size()) {} - SpdySettingsControlFrame(char* data, bool owns_buffer) + SpdyHeadersControlFrame() : SpdyControlFrame(size()) {} + SpdyHeadersControlFrame(char* data, bool owns_buffer) : SpdyControlFrame(data, owns_buffer) {} - uint32 num_entries() const { - return ntohl(block()->num_entries_); + SpdyStreamId stream_id() const { + return ntohl(block()->stream_id_) & kStreamIdMask; } - void set_num_entries(int val) { - mutable_block()->num_entries_ = htonl(val); + void set_stream_id(SpdyStreamId id) { + mutable_block()->stream_id_ = htonl(id & kStreamIdMask); } + // The number of bytes in the header block beyond the frame header length. int header_block_len() const { return length() - (size() - SpdyFrame::size()); } @@ -609,18 +679,18 @@ class SpdySettingsControlFrame : public SpdyControlFrame { return reinterpret_cast<const char*>(block()) + size(); } - // Returns the size of the SpdySettingsControlFrameBlock structure. - // Note: this is not the size of the SpdySettingsControlFrameBlock class. - static size_t size() { return sizeof(SpdySettingsControlFrameBlock); } + // Returns the size of the SpdyHeadersControlFrameBlock structure. + // Note: this is not the size of the SpdyHeadersControlFrame class. + static size_t size() { return sizeof(SpdyHeadersControlFrameBlock); } private: - const struct SpdySettingsControlFrameBlock* block() const { - return static_cast<SpdySettingsControlFrameBlock*>(frame_); + const struct SpdyHeadersControlFrameBlock* block() const { + return static_cast<SpdyHeadersControlFrameBlock*>(frame_); } - struct SpdySettingsControlFrameBlock* mutable_block() { - return static_cast<SpdySettingsControlFrameBlock*>(frame_); + struct SpdyHeadersControlFrameBlock* mutable_block() { + return static_cast<SpdyHeadersControlFrameBlock*>(frame_); } - DISALLOW_COPY_AND_ASSIGN(SpdySettingsControlFrame); + DISALLOW_COPY_AND_ASSIGN(SpdyHeadersControlFrame); }; // A WINDOW_UPDATE frame. diff --git a/net/spdy/spdy_proxy_client_socket.cc b/net/spdy/spdy_proxy_client_socket.cc index 8066007..1cddb5e 100644 --- a/net/spdy/spdy_proxy_client_socket.cc +++ b/net/spdy/spdy_proxy_client_socket.cc @@ -390,13 +390,18 @@ int SpdyProxyClientSocket::OnResponseReceived( const spdy::SpdyHeaderBlock& response, base::Time response_time, int status) { - // Save the response - SpdyHeadersToHttpResponse(response, &response_); + // If we've already received the reply, existing headers are too late. + // TODO(mbelshe): figure out a way to make HEADERS frames useful after the + // initial response. + if (next_state_ != STATE_READ_REPLY_COMPLETE) + return OK; - DCHECK_EQ(next_state_, STATE_READ_REPLY_COMPLETE); + // Save the response + int rv = SpdyHeadersToHttpResponse(response, &response_); + if (rv == ERR_INCOMPLETE_SPDY_HEADERS) + return rv; // More headers are coming. OnIOComplete(status); - return OK; } diff --git a/net/spdy/spdy_proxy_client_socket.h b/net/spdy/spdy_proxy_client_socket.h index 4a0747e..5f2dd6f 100644 --- a/net/spdy/spdy_proxy_client_socket.h +++ b/net/spdy/spdy_proxy_client_socket.h @@ -62,7 +62,6 @@ class SpdyProxyClientSocket : public ClientSocket, public SpdyStream::Delegate { } // ClientSocket methods: - virtual int Connect(CompletionCallback* callback); virtual void Disconnect(); virtual bool IsConnected() const; @@ -74,43 +73,21 @@ class SpdyProxyClientSocket : public ClientSocket, public SpdyStream::Delegate { virtual bool UsingTCPFastOpen() const; // Socket methods: - virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); virtual int Write(IOBuffer* buf, int buf_len, CompletionCallback* callback); - virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); - virtual int GetPeerAddress(AddressList* address) const; // SpdyStream::Delegate methods: - - // Called when SYN frame has been sent. - // Returns true if no more data to be sent after SYN frame. virtual bool OnSendHeadersComplete(int status); - - // Called when stream is ready to send data. - // Returns network error code. OK when it successfully sent data. virtual int OnSendBody(); - - // Called when data has been sent. |status| indicates network error - // or number of bytes has been sent. - // Returns true if no more data to be sent. virtual bool OnSendBodyComplete(int status); - - // Called when SYN_STREAM or SYN_REPLY received. |status| indicates network - // error. Returns network error code. virtual int OnResponseReceived(const spdy::SpdyHeaderBlock& response, base::Time response_time, int status); - - // Called when data is received. virtual void OnDataReceived(const char* data, int length); - - // Called when data is sent. virtual void OnDataSent(int length); - - // Called when SpdyStream is closed. virtual void OnClose(int status); private: diff --git a/net/spdy/spdy_session.cc b/net/spdy/spdy_session.cc index 472e3ae..225c159 100644 --- a/net/spdy/spdy_session.cc +++ b/net/spdy/spdy_session.cc @@ -31,8 +31,12 @@ namespace net { NetLogSpdySynParameter::NetLogSpdySynParameter( const linked_ptr<spdy::SpdyHeaderBlock>& headers, spdy::SpdyControlFlags flags, - spdy::SpdyStreamId id) - : headers_(headers), flags_(flags), id_(id) { + spdy::SpdyStreamId id, + spdy::SpdyStreamId associated_stream) + : headers_(headers), + flags_(flags), + id_(id), + associated_stream_(associated_stream) { } NetLogSpdySynParameter::~NetLogSpdySynParameter() { @@ -49,6 +53,8 @@ Value* NetLogSpdySynParameter::ToValue() const { dict->SetInteger("flags", flags_); dict->Set("headers", headers_list); dict->SetInteger("id", id_); + if (associated_stream_) + dict->SetInteger("associated_stream", associated_stream_); return dict; } @@ -242,8 +248,8 @@ SpdySession::SpdySession(const HostPortProxyPair& host_port_proxy_pair, frames_received_(0), sent_settings_(false), received_settings_(false), - initial_send_window_size_(spdy::kInitialWindowSize), - initial_recv_window_size_(spdy::kInitialWindowSize), + initial_send_window_size_(spdy::kSpdyStreamInitialWindowSize), + initial_recv_window_size_(spdy::kSpdyStreamInitialWindowSize), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SPDY_SESSION)) { DCHECK(HttpStreamFactory::spdy_enabled()); net_log_.BeginEvent( @@ -273,6 +279,8 @@ SpdySession::~SpdySession() { DCHECK_EQ(0u, num_active_streams()); DCHECK_EQ(0u, num_unclaimed_pushed_streams()); + DCHECK(pending_callback_map_.empty()); + RecordHistograms(); net_log_.EndEvent(NetLog::TYPE_SPDY_SESSION, NULL); @@ -292,9 +300,8 @@ net::Error SpdySession::InitializeWithSocket( is_secure_ = is_secure; certificate_error_code_ = certificate_error_code; - // This is a newly initialized session that no client should have a handle to - // yet, so there's no need to start writing data as in OnTCPConnect(), but we - // should start reading data. + // Write out any data that we might have to send, such as the settings frame. + WriteSocketLater(); net::Error error = ReadSocket(); if (error == ERR_IO_PENDING) return OK; @@ -313,14 +320,13 @@ int SpdySession::GetPushStream( // encrypted SSL socket. if (is_secure_ && certificate_error_code_ != OK && (url.SchemeIs("https") || url.SchemeIs("wss"))) { - LOG(DFATAL) << "Tried to get pushed spdy stream for secure content over an " - << "unauthenticated session."; - return certificate_error_code_; + LOG(ERROR) << "Tried to get pushed spdy stream for secure content over an " + << "unauthenticated session."; + CloseSessionOnError(static_cast<net::Error>(certificate_error_code_), true); + return ERR_SPDY_PROTOCOL_ERROR; } - const std::string& path = url.PathForRequest(); - - *stream = GetActivePushStream(path); + *stream = GetActivePushStream(url.spec()); if (stream->get()) { DCHECK(streams_pushed_and_claimed_count_ < streams_pushed_count_); streams_pushed_and_claimed_count_++; @@ -360,11 +366,14 @@ void SpdySession::ProcessPendingCreateStreams() { pending_create.priority, pending_create.spdy_stream, *pending_create.stream_net_log); + scoped_refptr<SpdyStream>* stream = pending_create.spdy_stream; + DCHECK(!ContainsKey(pending_callback_map_, stream)); + pending_callback_map_[stream] = + CallbackResultPair(pending_create.callback, error); MessageLoop::current()->PostTask( FROM_HERE, method_factory_.NewRunnableMethod( - &SpdySession::InvokeUserStreamCreationCallback, - pending_create.callback, error)); + &SpdySession::InvokeUserStreamCreationCallback, stream)); break; } } @@ -375,6 +384,12 @@ void SpdySession::ProcessPendingCreateStreams() { void SpdySession::CancelPendingCreateStreams( const scoped_refptr<SpdyStream>* spdy_stream) { + PendingCallbackMap::iterator it = pending_callback_map_.find(spdy_stream); + if (it != pending_callback_map_.end()) { + pending_callback_map_.erase(it); + return; + } + for (int i = 0;i < NUM_PRIORITIES;++i) { PendingCreateStreamQueue tmp; // Make a copy removing this trans @@ -401,9 +416,10 @@ int SpdySession::CreateStreamImpl( // encrypted SSL socket. if (is_secure_ && certificate_error_code_ != OK && (url.SchemeIs("https") || url.SchemeIs("wss"))) { - LOG(DFATAL) << "Tried to create spdy stream for secure content over an " - << "unauthenticated session."; - return certificate_error_code_; + LOG(ERROR) << "Tried to create spdy stream for secure content over an " + << "unauthenticated session."; + CloseSessionOnError(static_cast<net::Error>(certificate_error_code_), true); + return ERR_SPDY_PROTOCOL_ERROR; } const std::string& path = url.PathForRequest(); @@ -458,7 +474,7 @@ int SpdySession::WriteSynStream( net_log().AddEvent( NetLog::TYPE_SPDY_SESSION_SYN_STREAM, make_scoped_refptr( - new NetLogSpdySynParameter(headers, flags, stream_id))); + new NetLogSpdySynParameter(headers, flags, stream_id, 0))); } return ERR_IO_PENDING; @@ -983,7 +999,6 @@ void SpdySession::OnStreamFrameData(spdy::SpdyStreamId stream_id, bool SpdySession::Respond(const spdy::SpdyHeaderBlock& headers, const scoped_refptr<SpdyStream> stream) { int rv = OK; - rv = stream->OnResponseReceived(headers); if (rv < 0) { DCHECK_NE(rv, ERR_IO_PENDING); @@ -1004,24 +1019,24 @@ void SpdySession::OnSyn(const spdy::SpdySynStreamControlFrame& frame, NetLog::TYPE_SPDY_SESSION_PUSHED_SYN_STREAM, make_scoped_refptr(new NetLogSpdySynParameter( headers, static_cast<spdy::SpdyControlFlags>(frame.flags()), - stream_id))); + stream_id, associated_stream_id))); } // Server-initiated streams should have even sequence numbers. if ((stream_id & 0x1) != 0) { - LOG(ERROR) << "Received invalid OnSyn stream id " << stream_id; + LOG(WARNING) << "Received invalid OnSyn stream id " << stream_id; return; } if (IsStreamActive(stream_id)) { - LOG(ERROR) << "Received OnSyn for active stream " << stream_id; + LOG(WARNING) << "Received OnSyn for active stream " << stream_id; return; } if (associated_stream_id == 0) { - LOG(ERROR) << "Received invalid OnSyn associated stream id " - << associated_stream_id - << " for stream " << stream_id; + LOG(WARNING) << "Received invalid OnSyn associated stream id " + << associated_stream_id + << " for stream " << stream_id; ResetStream(stream_id, spdy::INVALID_STREAM); return; } @@ -1030,29 +1045,44 @@ void SpdySession::OnSyn(const spdy::SpdySynStreamControlFrame& frame, // TODO(mbelshe): DCHECK that this is a GET method? - const std::string& path = ContainsKey(*headers, "path") ? - headers->find("path")->second : ""; - // Verify that the response had a URL for us. - if (path.empty()) { + const std::string& url = ContainsKey(*headers, "url") ? + headers->find("url")->second : ""; + if (url.empty()) { + ResetStream(stream_id, spdy::PROTOCOL_ERROR); + LOG(WARNING) << "Pushed stream did not contain a url."; + return; + } + + GURL gurl(url); + if (!gurl.is_valid()) { ResetStream(stream_id, spdy::PROTOCOL_ERROR); - LOG(WARNING) << "Pushed stream did not contain a path."; + LOG(WARNING) << "Pushed stream url was invalid: " << url; return; } + // Verify we have a valid stream association. if (!IsStreamActive(associated_stream_id)) { - LOG(ERROR) << "Received OnSyn with inactive associated stream " + LOG(WARNING) << "Received OnSyn with inactive associated stream " << associated_stream_id; ResetStream(stream_id, spdy::INVALID_ASSOCIATED_STREAM); return; } - // TODO(erikchen): Actually do something with the associated id. + scoped_refptr<SpdyStream> associated_stream = + active_streams_[associated_stream_id]; + GURL associated_url(associated_stream->GetUrl()); + if (associated_url.GetOrigin() != gurl.GetOrigin()) { + LOG(WARNING) << "Rejected Cross Origin Push Stream " + << associated_stream_id; + ResetStream(stream_id, spdy::REFUSED_STREAM); + return; + } // There should not be an existing pushed stream with the same path. - PushedStreamMap::iterator it = unclaimed_pushed_streams_.find(path); + PushedStreamMap::iterator it = unclaimed_pushed_streams_.find(url); if (it != unclaimed_pushed_streams_.end()) { - LOG(ERROR) << "Received duplicate pushed stream with path: " << path; + LOG(WARNING) << "Received duplicate pushed stream with url: " << url; ResetStream(stream_id, spdy::PROTOCOL_ERROR); return; } @@ -1060,9 +1090,9 @@ void SpdySession::OnSyn(const spdy::SpdySynStreamControlFrame& frame, scoped_refptr<SpdyStream> stream( new SpdyStream(this, stream_id, true, net_log_)); - stream->set_path(path); + stream->set_path(gurl.PathForRequest()); - unclaimed_pushed_streams_[path] = stream; + unclaimed_pushed_streams_[url] = stream; ActivateStream(stream); stream->set_response_received(); @@ -1102,25 +1132,62 @@ void SpdySession::OnSynReply(const spdy::SpdySynReplyControlFrame& frame, NetLog::TYPE_SPDY_SESSION_SYN_REPLY, make_scoped_refptr(new NetLogSpdySynParameter( headers, static_cast<spdy::SpdyControlFlags>(frame.flags()), - stream_id))); + stream_id, 0))); } Respond(*headers, stream); } +void SpdySession::OnHeaders(const spdy::SpdyHeadersControlFrame& frame, + const linked_ptr<spdy::SpdyHeaderBlock>& headers) { + spdy::SpdyStreamId stream_id = frame.stream_id(); + + bool valid_stream = IsStreamActive(stream_id); + if (!valid_stream) { + // NOTE: it may just be that the stream was cancelled. + LOG(WARNING) << "Received HEADERS for invalid stream " << stream_id; + return; + } + + scoped_refptr<SpdyStream> stream = active_streams_[stream_id]; + CHECK_EQ(stream->stream_id(), stream_id); + CHECK(!stream->cancelled()); + + if (net_log().IsLoggingAllEvents()) { + net_log().AddEvent( + NetLog::TYPE_SPDY_SESSION_HEADERS, + make_scoped_refptr(new NetLogSpdySynParameter( + headers, static_cast<spdy::SpdyControlFlags>(frame.flags()), + stream_id, 0))); + } + + int rv = stream->OnHeaders(*headers); + if (rv < 0) { + DCHECK_NE(rv, ERR_IO_PENDING); + const spdy::SpdyStreamId stream_id = stream->stream_id(); + DeleteStream(stream_id, rv); + } +} + void SpdySession::OnControl(const spdy::SpdyControlFrame* frame) { const linked_ptr<spdy::SpdyHeaderBlock> headers(new spdy::SpdyHeaderBlock); uint32 type = frame->type(); - if (type == spdy::SYN_STREAM || type == spdy::SYN_REPLY) { + if (type == spdy::SYN_STREAM || + type == spdy::SYN_REPLY || + type == spdy::HEADERS) { if (!spdy_framer_.ParseHeaderBlock(frame, headers.get())) { LOG(WARNING) << "Could not parse Spdy Control Frame Header."; int stream_id = 0; - if (type == spdy::SYN_STREAM) + if (type == spdy::SYN_STREAM) { stream_id = (reinterpret_cast<const spdy::SpdySynStreamControlFrame*> (frame))->stream_id(); - if (type == spdy::SYN_REPLY) + } else if (type == spdy::SYN_REPLY) { stream_id = (reinterpret_cast<const spdy::SpdySynReplyControlFrame*> (frame))->stream_id(); + } else if (type == spdy::HEADERS) { + stream_id = (reinterpret_cast<const spdy::SpdyHeadersControlFrame*> + (frame))->stream_id(); + } if(IsStreamActive(stream_id)) ResetStream(stream_id, spdy::PROTOCOL_ERROR); return; @@ -1144,6 +1211,10 @@ void SpdySession::OnControl(const spdy::SpdyControlFrame* frame) { OnSyn(*reinterpret_cast<const spdy::SpdySynStreamControlFrame*>(frame), headers); break; + case spdy::HEADERS: + OnHeaders(*reinterpret_cast<const spdy::SpdyHeadersControlFrame*>(frame), + headers); + break; case spdy::SYN_REPLY: OnSynReply( *reinterpret_cast<const spdy::SpdySynReplyControlFrame*>(frame), @@ -1342,8 +1413,17 @@ void SpdySession::RecordHistograms() { } void SpdySession::InvokeUserStreamCreationCallback( - CompletionCallback* callback, int rv) { - callback->Run(rv); + scoped_refptr<SpdyStream>* stream) { + PendingCallbackMap::iterator it = pending_callback_map_.find(stream); + + // Exit if the request has already been cancelled. + if (it == pending_callback_map_.end()) + return; + + CompletionCallback* callback = it->second.callback; + int result = it->second.result; + pending_callback_map_.erase(it); + callback->Run(result); } } // namespace net diff --git a/net/spdy/spdy_session.h b/net/spdy/spdy_session.h index 24c8e42..210e9af 100644 --- a/net/spdy/spdy_session.h +++ b/net/spdy/spdy_session.h @@ -230,6 +230,18 @@ class SpdySession : public base::RefCounted<SpdySession>, typedef std::map<std::string, scoped_refptr<SpdyStream> > PushedStreamMap; typedef std::priority_queue<SpdyIOBuffer> OutputQueue; + struct CallbackResultPair { + CallbackResultPair() : callback(NULL), result(OK) {} + CallbackResultPair(CompletionCallback* callback_in, int result_in) + : callback(callback_in), result(result_in) {} + + CompletionCallback* callback; + int result; + }; + + typedef std::map<const scoped_refptr<SpdyStream>*, CallbackResultPair> + PendingCallbackMap; + virtual ~SpdySession(); void ProcessPendingCreateStreams(); @@ -251,6 +263,8 @@ class SpdySession : public base::RefCounted<SpdySession>, const linked_ptr<spdy::SpdyHeaderBlock>& headers); void OnSynReply(const spdy::SpdySynReplyControlFrame& frame, const linked_ptr<spdy::SpdyHeaderBlock>& headers); + void OnHeaders(const spdy::SpdyHeadersControlFrame& frame, + const linked_ptr<spdy::SpdyHeaderBlock>& headers); void OnRst(const spdy::SpdyRstStreamControlFrame& frame); void OnGoAway(const spdy::SpdyGoAwayControlFrame& frame); void OnSettings(const spdy::SpdySettingsControlFrame& frame); @@ -309,7 +323,7 @@ class SpdySession : public base::RefCounted<SpdySession>, // Invokes a user callback for stream creation. We provide this method so it // can be deferred to the MessageLoop, so we avoid re-entrancy problems. - void InvokeUserStreamCreationCallback(CompletionCallback* callback, int rv); + void InvokeUserStreamCreationCallback(scoped_refptr<SpdyStream>* stream); // Callbacks for the Spdy session. CompletionCallbackImpl<SpdySession> read_callback_; @@ -321,6 +335,11 @@ class SpdySession : public base::RefCounted<SpdySession>, // method. ScopedRunnableMethodFactory<SpdySession> method_factory_; + // Map of the SpdyStreams for which we have a pending Task to invoke a + // callback. This is necessary since, before we invoke said callback, it's + // possible that the request is cancelled. + PendingCallbackMap pending_callback_map_; + // The domain this session is connected to. const HostPortProxyPair host_port_proxy_pair_; @@ -415,7 +434,8 @@ class NetLogSpdySynParameter : public NetLog::EventParameters { public: NetLogSpdySynParameter(const linked_ptr<spdy::SpdyHeaderBlock>& headers, spdy::SpdyControlFlags flags, - spdy::SpdyStreamId id); + spdy::SpdyStreamId id, + spdy::SpdyStreamId associated_stream); virtual Value* ToValue() const; @@ -429,6 +449,7 @@ class NetLogSpdySynParameter : public NetLog::EventParameters { const linked_ptr<spdy::SpdyHeaderBlock> headers_; const spdy::SpdyControlFlags flags_; const spdy::SpdyStreamId id_; + const spdy::SpdyStreamId associated_stream_; DISALLOW_COPY_AND_ASSIGN(NetLogSpdySynParameter); }; diff --git a/net/spdy/spdy_session_unittest.cc b/net/spdy/spdy_session_unittest.cc index 0cfbe48..8b76cd9 100644 --- a/net/spdy/spdy_session_unittest.cc +++ b/net/spdy/spdy_session_unittest.cc @@ -236,6 +236,158 @@ TEST_F(SpdySessionTest, OnSettings) { EXPECT_EQ(OK, stream_releaser.WaitForResult()); } +// Start with max concurrent streams set to 1. Request two streams. When the +// first completes, have the callback close itself, which should trigger the +// second stream creation. Then cancel that one immediately. Don't crash. +// http://crbug.com/63532 +TEST_F(SpdySessionTest, CancelPendingCreateStream) { + SpdySessionDependencies session_deps; + session_deps.host_resolver->set_synchronous_mode(true); + + MockRead reads[] = { + MockRead(false, ERR_IO_PENDING) // Stall forever. + }; + + StaticSocketDataProvider data(reads, arraysize(reads), NULL, 0); + MockConnect connect_data(false, OK); + + data.set_connect_data(connect_data); + session_deps.socket_factory->AddSocketDataProvider(&data); + + SSLSocketDataProvider ssl(false, OK); + session_deps.socket_factory->AddSSLSocketDataProvider(&ssl); + + scoped_refptr<HttpNetworkSession> http_session( + SpdySessionDependencies::SpdyCreateSession(&session_deps)); + + const std::string kTestHost("www.foo.com"); + const int kTestPort = 80; + HostPortPair test_host_port_pair(kTestHost, kTestPort); + HostPortProxyPair pair(test_host_port_pair, ProxyServer::Direct()); + + // Initialize the SpdySettingsStorage with 1 max concurrent streams. + spdy::SpdySettings settings; + spdy::SettingsFlagsAndId id(spdy::SETTINGS_MAX_CONCURRENT_STREAMS); + id.set_id(spdy::SETTINGS_MAX_CONCURRENT_STREAMS); + id.set_flags(spdy::SETTINGS_FLAG_PLEASE_PERSIST); + settings.push_back(spdy::SpdySetting(id, 1)); + http_session->mutable_spdy_settings()->Set(test_host_port_pair, settings); + + // Create a session. + SpdySessionPool* spdy_session_pool(http_session->spdy_session_pool()); + EXPECT_FALSE(spdy_session_pool->HasSession(pair)); + scoped_refptr<SpdySession> session = + spdy_session_pool->Get(pair, http_session->mutable_spdy_settings(), + BoundNetLog()); + ASSERT_TRUE(spdy_session_pool->HasSession(pair)); + + scoped_refptr<TCPSocketParams> tcp_params( + new TCPSocketParams(kTestHost, kTestPort, MEDIUM, GURL(), false)); + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + EXPECT_EQ(OK, + connection->Init(test_host_port_pair.ToString(), tcp_params, MEDIUM, + NULL, http_session->tcp_socket_pool(), + BoundNetLog())); + EXPECT_EQ(OK, session->InitializeWithSocket(connection.release(), false, OK)); + + // Use scoped_ptr to let us invalidate the memory when we want to, to trigger + // a valgrind error if the callback is invoked when it's not supposed to be. + scoped_ptr<TestCompletionCallback> callback(new TestCompletionCallback); + + // Create 2 streams. First will succeed. Second will be pending. + scoped_refptr<SpdyStream> spdy_stream1; + GURL url("http://www.google.com"); + ASSERT_EQ(OK, + session->CreateStream(url, + MEDIUM, /* priority, not important */ + &spdy_stream1, + BoundNetLog(), + callback.get())); + + scoped_refptr<SpdyStream> spdy_stream2; + ASSERT_EQ(ERR_IO_PENDING, + session->CreateStream(url, + MEDIUM, /* priority, not important */ + &spdy_stream2, + BoundNetLog(), + callback.get())); + + // Release the first one, this will allow the second to be created. + spdy_stream1->Cancel(); + spdy_stream1 = NULL; + + session->CancelPendingCreateStreams(&spdy_stream2); + callback.reset(); + + // Should not crash when running the pending callback. + MessageLoop::current()->RunAllPending(); +} + +TEST_F(SpdySessionTest, SendSettingsOnNewSession) { + SpdySessionDependencies session_deps; + session_deps.host_resolver->set_synchronous_mode(true); + + MockRead reads[] = { + MockRead(false, ERR_IO_PENDING) // Stall forever. + }; + + // Create the bogus setting that we want to verify is sent out. + // Note that it will be marked as SETTINGS_FLAG_PERSISTED when sent out. But + // to set it into the SpdySettingsStorage, we need to mark as + // SETTINGS_FLAG_PLEASE_PERSIST. + spdy::SpdySettings settings; + const uint32 kBogusSettingId = 0xABAB; + const uint32 kBogusSettingValue = 0xCDCD; + spdy::SettingsFlagsAndId id(kBogusSettingId); + id.set_id(kBogusSettingId); + id.set_flags(spdy::SETTINGS_FLAG_PERSISTED); + settings.push_back(spdy::SpdySetting(id, kBogusSettingValue)); + MockConnect connect_data(false, OK); + scoped_ptr<spdy::SpdyFrame> settings_frame( + ConstructSpdySettings(settings)); + MockWrite writes[] = { + CreateMockWrite(*settings_frame), + }; + + StaticSocketDataProvider data( + reads, arraysize(reads), writes, arraysize(writes)); + data.set_connect_data(connect_data); + session_deps.socket_factory->AddSocketDataProvider(&data); + + SSLSocketDataProvider ssl(false, OK); + session_deps.socket_factory->AddSSLSocketDataProvider(&ssl); + + scoped_refptr<HttpNetworkSession> http_session( + SpdySessionDependencies::SpdyCreateSession(&session_deps)); + + const std::string kTestHost("www.foo.com"); + const int kTestPort = 80; + HostPortPair test_host_port_pair(kTestHost, kTestPort); + HostPortProxyPair pair(test_host_port_pair, ProxyServer::Direct()); + + id.set_flags(spdy::SETTINGS_FLAG_PLEASE_PERSIST); + settings.clear(); + settings.push_back(spdy::SpdySetting(id, kBogusSettingValue)); + http_session->mutable_spdy_settings()->Set(test_host_port_pair, settings); + SpdySessionPool* spdy_session_pool(http_session->spdy_session_pool()); + EXPECT_FALSE(spdy_session_pool->HasSession(pair)); + scoped_refptr<SpdySession> session = + spdy_session_pool->Get(pair, http_session->mutable_spdy_settings(), + BoundNetLog()); + EXPECT_TRUE(spdy_session_pool->HasSession(pair)); + + scoped_refptr<TCPSocketParams> tcp_params( + new TCPSocketParams(kTestHost, kTestPort, MEDIUM, GURL(), false)); + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + EXPECT_EQ(OK, + connection->Init(test_host_port_pair.ToString(), tcp_params, MEDIUM, + NULL, http_session->tcp_socket_pool(), + BoundNetLog())); + EXPECT_EQ(OK, session->InitializeWithSocket(connection.release(), false, OK)); + MessageLoop::current()->RunAllPending(); + EXPECT_TRUE(data.at_write_eof()); +} + } // namespace } // namespace net diff --git a/net/spdy/spdy_stream.cc b/net/spdy/spdy_stream.cc index f26dd6d..8a5d4a4 100644 --- a/net/spdy/spdy_stream.cc +++ b/net/spdy/spdy_stream.cc @@ -44,8 +44,8 @@ SpdyStream::SpdyStream(SpdySession* session, stream_id_(stream_id), priority_(0), stalled_by_flow_control_(false), - send_window_size_(spdy::kInitialWindowSize), - recv_window_size_(spdy::kInitialWindowSize), + send_window_size_(spdy::kSpdyStreamInitialWindowSize), + recv_window_size_(spdy::kSpdyStreamInitialWindowSize), pushed_(pushed), metrics_(Singleton<BandwidthMetrics>::get()), response_received_(false), @@ -60,14 +60,10 @@ SpdyStream::SpdyStream(SpdySession* session, net_log_(net_log), send_bytes_(0), recv_bytes_(0) { - net_log_.BeginEvent( - NetLog::TYPE_SPDY_STREAM, - make_scoped_refptr(new NetLogIntegerParameter("stream_id", stream_id_))); } SpdyStream::~SpdyStream() { UpdateHistograms(); - net_log_.EndEvent(NetLog::TYPE_SPDY_STREAM, NULL); } void SpdyStream::SetDelegate(Delegate* delegate) { @@ -88,9 +84,17 @@ void SpdyStream::PushedStreamReplayData() { if (cancelled_ || !delegate_) return; - delegate_->OnResponseReceived(*response_, response_time_, OK); - continue_buffering_data_ = false; + + int rv = delegate_->OnResponseReceived(*response_, response_time_, OK); + if (rv == ERR_INCOMPLETE_SPDY_HEADERS) { + // We don't have complete headers. Assume we're waiting for another + // HEADERS frame. Since we don't have headers, we had better not have + // any pending data frames. + DCHECK_EQ(0U, pending_buffers_.size()); + return; + } + std::vector<scoped_refptr<IOBufferWithSize> > buffers; buffers.swap(pending_buffers_); for (size_t i = 0; i < buffers.size(); ++i) { @@ -253,9 +257,44 @@ int SpdyStream::OnResponseReceived(const spdy::SpdyHeaderBlock& response) { return rv; } +int SpdyStream::OnHeaders(const spdy::SpdyHeaderBlock& headers) { + DCHECK(!response_->empty()); + + // Append all the headers into the response header block. + for (spdy::SpdyHeaderBlock::const_iterator it = headers.begin(); + it != headers.end(); ++it) { + // Disallow duplicate headers. This is just to be conservative. + if ((*response_).find(it->first) != (*response_).end()) { + LOG(WARNING) << "HEADERS duplicate header"; + response_status_ = ERR_SPDY_PROTOCOL_ERROR; + return ERR_SPDY_PROTOCOL_ERROR; + } + + (*response_)[it->first] = it->second; + } + + int rv = OK; + if (delegate_) { + rv = delegate_->OnResponseReceived(*response_, response_time_, rv); + // ERR_INCOMPLETE_SPDY_HEADERS means that we are waiting for more + // headers before the response header block is complete. + if (rv == ERR_INCOMPLETE_SPDY_HEADERS) + rv = OK; + } + return rv; +} + void SpdyStream::OnDataReceived(const char* data, int length) { DCHECK_GE(length, 0); + // If we don't have a response, then the SYN_REPLY did not come through. + // We cannot pass data up to the caller unless the reply headers have been + // received. + if (!response_received()) { + session_->CloseStream(stream_id_, ERR_SYN_REPLY_NOT_RECEIVED); + return; + } + if (!delegate_ || continue_buffering_data_) { // It should be valid for this to happen in the server push case. // We'll return received data when delegate gets attached to the stream. @@ -272,15 +311,7 @@ void SpdyStream::OnDataReceived(const char* data, int length) { return; } - CHECK(!closed()); - - // If we don't have a response, then the SYN_REPLY did not come through. - // We cannot pass data up to the caller unless the reply headers have been - // received. - if (!response_received()) { - session_->CloseStream(stream_id_, ERR_SYN_REPLY_NOT_RECEIVED); - return; - } + CHECK(!closed()); // A zero-length read means that the stream is being closed. if (!length) { @@ -366,6 +397,43 @@ bool SpdyStream::GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) { return session_->GetSSLCertRequestInfo(cert_request_info); } +bool SpdyStream::HasUrl() const { + if (pushed_) + return response_received(); + return request_.get() != NULL; +} + +GURL SpdyStream::GetUrl() const { + DCHECK(HasUrl()); + + if (pushed_) { + // assemble from the response + std::string url; + spdy::SpdyHeaderBlock::const_iterator it; + it = response_->find("url"); + if (it != (*response_).end()) + url = it->second; + return GURL(url); + } + + // assemble from the request + std::string scheme; + std::string host_port; + std::string path; + spdy::SpdyHeaderBlock::const_iterator it; + it = request_->find("scheme"); + if (it != (*request_).end()) + scheme = it->second; + it = request_->find("host"); + if (it != (*request_).end()) + host_port = it->second; + it = request_->find("path"); + if (it != (*request_).end()) + path = it->second; + std::string url = scheme + "://" + host_port + path; + return GURL(url); +} + int SpdyStream::DoLoop(int result) { do { State state = io_state_; diff --git a/net/spdy/spdy_stream.h b/net/spdy/spdy_stream.h index 4e30a60..9a055c2 100644 --- a/net/spdy/spdy_stream.h +++ b/net/spdy/spdy_stream.h @@ -13,6 +13,7 @@ #include "base/linked_ptr.h" #include "base/ref_counted.h" #include "base/scoped_ptr.h" +#include "googleurl/src/gurl.h" #include "net/base/bandwidth_metrics.h" #include "net/base/io_buffer.h" #include "net/base/net_log.h" @@ -53,8 +54,12 @@ class SpdyStream : public base::RefCounted<SpdyStream> { // Returns true if no more data to be sent. virtual bool OnSendBodyComplete(int status) = 0; - // Called when SYN_STREAM or SYN_REPLY received. |status| indicates network - // error. Returns network error code. + // Called when the SYN_STREAM, SYN_REPLY, or HEADERS frames are received. + // Normal streams will receive a SYN_REPLY and optional HEADERS frames. + // Pushed streams will receive a SYN_STREAM and optional HEADERS frames. + // Because a stream may have a SYN_* frame and multiple HEADERS frames, + // this callback may be called multiple times. + // |status| indicates network error. Returns network error code. virtual int OnResponseReceived(const spdy::SpdyHeaderBlock& response, base::Time response_time, int status) = 0; @@ -157,6 +162,10 @@ class SpdyStream : public base::RefCounted<SpdyStream> { // has been received for this stream. Returns a status code. int OnResponseReceived(const spdy::SpdyHeaderBlock& response); + // Called by the SpdySession when late-bound headers are received for a + // stream. Returns a status code. + int OnHeaders(const spdy::SpdyHeaderBlock& headers); + // Called by the SpdySession when response data has been received for this // stream. This callback may be called multiple times as data arrives // from the network, and will never be called prior to OnResponseReceived. @@ -206,6 +215,13 @@ class SpdyStream : public base::RefCounted<SpdyStream> { int response_status() const { return response_status_; } + // Returns true if the URL for this stream is known. + bool HasUrl() const; + + // Get the URL associated with this stream. Only valid when has_url() is + // true. + GURL GetUrl() const; + private: enum State { STATE_NONE, diff --git a/net/spdy/spdy_stream_unittest.cc b/net/spdy/spdy_stream_unittest.cc index cdb116f..3bf0a23 100644 --- a/net/spdy/spdy_stream_unittest.cc +++ b/net/spdy/spdy_stream_unittest.cc @@ -150,8 +150,12 @@ TEST_F(SpdyStreamTest, SendDataAfterOpen) { static const char* const kGetHeaders[] = { "method", "GET", - "url", - "http://www.google.com/", + "scheme", + "http", + "host", + "www.google.com", + "path", + "/", "version", "HTTP/1.1", }; @@ -189,7 +193,8 @@ TEST_F(SpdyStreamTest, SendDataAfterOpen) { SpdySession::SetSSLMode(false); scoped_refptr<SpdySession> session(CreateSpdySession()); - GURL url("http://www.google.com/"); + const char* kStreamUrl = "http://www.google.com/"; + GURL url(kStreamUrl); HostPortPair host_port_pair("www.google.com", 80); scoped_refptr<TCPSocketParams> tcp_params( @@ -213,11 +218,17 @@ TEST_F(SpdyStreamTest, SendDataAfterOpen) { new TestSpdyStreamDelegate(stream.get(), buf.get(), &callback)); stream->SetDelegate(delegate.get()); + EXPECT_FALSE(stream->HasUrl()); + linked_ptr<spdy::SpdyHeaderBlock> headers(new spdy::SpdyHeaderBlock); (*headers)["method"] = "GET"; - (*headers)["url"] = "http://www.google.com/"; + (*headers)["scheme"] = url.scheme(); + (*headers)["host"] = url.host(); + (*headers)["path"] = url.path(); (*headers)["version"] = "HTTP/1.1"; stream->set_spdy_headers(headers); + EXPECT_TRUE(stream->HasUrl()); + EXPECT_EQ(kStreamUrl, stream->GetUrl().spec()); EXPECT_EQ(ERR_IO_PENDING, stream->SendRequest(true)); @@ -231,4 +242,39 @@ TEST_F(SpdyStreamTest, SendDataAfterOpen) { EXPECT_TRUE(delegate->closed()); } +TEST_F(SpdyStreamTest, PushedStream) { + const char kStreamUrl[] = "http://www.google.com/"; + + SpdySessionDependencies session_deps; + session_ = SpdySessionDependencies::SpdyCreateSession(&session_deps); + SpdySessionPoolPeer pool_peer_(session_->spdy_session_pool()); + scoped_refptr<SpdySession> spdy_session(CreateSpdySession()); + BoundNetLog net_log; + + // Conjure up a stream. + scoped_refptr<SpdyStream> stream = new SpdyStream(spdy_session, + 2, + true, + net_log); + EXPECT_FALSE(stream->response_received()); + EXPECT_FALSE(stream->HasUrl()); + + // Set a couple of headers. + spdy::SpdyHeaderBlock response; + response["url"] = kStreamUrl; + stream->OnResponseReceived(response); + + // Send some basic headers. + spdy::SpdyHeaderBlock headers; + response["status"] = "200"; + response["version"] = "OK"; + stream->OnHeaders(headers); + + stream->set_response_received(); + EXPECT_TRUE(stream->response_received()); + EXPECT_TRUE(stream->HasUrl()); + EXPECT_EQ(kStreamUrl, stream->GetUrl().spec()); +} + + } // namespace net diff --git a/net/spdy/spdy_test_util.cc b/net/spdy/spdy_test_util.cc index 896cd9d..3a7f771 100644 --- a/net/spdy/spdy_test_util.cc +++ b/net/spdy/spdy_test_util.cc @@ -171,6 +171,10 @@ spdy::SpdyFrame* ConstructSpdyPacket(const SpdyHeaderInfo& header_info, case spdy::RST_STREAM: frame = framer.CreateRstStream(header_info.id, header_info.status); break; + case spdy::HEADERS: + frame = framer.CreateHeaders(header_info.id, header_info.control_flags, + header_info.compressed, &headers); + break; default: frame = framer.CreateDataFrame(header_info.id, header_info.data, header_info.data_length, @@ -330,7 +334,11 @@ spdy::SpdyFrame* ConstructSpdyGet(const char* const url, // This is so ugly. Why are we using char* in here again? std::string str_path = gurl.PathForRequest(); std::string str_scheme = gurl.scheme(); - std::string str_host = gurl.host(); // TODO(mbelshe): should have a port. + std::string str_host = gurl.host(); + if (gurl.has_port()) { + str_host += ":"; + str_host += gurl.port(); + } scoped_array<char> req(new char[str_path.size() + 1]); scoped_array<char> scheme(new char[str_scheme.size() + 1]); scoped_array<char> host(new char[str_host.size() + 1]); @@ -460,16 +468,14 @@ spdy::SpdyFrame* ConstructSpdyPush(const char* const extra_headers[], int extra_header_count, int stream_id, int associated_stream_id, - const char* path) { + const char* url) { const char* const kStandardGetHeaders[] = { "hello", "bye", - "path", - path, "status", "200 OK", "url", - path, + url, "version", "HTTP/1.1" }; @@ -489,15 +495,12 @@ spdy::SpdyFrame* ConstructSpdyPush(const char* const extra_headers[], int extra_header_count, int stream_id, int associated_stream_id, - const char* path, + const char* url, const char* status, - const char* location, - const char* url) { + const char* location) { const char* const kStandardGetHeaders[] = { "hello", "bye", - "path", - path, "status", status, "location", @@ -519,6 +522,45 @@ spdy::SpdyFrame* ConstructSpdyPush(const char* const extra_headers[], associated_stream_id); } +spdy::SpdyFrame* ConstructSpdyPush(int stream_id, + int associated_stream_id, + const char* url) { + const char* const kStandardGetHeaders[] = { + "url", + url + }; + return ConstructSpdyControlFrame(0, + 0, + false, + stream_id, + LOWEST, + spdy::SYN_STREAM, + spdy::CONTROL_FLAG_NONE, + kStandardGetHeaders, + arraysize(kStandardGetHeaders), + associated_stream_id); +} + +spdy::SpdyFrame* ConstructSpdyPushHeaders(int stream_id, + const char* const extra_headers[], + int extra_header_count) { + const char* const kStandardGetHeaders[] = { + "status", + "200 OK", + "version", + "HTTP/1.1" + }; + return ConstructSpdyControlFrame(extra_headers, + extra_header_count, + false, + stream_id, + LOWEST, + spdy::HEADERS, + spdy::CONTROL_FLAG_NONE, + kStandardGetHeaders, + arraysize(kStandardGetHeaders)); +} + // Constructs a standard SPDY SYN_REPLY packet with the specified status code. // Returns a SpdyFrame. spdy::SpdyFrame* ConstructSpdySynReplyError( @@ -580,8 +622,6 @@ spdy::SpdyFrame* ConstructSpdyGetSynReply(const char* const extra_headers[], "bye", "status", "200", - "url", - "/index.php", "version", "HTTP/1.1" }; diff --git a/net/spdy/spdy_test_util.h b/net/spdy/spdy_test_util.h index aecf08e..0a5d2e0 100644 --- a/net/spdy/spdy_test_util.h +++ b/net/spdy/spdy_test_util.h @@ -229,16 +229,22 @@ spdy::SpdyFrame* ConstructSpdyPush(const char* const extra_headers[], int extra_header_count, int stream_id, int associated_stream_id, - const char* path); + const char* url); spdy::SpdyFrame* ConstructSpdyPush(const char* const extra_headers[], int extra_header_count, int stream_id, int associated_stream_id, - const char* path, + const char* url, const char* status, - const char* location, + const char* location); +spdy::SpdyFrame* ConstructSpdyPush(int stream_id, + int associated_stream_id, const char* url); +spdy::SpdyFrame* ConstructSpdyPushHeaders(int stream_id, + const char* const extra_headers[], + int extra_header_count); + // Constructs a standard SPDY SYN_REPLY packet to match the SPDY GET. // |extra_headers| are the extra header-value pairs, which typically // will vary the most between calls. @@ -358,6 +364,7 @@ class SpdySessionDependencies { SpdySessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, session_deps->proxy_service, session_deps->socket_factory.get(), @@ -371,6 +378,7 @@ class SpdySessionDependencies { SpdySessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, session_deps->proxy_service, session_deps-> @@ -395,6 +403,7 @@ class SpdyURLRequestContext : public URLRequestContext { new HttpNetworkLayer(&socket_factory_, host_resolver_, NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, proxy_service_, ssl_config_service_, diff --git a/net/test/python_utils.cc b/net/test/python_utils.cc index 13438f7..0c61b48 100644 --- a/net/test/python_utils.cc +++ b/net/test/python_utils.cc @@ -70,7 +70,7 @@ bool GetPythonRunTime(FilePath* dir) { if (!PathService::Get(base::DIR_SOURCE_ROOT, dir)) return false; *dir = dir->Append(FILE_PATH_LITERAL("third_party")) - .Append(FILE_PATH_LITERAL("python_24")) + .Append(FILE_PATH_LITERAL("python_26")) .Append(FILE_PATH_LITERAL("python.exe")); #elif defined(OS_POSIX) *dir = FilePath("python"); diff --git a/net/test/test_server.cc b/net/test/test_server.cc index 0e8c461..0eaf8b5 100644 --- a/net/test/test_server.cc +++ b/net/test/test_server.cc @@ -17,11 +17,14 @@ #include "base/base64.h" #include "base/command_line.h" #include "base/debug/leak_annotations.h" +#include "base/json/json_reader.h" #include "base/file_util.h" #include "base/logging.h" #include "base/path_service.h" +#include "base/scoped_ptr.h" #include "base/string_number_conversions.h" #include "base/utf_string_conversions.h" +#include "base/values.h" #include "googleurl/src/gurl.h" #include "net/base/cert_test_util.h" #include "net/base/host_port_pair.h" @@ -186,6 +189,11 @@ const HostPortPair& TestServer::host_port_pair() const { return host_port_pair_; } +const DictionaryValue& TestServer::server_data() const { + DCHECK(started_); + return *server_data_; +} + std::string TestServer::GetScheme() const { switch (type_) { case TYPE_FTP: @@ -278,6 +286,11 @@ bool TestServer::SetPythonPath() { } third_party_dir = third_party_dir.Append(FILE_PATH_LITERAL("third_party")); + // For simplejson. (simplejson, unlike all the other python modules + // we include, doesn't have an extra 'simplejson' directory, so we + // need to include its parent directory, i.e. third_party_dir). + AppendToPythonPath(third_party_dir); + AppendToPythonPath(third_party_dir.Append(FILE_PATH_LITERAL("tlslite"))); AppendToPythonPath(third_party_dir.Append(FILE_PATH_LITERAL("pyftpdlib"))); @@ -373,4 +386,28 @@ bool TestServer::AddCommandLineArguments(CommandLine* command_line) const { return true; } +bool TestServer::ParseServerData(const std::string& server_data) { + VLOG(1) << "Server data: " << server_data; + base::JSONReader json_reader; + scoped_ptr<Value> value(json_reader.JsonToValue(server_data, true, false)); + if (!value.get() || + !value->IsType(Value::TYPE_DICTIONARY)) { + LOG(ERROR) << "Could not parse server data: " + << json_reader.GetErrorMessage(); + return false; + } + server_data_.reset(static_cast<DictionaryValue*>(value.release())); + int port = 0; + if (!server_data_->GetInteger("port", &port)) { + LOG(ERROR) << "Could not find port value"; + return false; + } + if ((port <= 0) || (port > kuint16max)) { + LOG(ERROR) << "Invalid port value: " << port; + return false; + } + host_port_pair_.set_port(port); + return true; +} + } // namespace net diff --git a/net/test/test_server.h b/net/test/test_server.h index f819365..00a8fc9 100644 --- a/net/test/test_server.h +++ b/net/test/test_server.h @@ -28,6 +28,7 @@ #endif class CommandLine; +class DictionaryValue; class GURL; namespace net { @@ -116,6 +117,7 @@ class TestServer { const FilePath& document_root() const { return document_root_; } const HostPortPair& host_port_pair() const; + const DictionaryValue& server_data() const; std::string GetScheme() const; bool GetAddressList(AddressList* address_list) const WARN_UNUSED_RESULT; @@ -146,6 +148,10 @@ class TestServer { // Waits for the server to start. Returns true on success. bool WaitToStart() WARN_UNUSED_RESULT; + // Parses the server data read from the test server. Returns true + // on success. + bool ParseServerData(const std::string& server_data) WARN_UNUSED_RESULT; + // Returns path to the root certificate. FilePath GetRootCertificatePath(); @@ -168,6 +174,9 @@ class TestServer { // Address the test server listens on. HostPortPair host_port_pair_; + // Holds the data sent from the server (e.g., port number). + scoped_ptr<DictionaryValue> server_data_; + // Handle of the Python process running the test server. base::ProcessHandle process_handle_; diff --git a/net/test/test_server_posix.cc b/net/test/test_server_posix.cc index 9c0210b..43bdb10 100644 --- a/net/test/test_server_posix.cc +++ b/net/test/test_server_posix.cc @@ -53,6 +53,42 @@ class OrphanedTestServerFilter : public base::ProcessFilter { DISALLOW_COPY_AND_ASSIGN(OrphanedTestServerFilter); }; +// Given a file descriptor, reads into |buffer| until |bytes_max| +// bytes has been read or an error has been encountered. Returns true +// if the read was successful. |remaining_time| is used as a timeout. +bool ReadData(int fd, ssize_t bytes_max, uint8* buffer, + base::TimeDelta* remaining_time) { + ssize_t bytes_read = 0; + base::Time previous_time = base::Time::Now(); + while (bytes_read < bytes_max) { + struct pollfd poll_fds[1]; + + poll_fds[0].fd = fd; + poll_fds[0].events = POLLIN | POLLPRI; + poll_fds[0].revents = 0; + + int rv = HANDLE_EINTR(poll(poll_fds, 1, + remaining_time->InMilliseconds())); + if (rv != 1) { + LOG(ERROR) << "Failed to poll for the child file descriptor."; + return false; + } + + base::Time current_time = base::Time::Now(); + base::TimeDelta elapsed_time_cycle = current_time - previous_time; + DCHECK(elapsed_time_cycle.InMilliseconds() >= 0); + *remaining_time -= elapsed_time_cycle; + previous_time = current_time; + + ssize_t num_bytes = HANDLE_EINTR(read(fd, buffer + bytes_read, + bytes_max - bytes_read)); + if (num_bytes <= 0) + return false; + bytes_read += num_bytes; + } + return true; +} + } // namespace namespace net { @@ -98,44 +134,32 @@ bool TestServer::LaunchPython(const FilePath& testserver_path) { } bool TestServer::WaitToStart() { - uint16 port; - uint8* buffer = reinterpret_cast<uint8*>(&port); - ssize_t bytes_read = 0; - ssize_t bytes_max = sizeof(port); + file_util::ScopedFD child_fd_closer(child_fd_closer_.release()); + base::TimeDelta remaining_time = base::TimeDelta::FromMilliseconds( TestTimeouts::action_max_timeout_ms()); - base::Time previous_time = base::Time::Now(); - while (bytes_read < bytes_max) { - struct pollfd poll_fds[1]; - - poll_fds[0].fd = child_fd_; - poll_fds[0].events = POLLIN | POLLPRI; - poll_fds[0].revents = 0; - - int rv = HANDLE_EINTR(poll(poll_fds, 1, remaining_time.InMilliseconds())); - if (rv != 1) { - LOG(ERROR) << "Failed to poll for the child file descriptor."; - return false; - } - base::Time current_time = base::Time::Now(); - base::TimeDelta elapsed_time_cycle = current_time - previous_time; - DCHECK(elapsed_time_cycle.InMilliseconds() >= 0); - remaining_time -= elapsed_time_cycle; - previous_time = current_time; - - ssize_t num_bytes = HANDLE_EINTR(read(child_fd_, buffer + bytes_read, - bytes_max - bytes_read)); - if (num_bytes <= 0) - break; - bytes_read += num_bytes; + uint32 server_data_len = 0; + if (!ReadData(child_fd_, sizeof(server_data_len), + reinterpret_cast<uint8*>(&server_data_len), + &remaining_time)) { + LOG(ERROR) << "Could not read server_data_len"; + return false; + } + std::string server_data(server_data_len, '\0'); + if (!ReadData(child_fd_, server_data_len, + reinterpret_cast<uint8*>(&server_data[0]), + &remaining_time)) { + LOG(ERROR) << "Could not read server_data (" << server_data_len + << " bytes)"; + return false; } - // We don't need the FD anymore. - child_fd_closer_.reset(NULL); - if (bytes_read < bytes_max) + if (!ParseServerData(server_data)) { + LOG(ERROR) << "Could not parse server_data: " << server_data; return false; - host_port_pair_.set_port(port); + } + return true; } diff --git a/net/test/test_server_win.cc b/net/test/test_server_win.cc index 64437cd..e1c54e9 100644 --- a/net/test/test_server_win.cc +++ b/net/test/test_server_win.cc @@ -75,17 +75,62 @@ bool LaunchTestServerAsJob(const CommandLine& cmdline, return true; } -void UnblockPipe(HANDLE handle, bool* unblocked) { - static const char kUnblock[] = "UNBLOCK"; +// Writes |size| bytes to |handle| and sets |*unblocked| to true. +// Used as a crude timeout mechanism by ReadData(). +void UnblockPipe(HANDLE handle, DWORD size, bool* unblocked) { + std::string unblock_data(size, '\0'); // Unblock the ReadFile in TestServer::WaitToStart by writing to the pipe. // Make sure the call succeeded, otherwise we are very likely to hang. DWORD bytes_written = 0; - CHECK(WriteFile(handle, kUnblock, arraysize(kUnblock), &bytes_written, + LOG(WARNING) << "Timeout reached; unblocking pipe by writing " + << size << " bytes"; + CHECK(WriteFile(handle, unblock_data.data(), size, &bytes_written, NULL)); - CHECK_EQ(arraysize(kUnblock), bytes_written); + CHECK_EQ(size, bytes_written); *unblocked = true; } +// Given a file handle, reads into |buffer| until |bytes_max| bytes +// has been read or an error has been encountered. Returns +// true if the read was successful. +bool ReadData(HANDLE read_fd, HANDLE write_fd, + DWORD bytes_max, uint8* buffer) { + base::Thread thread("test_server_watcher"); + if (!thread.Start()) + return false; + + // Prepare a timeout in case the server fails to start. + bool unblocked = false; + thread.message_loop()->PostDelayedTask( + FROM_HERE, + NewRunnableFunction(UnblockPipe, write_fd, bytes_max, &unblocked), + TestTimeouts::action_max_timeout_ms()); + + DWORD bytes_read = 0; + while (bytes_read < bytes_max) { + DWORD num_bytes; + if (!ReadFile(read_fd, buffer + bytes_read, bytes_max - bytes_read, + &num_bytes, NULL)) { + PLOG(ERROR) << "ReadFile failed"; + return false; + } + if (num_bytes <= 0) { + LOG(ERROR) << "ReadFile returned invalid byte count: " << num_bytes; + return false; + } + bytes_read += num_bytes; + } + + thread.Stop(); + // If the timeout kicked in, abort. + if (unblocked) { + LOG(ERROR) << "Timeout exceeded for ReadData"; + return false; + } + + return true; +} + } // namespace namespace net { @@ -96,7 +141,7 @@ bool TestServer::LaunchPython(const FilePath& testserver_path) { return false; python_exe = python_exe .Append(FILE_PATH_LITERAL("third_party")) - .Append(FILE_PATH_LITERAL("python_24")) + .Append(FILE_PATH_LITERAL("python_26")) .Append(FILE_PATH_LITERAL("python.exe")); CommandLine python_command(python_exe); @@ -146,43 +191,28 @@ bool TestServer::LaunchPython(const FilePath& testserver_path) { } bool TestServer::WaitToStart() { - base::Thread thread("test_server_watcher"); - if (!thread.Start()) - return false; - - // Prepare a timeout in case the server fails to start. - bool unblocked = false; - thread.message_loop()->PostDelayedTask(FROM_HERE, - NewRunnableFunction(UnblockPipe, child_write_fd_.Get(), &unblocked), - TestTimeouts::action_max_timeout_ms()); + ScopedHandle read_fd(child_read_fd_.Take()); + ScopedHandle write_fd(child_write_fd_.Take()); - // Try to read two bytes from the pipe indicating the ephemeral port number. - uint16 port; - uint8* buffer = reinterpret_cast<uint8*>(&port); - DWORD bytes_read = 0; - DWORD bytes_max = sizeof(port); - while (bytes_read < bytes_max) { - DWORD num_bytes; - if (!ReadFile(child_read_fd_, buffer + bytes_read, bytes_max - bytes_read, - &num_bytes, NULL)) - break; - if (num_bytes <= 0) - break; - bytes_read += num_bytes; + uint32 server_data_len = 0; + if (!ReadData(read_fd.Get(), write_fd.Get(), sizeof(server_data_len), + reinterpret_cast<uint8*>(&server_data_len))) { + LOG(ERROR) << "Could not read server_data_len"; + return false; } - thread.Stop(); - child_read_fd_.Close(); - child_write_fd_.Close(); - - // If we hit the timeout, fail. - if (unblocked) + std::string server_data(server_data_len, '\0'); + if (!ReadData(read_fd.Get(), write_fd.Get(), server_data_len, + reinterpret_cast<uint8*>(&server_data[0]))) { + LOG(ERROR) << "Could not read server_data (" << server_data_len + << " bytes)"; return false; + } - // If not enough bytes were read, fail. - if (bytes_read < bytes_max) + if (!ParseServerData(server_data)) { + LOG(ERROR) << "Could not parse server_data: " << server_data; return false; + } - host_port_pair_.set_port(port); return true; } diff --git a/net/third_party/mozilla_security_manager/nsPKCS12Blob.cpp b/net/third_party/mozilla_security_manager/nsPKCS12Blob.cpp index 35170cc..aae8d90 100644 --- a/net/third_party/mozilla_security_manager/nsPKCS12Blob.cpp +++ b/net/third_party/mozilla_security_manager/nsPKCS12Blob.cpp @@ -45,6 +45,7 @@ #include "base/crypto/scoped_nss_types.h" #include "base/logging.h" #include "base/nss_util_internal.h" +#include "base/singleton.h" #include "base/string_util.h" #include "net/base/net_errors.h" #include "net/base/x509_certificate.h" diff --git a/net/third_party/nss/README.chromium b/net/third_party/nss/README.chromium index d7f242f..7dc9de0 100644 --- a/net/third_party/nss/README.chromium +++ b/net/third_party/nss/README.chromium @@ -42,5 +42,14 @@ Patches: patches/snapstart.patch http://tools.ietf.org/html/draft-agl-tls-snapstart-00 + * Add OCSP stapling support + patches/ocspstapling.patch + + * Don't send a client certificate when renegotiating if the peer does not + request one. This only happened if the previous key exchange algorithm + was non-RSA. + patches/dheclientauth.patch + https://bugzilla.mozilla.org/show_bug.cgi?id=616757 + The ssl/bodge directory contains files taken from the NSS repo that we required for building libssl outside of its usual build environment. diff --git a/net/third_party/nss/patches/dheclientauth.patch b/net/third_party/nss/patches/dheclientauth.patch new file mode 100644 index 0000000..92d1d97 --- /dev/null +++ b/net/third_party/nss/patches/dheclientauth.patch @@ -0,0 +1,98 @@ +Index: mozilla/security/nss/lib/ssl/ssl3con.c +=================================================================== +RCS file: /cvsroot/mozilla/security/nss/lib/ssl/ssl3con.c,v +retrieving revision 1.142.2.4 +diff -u -p -u -8 -r1.142.2.4 ssl3con.c +--- mozilla/security/nss/lib/ssl/ssl3con.c 1 Sep 2010 19:47:11 -0000 1.142.2.4 ++++ mozilla/security/nss/lib/ssl/ssl3con.c 8 Dec 2010 06:55:49 -0000 +@@ -4832,24 +4832,18 @@ ssl3_SendCertificateVerify(sslSocket *ss + */ + slot = PK11_GetSlotFromPrivateKey(ss->ssl3.clientPrivateKey); + sid->u.ssl3.clAuthSeries = PK11_GetSlotSeries(slot); + sid->u.ssl3.clAuthSlotID = PK11_GetSlotID(slot); + sid->u.ssl3.clAuthModuleID = PK11_GetModuleID(slot); + sid->u.ssl3.clAuthValid = PR_TRUE; + PK11_FreeSlot(slot); + } +- /* If we're doing RSA key exchange, we're all done with the private key +- * here. Diffie-Hellman key exchanges need the client's +- * private key for the key exchange. +- */ +- if (ss->ssl3.hs.kea_def->exchKeyType == kt_rsa) { +- SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); +- ss->ssl3.clientPrivateKey = NULL; +- } ++ SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); ++ ss->ssl3.clientPrivateKey = NULL; + if (rv != SECSuccess) { + goto done; /* err code was set by ssl3_SignHashes */ + } + + rv = ssl3_AppendHandshakeHeader(ss, certificate_verify, buf.len + 2); + if (rv != SECSuccess) { + goto done; /* error code set by AppendHandshake */ + } +@@ -4894,16 +4888,30 @@ ssl3_HandleServerHello(sslSocket *ss, SS + goto alert_loser; + } + if (ss->ssl3.hs.ws != wait_server_hello) { + errCode = SSL_ERROR_RX_UNEXPECTED_SERVER_HELLO; + desc = unexpected_message; + goto alert_loser; + } + ++ /* clean up anything left from previous handshake. */ ++ if (ss->ssl3.clientCertChain != NULL) { ++ CERT_DestroyCertificateList(ss->ssl3.clientCertChain); ++ ss->ssl3.clientCertChain = NULL; ++ } ++ if (ss->ssl3.clientCertificate != NULL) { ++ CERT_DestroyCertificate(ss->ssl3.clientCertificate); ++ ss->ssl3.clientCertificate = NULL; ++ } ++ if (ss->ssl3.clientPrivateKey != NULL) { ++ SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); ++ ss->ssl3.clientPrivateKey = NULL; ++ } ++ + temp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length); + if (temp < 0) { + goto loser; /* alert has been sent */ + } + version = (SSL3ProtocolVersion)temp; + + /* this is appropriate since the negotiation is complete, and we only + ** know SSL 3.x. +@@ -5449,29 +5457,19 @@ ssl3_HandleCertificateRequest(sslSocket + + if (ss->ssl3.hs.ws != wait_cert_request && + ss->ssl3.hs.ws != wait_server_key) { + desc = unexpected_message; + errCode = SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST; + goto alert_loser; + } + +- /* clean up anything left from previous handshake. */ +- if (ss->ssl3.clientCertChain != NULL) { +- CERT_DestroyCertificateList(ss->ssl3.clientCertChain); +- ss->ssl3.clientCertChain = NULL; +- } +- if (ss->ssl3.clientCertificate != NULL) { +- CERT_DestroyCertificate(ss->ssl3.clientCertificate); +- ss->ssl3.clientCertificate = NULL; +- } +- if (ss->ssl3.clientPrivateKey != NULL) { +- SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); +- ss->ssl3.clientPrivateKey = NULL; +- } ++ PORT_Assert(ss->ssl3.clientCertChain == NULL); ++ PORT_Assert(ss->ssl3.clientCertificate == NULL); ++ PORT_Assert(ss->ssl3.clientPrivateKey == NULL); + + isTLS = (PRBool)(ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0); + rv = ssl3_ConsumeHandshakeVariable(ss, &cert_types, 1, &b, &length); + if (rv != SECSuccess) + goto loser; /* malformed, alert has been sent */ + + arena = ca_list.arena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE); + if (arena == NULL) diff --git a/net/third_party/nss/patches/ocspstapling.patch b/net/third_party/nss/patches/ocspstapling.patch new file mode 100644 index 0000000..13de561 --- /dev/null +++ b/net/third_party/nss/patches/ocspstapling.patch @@ -0,0 +1,508 @@ +commit aa046eb9a2f5bd6fb027a1a516c01ec2a093d287 +Author: Adam Langley <agl@chromium.org> +Date: Mon Nov 22 16:40:05 2010 -0500 + + nss: add support for OCSP stapling. + + This patch adds support in libssl for requesting and storing OCSP + stapled responses. + + BUG=none + TEST=none (yet) + + http://codereview.chromium.org/5045001 + +diff --git a/net/third_party/nss/ssl/ssl.def b/net/third_party/nss/ssl/ssl.def +index 60ebbb1..76417d0 100644 +--- a/net/third_party/nss/ssl/ssl.def ++++ b/net/third_party/nss/ssl/ssl.def +@@ -163,6 +163,7 @@ SSL_SetNextProtoNego; + ;+ global: + SSL_GetPredictedServerHelloData; + SSL_GetSnapStartResult; ++SSL_GetStapledOCSPResponse; + SSL_PeerCertificateChain; + SSL_SetPredictedPeerCertificates; + SSL_SetPredictedServerHelloData; +diff --git a/net/third_party/nss/ssl/ssl.h b/net/third_party/nss/ssl/ssl.h +index 9d3da0c..3515007 100644 +--- a/net/third_party/nss/ssl/ssl.h ++++ b/net/third_party/nss/ssl/ssl.h +@@ -148,6 +148,7 @@ SSL_IMPORT PRFileDesc *SSL_ImportFD(PRFileDesc *model, PRFileDesc *fd); + /* previous connection to the same server is required. See */ + /* SSL_GetPredictedServerHelloData, SSL_SetPredictedPeerCertificates and */ + /* SSL_SetSnapStartApplicationData. */ ++#define SSL_ENABLE_OCSP_STAPLING 24 /* Request OCSP stapling (client) */ + + #ifdef SSL_DEPRECATED_FUNCTION + /* Old deprecated function names */ +@@ -283,6 +284,23 @@ SSL_IMPORT CERTCertificate *SSL_PeerCertificate(PRFileDesc *fd); + SSL_IMPORT SECStatus SSL_PeerCertificateChain( + PRFileDesc *fd, CERTCertificate **certs, unsigned int *certs_size); + ++/* SSL_GetStapledOCSPResponse returns the OCSP response that was provided by ++ * the TLS server. The resulting data is copied to |out_data|. On entry, |*len| ++ * must contain the size of |out_data|. On exit, |*len| will contain the size ++ * of the OCSP stapled response. If the stapled response is too large to fit in ++ * |out_data| then it will be truncated. If no OCSP response was given by the ++ * server then it has zero length. ++ * ++ * You must set the SSL_ENABLE_OCSP_STAPLING option in order for OCSP responses ++ * to be provided by a server. ++ * ++ * You can call this function during the certificate verification callback or ++ * any time afterwards. ++ */ ++SSL_IMPORT SECStatus SSL_GetStapledOCSPResponse(PRFileDesc *fd, ++ unsigned char *out_data, ++ unsigned int *len); ++ + /* + ** Authenticate certificate hook. Called when a certificate comes in + ** (because of SSL_REQUIRE_CERTIFICATE in SSL_Enable) to authenticate the +diff --git a/net/third_party/nss/ssl/ssl3con.c b/net/third_party/nss/ssl/ssl3con.c +index c5ea79f..d56bb97 100644 +--- a/net/third_party/nss/ssl/ssl3con.c ++++ b/net/third_party/nss/ssl/ssl3con.c +@@ -7945,6 +7945,57 @@ ssl3_CopyPeerCertsToSID(ssl3CertNode *certs, sslSessionID *sid) + } + + /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete ++ * ssl3 CertificateStatus message. ++ * Caller must hold Handshake and RecvBuf locks. ++ * This is always called before ssl3_HandleCertificate, even if the Certificate ++ * message is sent first. ++ */ ++static SECStatus ++ssl3_HandleCertificateStatus(sslSocket *ss, SSL3Opaque *b, PRUint32 length) ++{ ++ PRInt32 status, len; ++ int errCode; ++ SSL3AlertDescription desc; ++ ++ if (!ss->ssl3.hs.may_get_cert_status || ++ ss->ssl3.hs.ws != wait_server_cert || ++ !ss->ssl3.hs.pending_cert_msg.data || ++ ss->ssl3.hs.cert_status.data) { ++ errCode = SSL_ERROR_RX_UNEXPECTED_CERT_STATUS; ++ desc = unexpected_message; ++ goto alert_loser; ++ } ++ ++ /* Consume the CertificateStatusType enum */ ++ status = ssl3_ConsumeHandshakeNumber(ss, 1, &b, &length); ++ if (status != 1 /* ocsp */) { ++ goto format_loser; ++ } ++ ++ len = ssl3_ConsumeHandshakeNumber(ss, 3, &b, &length); ++ if (len != length) { ++ goto format_loser; ++ } ++ ++ if (SECITEM_AllocItem(NULL, &ss->ssl3.hs.cert_status, length) == NULL) { ++ return SECFailure; ++ } ++ ss->ssl3.hs.cert_status.type = siBuffer; ++ PORT_Memcpy(ss->ssl3.hs.cert_status.data, b, length); ++ ++ return SECSuccess; ++ ++format_loser: ++ errCode = SSL_ERROR_BAD_CERT_STATUS_RESPONSE_ALERT; ++ desc = bad_certificate_status_response; ++ ++alert_loser: ++ (void)SSL3_SendAlert(ss, alert_fatal, desc); ++ (void)ssl_MapLowLevelError(errCode); ++ return SECFailure; ++} ++ ++/* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete + * ssl3 Certificate message. + * Caller must hold Handshake and RecvBuf locks. + */ +@@ -8773,6 +8824,26 @@ xmit_loser: + return SECSuccess; + } + ++/* This function handles any pending Certificate messages. Certificate messages ++ * can be pending if we expect a possible CertificateStatus message to follow. ++ * ++ * This function must be called immediately after handling the ++ * CertificateStatus message, and before handling any ServerKeyExchange or ++ * CertificateRequest messages. ++ */ ++static SECStatus ++ssl3_MaybeHandlePendingCertificateMessage(sslSocket *ss) ++{ ++ SECStatus rv = SECSuccess; ++ ++ if (ss->ssl3.hs.pending_cert_msg.data) { ++ rv = ssl3_HandleCertificate(ss, ss->ssl3.hs.pending_cert_msg.data, ++ ss->ssl3.hs.pending_cert_msg.len); ++ SECITEM_FreeItem(&ss->ssl3.hs.pending_cert_msg, PR_FALSE); ++ } ++ return rv; ++} ++ + /* Called from ssl3_HandleHandshake() when it has gathered a complete ssl3 + * hanshake message. + * Caller must hold Handshake and RecvBuf locks. +@@ -8872,14 +8943,42 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + rv = ssl3_HandleServerHello(ss, b, length); + break; + case certificate: ++ if (ss->ssl3.hs.may_get_cert_status) { ++ /* If we might get a CertificateStatus then we want to postpone the ++ * processing of the Certificate message until after we have ++ * processed the CertificateStatus */ ++ if (ss->ssl3.hs.pending_cert_msg.data || ++ ss->ssl3.hs.ws != wait_server_cert) { ++ (void)SSL3_SendAlert(ss, alert_fatal, unexpected_message); ++ (void)ssl_MapLowLevelError(SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); ++ return SECFailure; ++ } ++ if (SECITEM_AllocItem(NULL, &ss->ssl3.hs.pending_cert_msg, ++ length) == NULL) { ++ return SECFailure; ++ } ++ ss->ssl3.hs.pending_cert_msg.type = siBuffer; ++ PORT_Memcpy(ss->ssl3.hs.pending_cert_msg.data, b, length); ++ break; ++ } + rv = ssl3_HandleCertificate(ss, b, length); + break; ++ case certificate_status: ++ rv = ssl3_HandleCertificateStatus(ss, b, length); ++ if (rv != SECSuccess) ++ break; ++ PORT_Assert(ss->ssl3.hs.pending_cert_msg.data); ++ rv = ssl3_MaybeHandlePendingCertificateMessage(ss); ++ break; + case server_key_exchange: + if (ss->sec.isServer) { + (void)SSL3_SendAlert(ss, alert_fatal, unexpected_message); + PORT_SetError(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); + return SECFailure; + } ++ rv = ssl3_MaybeHandlePendingCertificateMessage(ss); ++ if (rv != SECSuccess) ++ break; + rv = ssl3_HandleServerKeyExchange(ss, b, length); + break; + case certificate_request: +@@ -8888,6 +8987,9 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + PORT_SetError(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST); + return SECFailure; + } ++ rv = ssl3_MaybeHandlePendingCertificateMessage(ss); ++ if (rv != SECSuccess) ++ break; + rv = ssl3_HandleCertificateRequest(ss, b, length); + break; + case server_hello_done: +@@ -8901,6 +9003,9 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + PORT_SetError(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); + return SECFailure; + } ++ rv = ssl3_MaybeHandlePendingCertificateMessage(ss); ++ if (rv != SECSuccess) ++ break; + rv = ssl3_HandleServerHelloDone(ss); + break; + case certificate_verify: +@@ -9767,6 +9872,12 @@ ssl3_DestroySSL3Info(sslSocket *ss) + if (ss->ssl3.hs.origClientHello.data) { + SECITEM_FreeItem(&ss->ssl3.hs.origClientHello, PR_FALSE); + } ++ if (ss->ssl3.hs.pending_cert_msg.data) { ++ SECITEM_FreeItem(&ss->ssl3.hs.pending_cert_msg, PR_FALSE); ++ } ++ if (ss->ssl3.hs.cert_status.data) { ++ SECITEM_FreeItem(&ss->ssl3.hs.cert_status, PR_FALSE); ++ } + + /* free the SSL3Buffer (msg_body) */ + PORT_Free(ss->ssl3.hs.msg_body.buf); +diff --git a/net/third_party/nss/ssl/ssl3ext.c b/net/third_party/nss/ssl/ssl3ext.c +index f044e1c..b93671e 100644 +--- a/net/third_party/nss/ssl/ssl3ext.c ++++ b/net/third_party/nss/ssl/ssl3ext.c +@@ -247,6 +247,7 @@ static const ssl3HelloExtensionHandler serverHelloHandlersTLS[] = { + { ssl_session_ticket_xtn, &ssl3_ClientHandleSessionTicketXtn }, + { ssl_renegotiation_info_xtn, &ssl3_HandleRenegotiationInfoXtn }, + { ssl_next_proto_neg_xtn, &ssl3_ClientHandleNextProtoNegoXtn }, ++ { ssl_cert_status_xtn, &ssl3_ClientHandleStatusRequestXtn }, + { ssl_snap_start_xtn, &ssl3_ClientHandleSnapStartXtn }, + { -1, NULL } + }; +@@ -272,6 +273,7 @@ ssl3HelloExtensionSender clientHelloSendersTLS[SSL_MAX_EXTENSIONS] = { + #endif + { ssl_session_ticket_xtn, &ssl3_SendSessionTicketXtn }, + { ssl_next_proto_neg_xtn, &ssl3_ClientSendNextProtoNegoXtn }, ++ { ssl_cert_status_xtn, &ssl3_ClientSendStatusRequestXtn }, + { ssl_snap_start_xtn, &ssl3_SendSnapStartXtn } + /* NOTE: The Snap Start sender MUST be the last extension in the list. */ + /* any extra entries will appear as { 0, NULL } */ +@@ -659,6 +661,80 @@ ssl3_ClientSendNextProtoNegoXtn(sslSocket * ss, + return -1; + } + ++SECStatus ++ssl3_ClientHandleStatusRequestXtn(sslSocket *ss, PRUint16 ex_type, ++ SECItem *data) ++{ ++ /* If we didn't request this extension, then the server may not echo it. */ ++ if (!ss->opt.enableOCSPStapling) ++ return SECFailure; ++ ++ /* The echoed extension must be empty. */ ++ if (data->len != 0) ++ return SECFailure; ++ ++ ss->ssl3.hs.may_get_cert_status = PR_TRUE; ++ ++ /* Keep track of negotiated extensions. */ ++ ss->xtnData.negotiated[ss->xtnData.numNegotiated++] = ex_type; ++ ++ return SECSuccess; ++} ++ ++/* ssl3_ClientSendStatusRequestXtn builds the status_request extension on the ++ * client side. See RFC 4366 section 3.6. */ ++PRInt32 ++ssl3_ClientSendStatusRequestXtn(sslSocket * ss, PRBool append, ++ PRUint32 maxBytes) ++{ ++ PRInt32 extension_length; ++ ++ if (!ss->opt.enableOCSPStapling) ++ return 0; ++ ++ /* extension_type (2-bytes) + ++ * length(extension_data) (2-bytes) + ++ * status_type (1) + ++ * responder_id_list length (2) + ++ * request_extensions length (2) ++ */ ++ extension_length = 9; ++ ++ if (append && maxBytes >= extension_length) { ++ SECStatus rv; ++ TLSExtensionData *xtnData; ++ ++ /* extension_type */ ++ rv = ssl3_AppendHandshakeNumber(ss, ssl_cert_status_xtn, 2); ++ if (rv != SECSuccess) ++ return -1; ++ rv = ssl3_AppendHandshakeNumber(ss, extension_length - 4, 2); ++ if (rv != SECSuccess) ++ return -1; ++ rv = ssl3_AppendHandshakeNumber(ss, 1 /* status_type ocsp */, 1); ++ if (rv != SECSuccess) ++ return -1; ++ /* A zero length responder_id_list means that the responders are ++ * implicitly known to the server. */ ++ rv = ssl3_AppendHandshakeNumber(ss, 0, 2); ++ if (rv != SECSuccess) ++ return -1; ++ /* A zero length request_extensions means that there are no extensions. ++ * Specifically, we don't set the id-pkix-ocsp-nonce extension. This ++ * means that the server can replay a cached OCSP response to us. */ ++ rv = ssl3_AppendHandshakeNumber(ss, 0, 2); ++ if (rv != SECSuccess) ++ return -1; ++ ++ xtnData = &ss->xtnData; ++ xtnData->advertised[xtnData->numAdvertised++] = ssl_cert_status_xtn; ++ } else if (maxBytes < extension_length) { ++ PORT_Assert(0); ++ return 0; ++ } ++ return extension_length; ++} ++ + /* + * NewSessionTicket + * Called from ssl3_HandleFinished +diff --git a/net/third_party/nss/ssl/ssl3prot.h b/net/third_party/nss/ssl/ssl3prot.h +index f3c950e..aeaacdd 100644 +--- a/net/third_party/nss/ssl/ssl3prot.h ++++ b/net/third_party/nss/ssl/ssl3prot.h +@@ -158,6 +158,7 @@ typedef enum { + certificate_verify = 15, + client_key_exchange = 16, + finished = 20, ++ certificate_status = 22, + next_proto = 67 + } SSL3HandshakeType; + +diff --git a/net/third_party/nss/ssl/sslerr.h b/net/third_party/nss/ssl/sslerr.h +index bd72f97..eb56ea9 100644 +--- a/net/third_party/nss/ssl/sslerr.h ++++ b/net/third_party/nss/ssl/sslerr.h +@@ -203,6 +203,8 @@ SSL_ERROR_RX_UNEXPECTED_UNCOMPRESSED_RECORD = (SSL_ERROR_BASE + 114), + + SSL_ERROR_WEAK_SERVER_KEY = (SSL_ERROR_BASE + 115), + ++SSL_ERROR_RX_UNEXPECTED_CERT_STATUS = (SSL_ERROR_BASE + 116), ++ + SSL_ERROR_END_OF_LIST /* let the c compiler determine the value of this. */ + } SSLErrorCodes; + #endif /* NO_SECURITY_ERROR_ENUM */ +diff --git a/net/third_party/nss/ssl/sslimpl.h b/net/third_party/nss/ssl/sslimpl.h +index b84511b..c656f65 100644 +--- a/net/third_party/nss/ssl/sslimpl.h ++++ b/net/third_party/nss/ssl/sslimpl.h +@@ -350,6 +350,7 @@ typedef struct sslOptionsStr { + unsigned int requireSafeNegotiation : 1; /* 22 */ + unsigned int enableFalseStart : 1; /* 23 */ + unsigned int enableSnapStart : 1; /* 24 */ ++ unsigned int enableOCSPStapling : 1; /* 25 */ + } sslOptions; + + typedef enum { sslHandshakingUndetermined = 0, +@@ -820,6 +821,14 @@ const ssl3CipherSuiteDef *suite_def; + * when this one finishes */ + PRBool usedStepDownKey; /* we did a server key exchange. */ + PRBool sendingSCSV; /* instead of empty RI */ ++ PRBool may_get_cert_status; /* the server echoed a ++ * status_request extension so ++ * may send a CertificateStatus ++ * handshake message. */ ++ SECItem pending_cert_msg; /* a Certificate message which we ++ * save temporarily if we may get ++ * a CertificateStatus message */ ++ SECItem cert_status; /* an OCSP response */ + sslBuffer msgState; /* current state for handshake messages*/ + /* protected by recvBufLock */ + sslBuffer messages; /* Accumulated handshake messages */ +@@ -1620,6 +1629,8 @@ extern SECStatus ssl3_ClientHandleSessionTicketXtn(sslSocket *ss, + PRUint16 ex_type, SECItem *data); + extern SECStatus ssl3_ClientHandleNextProtoNegoXtn(sslSocket *ss, + PRUint16 ex_type, SECItem *data); ++extern SECStatus ssl3_ClientHandleStatusRequestXtn(sslSocket *ss, ++ PRUint16 ex_type, SECItem *data); + extern SECStatus ssl3_ServerHandleSessionTicketXtn(sslSocket *ss, + PRUint16 ex_type, SECItem *data); + extern SECStatus ssl3_ServerHandleNextProtoNegoXtn(sslSocket *ss, +@@ -1631,6 +1642,8 @@ extern SECStatus ssl3_ServerHandleNextProtoNegoXtn(sslSocket *ss, + */ + extern PRInt32 ssl3_SendSessionTicketXtn(sslSocket *ss, PRBool append, + PRUint32 maxBytes); ++extern PRInt32 ssl3_ClientSendStatusRequestXtn(sslSocket *ss, PRBool append, ++ PRUint32 maxBytes); + + /* ClientHello and ServerHello extension senders. + * The code is in ssl3ext.c. +diff --git a/net/third_party/nss/ssl/sslsock.c b/net/third_party/nss/ssl/sslsock.c +index 33e7f3e..b14a935 100644 +--- a/net/third_party/nss/ssl/sslsock.c ++++ b/net/third_party/nss/ssl/sslsock.c +@@ -185,6 +185,7 @@ static sslOptions ssl_defaults = { + 2, /* enableRenegotiation (default: requires extension) */ + PR_FALSE, /* requireSafeNegotiation */ + PR_FALSE, /* enableFalseStart */ ++ PR_FALSE, /* enableOCSPStapling */ + }; + + sslSessionIDLookupFunc ssl_sid_lookup; +@@ -746,6 +747,10 @@ SSL_OptionSet(PRFileDesc *fd, PRInt32 which, PRBool on) + ss->opt.enableSnapStart = on; + break; + ++ case SSL_ENABLE_OCSP_STAPLING: ++ ss->opt.enableOCSPStapling = on; ++ break; ++ + default: + PORT_SetError(SEC_ERROR_INVALID_ARGS); + rv = SECFailure; +@@ -811,6 +816,7 @@ SSL_OptionGet(PRFileDesc *fd, PRInt32 which, PRBool *pOn) + on = ss->opt.requireSafeNegotiation; break; + case SSL_ENABLE_FALSE_START: on = ss->opt.enableFalseStart; break; + case SSL_ENABLE_SNAP_START: on = ss->opt.enableSnapStart; break; ++ case SSL_ENABLE_OCSP_STAPLING: on = ss->opt.enableOCSPStapling; break; + + default: + PORT_SetError(SEC_ERROR_INVALID_ARGS); +@@ -863,6 +869,9 @@ SSL_OptionGetDefault(PRInt32 which, PRBool *pOn) + break; + case SSL_ENABLE_FALSE_START: on = ssl_defaults.enableFalseStart; break; + case SSL_ENABLE_SNAP_START: on = ssl_defaults.enableSnapStart; break; ++ case SSL_ENABLE_OCSP_STAPLING: ++ on = ssl_defaults.enableOCSPStapling; ++ break; + + default: + PORT_SetError(SEC_ERROR_INVALID_ARGS); +@@ -1014,6 +1023,10 @@ SSL_OptionSetDefault(PRInt32 which, PRBool on) + ssl_defaults.enableSnapStart = on; + break; + ++ case SSL_ENABLE_OCSP_STAPLING: ++ ssl_defaults.enableOCSPStapling = on; ++ break; ++ + default: + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; +@@ -1473,6 +1486,36 @@ loser: + #endif + } + ++SECStatus ++SSL_GetStapledOCSPResponse(PRFileDesc *fd, unsigned char *out_data, ++ unsigned int *len) { ++ sslSocket *ss = ssl_FindSocket(fd); ++ ++ if (!ss) { ++ SSL_DBG(("%d: SSL[%d]: bad socket in SSL_GetStapledOCSPResponse", ++ SSL_GETPID(), fd)); ++ return SECFailure; ++ } ++ ++ ssl_Get1stHandshakeLock(ss); ++ ssl_GetSSL3HandshakeLock(ss); ++ ++ if (ss->ssl3.hs.cert_status.data) { ++ unsigned int todo = ss->ssl3.hs.cert_status.len; ++ if (todo > *len) ++ todo = *len; ++ *len = ss->ssl3.hs.cert_status.len; ++ PORT_Memcpy(out_data, ss->ssl3.hs.cert_status.data, todo); ++ } else { ++ *len = 0; ++ } ++ ++ ssl_ReleaseSSL3HandshakeLock(ss); ++ ssl_Release1stHandshakeLock(ss); ++ ++ return SECSuccess; ++} ++ + /************************************************************************/ + /* The following functions are the TOP LEVEL SSL functions. + ** They all get called through the NSPRIOMethods table below. +diff --git a/net/third_party/nss/ssl/sslt.h b/net/third_party/nss/ssl/sslt.h +index 68cbf87..3fa3f9b 100644 +--- a/net/third_party/nss/ssl/sslt.h ++++ b/net/third_party/nss/ssl/sslt.h +@@ -198,6 +198,7 @@ typedef enum { + /* Update SSL_MAX_EXTENSIONS whenever a new extension type is added. */ + typedef enum { + ssl_server_name_xtn = 0, ++ ssl_cert_status_xtn = 5, + #ifdef NSS_ENABLE_ECC + ssl_elliptic_curves_xtn = 10, + ssl_ec_point_formats_xtn = 11, +@@ -208,7 +209,7 @@ typedef enum { + ssl_renegotiation_info_xtn = 0xff01 /* experimental number */ + } SSLExtensionType; + +-#define SSL_MAX_EXTENSIONS 7 ++#define SSL_MAX_EXTENSIONS 8 + + typedef enum { + /* No Snap Start handshake was attempted. */ diff --git a/net/third_party/nss/ssl/ssl.def b/net/third_party/nss/ssl/ssl.def index 60ebbb1..76417d0 100644 --- a/net/third_party/nss/ssl/ssl.def +++ b/net/third_party/nss/ssl/ssl.def @@ -163,6 +163,7 @@ SSL_SetNextProtoNego; ;+ global: SSL_GetPredictedServerHelloData; SSL_GetSnapStartResult; +SSL_GetStapledOCSPResponse; SSL_PeerCertificateChain; SSL_SetPredictedPeerCertificates; SSL_SetPredictedServerHelloData; diff --git a/net/third_party/nss/ssl/ssl.h b/net/third_party/nss/ssl/ssl.h index 9d3da0c..3515007 100644 --- a/net/third_party/nss/ssl/ssl.h +++ b/net/third_party/nss/ssl/ssl.h @@ -148,6 +148,7 @@ SSL_IMPORT PRFileDesc *SSL_ImportFD(PRFileDesc *model, PRFileDesc *fd); /* previous connection to the same server is required. See */ /* SSL_GetPredictedServerHelloData, SSL_SetPredictedPeerCertificates and */ /* SSL_SetSnapStartApplicationData. */ +#define SSL_ENABLE_OCSP_STAPLING 24 /* Request OCSP stapling (client) */ #ifdef SSL_DEPRECATED_FUNCTION /* Old deprecated function names */ @@ -283,6 +284,23 @@ SSL_IMPORT CERTCertificate *SSL_PeerCertificate(PRFileDesc *fd); SSL_IMPORT SECStatus SSL_PeerCertificateChain( PRFileDesc *fd, CERTCertificate **certs, unsigned int *certs_size); +/* SSL_GetStapledOCSPResponse returns the OCSP response that was provided by + * the TLS server. The resulting data is copied to |out_data|. On entry, |*len| + * must contain the size of |out_data|. On exit, |*len| will contain the size + * of the OCSP stapled response. If the stapled response is too large to fit in + * |out_data| then it will be truncated. If no OCSP response was given by the + * server then it has zero length. + * + * You must set the SSL_ENABLE_OCSP_STAPLING option in order for OCSP responses + * to be provided by a server. + * + * You can call this function during the certificate verification callback or + * any time afterwards. + */ +SSL_IMPORT SECStatus SSL_GetStapledOCSPResponse(PRFileDesc *fd, + unsigned char *out_data, + unsigned int *len); + /* ** Authenticate certificate hook. Called when a certificate comes in ** (because of SSL_REQUIRE_CERTIFICATE in SSL_Enable) to authenticate the diff --git a/net/third_party/nss/ssl/ssl3con.c b/net/third_party/nss/ssl/ssl3con.c index c5ea79f..f5c0880 100644 --- a/net/third_party/nss/ssl/ssl3con.c +++ b/net/third_party/nss/ssl/ssl3con.c @@ -4843,10 +4843,8 @@ ssl3_SendCertificateVerify(sslSocket *ss) &sid->u.ssl3.clPlatformAuthInfo); sid->u.ssl3.clPlatformAuthValid = PR_TRUE; } - if (ss->ssl3.hs.kea_def->exchKeyType == kt_rsa) { - ssl_FreePlatformKey(ss->ssl3.platformClientKey); - ss->ssl3.platformClientKey = (PlatformKey)NULL; - } + ssl_FreePlatformKey(ss->ssl3.platformClientKey); + ss->ssl3.platformClientKey = (PlatformKey)NULL; #else /* NSS_PLATFORM_CLIENT_AUTH */ rv = ssl3_SignHashes(&hashes, ss->ssl3.clientPrivateKey, &buf, isTLS); if (rv == SECSuccess) { @@ -4864,14 +4862,8 @@ ssl3_SendCertificateVerify(sslSocket *ss) sid->u.ssl3.clAuthValid = PR_TRUE; PK11_FreeSlot(slot); } - /* If we're doing RSA key exchange, we're all done with the private key - * here. Diffie-Hellman key exchanges need the client's - * private key for the key exchange. - */ - if (ss->ssl3.hs.kea_def->exchKeyType == kt_rsa) { - SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); - ss->ssl3.clientPrivateKey = NULL; - } + SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); + ss->ssl3.clientPrivateKey = NULL; #endif /* NSS_PLATFORM_CLIENT_AUTH */ if (rv != SECSuccess) { goto done; /* err code was set by ssl3_SignHashes */ @@ -5022,6 +5014,26 @@ ssl3_HandleServerHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) desc = unexpected_message; goto alert_loser; } + + /* clean up anything left from previous handshake. */ + if (ss->ssl3.clientCertChain != NULL) { + CERT_DestroyCertificateList(ss->ssl3.clientCertChain); + ss->ssl3.clientCertChain = NULL; + } + if (ss->ssl3.clientCertificate != NULL) { + CERT_DestroyCertificate(ss->ssl3.clientCertificate); + ss->ssl3.clientCertificate = NULL; + } + if (ss->ssl3.clientPrivateKey != NULL) { + SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); + ss->ssl3.clientPrivateKey = NULL; + } +#ifdef NSS_PLATFORM_CLIENT_AUTH + if (ss->ssl3.platformClientKey) { + ssl_FreePlatformKey(ss->ssl3.platformClientKey); + ss->ssl3.platformClientKey = (PlatformKey)NULL; + } +#endif /* NSS_PLATFORM_CLIENT_AUTH */ if (ss->ssl3.serverHelloPredictionData.data) SECITEM_FreeItem(&ss->ssl3.serverHelloPredictionData, PR_FALSE); @@ -5519,26 +5531,13 @@ ssl3_HandleCertificateRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length) errCode = SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST; goto alert_loser; } - - /* clean up anything left from previous handshake. */ - if (ss->ssl3.clientCertChain != NULL) { - CERT_DestroyCertificateList(ss->ssl3.clientCertChain); - ss->ssl3.clientCertChain = NULL; - } - if (ss->ssl3.clientCertificate != NULL) { - CERT_DestroyCertificate(ss->ssl3.clientCertificate); - ss->ssl3.clientCertificate = NULL; - } - if (ss->ssl3.clientPrivateKey != NULL) { - SECKEY_DestroyPrivateKey(ss->ssl3.clientPrivateKey); - ss->ssl3.clientPrivateKey = NULL; - } + + PORT_Assert(ss->ssl3.clientCertChain == NULL); + PORT_Assert(ss->ssl3.clientCertificate == NULL); + PORT_Assert(ss->ssl3.clientPrivateKey == NULL); #ifdef NSS_PLATFORM_CLIENT_AUTH - if (ss->ssl3.platformClientKey) { - ssl_FreePlatformKey(ss->ssl3.platformClientKey); - ss->ssl3.platformClientKey = (PlatformKey)NULL; - } -#endif /* NSS_PLATFORM_CLIENT_AUTH */ + PORT_Assert(ss->ssl3.platformClientKey == (PlatformKey)NULL); +#endif /* NSS_PLATFORM_CLIENT_AUTH */ isTLS = (PRBool)(ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0); rv = ssl3_ConsumeHandshakeVariable(ss, &cert_types, 1, &b, &length); @@ -7945,6 +7944,57 @@ ssl3_CopyPeerCertsToSID(ssl3CertNode *certs, sslSessionID *sid) } /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete + * ssl3 CertificateStatus message. + * Caller must hold Handshake and RecvBuf locks. + * This is always called before ssl3_HandleCertificate, even if the Certificate + * message is sent first. + */ +static SECStatus +ssl3_HandleCertificateStatus(sslSocket *ss, SSL3Opaque *b, PRUint32 length) +{ + PRInt32 status, len; + int errCode; + SSL3AlertDescription desc; + + if (!ss->ssl3.hs.may_get_cert_status || + ss->ssl3.hs.ws != wait_server_cert || + !ss->ssl3.hs.pending_cert_msg.data || + ss->ssl3.hs.cert_status.data) { + errCode = SSL_ERROR_RX_UNEXPECTED_CERT_STATUS; + desc = unexpected_message; + goto alert_loser; + } + + /* Consume the CertificateStatusType enum */ + status = ssl3_ConsumeHandshakeNumber(ss, 1, &b, &length); + if (status != 1 /* ocsp */) { + goto format_loser; + } + + len = ssl3_ConsumeHandshakeNumber(ss, 3, &b, &length); + if (len != length) { + goto format_loser; + } + + if (SECITEM_AllocItem(NULL, &ss->ssl3.hs.cert_status, length) == NULL) { + return SECFailure; + } + ss->ssl3.hs.cert_status.type = siBuffer; + PORT_Memcpy(ss->ssl3.hs.cert_status.data, b, length); + + return SECSuccess; + +format_loser: + errCode = SSL_ERROR_BAD_CERT_STATUS_RESPONSE_ALERT; + desc = bad_certificate_status_response; + +alert_loser: + (void)SSL3_SendAlert(ss, alert_fatal, desc); + (void)ssl_MapLowLevelError(errCode); + return SECFailure; +} + +/* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete * ssl3 Certificate message. * Caller must hold Handshake and RecvBuf locks. */ @@ -8773,6 +8823,26 @@ xmit_loser: return SECSuccess; } +/* This function handles any pending Certificate messages. Certificate messages + * can be pending if we expect a possible CertificateStatus message to follow. + * + * This function must be called immediately after handling the + * CertificateStatus message, and before handling any ServerKeyExchange or + * CertificateRequest messages. + */ +static SECStatus +ssl3_MaybeHandlePendingCertificateMessage(sslSocket *ss) +{ + SECStatus rv = SECSuccess; + + if (ss->ssl3.hs.pending_cert_msg.data) { + rv = ssl3_HandleCertificate(ss, ss->ssl3.hs.pending_cert_msg.data, + ss->ssl3.hs.pending_cert_msg.len); + SECITEM_FreeItem(&ss->ssl3.hs.pending_cert_msg, PR_FALSE); + } + return rv; +} + /* Called from ssl3_HandleHandshake() when it has gathered a complete ssl3 * hanshake message. * Caller must hold Handshake and RecvBuf locks. @@ -8872,14 +8942,42 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length) rv = ssl3_HandleServerHello(ss, b, length); break; case certificate: + if (ss->ssl3.hs.may_get_cert_status) { + /* If we might get a CertificateStatus then we want to postpone the + * processing of the Certificate message until after we have + * processed the CertificateStatus */ + if (ss->ssl3.hs.pending_cert_msg.data || + ss->ssl3.hs.ws != wait_server_cert) { + (void)SSL3_SendAlert(ss, alert_fatal, unexpected_message); + (void)ssl_MapLowLevelError(SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); + return SECFailure; + } + if (SECITEM_AllocItem(NULL, &ss->ssl3.hs.pending_cert_msg, + length) == NULL) { + return SECFailure; + } + ss->ssl3.hs.pending_cert_msg.type = siBuffer; + PORT_Memcpy(ss->ssl3.hs.pending_cert_msg.data, b, length); + break; + } rv = ssl3_HandleCertificate(ss, b, length); break; + case certificate_status: + rv = ssl3_HandleCertificateStatus(ss, b, length); + if (rv != SECSuccess) + break; + PORT_Assert(ss->ssl3.hs.pending_cert_msg.data); + rv = ssl3_MaybeHandlePendingCertificateMessage(ss); + break; case server_key_exchange: if (ss->sec.isServer) { (void)SSL3_SendAlert(ss, alert_fatal, unexpected_message); PORT_SetError(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); return SECFailure; } + rv = ssl3_MaybeHandlePendingCertificateMessage(ss); + if (rv != SECSuccess) + break; rv = ssl3_HandleServerKeyExchange(ss, b, length); break; case certificate_request: @@ -8888,6 +8986,9 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length) PORT_SetError(SSL_ERROR_RX_UNEXPECTED_CERT_REQUEST); return SECFailure; } + rv = ssl3_MaybeHandlePendingCertificateMessage(ss); + if (rv != SECSuccess) + break; rv = ssl3_HandleCertificateRequest(ss, b, length); break; case server_hello_done: @@ -8901,6 +9002,9 @@ ssl3_HandleHandshakeMessage(sslSocket *ss, SSL3Opaque *b, PRUint32 length) PORT_SetError(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); return SECFailure; } + rv = ssl3_MaybeHandlePendingCertificateMessage(ss); + if (rv != SECSuccess) + break; rv = ssl3_HandleServerHelloDone(ss); break; case certificate_verify: @@ -9767,6 +9871,12 @@ ssl3_DestroySSL3Info(sslSocket *ss) if (ss->ssl3.hs.origClientHello.data) { SECITEM_FreeItem(&ss->ssl3.hs.origClientHello, PR_FALSE); } + if (ss->ssl3.hs.pending_cert_msg.data) { + SECITEM_FreeItem(&ss->ssl3.hs.pending_cert_msg, PR_FALSE); + } + if (ss->ssl3.hs.cert_status.data) { + SECITEM_FreeItem(&ss->ssl3.hs.cert_status, PR_FALSE); + } /* free the SSL3Buffer (msg_body) */ PORT_Free(ss->ssl3.hs.msg_body.buf); diff --git a/net/third_party/nss/ssl/ssl3ext.c b/net/third_party/nss/ssl/ssl3ext.c index f044e1c..b93671e 100644 --- a/net/third_party/nss/ssl/ssl3ext.c +++ b/net/third_party/nss/ssl/ssl3ext.c @@ -247,6 +247,7 @@ static const ssl3HelloExtensionHandler serverHelloHandlersTLS[] = { { ssl_session_ticket_xtn, &ssl3_ClientHandleSessionTicketXtn }, { ssl_renegotiation_info_xtn, &ssl3_HandleRenegotiationInfoXtn }, { ssl_next_proto_neg_xtn, &ssl3_ClientHandleNextProtoNegoXtn }, + { ssl_cert_status_xtn, &ssl3_ClientHandleStatusRequestXtn }, { ssl_snap_start_xtn, &ssl3_ClientHandleSnapStartXtn }, { -1, NULL } }; @@ -272,6 +273,7 @@ ssl3HelloExtensionSender clientHelloSendersTLS[SSL_MAX_EXTENSIONS] = { #endif { ssl_session_ticket_xtn, &ssl3_SendSessionTicketXtn }, { ssl_next_proto_neg_xtn, &ssl3_ClientSendNextProtoNegoXtn }, + { ssl_cert_status_xtn, &ssl3_ClientSendStatusRequestXtn }, { ssl_snap_start_xtn, &ssl3_SendSnapStartXtn } /* NOTE: The Snap Start sender MUST be the last extension in the list. */ /* any extra entries will appear as { 0, NULL } */ @@ -659,6 +661,80 @@ ssl3_ClientSendNextProtoNegoXtn(sslSocket * ss, return -1; } +SECStatus +ssl3_ClientHandleStatusRequestXtn(sslSocket *ss, PRUint16 ex_type, + SECItem *data) +{ + /* If we didn't request this extension, then the server may not echo it. */ + if (!ss->opt.enableOCSPStapling) + return SECFailure; + + /* The echoed extension must be empty. */ + if (data->len != 0) + return SECFailure; + + ss->ssl3.hs.may_get_cert_status = PR_TRUE; + + /* Keep track of negotiated extensions. */ + ss->xtnData.negotiated[ss->xtnData.numNegotiated++] = ex_type; + + return SECSuccess; +} + +/* ssl3_ClientSendStatusRequestXtn builds the status_request extension on the + * client side. See RFC 4366 section 3.6. */ +PRInt32 +ssl3_ClientSendStatusRequestXtn(sslSocket * ss, PRBool append, + PRUint32 maxBytes) +{ + PRInt32 extension_length; + + if (!ss->opt.enableOCSPStapling) + return 0; + + /* extension_type (2-bytes) + + * length(extension_data) (2-bytes) + + * status_type (1) + + * responder_id_list length (2) + + * request_extensions length (2) + */ + extension_length = 9; + + if (append && maxBytes >= extension_length) { + SECStatus rv; + TLSExtensionData *xtnData; + + /* extension_type */ + rv = ssl3_AppendHandshakeNumber(ss, ssl_cert_status_xtn, 2); + if (rv != SECSuccess) + return -1; + rv = ssl3_AppendHandshakeNumber(ss, extension_length - 4, 2); + if (rv != SECSuccess) + return -1; + rv = ssl3_AppendHandshakeNumber(ss, 1 /* status_type ocsp */, 1); + if (rv != SECSuccess) + return -1; + /* A zero length responder_id_list means that the responders are + * implicitly known to the server. */ + rv = ssl3_AppendHandshakeNumber(ss, 0, 2); + if (rv != SECSuccess) + return -1; + /* A zero length request_extensions means that there are no extensions. + * Specifically, we don't set the id-pkix-ocsp-nonce extension. This + * means that the server can replay a cached OCSP response to us. */ + rv = ssl3_AppendHandshakeNumber(ss, 0, 2); + if (rv != SECSuccess) + return -1; + + xtnData = &ss->xtnData; + xtnData->advertised[xtnData->numAdvertised++] = ssl_cert_status_xtn; + } else if (maxBytes < extension_length) { + PORT_Assert(0); + return 0; + } + return extension_length; +} + /* * NewSessionTicket * Called from ssl3_HandleFinished diff --git a/net/third_party/nss/ssl/ssl3prot.h b/net/third_party/nss/ssl/ssl3prot.h index f3c950e..aeaacdd 100644 --- a/net/third_party/nss/ssl/ssl3prot.h +++ b/net/third_party/nss/ssl/ssl3prot.h @@ -158,6 +158,7 @@ typedef enum { certificate_verify = 15, client_key_exchange = 16, finished = 20, + certificate_status = 22, next_proto = 67 } SSL3HandshakeType; diff --git a/net/third_party/nss/ssl/sslerr.h b/net/third_party/nss/ssl/sslerr.h index bd72f97..eb56ea9 100644 --- a/net/third_party/nss/ssl/sslerr.h +++ b/net/third_party/nss/ssl/sslerr.h @@ -203,6 +203,8 @@ SSL_ERROR_RX_UNEXPECTED_UNCOMPRESSED_RECORD = (SSL_ERROR_BASE + 114), SSL_ERROR_WEAK_SERVER_KEY = (SSL_ERROR_BASE + 115), +SSL_ERROR_RX_UNEXPECTED_CERT_STATUS = (SSL_ERROR_BASE + 116), + SSL_ERROR_END_OF_LIST /* let the c compiler determine the value of this. */ } SSLErrorCodes; #endif /* NO_SECURITY_ERROR_ENUM */ diff --git a/net/third_party/nss/ssl/sslimpl.h b/net/third_party/nss/ssl/sslimpl.h index b84511b..c656f65 100644 --- a/net/third_party/nss/ssl/sslimpl.h +++ b/net/third_party/nss/ssl/sslimpl.h @@ -350,6 +350,7 @@ typedef struct sslOptionsStr { unsigned int requireSafeNegotiation : 1; /* 22 */ unsigned int enableFalseStart : 1; /* 23 */ unsigned int enableSnapStart : 1; /* 24 */ + unsigned int enableOCSPStapling : 1; /* 25 */ } sslOptions; typedef enum { sslHandshakingUndetermined = 0, @@ -820,6 +821,14 @@ const ssl3CipherSuiteDef *suite_def; * when this one finishes */ PRBool usedStepDownKey; /* we did a server key exchange. */ PRBool sendingSCSV; /* instead of empty RI */ + PRBool may_get_cert_status; /* the server echoed a + * status_request extension so + * may send a CertificateStatus + * handshake message. */ + SECItem pending_cert_msg; /* a Certificate message which we + * save temporarily if we may get + * a CertificateStatus message */ + SECItem cert_status; /* an OCSP response */ sslBuffer msgState; /* current state for handshake messages*/ /* protected by recvBufLock */ sslBuffer messages; /* Accumulated handshake messages */ @@ -1620,6 +1629,8 @@ extern SECStatus ssl3_ClientHandleSessionTicketXtn(sslSocket *ss, PRUint16 ex_type, SECItem *data); extern SECStatus ssl3_ClientHandleNextProtoNegoXtn(sslSocket *ss, PRUint16 ex_type, SECItem *data); +extern SECStatus ssl3_ClientHandleStatusRequestXtn(sslSocket *ss, + PRUint16 ex_type, SECItem *data); extern SECStatus ssl3_ServerHandleSessionTicketXtn(sslSocket *ss, PRUint16 ex_type, SECItem *data); extern SECStatus ssl3_ServerHandleNextProtoNegoXtn(sslSocket *ss, @@ -1631,6 +1642,8 @@ extern SECStatus ssl3_ServerHandleNextProtoNegoXtn(sslSocket *ss, */ extern PRInt32 ssl3_SendSessionTicketXtn(sslSocket *ss, PRBool append, PRUint32 maxBytes); +extern PRInt32 ssl3_ClientSendStatusRequestXtn(sslSocket *ss, PRBool append, + PRUint32 maxBytes); /* ClientHello and ServerHello extension senders. * The code is in ssl3ext.c. diff --git a/net/third_party/nss/ssl/sslsock.c b/net/third_party/nss/ssl/sslsock.c index 33e7f3e..b14a935 100644 --- a/net/third_party/nss/ssl/sslsock.c +++ b/net/third_party/nss/ssl/sslsock.c @@ -185,6 +185,7 @@ static sslOptions ssl_defaults = { 2, /* enableRenegotiation (default: requires extension) */ PR_FALSE, /* requireSafeNegotiation */ PR_FALSE, /* enableFalseStart */ + PR_FALSE, /* enableOCSPStapling */ }; sslSessionIDLookupFunc ssl_sid_lookup; @@ -746,6 +747,10 @@ SSL_OptionSet(PRFileDesc *fd, PRInt32 which, PRBool on) ss->opt.enableSnapStart = on; break; + case SSL_ENABLE_OCSP_STAPLING: + ss->opt.enableOCSPStapling = on; + break; + default: PORT_SetError(SEC_ERROR_INVALID_ARGS); rv = SECFailure; @@ -811,6 +816,7 @@ SSL_OptionGet(PRFileDesc *fd, PRInt32 which, PRBool *pOn) on = ss->opt.requireSafeNegotiation; break; case SSL_ENABLE_FALSE_START: on = ss->opt.enableFalseStart; break; case SSL_ENABLE_SNAP_START: on = ss->opt.enableSnapStart; break; + case SSL_ENABLE_OCSP_STAPLING: on = ss->opt.enableOCSPStapling; break; default: PORT_SetError(SEC_ERROR_INVALID_ARGS); @@ -863,6 +869,9 @@ SSL_OptionGetDefault(PRInt32 which, PRBool *pOn) break; case SSL_ENABLE_FALSE_START: on = ssl_defaults.enableFalseStart; break; case SSL_ENABLE_SNAP_START: on = ssl_defaults.enableSnapStart; break; + case SSL_ENABLE_OCSP_STAPLING: + on = ssl_defaults.enableOCSPStapling; + break; default: PORT_SetError(SEC_ERROR_INVALID_ARGS); @@ -1014,6 +1023,10 @@ SSL_OptionSetDefault(PRInt32 which, PRBool on) ssl_defaults.enableSnapStart = on; break; + case SSL_ENABLE_OCSP_STAPLING: + ssl_defaults.enableOCSPStapling = on; + break; + default: PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; @@ -1473,6 +1486,36 @@ loser: #endif } +SECStatus +SSL_GetStapledOCSPResponse(PRFileDesc *fd, unsigned char *out_data, + unsigned int *len) { + sslSocket *ss = ssl_FindSocket(fd); + + if (!ss) { + SSL_DBG(("%d: SSL[%d]: bad socket in SSL_GetStapledOCSPResponse", + SSL_GETPID(), fd)); + return SECFailure; + } + + ssl_Get1stHandshakeLock(ss); + ssl_GetSSL3HandshakeLock(ss); + + if (ss->ssl3.hs.cert_status.data) { + unsigned int todo = ss->ssl3.hs.cert_status.len; + if (todo > *len) + todo = *len; + *len = ss->ssl3.hs.cert_status.len; + PORT_Memcpy(out_data, ss->ssl3.hs.cert_status.data, todo); + } else { + *len = 0; + } + + ssl_ReleaseSSL3HandshakeLock(ss); + ssl_Release1stHandshakeLock(ss); + + return SECSuccess; +} + /************************************************************************/ /* The following functions are the TOP LEVEL SSL functions. ** They all get called through the NSPRIOMethods table below. diff --git a/net/third_party/nss/ssl/sslt.h b/net/third_party/nss/ssl/sslt.h index 68cbf87..3fa3f9b 100644 --- a/net/third_party/nss/ssl/sslt.h +++ b/net/third_party/nss/ssl/sslt.h @@ -198,6 +198,7 @@ typedef enum { /* Update SSL_MAX_EXTENSIONS whenever a new extension type is added. */ typedef enum { ssl_server_name_xtn = 0, + ssl_cert_status_xtn = 5, #ifdef NSS_ENABLE_ECC ssl_elliptic_curves_xtn = 10, ssl_ec_point_formats_xtn = 11, @@ -208,7 +209,7 @@ typedef enum { ssl_renegotiation_info_xtn = 0xff01 /* experimental number */ } SSLExtensionType; -#define SSL_MAX_EXTENSIONS 7 +#define SSL_MAX_EXTENSIONS 8 typedef enum { /* No Snap Start handshake was attempted. */ diff --git a/net/tools/fetch/fetch_client.cc b/net/tools/fetch/fetch_client.cc index 3bdbcbf..138bed3 100644 --- a/net/tools/fetch/fetch_client.cc +++ b/net/tools/fetch/fetch_client.cc @@ -147,7 +147,7 @@ int main(int argc, char**argv) { scoped_ptr<net::HttpAuthHandlerFactory> http_auth_handler_factory( net::HttpAuthHandlerFactory::CreateDefault(host_resolver.get())); if (use_cache) { - factory = new net::HttpCache(host_resolver.get(), NULL, proxy_service, + factory = new net::HttpCache(host_resolver.get(), NULL, NULL, proxy_service, ssl_config_service, http_auth_handler_factory.get(), NULL, NULL, net::HttpCache::DefaultBackend::InMemory(0)); } else { @@ -155,6 +155,7 @@ int main(int argc, char**argv) { net::ClientSocketFactory::GetDefaultFactory(), host_resolver.get(), NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, proxy_service, ssl_config_service, diff --git a/net/tools/flip_server/balsa_headers.cc b/net/tools/flip_server/balsa_headers.cc index 2196cd4..74364a2 100644 --- a/net/tools/flip_server/balsa_headers.cc +++ b/net/tools/flip_server/balsa_headers.cc @@ -616,7 +616,7 @@ void BalsaHeaders::SetContentLength(size_t length) { content_length_ = length; // FastUInt64ToBuffer is supposed to use a maximum of kFastToBufferSize bytes. char buffer[kFastToBufferSize]; - int len_converted = snprintf(buffer, sizeof(buffer), "%d", length); + int len_converted = snprintf(buffer, sizeof(buffer), "%zd", length); CHECK_GT(len_converted, 0); const base::StringPiece length_str(buffer, len_converted); AppendHeader(content_length, length_str); @@ -725,7 +725,7 @@ void BalsaHeaders::SetParsedResponseCodeAndUpdateFirstline( size_t parsed_response_code) { char buffer[kFastToBufferSize]; int len_converted = snprintf(buffer, sizeof(buffer), - "%d", parsed_response_code); + "%zd", parsed_response_code); CHECK_GT(len_converted, 0); SetResponseCode(base::StringPiece(buffer, len_converted)); } diff --git a/net/tools/flip_server/create_listener.cc b/net/tools/flip_server/create_listener.cc index 3538261..59a03a6 100644 --- a/net/tools/flip_server/create_listener.cc +++ b/net/tools/flip_server/create_listener.cc @@ -7,8 +7,10 @@ #include <netdb.h> // for getaddrinfo and getnameinfo #include <netinet/in.h> // for IPPROTO_*, etc. #include <stdlib.h> // for EXIT_FAILURE +#include <netinet/tcp.h> // For TCP_NODELAY #include <sys/socket.h> // for getaddrinfo and getnameinfo #include <sys/types.h> // " +#include <fcntl.h> #include <unistd.h> // for exit() #include <ostream> @@ -64,15 +66,47 @@ bool CloseSocket(int *fd, int tries) { //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// +// Sets an FD to be nonblocking. +void SetNonBlocking(int fd) { + DCHECK_GE(fd, 0); + + int fcntl_return = fcntl(fd, F_GETFL, 0); + CHECK_NE(fcntl_return, -1) + << "error doing fcntl(fd, F_GETFL, 0) fd: " << fd + << " errno=" << errno; + + if (fcntl_return & O_NONBLOCK) + return; + + fcntl_return = fcntl(fd, F_SETFL, fcntl_return | O_NONBLOCK); + CHECK_NE(fcntl_return, -1) + << "error doing fcntl(fd, F_SETFL, fcntl_return) fd: " << fd + << " errno=" << errno; +} + +int SetDisableNagle(int fd) { + int on = 1; + int rc; + rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast<char*>(&on), sizeof(on)); + if (rc < 0) { + close(fd); + LOG(FATAL) << "setsockopt() TCP_NODELAY: failed on fd " << fd; + return 0; + } + return 1; +} + // see header for documentation of this function. -void CreateListeningSocket(const std::string& host, - const std::string& port, - bool is_numeric_host_address, - int backlog, - int * listen_fd, - bool reuseaddr, - bool reuseport, - std::ostream* error_stream) { +int CreateListeningSocket(const std::string& host, + const std::string& port, + bool is_numeric_host_address, + int backlog, + bool reuseaddr, + bool reuseport, + bool wait_for_iface, + bool disable_nagle, + int * listen_fd ) { // start out by assuming things will fail. *listen_fd = -1; @@ -91,16 +125,15 @@ void CreateListeningSocket(const std::string& host, } hints.ai_flags |= AI_PASSIVE; - hints.ai_family = PF_INET; // we know it'll be IPv4, but if we didn't - // hints.ai_family = PF_UNSPEC; // know we'd use this. <--- + hints.ai_family = PF_INET; hints.ai_socktype = SOCK_STREAM; int err = 0; if ((err=getaddrinfo(node, service, &hints, &results))) { // gai_strerror -is- threadsafe, so we get to use it here. - *error_stream << "getaddrinfo " << " for (" << host << ":" << port + LOG(ERROR) << "getaddrinfo " << " for (" << host << ":" << port << ") " << gai_strerror(err) << "\n"; - return; + return -1; } // this will delete the addrinfo memory when we return from this function. AddrinfoGuard addrinfo_guard(results); @@ -109,9 +142,9 @@ void CreateListeningSocket(const std::string& host, results->ai_socktype, results->ai_protocol); if (sock == -1) { - *error_stream << "Unable to create socket for (" << host << ":" + LOG(ERROR) << "Unable to create socket for (" << host << ":" << port << "): " << strerror(errno) << "\n"; - return; + return -1; } if (reuseaddr) { @@ -141,39 +174,126 @@ void CreateListeningSocket(const std::string& host, } if (bind(sock, results->ai_addr, results->ai_addrlen)) { - *error_stream << "Bind was unsuccessful for (" << host << ":" - << port << "): " << strerror(errno) << "\n"; + // If we are waiting for the interface to be raised, such as in an + // HA environment, ignore reporting any errors. + int saved_errno = errno; + if ( !wait_for_iface || errno != EADDRNOTAVAIL) { + LOG(ERROR) << "Bind was unsuccessful for (" << host << ":" + << port << "): " << strerror(errno) << "\n"; + } // if we knew that we were not multithreaded, we could do the following: // " : " << strerror(errno) << "\n"; if (CloseSocket(&sock, 100)) { - return; + if ( saved_errno == EADDRNOTAVAIL ) { + return -3; + } + return -2; } else { // couldn't even close the dang socket?! - *error_stream << "Unable to close the socket.. Considering this a fatal " + LOG(ERROR) << "Unable to close the socket.. Considering this a fatal " "error, and exiting\n"; exit(EXIT_FAILURE); + return -1; + } + } + + if (disable_nagle) { + if (!SetDisableNagle(sock)) { + return -1; } } if (listen(sock, backlog)) { // listen was unsuccessful. - *error_stream << "Listen was unsuccessful for (" << host << ":" + LOG(ERROR) << "Listen was unsuccessful for (" << host << ":" << port << "): " << strerror(errno) << "\n"; // if we knew that we were not multithreaded, we could do the following: // " : " << strerror(errno) << "\n"; if (CloseSocket(&sock, 100)) { sock = -1; - return; + return -1; } else { // couldn't even close the dang socket?! - *error_stream << "Unable to close the socket.. Considering this a fatal " + LOG(FATAL) << "Unable to close the socket.. Considering this a fatal " "error, and exiting\n"; - exit(EXIT_FAILURE); } } + // If we've gotten to here, Yeay! Success! *listen_fd = sock; + + return 0; +} + +int CreateConnectedSocket( int *connect_fd, + const std::string& host, + const std::string& port, + bool is_numeric_host_address, + bool disable_nagle ) { + const char* node = NULL; + const char* service = NULL; + + *connect_fd = -1; + if (!host.empty()) + node = host.c_str(); + if (!port.empty()) + service = port.c_str(); + + struct addrinfo *results = 0; + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + + if (is_numeric_host_address) + hints.ai_flags = AI_NUMERICHOST; // iff you know the name is numeric. + hints.ai_flags |= AI_PASSIVE; + + hints.ai_family = PF_INET; + hints.ai_socktype = SOCK_STREAM; + + int err = 0; + if ((err=getaddrinfo(node, service, &hints, &results))) { + // gai_strerror -is- threadsafe, so we get to use it here. + LOG(ERROR) << "getaddrinfo for (" << node << ":" << service << "): " + << gai_strerror(err); + return -1; + } + // this will delete the addrinfo memory when we return from this function. + AddrinfoGuard addrinfo_guard(results); + + int sock = socket(results->ai_family, + results->ai_socktype, + results->ai_protocol); + if (sock == -1) { + LOG(ERROR) << "Unable to create socket for (" << node << ":" << service + << "): " << strerror( errno ); + return -1; + } + + SetNonBlocking( sock ); + + if (disable_nagle) { + if (!SetDisableNagle(sock)) { + return -1; + } + } + + int ret_val = 0; + if ( connect( sock, results->ai_addr, results->ai_addrlen ) ) { + if ( errno != EINPROGRESS ) { + LOG(ERROR) << "Connect was unsuccessful for (" << node << ":" << service + << "): " << strerror(errno); + close( sock ); + return -1; + } + } else { + ret_val = 1; + } + + // If we've gotten to here, Yeay! Success! + *connect_fd = sock; + + return ret_val; } } // namespace net diff --git a/net/tools/flip_server/create_listener.h b/net/tools/flip_server/create_listener.h index 4c0a277..3a7b16e 100644 --- a/net/tools/flip_server/create_listener.h +++ b/net/tools/flip_server/create_listener.h @@ -11,6 +11,8 @@ namespace net { +void SetNonBlocking(int fd); + // Summary: // creates a socket for listening, and bind()s and listen()s it. // Args: @@ -27,19 +29,29 @@ namespace net { // backlog - passed into listen. This is the number of pending incoming // connections a socket which is listening may have acquired before // the OS starts rejecting new incoming connections. +// reuseaddr - if true sets SO_REUSEADDR on the listening socket +// reuseport - if true sets SO_REUSEPORT on the listening socket +// wait_for_iface - A boolean indicating that CreateListeningSocket should +// block until the interface that it will bind to has been +// raised. This is intended for HA environments. +// disable_nagle - if true sets TCP_NODELAY on the listening socket. // listen_fd - this will be assigned a positive value if the socket is // successfully created, else it will be assigned -1. -// error_stream - in the case of errors, output describing the error will -// be written into error_stream. -void CreateListeningSocket(const std::string& host, - const std::string& port, - bool is_numeric_host_address, - int backlog, - int * listen_fd, - bool reuseaddr, - bool reuseport, - std::ostream* error_stream); +int CreateListeningSocket(const std::string& host, + const std::string& port, + bool is_numeric_host_address, + int backlog, + bool reuseaddr, + bool reuseport, + bool wait_for_iface, + bool disable_nagle, + int * listen_fd); +int CreateConnectedSocket(int *connect_fd, + const std::string& host, + const std::string& port, + bool is_numeric_host_address, + bool disable_nagle); } // namespace net #endif // NET_TOOLS_FLIP_SERVER_CREATE_LISTENER_H__ diff --git a/net/tools/flip_server/epoll_server.cc b/net/tools/flip_server/epoll_server.cc index a1d6ca1..f78663c 100644 --- a/net/tools/flip_server/epoll_server.cc +++ b/net/tools/flip_server/epoll_server.cc @@ -13,7 +13,6 @@ #include "base/logging.h" #include "base/timer.h" -#include "net/tools/flip_server/other_defines.h" // Design notes: An efficient implementation of ready list has the following // desirable properties: @@ -478,7 +477,8 @@ int EpollServer::NumFDsRegistered() const { void EpollServer::Wake() { char data = 'd'; // 'd' is for data. It's good enough for me. - write(write_fd_, &data, 1); + int rv = write(write_fd_, &data, 1); + DCHECK(rv == 1); } int64 EpollServer::NowInUsec() const { diff --git a/net/tools/flip_server/flip_config.h b/net/tools/flip_server/flip_config.h new file mode 100644 index 0000000..3f202f8 --- /dev/null +++ b/net/tools/flip_server/flip_config.h @@ -0,0 +1,150 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_FLIP_PROXY_CONFIG_H +#define NET_TOOLS_FLIP_PROXY_CONFIG_H +#pragma once + +#include <arpa/inet.h> // in_addr_t + +#include "base/logging.h" +#include "net/tools/flip_server/create_listener.h" + +#include <vector> +#include <string> + +using std::string; +using std::vector; + +enum FlipHandlerType { + FLIP_HANDLER_PROXY, + FLIP_HANDLER_SPDY_SERVER, + FLIP_HANDLER_HTTP_SERVER +}; + +class FlipAcceptor { +public: + enum FlipHandlerType flip_handler_type_; + string listen_ip_; + string listen_port_; + string ssl_cert_filename_; + string ssl_key_filename_; + string server_ip_; + string server_port_; + int accept_backlog_size_; + bool disable_nagle_; + int accepts_per_wake_; + int listen_fd_; + void* memory_cache_; + + FlipAcceptor(enum FlipHandlerType flip_handler_type, + string listen_ip, + string listen_port, + string ssl_cert_filename, + string ssl_key_filename, + string server_ip, + string server_port, + int accept_backlog_size, + bool disable_nagle, + int accepts_per_wake, + bool reuseport, + bool wait_for_iface, + void *memory_cache) : + flip_handler_type_(flip_handler_type), + listen_ip_(listen_ip), + listen_port_(listen_port), + ssl_cert_filename_(ssl_cert_filename), + ssl_key_filename_(ssl_key_filename), + server_ip_(server_ip), + server_port_(server_port), + accept_backlog_size_(accept_backlog_size), + disable_nagle_(disable_nagle), + accepts_per_wake_(accepts_per_wake), + memory_cache_(memory_cache) + { + VLOG(1) << "Attempting to listen on " << listen_ip_.c_str() << ":" + << listen_port_.c_str(); + while (1) { + int ret = net::CreateListeningSocket(listen_ip_, + listen_port_, + true, + accept_backlog_size_, + true, + reuseport, + wait_for_iface, + disable_nagle_, + &listen_fd_); + if ( ret == 0 ) { + break; + } else if ( ret == -3 && wait_for_iface ) { + // Binding error EADDRNOTAVAIL was encounted. We need + // to wait for the interfaces to raised. try again. + usleep(200000); + } else { + LOG(ERROR) << "Unable to create listening socket for: ret = " << ret + << ": " << listen_ip_.c_str() << ":" + << listen_port_.c_str(); + return; + } + } + net::SetNonBlocking(listen_fd_); + VLOG(1) << "Listening for spdy on port: " << listen_ip_ << ":" + << listen_port_; + } + ~FlipAcceptor () {} +}; + +class FlipConfig { +public: + std::vector <FlipAcceptor*> acceptors_; + double server_think_time_in_s_; + enum logging::LoggingDestination log_destination_; + string log_filename_; + bool forward_ip_header_enabled_; + string forward_ip_header_; + bool wait_for_iface_; + int ssl_session_expiry_; + + FlipConfig() : + server_think_time_in_s_(0), + log_destination_(logging::LOG_ONLY_TO_SYSTEM_DEBUG_LOG), + forward_ip_header_enabled_(false), + wait_for_iface_(false), + ssl_session_expiry_(300) + {} + + ~FlipConfig() {} + + void AddAcceptor(enum FlipHandlerType flip_handler_type, + string listen_ip, + string listen_port, + string ssl_cert_filename, + string ssl_key_filename, + string server_ip, + string server_port, + int accept_backlog_size, + bool disable_nagle, + int accepts_per_wake, + bool reuseport, + bool wait_for_iface, + void *memory_cache) { + // TODO(mbelshe): create a struct FlipConfigArgs{} for the arguments. + acceptors_.push_back(new FlipAcceptor(flip_handler_type, + listen_ip, + listen_port, + ssl_cert_filename, + ssl_key_filename, + server_ip, + server_port, + accept_backlog_size, + disable_nagle, + accepts_per_wake, + reuseport, + wait_for_iface, + memory_cache)); + } + +}; + +#endif diff --git a/net/tools/flip_server/flip_in_mem_edsm_server.cc b/net/tools/flip_server/flip_in_mem_edsm_server.cc index 1365818..7848861 100644 --- a/net/tools/flip_server/flip_in_mem_edsm_server.cc +++ b/net/tools/flip_server/flip_in_mem_edsm_server.cc @@ -9,6 +9,7 @@ #include <unistd.h> #include <openssl/err.h> #include <openssl/ssl.h> +#include <signal.h> #include <deque> #include <iostream> @@ -16,6 +17,7 @@ #include <vector> #include <list> +#include "base/command_line.h" #include "base/logging.h" #include "base/simple_thread.h" #include "base/timer.h" @@ -30,16 +32,14 @@ #include "net/tools/flip_server/balsa_headers.h" #include "net/tools/flip_server/balsa_visitor_interface.h" #include "net/tools/flip_server/buffer_interface.h" -#include "net/tools/flip_server/create_listener.h" #include "net/tools/flip_server/epoll_server.h" -#include "net/tools/flip_server/other_defines.h" #include "net/tools/flip_server/ring_buffer.h" #include "net/tools/flip_server/simple_buffer.h" #include "net/tools/flip_server/split.h" +#include "net/tools/flip_server/flip_config.h" //////////////////////////////////////////////////////////////////////////////// -using std::cerr; using std::deque; using std::list; using std::map; @@ -47,26 +47,24 @@ using std::ostream; using std::pair; using std::string; using std::vector; +using std::cout; //////////////////////////////////////////////////////////////////////////////// +#define IPV4_PRINTABLE_FORMAT(IP) (((IP)>>0)&0xff),(((IP)>>8)&0xff), \ + (((IP)>>16)&0xff),(((IP)>>24)&0xff) -// If set to true, then the server will act as an SSL server for both -// HTTP and SPDY); -bool FLAGS_use_ssl = true; +#define ACCEPTOR_CLIENT_IDENT acceptor_->listen_ip_ << ":" \ + << acceptor_->listen_port_ << " " +#define ACCEPTOR_SERVER_IDENT acceptor_->server_ip_ << ":" \ + << acceptor_->server_port_ << " " -// The name of the cert .pem file); -string FLAGS_ssl_cert_name = "cert.pem"; +#define NEXT_PROTO_STRING "\x06spdy/2\x08http/1.1\x08http/1.0" -// The name of the key .pem file); -string FLAGS_ssl_key_name = "key.pem"; - -// The number of responses given before the server closes the -// connection); -int32 FLAGS_response_count_until_close = 1000*1000; +#define SSL_CTX_DEFAULT_CIPHER_LIST "RC4:!aNULL:!eNULL" // If true, then disables the nagle algorithm); -bool FLAGS_no_nagle = true; +bool FLAGS_disable_nagle = true; // The number of times that accept() will be called when the // alarm goes off when the accept_using_alarm flag is set to true. @@ -74,12 +72,6 @@ bool FLAGS_no_nagle = true; // is completely drained and the accept() call returns an error); int32 FLAGS_accepts_per_wake = 0; -// The port on which the spdy server listens); -int32 FLAGS_spdy_port = 10040; - -// The port on which the http server listens); -int32 FLAGS_port = 16002; - // The size of the TCP accept backlog); int32 FLAGS_accept_backlog_size = 1024; @@ -87,7 +79,7 @@ int32 FLAGS_accept_backlog_size = 1024; string FLAGS_cache_base_dir = "."; // If true, then encode url to filename); -bool FLAGS_need_to_encode_url = true; +bool FLAGS_need_to_encode_url = false; // If set to false a single socket will be used. If set to true // then a new socket will be created for each accept thread. @@ -105,9 +97,6 @@ bool FLAGS_use_xsub = false; // Does the server send X-Associated-Content headers); bool FLAGS_use_xac = false; -// Does the server advance cwnd by sending no-op packets); -bool FLAGS_use_cwnd_opener = false; - // Does the server compress data frames); bool FLAGS_use_compression = false; @@ -115,8 +104,6 @@ bool FLAGS_use_compression = false; using base::StringPiece; using base::SimpleThread; -// using base::Lock; // heh, this isn't in base namespace?! -// using base::AutoLock; // ditto! using net::BalsaFrame; using net::BalsaFrameEnums; using net::BalsaHeaders; @@ -137,6 +124,7 @@ using spdy::RST_STREAM; using spdy::SYN_REPLY; using spdy::SYN_STREAM; using spdy::SpdyControlFrame; +using spdy::SpdySettingsControlFrame; using spdy::SpdyDataFlags; using spdy::SpdyDataFrame; using spdy::SpdyRstStreamControlFrame; @@ -149,6 +137,7 @@ using spdy::SpdyStreamId; using spdy::SpdySynReplyControlFrame; using spdy::SpdySynStreamControlFrame; +FlipConfig g_proxy_config; //////////////////////////////////////////////////////////////////////////////// @@ -156,11 +145,21 @@ void PrintSslError() { char buf[128]; // this buffer must be at least 120 chars long. int error_num = ERR_get_error(); while (error_num != 0) { - LOG(INFO)<< ERR_error_string(error_num, buf); + LOG(ERROR) << ERR_error_string(error_num, buf); error_num = ERR_get_error(); } } +static int ssl_set_npn_callback(SSL *s, + const unsigned char **data, + unsigned int *len, + void *arg) +{ + VLOG(1) << "SSL NPN callback: advertising protocols."; + *data = (const unsigned char *) NEXT_PROTO_STRING; + *len = strlen(NEXT_PROTO_STRING); + return SSL_TLSEXT_ERR_OK; +} //////////////////////////////////////////////////////////////////////////////// // Creates a socket with domain, type and protocol parameters. @@ -172,26 +171,6 @@ int CreateSocket(int domain, int type, int protocol, int *fd) { return (*fd == -1) ? errno : 0; } -//////////////////////////////////////////////////////////////////////////////// - -// Sets an FD to be nonblocking. -void SetNonBlocking(int fd) { - DCHECK(fd >= 0); - - int fcntl_return = fcntl(fd, F_GETFL, 0); - CHECK_NE(fcntl_return, -1) - << "error doing fcntl(fd, F_GETFL, 0) fd: " << fd - << " errno=" << errno; - - if (fcntl_return & O_NONBLOCK) - return; - - fcntl_return = fcntl(fd, F_SETFL, fcntl_return | O_NONBLOCK); - CHECK_NE(fcntl_return, -1) - << "error doing fcntl(fd, F_SETFL, fcntl_return) fd: " << fd - << " errno=" << errno; -} - // Encode the URL. string EncodeURL(string uri, string host, string method) { if (!FLAGS_need_to_encode_url) { @@ -212,20 +191,16 @@ string EncodeURL(string uri, string host, string method) { //////////////////////////////////////////////////////////////////////////////// - -struct GlobalSSLState { +struct SSLState { SSL_METHOD* ssl_method; SSL_CTX* ssl_ctx; }; -//////////////////////////////////////////////////////////////////////////////// - -GlobalSSLState* global_ssl_state = NULL; - -//////////////////////////////////////////////////////////////////////////////// - // SSL stuff -void spdy_init_ssl(GlobalSSLState* state) { +void spdy_init_ssl(SSLState* state, + string ssl_cert_name, + string ssl_key_name, + bool use_npn) { SSL_library_init(); PrintSslError(); @@ -241,13 +216,13 @@ void spdy_init_ssl(GlobalSSLState* state) { // Disable SSLv2 support. SSL_CTX_set_options(state->ssl_ctx, SSL_OP_NO_SSLv2); if (SSL_CTX_use_certificate_file(state->ssl_ctx, - FLAGS_ssl_cert_name.c_str(), + ssl_cert_name.c_str(), SSL_FILETYPE_PEM) <= 0) { PrintSslError(); LOG(FATAL) << "Unable to use cert.pem as SSL cert."; } if (SSL_CTX_use_PrivateKey_file(state->ssl_ctx, - FLAGS_ssl_key_name.c_str(), + ssl_key_name.c_str(), SSL_FILETYPE_PEM) <= 0) { PrintSslError(); LOG(FATAL) << "Unable to use key.pem as SSL key."; @@ -256,6 +231,21 @@ void spdy_init_ssl(GlobalSSLState* state) { PrintSslError(); LOG(FATAL) << "The cert.pem and key.pem files don't match"; } + if (use_npn) { + SSL_CTX_set_next_protos_advertised_cb(state->ssl_ctx, + ssl_set_npn_callback, NULL); + } + VLOG(1) << "SSL CTX default cipher list: " << SSL_CTX_DEFAULT_CIPHER_LIST; + SSL_CTX_set_cipher_list(state->ssl_ctx, SSL_CTX_DEFAULT_CIPHER_LIST); + + VLOG(1) << "SSL CTX session expiry: " << g_proxy_config.ssl_session_expiry_ + << " seconds"; + SSL_CTX_set_timeout(state->ssl_ctx, g_proxy_config.ssl_session_expiry_); + +#ifdef SSL_MODE_RELEASE_BUFFERS + VLOG(1) << "SSL CTX: Setting Release Buffers mode."; + SSL_CTX_set_mode(state->ssl_ctx, SSL_MODE_RELEASE_BUFFERS); +#endif } SSL* spdy_new_ssl(SSL_CTX* ssl_ctx) { @@ -264,6 +254,7 @@ SSL* spdy_new_ssl(SSL_CTX* ssl_ctx) { SSL_set_accept_state(ssl); PrintSslError(); + return ssl; } @@ -404,7 +395,6 @@ class MemoryCache { } void AddFiles() { - LOG(INFO) << "Adding files!"; deque<string> paths; cwd_ = FLAGS_cache_base_dir; paths.push_back(cwd_ + "/GET_"); @@ -580,7 +570,6 @@ class MemoryCache { LOG(ERROR) << "Skipping subresource with unknown content-type"; return; } - // Now, lets see if this is the same host or not bool same_host = (UrlUtilities::GetUrlHost(referrer) == UrlUtilities::GetUrlHost(url)); @@ -673,8 +662,6 @@ class MemoryCache { } }; -//////////////////////////////////////////////////////////////////////////////// - class NotifierInterface { public: virtual ~NotifierInterface() {} @@ -683,16 +670,33 @@ class NotifierInterface { //////////////////////////////////////////////////////////////////////////////// +class SMConnectionPoolInterface; + class SMInterface { public: - virtual size_t ProcessInput(const char* data, size_t len) = 0; + virtual void InitSMInterface(SMInterface* sm_other_interface, + int32 server_idx) = 0; + virtual void InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + bool use_ssl) = 0; + virtual size_t ProcessReadInput(const char* data, size_t len) = 0; + virtual size_t ProcessWriteInput(const char* data, size_t len) = 0; + virtual void SetStreamID(uint32 stream_id) = 0; virtual bool MessageFullyRead() const = 0; virtual bool Error() const = 0; virtual const char* ErrorAsString() const = 0; virtual void Reset() = 0; + virtual void ResetForNewInterface(int32 server_idx) = 0; + // ResetForNewConnection is used for interfaces which control SMConnection + // objects. When called an interface may put its connection object into + // a reusable instance pool. Currently this is what the HttpSM interface + // does. virtual void ResetForNewConnection() = 0; + virtual void Cleanup() = 0; - virtual void PostAcceptHook() = 0; + virtual int PostAcceptHook() = 0; virtual void NewStream(uint32 stream_id, uint32 priority, const string& filename) = 0; @@ -709,65 +713,98 @@ class SMInterface { virtual ~SMInterface() {} }; -//////////////////////////////////////////////////////////////////////////////// +class SMConnectionInterface { + public: + virtual ~SMConnectionInterface() {} + virtual void ReadyToSend() = 0; + virtual EpollServer* epoll_server() = 0; +}; -class SMServerConnection; -typedef SMInterface*(SMInterfaceFactory)(SMServerConnection*); - -//////////////////////////////////////////////////////////////////////////////// +class HttpSM; +class SMConnection; typedef list<DataFrame> OutputList; -//////////////////////////////////////////////////////////////////////////////// - -class SMServerConnection; - -class SMServerConnectionPoolInterface { +class SMConnectionPoolInterface { public: - virtual ~SMServerConnectionPoolInterface() {} - // SMServerConnections will use this: - virtual void SMServerConnectionDone(SMServerConnection* connection) = 0; + virtual ~SMConnectionPoolInterface() {} + // SMConnections will use this: + virtual void SMConnectionDone(SMConnection* connection) = 0; }; -//////////////////////////////////////////////////////////////////////////////// - -class SMServerConnection: public EpollCallbackInterface, - public NotifierInterface { - private: - SMServerConnection(SMInterfaceFactory* sm_interface_factory, - MemoryCache* memory_cache, - EpollServer* epoll_server) : - fd_(-1), - events_(0), - - registered_in_epoll_server_(false), - initialized_(false), +SMInterface* NewStreamerSM(SMConnection* connection, + SMInterface* sm_interface, + EpollServer* epoll_server, + FlipAcceptor* acceptor); - connection_pool_(NULL), - epoll_server_(epoll_server), +SMInterface* NewSpdySM(SMConnection* connection, + SMInterface* sm_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor); - read_buffer_(4096*10), - memory_cache_(memory_cache), - sm_interface_(sm_interface_factory(this)), +SMInterface* NewHttpSM(SMConnection* connection, + SMInterface* sm_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor); - max_bytes_sent_per_dowrite_(4096), +//////////////////////////////////////////////////////////////////////////////// - ssl_(NULL) {} +class SMConnection: public SMConnectionInterface, + public EpollCallbackInterface, + public NotifierInterface { + private: + SMConnection(EpollServer* epoll_server, + SSLState* ssl_state, + MemoryCache* memory_cache, + FlipAcceptor* acceptor, + string log_prefix) + : fd_(-1), + events_(0), + registered_in_epoll_server_(false), + initialized_(false), + protocol_detected_(false), + connection_complete_(false), + connection_pool_(NULL), + epoll_server_(epoll_server), + ssl_state_(ssl_state), + memory_cache_(memory_cache), + acceptor_(acceptor), + read_buffer_(4096*10), + sm_spdy_interface_(NULL), + sm_http_interface_(NULL), + sm_streamer_interface_(NULL), + sm_interface_(NULL), + log_prefix_(log_prefix), + max_bytes_sent_per_dowrite_(4096), + ssl_(NULL) + {} int fd_; int events_; bool registered_in_epoll_server_; bool initialized_; + bool protocol_detected_; + bool connection_complete_; - SMServerConnectionPoolInterface* connection_pool_; - EpollServer* epoll_server_; + SMConnectionPoolInterface* connection_pool_; + + EpollServer *epoll_server_; + SSLState *ssl_state_; + MemoryCache* memory_cache_; + FlipAcceptor *acceptor_; + string client_ip_; RingBuffer read_buffer_; OutputList output_list_; - MemoryCache* memory_cache_; + SMInterface* sm_spdy_interface_; + SMInterface* sm_http_interface_; + SMInterface* sm_streamer_interface_; SMInterface* sm_interface_; + string log_prefix_; size_t max_bytes_sent_per_dowrite_; @@ -777,45 +814,101 @@ class SMServerConnection: public EpollCallbackInterface, OutputList* output_list() { return &output_list_; } MemoryCache* memory_cache() { return memory_cache_; } void ReadyToSend() { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Setting ready to send: EPOLLIN | EPOLLOUT"; epoll_server_->SetFDReady(fd_, EPOLLIN | EPOLLOUT); } void EnqueueDataFrame(const DataFrame& df) { output_list_.push_back(df); - VLOG(2) << "EnqueueDataFrame. Setting FD ready."; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "EnqueueDataFrame: " + << "size = " << df.size << ": Setting FD ready."; ReadyToSend(); } + int fd() { return fd_; } public: - ~SMServerConnection() { + ~SMConnection() { if (initialized()) { Reset(); } } - static SMServerConnection* NewSMServerConnection(SMInterfaceFactory* smif, - MemoryCache* memory_cache, - EpollServer* epoll_server) { - return new SMServerConnection(smif, memory_cache, epoll_server); + static SMConnection* NewSMConnection(EpollServer* epoll_server, + SSLState *ssl_state, + MemoryCache* memory_cache, + FlipAcceptor *acceptor, + string log_prefix) { + return new SMConnection(epoll_server, ssl_state, memory_cache, + acceptor, log_prefix); } bool initialized() const { return initialized_; } + string client_ip() const { return client_ip_; } - void InitSMServerConnection(SMServerConnectionPoolInterface* connection_pool, - EpollServer* epoll_server, - int fd) { + void InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + bool use_ssl) { if (initialized_) { LOG(FATAL) << "Attempted to initialize already initialized server"; return; } - if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) { - epoll_server_->UnregisterFD(fd_); - } - if (fd_ != -1) { - VLOG(2) << "Closing pre-existing fd"; - close(fd_); - fd_ = -1; - } - fd_ = fd; + if (fd == -1) { + // If fd == -1, then we are initializing a new connection that will + // connect to the backend. + // + // ret: -1 == error + // 0 == connection in progress + // 1 == connection complete + // TODO: is_numeric_host_address value needs to be detected + int ret = net::CreateConnectedSocket(&fd_, + acceptor_->server_ip_, + acceptor_->server_port_, + true, + acceptor_->disable_nagle_); + + if (ret < 0) { + LOG(ERROR) << "-1 Could not create connected socket"; + return; + } else if (ret == 1) { + DCHECK_NE(-1, fd_); + connection_complete_ = true; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Connection complete to: " << ACCEPTOR_SERVER_IDENT; + } + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Connecting to server: " << ACCEPTOR_SERVER_IDENT; + } else { + // If fd != -1 then we are initializing a connection that has just been + // accepted from the listen socket. + connection_complete_ = true; + if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) { + epoll_server_->UnregisterFD(fd_); + } + if (fd_ != -1) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Closing pre-existing fd"; + close(fd_); + fd_ = -1; + } + + fd_ = fd; + struct sockaddr sock_addr; + socklen_t addr_size = sizeof(sock_addr); + addr_size = sizeof(sock_addr); + int res = getsockname(fd_, &sock_addr, &addr_size); + if (res < 0) { + LOG(ERROR) << "Could not get socket address for fd " << fd_ + << ": getsockname: " << strerror(errno); + } else { + struct sockaddr_in *sock_addr_in = (struct sockaddr_in *)&sock_addr; + char ip[16]; + snprintf(ip, sizeof(ip), "%d.%d.%d.%d", + IPV4_PRINTABLE_FORMAT(sock_addr_in->sin_addr.s_addr)); + client_ip_ = ip; + } + } registered_in_epoll_server_ = false; initialized_ = true; @@ -823,21 +916,42 @@ class SMServerConnection: public EpollCallbackInterface, connection_pool_ = connection_pool; epoll_server_ = epoll_server; - sm_interface_->Reset(); + if (sm_interface) { + sm_interface_ = sm_interface; + protocol_detected_ = true; + } + read_buffer_.Clear(); epoll_server_->RegisterFD(fd_, this, EPOLLIN | EPOLLOUT | EPOLLET); - if (global_ssl_state) { - ssl_ = spdy_new_ssl(global_ssl_state->ssl_ctx); + if (use_ssl) { + ssl_ = spdy_new_ssl(ssl_state_->ssl_ctx); SSL_set_fd(ssl_, fd_); PrintSslError(); } - sm_interface_->PostAcceptHook(); } int Send(const char* bytes, int len, int flags) { - return send(fd_, bytes, len, flags); + ssize_t bytes_written = 0; + if (ssl_) { + bytes_written = SSL_write(ssl_, bytes, len); + if (bytes_written < 0) { + switch(SSL_get_error(ssl_, bytes_written)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_ACCEPT: + case SSL_ERROR_WANT_CONNECT: + return -2; + default: + PrintSslError(); + break; + } + } + } else { + bytes_written = send(fd_, bytes, len, flags); + } + return bytes_written; } // the following are from the EpollCallbackInterface @@ -861,43 +975,95 @@ class SMServerConnection: public EpollCallbackInterface, return; } + void Cleanup(const char* cleanup) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Cleanup"; + if (!initialized_) { + return; + } + Reset(); + if (connection_pool_) { + connection_pool_->SMConnectionDone(this); + } + if (sm_interface_) { + sm_interface_->ResetForNewConnection(); + } + } + private: void HandleEvents() { - VLOG(1) << "Received: " << EpollServer::EventMaskToString(events_); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Received: " + << EpollServer::EventMaskToString(events_).c_str(); + if (events_ & EPOLLIN) { if (!DoRead()) goto handle_close_or_error; } if (events_ & EPOLLOUT) { + // Check if we have connected or not + if (connection_complete_ == false) { + int sock_error; + socklen_t sock_error_len = sizeof(sock_error); + int ret = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &sock_error, + &sock_error_len); + if (ret != 0) { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "getsockopt error: " << errno << ": " << strerror(errno); + goto handle_close_or_error; + } + if (sock_error == 0) { + connection_complete_ = true; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Connection complete to " << ACCEPTOR_SERVER_IDENT; + } else if (sock_error == EINPROGRESS) { + return; + } else { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "error connecting to server"; + goto handle_close_or_error; + } + } if (!DoWrite()) goto handle_close_or_error; } if (events_ & (EPOLLHUP | EPOLLERR)) { - VLOG(2) << "!!!! Got HUP or ERR"; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "!!! Got HUP or ERR"; goto handle_close_or_error; } return; - handle_close_or_error: + handle_close_or_error: Cleanup("HandleEvents"); } bool DoRead() { - VLOG(2) << "DoRead()"; - if (fd_ == -1) { - VLOG(2) << "DoRead(): fd_ == -1. Invalid FD. Returning false"; - return false; - } + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead()"; while (!read_buffer_.Full()) { char* bytes; int size; + if (fd_ == -1) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoRead(): fd_ == -1. Invalid FD. Returning false"; + return false; + } read_buffer_.GetWritablePtr(&bytes, &size); ssize_t bytes_read = 0; if (ssl_) { bytes_read = SSL_read(ssl_, bytes, size); - PrintSslError(); + if (bytes_read < 0) { + switch(SSL_get_error(ssl_, bytes_read)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_ACCEPT: + case SSL_ERROR_WANT_CONNECT: + events_ &= ~EPOLLIN; + goto done; + default: + PrintSslError(); + break; + } + } } else { bytes_read = recv(fd_, bytes, size, MSG_DONTWAIT); } @@ -906,33 +1072,107 @@ class SMServerConnection: public EpollCallbackInterface, switch (stored_errno) { case EAGAIN: events_ &= ~EPOLLIN; - VLOG(2) << "Got EAGAIN while reading"; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EAGAIN while reading"; goto done; case EINTR: - VLOG(2) << "Got EINTR while reading"; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EINTR while reading"; continue; default: - VLOG(2) << "While calling recv, got error: " << stored_errno - << " " << strerror(stored_errno); + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "While calling recv, got error: " + << (ssl_?"(ssl error)":strerror(stored_errno)); goto error_or_close; } } else if (bytes_read > 0) { - VLOG(2) << "Read: " << bytes_read << " bytes from fd: " << fd_; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "read " << bytes_read + << " bytes"; + if (!protocol_detected_) { + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_HTTP_SERVER) { + // Http Server + protocol_detected_ = true; + if (!sm_http_interface_) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Created HTTP interface."; + sm_http_interface_ = NewHttpSM(this, NULL, epoll_server_, + memory_cache_, acceptor_); + } else { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Reusing HTTP interface."; + } + sm_interface_ = sm_http_interface_; + } else if (ssl_) { + protocol_detected_ = true; + if (SSL_session_reused(ssl_) == 0) { + VLOG(1) << "Session status: renegotiated"; + } else { + VLOG(1) << "Session status: resumed"; + } + const unsigned char *npn_proto; + unsigned int npn_proto_len; + SSL_get0_next_proto_negotiated(ssl_, &npn_proto, &npn_proto_len); + if (npn_proto_len > 0) { + string npn_proto_str((const char *)npn_proto, npn_proto_len); + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "NPN protocol detected: " << npn_proto_str; + } else { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "NPN protocol detected: none"; + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_SPDY_SERVER) { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "NPN protocol: Could not negotiate SPDY protocol."; + goto error_or_close; + } + } + if (npn_proto_len > 0 && + !strncmp((char *)npn_proto, "spdy/2", npn_proto_len)) { + if (!sm_spdy_interface_) { + sm_spdy_interface_ = NewSpdySM(this, NULL, epoll_server_, + memory_cache_, acceptor_); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Created SPDY interface."; + } else { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Reusing SPDY interface."; + } + sm_interface_ = sm_spdy_interface_; + } else { + if (!sm_streamer_interface_) { + sm_streamer_interface_ = NewStreamerSM(this, NULL, + epoll_server_, + acceptor_); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Created Streamer interface."; + } else { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Reusing Streamer interface: "; + } + sm_interface_ = sm_streamer_interface_; + } + } + if (sm_interface_->PostAcceptHook() == 0) { + goto error_or_close; + } + } read_buffer_.AdvanceWritablePtr(bytes_read); if (!DoConsumeReadData()) { goto error_or_close; } continue; } else { // bytes_read == 0 - VLOG(2) << "0 bytes read with recv call."; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "0 bytes read with recv call."; } goto error_or_close; } done: return true; - error_or_close: - VLOG(2) << "DoRead(): error_or_close. Cleaning up, then returning false"; + error_or_close: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoRead(): error_or_close. " + << "Cleaning up, then returning false"; Cleanup("DoRead"); return false; } @@ -942,19 +1182,21 @@ class SMServerConnection: public EpollCallbackInterface, int size; read_buffer_.GetReadablePtr(&bytes, &size); while (size != 0) { - size_t bytes_consumed = sm_interface_->ProcessInput(bytes, size); - VLOG(2) << "consumed: " << bytes_consumed << " from socket fd: " << fd_; + size_t bytes_consumed = sm_interface_->ProcessReadInput(bytes, size); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "consumed " + << bytes_consumed << " bytes"; if (bytes_consumed == 0) { break; } read_buffer_.AdvanceReadablePtr(bytes_consumed); if (sm_interface_->MessageFullyRead()) { - VLOG(2) << "HandleRequestFullyRead"; - HandleRequestFullyRead(); - sm_interface_->Reset(); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "HandleRequestFullyRead: Setting EPOLLOUT"; + HandleResponseFullyRead(); events_ |= EPOLLOUT; } else if (sm_interface_->Error()) { - LOG(ERROR) << "Framer error detected: " + LOG(ERROR) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Framer error detected: Setting EPOLLOUT: " << sm_interface_->ErrorAsString(); // this causes everything to be closed/cleaned up. events_ |= EPOLLOUT; @@ -970,7 +1212,8 @@ class SMServerConnection: public EpollCallbackInterface, // feeding files into the output buffer. } - void HandleRequestFullyRead() { + void HandleResponseFullyRead() { + sm_interface_->Cleanup(); } void Notify() { @@ -980,20 +1223,29 @@ class SMServerConnection: public EpollCallbackInterface, size_t bytes_sent = 0; int flags = MSG_NOSIGNAL | MSG_DONTWAIT; if (fd_ == -1) { - VLOG(2) << "DoWrite: fd == -1. Returning false."; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoWrite: fd == -1. Returning false."; return false; } if (output_list_.empty()) { - sm_interface_->GetOutput(); - if (output_list_.empty()) + VLOG(2) << log_prefix_ << "DoWrite: Output list empty."; + if (sm_interface_) { + sm_interface_->GetOutput(); + } + if (output_list_.empty()) { events_ &= ~EPOLLOUT; + } } while (!output_list_.empty()) { + VLOG(2) << log_prefix_ << "DoWrite: Items in output list: " + << output_list_.size(); if (bytes_sent >= max_bytes_sent_per_dowrite_) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << " byte sent >= max bytes sent per write: Setting EPOLLOUT"; events_ |= EPOLLOUT; break; } - if (output_list_.size() < 2) { + if (sm_interface_ && output_list_.size() < 2) { sm_interface_->GetOutput(); } DataFrame& data_frame = output_list_.front(); @@ -1003,6 +1255,11 @@ class SMServerConnection: public EpollCallbackInterface, size -= data_frame.index; DCHECK_GE(size, 0); if (size <= 0) { + // Empty data frame. Indicates end of data from client. + // Uncork the socket. + int state = 0; + VLOG(2) << log_prefix_ << "Empty data frame, uncorking socket."; + setsockopt( fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof( state ) ); data_frame.MaybeDelete(); output_list_.pop_front(); continue; @@ -1010,56 +1267,63 @@ class SMServerConnection: public EpollCallbackInterface, flags = MSG_NOSIGNAL | MSG_DONTWAIT; if (output_list_.size() > 1) { + VLOG(2) << log_prefix_ << "Outlist size: " << output_list_.size() + << ": Adding MSG_MORE flag"; flags |= MSG_MORE; } - ssize_t bytes_written = 0; - if (ssl_) { - bytes_written = SSL_write(ssl_, bytes, size); - PrintSslError(); - } else { - bytes_written = send(fd_, bytes, size, flags); - } + VLOG(2) << log_prefix_ << "Attempting to send " << size << " bytes."; + ssize_t bytes_written = Send(bytes, size, flags); int stored_errno = errno; if (bytes_written == -1) { switch (stored_errno) { case EAGAIN: events_ &= ~EPOLLOUT; - VLOG(2) << " Got EAGAIN while writing"; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EAGAIN while writing"; goto done; case EINTR: - VLOG(2) << " Got EINTR while writing"; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EINTR while writing"; continue; default: - VLOG(2) << "While calling send, got error: " << stored_errno - << " " << strerror(stored_errno); + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "While calling send, got error: " << stored_errno + << ": " << (ssl_?"":strerror(stored_errno)); goto error_or_close; } } else if (bytes_written > 0) { - VLOG(1) << "Wrote: " << bytes_written << " bytes to socket fd: " - << fd_; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Wrote: " + << bytes_written << " bytes"; data_frame.index += bytes_written; bytes_sent += bytes_written; continue; + } else if (bytes_written == -2) { + // -2 handles SSL_ERROR_WANT_* errors + events_ &= ~EPOLLOUT; + goto done; } - VLOG(2) << "0 bytes written to socket " << fd_ << " with send call."; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "0 bytes written with send call."; goto error_or_close; } done: return true; error_or_close: - VLOG(2) << "DoWrite: error_or_close. Returning false after cleaning up"; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoWrite: error_or_close. Returning false " + << "after cleaning up"; Cleanup("DoWrite"); return false; } - friend ostream& operator<<(ostream& os, const SMServerConnection& c) { + friend ostream& operator<<(ostream& os, const SMConnection& c) { os << &c << "\n"; return os; } void Reset() { - VLOG(2) << "Resetting"; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Resetting"; if (ssl_) { SSL_shutdown(ssl_); PrintSslError(); @@ -1071,25 +1335,17 @@ class SMServerConnection: public EpollCallbackInterface, registered_in_epoll_server_ = false; } if (fd_ >= 0) { - VLOG(2) << "Closing connection"; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Closing connection"; close(fd_); fd_ = -1; } - sm_interface_->ResetForNewConnection(); read_buffer_.Clear(); initialized_ = false; + protocol_detected_ = false; events_ = 0; output_list_.clear(); } - void Cleanup(const char* cleanup) { - VLOG(2) << "Cleaning up: " << cleanup; - if (!initialized_) { - return; - } - Reset(); - connection_pool_->SMServerConnectionDone(this); - } }; //////////////////////////////////////////////////////////////////////////////// @@ -1114,13 +1370,15 @@ class OutputOrdering { PriorityRing first_data_senders_; uint32 first_data_senders_threshold_; // when you've passed this, you're no // longer a first_data_sender... - SMServerConnection* connection_; + SMConnectionInterface* connection_; EpollServer* epoll_server_; - explicit OutputOrdering(SMServerConnection* connection) : - first_data_senders_threshold_(kInitialDataSendersThreshold), - connection_(connection), - epoll_server_(connection->epoll_server()) { + explicit OutputOrdering(SMConnectionInterface* connection) : + first_data_senders_threshold_(kInitialDataSendersThreshold), + connection_(connection) { + if (connection) { + epoll_server_ = connection->epoll_server(); + } } void Reset() { @@ -1151,7 +1409,7 @@ class OutputOrdering { int64 OnAlarm() { OnUnregistration(); output_ordering_->MoveToActive(pmp_, mci_); - VLOG(1) << "ON ALARM! Should now start to output..."; + VLOG(2) << "ON ALARM! Should now start to output..."; delete this; return 0; } @@ -1179,7 +1437,7 @@ class OutputOrdering { }; void MoveToActive(PriorityMapPointer* pmp, MemCacheIter mci) { - VLOG(1) <<"Moving to active!"; + VLOG(2) << "Moving to active!"; first_data_senders_.push_back(mci); pmp->ring = &first_data_senders_; pmp->it = first_data_senders_.end(); @@ -1189,9 +1447,9 @@ class OutputOrdering { void AddToOutputOrder(const MemCacheIter& mci) { if (ExistsInPriorityMaps(mci.stream_id)) - LOG(FATAL) << "OOps, already was inserted here?!"; + LOG(ERROR) << "OOps, already was inserted here?!"; - double think_time_in_s = FLAGS_server_think_time_in_s; + double think_time_in_s = g_proxy_config.server_think_time_in_s_; string x_server_latency = mci.file_data->headers->GetHeader("X-Server-Latency").as_string(); if (x_server_latency.size() != 0) { @@ -1199,7 +1457,8 @@ class OutputOrdering { double tmp_think_time_in_s = strtod(x_server_latency.c_str(), &endp); if (endp != x_server_latency.c_str() + x_server_latency.size()) { LOG(ERROR) << "Unable to understand X-Server-Latency of: " - << x_server_latency << " for resource: " << mci.file_data->filename; + << x_server_latency << " for resource: " + << mci.file_data->filename.c_str(); } else { think_time_in_s = tmp_think_time_in_s; } @@ -1211,7 +1470,7 @@ class OutputOrdering { PriorityMapPointer& pmp = sitpmi->second; BeginOutputtingAlarm* boa = new BeginOutputtingAlarm(this, &pmp, mci); - VLOG(2) << "Server think time: " << think_time_in_s; + VLOG(1) << "Server think time: " << think_time_in_s; epoll_server_->RegisterAlarmApproximateDelta( think_time_in_s * 1000000, boa); } @@ -1275,34 +1534,169 @@ class OutputOrdering { } }; + //////////////////////////////////////////////////////////////////////////////// class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { private: uint64 seq_num_; - SpdyFramer* framer_; + SpdyFramer* spdy_framer_; - SMServerConnection* connection_; - OutputList* output_list_; - OutputOrdering output_ordering_; - MemoryCache* memory_cache_; + SMConnection* connection_; + OutputList* client_output_list_; + OutputOrdering client_output_ordering_; uint32 next_outgoing_stream_id_; + EpollServer* epoll_server_; + FlipAcceptor* acceptor_; + MemoryCache* memory_cache_; + vector<SMInterface*> server_interface_list; + vector<int32> unused_server_interface_list; + typedef map<uint32,SMInterface*> StreamToSmif; + StreamToSmif stream_to_smif_; public: - explicit SpdySM(SMServerConnection* connection) : - seq_num_(0), - framer_(new SpdyFramer), - connection_(connection), - output_list_(connection->output_list()), - output_ordering_(connection), - memory_cache_(connection->memory_cache()), - next_outgoing_stream_id_(2) { - framer_->set_visitor(this); + SpdySM(SMConnection* connection, + SMInterface* sm_http_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor) + : seq_num_(0), + spdy_framer_(new SpdyFramer), + connection_(connection), + client_output_list_(connection->output_list()), + client_output_ordering_(connection), + next_outgoing_stream_id_(2), + epoll_server_(epoll_server), + acceptor_(acceptor), + memory_cache_(memory_cache) { + spdy_framer_->set_visitor(this); } + + ~SpdySM() { + delete spdy_framer_; + } + + void InitSMInterface(SMInterface* sm_http_interface, + int32 server_idx) { } + + void InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + bool use_ssl) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT + << "SpdySM: Initializing server connection."; + connection_->InitSMConnection(connection_pool, sm_interface, + epoll_server, fd, use_ssl); + } + private: virtual void OnError(SpdyFramer* framer) { /* do nothing with this right now */ } + SMInterface* NewConnectionInterface() { + SMConnection* server_connection = + SMConnection::NewSMConnection(epoll_server_, NULL, + memory_cache_, acceptor_, + "http_conn: "); + if (server_connection == NULL) { + LOG(ERROR) << "SpdySM: Could not create server connection"; + return NULL; + } + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Creating new HTTP interface"; + SMInterface *sm_http_interface = NewHttpSM(server_connection, this, + epoll_server_, memory_cache_, + acceptor_); + return sm_http_interface; + } + + SMInterface* FindOrMakeNewSMConnectionInterface() { + SMInterface *sm_http_interface; + int32 server_idx; + if (unused_server_interface_list.empty()) { + sm_http_interface = NewConnectionInterface(); + server_idx = server_interface_list.size(); + server_interface_list.push_back(sm_http_interface); + VLOG(2) << ACCEPTOR_CLIENT_IDENT + << "SpdySM: Making new server connection on index: " + << server_idx; + } else { + server_idx = unused_server_interface_list.back(); + unused_server_interface_list.pop_back(); + sm_http_interface = server_interface_list.at(server_idx); + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Reusing connection on " + << "index: " << server_idx; + } + + sm_http_interface->InitSMInterface(this, server_idx); + sm_http_interface->InitSMConnection(NULL, sm_http_interface, + epoll_server_, -1, false); + + return sm_http_interface; + } + + int SpdyHandleNewStream(const SpdyControlFrame* frame, + string *http_data) + { + bool parsed_headers = false; + SpdyHeaderBlock headers; + const SpdySynStreamControlFrame* syn_stream = + reinterpret_cast<const SpdySynStreamControlFrame*>(frame); + + parsed_headers = spdy_framer_->ParseHeaderBlock(frame, &headers); + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: OnSyn(" + << syn_stream->stream_id() << ")"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: headers parsed?: " + << (parsed_headers? "yes": "no"); + if (parsed_headers) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: # headers: " + << headers.size(); + } + SpdyHeaderBlock::iterator url = headers.find("url"); + SpdyHeaderBlock::iterator method = headers.find("method"); + if (url == headers.end() || method == headers.end()) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: didn't find method or url " + << "or method. Not creating stream"; + return 0; + } + + string uri = UrlUtilities::GetUrlPath(url->second); + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_SPDY_SERVER) { + SpdyHeaderBlock::iterator referer = headers.find("referer"); + if (referer != headers.end() && method->second == "GET") { + memory_cache_->UpdateHeaders(referer->second, url->second); + } + string host = UrlUtilities::GetUrlHost(url->second); + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Request: " << method->second + << " " << uri; + string filename = EncodeURL(uri, host, method->second); + NewStream(syn_stream->stream_id(), + reinterpret_cast<const SpdySynStreamControlFrame*>(frame)-> + priority(), + filename); + } else { + SpdyHeaderBlock::iterator version = headers.find("version"); + *http_data += method->second + " " + uri + " " + version->second + "\r\n"; + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Request: " << method->second << " " + << uri << " " << version->second; + for (SpdyHeaderBlock::iterator i = headers.begin(); + i != headers.end(); ++i) { + *http_data += i->first + ": " + i->second + "\r\n"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << i->first.c_str() << ":" + << i->second.c_str(); + } + if (g_proxy_config.forward_ip_header_enabled_) { + // X-Client-Cluster-IP header + *http_data += g_proxy_config.forward_ip_header_ + ": " + + connection_->client_ip() + "\r\n"; + } + *http_data += "\r\n"; + } + + VLOG(3) << ACCEPTOR_CLIENT_IDENT << "SpdySM: HTTP Request:\n" << http_data; + return 1; + } + virtual void OnControl(const SpdyControlFrame* frame) { SpdyHeaderBlock headers; bool parsed_headers = false; @@ -1311,122 +1705,126 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { { const SpdySynStreamControlFrame* syn_stream = reinterpret_cast<const SpdySynStreamControlFrame*>(frame); - parsed_headers = framer_->ParseHeaderBlock(frame, &headers); - VLOG(2) << "OnSyn(" << syn_stream->stream_id() << ")"; - VLOG(2) << "headers parsed?: " << (parsed_headers? "yes": "no"); - if (parsed_headers) { - VLOG(2) << "# headers: " << headers.size(); - } - for (SpdyHeaderBlock::iterator i = headers.begin(); - i != headers.end(); - ++i) { - VLOG(2) << i->first << ": " << i->second; - } - SpdyHeaderBlock::iterator method = headers.find("method"); - SpdyHeaderBlock::iterator url = headers.find("url"); - if (url == headers.end() || method == headers.end()) { - VLOG(2) << "didn't find method or url or method. Not creating stream"; - break; - } + string http_data; + int ret = SpdyHandleNewStream(frame, &http_data); + if (!ret) { + LOG(ERROR) << "SpdySM: Could not convert spdy into http."; + break; + } - SpdyHeaderBlock::iterator referer = headers.find("referer"); - if (referer != headers.end() && method->second == "GET") { - memory_cache_->UpdateHeaders(referer->second, url->second); - } - string uri = UrlUtilities::GetUrlPath(url->second); - string host = UrlUtilities::GetUrlHost(url->second); - - string filename = EncodeURL(uri, host, method->second); - NewStream(syn_stream->stream_id(), - reinterpret_cast<const SpdySynStreamControlFrame*>(frame)-> - priority(), - filename); + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { + SMInterface *sm_http_interface = + FindOrMakeNewSMConnectionInterface(); + stream_to_smif_[syn_stream->stream_id()] = sm_http_interface; + sm_http_interface->SetStreamID(syn_stream->stream_id()); + sm_http_interface->ProcessWriteInput(http_data.c_str(), + http_data.size()); + } } break; case SYN_REPLY: - parsed_headers = framer_->ParseHeaderBlock(frame, &headers); - VLOG(2) << "OnSynReply(" - << reinterpret_cast<const SpdySynReplyControlFrame*>( - frame)->stream_id() << ")"; + parsed_headers = spdy_framer_->ParseHeaderBlock(frame, &headers); + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: OnSynReply(" << + reinterpret_cast<const SpdySynReplyControlFrame*>(frame)->stream_id() + << ")"; break; case RST_STREAM: { const SpdyRstStreamControlFrame* rst_stream = reinterpret_cast<const SpdyRstStreamControlFrame*>(frame); - VLOG(2) << "OnRst(" << rst_stream->stream_id() << ")"; - output_ordering_.RemoveStreamId(rst_stream ->stream_id()); + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: OnRst(" + << rst_stream->stream_id() << ")"; + client_output_ordering_.RemoveStreamId(rst_stream ->stream_id()); } break; default: - LOG(DFATAL) << "Unknown control frame type"; + LOG(ERROR) << "SpdySM: Unknown control frame type"; } } - virtual void OnStreamFrameData( - SpdyStreamId stream_id, - const char* data, size_t len) { - VLOG(2) << "StreamData(" << stream_id << ", [" << len << "])"; - /* do nothing with this right now */ - } - virtual void OnLameDuck() { - /* do nothing with this right now */ + virtual void OnStreamFrameData(SpdyStreamId stream_id, + const char* data, size_t len) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: StreamData(" << stream_id + << ", [" << len << "])"; + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { + stream_to_smif_[stream_id]->ProcessWriteInput(data, len); + } } public: - ~SpdySM() { - Reset(); + size_t ProcessReadInput(const char* data, size_t len) { + return spdy_framer_->ProcessInput(data, len); } - size_t ProcessInput(const char* data, size_t len) { - return framer_->ProcessInput(data, len); + + size_t ProcessWriteInput(const char* data, size_t len) { + return 0; } bool MessageFullyRead() const { - return framer_->MessageFullyRead(); + return spdy_framer_->MessageFullyRead(); } + void SetStreamID(uint32 stream_id) {} + bool Error() const { - return framer_->HasError(); + return spdy_framer_->HasError(); } const char* ErrorAsString() const { - return SpdyFramer::ErrorCodeToString(framer_->error_code()); + DCHECK(Error()); + return SpdyFramer::ErrorCodeToString(spdy_framer_->error_code()); + } + + void Reset() { + } + + void ResetForNewInterface(int32 server_idx) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Reset for new interface: " + << "server_idx: " << server_idx; + unused_server_interface_list.push_back(server_idx); } - void Reset() {} void ResetForNewConnection() { // seq_num is not cleared, intentionally. - delete framer_; - framer_ = new SpdyFramer; - framer_->set_visitor(this); - output_ordering_.Reset(); + delete spdy_framer_; + spdy_framer_ = new SpdyFramer; + spdy_framer_->set_visitor(this); + client_output_ordering_.Reset(); next_outgoing_stream_id_ = 2; } - // Send a couple of NOOP packets to force opening of cwnd. - void PostAcceptHook() { - if (!FLAGS_use_cwnd_opener) - return; - - // We send 2 because that is the initial cwnd, and also because - // we have to in order to get an ACK back from the client due to - // delayed ACK. - const int kPkts = 2; - - LOG(ERROR) << "Sending NOP FRAMES"; - - scoped_ptr<SpdyControlFrame> frame(SpdyFramer::CreateNopFrame()); - for (int i = 0; i < kPkts; ++i) { - char* bytes = frame->data(); - size_t size = SpdyFrame::size(); - ssize_t bytes_written = connection_->Send(bytes, size, MSG_DONTWAIT); - if (static_cast<size_t>(bytes_written) != size) { - LOG(ERROR) << "Trouble sending Nop packet! (" << errno << ")"; - if (errno == EAGAIN) - break; + // SMInterface's Cleanup is currently only called by SMConnection after a + // protocol message as been fully read. Spdy's SMInterface does not need + // to do any cleanup at this time. + // TODO (klindsay) This method is probably not being used properly and + // some logic review and method renaming is probably in order. + void Cleanup() {} + + // Send a settings frame and possibly some NOOP packets to force + // opening of cwnd + int PostAcceptHook() { + ssize_t bytes_written; + spdy::SpdySettings settings; + spdy::SettingsFlagsAndId settings_id(0); + settings_id.set_id(spdy::SETTINGS_MAX_CONCURRENT_STREAMS); + settings.push_back(spdy::SpdySetting(settings_id, 100)); + scoped_ptr<SpdySettingsControlFrame> + settings_frame(spdy_framer_->CreateSettings(settings)); + + char* bytes = settings_frame->data(); + size_t size = SpdyFrame::size() + settings_frame->length(); + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Sending Settings Frame"; + bytes_written = connection_->Send(bytes, size, + MSG_NOSIGNAL | MSG_DONTWAIT); + if (static_cast<size_t>(bytes_written) != size) { + LOG(ERROR) << "Trouble sending SETTINGS frame! (" << errno << ")"; + if (errno == EAGAIN) { + return 0; } } + return 1; } void AddAssociatedContent(FileData* file_data) { @@ -1436,10 +1834,12 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { string filename = "GET_"; filename += related_file.second; if (!memory_cache_->AssignFileData(filename, &mci)) { - VLOG(1) << "Unable to find associated content for: " << filename; + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Unable to find associated " + << "content for: " << filename; continue; } - VLOG(1) << "Adding associated content: " << filename; + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Adding associated content: " + << filename; mci.stream_id = next_outgoing_stream_id_; next_outgoing_stream_id_ += 2; mci.priority = related_file.first; @@ -1451,20 +1851,24 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { MemCacheIter mci; mci.stream_id = stream_id; mci.priority = priority; - if (!memory_cache_->AssignFileData(filename, &mci)) { - // error creating new stream. - VLOG(2) << "Sending ErrorNotFound"; - SendErrorNotFound(stream_id); + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_SPDY_SERVER) { + if (!memory_cache_->AssignFileData(filename, &mci)) { + // error creating new stream. + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Sending ErrorNotFound"; + SendErrorNotFound(stream_id); + } else { + AddToOutputOrder(mci); + if (FLAGS_use_xac) { + AddAssociatedContent(mci.file_data); + } + } } else { AddToOutputOrder(mci); - if (FLAGS_use_xac) { - AddAssociatedContent(mci.file_data); - } } } void AddToOutputOrder(const MemCacheIter& mci) { - output_ordering_.AddToOutputOrder(mci); + client_output_ordering_.AddToOutputOrder(mci); } void SendEOF(uint32 stream_id) { @@ -1493,12 +1897,12 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { SendDataFrameImpl(stream_id, data, len, spdy_flags, compress); } - SpdyFramer* spdy_framer() { return framer_; } + SpdyFramer* spdy_framer() { return spdy_framer_; } private: void SendEOFImpl(uint32 stream_id) { SendDataFrame(stream_id, NULL, 0, DATA_FLAG_FIN, false); - VLOG(2) << "Sending EOF: " << stream_id; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Sending EOF: " << stream_id; KillStream(stream_id); } @@ -1507,7 +1911,7 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "404", "Not Found"); SendSynReplyImpl(stream_id, my_headers); SendDataFrame(stream_id, "wtf?", 4, DATA_FLAG_FIN, false); - output_ordering_.RemoveStreamId(stream_id); + client_output_ordering_.RemoveStreamId(stream_id); } void SendOKResponseImpl(uint32 stream_id, string* output) { @@ -1516,11 +1920,11 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { SendSynReplyImpl(stream_id, my_headers); SendDataFrame( stream_id, output->c_str(), output->size(), DATA_FLAG_FIN, false); - output_ordering_.RemoveStreamId(stream_id); + client_output_ordering_.RemoveStreamId(stream_id); } void KillStream(uint32 stream_id) { - output_ordering_.RemoveStreamId(stream_id); + client_output_ordering_.RemoveStreamId(stream_id); } void CopyHeaders(SpdyHeaderBlock& dest, const BalsaHeaders& headers) { @@ -1559,8 +1963,8 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { CopyHeaders(block, headers); SpdySynStreamControlFrame* fsrcf = - framer_->CreateSynStream(stream_id, 0, 0, CONTROL_FLAG_NONE, true, - &block); + spdy_framer_->CreateSynStream(stream_id, 0, 0, CONTROL_FLAG_NONE, true, + &block); DataFrame df; df.size = fsrcf->length() + SpdyFrame::size(); size_t df_size = df.size; @@ -1568,7 +1972,8 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { df.delete_when_done = true; EnqueueDataFrame(df); - VLOG(2) << "Sending SynStreamheader " << stream_id; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Sending SynStreamheader " + << stream_id; return df_size; } @@ -1580,7 +1985,7 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { block["version"] = headers.response_version().as_string(); SpdySynReplyControlFrame* fsrcf = - framer_->CreateSynReply(stream_id, CONTROL_FLAG_NONE, true, &block); + spdy_framer_->CreateSynReply(stream_id, CONTROL_FLAG_NONE, true, &block); DataFrame df; df.size = fsrcf->length() + SpdyFrame::size(); size_t df_size = df.size; @@ -1588,7 +1993,8 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { df.delete_when_done = true; EnqueueDataFrame(df); - VLOG(2) << "Sending SynReplyheader " << stream_id; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Sending SynReplyheader " + << stream_id; return df_size; } @@ -1601,16 +2007,16 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { // TODO(mbelshe): We can't compress here - before going into the // priority queue. Compression needs to be done // with late binding. - SpdyDataFrame* fdf = framer_->CreateDataFrame(stream_id, data, len, - flags); + SpdyDataFrame* fdf = spdy_framer_->CreateDataFrame(stream_id, data, len, + flags); DataFrame df; df.size = fdf->length() + SpdyFrame::size(); df.data = fdf->data(); df.delete_when_done = true; EnqueueDataFrame(df); - VLOG(2) << "Sending data frame" << stream_id << " [" << len << "]" - << " shrunk to " << fdf->length(); + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Sending data frame " + << stream_id << " [" << len << "] shrunk to " << fdf->length(); } void EnqueueDataFrame(const DataFrame& df) { @@ -1618,16 +2024,17 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { } void GetOutput() { - while (output_list_->size() < 2) { - MemCacheIter* mci = output_ordering_.GetIter(); + while (client_output_list_->size() < 2) { + MemCacheIter* mci = client_output_ordering_.GetIter(); if (mci == NULL) { - VLOG(2) << "GetOutput: nothing to output!?"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT + << "SpdySM: GetOutput: nothing to output!?"; return; } if (!mci->transformed_header) { mci->transformed_header = true; - VLOG(2) << "GetOutput transformed header stream_id: [" - << mci->stream_id << "]"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: GetOutput transformed " + << "header stream_id: [" << mci->stream_id << "]"; if ((mci->stream_id % 2) == 0) { // this is a server initiated stream. // Ideally, we'd do a 'syn-push' here, instead of a syn-reply. @@ -1647,7 +2054,8 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { return; } if (mci->body_bytes_consumed >= mci->file_data->body.size()) { - VLOG(2) << "GetOutput remove_stream_id: [" << mci->stream_id << "]"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: GetOutput " + << "remove_stream_id: [" << mci->stream_id << "]"; SendEOF(mci->stream_id); return; } @@ -1669,8 +2077,8 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { SendDataFrame(mci->stream_id, mci->file_data->body.data() + mci->body_bytes_consumed, num_to_write, 0, should_compress); - VLOG(2) << "GetOutput SendDataFrame[" << mci->stream_id - << "]: " << num_to_write; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: GetOutput SendDataFrame[" + << mci->stream_id << "]: " << num_to_write; mci->body_bytes_consumed += num_to_write; mci->bytes_sent += num_to_write; } @@ -1679,28 +2087,41 @@ class SpdySM : public SpdyFramerVisitorInterface, public SMInterface { //////////////////////////////////////////////////////////////////////////////// -class HTTPSM : public BalsaVisitorInterface, public SMInterface { +class HttpSM : public BalsaVisitorInterface, public SMInterface { private: uint64 seq_num_; - BalsaFrame* framer_; + BalsaFrame* http_framer_; BalsaHeaders headers_; uint32 stream_id_; + int32 server_idx_; - SMServerConnection* connection_; + SMConnection* connection_; + SMInterface* sm_spdy_interface_; OutputList* output_list_; OutputOrdering output_ordering_; MemoryCache* memory_cache_; + FlipAcceptor* acceptor_; public: - explicit HTTPSM(SMServerConnection* connection) : + explicit HttpSM(SMConnection* connection, + SMInterface* sm_spdy_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor) : seq_num_(0), - framer_(new BalsaFrame), - stream_id_(1), + http_framer_(new BalsaFrame), + stream_id_(0), + server_idx_(-1), connection_(connection), + sm_spdy_interface_(sm_spdy_interface), output_list_(connection->output_list()), output_ordering_(connection), - memory_cache_(connection->memory_cache()) { - framer_->set_balsa_visitor(this); - framer_->set_balsa_headers(&headers_); + memory_cache_(connection->memory_cache()), + acceptor_(acceptor) { + http_framer_->set_balsa_visitor(this); + http_framer_->set_balsa_headers(&headers_); + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { + http_framer_->set_is_request(false); + } } private: typedef map<string, uint32> ClientTokenMap; @@ -1708,20 +2129,31 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { virtual void ProcessBodyInput(const char *input, size_t size) { } virtual void ProcessBodyData(const char *input, size_t size) { - // ignoring this. + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Process Body Data: stream " + << stream_id_ << ": size " << size; + sm_spdy_interface_->SendDataFrame(stream_id_, input, size, 0, false); + } } virtual void ProcessHeaderInput(const char *input, size_t size) { } virtual void ProcessTrailerInput(const char *input, size_t size) {} virtual void ProcessHeaders(const BalsaHeaders& headers) { - VLOG(2) << "Got new request!"; - string host = UrlUtilities::GetUrlHost( - headers.GetHeader("Host").as_string()); - string method = headers.request_method().as_string(); - string filename = EncodeURL(headers.request_uri().as_string(), host, - method); - NewStream(stream_id_, 0, filename); - stream_id_ += 2; + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_HTTP_SERVER) { + string host = + UrlUtilities::GetUrlHost(headers.GetHeader("Host").as_string()); + string method = headers.request_method().as_string(); + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Received Request: " + << headers.request_uri().as_string() << " " << method; + string filename = EncodeURL(headers.request_uri().as_string(), + host, method); + NewStream(stream_id_, 0, filename); + stream_id_ += 2; + } else { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Received Response from " + << ACCEPTOR_SERVER_IDENT; + sm_spdy_interface_->SendSynReply(stream_id_, headers); + } } virtual void ProcessRequestFirstLine(const char* line_input, size_t line_length, @@ -1743,7 +2175,13 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { virtual void ProcessChunkExtensions(const char *input, size_t size) {} virtual void HeaderDone() {} virtual void MessageDone() { - VLOG(2) << "MessageDone!"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: MessageDone. Sending EOF: " + << "stream " << stream_id_; + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { + sm_spdy_interface_->SendEOF(stream_id_); + } else { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: MessageDone."; + } } virtual void HandleHeaderError(BalsaFrame* framer) { HandleError(); @@ -1757,40 +2195,98 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { } void HandleError() { - VLOG(2) << "Error detected"; + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Error detected"; } public: - ~HTTPSM() { + ~HttpSM() { Reset(); + delete http_framer_; } - size_t ProcessInput(const char* data, size_t len) { - return framer_->ProcessInput(data, len); + + void InitSMInterface(SMInterface* sm_spdy_interface, + int32 server_idx) + { + sm_spdy_interface_ = sm_spdy_interface; + server_idx_ = server_idx; + } + + void InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + bool use_ssl) + { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Initializing server " + << "connection."; + connection_->InitSMConnection(connection_pool, sm_interface, + epoll_server, fd, use_ssl); + } + + size_t ProcessReadInput(const char* data, size_t len) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Process read input: stream " + << stream_id_; + return http_framer_->ProcessInput(data, len); + } + + size_t ProcessWriteInput(const char* data, size_t len) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Process write input: size " + << len << ": stream " << stream_id_; + char * dataPtr = new char[len]; + memcpy(dataPtr, data, len); + DataFrame data_frame; + data_frame.data = (const char *)dataPtr; + data_frame.size = len; + data_frame.delete_when_done = true; + connection_->EnqueueDataFrame(data_frame); + return len; } bool MessageFullyRead() const { - return framer_->MessageFullyRead(); + return http_framer_->MessageFullyRead(); + } + + void SetStreamID(uint32 stream_id) { + stream_id_ = stream_id; } bool Error() const { - return framer_->Error(); + return http_framer_->Error(); } const char* ErrorAsString() const { - return BalsaFrameEnums::ErrorCodeToString(framer_->ErrorCode()); + return BalsaFrameEnums::ErrorCodeToString(http_framer_->ErrorCode()); } void Reset() { - framer_->Reset(); + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Reset: stream %d " + << stream_id_; + http_framer_->Reset(); + } + + void ResetForNewInterface(int32 server_idx) { } void ResetForNewConnection() { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Server connection closing " + << "to: " << ACCEPTOR_SERVER_IDENT; seq_num_ = 0; output_ordering_.Reset(); - framer_->Reset(); + http_framer_->Reset(); + if (sm_spdy_interface_) { + sm_spdy_interface_->ResetForNewInterface(server_idx_); + } + } + + void Cleanup() { + if (!(acceptor_->flip_handler_type_ == FLIP_HANDLER_HTTP_SERVER)) { + connection_->Cleanup("HttpSM Request Fully Read: stream_id " + + stream_id_); + } } - void PostAcceptHook() { + int PostAcceptHook() { + return 1; } void NewStream(uint32 stream_id, uint32 priority, const string& filename) { @@ -1798,6 +2294,8 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { mci.stream_id = stream_id; mci.priority = priority; if (!memory_cache_->AssignFileData(filename, &mci)) { + // error creating new stream. + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "Sending ErrorNotFound"; SendErrorNotFound(stream_id); } else { AddToOutputOrder(mci); @@ -1810,6 +2308,9 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { void SendEOF(uint32 stream_id) { SendEOFImpl(stream_id); + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { + sm_spdy_interface_->ResetForNewInterface(server_idx_); + } } void SendErrorNotFound(uint32 stream_id) { @@ -1833,7 +2334,7 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { SendDataFrameImpl(stream_id, data, len, flags, compress); } - BalsaFrame* spdy_framer() { return framer_; } + BalsaFrame* spdy_framer() { return http_framer_; } private: void SendEOFImpl(uint32 stream_id) { @@ -1842,6 +2343,9 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { df.size = 5; df.delete_when_done = false; EnqueueDataFrame(df); + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_HTTP_SERVER) { + Reset(); + } } void SendErrorNotFoundImpl(uint32 stream_id) { @@ -1875,7 +2379,8 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { df.data = buffer; df.delete_when_done = true; sb.Read(buffer, df.size); - VLOG(2) << "******************Sending HTTP Reply header " << stream_id; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "Sending HTTP Reply header " + << stream_id_; size_t df_size = df.size; EnqueueDataFrame(df); return df_size; @@ -1890,7 +2395,8 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { df.data = buffer; df.delete_when_done = true; sb.Read(buffer, df.size); - VLOG(2) << "******************Sending HTTP Reply header " << stream_id; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "Sending HTTP Reply header " + << stream_id_; size_t df_size = df.size; EnqueueDataFrame(df); return df_size; @@ -1913,27 +2419,31 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { } void EnqueueDataFrame(const DataFrame& df) { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Enqueue data frame: stream " + << stream_id_; connection_->EnqueueDataFrame(df); } void GetOutput() { MemCacheIter* mci = output_ordering_.GetIter(); if (mci == NULL) { - VLOG(2) << "GetOutput: nothing to output!?"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: GetOutput: nothing to " + << "output!?: stream " << stream_id_; return; } if (!mci->transformed_header) { mci->bytes_sent = SendSynReply(mci->stream_id, *(mci->file_data->headers)); mci->transformed_header = true; - VLOG(2) << "GetOutput transformed header stream_id: [" - << mci->stream_id << "]"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: GetOutput transformed " + << "header stream_id: [" << mci->stream_id << "]"; return; } if (mci->body_bytes_consumed >= mci->file_data->body.size()) { SendEOF(mci->stream_id); output_ordering_.RemoveStreamId(mci->stream_id); - VLOG(2) << "GetOutput remove_stream_id: [" << mci->stream_id << "]"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "GetOutput remove_stream_id: [" + << mci->stream_id << "]"; return; } size_t num_to_write = @@ -1943,8 +2453,8 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { SendDataFrame(mci->stream_id, mci->file_data->body.data() + mci->body_bytes_consumed, num_to_write, 0, true); - VLOG(2) << "GetOutput SendDataFrame[" << mci->stream_id - << "]: " << num_to_write; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: GetOutput SendDataFrame[" + << mci->stream_id << "]: " << num_to_write; mci->body_bytes_consumed += num_to_write; mci->bytes_sent += num_to_write; } @@ -1952,119 +2462,312 @@ class HTTPSM : public BalsaVisitorInterface, public SMInterface { //////////////////////////////////////////////////////////////////////////////// -class Notification { +class StreamerSM : public SMInterface { + private: + SMConnection* connection_; + SMInterface* sm_other_interface_; + EpollServer* epoll_server_; + FlipAcceptor* acceptor_; public: - explicit Notification(bool value) : value_(value) {} + explicit StreamerSM(SMConnection* connection, + SMInterface* sm_other_interface, + EpollServer* epoll_server, + FlipAcceptor* acceptor) : + connection_(connection), + sm_other_interface_(sm_other_interface), + epoll_server_(epoll_server), + acceptor_(acceptor) + { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "Creating StreamerSM object"; + } + ~StreamerSM() { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Destroying StreamerSM object"; + Reset(); + } + + void InitSMInterface(SMInterface* sm_other_interface, + int32 server_idx) + { + sm_other_interface_ = sm_other_interface; + } + + void InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + bool use_ssl) + { + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "StreamerSM: Initializing server " + << "connection."; + connection_->InitSMConnection(connection_pool, sm_interface, + epoll_server, fd, use_ssl); + } + + size_t ProcessReadInput(const char* data, size_t len) { + return sm_other_interface_->ProcessWriteInput(data, len); + } + + size_t ProcessWriteInput(const char* data, size_t len) { + char * dataPtr = new char[len]; + memcpy(dataPtr, data, len); + DataFrame df; + df.data = (const char *)dataPtr; + df.size = len; + df.delete_when_done = true; + connection_->EnqueueDataFrame(df); + return len; + } + + bool MessageFullyRead() const { + return false; + } + + void SetStreamID(uint32 stream_id) {} + + bool Error() const { + return false; + } + + const char* ErrorAsString() const { + return "(none)"; + } + + void Reset() { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "StreamerSM: Reset"; + connection_->Cleanup("Server Reset"); + } + + void ResetForNewInterface(int32 server_idx) { + } + + void ResetForNewConnection() { + sm_other_interface_->Reset(); + } + + void Cleanup() { + } + + int PostAcceptHook() { + if (!sm_other_interface_) { + SMConnection *server_connection = + SMConnection::NewSMConnection(epoll_server_, NULL, NULL, + acceptor_, "server_conn: "); + if (server_connection == NULL) { + LOG(ERROR) << "StreamerSM: Could not create server conenction."; + return 0; + } + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "StreamerSM: Creating new server " + << "connection."; + sm_other_interface_ = new StreamerSM(server_connection, this, + epoll_server_, acceptor_); + sm_other_interface_->InitSMInterface(this, 0); + } + sm_other_interface_->InitSMConnection(NULL, sm_other_interface_, + epoll_server_, -1, false); + + return 1; + } + + void NewStream(uint32 stream_id, uint32 priority, const string& filename) { + } + + void AddToOutputOrder(const MemCacheIter& mci) { + } + + void SendEOF(uint32 stream_id) { + } + + void SendErrorNotFound(uint32 stream_id) { + } + + void SendOKResponse(uint32 stream_id, string output) { + } + + size_t SendSynStream(uint32 stream_id, const BalsaHeaders& headers) { + return 0; + } + + size_t SendSynReply(uint32 stream_id, const BalsaHeaders& headers) { + return 0; + } + + void SendDataFrame(uint32 stream_id, const char* data, int64 len, + uint32 flags, bool compress) { + } - void Notify() { - AutoLock al(lock_); - value_ = true; - } - bool HasBeenNotified() { - AutoLock al(lock_); - return value_; - } - bool value_; - Lock lock_; + private: + void SendEOFImpl(uint32 stream_id) { + } + + void SendErrorNotFoundImpl(uint32 stream_id) { + } + + void SendOKResponseImpl(uint32 stream_id, string* output) { + } + + size_t SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers) { + return 0; + } + + size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers) { + return 0; + } + + void SendDataFrameImpl(uint32 stream_id, const char* data, int64 len, + uint32 flags, bool compress) { + } + + void GetOutput() { + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +class Notification { + public: + explicit Notification(bool value) : value_(value) {} + + void Notify() { + AutoLock al(lock_); + value_ = true; + } + bool HasBeenNotified() { + AutoLock al(lock_); + return value_; + } + bool value_; + Lock lock_; }; //////////////////////////////////////////////////////////////////////////////// class SMAcceptorThread : public SimpleThread, public EpollCallbackInterface, - public SMServerConnectionPoolInterface { + public SMConnectionPoolInterface { EpollServer epoll_server_; - int listen_fd_; - int accepts_per_wake_; + FlipAcceptor *acceptor_; + SSLState *ssl_state_; + bool use_ssl_; - vector<SMServerConnection*> unused_server_connections_; - vector<SMServerConnection*> tmp_unused_server_connections_; - vector<SMServerConnection*> allocated_server_connections_; + vector<SMConnection*> unused_server_connections_; + vector<SMConnection*> tmp_unused_server_connections_; + vector<SMConnection*> allocated_server_connections_; Notification quitting_; - SMInterfaceFactory* sm_interface_factory_; MemoryCache* memory_cache_; public: - SMAcceptorThread(int listen_fd, - int accepts_per_wake, - SMInterfaceFactory* smif, + SMAcceptorThread(FlipAcceptor *acceptor, MemoryCache* memory_cache) : SimpleThread("SMAcceptorThread"), - listen_fd_(listen_fd), - accepts_per_wake_(accepts_per_wake), + acceptor_(acceptor), + ssl_state_(NULL), + use_ssl_(false), quitting_(false), - sm_interface_factory_(smif), - memory_cache_(memory_cache) { + memory_cache_(memory_cache) + { + if (!acceptor->ssl_cert_filename_.empty() && + !acceptor->ssl_cert_filename_.empty()) { + ssl_state_ = new SSLState; + bool use_npn = true; + if (acceptor_->flip_handler_type_ == FLIP_HANDLER_HTTP_SERVER) { + use_npn = false; + } + spdy_init_ssl(ssl_state_, acceptor_->ssl_cert_filename_, + acceptor_->ssl_key_filename_, use_npn); + use_ssl_ = true; + } } ~SMAcceptorThread() { - for (vector<SMServerConnection*>::iterator i = - allocated_server_connections_.begin(); + for (vector<SMConnection*>::iterator i = + allocated_server_connections_.begin(); i != allocated_server_connections_.end(); ++i) { delete *i; } } - SMServerConnection* NewConnection() { - SMServerConnection* server = - SMServerConnection::NewSMServerConnection(sm_interface_factory_, - memory_cache_, - &epoll_server_); + SMConnection* NewConnection() { + SMConnection* server = + SMConnection::NewSMConnection(&epoll_server_, ssl_state_, + memory_cache_, acceptor_, + "client_conn: "); allocated_server_connections_.push_back(server); - VLOG(3) << "Making new server: " << server; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "Acceptor: Making new server."; return server; } - SMServerConnection* FindOrMakeNewSMServerConnection() { + SMConnection* FindOrMakeNewSMConnection() { if (unused_server_connections_.empty()) { return NewConnection(); } - SMServerConnection* retval = unused_server_connections_.back(); + SMConnection* server = unused_server_connections_.back(); unused_server_connections_.pop_back(); - return retval; + VLOG(2) << ACCEPTOR_CLIENT_IDENT << "Acceptor: Reusing server."; + return server; } - void InitWorker() { - epoll_server_.RegisterFD(listen_fd_, this, EPOLLIN | EPOLLET); + epoll_server_.RegisterFD(acceptor_->listen_fd_, this, EPOLLIN | EPOLLET); } - void HandleConnection(int client_fd) { - SMServerConnection* server_connection = FindOrMakeNewSMServerConnection(); + void HandleConnection(int server_fd) { + int on = 1; + int rc; + if (acceptor_->disable_nagle_) { + rc = setsockopt(server_fd, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast<char*>(&on), sizeof(on)); + if (rc < 0) { + close(server_fd); + LOG(ERROR) << "setsockopt() failed fd=" + server_fd; + return; + } + } + + SMConnection* server_connection = FindOrMakeNewSMConnection(); if (server_connection == NULL) { - VLOG(2) << "Closing " << client_fd; - close(client_fd); + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Acceptor: Closing fd " << server_fd; + close(server_fd); return; } - server_connection->InitSMServerConnection(this, - &epoll_server_, - client_fd); + server_connection->InitSMConnection(this, + NULL, + &epoll_server_, + server_fd, + use_ssl_); } void AcceptFromListenFD() { - if (accepts_per_wake_ > 0) { - for (int i = 0; i < accepts_per_wake_; ++i) { + if (acceptor_->accepts_per_wake_ > 0) { + for (int i = 0; i < acceptor_->accepts_per_wake_; ++i) { struct sockaddr address; socklen_t socklen = sizeof(address); - int fd = accept(listen_fd_, &address, &socklen); + int fd = accept(acceptor_->listen_fd_, &address, &socklen); if (fd == -1) { - VLOG(2) << "accept fail(" << listen_fd_ << "): " << errno; + if (errno != 11) { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Acceptor: accept fail(" + << acceptor_->listen_fd_ << "): " << errno << ": " + << strerror(errno); + } break; } - VLOG(2) << "********************Accepted fd: " << fd << "\n\n\n"; + VLOG(1) << ACCEPTOR_CLIENT_IDENT << " Accepted connection"; HandleConnection(fd); } } else { while (true) { struct sockaddr address; socklen_t socklen = sizeof(address); - int fd = accept(listen_fd_, &address, &socklen); + int fd = accept(acceptor_->listen_fd_, &address, &socklen); if (fd == -1) { - VLOG(2) << "accept fail(" << listen_fd_ << "): " << errno; + if (errno != 11) { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Acceptor: accept fail(" + << acceptor_->listen_fd_ << "): " << errno << ": " + << strerror(errno); + } break; } - VLOG(2) << "********************Accepted fd: " << fd << "\n\n\n"; + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Accepted connection"; HandleConnection(fd); } } @@ -2075,7 +2778,8 @@ class SMAcceptorThread : public SimpleThread, virtual void OnModification(int fd, int event_mask) { } virtual void OnEvent(int fd, EpollEvent* event) { if (event->in_events | EPOLLIN) { - VLOG(2) << "Accepting based upon epoll events"; + VLOG(2) << ACCEPTOR_CLIENT_IDENT + << "Acceptor: Accepting based upon epoll events"; AcceptFromListenFD(); } } @@ -2097,56 +2801,58 @@ class SMAcceptorThread : public SimpleThread, } } - // SMServerConnections will use this: - virtual void SMServerConnectionDone(SMServerConnection* sc) { - VLOG(3) << "Done with server connection: " << sc; + // SMConnections will use this: + virtual void SMConnectionDone(SMConnection* sc) { + VLOG(1) << ACCEPTOR_CLIENT_IDENT << "Done with connection."; tmp_unused_server_connections_.push_back(sc); } }; //////////////////////////////////////////////////////////////////////////////// -SMInterface* NewSpdySM(SMServerConnection* connection) { - return new SpdySM(connection); +SMInterface* NewStreamerSM(SMConnection* connection, + SMInterface* sm_interface, + EpollServer* epoll_server, + FlipAcceptor* acceptor) { + return new StreamerSM(connection, sm_interface, epoll_server, acceptor); } -SMInterface* NewHTTPSM(SMServerConnection* connection) { - return new HTTPSM(connection); + +SMInterface* NewSpdySM(SMConnection* connection, + SMInterface* sm_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor) { + return new SpdySM(connection, sm_interface, epoll_server, + memory_cache, acceptor); +} + +SMInterface* NewHttpSM(SMConnection* connection, + SMInterface* sm_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor) { + return new HttpSM(connection, sm_interface, epoll_server, + memory_cache, acceptor); } //////////////////////////////////////////////////////////////////////////////// -int CreateListeningSocket(int port, int backlog_size, - bool reuseport, bool no_nagle) { - int listening_socket = 0; - char port_buf[256]; - snprintf(port_buf, sizeof(port_buf), "%d", port); - cerr <<" Attempting to listen on port: " << port_buf << "\n"; - cerr <<" input port: " << port << "\n"; - net::CreateListeningSocket("", - port_buf, - true, - backlog_size, - &listening_socket, - true, - reuseport, - &cerr); - SetNonBlocking(listening_socket); - if (no_nagle) { - // set SO_REUSEADDR on the listening socket. - int on = 1; - int rc; - rc = setsockopt(listening_socket, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast<char*>(&on), sizeof(on)); - if (rc < 0) { - close(listening_socket); - LOG(FATAL) << "setsockopt() failed fd=" << listening_socket << "\n"; - } +std::vector<std::string> &split(const std::string &s, + char delim, + std::vector<std::string> &elems) { + std::stringstream ss(s); + std::string item; + while(std::getline(ss, item, delim)) { + elems.push_back(item); } - return listening_socket; + return elems; } -//////////////////////////////////////////////////////////////////////////////// +std::vector<std::string> split(const std::string &s, char delim) { + std::vector<std::string> elems; + return split(s, delim, elems); +} bool GotQuitFromStdin() { // Make stdin nonblocking. Yes this is done each time. Oh well. @@ -2157,16 +2863,13 @@ bool GotQuitFromStdin() { maybequit += c; } if (maybequit.size()) { - VLOG(2) << "scanning string: \"" << maybequit << "\""; + VLOG(1) << "scanning string: \"" << maybequit << "\""; } return (maybequit.size() > 1 && (maybequit.c_str()[0] == 'q' || maybequit.c_str()[0] == 'Q')); } - -//////////////////////////////////////////////////////////////////////////////// - const char* BoolToStr(bool b) { if (b) return "true"; @@ -2175,103 +2878,171 @@ const char* BoolToStr(bool b) { //////////////////////////////////////////////////////////////////////////////// -int main(int argc, char**argv) { - bool use_ssl = FLAGS_use_ssl; - int response_count_until_close = FLAGS_response_count_until_close; - int spdy_port = FLAGS_spdy_port; - int port = FLAGS_port; - int backlog_size = FLAGS_accept_backlog_size; - bool reuseport = FLAGS_reuseport; - bool no_nagle = FLAGS_no_nagle; - double server_think_time_in_s = FLAGS_server_think_time_in_s; - int accepts_per_wake = FLAGS_accepts_per_wake; - int num_threads = 1; +int main (int argc, char**argv) +{ + unsigned int i = 0; + bool wait_for_iface = false; + + CommandLine::Init(argc, argv); + CommandLine cl(argc, argv); + + if (cl.HasSwitch("--help") || argc < 2) { + cout << argv[0] << " <options>\n"; + cout << "\t--proxy<1..n>=\"<listen ip>,<listen port>,<ssl cert filename>," + << "<ssl key filename>,<server ip>,<server port>\"\n"; + cout << "\t--spdy-server=\"<listen ip>,<listen port>,<ssl cert filename>," + << "<ssl key filename>\"\n"; + cout << "\t--http-server=\"<listen ip>,<listen port>,<ssl cert filename>," + << "<ssl key filename>\"\n"; + cout << "\t--forward-ip-header=<header name>\n"; + cout << "\t--logdest=file|system|both\n"; + cout << "\t--logfile=<logfile>\n"; + cout << "\t--wait-for-iface\n"; + cout << "\t--ssl-session-expiry=<seconds> (default is 300)\n"; + cout << "\t--help\n"; + exit(0); + } + + g_proxy_config.server_think_time_in_s_ = FLAGS_server_think_time_in_s; + + if (cl.HasSwitch("forward-ip-header")) { + g_proxy_config.forward_ip_header_enabled_ = true; + g_proxy_config.forward_ip_header_ = + cl.GetSwitchValueASCII("forward-ip-header"); + } else { + g_proxy_config.forward_ip_header_enabled_ = false; + } + + if (cl.HasSwitch("logdest")) { + string log_dest_value = cl.GetSwitchValueASCII("logdest"); + if (log_dest_value.compare("file") == 0) { + g_proxy_config.log_destination_ = logging::LOG_ONLY_TO_FILE; + } else if (log_dest_value.compare("system") == 0) { + g_proxy_config.log_destination_ = logging::LOG_ONLY_TO_SYSTEM_DEBUG_LOG; + } else if (log_dest_value.compare("both") == 0) { + g_proxy_config.log_destination_ = + logging::LOG_TO_BOTH_FILE_AND_SYSTEM_DEBUG_LOG; + } else { + LOG(FATAL) << "Invalid logging destination value: " << log_dest_value; + } + } else { + g_proxy_config.log_destination_ = logging::LOG_NONE; + } + if (cl.HasSwitch("logfile")) { + g_proxy_config.log_filename_ = cl.GetSwitchValueASCII("logfile"); + if (g_proxy_config.log_destination_ == logging::LOG_NONE) { + g_proxy_config.log_destination_ = logging::LOG_ONLY_TO_FILE; + } + } else if (g_proxy_config.log_destination_ == logging::LOG_ONLY_TO_FILE || + g_proxy_config.log_destination_ == + logging::LOG_TO_BOTH_FILE_AND_SYSTEM_DEBUG_LOG) { + LOG(FATAL) << "Logging destination requires a log file to be specified."; + } - MemoryCache spdy_memory_cache; - spdy_memory_cache.AddFiles(); + if (cl.HasSwitch("wait-for-iface")) { + wait_for_iface = true; + } + if (cl.HasSwitch("ssl-session-expiry")) { + string session_expiry = cl.GetSwitchValueASCII("ssl-session-expiry"); + g_proxy_config.ssl_session_expiry_ = atoi( session_expiry.c_str() ); + } + + InitLogging(g_proxy_config.log_filename_.c_str(), + g_proxy_config.log_destination_, + logging::DONT_LOCK_LOG_FILE, + logging::APPEND_TO_OLD_LOG_FILE); + + LOG(INFO) << "Flip SPDY proxy started with configuration:"; + LOG(INFO) << "Logging destination : " << g_proxy_config.log_destination_; + LOG(INFO) << "Log file : " << g_proxy_config.log_filename_; + LOG(INFO) << "Forward IP Header : " + << (g_proxy_config.forward_ip_header_enabled_ ? + g_proxy_config.forward_ip_header_ : "(disabled)"); + LOG(INFO) << "Wait for interfaces : " << (wait_for_iface?"true":"false"); + LOG(INFO) << "Accept backlog size : " << FLAGS_accept_backlog_size; + LOG(INFO) << "Accepts per wake : " << FLAGS_accepts_per_wake; + LOG(INFO) << "Disable nagle : " + << (FLAGS_disable_nagle?"true":"false"); + LOG(INFO) << "Reuseport : " << (FLAGS_reuseport?"true":"false"); + + // Proxy Acceptors + while (true) { + i += 1; + std::stringstream name; + name << "proxy" << i; + if (!cl.HasSwitch(name.str())) { + break; + } + string value = cl.GetSwitchValueASCII(name.str()); + vector<std::string> valueArgs = split(value, ','); + CHECK_EQ((unsigned int)6, valueArgs.size()); + // If wait_for_iface is enabled, then this call will block + // indefinitely until the interface is raised. + g_proxy_config.AddAcceptor(FLIP_HANDLER_PROXY, + valueArgs[0], valueArgs[1], + valueArgs[2], valueArgs[3], + valueArgs[4], valueArgs[5], + FLAGS_accept_backlog_size, + FLAGS_disable_nagle, + FLAGS_accepts_per_wake, + FLAGS_reuseport, + wait_for_iface, + NULL); + } + + // Spdy Server Acceptor + MemoryCache spdy_memory_cache; + if (cl.HasSwitch("spdy-server")) { + spdy_memory_cache.AddFiles(); + string value = cl.GetSwitchValueASCII("spdy-server"); + vector<std::string> valueArgs = split(value, ','); + g_proxy_config.AddAcceptor(FLIP_HANDLER_SPDY_SERVER, + valueArgs[0], valueArgs[1], + valueArgs[2], valueArgs[3], + "", "", + FLAGS_accept_backlog_size, + FLAGS_disable_nagle, + FLAGS_accepts_per_wake, + FLAGS_reuseport, + wait_for_iface, + &spdy_memory_cache); + } + + // Spdy Server Acceptor MemoryCache http_memory_cache; - http_memory_cache.CloneFrom(spdy_memory_cache); - - LOG(INFO) << - "Starting up with the following state: \n" - " use_ssl: " << use_ssl << "\n" - " response_count_until_close: " << response_count_until_close << "\n" - " port: " << port << "\n" - " spdy_port: " << spdy_port << "\n" - " backlog_size: " << backlog_size << "\n" - " reuseport: " << BoolToStr(reuseport) << "\n" - " no_nagle: " << BoolToStr(no_nagle) << "\n" - " server_think_time_in_s: " << server_think_time_in_s << "\n" - " accepts_per_wake: " << accepts_per_wake << "\n" - " num_threads: " << num_threads << "\n" - " use_xsub: " << BoolToStr(FLAGS_use_xsub) << "\n" - " use_xac: " << BoolToStr(FLAGS_use_xac) << "\n"; - - if (use_ssl) { - global_ssl_state = new GlobalSSLState; - spdy_init_ssl(global_ssl_state); - } else { - global_ssl_state = NULL; + if (cl.HasSwitch("http-server")) { + http_memory_cache.AddFiles(); + string value = cl.GetSwitchValueASCII("http-server"); + vector<std::string> valueArgs = split(value, ','); + g_proxy_config.AddAcceptor(FLIP_HANDLER_HTTP_SERVER, + valueArgs[0], valueArgs[1], + valueArgs[2], valueArgs[3], + "", "", + FLAGS_accept_backlog_size, + FLAGS_disable_nagle, + FLAGS_accepts_per_wake, + FLAGS_reuseport, + wait_for_iface, + &http_memory_cache); } - EpollServer epoll_server; + vector<SMAcceptorThread*> sm_worker_threads_; - { - // spdy - int listen_fd = -1; - - if (reuseport || listen_fd == -1) { - listen_fd = CreateListeningSocket(spdy_port, backlog_size, - reuseport, no_nagle); - if (listen_fd < 0) { - LOG(FATAL) << "Unable to open listening socket on spdy_port: " - << spdy_port; - } else { - LOG(INFO) << "Listening for spdy on port: " << spdy_port; - } - } - sm_worker_threads_.push_back( - new SMAcceptorThread(listen_fd, - accepts_per_wake, - &NewSpdySM, - &spdy_memory_cache)); - // Note that spdy_memory_cache is not threadsafe, it is merely - // thread compatible. Thus, if ever we are to spawn multiple threads, - // we either must make the MemoryCache threadsafe, or use - // a separate MemoryCache for each thread. - // - // The latter is what is currently being done as we spawn - // two threads (one for spdy, one for http). - sm_worker_threads_.back()->InitWorker(); - sm_worker_threads_.back()->Start(); - } + for (i = 0; i < g_proxy_config.acceptors_.size(); i++) { + FlipAcceptor *acceptor = g_proxy_config.acceptors_[i]; - { - // http - int listen_fd = -1; - if (reuseport || listen_fd == -1) { - listen_fd = CreateListeningSocket(port, backlog_size, - reuseport, no_nagle); - if (listen_fd < 0) { - LOG(FATAL) << "Unable to open listening socket on port: " << port; - } else { - LOG(INFO) << "Listening for HTTP on port: " << port; - } - } - sm_worker_threads_.push_back( - new SMAcceptorThread(listen_fd, - accepts_per_wake, - &NewHTTPSM, - &http_memory_cache)); + sm_worker_threads_.push_back(new SMAcceptorThread(acceptor, + (MemoryCache *)acceptor->memory_cache_)); // Note that spdy_memory_cache is not threadsafe, it is merely // thread compatible. Thus, if ever we are to spawn multiple threads, // we either must make the MemoryCache threadsafe, or use // a separate MemoryCache for each thread. // // The latter is what is currently being done as we spawn - // two threads (one for spdy, one for http). + // a separate thread for each http and spdy server acceptor. + sm_worker_threads_.back()->InitWorker(); sm_worker_threads_.back()->Start(); } @@ -2288,5 +3059,6 @@ int main(int argc, char**argv) { } usleep(1000*10); // 10 ms } + return 0; } diff --git a/net/tools/flip_server/other_defines.h b/net/tools/flip_server/other_defines.h deleted file mode 100644 index dda2151..0000000 --- a/net/tools/flip_server/other_defines.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_TOOLS_FLIP_SERVER_OTHER_DEFINES -#define NET_TOOLS_FLIP_SERVER_OTHER_DEFINES -#pragma once - -class NullStream { - public: - NullStream() {} - template <typename T> - NullStream operator<<(T t) { return *this;} -}; - -#define VLOG(X) NullStream() -#define DVLOG(X) NullStream() - - -#endif // NET_TOOLS_FLIP_SERVER_OTHER_DEFINES - diff --git a/net/tools/flip_server/ring_buffer.h b/net/tools/flip_server/ring_buffer.h index ee6f360..75ad0ec 100644 --- a/net/tools/flip_server/ring_buffer.h +++ b/net/tools/flip_server/ring_buffer.h @@ -8,7 +8,6 @@ #include "base/scoped_ptr.h" #include "net/tools/flip_server/buffer_interface.h" -#include "net/tools/flip_server/other_defines.h" namespace net { diff --git a/net/tools/flip_server/simple_buffer.h b/net/tools/flip_server/simple_buffer.h index 71fafc9..fee7d4b 100644 --- a/net/tools/flip_server/simple_buffer.h +++ b/net/tools/flip_server/simple_buffer.h @@ -9,7 +9,6 @@ #include <string> #include "net/tools/flip_server/buffer_interface.h" -#include "net/tools/flip_server/other_defines.h" namespace net { diff --git a/net/tools/testserver/device_management.py b/net/tools/testserver/device_management.py index 7037608..d715227 100644 --- a/net/tools/testserver/device_management.py +++ b/net/tools/testserver/device_management.py @@ -186,7 +186,7 @@ class RequestHandler(object): # Respond only if the client requested policy for the cros/device scope, # since that's where chrome policy is supposed to live in. - if msg.policy_scope == 'cros/device': + if msg.policy_scope == 'chromeos/device': setting = response.policy_response.setting.add() setting.policy_key = 'chrome-policy' policy_value = dm.GenericSetting() @@ -226,17 +226,19 @@ class RequestHandler(object): will contain the error response to send back. """ error = None - + dmtoken = None + request_device_id = self.GetUniqueParam('deviceid') match = re.match('GoogleDMToken token=(\\w+)', self._headers.getheader('Authorization', '')) if match: dmtoken = match.group(1) - if not dmtoken: - error = dm.DeviceManagementResponse.DEVICE_MANAGEMENT_TOKEN_INVALID - elif not self._server.LookupDevice(dmtoken): - error = dm.DeviceManagementResponse.DEVICE_NOT_FOUND - else: - return (dmtoken, None) + if not dmtoken: + error = dm.DeviceManagementResponse.DEVICE_MANAGEMENT_TOKEN_INVALID + elif (not request_device_id or + not self._server.LookupDevice(dmtoken) == request_device_id): + error = dm.DeviceManagementResponse.DEVICE_NOT_FOUND + else: + return (dmtoken, None) response = dm.DeviceManagementResponse() response.error = error diff --git a/net/tools/testserver/run_testserver.cc b/net/tools/testserver/run_testserver.cc index e92dfb2..650b415 100644 --- a/net/tools/testserver/run_testserver.cc +++ b/net/tools/testserver/run_testserver.cc @@ -6,6 +6,7 @@ #include "base/at_exit.h" #include "base/command_line.h" +#include "base/file_path.h" #include "base/logging.h" #include "base/message_loop.h" #include "net/test/test_server.h" @@ -23,6 +24,14 @@ int main(int argc, const char* argv[]) { CommandLine::Init(argc, argv); CommandLine* command_line = CommandLine::ForCurrentProcess(); + if (!logging::InitLogging(FILE_PATH_LITERAL("testserver.log"), + logging::LOG_TO_BOTH_FILE_AND_SYSTEM_DEBUG_LOG, + logging::LOCK_LOG_FILE, + logging::APPEND_TO_OLD_LOG_FILE)) { + printf("Error: could not initialize logging. Exiting.\n"); + return -1; + } + if (command_line->GetSwitchCount() == 0 || command_line->HasSwitch("help")) { PrintUsage(); diff --git a/net/tools/testserver/testserver.py b/net/tools/testserver/testserver.py index 2551139..c0dc930 100755 --- a/net/tools/testserver/testserver.py +++ b/net/tools/testserver/testserver.py @@ -13,14 +13,18 @@ It can use https if you specify the flag --https=CERT where CERT is the path to a pem file containing the certificate and private key that should be used. """ +import asyncore import base64 import BaseHTTPServer import cgi +import errno import optparse import os import re -import shutil +import select +import simplejson import SocketServer +import socket import sys import struct import time @@ -48,6 +52,7 @@ SERVER_HTTP = 0 SERVER_FTP = 1 SERVER_SYNC = 2 +# Using debug() seems to cause hangs on XP: see http://crbug.com/64515 . debug_output = sys.stderr def debug(str): debug_output.write(str + "\n") @@ -115,12 +120,91 @@ class SyncHTTPServer(StoppableHTTPServer): # We import here to avoid pulling in chromiumsync's dependencies # unless strictly necessary. import chromiumsync - self._sync_handler = chromiumsync.TestServer() + import xmppserver StoppableHTTPServer.__init__(self, server_address, request_handler_class) + self._sync_handler = chromiumsync.TestServer() + self._xmpp_socket_map = {} + self._xmpp_server = xmppserver.XmppServer( + self._xmpp_socket_map, ('localhost', 0)) + self.xmpp_port = self._xmpp_server.getsockname()[1] def HandleCommand(self, query, raw_request): return self._sync_handler.HandleCommand(query, raw_request) + def HandleRequestNoBlock(self): + """Handles a single request. + + Copied from SocketServer._handle_request_noblock(). + """ + try: + request, client_address = self.get_request() + except socket.error: + return + if self.verify_request(request, client_address): + try: + self.process_request(request, client_address) + except: + self.handle_error(request, client_address) + self.close_request(request) + + def serve_forever(self): + """This is a merge of asyncore.loop() and SocketServer.serve_forever(). + """ + + def RunDispatcherHandler(dispatcher, handler): + """Handles a single event for an asyncore.dispatcher. + + Adapted from asyncore.read() et al. + """ + try: + handler(dispatcher) + except (asyncore.ExitNow, KeyboardInterrupt, SystemExit): + raise + except: + dispatcher.handle_error() + + while True: + read_fds = [ self.fileno() ] + write_fds = [] + exceptional_fds = [] + + for fd, xmpp_connection in self._xmpp_socket_map.items(): + is_r = xmpp_connection.readable() + is_w = xmpp_connection.writable() + if is_r: + read_fds.append(fd) + if is_w: + write_fds.append(fd) + if is_r or is_w: + exceptional_fds.append(fd) + + try: + read_fds, write_fds, exceptional_fds = ( + select.select(read_fds, write_fds, exceptional_fds)) + except select.error, err: + if err.args[0] != errno.EINTR: + raise + else: + continue + + for fd in read_fds: + if fd == self.fileno(): + self.HandleRequestNoBlock() + continue + xmpp_connection = self._xmpp_socket_map.get(fd) + RunDispatcherHandler(xmpp_connection, + asyncore.dispatcher.handle_read_event) + + for fd in write_fds: + xmpp_connection = self._xmpp_socket_map.get(fd) + RunDispatcherHandler(xmpp_connection, + asyncore.dispatcher.handle_write_event) + + for fd in exceptional_fds: + xmpp_connection = self._xmpp_socket_map.get(fd) + RunDispatcherHandler(xmpp_connection, + asyncore.dispatcher.handle_expt_event) + class BasePageHandler(BaseHTTPServer.BaseHTTPRequestHandler): @@ -1263,6 +1347,8 @@ def main(options, args): port = options.port + server_data = {} + if options.server_type == SERVER_HTTP: if options.cert: # let's make sure the cert file exists. @@ -1285,12 +1371,14 @@ def main(options, args): server.data_dir = MakeDataDir() server.file_root_url = options.file_root_url - listen_port = server.server_port + server_data['port'] = server.server_port server._device_management_handler = None elif options.server_type == SERVER_SYNC: server = SyncHTTPServer(('127.0.0.1', port), SyncPageHandler) print 'Sync HTTP server started on port %d...' % server.server_port - listen_port = server.server_port + print 'Sync XMPP server started on port %d...' % server.xmpp_port + server_data['port'] = server.server_port + server_data['xmpp_port'] = server.xmpp_port # means FTP Server else: my_data_dir = MakeDataDir() @@ -1315,21 +1403,26 @@ def main(options, args): # Instantiate FTP server class and listen to 127.0.0.1:port address = ('127.0.0.1', port) server = pyftpdlib.ftpserver.FTPServer(address, ftp_handler) - listen_port = server.socket.getsockname()[1] - print 'FTP server started on port %d...' % listen_port + server_data['port'] = server.socket.getsockname()[1] + print 'FTP server started on port %d...' % server_data['port'] # Notify the parent that we've started. (BaseServer subclasses # bind their sockets on construction.) if options.startup_pipe is not None: + server_data_json = simplejson.dumps(server_data) + server_data_len = len(server_data_json) + print 'sending server_data: %s (%d bytes)' % ( + server_data_json, server_data_len) if sys.platform == 'win32': fd = msvcrt.open_osfhandle(options.startup_pipe, 0) else: fd = options.startup_pipe startup_pipe = os.fdopen(fd, "w") - # Write the listening port as a 2 byte value. This is _not_ using - # network byte ordering since the other end of the pipe is on the same - # machine. - startup_pipe.write(struct.pack('@H', listen_port)) + # First write the data length as an unsigned 4-byte value. This + # is _not_ using network byte ordering since the other end of the + # pipe is on the same machine. + startup_pipe.write(struct.pack('=L', server_data_len)) + startup_pipe.write(server_data_json) startup_pipe.close() try: diff --git a/net/tools/testserver/xmppserver.py b/net/tools/testserver/xmppserver.py index ad99571..ac9c276 100644 --- a/net/tools/testserver/xmppserver.py +++ b/net/tools/testserver/xmppserver.py @@ -520,7 +520,6 @@ class XmppServer(asyncore.dispatcher): self._socket_map = socket_map self._socket_map[self.fileno()] = self self._connections = set() - print 'XMPP server running at %s' % AddrString(addr) def handle_accept(self): (sock, addr) = self.accept() diff --git a/net/url_request/url_request.cc b/net/url_request/url_request.cc index 7898848..ac54a1f 100644 --- a/net/url_request/url_request.cc +++ b/net/url_request/url_request.cc @@ -89,6 +89,7 @@ void URLRequest::Delegate::OnGetCookies(URLRequest* request, void URLRequest::Delegate::OnSetCookie(URLRequest* request, const std::string& cookie_line, + const net::CookieOptions& options, bool blocked_by_policy) { } diff --git a/net/url_request/url_request.h b/net/url_request/url_request.h index 01f5984..fb81500 100644 --- a/net/url_request/url_request.h +++ b/net/url_request/url_request.h @@ -29,21 +29,23 @@ class Time; } // namespace base namespace net { +class CookieOptions; class IOBuffer; class SSLCertRequestInfo; class UploadData; +class URLRequestJob; class X509Certificate; } // namespace net class FilePath; class URLRequestContext; -class URLRequestJob; // This stores the values of the Set-Cookie headers received during the request. // Each item in the vector corresponds to a Set-Cookie: line received, // excluding the "Set-Cookie:" part. typedef std::vector<std::string> ResponseCookies; +namespace net { //----------------------------------------------------------------------------- // A class representing the asynchronous load of a data stream from an URL. // @@ -187,6 +189,7 @@ class URLRequest : public NonThreadSafe { // when LOAD_DO_NOT_SAVE_COOKIES is specified. virtual void OnSetCookie(URLRequest* request, const std::string& cookie_line, + const net::CookieOptions& options, bool blocked_by_policy); // After calling Start(), the delegate will receive an OnResponseStarted @@ -302,7 +305,7 @@ class URLRequest : public NonThreadSafe { // expected modification time is provided (non-zero), it will be used to // check if the underlying file has been changed or not. The granularity of // the time comparison is 1 second since time_t precision is used in WebKit. - void AppendBytesToUpload(const char* bytes, int bytes_len); + void AppendBytesToUpload(const char* bytes, int bytes_len); // takes a copy void AppendFileRangeToUpload(const FilePath& file_path, uint64 offset, uint64 length, const base::Time& expected_modification_time); @@ -431,9 +434,7 @@ class URLRequest : public NonThreadSafe { // and the response has not yet been called). bool is_pending() const { return is_pending_; } - // Returns the error status of the request. This value is 0 if there is no - // error. Otherwise, it is a value defined by the operating system (e.g., an - // error code returned by GetLastError() on windows). + // Returns the error status of the request. const URLRequestStatus& status() const { return status_; } // This method is called to start the request. The delegate will receive @@ -643,4 +644,8 @@ class URLRequest : public NonThreadSafe { DISALLOW_COPY_AND_ASSIGN(URLRequest); }; +} // namespace net + +typedef net::URLRequest URLRequest; + #endif // NET_URL_REQUEST_URL_REQUEST_H_ diff --git a/net/url_request/url_request_context.cc b/net/url_request/url_request_context.cc index 137901d..281aa7e 100644 --- a/net/url_request/url_request_context.cc +++ b/net/url_request/url_request_context.cc @@ -12,6 +12,7 @@ URLRequestContext::URLRequestContext() : net_log_(NULL), host_resolver_(NULL), dnsrr_resolver_(NULL), + dns_cert_checker_(NULL), http_transaction_factory_(NULL), ftp_transaction_factory_(NULL), http_auth_handler_factory_(NULL), diff --git a/net/url_request/url_request_context.h b/net/url_request/url_request_context.h index a8b19eb..bc601b3 100644 --- a/net/url_request/url_request_context.h +++ b/net/url_request/url_request_context.h @@ -18,10 +18,12 @@ #include "net/base/transport_security_state.h" #include "net/ftp/ftp_auth_cache.h" #include "net/proxy/proxy_service.h" +#include "net/socket/dns_cert_provenance_checker.h" namespace net { class CookiePolicy; class CookieStore; +class DnsCertProvenanceChecker; class DnsRRResolver; class FtpTransactionFactory; class HostResolver; @@ -29,8 +31,8 @@ class HttpAuthHandlerFactory; class HttpNetworkDelegate; class HttpTransactionFactory; class SSLConfigService; -} class URLRequest; +} // namespace net // Subclass to provide application-specific context for URLRequest instances. class URLRequestContext @@ -51,6 +53,10 @@ class URLRequestContext return dnsrr_resolver_; } + net::DnsCertProvenanceChecker* dns_cert_checker() const { + return dns_cert_checker_.get(); + } + // Get the proxy service for this context. net::ProxyService* proxy_service() const { return proxy_service_; @@ -129,6 +135,7 @@ class URLRequestContext net::NetLog* net_log_; net::HostResolver* host_resolver_; net::DnsRRResolver* dnsrr_resolver_; + scoped_ptr<net::DnsCertProvenanceChecker> dns_cert_checker_; scoped_refptr<net::ProxyService> proxy_service_; scoped_refptr<net::SSLConfigService> ssl_config_service_; net::HttpTransactionFactory* http_transaction_factory_; diff --git a/net/url_request/url_request_data_job.h b/net/url_request/url_request_data_job.h index 1bb868a..7171088 100644 --- a/net/url_request/url_request_data_job.h +++ b/net/url_request/url_request_data_job.h @@ -1,4 +1,4 @@ -// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -11,17 +11,19 @@ #include "net/url_request/url_request.h" #include "net/url_request/url_request_simple_job.h" +namespace net { class URLRequest; +} // namespace net class URLRequestDataJob : public URLRequestSimpleJob { public: - explicit URLRequestDataJob(URLRequest* request); + explicit URLRequestDataJob(net::URLRequest* request); virtual bool GetData(std::string* mime_type, std::string* charset, std::string* data) const; - static URLRequest::ProtocolFactory Factory; + static net::URLRequest::ProtocolFactory Factory; private: ~URLRequestDataJob(); @@ -30,4 +32,3 @@ class URLRequestDataJob : public URLRequestSimpleJob { }; #endif // NET_URL_REQUEST_URL_REQUEST_DATA_JOB_H_ - diff --git a/net/url_request/url_request_error_job.cc b/net/url_request/url_request_error_job.cc index 4bb7195..1aeffac 100644 --- a/net/url_request/url_request_error_job.cc +++ b/net/url_request/url_request_error_job.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -8,7 +8,7 @@ #include "net/base/net_errors.h" #include "net/url_request/url_request_status.h" -URLRequestErrorJob::URLRequestErrorJob(URLRequest* request, int error) +URLRequestErrorJob::URLRequestErrorJob(net::URLRequest* request, int error) : URLRequestJob(request), error_(error) { } diff --git a/net/url_request/url_request_error_job.h b/net/url_request/url_request_error_job.h index efaea0c..6e7c879 100644 --- a/net/url_request/url_request_error_job.h +++ b/net/url_request/url_request_error_job.h @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // @@ -13,15 +13,16 @@ class URLRequestErrorJob : public URLRequestJob { public: - URLRequestErrorJob(URLRequest* request, int error); + URLRequestErrorJob(net::URLRequest* request, int error); virtual void Start(); private: ~URLRequestErrorJob() {} - int error_; void StartAsync(); + + int error_; }; #endif // NET_URL_REQUEST_URL_REQUEST_ERROR_JOB_H_ diff --git a/net/url_request/url_request_file_dir_job.h b/net/url_request/url_request_file_dir_job.h index c78de97..aefaf5b 100644 --- a/net/url_request/url_request_file_dir_job.h +++ b/net/url_request/url_request_file_dir_job.h @@ -17,7 +17,7 @@ class URLRequestFileDirJob : public URLRequestJob, public net::DirectoryLister::DirectoryListerDelegate { public: - URLRequestFileDirJob(URLRequest* request, const FilePath& dir_path); + URLRequestFileDirJob(net::URLRequest* request, const FilePath& dir_path); // URLRequestJob methods: virtual void Start(); diff --git a/net/url_request/url_request_file_job.h b/net/url_request/url_request_file_job.h index 4512c1c..e745cfd 100644 --- a/net/url_request/url_request_file_job.h +++ b/net/url_request/url_request_file_job.h @@ -36,6 +36,10 @@ class URLRequestFileJob : public URLRequestJob { static URLRequest::ProtocolFactory Factory; +#if defined(OS_CHROMEOS) + static bool AccessDisabled(const FilePath& file_path); +#endif + protected: virtual ~URLRequestFileJob(); @@ -45,9 +49,6 @@ class URLRequestFileJob : public URLRequestJob { private: void DidResolve(bool exists, const base::PlatformFileInfo& file_info); void DidRead(int result); -#if defined(OS_CHROMEOS) - static bool AccessDisabled(const FilePath& file_path); -#endif net::CompletionCallbackImpl<URLRequestFileJob> io_callback_; net::FileStream stream_; diff --git a/net/url_request/url_request_filter.cc b/net/url_request/url_request_filter.cc index 44863e3..fb305b2 100644 --- a/net/url_request/url_request_filter.cc +++ b/net/url_request/url_request_filter.cc @@ -18,8 +18,8 @@ URLRequestFilter* URLRequestFilter::GetInstance() { } /* static */ -URLRequestJob* URLRequestFilter::Factory(URLRequest* request, - const std::string& scheme) { +net::URLRequestJob* URLRequestFilter::Factory(net::URLRequest* request, + const std::string& scheme) { // Returning null here just means that the built-in handler will be used. return GetInstance()->FindRequestHandler(request, scheme); } @@ -112,9 +112,10 @@ void URLRequestFilter::ClearHandlers() { URLRequestFilter::URLRequestFilter() : hit_count_(0) { } -URLRequestJob* URLRequestFilter::FindRequestHandler(URLRequest* request, - const std::string& scheme) { - URLRequestJob* job = NULL; +net::URLRequestJob* URLRequestFilter::FindRequestHandler( + net::URLRequest* request, + const std::string& scheme) { + net::URLRequestJob* job = NULL; if (request->url().is_valid()) { // Check the hostname map first. const std::string& hostname = request->url().host(); diff --git a/net/url_request/url_request_filter.h b/net/url_request/url_request_filter.h index 17eb4bb..c3021cb 100644 --- a/net/url_request/url_request_filter.h +++ b/net/url_request/url_request_filter.h @@ -27,7 +27,10 @@ #include "net/url_request/url_request.h" class GURL; + +namespace net { class URLRequestJob; +} // namespace net class URLRequestFilter { public: @@ -67,8 +70,8 @@ class URLRequestFilter { URLRequestFilter(); // Helper method that looks up the request in the url_handler_map_. - URLRequestJob* FindRequestHandler(URLRequest* request, - const std::string& scheme); + net::URLRequestJob* FindRequestHandler(URLRequest* request, + const std::string& scheme); // Maps hostnames to factories. Hostnames take priority over URLs. HostnameHandlerMap hostname_handler_map_; diff --git a/net/url_request/url_request_ftp_job.h b/net/url_request/url_request_ftp_job.h index 48f963d..bd2fdc8 100644 --- a/net/url_request/url_request_ftp_job.h +++ b/net/url_request/url_request_ftp_job.h @@ -22,9 +22,10 @@ class URLRequestContext; class URLRequestFtpJob : public URLRequestJob { public: - explicit URLRequestFtpJob(URLRequest* request); + explicit URLRequestFtpJob(net::URLRequest* request); - static URLRequestJob* Factory(URLRequest* request, const std::string& scheme); + static URLRequestJob* Factory(net::URLRequest* request, + const std::string& scheme); // URLRequestJob methods: virtual bool GetMimeType(std::string* mime_type) const; diff --git a/net/url_request/url_request_http_job.cc b/net/url_request/url_request_http_job.cc index 025415b..72cd9a6 100644 --- a/net/url_request/url_request_http_job.cc +++ b/net/url_request/url_request_http_job.cc @@ -33,6 +33,8 @@ #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_error_job.h" #include "net/url_request/url_request_redirect_job.h" +#include "net/url_request/url_request_throttler_header_adapter.h" +#include "net/url_request/url_request_throttler_manager.h" static const char kAvailDictionaryHeader[] = "Avail-Dictionary"; @@ -77,7 +79,6 @@ URLRequestJob* URLRequestHttpJob::Factory(URLRequest* request, URLRequestHttpJob::URLRequestHttpJob(URLRequest* request) : URLRequestJob(request), - context_(request->context()), response_info_(NULL), response_cookies_save_index_(0), proxy_auth_state_(net::AUTH_STATE_DONT_NEED_AUTH), @@ -92,6 +93,8 @@ URLRequestHttpJob::URLRequestHttpJob(URLRequest* request) this, &URLRequestHttpJob::OnReadCompleted)), read_in_progress_(false), transaction_(NULL), + throttling_entry_(net::URLRequestThrottlerManager::GetInstance()-> + RegisterRequestUrl(request->url())), sdch_dictionary_advertised_(false), sdch_test_activated_(false), sdch_test_control_(false), @@ -472,6 +475,7 @@ void URLRequestHttpJob::OnCanSetCookieCompleted(int policy) { request_->delegate()->OnSetCookie( request_, response_cookies_[response_cookies_save_index_], + net::CookieOptions(), true); } else if ((policy == net::OK || policy == net::OK_FOR_SESSION_ONLY) && request_->context()->cookie_store()) { @@ -486,6 +490,7 @@ void URLRequestHttpJob::OnCanSetCookieCompleted(int policy) { request_->delegate()->OnSetCookie( request_, response_cookies_[response_cookies_save_index_], + options, false); } response_cookies_save_index_++; @@ -570,6 +575,12 @@ void URLRequestHttpJob::NotifyHeadersComplete() { // also need this info. is_cached_content_ = response_info_->was_cached; + if (!is_cached_content_) { + net::URLRequestThrottlerHeaderAdapter response_adapter( + response_info_->headers); + throttling_entry_->UpdateWithResponse(&response_adapter); + } + ProcessStrictTransportSecurityHeader(); if (SdchManager::Global() && @@ -609,6 +620,7 @@ void URLRequestHttpJob::DestroyTransaction() { transaction_.reset(); response_info_ = NULL; + context_ = NULL; } void URLRequestHttpJob::StartTransaction() { @@ -618,6 +630,7 @@ void URLRequestHttpJob::StartTransaction() { // with auth provided by username_ and password_. int rv; + if (transaction_.get()) { rv = transaction_->RestartWithAuth(username_, password_, &start_callback_); username_.clear(); @@ -629,8 +642,16 @@ void URLRequestHttpJob::StartTransaction() { rv = request_->context()->http_transaction_factory()->CreateTransaction( &transaction_); if (rv == net::OK) { - rv = transaction_->Start( - &request_info_, &start_callback_, request_->net_log()); + if (!throttling_entry_->IsDuringExponentialBackoff()) { + rv = transaction_->Start( + &request_info_, &start_callback_, request_->net_log()); + } else { + // Special error code for the exponential back-off module. + rv = net::ERR_TEMPORARILY_THROTTLED; + } + // Make sure the context is alive for the duration of the + // transaction. + context_ = request_->context(); } } diff --git a/net/url_request/url_request_http_job.h b/net/url_request/url_request_http_job.h index 431756a..c9139b0 100644 --- a/net/url_request/url_request_http_job.h +++ b/net/url_request/url_request_http_job.h @@ -15,6 +15,7 @@ #include "net/base/completion_callback.h" #include "net/http/http_request_info.h" #include "net/url_request/url_request_job.h" +#include "net/url_request/url_request_throttler_entry_interface.h" namespace net { class HttpResponseInfo; @@ -26,10 +27,11 @@ class URLRequestContext; // provides an implementation for both HTTP and HTTPS. class URLRequestHttpJob : public URLRequestJob { public: - static URLRequestJob* Factory(URLRequest* request, const std::string& scheme); + static URLRequestJob* Factory(net::URLRequest* request, + const std::string& scheme); protected: - explicit URLRequestHttpJob(URLRequest* request); + explicit URLRequestHttpJob(net::URLRequest* request); // URLRequestJob methods: virtual void SetUpload(net::UploadData* upload); @@ -112,6 +114,9 @@ class URLRequestHttpJob : public URLRequestJob { scoped_ptr<net::HttpTransaction> transaction_; + // This is used to supervise traffic and enforce exponential back-off. + scoped_refptr<net::URLRequestThrottlerEntryInterface> throttling_entry_; + // Indicated if an SDCH dictionary was advertised, and hence an SDCH // compressed response is expected. We use this to help detect (accidental?) // proxy corruption of a response, which sometimes marks SDCH content as diff --git a/net/url_request/url_request_job.h b/net/url_request/url_request_job.h index ca69940..239f5e9 100644 --- a/net/url_request/url_request_job.h +++ b/net/url_request/url_request_job.h @@ -23,17 +23,16 @@ class HttpRequestHeaders; class HttpResponseInfo; class IOBuffer; class UploadData; +class URLRequest; class X509Certificate; -} +} // namespace net -class URLRequest; class URLRequestStatus; class URLRequestJobMetrics; -// The URLRequestJob is using RefCounterThreadSafe because some sub classes -// can be destroyed on multiple threads. This is the case of the -// UrlRequestFileJob. -class URLRequestJob : public base::RefCountedThreadSafe<URLRequestJob>, +namespace net { + +class URLRequestJob : public base::RefCounted<URLRequestJob>, public FilterContext { public: // When histogramming results related to SDCH and/or an SDCH latency test, the @@ -43,11 +42,11 @@ class URLRequestJob : public base::RefCountedThreadSafe<URLRequestJob>, // congestion window on stalling of transmissions. static const size_t kSdchPacketHistogramCount = 5; - explicit URLRequestJob(URLRequest* request); + explicit URLRequestJob(net::URLRequest* request); // Returns the request that owns this job. THIS POINTER MAY BE NULL if the // request was destroyed. - URLRequest* request() const { + net::URLRequest* request() const { return request_; } @@ -218,7 +217,7 @@ class URLRequestJob : public base::RefCountedThreadSafe<URLRequestJob>, virtual void RecordPacketStats(StatisticSelector statistic) const; protected: - friend class base::RefCountedThreadSafe<URLRequestJob>; + friend class base::RefCounted<URLRequestJob>; virtual ~URLRequestJob(); // Notifies the job that headers have been received. @@ -281,7 +280,7 @@ class URLRequestJob : public base::RefCountedThreadSafe<URLRequestJob>, // The request that initiated this job. This value MAY BE NULL if the // request was released by DetachRequest(). - URLRequest* request_; + net::URLRequest* request_; // The status of the job. const URLRequestStatus GetStatus(); @@ -423,4 +422,8 @@ class URLRequestJob : public base::RefCountedThreadSafe<URLRequestJob>, DISALLOW_COPY_AND_ASSIGN(URLRequestJob); }; +} // namespace net + +typedef net::URLRequestJob URLRequestJob; + #endif // NET_URL_REQUEST_URL_REQUEST_JOB_H_ diff --git a/net/url_request/url_request_job_manager.h b/net/url_request/url_request_job_manager.h index f459f40..0fbc31e 100644 --- a/net/url_request/url_request_job_manager.h +++ b/net/url_request/url_request_job_manager.h @@ -33,18 +33,18 @@ class URLRequestJobManager { // Instantiate an URLRequestJob implementation based on the registered // interceptors and protocol factories. This will always succeed in // returning a job unless we are--in the extreme case--out of memory. - URLRequestJob* CreateJob(URLRequest* request) const; + net::URLRequestJob* CreateJob(net::URLRequest* request) const; // Allows interceptors to hijack the request after examining the new location // of a redirect. Returns NULL if no interceptor intervenes. - URLRequestJob* MaybeInterceptRedirect(URLRequest* request, - const GURL& location) const; + net::URLRequestJob* MaybeInterceptRedirect(net::URLRequest* request, + const GURL& location) const; // Allows interceptors to hijack the request after examining the response // status and headers. This is also called when there is no server response // at all to allow interception of failed requests due to network errors. // Returns NULL if no interceptor intervenes. - URLRequestJob* MaybeInterceptResponse(URLRequest* request) const; + net::URLRequestJob* MaybeInterceptResponse(net::URLRequest* request) const; // Returns true if there is a protocol factory registered for the given // scheme. Note: also returns true if there is a built-in handler for the diff --git a/net/url_request/url_request_job_tracker.h b/net/url_request/url_request_job_tracker.h index 8b554b9..cd0bd86 100644 --- a/net/url_request/url_request_job_tracker.h +++ b/net/url_request/url_request_job_tracker.h @@ -11,7 +11,10 @@ #include "base/observer_list.h" #include "net/url_request/url_request_status.h" +namespace net { class URLRequestJob; +} // namespace net + class GURL; // This class maintains a list of active URLRequestJobs for debugging purposes. @@ -23,25 +26,25 @@ class GURL; // class URLRequestJobTracker { public: - typedef std::vector<URLRequestJob*> JobList; + typedef std::vector<net::URLRequestJob*> JobList; typedef JobList::const_iterator JobIterator; // The observer's methods are called on the thread that called AddObserver. class JobObserver { public: // Called after the given job has been added to the list - virtual void OnJobAdded(URLRequestJob* job) = 0; + virtual void OnJobAdded(net::URLRequestJob* job) = 0; // Called after the given job has been removed from the list - virtual void OnJobRemoved(URLRequestJob* job) = 0; + virtual void OnJobRemoved(net::URLRequestJob* job) = 0; // Called when the given job has completed, before notifying the request - virtual void OnJobDone(URLRequestJob* job, + virtual void OnJobDone(net::URLRequestJob* job, const URLRequestStatus& status) = 0; // Called when the given job is about to follow a redirect to the given // new URL. The redirect type is given in status_code - virtual void OnJobRedirect(URLRequestJob* job, const GURL& location, + virtual void OnJobRedirect(net::URLRequestJob* job, const GURL& location, int status_code) = 0; // Called when a new chunk of unfiltered bytes has been read for @@ -49,7 +52,7 @@ class URLRequestJobTracker { // read event only. |buf| is a pointer to the data buffer that // contains those bytes. The data in |buf| is only valid for the // duration of the OnBytesRead callback. - virtual void OnBytesRead(URLRequestJob* job, const char* buf, + virtual void OnBytesRead(net::URLRequestJob* job, const char* buf, int byte_count) = 0; virtual ~JobObserver() {} @@ -70,16 +73,16 @@ class URLRequestJobTracker { // adds or removes the job from the active list, should be called by the // job constructor and destructor. Note: don't use "AddJob" since that // is #defined by windows.h :( - void AddNewJob(URLRequestJob* job); - void RemoveJob(URLRequestJob* job); + void AddNewJob(net::URLRequestJob* job); + void RemoveJob(net::URLRequestJob* job); // Job status change notifications - void OnJobDone(URLRequestJob* job, const URLRequestStatus& status); - void OnJobRedirect(URLRequestJob* job, const GURL& location, + void OnJobDone(net::URLRequestJob* job, const URLRequestStatus& status); + void OnJobRedirect(net::URLRequestJob* job, const GURL& location, int status_code); // Bytes read notifications. - void OnBytesRead(URLRequestJob* job, const char* buf, int byte_count); + void OnBytesRead(net::URLRequestJob* job, const char* buf, int byte_count); // allows iteration over all active jobs JobIterator begin() const { diff --git a/net/url_request/url_request_redirect_job.cc b/net/url_request/url_request_redirect_job.cc index d8a1a3e..001da10 100644 --- a/net/url_request/url_request_redirect_job.cc +++ b/net/url_request/url_request_redirect_job.cc @@ -6,7 +6,7 @@ #include "base/message_loop.h" -URLRequestRedirectJob::URLRequestRedirectJob(URLRequest* request, +URLRequestRedirectJob::URLRequestRedirectJob(net::URLRequest* request, GURL redirect_destination) : URLRequestJob(request), redirect_destination_(redirect_destination) { } diff --git a/net/url_request/url_request_redirect_job.h b/net/url_request/url_request_redirect_job.h index 55c34a7..7466cec 100644 --- a/net/url_request/url_request_redirect_job.h +++ b/net/url_request/url_request_redirect_job.h @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -16,7 +16,7 @@ class GURL; class URLRequestRedirectJob : public URLRequestJob { public: // Constructs a job that redirects to the specified URL. - URLRequestRedirectJob(URLRequest* request, GURL redirect_destination); + URLRequestRedirectJob(net::URLRequest* request, GURL redirect_destination); virtual void Start(); bool IsRedirectResponse(GURL* location, int* http_status_code); @@ -30,4 +30,3 @@ class URLRequestRedirectJob : public URLRequestJob { }; #endif // NET_URL_REQUEST_URL_REQUEST_REDIRECT_JOB_H_ - diff --git a/net/url_request/url_request_simple_job.cc b/net/url_request/url_request_simple_job.cc index 2f23d58..38e0c4d 100644 --- a/net/url_request/url_request_simple_job.cc +++ b/net/url_request/url_request_simple_job.cc @@ -9,7 +9,7 @@ #include "net/base/net_errors.h" #include "net/url_request/url_request_status.h" -URLRequestSimpleJob::URLRequestSimpleJob(URLRequest* request) +URLRequestSimpleJob::URLRequestSimpleJob(net::URLRequest* request) : URLRequestJob(request), data_offset_(0) { } diff --git a/net/url_request/url_request_simple_job.h b/net/url_request/url_request_simple_job.h index bcc4047..877b081 100644 --- a/net/url_request/url_request_simple_job.h +++ b/net/url_request/url_request_simple_job.h @@ -10,11 +10,13 @@ #include "net/url_request/url_request_job.h" +namespace net { class URLRequest; +} // namespace net -class URLRequestSimpleJob : public URLRequestJob { +class URLRequestSimpleJob : public net::URLRequestJob { public: - explicit URLRequestSimpleJob(URLRequest* request); + explicit URLRequestSimpleJob(net::URLRequest* request); virtual void Start(); virtual bool ReadRawData(net::IOBuffer* buf, int buf_size, int *bytes_read); diff --git a/net/url_request/url_request_test_job.h b/net/url_request/url_request_test_job.h index 7cb4777..14b90d4 100644 --- a/net/url_request/url_request_test_job.h +++ b/net/url_request/url_request_test_job.h @@ -33,7 +33,7 @@ // // Optionally, you can also construct test jobs that advance automatically // without having to call ProcessOnePendingMessage. -class URLRequestTestJob : public URLRequestJob { +class URLRequestTestJob : public net::URLRequestJob { public: // Constructs a job to return one of the canned responses depending on the // request url, with auto advance disabled. diff --git a/net/url_request/url_request_throttler_entry.cc b/net/url_request/url_request_throttler_entry.cc new file mode 100644 index 0000000..4abb438 --- /dev/null +++ b/net/url_request/url_request_throttler_entry.cc @@ -0,0 +1,242 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/url_request/url_request_throttler_entry.h" + +#include <cmath> + +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/string_number_conversions.h" +#include "net/url_request/url_request_throttler_header_interface.h" + +namespace net { + +const int URLRequestThrottlerEntry::kDefaultSlidingWindowPeriodMs = 2000; +const int URLRequestThrottlerEntry::kDefaultMaxSendThreshold = 20; +const int URLRequestThrottlerEntry::kDefaultInitialBackoffMs = 700; +const int URLRequestThrottlerEntry::kDefaultAdditionalConstantMs = 100; +const double URLRequestThrottlerEntry::kDefaultMultiplyFactor = 1.4; +const double URLRequestThrottlerEntry::kDefaultJitterFactor = 0.4; +const int URLRequestThrottlerEntry::kDefaultMaximumBackoffMs = 60 * 60 * 1000; +const int URLRequestThrottlerEntry::kDefaultEntryLifetimeMs = 120000; +const char URLRequestThrottlerEntry::kRetryHeaderName[] = "X-Retry-After"; + +URLRequestThrottlerEntry::URLRequestThrottlerEntry() + : sliding_window_period_( + base::TimeDelta::FromMilliseconds(kDefaultSlidingWindowPeriodMs)), + max_send_threshold_(kDefaultMaxSendThreshold), + initial_backoff_ms_(kDefaultInitialBackoffMs), + additional_constant_ms_(kDefaultAdditionalConstantMs), + multiply_factor_(kDefaultMultiplyFactor), + jitter_factor_(kDefaultJitterFactor), + maximum_backoff_ms_(kDefaultMaximumBackoffMs), + entry_lifetime_ms_(kDefaultEntryLifetimeMs) { + Initialize(); +} + +URLRequestThrottlerEntry::URLRequestThrottlerEntry( + int sliding_window_period_ms, + int max_send_threshold, + int initial_backoff_ms, + int additional_constant_ms, + double multiply_factor, + double jitter_factor, + int maximum_backoff_ms) + : sliding_window_period_( + base::TimeDelta::FromMilliseconds(sliding_window_period_ms)), + max_send_threshold_(max_send_threshold), + initial_backoff_ms_(initial_backoff_ms), + additional_constant_ms_(additional_constant_ms), + multiply_factor_(multiply_factor), + jitter_factor_(jitter_factor), + maximum_backoff_ms_(maximum_backoff_ms), + entry_lifetime_ms_(-1) { + DCHECK_GT(sliding_window_period_ms, 0); + DCHECK_GT(max_send_threshold_, 0); + DCHECK_GE(initial_backoff_ms_, 0); + DCHECK_GE(additional_constant_ms_, 0); + DCHECK_GT(multiply_factor_, 0); + DCHECK_GE(jitter_factor_, 0); + DCHECK_LT(jitter_factor_, 1); + DCHECK_GE(maximum_backoff_ms_, 0); + + Initialize(); +} + +URLRequestThrottlerEntry::~URLRequestThrottlerEntry() { +} + +void URLRequestThrottlerEntry::Initialize() { + // Since this method is called by the constructors, GetTimeNow() (a virtual + // method) is not used. + exponential_backoff_release_time_ = base::TimeTicks::Now(); + failure_count_ = 0; + latest_response_was_failure_ = false; + + sliding_window_release_time_ = base::TimeTicks::Now(); +} + +bool URLRequestThrottlerEntry::IsDuringExponentialBackoff() const { + return exponential_backoff_release_time_ > GetTimeNow(); +} + +int64 URLRequestThrottlerEntry::ReserveSendingTimeForNextRequest( + const base::TimeTicks& earliest_time) { + base::TimeTicks now = GetTimeNow(); + // If a lot of requests were successfully made recently, + // sliding_window_release_time_ may be greater than + // exponential_backoff_release_time_. + base::TimeTicks recommended_sending_time = + std::max(std::max(now, earliest_time), + std::max(exponential_backoff_release_time_, + sliding_window_release_time_)); + + DCHECK(send_log_.empty() || + recommended_sending_time >= send_log_.back()); + // Log the new send event. + send_log_.push(recommended_sending_time); + + sliding_window_release_time_ = recommended_sending_time; + + // Drop the out-of-date events in the event list. + // We don't need to worry that the queue may become empty during this + // operation, since the last element is sliding_window_release_time_. + while ((send_log_.front() + sliding_window_period_ <= + sliding_window_release_time_) || + send_log_.size() > static_cast<unsigned>(max_send_threshold_)) { + send_log_.pop(); + } + + // Check if there are too many send events in recent time. + if (send_log_.size() == static_cast<unsigned>(max_send_threshold_)) + sliding_window_release_time_ = send_log_.front() + sliding_window_period_; + + return (recommended_sending_time - now).InMillisecondsRoundedUp(); +} + +base::TimeTicks + URLRequestThrottlerEntry::GetExponentialBackoffReleaseTime() const { + return exponential_backoff_release_time_; +} + +void URLRequestThrottlerEntry::UpdateWithResponse( + const URLRequestThrottlerHeaderInterface* response) { + if (response->GetResponseCode() >= 500) { + failure_count_++; + latest_response_was_failure_ = true; + exponential_backoff_release_time_ = + CalculateExponentialBackoffReleaseTime(); + } else { + // We slowly decay the number of times delayed instead of resetting it to 0 + // in order to stay stable if we received lots of requests with + // malformed bodies at the same time. + if (failure_count_ > 0) + failure_count_--; + + latest_response_was_failure_ = false; + + // The reason why we are not just cutting the release time to GetTimeNow() + // is on the one hand, it would unset delay put by our custom retry-after + // header and on the other we would like to push every request up to our + // "horizon" when dealing with multiple in-flight requests. Ex: If we send + // three requests and we receive 2 failures and 1 success. The success that + // follows those failures will not reset the release time, further requests + // will then need to wait the delay caused by the 2 failures. + exponential_backoff_release_time_ = std::max( + GetTimeNow(), exponential_backoff_release_time_); + + std::string retry_header = response->GetNormalizedValue(kRetryHeaderName); + if (!retry_header.empty()) + HandleCustomRetryAfter(retry_header); + } +} + +bool URLRequestThrottlerEntry::IsEntryOutdated() const { + if (entry_lifetime_ms_ == -1) + return false; + + base::TimeTicks now = GetTimeNow(); + + // If there are send events in the sliding window period, we still need this + // entry. + if (send_log_.size() > 0 && + send_log_.back() + sliding_window_period_ > now) { + return false; + } + + int64 unused_since_ms = + (now - exponential_backoff_release_time_).InMilliseconds(); + + // Release time is further than now, we are managing it. + if (unused_since_ms < 0) + return false; + + // latest_response_was_failure_ is true indicates that the latest one or + // more requests encountered server errors or had malformed response bodies. + // In that case, we don't want to collect the entry unless it hasn't been used + // for longer than the maximum allowed back-off. + if (latest_response_was_failure_) + return unused_since_ms > std::max(maximum_backoff_ms_, entry_lifetime_ms_); + + // Otherwise, consider the entry is outdated if it hasn't been used for the + // specified lifetime period. + return unused_since_ms > entry_lifetime_ms_; +} + +void URLRequestThrottlerEntry::ReceivedContentWasMalformed() { + // For any response that is marked as malformed now, we have probably + // considered it as a success when receiving it and decreased the failure + // count by 1. As a result, we increase the failure count by 2 here to undo + // the effect and record a failure. + // + // Please note that this may lead to a larger failure count than expected, + // because we don't decrease the failure count for successful responses when + // it has already reached 0. + failure_count_ += 2; + latest_response_was_failure_ = true; + exponential_backoff_release_time_ = CalculateExponentialBackoffReleaseTime(); +} + +base::TimeTicks + URLRequestThrottlerEntry::CalculateExponentialBackoffReleaseTime() { + double delay = initial_backoff_ms_; + delay *= pow(multiply_factor_, failure_count_); + delay += additional_constant_ms_; + delay -= base::RandDouble() * jitter_factor_ * delay; + + // Ensure that we do not exceed maximum delay. + int64 delay_int = static_cast<int64>(delay + 0.5); + delay_int = std::min(delay_int, static_cast<int64>(maximum_backoff_ms_)); + + return std::max(GetTimeNow() + base::TimeDelta::FromMilliseconds(delay_int), + exponential_backoff_release_time_); +} + +base::TimeTicks URLRequestThrottlerEntry::GetTimeNow() const { + return base::TimeTicks::Now(); +} + +void URLRequestThrottlerEntry::HandleCustomRetryAfter( + const std::string& header_value) { + // Input parameter is the number of seconds to wait in a floating point value. + double time_in_sec = 0; + bool conversion_is_ok = base::StringToDouble(header_value, &time_in_sec); + + // Conversion of custom retry-after header value failed. + if (!conversion_is_ok) + return; + + // We must use an int value later so we transform this in milliseconds. + int64 value_ms = static_cast<int64>(0.5 + time_in_sec * 1000); + + if (maximum_backoff_ms_ < value_ms || value_ms < 0) + return; + + exponential_backoff_release_time_ = std::max( + (GetTimeNow() + base::TimeDelta::FromMilliseconds(value_ms)), + exponential_backoff_release_time_); +} + +} // namespace net diff --git a/net/url_request/url_request_throttler_entry.h b/net/url_request/url_request_throttler_entry.h new file mode 100644 index 0000000..9b8955d --- /dev/null +++ b/net/url_request/url_request_throttler_entry.h @@ -0,0 +1,157 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_URL_REQUEST_URL_REQUEST_THROTTLER_ENTRY_H_ +#define NET_URL_REQUEST_URL_REQUEST_THROTTLER_ENTRY_H_ + +#include <queue> +#include <string> + +#include "base/basictypes.h" +#include "net/url_request/url_request_throttler_entry_interface.h" + +namespace net { + +// URLRequestThrottlerEntry represents an entry of URLRequestThrottlerManager. +// It analyzes requests of a specific URL over some period of time, in order to +// deduce the back-off time for every request. +// The back-off algorithm consists of two parts. Firstly, exponential back-off +// is used when receiving 5XX server errors or malformed response bodies. +// The exponential back-off rule is enforced by URLRequestHttpJob. Any request +// sent during the back-off period will be cancelled. +// Secondly, a sliding window is used to count recent requests to a given +// destination and provide guidance (to the application level only) on whether +// too many requests have been sent and when a good time to send the next one +// would be. This is never used to deny requests at the network level. +class URLRequestThrottlerEntry : public URLRequestThrottlerEntryInterface { + public: + // Sliding window period. + static const int kDefaultSlidingWindowPeriodMs; + + // Maximum number of requests allowed in sliding window period. + static const int kDefaultMaxSendThreshold; + + // Initial delay for exponential back-off. + static const int kDefaultInitialBackoffMs; + + // Additional constant to adjust back-off. + static const int kDefaultAdditionalConstantMs; + + // Factor by which the waiting time will be multiplied. + static const double kDefaultMultiplyFactor; + + // Fuzzing percentage. ex: 10% will spread requests randomly + // between 90%-100% of the calculated time. + static const double kDefaultJitterFactor; + + // Maximum amount of time we are willing to delay our request. + static const int kDefaultMaximumBackoffMs; + + // Time after which the entry is considered outdated. + static const int kDefaultEntryLifetimeMs; + + // Name of the header that servers can use to ask clients to delay their next + // request. + static const char kRetryHeaderName[]; + + URLRequestThrottlerEntry(); + + // The life span of instances created with this constructor is set to + // infinite. + // It is only used by unit tests. + URLRequestThrottlerEntry(int sliding_window_period_ms, + int max_send_threshold, + int initial_backoff_ms, + int additional_constant_ms, + double multiply_factor, + double jitter_factor, + int maximum_backoff_ms); + + // Implementation of URLRequestThrottlerEntryInterface. + virtual bool IsDuringExponentialBackoff() const; + virtual int64 ReserveSendingTimeForNextRequest( + const base::TimeTicks& earliest_time); + virtual base::TimeTicks GetExponentialBackoffReleaseTime() const; + virtual void UpdateWithResponse( + const URLRequestThrottlerHeaderInterface* response); + virtual void ReceivedContentWasMalformed(); + + // Used by the manager, returns true if the entry needs to be garbage + // collected. + bool IsEntryOutdated() const; + + protected: + virtual ~URLRequestThrottlerEntry(); + + void Initialize(); + + // Calculates the release time for exponential back-off. + base::TimeTicks CalculateExponentialBackoffReleaseTime(); + + // Equivalent to TimeTicks::Now(), virtual to be mockable for testing purpose. + virtual base::TimeTicks GetTimeNow() const; + + // Used internally to increase release time following a retry-after header. + void HandleCustomRetryAfter(const std::string& header_value); + + // Used by tests. + void set_exponential_backoff_release_time( + const base::TimeTicks& release_time) { + exponential_backoff_release_time_ = release_time; + } + + // Used by tests. + base::TimeTicks sliding_window_release_time() const { + return sliding_window_release_time_; + } + + // Used by tests. + void set_sliding_window_release_time(const base::TimeTicks& release_time) { + sliding_window_release_time_ = release_time; + } + + // Used by tests. + void set_failure_count(int failure_count) { + failure_count_ = failure_count; + } + + private: + // Timestamp calculated by the exponential back-off algorithm at which we are + // allowed to start sending requests again. + base::TimeTicks exponential_backoff_release_time_; + + // Number of times we encounter server errors or malformed response bodies. + int failure_count_; + + // If true, the last request response was a failure. + // Note that this member can be false at the same time as failure_count_ can + // be greater than 0, since we gradually decrease failure_count_, instead of + // resetting it to 0 directly, when we receive successful responses. + bool latest_response_was_failure_; + + // Timestamp calculated by the sliding window algorithm for when we advise + // clients the next request should be made, at the earliest. Advisory only, + // not used to deny requests. + base::TimeTicks sliding_window_release_time_; + + // A list of the recent send events. We use them to decide whether there are + // too many requests sent in sliding window. + std::queue<base::TimeTicks> send_log_; + + const base::TimeDelta sliding_window_period_; + const int max_send_threshold_; + const int initial_backoff_ms_; + const int additional_constant_ms_; + const double multiply_factor_; + const double jitter_factor_; + const int maximum_backoff_ms_; + // Set to -1 if the entry never expires. + const int entry_lifetime_ms_; + + DISALLOW_COPY_AND_ASSIGN(URLRequestThrottlerEntry); +}; + +} // namespace net + +#endif // NET_URL_REQUEST_URL_REQUEST_THROTTLER_ENTRY_H_ diff --git a/net/url_request/url_request_throttler_entry_interface.h b/net/url_request/url_request_throttler_entry_interface.h new file mode 100644 index 0000000..f443b29 --- /dev/null +++ b/net/url_request/url_request_throttler_entry_interface.h @@ -0,0 +1,64 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_URL_REQUEST_URL_REQUEST_THROTTLER_ENTRY_INTERFACE_H_ +#define NET_URL_REQUEST_URL_REQUEST_THROTTLER_ENTRY_INTERFACE_H_ + +#include "base/basictypes.h" +#include "base/ref_counted.h" +#include "base/time.h" + +namespace net { + +class URLRequestThrottlerHeaderInterface; + +// Interface provided on entries of the URL request throttler manager. +class URLRequestThrottlerEntryInterface + : public base::RefCountedThreadSafe<URLRequestThrottlerEntryInterface> { + public: + URLRequestThrottlerEntryInterface() {} + + // Returns true when we have encountered server errors and are doing + // exponential back-off. + // URLRequestHttpJob checks this method prior to every request; it cancels + // requests if this method returns true. + virtual bool IsDuringExponentialBackoff() const = 0; + + // Calculates a recommended sending time for the next request and reserves it. + // The sending time is not earlier than the current exponential back-off + // release time or |earliest_time|. Moreover, the previous results of + // the method are taken into account, in order to make sure they are spread + // properly over time. + // Returns the recommended delay before sending the next request, in + // milliseconds. The return value is always positive or 0. + // Although it is not mandatory, respecting the value returned by this method + // is helpful to avoid traffic overload. + virtual int64 ReserveSendingTimeForNextRequest( + const base::TimeTicks& earliest_time) = 0; + + // Returns the time after which requests are allowed. + virtual base::TimeTicks GetExponentialBackoffReleaseTime() const = 0; + + // This method needs to be called each time a response is received. + virtual void UpdateWithResponse( + const URLRequestThrottlerHeaderInterface* response) = 0; + + // Lets higher-level modules, that know how to parse particular response + // bodies, notify of receiving malformed content for the given URL. This will + // be handled by the throttler as if an HTTP 5xx response had been received to + // the request, i.e. it will count as a failure. + virtual void ReceivedContentWasMalformed() = 0; + + protected: + friend class base::RefCountedThreadSafe<URLRequestThrottlerEntryInterface>; + virtual ~URLRequestThrottlerEntryInterface() {} + + private: + friend class base::RefCounted<URLRequestThrottlerEntryInterface>; + DISALLOW_COPY_AND_ASSIGN(URLRequestThrottlerEntryInterface); +}; + +} // namespace net + +#endif // NET_URL_REQUEST_URL_REQUEST_THROTTLER_ENTRY_INTERFACE_H_ diff --git a/net/url_request/url_request_throttler_header_adapter.cc b/net/url_request/url_request_throttler_header_adapter.cc new file mode 100644 index 0000000..e453071 --- /dev/null +++ b/net/url_request/url_request_throttler_header_adapter.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/url_request/url_request_throttler_header_adapter.h" + +#include "net/http/http_response_headers.h" + +namespace net { + +URLRequestThrottlerHeaderAdapter::URLRequestThrottlerHeaderAdapter( + net::HttpResponseHeaders* headers) + : response_header_(headers) { +} + +std::string URLRequestThrottlerHeaderAdapter::GetNormalizedValue( + const std::string& key) const { + std::string return_value; + response_header_->GetNormalizedHeader(key, &return_value); + return return_value; +} + +int URLRequestThrottlerHeaderAdapter::GetResponseCode() const { + return response_header_->response_code(); +} + +} // namespace net diff --git a/net/url_request/url_request_throttler_header_adapter.h b/net/url_request/url_request_throttler_header_adapter.h new file mode 100644 index 0000000..599a9f6 --- /dev/null +++ b/net/url_request/url_request_throttler_header_adapter.h @@ -0,0 +1,34 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_URL_REQUEST_URL_REQUEST_THROTTLER_HEADER_ADAPTER_H_ +#define NET_URL_REQUEST_URL_REQUEST_THROTTLER_HEADER_ADAPTER_H_ + +#include <string> + +#include "base/ref_counted.h" +#include "net/url_request/url_request_throttler_header_interface.h" + +namespace net { + +class HttpResponseHeaders; + +// Adapter for the HTTP header interface of the URL request throttler component. +class URLRequestThrottlerHeaderAdapter + : public URLRequestThrottlerHeaderInterface { + public: + explicit URLRequestThrottlerHeaderAdapter(net::HttpResponseHeaders* headers); + virtual ~URLRequestThrottlerHeaderAdapter() {} + + // Implementation of URLRequestThrottlerHeaderInterface + virtual std::string GetNormalizedValue(const std::string& key) const; + virtual int GetResponseCode() const; + + private: + const scoped_refptr<net::HttpResponseHeaders> response_header_; +}; + +} // namespace net + +#endif // NET_URL_REQUEST_URL_REQUEST_THROTTLER_HEADER_ADAPTER_H_ diff --git a/net/url_request/url_request_throttler_header_interface.h b/net/url_request/url_request_throttler_header_interface.h new file mode 100644 index 0000000..c69d185 --- /dev/null +++ b/net/url_request/url_request_throttler_header_interface.h @@ -0,0 +1,28 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_URL_REQUEST_URL_REQUEST_THROTTLER_HEADER_INTERFACE_H_ +#define NET_URL_REQUEST_URL_REQUEST_THROTTLER_HEADER_INTERFACE_H_ + +#include <string> + +namespace net { + +// Interface to an HTTP header to enforce we have the methods we need. +class URLRequestThrottlerHeaderInterface { + public: + virtual ~URLRequestThrottlerHeaderInterface() {} + + // Method that enables us to fetch the header value by its key. + // ex: location: www.example.com -> key = "location" value = "www.example.com" + // If the key does not exist, it returns an empty string. + virtual std::string GetNormalizedValue(const std::string& key) const = 0; + + // Returns the HTTP response code associated with the request. + virtual int GetResponseCode() const = 0; +}; + +} // namespace net + +#endif // NET_URL_REQUEST_URL_REQUEST_THROTTLER_HEADER_INTERFACE_H_ diff --git a/net/url_request/url_request_throttler_manager.cc b/net/url_request/url_request_throttler_manager.cc new file mode 100644 index 0000000..5428d9a --- /dev/null +++ b/net/url_request/url_request_throttler_manager.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/url_request/url_request_throttler_manager.h" + +#include "base/string_util.h" + +namespace net { + +const unsigned int URLRequestThrottlerManager::kMaximumNumberOfEntries = 1500; +const unsigned int URLRequestThrottlerManager::kRequestsBetweenCollecting = 200; + +URLRequestThrottlerManager* URLRequestThrottlerManager::GetInstance() { + return Singleton<URLRequestThrottlerManager>::get(); +} + +scoped_refptr<URLRequestThrottlerEntryInterface> + URLRequestThrottlerManager::RegisterRequestUrl(const GURL &url) { + // Normalize the url. + std::string url_id = GetIdFromUrl(url); + + // Periodically garbage collect old entries. + GarbageCollectEntriesIfNecessary(); + + // Find the entry in the map or create it. + scoped_refptr<URLRequestThrottlerEntry>& entry = url_entries_[url_id]; + if (entry == NULL) + entry = new URLRequestThrottlerEntry(); + + return entry; +} + +URLRequestThrottlerManager::URLRequestThrottlerManager() + : requests_since_last_gc_(0) { +} + +URLRequestThrottlerManager::~URLRequestThrottlerManager() { + // Delete all entries. + url_entries_.clear(); +} + +std::string URLRequestThrottlerManager::GetIdFromUrl(const GURL& url) const { + if (!url.is_valid()) + return url.possibly_invalid_spec(); + + if (url_id_replacements_ == NULL) { + url_id_replacements_.reset(new GURL::Replacements()); + + url_id_replacements_->ClearPassword(); + url_id_replacements_->ClearUsername(); + url_id_replacements_->ClearQuery(); + url_id_replacements_->ClearRef(); + } + + GURL id = url.ReplaceComponents(*url_id_replacements_); + return StringToLowerASCII(id.spec()); +} + +void URLRequestThrottlerManager::GarbageCollectEntries() { + UrlEntryMap::iterator i = url_entries_.begin(); + + while (i != url_entries_.end()) { + if ((i->second)->IsEntryOutdated()) { + url_entries_.erase(i++); + } else { + ++i; + } + } + + // In case something broke we want to make sure not to grow indefinitely. + while (url_entries_.size() > kMaximumNumberOfEntries) { + url_entries_.erase(url_entries_.begin()); + } +} + +void URLRequestThrottlerManager::GarbageCollectEntriesIfNecessary() { + requests_since_last_gc_++; + if (requests_since_last_gc_ < kRequestsBetweenCollecting) + return; + + requests_since_last_gc_ = 0; + GarbageCollectEntries(); +} + +void URLRequestThrottlerManager::OverrideEntryForTests( + const GURL& url, + URLRequestThrottlerEntry* entry) { + if (entry == NULL) + return; + + // Normalize the url. + std::string url_id = GetIdFromUrl(url); + + // Periodically garbage collect old entries. + GarbageCollectEntriesIfNecessary(); + + url_entries_[url_id] = entry; +} + +void URLRequestThrottlerManager::EraseEntryForTests(const GURL& url) { + // Normalize the url. + std::string url_id = GetIdFromUrl(url); + url_entries_.erase(url_id); +} + +} // namespace net diff --git a/net/url_request/url_request_throttler_manager.h b/net/url_request/url_request_throttler_manager.h new file mode 100644 index 0000000..6c8cd2f --- /dev/null +++ b/net/url_request/url_request_throttler_manager.h @@ -0,0 +1,101 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_URL_REQUEST_URL_REQUEST_THROTTLER_MANAGER_H_ +#define NET_URL_REQUEST_URL_REQUEST_THROTTLER_MANAGER_H_ + +#include <map> +#include <string> + +#include "base/basictypes.h" +#include "base/scoped_ptr.h" +#include "base/singleton.h" +#include "googleurl/src/gurl.h" +#include "net/url_request/url_request_throttler_entry.h" + +namespace net { + +// Class that registers URL request throttler entries for URLs being accessed in +// order to supervise traffic. URL requests for HTTP contents should register +// their URLs in this manager on each request. +// URLRequestThrottlerManager maintains a map of URL IDs to URL request +// throttler entries. It creates URL request throttler entries when new URLs are +// registered, and does garbage collection from time to time in order to clean +// out outdated entries. URL ID consists of lowercased scheme, host, port and +// path. All URLs converted to the same ID will share the same entry. +// +// NOTE: All usage of the singleton object of this class should be on the same +// thread. +class URLRequestThrottlerManager { + public: + static URLRequestThrottlerManager* GetInstance(); + + // Must be called for every request, returns the URL request throttler entry + // associated with the URL. The caller must inform this entry of some events. + // Please refer to url_request_throttler_entry_interface.h for further + // informations. + scoped_refptr<URLRequestThrottlerEntryInterface> RegisterRequestUrl( + const GURL& url); + + // Registers a new entry in this service and overrides the existing entry (if + // any) for the URL. The service will hold a reference to the entry. + // It is only used by unit tests. + void OverrideEntryForTests(const GURL& url, URLRequestThrottlerEntry* entry); + + // Explicitly erases an entry. + // This is useful to remove those entries which have got infinite lifetime and + // thus won't be garbage collected. + // It is only used by unit tests. + void EraseEntryForTests(const GURL& url); + + protected: + URLRequestThrottlerManager(); + ~URLRequestThrottlerManager(); + + // Method that allows us to transform a URL into an ID that can be used in our + // map. Resulting IDs will be lowercase and consist of the scheme, host, port + // and path (without query string, fragment, etc.). + // If the URL is invalid, the invalid spec will be returned, without any + // transformation. + std::string GetIdFromUrl(const GURL& url) const; + + // Method that ensures the map gets cleaned from time to time. The period at + // which garbage collecting happens is adjustable with the + // kRequestBetweenCollecting constant. + void GarbageCollectEntriesIfNecessary(); + // Method that does the actual work of garbage collecting. + void GarbageCollectEntries(); + + // Used by tests. + int GetNumberOfEntriesForTests() const { return url_entries_.size(); } + + private: + friend struct DefaultSingletonTraits<URLRequestThrottlerManager>; + + // From each URL we generate an ID composed of the scheme, host, port and path + // that allows us to uniquely map an entry to it. + typedef std::map<std::string, scoped_refptr<URLRequestThrottlerEntry> > + UrlEntryMap; + + // Maximum number of entries that we are willing to collect in our map. + static const unsigned int kMaximumNumberOfEntries; + // Number of requests that will be made between garbage collection. + static const unsigned int kRequestsBetweenCollecting; + + // Map that contains a list of URL ID and their matching + // URLRequestThrottlerEntry. + UrlEntryMap url_entries_; + + // This keeps track of how many requests have been made. Used with + // GarbageCollectEntries. + unsigned int requests_since_last_gc_; + + mutable scoped_ptr<GURL::Replacements> url_id_replacements_; + + DISALLOW_COPY_AND_ASSIGN(URLRequestThrottlerManager); +}; + +} // namespace net + +#endif // NET_URL_REQUEST_URL_REQUEST_THROTTLER_MANAGER_H_ diff --git a/net/url_request/url_request_throttler_unittest.cc b/net/url_request/url_request_throttler_unittest.cc new file mode 100644 index 0000000..0683f91 --- /dev/null +++ b/net/url_request/url_request_throttler_unittest.cc @@ -0,0 +1,346 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/pickle.h" +#include "base/scoped_ptr.h" +#include "base/string_number_conversions.h" +#include "base/time.h" +#include "net/base/test_completion_callback.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_throttler_header_interface.h" +#include "net/url_request/url_request_throttler_manager.h" +#include "testing/gtest/include/gtest/gtest.h" + +using base::TimeDelta; +using base::TimeTicks; + +namespace { +class MockURLRequestThrottlerManager; + +class MockURLRequestThrottlerEntry : public net::URLRequestThrottlerEntry { + public : + MockURLRequestThrottlerEntry() {} + MockURLRequestThrottlerEntry( + const TimeTicks& exponential_backoff_release_time, + const TimeTicks& sliding_window_release_time, + const TimeTicks& fake_now) + : fake_time_now_(fake_now) { + set_exponential_backoff_release_time(exponential_backoff_release_time); + set_sliding_window_release_time(sliding_window_release_time); + } + virtual ~MockURLRequestThrottlerEntry() {} + + void ResetToBlank(const TimeTicks& time_now) { + fake_time_now_ = time_now; + set_exponential_backoff_release_time(time_now); + set_failure_count(0); + set_sliding_window_release_time(time_now); + } + + // Overridden for tests. + virtual TimeTicks GetTimeNow() const { return fake_time_now_; } + + void set_exponential_backoff_release_time( + const base::TimeTicks& release_time) { + net::URLRequestThrottlerEntry::set_exponential_backoff_release_time( + release_time); + } + + base::TimeTicks sliding_window_release_time() const { + return net::URLRequestThrottlerEntry::sliding_window_release_time(); + } + + void set_sliding_window_release_time( + const base::TimeTicks& release_time) { + net::URLRequestThrottlerEntry::set_sliding_window_release_time( + release_time); + } + + TimeTicks fake_time_now_; +}; + +class MockURLRequestThrottlerHeaderAdapter + : public net::URLRequestThrottlerHeaderInterface { + public: + MockURLRequestThrottlerHeaderAdapter() + : fake_retry_value_("0.0"), + fake_response_code_(0) { + } + + MockURLRequestThrottlerHeaderAdapter(const std::string& retry_value, + int response_code) + : fake_retry_value_(retry_value), + fake_response_code_(response_code) { + } + + virtual ~MockURLRequestThrottlerHeaderAdapter() {} + + virtual std::string GetNormalizedValue(const std::string& key) const { + if (key == MockURLRequestThrottlerEntry::kRetryHeaderName) + return fake_retry_value_; + return ""; + } + + virtual int GetResponseCode() const { return fake_response_code_; } + + std::string fake_retry_value_; + int fake_response_code_; +}; + +class MockURLRequestThrottlerManager : public net::URLRequestThrottlerManager { + public: + MockURLRequestThrottlerManager() : create_entry_index_(0) {} + + // Method to process the URL using URLRequestThrottlerManager protected + // method. + std::string DoGetUrlIdFromUrl(const GURL& url) { return GetIdFromUrl(url); } + + // Method to use the garbage collecting method of URLRequestThrottlerManager. + void DoGarbageCollectEntries() { GarbageCollectEntries(); } + + // Returns the number of entries in the map. + int GetNumberOfEntries() const { return GetNumberOfEntriesForTests(); } + + void CreateEntry(bool is_outdated) { + TimeTicks time = TimeTicks::Now(); + if (is_outdated) { + time -= TimeDelta::FromMilliseconds( + MockURLRequestThrottlerEntry::kDefaultEntryLifetimeMs + 1000); + } + std::string fake_url_string("http://www.fakeurl.com/"); + fake_url_string.append(base::IntToString(create_entry_index_++)); + GURL fake_url(fake_url_string); + OverrideEntryForTests( + fake_url, + new MockURLRequestThrottlerEntry(time, TimeTicks::Now(), + TimeTicks::Now())); + } + + private: + int create_entry_index_; +}; + +struct TimeAndBool { + TimeAndBool(const TimeTicks& time_value, bool expected, int line_num) { + time = time_value; + result = expected; + line = line_num; + } + TimeTicks time; + bool result; + int line; +}; + +struct GurlAndString { + GurlAndString(const GURL& url_value, + const std::string& expected, + int line_num) { + url = url_value; + result = expected; + line = line_num; + } + GURL url; + std::string result; + int line; +}; + +} // namespace + +class URLRequestThrottlerEntryTest : public testing::Test { + protected: + virtual void SetUp(); + TimeTicks now_; + scoped_refptr<MockURLRequestThrottlerEntry> entry_; +}; + +void URLRequestThrottlerEntryTest::SetUp() { + now_ = TimeTicks::Now(); + entry_ = new MockURLRequestThrottlerEntry(); + entry_->ResetToBlank(now_); +} + +std::ostream& operator<<(std::ostream& out, const base::TimeTicks& time) { + return out << time.ToInternalValue(); +} + +TEST_F(URLRequestThrottlerEntryTest, InterfaceDuringExponentialBackoff) { + entry_->set_exponential_backoff_release_time( + entry_->fake_time_now_ + TimeDelta::FromMilliseconds(1)); + EXPECT_TRUE(entry_->IsDuringExponentialBackoff()); +} + +TEST_F(URLRequestThrottlerEntryTest, InterfaceNotDuringExponentialBackoff) { + entry_->set_exponential_backoff_release_time(entry_->fake_time_now_); + EXPECT_FALSE(entry_->IsDuringExponentialBackoff()); + entry_->set_exponential_backoff_release_time( + entry_->fake_time_now_ - TimeDelta::FromMilliseconds(1)); + EXPECT_FALSE(entry_->IsDuringExponentialBackoff()); +} + +TEST_F(URLRequestThrottlerEntryTest, InterfaceUpdateRetryAfter) { + // If the response we received has a retry-after field, + // the request should be delayed. + MockURLRequestThrottlerHeaderAdapter header_w_delay_header("5.5", 200); + entry_->UpdateWithResponse(&header_w_delay_header); + EXPECT_GT(entry_->GetExponentialBackoffReleaseTime(), entry_->fake_time_now_) + << "When the server put a positive value in retry-after we should " + "increase release_time"; + + entry_->ResetToBlank(now_); + header_w_delay_header.fake_retry_value_ = "-5.5"; + EXPECT_EQ(entry_->GetExponentialBackoffReleaseTime(), entry_->fake_time_now_) + << "When given a negative value, it should not change the release_time"; +} + +TEST_F(URLRequestThrottlerEntryTest, InterfaceUpdateFailure) { + MockURLRequestThrottlerHeaderAdapter failure_response("0", 505); + entry_->UpdateWithResponse(&failure_response); + EXPECT_GT(entry_->GetExponentialBackoffReleaseTime(), entry_->fake_time_now_) + << "A failure should increase the release_time"; +} + +TEST_F(URLRequestThrottlerEntryTest, InterfaceUpdateSuccess) { + MockURLRequestThrottlerHeaderAdapter success_response("0", 200); + entry_->UpdateWithResponse(&success_response); + EXPECT_EQ(entry_->GetExponentialBackoffReleaseTime(), entry_->fake_time_now_) + << "A success should not add any delay"; +} + +TEST_F(URLRequestThrottlerEntryTest, InterfaceUpdateSuccessThenFailure) { + MockURLRequestThrottlerHeaderAdapter failure_response("0", 500); + MockURLRequestThrottlerHeaderAdapter success_response("0", 200); + entry_->UpdateWithResponse(&success_response); + entry_->UpdateWithResponse(&failure_response); + EXPECT_GT(entry_->GetExponentialBackoffReleaseTime(), entry_->fake_time_now_) + << "This scenario should add delay"; +} + +TEST_F(URLRequestThrottlerEntryTest, IsEntryReallyOutdated) { + TimeDelta lifetime = TimeDelta::FromMilliseconds( + MockURLRequestThrottlerEntry::kDefaultEntryLifetimeMs); + const TimeDelta kFiveMs = TimeDelta::FromMilliseconds(5); + + TimeAndBool test_values[] = { + TimeAndBool(now_, false, __LINE__), + TimeAndBool(now_ - kFiveMs, false, __LINE__), + TimeAndBool(now_ + kFiveMs, false, __LINE__), + TimeAndBool(now_ - lifetime, false, __LINE__), + TimeAndBool(now_ - (lifetime + kFiveMs), true, __LINE__)}; + + for (unsigned int i = 0; i < arraysize(test_values); ++i) { + entry_->set_exponential_backoff_release_time(test_values[i].time); + EXPECT_EQ(entry_->IsEntryOutdated(), test_values[i].result) << + "Test case #" << i << " line " << test_values[i].line << " failed"; + } +} + +TEST_F(URLRequestThrottlerEntryTest, MaxAllowedBackoff) { + for (int i = 0; i < 30; ++i) { + MockURLRequestThrottlerHeaderAdapter response_adapter("0.0", 505); + entry_->UpdateWithResponse(&response_adapter); + } + + TimeDelta delay = entry_->GetExponentialBackoffReleaseTime() - now_; + EXPECT_EQ(delay.InMilliseconds(), + MockURLRequestThrottlerEntry::kDefaultMaximumBackoffMs); +} + +TEST_F(URLRequestThrottlerEntryTest, MalformedContent) { + MockURLRequestThrottlerHeaderAdapter response_adapter("0.0", 505); + for (int i = 0; i < 5; ++i) + entry_->UpdateWithResponse(&response_adapter); + + TimeTicks release_after_failures = entry_->GetExponentialBackoffReleaseTime(); + + // Inform the entry that a response body was malformed, which is supposed to + // increase the back-off time. + entry_->ReceivedContentWasMalformed(); + EXPECT_GT(entry_->GetExponentialBackoffReleaseTime(), release_after_failures); +} + +TEST_F(URLRequestThrottlerEntryTest, SlidingWindow) { + int max_send = net::URLRequestThrottlerEntry::kDefaultMaxSendThreshold; + int sliding_window = + net::URLRequestThrottlerEntry::kDefaultSlidingWindowPeriodMs; + + TimeTicks time_1 = entry_->fake_time_now_ + + TimeDelta::FromMilliseconds(sliding_window / 3); + TimeTicks time_2 = entry_->fake_time_now_ + + TimeDelta::FromMilliseconds(2 * sliding_window / 3); + TimeTicks time_3 = entry_->fake_time_now_ + + TimeDelta::FromMilliseconds(sliding_window); + TimeTicks time_4 = entry_->fake_time_now_ + + TimeDelta::FromMilliseconds(sliding_window + 2 * sliding_window / 3); + + entry_->set_exponential_backoff_release_time(time_1); + + for (int i = 0; i < max_send / 2; ++i) { + EXPECT_EQ(2 * sliding_window / 3, + entry_->ReserveSendingTimeForNextRequest(time_2)); + } + EXPECT_EQ(time_2, entry_->sliding_window_release_time()); + + entry_->fake_time_now_ = time_3; + + for (int i = 0; i < (max_send + 1) / 2; ++i) + EXPECT_EQ(0, entry_->ReserveSendingTimeForNextRequest(TimeTicks())); + + EXPECT_EQ(time_4, entry_->sliding_window_release_time()); +} + +TEST(URLRequestThrottlerManager, IsUrlStandardised) { + MockURLRequestThrottlerManager manager; + GurlAndString test_values[] = { + GurlAndString(GURL("http://www.example.com"), + std::string("http://www.example.com/"), __LINE__), + GurlAndString(GURL("http://www.Example.com"), + std::string("http://www.example.com/"), __LINE__), + GurlAndString(GURL("http://www.ex4mple.com/Pr4c71c41"), + std::string("http://www.ex4mple.com/pr4c71c41"), __LINE__), + GurlAndString(GURL("http://www.example.com/0/token/false"), + std::string("http://www.example.com/0/token/false"), + __LINE__), + GurlAndString(GURL("http://www.example.com/index.php?code=javascript"), + std::string("http://www.example.com/index.php"), __LINE__), + GurlAndString(GURL("http://www.example.com/index.php?code=1#superEntry"), + std::string("http://www.example.com/index.php"), + __LINE__), + GurlAndString(GURL("http://www.example.com:1234/"), + std::string("http://www.example.com:1234/"), __LINE__)}; + + for (unsigned int i = 0; i < arraysize(test_values); ++i) { + std::string temp = manager.DoGetUrlIdFromUrl(test_values[i].url); + EXPECT_EQ(temp, test_values[i].result) << + "Test case #" << i << " line " << test_values[i].line << " failed"; + } +} + +TEST(URLRequestThrottlerManager, AreEntriesBeingCollected) { + MockURLRequestThrottlerManager manager; + + manager.CreateEntry(true); // true = Entry is outdated. + manager.CreateEntry(true); + manager.CreateEntry(true); + manager.DoGarbageCollectEntries(); + EXPECT_EQ(0, manager.GetNumberOfEntries()); + + manager.CreateEntry(false); + manager.CreateEntry(false); + manager.CreateEntry(false); + manager.CreateEntry(true); + manager.DoGarbageCollectEntries(); + EXPECT_EQ(3, manager.GetNumberOfEntries()); +} + +TEST(URLRequestThrottlerManager, IsHostBeingRegistered) { + MockURLRequestThrottlerManager manager; + + manager.RegisterRequestUrl(GURL("http://www.example.com/")); + manager.RegisterRequestUrl(GURL("http://www.google.com/")); + manager.RegisterRequestUrl(GURL("http://www.google.com/index/0")); + manager.RegisterRequestUrl(GURL("http://www.google.com/index/0?code=1")); + manager.RegisterRequestUrl(GURL("http://www.google.com/index/0#lolsaure")); + + EXPECT_EQ(3, manager.GetNumberOfEntries()); +} diff --git a/net/url_request/url_request_unittest.h b/net/url_request/url_request_unittest.h index abb6ab5..af8f49e 100644 --- a/net/url_request/url_request_unittest.h +++ b/net/url_request/url_request_unittest.h @@ -162,6 +162,7 @@ class TestURLRequestContext : public URLRequestContext { http_transaction_factory_ = new net::HttpCache( net::HttpNetworkLayer::CreateFactory(host_resolver_, NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, proxy_service_, ssl_config_service_, @@ -318,6 +319,7 @@ class TestDelegate : public URLRequest::Delegate { virtual void OnSetCookie(URLRequest* request, const std::string& cookie_line, + const net::CookieOptions& options, bool blocked_by_policy) { if (blocked_by_policy) { blocked_set_cookie_count_++; |