diff options
author | mattm@chromium.org <mattm@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-01-25 03:57:30 +0000 |
---|---|---|
committer | mattm@chromium.org <mattm@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-01-25 03:57:30 +0000 |
commit | 646a620d85405183eb7de2613ff3ed64223a8dc0 (patch) | |
tree | 5859a208f1ded78c16db6d824553aef2f60274ec | |
parent | e0f104c61c526fc9d0a09bc221ec393acabd3c8f (diff) | |
download | chromium_src-646a620d85405183eb7de2613ff3ed64223a8dc0.zip chromium_src-646a620d85405183eb7de2613ff3ed64223a8dc0.tar.gz chromium_src-646a620d85405183eb7de2613ff3ed64223a8dc0.tar.bz2 |
Make ServerBoundCertStore interface async, move SQLiteServerBoundCertStore load onto DB thread.
Fix chromeos::ProfileAuthData::Transfer to only transfer server bound certs when cookies are being transferred.
BUG=89665,166919
Review URL: https://chromiumcodereview.appspot.com/11742037
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@178742 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | chrome/browser/browsing_data/browsing_data_remover.cc | 8 | ||||
-rw-r--r-- | chrome/browser/browsing_data/browsing_data_remover.h | 5 | ||||
-rw-r--r-- | chrome/browser/browsing_data/browsing_data_remover_unittest.cc | 14 | ||||
-rw-r--r-- | chrome/browser/browsing_data/browsing_data_server_bound_cert_helper.cc | 58 | ||||
-rw-r--r-- | chrome/browser/chromeos/login/profile_auth_data.cc | 262 | ||||
-rw-r--r-- | chrome/browser/chromeos/login/profile_auth_data.h | 6 | ||||
-rw-r--r-- | chrome/browser/net/sqlite_server_bound_cert_store.cc | 62 | ||||
-rw-r--r-- | chrome/browser/net/sqlite_server_bound_cert_store.h | 4 | ||||
-rw-r--r-- | chrome/browser/net/sqlite_server_bound_cert_store_unittest.cc | 37 | ||||
-rw-r--r-- | net/base/default_server_bound_cert_store.cc | 410 | ||||
-rw-r--r-- | net/base/default_server_bound_cert_store.h | 84 | ||||
-rw-r--r-- | net/base/default_server_bound_cert_store_unittest.cc | 331 | ||||
-rw-r--r-- | net/base/server_bound_cert_service.cc | 186 | ||||
-rw-r--r-- | net/base/server_bound_cert_service.h | 17 | ||||
-rw-r--r-- | net/base/server_bound_cert_store.h | 45 |
15 files changed, 1189 insertions, 340 deletions
diff --git a/chrome/browser/browsing_data/browsing_data_remover.cc b/chrome/browser/browsing_data/browsing_data_remover.cc index d55c7bc..0476799 100644 --- a/chrome/browser/browsing_data/browsing_data_remover.cc +++ b/chrome/browser/browsing_data/browsing_data_remover.cc @@ -991,7 +991,13 @@ void BrowsingDataRemover::ClearServerBoundCertsOnIOThread( net::ServerBoundCertService* server_bound_cert_service = rq_context->GetURLRequestContext()->server_bound_cert_service(); server_bound_cert_service->GetCertStore()->DeleteAllCreatedBetween( - delete_begin_, delete_end_); + delete_begin_, delete_end_, + base::Bind(&BrowsingDataRemover::OnClearedServerBoundCertsOnIOThread, + base::Unretained(this), base::Unretained(rq_context))); +} + +void BrowsingDataRemover::OnClearedServerBoundCertsOnIOThread( + net::URLRequestContextGetter* rq_context) { // Need to close open SSL connections which may be using the channel ids we // are deleting. // TODO(mattm): http://crbug.com/166069 Make the server bound cert diff --git a/chrome/browser/browsing_data/browsing_data_remover.h b/chrome/browser/browsing_data/browsing_data_remover.h index 7222237..9cc02b0 100644 --- a/chrome/browser/browsing_data/browsing_data_remover.h +++ b/chrome/browser/browsing_data/browsing_data_remover.h @@ -321,6 +321,11 @@ class BrowsingDataRemover : public content::NotificationObserver, void ClearServerBoundCertsOnIOThread( net::URLRequestContextGetter* rq_context); + // Callback on IO Thread when server bound certs have been deleted. Clears SSL + // connection pool and posts to UI thread to run OnClearedServerBoundCerts. + void OnClearedServerBoundCertsOnIOThread( + net::URLRequestContextGetter* rq_context); + // Callback when server bound certs have been deleted. Invokes // NotifyAndDeleteIfDone. void OnClearedServerBoundCerts(); diff --git a/chrome/browser/browsing_data/browsing_data_remover_unittest.cc b/chrome/browser/browsing_data/browsing_data_remover_unittest.cc index 9f8cede..ad98c33 100644 --- a/chrome/browser/browsing_data/browsing_data_remover_unittest.cc +++ b/chrome/browser/browsing_data/browsing_data_remover_unittest.cc @@ -297,6 +297,11 @@ class RemoveServerBoundCertTester : public net::SSLConfigService::Observer { now + base::TimeDelta::FromDays(1)); } + void GetCertList(net::ServerBoundCertStore::ServerBoundCertList* certs) { + GetCertStore()->GetAllServerBoundCerts( + base::Bind(&RemoveServerBoundCertTester::GetAllCertsCallback, certs)); + } + net::ServerBoundCertStore* GetCertStore() { return server_bound_cert_service_->GetCertStore(); } @@ -311,6 +316,12 @@ class RemoveServerBoundCertTester : public net::SSLConfigService::Observer { } private: + static void GetAllCertsCallback( + net::ServerBoundCertStore::ServerBoundCertList* dest, + const net::ServerBoundCertStore::ServerBoundCertList& result) { + *dest = result; + } + net::ServerBoundCertService* server_bound_cert_service_; scoped_refptr<net::SSLConfigService> ssl_config_service_; int ssl_config_changed_count_; @@ -711,7 +722,8 @@ TEST_F(BrowsingDataRemoverTest, RemoveServerBoundCertLastHour) { EXPECT_EQ(1, tester.ssl_config_changed_count()); ASSERT_EQ(1, tester.ServerBoundCertCount()); net::ServerBoundCertStore::ServerBoundCertList certs; - tester.GetCertStore()->GetAllServerBoundCerts(&certs); + tester.GetCertList(&certs); + ASSERT_EQ(1U, certs.size()); EXPECT_EQ(kTestOrigin2, certs.front().server_identifier()); } diff --git a/chrome/browser/browsing_data/browsing_data_server_bound_cert_helper.cc b/chrome/browser/browsing_data/browsing_data_server_bound_cert_helper.cc index 3aafec6..7cad5c3 100644 --- a/chrome/browser/browsing_data/browsing_data_server_bound_cert_helper.cc +++ b/chrome/browser/browsing_data/browsing_data_server_bound_cert_helper.cc @@ -31,17 +31,18 @@ class BrowsingDataServerBoundCertHelperImpl // Fetch the certs. This must be called in the IO thread. void FetchOnIOThread(); + void OnFetchComplete( + const net::ServerBoundCertStore::ServerBoundCertList& cert_list); + // Notifies the completion callback. This must be called in the UI thread. - void NotifyInUIThread(); + void NotifyInUIThread( + const net::ServerBoundCertStore::ServerBoundCertList& cert_list); // Delete a single cert. This must be called in IO thread. void DeleteOnIOThread(const std::string& server_id); - // Access to |server_bound_cert_list_| is triggered indirectly via the UI - // thread and guarded by |is_fetching_|. This means |server_bound_cert_list_| - // is only accessed while |is_fetching_| is true. The flag |is_fetching_| is - // only accessed on the UI thread. - net::ServerBoundCertStore::ServerBoundCertList server_bound_cert_list_; + // Called when deletion is done. + void DeleteCallback(); // Indicates whether or not we're currently fetching information: // it's true when StartFetching() is called in the UI thread, and it's reset @@ -97,20 +98,28 @@ void BrowsingDataServerBoundCertHelperImpl::FetchOnIOThread() { request_context_getter_->GetURLRequestContext()-> server_bound_cert_service()->GetCertStore(); if (cert_store) { - server_bound_cert_list_.clear(); - cert_store->GetAllServerBoundCerts(&server_bound_cert_list_); - content::BrowserThread::PostTask( - content::BrowserThread::UI, FROM_HERE, - base::Bind(&BrowsingDataServerBoundCertHelperImpl::NotifyInUIThread, - this)); + cert_store->GetAllServerBoundCerts(base::Bind( + &BrowsingDataServerBoundCertHelperImpl::OnFetchComplete, this)); + } else { + OnFetchComplete(net::ServerBoundCertStore::ServerBoundCertList()); } } -void BrowsingDataServerBoundCertHelperImpl::NotifyInUIThread() { +void BrowsingDataServerBoundCertHelperImpl::OnFetchComplete( + const net::ServerBoundCertStore::ServerBoundCertList& cert_list) { + DCHECK(content::BrowserThread::CurrentlyOn(content::BrowserThread::IO)); + content::BrowserThread::PostTask( + content::BrowserThread::UI, FROM_HERE, + base::Bind(&BrowsingDataServerBoundCertHelperImpl::NotifyInUIThread, + this, cert_list)); +} + +void BrowsingDataServerBoundCertHelperImpl::NotifyInUIThread( + const net::ServerBoundCertStore::ServerBoundCertList& cert_list) { DCHECK(content::BrowserThread::CurrentlyOn(content::BrowserThread::UI)); DCHECK(is_fetching_); is_fetching_ = false; - completion_callback_.Run(server_bound_cert_list_); + completion_callback_.Run(cert_list); completion_callback_.Reset(); } @@ -121,16 +130,23 @@ void BrowsingDataServerBoundCertHelperImpl::DeleteOnIOThread( request_context_getter_->GetURLRequestContext()-> server_bound_cert_service()->GetCertStore(); if (cert_store) { - cert_store->DeleteServerBoundCert(server_id); - // Need to close open SSL connections which may be using the channel ids we - // are deleting. - // TODO(mattm): http://crbug.com/166069 Make the server bound cert - // service/store have observers that can notify relevant things directly. - request_context_getter_->GetURLRequestContext()->ssl_config_service()-> - NotifySSLConfigChange(); + cert_store->DeleteServerBoundCert( + server_id, + base::Bind(&BrowsingDataServerBoundCertHelperImpl::DeleteCallback, + this)); } } +void BrowsingDataServerBoundCertHelperImpl::DeleteCallback() { + DCHECK(content::BrowserThread::CurrentlyOn(content::BrowserThread::IO)); + // Need to close open SSL connections which may be using the channel ids we + // are deleting. + // TODO(mattm): http://crbug.com/166069 Make the server bound cert + // service/store have observers that can notify relevant things directly. + request_context_getter_->GetURLRequestContext()->ssl_config_service()-> + NotifySSLConfigChange(); +} + } // namespace // static diff --git a/chrome/browser/chromeos/login/profile_auth_data.cc b/chrome/browser/chromeos/login/profile_auth_data.cc index ab79b03..e83d945 100644 --- a/chrome/browser/chromeos/login/profile_auth_data.cc +++ b/chrome/browser/chromeos/login/profile_auth_data.cc @@ -22,127 +22,189 @@ namespace chromeos { namespace { -// Callback for transferring |cookies_to_transfer| into |cookie_monster| if -// its jar is completely empty. -void OnTransferCookiesIfEmptyJar( - net::CookieMonster* cookie_monster, - const net::CookieList& cookies_to_transfer, - const base::Callback<void()>& cookies_transfered_callback, - const net::CookieList& cookies_in_jar) { - std::string sid; - std::string lsid; - // Transfer only if the existing cookie jar is empty. - if (!cookies_in_jar.size()) - cookie_monster->InitializeFrom(cookies_to_transfer); +class ProfileAuthDataTransferer { + public: + ProfileAuthDataTransferer( + Profile* from_profile, + Profile* to_profile, + bool transfer_cookies, + const base::Closure& completion_callback); + + void BeginTransfer(); + + private: + void BeginTransferOnIOThread(); + void MaybeDoCookieAndCertTransfer(); + void Finish(); + + void OnTransferCookiesIfEmptyJar(const net::CookieList& cookies_in_jar); + void OnGetCookiesToTransfer(const net::CookieList& cookies_to_transfer); + void RetrieveDefaultCookies(); + void OnGetServerBoundCertsToTransfer( + const net::ServerBoundCertStore::ServerBoundCertList& certs); + void RetrieveDefaultServerBoundCerts(); + void TransferDefaultAuthCache(); + + scoped_refptr<net::URLRequestContextGetter> from_context_; + scoped_refptr<net::URLRequestContextGetter> to_context_; + bool transfer_cookies_; + base::Closure completion_callback_; + + net::CookieList cookies_to_transfer_; + net::ServerBoundCertStore::ServerBoundCertList certs_to_transfer_; + + bool got_cookies_; + bool got_server_bound_certs_; +}; + +ProfileAuthDataTransferer::ProfileAuthDataTransferer( + Profile* from_profile, + Profile* to_profile, + bool transfer_cookies, + const base::Closure& completion_callback) + : from_context_(from_profile->GetRequestContext()), + to_context_(to_profile->GetRequestContext()), + transfer_cookies_(transfer_cookies), + completion_callback_(completion_callback), + got_cookies_(false), + got_server_bound_certs_(false) { +} +void ProfileAuthDataTransferer::BeginTransfer() { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); + // If we aren't transferring cookies, post the completion callback + // immediately. Otherwise, it will be called when both cookies and channel + // ids are finished transferring. + if (!transfer_cookies_) { + BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, completion_callback_); + // Null the callback so that when Finish is called the callback won't be + // called again. + completion_callback_.Reset(); + } BrowserThread::PostTask( - BrowserThread::UI, FROM_HERE, cookies_transfered_callback); - return; + BrowserThread::IO, FROM_HERE, + base::Bind(&ProfileAuthDataTransferer::BeginTransferOnIOThread, + base::Unretained(this))); } -// Callback for receiving |cookies_to_transfer| from the authentication profile -// cookie jar. -void OnGetCookiesToTransfer( - net::CookieMonster* cookie_monster, - const base::Callback<void()>& cookies_transfered_callback, - const net::CookieList& cookies_to_transfer) { +void ProfileAuthDataTransferer::BeginTransferOnIOThread() { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); + TransferDefaultAuthCache(); + + if (transfer_cookies_) { + RetrieveDefaultCookies(); + RetrieveDefaultServerBoundCerts(); + } else { + Finish(); + } +} + +// If both cookies and server bound certs have been retrieved, see if we need to +// do the actual transfer. +void ProfileAuthDataTransferer::MaybeDoCookieAndCertTransfer() { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); + if (!(got_cookies_ && got_server_bound_certs_)) + return; // Nothing to transfer over? - if (!cookies_to_transfer.size()) { - BrowserThread::PostTask( - BrowserThread::UI, FROM_HERE, cookies_transfered_callback); + if (!cookies_to_transfer_.size()) { + Finish(); return; } + // Now let's see if the target cookie monster's jar is even empty. - cookie_monster->GetAllCookiesAsync( - base::Bind(&OnTransferCookiesIfEmptyJar, - make_scoped_refptr(cookie_monster), - cookies_to_transfer, - cookies_transfered_callback)); + net::CookieStore* to_store = + to_context_->GetURLRequestContext()->cookie_store(); + net::CookieMonster* to_monster = to_store->GetCookieMonster(); + to_monster->GetAllCookiesAsync( + base::Bind(&ProfileAuthDataTransferer::OnTransferCookiesIfEmptyJar, + base::Unretained(this))); } -// Transfers initial set of Profile cookies from the |from_context| to cookie -// jar of |to_context|. -void TransferDefaultCookiesOnIOThread( - net::URLRequestContextGetter* from_context, - net::URLRequestContextGetter* to_context, - const base::Callback<void()>& cookies_transfered_callback) { +// Post the |completion_callback_| and delete ourself. +void ProfileAuthDataTransferer::Finish() { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); - net::CookieStore* to_store = - to_context->GetURLRequestContext()->cookie_store(); - net::CookieMonster* to_monster = to_store->GetCookieMonster(); + if (!completion_callback_.is_null()) + BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, completion_callback_); + delete this; +} - net::CookieStore* from_store = - from_context->GetURLRequestContext()->cookie_store(); - net::CookieMonster* from_monster = from_store->GetCookieMonster(); - from_monster->SetKeepExpiredCookies(); - from_monster->GetAllCookiesAsync(base::Bind(&OnGetCookiesToTransfer, - make_scoped_refptr(to_monster), - cookies_transfered_callback)); +// Callback for transferring |cookies_to_transfer_| into |to_context_|'s +// CookieMonster if its jar is completely empty. If authentication was +// performed by an extension, then the set of cookies that was acquired through +// such that process will be automatically transfered into the profile. +void ProfileAuthDataTransferer::OnTransferCookiesIfEmptyJar( + const net::CookieList& cookies_in_jar) { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); + // Transfer only if the existing cookie jar is empty. + if (!cookies_in_jar.size()) { + net::CookieStore* to_store = + to_context_->GetURLRequestContext()->cookie_store(); + net::CookieMonster* to_monster = to_store->GetCookieMonster(); + to_monster->InitializeFrom(cookies_to_transfer_); + + net::ServerBoundCertService* to_cert_service = + to_context_->GetURLRequestContext()->server_bound_cert_service(); + to_cert_service->GetCertStore()->InitializeFrom(certs_to_transfer_); + } + + Finish(); } -// Transfers default server bound certs of |from_context| to server bound certs -// storage of |to_context|. -void TransferDefaultServerBoundCertsIOThread( - net::URLRequestContextGetter* from_context, - net::URLRequestContextGetter* to_context) { +// Callback for receiving |cookies_to_transfer| from the authentication profile +// cookie jar. +void ProfileAuthDataTransferer::OnGetCookiesToTransfer( + const net::CookieList& cookies_to_transfer) { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); - net::ServerBoundCertService* default_service = - from_context->GetURLRequestContext()->server_bound_cert_service(); - net::ServerBoundCertStore::ServerBoundCertList server_bound_certs; - default_service->GetCertStore()->GetAllServerBoundCerts(&server_bound_certs); + got_cookies_ = true; + MaybeDoCookieAndCertTransfer(); +} + +// Retrieves initial set of Profile cookies from the |from_context_|. +void ProfileAuthDataTransferer::RetrieveDefaultCookies() { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); - net::ServerBoundCertService* new_service = - to_context->GetURLRequestContext()->server_bound_cert_service(); - new_service->GetCertStore()->InitializeFrom(server_bound_certs); + net::CookieStore* from_store = + from_context_->GetURLRequestContext()->cookie_store(); + net::CookieMonster* from_monster = from_store->GetCookieMonster(); + from_monster->SetKeepExpiredCookies(); + from_monster->GetAllCookiesAsync( + base::Bind(&ProfileAuthDataTransferer::OnGetCookiesToTransfer, + base::Unretained(this))); } -// Transfers default auth cache of |from_context| to auth cache storage of -// |to_context|. -void TransferDefaultAuthCacheOnIOThread( - net::URLRequestContextGetter* from_context, - net::URLRequestContextGetter* to_context) { +// Callback for receiving |cookies_to_transfer| from the authentication profile +// cookie jar. +void ProfileAuthDataTransferer::OnGetServerBoundCertsToTransfer( + const net::ServerBoundCertStore::ServerBoundCertList& certs) { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); - net::HttpAuthCache* new_cache = to_context->GetURLRequestContext()-> - http_transaction_factory()->GetSession()->http_auth_cache(); - new_cache->UpdateAllFrom(*from_context->GetURLRequestContext()-> - http_transaction_factory()->GetSession()->http_auth_cache()); + certs_to_transfer_ = certs; + got_server_bound_certs_ = true; + MaybeDoCookieAndCertTransfer(); } -// Transfers cookies and server bound certs from the |from_profile| into -// the |to_profile|. If authentication was performed by an extension, then -// the set of cookies that was acquired through such that process will be -// automatically transfered into the profile. -void TransferDefaultCookiesAndServerBoundCerts( - Profile* from_profile, - Profile* to_profile, - const base::Callback<void()>& cookies_transfered_callback) { - BrowserThread::PostTask( - BrowserThread::IO, FROM_HERE, - base::Bind(&TransferDefaultCookiesOnIOThread, - make_scoped_refptr(from_profile->GetRequestContext()), - make_scoped_refptr(to_profile->GetRequestContext()), - cookies_transfered_callback)); - BrowserThread::PostTask( - BrowserThread::IO, FROM_HERE, - base::Bind(&TransferDefaultServerBoundCertsIOThread, - make_scoped_refptr(from_profile->GetRequestContext()), - make_scoped_refptr(to_profile->GetRequestContext()))); +// Retrieves server bound certs of |from_context_|. +void ProfileAuthDataTransferer::RetrieveDefaultServerBoundCerts() { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); + net::ServerBoundCertService* from_service = + from_context_->GetURLRequestContext()->server_bound_cert_service(); + + from_service->GetCertStore()->GetAllServerBoundCerts( + base::Bind(&ProfileAuthDataTransferer::OnGetServerBoundCertsToTransfer, + base::Unretained(this))); } -// Transfers HTTP authentication cache from the |from_profile| -// into the |to_profile|. If user was required to authenticate with a proxy +// Transfers HTTP authentication cache from the |from_context_| +// into the |to_context_|. If user was required to authenticate with a proxy // during the login, this authentication information will be transferred // into the new session. -void TransferDefaultAuthCache(Profile* from_profile, - Profile* to_profile) { - BrowserThread::PostTask( - BrowserThread::IO, FROM_HERE, - base::Bind(&TransferDefaultAuthCacheOnIOThread, - make_scoped_refptr(from_profile->GetRequestContext()), - make_scoped_refptr(to_profile->GetRequestContext()))); +void ProfileAuthDataTransferer::TransferDefaultAuthCache() { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); + net::HttpAuthCache* new_cache = to_context_->GetURLRequestContext()-> + http_transaction_factory()->GetSession()->http_auth_cache(); + new_cache->UpdateAllFrom(*from_context_->GetURLRequestContext()-> + http_transaction_factory()->GetSession()->http_auth_cache()); } } // namespace @@ -151,18 +213,10 @@ void ProfileAuthData::Transfer( Profile* from_profile, Profile* to_profile, bool transfer_cookies, - const base::Callback<void()>& cookies_transfered_callback) { + const base::Closure& completion_callback) { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); - if (transfer_cookies) { - TransferDefaultCookiesAndServerBoundCerts(from_profile, - to_profile, - cookies_transfered_callback); - } else { - BrowserThread::PostTask( - BrowserThread::UI, FROM_HERE, cookies_transfered_callback); - } - - TransferDefaultAuthCache(from_profile, to_profile); + (new ProfileAuthDataTransferer(from_profile, to_profile, transfer_cookies, + completion_callback))->BeginTransfer(); } } // namespace chromeos diff --git a/chrome/browser/chromeos/login/profile_auth_data.h b/chrome/browser/chromeos/login/profile_auth_data.h index 9cbd7f0..a7805e4 100644 --- a/chrome/browser/chromeos/login/profile_auth_data.h +++ b/chrome/browser/chromeos/login/profile_auth_data.h @@ -18,12 +18,12 @@ class ProfileAuthData { public: // Transfers proxy authentication cache and optionally |transfer_cookies| and // server bound certs from the profile that was used for authentication. - // |cookies_transfered_callback| will be called on UI thread after cookie - // transfer part of this operation is completed. + // |completion_callback| will be called on UI thread after the operation is + // completed. static void Transfer(Profile* from_profile, Profile* to_profile, bool transfer_cookies, - const base::Closure& cookies_transfered_callback); + const base::Closure& completion_callback); private: DISALLOW_IMPLICIT_CONSTRUCTORS(ProfileAuthData); diff --git a/chrome/browser/net/sqlite_server_bound_cert_store.cc b/chrome/browser/net/sqlite_server_bound_cert_store.cc index b85e939..c4a004c 100644 --- a/chrome/browser/net/sqlite_server_bound_cert_store.cc +++ b/chrome/browser/net/sqlite_server_bound_cert_store.cc @@ -41,9 +41,8 @@ class SQLiteServerBoundCertStore::Backend clear_on_exit_policy_(clear_on_exit_policy) { } - // Creates or load the SQLite database. - bool Load( - std::vector<net::DefaultServerBoundCertStore::ServerBoundCert*>* certs); + // Creates or loads the SQLite database. + void Load(const LoadedCallback& loaded_callback); // Batch a server bound cert addition. void AddServerBoundCert( @@ -63,6 +62,10 @@ class SQLiteServerBoundCertStore::Backend void SetForceKeepSessionState(); private: + void LoadOnDBThreadAndNotify(const LoadedCallback& loaded_callback); + void LoadOnDBThread( + std::vector<net::DefaultServerBoundCertStore::ServerBoundCert*>* certs); + friend class base::RefCountedThreadSafe<SQLiteServerBoundCertStore::Backend>; // You should call Close() before destructing this object. @@ -155,15 +158,36 @@ bool InitTable(sql::Connection* db) { } // namespace -bool SQLiteServerBoundCertStore::Backend::Load( - std::vector<net::DefaultServerBoundCertStore::ServerBoundCert*>* certs) { +void SQLiteServerBoundCertStore::Backend::Load( + const LoadedCallback& loaded_callback) { // This function should be called only once per instance. DCHECK(!db_.get()); - // TODO(paivanof@gmail.com): We do a lot of disk access in this function, - // thus we do an exception to allow IO on the UI thread. This code will be - // moved to the DB thread as part of http://crbug.com/89665. - base::ThreadRestrictions::ScopedAllowIO allow_io; + BrowserThread::PostTask( + BrowserThread::DB, FROM_HERE, + base::Bind(&Backend::LoadOnDBThreadAndNotify, this, loaded_callback)); +} + +void SQLiteServerBoundCertStore::Backend::LoadOnDBThreadAndNotify( + const LoadedCallback& loaded_callback) { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::DB)); + scoped_ptr<ScopedVector<net::DefaultServerBoundCertStore::ServerBoundCert> > + certs(new ScopedVector<net::DefaultServerBoundCertStore::ServerBoundCert>( + )); + + LoadOnDBThread(&certs->get()); + + BrowserThread::PostTask( + BrowserThread::IO, FROM_HERE, + base::Bind(loaded_callback, base::Passed(&certs))); +} + +void SQLiteServerBoundCertStore::Backend::LoadOnDBThread( + std::vector<net::DefaultServerBoundCertStore::ServerBoundCert*>* certs) { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::DB)); + + // This method should be called only once per instance. + DCHECK(!db_.get()); base::TimeTicks start = base::TimeTicks::Now(); @@ -171,7 +195,7 @@ bool SQLiteServerBoundCertStore::Backend::Load( // from it. const FilePath dir = path_.DirName(); if (!file_util::PathExists(dir) && !file_util::CreateDirectory(dir)) - return false; + return; int64 db_size = 0; if (file_util::GetFileSize(path_, &db_size)) @@ -181,13 +205,13 @@ bool SQLiteServerBoundCertStore::Backend::Load( if (!db_->Open(path_)) { NOTREACHED() << "Unable to open cert DB."; db_.reset(); - return false; + return; } if (!EnsureDatabaseVersion() || !InitTable(db_.get())) { NOTREACHED() << "Unable to open cert DB."; db_.reset(); - return false; + return; } db_->Preload(); @@ -198,7 +222,7 @@ bool SQLiteServerBoundCertStore::Backend::Load( "creation_time FROM origin_bound_certs")); if (!smt.is_valid()) { db_.reset(); - return false; + return; } while (smt.Step()) { @@ -218,12 +242,14 @@ bool SQLiteServerBoundCertStore::Backend::Load( } UMA_HISTOGRAM_COUNTS_10000("DomainBoundCerts.DBLoadedCount", certs->size()); + base::TimeDelta load_time = base::TimeTicks::Now() - start; UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.DBLoadTime", - base::TimeTicks::Now() - start, + load_time, base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(1), 50); - return true; + DVLOG(1) << "loaded " << certs->size() << " in " << load_time.InMilliseconds() + << " ms"; } bool SQLiteServerBoundCertStore::Backend::EnsureDatabaseVersion() { @@ -544,9 +570,9 @@ SQLiteServerBoundCertStore::SQLiteServerBoundCertStore( : backend_(new Backend(path, clear_on_exit_policy)) { } -bool SQLiteServerBoundCertStore::Load( - std::vector<net::DefaultServerBoundCertStore::ServerBoundCert*>* certs) { - return backend_->Load(certs); +void SQLiteServerBoundCertStore::Load( + const LoadedCallback& loaded_callback) { + backend_->Load(loaded_callback); } void SQLiteServerBoundCertStore::AddServerBoundCert( diff --git a/chrome/browser/net/sqlite_server_bound_cert_store.h b/chrome/browser/net/sqlite_server_bound_cert_store.h index 91cb4b4..183923a 100644 --- a/chrome/browser/net/sqlite_server_bound_cert_store.h +++ b/chrome/browser/net/sqlite_server_bound_cert_store.h @@ -28,9 +28,7 @@ class SQLiteServerBoundCertStore ClearOnExitPolicy* clear_on_exit_policy); // net::DefaultServerBoundCertStore::PersistentStore: - virtual bool Load( - std::vector<net::DefaultServerBoundCertStore::ServerBoundCert*>* certs) - OVERRIDE; + virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE; virtual void AddServerBoundCert( const net::DefaultServerBoundCertStore::ServerBoundCert& cert) OVERRIDE; virtual void DeleteServerBoundCert( diff --git a/chrome/browser/net/sqlite_server_bound_cert_store_unittest.cc b/chrome/browser/net/sqlite_server_bound_cert_store_unittest.cc index 6c62fe7..3dc1b45 100644 --- a/chrome/browser/net/sqlite_server_bound_cert_store_unittest.cc +++ b/chrome/browser/net/sqlite_server_bound_cert_store_unittest.cc @@ -8,6 +8,7 @@ #include "base/memory/ref_counted.h" #include "base/memory/scoped_vector.h" #include "base/message_loop.h" +#include "base/run_loop.h" #include "base/stl_util.h" #include "base/test/thread_test_helper.h" #include "chrome/browser/net/clear_on_exit_policy.h" @@ -25,7 +26,26 @@ using content::BrowserThread; class SQLiteServerBoundCertStoreTest : public testing::Test { public: SQLiteServerBoundCertStoreTest() - : db_thread_(BrowserThread::DB) { + : db_thread_(BrowserThread::DB), + io_thread_(BrowserThread::IO, &message_loop_) {} + + void Load( + ScopedVector<net::DefaultServerBoundCertStore::ServerBoundCert>* certs) { + base::RunLoop run_loop; + store_->Load(base::Bind(&SQLiteServerBoundCertStoreTest::OnLoaded, + base::Unretained(this), + &run_loop)); + run_loop.Run(); + certs->swap(certs_); + certs_.clear(); + } + + void OnLoaded( + base::RunLoop* run_loop, + scoped_ptr<ScopedVector< + net::DefaultServerBoundCertStore::ServerBoundCert> > certs) { + certs_.swap(*certs); + run_loop->Quit(); } protected: @@ -66,7 +86,7 @@ class SQLiteServerBoundCertStoreTest : public testing::Test { store_ = new SQLiteServerBoundCertStore( temp_dir_.path().Append(chrome::kOBCertFilename), NULL); ScopedVector<net::DefaultServerBoundCertStore::ServerBoundCert> certs; - ASSERT_TRUE(store_->Load(&certs.get())); + Load(&certs); ASSERT_EQ(0u, certs.size()); // Make sure the store gets written at least once. store_->AddServerBoundCert( @@ -78,9 +98,12 @@ class SQLiteServerBoundCertStoreTest : public testing::Test { "a", "b")); } + MessageLoopForIO message_loop_; content::TestBrowserThread db_thread_; + content::TestBrowserThread io_thread_; base::ScopedTempDir temp_dir_; scoped_refptr<SQLiteServerBoundCertStore> store_; + ScopedVector<net::DefaultServerBoundCertStore::ServerBoundCert> certs_; }; // Test if data is stored as expected in the SQLite database. @@ -107,7 +130,7 @@ TEST_F(SQLiteServerBoundCertStoreTest, TestPersistence) { temp_dir_.path().Append(chrome::kOBCertFilename), NULL); // Reload and test for persistence - ASSERT_TRUE(store_->Load(&certs.get())); + Load(&certs); ASSERT_EQ(2U, certs.size()); net::DefaultServerBoundCertStore::ServerBoundCert* ec_cert; net::DefaultServerBoundCertStore::ServerBoundCert* rsa_cert; @@ -142,7 +165,7 @@ TEST_F(SQLiteServerBoundCertStoreTest, TestPersistence) { temp_dir_.path().Append(chrome::kOBCertFilename), NULL); // Reload and check if the cert has been removed. - ASSERT_TRUE(store_->Load(&certs.get())); + Load(&certs); ASSERT_EQ(0U, certs.size()); } @@ -193,7 +216,7 @@ TEST_F(SQLiteServerBoundCertStoreTest, TestUpgradeV1) { store_ = new SQLiteServerBoundCertStore(v1_db_path, NULL); // Load the database and ensure the certs can be read and are marked as RSA. - ASSERT_TRUE(store_->Load(&certs.get())); + Load(&certs); ASSERT_EQ(2U, certs.size()); ASSERT_STREQ("google.com", certs[0]->server_identifier().c_str()); @@ -281,7 +304,7 @@ TEST_F(SQLiteServerBoundCertStoreTest, TestUpgradeV2) { store_ = new SQLiteServerBoundCertStore(v2_db_path, NULL); // Load the database and ensure the certs can be read and are marked as RSA. - ASSERT_TRUE(store_->Load(&certs.get())); + Load(&certs); ASSERT_EQ(2U, certs.size()); ASSERT_STREQ("google.com", certs[0]->server_identifier().c_str()); @@ -371,7 +394,7 @@ TEST_F(SQLiteServerBoundCertStoreTest, TestUpgradeV3) { store_ = new SQLiteServerBoundCertStore(v3_db_path, NULL); // Load the database and ensure the certs can be read and are marked as RSA. - ASSERT_TRUE(store_->Load(&certs.get())); + Load(&certs); ASSERT_EQ(2U, certs.size()); ASSERT_STREQ("google.com", certs[0]->server_identifier().c_str()); diff --git a/net/base/default_server_bound_cert_store.cc b/net/base/default_server_bound_cert_store.cc index 05cd826..d4dd1d0 100644 --- a/net/base/default_server_bound_cert_store.cc +++ b/net/base/default_server_bound_cert_store.cc @@ -6,16 +6,235 @@ #include "base/bind.h" #include "base/message_loop.h" +#include "base/metrics/histogram.h" namespace net { +// -------------------------------------------------------------------------- +// Task +class DefaultServerBoundCertStore::Task { + public: + virtual ~Task(); + + // Runs the task and invokes the client callback on the thread that + // originally constructed the task. + virtual void Run(DefaultServerBoundCertStore* store) = 0; + + protected: + void InvokeCallback(base::Closure callback) const; +}; + +DefaultServerBoundCertStore::Task::~Task() { +} + +void DefaultServerBoundCertStore::Task::InvokeCallback( + base::Closure callback) const { + if (!callback.is_null()) + callback.Run(); +} + +// -------------------------------------------------------------------------- +// GetServerBoundCertTask +class DefaultServerBoundCertStore::GetServerBoundCertTask + : public DefaultServerBoundCertStore::Task { + public: + GetServerBoundCertTask(const std::string& server_identifier, + const GetCertCallback& callback); + virtual ~GetServerBoundCertTask(); + virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE; + + private: + std::string server_identifier_; + GetCertCallback callback_; +}; + +DefaultServerBoundCertStore::GetServerBoundCertTask::GetServerBoundCertTask( + const std::string& server_identifier, + const GetCertCallback& callback) + : server_identifier_(server_identifier), + callback_(callback) { +} + +DefaultServerBoundCertStore::GetServerBoundCertTask::~GetServerBoundCertTask() { +} + +void DefaultServerBoundCertStore::GetServerBoundCertTask::Run( + DefaultServerBoundCertStore* store) { + SSLClientCertType type = CLIENT_CERT_INVALID_TYPE; + base::Time expiration_time; + std::string private_key_result; + std::string cert_result; + bool was_sync = store->GetServerBoundCert( + server_identifier_, &type, &expiration_time, &private_key_result, + &cert_result, GetCertCallback()); + DCHECK(was_sync); + + InvokeCallback(base::Bind(callback_, server_identifier_, type, + expiration_time, private_key_result, cert_result)); +} + +// -------------------------------------------------------------------------- +// SetServerBoundCertTask +class DefaultServerBoundCertStore::SetServerBoundCertTask + : public DefaultServerBoundCertStore::Task { + public: + SetServerBoundCertTask(const std::string& server_identifier, + SSLClientCertType type, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert); + virtual ~SetServerBoundCertTask(); + virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE; + + private: + std::string server_identifier_; + SSLClientCertType type_; + base::Time creation_time_; + base::Time expiration_time_; + std::string private_key_; + std::string cert_; +}; + +DefaultServerBoundCertStore::SetServerBoundCertTask::SetServerBoundCertTask( + const std::string& server_identifier, + SSLClientCertType type, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert) + : server_identifier_(server_identifier), + type_(type), + creation_time_(creation_time), + expiration_time_(expiration_time), + private_key_(private_key), + cert_(cert) { +} + +DefaultServerBoundCertStore::SetServerBoundCertTask::~SetServerBoundCertTask() { +} + +void DefaultServerBoundCertStore::SetServerBoundCertTask::Run( + DefaultServerBoundCertStore* store) { + store->SyncSetServerBoundCert(server_identifier_, type_, creation_time_, + expiration_time_, private_key_, cert_); +} + +// -------------------------------------------------------------------------- +// DeleteServerBoundCertTask +class DefaultServerBoundCertStore::DeleteServerBoundCertTask + : public DefaultServerBoundCertStore::Task { + public: + DeleteServerBoundCertTask(const std::string& server_identifier, + const base::Closure& callback); + virtual ~DeleteServerBoundCertTask(); + virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE; + + private: + std::string server_identifier_; + base::Closure callback_; +}; + +DefaultServerBoundCertStore::DeleteServerBoundCertTask:: + DeleteServerBoundCertTask( + const std::string& server_identifier, + const base::Closure& callback) + : server_identifier_(server_identifier), + callback_(callback) { +} + +DefaultServerBoundCertStore::DeleteServerBoundCertTask:: + ~DeleteServerBoundCertTask() { +} + +void DefaultServerBoundCertStore::DeleteServerBoundCertTask::Run( + DefaultServerBoundCertStore* store) { + store->SyncDeleteServerBoundCert(server_identifier_); + + InvokeCallback(callback_); +} + +// -------------------------------------------------------------------------- +// DeleteAllCreatedBetweenTask +class DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask + : public DefaultServerBoundCertStore::Task { + public: + DeleteAllCreatedBetweenTask(base::Time delete_begin, + base::Time delete_end, + const base::Closure& callback); + virtual ~DeleteAllCreatedBetweenTask(); + virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE; + + private: + base::Time delete_begin_; + base::Time delete_end_; + base::Closure callback_; +}; + +DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask:: + DeleteAllCreatedBetweenTask( + base::Time delete_begin, + base::Time delete_end, + const base::Closure& callback) + : delete_begin_(delete_begin), + delete_end_(delete_end), + callback_(callback) { +} + +DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask:: + ~DeleteAllCreatedBetweenTask() { +} + +void DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask::Run( + DefaultServerBoundCertStore* store) { + store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_); + + InvokeCallback(callback_); +} + +// -------------------------------------------------------------------------- +// GetAllServerBoundCertsTask +class DefaultServerBoundCertStore::GetAllServerBoundCertsTask + : public DefaultServerBoundCertStore::Task { + public: + explicit GetAllServerBoundCertsTask(const GetCertListCallback& callback); + virtual ~GetAllServerBoundCertsTask(); + virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE; + + private: + std::string server_identifier_; + GetCertListCallback callback_; +}; + +DefaultServerBoundCertStore::GetAllServerBoundCertsTask:: + GetAllServerBoundCertsTask(const GetCertListCallback& callback) + : callback_(callback) { +} + +DefaultServerBoundCertStore::GetAllServerBoundCertsTask:: + ~GetAllServerBoundCertsTask() { +} + +void DefaultServerBoundCertStore::GetAllServerBoundCertsTask::Run( + DefaultServerBoundCertStore* store) { + ServerBoundCertList cert_list; + store->SyncGetAllServerBoundCerts(&cert_list); + + InvokeCallback(base::Bind(callback_, cert_list)); +} + +// -------------------------------------------------------------------------- +// DefaultServerBoundCertStore + // static const size_t DefaultServerBoundCertStore::kMaxCerts = 3300; DefaultServerBoundCertStore::DefaultServerBoundCertStore( PersistentStore* store) : initialized_(false), - store_(store) {} + loaded_(false), + store_(store), + ALLOW_THIS_IN_INITIALIZER_LIST(weak_ptr_factory_(this)) {} void DefaultServerBoundCertStore::FlushStore( const base::Closure& completion_task) { @@ -30,21 +249,28 @@ void DefaultServerBoundCertStore::FlushStore( bool DefaultServerBoundCertStore::GetServerBoundCert( const std::string& server_identifier, SSLClientCertType* type, - base::Time* creation_time, base::Time* expiration_time, std::string* private_key_result, - std::string* cert_result) { + std::string* cert_result, + const GetCertCallback& callback) { DCHECK(CalledOnValidThread()); InitIfNecessary(); + if (!loaded_) { + EnqueueTask(scoped_ptr<Task>( + new GetServerBoundCertTask(server_identifier, callback))); + return false; + } + ServerBoundCertMap::iterator it = server_bound_certs_.find(server_identifier); - if (it == server_bound_certs_.end()) - return false; + if (it == server_bound_certs_.end()) { + *type = CLIENT_CERT_INVALID_TYPE; + return true; + } ServerBoundCert* cert = it->second; *type = cert->type(); - *creation_time = cert->creation_time(); *expiration_time = cert->expiration_time(); *private_key_result = cert->private_key(); *cert_result = cert->cert(); @@ -59,61 +285,38 @@ void DefaultServerBoundCertStore::SetServerBoundCert( base::Time expiration_time, const std::string& private_key, const std::string& cert) { - DCHECK(CalledOnValidThread()); - InitIfNecessary(); - - InternalDeleteServerBoundCert(server_identifier); - InternalInsertServerBoundCert( - server_identifier, - new ServerBoundCert( - server_identifier, type, creation_time, expiration_time, private_key, - cert)); + RunOrEnqueueTask(scoped_ptr<Task>(new SetServerBoundCertTask( + server_identifier, type, creation_time, expiration_time, private_key, + cert))); } void DefaultServerBoundCertStore::DeleteServerBoundCert( - const std::string& server_identifier) { - DCHECK(CalledOnValidThread()); - InitIfNecessary(); - InternalDeleteServerBoundCert(server_identifier); + const std::string& server_identifier, + const base::Closure& callback) { + RunOrEnqueueTask(scoped_ptr<Task>( + new DeleteServerBoundCertTask(server_identifier, callback))); } void DefaultServerBoundCertStore::DeleteAllCreatedBetween( base::Time delete_begin, - base::Time delete_end) { - DCHECK(CalledOnValidThread()); - InitIfNecessary(); - for (ServerBoundCertMap::iterator it = server_bound_certs_.begin(); - it != server_bound_certs_.end();) { - ServerBoundCertMap::iterator cur = it; - ++it; - ServerBoundCert* cert = cur->second; - if ((delete_begin.is_null() || cert->creation_time() >= delete_begin) && - (delete_end.is_null() || cert->creation_time() < delete_end)) { - if (store_) - store_->DeleteServerBoundCert(*cert); - delete cert; - server_bound_certs_.erase(cur); - } - } + base::Time delete_end, + const base::Closure& callback) { + RunOrEnqueueTask(scoped_ptr<Task>( + new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback))); } -void DefaultServerBoundCertStore::DeleteAll() { - DeleteAllCreatedBetween(base::Time(), base::Time()); +void DefaultServerBoundCertStore::DeleteAll( + const base::Closure& callback) { + DeleteAllCreatedBetween(base::Time(), base::Time(), callback); } void DefaultServerBoundCertStore::GetAllServerBoundCerts( - ServerBoundCertList* server_bound_certs) { - DCHECK(CalledOnValidThread()); - InitIfNecessary(); - for (ServerBoundCertMap::iterator it = server_bound_certs_.begin(); - it != server_bound_certs_.end(); ++it) { - server_bound_certs->push_back(*it->second); - } + const GetCertListCallback& callback) { + RunOrEnqueueTask(scoped_ptr<Task>(new GetAllServerBoundCertsTask(callback))); } int DefaultServerBoundCertStore::GetCertCount() { DCHECK(CalledOnValidThread()); - InitIfNecessary(); return server_bound_certs_.size(); } @@ -143,23 +346,123 @@ void DefaultServerBoundCertStore::DeleteAllInMemory() { void DefaultServerBoundCertStore::InitStore() { DCHECK(CalledOnValidThread()); DCHECK(store_) << "Store must exist to initialize"; + DCHECK(!loaded_); + + store_->Load(base::Bind(&DefaultServerBoundCertStore::OnLoaded, + weak_ptr_factory_.GetWeakPtr())); +} - // Initialize the store and sync in any saved persistent certs. - std::vector<ServerBoundCert*> certs; - // Reserve space for the maximum amount of certs a database should have. - // This prevents multiple vector growth / copies as we append certs. - certs.reserve(kMaxCerts); - store_->Load(&certs); +void DefaultServerBoundCertStore::OnLoaded( + scoped_ptr<ScopedVector<ServerBoundCert> > certs) { + DCHECK(CalledOnValidThread()); - for (std::vector<ServerBoundCert*>::const_iterator it = certs.begin(); - it != certs.end(); ++it) { + for (std::vector<ServerBoundCert*>::const_iterator it = certs->begin(); + it != certs->end(); ++it) { + DCHECK(server_bound_certs_.find((*it)->server_identifier()) == + server_bound_certs_.end()); server_bound_certs_[(*it)->server_identifier()] = *it; } + certs->weak_clear(); + + loaded_ = true; + + base::TimeDelta wait_time; + if (!waiting_tasks_.empty()) + wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_; + DVLOG(1) << "Task delay " << wait_time.InMilliseconds(); + UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime", + wait_time, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 50); + UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount", + waiting_tasks_.size()); + + + for (ScopedVector<Task>::iterator i = waiting_tasks_.begin(); + i != waiting_tasks_.end(); ++i) + (*i)->Run(this); + waiting_tasks_.clear(); +} + +void DefaultServerBoundCertStore::SyncSetServerBoundCert( + const std::string& server_identifier, + SSLClientCertType type, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert) { + DCHECK(CalledOnValidThread()); + DCHECK(loaded_); + + InternalDeleteServerBoundCert(server_identifier); + InternalInsertServerBoundCert( + server_identifier, + new ServerBoundCert( + server_identifier, type, creation_time, expiration_time, private_key, + cert)); +} + +void DefaultServerBoundCertStore::SyncDeleteServerBoundCert( + const std::string& server_identifier) { + DCHECK(CalledOnValidThread()); + DCHECK(loaded_); + InternalDeleteServerBoundCert(server_identifier); +} + +void DefaultServerBoundCertStore::SyncDeleteAllCreatedBetween( + base::Time delete_begin, + base::Time delete_end) { + DCHECK(CalledOnValidThread()); + DCHECK(loaded_); + for (ServerBoundCertMap::iterator it = server_bound_certs_.begin(); + it != server_bound_certs_.end();) { + ServerBoundCertMap::iterator cur = it; + ++it; + ServerBoundCert* cert = cur->second; + if ((delete_begin.is_null() || cert->creation_time() >= delete_begin) && + (delete_end.is_null() || cert->creation_time() < delete_end)) { + if (store_) + store_->DeleteServerBoundCert(*cert); + delete cert; + server_bound_certs_.erase(cur); + } + } +} + +void DefaultServerBoundCertStore::SyncGetAllServerBoundCerts( + ServerBoundCertList* cert_list) { + DCHECK(CalledOnValidThread()); + DCHECK(loaded_); + for (ServerBoundCertMap::iterator it = server_bound_certs_.begin(); + it != server_bound_certs_.end(); ++it) + cert_list->push_back(*it->second); +} + +void DefaultServerBoundCertStore::EnqueueTask(scoped_ptr<Task> task) { + DCHECK(CalledOnValidThread()); + DCHECK(!loaded_); + if (waiting_tasks_.empty()) + waiting_tasks_start_time_ = base::TimeTicks::Now(); + waiting_tasks_.push_back(task.release()); +} + +void DefaultServerBoundCertStore::RunOrEnqueueTask(scoped_ptr<Task> task) { + DCHECK(CalledOnValidThread()); + InitIfNecessary(); + + if (!loaded_) { + EnqueueTask(task.Pass()); + return; + } + + task->Run(this); } void DefaultServerBoundCertStore::InternalDeleteServerBoundCert( const std::string& server_identifier) { DCHECK(CalledOnValidThread()); + DCHECK(loaded_); ServerBoundCertMap::iterator it = server_bound_certs_.find(server_identifier); if (it == server_bound_certs_.end()) @@ -176,6 +479,7 @@ void DefaultServerBoundCertStore::InternalInsertServerBoundCert( const std::string& server_identifier, ServerBoundCert* cert) { DCHECK(CalledOnValidThread()); + DCHECK(loaded_); if (store_) store_->AddServerBoundCert(*cert); diff --git a/net/base/default_server_bound_cert_store.h b/net/base/default_server_bound_cert_store.h index 48bf784..33b63ce 100644 --- a/net/base/default_server_bound_cert_store.h +++ b/net/base/default_server_bound_cert_store.h @@ -12,11 +12,12 @@ #include "base/callback_forward.h" #include "base/compiler_specific.h" #include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/memory/weak_ptr.h" #include "net/base/net_export.h" #include "net/base/server_bound_cert_store.h" -class Task; - namespace net { // This class is the system for storing and retrieving server bound certs. @@ -52,10 +53,10 @@ class NET_EXPORT DefaultServerBoundCertStore : public ServerBoundCertStore { virtual bool GetServerBoundCert( const std::string& server_identifier, SSLClientCertType* type, - base::Time* creation_time, base::Time* expiration_time, std::string* private_key_result, - std::string* cert_result) OVERRIDE; + std::string* cert_result, + const GetCertCallback& callback) OVERRIDE; virtual void SetServerBoundCert( const std::string& server_identifier, SSLClientCertType type, @@ -63,29 +64,44 @@ class NET_EXPORT DefaultServerBoundCertStore : public ServerBoundCertStore { base::Time expiration_time, const std::string& private_key, const std::string& cert) OVERRIDE; - virtual void DeleteServerBoundCert(const std::string& server_identifier) - OVERRIDE; - virtual void DeleteAllCreatedBetween(base::Time delete_begin, - base::Time delete_end) OVERRIDE; - virtual void DeleteAll() OVERRIDE; + virtual void DeleteServerBoundCert( + const std::string& server_identifier, + const base::Closure& callback) OVERRIDE; + virtual void DeleteAllCreatedBetween( + base::Time delete_begin, + base::Time delete_end, + const base::Closure& callback) OVERRIDE; + virtual void DeleteAll(const base::Closure& callback) OVERRIDE; virtual void GetAllServerBoundCerts( - ServerBoundCertList* server_bound_certs) OVERRIDE; + const GetCertListCallback& callback) OVERRIDE; virtual int GetCertCount() OVERRIDE; virtual void SetForceKeepSessionState() OVERRIDE; private: + class Task; + class GetServerBoundCertTask; + class SetServerBoundCertTask; + class DeleteServerBoundCertTask; + class DeleteAllCreatedBetweenTask; + class GetAllServerBoundCertsTask; + static const size_t kMaxCerts; // Deletes all of the certs. Does not delete them from |store_|. void DeleteAllInMemory(); // Called by all non-static functions to ensure that the cert store has - // been initialized. This is not done during creating so it doesn't block - // the window showing. + // been initialized. + // TODO(mattm): since we load asynchronously now, maybe we should start + // loading immediately on construction, or provide some method to initiate + // loading? void InitIfNecessary() { if (!initialized_) { - if (store_) + if (store_) { InitStore(); + } else { + loaded_ = true; + } initialized_ = true; } } @@ -94,6 +110,29 @@ class NET_EXPORT DefaultServerBoundCertStore : public ServerBoundCertStore { // Should only be called by InitIfNecessary(). void InitStore(); + // Callback for backing store loading completion. + void OnLoaded(scoped_ptr<ScopedVector<ServerBoundCert> > certs); + + // Syncronous methods which do the actual work. Can only be called after + // initialization is complete. + void SyncSetServerBoundCert( + const std::string& server_identifier, + SSLClientCertType type, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert); + void SyncDeleteServerBoundCert(const std::string& server_identifier); + void SyncDeleteAllCreatedBetween(base::Time delete_begin, + base::Time delete_end); + void SyncGetAllServerBoundCerts(ServerBoundCertList* cert_list); + + // Add |task| to |waiting_tasks_|. + void EnqueueTask(scoped_ptr<Task> task); + // If already initialized, run |task| immediately. Otherwise add it to + // |waiting_tasks_|. + void RunOrEnqueueTask(scoped_ptr<Task> task); + // Deletes the cert for the specified server, if such a cert exists, from the // in-memory store. Deletes it from |store_| if |store_| is not NULL. void InternalDeleteServerBoundCert(const std::string& server); @@ -105,13 +144,23 @@ class NET_EXPORT DefaultServerBoundCertStore : public ServerBoundCertStore { ServerBoundCert* cert); // Indicates whether the cert store has been initialized. This happens - // Lazily in InitStoreIfNecessary(). + // lazily in InitIfNecessary(). bool initialized_; + // Indicates whether loading from the backend store is completed and + // calls may be immediately processed. + bool loaded_; + + // Tasks that are waiting to be run once we finish loading. + ScopedVector<Task> waiting_tasks_; + base::TimeTicks waiting_tasks_start_time_; + scoped_refptr<PersistentStore> store_; ServerBoundCertMap server_bound_certs_; + base::WeakPtrFactory<DefaultServerBoundCertStore> weak_ptr_factory_; + DISALLOW_COPY_AND_ASSIGN(DefaultServerBoundCertStore); }; @@ -121,11 +170,14 @@ typedef base::RefCountedThreadSafe<DefaultServerBoundCertStore::PersistentStore> class NET_EXPORT DefaultServerBoundCertStore::PersistentStore : public RefcountedPersistentStore { public: + typedef base::Callback<void(scoped_ptr<ScopedVector<ServerBoundCert> >)> + LoadedCallback; + // Initializes the store and retrieves the existing certs. This will be // called only once at startup. Note that the certs are individually allocated // and that ownership is transferred to the caller upon return. - virtual bool Load( - std::vector<ServerBoundCert*>* certs) = 0; + // The |loaded_callback| must not be called synchronously. + virtual void Load(const LoadedCallback& loaded_callback) = 0; virtual void AddServerBoundCert(const ServerBoundCert& cert) = 0; diff --git a/net/base/default_server_bound_cert_store_unittest.cc b/net/base/default_server_bound_cert_store_unittest.cc index bc95398..df6009f 100644 --- a/net/base/default_server_bound_cert_store_unittest.cc +++ b/net/base/default_server_bound_cert_store_unittest.cc @@ -12,19 +12,67 @@ #include "base/compiler_specific.h" #include "base/logging.h" #include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { +namespace { + +void CallCounter(int* counter) { + (*counter)++; +} + +void NotCalled() { + ADD_FAILURE() << "Unexpected callback execution."; +} + +void GetCertCallbackNotCalled(const std::string& server_identifier, + SSLClientCertType type, + base::Time expiration_time, + const std::string& private_key_result, + const std::string& cert_result) { + ADD_FAILURE() << "Unexpected callback execution."; +} + +class AsyncGetCertHelper { + public: + AsyncGetCertHelper() : called_(false) {} + + void Callback(const std::string& server_identifier, + SSLClientCertType type, + base::Time expiration_time, + const std::string& private_key_result, + const std::string& cert_result) { + server_identifier_ = server_identifier; + type_ = type; + expiration_time_ = expiration_time; + private_key_ = private_key_result; + cert_ = cert_result; + called_ = true; + } + + std::string server_identifier_; + SSLClientCertType type_; + base::Time expiration_time_; + std::string private_key_; + std::string cert_; + bool called_; +}; + +void GetAllCallback( + ServerBoundCertStore::ServerBoundCertList* dest, + const ServerBoundCertStore::ServerBoundCertList& result) { + *dest = result; +} + class MockPersistentStore : public DefaultServerBoundCertStore::PersistentStore { public: MockPersistentStore(); // DefaultServerBoundCertStore::PersistentStore implementation. - virtual bool Load( - std::vector<DefaultServerBoundCertStore::ServerBoundCert*>* certs) - OVERRIDE; + virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE; virtual void AddServerBoundCert( const DefaultServerBoundCertStore::ServerBoundCert& cert) OVERRIDE; virtual void DeleteServerBoundCert( @@ -44,8 +92,9 @@ class MockPersistentStore MockPersistentStore::MockPersistentStore() {} -bool MockPersistentStore::Load( - std::vector<DefaultServerBoundCertStore::ServerBoundCert*>* certs) { +void MockPersistentStore::Load(const LoadedCallback& loaded_callback) { + scoped_ptr<ScopedVector<DefaultServerBoundCertStore::ServerBoundCert> > + certs(new ScopedVector<DefaultServerBoundCertStore::ServerBoundCert>()); ServerBoundCertMap::iterator it; for (it = origin_certs_.begin(); it != origin_certs_.end(); ++it) { @@ -53,7 +102,8 @@ bool MockPersistentStore::Load( new DefaultServerBoundCertStore::ServerBoundCert(it->second)); } - return true; + MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(loaded_callback, base::Passed(&certs))); } void MockPersistentStore::AddServerBoundCert( @@ -74,6 +124,8 @@ void MockPersistentStore::Flush(const base::Closure& completion_task) { MockPersistentStore::~MockPersistentStore() {} +} // namespace + TEST(DefaultServerBoundCertStoreTest, TestLoading) { scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); @@ -94,13 +146,16 @@ TEST(DefaultServerBoundCertStoreTest, TestLoading) { // Make sure certs load properly. DefaultServerBoundCertStore store(persistent_store.get()); - EXPECT_EQ(2, store.GetCertCount()); + // Load has not occurred yet. + EXPECT_EQ(0, store.GetCertCount()); store.SetServerBoundCert( "verisign.com", CLIENT_CERT_RSA_SIGN, base::Time(), base::Time(), "e", "f"); + // Wait for load & queued set task. + MessageLoop::current()->RunUntilIdle(); EXPECT_EQ(2, store.GetCertCount()); store.SetServerBoundCert( "twitter.com", @@ -108,22 +163,25 @@ TEST(DefaultServerBoundCertStoreTest, TestLoading) { base::Time(), base::Time(), "g", "h"); + // Set should be synchronous now that load is done. EXPECT_EQ(3, store.GetCertCount()); } +//TODO(mattm): add more tests of without a persistent store? TEST(DefaultServerBoundCertStoreTest, TestSettingAndGetting) { + // No persistent store, all calls will be synchronous. DefaultServerBoundCertStore store(NULL); SSLClientCertType type; - base::Time creation_time; base::Time expiration_time; std::string private_key, cert; EXPECT_EQ(0, store.GetCertCount()); - EXPECT_FALSE(store.GetServerBoundCert("verisign.com", - &type, - &creation_time, - &expiration_time, - &private_key, - &cert)); + EXPECT_TRUE(store.GetServerBoundCert("verisign.com", + &type, + &expiration_time, + &private_key, + &cert, + base::Bind(&GetCertCallbackNotCalled))); + EXPECT_EQ(CLIENT_CERT_INVALID_TYPE, type); EXPECT_TRUE(private_key.empty()); EXPECT_TRUE(cert.empty()); store.SetServerBoundCert( @@ -134,12 +192,11 @@ TEST(DefaultServerBoundCertStoreTest, TestSettingAndGetting) { "i", "j"); EXPECT_TRUE(store.GetServerBoundCert("verisign.com", &type, - &creation_time, &expiration_time, &private_key, - &cert)); + &cert, + base::Bind(&GetCertCallbackNotCalled))); EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type); - EXPECT_EQ(123, creation_time.ToInternalValue()); EXPECT_EQ(456, expiration_time.ToInternalValue()); EXPECT_EQ("i", private_key); EXPECT_EQ("j", cert); @@ -150,7 +207,6 @@ TEST(DefaultServerBoundCertStoreTest, TestDuplicateCerts) { DefaultServerBoundCertStore store(persistent_store.get()); SSLClientCertType type; - base::Time creation_time; base::Time expiration_time; std::string private_key, cert; EXPECT_EQ(0, store.GetCertCount()); @@ -167,25 +223,57 @@ TEST(DefaultServerBoundCertStoreTest, TestDuplicateCerts) { base::Time::FromInternalValue(4567), "c", "d"); + // Wait for load & queued set tasks. + MessageLoop::current()->RunUntilIdle(); EXPECT_EQ(1, store.GetCertCount()); EXPECT_TRUE(store.GetServerBoundCert("verisign.com", &type, - &creation_time, &expiration_time, &private_key, - &cert)); + &cert, + base::Bind(&GetCertCallbackNotCalled))); EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type); - EXPECT_EQ(456, creation_time.ToInternalValue()); EXPECT_EQ(4567, expiration_time.ToInternalValue()); EXPECT_EQ("c", private_key); EXPECT_EQ("d", cert); } +TEST(DefaultServerBoundCertStoreTest, TestAsyncGet) { + scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "verisign.com", + CLIENT_CERT_RSA_SIGN, + base::Time::FromInternalValue(123), + base::Time::FromInternalValue(1234), + "a", "b")); + + DefaultServerBoundCertStore store(persistent_store.get()); + AsyncGetCertHelper helper; + SSLClientCertType type; + base::Time expiration_time; + std::string private_key; + std::string cert = "not set"; + EXPECT_EQ(0, store.GetCertCount()); + EXPECT_FALSE(store.GetServerBoundCert( + "verisign.com", &type, &expiration_time, &private_key, &cert, + base::Bind(&AsyncGetCertHelper::Callback, base::Unretained(&helper)))); + + // Wait for load & queued get tasks. + MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(1, store.GetCertCount()); + EXPECT_EQ("not set", cert); + EXPECT_TRUE(helper.called_); + EXPECT_EQ("verisign.com", helper.server_identifier_); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, helper.type_); + EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue()); + EXPECT_EQ("a", helper.private_key_); + EXPECT_EQ("b", helper.cert_); +} + TEST(DefaultServerBoundCertStoreTest, TestDeleteAll) { scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); DefaultServerBoundCertStore store(persistent_store.get()); - EXPECT_EQ(0, store.GetCertCount()); store.SetServerBoundCert( "verisign.com", CLIENT_CERT_RSA_SIGN, @@ -204,18 +292,53 @@ TEST(DefaultServerBoundCertStoreTest, TestDeleteAll) { base::Time(), base::Time(), "e", "f"); + // Wait for load & queued set tasks. + MessageLoop::current()->RunUntilIdle(); EXPECT_EQ(3, store.GetCertCount()); - store.DeleteAll(); + int delete_finished = 0; + store.DeleteAll(base::Bind(&CallCounter, &delete_finished)); + ASSERT_EQ(1, delete_finished); EXPECT_EQ(0, store.GetCertCount()); } +TEST(DefaultServerBoundCertStoreTest, TestAsyncGetAndDeleteAll) { + scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "verisign.com", + CLIENT_CERT_RSA_SIGN, + base::Time(), + base::Time(), + "a", "b")); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "google.com", + CLIENT_CERT_RSA_SIGN, + base::Time(), + base::Time(), + "c", "d")); + + ServerBoundCertStore::ServerBoundCertList pre_certs; + ServerBoundCertStore::ServerBoundCertList post_certs; + int delete_finished = 0; + DefaultServerBoundCertStore store(persistent_store.get()); + + store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &pre_certs)); + store.DeleteAll(base::Bind(&CallCounter, &delete_finished)); + store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &post_certs)); + // Tasks have not run yet. + EXPECT_EQ(0u, pre_certs.size()); + // Wait for load & queued tasks. + MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(0, store.GetCertCount()); + EXPECT_EQ(2u, pre_certs.size()); + EXPECT_EQ(0u, post_certs.size()); +} + TEST(DefaultServerBoundCertStoreTest, TestDelete) { scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); DefaultServerBoundCertStore store(persistent_store.get()); SSLClientCertType type; - base::Time creation_time; base::Time expiration_time; std::string private_key, cert; EXPECT_EQ(0, store.GetCertCount()); @@ -225,6 +348,9 @@ TEST(DefaultServerBoundCertStoreTest, TestDelete) { base::Time(), base::Time(), "a", "b"); + // Wait for load & queued set task. + MessageLoop::current()->RunUntilIdle(); + store.SetServerBoundCert( "google.com", CLIENT_CERT_ECDSA_SIGN, @@ -233,28 +359,92 @@ TEST(DefaultServerBoundCertStoreTest, TestDelete) { "c", "d"); EXPECT_EQ(2, store.GetCertCount()); - store.DeleteServerBoundCert("verisign.com"); + int delete_finished = 0; + store.DeleteServerBoundCert("verisign.com", + base::Bind(&CallCounter, &delete_finished)); + ASSERT_EQ(1, delete_finished); EXPECT_EQ(1, store.GetCertCount()); - EXPECT_FALSE(store.GetServerBoundCert("verisign.com", - &type, - &creation_time, - &expiration_time, - &private_key, - &cert)); + EXPECT_TRUE(store.GetServerBoundCert("verisign.com", + &type, + &expiration_time, + &private_key, + &cert, + base::Bind(&GetCertCallbackNotCalled))); + EXPECT_EQ(CLIENT_CERT_INVALID_TYPE, type); EXPECT_TRUE(store.GetServerBoundCert("google.com", &type, - &creation_time, &expiration_time, &private_key, - &cert)); - store.DeleteServerBoundCert("google.com"); + &cert, + base::Bind(&GetCertCallbackNotCalled))); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type); + int delete2_finished = 0; + store.DeleteServerBoundCert("google.com", + base::Bind(&CallCounter, &delete2_finished)); + ASSERT_EQ(1, delete2_finished); EXPECT_EQ(0, store.GetCertCount()); - EXPECT_FALSE(store.GetServerBoundCert("google.com", - &type, - &creation_time, - &expiration_time, - &private_key, - &cert)); + EXPECT_TRUE(store.GetServerBoundCert("google.com", + &type, + &expiration_time, + &private_key, + &cert, + base::Bind(&GetCertCallbackNotCalled))); + EXPECT_EQ(CLIENT_CERT_INVALID_TYPE, type); +} + +TEST(DefaultServerBoundCertStoreTest, TestAsyncDelete) { + scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "a.com", + CLIENT_CERT_RSA_SIGN, + base::Time::FromInternalValue(1), + base::Time::FromInternalValue(2), + "a", "b")); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "b.com", + CLIENT_CERT_RSA_SIGN, + base::Time::FromInternalValue(3), + base::Time::FromInternalValue(4), + "c", "d")); + DefaultServerBoundCertStore store(persistent_store.get()); + int delete_finished = 0; + store.DeleteServerBoundCert("a.com", + base::Bind(&CallCounter, &delete_finished)); + + AsyncGetCertHelper a_helper; + AsyncGetCertHelper b_helper; + SSLClientCertType type; + base::Time expiration_time; + std::string private_key; + std::string cert = "not set"; + EXPECT_EQ(0, store.GetCertCount()); + EXPECT_FALSE(store.GetServerBoundCert( + "a.com", &type, &expiration_time, &private_key, &cert, + base::Bind(&AsyncGetCertHelper::Callback, base::Unretained(&a_helper)))); + EXPECT_FALSE(store.GetServerBoundCert( + "b.com", &type, &expiration_time, &private_key, &cert, + base::Bind(&AsyncGetCertHelper::Callback, base::Unretained(&b_helper)))); + + EXPECT_EQ(0, delete_finished); + EXPECT_FALSE(a_helper.called_); + EXPECT_FALSE(b_helper.called_); + // Wait for load & queued tasks. + MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(1, delete_finished); + EXPECT_EQ(1, store.GetCertCount()); + EXPECT_EQ("not set", cert); + EXPECT_TRUE(a_helper.called_); + EXPECT_EQ("a.com", a_helper.server_identifier_); + EXPECT_EQ(CLIENT_CERT_INVALID_TYPE, a_helper.type_); + EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue()); + EXPECT_EQ("", a_helper.private_key_); + EXPECT_EQ("", a_helper.cert_); + EXPECT_TRUE(b_helper.called_); + EXPECT_EQ("b.com", b_helper.server_identifier_); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, b_helper.type_); + EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue()); + EXPECT_EQ("c", b_helper.private_key_); + EXPECT_EQ("d", b_helper.cert_); } TEST(DefaultServerBoundCertStoreTest, TestGetAll) { @@ -286,10 +476,12 @@ TEST(DefaultServerBoundCertStoreTest, TestGetAll) { base::Time(), base::Time(), "g", "h"); + // Wait for load & queued set tasks. + MessageLoop::current()->RunUntilIdle(); EXPECT_EQ(4, store.GetCertCount()); ServerBoundCertStore::ServerBoundCertList certs; - store.GetAllServerBoundCerts(&certs); + store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs)); EXPECT_EQ(4u, certs.size()); } @@ -309,6 +501,8 @@ TEST(DefaultServerBoundCertStoreTest, TestInitializeFrom) { base::Time(), base::Time(), "c", "d"); + // Wait for load & queued set tasks. + MessageLoop::current()->RunUntilIdle(); EXPECT_EQ(2, store.GetCertCount()); ServerBoundCertStore::ServerBoundCertList source_certs; @@ -329,7 +523,60 @@ TEST(DefaultServerBoundCertStoreTest, TestInitializeFrom) { EXPECT_EQ(3, store.GetCertCount()); ServerBoundCertStore::ServerBoundCertList certs; - store.GetAllServerBoundCerts(&certs); + store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs)); + ASSERT_EQ(3u, certs.size()); + + ServerBoundCertStore::ServerBoundCertList::iterator cert = certs.begin(); + EXPECT_EQ("both.com", cert->server_identifier()); + EXPECT_EQ("e", cert->private_key()); + + ++cert; + EXPECT_EQ("copied.com", cert->server_identifier()); + EXPECT_EQ("g", cert->private_key()); + + ++cert; + EXPECT_EQ("preexisting.com", cert->server_identifier()); + EXPECT_EQ("a", cert->private_key()); +} + +TEST(DefaultServerBoundCertStoreTest, TestAsyncInitializeFrom) { + scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "preexisting.com", + CLIENT_CERT_RSA_SIGN, + base::Time(), + base::Time(), + "a", "b")); + persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( + "both.com", + CLIENT_CERT_RSA_SIGN, + base::Time(), + base::Time(), + "c", "d")); + + DefaultServerBoundCertStore store(persistent_store.get()); + ServerBoundCertStore::ServerBoundCertList source_certs; + source_certs.push_back(ServerBoundCertStore::ServerBoundCert( + "both.com", + CLIENT_CERT_RSA_SIGN, + base::Time(), + base::Time(), + // Key differs from above to test that existing entries are overwritten. + "e", "f")); + source_certs.push_back(ServerBoundCertStore::ServerBoundCert( + "copied.com", + CLIENT_CERT_RSA_SIGN, + base::Time(), + base::Time(), + "g", "h")); + store.InitializeFrom(source_certs); + EXPECT_EQ(0, store.GetCertCount()); + // Wait for load & queued tasks. + MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(3, store.GetCertCount()); + + ServerBoundCertStore::ServerBoundCertList certs; + store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs)); ASSERT_EQ(3u, certs.size()); ServerBoundCertStore::ServerBoundCertList::iterator cert = certs.begin(); diff --git a/net/base/server_bound_cert_service.cc b/net/base/server_bound_cert_service.cc index 3194a66..e00d7bc 100644 --- a/net/base/server_bound_cert_service.cc +++ b/net/base/server_bound_cert_service.cc @@ -45,19 +45,36 @@ bool IsSupportedCertType(uint8 type) { switch(type) { case CLIENT_CERT_ECDSA_SIGN: return true; + // If we add any more supported types, CertIsValid will need to be updated + // to check that the returned type matches one of the requested types. default: return false; } } +bool CertIsValid(const std::string& domain, + SSLClientCertType type, + base::Time expiration_time) { + if (expiration_time < base::Time::Now()) { + DVLOG(1) << "Cert store had expired cert for " << domain; + return false; + } else if (!IsSupportedCertType(type)) { + DVLOG(1) << "Cert store had cert of wrong type " << type << " for " + << domain; + return false; + } + return true; +} + // Used by the GetDomainBoundCertResult histogram to record the final // outcome of each GetDomainBoundCert call. Do not re-use values. enum GetCertResult { // Synchronously found and returned an existing domain bound cert. SYNC_SUCCESS = 0, - // Generated and returned a domain bound cert asynchronously. + // Retrieved or generated and returned a domain bound cert asynchronously. ASYNC_SUCCESS = 1, - // Generation request was cancelled before the cert generation completed. + // Retrieval/generation request was cancelled before the cert generation + // completed. ASYNC_CANCELLED = 2, // Cert generation failed. ASYNC_FAILURE_KEYGEN = 3, @@ -204,6 +221,9 @@ class ServerBoundCertServiceRequest { case ERR_PRIVATE_KEY_EXPORT_FAILED: RecordGetDomainBoundCertResult(ASYNC_FAILURE_EXPORT_KEY); break; + case ERR_INSUFFICIENT_RESOURCES: + RecordGetDomainBoundCertResult(WORKER_FAILURE); + break; default: RecordGetDomainBoundCertResult(ASYNC_FAILURE_UNKNOWN); break; @@ -295,7 +315,8 @@ class ServerBoundCertServiceWorker { // origin message loop. class ServerBoundCertServiceJob { public: - ServerBoundCertServiceJob(SSLClientCertType type) : type_(type) { + ServerBoundCertServiceJob(SSLClientCertType type) + : type_(type) { } ~ServerBoundCertServiceJob() { @@ -450,38 +471,7 @@ int ServerBoundCertService::GetDomainBoundCert( requests_++; - // Check if a domain bound cert of an acceptable type already exists for this - // domain, and that it has not expired. - base::Time now = base::Time::Now(); - base::Time creation_time; - base::Time expiration_time; - if (server_bound_cert_store_->GetServerBoundCert(domain, - type, - &creation_time, - &expiration_time, - private_key, - cert)) { - if (expiration_time < now) { - DVLOG(1) << "Cert store had expired cert for " << domain; - } else if (!IsSupportedCertType(*type) || - std::find(requested_types.begin(), requested_types.end(), - *type) == requested_types.end()) { - DVLOG(1) << "Cert store had cert of wrong type " << *type << " for " - << domain; - } else { - DVLOG(1) << "Cert store had valid cert for " << domain - << " of type " << *type; - cert_store_hits_++; - RecordGetDomainBoundCertResult(SYNC_SUCCESS); - base::TimeDelta request_time = base::TimeTicks::Now() - request_start; - UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time); - RecordGetCertTime(request_time); - return OK; - } - } - - // |server_bound_cert_store_| has no cert for this domain. See if an - // identical request is currently in flight. + // See if an identical request is currently in flight. ServerBoundCertServiceJob* job = NULL; std::map<std::string, ServerBoundCertServiceJob*>::const_iterator j; j = inflight_.find(domain); @@ -503,23 +493,63 @@ int ServerBoundCertService::GetDomainBoundCert( return ERR_ORIGIN_BOUND_CERT_GENERATION_TYPE_MISMATCH; } inflight_joins_++; - } else { - // Need to make a new request. + + ServerBoundCertServiceRequest* request = new ServerBoundCertServiceRequest( + request_start, + base::Bind(&RequestHandle::OnRequestComplete, + base::Unretained(out_req)), + type, private_key, cert); + job->AddRequest(request); + out_req->RequestStarted(this, request, callback); + return ERR_IO_PENDING; + } + + // Check if a domain bound cert of an acceptable type already exists for this + // domain, and that it has not expired. + base::Time expiration_time; + if (server_bound_cert_store_->GetServerBoundCert( + domain, + type, + &expiration_time, + private_key, + cert, + base::Bind(&ServerBoundCertService::GotServerBoundCert, + weak_ptr_factory_.GetWeakPtr()))) { + if (*type != CLIENT_CERT_INVALID_TYPE) { + // Sync lookup found a cert. + if (CertIsValid(domain, *type, expiration_time)) { + DVLOG(1) << "Cert store had valid cert for " << domain + << " of type " << *type; + cert_store_hits_++; + RecordGetDomainBoundCertResult(SYNC_SUCCESS); + base::TimeDelta request_time = base::TimeTicks::Now() - request_start; + UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time); + RecordGetCertTime(request_time); + return OK; + } + } + + // Sync lookup did not find a cert, or it found an expired one. Start + // generating a new one. ServerBoundCertServiceWorker* worker = new ServerBoundCertServiceWorker( - domain, - preferred_type, - base::Bind(&ServerBoundCertService::HandleResult, - weak_ptr_factory_.GetWeakPtr())); + domain, + preferred_type, + base::Bind(&ServerBoundCertService::GeneratedServerBoundCert, + weak_ptr_factory_.GetWeakPtr())); if (!worker->Start(task_runner_)) { + delete worker; // TODO(rkn): Log to the NetLog. LOG(ERROR) << "ServerBoundCertServiceWorker couldn't be started."; RecordGetDomainBoundCertResult(WORKER_FAILURE); - return ERR_INSUFFICIENT_RESOURCES; // Just a guess. + return ERR_INSUFFICIENT_RESOURCES; } - job = new ServerBoundCertServiceJob(preferred_type); - inflight_[domain] = job; } + // We are either waiting for async DB lookup, or waiting for cert generation. + // Create a job & request to track it. + job = new ServerBoundCertServiceJob(preferred_type); + inflight_[domain] = job; + ServerBoundCertServiceRequest* request = new ServerBoundCertServiceRequest( request_start, base::Bind(&RequestHandle::OnRequestComplete, base::Unretained(out_req)), @@ -529,6 +559,51 @@ int ServerBoundCertService::GetDomainBoundCert( return ERR_IO_PENDING; } +void ServerBoundCertService::GotServerBoundCert( + const std::string& server_identifier, + SSLClientCertType type, + base::Time expiration_time, + const std::string& key, + const std::string& cert) { + DCHECK(CalledOnValidThread()); + + std::map<std::string, ServerBoundCertServiceJob*>::iterator j; + j = inflight_.find(server_identifier); + if (j == inflight_.end()) { + NOTREACHED(); + return; + } + ServerBoundCertServiceJob* job = j->second; + + if (type != CLIENT_CERT_INVALID_TYPE) { + // Async DB lookup found a cert. + if (CertIsValid(server_identifier, type, expiration_time)) { + DVLOG(1) << "Cert store had valid cert for " << server_identifier + << " of type " << type; + cert_store_hits_++; + // ServerBoundCertServiceRequest::Post will do the histograms and stuff. + HandleResult(OK, server_identifier, type, key, cert); + return; + } + } + + // Async lookup did not find a cert, or it found an expired one. Start + // generating a new one. + ServerBoundCertServiceWorker* worker = new ServerBoundCertServiceWorker( + server_identifier, + job->type(), + base::Bind(&ServerBoundCertService::GeneratedServerBoundCert, + weak_ptr_factory_.GetWeakPtr())); + if (!worker->Start(task_runner_)) { + delete worker; + // TODO(rkn): Log to the NetLog. + LOG(ERROR) << "ServerBoundCertServiceWorker couldn't be started."; + HandleResult(ERR_INSUFFICIENT_RESOURCES, server_identifier, + CLIENT_CERT_INVALID_TYPE, "", ""); + return; + } +} + ServerBoundCertStore* ServerBoundCertService::GetCertStore() { return server_bound_cert_store_.get(); } @@ -538,9 +613,7 @@ void ServerBoundCertService::CancelRequest(ServerBoundCertServiceRequest* req) { req->Cancel(); } -// HandleResult is called by ServerBoundCertServiceWorker on the origin message -// loop. It deletes ServerBoundCertServiceJob. -void ServerBoundCertService::HandleResult( +void ServerBoundCertService::GeneratedServerBoundCert( const std::string& server_identifier, int error, scoped_ptr<ServerBoundCertStore::ServerBoundCert> cert) { @@ -552,7 +625,21 @@ void ServerBoundCertService::HandleResult( server_bound_cert_store_->SetServerBoundCert( cert->server_identifier(), cert->type(), cert->creation_time(), cert->expiration_time(), cert->private_key(), cert->cert()); + + HandleResult(error, server_identifier, cert->type(), cert->private_key(), + cert->cert()); + } else { + HandleResult(error, server_identifier, CLIENT_CERT_INVALID_TYPE, "", ""); } +} + +void ServerBoundCertService::HandleResult( + int error, + const std::string& server_identifier, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert) { + DCHECK(CalledOnValidThread()); std::map<std::string, ServerBoundCertServiceJob*>::iterator j; j = inflight_.find(server_identifier); @@ -563,10 +650,7 @@ void ServerBoundCertService::HandleResult( ServerBoundCertServiceJob* job = j->second; inflight_.erase(j); - if (cert) - job->HandleResult(error, cert->type(), cert->private_key(), cert->cert()); - else - job->HandleResult(error, CLIENT_CERT_INVALID_TYPE, "", ""); + job->HandleResult(error, type, private_key, cert); delete job; } diff --git a/net/base/server_bound_cert_service.h b/net/base/server_bound_cert_service.h index c8c9a05..151b279 100644 --- a/net/base/server_bound_cert_service.h +++ b/net/base/server_bound_cert_service.h @@ -128,9 +128,20 @@ class NET_EXPORT ServerBoundCertService // callback will not be called. void CancelRequest(ServerBoundCertServiceRequest* req); - void HandleResult(const std::string& server_identifier, - int error, - scoped_ptr<ServerBoundCertStore::ServerBoundCert> cert); + void GotServerBoundCert(const std::string& server_identifier, + SSLClientCertType type, + base::Time expiration_time, + const std::string& key, + const std::string& cert); + void GeneratedServerBoundCert( + const std::string& server_identifier, + int error, + scoped_ptr<ServerBoundCertStore::ServerBoundCert> cert); + void HandleResult(int error, + const std::string& server_identifier, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert); scoped_ptr<ServerBoundCertStore> server_bound_cert_store_; scoped_refptr<base::TaskRunner> task_runner_; diff --git a/net/base/server_bound_cert_store.h b/net/base/server_bound_cert_store.h index 7a6f866..85e1035 100644 --- a/net/base/server_bound_cert_store.h +++ b/net/base/server_bound_cert_store.h @@ -8,6 +8,7 @@ #include <list> #include <string> +#include "base/callback.h" #include "base/threading/non_thread_safe.h" #include "base/time.h" #include "net/base/net_export.h" @@ -65,23 +66,29 @@ class NET_EXPORT ServerBoundCertStore typedef std::list<ServerBoundCert> ServerBoundCertList; + typedef base::Callback<void( + const std::string&, + SSLClientCertType, + base::Time, + const std::string&, + const std::string&)> GetCertCallback; + typedef base::Callback<void(const ServerBoundCertList&)> GetCertListCallback; + virtual ~ServerBoundCertStore() {} - // TODO(rkn): File I/O may be required, so this should have an asynchronous - // interface. - // Returns true on success. |private_key_result| stores a DER-encoded - // PrivateKeyInfo struct, |cert_result| stores a DER-encoded certificate, - // |type| is the ClientCertificateType of the returned certificate, - // |creation_time| stores the start of the validity period of the certificate - // and |expiration_time| is the expiration time of the certificate. - // Returns false if no server bound cert exists for the specified server. + // GetServerBoundCert may return the result synchronously through the + // output parameters, in which case it will return true. Otherwise it will + // return false and the callback will be called with the result + // asynchronously. + // In either case, the type will be CLIENT_CERT_INVALID_TYPE if no cert + // existed for the given |server_identifier|. virtual bool GetServerBoundCert( const std::string& server_identifier, SSLClientCertType* type, - base::Time* creation_time, base::Time* expiration_time, std::string* private_key_result, - std::string* cert_result) = 0; + std::string* cert_result, + const GetCertCallback& callback) = 0; // Adds a server bound cert and the corresponding private key to the store. virtual void SetServerBoundCert( @@ -94,26 +101,30 @@ class NET_EXPORT ServerBoundCertStore // Removes a server bound cert and the corresponding private key from the // store. - virtual void DeleteServerBoundCert(const std::string& server_identifier) = 0; + virtual void DeleteServerBoundCert( + const std::string& server_identifier, + const base::Closure& completion_callback) = 0; // Deletes all of the server bound certs that have a creation_date greater // than or equal to |delete_begin| and less than |delete_end|. If a // base::Time value is_null, that side of the comparison is unbounded. - virtual void DeleteAllCreatedBetween(base::Time delete_begin, - base::Time delete_end) = 0; + virtual void DeleteAllCreatedBetween( + base::Time delete_begin, + base::Time delete_end, + const base::Closure& completion_callback) = 0; // Removes all server bound certs and the corresponding private keys from // the store. - virtual void DeleteAll() = 0; + virtual void DeleteAll(const base::Closure& completion_callback) = 0; // Returns all server bound certs and the corresponding private keys. - virtual void GetAllServerBoundCerts( - ServerBoundCertList* server_bound_certs) = 0; + virtual void GetAllServerBoundCerts(const GetCertListCallback& callback) = 0; // Helper function that adds all certs from |list| into this instance. void InitializeFrom(const ServerBoundCertList& list); - // Returns the number of certs in the store. + // Returns the number of certs in the store. May return 0 if the backing + // store is not loaded yet. // Public only for unit testing. virtual int GetCertCount() = 0; |