diff options
author | wtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2008-08-14 20:33:25 +0000 |
---|---|---|
committer | wtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2008-08-14 20:33:25 +0000 |
commit | 4628a2a1293b6661630162edfce543998c69f105 (patch) | |
tree | 40b2ffaf5d0f56c27b56aa44391bd0271bcb8302 /net/base | |
parent | 0297f4f98eedef215784483827deb2356f44e7ca (diff) | |
download | chromium_src-4628a2a1293b6661630162edfce543998c69f105.zip chromium_src-4628a2a1293b6661630162edfce543998c69f105.tar.gz chromium_src-4628a2a1293b6661630162edfce543998c69f105.tar.bz2 |
First cut at implementing SSLClientSocket using Schannel.
Not implemented:
- Handling certificate errors
- Handling session renegotiation
- Sending the close_notify alert
- Miscellaneous TODOs and DCHECKs in the code.
R=darin
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@884 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/base')
-rw-r--r-- | net/base/ssl_client_socket.cc | 390 | ||||
-rw-r--r-- | net/base/ssl_client_socket.h | 32 | ||||
-rw-r--r-- | net/base/ssl_client_socket_unittest.cc | 16 |
3 files changed, 368 insertions, 70 deletions
diff --git a/net/base/ssl_client_socket.cc b/net/base/ssl_client_socket.cc index 711a114..706034f 100644 --- a/net/base/ssl_client_socket.cc +++ b/net/base/ssl_client_socket.cc @@ -34,6 +34,7 @@ #include "base/singleton.h" #include "base/string_util.h" #include "net/base/net_errors.h" +#include "net/base/ssl_info.h" namespace net { @@ -41,38 +42,43 @@ namespace net { class SChannelLib { public: - SecurityFunctionTable funcs; + PSecurityFunctionTable funcs; - SChannelLib() { - memset(&funcs, 0, sizeof(funcs)); - lib_ = LoadLibrary(L"SCHANNEL.DLL"); + SChannelLib() : funcs(NULL) { + lib_ = LoadLibrary(L"secur32.dll"); if (lib_) { INIT_SECURITY_INTERFACE init_security_interface = reinterpret_cast<INIT_SECURITY_INTERFACE>( GetProcAddress(lib_, "InitSecurityInterfaceW")); - if (init_security_interface) { - PSecurityFunctionTable funcs_ptr = init_security_interface(); - if (funcs_ptr) - memcpy(&funcs, funcs_ptr, sizeof(funcs)); - } + if (init_security_interface) + funcs = init_security_interface(); } } ~SChannelLib() { - FreeLibrary(lib_); + if (lib_) + FreeLibrary(lib_); } private: HMODULE lib_; }; -static inline SecurityFunctionTable& SChannel() { +static inline PSecurityFunctionTable SChannel() { return Singleton<SChannelLib>()->funcs; } //----------------------------------------------------------------------------- -static const int kRecvBufferSize = 0x10000; +// Size of recv_buffer_ +// +// Ciphertext is decrypted one SSL record at a time, so recv_buffer_ needs to +// have room for a full SSL record, with the header and trailer. Here is the +// breakdown of the size: +// 5: SSL record header +// 16K: SSL record maximum size +// 64: >= SSL record trailer (16 or 20 have been observed) +static const int kRecvBufferSize = (5 + 16*1024 + 64); SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket, const std::string& hostname) @@ -84,9 +90,14 @@ SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket, user_buf_(NULL), user_buf_len_(0), next_state_(STATE_NONE), + payload_send_buffer_len_(0), bytes_sent_(0), + decrypted_ptr_(NULL), + bytes_decrypted_(0), + received_ptr_(NULL), bytes_received_(0), - completed_handshake_(false) { + completed_handshake_(false), + ignore_ok_result_(false) { memset(&stream_sizes_, 0, sizeof(stream_sizes_)); memset(&send_buffer_, 0, sizeof(send_buffer_)); memset(&creds_, 0, sizeof(creds_)); @@ -115,20 +126,25 @@ int SSLClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) { } void SSLClientSocket::Disconnect() { + // TODO(wtc): Send SSL close_notify alert. + completed_handshake_ = false; transport_->Disconnect(); if (send_buffer_.pvBuffer) { - SChannel().FreeContextBuffer(send_buffer_.pvBuffer); + SChannel()->FreeContextBuffer(send_buffer_.pvBuffer); memset(&send_buffer_, 0, sizeof(send_buffer_)); } if (creds_.dwLower || creds_.dwUpper) { - SChannel().FreeCredentialsHandle(&creds_); + SChannel()->FreeCredentialsHandle(&creds_); memset(&creds_, 0, sizeof(creds_)); } if (ctxt_.dwLower || ctxt_.dwUpper) { - SChannel().DeleteSecurityContext(&ctxt_); + SChannel()->DeleteSecurityContext(&ctxt_); memset(&ctxt_, 0, sizeof(ctxt_)); } + // TODO(wtc): reset more members? + bytes_decrypted_ = 0; + bytes_received_ = 0; } bool SSLClientSocket::IsConnected() const { @@ -141,10 +157,32 @@ int SSLClientSocket::Read(char* buf, int buf_len, DCHECK(next_state_ == STATE_NONE); DCHECK(!user_callback_); + // If we have surplus decrypted plaintext, satisfy the Read with it without + // reading more ciphertext from the transport socket. + if (bytes_decrypted_ != 0) { + int len = std::min(buf_len, bytes_decrypted_); + memcpy(buf, decrypted_ptr_, len); + decrypted_ptr_ += len; + bytes_decrypted_ -= len; + if (bytes_decrypted_ == 0) { + decrypted_ptr_ = NULL; + if (bytes_received_ != 0) { + memmove(recv_buffer_.get(), received_ptr_, bytes_received_); + received_ptr_ = recv_buffer_.get(); + } + } + return len; + } + user_buf_ = buf; user_buf_len_ = buf_len; - next_state_ = STATE_PAYLOAD_READ; + if (bytes_received_ == 0) { + next_state_ = STATE_PAYLOAD_READ; + } else { + next_state_ = STATE_PAYLOAD_READ_COMPLETE; + ignore_ok_result_ = true; // OK doesn't mean EOF. + } int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) user_callback_ = callback; @@ -160,18 +198,41 @@ int SSLClientSocket::Write(const char* buf, int buf_len, user_buf_ = const_cast<char*>(buf); user_buf_len_ = buf_len; - next_state_ = STATE_PAYLOAD_WRITE; + next_state_ = STATE_PAYLOAD_ENCRYPT; int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) user_callback_ = callback; return rv; } +void SSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + SECURITY_STATUS status; + PCCERT_CONTEXT server_cert = NULL; + status = SChannel()->QueryContextAttributes(&ctxt_, + SECPKG_ATTR_REMOTE_CERT_CONTEXT, + &server_cert); + if (status == SEC_E_OK) { + DCHECK(server_cert); + ssl_info->cert = X509Certificate::CreateFromHandle(server_cert); + } + SecPkgContext_ConnectionInfo connection_info; + status = SChannel()->QueryContextAttributes(&ctxt_, + SECPKG_ATTR_CONNECTION_INFO, + &connection_info); + if (status == SEC_E_OK) { + // TODO(wtc): compute the overall security strength, taking into account + // dwExchStrength and dwHashStrength. dwExchStrength needs to be + // normalized. + ssl_info->security_bits = connection_info.dwCipherStrength; + } + ssl_info->cert_status = 0; +} + void SSLClientSocket::DoCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); DCHECK(user_callback_); - // since Run may result in Read being called, clear callback_ up front. + // since Run may result in Read being called, clear user_callback_ up front. CompletionCallback* c = user_callback_; user_callback_ = NULL; c->Run(rv); @@ -214,6 +275,9 @@ int SSLClientSocket::DoLoop(int last_io_result) { case STATE_PAYLOAD_READ_COMPLETE: rv = DoPayloadReadComplete(rv); break; + case STATE_PAYLOAD_ENCRYPT: + rv = DoPayloadEncrypt(); + break; case STATE_PAYLOAD_WRITE: rv = DoPayloadWrite(); break; @@ -242,25 +306,51 @@ int SSLClientSocket::DoConnectComplete(int result) { SCHANNEL_CRED schannel_cred = {0}; schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; + + // TODO(wtc): This should be configurable. Hardcoded to do SSL 3.0 and + // TLS 1.0 for now. The default (0) means Schannel selects the protocol. + // The global system registry settings take precedence over this value. + schannel_cred.grbitEnabledProtocols = SP_PROT_SSL3TLS1; + + // The default session lifetime is 36000000 milliseconds (ten hours). Set + // schannel_cred.dwSessionLifespan to change the number of milliseconds that + // Schannel keeps the session in its session cache. + + // We can set the key exchange algorithms (RSA or DH) in + // schannel_cred.{cSupportedAlgs,palgSupportedAlgs}. + + // TODO(wtc): We may need to use SCH_CRED_IGNORE_NO_REVOCATION_CHECK and + // SCH_CRED_IGNORE_REVOCATION_OFFLINE, but only after getting the + // CRYPT_E_NO_REVOCATION_CHECK and CRYPT_E_REVOCATION_OFFLINE errors. + // + // Look into undocumented or poorly documented flags: + // SCH_CRED_RESTRICTED_ROOTS + // SCH_CRED_REVOCATION_CHECK_CACHE_ONLY + // SCH_CRED_CACHE_ONLY_URL_RETRIEVAL + // SCH_CRED_MEMORY_STORE_CERT + // SCH_CRED_CACHE_ONLY_URL_RETRIEVAL_ON_CREATE + // + // SCH_CRED_NO_SERVERNAME_CHECK can be useful during testing. schannel_cred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS | - SCH_CRED_NO_SYSTEM_MAPPER | - SCH_CRED_REVOCATION_CHECK_CHAIN; + SCH_CRED_NO_SERVERNAME_CHECK | // Remove me! + SCH_CRED_AUTO_CRED_VALIDATION | + SCH_CRED_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; TimeStamp expiry; SECURITY_STATUS status; - status = SChannel().AcquireCredentialsHandle( - NULL, - UNISP_NAME, + status = SChannel()->AcquireCredentialsHandle( + NULL, // Not used + UNISP_NAME, // Microsoft Unified Security Protocol Provider SECPKG_CRED_OUTBOUND, - NULL, + NULL, // Not used &schannel_cred, - NULL, - NULL, + NULL, // Not used + NULL, // Not used &creds_, - &expiry); + &expiry); // Optional if (status != SEC_E_OK) { DLOG(ERROR) << "AcquireCredentialsHandle failed: " << status; - return ERR_FAILED; + return ERR_FAILED; // TODO(wtc): map SEC_E_xxx error codes. } SecBufferDesc buffer_desc; @@ -280,22 +370,22 @@ int SSLClientSocket::DoConnectComplete(int result) { buffer_desc.pBuffers = &send_buffer_; buffer_desc.ulVersion = SECBUFFER_VERSION; - status = SChannel().InitializeSecurityContext( + status = SChannel()->InitializeSecurityContext( &creds_, - NULL, + NULL, // NULL on the first call const_cast<wchar_t*>(ASCIIToWide(hostname_).c_str()), flags, - 0, - SECURITY_NATIVE_DREP, - NULL, - 0, - &ctxt_, + 0, // Reserved + SECURITY_NATIVE_DREP, // TODO(wtc): MSDN says this should be set to 0. + NULL, // NULL on the first call + 0, // Reserved + &ctxt_, // Receives the new context handle &buffer_desc, &out_flags, &expiry); if (status != SEC_I_CONTINUE_NEEDED) { DLOG(ERROR) << "InitializeSecurityContext failed: " << status; - return ERR_FAILED; + return ERR_FAILED; // TODO(wtc): map SEC_E_xxx error codes. } next_state_ = STATE_HANDSHAKE_WRITE; @@ -322,9 +412,11 @@ int SSLClientSocket::DoHandshakeRead() { int SSLClientSocket::DoHandshakeReadComplete(int result) { if (result < 0) return result; - if (result == 0) + if (result == 0 && !ignore_ok_result_) return ERR_FAILED; // Incomplete response :( + ignore_ok_result_ = false; + bytes_received_ += result; // Process the contents of recv_buffer_. @@ -362,7 +454,7 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) { send_buffer_.BufferType = SECBUFFER_TOKEN; send_buffer_.cbBuffer = 0; - status = SChannel().InitializeSecurityContext( + status = SChannel()->InitializeSecurityContext( &creds_, &ctxt_, NULL, @@ -377,18 +469,25 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) { &expiry); if (status == SEC_E_INCOMPLETE_MESSAGE) { + DCHECK(FAILED(status)); + DCHECK(send_buffer_.cbBuffer == 0 || + !(out_flags & ISC_RET_EXTENDED_ERROR)); next_state_ = STATE_HANDSHAKE_READ; return OK; } - // OK, all of the received data was consumed. - bytes_received_ = 0; - if (send_buffer_.cbBuffer != 0 && (status == SEC_E_OK || status == SEC_I_CONTINUE_NEEDED || FAILED(status) && (out_flags & ISC_RET_EXTENDED_ERROR))) { + // TODO(wtc): if status is SEC_E_OK, we should finish the handshake + // successfully after sending send_buffer_. + // If FAILED(status) is true, we should terminate the connection after + // sending send_buffer_. + DCHECK(status == SEC_I_CONTINUE_NEEDED); // We only handle this case + // correctly. next_state_ = STATE_HANDSHAKE_WRITE; + bytes_received_ = 0; return OK; } @@ -396,13 +495,28 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) { if (in_buffers[1].BufferType == SECBUFFER_EXTRA) { // TODO(darin) need to save this data for later. NOTREACHED() << "should not occur for HTTPS traffic"; + return ERR_FAILED; } + bytes_received_ = 0; return DidCompleteHandshake(); } if (FAILED(status)) - return ERR_FAILED; + return ERR_FAILED; // TODO(wtc): map error codes, in particular cert + // errors such as SEC_E_UNTRUSTED_ROOT. + + DCHECK(status == SEC_I_CONTINUE_NEEDED); + if (in_buffers[1].BufferType == SECBUFFER_EXTRA) { + memmove(&recv_buffer_[0], + &recv_buffer_[0] + (bytes_received_ - in_buffers[1].cbBuffer), + in_buffers[1].cbBuffer); + bytes_received_ = in_buffers[1].cbBuffer; + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + ignore_ok_result_ = true; // OK doesn't mean EOF. + return OK; + } + bytes_received_ = 0; next_state_ = STATE_HANDSHAKE_READ; return OK; } @@ -426,14 +540,18 @@ int SSLClientSocket::DoHandshakeWriteComplete(int result) { DCHECK(result != 0); - // TODO(darin): worry about overflow? bytes_sent_ += result; DCHECK(bytes_sent_ <= static_cast<int>(send_buffer_.cbBuffer)); - if (bytes_sent_ == static_cast<int>(send_buffer_.cbBuffer)) { - SChannel().FreeContextBuffer(send_buffer_.pvBuffer); + if (bytes_sent_ >= static_cast<int>(send_buffer_.cbBuffer)) { + bool overflow = (bytes_sent_ > static_cast<int>(send_buffer_.cbBuffer)); + SECURITY_STATUS status = + SChannel()->FreeContextBuffer(send_buffer_.pvBuffer); + DCHECK(status == SEC_E_OK); memset(&send_buffer_, 0, sizeof(send_buffer_)); bytes_sent_ = 0; + if (overflow) // Bug! + return ERR_UNEXPECTED; next_state_ = STATE_HANDSHAKE_READ; } else { // Send the remaining bytes. @@ -446,46 +564,196 @@ int SSLClientSocket::DoHandshakeWriteComplete(int result) { int SSLClientSocket::DoPayloadRead() { next_state_ = STATE_PAYLOAD_READ_COMPLETE; - return ERR_FAILED; + DCHECK(recv_buffer_.get()); + + char* buf = recv_buffer_.get() + bytes_received_; + int buf_len = kRecvBufferSize - bytes_received_; + + if (buf_len <= 0) { + NOTREACHED() << "Receive buffer is too small!"; + return ERR_FAILED; + } + + return transport_->Read(buf, buf_len, &io_callback_); } int SSLClientSocket::DoPayloadReadComplete(int result) { - return ERR_FAILED; + if (result < 0) + return result; + if (result == 0 && !ignore_ok_result_) { + // TODO(wtc): Unless we have received the close_notify alert, we need to + // return an error code indicating that the SSL connection ended + // uncleanly, a potential truncation attack. + if (bytes_received_ != 0) + return ERR_FAILED; + return OK; + } + + ignore_ok_result_ = false; + + bytes_received_ += result; + + // Process the contents of recv_buffer_. + SecBuffer buffers[4]; + buffers[0].pvBuffer = recv_buffer_.get(); + buffers[0].cbBuffer = bytes_received_; + buffers[0].BufferType = SECBUFFER_DATA; + + buffers[1].BufferType = SECBUFFER_EMPTY; + buffers[2].BufferType = SECBUFFER_EMPTY; + buffers[3].BufferType = SECBUFFER_EMPTY; + + SecBufferDesc buffer_desc; + buffer_desc.cBuffers = 4; + buffer_desc.pBuffers = buffers; + buffer_desc.ulVersion = SECBUFFER_VERSION; + + SECURITY_STATUS status; + status = SChannel()->DecryptMessage(&ctxt_, &buffer_desc, 0, NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + next_state_ = STATE_PAYLOAD_READ; + return OK; + } + + if (status == SEC_I_CONTEXT_EXPIRED) { + // Received the close_notify alert. + bytes_received_ = 0; + return OK; + } + + if (status != SEC_E_OK && status != SEC_I_RENEGOTIATE) { + DCHECK(status != SEC_E_MESSAGE_ALTERED); + return ERR_FAILED; // TODO(wtc): map error code + } + + // The received ciphertext was decrypted in place in recv_buffer_. Remember + // the location and length of the decrypted plaintext and any unused + // ciphertext. + decrypted_ptr_ = NULL; + bytes_decrypted_ = 0; + received_ptr_ = NULL; + bytes_received_ = 0; + for (int i = 1; i < 4; i++) { + if (!decrypted_ptr_ && buffers[i].BufferType == SECBUFFER_DATA) { + decrypted_ptr_ = static_cast<char*>(buffers[i].pvBuffer); + bytes_decrypted_ = buffers[i].cbBuffer; + } + if (!received_ptr_ && buffers[i].BufferType == SECBUFFER_EXTRA) { + received_ptr_ = static_cast<char*>(buffers[i].pvBuffer); + bytes_received_ = buffers[i].cbBuffer; + } + } + + int len = 0; + if (bytes_decrypted_ != 0) { + len = std::min(user_buf_len_, bytes_decrypted_); + memcpy(user_buf_, decrypted_ptr_, len); + decrypted_ptr_ += len; + bytes_decrypted_ -= len; + } + if (bytes_decrypted_ == 0) { + decrypted_ptr_ = NULL; + if (bytes_received_ != 0) { + memmove(recv_buffer_.get(), received_ptr_, bytes_received_); + received_ptr_ = recv_buffer_.get(); + } + } + // TODO(wtc): need to handle SEC_I_RENEGOTIATE. + DCHECK(status == SEC_E_OK); + return len; } -int SSLClientSocket::DoPayloadWrite() { +int SSLClientSocket::DoPayloadEncrypt() { DCHECK(user_buf_); DCHECK(user_buf_len_ > 0); - next_state_ = STATE_PAYLOAD_WRITE_COMPLETE; - - size_t message_len = std::min( + ULONG message_len = std::min( stream_sizes_.cbMaximumMessage, static_cast<ULONG>(user_buf_len_)); - size_t alloc_len = + ULONG alloc_len = message_len + stream_sizes_.cbHeader + stream_sizes_.cbTrailer; + user_buf_len_ = message_len; + + payload_send_buffer_.reset(new char[alloc_len]); + memcpy(&payload_send_buffer_[stream_sizes_.cbHeader], + user_buf_, message_len); - /* SecBuffer buffers[4]; - buffers[0]. + buffers[0].pvBuffer = payload_send_buffer_.get(); + buffers[0].cbBuffer = stream_sizes_.cbHeader; + buffers[0].BufferType = SECBUFFER_STREAM_HEADER; + + buffers[1].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader]; + buffers[1].cbBuffer = message_len; + buffers[1].BufferType = SECBUFFER_DATA; + + buffers[2].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader + + message_len]; + buffers[2].cbBuffer = stream_sizes_.cbTrailer; + buffers[2].BufferType = SECBUFFER_STREAM_TRAILER; + + buffers[3].BufferType = SECBUFFER_EMPTY; SecBufferDesc buffer_desc; buffer_desc.cBuffers = 4; - buffer_desc.pBuffers = //XXX + buffer_desc.pBuffers = buffers; buffer_desc.ulVersion = SECBUFFER_VERSION; - SECURITY_STATUS status = SChannel().EncryptMessage( + SECURITY_STATUS status = SChannel()->EncryptMessage( &ctxt_, 0, &buffer_desc, 0); - */ - return ERR_FAILED; + if (FAILED(status)) + return ERR_FAILED; + + payload_send_buffer_len_ = buffers[0].cbBuffer + + buffers[1].cbBuffer + + buffers[2].cbBuffer; + DCHECK(bytes_sent_ == 0); + + next_state_ = STATE_PAYLOAD_WRITE; + return OK; +} + +int SSLClientSocket::DoPayloadWrite() { + next_state_ = STATE_PAYLOAD_WRITE_COMPLETE; + + // We should have something to send. + DCHECK(payload_send_buffer_.get()); + DCHECK(payload_send_buffer_len_ > 0); + + const char* buf = payload_send_buffer_.get() + bytes_sent_; + int buf_len = payload_send_buffer_len_ - bytes_sent_; + + return transport_->Write(buf, buf_len, &io_callback_); } int SSLClientSocket::DoPayloadWriteComplete(int result) { - return ERR_FAILED; + if (result < 0) + return result; + + DCHECK(result != 0); + + bytes_sent_ += result; + DCHECK(bytes_sent_ <= payload_send_buffer_len_); + + if (bytes_sent_ >= payload_send_buffer_len_) { + bool overflow = (bytes_sent_ > payload_send_buffer_len_); + payload_send_buffer_.reset(); + payload_send_buffer_len_ = 0; + bytes_sent_ = 0; + if (overflow) // Bug! + return ERR_UNEXPECTED; + // Done + return user_buf_len_; + } + + // Send the remaining bytes. + next_state_ = STATE_PAYLOAD_WRITE; + return OK; } int SSLClientSocket::DidCompleteHandshake() { - SECURITY_STATUS status = SChannel().QueryContextAttributes( + SECURITY_STATUS status = SChannel()->QueryContextAttributes( &ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); if (status != SEC_E_OK) { DLOG(ERROR) << "QueryContextAttributes failed: " << status; diff --git a/net/base/ssl_client_socket.h b/net/base/ssl_client_socket.h index 3e1779f..8711e2a 100644 --- a/net/base/ssl_client_socket.h +++ b/net/base/ssl_client_socket.h @@ -43,6 +43,8 @@ namespace net { +class SSLInfo; + // A client socket that uses SSL as the transport layer. // // NOTE: The SSL handshake occurs within the Connect method after a TCP @@ -68,6 +70,9 @@ class SSLClientSocket : public ClientSocket { virtual int Read(char* buf, int buf_len, CompletionCallback* callback); virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); + // Gets the SSL connection information of the socket. + void GetSSLInfo(SSLInfo* ssl_info); + private: void DoCallback(int result); void OnIOComplete(int result); @@ -81,6 +86,7 @@ class SSLClientSocket : public ClientSocket { int DoHandshakeWriteComplete(int result); int DoPayloadRead(); int DoPayloadReadComplete(int result); + int DoPayloadEncrypt(); int DoPayloadWrite(); int DoPayloadWriteComplete(int result); @@ -104,6 +110,7 @@ class SSLClientSocket : public ClientSocket { STATE_HANDSHAKE_READ_COMPLETE, STATE_HANDSHAKE_WRITE, STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_PAYLOAD_ENCRYPT, STATE_PAYLOAD_WRITE, STATE_PAYLOAD_WRITE_COMPLETE, STATE_PAYLOAD_READ, @@ -116,12 +123,35 @@ class SSLClientSocket : public ClientSocket { CredHandle creds_; CtxtHandle ctxt_; SecBuffer send_buffer_; + scoped_array<char> payload_send_buffer_; + int payload_send_buffer_len_; int bytes_sent_; + // recv_buffer_ holds the received ciphertext. Since Schannel decrypts + // data in place, sometimes recv_buffer_ may contain decrypted plaintext and + // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record + // at a time.) + // + // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of + // recv_buffer_, ready to be passed to DecryptMessage. scoped_array<char> recv_buffer_; - int bytes_received_; + char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_ + int bytes_decrypted_; // The number of bytes of decrypted plaintext. + char* received_ptr_; // Points to the received ciphertext in recv_buffer_ + int bytes_received_; // The number of bytes of received ciphertext. bool completed_handshake_; + + // Only used in the STATE_HANDSHAKE_READ_COMPLETE and + // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK + // should be ignored, to prevent it from being interpreted as EOF. + // + // The reason we need this flag is that OK means not only "0 bytes of data + // were read" but also EOF. We set ignore_ok_result_ to true when we need + // to continue processing previously read data without reading more data. + // We have to pass a 'result' of OK to the DoLoop method, and don't want it + // to be interpreted as EOF. + bool ignore_ok_result_; }; } // namespace net diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc index ebcce65..7bf8ea3 100644 --- a/net/base/ssl_client_socket_unittest.cc +++ b/net/base/ssl_client_socket_unittest.cc @@ -51,7 +51,7 @@ TEST_F(SSLClientSocketTest, Connect) { net::HostResolver resolver; TestCompletionCallback callback; - std::string hostname = "www.verisign.com"; + std::string hostname = "bugs.webkit.org"; int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(net::OK, rv); @@ -73,13 +73,12 @@ TEST_F(SSLClientSocketTest, Connect) { EXPECT_FALSE(sock.IsConnected()); } -#if 0 TEST_F(SSLClientSocketTest, Read) { net::AddressList addr; net::HostResolver resolver; TestCompletionCallback callback; - std::string hostname = "www.google.com"; + std::string hostname = "bugs.webkit.org"; int rv = resolver.Resolve(hostname, 443, &addr, &callback); EXPECT_EQ(rv, net::ERR_IO_PENDING); @@ -124,10 +123,11 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { net::HostResolver resolver; TestCompletionCallback callback; - int rv = resolver.Resolve("www.google.com", 80, &addr, NULL); + std::string hostname = "bugs.webkit.org"; + int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(rv, net::OK); - net::TCPClientSocket sock(addr); + net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname); rv = sock.Connect(&callback); if (rv != net::OK) { @@ -165,10 +165,11 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { net::HostResolver resolver; TestCompletionCallback callback; - int rv = resolver.Resolve("www.google.com", 80, &addr, NULL); + std::string hostname = "bugs.webkit.org"; + int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(rv, net::OK); - net::TCPClientSocket sock(addr); + net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname); rv = sock.Connect(&callback); if (rv != net::OK) { @@ -197,4 +198,3 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { EXPECT_NE(rv, 0); } -#endif |