diff options
author | wtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-12-16 17:27:15 +0000 |
---|---|---|
committer | wtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-12-16 17:27:15 +0000 |
commit | 822581d32a6836feae73b96a2ce494a058004423 (patch) | |
tree | 925796acd3c3aeaa357378c096c5d9efec31bf36 /net | |
parent | ae89b8d559bfa6b3a2c1d404b21386bcc8995472 (diff) | |
download | chromium_src-822581d32a6836feae73b96a2ce494a058004423.zip chromium_src-822581d32a6836feae73b96a2ce494a058004423.tar.gz chromium_src-822581d32a6836feae73b96a2ce494a058004423.tar.bz2 |
Cache certificate verification results in memory.
R=agl
BUG=63357
TEST=none
Review URL: http://codereview.chromium.org/5386001
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@69414 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
51 files changed, 1091 insertions, 198 deletions
diff --git a/net/base/cert_verifier.cc b/net/base/cert_verifier.cc index ae910b4..4b3d904 100644 --- a/net/base/cert_verifier.cc +++ b/net/base/cert_verifier.cc @@ -1,45 +1,158 @@ -// Copyright (c) 2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/base/cert_verifier.h" -#if defined(USE_NSS) -#include <private/pprthred.h> // PR_DetatchThread -#endif - +#include "base/compiler_specific.h" #include "base/lock.h" -#include "base/message_loop_proxy.h" -#include "base/scoped_ptr.h" +#include "base/message_loop.h" +#include "base/stl_util-inl.h" #include "base/worker_pool.h" -#include "net/base/cert_verify_result.h" #include "net/base/net_errors.h" #include "net/base/x509_certificate.h" +#if defined(USE_NSS) +#include <private/pprthred.h> // PR_DetachThread +#endif + namespace net { -class CertVerifier::Request : - public base::RefCountedThreadSafe<CertVerifier::Request> { +//////////////////////////////////////////////////////////////////////////// + +// Life of a request: +// +// CertVerifier CertVerifierJob CertVerifierWorker Request +// | (origin loop) (worker loop) +// | +// Verify() +// |---->-------------------<creates> +// | +// |---->----<creates> +// | +// |---->---------------------------------------------------<creates> +// | +// |---->--------------------Start +// | | +// | PostTask +// | +// | <starts verifying> +// |---->-----AddRequest | +// | +// | +// | +// Finish +// | +// PostTask +// +// | +// DoReply +// |----<-----------------------| +// HandleResult +// | +// |---->-----HandleResult +// | +// |------>-----------------------------------Post +// +// +// +// On a cache hit, CertVerifier::Verify() returns synchronously without +// posting a task to a worker thread. + +// The number of CachedCertVerifyResult objects that we'll cache. +static const unsigned kMaxCacheEntries = 256; + +// The number of seconds for which we'll cache a cache entry. +static const unsigned kTTLSecs = 1800; // 30 minutes. + +namespace { + +class DefaultTimeService : public CertVerifier::TimeService { + public: + // CertVerifier::TimeService methods: + virtual base::Time Now() { return base::Time::Now(); } +}; + +} // namespace + +CachedCertVerifyResult::CachedCertVerifyResult() : error(ERR_FAILED) { +} + +CachedCertVerifyResult::~CachedCertVerifyResult() {} + +bool CachedCertVerifyResult::HasExpired(const base::Time current_time) const { + return current_time >= expiry; +} + +// Represents the output and result callback of a request. +class CertVerifierRequest { public: - Request(CertVerifier* verifier, - X509Certificate* cert, - const std::string& hostname, - int flags, - CertVerifyResult* verify_result, - CompletionCallback* callback) + CertVerifierRequest(CompletionCallback* callback, + CertVerifyResult* verify_result) + : callback_(callback), + verify_result_(verify_result) { + } + + // Ensures that the result callback will never be made. + void Cancel() { + callback_ = NULL; + verify_result_ = NULL; + } + + // Copies the contents of |verify_result| to the caller's + // CertVerifyResult and calls the callback. + void Post(const CachedCertVerifyResult& verify_result) { + if (callback_) { + *verify_result_ = verify_result.result; + callback_->Run(verify_result.error); + } + delete this; + } + + private: + CompletionCallback* callback_; + CertVerifyResult* verify_result_; +}; + + +// CertVerifierWorker runs on a worker thread and takes care of the blocking +// process of performing the certificate verification. Deletes itself +// eventually if Start() succeeds. +class CertVerifierWorker { + public: + CertVerifierWorker(X509Certificate* cert, + const std::string& hostname, + int flags, + CertVerifier* cert_verifier) : cert_(cert), hostname_(hostname), flags_(flags), - verifier_(verifier), - verify_result_(verify_result), - callback_(callback), - origin_loop_proxy_(base::MessageLoopProxy::CreateForCurrentThread()), - error_(OK) { + origin_loop_(MessageLoop::current()), + cert_verifier_(cert_verifier), + canceled_(false), + error_(ERR_FAILED) { + } + + bool Start() { + DCHECK_EQ(MessageLoop::current(), origin_loop_); + + return WorkerPool::PostTask( + FROM_HERE, NewRunnableMethod(this, &CertVerifierWorker::Run), + true /* task is slow */); } - void DoVerify() { - // Running on the worker thread - error_ = cert_->Verify(hostname_, flags_, &result_); + // Cancel is called from the origin loop when the CertVerifier is getting + // deleted. + void Cancel() { + DCHECK_EQ(MessageLoop::current(), origin_loop_); + AutoLock locked(lock_); + canceled_ = true; + } + + private: + void Run() { + // Runs on a worker thread. + error_ = cert_->Verify(hostname_, flags_, &verify_result_); #if defined(USE_NSS) // Detach the thread from NSPR. // Calling NSS functions attaches the thread to NSPR, which stores @@ -50,109 +163,319 @@ class CertVerifier::Request : // destructors run. PR_DetachThread(); #endif + Finish(); + } - scoped_ptr<Task> reply(NewRunnableMethod(this, &Request::DoCallback)); - - // The origin loop could go away while we are trying to post to it, so we - // need to call its PostTask method inside a lock. See ~CertVerifier. - AutoLock locked(origin_loop_proxy_lock_); - if (origin_loop_proxy_) { - bool posted = origin_loop_proxy_->PostTask(FROM_HERE, reply.release()); - // TODO(willchan): Fix leaks and then change this to a DCHECK. - LOG_IF(ERROR, !posted) << "Leaked CertVerifier!"; + // DoReply runs on the origin thread. + void DoReply() { + DCHECK_EQ(MessageLoop::current(), origin_loop_); + { + // We lock here because the worker thread could still be in Finished, + // after the PostTask, but before unlocking |lock_|. If we do not lock in + // this case, we will end up deleting a locked Lock, which can lead to + // memory leaks or worse errors. + AutoLock locked(lock_); + if (!canceled_) { + cert_verifier_->HandleResult(cert_, hostname_, flags_, + error_, verify_result_); + } } + delete this; } - void DoCallback() { - // Running on the origin thread. + void Finish() { + // Runs on the worker thread. + // We assume that the origin loop outlives the CertVerifier. If the + // CertVerifier is deleted, it will call Cancel on us. If it does so + // before the Acquire, we'll delete ourselves and return. If it's trying to + // do so concurrently, then it'll block on the lock and we'll call PostTask + // while the CertVerifier (and therefore the MessageLoop) is still alive. + // If it does so after this function, we assume that the MessageLoop will + // process pending tasks. In which case we'll notice the |canceled_| flag + // in DoReply. - // We may have been cancelled! - if (!verifier_) - return; + bool canceled; + { + AutoLock locked(lock_); + canceled = canceled_; + if (!canceled) { + origin_loop_->PostTask( + FROM_HERE, NewRunnableMethod(this, &CertVerifierWorker::DoReply)); + } + } - *verify_result_ = result_; + if (canceled) + delete this; + } - // Drop the verifier's reference to us. Do this before running the - // callback since the callback might result in the verifier being - // destroyed. - verifier_->request_ = NULL; + scoped_refptr<X509Certificate> cert_; + const std::string hostname_; + const int flags_; + MessageLoop* const origin_loop_; + CertVerifier* const cert_verifier_; - callback_->Run(error_); - } + // lock_ protects canceled_. + Lock lock_; - void Cancel() { - verifier_ = NULL; + // If canceled_ is true, + // * origin_loop_ cannot be accessed by the worker thread, + // * cert_verifier_ cannot be accessed by any thread. + bool canceled_; + + int error_; + CertVerifyResult verify_result_; + + DISALLOW_COPY_AND_ASSIGN(CertVerifierWorker); +}; - AutoLock locked(origin_loop_proxy_lock_); - origin_loop_proxy_ = NULL; +// A CertVerifierJob is a one-to-one counterpart of a CertVerifierWorker. It +// lives only on the CertVerifier's origin message loop. +class CertVerifierJob { + public: + explicit CertVerifierJob(CertVerifierWorker* worker) : worker_(worker) { } - private: - friend class base::RefCountedThreadSafe<CertVerifier::Request>; + ~CertVerifierJob() { + if (worker_) + worker_->Cancel(); + } - ~Request() {} + void AddRequest(CertVerifierRequest* request) { + requests_.push_back(request); + } - // Set on the origin thread, read on the worker thread. - scoped_refptr<X509Certificate> cert_; - std::string hostname_; - // bitwise OR'd of X509Certificate::VerifyFlags. - int flags_; + void HandleResult(const CachedCertVerifyResult& verify_result) { + worker_ = NULL; + PostAll(verify_result); + } - // Only used on the origin thread (where Verify was called). - CertVerifier* verifier_; - CertVerifyResult* verify_result_; - CompletionCallback* callback_; + private: + void PostAll(const CachedCertVerifyResult& verify_result) { + std::vector<CertVerifierRequest*> requests; + requests_.swap(requests); - // Used to post ourselves onto the origin thread. - Lock origin_loop_proxy_lock_; - // Use a MessageLoopProxy in case the owner of the CertVerifier is leaked, so - // this code won't crash: http://crbug.com/42275. If this is leaked, then it - // doesn't get Cancel()'d, so |origin_loop_proxy_| doesn't get NULL'd out. If - // the MessageLoop goes away, then if we had used a MessageLoop, this would - // crash. - scoped_refptr<base::MessageLoopProxy> origin_loop_proxy_; + for (std::vector<CertVerifierRequest*>::iterator + i = requests.begin(); i != requests.end(); i++) { + (*i)->Post(verify_result); + // Post() causes the CertVerifierRequest to delete itself. + } + } - // Assigned on the worker thread, read on the origin thread. - int error_; - CertVerifyResult result_; + std::vector<CertVerifierRequest*> requests_; + CertVerifierWorker* worker_; }; -//----------------------------------------------------------------------------- -CertVerifier::CertVerifier() { +CertVerifier::CertVerifier() + : time_service_(new DefaultTimeService), + requests_(0), + cache_hits_(0), + inflight_joins_(0) { +} + +CertVerifier::CertVerifier(TimeService* time_service) + : time_service_(time_service), + requests_(0), + cache_hits_(0), + inflight_joins_(0) { } CertVerifier::~CertVerifier() { - if (request_) - request_->Cancel(); + STLDeleteValues(&inflight_); } int CertVerifier::Verify(X509Certificate* cert, const std::string& hostname, int flags, CertVerifyResult* verify_result, - CompletionCallback* callback) { - DCHECK(!request_) << "verifier already in use"; + CompletionCallback* callback, + RequestHandle* out_req) { + DCHECK(CalledOnValidThread()); - // Do a synchronous verification. - if (!callback) { - CertVerifyResult result; - int rv = cert->Verify(hostname, flags, &result); - *verify_result = result; - return rv; + if (!callback || !verify_result || hostname.empty()) { + *out_req = NULL; + return ERR_INVALID_ARGUMENT; } - request_ = new Request(this, cert, hostname, flags, verify_result, callback); + requests_++; - // Dispatch to worker thread... - if (!WorkerPool::PostTask(FROM_HERE, - NewRunnableMethod(request_.get(), &Request::DoVerify), true)) { - NOTREACHED(); - request_ = NULL; - return ERR_FAILED; + const RequestParams key = {cert->fingerprint(), hostname, flags}; + // First check the cache. + std::map<RequestParams, CachedCertVerifyResult>::iterator i; + i = cache_.find(key); + if (i != cache_.end()) { + if (!i->second.HasExpired(time_service_->Now())) { + cache_hits_++; + *out_req = NULL; + *verify_result = i->second.result; + return i->second.error; + } + // Cache entry has expired. + cache_.erase(i); } + // No cache hit. See if an identical request is currently in flight. + CertVerifierJob* job; + std::map<RequestParams, CertVerifierJob*>::const_iterator j; + j = inflight_.find(key); + if (j != inflight_.end()) { + // An identical request is in flight already. We'll just attach our + // callback. + inflight_joins_++; + job = j->second; + } else { + // Need to make a new request. + CertVerifierWorker* worker = new CertVerifierWorker(cert, hostname, flags, + this); + job = new CertVerifierJob(worker); + inflight_.insert(std::make_pair(key, job)); + if (!worker->Start()) { + inflight_.erase(key); + delete job; + delete worker; + *out_req = NULL; + return ERR_FAILED; // TODO(wtc): Log an error message. + } + } + + CertVerifierRequest* request = + new CertVerifierRequest(callback, verify_result); + job->AddRequest(request); + *out_req = request; return ERR_IO_PENDING; } +void CertVerifier::CancelRequest(RequestHandle req) { + DCHECK(CalledOnValidThread()); + CertVerifierRequest* request = reinterpret_cast<CertVerifierRequest*>(req); + request->Cancel(); +} + +void CertVerifier::ClearCache() { + DCHECK(CalledOnValidThread()); + + cache_.clear(); + // Leaves inflight_ alone. +} + +size_t CertVerifier::GetCacheSize() const { + DCHECK(CalledOnValidThread()); + + return cache_.size(); +} + +// HandleResult is called by CertVerifierWorker on the origin message loop. +// It deletes CertVerifierJob. +void CertVerifier::HandleResult(X509Certificate* cert, + const std::string& hostname, + int flags, + int error, + const CertVerifyResult& verify_result) { + DCHECK(CalledOnValidThread()); + + const base::Time current_time(time_service_->Now()); + + CachedCertVerifyResult cached_result; + cached_result.error = error; + cached_result.result = verify_result; + uint32 ttl = kTTLSecs; + cached_result.expiry = current_time + base::TimeDelta::FromSeconds(ttl); + + const RequestParams key = {cert->fingerprint(), hostname, flags}; + + DCHECK_GE(kMaxCacheEntries, 1u); + DCHECK_LE(cache_.size(), kMaxCacheEntries); + if (cache_.size() == kMaxCacheEntries) { + // Need to remove an element of the cache. + std::map<RequestParams, CachedCertVerifyResult>::iterator i, cur; + for (i = cache_.begin(); i != cache_.end(); ) { + cur = i++; + if (cur->second.HasExpired(current_time)) + cache_.erase(cur); + } + } + if (cache_.size() == kMaxCacheEntries) { + // If we didn't clear out any expired entries, we just remove the first + // element. Crummy but simple. + cache_.erase(cache_.begin()); + } + + cache_.insert(std::make_pair(key, cached_result)); + + std::map<RequestParams, CertVerifierJob*>::iterator j; + j = inflight_.find(key); + if (j == inflight_.end()) { + NOTREACHED(); + return; + } + CertVerifierJob* job = j->second; + inflight_.erase(j); + + job->HandleResult(cached_result); + delete job; +} + +///////////////////////////////////////////////////////////////////// + +SingleRequestCertVerifier::SingleRequestCertVerifier( + CertVerifier* cert_verifier) + : cert_verifier_(cert_verifier), + cur_request_(NULL), + cur_request_callback_(NULL), + ALLOW_THIS_IN_INITIALIZER_LIST( + callback_(this, &SingleRequestCertVerifier::OnVerifyCompletion)) { + DCHECK(cert_verifier_ != NULL); +} + +SingleRequestCertVerifier::~SingleRequestCertVerifier() { + if (cur_request_) { + cert_verifier_->CancelRequest(cur_request_); + cur_request_ = NULL; + } +} + +int SingleRequestCertVerifier::Verify(X509Certificate* cert, + const std::string& hostname, + int flags, + CertVerifyResult* verify_result, + CompletionCallback* callback) { + // Should not be already in use. + DCHECK(!cur_request_ && !cur_request_callback_); + + // Do a synchronous verification. + if (!callback) + return cert->Verify(hostname, flags, verify_result); + + CertVerifier::RequestHandle request = NULL; + + // We need to be notified of completion before |callback| is called, so that + // we can clear out |cur_request_*|. + int rv = cert_verifier_->Verify( + cert, hostname, flags, verify_result, &callback_, &request); + + if (rv == ERR_IO_PENDING) { + // Cleared in OnVerifyCompletion(). + cur_request_ = request; + cur_request_callback_ = callback; + } + + return rv; +} + +void SingleRequestCertVerifier::OnVerifyCompletion(int result) { + DCHECK(cur_request_ && cur_request_callback_); + + CompletionCallback* callback = cur_request_callback_; + + // Clear the outstanding request information. + cur_request_ = NULL; + cur_request_callback_ = NULL; + + // Call the user's original callback. + callback->Run(result); +} + } // namespace net + +DISABLE_RUNNABLE_METHOD_REFCOUNT(net::CertVerifierWorker); + diff --git a/net/base/cert_verifier.h b/net/base/cert_verifier.h index 791f8d3..3d19abb 100644 --- a/net/base/cert_verifier.h +++ b/net/base/cert_verifier.h @@ -1,4 +1,4 @@ -// Copyright (c) 2008-2009 The Chromium Authors. All rights reserved. +// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -6,29 +6,68 @@ #define NET_BASE_CERT_VERIFIER_H_ #pragma once +#include <map> #include <string> #include "base/basictypes.h" -#include "base/ref_counted.h" +#include "base/non_thread_safe.h" +#include "base/scoped_ptr.h" +#include "base/time.h" +#include "net/base/cert_verify_result.h" #include "net/base/completion_callback.h" +#include "net/base/x509_cert_types.h" namespace net { -class CertVerifyResult; +class CertVerifierJob; +class CertVerifierWorker; class X509Certificate; -// This class represents the task of verifying a certificate. It can only -// verify a single certificate at a time, so if you need to verify multiple -// certificates at the same time, you will need to allocate a CertVerifier -// object for each certificate. +// CachedCertVerifyResult contains the result of a certificate verification. +struct CachedCertVerifyResult { + CachedCertVerifyResult(); + ~CachedCertVerifyResult(); + + int error; // The return value of CertVerifier::Verify. + CertVerifyResult result; // The output of CertVerifier::Verify. + + // The time at which the certificate verification result expires. + base::Time expiry; + + // Returns true if |current_time| is greater than or equal to |expiry|. + bool HasExpired(base::Time current_time) const; +}; + +// CertVerifier represents a service for verifying certificates. // -class CertVerifier { +// CertVerifier can handle multiple requests at a time, so when canceling a +// request the RequestHandle that was returned by Verify() needs to be +// given. A simpler alternative for consumers that only have 1 outstanding +// request at a time is to create a SingleRequestCertVerifier wrapper around +// CertVerifier (which will automatically cancel the single request when it +// goes out of scope). +class CertVerifier : public NonThreadSafe { public: + // Opaque type used to cancel a request. + typedef void* RequestHandle; + + // CertVerifier must not call base::Time::Now() directly. It must call + // time_service_->Now(). This allows unit tests to mock the current time. + class TimeService { + public: + virtual ~TimeService() {} + + virtual base::Time Now() = 0; + }; + CertVerifier(); - // If a completion callback is pending when the verifier is destroyed, the - // certificate verification is cancelled, and the completion callback will - // not be called. + // Used by unit tests to mock the current time. Takes ownership of + // |time_service|. + explicit CertVerifier(TimeService* time_service); + + // When the verifier is destroyed, all certificate verifications requests are + // canceled, and their completion callbacks will not be called. ~CertVerifier(); // Verifies the given certificate against the given hostname. Returns OK if @@ -49,23 +88,128 @@ class CertVerifier { // VERIFY_REV_CHECKING_ENABLED is not set), EV certificate verification will // not be performed. // - // When callback is null, the operation completes synchronously. - // - // When callback is non-null, ERR_IO_PENDING is returned if the operation + // |callback| must not be null. ERR_IO_PENDING is returned if the operation // could not be completed synchronously, in which case the result code will // be passed to the callback when available. // - int Verify(X509Certificate* cert, const std::string& hostname, - int flags, CertVerifyResult* verify_result, - CompletionCallback* callback); + // If |out_req| is non-NULL, then |*out_req| will be filled with a handle to + // the async request. This handle is not valid after the request has + // completed. + int Verify(X509Certificate* cert, + const std::string& hostname, + int flags, + CertVerifyResult* verify_result, + CompletionCallback* callback, + RequestHandle* out_req); + + // Cancels the specified request. |req| is the handle returned by Verify(). + // After a request is canceled, its completion callback will not be called. + void CancelRequest(RequestHandle req); + + // Clears the verification result cache. + void ClearCache(); + + size_t GetCacheSize() const; + + uint64 requests() const { return requests_; } + uint64 cache_hits() const { return cache_hits_; } + uint64 inflight_joins() const { return inflight_joins_; } private: - class Request; - friend class Request; - scoped_refptr<Request> request_; + friend class CertVerifierWorker; // Calls HandleResult. + + // Input parameters of a certificate verification request. + struct RequestParams { + bool operator==(const RequestParams& other) const { + // |flags| is compared before |cert_fingerprint| and |hostname| under + // assumption that integer comparisons are faster than memory and string + // comparisons. + return (flags == other.flags && + memcmp(cert_fingerprint.data, other.cert_fingerprint.data, + sizeof(cert_fingerprint.data)) == 0 && + hostname == other.hostname); + } + + bool operator<(const RequestParams& other) const { + // |flags| is compared before |cert_fingerprint| and |hostname| under + // assumption that integer comparisons are faster than memory and string + // comparisons. + if (flags != other.flags) + return flags < other.flags; + int rv = memcmp(cert_fingerprint.data, other.cert_fingerprint.data, + sizeof(cert_fingerprint.data)); + if (rv != 0) + return rv < 0; + return hostname < other.hostname; + } + + SHA1Fingerprint cert_fingerprint; + std::string hostname; + int flags; + }; + + void HandleResult(X509Certificate* cert, + const std::string& hostname, + int flags, + int error, + const CertVerifyResult& verify_result); + + // cache_ maps from a request to a cached result. The cached result may + // have expired and the size of |cache_| must be <= kMaxCacheEntries. + std::map<RequestParams, CachedCertVerifyResult> cache_; + + // inflight_ maps from a request to an active verification which is taking + // place. + std::map<RequestParams, CertVerifierJob*> inflight_; + + scoped_ptr<TimeService> time_service_; + + uint64 requests_; + uint64 cache_hits_; + uint64 inflight_joins_; + DISALLOW_COPY_AND_ASSIGN(CertVerifier); }; +// This class represents the task of verifying a certificate. It wraps +// CertVerifier to verify only a single certificate at a time and cancels this +// request when going out of scope. +class SingleRequestCertVerifier { + public: + // |cert_verifier| must remain valid for the lifetime of |this|. + explicit SingleRequestCertVerifier(CertVerifier* cert_verifier); + + // If a completion callback is pending when the verifier is destroyed, the + // certificate verification is canceled, and the completion callback will + // not be called. + ~SingleRequestCertVerifier(); + + // Verifies the given certificate, filling out the |verify_result| object + // upon success. See CertVerifier::Verify() for details. + int Verify(X509Certificate* cert, + const std::string& hostname, + int flags, + CertVerifyResult* verify_result, + CompletionCallback* callback); + + private: + // Callback for when the request to |cert_verifier_| completes, so we + // dispatch to the user's callback. + void OnVerifyCompletion(int result); + + // The actual certificate verifier that will handle the request. + CertVerifier* const cert_verifier_; + + // The current request (if any). + CertVerifier::RequestHandle cur_request_; + CompletionCallback* cur_request_callback_; + + // Completion callback for when request to |cert_verifier_| completes. + net::CompletionCallbackImpl<SingleRequestCertVerifier> callback_; + + DISALLOW_COPY_AND_ASSIGN(SingleRequestCertVerifier); +}; + } // namespace net #endif // NET_BASE_CERT_VERIFIER_H_ diff --git a/net/base/cert_verifier_unittest.cc b/net/base/cert_verifier_unittest.cc new file mode 100644 index 0000000..ca5e1f4 --- /dev/null +++ b/net/base/cert_verifier_unittest.cc @@ -0,0 +1,260 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/cert_verifier.h" + +#include "base/callback.h" +#include "base/file_path.h" +#include "base/stringprintf.h" +#include "net/base/cert_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/base/x509_certificate.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +class TestTimeService : public CertVerifier::TimeService { + public: + // CertVerifier::TimeService methods: + virtual base::Time Now() { return current_time_; } + + void set_current_time(base::Time now) { current_time_ = now; } + + private: + base::Time current_time_; +}; + +class CertVerifierTest : public testing::Test { +}; + +class ExplodingCallback : public CallbackRunner<Tuple1<int> > { + public: + virtual void RunWithParams(const Tuple1<int>& params) { + FAIL(); + } +}; + +// Tests a cache hit, which should results in synchronous completion. +TEST_F(CertVerifierTest, CacheHit) { + TestTimeService* time_service = new TestTimeService; + base::Time current_time = base::Time::Now(); + time_service->set_current_time(current_time); + CertVerifier verifier(time_service); + + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> google_cert( + ImportCertFromFile(certs_dir, "google.single.der")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), google_cert); + + int error; + CertVerifyResult verify_result; + TestCompletionCallback callback; + CertVerifier::RequestHandle request_handle; + + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_EQ(1u, verifier.requests()); + ASSERT_EQ(0u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); + + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + // Synchronous completion. + ASSERT_NE(ERR_IO_PENDING, error); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_TRUE(request_handle == NULL); + ASSERT_EQ(2u, verifier.requests()); + ASSERT_EQ(1u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); +} + +// Tests an inflight join. +TEST_F(CertVerifierTest, InflightJoin) { + TestTimeService* time_service = new TestTimeService; + base::Time current_time = base::Time::Now(); + time_service->set_current_time(current_time); + CertVerifier verifier(time_service); + + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> google_cert( + ImportCertFromFile(certs_dir, "google.single.der")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), google_cert); + + int error; + CertVerifyResult verify_result; + TestCompletionCallback callback; + CertVerifier::RequestHandle request_handle; + CertVerifyResult verify_result2; + TestCompletionCallback callback2; + CertVerifier::RequestHandle request_handle2; + + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result2, + &callback2, &request_handle2); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle2 != NULL); + error = callback.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + error = callback2.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_EQ(2u, verifier.requests()); + ASSERT_EQ(0u, verifier.cache_hits()); + ASSERT_EQ(1u, verifier.inflight_joins()); +} + +// Tests cache entry expiration. +TEST_F(CertVerifierTest, ExpiredCacheEntry) { + TestTimeService* time_service = new TestTimeService; + base::Time current_time = base::Time::Now(); + time_service->set_current_time(current_time); + CertVerifier verifier(time_service); + + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> google_cert( + ImportCertFromFile(certs_dir, "google.single.der")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), google_cert); + + int error; + CertVerifyResult verify_result; + TestCompletionCallback callback; + CertVerifier::RequestHandle request_handle; + + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_EQ(1u, verifier.requests()); + ASSERT_EQ(0u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); + + // Before expiration, should have a cache hit. + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + // Synchronous completion. + ASSERT_NE(ERR_IO_PENDING, error); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_TRUE(request_handle == NULL); + ASSERT_EQ(2u, verifier.requests()); + ASSERT_EQ(1u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); + + // After expiration, should not have a cache hit. + ASSERT_EQ(1u, verifier.GetCacheSize()); + current_time += base::TimeDelta::FromMinutes(60); + time_service->set_current_time(current_time); + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + ASSERT_EQ(0u, verifier.GetCacheSize()); + error = callback.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_EQ(3u, verifier.requests()); + ASSERT_EQ(1u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); +} + +// Tests a full cache. +TEST_F(CertVerifierTest, FullCache) { + TestTimeService* time_service = new TestTimeService; + base::Time current_time = base::Time::Now(); + time_service->set_current_time(current_time); + CertVerifier verifier(time_service); + + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> google_cert( + ImportCertFromFile(certs_dir, "google.single.der")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), google_cert); + + int error; + CertVerifyResult verify_result; + TestCompletionCallback callback; + CertVerifier::RequestHandle request_handle; + + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_EQ(1u, verifier.requests()); + ASSERT_EQ(0u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); + + const unsigned kCacheSize = 256; + + for (unsigned i = 0; i < kCacheSize; i++) { + std::string hostname = base::StringPrintf("www%d.example.com", i + 1); + error = verifier.Verify(google_cert, hostname, 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + ASSERT_TRUE(IsCertificateError(error)); + } + ASSERT_EQ(kCacheSize + 1, verifier.requests()); + ASSERT_EQ(0u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); + + ASSERT_EQ(kCacheSize, verifier.GetCacheSize()); + current_time += base::TimeDelta::FromMinutes(60); + time_service->set_current_time(current_time); + error = verifier.Verify(google_cert, "www999.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + ASSERT_EQ(kCacheSize, verifier.GetCacheSize()); + error = callback.WaitForResult(); + ASSERT_EQ(1u, verifier.GetCacheSize()); + ASSERT_TRUE(IsCertificateError(error)); + ASSERT_EQ(kCacheSize + 2, verifier.requests()); + ASSERT_EQ(0u, verifier.cache_hits()); + ASSERT_EQ(0u, verifier.inflight_joins()); +} + +// Tests that the callback of a canceled request is never made. +TEST_F(CertVerifierTest, CancelRequest) { + CertVerifier verifier; + + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> google_cert( + ImportCertFromFile(certs_dir, "google.single.der")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), google_cert); + + int error; + CertVerifyResult verify_result; + ExplodingCallback exploding_callback; + CertVerifier::RequestHandle request_handle; + + error = verifier.Verify(google_cert, "www.example.com", 0, &verify_result, + &exploding_callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + verifier.CancelRequest(request_handle); + + // Issue a few more requests to the worker pool and wait for their + // completion, so that the task of the canceled request (which runs on a + // worker thread) is likely to complete by the end of this test. + TestCompletionCallback callback; + for (int i = 0; i < 5; ++i) { + error = verifier.Verify(google_cert, "www2.example.com", 0, &verify_result, + &callback, &request_handle); + ASSERT_EQ(ERR_IO_PENDING, error); + ASSERT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + verifier.ClearCache(); + } +} + +} // namespace net diff --git a/net/http/disk_cache_based_ssl_host_info.cc b/net/http/disk_cache_based_ssl_host_info.cc index 2b83f56..1b1dfaf 100644 --- a/net/http/disk_cache_based_ssl_host_info.cc +++ b/net/http/disk_cache_based_ssl_host_info.cc @@ -9,6 +9,7 @@ #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/http/http_cache.h" +#include "net/http/http_network_session.h" namespace net { @@ -16,7 +17,8 @@ DiskCacheBasedSSLHostInfo::DiskCacheBasedSSLHostInfo( const std::string& hostname, const SSLConfig& ssl_config, HttpCache* http_cache) - : SSLHostInfo(hostname, ssl_config), + : SSLHostInfo(hostname, ssl_config, + http_cache->network_layer()->GetSession()->cert_verifier()), weak_ptr_factory_(ALLOW_THIS_IN_INITIALIZER_LIST(this)), callback_(new CallbackImpl(weak_ptr_factory_.GetWeakPtr(), &DiskCacheBasedSSLHostInfo::DoLoop)), diff --git a/net/http/http_cache.cc b/net/http/http_cache.cc index ea4e48b..51cc55f 100644 --- a/net/http/http_cache.cc +++ b/net/http/http_cache.cc @@ -263,7 +263,7 @@ void HttpCache::MetadataWriter::OnIOComplete(int result) { class HttpCache::SSLHostInfoFactoryAdaptor : public SSLHostInfoFactory { public: - SSLHostInfoFactoryAdaptor(HttpCache* http_cache) + explicit SSLHostInfoFactoryAdaptor(HttpCache* http_cache) : http_cache_(http_cache) { } @@ -279,6 +279,7 @@ class HttpCache::SSLHostInfoFactoryAdaptor : public SSLHostInfoFactory { //----------------------------------------------------------------------------- HttpCache::HttpCache(HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker_, ProxyService* proxy_service, @@ -293,7 +294,7 @@ HttpCache::HttpCache(HostResolver* host_resolver, ssl_host_info_factory_(new SSLHostInfoFactoryAdaptor( ALLOW_THIS_IN_INITIALIZER_LIST(this))), network_layer_(HttpNetworkLayer::CreateFactory(host_resolver, - dnsrr_resolver, dns_cert_checker_, + cert_verifier, dnsrr_resolver, dns_cert_checker_, ssl_host_info_factory_.get(), proxy_service, ssl_config_service, http_auth_handler_factory, network_delegate, net_log)), diff --git a/net/http/http_cache.h b/net/http/http_cache.h index 4b7d736..5c812da 100644 --- a/net/http/http_cache.h +++ b/net/http/http_cache.h @@ -41,6 +41,7 @@ class Entry; namespace net { +class CertVerifier; class DnsCertProvenanceChecker; class DnsRRResolver; class HostResolver; @@ -117,6 +118,7 @@ class HttpCache : public HttpTransactionFactory, // The disk cache is initialized lazily (by CreateTransaction) in this case. // The HttpCache takes ownership of the |backend_factory|. HttpCache(HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, ProxyService* proxy_service, diff --git a/net/http/http_network_layer.cc b/net/http/http_network_layer.cc index 3da23c2..3d3c5dd 100644 --- a/net/http/http_network_layer.cc +++ b/net/http/http_network_layer.cc @@ -21,6 +21,7 @@ namespace net { // static HttpTransactionFactory* HttpNetworkLayer::CreateFactory( HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -32,7 +33,7 @@ HttpTransactionFactory* HttpNetworkLayer::CreateFactory( DCHECK(proxy_service); return new HttpNetworkLayer(ClientSocketFactory::GetDefaultFactory(), - host_resolver, dnsrr_resolver, + host_resolver, cert_verifier, dnsrr_resolver, dns_cert_checker, ssl_host_info_factory, proxy_service, ssl_config_service, http_auth_handler_factory, @@ -52,6 +53,7 @@ HttpTransactionFactory* HttpNetworkLayer::CreateFactory( HttpNetworkLayer::HttpNetworkLayer( ClientSocketFactory* socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -62,6 +64,7 @@ HttpNetworkLayer::HttpNetworkLayer( NetLog* net_log) : socket_factory_(socket_factory), host_resolver_(host_resolver), + cert_verifier_(cert_verifier), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), @@ -80,6 +83,7 @@ HttpNetworkLayer::HttpNetworkLayer( HttpNetworkLayer::HttpNetworkLayer( ClientSocketFactory* socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -91,6 +95,7 @@ HttpNetworkLayer::HttpNetworkLayer( NetLog* net_log) : socket_factory_(socket_factory), host_resolver_(host_resolver), + cert_verifier_(cert_verifier), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), @@ -108,6 +113,8 @@ HttpNetworkLayer::HttpNetworkLayer( HttpNetworkLayer::HttpNetworkLayer(HttpNetworkSession* session) : socket_factory_(ClientSocketFactory::GetDefaultFactory()), + host_resolver_(NULL), + cert_verifier_(NULL), dnsrr_resolver_(NULL), dns_cert_checker_(NULL), ssl_host_info_factory_(NULL), @@ -150,6 +157,7 @@ HttpNetworkSession* HttpNetworkLayer::GetSession() { spdy_session_pool_.reset(new SpdySessionPool(ssl_config_service_)); session_ = new HttpNetworkSession( host_resolver_, + cert_verifier_, dnsrr_resolver_, dns_cert_checker_, ssl_host_info_factory_, @@ -162,6 +170,7 @@ HttpNetworkSession* HttpNetworkLayer::GetSession() { net_log_); // These were just temps for lazy-initializing HttpNetworkSession. host_resolver_ = NULL; + cert_verifier_ = NULL; dnsrr_resolver_ = NULL; dns_cert_checker_ = NULL; ssl_host_info_factory_ = NULL; diff --git a/net/http/http_network_layer.h b/net/http/http_network_layer.h index 7781efb..91e1a86 100644 --- a/net/http/http_network_layer.h +++ b/net/http/http_network_layer.h @@ -15,6 +15,7 @@ namespace net { +class CertVerifier; class ClientSocketFactory; class DnsCertProvenanceChecker; class DnsRRResolver; @@ -30,10 +31,12 @@ class SSLHostInfoFactory; class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { public: - // |socket_factory|, |proxy_service| and |host_resolver| must remain valid for - // the lifetime of HttpNetworkLayer. + // |socket_factory|, |proxy_service|, |host_resolver|, etc. must remain + // valid for the lifetime of HttpNetworkLayer. + // TODO(wtc): we only need the next constructor. HttpNetworkLayer(ClientSocketFactory* socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -42,11 +45,10 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { HttpAuthHandlerFactory* http_auth_handler_factory, HttpNetworkDelegate* network_delegate, NetLog* net_log); - // Construct a HttpNetworkLayer with an existing HttpNetworkSession which - // contains a valid ProxyService. HttpNetworkLayer( ClientSocketFactory* socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -57,6 +59,8 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { HttpNetworkDelegate* network_delegate, NetLog* net_log); + // Construct a HttpNetworkLayer with an existing HttpNetworkSession which + // contains a valid ProxyService. explicit HttpNetworkLayer(HttpNetworkSession* session); ~HttpNetworkLayer(); @@ -64,6 +68,7 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { // and allows other implementations to be substituted. static HttpTransactionFactory* CreateFactory( HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -100,9 +105,10 @@ class HttpNetworkLayer : public HttpTransactionFactory, public NonThreadSafe { // The factory we will use to create network sockets. ClientSocketFactory* socket_factory_; - // The host resolver and proxy service that will be used when lazily + // The host resolver, proxy service, etc. that will be used when lazily // creating |session_|. HostResolver* host_resolver_; + CertVerifier* cert_verifier_; DnsRRResolver* dnsrr_resolver_; DnsCertProvenanceChecker* dns_cert_checker_; SSLHostInfoFactory* ssl_host_info_factory_; diff --git a/net/http/http_network_layer_unittest.cc b/net/http/http_network_layer_unittest.cc index 3ed54bf..2720c10 100644 --- a/net/http/http_network_layer_unittest.cc +++ b/net/http/http_network_layer_unittest.cc @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "net/base/cert_verifier.h" #include "net/base/mock_host_resolver.h" #include "net/base/net_log.h" #include "net/base/ssl_config_service_defaults.h" @@ -21,9 +22,11 @@ class HttpNetworkLayerTest : public PlatformTest { TEST_F(HttpNetworkLayerTest, CreateAndDestroy) { MockHostResolver host_resolver; + net::CertVerifier cert_verifier; net::HttpNetworkLayer factory( NULL, &host_resolver, + &cert_verifier, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -41,9 +44,11 @@ TEST_F(HttpNetworkLayerTest, CreateAndDestroy) { TEST_F(HttpNetworkLayerTest, Suspend) { MockHostResolver host_resolver; + net::CertVerifier cert_verifier; net::HttpNetworkLayer factory( NULL, &host_resolver, + &cert_verifier, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -90,9 +95,11 @@ TEST_F(HttpNetworkLayerTest, GET) { mock_socket_factory.AddSocketDataProvider(&data); MockHostResolver host_resolver; + net::CertVerifier cert_verifier; net::HttpNetworkLayer factory( &mock_socket_factory, &host_resolver, + &cert_verifier, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, diff --git a/net/http/http_network_session.cc b/net/http/http_network_session.cc index 1e77b49..e3de475 100644 --- a/net/http/http_network_session.cc +++ b/net/http/http_network_session.cc @@ -20,6 +20,7 @@ namespace net { // TODO(mbelshe): Move the socket factories into HttpStreamFactory. HttpNetworkSession::HttpNetworkSession( HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -32,6 +33,7 @@ HttpNetworkSession::HttpNetworkSession( NetLog* net_log) : socket_factory_(client_socket_factory), host_resolver_(host_resolver), + cert_verifier_(cert_verifier), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), proxy_service_(proxy_service), @@ -39,6 +41,7 @@ HttpNetworkSession::HttpNetworkSession( socket_pool_manager_(net_log, client_socket_factory, host_resolver, + cert_verifier, dnsrr_resolver, dns_cert_checker, ssl_host_info_factory, diff --git a/net/http/http_network_session.h b/net/http/http_network_session.h index 43424d2..2c923b6 100644 --- a/net/http/http_network_session.h +++ b/net/http/http_network_session.h @@ -28,6 +28,7 @@ class Value; namespace net { +class CertVerifier; class ClientSocketFactory; class DnsCertProvenanceChecker; class DnsRRResolver; @@ -48,6 +49,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, public: HttpNetworkSession( HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -109,6 +111,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, // SSL sockets come from the socket_factory(). ClientSocketFactory* socket_factory() { return socket_factory_; } HostResolver* host_resolver() { return host_resolver_; } + CertVerifier* cert_verifier() { return cert_verifier_; } DnsRRResolver* dnsrr_resolver() { return dnsrr_resolver_; } DnsCertProvenanceChecker* dns_cert_checker() { return dns_cert_checker_; @@ -152,6 +155,7 @@ class HttpNetworkSession : public base::RefCounted<HttpNetworkSession>, SSLClientAuthCache ssl_client_auth_cache_; HttpAlternateProtocols alternate_protocols_; HostResolver* const host_resolver_; + CertVerifier* cert_verifier_; DnsRRResolver* dnsrr_resolver_; DnsCertProvenanceChecker* dns_cert_checker_; scoped_refptr<ProxyService> proxy_service_; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 95a8599..79047e1 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -74,6 +74,7 @@ struct SessionDependencies { // Default set of dependencies -- "null" proxy service. SessionDependencies() : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), proxy_service(ProxyService::CreateDirect()), ssl_config_service(new SSLConfigServiceDefaults), http_auth_handler_factory( @@ -83,6 +84,7 @@ struct SessionDependencies { // Custom proxy service dependency. explicit SessionDependencies(ProxyService* proxy_service) : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), proxy_service(proxy_service), ssl_config_service(new SSLConfigServiceDefaults), http_auth_handler_factory( @@ -90,6 +92,7 @@ struct SessionDependencies { net_log(NULL) {} scoped_ptr<MockHostResolverBase> host_resolver; + scoped_ptr<CertVerifier> cert_verifier; scoped_refptr<ProxyService> proxy_service; scoped_refptr<SSLConfigService> ssl_config_service; MockClientSocketFactory socket_factory; @@ -99,6 +102,7 @@ struct SessionDependencies { HttpNetworkSession* CreateSession(SessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), + session_deps->cert_verifier.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -310,7 +314,8 @@ CaptureGroupNameHttpProxySocketPool::CaptureGroupNameSocketPool( template<> CaptureGroupNameSSLSocketPool::CaptureGroupNameSocketPool( HttpNetworkSession* session) - : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), NULL, NULL, + : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), + session->cert_verifier(), NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) {} //----------------------------------------------------------------------------- @@ -6679,7 +6684,8 @@ TEST_F(HttpNetworkTransactionTest, session->ssl_config_service()->GetSSLConfig(&ssl_config); ClientSocket* socket = connection->release_socket(); socket = session->socket_factory()->CreateSSLClientSocket( - socket, HostPortPair("" , 443), ssl_config, NULL /* ssl_host_info */); + socket, HostPortPair("" , 443), ssl_config, NULL /* ssl_host_info */, + session->cert_verifier()); connection->set_socket(socket); EXPECT_EQ(ERR_IO_PENDING, socket->Connect(&callback)); EXPECT_EQ(OK, callback.WaitForResult()); diff --git a/net/http/http_proxy_client_socket_pool_unittest.cc b/net/http/http_proxy_client_socket_pool_unittest.cc index 56fae19..478a312 100644 --- a/net/http/http_proxy_client_socket_pool_unittest.cc +++ b/net/http/http_proxy_client_socket_pool_unittest.cc @@ -62,9 +62,11 @@ class HttpProxyClientSocketPoolTest : public TestWithHttpParam { ssl_histograms_("MockSSL"), ssl_config_service_(new SSLConfigServiceDefaults), host_resolver_(new MockHostResolver), + cert_verifier_(new CertVerifier), ssl_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, &ssl_histograms_, host_resolver_.get(), + cert_verifier_.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -77,6 +79,7 @@ class HttpProxyClientSocketPoolTest : public TestWithHttpParam { http_auth_handler_factory_( HttpAuthHandlerFactory::CreateDefault(host_resolver_.get())), session_(new HttpNetworkSession(host_resolver_.get(), + cert_verifier_.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -192,6 +195,7 @@ class HttpProxyClientSocketPoolTest : public TestWithHttpParam { ClientSocketPoolHistograms ssl_histograms_; scoped_refptr<SSLConfigService> ssl_config_service_; scoped_ptr<HostResolver> host_resolver_; + scoped_ptr<CertVerifier> cert_verifier_; SSLClientSocketPool ssl_socket_pool_; scoped_ptr<HttpAuthHandlerFactory> http_auth_handler_factory_; diff --git a/net/http/http_response_body_drainer_unittest.cc b/net/http/http_response_body_drainer_unittest.cc index 75f099a..76304f8 100644 --- a/net/http/http_response_body_drainer_unittest.cc +++ b/net/http/http_response_body_drainer_unittest.cc @@ -178,6 +178,7 @@ class HttpResponseBodyDrainerTest : public testing::Test { NULL /* host_resolver */, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, + NULL, NULL /* ssl_host_info_factory */, ProxyService::CreateDirect(), NULL, diff --git a/net/http/http_stream_factory_unittest.cc b/net/http/http_stream_factory_unittest.cc index 63fce33..646f79c 100644 --- a/net/http/http_stream_factory_unittest.cc +++ b/net/http/http_stream_factory_unittest.cc @@ -7,6 +7,7 @@ #include <string> #include "base/basictypes.h" +#include "net/base/cert_verifier.h" #include "net/base/mock_host_resolver.h" #include "net/base/net_log.h" #include "net/base/ssl_config_service_defaults.h" @@ -27,6 +28,7 @@ struct SessionDependencies { // Custom proxy service dependency. explicit SessionDependencies(ProxyService* proxy_service) : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), proxy_service(proxy_service), ssl_config_service(new SSLConfigServiceDefaults), http_auth_handler_factory( @@ -34,6 +36,7 @@ struct SessionDependencies { net_log(NULL) {} scoped_ptr<MockHostResolverBase> host_resolver; + scoped_ptr<CertVerifier> cert_verifier; scoped_refptr<ProxyService> proxy_service; scoped_refptr<SSLConfigService> ssl_config_service; MockClientSocketFactory socket_factory; @@ -43,6 +46,7 @@ struct SessionDependencies { HttpNetworkSession* CreateSession(SessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), + session_deps->cert_verifier.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -170,7 +174,8 @@ CapturePreconnectsHttpProxySocketPool::CapturePreconnectsSocketPool( template<> CapturePreconnectsSSLSocketPool::CapturePreconnectsSocketPool( HttpNetworkSession* session) - : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), NULL, NULL, + : SSLClientSocketPool(0, 0, NULL, session->host_resolver(), + session->cert_verifier(), NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) {} TEST(HttpStreamFactoryTest, PreconnectDirect) { diff --git a/net/net.gyp b/net/net.gyp index 4460b41..7be250c 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -841,6 +841,7 @@ 'sources': [ 'base/address_list_unittest.cc', 'base/cert_database_nss_unittest.cc', + 'base/cert_verifier_unittest.cc', 'base/cookie_monster_unittest.cc', 'base/data_url_unittest.cc', 'base/directory_lister_unittest.cc', diff --git a/net/proxy/proxy_script_fetcher_impl_unittest.cc b/net/proxy/proxy_script_fetcher_impl_unittest.cc index dc7ac45..ec0fb58 100644 --- a/net/proxy/proxy_script_fetcher_impl_unittest.cc +++ b/net/proxy/proxy_script_fetcher_impl_unittest.cc @@ -39,18 +39,21 @@ class RequestContext : public URLRequestContext { host_resolver_ = net::CreateSystemHostResolver(net::HostResolver::kDefaultParallelism, NULL, NULL); + cert_verifier_ = new net::CertVerifier; proxy_service_ = net::ProxyService::CreateFixed(no_proxy); ssl_config_service_ = new net::SSLConfigServiceDefaults; http_transaction_factory_ = new net::HttpCache( - net::HttpNetworkLayer::CreateFactory(host_resolver_, NULL, NULL, NULL, - proxy_service_, ssl_config_service_, NULL, NULL, NULL), + net::HttpNetworkLayer::CreateFactory(host_resolver_, cert_verifier_, + NULL, NULL, NULL, proxy_service_, ssl_config_service_, NULL, NULL, + NULL), net::HttpCache::DefaultBackend::InMemory(0)); } private: ~RequestContext() { delete http_transaction_factory_; + delete cert_verifier_; delete host_resolver_; } }; diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc index 1c998c6..f4da066 100644 --- a/net/socket/client_socket_factory.cc +++ b/net/socket/client_socket_factory.cc @@ -30,19 +30,21 @@ SSLClientSocket* DefaultSSLClientSocketFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { scoped_ptr<SSLHostInfo> shi(ssl_host_info); #if defined(OS_WIN) - return new SSLClientSocketWin(transport_socket, host_and_port, ssl_config); + return new SSLClientSocketWin(transport_socket, host_and_port, ssl_config, + cert_verifier); #elif defined(USE_OPENSSL) return new SSLClientSocketOpenSSL(transport_socket, host_and_port, - ssl_config); + ssl_config, cert_verifier); #elif defined(USE_NSS) return new SSLClientSocketNSS(transport_socket, host_and_port, ssl_config, - shi.release(), dns_cert_checker); + shi.release(), cert_verifier, dns_cert_checker); #elif defined(OS_MACOSX) return new SSLClientSocketNSS(transport_socket, host_and_port, ssl_config, - shi.release(), dns_cert_checker); + shi.release(), cert_verifier, dns_cert_checker); #else NOTIMPLEMENTED(); return NULL; @@ -65,9 +67,10 @@ class DefaultClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { return g_ssl_factory(transport_socket, host_and_port, ssl_config, - ssl_host_info, dns_cert_checker); + ssl_host_info, cert_verifier, dns_cert_checker); } }; @@ -92,11 +95,12 @@ SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( ClientSocket* transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - SSLHostInfo* ssl_host_info) { + SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier) { ClientSocketHandle* socket_handle = new ClientSocketHandle(); socket_handle->set_socket(transport_socket); return CreateSSLClientSocket(socket_handle, host_and_port, ssl_config, - ssl_host_info, + ssl_host_info, cert_verifier, NULL /* DnsCertProvenanceChecker */); } diff --git a/net/socket/client_socket_factory.h b/net/socket/client_socket_factory.h index 0ab370a9..2a0cd7c 100644 --- a/net/socket/client_socket_factory.h +++ b/net/socket/client_socket_factory.h @@ -14,6 +14,7 @@ namespace net { class AddressList; +class CertVerifier; class ClientSocket; class ClientSocketHandle; class DnsCertProvenanceChecker; @@ -28,6 +29,7 @@ typedef SSLClientSocket* (*SSLClientSocketFactory)( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); // An interface used to instantiate ClientSocket objects. Used to facilitate @@ -48,6 +50,7 @@ class ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) = 0; // Deprecated function (http://crbug.com/37810) that takes a ClientSocket. @@ -55,7 +58,8 @@ class ClientSocketFactory { ClientSocket* transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - SSLHostInfo* ssl_host_info); + SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier); // Returns the default ClientSocketFactory. static ClientSocketFactory* GetDefaultFactory(); diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index 843b6be..7c0e2e1 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -110,6 +110,7 @@ class MockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { NOTIMPLEMENTED(); delete ssl_host_info; diff --git a/net/socket/client_socket_pool_manager.cc b/net/socket/client_socket_pool_manager.cc index 6c73c36..8516fbc 100644 --- a/net/socket/client_socket_pool_manager.cc +++ b/net/socket/client_socket_pool_manager.cc @@ -55,6 +55,7 @@ ClientSocketPoolManager::ClientSocketPoolManager( NetLog* net_log, ClientSocketFactory* socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -63,6 +64,7 @@ ClientSocketPoolManager::ClientSocketPoolManager( : net_log_(net_log), socket_factory_(socket_factory), host_resolver_(host_resolver), + cert_verifier_(cert_verifier), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), @@ -80,6 +82,7 @@ ClientSocketPoolManager::ClientSocketPoolManager( g_max_sockets, g_max_sockets_per_group, &ssl_pool_histograms_, host_resolver, + cert_verifier, dnsrr_resolver, dns_cert_checker, ssl_host_info_factory, @@ -230,6 +233,7 @@ HttpProxyClientSocketPool* ClientSocketPoolManager::GetSocketPoolForHTTPProxy( g_max_sockets_per_proxy_server, g_max_sockets_per_group, &ssl_for_https_proxy_pool_histograms_, host_resolver_, + cert_verifier_, dnsrr_resolver_, dns_cert_checker_, ssl_host_info_factory_, @@ -266,6 +270,7 @@ SSLClientSocketPool* ClientSocketPoolManager::GetSocketPoolForSSLWithProxy( g_max_sockets_per_proxy_server, g_max_sockets_per_group, &ssl_pool_histograms_, host_resolver_, + cert_verifier_, dnsrr_resolver_, dns_cert_checker_, ssl_host_info_factory_, diff --git a/net/socket/client_socket_pool_manager.h b/net/socket/client_socket_pool_manager.h index 823213e..cfcb465 100644 --- a/net/socket/client_socket_pool_manager.h +++ b/net/socket/client_socket_pool_manager.h @@ -6,8 +6,8 @@ // simple container for all of them. Most importantly, it handles the lifetime // and destruction order properly. -#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_ -#define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_ +#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_ +#define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_ #pragma once #include <map> @@ -23,6 +23,7 @@ class Value; namespace net { +class CertVerifier; class ClientSocketFactory; class ClientSocketPoolHistograms; class DnsCertProvenanceChecker; @@ -54,13 +55,14 @@ class OwnedPoolMap : public std::map<Key, Value> { } }; -} // internal +} // namespace internal class ClientSocketPoolManager : public NonThreadSafe { public: ClientSocketPoolManager(NetLog* net_log, ClientSocketFactory* socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -106,6 +108,7 @@ class ClientSocketPoolManager : public NonThreadSafe { NetLog* const net_log_; ClientSocketFactory* const socket_factory_; HostResolver* const host_resolver_; + CertVerifier* const cert_verifier_; DnsRRResolver* const dnsrr_resolver_; DnsCertProvenanceChecker* const dns_cert_checker_; SSLHostInfoFactory* const ssl_host_info_factory_; @@ -146,4 +149,4 @@ class ClientSocketPoolManager : public NonThreadSafe { } // namespace net -#endif // NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_ +#endif // NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_ diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index b2e738a..d88399d 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -402,7 +402,7 @@ int DeterministicMockTCPClientSocket::Read( return CompleteRead(); } -void DeterministicMockTCPClientSocket::CompleteWrite(){ +void DeterministicMockTCPClientSocket::CompleteWrite() { was_used_to_convey_data_ = true; write_pending_ = false; write_callback_->Run(write_result_); @@ -1016,6 +1016,7 @@ SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { MockSSLClientSocket* socket = new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, @@ -1066,6 +1067,7 @@ SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { MockSSLClientSocket* socket = new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 0a01df3..73dd07c 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -537,6 +537,7 @@ class MockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); SocketDataProviderArray<SocketDataProvider>& mock_data() { return mock_data_; @@ -882,6 +883,7 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); SocketDataProviderArray<DeterministicSocketData>& mock_data() { diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index 488beeb..352b3b1 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -520,7 +520,8 @@ EnabledCipherSuites::EnabledCipherSuites() { SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const SSLConfig& ssl_config) + const SSLConfig& ssl_config, + CertVerifier* cert_verifier) : handshake_io_callback_(this, &SSLClientSocketMac::OnHandshakeIOComplete), transport_read_callback_(this, &SSLClientSocketMac::OnTransportReadComplete), @@ -535,6 +536,7 @@ SSLClientSocketMac::SSLClientSocketMac(ClientSocketHandle* transport_socket, user_read_buf_len_(0), user_write_buf_len_(0), next_handshake_state_(STATE_NONE), + cert_verifier_(cert_verifier), renegotiating_(false), client_cert_requested_(false), ssl_context_(NULL), @@ -1066,7 +1068,7 @@ int SSLClientSocketMac::DoVerifyCert() { flags |= X509Certificate::VERIFY_REV_CHECKING_ENABLED; if (ssl_config_.verify_ev_cert) flags |= X509Certificate::VERIFY_EV_CERT; - verifier_.reset(new CertVerifier); + verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); return verifier_->Verify(server_cert_, host_and_port_.host(), flags, &server_cert_verify_result_, &handshake_io_callback_); diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index e84bee4..a94b2bd 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -23,6 +23,7 @@ namespace net { class CertVerifier; class ClientSocketHandle; +class SingleRequestCertVerifier; // An SSL client socket implemented with Secure Transport. class SSLClientSocketMac : public SSLClientSocket { @@ -35,7 +36,8 @@ class SSLClientSocketMac : public SSLClientSocket { // the SSL settings. SSLClientSocketMac(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const SSLConfig& ssl_config); + const SSLConfig& ssl_config, + CertVerifier* cert_verifier); ~SSLClientSocketMac(); // SSLClientSocket methods: @@ -137,7 +139,8 @@ class SSLClientSocketMac : public SSLClientSocket { State next_handshake_state_; scoped_refptr<X509Certificate> server_cert_; - scoped_ptr<CertVerifier> verifier_; + CertVerifier* const cert_verifier_; + scoped_ptr<SingleRequestCertVerifier> verifier_; CertVerifyResult server_cert_verify_result_; // The initial handshake has already completed, and the current handshake diff --git a/net/socket/ssl_client_socket_mac_factory.cc b/net/socket/ssl_client_socket_mac_factory.cc index bf732e6..211e2a4 100644 --- a/net/socket/ssl_client_socket_mac_factory.cc +++ b/net/socket/ssl_client_socket_mac_factory.cc @@ -14,9 +14,11 @@ SSLClientSocket* SSLClientSocketMacFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { delete ssl_host_info; - return new SSLClientSocketMac(transport_socket, host_and_port, ssl_config); + return new SSLClientSocketMac(transport_socket, host_and_port, ssl_config, + cert_verifier); } } // namespace net diff --git a/net/socket/ssl_client_socket_mac_factory.h b/net/socket/ssl_client_socket_mac_factory.h index 5539136..ebda9c3 100644 --- a/net/socket/ssl_client_socket_mac_factory.h +++ b/net/socket/ssl_client_socket_mac_factory.h @@ -19,6 +19,7 @@ SSLClientSocket* SSLClientSocketMacFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); } // namespace net diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index bbfe12f..05cad27 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -408,6 +408,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_ctx) : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( this, &SSLClientSocketNSS::BufferSendComplete)), @@ -430,6 +431,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, server_cert_verify_result_(NULL), ssl_connection_status_(0), client_auth_cert_needed_(false), + cert_verifier_(cert_verifier), handshake_callback_called_(false), completed_handshake_(false), pseudo_connected_(false), @@ -2464,7 +2466,7 @@ int SSLClientSocketNSS::DoVerifyCert(int result) { flags |= X509Certificate::VERIFY_REV_CHECKING_ENABLED; if (ssl_config_.verify_ev_cert) flags |= X509Certificate::VERIFY_EV_CERT; - verifier_.reset(new CertVerifier); + verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); server_cert_verify_result_ = &local_server_cert_verify_result_; return verifier_->Verify(server_cert_, host_and_port_.host(), flags, &local_server_cert_verify_result_, diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 8798361..bca4166 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -32,6 +32,7 @@ class BoundNetLog; class CertVerifier; class ClientSocketHandle; class DnsCertProvenanceChecker; +class SingleRequestCertVerifier; class SSLHostInfo; class X509Certificate; @@ -48,6 +49,7 @@ class SSLClientSocketNSS : public SSLClientSocket { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dnsrr_resolver); ~SSLClientSocketNSS(); @@ -193,7 +195,8 @@ class SSLClientSocketNSS : public SSLClientSocket { std::vector<scoped_refptr<X509Certificate> > client_certs_; bool client_auth_cert_needed_; - scoped_ptr<CertVerifier> verifier_; + CertVerifier* const cert_verifier_; + scoped_ptr<SingleRequestCertVerifier> verifier_; // True if NSS has called HandshakeCallback. bool handshake_callback_called_; diff --git a/net/socket/ssl_client_socket_nss_factory.cc b/net/socket/ssl_client_socket_nss_factory.cc index e4c01f0..435ddff 100644 --- a/net/socket/ssl_client_socket_nss_factory.cc +++ b/net/socket/ssl_client_socket_nss_factory.cc @@ -19,10 +19,11 @@ SSLClientSocket* SSLClientSocketNSSFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { scoped_ptr<SSLHostInfo> shi(ssl_host_info); return new SSLClientSocketNSS(transport_socket, host_and_port, ssl_config, - shi.release(), dns_cert_checker); + shi.release(), cert_verifier, dns_cert_checker); } } // namespace net diff --git a/net/socket/ssl_client_socket_nss_factory.h b/net/socket/ssl_client_socket_nss_factory.h index 15b05b2..ed5e588 100644 --- a/net/socket/ssl_client_socket_nss_factory.h +++ b/net/socket/ssl_client_socket_nss_factory.h @@ -19,6 +19,7 @@ SSLClientSocket* SSLClientSocketNSSFactory( const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); } // namespace net diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index ab4ba6c..e485c8a 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -380,7 +380,8 @@ struct SslSetClearMask { SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const SSLConfig& ssl_config) + const SSLConfig& ssl_config, + CertVerifier* cert_verifier) : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( this, &SSLClientSocketOpenSSL::BufferSendComplete)), ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( @@ -392,6 +393,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( user_write_callback_(NULL), completed_handshake_(false), client_auth_cert_needed_(false), + cert_verifier_(cert_verifier), ALLOW_THIS_IN_INITIALIZER_LIST(handshake_io_callback_( this, &SSLClientSocketOpenSSL::OnHandshakeIOComplete)), ssl_(NULL), @@ -813,7 +815,7 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) { flags |= X509Certificate::VERIFY_REV_CHECKING_ENABLED; if (ssl_config_.verify_ev_cert) flags |= X509Certificate::VERIFY_EV_CERT; - verifier_.reset(new CertVerifier); + verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); return verifier_->Verify(server_cert_, host_and_port_.host(), flags, &server_cert_verify_result_, &handshake_io_callback_); diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 62cc4d4..d59b507 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -24,6 +24,7 @@ typedef struct x509_st X509; namespace net { class CertVerifier; +class SingleRequestCertVerifier; class SSLCertRequestInfo; class SSLConfig; class SSLInfo; @@ -37,7 +38,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // settings. SSLClientSocketOpenSSL(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const SSLConfig& ssl_config); + const SSLConfig& ssl_config, + CertVerifier* cert_verifier); ~SSLClientSocketOpenSSL(); const HostPortPair& host_and_port() const { return host_and_port_; } @@ -131,7 +133,8 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { std::vector<scoped_refptr<X509Certificate> > client_certs_; bool client_auth_cert_needed_; - scoped_ptr<CertVerifier> verifier_; + CertVerifier* const cert_verifier_; + scoped_ptr<SingleRequestCertVerifier> verifier_; CompletionCallbackImpl<SSLClientSocketOpenSSL> handshake_io_callback_; // OpenSSL stuff diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc index 7124efa..deaf4f3 100644 --- a/net/socket/ssl_client_socket_pool.cc +++ b/net/socket/ssl_client_socket_pool.cc @@ -77,6 +77,7 @@ SSLConnectJob::SSLConnectJob( HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -89,7 +90,8 @@ SSLConnectJob::SSLConnectJob( socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), client_socket_factory_(client_socket_factory), - resolver_(host_resolver), + host_resolver_(host_resolver), + cert_verifier_(cert_verifier), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), @@ -289,7 +291,8 @@ int SSLConnectJob::DoSSLConnect() { ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( transport_socket_handle_.release(), params_->host_and_port(), - params_->ssl_config(), ssl_host_info_.release(), dns_cert_checker_)); + params_->ssl_config(), ssl_host_info_.release(), cert_verifier_, + dns_cert_checker_)); return ssl_socket_->Connect(&callback_); } @@ -360,7 +363,7 @@ ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), tcp_pool_, socks_pool_, http_proxy_pool_, client_socket_factory_, host_resolver_, - dnsrr_resolver_, dns_cert_checker_, + cert_verifier_, dnsrr_resolver_, dns_cert_checker_, ssl_host_info_factory_, delegate, net_log_); } @@ -370,6 +373,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -379,6 +383,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( http_proxy_pool_(http_proxy_pool), client_socket_factory_(client_socket_factory), host_resolver_(host_resolver), + cert_verifier_(cert_verifier), dnsrr_resolver_(dnsrr_resolver), dns_cert_checker_(dns_cert_checker), ssl_host_info_factory_(ssl_host_info_factory), @@ -406,6 +411,7 @@ SSLClientSocketPool::SSLClientSocketPool( int max_sockets_per_group, ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -424,8 +430,8 @@ SSLClientSocketPool::SSLClientSocketPool( base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), new SSLConnectJobFactory(tcp_pool, socks_pool, http_proxy_pool, client_socket_factory, host_resolver, - dnsrr_resolver, dns_cert_checker, - ssl_host_info_factory, + cert_verifier, dnsrr_resolver, + dns_cert_checker, ssl_host_info_factory, net_log)), ssl_config_service_(ssl_config_service) { if (ssl_config_service_) diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h index 136516f..468d3ed1 100644 --- a/net/socket/ssl_client_socket_pool.h +++ b/net/socket/ssl_client_socket_pool.h @@ -22,6 +22,7 @@ namespace net { +class CertVerifier; class ClientSocketFactory; class ConnectJobFactory; class DnsCertProvenanceChecker; @@ -95,6 +96,7 @@ class SSLConnectJob : public ConnectJob { HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -144,7 +146,8 @@ class SSLConnectJob : public ConnectJob { SOCKSClientSocketPool* const socks_pool_; HttpProxyClientSocketPool* const http_proxy_pool_; ClientSocketFactory* const client_socket_factory_; - HostResolver* const resolver_; + HostResolver* const host_resolver_; + CertVerifier* const cert_verifier_; DnsRRResolver* const dnsrr_resolver_; DnsCertProvenanceChecker* dns_cert_checker_; SSLHostInfoFactory* const ssl_host_info_factory_; @@ -173,6 +176,7 @@ class SSLClientSocketPool : public ClientSocketPool, int max_sockets_per_group, ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -241,6 +245,7 @@ class SSLClientSocketPool : public ClientSocketPool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, + CertVerifier* cert_verifier, DnsRRResolver* dnsrr_resolver, DnsCertProvenanceChecker* dns_cert_checker, SSLHostInfoFactory* ssl_host_info_factory, @@ -262,6 +267,7 @@ class SSLClientSocketPool : public ClientSocketPool, HttpProxyClientSocketPool* const http_proxy_pool_; ClientSocketFactory* const client_socket_factory_; HostResolver* const host_resolver_; + CertVerifier* const cert_verifier_; DnsRRResolver* const dnsrr_resolver_; DnsCertProvenanceChecker* const dns_cert_checker_; SSLHostInfoFactory* const ssl_host_info_factory_; diff --git a/net/socket/ssl_client_socket_pool_unittest.cc b/net/socket/ssl_client_socket_pool_unittest.cc index 247638b..37e21ca 100644 --- a/net/socket/ssl_client_socket_pool_unittest.cc +++ b/net/socket/ssl_client_socket_pool_unittest.cc @@ -10,6 +10,7 @@ #include "base/time.h" #include "base/utf_string_conversions.h" #include "net/base/auth.h" +#include "net/base/cert_verifier.h" #include "net/base/mock_host_resolver.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -36,9 +37,11 @@ class SSLClientSocketPoolTest : public testing::Test { protected: SSLClientSocketPoolTest() : host_resolver_(new MockHostResolver), + cert_verifier_(new CertVerifier), http_auth_handler_factory_(HttpAuthHandlerFactory::CreateDefault( host_resolver_.get())), session_(new HttpNetworkSession(host_resolver_.get(), + cert_verifier_.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -96,7 +99,8 @@ class SSLClientSocketPoolTest : public testing::Test { kMaxSockets, kMaxSocketsPerGroup, ssl_histograms_.get(), - NULL, + NULL /* host_resolver */, + NULL /* cert_verifier */, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -131,6 +135,7 @@ class SSLClientSocketPoolTest : public testing::Test { MockClientSocketFactory socket_factory_; scoped_ptr<HostResolver> host_resolver_; + scoped_ptr<CertVerifier> cert_verifier_; scoped_ptr<HttpAuthHandlerFactory> http_auth_handler_factory_; scoped_refptr<HttpNetworkSession> session_; diff --git a/net/socket/ssl_client_socket_snapstart_unittest.cc b/net/socket/ssl_client_socket_snapstart_unittest.cc index ecb9789..d782993 100644 --- a/net/socket/ssl_client_socket_snapstart_unittest.cc +++ b/net/socket/ssl_client_socket_snapstart_unittest.cc @@ -41,8 +41,8 @@ namespace net { // pretends that certificate verification always succeeds. class TestSSLHostInfo : public SSLHostInfo { public: - TestSSLHostInfo() - : SSLHostInfo("example.com", kDefaultSSLConfig) { + explicit TestSSLHostInfo(CertVerifier* cert_verifier) + : SSLHostInfo("example.com", kDefaultSSLConfig, cert_verifier) { if (!saved_.empty()) Parse(saved_); cert_verification_complete_ = true; @@ -194,7 +194,7 @@ class SSLClientSocketSnapStartTest : public PlatformTest { scoped_ptr<SSLClientSocket> sock( socket_factory_->CreateSSLClientSocket( transport, HostPortPair("example.com", 443), ssl_config_, - new TestSSLHostInfo())); + new TestSSLHostInfo(&cert_verifier_), &cert_verifier_)); TestCompletionCallback callback; int rv = sock->Connect(&callback); @@ -265,6 +265,7 @@ class SSLClientSocketSnapStartTest : public PlatformTest { } base::ProcessHandle child_; + CertVerifier cert_verifier_; ClientSocketFactory* const socket_factory_; struct sockaddr_in remote_; int client_; diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index 0410a06..9ba5cbf 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -5,6 +5,7 @@ #include "net/socket/ssl_client_socket.h" #include "net/base/address_list.h" +#include "net/base/cert_verifier.h" #include "net/base/host_resolver.h" #include "net/base/io_buffer.h" #include "net/base/net_log.h" @@ -26,11 +27,24 @@ const net::SSLConfig kDefaultSSLConfig; class SSLClientSocketTest : public PlatformTest { public: SSLClientSocketTest() - : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new net::CertVerifier) { } protected: + net::SSLClientSocket* CreateSSLClientSocket( + net::ClientSocket* transport_socket, + const net::HostPortPair& host_and_port, + const net::SSLConfig& ssl_config) { + return socket_factory_->CreateSSLClientSocket(transport_socket, + host_and_port, + ssl_config, + NULL, + cert_verifier_.get()); + } + net::ClientSocketFactory* socket_factory_; + scoped_ptr<net::CertVerifier> cert_verifier_; }; //----------------------------------------------------------------------------- @@ -67,7 +81,8 @@ TEST_F(SSLClientSocketTest, Connect) { scoped_ptr<net::SSLClientSocket> sock( socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + transport, test_server.host_port_pair(), kDefaultSSLConfig, + NULL, cert_verifier_.get())); EXPECT_FALSE(sock->IsConnected()); @@ -107,8 +122,8 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -150,8 +165,8 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -196,8 +211,8 @@ TEST_F(SSLClientSocketTest, FLAKY_ConnectClientAuthCertRequested) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -243,8 +258,8 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { ssl_config.client_cert = NULL; scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + ssl_config)); EXPECT_FALSE(sock->IsConnected()); @@ -289,8 +304,8 @@ TEST_F(SSLClientSocketTest, Read) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); rv = sock->Connect(&callback); if (rv == net::ERR_IO_PENDING) @@ -345,7 +360,8 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { scoped_ptr<net::SSLClientSocket> sock( socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + transport, test_server.host_port_pair(), kDefaultSSLConfig, + NULL, cert_verifier_.get())); rv = sock->Connect(&callback); if (rv == net::ERR_IO_PENDING) @@ -398,8 +414,8 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); rv = sock->Connect(&callback); if (rv == net::ERR_IO_PENDING) @@ -448,8 +464,8 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); rv = sock->Connect(&callback); if (rv == net::ERR_IO_PENDING) @@ -518,8 +534,8 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { EXPECT_EQ(net::OK, rv); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), kDefaultSSLConfig, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + kDefaultSSLConfig)); rv = sock->Connect(&callback); EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); @@ -560,8 +576,8 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]); scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - transport, test_server.host_port_pair(), ssl_config, NULL)); + CreateSSLClientSocket(transport, test_server.host_port_pair(), + ssl_config)); EXPECT_FALSE(sock->IsConnected()); diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 19c3814..ae4d4b5 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -376,7 +376,8 @@ static const int kRecvBufferSize = (5 + 16*1024 + 64); SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const SSLConfig& ssl_config) + const SSLConfig& ssl_config, + CertVerifier* cert_verifier) : ALLOW_THIS_IN_INITIALIZER_LIST( handshake_io_callback_(this, &SSLClientSocketWin::OnHandshakeIOComplete)), @@ -393,6 +394,7 @@ SSLClientSocketWin::SSLClientSocketWin(ClientSocketHandle* transport_socket, user_write_callback_(NULL), user_write_buf_len_(0), next_state_(STATE_NONE), + cert_verifier_(cert_verifier), creds_(NULL), isc_status_(SEC_E_OK), payload_send_buffer_len_(0), @@ -1124,7 +1126,7 @@ int SSLClientSocketWin::DoVerifyCert() { flags |= X509Certificate::VERIFY_REV_CHECKING_ENABLED; if (ssl_config_.verify_ev_cert) flags |= X509Certificate::VERIFY_EV_CERT; - verifier_.reset(new CertVerifier); + verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); return verifier_->Verify(server_cert_, host_and_port_.host(), flags, &server_cert_verify_result_, &handshake_io_callback_); diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 61c67f0..2bb1853 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -28,6 +28,7 @@ class BoundNetLog; class CertVerifier; class ClientSocketHandle; class HostPortPair; +class SingleRequestCertVerifier; // An SSL client socket implemented with the Windows Schannel. class SSLClientSocketWin : public SSLClientSocket { @@ -40,7 +41,8 @@ class SSLClientSocketWin : public SSLClientSocket { // the SSL settings. SSLClientSocketWin(ClientSocketHandle* transport_socket, const HostPortPair& host_and_port, - const SSLConfig& ssl_config); + const SSLConfig& ssl_config, + CertVerifier* cert_verifier); ~SSLClientSocketWin(); // SSLClientSocket methods: @@ -145,7 +147,8 @@ class SSLClientSocketWin : public SSLClientSocket { SecPkgContext_StreamSizes stream_sizes_; scoped_refptr<X509Certificate> server_cert_; - scoped_ptr<CertVerifier> verifier_; + CertVerifier* const cert_verifier_; + scoped_ptr<SingleRequestCertVerifier> verifier_; CertVerifyResult server_cert_verify_result_; CredHandle* creds_; diff --git a/net/socket/ssl_host_info.cc b/net/socket/ssl_host_info.cc index 8c1b79f..527c2db 100644 --- a/net/socket/ssl_host_info.cc +++ b/net/socket/ssl_host_info.cc @@ -7,7 +7,6 @@ #include "base/metrics/histogram.h" #include "base/pickle.h" #include "base/string_piece.h" -#include "net/base/cert_verifier.h" #include "net/base/ssl_config_service.h" #include "net/base/x509_certificate.h" #include "net/socket/ssl_client_socket.h" @@ -29,7 +28,8 @@ void SSLHostInfo::State::Clear() { SSLHostInfo::SSLHostInfo( const std::string& hostname, - const SSLConfig& ssl_config) + const SSLConfig& ssl_config, + CertVerifier* cert_verifier) : cert_verification_complete_(false), cert_verification_error_(ERR_CERT_INVALID), hostname_(hostname), @@ -37,6 +37,7 @@ SSLHostInfo::SSLHostInfo( cert_verification_callback_(NULL), rev_checking_enabled_(ssl_config.rev_checking_enabled), verify_ev_cert_(ssl_config.verify_ev_cert), + verifier_(cert_verifier), callback_(new CancelableCompletionCallback<SSLHostInfo>( ALLOW_THIS_IN_INITIALIZER_LIST(this), &SSLHostInfo::VerifyCallback)) { @@ -110,12 +111,11 @@ bool SSLHostInfo::ParseInner(const std::string& data) { flags |= X509Certificate::VERIFY_EV_CERT; if (rev_checking_enabled_) flags |= X509Certificate::VERIFY_REV_CHECKING_ENABLED; - verifier_.reset(new CertVerifier); VLOG(1) << "Kicking off verification for " << hostname_; verification_start_time_ = base::TimeTicks::Now(); verification_end_time_ = base::TimeTicks(); - if (verifier_->Verify(cert_.get(), hostname_, flags, - &cert_verify_result_, callback_) == OK) { + if (verifier_.Verify(cert_.get(), hostname_, flags, + &cert_verify_result_, callback_) == OK) { VerifyCallback(OK); } } else { diff --git a/net/socket/ssl_host_info.h b/net/socket/ssl_host_info.h index 782293e..8f1502b 100644 --- a/net/socket/ssl_host_info.h +++ b/net/socket/ssl_host_info.h @@ -11,13 +11,13 @@ #include "base/ref_counted.h" #include "base/scoped_ptr.h" #include "base/time.h" +#include "net/base/cert_verifier.h" #include "net/base/cert_verify_result.h" #include "net/base/completion_callback.h" #include "net/socket/ssl_client_socket.h" namespace net { -class CertVerifier; class X509Certificate; struct SSLConfig; @@ -27,7 +27,9 @@ struct SSLConfig; // certificates. class SSLHostInfo { public: - SSLHostInfo(const std::string& hostname, const SSLConfig& ssl_config); + SSLHostInfo(const std::string& hostname, + const SSLConfig& ssl_config, + CertVerifier *certVerifier); virtual ~SSLHostInfo(); // Start will commence the lookup. This must be called before any other @@ -127,7 +129,7 @@ class SSLHostInfo { base::TimeTicks verification_start_time_; base::TimeTicks verification_end_time_; CertVerifyResult cert_verify_result_; - scoped_ptr<CertVerifier> verifier_; + SingleRequestCertVerifier verifier_; scoped_refptr<X509Certificate> cert_; scoped_refptr<CancelableCompletionCallback<SSLHostInfo> > callback_; }; diff --git a/net/socket/tcp_client_socket_pool_unittest.cc b/net/socket/tcp_client_socket_pool_unittest.cc index c44815c..454f5b8 100644 --- a/net/socket/tcp_client_socket_pool_unittest.cc +++ b/net/socket/tcp_client_socket_pool_unittest.cc @@ -149,7 +149,7 @@ class MockPendingClientSocket : public ClientSocket { virtual bool IsConnectedAndIdle() const { return is_connected_; } - virtual int GetPeerAddress(AddressList* address) const{ + virtual int GetPeerAddress(AddressList* address) const { return ERR_UNEXPECTED; } virtual const BoundNetLog& NetLog() const { @@ -251,6 +251,7 @@ class MockClientSocketFactory : public ClientSocketFactory { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker) { NOTIMPLEMENTED(); delete ssl_host_info; diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index 4075e02..21c8e74 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -50,6 +50,8 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate) url_(url), max_pending_send_allowed_(kMaxPendingSendAllowed), next_state_(STATE_NONE), + host_resolver_(NULL), + cert_verifier_(NULL), http_auth_handler_factory_(NULL), factory_(ClientSocketFactory::GetDefaultFactory()), proxy_mode_(kDirectConnection), @@ -119,6 +121,7 @@ void SocketStream::set_context(URLRequestContext* context) { if (context_) { host_resolver_ = context_->host_resolver(); + cert_verifier_ = context_->cert_verifier(); http_auth_handler_factory_ = context_->http_auth_handler_factory(); } } @@ -800,7 +803,8 @@ int SocketStream::DoSSLConnect() { socket_.reset(factory_->CreateSSLClientSocket(socket_.release(), HostPortPair::FromURL(url_), ssl_config_, - NULL /* ssl_host_info */)); + NULL /* ssl_host_info */, + cert_verifier_)); next_state_ = STATE_SSL_CONNECT_COMPLETE; metrics_->OnSSLConnection(); return socket_->Connect(&io_callback_); diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index e1f2584..f485543 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -274,6 +274,7 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { State next_state_; HostResolver* host_resolver_; + CertVerifier* cert_verifier_; HttpAuthHandlerFactory* http_auth_handler_factory_; ClientSocketFactory* factory_; diff --git a/net/spdy/spdy_test_util.h b/net/spdy/spdy_test_util.h index 0a5d2e0..aeabe6a 100644 --- a/net/spdy/spdy_test_util.h +++ b/net/spdy/spdy_test_util.h @@ -7,6 +7,7 @@ #pragma once #include "base/basictypes.h" +#include "net/base/cert_verifier.h" #include "net/base/mock_host_resolver.h" #include "net/base/request_priority.h" #include "net/base/ssl_config_service_defaults.h" @@ -327,6 +328,7 @@ class SpdySessionDependencies { // Default set of dependencies -- "null" proxy service. SpdySessionDependencies() : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), proxy_service(ProxyService::CreateDirect()), ssl_config_service(new SSLConfigServiceDefaults), socket_factory(new MockClientSocketFactory), @@ -345,6 +347,7 @@ class SpdySessionDependencies { // Custom proxy service dependency. explicit SpdySessionDependencies(ProxyService* proxy_service) : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), proxy_service(proxy_service), ssl_config_service(new SSLConfigServiceDefaults), socket_factory(new MockClientSocketFactory), @@ -354,6 +357,7 @@ class SpdySessionDependencies { // NOTE: host_resolver must be ordered before http_auth_handler_factory. scoped_ptr<MockHostResolverBase> host_resolver; + scoped_ptr<CertVerifier> cert_verifier; scoped_refptr<ProxyService> proxy_service; scoped_refptr<SSLConfigService> ssl_config_service; scoped_ptr<MockClientSocketFactory> socket_factory; @@ -363,6 +367,7 @@ class SpdySessionDependencies { static HttpNetworkSession* SpdyCreateSession( SpdySessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), + session_deps->cert_verifier.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -377,6 +382,7 @@ class SpdySessionDependencies { static HttpNetworkSession* SpdyCreateSessionDeterministic( SpdySessionDependencies* session_deps) { return new HttpNetworkSession(session_deps->host_resolver.get(), + session_deps->cert_verifier.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -395,6 +401,7 @@ class SpdyURLRequestContext : public URLRequestContext { public: SpdyURLRequestContext() { host_resolver_ = new MockHostResolver(); + cert_verifier_ = new CertVerifier; proxy_service_ = ProxyService::CreateDirect(); ssl_config_service_ = new SSLConfigServiceDefaults; http_auth_handler_factory_ = HttpAuthHandlerFactory::CreateDefault( @@ -402,6 +409,7 @@ class SpdyURLRequestContext : public URLRequestContext { http_transaction_factory_ = new net::HttpCache( new HttpNetworkLayer(&socket_factory_, host_resolver_, + cert_verifier_, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -420,6 +428,7 @@ class SpdyURLRequestContext : public URLRequestContext { virtual ~SpdyURLRequestContext() { delete http_transaction_factory_; delete http_auth_handler_factory_; + delete cert_verifier_; delete host_resolver_; } diff --git a/net/tools/fetch/fetch_client.cc b/net/tools/fetch/fetch_client.cc index 800f3070..0d9682f 100644 --- a/net/tools/fetch/fetch_client.cc +++ b/net/tools/fetch/fetch_client.cc @@ -11,6 +11,7 @@ #include "base/metrics/stats_counters.h" #include "base/string_number_conversions.h" #include "base/string_util.h" +#include "net/base/cert_verifier.h" #include "net/base/completion_callback.h" #include "net/base/host_resolver.h" #include "net/base/io_buffer.h" @@ -140,6 +141,7 @@ int main(int argc, char**argv) { net::CreateSystemHostResolver(net::HostResolver::kDefaultParallelism, NULL, NULL)); + scoped_ptr<net::CertVerifier> cert_verifier(new net::CertVerifier); scoped_refptr<net::ProxyService> proxy_service( net::ProxyService::CreateDirect()); scoped_refptr<net::SSLConfigService> ssl_config_service( @@ -148,13 +150,15 @@ int main(int argc, char**argv) { scoped_ptr<net::HttpAuthHandlerFactory> http_auth_handler_factory( net::HttpAuthHandlerFactory::CreateDefault(host_resolver.get())); if (use_cache) { - factory = new net::HttpCache(host_resolver.get(), NULL, NULL, proxy_service, - ssl_config_service, http_auth_handler_factory.get(), NULL, NULL, + factory = new net::HttpCache(host_resolver.get(), cert_verifier.get(), + NULL, NULL, proxy_service, ssl_config_service, + http_auth_handler_factory.get(), NULL, NULL, net::HttpCache::DefaultBackend::InMemory(0)); } else { factory = new net::HttpNetworkLayer( net::ClientSocketFactory::GetDefaultFactory(), host_resolver.get(), + cert_verifier.get(), NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, @@ -204,7 +208,7 @@ int main(int argc, char**argv) { // Dump the stats table. printf("<stats>\n"); int counter_max = table.GetMaxCounters(); - for (int index=0; index < counter_max; index++) { + for (int index = 0; index < counter_max; index++) { std::string name(table.GetRowName(index)); if (name.length() > 0) { int value = table.GetRowValue(index); diff --git a/net/url_request/url_request_context.cc b/net/url_request/url_request_context.cc index 281aa7e..04f0da0 100644 --- a/net/url_request/url_request_context.cc +++ b/net/url_request/url_request_context.cc @@ -11,6 +11,7 @@ URLRequestContext::URLRequestContext() : net_log_(NULL), host_resolver_(NULL), + cert_verifier_(NULL), dnsrr_resolver_(NULL), dns_cert_checker_(NULL), http_transaction_factory_(NULL), diff --git a/net/url_request/url_request_context.h b/net/url_request/url_request_context.h index f8a6c7d..d3ba85f 100644 --- a/net/url_request/url_request_context.h +++ b/net/url_request/url_request_context.h @@ -21,6 +21,7 @@ #include "net/socket/dns_cert_provenance_checker.h" namespace net { +class CertVerifier; class CookiePolicy; class CookieStore; class DnsCertProvenanceChecker; @@ -50,6 +51,10 @@ class URLRequestContext return host_resolver_; } + net::CertVerifier* cert_verifier() const { + return cert_verifier_; + } + net::DnsRRResolver* dnsrr_resolver() const { return dnsrr_resolver_; } @@ -130,6 +135,7 @@ class URLRequestContext // subclasses. net::NetLog* net_log_; net::HostResolver* host_resolver_; + net::CertVerifier* cert_verifier_; net::DnsRRResolver* dnsrr_resolver_; scoped_ptr<net::DnsCertProvenanceChecker> dns_cert_checker_; scoped_refptr<net::ProxyService> proxy_service_; diff --git a/net/url_request/url_request_unittest.h b/net/url_request/url_request_unittest.h index 16b4dc6..50236fd 100644 --- a/net/url_request/url_request_unittest.h +++ b/net/url_request/url_request_unittest.h @@ -20,6 +20,7 @@ #include "base/thread.h" #include "base/time.h" #include "base/utf_string_conversions.h" +#include "net/base/cert_verifier.h" #include "net/base/cookie_monster.h" #include "net/base/cookie_policy.h" #include "net/base/host_resolver.h" @@ -150,17 +151,20 @@ class TestURLRequestContext : public URLRequestContext { delete ftp_transaction_factory_; delete http_transaction_factory_; delete http_auth_handler_factory_; + delete cert_verifier_; delete host_resolver_; } private: void Init() { + cert_verifier_ = new net::CertVerifier; ftp_transaction_factory_ = new net::FtpNetworkLayer(host_resolver_); ssl_config_service_ = new net::SSLConfigServiceDefaults; http_auth_handler_factory_ = net::HttpAuthHandlerFactory::CreateDefault( host_resolver_); http_transaction_factory_ = new net::HttpCache( net::HttpNetworkLayer::CreateFactory(host_resolver_, + cert_verifier_, NULL /* dnsrr_resolver */, NULL /* dns_cert_checker */, NULL /* ssl_host_info_factory */, |