diff options
-rw-r--r-- | chrome/browser/net/sqlite_origin_bound_cert_store.cc | 36 | ||||
-rw-r--r-- | chrome/browser/net/sqlite_origin_bound_cert_store_unittest.cc | 102 | ||||
-rw-r--r-- | crypto/ec_private_key.h | 17 | ||||
-rw-r--r-- | crypto/ec_private_key_nss.cc | 134 | ||||
-rw-r--r-- | net/base/default_origin_bound_cert_store.cc | 29 | ||||
-rw-r--r-- | net/base/default_origin_bound_cert_store.h | 40 | ||||
-rw-r--r-- | net/base/default_origin_bound_cert_store_unittest.cc | 81 | ||||
-rw-r--r-- | net/base/net_error_list.h | 7 | ||||
-rw-r--r-- | net/base/origin_bound_cert_service.cc | 180 | ||||
-rw-r--r-- | net/base/origin_bound_cert_service.h | 35 | ||||
-rw-r--r-- | net/base/origin_bound_cert_service_unittest.cc | 279 | ||||
-rw-r--r-- | net/base/origin_bound_cert_store.cc | 23 | ||||
-rw-r--r-- | net/base/origin_bound_cert_store.h | 52 | ||||
-rw-r--r-- | net/base/ssl_client_cert_type.h | 22 | ||||
-rw-r--r-- | net/net.gyp | 2 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_nss.cc | 70 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_nss.h | 7 |
17 files changed, 875 insertions, 241 deletions
diff --git a/chrome/browser/net/sqlite_origin_bound_cert_store.cc b/chrome/browser/net/sqlite_origin_bound_cert_store.cc index 6c1d519..2223394 100644 --- a/chrome/browser/net/sqlite_origin_bound_cert_store.cc +++ b/chrome/browser/net/sqlite_origin_bound_cert_store.cc @@ -17,6 +17,7 @@ #include "base/threading/thread_restrictions.h" #include "chrome/browser/diagnostics/sqlite_diagnostics.h" #include "content/public/browser/browser_thread.h" +#include "net/base/ssl_client_cert_type.h" #include "sql/meta_table.h" #include "sql/statement.h" #include "sql/transaction.h" @@ -116,7 +117,7 @@ class SQLiteOriginBoundCertStore::Backend }; // Version number of the database. -static const int kCurrentVersionNumber = 1; +static const int kCurrentVersionNumber = 2; static const int kCompatibleVersionNumber = 1; namespace { @@ -127,7 +128,8 @@ bool InitTable(sql::Connection* db) { if (!db->Execute("CREATE TABLE origin_bound_certs (" "origin TEXT NOT NULL UNIQUE PRIMARY KEY," "private_key BLOB NOT NULL," - "cert BLOB NOT NULL)")) + "cert BLOB NOT NULL," + "cert_type INTEGER)")) return false; } @@ -169,7 +171,7 @@ bool SQLiteOriginBoundCertStore::Backend::Load( // Slurp all the certs into the out-vector. sql::Statement smt(db_->GetUniqueStatement( - "SELECT origin, private_key, cert FROM origin_bound_certs")); + "SELECT origin, private_key, cert, cert_type FROM origin_bound_certs")); if (!smt) { NOTREACHED() << "select statement prep failed"; db_.reset(); @@ -183,6 +185,7 @@ bool SQLiteOriginBoundCertStore::Backend::Load( scoped_ptr<net::DefaultOriginBoundCertStore::OriginBoundCert> cert( new net::DefaultOriginBoundCertStore::OriginBoundCert( smt.ColumnString(0), // origin + static_cast<net::SSLClientCertType>(smt.ColumnInt(3)), private_key_from_db, cert_from_db)); certs->push_back(cert.release()); @@ -204,6 +207,28 @@ bool SQLiteOriginBoundCertStore::Backend::EnsureDatabaseVersion() { } int cur_version = meta_table_.GetVersionNumber(); + if (cur_version == 1) { + sql::Transaction transaction(db_.get()); + if (!transaction.Begin()) + return false; + if (!db_->Execute("ALTER TABLE origin_bound_certs ADD COLUMN cert_type " + "INTEGER")) { + LOG(WARNING) << "Unable to update origin bound cert database to " + << "version 2."; + return false; + } + // All certs in version 1 database are rsa_sign, which has a value of 1. + if (!db_->Execute("UPDATE origin_bound_certs SET cert_type = 1")) { + LOG(WARNING) << "Unable to update origin bound cert database to " + << "version 2."; + return false; + } + ++cur_version; + meta_table_.SetVersionNumber(cur_version); + meta_table_.SetCompatibleVersionNumber( + std::min(cur_version, kCompatibleVersionNumber)); + transaction.Commit(); + } // Put future migration cases here. @@ -273,8 +298,8 @@ void SQLiteOriginBoundCertStore::Backend::Commit() { return; sql::Statement add_smt(db_->GetCachedStatement(SQL_FROM_HERE, - "INSERT INTO origin_bound_certs (origin, private_key, cert) " - "VALUES (?,?,?)")); + "INSERT INTO origin_bound_certs (origin, private_key, cert, cert_type) " + "VALUES (?,?,?,?)")); if (!add_smt) { NOTREACHED(); return; @@ -304,6 +329,7 @@ void SQLiteOriginBoundCertStore::Backend::Commit() { add_smt.BindBlob(1, private_key.data(), private_key.size()); const std::string& cert = po->cert().cert(); add_smt.BindBlob(2, cert.data(), cert.size()); + add_smt.BindInt(3, po->cert().type()); if (!add_smt.Run()) NOTREACHED() << "Could not add an origin bound cert to the DB."; break; diff --git a/chrome/browser/net/sqlite_origin_bound_cert_store_unittest.cc b/chrome/browser/net/sqlite_origin_bound_cert_store_unittest.cc index 4768395..cc3130d 100644 --- a/chrome/browser/net/sqlite_origin_bound_cert_store_unittest.cc +++ b/chrome/browser/net/sqlite_origin_bound_cert_store_unittest.cc @@ -12,6 +12,7 @@ #include "chrome/browser/net/sqlite_origin_bound_cert_store.h" #include "chrome/common/chrome_constants.h" #include "content/test/test_browser_thread.h" +#include "sql/statement.h" #include "testing/gtest/include/gtest/gtest.h" using content::BrowserThread; @@ -34,7 +35,8 @@ class SQLiteOriginBoundCertStoreTest : public testing::Test { // Make sure the store gets written at least once. store_->AddOriginBoundCert( net::DefaultOriginBoundCertStore::OriginBoundCert( - "https://encrypted.google.com:8443", "a", "b")); + "https://encrypted.google.com:8443", + net::CLIENT_CERT_RSA_SIGN, "a", "b")); } content::TestBrowserThread db_thread_; @@ -75,6 +77,10 @@ TEST_F(SQLiteOriginBoundCertStoreTest, RemoveOnDestruction) { // Test if data is stored as expected in the SQLite database. TEST_F(SQLiteOriginBoundCertStoreTest, TestPersistence) { + store_->AddOriginBoundCert( + net::DefaultOriginBoundCertStore::OriginBoundCert( + "https://www.google.com/", net::CLIENT_CERT_ECDSA_SIGN, "c", "d")); + std::vector<net::DefaultOriginBoundCertStore::OriginBoundCert*> certs; // Replace the store effectively destroying the current one and forcing it // to write it's data to disk. Then we can see if after loading it again it @@ -90,13 +96,28 @@ TEST_F(SQLiteOriginBoundCertStoreTest, TestPersistence) { // Reload and test for persistence ASSERT_TRUE(store_->Load(&certs)); - ASSERT_EQ(1U, certs.size()); - ASSERT_STREQ("https://encrypted.google.com:8443", certs[0]->origin().c_str()); - ASSERT_STREQ("a", certs[0]->private_key().c_str()); - ASSERT_STREQ("b", certs[0]->cert().c_str()); + ASSERT_EQ(2U, certs.size()); + net::DefaultOriginBoundCertStore::OriginBoundCert* ec_cert; + net::DefaultOriginBoundCertStore::OriginBoundCert* rsa_cert; + if (net::CLIENT_CERT_RSA_SIGN == certs[0]->type()) { + rsa_cert = certs[0]; + ec_cert = certs[1]; + } else { + rsa_cert = certs[1]; + ec_cert = certs[0]; + } + ASSERT_STREQ("https://encrypted.google.com:8443", rsa_cert->origin().c_str()); + ASSERT_EQ(net::CLIENT_CERT_RSA_SIGN, rsa_cert->type()); + ASSERT_STREQ("a", rsa_cert->private_key().c_str()); + ASSERT_STREQ("b", rsa_cert->cert().c_str()); + ASSERT_STREQ("https://www.google.com/", ec_cert->origin().c_str()); + ASSERT_EQ(net::CLIENT_CERT_ECDSA_SIGN, ec_cert->type()); + ASSERT_STREQ("c", ec_cert->private_key().c_str()); + ASSERT_STREQ("d", ec_cert->cert().c_str()); // Now delete the cert and check persistence again. store_->DeleteOriginBoundCert(*certs[0]); + store_->DeleteOriginBoundCert(*certs[1]); store_ = NULL; // Make sure we wait until the destructor has run. ASSERT_TRUE(helper->Run()); @@ -110,6 +131,69 @@ TEST_F(SQLiteOriginBoundCertStoreTest, TestPersistence) { ASSERT_EQ(0U, certs.size()); } +TEST_F(SQLiteOriginBoundCertStoreTest, TestUpgrade) { + // Reset the store. We'll be using a different database for this test. + store_ = NULL; + + FilePath v1_db_path(temp_dir_.path().AppendASCII("v1db")); + + // Create a version 1 database. + { + sql::Connection db; + ASSERT_TRUE(db.Open(v1_db_path)); + ASSERT_TRUE(db.Execute( + "CREATE TABLE meta(key LONGVARCHAR NOT NULL UNIQUE PRIMARY KEY," + "value LONGVARCHAR);" + "INSERT INTO \"meta\" VALUES('version','1');" + "INSERT INTO \"meta\" VALUES('last_compatible_version','1');" + "CREATE TABLE origin_bound_certs (" + "origin TEXT NOT NULL UNIQUE PRIMARY KEY," + "private_key BLOB NOT NULL,cert BLOB NOT NULL);" + "INSERT INTO \"origin_bound_certs\" VALUES(" + "'https://google.com',X'AA',X'BB');" + "INSERT INTO \"origin_bound_certs\" VALUES(" + "'https://foo.com',X'CC',X'DD');" + )); + } + + std::vector<net::DefaultOriginBoundCertStore::OriginBoundCert*> certs; + store_ = new SQLiteOriginBoundCertStore(v1_db_path); + + // Load the database and ensure the certs can be read and are marked as RSA. + ASSERT_TRUE(store_->Load(&certs)); + ASSERT_EQ(2U, certs.size()); + ASSERT_STREQ("https://google.com", certs[0]->origin().c_str()); + ASSERT_EQ(net::CLIENT_CERT_RSA_SIGN, certs[0]->type()); + ASSERT_STREQ("\xaa", certs[0]->private_key().c_str()); + ASSERT_STREQ("\xbb", certs[0]->cert().c_str()); + ASSERT_STREQ("https://foo.com", certs[1]->origin().c_str()); + ASSERT_EQ(net::CLIENT_CERT_RSA_SIGN, certs[1]->type()); + ASSERT_STREQ("\xcc", certs[1]->private_key().c_str()); + ASSERT_STREQ("\xdd", certs[1]->cert().c_str()); + + STLDeleteContainerPointers(certs.begin(), certs.end()); + certs.clear(); + + store_ = NULL; + // Make sure we wait until the destructor has run. + scoped_refptr<base::ThreadTestHelper> helper( + new base::ThreadTestHelper( + BrowserThread::GetMessageLoopProxyForThread(BrowserThread::DB))); + ASSERT_TRUE(helper->Run()); + + // Verify the database version is updated. + { + sql::Connection db; + ASSERT_TRUE(db.Open(v1_db_path)); + sql::Statement smt(db.GetUniqueStatement( + "SELECT value FROM meta WHERE key = \"version\"")); + ASSERT_TRUE(smt); + ASSERT_TRUE(smt.Step()); + EXPECT_EQ(2, smt.ColumnInt(0)); + EXPECT_FALSE(smt.Step()); + } +} + // Test that we can force the database to be written by calling Flush(). TEST_F(SQLiteOriginBoundCertStoreTest, TestFlush) { // File timestamps don't work well on all platforms, so we'll determine @@ -125,9 +209,11 @@ TEST_F(SQLiteOriginBoundCertStoreTest, TestFlush) { std::string private_key(1000, c); std::string cert(1000, c); store_->AddOriginBoundCert( - net::DefaultOriginBoundCertStore::OriginBoundCert(origin, - private_key, - cert)); + net::DefaultOriginBoundCertStore::OriginBoundCert( + origin, + net::CLIENT_CERT_RSA_SIGN, + private_key, + cert)); } // Call Flush() and wait until the DB thread is idle. diff --git a/crypto/ec_private_key.h b/crypto/ec_private_key.h index 0d287de..44f754b 100644 --- a/crypto/ec_private_key.h +++ b/crypto/ec_private_key.h @@ -18,6 +18,7 @@ typedef struct evp_pkey_st EVP_PKEY; #else // Forward declaration. +typedef struct CERTSubjectPublicKeyInfoStr CERTSubjectPublicKeyInfo; typedef struct SECKEYPrivateKeyStr SECKEYPrivateKey; typedef struct SECKEYPublicKeyStr SECKEYPublicKey; #endif @@ -65,6 +66,22 @@ class CRYPTO_EXPORT ECPrivateKey { const std::vector<uint8>& encrypted_private_key_info, const std::vector<uint8>& subject_public_key_info); +#if !defined(USE_OPENSSL) + // Imports the key pair and returns in |public_key| and |key|. + // Shortcut for code that needs to keep a reference directly to NSS types + // without having to create a ECPrivateKey object and make a copy of them. + // TODO(mattm): move this function to some NSS util file. + static bool ImportFromEncryptedPrivateKeyInfo( + const std::string& password, + const uint8* encrypted_private_key_info, + size_t encrypted_private_key_info_len, + CERTSubjectPublicKeyInfo* decoded_spki, + bool permanent, + bool sensitive, + SECKEYPrivateKey** key, + SECKEYPublicKey** public_key); +#endif + #if defined(USE_OPENSSL) EVP_PKEY* key() { return key_; } #else diff --git a/crypto/ec_private_key_nss.cc b/crypto/ec_private_key_nss.cc index cc46101..1fb13e7 100644 --- a/crypto/ec_private_key_nss.cc +++ b/crypto/ec_private_key_nss.cc @@ -104,6 +104,76 @@ ECPrivateKey* ECPrivateKey::CreateSensitiveFromEncryptedPrivateKeyInfo( #endif } +// static +bool ECPrivateKey::ImportFromEncryptedPrivateKeyInfo( + const std::string& password, + const uint8* encrypted_private_key_info, + size_t encrypted_private_key_info_len, + CERTSubjectPublicKeyInfo* decoded_spki, + bool permanent, + bool sensitive, + SECKEYPrivateKey** key, + SECKEYPublicKey** public_key) { + ScopedPK11Slot slot(GetPrivateNSSKeySlot()); + if (!slot.get()) + return false; + + *public_key = SECKEY_ExtractPublicKey(decoded_spki); + + if (!*public_key) { + DLOG(ERROR) << "SECKEY_ExtractPublicKey: " << PORT_GetError(); + return false; + } + + SECItem encoded_epki = { + siBuffer, + const_cast<unsigned char*>(encrypted_private_key_info), + encrypted_private_key_info_len + }; + SECKEYEncryptedPrivateKeyInfo epki; + memset(&epki, 0, sizeof(epki)); + + ScopedPLArenaPool arena(PORT_NewArena(DER_DEFAULT_CHUNKSIZE)); + + SECStatus rv = SEC_QuickDERDecodeItem( + arena.get(), + &epki, + SEC_ASN1_GET(SECKEY_EncryptedPrivateKeyInfoTemplate), + &encoded_epki); + if (rv != SECSuccess) { + DLOG(ERROR) << "SEC_QuickDERDecodeItem: " << PORT_GetError(); + SECKEY_DestroyPublicKey(*public_key); + *public_key = NULL; + return false; + } + + SECItem password_item = { + siBuffer, + reinterpret_cast<unsigned char*>(const_cast<char*>(password.data())), + password.size() + }; + + rv = ImportEncryptedECPrivateKeyInfoAndReturnKey( + slot.get(), + &epki, + &password_item, + NULL, // nickname + &(*public_key)->u.ec.publicValue, + permanent, + sensitive, + key, + NULL); // wincx + if (rv != SECSuccess) { + DLOG(ERROR) << "ImportEncryptedECPrivateKeyInfoAndReturnKey: " + << PORT_GetError(); + SECKEY_DestroyPublicKey(*public_key); + *public_key = NULL; + return false; + } + + return true; +} + bool ECPrivateKey::ExportEncryptedPrivateKey( const std::string& password, int iterations, @@ -227,10 +297,6 @@ ECPrivateKey* ECPrivateKey::CreateFromEncryptedPrivateKeyInfoWithParams( scoped_ptr<ECPrivateKey> result(new ECPrivateKey); - ScopedPK11Slot slot(GetPrivateNSSKeySlot()); - if (!slot.get()) - return NULL; - SECItem encoded_spki = { siBuffer, const_cast<unsigned char*>(&subject_public_key_info[0]), @@ -243,58 +309,22 @@ ECPrivateKey* ECPrivateKey::CreateFromEncryptedPrivateKeyInfoWithParams( return NULL; } - result->public_key_ = SECKEY_ExtractPublicKey(decoded_spki); - - SECKEY_DestroySubjectPublicKeyInfo(decoded_spki); - - if (!result->public_key_) { - DLOG(ERROR) << "SECKEY_ExtractPublicKey: " << PORT_GetError(); - return NULL; - } - - SECItem encoded_epki = { - siBuffer, - const_cast<unsigned char*>(&encrypted_private_key_info[0]), - encrypted_private_key_info.size() - }; - SECKEYEncryptedPrivateKeyInfo epki; - memset(&epki, 0, sizeof(epki)); - - ScopedPLArenaPool arena(PORT_NewArena(DER_DEFAULT_CHUNKSIZE)); - - SECStatus rv = SEC_QuickDERDecodeItem( - arena.get(), - &epki, - SEC_ASN1_GET(SECKEY_EncryptedPrivateKeyInfoTemplate), - &encoded_epki); - if (rv != SECSuccess) { - DLOG(ERROR) << "SEC_ASN1DecodeItem: " << PORT_GetError(); - return NULL; - } - - SECItem password_item = { - siBuffer, - reinterpret_cast<unsigned char*>(const_cast<char*>(password.data())), - password.size() - }; - - rv = ImportEncryptedECPrivateKeyInfoAndReturnKey( - slot.get(), - &epki, - &password_item, - NULL, // nickname - &result->public_key_->u.ec.publicValue, + bool success = ECPrivateKey::ImportFromEncryptedPrivateKeyInfo( + password, + &encrypted_private_key_info[0], + encrypted_private_key_info.size(), + decoded_spki, permanent, sensitive, &result->key_, - NULL); // wincx - if (rv != SECSuccess) { - DLOG(ERROR) << "ImportEncryptedECPrivateKeyInfoAndReturnKey: " - << PORT_GetError(); - return NULL; - } + &result->public_key_); - return result.release(); + SECKEY_DestroySubjectPublicKeyInfo(decoded_spki); + + if (success) + return result.release(); + + return NULL; } } // namespace crypto diff --git a/net/base/default_origin_bound_cert_store.cc b/net/base/default_origin_bound_cert_store.cc index 82aec7d..8104658 100644 --- a/net/base/default_origin_bound_cert_store.cc +++ b/net/base/default_origin_bound_cert_store.cc @@ -29,6 +29,7 @@ void DefaultOriginBoundCertStore::FlushStore( bool DefaultOriginBoundCertStore::GetOriginBoundCert( const std::string& origin, + SSLClientCertType* type, std::string* private_key_result, std::string* cert_result) { base::AutoLock autolock(lock_); @@ -40,6 +41,7 @@ bool DefaultOriginBoundCertStore::GetOriginBoundCert( return false; OriginBoundCert* cert = it->second; + *type = cert->type(); *private_key_result = cert->private_key(); *cert_result = cert->cert(); @@ -48,14 +50,15 @@ bool DefaultOriginBoundCertStore::GetOriginBoundCert( void DefaultOriginBoundCertStore::SetOriginBoundCert( const std::string& origin, + SSLClientCertType type, const std::string& private_key, const std::string& cert) { base::AutoLock autolock(lock_); InitIfNecessary(); InternalDeleteOriginBoundCert(origin); - InternalInsertOriginBoundCert(origin, - new OriginBoundCert(origin, private_key, cert)); + InternalInsertOriginBoundCert( + origin, new OriginBoundCert(origin, type, private_key, cert)); } void DefaultOriginBoundCertStore::DeleteOriginBoundCert( @@ -79,20 +82,12 @@ void DefaultOriginBoundCertStore::DeleteAll() { } void DefaultOriginBoundCertStore::GetAllOriginBoundCerts( - std::vector<OriginBoundCertInfo>* origin_bound_certs) { + std::vector<OriginBoundCert>* origin_bound_certs) { base::AutoLock autolock(lock_); InitIfNecessary(); for (OriginBoundCertMap::iterator it = origin_bound_certs_.begin(); it != origin_bound_certs_.end(); ++it) { - OriginBoundCertInfo cert_info = { - it->second->origin(), - it->second->private_key(), - it->second->cert() - }; - // TODO(rkn): Make changes so we can simply write - // origin_bound_certs->push_back(*it->second). This is probably best done - // by unnesting the OriginBoundCert class. - origin_bound_certs->push_back(cert_info); + origin_bound_certs->push_back(*it->second); } } @@ -160,16 +155,6 @@ void DefaultOriginBoundCertStore::InternalInsertOriginBoundCert( origin_bound_certs_[origin] = cert; } -DefaultOriginBoundCertStore::OriginBoundCert::OriginBoundCert() {} - -DefaultOriginBoundCertStore::OriginBoundCert::OriginBoundCert( - const std::string& origin, - const std::string& private_key, - const std::string& cert) - : origin_(origin), - private_key_(private_key), - cert_(cert) {} - DefaultOriginBoundCertStore::PersistentStore::PersistentStore() {} } // namespace net diff --git a/net/base/default_origin_bound_cert_store.h b/net/base/default_origin_bound_cert_store.h index 05dd70b..f5e0394 100644 --- a/net/base/default_origin_bound_cert_store.h +++ b/net/base/default_origin_bound_cert_store.h @@ -31,7 +31,6 @@ namespace net { // by IO and origin bound cert management UI. class NET_EXPORT DefaultOriginBoundCertStore : public OriginBoundCertStore { public: - class OriginBoundCert; class PersistentStore; // The key for each OriginBoundCert* in OriginBoundCertMap is the @@ -55,16 +54,20 @@ class NET_EXPORT DefaultOriginBoundCertStore : public OriginBoundCertStore { void FlushStore(const base::Closure& completion_task); // OriginBoundCertStore implementation. - virtual bool GetOriginBoundCert(const std::string& origin, - std::string* private_key_result, - std::string* cert_result) OVERRIDE; - virtual void SetOriginBoundCert(const std::string& origin, - const std::string& private_key, - const std::string& cert) OVERRIDE; + virtual bool GetOriginBoundCert( + const std::string& origin, + SSLClientCertType* type, + std::string* private_key_result, + std::string* cert_result) OVERRIDE; + virtual void SetOriginBoundCert( + const std::string& origin, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert) OVERRIDE; virtual void DeleteOriginBoundCert(const std::string& origin) OVERRIDE; virtual void DeleteAll() OVERRIDE; virtual void GetAllOriginBoundCerts( - std::vector<OriginBoundCertInfo>* origin_bound_certs) OVERRIDE; + std::vector<OriginBoundCert>* origin_bound_certs) OVERRIDE; virtual int GetCertCount() OVERRIDE; private: @@ -113,25 +116,6 @@ class NET_EXPORT DefaultOriginBoundCertStore : public OriginBoundCertStore { DISALLOW_COPY_AND_ASSIGN(DefaultOriginBoundCertStore); }; -// The OriginBoundCert class contains a private key in addition to the origin -// and the cert. -class NET_EXPORT DefaultOriginBoundCertStore::OriginBoundCert { - public: - OriginBoundCert(); - OriginBoundCert(const std::string& origin, - const std::string& privatekey, - const std::string& cert); - - const std::string& origin() const { return origin_; } - const std::string& private_key() const { return private_key_; } - const std::string& cert() const { return cert_; } - - private: - std::string origin_; - std::string private_key_; - std::string cert_; -}; - typedef base::RefCountedThreadSafe<DefaultOriginBoundCertStore::PersistentStore> RefcountedPersistentStore; @@ -144,7 +128,7 @@ class NET_EXPORT DefaultOriginBoundCertStore::PersistentStore // called only once at startup. Note that the certs are individually allocated // and that ownership is transferred to the caller upon return. virtual bool Load( - std::vector<DefaultOriginBoundCertStore::OriginBoundCert*>* certs) = 0; + std::vector<OriginBoundCert*>* certs) = 0; virtual void AddOriginBoundCert(const OriginBoundCert& cert) = 0; diff --git a/net/base/default_origin_bound_cert_store_unittest.cc b/net/base/default_origin_bound_cert_store_unittest.cc index 45356e8..6b888e0 100644 --- a/net/base/default_origin_bound_cert_store_unittest.cc +++ b/net/base/default_origin_bound_cert_store_unittest.cc @@ -77,33 +77,40 @@ TEST(DefaultOriginBoundCertStoreTest, TestLoading) { persistent_store->AddOriginBoundCert( DefaultOriginBoundCertStore::OriginBoundCert( - "https://encrypted.google.com/", "a", "b")); + "https://encrypted.google.com/", CLIENT_CERT_RSA_SIGN, "a", "b")); persistent_store->AddOriginBoundCert( DefaultOriginBoundCertStore::OriginBoundCert( - "https://www.verisign.com/", "c", "d")); + "https://www.verisign.com/", CLIENT_CERT_ECDSA_SIGN, "c", "d")); // Make sure certs load properly. DefaultOriginBoundCertStore store(persistent_store.get()); EXPECT_EQ(2, store.GetCertCount()); - store.SetOriginBoundCert("https://www.verisign.com/", "e", "f"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_RSA_SIGN, "e", "f"); EXPECT_EQ(2, store.GetCertCount()); - store.SetOriginBoundCert("https://www.twitter.com/", "g", "h"); + store.SetOriginBoundCert( + "https://www.twitter.com/", CLIENT_CERT_RSA_SIGN, "g", "h"); EXPECT_EQ(3, store.GetCertCount()); } TEST(DefaultOriginBoundCertStoreTest, TestSettingAndGetting) { DefaultOriginBoundCertStore store(NULL); + SSLClientCertType type; std::string private_key, cert; EXPECT_EQ(0, store.GetCertCount()); EXPECT_FALSE(store.GetOriginBoundCert("https://www.verisign.com/", - &private_key, - &cert)); + &type, + &private_key, + &cert)); EXPECT_TRUE(private_key.empty()); EXPECT_TRUE(cert.empty()); - store.SetOriginBoundCert("https://www.verisign.com/", "i", "j"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_RSA_SIGN, "i", "j"); EXPECT_TRUE(store.GetOriginBoundCert("https://www.verisign.com/", - &private_key, - &cert)); + &type, + &private_key, + &cert)); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type); EXPECT_EQ("i", private_key); EXPECT_EQ("j", cert); } @@ -112,15 +119,20 @@ TEST(DefaultOriginBoundCertStoreTest, TestDuplicateCerts) { scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); DefaultOriginBoundCertStore store(persistent_store.get()); + SSLClientCertType type; std::string private_key, cert; EXPECT_EQ(0, store.GetCertCount()); - store.SetOriginBoundCert("https://www.verisign.com/", "a", "b"); - store.SetOriginBoundCert("https://www.verisign.com/", "c", "d"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_RSA_SIGN, "a", "b"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_ECDSA_SIGN, "c", "d"); EXPECT_EQ(1, store.GetCertCount()); EXPECT_TRUE(store.GetOriginBoundCert("https://www.verisign.com/", - &private_key, - &cert)); + &type, + &private_key, + &cert)); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type); EXPECT_EQ("c", private_key); EXPECT_EQ("d", cert); } @@ -130,9 +142,12 @@ TEST(DefaultOriginBoundCertStoreTest, TestDeleteAll) { DefaultOriginBoundCertStore store(persistent_store.get()); EXPECT_EQ(0, store.GetCertCount()); - store.SetOriginBoundCert("https://www.verisign.com/", "a", "b"); - store.SetOriginBoundCert("https://www.google.com/", "c", "d"); - store.SetOriginBoundCert("https://www.harvard.com/", "e", "f"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_RSA_SIGN, "a", "b"); + store.SetOriginBoundCert( + "https://www.google.com/", CLIENT_CERT_RSA_SIGN, "c", "d"); + store.SetOriginBoundCert( + "https://www.harvard.com/", CLIENT_CERT_RSA_SIGN, "e", "f"); EXPECT_EQ(3, store.GetCertCount()); store.DeleteAll(); @@ -143,25 +158,31 @@ TEST(DefaultOriginBoundCertStoreTest, TestDelete) { scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); DefaultOriginBoundCertStore store(persistent_store.get()); + SSLClientCertType type; std::string private_key, cert; EXPECT_EQ(0, store.GetCertCount()); - store.SetOriginBoundCert("https://www.verisign.com/", "a", "b"); - store.SetOriginBoundCert("https://www.google.com/", "c", "d"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_RSA_SIGN, "a", "b"); + store.SetOriginBoundCert( + "https://www.google.com/", CLIENT_CERT_ECDSA_SIGN, "c", "d"); EXPECT_EQ(2, store.GetCertCount()); store.DeleteOriginBoundCert("https://www.verisign.com/"); EXPECT_EQ(1, store.GetCertCount()); EXPECT_FALSE(store.GetOriginBoundCert("https://www.verisign.com/", - &private_key, - &cert)); - EXPECT_TRUE(store.GetOriginBoundCert("https://www.google.com/", + &type, &private_key, &cert)); + EXPECT_TRUE(store.GetOriginBoundCert("https://www.google.com/", + &type, + &private_key, + &cert)); store.DeleteOriginBoundCert("https://www.google.com/"); EXPECT_EQ(0, store.GetCertCount()); EXPECT_FALSE(store.GetOriginBoundCert("https://www.google.com/", - &private_key, - &cert)); + &type, + &private_key, + &cert)); } TEST(DefaultOriginBoundCertStoreTest, TestGetAll) { @@ -169,13 +190,17 @@ TEST(DefaultOriginBoundCertStoreTest, TestGetAll) { DefaultOriginBoundCertStore store(persistent_store.get()); EXPECT_EQ(0, store.GetCertCount()); - store.SetOriginBoundCert("https://www.verisign.com/", "a", "b"); - store.SetOriginBoundCert("https://www.google.com/", "c", "d"); - store.SetOriginBoundCert("https://www.harvard.com/", "e", "f"); - store.SetOriginBoundCert("https://www.mit.com/", "g", "h"); + store.SetOriginBoundCert( + "https://www.verisign.com/", CLIENT_CERT_RSA_SIGN, "a", "b"); + store.SetOriginBoundCert( + "https://www.google.com/", CLIENT_CERT_ECDSA_SIGN, "c", "d"); + store.SetOriginBoundCert( + "https://www.harvard.com/", CLIENT_CERT_RSA_SIGN, "e", "f"); + store.SetOriginBoundCert( + "https://www.mit.com/", CLIENT_CERT_RSA_SIGN, "g", "h"); EXPECT_EQ(4, store.GetCertCount()); - std::vector<OriginBoundCertStore::OriginBoundCertInfo> certs; + std::vector<OriginBoundCertStore::OriginBoundCert> certs; store.GetAllOriginBoundCerts(&certs); EXPECT_EQ(4u, certs.size()); } diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h index 430d7ae..f9734e0 100644 --- a/net/base/net_error_list.h +++ b/net/base/net_error_list.h @@ -278,6 +278,13 @@ NET_ERROR(SSL_BAD_PEER_PUBLIC_KEY, -149) // one of a set of public keys exist on the path from the leaf to the root. NET_ERROR(SSL_PINNED_KEY_NOT_IN_CERT_CHAIN, -150) +// Server request for client certificate did not contain any types we support. +NET_ERROR(CLIENT_AUTH_CERT_TYPE_UNSUPPORTED, -151) + +// Server requested one type of cert, then requested a different type while the +// first was still being generated. +NET_ERROR(ORIGIN_BOUND_CERT_GENERATION_TYPE_MISMATCH, -152) + // Certificate error codes // // The values of certificate error codes must be consecutive. diff --git a/net/base/origin_bound_cert_service.cc b/net/base/origin_bound_cert_service.cc index 4d1af28..f86d82c 100644 --- a/net/base/origin_bound_cert_service.cc +++ b/net/base/origin_bound_cert_service.cc @@ -4,6 +4,7 @@ #include "net/base/origin_bound_cert_service.h" +#include <algorithm> #include <limits> #include "base/compiler_specific.h" @@ -15,6 +16,7 @@ #include "base/rand_util.h" #include "base/stl_util.h" #include "base/threading/worker_pool.h" +#include "crypto/ec_private_key.h" #include "crypto/rsa_private_key.h" #include "net/base/net_errors.h" #include "net/base/origin_bound_cert_store.h" @@ -32,15 +34,27 @@ namespace { const int kKeySizeInBits = 1024; const int kValidityPeriodInDays = 365; +bool IsSupportedCertType(uint8 type) { + switch(type) { + case CLIENT_CERT_RSA_SIGN: + case CLIENT_CERT_ECDSA_SIGN: + return true; + default: + return false; + } +} + } // namespace // Represents the output and result callback of a request. class OriginBoundCertServiceRequest { public: OriginBoundCertServiceRequest(const CompletionCallback& callback, + SSLClientCertType* type, std::string* private_key, std::string* cert) : callback_(callback), + type_(type), private_key_(private_key), cert_(cert) { } @@ -48,6 +62,7 @@ class OriginBoundCertServiceRequest { // Ensures that the result callback will never be made. void Cancel() { callback_.Reset(); + type_ = NULL; private_key_ = NULL; cert_ = NULL; } @@ -55,9 +70,11 @@ class OriginBoundCertServiceRequest { // Copies the contents of |private_key| and |cert| to the caller's output // arguments and calls the callback. void Post(int error, + SSLClientCertType type, const std::string& private_key, const std::string& cert) { if (!callback_.is_null()) { + *type_ = type; *private_key_ = private_key; *cert_ = cert; callback_.Run(error); @@ -69,6 +86,7 @@ class OriginBoundCertServiceRequest { private: CompletionCallback callback_; + SSLClientCertType* type_; std::string* private_key_; std::string* cert_; }; @@ -80,8 +98,10 @@ class OriginBoundCertServiceWorker { public: OriginBoundCertServiceWorker( const std::string& origin, + SSLClientCertType type, OriginBoundCertService* origin_bound_cert_service) : origin_(origin), + type_(type), serial_number_(base::RandInt(0, std::numeric_limits<int>::max())), origin_loop_(MessageLoop::current()), origin_bound_cert_service_(origin_bound_cert_service), @@ -110,6 +130,7 @@ class OriginBoundCertServiceWorker { void Run() { // Runs on a worker thread. error_ = OriginBoundCertService::GenerateCert(origin_, + type_, serial_number_, &private_key_, &cert_); @@ -136,7 +157,7 @@ class OriginBoundCertServiceWorker { // memory leaks or worse errors. base::AutoLock locked(lock_); if (!canceled_) { - origin_bound_cert_service_->HandleResult(origin_, error_, + origin_bound_cert_service_->HandleResult(origin_, error_, type_, private_key_, cert_); } } @@ -169,6 +190,7 @@ class OriginBoundCertServiceWorker { } const std::string origin_; + const SSLClientCertType type_; // Note that serial_number_ must be initialized on a non-worker thread // (see documentation for OriginBoundCertService::GenerateCert). uint32 serial_number_; @@ -195,8 +217,9 @@ class OriginBoundCertServiceWorker { // origin message loop. class OriginBoundCertServiceJob { public: - explicit OriginBoundCertServiceJob(OriginBoundCertServiceWorker* worker) - : worker_(worker) { + OriginBoundCertServiceJob(OriginBoundCertServiceWorker* worker, + SSLClientCertType type) + : worker_(worker), type_(type) { } ~OriginBoundCertServiceJob() { @@ -206,19 +229,23 @@ class OriginBoundCertServiceJob { } } + SSLClientCertType type() const { return type_; } + void AddRequest(OriginBoundCertServiceRequest* request) { requests_.push_back(request); } void HandleResult(int error, + SSLClientCertType type, const std::string& private_key, const std::string& cert) { worker_ = NULL; - PostAll(error, private_key, cert); + PostAll(error, type, private_key, cert); } private: void PostAll(int error, + SSLClientCertType type, const std::string& private_key, const std::string& cert) { std::vector<OriginBoundCertServiceRequest*> requests; @@ -226,7 +253,7 @@ class OriginBoundCertServiceJob { for (std::vector<OriginBoundCertServiceRequest*>::iterator i = requests.begin(); i != requests.end(); i++) { - (*i)->Post(error, private_key, cert); + (*i)->Post(error, type, private_key, cert); // Post() causes the OriginBoundCertServiceRequest to delete itself. } } @@ -244,8 +271,12 @@ class OriginBoundCertServiceJob { std::vector<OriginBoundCertServiceRequest*> requests_; OriginBoundCertServiceWorker* worker_; + SSLClientCertType type_; }; +// static +const char OriginBoundCertService::kEPKIPassword[] = ""; + OriginBoundCertService::OriginBoundCertService( OriginBoundCertStore* origin_bound_cert_store) : origin_bound_cert_store_(origin_bound_cert_store), @@ -259,43 +290,80 @@ OriginBoundCertService::~OriginBoundCertService() { int OriginBoundCertService::GetOriginBoundCert( const std::string& origin, + const std::vector<uint8>& requested_types, + SSLClientCertType* type, std::string* private_key, std::string* cert, const CompletionCallback& callback, RequestHandle* out_req) { DCHECK(CalledOnValidThread()); - if (callback.is_null() || !private_key || !cert || origin.empty()) { + if (callback.is_null() || !private_key || !cert || origin.empty() || + requested_types.empty()) { *out_req = NULL; return ERR_INVALID_ARGUMENT; } + SSLClientCertType preferred_type = CLIENT_CERT_INVALID_TYPE; + for (size_t i = 0; i < requested_types.size(); ++i) { + if (IsSupportedCertType(requested_types[i])) { + preferred_type = static_cast<SSLClientCertType>(requested_types[i]); + break; + } + } + if (preferred_type == CLIENT_CERT_INVALID_TYPE) { + // None of the requested types are supported. + *out_req = NULL; + return ERR_CLIENT_AUTH_CERT_TYPE_UNSUPPORTED; + } + requests_++; - // Check if an origin bound cert already exists for this origin. + // Check if an origin bound cert of an acceptable type already exists for this + // origin. if (origin_bound_cert_store_->GetOriginBoundCert(origin, + type, private_key, cert)) { - cert_store_hits_++; - *out_req = NULL; - return OK; + if (IsSupportedCertType(*type) && + std::find(requested_types.begin(), requested_types.end(), *type) != + requested_types.end()) { + cert_store_hits_++; + *out_req = NULL; + return OK; + } + DVLOG(1) << "Cert store had cert of wrong type " << *type << " for " + << origin; } // |origin_bound_cert_store_| has no cert for this origin. See if an // identical request is currently in flight. - OriginBoundCertServiceJob* job; + OriginBoundCertServiceJob* job = NULL; std::map<std::string, OriginBoundCertServiceJob*>::const_iterator j; j = inflight_.find(origin); if (j != inflight_.end()) { // An identical request is in flight already. We'll just attach our // callback. - inflight_joins_++; job = j->second; + // Check that the job is for an acceptable type of cert. + if (std::find(requested_types.begin(), requested_types.end(), job->type()) + == requested_types.end()) { + DVLOG(1) << "Found inflight job of wrong type " << job->type() + << " for " << origin; + *out_req = NULL; + // If we get here, the server is asking for different types of certs in + // short succession. This probably means the server is broken or + // misconfigured. Since we only store one type of cert per origin, we + // are unable to handle this well. Just return an error and let the first + // job finish. + return ERR_ORIGIN_BOUND_CERT_GENERATION_TYPE_MISMATCH; + } + inflight_joins_++; } else { // Need to make a new request. OriginBoundCertServiceWorker* worker = - new OriginBoundCertServiceWorker(origin, this); - job = new OriginBoundCertServiceJob(worker); + new OriginBoundCertServiceWorker(origin, preferred_type, this); + job = new OriginBoundCertServiceJob(worker, preferred_type); if (!worker->Start()) { delete job; delete worker; @@ -308,7 +376,7 @@ int OriginBoundCertService::GetOriginBoundCert( } OriginBoundCertServiceRequest* request = - new OriginBoundCertServiceRequest(callback, private_key, cert); + new OriginBoundCertServiceRequest(callback, type, private_key, cert); job->AddRequest(request); *out_req = request; return ERR_IO_PENDING; @@ -316,31 +384,64 @@ int OriginBoundCertService::GetOriginBoundCert( // static int OriginBoundCertService::GenerateCert(const std::string& origin, + SSLClientCertType type, uint32 serial_number, std::string* private_key, std::string* cert) { - scoped_ptr<crypto::RSAPrivateKey> key( - crypto::RSAPrivateKey::Create(kKeySizeInBits)); - if (!key.get()) { - LOG(WARNING) << "Unable to create key pair for client"; - return ERR_KEY_GENERATION_FAILED; - } std::string der_cert; - if (!x509_util::CreateOriginBoundCertRSA( - key.get(), - origin, - serial_number, - base::TimeDelta::FromDays(kValidityPeriodInDays), - &der_cert)) { - LOG(WARNING) << "Unable to create x509 cert for client"; - return ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED; - } - std::vector<uint8> private_key_info; - if (!key->ExportPrivateKey(&private_key_info)) { - LOG(WARNING) << "Unable to export private key"; - return ERR_PRIVATE_KEY_EXPORT_FAILED; + switch (type) { + case CLIENT_CERT_RSA_SIGN: { + scoped_ptr<crypto::RSAPrivateKey> key( + crypto::RSAPrivateKey::Create(kKeySizeInBits)); + if (!key.get()) { + DLOG(ERROR) << "Unable to create key pair for client"; + return ERR_KEY_GENERATION_FAILED; + } + if (!x509_util::CreateOriginBoundCertRSA( + key.get(), + origin, + serial_number, + base::TimeDelta::FromDays(kValidityPeriodInDays), + &der_cert)) { + DLOG(ERROR) << "Unable to create x509 cert for client"; + return ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED; + } + + if (!key->ExportPrivateKey(&private_key_info)) { + DLOG(ERROR) << "Unable to export private key"; + return ERR_PRIVATE_KEY_EXPORT_FAILED; + } + break; + } + case CLIENT_CERT_ECDSA_SIGN: { + scoped_ptr<crypto::ECPrivateKey> key(crypto::ECPrivateKey::Create()); + if (!key.get()) { + DLOG(ERROR) << "Unable to create key pair for client"; + return ERR_KEY_GENERATION_FAILED; + } + if (!x509_util::CreateOriginBoundCertEC( + key.get(), + origin, + serial_number, + base::TimeDelta::FromDays(kValidityPeriodInDays), + &der_cert)) { + DLOG(ERROR) << "Unable to create x509 cert for client"; + return ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED; + } + + if (!key->ExportEncryptedPrivateKey( + kEPKIPassword, 1, &private_key_info)) { + DLOG(ERROR) << "Unable to export private key"; + return ERR_PRIVATE_KEY_EXPORT_FAILED; + } + break; + } + default: + NOTREACHED(); + return ERR_INVALID_ARGUMENT; } + // TODO(rkn): Perhaps ExportPrivateKey should be changed to output a // std::string* to prevent this copying. std::string key_out(private_key_info.begin(), private_key_info.end()); @@ -360,12 +461,13 @@ void OriginBoundCertService::CancelRequest(RequestHandle req) { // HandleResult is called by OriginBoundCertServiceWorker on the origin message // loop. It deletes OriginBoundCertServiceJob. void OriginBoundCertService::HandleResult(const std::string& origin, - int error, - const std::string& private_key, - const std::string& cert) { + int error, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert) { DCHECK(CalledOnValidThread()); - origin_bound_cert_store_->SetOriginBoundCert(origin, private_key, cert); + origin_bound_cert_store_->SetOriginBoundCert(origin, type, private_key, cert); std::map<std::string, OriginBoundCertServiceJob*>::iterator j; j = inflight_.find(origin); @@ -376,7 +478,7 @@ void OriginBoundCertService::HandleResult(const std::string& origin, OriginBoundCertServiceJob* job = j->second; inflight_.erase(j); - job->HandleResult(error, private_key, cert); + job->HandleResult(error, type, private_key, cert); delete job; } diff --git a/net/base/origin_bound_cert_service.h b/net/base/origin_bound_cert_service.h index 861602f..c3861e6 100644 --- a/net/base/origin_bound_cert_service.h +++ b/net/base/origin_bound_cert_service.h @@ -8,12 +8,14 @@ #include <map> #include <string> +#include <vector> #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" #include "base/threading/non_thread_safe.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" +#include "net/base/ssl_client_cert_type.h" namespace net { @@ -30,20 +32,28 @@ class NET_EXPORT OriginBoundCertService // Opaque type used to cancel a request. typedef void* RequestHandle; + // Password used on EncryptedPrivateKeyInfo data stored in EC private_key + // values. (This is not used to provide any security, but to workaround NSS + // being unable to import unencrypted PrivateKeyInfo for EC keys.) + static const char kEPKIPassword[]; + // This object owns origin_bound_cert_store. explicit OriginBoundCertService( OriginBoundCertStore* origin_bound_cert_store); ~OriginBoundCertService(); - // TODO(rkn): Specify certificate type (RSA or DSA). + // Fetches the origin bound cert for the specified origin of the specified + // type if one exists and creates one otherwise. Returns OK if successful or + // an error code upon failure. // - // Fetches the origin bound cert for the specified origin if one exists - // and creates one otherwise. Returns OK if successful or an error code upon - // failure. + // |requested_types| is a list of the TLS ClientCertificateTypes the site will + // accept, ordered from most preferred to least preferred. Types we don't + // support will be ignored. See ssl_client_cert_type.h. // // On successful completion, |private_key| stores a DER-encoded - // PrivateKeyInfo struct, and |cert| stores a DER-encoded certificate. + // PrivateKeyInfo struct, and |cert| stores a DER-encoded certificate, and + // |type| specifies the type of certificate that was returned. // // |callback| must not be null. ERR_IO_PENDING is returned if the operation // could not be completed immediately, in which case the result code will @@ -52,11 +62,14 @@ class NET_EXPORT OriginBoundCertService // If |out_req| is non-NULL, then |*out_req| will be filled with a handle to // the async request. This handle is not valid after the request has // completed. - int GetOriginBoundCert(const std::string& origin, - std::string* private_key, - std::string* cert, - const CompletionCallback& callback, - RequestHandle* out_req); + int GetOriginBoundCert( + const std::string& origin, + const std::vector<uint8>& requested_types, + SSLClientCertType* type, + std::string* private_key, + std::string* cert, + const CompletionCallback& callback, + RequestHandle* out_req); // Cancels the specified request. |req| is the handle returned by // GetOriginBoundCert(). After a request is canceled, its completion @@ -79,12 +92,14 @@ class NET_EXPORT OriginBoundCertService // base::RandInt, which opens the file /dev/urandom. /dev/urandom is opened // with a LazyInstance, which is not allowed on a worker thread. static int GenerateCert(const std::string& origin, + SSLClientCertType type, uint32 serial_number, std::string* private_key, std::string* cert); void HandleResult(const std::string& origin, int error, + SSLClientCertType type, const std::string& private_key, const std::string& cert); diff --git a/net/base/origin_bound_cert_service_unittest.cc b/net/base/origin_bound_cert_service_unittest.cc index 3fdb443..1adedfa 100644 --- a/net/base/origin_bound_cert_service_unittest.cc +++ b/net/base/origin_bound_cert_service_unittest.cc @@ -9,7 +9,9 @@ #include "base/bind.h" #include "base/memory/scoped_ptr.h" +#include "crypto/ec_private_key.h" #include "crypto/rsa_private_key.h" +#include "net/base/asn1_util.h" #include "net/base/default_origin_bound_cert_store.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -33,53 +35,178 @@ TEST(OriginBoundCertServiceTest, CacheHit) { std::string origin("https://encrypted.google.com:443"); int error; + std::vector<uint8> types; + types.push_back(CLIENT_CERT_RSA_SIGN); TestCompletionCallback callback; OriginBoundCertService::RequestHandle request_handle; // Asynchronous completion. + SSLClientCertType type1; std::string private_key_info1, der_cert1; EXPECT_EQ(0, service->cert_count()); error = service->GetOriginBoundCert( - origin, &private_key_info1, &der_cert1, callback.callback(), - &request_handle); + origin, types, &type1, &private_key_info1, &der_cert1, + callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle != NULL); error = callback.WaitForResult(); EXPECT_EQ(OK, error); EXPECT_EQ(1, service->cert_count()); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type1); EXPECT_FALSE(private_key_info1.empty()); EXPECT_FALSE(der_cert1.empty()); // Synchronous completion. + SSLClientCertType type2; + // If we request EC and RSA, should still retrieve the RSA cert. + types.insert(types.begin(), CLIENT_CERT_ECDSA_SIGN); std::string private_key_info2, der_cert2; error = service->GetOriginBoundCert( - origin, &private_key_info2, &der_cert2, callback.callback(), - &request_handle); + origin, types, &type2, &private_key_info2, &der_cert2, + callback.callback(), &request_handle); EXPECT_TRUE(request_handle == NULL); EXPECT_EQ(OK, error); EXPECT_EQ(1, service->cert_count()); - + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type2); EXPECT_EQ(private_key_info1, private_key_info2); EXPECT_EQ(der_cert1, der_cert2); - EXPECT_EQ(2u, service->requests()); - EXPECT_EQ(1u, service->cert_store_hits()); + // Request only EC. Should generate a new EC cert and discard the old RSA + // cert. + SSLClientCertType type3; + types.pop_back(); // Remove CLIENT_CERT_RSA_SIGN from requested types. + std::string private_key_info3, der_cert3; + EXPECT_EQ(1, service->cert_count()); + error = service->GetOriginBoundCert( + origin, types, &type3, &private_key_info3, &der_cert3, + callback.callback(), &request_handle); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + EXPECT_EQ(OK, error); + EXPECT_EQ(1, service->cert_count()); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type3); + EXPECT_FALSE(private_key_info1.empty()); + EXPECT_FALSE(der_cert1.empty()); + EXPECT_NE(private_key_info1, private_key_info3); + EXPECT_NE(der_cert1, der_cert3); + + // Synchronous completion. + // If we request RSA and EC, should now retrieve the EC cert. + SSLClientCertType type4; + types.insert(types.begin(), CLIENT_CERT_RSA_SIGN); + std::string private_key_info4, der_cert4; + error = service->GetOriginBoundCert( + origin, types, &type4, &private_key_info4, &der_cert4, + callback.callback(), &request_handle); + EXPECT_TRUE(request_handle == NULL); + EXPECT_EQ(OK, error); + EXPECT_EQ(1, service->cert_count()); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type4); + EXPECT_EQ(private_key_info3, private_key_info4); + EXPECT_EQ(der_cert3, der_cert4); + + EXPECT_EQ(4u, service->requests()); + EXPECT_EQ(2u, service->cert_store_hits()); EXPECT_EQ(0u, service->inflight_joins()); } +TEST(OriginBoundCertServiceTest, UnsupportedTypes) { + scoped_ptr<OriginBoundCertService> service( + new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); + std::string origin("https://encrypted.google.com:443"); + + int error; + std::vector<uint8> types; + TestCompletionCallback callback; + OriginBoundCertService::RequestHandle request_handle; + + // Empty requested_types. + SSLClientCertType type1; + std::string private_key_info1, der_cert1; + error = service->GetOriginBoundCert( + origin, types, &type1, &private_key_info1, &der_cert1, + callback.callback(), &request_handle); + EXPECT_EQ(ERR_INVALID_ARGUMENT, error); + EXPECT_TRUE(request_handle == NULL); + + // No supported types in requested_types. + types.push_back(2); + types.push_back(3); + error = service->GetOriginBoundCert( + origin, types, &type1, &private_key_info1, &der_cert1, + callback.callback(), &request_handle); + EXPECT_EQ(ERR_CLIENT_AUTH_CERT_TYPE_UNSUPPORTED, error); + EXPECT_TRUE(request_handle == NULL); + + // Supported types after unsupported ones in requested_types. + types.push_back(CLIENT_CERT_ECDSA_SIGN); + types.push_back(CLIENT_CERT_RSA_SIGN); + // Asynchronous completion. + EXPECT_EQ(0, service->cert_count()); + error = service->GetOriginBoundCert( + origin, types, &type1, &private_key_info1, &der_cert1, + callback.callback(), &request_handle); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + EXPECT_EQ(OK, error); + EXPECT_EQ(1, service->cert_count()); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type1); + EXPECT_FALSE(private_key_info1.empty()); + EXPECT_FALSE(der_cert1.empty()); + + // Now that the cert is created, doing requests for unsupported types + // shouldn't affect the created cert. + // Empty requested_types. + types.clear(); + SSLClientCertType type2; + std::string private_key_info2, der_cert2; + error = service->GetOriginBoundCert( + origin, types, &type2, &private_key_info2, &der_cert2, + callback.callback(), &request_handle); + EXPECT_EQ(ERR_INVALID_ARGUMENT, error); + EXPECT_TRUE(request_handle == NULL); + + // No supported types in requested_types. + types.push_back(2); + types.push_back(3); + error = service->GetOriginBoundCert( + origin, types, &type2, &private_key_info2, &der_cert2, + callback.callback(), &request_handle); + EXPECT_EQ(ERR_CLIENT_AUTH_CERT_TYPE_UNSUPPORTED, error); + EXPECT_TRUE(request_handle == NULL); + + // If we request EC, the cert we created before should still be there. + types.push_back(CLIENT_CERT_RSA_SIGN); + types.push_back(CLIENT_CERT_ECDSA_SIGN); + error = service->GetOriginBoundCert( + origin, types, &type2, &private_key_info2, &der_cert2, + callback.callback(), &request_handle); + EXPECT_TRUE(request_handle == NULL); + EXPECT_EQ(OK, error); + EXPECT_EQ(1, service->cert_count()); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type2); + EXPECT_EQ(private_key_info1, private_key_info2); + EXPECT_EQ(der_cert1, der_cert2); +} + TEST(OriginBoundCertServiceTest, StoreCerts) { scoped_ptr<OriginBoundCertService> service( new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); int error; + std::vector<uint8> types; + types.push_back(CLIENT_CERT_RSA_SIGN); TestCompletionCallback callback; OriginBoundCertService::RequestHandle request_handle; std::string origin1("https://encrypted.google.com:443"); + SSLClientCertType type1; std::string private_key_info1, der_cert1; EXPECT_EQ(0, service->cert_count()); error = service->GetOriginBoundCert( - origin1, &private_key_info1, &der_cert1, callback.callback(), - &request_handle); + origin1, types, &type1, &private_key_info1, &der_cert1, + callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle != NULL); error = callback.WaitForResult(); @@ -87,10 +214,11 @@ TEST(OriginBoundCertServiceTest, StoreCerts) { EXPECT_EQ(1, service->cert_count()); std::string origin2("https://www.verisign.com:443"); + SSLClientCertType type2; std::string private_key_info2, der_cert2; error = service->GetOriginBoundCert( - origin2, &private_key_info2, &der_cert2, callback.callback(), - &request_handle); + origin2, types, &type2, &private_key_info2, &der_cert2, + callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle != NULL); error = callback.WaitForResult(); @@ -98,10 +226,12 @@ TEST(OriginBoundCertServiceTest, StoreCerts) { EXPECT_EQ(2, service->cert_count()); std::string origin3("https://www.twitter.com:443"); + SSLClientCertType type3; std::string private_key_info3, der_cert3; + types[0] = CLIENT_CERT_ECDSA_SIGN; error = service->GetOriginBoundCert( - origin3, &private_key_info3, &der_cert3, callback.callback(), - &request_handle); + origin3, types, &type3, &private_key_info3, &der_cert3, + callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle != NULL); error = callback.WaitForResult(); @@ -114,6 +244,9 @@ TEST(OriginBoundCertServiceTest, StoreCerts) { EXPECT_NE(der_cert1, der_cert3); EXPECT_NE(private_key_info2, private_key_info3); EXPECT_NE(der_cert2, der_cert3); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type1); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type2); + EXPECT_EQ(CLIENT_CERT_ECDSA_SIGN, type3); } // Tests an inflight join. @@ -122,23 +255,30 @@ TEST(OriginBoundCertServiceTest, InflightJoin) { new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); std::string origin("https://encrypted.google.com:443"); int error; + std::vector<uint8> types; + types.push_back(CLIENT_CERT_RSA_SIGN); + SSLClientCertType type1; std::string private_key_info1, der_cert1; TestCompletionCallback callback1; OriginBoundCertService::RequestHandle request_handle1; + SSLClientCertType type2; std::string private_key_info2, der_cert2; TestCompletionCallback callback2; OriginBoundCertService::RequestHandle request_handle2; error = service->GetOriginBoundCert( - origin, &private_key_info1, &der_cert1, callback1.callback(), - &request_handle1); + origin, types, &type1, &private_key_info1, &der_cert1, + callback1.callback(), &request_handle1); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle1 != NULL); + // If we request EC and RSA in the 2nd request, should still join with the + // original request. + types.insert(types.begin(), CLIENT_CERT_ECDSA_SIGN); error = service->GetOriginBoundCert( - origin, &private_key_info2, &der_cert2, callback2.callback(), - &request_handle2); + origin, types, &type2, &private_key_info2, &der_cert2, + callback2.callback(), &request_handle2); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle2 != NULL); @@ -147,22 +287,73 @@ TEST(OriginBoundCertServiceTest, InflightJoin) { error = callback2.WaitForResult(); EXPECT_EQ(OK, error); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type1); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type2); EXPECT_EQ(2u, service->requests()); EXPECT_EQ(0u, service->cert_store_hits()); EXPECT_EQ(1u, service->inflight_joins()); } -TEST(OriginBoundCertServiceTest, ExtractValuesFromBytes) { +// Tests an inflight join with mismatching request types. +TEST(OriginBoundCertServiceTest, InflightJoinTypeMismatch) { scoped_ptr<OriginBoundCertService> service( new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); std::string origin("https://encrypted.google.com:443"); + int error; + std::vector<uint8> types1; + types1.push_back(CLIENT_CERT_RSA_SIGN); + std::vector<uint8> types2; + types2.push_back(CLIENT_CERT_ECDSA_SIGN); + + SSLClientCertType type1; + std::string private_key_info1, der_cert1; + TestCompletionCallback callback1; + OriginBoundCertService::RequestHandle request_handle1; + + SSLClientCertType type2; + std::string private_key_info2, der_cert2; + TestCompletionCallback callback2; + OriginBoundCertService::RequestHandle request_handle2; + + error = service->GetOriginBoundCert( + origin, types1, &type1, &private_key_info1, &der_cert1, + callback1.callback(), &request_handle1); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle1 != NULL); + // If we request only EC in the 2nd request, it should return an error. + error = service->GetOriginBoundCert( + origin, types2, &type2, &private_key_info2, &der_cert2, + callback2.callback(), &request_handle2); + EXPECT_EQ(ERR_ORIGIN_BOUND_CERT_GENERATION_TYPE_MISMATCH, error); + EXPECT_TRUE(request_handle2 == NULL); + + error = callback1.WaitForResult(); + EXPECT_EQ(OK, error); + + EXPECT_FALSE(private_key_info1.empty()); + EXPECT_FALSE(der_cert1.empty()); + EXPECT_TRUE(private_key_info2.empty()); + EXPECT_TRUE(der_cert2.empty()); + EXPECT_EQ(CLIENT_CERT_RSA_SIGN, type1); + EXPECT_EQ(2u, service->requests()); + EXPECT_EQ(0u, service->cert_store_hits()); + EXPECT_EQ(0u, service->inflight_joins()); +} + +TEST(OriginBoundCertServiceTest, ExtractValuesFromBytesRSA) { + scoped_ptr<OriginBoundCertService> service( + new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); + std::string origin("https://encrypted.google.com:443"); + SSLClientCertType type; std::string private_key_info, der_cert; int error; + std::vector<uint8> types; + types.push_back(CLIENT_CERT_RSA_SIGN); TestCompletionCallback callback; OriginBoundCertService::RequestHandle request_handle; error = service->GetOriginBoundCert( - origin, &private_key_info, &der_cert, callback.callback(), + origin, types, &type, &private_key_info, &der_cert, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle != NULL); @@ -181,16 +372,60 @@ TEST(OriginBoundCertServiceTest, ExtractValuesFromBytes) { EXPECT_TRUE(x509cert != NULL); } +TEST(OriginBoundCertServiceTest, ExtractValuesFromBytesEC) { + scoped_ptr<OriginBoundCertService> service( + new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); + std::string origin("https://encrypted.google.com:443"); + SSLClientCertType type; + std::string private_key_info, der_cert; + int error; + std::vector<uint8> types; + types.push_back(CLIENT_CERT_ECDSA_SIGN); + TestCompletionCallback callback; + OriginBoundCertService::RequestHandle request_handle; + + error = service->GetOriginBoundCert( + origin, types, &type, &private_key_info, &der_cert, callback.callback(), + &request_handle); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle != NULL); + error = callback.WaitForResult(); + EXPECT_EQ(OK, error); + + base::StringPiece spki_piece; + ASSERT_TRUE(asn1::ExtractSPKIFromDERCert(der_cert, &spki_piece)); + std::vector<uint8> spki( + spki_piece.data(), + spki_piece.data() + spki_piece.size()); + + // Check that we can retrieve the key from the bytes. + std::vector<uint8> key_vec(private_key_info.begin(), private_key_info.end()); + scoped_ptr<crypto::ECPrivateKey> private_key( + crypto::ECPrivateKey::CreateFromEncryptedPrivateKeyInfo( + OriginBoundCertService::kEPKIPassword, key_vec, spki)); + EXPECT_TRUE(private_key != NULL); + + // Check that we can retrieve the cert from the bytes. + scoped_refptr<X509Certificate> x509cert( + X509Certificate::CreateFromBytes(der_cert.data(), der_cert.size())); + EXPECT_TRUE(x509cert != NULL); +} + // Tests that the callback of a canceled request is never made. TEST(OriginBoundCertServiceTest, CancelRequest) { scoped_ptr<OriginBoundCertService> service( new OriginBoundCertService(new DefaultOriginBoundCertStore(NULL))); std::string origin("https://encrypted.google.com:443"); + SSLClientCertType type; std::string private_key_info, der_cert; int error; + std::vector<uint8> types; + types.push_back(CLIENT_CERT_RSA_SIGN); OriginBoundCertService::RequestHandle request_handle; error = service->GetOriginBoundCert(origin, + types, + &type, &private_key_info, &der_cert, base::Bind(&FailTest), @@ -206,6 +441,8 @@ TEST(OriginBoundCertServiceTest, CancelRequest) { for (int i = 0; i < 5; ++i) { error = service->GetOriginBoundCert( "https://encrypted.google.com:" + std::string(1, (char) ('1' + i)), + types, + &type, &private_key_info, &der_cert, callback.callback(), @@ -214,6 +451,10 @@ TEST(OriginBoundCertServiceTest, CancelRequest) { EXPECT_TRUE(request_handle != NULL); error = callback.WaitForResult(); } + + // Even though the original request was cancelled, the service will still + // store the result, it just doesn't call the callback. + EXPECT_EQ(6, service->cert_count()); } #endif // !defined(USE_OPENSSL) diff --git a/net/base/origin_bound_cert_store.cc b/net/base/origin_bound_cert_store.cc new file mode 100644 index 0000000..16d054a --- /dev/null +++ b/net/base/origin_bound_cert_store.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/origin_bound_cert_store.h" + +namespace net { + +OriginBoundCertStore::OriginBoundCert::OriginBoundCert() {} + +OriginBoundCertStore::OriginBoundCert::OriginBoundCert( + const std::string& origin, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert) + : origin_(origin), + type_(type), + private_key_(private_key), + cert_(cert) {} + +OriginBoundCertStore::OriginBoundCert::~OriginBoundCert() {} + +} // namespace net diff --git a/net/base/origin_bound_cert_store.h b/net/base/origin_bound_cert_store.h index 4cb1132..094839b 100644 --- a/net/base/origin_bound_cert_store.h +++ b/net/base/origin_bound_cert_store.h @@ -10,6 +10,7 @@ #include <vector> #include "net/base/net_export.h" +#include "net/base/ssl_client_cert_type.h" namespace net { @@ -22,30 +23,55 @@ namespace net { class NET_EXPORT OriginBoundCertStore { public: - // Used by GetAllOriginBoundCerts. - struct OriginBoundCertInfo { - std::string origin; // Origin, for instance "https://www.verisign.com:443". - std::string private_key; // DER-encoded PrivateKeyInfo struct. - std::string cert; // DER-encoded certificate. + // The OriginBoundCert class contains a private key in addition to the origin + // cert, and cert type. + class NET_EXPORT OriginBoundCert { + public: + OriginBoundCert(); + OriginBoundCert(const std::string& origin, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert); + ~OriginBoundCert(); + + // Origin, for instance "https://www.verisign.com:443" + const std::string& origin() const { return origin_; } + // TLS ClientCertificateType. + SSLClientCertType type() const { return type_; } + // The encoding of the private key depends on the type. + // rsa_sign: DER-encoded PrivateKeyInfo struct. + // ecdsa_sign: DER-encoded EncryptedPrivateKeyInfo struct. + const std::string& private_key() const { return private_key_; } + // DER-encoded certificate. + const std::string& cert() const { return cert_; } + + private: + std::string origin_; + SSLClientCertType type_; + std::string private_key_; + std::string cert_; }; virtual ~OriginBoundCertStore() {} - // TODO(rkn): Specify certificate type (RSA or DSA). // TODO(rkn): File I/O may be required, so this should have an asynchronous // interface. // Returns true on success. |private_key_result| stores a DER-encoded // PrivateKeyInfo struct and |cert_result| stores a DER-encoded // certificate. Returns false if no origin bound cert exists for the // specified origin. - virtual bool GetOriginBoundCert(const std::string& origin, - std::string* private_key_result, - std::string* cert_result) = 0; + virtual bool GetOriginBoundCert( + const std::string& origin, + SSLClientCertType* type, + std::string* private_key_result, + std::string* cert_result) = 0; // Adds an origin bound cert and the corresponding private key to the store. - virtual void SetOriginBoundCert(const std::string& origin, - const std::string& private_key, - const std::string& cert) = 0; + virtual void SetOriginBoundCert( + const std::string& origin, + SSLClientCertType type, + const std::string& private_key, + const std::string& cert) = 0; // Removes an origin bound cert and the corresponding private key from the // store. @@ -57,7 +83,7 @@ class NET_EXPORT OriginBoundCertStore { // Returns all origin bound certs and the corresponding private keys. virtual void GetAllOriginBoundCerts( - std::vector<OriginBoundCertInfo>* origin_bound_certs) = 0; + std::vector<OriginBoundCert>* origin_bound_certs) = 0; // Returns the number of certs in the store. // Public only for unit testing. diff --git a/net/base/ssl_client_cert_type.h b/net/base/ssl_client_cert_type.h new file mode 100644 index 0000000..921a6f4 --- /dev/null +++ b/net/base/ssl_client_cert_type.h @@ -0,0 +1,22 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_SSL_CLIENT_CERT_TYPE_H_ +#define NET_BASE_SSL_CLIENT_CERT_TYPE_H_ +#pragma once + +namespace net { + +// TLS ClientCertificateType Identifiers +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-1 +enum SSLClientCertType { + CLIENT_CERT_RSA_SIGN = 1, + CLIENT_CERT_ECDSA_SIGN = 64, + // 224-255 are Reserved for Private Use, we pick one to use as "invalid". + CLIENT_CERT_INVALID_TYPE = 255, +}; + +} // namespace net + +#endif // NET_BASE_SSL_CLIENT_CERT_TYPE_H_ diff --git a/net/net.gyp b/net/net.gyp index 777f0f6..7c897ce 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -190,6 +190,7 @@ 'base/openssl_private_key_store_android.cc', 'base/origin_bound_cert_service.cc', 'base/origin_bound_cert_service.h', + 'base/origin_bound_cert_store.cc', 'base/origin_bound_cert_store.h', 'base/pem_tokenizer.cc', 'base/pem_tokenizer.h', @@ -213,6 +214,7 @@ 'base/ssl_cipher_suite_names.h', 'base/ssl_client_auth_cache.cc', 'base/ssl_client_auth_cache.h', + 'base/ssl_client_cert_type.h', 'base/ssl_config_service.cc', 'base/ssl_config_service.h', 'base/ssl_config_service_defaults.cc', diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 16b8da2..d2991ba 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -75,6 +75,7 @@ #include "base/stringprintf.h" #include "base/threading/thread_restrictions.h" #include "base/values.h" +#include "crypto/ec_private_key.h" #include "crypto/rsa_private_key.h" #include "crypto/scoped_nss_types.h" #include "net/base/address_list.h" @@ -1549,20 +1550,48 @@ int SSLClientSocketNSS::ImportOBCertAndKey(CERTCertificate** cert, return MapNSSError(PORT_GetError()); // Set the private key. - SECItem der_private_key_info; - der_private_key_info.data = (unsigned char*)ob_private_key_.data(); - der_private_key_info.len = ob_private_key_.size(); - const unsigned int key_usage = KU_DIGITAL_SIGNATURE; - crypto::ScopedPK11Slot slot(PK11_GetInternalSlot()); - SECStatus rv = PK11_ImportDERPrivateKeyInfoAndReturnKey( - slot.get(), &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE, - key_usage, key, NULL); + switch (ob_cert_type_) { + case CLIENT_CERT_RSA_SIGN: { + SECItem der_private_key_info; + der_private_key_info.data = (unsigned char*)ob_private_key_.data(); + der_private_key_info.len = ob_private_key_.size(); + const unsigned int key_usage = KU_DIGITAL_SIGNATURE; + crypto::ScopedPK11Slot slot(PK11_GetInternalSlot()); + SECStatus rv = PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot.get(), &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE, + key_usage, key, NULL); + + if (rv != SECSuccess) { + int error = MapNSSError(PORT_GetError()); + CERT_DestroyCertificate(*cert); + *cert = NULL; + return error; + } + break; + } - if (rv != SECSuccess) { - int error = MapNSSError(PORT_GetError()); - CERT_DestroyCertificate(*cert); - *cert = NULL; - return error; + case CLIENT_CERT_ECDSA_SIGN: { + SECKEYPublicKey* public_key = NULL; + if (!crypto::ECPrivateKey::ImportFromEncryptedPrivateKeyInfo( + OriginBoundCertService::kEPKIPassword, + reinterpret_cast<const unsigned char*>(ob_private_key_.data()), + ob_private_key_.size(), + &(*cert)->subjectPublicKeyInfo, + false, + false, + key, + &public_key)) { + CERT_DestroyCertificate(*cert); + *cert = NULL; + return MapNSSError(PORT_GetError()); + } + SECKEY_DestroyPublicKey(public_key); + break; + } + + default: + NOTREACHED(); + return ERR_INVALID_ARGUMENT; } return OK; @@ -2117,6 +2146,7 @@ bool SSLClientSocketNSS::OriginBoundCertNegotiated(PRFileDesc* socket) { } SECStatus SSLClientSocketNSS::OriginBoundClientAuthHandler( + const std::vector<uint8>& requested_cert_types, CERTCertificate** result_certificate, SECKEYPrivateKey** result_private_key) { ob_cert_xtn_negotiated_ = true; @@ -2126,6 +2156,8 @@ SECStatus SSLClientSocketNSS::OriginBoundClientAuthHandler( net_log_.BeginEvent(NetLog::TYPE_SSL_GET_ORIGIN_BOUND_CERT, NULL); int error = origin_bound_cert_service_->GetOriginBoundCert( origin, + requested_cert_types, + &ob_cert_type_, &ob_private_key_, &ob_cert_, base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, @@ -2175,8 +2207,12 @@ SECStatus SSLClientSocketNSS::PlatformClientAuthHandler( // Check if an origin-bound certificate is requested. if (OriginBoundCertNegotiated(socket)) { + // TODO(mattm): Once NSS supports it, pass the actual requested types. + std::vector<uint8> requested_cert_types; + requested_cert_types.push_back(CLIENT_CERT_ECDSA_SIGN); + requested_cert_types.push_back(CLIENT_CERT_RSA_SIGN); return that->OriginBoundClientAuthHandler( - result_nss_certificate, result_nss_private_key); + requested_cert_types, result_nss_certificate, result_nss_private_key); } that->client_auth_cert_needed_ = !that->ssl_config_.send_client_cert; @@ -2480,8 +2516,12 @@ SECStatus SSLClientSocketNSS::ClientAuthHandler( // Check if an origin-bound certificate is requested. if (OriginBoundCertNegotiated(socket)) { + // TODO(mattm): Once NSS supports it, pass the actual requested types. + std::vector<uint8> requested_cert_types; + requested_cert_types.push_back(CLIENT_CERT_ECDSA_SIGN); + requested_cert_types.push_back(CLIENT_CERT_RSA_SIGN); return that->OriginBoundClientAuthHandler( - result_certificate, result_private_key); + requested_cert_types, result_certificate, result_private_key); } // Regular client certificate requested. diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 0eddd76..7b56844 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -167,8 +167,10 @@ class SSLClientSocketNSS : public SSLClientSocket { static bool OriginBoundCertNegotiated(PRFileDesc* socket); // Origin bound cert client auth handler. // Returns the value the ClientAuthHandler function should return. - SECStatus OriginBoundClientAuthHandler(CERTCertificate** result_certificate, - SECKEYPrivateKey** result_private_key); + SECStatus OriginBoundClientAuthHandler( + const std::vector<uint8>& requested_cert_types, + CERTCertificate** result_certificate, + SECKEYPrivateKey** result_private_key); #if defined(NSS_PLATFORM_CLIENT_AUTH) // On platforms where we use the native certificate store, NSS calls this // instead when client authentication is requested. At most one of @@ -259,6 +261,7 @@ class SSLClientSocketNSS : public SSLClientSocket { // For origin bound certificates in client auth. bool ob_cert_xtn_negotiated_; OriginBoundCertService* origin_bound_cert_service_; + SSLClientCertType ob_cert_type_; std::string ob_private_key_; std::string ob_cert_; OriginBoundCertService::RequestHandle ob_cert_request_handle_; |