From d91f84376fe8dd249770ac19b7c08f8fcc20f446 Mon Sep 17 00:00:00 2001
From: "wtc@chromium.org"
 <wtc@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>
Date: Tue, 5 May 2009 23:55:59 +0000
Subject: Separate the key setting code in the constructor of HMAC class into
 the Init method.

Overload the Init method for char* and std::string.

Add DCHECKs to the destruction methods in ~HMAC in hmac_win.cc.

The patch is written by Takeshi Yoshino <tyoshino@google.com>.
Original code review: http://codereview.chromium.org/88062

R=wtc
http://crbug.com/2297
TEST=base_unittests should pass.  Safe browsing should continue to work.
Review URL: http://codereview.chromium.org/113001

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@15353 0039d316-1c4b-4281-b951-d872f2087c98
---
 base/hmac.h                                        | 19 +++++++---
 base/hmac_mac.cc                                   | 15 +++++++-
 base/hmac_nss.cc                                   | 34 ++++++++++++++++--
 base/hmac_unittest.cc                              | 13 +++----
 base/hmac_win.cc                                   | 41 +++++++++++++++++-----
 chrome/browser/safe_browsing/safe_browsing_util.cc |  6 ++--
 6 files changed, 103 insertions(+), 25 deletions(-)

diff --git a/base/hmac.h b/base/hmac.h
index bbfe855..bdd23a9 100644
--- a/base/hmac.h
+++ b/base/hmac.h
@@ -25,12 +25,23 @@ class HMAC {
     SHA1
   };
 
-  HMAC(HashAlgorithm hash_alg, const unsigned char* key, int key_length);
+  explicit HMAC(HashAlgorithm hash_alg);
   ~HMAC();
 
-  // Calculates the HMAC for the message in |data| using the algorithm and key
-  // supplied to the constructor.  The HMAC is returned in |digest|, which
-  // has |digest_length| bytes of storage available.
+  // Initializes this instance using |key| of the length |key_length|. Call Init
+  // only once. It returns false on the second or later calls.
+  bool Init(const unsigned char* key, int key_length);
+
+  // Initializes this instance using |key|. Call Init only once. It returns
+  // false on the second or later calls.
+  bool Init(const std::string& key) {
+    return Init(reinterpret_cast<const unsigned char*>(key.data()),
+                static_cast<int>(key.size()));
+  }
+
+  // Calculates the HMAC for the message in |data| using the algorithm supplied
+  // to the constructor and the key supplied to the Init method. The HMAC is
+  // returned in |digest|, which has |digest_length| bytes of storage available.
   bool Sign(const std::string& data, unsigned char* digest, int digest_length);
 
  private:
diff --git a/base/hmac_mac.cc b/base/hmac_mac.cc
index 668b306..bbc9330 100644
--- a/base/hmac_mac.cc
+++ b/base/hmac_mac.cc
@@ -14,9 +14,22 @@ struct HMACPlatformData {
   std::string key_;
 };
 
-HMAC::HMAC(HashAlgorithm hash_alg, const unsigned char* key, int key_length)
+HMAC::HMAC(HashAlgorithm hash_alg)
     : hash_alg_(hash_alg), plat_(new HMACPlatformData()) {
+  // Only SHA-1 digest is supported now.
+  DCHECK(hash_alg_ == SHA1);
+}
+
+bool HMAC::Init(const unsigned char *key, int key_length) {
+  if (!plat_->key_.empty()) {
+    // Init must not be called more than once on the same HMAC object.
+    NOTREACHED();
+    return false;
+  }
+
   plat_->key_.assign(reinterpret_cast<const char*>(key), key_length);
+
+  return true;
 }
 
 HMAC::~HMAC() {
diff --git a/base/hmac_nss.cc b/base/hmac_nss.cc
index 293f61d..f56a9fc 100644
--- a/base/hmac_nss.cc
+++ b/base/hmac_nss.cc
@@ -42,14 +42,31 @@ struct HMACPlatformData {
   ScopedNSSSymKey sym_key_;
 };
 
-HMAC::HMAC(HashAlgorithm hash_alg, const unsigned char* key, int key_length)
+HMAC::HMAC(HashAlgorithm hash_alg)
     : hash_alg_(hash_alg), plat_(new HMACPlatformData()) {
+  // Only SHA-1 digest is supported now.
   DCHECK(hash_alg_ == SHA1);
+}
 
+bool HMAC::Init(const unsigned char *key, int key_length) {
   base::EnsureNSSInit();
 
+  if (hash_alg_ != SHA1) {
+    NOTREACHED();
+    return false;
+  }
+
+  if (plat_->slot_.get() || plat_->slot_.get()) {
+    // Init must not be called more than twice on the same HMAC object.
+    NOTREACHED();
+    return false;
+  }
+
   plat_->slot_.reset(PK11_GetBestSlot(CKM_SHA_1_HMAC, NULL));
-  CHECK(plat_->slot_.get());
+  if (!plat_->slot_.get()) {
+    NOTREACHED();
+    return false;
+  }
 
   SECItem key_item;
   key_item.type = siBuffer;
@@ -62,7 +79,12 @@ HMAC::HMAC(HashAlgorithm hash_alg, const unsigned char* key, int key_length)
                                           CKA_SIGN,
                                           &key_item,
                                           NULL));
-  CHECK(plat_->sym_key_.get());
+  if (!plat_->sym_key_.get()) {
+    NOTREACHED();
+    return false;
+  }
+
+  return true;
 }
 
 HMAC::~HMAC() {
@@ -71,6 +93,12 @@ HMAC::~HMAC() {
 bool HMAC::Sign(const std::string& data,
                 unsigned char* digest,
                 int digest_length) {
+  if (!plat_->sym_key_.get()) {
+    // Init has not been called before Sign.
+    NOTREACHED();
+    return false;
+  }
+
   SECItem param = { siBuffer, NULL, 0 };
   ScopedNSSContext context(PK11_CreateContextBySymKey(CKM_SHA_1_HMAC,
                                                       CKA_SIGN,
diff --git a/base/hmac_unittest.cc b/base/hmac_unittest.cc
index 9881369..56b811a 100644
--- a/base/hmac_unittest.cc
+++ b/base/hmac_unittest.cc
@@ -51,7 +51,8 @@ TEST(HMACTest, HmacSafeBrowsingResponseTest) {
 
   std::string message_data(kMessage);
 
-  base::HMAC hmac(base::HMAC::SHA1, kClientKey, kKeySize);
+  base::HMAC hmac(base::HMAC::SHA1);
+  ASSERT_TRUE(hmac.Init(kClientKey, kKeySize));
   unsigned char calculated_hmac[kDigestSize];
 
   EXPECT_TRUE(hmac.Sign(message_data, calculated_hmac, kDigestSize));
@@ -119,9 +120,9 @@ TEST(HMACTest, RFC2202TestCases) {
   };
 
   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(cases); ++i) {
-    base::HMAC hmac(base::HMAC::SHA1,
-                    reinterpret_cast<const unsigned char*>(cases[i].key),
-                    cases[i].key_len);
+    base::HMAC hmac(base::HMAC::SHA1);
+    ASSERT_TRUE(hmac.Init(reinterpret_cast<const unsigned char*>(cases[i].key),
+                          cases[i].key_len));
     std::string data_string(cases[i].data, cases[i].data_len);
     unsigned char digest[kDigestSize];
     EXPECT_TRUE(hmac.Sign(data_string, digest, kDigestSize));
@@ -152,8 +153,8 @@ TEST(HMACTest, HMACObjectReuse) {
           "\xBB\xFF\x1A\x91" }
   };
 
-  base::HMAC hmac(base::HMAC::SHA1,
-                  reinterpret_cast<const unsigned char*>(key), key_len);
+  base::HMAC hmac(base::HMAC::SHA1);
+  ASSERT_TRUE(hmac.Init(reinterpret_cast<const unsigned char*>(key), key_len));
   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(cases); ++i) {
     std::string data_string(cases[i].data, cases[i].data_len);
     unsigned char digest[kDigestSize];
diff --git a/base/hmac_win.cc b/base/hmac_win.cc
index d927ac1..2b2e9cc 100644
--- a/base/hmac_win.cc
+++ b/base/hmac_win.cc
@@ -21,11 +21,25 @@ struct HMACPlatformData {
   HCRYPTKEY hkey_;
 };
 
-HMAC::HMAC(HashAlgorithm hash_alg, const unsigned char* key, int key_length)
+HMAC::HMAC(HashAlgorithm hash_alg)
     : hash_alg_(hash_alg), plat_(new HMACPlatformData()) {
+  // Only SHA-1 digest is supported now.
+  DCHECK(hash_alg_ == SHA1);
+}
+
+bool HMAC::Init(const unsigned char *key, int key_length) {
+  if (plat_->provider_ || plat_->hkey_) {
+    // Init must not be called more than once on the same HMAC object.
+    NOTREACHED();
+    return false;
+  }
+
   if (!CryptAcquireContext(&plat_->provider_, NULL, NULL,
-                           PROV_RSA_FULL, CRYPT_VERIFYCONTEXT))
+                           PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) {
+    NOTREACHED();
     plat_->provider_ = NULL;
+    return false;
+  }
 
   // This code doesn't work on Win2k because PLAINTEXTKEYBLOB and
   // CRYPT_IPSEC_HMAC_KEY are not supported on Windows 2000.  PLAINTEXTKEYBLOB
@@ -53,20 +67,31 @@ HMAC::HMAC(HashAlgorithm hash_alg, const unsigned char* key, int key_length)
   if (!CryptImportKey(plat_->provider_, &key_blob_storage[0],
                       key_blob_storage.size(), 0, CRYPT_IPSEC_HMAC_KEY,
                       &plat_->hkey_)) {
+    NOTREACHED();
     plat_->hkey_ = NULL;
+    return false;
   }
 
   // Destroy the copy of the key.
   SecureZeroMemory(key_blob->key_data, key_length);
+
+  return true;
 }
 
 HMAC::~HMAC() {
-  if (plat_->hkey_)
-    CryptDestroyKey(plat_->hkey_);
-  if (plat_->hash_)
-    CryptDestroyHash(plat_->hash_);
-  if (plat_->provider_)
-    CryptReleaseContext(plat_->provider_, 0);
+  BOOL ok;
+  if (plat_->hkey_) {
+    ok = CryptDestroyKey(plat_->hkey_);
+    DCHECK(ok);
+  }
+  if (plat_->hash_) {
+    ok = CryptDestroyHash(plat_->hash_);
+    DCHECK(ok);
+  }
+  if (plat_->provider_) {
+    ok = CryptReleaseContext(plat_->provider_, 0);
+    DCHECK(ok);
+  }
 }
 
 bool HMAC::Sign(const std::string& data,
diff --git a/chrome/browser/safe_browsing/safe_browsing_util.cc b/chrome/browser/safe_browsing/safe_browsing_util.cc
index ea19736..4f9d776 100644
--- a/chrome/browser/safe_browsing/safe_browsing_util.cc
+++ b/chrome/browser/safe_browsing/safe_browsing_util.cc
@@ -170,9 +170,9 @@ bool VerifyMAC(const std::string& key, const std::string& mac,
   std::string decoded_mac;
   net::Base64Decode(mac_copy, &decoded_mac);
 
-  base::HMAC hmac(base::HMAC::SHA1,
-                  reinterpret_cast<const unsigned char*>(decoded_key.data()),
-                  static_cast<int>(decoded_key.length()));
+  base::HMAC hmac(base::HMAC::SHA1);
+  if (!hmac.Init(decoded_key))
+    return false;
   const std::string data_str(data, data_length);
   unsigned char digest[kSafeBrowsingMacDigestSize];
   if (!hmac.Sign(data_str, digest, kSafeBrowsingMacDigestSize))
-- 
cgit v1.1