From aaead5019627818c93693fdb6ec04d47b47c17f2 Mon Sep 17 00:00:00 2001 From: "wtc@google.com" Date: Wed, 15 Oct 2008 00:20:11 +0000 Subject: Turn SSLClientSocket into an interface. The original ssl_client_socket.{h,cc} are renamed ssl_client_socket_win.{h,cc}. The new ssl_client_socket.h defines the SSLClientSocket interface, which simply extends the ClientSocket interface with a new GetSSLInfo method. ClientSocketFactory::CreateSSLClientSocket returns SSLClientSocket* instead of ClientSocket*. Replace the SSL protocol version mask parameter to the constructor and factory method by a SSLConfig parameter. R=darin Review URL: http://codereview.chromium.org/7304 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@3387 0039d316-1c4b-4281-b951-d872f2087c98 --- net/SConscript | 2 +- net/base/client_socket_factory.cc | 9 +- net/base/client_socket_factory.h | 9 +- net/base/ssl_client_socket.cc | 916 ------------------------- net/base/ssl_client_socket.h | 125 +--- net/base/ssl_client_socket_unittest.cc | 60 +- net/base/ssl_client_socket_win.cc | 917 ++++++++++++++++++++++++++ net/base/ssl_client_socket_win.h | 139 ++++ net/build/net.vcproj | 8 +- net/http/http_network_transaction.cc | 27 +- net/http/http_network_transaction.h | 3 +- net/http/http_network_transaction_unittest.cc | 12 +- net/net.xcodeproj/project.pbxproj | 2 - 13 files changed, 1123 insertions(+), 1106 deletions(-) delete mode 100644 net/base/ssl_client_socket.cc create mode 100644 net/base/ssl_client_socket_win.cc create mode 100644 net/base/ssl_client_socket_win.h diff --git a/net/SConscript b/net/SConscript index b3dd629..37d6c18 100644 --- a/net/SConscript +++ b/net/SConscript @@ -94,7 +94,7 @@ if env['PLATFORM'] == 'win32': 'base/directory_lister.cc', 'base/dns_resolution_observer.cc', 'base/listen_socket.cc', - 'base/ssl_client_socket.cc', + 'base/ssl_client_socket_win.cc', 'base/ssl_config_service.cc', 'base/tcp_client_socket.cc', 'base/telnet_server.cc', diff --git a/net/base/client_socket_factory.cc b/net/base/client_socket_factory.cc index 43df0e6..10f24df 100644 --- a/net/base/client_socket_factory.cc +++ b/net/base/client_socket_factory.cc @@ -7,7 +7,7 @@ #include "base/singleton.h" #include "build/build_config.h" #if defined(OS_WIN) -#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_client_socket_win.h" #endif #include "net/base/tcp_client_socket.h" @@ -20,13 +20,12 @@ class DefaultClientSocketFactory : public ClientSocketFactory { return new TCPClientSocket(addresses); } - virtual ClientSocket* CreateSSLClientSocket( + virtual SSLClientSocket* CreateSSLClientSocket( ClientSocket* transport_socket, const std::string& hostname, - int protocol_version_mask) { + const SSLConfig& ssl_config) { #if defined(OS_WIN) - return new SSLClientSocket(transport_socket, hostname, - protocol_version_mask); + return new SSLClientSocketWin(transport_socket, hostname, ssl_config); #else // TODO(pinkerton): turn on when we port SSL socket from win32 NOTIMPLEMENTED(); diff --git a/net/base/client_socket_factory.h b/net/base/client_socket_factory.h index ed371a3..3a3bb84 100644 --- a/net/base/client_socket_factory.h +++ b/net/base/client_socket_factory.h @@ -11,6 +11,8 @@ namespace net { class AddressList; class ClientSocket; +class SSLClientSocket; +struct SSLConfig; // An interface used to instantiate ClientSocket objects. Used to facilitate // testing code with mock socket implementations. @@ -21,13 +23,10 @@ class ClientSocketFactory { virtual ClientSocket* CreateTCPClientSocket( const AddressList& addresses) = 0; - // protocol_version_mask is a bitmask that specifies which versions of the - // SSL protocol (SSL 2.0, SSL 3.0, and TLS 1.0) should be enabled. The bit - // flags are defined in net/base/ssl_client_socket.h. - virtual ClientSocket* CreateSSLClientSocket( + virtual SSLClientSocket* CreateSSLClientSocket( ClientSocket* transport_socket, const std::string& hostname, - int protocol_version_mask) = 0; + const SSLConfig& ssl_config) = 0; // Returns the default ClientSocketFactory. static ClientSocketFactory* GetDefaultFactory(); diff --git a/net/base/ssl_client_socket.cc b/net/base/ssl_client_socket.cc deleted file mode 100644 index d155009..0000000 --- a/net/base/ssl_client_socket.cc +++ /dev/null @@ -1,916 +0,0 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/base/ssl_client_socket.h" - -#include - -#include "base/singleton.h" -#include "base/string_util.h" -#include "net/base/net_errors.h" -#include "net/base/ssl_info.h" - -#pragma comment(lib, "secur32.lib") - -namespace net { - -//----------------------------------------------------------------------------- - -// TODO(wtc): See http://msdn.microsoft.com/en-us/library/aa377188(VS.85).aspx -// for the other error codes we may need to map. -static int MapSecurityError(SECURITY_STATUS err) { - // There are numerous security error codes, but these are the ones we thus - // far find interesting. - switch (err) { - case SEC_E_WRONG_PRINCIPAL: // Schannel - case CERT_E_CN_NO_MATCH: // CryptoAPI - return ERR_CERT_COMMON_NAME_INVALID; - case SEC_E_UNTRUSTED_ROOT: // Schannel - case CERT_E_UNTRUSTEDROOT: // CryptoAPI - return ERR_CERT_AUTHORITY_INVALID; - case SEC_E_CERT_EXPIRED: // Schannel - case CERT_E_EXPIRED: // CryptoAPI - return ERR_CERT_DATE_INVALID; - case CRYPT_E_NO_REVOCATION_CHECK: - return ERR_CERT_NO_REVOCATION_MECHANISM; - case CRYPT_E_REVOKED: // Schannel and CryptoAPI - return ERR_CERT_REVOKED; - case SEC_E_CERT_UNKNOWN: - return ERR_CERT_INVALID; - // We received an unexpected_message or illegal_parameter alert message - // from the server. - case SEC_E_ILLEGAL_MESSAGE: - return ERR_SSL_PROTOCOL_ERROR; - case SEC_E_ALGORITHM_MISMATCH: - return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; - case SEC_E_OK: - return OK; - default: - LOG(WARNING) << "Unknown error " << err << " mapped to net::ERR_FAILED"; - return ERR_FAILED; - } -} - -// Map a network error code to the equivalent certificate status flag. If -// the error code is not a certificate error, it is mapped to 0. -static int MapNetErrorToCertStatus(int error) { - switch (error) { - case ERR_CERT_COMMON_NAME_INVALID: - return CERT_STATUS_COMMON_NAME_INVALID; - case ERR_CERT_DATE_INVALID: - return CERT_STATUS_DATE_INVALID; - case ERR_CERT_AUTHORITY_INVALID: - return CERT_STATUS_AUTHORITY_INVALID; - case ERR_CERT_NO_REVOCATION_MECHANISM: - return CERT_STATUS_NO_REVOCATION_MECHANISM; - case ERR_CERT_UNABLE_TO_CHECK_REVOCATION: - return CERT_STATUS_UNABLE_TO_CHECK_REVOCATION; - case ERR_CERT_REVOKED: - return CERT_STATUS_REVOKED; - case ERR_CERT_CONTAINS_ERRORS: - NOTREACHED(); - // Falls through. - case ERR_CERT_INVALID: - return CERT_STATUS_INVALID; - default: - return 0; - } -} - -//----------------------------------------------------------------------------- - -// Size of recv_buffer_ -// -// Ciphertext is decrypted one SSL record at a time, so recv_buffer_ needs to -// have room for a full SSL record, with the header and trailer. Here is the -// breakdown of the size: -// 5: SSL record header -// 16K: SSL record maximum size -// 64: >= SSL record trailer (16 or 20 have been observed) -static const int kRecvBufferSize = (5 + 16*1024 + 64); - -SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket, - const std::string& hostname, - int protocol_version_mask) -#pragma warning(suppress: 4355) - : io_callback_(this, &SSLClientSocket::OnIOComplete), - transport_(transport_socket), - hostname_(hostname), - protocol_version_mask_(protocol_version_mask), - user_callback_(NULL), - user_buf_(NULL), - user_buf_len_(0), - next_state_(STATE_NONE), - server_cert_(NULL), - server_cert_status_(0), - payload_send_buffer_len_(0), - bytes_sent_(0), - decrypted_ptr_(NULL), - bytes_decrypted_(0), - received_ptr_(NULL), - bytes_received_(0), - completed_handshake_(false), - ignore_ok_result_(false), - no_client_cert_(false) { - memset(&stream_sizes_, 0, sizeof(stream_sizes_)); - memset(&send_buffer_, 0, sizeof(send_buffer_)); - memset(&creds_, 0, sizeof(creds_)); - memset(&ctxt_, 0, sizeof(ctxt_)); -} - -SSLClientSocket::~SSLClientSocket() { - Disconnect(); -} - -int SSLClientSocket::Connect(CompletionCallback* callback) { - DCHECK(transport_.get()); - DCHECK(next_state_ == STATE_NONE); - DCHECK(!user_callback_); - - next_state_ = STATE_CONNECT; - int rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) - user_callback_ = callback; - return rv; -} - -int SSLClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) { - // TODO(darin): implement me! - return ERR_FAILED; -} - -void SSLClientSocket::Disconnect() { - // TODO(wtc): Send SSL close_notify alert. - completed_handshake_ = false; - transport_->Disconnect(); - - if (send_buffer_.pvBuffer) { - FreeContextBuffer(send_buffer_.pvBuffer); - memset(&send_buffer_, 0, sizeof(send_buffer_)); - } - if (creds_.dwLower || creds_.dwUpper) { - FreeCredentialsHandle(&creds_); - memset(&creds_, 0, sizeof(creds_)); - } - if (ctxt_.dwLower || ctxt_.dwUpper) { - DeleteSecurityContext(&ctxt_); - memset(&ctxt_, 0, sizeof(ctxt_)); - } - if (server_cert_) - CertFreeCertificateContext(server_cert_); - - // TODO(wtc): reset more members? - bytes_decrypted_ = 0; - bytes_received_ = 0; -} - -bool SSLClientSocket::IsConnected() const { - // Ideally, we should also check if we have received the close_notify alert - // message from the server, and return false in that case. We're not doing - // that, so this function may return a false positive. Since the upper - // layer (HttpNetworkTransaction) needs to handle a persistent connection - // closed by the server when we send a request anyway, a false positive in - // exchange for simpler code is a good trade-off. - return completed_handshake_ && transport_->IsConnected(); -} - -int SSLClientSocket::Read(char* buf, int buf_len, - CompletionCallback* callback) { - DCHECK(completed_handshake_); - DCHECK(next_state_ == STATE_NONE); - DCHECK(!user_callback_); - - // If we have surplus decrypted plaintext, satisfy the Read with it without - // reading more ciphertext from the transport socket. - if (bytes_decrypted_ != 0) { - int len = std::min(buf_len, bytes_decrypted_); - memcpy(buf, decrypted_ptr_, len); - decrypted_ptr_ += len; - bytes_decrypted_ -= len; - if (bytes_decrypted_ == 0) { - decrypted_ptr_ = NULL; - if (bytes_received_ != 0) { - memmove(recv_buffer_.get(), received_ptr_, bytes_received_); - received_ptr_ = recv_buffer_.get(); - } - } - return len; - } - - user_buf_ = buf; - user_buf_len_ = buf_len; - - if (bytes_received_ == 0) { - next_state_ = STATE_PAYLOAD_READ; - } else { - next_state_ = STATE_PAYLOAD_READ_COMPLETE; - ignore_ok_result_ = true; // OK doesn't mean EOF. - } - int rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) - user_callback_ = callback; - return rv; -} - -int SSLClientSocket::Write(const char* buf, int buf_len, - CompletionCallback* callback) { - DCHECK(completed_handshake_); - DCHECK(next_state_ == STATE_NONE); - DCHECK(!user_callback_); - - user_buf_ = const_cast(buf); - user_buf_len_ = buf_len; - - next_state_ = STATE_PAYLOAD_ENCRYPT; - int rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) - user_callback_ = callback; - return rv; -} - -void SSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { - SECURITY_STATUS status = SEC_E_OK; - if (server_cert_ == NULL) { - status = QueryContextAttributes(&ctxt_, - SECPKG_ATTR_REMOTE_CERT_CONTEXT, - &server_cert_); - } - if (status == SEC_E_OK) { - DCHECK(server_cert_); - PCCERT_CONTEXT dup_cert = CertDuplicateCertificateContext(server_cert_); - ssl_info->cert = X509Certificate::CreateFromHandle(dup_cert); - } - SecPkgContext_ConnectionInfo connection_info; - status = QueryContextAttributes(&ctxt_, - SECPKG_ATTR_CONNECTION_INFO, - &connection_info); - if (status == SEC_E_OK) { - // TODO(wtc): compute the overall security strength, taking into account - // dwExchStrength and dwHashStrength. dwExchStrength needs to be - // normalized. - ssl_info->security_bits = connection_info.dwCipherStrength; - } - ssl_info->cert_status = server_cert_status_; -} - -void SSLClientSocket::DoCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_callback_); - - // since Run may result in Read being called, clear user_callback_ up front. - CompletionCallback* c = user_callback_; - user_callback_ = NULL; - c->Run(rv); -} - -void SSLClientSocket::OnIOComplete(int result) { - int rv = DoLoop(result); - if (rv != ERR_IO_PENDING) - DoCallback(rv); -} - -int SSLClientSocket::DoLoop(int last_io_result) { - DCHECK(next_state_ != STATE_NONE); - int rv = last_io_result; - do { - State state = next_state_; - next_state_ = STATE_NONE; - switch (state) { - case STATE_CONNECT: - rv = DoConnect(); - break; - case STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); - break; - case STATE_HANDSHAKE_READ: - rv = DoHandshakeRead(); - break; - case STATE_HANDSHAKE_READ_COMPLETE: - rv = DoHandshakeReadComplete(rv); - break; - case STATE_HANDSHAKE_WRITE: - rv = DoHandshakeWrite(); - break; - case STATE_HANDSHAKE_WRITE_COMPLETE: - rv = DoHandshakeWriteComplete(rv); - break; - case STATE_PAYLOAD_READ: - rv = DoPayloadRead(); - break; - case STATE_PAYLOAD_READ_COMPLETE: - rv = DoPayloadReadComplete(rv); - break; - case STATE_PAYLOAD_ENCRYPT: - rv = DoPayloadEncrypt(); - break; - case STATE_PAYLOAD_WRITE: - rv = DoPayloadWrite(); - break; - case STATE_PAYLOAD_WRITE_COMPLETE: - rv = DoPayloadWriteComplete(rv); - break; - default: - rv = ERR_UNEXPECTED; - NOTREACHED() << "unexpected state"; - break; - } - } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); - return rv; -} - -int SSLClientSocket::DoConnect() { - next_state_ = STATE_CONNECT_COMPLETE; - return transport_->Connect(&io_callback_); -} - -int SSLClientSocket::DoConnectComplete(int result) { - if (result < 0) - return result; - - memset(&ctxt_, 0, sizeof(ctxt_)); - memset(&creds_, 0, sizeof(creds_)); - - SCHANNEL_CRED schannel_cred = {0}; - schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; - - // The global system registry settings take precedence over the value of - // schannel_cred.grbitEnabledProtocols. - schannel_cred.grbitEnabledProtocols = 0; - if (protocol_version_mask_ & SSL2) - schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL2; - if (protocol_version_mask_ & SSL3) - schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL3; - if (protocol_version_mask_ & TLS1) - schannel_cred.grbitEnabledProtocols |= SP_PROT_TLS1; - // The default (0) means Schannel selects the protocol, rather than no - // protocols are selected. So we have to fail here. - if (schannel_cred.grbitEnabledProtocols == 0) - return ERR_NO_SSL_VERSIONS_ENABLED; - - // The default session lifetime is 36000000 milliseconds (ten hours). Set - // schannel_cred.dwSessionLifespan to change the number of milliseconds that - // Schannel keeps the session in its session cache. - - // We can set the key exchange algorithms (RSA or DH) in - // schannel_cred.{cSupportedAlgs,palgSupportedAlgs}. - - // Although SCH_CRED_AUTO_CRED_VALIDATION is convenient, we have to use - // SCH_CRED_MANUAL_CRED_VALIDATION for three reasons. - // 1. SCH_CRED_AUTO_CRED_VALIDATION doesn't allow us to get the certificate - // context if the certificate validation fails. - // 2. SCH_CRED_AUTO_CRED_VALIDATION returns only one error even if the - // certificate has multiple errors. - // 3. SCH_CRED_AUTO_CRED_VALIDATION doesn't allow us to ignore untrusted CA - // and expired certificate errors. There are only flags to ignore the - // name mismatch and unable-to-check-revocation errors. - // - // TODO(wtc): Look into undocumented or poorly documented flags: - // SCH_CRED_RESTRICTED_ROOTS - // SCH_CRED_REVOCATION_CHECK_CACHE_ONLY - // SCH_CRED_CACHE_ONLY_URL_RETRIEVAL - // SCH_CRED_MEMORY_STORE_CERT - schannel_cred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS | - SCH_CRED_MANUAL_CRED_VALIDATION; - TimeStamp expiry; - SECURITY_STATUS status; - - status = AcquireCredentialsHandle( - NULL, // Not used - UNISP_NAME, // Microsoft Unified Security Protocol Provider - SECPKG_CRED_OUTBOUND, - NULL, // Not used - &schannel_cred, - NULL, // Not used - NULL, // Not used - &creds_, - &expiry); // Optional - if (status != SEC_E_OK) { - DLOG(ERROR) << "AcquireCredentialsHandle failed: " << status; - return MapSecurityError(status); - } - - SecBufferDesc buffer_desc; - DWORD out_flags; - DWORD flags = ISC_REQ_SEQUENCE_DETECT | - ISC_REQ_REPLAY_DETECT | - ISC_REQ_CONFIDENTIALITY | - ISC_RET_EXTENDED_ERROR | - ISC_REQ_ALLOCATE_MEMORY | - ISC_REQ_STREAM; - - send_buffer_.pvBuffer = NULL; - send_buffer_.BufferType = SECBUFFER_TOKEN; - send_buffer_.cbBuffer = 0; - - buffer_desc.cBuffers = 1; - buffer_desc.pBuffers = &send_buffer_; - buffer_desc.ulVersion = SECBUFFER_VERSION; - - status = InitializeSecurityContext( - &creds_, - NULL, // NULL on the first call - const_cast(ASCIIToWide(hostname_).c_str()), - flags, - 0, // Reserved - SECURITY_NATIVE_DREP, // TODO(wtc): MSDN says this should be set to 0. - NULL, // NULL on the first call - 0, // Reserved - &ctxt_, // Receives the new context handle - &buffer_desc, - &out_flags, - &expiry); - if (status != SEC_I_CONTINUE_NEEDED) { - DLOG(ERROR) << "InitializeSecurityContext failed: " << status; - return MapSecurityError(status); - } - - next_state_ = STATE_HANDSHAKE_WRITE; - return OK; -} - -int SSLClientSocket::DoHandshakeRead() { - next_state_ = STATE_HANDSHAKE_READ_COMPLETE; - - if (!recv_buffer_.get()) - recv_buffer_.reset(new char[kRecvBufferSize]); - - char* buf = recv_buffer_.get() + bytes_received_; - int buf_len = kRecvBufferSize - bytes_received_; - - if (buf_len <= 0) { - NOTREACHED() << "Receive buffer is too small!"; - return ERR_UNEXPECTED; - } - - return transport_->Read(buf, buf_len, &io_callback_); -} - -int SSLClientSocket::DoHandshakeReadComplete(int result) { - if (result < 0) - return result; - if (result == 0 && !ignore_ok_result_) - return ERR_FAILED; // Incomplete response :( - - ignore_ok_result_ = false; - - bytes_received_ += result; - - // Process the contents of recv_buffer_. - SECURITY_STATUS status; - TimeStamp expiry; - DWORD out_flags; - - DWORD flags = ISC_REQ_SEQUENCE_DETECT | - ISC_REQ_REPLAY_DETECT | - ISC_REQ_CONFIDENTIALITY | - ISC_RET_EXTENDED_ERROR | - ISC_REQ_ALLOCATE_MEMORY | - ISC_REQ_STREAM; - - // When InitializeSecurityContext returns SEC_I_INCOMPLETE_CREDENTIALS, - // John Banes (a Microsoft security developer) said we need to pass in the - // ISC_REQ_USE_SUPPLIED_CREDS flag if we skip finding a client certificate - // and just call InitializeSecurityContext again. (See - // (http://www.derkeiler.com/Newsgroups/microsoft.public.platformsdk.security/2004-08/0187.html.) - // My testing on XP SP2 and Vista SP1 shows that it still works without - // passing in this flag, but I pass it in to be safe. - if (no_client_cert_) - flags |= ISC_REQ_USE_SUPPLIED_CREDS; - - SecBufferDesc in_buffer_desc, out_buffer_desc; - SecBuffer in_buffers[2]; - - in_buffer_desc.cBuffers = 2; - in_buffer_desc.pBuffers = in_buffers; - in_buffer_desc.ulVersion = SECBUFFER_VERSION; - - in_buffers[0].pvBuffer = &recv_buffer_[0]; - in_buffers[0].cbBuffer = bytes_received_; - in_buffers[0].BufferType = SECBUFFER_TOKEN; - - in_buffers[1].pvBuffer = NULL; - in_buffers[1].cbBuffer = 0; - in_buffers[1].BufferType = SECBUFFER_EMPTY; - - out_buffer_desc.cBuffers = 1; - out_buffer_desc.pBuffers = &send_buffer_; - out_buffer_desc.ulVersion = SECBUFFER_VERSION; - - send_buffer_.pvBuffer = NULL; - send_buffer_.BufferType = SECBUFFER_TOKEN; - send_buffer_.cbBuffer = 0; - - status = InitializeSecurityContext( - &creds_, - &ctxt_, - NULL, - flags, - 0, - SECURITY_NATIVE_DREP, - &in_buffer_desc, - 0, - NULL, - &out_buffer_desc, - &out_flags, - &expiry); - - if (status == SEC_E_INCOMPLETE_MESSAGE) { - DCHECK(FAILED(status)); - DCHECK(send_buffer_.cbBuffer == 0 || - !(out_flags & ISC_RET_EXTENDED_ERROR)); - next_state_ = STATE_HANDSHAKE_READ; - return OK; - } - - if (send_buffer_.cbBuffer != 0 && - (status == SEC_E_OK || - status == SEC_I_CONTINUE_NEEDED || - FAILED(status) && (out_flags & ISC_RET_EXTENDED_ERROR))) { - // TODO(wtc): if status is SEC_E_OK, we should finish the handshake - // successfully after sending send_buffer_. - // If FAILED(status) is true, we should terminate the connection after - // sending send_buffer_. - DCHECK(status == SEC_I_CONTINUE_NEEDED); // We only handle this case - // correctly. - next_state_ = STATE_HANDSHAKE_WRITE; - bytes_received_ = 0; - return OK; - } - - if (status == SEC_E_OK) { - if (in_buffers[1].BufferType == SECBUFFER_EXTRA) { - // TODO(darin) need to save this data for later. - NOTREACHED() << "should not occur for HTTPS traffic"; - return ERR_FAILED; - } - bytes_received_ = 0; - return DidCompleteHandshake(); - } - - if (FAILED(status)) - return MapSecurityError(status); - - if (status == SEC_I_INCOMPLETE_CREDENTIALS) { - // We don't support SSL client authentication yet. For now we just set - // no_client_cert_ to true and call InitializeSecurityContext again. - no_client_cert_ = true; - next_state_ = STATE_HANDSHAKE_READ_COMPLETE; - ignore_ok_result_ = true; // OK doesn't mean EOF. - return OK; - } - - DCHECK(status == SEC_I_CONTINUE_NEEDED); - if (in_buffers[1].BufferType == SECBUFFER_EXTRA) { - memmove(&recv_buffer_[0], - &recv_buffer_[0] + (bytes_received_ - in_buffers[1].cbBuffer), - in_buffers[1].cbBuffer); - bytes_received_ = in_buffers[1].cbBuffer; - next_state_ = STATE_HANDSHAKE_READ_COMPLETE; - ignore_ok_result_ = true; // OK doesn't mean EOF. - return OK; - } - - bytes_received_ = 0; - next_state_ = STATE_HANDSHAKE_READ; - return OK; -} - -int SSLClientSocket::DoHandshakeWrite() { - next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; - - // We should have something to send. - DCHECK(send_buffer_.pvBuffer); - DCHECK(send_buffer_.cbBuffer > 0); - - const char* buf = static_cast(send_buffer_.pvBuffer) + bytes_sent_; - int buf_len = send_buffer_.cbBuffer - bytes_sent_; - - return transport_->Write(buf, buf_len, &io_callback_); -} - -int SSLClientSocket::DoHandshakeWriteComplete(int result) { - if (result < 0) - return result; - - DCHECK(result != 0); - - bytes_sent_ += result; - DCHECK(bytes_sent_ <= static_cast(send_buffer_.cbBuffer)); - - if (bytes_sent_ >= static_cast(send_buffer_.cbBuffer)) { - bool overflow = (bytes_sent_ > static_cast(send_buffer_.cbBuffer)); - SECURITY_STATUS status = FreeContextBuffer(send_buffer_.pvBuffer); - DCHECK(status == SEC_E_OK); - memset(&send_buffer_, 0, sizeof(send_buffer_)); - bytes_sent_ = 0; - if (overflow) // Bug! - return ERR_UNEXPECTED; - next_state_ = STATE_HANDSHAKE_READ; - } else { - // Send the remaining bytes. - next_state_ = STATE_HANDSHAKE_WRITE; - } - - return OK; -} - -int SSLClientSocket::DoPayloadRead() { - next_state_ = STATE_PAYLOAD_READ_COMPLETE; - - DCHECK(recv_buffer_.get()); - - char* buf = recv_buffer_.get() + bytes_received_; - int buf_len = kRecvBufferSize - bytes_received_; - - if (buf_len <= 0) { - NOTREACHED() << "Receive buffer is too small!"; - return ERR_FAILED; - } - - return transport_->Read(buf, buf_len, &io_callback_); -} - -int SSLClientSocket::DoPayloadReadComplete(int result) { - if (result < 0) - return result; - if (result == 0 && !ignore_ok_result_) { - // TODO(wtc): Unless we have received the close_notify alert, we need to - // return an error code indicating that the SSL connection ended - // uncleanly, a potential truncation attack. - if (bytes_received_ != 0) - return ERR_FAILED; - return OK; - } - - ignore_ok_result_ = false; - - bytes_received_ += result; - - // Process the contents of recv_buffer_. - SecBuffer buffers[4]; - buffers[0].pvBuffer = recv_buffer_.get(); - buffers[0].cbBuffer = bytes_received_; - buffers[0].BufferType = SECBUFFER_DATA; - - buffers[1].BufferType = SECBUFFER_EMPTY; - buffers[2].BufferType = SECBUFFER_EMPTY; - buffers[3].BufferType = SECBUFFER_EMPTY; - - SecBufferDesc buffer_desc; - buffer_desc.cBuffers = 4; - buffer_desc.pBuffers = buffers; - buffer_desc.ulVersion = SECBUFFER_VERSION; - - SECURITY_STATUS status; - status = DecryptMessage(&ctxt_, &buffer_desc, 0, NULL); - - if (status == SEC_E_INCOMPLETE_MESSAGE) { - next_state_ = STATE_PAYLOAD_READ; - return OK; - } - - if (status == SEC_I_CONTEXT_EXPIRED) { - // Received the close_notify alert. - bytes_received_ = 0; - return OK; - } - - if (status != SEC_E_OK && status != SEC_I_RENEGOTIATE) { - DCHECK(status != SEC_E_MESSAGE_ALTERED); - return MapSecurityError(status); - } - - // The received ciphertext was decrypted in place in recv_buffer_. Remember - // the location and length of the decrypted plaintext and any unused - // ciphertext. - decrypted_ptr_ = NULL; - bytes_decrypted_ = 0; - received_ptr_ = NULL; - bytes_received_ = 0; - for (int i = 1; i < 4; i++) { - if (!decrypted_ptr_ && buffers[i].BufferType == SECBUFFER_DATA) { - decrypted_ptr_ = static_cast(buffers[i].pvBuffer); - bytes_decrypted_ = buffers[i].cbBuffer; - } - if (!received_ptr_ && buffers[i].BufferType == SECBUFFER_EXTRA) { - received_ptr_ = static_cast(buffers[i].pvBuffer); - bytes_received_ = buffers[i].cbBuffer; - } - } - - int len = 0; - if (bytes_decrypted_ != 0) { - len = std::min(user_buf_len_, bytes_decrypted_); - memcpy(user_buf_, decrypted_ptr_, len); - decrypted_ptr_ += len; - bytes_decrypted_ -= len; - } - if (bytes_decrypted_ == 0) { - decrypted_ptr_ = NULL; - if (bytes_received_ != 0) { - memmove(recv_buffer_.get(), received_ptr_, bytes_received_); - received_ptr_ = recv_buffer_.get(); - } - } - // TODO(wtc): need to handle SEC_I_RENEGOTIATE. - DCHECK(status == SEC_E_OK); - // If we decrypted 0 bytes, don't report 0 bytes read, which would be - // mistaken for EOF. Continue decrypting or read more. - if (len == 0) { - if (bytes_received_ == 0) { - next_state_ = STATE_PAYLOAD_READ; - } else { - next_state_ = STATE_PAYLOAD_READ_COMPLETE; - ignore_ok_result_ = true; // OK doesn't mean EOF. - } - } - return len; -} - -int SSLClientSocket::DoPayloadEncrypt() { - DCHECK(user_buf_); - DCHECK(user_buf_len_ > 0); - - ULONG message_len = std::min( - stream_sizes_.cbMaximumMessage, static_cast(user_buf_len_)); - ULONG alloc_len = - message_len + stream_sizes_.cbHeader + stream_sizes_.cbTrailer; - user_buf_len_ = message_len; - - payload_send_buffer_.reset(new char[alloc_len]); - memcpy(&payload_send_buffer_[stream_sizes_.cbHeader], - user_buf_, message_len); - - SecBuffer buffers[4]; - buffers[0].pvBuffer = payload_send_buffer_.get(); - buffers[0].cbBuffer = stream_sizes_.cbHeader; - buffers[0].BufferType = SECBUFFER_STREAM_HEADER; - - buffers[1].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader]; - buffers[1].cbBuffer = message_len; - buffers[1].BufferType = SECBUFFER_DATA; - - buffers[2].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader + - message_len]; - buffers[2].cbBuffer = stream_sizes_.cbTrailer; - buffers[2].BufferType = SECBUFFER_STREAM_TRAILER; - - buffers[3].BufferType = SECBUFFER_EMPTY; - - SecBufferDesc buffer_desc; - buffer_desc.cBuffers = 4; - buffer_desc.pBuffers = buffers; - buffer_desc.ulVersion = SECBUFFER_VERSION; - - SECURITY_STATUS status = EncryptMessage(&ctxt_, 0, &buffer_desc, 0); - - if (FAILED(status)) - return MapSecurityError(status); - - payload_send_buffer_len_ = buffers[0].cbBuffer + - buffers[1].cbBuffer + - buffers[2].cbBuffer; - DCHECK(bytes_sent_ == 0); - - next_state_ = STATE_PAYLOAD_WRITE; - return OK; -} - -int SSLClientSocket::DoPayloadWrite() { - next_state_ = STATE_PAYLOAD_WRITE_COMPLETE; - - // We should have something to send. - DCHECK(payload_send_buffer_.get()); - DCHECK(payload_send_buffer_len_ > 0); - - const char* buf = payload_send_buffer_.get() + bytes_sent_; - int buf_len = payload_send_buffer_len_ - bytes_sent_; - - return transport_->Write(buf, buf_len, &io_callback_); -} - -int SSLClientSocket::DoPayloadWriteComplete(int result) { - if (result < 0) - return result; - - DCHECK(result != 0); - - bytes_sent_ += result; - DCHECK(bytes_sent_ <= payload_send_buffer_len_); - - if (bytes_sent_ >= payload_send_buffer_len_) { - bool overflow = (bytes_sent_ > payload_send_buffer_len_); - payload_send_buffer_.reset(); - payload_send_buffer_len_ = 0; - bytes_sent_ = 0; - if (overflow) // Bug! - return ERR_UNEXPECTED; - // Done - return user_buf_len_; - } - - // Send the remaining bytes. - next_state_ = STATE_PAYLOAD_WRITE; - return OK; -} - -int SSLClientSocket::DidCompleteHandshake() { - SECURITY_STATUS status = QueryContextAttributes( - &ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); - if (status != SEC_E_OK) { - DLOG(ERROR) << "QueryContextAttributes failed: " << status; - return MapSecurityError(status); - } - DCHECK(!server_cert_); - status = QueryContextAttributes( - &ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &server_cert_); - if (status != SEC_E_OK) { - DLOG(ERROR) << "QueryContextAttributes failed: " << status; - return MapSecurityError(status); - } - - completed_handshake_ = true; - int rv = VerifyServerCert(); - // TODO(wtc): for now, always check revocation. - server_cert_status_ = CERT_STATUS_REV_CHECKING_ENABLED; - if (rv) - server_cert_status_ |= MapNetErrorToCertStatus(rv); - return rv; -} - -int SSLClientSocket::VerifyServerCert() { - DCHECK(server_cert_); - - // Build and validate certificate chain. - - CERT_CHAIN_PARA chain_para; - memset(&chain_para, 0, sizeof(chain_para)); - chain_para.cbSize = sizeof(chain_para); - // TODO(wtc): consider requesting the usage szOID_PKIX_KP_SERVER_AUTH - // or szOID_SERVER_GATED_CRYPTO or szOID_SGC_NETSCAPE - chain_para.RequestedUsage.dwType = USAGE_MATCH_TYPE_AND; - chain_para.RequestedUsage.Usage.cUsageIdentifier = 0; - chain_para.RequestedUsage.Usage.rgpszUsageIdentifier = NULL; // LPSTR* - PCCERT_CHAIN_CONTEXT chain_context; - // TODO(wtc): for now, always check revocation. If we don't want to check - // revocation, use the CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY flag. - if (!CertGetCertificateChain( - NULL, // default chain engine, HCCE_CURRENT_USER - server_cert_, - NULL, // current system time - server_cert_->hCertStore, // search this store - &chain_para, - CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT | - CERT_CHAIN_CACHE_END_CERT, - NULL, // reserved - &chain_context)) { - return MapSecurityError(GetLastError()); - } - - std::wstring wstr_hostname = ASCIIToWide(hostname_); - - SSL_EXTRA_CERT_CHAIN_POLICY_PARA extra_policy_para; - memset(&extra_policy_para, 0, sizeof(extra_policy_para)); - extra_policy_para.cbSize = sizeof(extra_policy_para); - extra_policy_para.dwAuthType = AUTHTYPE_SERVER; - // TODO(wtc): Set these flags in fdwChecks to ignore cert errors. - // SECURITY_FLAG_IGNORE_REVOCATION - // SECURITY_FLAG_IGNORE_UNKNOWN_CA - // SECURITY_FLAG_IGNORE_WRONG_USAGE - // SECURITY_FLAG_IGNORE_CERT_CN_INVALID - // SECURITY_FLAG_IGNORE_CERT_DATE_INVALID - extra_policy_para.fdwChecks = 0; - extra_policy_para.pwszServerName = - const_cast(wstr_hostname.c_str()); - - CERT_CHAIN_POLICY_PARA policy_para; - memset(&policy_para, 0, sizeof(policy_para)); - policy_para.cbSize = sizeof(policy_para); - // TODO(wtc): It seems that we can also ignore cert errors by setting - // dwFlags. - policy_para.dwFlags = 0; - policy_para.pvExtraPolicyPara = &extra_policy_para; - - CERT_CHAIN_POLICY_STATUS policy_status; - memset(&policy_status, 0, sizeof(policy_status)); - policy_status.cbSize = sizeof(policy_status); - - if (!CertVerifyCertificateChainPolicy( - CERT_CHAIN_POLICY_SSL, - chain_context, - &policy_para, - &policy_status)) { - return MapSecurityError(GetLastError()); - } - - if (policy_status.dwError) - return MapSecurityError(policy_status.dwError); - - CertFreeCertificateChain(chain_context); - - return OK; -} - -} // namespace net - diff --git a/net/base/ssl_client_socket.h b/net/base/ssl_client_socket.h index 100e514..dca5ef3 100644 --- a/net/base/ssl_client_socket.h +++ b/net/base/ssl_client_socket.h @@ -5,17 +5,7 @@ #ifndef NET_BASE_SSL_CLIENT_SOCKET_H_ #define NET_BASE_SSL_CLIENT_SOCKET_H_ -#define SECURITY_WIN32 // Needs to be defined before including security.h - -#include -#include -#include - -#include - -#include "base/scoped_ptr.h" #include "net/base/client_socket.h" -#include "net/base/completion_callback.h" namespace net { @@ -30,121 +20,8 @@ class SSLInfo; // class SSLClientSocket : public ClientSocket { public: - enum { - SSL2 = 1 << 0, - SSL3 = 1 << 1, - TLS1 = 1 << 2 - }; - - // Takes ownership of the transport_socket, which may already be connected. - // The given hostname will be compared with the name(s) in the server's - // certificate during the SSL handshake. protocol_version_mask is a bitwise - // OR of SSL2, SSL3, and TLS1 that specifies which versions of the SSL - // protocol should be enabled. - SSLClientSocket(ClientSocket* transport_socket, - const std::string& hostname, - int protocol_version_mask); - ~SSLClientSocket(); - - // ClientSocket methods: - virtual int Connect(CompletionCallback* callback); - virtual int ReconnectIgnoringLastError(CompletionCallback* callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - - // Socket methods: - virtual int Read(char* buf, int buf_len, CompletionCallback* callback); - virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); - // Gets the SSL connection information of the socket. - void GetSSLInfo(SSLInfo* ssl_info); - - private: - void DoCallback(int result); - void OnIOComplete(int result); - - int DoLoop(int last_io_result); - int DoConnect(); - int DoConnectComplete(int result); - int DoHandshakeRead(); - int DoHandshakeReadComplete(int result); - int DoHandshakeWrite(); - int DoHandshakeWriteComplete(int result); - int DoPayloadRead(); - int DoPayloadReadComplete(int result); - int DoPayloadEncrypt(); - int DoPayloadWrite(); - int DoPayloadWriteComplete(int result); - - int DidCompleteHandshake(); - int VerifyServerCert(); - - CompletionCallbackImpl io_callback_; - scoped_ptr transport_; - std::string hostname_; - int protocol_version_mask_; - - CompletionCallback* user_callback_; - - // Used by both Read and Write functions. - char* user_buf_; - int user_buf_len_; - - enum State { - STATE_NONE, - STATE_CONNECT, - STATE_CONNECT_COMPLETE, - STATE_HANDSHAKE_READ, - STATE_HANDSHAKE_READ_COMPLETE, - STATE_HANDSHAKE_WRITE, - STATE_HANDSHAKE_WRITE_COMPLETE, - STATE_PAYLOAD_ENCRYPT, - STATE_PAYLOAD_WRITE, - STATE_PAYLOAD_WRITE_COMPLETE, - STATE_PAYLOAD_READ, - STATE_PAYLOAD_READ_COMPLETE, - }; - State next_state_; - - SecPkgContext_StreamSizes stream_sizes_; - PCCERT_CONTEXT server_cert_; - int server_cert_status_; - - CredHandle creds_; - CtxtHandle ctxt_; - SecBuffer send_buffer_; - scoped_array payload_send_buffer_; - int payload_send_buffer_len_; - int bytes_sent_; - - // recv_buffer_ holds the received ciphertext. Since Schannel decrypts - // data in place, sometimes recv_buffer_ may contain decrypted plaintext and - // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record - // at a time.) - // - // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of - // recv_buffer_, ready to be passed to DecryptMessage. - scoped_array recv_buffer_; - char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_ - int bytes_decrypted_; // The number of bytes of decrypted plaintext. - char* received_ptr_; // Points to the received ciphertext in recv_buffer_ - int bytes_received_; // The number of bytes of received ciphertext. - - bool completed_handshake_; - - // Only used in the STATE_HANDSHAKE_READ_COMPLETE and - // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK - // should be ignored, to prevent it from being interpreted as EOF. - // - // The reason we need this flag is that OK means not only "0 bytes of data - // were read" but also EOF. We set ignore_ok_result_ to true when we need - // to continue processing previously read data without reading more data. - // We have to pass a 'result' of OK to the DoLoop method, and don't want it - // to be interpreted as EOF. - bool ignore_ok_result_; - - // True if the user has no client certificate. - bool no_client_cert_; + virtual void GetSSLInfo(SSLInfo* ssl_info) = 0; }; } // namespace net diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc index 2aba7ab..d1f1f82 100644 --- a/net/base/ssl_client_socket_unittest.cc +++ b/net/base/ssl_client_socket_unittest.cc @@ -3,9 +3,11 @@ // found in the LICENSE file. #include "net/base/address_list.h" +#include "net/base/client_socket_factory.h" #include "net/base/host_resolver.h" #include "net/base/net_errors.h" #include "net/base/ssl_client_socket.h" +#include "net/base/ssl_config_service.h" #include "net/base/tcp_client_socket.h" #include "net/base/test_completion_callback.h" #include "testing/gtest/include/gtest/gtest.h" @@ -14,10 +16,16 @@ namespace { -const unsigned int kDefaultSSLVersionMask = net::SSLClientSocket::SSL3 | - net::SSLClientSocket::TLS1; +const net::SSLConfig kDefaultSSLConfig; class SSLClientSocketTest : public testing::Test { + public: + SSLClientSocketTest() + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { + } + + protected: + net::ClientSocketFactory* socket_factory_; }; } // namespace @@ -34,12 +42,13 @@ TEST_F(SSLClientSocketTest, DISABLED_Connect) { int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(net::OK, rv); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - EXPECT_FALSE(sock.IsConnected()); + EXPECT_FALSE(sock->IsConnected()); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(net::ERR_IO_PENDING, rv); @@ -47,10 +56,10 @@ TEST_F(SSLClientSocketTest, DISABLED_Connect) { EXPECT_EQ(net::OK, rv); } - EXPECT_TRUE(sock.IsConnected()); + EXPECT_TRUE(sock->IsConnected()); - sock.Disconnect(); - EXPECT_FALSE(sock.IsConnected()); + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); } // bug 1354783 @@ -66,10 +75,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) { rv = callback.WaitForResult(); EXPECT_EQ(rv, net::OK); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -78,7 +88,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -88,7 +98,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) { char buf[4096]; for (;;) { - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -110,10 +120,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) { int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(rv, net::OK); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -122,7 +133,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -132,7 +143,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) { char buf[1]; for (;;) { - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -154,10 +165,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) { int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(rv, net::OK); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -166,7 +178,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -176,7 +188,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) { // Do a partial read and then exit. This test should not crash! char buf[512]; - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) diff --git a/net/base/ssl_client_socket_win.cc b/net/base/ssl_client_socket_win.cc new file mode 100644 index 0000000..1eeb090 --- /dev/null +++ b/net/base/ssl_client_socket_win.cc @@ -0,0 +1,917 @@ +// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/ssl_client_socket_win.h" + +#include + +#include "base/singleton.h" +#include "base/string_util.h" +#include "net/base/net_errors.h" +#include "net/base/ssl_info.h" + +#pragma comment(lib, "secur32.lib") + +namespace net { + +//----------------------------------------------------------------------------- + +// TODO(wtc): See http://msdn.microsoft.com/en-us/library/aa377188(VS.85).aspx +// for the other error codes we may need to map. +static int MapSecurityError(SECURITY_STATUS err) { + // There are numerous security error codes, but these are the ones we thus + // far find interesting. + switch (err) { + case SEC_E_WRONG_PRINCIPAL: // Schannel + case CERT_E_CN_NO_MATCH: // CryptoAPI + return ERR_CERT_COMMON_NAME_INVALID; + case SEC_E_UNTRUSTED_ROOT: // Schannel + case CERT_E_UNTRUSTEDROOT: // CryptoAPI + return ERR_CERT_AUTHORITY_INVALID; + case SEC_E_CERT_EXPIRED: // Schannel + case CERT_E_EXPIRED: // CryptoAPI + return ERR_CERT_DATE_INVALID; + case CRYPT_E_NO_REVOCATION_CHECK: + return ERR_CERT_NO_REVOCATION_MECHANISM; + case CRYPT_E_REVOKED: // Schannel and CryptoAPI + return ERR_CERT_REVOKED; + case SEC_E_CERT_UNKNOWN: + return ERR_CERT_INVALID; + // We received an unexpected_message or illegal_parameter alert message + // from the server. + case SEC_E_ILLEGAL_MESSAGE: + return ERR_SSL_PROTOCOL_ERROR; + case SEC_E_ALGORITHM_MISMATCH: + return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; + case SEC_E_OK: + return OK; + default: + LOG(WARNING) << "Unknown error " << err << " mapped to net::ERR_FAILED"; + return ERR_FAILED; + } +} + +// Map a network error code to the equivalent certificate status flag. If +// the error code is not a certificate error, it is mapped to 0. +static int MapNetErrorToCertStatus(int error) { + switch (error) { + case ERR_CERT_COMMON_NAME_INVALID: + return CERT_STATUS_COMMON_NAME_INVALID; + case ERR_CERT_DATE_INVALID: + return CERT_STATUS_DATE_INVALID; + case ERR_CERT_AUTHORITY_INVALID: + return CERT_STATUS_AUTHORITY_INVALID; + case ERR_CERT_NO_REVOCATION_MECHANISM: + return CERT_STATUS_NO_REVOCATION_MECHANISM; + case ERR_CERT_UNABLE_TO_CHECK_REVOCATION: + return CERT_STATUS_UNABLE_TO_CHECK_REVOCATION; + case ERR_CERT_REVOKED: + return CERT_STATUS_REVOKED; + case ERR_CERT_CONTAINS_ERRORS: + NOTREACHED(); + // Falls through. + case ERR_CERT_INVALID: + return CERT_STATUS_INVALID; + default: + return 0; + } +} + +//----------------------------------------------------------------------------- + +// Size of recv_buffer_ +// +// Ciphertext is decrypted one SSL record at a time, so recv_buffer_ needs to +// have room for a full SSL record, with the header and trailer. Here is the +// breakdown of the size: +// 5: SSL record header +// 16K: SSL record maximum size +// 64: >= SSL record trailer (16 or 20 have been observed) +static const int kRecvBufferSize = (5 + 16*1024 + 64); + +SSLClientSocketWin::SSLClientSocketWin(ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) +#pragma warning(suppress: 4355) + : io_callback_(this, &SSLClientSocketWin::OnIOComplete), + transport_(transport_socket), + hostname_(hostname), + ssl_config_(ssl_config), + user_callback_(NULL), + user_buf_(NULL), + user_buf_len_(0), + next_state_(STATE_NONE), + server_cert_(NULL), + server_cert_status_(0), + payload_send_buffer_len_(0), + bytes_sent_(0), + decrypted_ptr_(NULL), + bytes_decrypted_(0), + received_ptr_(NULL), + bytes_received_(0), + completed_handshake_(false), + ignore_ok_result_(false), + no_client_cert_(false) { + memset(&stream_sizes_, 0, sizeof(stream_sizes_)); + memset(&send_buffer_, 0, sizeof(send_buffer_)); + memset(&creds_, 0, sizeof(creds_)); + memset(&ctxt_, 0, sizeof(ctxt_)); +} + +SSLClientSocketWin::~SSLClientSocketWin() { + Disconnect(); +} + +void SSLClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { + SECURITY_STATUS status = SEC_E_OK; + if (server_cert_ == NULL) { + status = QueryContextAttributes(&ctxt_, + SECPKG_ATTR_REMOTE_CERT_CONTEXT, + &server_cert_); + } + if (status == SEC_E_OK) { + DCHECK(server_cert_); + PCCERT_CONTEXT dup_cert = CertDuplicateCertificateContext(server_cert_); + ssl_info->cert = X509Certificate::CreateFromHandle(dup_cert); + } + SecPkgContext_ConnectionInfo connection_info; + status = QueryContextAttributes(&ctxt_, + SECPKG_ATTR_CONNECTION_INFO, + &connection_info); + if (status == SEC_E_OK) { + // TODO(wtc): compute the overall security strength, taking into account + // dwExchStrength and dwHashStrength. dwExchStrength needs to be + // normalized. + ssl_info->security_bits = connection_info.dwCipherStrength; + } + ssl_info->cert_status = server_cert_status_; +} + +int SSLClientSocketWin::Connect(CompletionCallback* callback) { + DCHECK(transport_.get()); + DCHECK(next_state_ == STATE_NONE); + DCHECK(!user_callback_); + + next_state_ = STATE_CONNECT; + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + user_callback_ = callback; + return rv; +} + +int SSLClientSocketWin::ReconnectIgnoringLastError( + CompletionCallback* callback) { + // TODO(darin): implement me! + return ERR_FAILED; +} + +void SSLClientSocketWin::Disconnect() { + // TODO(wtc): Send SSL close_notify alert. + completed_handshake_ = false; + transport_->Disconnect(); + + if (send_buffer_.pvBuffer) { + FreeContextBuffer(send_buffer_.pvBuffer); + memset(&send_buffer_, 0, sizeof(send_buffer_)); + } + if (creds_.dwLower || creds_.dwUpper) { + FreeCredentialsHandle(&creds_); + memset(&creds_, 0, sizeof(creds_)); + } + if (ctxt_.dwLower || ctxt_.dwUpper) { + DeleteSecurityContext(&ctxt_); + memset(&ctxt_, 0, sizeof(ctxt_)); + } + if (server_cert_) + CertFreeCertificateContext(server_cert_); + + // TODO(wtc): reset more members? + bytes_decrypted_ = 0; + bytes_received_ = 0; +} + +bool SSLClientSocketWin::IsConnected() const { + // Ideally, we should also check if we have received the close_notify alert + // message from the server, and return false in that case. We're not doing + // that, so this function may return a false positive. Since the upper + // layer (HttpNetworkTransaction) needs to handle a persistent connection + // closed by the server when we send a request anyway, a false positive in + // exchange for simpler code is a good trade-off. + return completed_handshake_ && transport_->IsConnected(); +} + +int SSLClientSocketWin::Read(char* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(completed_handshake_); + DCHECK(next_state_ == STATE_NONE); + DCHECK(!user_callback_); + + // If we have surplus decrypted plaintext, satisfy the Read with it without + // reading more ciphertext from the transport socket. + if (bytes_decrypted_ != 0) { + int len = std::min(buf_len, bytes_decrypted_); + memcpy(buf, decrypted_ptr_, len); + decrypted_ptr_ += len; + bytes_decrypted_ -= len; + if (bytes_decrypted_ == 0) { + decrypted_ptr_ = NULL; + if (bytes_received_ != 0) { + memmove(recv_buffer_.get(), received_ptr_, bytes_received_); + received_ptr_ = recv_buffer_.get(); + } + } + return len; + } + + user_buf_ = buf; + user_buf_len_ = buf_len; + + if (bytes_received_ == 0) { + next_state_ = STATE_PAYLOAD_READ; + } else { + next_state_ = STATE_PAYLOAD_READ_COMPLETE; + ignore_ok_result_ = true; // OK doesn't mean EOF. + } + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + user_callback_ = callback; + return rv; +} + +int SSLClientSocketWin::Write(const char* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(completed_handshake_); + DCHECK(next_state_ == STATE_NONE); + DCHECK(!user_callback_); + + user_buf_ = const_cast(buf); + user_buf_len_ = buf_len; + + next_state_ = STATE_PAYLOAD_ENCRYPT; + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + user_callback_ = callback; + return rv; +} + +void SSLClientSocketWin::DoCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_callback_); + + // since Run may result in Read being called, clear user_callback_ up front. + CompletionCallback* c = user_callback_; + user_callback_ = NULL; + c->Run(rv); +} + +void SSLClientSocketWin::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + DoCallback(rv); +} + +int SSLClientSocketWin::DoLoop(int last_io_result) { + DCHECK(next_state_ != STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT: + rv = DoConnect(); + break; + case STATE_CONNECT_COMPLETE: + rv = DoConnectComplete(rv); + break; + case STATE_HANDSHAKE_READ: + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + break; + case STATE_HANDSHAKE_WRITE: + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + break; + case STATE_PAYLOAD_READ: + rv = DoPayloadRead(); + break; + case STATE_PAYLOAD_READ_COMPLETE: + rv = DoPayloadReadComplete(rv); + break; + case STATE_PAYLOAD_ENCRYPT: + rv = DoPayloadEncrypt(); + break; + case STATE_PAYLOAD_WRITE: + rv = DoPayloadWrite(); + break; + case STATE_PAYLOAD_WRITE_COMPLETE: + rv = DoPayloadWriteComplete(rv); + break; + default: + rv = ERR_UNEXPECTED; + NOTREACHED() << "unexpected state"; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int SSLClientSocketWin::DoConnect() { + next_state_ = STATE_CONNECT_COMPLETE; + return transport_->Connect(&io_callback_); +} + +int SSLClientSocketWin::DoConnectComplete(int result) { + if (result < 0) + return result; + + memset(&ctxt_, 0, sizeof(ctxt_)); + memset(&creds_, 0, sizeof(creds_)); + + SCHANNEL_CRED schannel_cred = {0}; + schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; + + // The global system registry settings take precedence over the value of + // schannel_cred.grbitEnabledProtocols. + schannel_cred.grbitEnabledProtocols = 0; + if (ssl_config_.ssl2_enabled) + schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL2; + if (ssl_config_.ssl3_enabled) + schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL3; + if (ssl_config_.tls1_enabled) + schannel_cred.grbitEnabledProtocols |= SP_PROT_TLS1; + // The default (0) means Schannel selects the protocol, rather than no + // protocols are selected. So we have to fail here. + if (schannel_cred.grbitEnabledProtocols == 0) + return ERR_NO_SSL_VERSIONS_ENABLED; + + // The default session lifetime is 36000000 milliseconds (ten hours). Set + // schannel_cred.dwSessionLifespan to change the number of milliseconds that + // Schannel keeps the session in its session cache. + + // We can set the key exchange algorithms (RSA or DH) in + // schannel_cred.{cSupportedAlgs,palgSupportedAlgs}. + + // Although SCH_CRED_AUTO_CRED_VALIDATION is convenient, we have to use + // SCH_CRED_MANUAL_CRED_VALIDATION for three reasons. + // 1. SCH_CRED_AUTO_CRED_VALIDATION doesn't allow us to get the certificate + // context if the certificate validation fails. + // 2. SCH_CRED_AUTO_CRED_VALIDATION returns only one error even if the + // certificate has multiple errors. + // 3. SCH_CRED_AUTO_CRED_VALIDATION doesn't allow us to ignore untrusted CA + // and expired certificate errors. There are only flags to ignore the + // name mismatch and unable-to-check-revocation errors. + // + // TODO(wtc): Look into undocumented or poorly documented flags: + // SCH_CRED_RESTRICTED_ROOTS + // SCH_CRED_REVOCATION_CHECK_CACHE_ONLY + // SCH_CRED_CACHE_ONLY_URL_RETRIEVAL + // SCH_CRED_MEMORY_STORE_CERT + schannel_cred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS | + SCH_CRED_MANUAL_CRED_VALIDATION; + TimeStamp expiry; + SECURITY_STATUS status; + + status = AcquireCredentialsHandle( + NULL, // Not used + UNISP_NAME, // Microsoft Unified Security Protocol Provider + SECPKG_CRED_OUTBOUND, + NULL, // Not used + &schannel_cred, + NULL, // Not used + NULL, // Not used + &creds_, + &expiry); // Optional + if (status != SEC_E_OK) { + DLOG(ERROR) << "AcquireCredentialsHandle failed: " << status; + return MapSecurityError(status); + } + + SecBufferDesc buffer_desc; + DWORD out_flags; + DWORD flags = ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_CONFIDENTIALITY | + ISC_RET_EXTENDED_ERROR | + ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_STREAM; + + send_buffer_.pvBuffer = NULL; + send_buffer_.BufferType = SECBUFFER_TOKEN; + send_buffer_.cbBuffer = 0; + + buffer_desc.cBuffers = 1; + buffer_desc.pBuffers = &send_buffer_; + buffer_desc.ulVersion = SECBUFFER_VERSION; + + status = InitializeSecurityContext( + &creds_, + NULL, // NULL on the first call + const_cast(ASCIIToWide(hostname_).c_str()), + flags, + 0, // Reserved + SECURITY_NATIVE_DREP, // TODO(wtc): MSDN says this should be set to 0. + NULL, // NULL on the first call + 0, // Reserved + &ctxt_, // Receives the new context handle + &buffer_desc, + &out_flags, + &expiry); + if (status != SEC_I_CONTINUE_NEEDED) { + DLOG(ERROR) << "InitializeSecurityContext failed: " << status; + return MapSecurityError(status); + } + + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; +} + +int SSLClientSocketWin::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (!recv_buffer_.get()) + recv_buffer_.reset(new char[kRecvBufferSize]); + + char* buf = recv_buffer_.get() + bytes_received_; + int buf_len = kRecvBufferSize - bytes_received_; + + if (buf_len <= 0) { + NOTREACHED() << "Receive buffer is too small!"; + return ERR_UNEXPECTED; + } + + return transport_->Read(buf, buf_len, &io_callback_); +} + +int SSLClientSocketWin::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + if (result == 0 && !ignore_ok_result_) + return ERR_FAILED; // Incomplete response :( + + ignore_ok_result_ = false; + + bytes_received_ += result; + + // Process the contents of recv_buffer_. + SECURITY_STATUS status; + TimeStamp expiry; + DWORD out_flags; + + DWORD flags = ISC_REQ_SEQUENCE_DETECT | + ISC_REQ_REPLAY_DETECT | + ISC_REQ_CONFIDENTIALITY | + ISC_RET_EXTENDED_ERROR | + ISC_REQ_ALLOCATE_MEMORY | + ISC_REQ_STREAM; + + // When InitializeSecurityContext returns SEC_I_INCOMPLETE_CREDENTIALS, + // John Banes (a Microsoft security developer) said we need to pass in the + // ISC_REQ_USE_SUPPLIED_CREDS flag if we skip finding a client certificate + // and just call InitializeSecurityContext again. (See + // (http://www.derkeiler.com/Newsgroups/microsoft.public.platformsdk.security/2004-08/0187.html.) + // My testing on XP SP2 and Vista SP1 shows that it still works without + // passing in this flag, but I pass it in to be safe. + if (no_client_cert_) + flags |= ISC_REQ_USE_SUPPLIED_CREDS; + + SecBufferDesc in_buffer_desc, out_buffer_desc; + SecBuffer in_buffers[2]; + + in_buffer_desc.cBuffers = 2; + in_buffer_desc.pBuffers = in_buffers; + in_buffer_desc.ulVersion = SECBUFFER_VERSION; + + in_buffers[0].pvBuffer = &recv_buffer_[0]; + in_buffers[0].cbBuffer = bytes_received_; + in_buffers[0].BufferType = SECBUFFER_TOKEN; + + in_buffers[1].pvBuffer = NULL; + in_buffers[1].cbBuffer = 0; + in_buffers[1].BufferType = SECBUFFER_EMPTY; + + out_buffer_desc.cBuffers = 1; + out_buffer_desc.pBuffers = &send_buffer_; + out_buffer_desc.ulVersion = SECBUFFER_VERSION; + + send_buffer_.pvBuffer = NULL; + send_buffer_.BufferType = SECBUFFER_TOKEN; + send_buffer_.cbBuffer = 0; + + status = InitializeSecurityContext( + &creds_, + &ctxt_, + NULL, + flags, + 0, + SECURITY_NATIVE_DREP, + &in_buffer_desc, + 0, + NULL, + &out_buffer_desc, + &out_flags, + &expiry); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + DCHECK(FAILED(status)); + DCHECK(send_buffer_.cbBuffer == 0 || + !(out_flags & ISC_RET_EXTENDED_ERROR)); + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + if (send_buffer_.cbBuffer != 0 && + (status == SEC_E_OK || + status == SEC_I_CONTINUE_NEEDED || + FAILED(status) && (out_flags & ISC_RET_EXTENDED_ERROR))) { + // TODO(wtc): if status is SEC_E_OK, we should finish the handshake + // successfully after sending send_buffer_. + // If FAILED(status) is true, we should terminate the connection after + // sending send_buffer_. + DCHECK(status == SEC_I_CONTINUE_NEEDED); // We only handle this case + // correctly. + next_state_ = STATE_HANDSHAKE_WRITE; + bytes_received_ = 0; + return OK; + } + + if (status == SEC_E_OK) { + if (in_buffers[1].BufferType == SECBUFFER_EXTRA) { + // TODO(darin) need to save this data for later. + NOTREACHED() << "should not occur for HTTPS traffic"; + return ERR_FAILED; + } + bytes_received_ = 0; + return DidCompleteHandshake(); + } + + if (FAILED(status)) + return MapSecurityError(status); + + if (status == SEC_I_INCOMPLETE_CREDENTIALS) { + // We don't support SSL client authentication yet. For now we just set + // no_client_cert_ to true and call InitializeSecurityContext again. + no_client_cert_ = true; + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + ignore_ok_result_ = true; // OK doesn't mean EOF. + return OK; + } + + DCHECK(status == SEC_I_CONTINUE_NEEDED); + if (in_buffers[1].BufferType == SECBUFFER_EXTRA) { + memmove(&recv_buffer_[0], + &recv_buffer_[0] + (bytes_received_ - in_buffers[1].cbBuffer), + in_buffers[1].cbBuffer); + bytes_received_ = in_buffers[1].cbBuffer; + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + ignore_ok_result_ = true; // OK doesn't mean EOF. + return OK; + } + + bytes_received_ = 0; + next_state_ = STATE_HANDSHAKE_READ; + return OK; +} + +int SSLClientSocketWin::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + // We should have something to send. + DCHECK(send_buffer_.pvBuffer); + DCHECK(send_buffer_.cbBuffer > 0); + + const char* buf = static_cast(send_buffer_.pvBuffer) + bytes_sent_; + int buf_len = send_buffer_.cbBuffer - bytes_sent_; + + return transport_->Write(buf, buf_len, &io_callback_); +} + +int SSLClientSocketWin::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + DCHECK(result != 0); + + bytes_sent_ += result; + DCHECK(bytes_sent_ <= static_cast(send_buffer_.cbBuffer)); + + if (bytes_sent_ >= static_cast(send_buffer_.cbBuffer)) { + bool overflow = (bytes_sent_ > static_cast(send_buffer_.cbBuffer)); + SECURITY_STATUS status = FreeContextBuffer(send_buffer_.pvBuffer); + DCHECK(status == SEC_E_OK); + memset(&send_buffer_, 0, sizeof(send_buffer_)); + bytes_sent_ = 0; + if (overflow) // Bug! + return ERR_UNEXPECTED; + next_state_ = STATE_HANDSHAKE_READ; + } else { + // Send the remaining bytes. + next_state_ = STATE_HANDSHAKE_WRITE; + } + + return OK; +} + +int SSLClientSocketWin::DoPayloadRead() { + next_state_ = STATE_PAYLOAD_READ_COMPLETE; + + DCHECK(recv_buffer_.get()); + + char* buf = recv_buffer_.get() + bytes_received_; + int buf_len = kRecvBufferSize - bytes_received_; + + if (buf_len <= 0) { + NOTREACHED() << "Receive buffer is too small!"; + return ERR_FAILED; + } + + return transport_->Read(buf, buf_len, &io_callback_); +} + +int SSLClientSocketWin::DoPayloadReadComplete(int result) { + if (result < 0) + return result; + if (result == 0 && !ignore_ok_result_) { + // TODO(wtc): Unless we have received the close_notify alert, we need to + // return an error code indicating that the SSL connection ended + // uncleanly, a potential truncation attack. + if (bytes_received_ != 0) + return ERR_FAILED; + return OK; + } + + ignore_ok_result_ = false; + + bytes_received_ += result; + + // Process the contents of recv_buffer_. + SecBuffer buffers[4]; + buffers[0].pvBuffer = recv_buffer_.get(); + buffers[0].cbBuffer = bytes_received_; + buffers[0].BufferType = SECBUFFER_DATA; + + buffers[1].BufferType = SECBUFFER_EMPTY; + buffers[2].BufferType = SECBUFFER_EMPTY; + buffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc buffer_desc; + buffer_desc.cBuffers = 4; + buffer_desc.pBuffers = buffers; + buffer_desc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status; + status = DecryptMessage(&ctxt_, &buffer_desc, 0, NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + next_state_ = STATE_PAYLOAD_READ; + return OK; + } + + if (status == SEC_I_CONTEXT_EXPIRED) { + // Received the close_notify alert. + bytes_received_ = 0; + return OK; + } + + if (status != SEC_E_OK && status != SEC_I_RENEGOTIATE) { + DCHECK(status != SEC_E_MESSAGE_ALTERED); + return MapSecurityError(status); + } + + // The received ciphertext was decrypted in place in recv_buffer_. Remember + // the location and length of the decrypted plaintext and any unused + // ciphertext. + decrypted_ptr_ = NULL; + bytes_decrypted_ = 0; + received_ptr_ = NULL; + bytes_received_ = 0; + for (int i = 1; i < 4; i++) { + if (!decrypted_ptr_ && buffers[i].BufferType == SECBUFFER_DATA) { + decrypted_ptr_ = static_cast(buffers[i].pvBuffer); + bytes_decrypted_ = buffers[i].cbBuffer; + } + if (!received_ptr_ && buffers[i].BufferType == SECBUFFER_EXTRA) { + received_ptr_ = static_cast(buffers[i].pvBuffer); + bytes_received_ = buffers[i].cbBuffer; + } + } + + int len = 0; + if (bytes_decrypted_ != 0) { + len = std::min(user_buf_len_, bytes_decrypted_); + memcpy(user_buf_, decrypted_ptr_, len); + decrypted_ptr_ += len; + bytes_decrypted_ -= len; + } + if (bytes_decrypted_ == 0) { + decrypted_ptr_ = NULL; + if (bytes_received_ != 0) { + memmove(recv_buffer_.get(), received_ptr_, bytes_received_); + received_ptr_ = recv_buffer_.get(); + } + } + // TODO(wtc): need to handle SEC_I_RENEGOTIATE. + DCHECK(status == SEC_E_OK); + // If we decrypted 0 bytes, don't report 0 bytes read, which would be + // mistaken for EOF. Continue decrypting or read more. + if (len == 0) { + if (bytes_received_ == 0) { + next_state_ = STATE_PAYLOAD_READ; + } else { + next_state_ = STATE_PAYLOAD_READ_COMPLETE; + ignore_ok_result_ = true; // OK doesn't mean EOF. + } + } + return len; +} + +int SSLClientSocketWin::DoPayloadEncrypt() { + DCHECK(user_buf_); + DCHECK(user_buf_len_ > 0); + + ULONG message_len = std::min( + stream_sizes_.cbMaximumMessage, static_cast(user_buf_len_)); + ULONG alloc_len = + message_len + stream_sizes_.cbHeader + stream_sizes_.cbTrailer; + user_buf_len_ = message_len; + + payload_send_buffer_.reset(new char[alloc_len]); + memcpy(&payload_send_buffer_[stream_sizes_.cbHeader], + user_buf_, message_len); + + SecBuffer buffers[4]; + buffers[0].pvBuffer = payload_send_buffer_.get(); + buffers[0].cbBuffer = stream_sizes_.cbHeader; + buffers[0].BufferType = SECBUFFER_STREAM_HEADER; + + buffers[1].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader]; + buffers[1].cbBuffer = message_len; + buffers[1].BufferType = SECBUFFER_DATA; + + buffers[2].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader + + message_len]; + buffers[2].cbBuffer = stream_sizes_.cbTrailer; + buffers[2].BufferType = SECBUFFER_STREAM_TRAILER; + + buffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc buffer_desc; + buffer_desc.cBuffers = 4; + buffer_desc.pBuffers = buffers; + buffer_desc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status = EncryptMessage(&ctxt_, 0, &buffer_desc, 0); + + if (FAILED(status)) + return MapSecurityError(status); + + payload_send_buffer_len_ = buffers[0].cbBuffer + + buffers[1].cbBuffer + + buffers[2].cbBuffer; + DCHECK(bytes_sent_ == 0); + + next_state_ = STATE_PAYLOAD_WRITE; + return OK; +} + +int SSLClientSocketWin::DoPayloadWrite() { + next_state_ = STATE_PAYLOAD_WRITE_COMPLETE; + + // We should have something to send. + DCHECK(payload_send_buffer_.get()); + DCHECK(payload_send_buffer_len_ > 0); + + const char* buf = payload_send_buffer_.get() + bytes_sent_; + int buf_len = payload_send_buffer_len_ - bytes_sent_; + + return transport_->Write(buf, buf_len, &io_callback_); +} + +int SSLClientSocketWin::DoPayloadWriteComplete(int result) { + if (result < 0) + return result; + + DCHECK(result != 0); + + bytes_sent_ += result; + DCHECK(bytes_sent_ <= payload_send_buffer_len_); + + if (bytes_sent_ >= payload_send_buffer_len_) { + bool overflow = (bytes_sent_ > payload_send_buffer_len_); + payload_send_buffer_.reset(); + payload_send_buffer_len_ = 0; + bytes_sent_ = 0; + if (overflow) // Bug! + return ERR_UNEXPECTED; + // Done + return user_buf_len_; + } + + // Send the remaining bytes. + next_state_ = STATE_PAYLOAD_WRITE; + return OK; +} + +int SSLClientSocketWin::DidCompleteHandshake() { + SECURITY_STATUS status = QueryContextAttributes( + &ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); + if (status != SEC_E_OK) { + DLOG(ERROR) << "QueryContextAttributes failed: " << status; + return MapSecurityError(status); + } + DCHECK(!server_cert_); + status = QueryContextAttributes( + &ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &server_cert_); + if (status != SEC_E_OK) { + DLOG(ERROR) << "QueryContextAttributes failed: " << status; + return MapSecurityError(status); + } + + completed_handshake_ = true; + int rv = VerifyServerCert(); + // TODO(wtc): for now, always check revocation. + server_cert_status_ = CERT_STATUS_REV_CHECKING_ENABLED; + if (rv) + server_cert_status_ |= MapNetErrorToCertStatus(rv); + return rv; +} + +int SSLClientSocketWin::VerifyServerCert() { + DCHECK(server_cert_); + + // Build and validate certificate chain. + + CERT_CHAIN_PARA chain_para; + memset(&chain_para, 0, sizeof(chain_para)); + chain_para.cbSize = sizeof(chain_para); + // TODO(wtc): consider requesting the usage szOID_PKIX_KP_SERVER_AUTH + // or szOID_SERVER_GATED_CRYPTO or szOID_SGC_NETSCAPE + chain_para.RequestedUsage.dwType = USAGE_MATCH_TYPE_AND; + chain_para.RequestedUsage.Usage.cUsageIdentifier = 0; + chain_para.RequestedUsage.Usage.rgpszUsageIdentifier = NULL; // LPSTR* + PCCERT_CHAIN_CONTEXT chain_context; + // TODO(wtc): for now, always check revocation. If we don't want to check + // revocation, use the CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY flag. + if (!CertGetCertificateChain( + NULL, // default chain engine, HCCE_CURRENT_USER + server_cert_, + NULL, // current system time + server_cert_->hCertStore, // search this store + &chain_para, + CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT | + CERT_CHAIN_CACHE_END_CERT, + NULL, // reserved + &chain_context)) { + return MapSecurityError(GetLastError()); + } + + std::wstring wstr_hostname = ASCIIToWide(hostname_); + + SSL_EXTRA_CERT_CHAIN_POLICY_PARA extra_policy_para; + memset(&extra_policy_para, 0, sizeof(extra_policy_para)); + extra_policy_para.cbSize = sizeof(extra_policy_para); + extra_policy_para.dwAuthType = AUTHTYPE_SERVER; + // TODO(wtc): Set these flags in fdwChecks to ignore cert errors. + // SECURITY_FLAG_IGNORE_REVOCATION + // SECURITY_FLAG_IGNORE_UNKNOWN_CA + // SECURITY_FLAG_IGNORE_WRONG_USAGE + // SECURITY_FLAG_IGNORE_CERT_CN_INVALID + // SECURITY_FLAG_IGNORE_CERT_DATE_INVALID + extra_policy_para.fdwChecks = 0; + extra_policy_para.pwszServerName = + const_cast(wstr_hostname.c_str()); + + CERT_CHAIN_POLICY_PARA policy_para; + memset(&policy_para, 0, sizeof(policy_para)); + policy_para.cbSize = sizeof(policy_para); + // TODO(wtc): It seems that we can also ignore cert errors by setting + // dwFlags. + policy_para.dwFlags = 0; + policy_para.pvExtraPolicyPara = &extra_policy_para; + + CERT_CHAIN_POLICY_STATUS policy_status; + memset(&policy_status, 0, sizeof(policy_status)); + policy_status.cbSize = sizeof(policy_status); + + if (!CertVerifyCertificateChainPolicy( + CERT_CHAIN_POLICY_SSL, + chain_context, + &policy_para, + &policy_status)) { + return MapSecurityError(GetLastError()); + } + + if (policy_status.dwError) + return MapSecurityError(policy_status.dwError); + + CertFreeCertificateChain(chain_context); + + return OK; +} + +} // namespace net + diff --git a/net/base/ssl_client_socket_win.h b/net/base/ssl_client_socket_win.h new file mode 100644 index 0000000..403e7f3 --- /dev/null +++ b/net/base/ssl_client_socket_win.h @@ -0,0 +1,139 @@ +// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_SSL_CLIENT_SOCKET_WIN_H_ +#define NET_BASE_SSL_CLIENT_SOCKET_WIN_H_ + +#define SECURITY_WIN32 // Needs to be defined before including security.h + +#include +#include +#include + +#include + +#include "base/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_config_service.h" + +namespace net { + +// An SSL client socket implemented with the Windows Schannel. +class SSLClientSocketWin : public SSLClientSocket { + public: + // Takes ownership of the transport_socket, which may already be connected. + // The given hostname will be compared with the name(s) in the server's + // certificate during the SSL handshake. ssl_config specifies the SSL + // settings. + SSLClientSocketWin(ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config); + ~SSLClientSocketWin(); + + // SSLClientSocket methods: + virtual void GetSSLInfo(SSLInfo* ssl_info); + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback); + virtual int ReconnectIgnoringLastError(CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback); + virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); + + private: + void DoCallback(int result); + void OnIOComplete(int result); + + int DoLoop(int last_io_result); + int DoConnect(); + int DoConnectComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + int DoPayloadRead(); + int DoPayloadReadComplete(int result); + int DoPayloadEncrypt(); + int DoPayloadWrite(); + int DoPayloadWriteComplete(int result); + + int DidCompleteHandshake(); + int VerifyServerCert(); + + CompletionCallbackImpl io_callback_; + scoped_ptr transport_; + std::string hostname_; + SSLConfig ssl_config_; + + CompletionCallback* user_callback_; + + // Used by both Read and Write functions. + char* user_buf_; + int user_buf_len_; + + enum State { + STATE_NONE, + STATE_CONNECT, + STATE_CONNECT_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_PAYLOAD_ENCRYPT, + STATE_PAYLOAD_WRITE, + STATE_PAYLOAD_WRITE_COMPLETE, + STATE_PAYLOAD_READ, + STATE_PAYLOAD_READ_COMPLETE, + }; + State next_state_; + + SecPkgContext_StreamSizes stream_sizes_; + PCCERT_CONTEXT server_cert_; + int server_cert_status_; + + CredHandle creds_; + CtxtHandle ctxt_; + SecBuffer send_buffer_; + scoped_array payload_send_buffer_; + int payload_send_buffer_len_; + int bytes_sent_; + + // recv_buffer_ holds the received ciphertext. Since Schannel decrypts + // data in place, sometimes recv_buffer_ may contain decrypted plaintext and + // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record + // at a time.) + // + // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of + // recv_buffer_, ready to be passed to DecryptMessage. + scoped_array recv_buffer_; + char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_ + int bytes_decrypted_; // The number of bytes of decrypted plaintext. + char* received_ptr_; // Points to the received ciphertext in recv_buffer_ + int bytes_received_; // The number of bytes of received ciphertext. + + bool completed_handshake_; + + // Only used in the STATE_HANDSHAKE_READ_COMPLETE and + // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK + // should be ignored, to prevent it from being interpreted as EOF. + // + // The reason we need this flag is that OK means not only "0 bytes of data + // were read" but also EOF. We set ignore_ok_result_ to true when we need + // to continue processing previously read data without reading more data. + // We have to pass a 'result' of OK to the DoLoop method, and don't want it + // to be interpreted as EOF. + bool ignore_ok_result_; + + // True if the user has no client certificate. + bool no_client_cert_; +}; + +} // namespace net + +#endif // NET_BASE_SSL_CLIENT_SOCKET_WIN_H_ + diff --git a/net/build/net.vcproj b/net/build/net.vcproj index b56b05f..b711b49 100644 --- a/net/build/net.vcproj +++ b/net/build/net.vcproj @@ -433,11 +433,15 @@ > + + CreateSSLClientSocket(s, request_->url.host(), - ssl_version_mask_); + ssl_config_); connection_.set_socket(s); return connection_.socket()->Connect(&io_callback_); @@ -510,7 +503,7 @@ int HttpNetworkTransaction::DoSSLConnectOverTunnel() { // Add a SSL socket on top of our existing transport socket. ClientSocket* s = connection_.release_socket(); s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(), - ssl_version_mask_); + ssl_config_); connection_.set_socket(s); return connection_.socket()->Connect(&io_callback_); } @@ -834,13 +827,11 @@ int HttpNetworkTransaction::DidReadResponseHeaders() { } } -#if defined(OS_WIN) if (using_ssl_ && !establishing_tunnel_) { SSLClientSocket* ssl_socket = reinterpret_cast(connection_.socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); } -#endif return OK; } @@ -869,25 +860,22 @@ int HttpNetworkTransaction::HandleCertificateError(int error) { } } -#if defined(OS_WIN) if (error != OK) { SSLClientSocket* ssl_socket = reinterpret_cast(connection_.socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); } -#endif return error; } int HttpNetworkTransaction::HandleSSLHandshakeError(int error) { -#if defined(OS_WIN) switch (error) { case ERR_SSL_PROTOCOL_ERROR: case ERR_SSL_VERSION_OR_CIPHER_MISMATCH: - if (ssl_version_mask_ & SSLClientSocket::TLS1) { + if (ssl_config_.tls1_enabled) { // This could be a TLS-intolerant server or an SSL 3.0 server that // chose a TLS-only cipher suite. Turn off TLS 1.0 and retry. - ssl_version_mask_ &= ~SSLClientSocket::TLS1; + ssl_config_.tls1_enabled = false; connection_.set_socket(NULL); connection_.Reset(); next_state_ = STATE_INIT_CONNECTION; @@ -895,7 +883,6 @@ int HttpNetworkTransaction::HandleSSLHandshakeError(int error) { } break; } -#endif return error; } @@ -1001,7 +988,7 @@ void HttpNetworkTransaction::AddAuthorizationHeader(HttpAuth::Target target) { // Add auth data to cache session_->auth_cache()->Add(auth_cache_key_[target], auth_data_[target]); - + // Add a Authorization/Proxy-Authorization header line. std::string credentials = auth_handler_[target]->GenerateCredentials( auth_data_[target]->username, diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index bbbfb74..475056e 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -11,6 +11,7 @@ #include "net/base/address_list.h" #include "net/base/client_socket_handle.h" #include "net/base/host_resolver.h" +#include "net/base/ssl_config_service.h" #include "net/http/http_auth.h" #include "net/http/http_auth_handler.h" #include "net/http/http_response_info.h" @@ -186,7 +187,7 @@ class HttpNetworkTransaction : public HttpTransaction { // the real request/response of the transaction. bool establishing_tunnel_; - int ssl_version_mask_; + SSLConfig ssl_config_; std::string request_headers_; size_t request_headers_bytes_sent_; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 56fb5cd..9c2d6a7 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -137,7 +137,7 @@ class MockTCPClientSocket : public net::ClientSocket { // Not using mock writes; succeed synchronously. if (!data_->writes) return buf_len; - + // Check that what we are writing matches the expectation. // Then give the mocked return value. MockWrite& w = data_->writes[write_index_]; @@ -185,10 +185,10 @@ class MockClientSocketFactory : public net::ClientSocketFactory { const net::AddressList& addresses) { return new MockTCPClientSocket(addresses); } - virtual net::ClientSocket* CreateSSLClientSocket( + virtual net::SSLClientSocket* CreateSSLClientSocket( net::ClientSocket* transport_socket, const std::string& hostname, - int protocol_version_mask) { + const net::SSLConfig& ssl_config) { return NULL; } }; @@ -623,7 +623,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuth) { MockRead("HTTP/1.0 401 Unauthorized\r\n"), // Give a couple authenticate options (only the middle one is actually // supported). - MockRead("WWW-Authenticate: Basic\r\n"), // Malformed + MockRead("WWW-Authenticate: Basic\r\n"), // Malformed MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("WWW-Authenticate: UNSUPPORTED realm=\"FOO\"\r\n"), MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), @@ -717,7 +717,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyThenServer) { MockRead("HTTP/1.0 407 Unauthorized\r\n"), // Give a couple authenticate options (only the middle one is actually // supported). - MockRead("Proxy-Authenticate: Basic\r\n"), // Malformed + MockRead("Proxy-Authenticate: Basic\r\n"), // Malformed MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("Proxy-Authenticate: UNSUPPORTED realm=\"FOO\"\r\n"), MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), @@ -745,7 +745,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyThenServer) { MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), MockRead("Content-Length: 2000\r\n\r\n"), - MockRead(false, net::ERR_FAILED), // Won't be reached. + MockRead(false, net::ERR_FAILED), // Won't be reached. }; // After calling trans->RestartWithAuth() the second time, we should send diff --git a/net/net.xcodeproj/project.pbxproj b/net/net.xcodeproj/project.pbxproj index 15f2211..a594020 100644 --- a/net/net.xcodeproj/project.pbxproj +++ b/net/net.xcodeproj/project.pbxproj @@ -491,7 +491,6 @@ 7BED32940E5A181C00A747DB /* ssl_config_service.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_config_service.cc; sourceTree = ""; }; 7BED32950E5A181C00A747DB /* ssl_client_socket_unittest.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_client_socket_unittest.cc; sourceTree = ""; }; 7BED32960E5A181C00A747DB /* ssl_client_socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ssl_client_socket.h; sourceTree = ""; }; - 7BED32970E5A181C00A747DB /* ssl_client_socket.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_client_socket.cc; sourceTree = ""; }; 7BED32980E5A181C00A747DB /* socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = socket.h; sourceTree = ""; }; 7BED32990E5A181C00A747DB /* registry_controlled_domain_unittest.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = registry_controlled_domain_unittest.cc; sourceTree = ""; }; 7BED329A0E5A181C00A747DB /* registry_controlled_domain.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = registry_controlled_domain.h; sourceTree = ""; }; @@ -946,7 +945,6 @@ E49DD2E80E892F8C003C7A87 /* sdch_manager.cc */, E49DD2E90E892F8C003C7A87 /* sdch_manager.h */, 7BED32980E5A181C00A747DB /* socket.h */, - 7BED32970E5A181C00A747DB /* ssl_client_socket.cc */, 7BED32960E5A181C00A747DB /* ssl_client_socket.h */, 7BED32950E5A181C00A747DB /* ssl_client_socket_unittest.cc */, 7BED32940E5A181C00A747DB /* ssl_config_service.cc */, -- cgit v1.1