// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "crypto/encryptor.h" #include <cryptohi.h> #include <vector> #include "base/logging.h" #include "crypto/nss_util.h" #include "crypto/symmetric_key.h" namespace crypto { namespace { inline CK_MECHANISM_TYPE GetMechanism(Encryptor::Mode mode) { switch (mode) { case Encryptor::CBC: return CKM_AES_CBC_PAD; case Encryptor::CTR: // AES-CTR encryption uses ECB encryptor as a building block since // NSS doesn't support CTR encryption mode. return CKM_AES_ECB; default: NOTREACHED() << "Unsupported mode of operation"; break; } return static_cast<CK_MECHANISM_TYPE>(-1); } } // namespace Encryptor::Encryptor() : key_(NULL), mode_(CBC) { EnsureNSSInit(); } Encryptor::~Encryptor() { } bool Encryptor::Init(SymmetricKey* key, Mode mode, const base::StringPiece& iv) { DCHECK(key); DCHECK(CBC == mode || CTR == mode) << "Unsupported mode of operation"; key_ = key; mode_ = mode; if (mode == CBC && iv.size() != AES_BLOCK_SIZE) return false; slot_.reset(PK11_GetBestSlot(GetMechanism(mode), NULL)); if (!slot_.get()) return false; switch (mode) { case CBC: SECItem iv_item; iv_item.type = siBuffer; iv_item.data = reinterpret_cast<unsigned char*>( const_cast<char *>(iv.data())); iv_item.len = iv.size(); param_.reset(PK11_ParamFromIV(GetMechanism(mode), &iv_item)); break; case CTR: param_.reset(PK11_ParamFromIV(GetMechanism(mode), NULL)); break; } return param_ != NULL; } bool Encryptor::Encrypt(const base::StringPiece& plaintext, std::string* ciphertext) { CHECK(!plaintext.empty() || (mode_ == CBC)); ScopedPK11Context context(PK11_CreateContextBySymKey(GetMechanism(mode_), CKA_ENCRYPT, key_->key(), param_.get())); if (!context.get()) return false; return (mode_ == CTR) ? CryptCTR(context.get(), plaintext, ciphertext) : Crypt(context.get(), plaintext, ciphertext); } bool Encryptor::Decrypt(const base::StringPiece& ciphertext, std::string* plaintext) { CHECK(!ciphertext.empty()); ScopedPK11Context context(PK11_CreateContextBySymKey( GetMechanism(mode_), (mode_ == CTR ? CKA_ENCRYPT : CKA_DECRYPT), key_->key(), param_.get())); if (!context.get()) return false; return (mode_ == CTR) ? CryptCTR(context.get(), ciphertext, plaintext) : Crypt(context.get(), ciphertext, plaintext); } bool Encryptor::Crypt(PK11Context* context, const base::StringPiece& input, std::string* output) { size_t output_len = input.size() + AES_BLOCK_SIZE; CHECK_GT(output_len, input.size()); output->resize(output_len); uint8* output_data = reinterpret_cast<uint8*>(const_cast<char*>(output->data())); int input_len = input.size(); uint8* input_data = reinterpret_cast<uint8*>(const_cast<char*>(input.data())); int op_len; SECStatus rv = PK11_CipherOp(context, output_data, &op_len, output_len, input_data, input_len); if (SECSuccess != rv) { output->clear(); return false; } unsigned int digest_len; rv = PK11_DigestFinal(context, output_data + op_len, &digest_len, output_len - op_len); if (SECSuccess != rv) { output->clear(); return false; } output->resize(op_len + digest_len); return true; } bool Encryptor::CryptCTR(PK11Context* context, const base::StringPiece& input, std::string* output) { if (!counter_.get()) { LOG(ERROR) << "Counter value not set in CTR mode."; return false; } size_t output_len = ((input.size() + AES_BLOCK_SIZE - 1) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE; CHECK_GE(output_len, input.size()); output->resize(output_len); uint8* output_data = reinterpret_cast<uint8*>(const_cast<char*>(output->data())); size_t mask_len; bool ret = GenerateCounterMask(input.size(), output_data, &mask_len); if (!ret) return false; CHECK_EQ(mask_len, output_len); int op_len; SECStatus rv = PK11_CipherOp(context, output_data, &op_len, output_len, output_data, mask_len); if (SECSuccess != rv) return false; CHECK_EQ(static_cast<int>(mask_len), op_len); unsigned int digest_len; rv = PK11_DigestFinal(context, NULL, &digest_len, 0); if (SECSuccess != rv) return false; CHECK(!digest_len); // Use |output_data| to mask |input|. MaskMessage( reinterpret_cast<uint8*>(const_cast<char*>(input.data())), input.length(), output_data, output_data); output->resize(input.length()); return true; } } // namespace crypto