summaryrefslogtreecommitdiffstats
path: root/base/sha1_win.cc
diff options
context:
space:
mode:
Diffstat (limited to 'base/sha1_win.cc')
-rw-r--r--base/sha1_win.cc139
1 files changed, 45 insertions, 94 deletions
diff --git a/base/sha1_win.cc b/base/sha1_win.cc
index 0d0cf2c..853c244 100644
--- a/base/sha1_win.cc
+++ b/base/sha1_win.cc
@@ -7,108 +7,59 @@
#include <windows.h>
#include <wincrypt.h>
+#include "base/crypto/scoped_capi_types.h"
#include "base/logging.h"
namespace base {
-// Implementation of SHA-1 using Windows CryptoAPI.
-
-// Usage example:
-//
-// SecureHashAlgorithm sha;
-// while(there is data to hash)
-// sha.Update(moredata, size of data);
-// sha.Final();
-// memcpy(somewhere, sha.Digest(), 20);
-//
-// to reuse the instance of sha, call sha.Init();
-
-class SecureHashAlgorithm {
- public:
- SecureHashAlgorithm() : prov_(NULL), hash_(NULL) { Init(); }
-
- void Init();
- void Update(const void* data, size_t nbytes);
- void Final();
-
- // 20 bytes of message digest.
- const unsigned char* Digest() const {
- return reinterpret_cast<const unsigned char*>(result_.data());
- }
-
- private:
- // Cleans up prov_, hash_, and result_.
- void Cleanup();
-
- HCRYPTPROV prov_;
- HCRYPTHASH hash_;
- std::string result_;
-};
-
-void SecureHashAlgorithm::Init() {
- Cleanup();
-
- if (!CryptAcquireContext(&prov_, 0, 0, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) {
+std::string SHA1HashString(const std::string& str) {
+ ScopedHCRYPTPROV provider;
+ if (!CryptAcquireContext(provider.receive(), NULL, NULL, PROV_RSA_FULL,
+ CRYPT_VERIFYCONTEXT)) {
LOG(ERROR) << "CryptAcquireContext failed: " << GetLastError();
- return;
- }
-
- // Initialize the hash.
- if (!CryptCreateHash(prov_, CALG_SHA1, 0, 0, &hash_)) {
- LOG(ERROR) << "CryptCreateHash failed: " << GetLastError();
- return;
+ return std::string(SHA1_LENGTH, '\0');
}
-}
-
-void SecureHashAlgorithm::Update(const void* data, size_t nbytes) {
- BOOL ok = CryptHashData(hash_, reinterpret_cast<CONST BYTE*>(data),
- static_cast<DWORD>(nbytes), 0);
- CHECK(ok) << "CryptHashData failed: " << GetLastError();
-}
-
-void SecureHashAlgorithm::Final() {
- DWORD hash_len = 0;
- DWORD buffer_size = sizeof(hash_len);
- if (!CryptGetHashParam(hash_, HP_HASHSIZE,
- reinterpret_cast<unsigned char*>(&hash_len),
- &buffer_size, 0)) {
- LOG(ERROR) << "CryptGetHashParam(HP_HASHSIZE) failed: " << GetLastError();
- result_.assign(SHA1_LENGTH, '\0');
- return;
- }
-
- // Get the hash data.
- if (!CryptGetHashParam(hash_, HP_HASHVAL,
- reinterpret_cast<BYTE*>(WriteInto(&result_,
- hash_len + 1)),
- &hash_len, 0)) {
- LOG(ERROR) << "CryptGetHashParam(HP_HASHVAL) failed: " << GetLastError();
- result_.assign(SHA1_LENGTH, '\0');
- return;
- }
-}
-void SecureHashAlgorithm::Cleanup() {
- BOOL ok;
- if (hash_) {
- ok = CryptDestroyHash(hash_);
- DCHECK(ok);
- hash_ = NULL;
- }
- if (prov_) {
- ok = CryptReleaseContext(prov_, 0);
- DCHECK(ok);
- prov_ = NULL;
+ {
+ ScopedHCRYPTHASH hash;
+ if (!CryptCreateHash(provider, CALG_SHA1, 0, 0, hash.receive())) {
+ LOG(ERROR) << "CryptCreateHash failed: " << GetLastError();
+ return std::string(SHA1_LENGTH, '\0');
+ }
+
+ if (!CryptHashData(hash, reinterpret_cast<CONST BYTE*>(str.data()),
+ static_cast<DWORD>(str.length()), 0)) {
+ LOG(ERROR) << "CryptHashData failed: " << GetLastError();
+ return std::string(SHA1_LENGTH, '\0');
+ }
+
+ DWORD hash_len = 0;
+ DWORD buffer_size = sizeof hash_len;
+ if (!CryptGetHashParam(hash, HP_HASHSIZE,
+ reinterpret_cast<unsigned char*>(&hash_len),
+ &buffer_size, 0)) {
+ LOG(ERROR) << "CryptGetHashParam(HP_HASHSIZE) failed: " << GetLastError();
+ return std::string(SHA1_LENGTH, '\0');
+ }
+
+ std::string result;
+ if (!CryptGetHashParam(hash, HP_HASHVAL,
+ // We need the + 1 here not because the call will write a trailing \0,
+ // but so that result.length() is correctly set to |hash_len|.
+ reinterpret_cast<BYTE*>(WriteInto(&result, hash_len + 1)), &hash_len,
+ 0))) {
+ LOG(ERROR) << "CryptGetHashParam(HP_HASHVAL) failed: " << GetLastError();
+ return std::string(SHA1_LENGTH, '\0');
+ }
+
+ if (hash_len != SHA1_LENGTH) {
+ LOG(ERROR) << "Returned hash value is wrong length: " << hash_len
+ << " should be " << SHA1_LENGTH;
+ return std::string(SHA1_LENGTH, '\0');
+ }
+
+ return result;
}
- result_.clear();
-}
-
-std::string SHA1HashString(const std::string& str) {
- SecureHashAlgorithm sha;
- sha.Update(str.c_str(), str.length());
- sha.Final();
- std::string out(reinterpret_cast<const char*>(sha.Digest()), SHA1_LENGTH);
- return out;
}
} // namespace base