summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorwtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98>2008-08-14 20:33:25 +0000
committerwtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98>2008-08-14 20:33:25 +0000
commit4628a2a1293b6661630162edfce543998c69f105 (patch)
tree40b2ffaf5d0f56c27b56aa44391bd0271bcb8302 /net
parent0297f4f98eedef215784483827deb2356f44e7ca (diff)
downloadchromium_src-4628a2a1293b6661630162edfce543998c69f105.zip
chromium_src-4628a2a1293b6661630162edfce543998c69f105.tar.gz
chromium_src-4628a2a1293b6661630162edfce543998c69f105.tar.bz2
First cut at implementing SSLClientSocket using Schannel.
Not implemented: - Handling certificate errors - Handling session renegotiation - Sending the close_notify alert - Miscellaneous TODOs and DCHECKs in the code. R=darin git-svn-id: svn://svn.chromium.org/chrome/trunk/src@884 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r--net/base/ssl_client_socket.cc390
-rw-r--r--net/base/ssl_client_socket.h32
-rw-r--r--net/base/ssl_client_socket_unittest.cc16
-rw-r--r--net/http/http_network_transaction.cc9
4 files changed, 376 insertions, 71 deletions
diff --git a/net/base/ssl_client_socket.cc b/net/base/ssl_client_socket.cc
index 711a114..706034f 100644
--- a/net/base/ssl_client_socket.cc
+++ b/net/base/ssl_client_socket.cc
@@ -34,6 +34,7 @@
#include "base/singleton.h"
#include "base/string_util.h"
#include "net/base/net_errors.h"
+#include "net/base/ssl_info.h"
namespace net {
@@ -41,38 +42,43 @@ namespace net {
class SChannelLib {
public:
- SecurityFunctionTable funcs;
+ PSecurityFunctionTable funcs;
- SChannelLib() {
- memset(&funcs, 0, sizeof(funcs));
- lib_ = LoadLibrary(L"SCHANNEL.DLL");
+ SChannelLib() : funcs(NULL) {
+ lib_ = LoadLibrary(L"secur32.dll");
if (lib_) {
INIT_SECURITY_INTERFACE init_security_interface =
reinterpret_cast<INIT_SECURITY_INTERFACE>(
GetProcAddress(lib_, "InitSecurityInterfaceW"));
- if (init_security_interface) {
- PSecurityFunctionTable funcs_ptr = init_security_interface();
- if (funcs_ptr)
- memcpy(&funcs, funcs_ptr, sizeof(funcs));
- }
+ if (init_security_interface)
+ funcs = init_security_interface();
}
}
~SChannelLib() {
- FreeLibrary(lib_);
+ if (lib_)
+ FreeLibrary(lib_);
}
private:
HMODULE lib_;
};
-static inline SecurityFunctionTable& SChannel() {
+static inline PSecurityFunctionTable SChannel() {
return Singleton<SChannelLib>()->funcs;
}
//-----------------------------------------------------------------------------
-static const int kRecvBufferSize = 0x10000;
+// Size of recv_buffer_
+//
+// Ciphertext is decrypted one SSL record at a time, so recv_buffer_ needs to
+// have room for a full SSL record, with the header and trailer. Here is the
+// breakdown of the size:
+// 5: SSL record header
+// 16K: SSL record maximum size
+// 64: >= SSL record trailer (16 or 20 have been observed)
+static const int kRecvBufferSize = (5 + 16*1024 + 64);
SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket,
const std::string& hostname)
@@ -84,9 +90,14 @@ SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket,
user_buf_(NULL),
user_buf_len_(0),
next_state_(STATE_NONE),
+ payload_send_buffer_len_(0),
bytes_sent_(0),
+ decrypted_ptr_(NULL),
+ bytes_decrypted_(0),
+ received_ptr_(NULL),
bytes_received_(0),
- completed_handshake_(false) {
+ completed_handshake_(false),
+ ignore_ok_result_(false) {
memset(&stream_sizes_, 0, sizeof(stream_sizes_));
memset(&send_buffer_, 0, sizeof(send_buffer_));
memset(&creds_, 0, sizeof(creds_));
@@ -115,20 +126,25 @@ int SSLClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) {
}
void SSLClientSocket::Disconnect() {
+ // TODO(wtc): Send SSL close_notify alert.
+ completed_handshake_ = false;
transport_->Disconnect();
if (send_buffer_.pvBuffer) {
- SChannel().FreeContextBuffer(send_buffer_.pvBuffer);
+ SChannel()->FreeContextBuffer(send_buffer_.pvBuffer);
memset(&send_buffer_, 0, sizeof(send_buffer_));
}
if (creds_.dwLower || creds_.dwUpper) {
- SChannel().FreeCredentialsHandle(&creds_);
+ SChannel()->FreeCredentialsHandle(&creds_);
memset(&creds_, 0, sizeof(creds_));
}
if (ctxt_.dwLower || ctxt_.dwUpper) {
- SChannel().DeleteSecurityContext(&ctxt_);
+ SChannel()->DeleteSecurityContext(&ctxt_);
memset(&ctxt_, 0, sizeof(ctxt_));
}
+ // TODO(wtc): reset more members?
+ bytes_decrypted_ = 0;
+ bytes_received_ = 0;
}
bool SSLClientSocket::IsConnected() const {
@@ -141,10 +157,32 @@ int SSLClientSocket::Read(char* buf, int buf_len,
DCHECK(next_state_ == STATE_NONE);
DCHECK(!user_callback_);
+ // If we have surplus decrypted plaintext, satisfy the Read with it without
+ // reading more ciphertext from the transport socket.
+ if (bytes_decrypted_ != 0) {
+ int len = std::min(buf_len, bytes_decrypted_);
+ memcpy(buf, decrypted_ptr_, len);
+ decrypted_ptr_ += len;
+ bytes_decrypted_ -= len;
+ if (bytes_decrypted_ == 0) {
+ decrypted_ptr_ = NULL;
+ if (bytes_received_ != 0) {
+ memmove(recv_buffer_.get(), received_ptr_, bytes_received_);
+ received_ptr_ = recv_buffer_.get();
+ }
+ }
+ return len;
+ }
+
user_buf_ = buf;
user_buf_len_ = buf_len;
- next_state_ = STATE_PAYLOAD_READ;
+ if (bytes_received_ == 0) {
+ next_state_ = STATE_PAYLOAD_READ;
+ } else {
+ next_state_ = STATE_PAYLOAD_READ_COMPLETE;
+ ignore_ok_result_ = true; // OK doesn't mean EOF.
+ }
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING)
user_callback_ = callback;
@@ -160,18 +198,41 @@ int SSLClientSocket::Write(const char* buf, int buf_len,
user_buf_ = const_cast<char*>(buf);
user_buf_len_ = buf_len;
- next_state_ = STATE_PAYLOAD_WRITE;
+ next_state_ = STATE_PAYLOAD_ENCRYPT;
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING)
user_callback_ = callback;
return rv;
}
+void SSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ SECURITY_STATUS status;
+ PCCERT_CONTEXT server_cert = NULL;
+ status = SChannel()->QueryContextAttributes(&ctxt_,
+ SECPKG_ATTR_REMOTE_CERT_CONTEXT,
+ &server_cert);
+ if (status == SEC_E_OK) {
+ DCHECK(server_cert);
+ ssl_info->cert = X509Certificate::CreateFromHandle(server_cert);
+ }
+ SecPkgContext_ConnectionInfo connection_info;
+ status = SChannel()->QueryContextAttributes(&ctxt_,
+ SECPKG_ATTR_CONNECTION_INFO,
+ &connection_info);
+ if (status == SEC_E_OK) {
+ // TODO(wtc): compute the overall security strength, taking into account
+ // dwExchStrength and dwHashStrength. dwExchStrength needs to be
+ // normalized.
+ ssl_info->security_bits = connection_info.dwCipherStrength;
+ }
+ ssl_info->cert_status = 0;
+}
+
void SSLClientSocket::DoCallback(int rv) {
DCHECK(rv != ERR_IO_PENDING);
DCHECK(user_callback_);
- // since Run may result in Read being called, clear callback_ up front.
+ // since Run may result in Read being called, clear user_callback_ up front.
CompletionCallback* c = user_callback_;
user_callback_ = NULL;
c->Run(rv);
@@ -214,6 +275,9 @@ int SSLClientSocket::DoLoop(int last_io_result) {
case STATE_PAYLOAD_READ_COMPLETE:
rv = DoPayloadReadComplete(rv);
break;
+ case STATE_PAYLOAD_ENCRYPT:
+ rv = DoPayloadEncrypt();
+ break;
case STATE_PAYLOAD_WRITE:
rv = DoPayloadWrite();
break;
@@ -242,25 +306,51 @@ int SSLClientSocket::DoConnectComplete(int result) {
SCHANNEL_CRED schannel_cred = {0};
schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
+
+ // TODO(wtc): This should be configurable. Hardcoded to do SSL 3.0 and
+ // TLS 1.0 for now. The default (0) means Schannel selects the protocol.
+ // The global system registry settings take precedence over this value.
+ schannel_cred.grbitEnabledProtocols = SP_PROT_SSL3TLS1;
+
+ // The default session lifetime is 36000000 milliseconds (ten hours). Set
+ // schannel_cred.dwSessionLifespan to change the number of milliseconds that
+ // Schannel keeps the session in its session cache.
+
+ // We can set the key exchange algorithms (RSA or DH) in
+ // schannel_cred.{cSupportedAlgs,palgSupportedAlgs}.
+
+ // TODO(wtc): We may need to use SCH_CRED_IGNORE_NO_REVOCATION_CHECK and
+ // SCH_CRED_IGNORE_REVOCATION_OFFLINE, but only after getting the
+ // CRYPT_E_NO_REVOCATION_CHECK and CRYPT_E_REVOCATION_OFFLINE errors.
+ //
+ // Look into undocumented or poorly documented flags:
+ // SCH_CRED_RESTRICTED_ROOTS
+ // SCH_CRED_REVOCATION_CHECK_CACHE_ONLY
+ // SCH_CRED_CACHE_ONLY_URL_RETRIEVAL
+ // SCH_CRED_MEMORY_STORE_CERT
+ // SCH_CRED_CACHE_ONLY_URL_RETRIEVAL_ON_CREATE
+ //
+ // SCH_CRED_NO_SERVERNAME_CHECK can be useful during testing.
schannel_cred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS |
- SCH_CRED_NO_SYSTEM_MAPPER |
- SCH_CRED_REVOCATION_CHECK_CHAIN;
+ SCH_CRED_NO_SERVERNAME_CHECK | // Remove me!
+ SCH_CRED_AUTO_CRED_VALIDATION |
+ SCH_CRED_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
TimeStamp expiry;
SECURITY_STATUS status;
- status = SChannel().AcquireCredentialsHandle(
- NULL,
- UNISP_NAME,
+ status = SChannel()->AcquireCredentialsHandle(
+ NULL, // Not used
+ UNISP_NAME, // Microsoft Unified Security Protocol Provider
SECPKG_CRED_OUTBOUND,
- NULL,
+ NULL, // Not used
&schannel_cred,
- NULL,
- NULL,
+ NULL, // Not used
+ NULL, // Not used
&creds_,
- &expiry);
+ &expiry); // Optional
if (status != SEC_E_OK) {
DLOG(ERROR) << "AcquireCredentialsHandle failed: " << status;
- return ERR_FAILED;
+ return ERR_FAILED; // TODO(wtc): map SEC_E_xxx error codes.
}
SecBufferDesc buffer_desc;
@@ -280,22 +370,22 @@ int SSLClientSocket::DoConnectComplete(int result) {
buffer_desc.pBuffers = &send_buffer_;
buffer_desc.ulVersion = SECBUFFER_VERSION;
- status = SChannel().InitializeSecurityContext(
+ status = SChannel()->InitializeSecurityContext(
&creds_,
- NULL,
+ NULL, // NULL on the first call
const_cast<wchar_t*>(ASCIIToWide(hostname_).c_str()),
flags,
- 0,
- SECURITY_NATIVE_DREP,
- NULL,
- 0,
- &ctxt_,
+ 0, // Reserved
+ SECURITY_NATIVE_DREP, // TODO(wtc): MSDN says this should be set to 0.
+ NULL, // NULL on the first call
+ 0, // Reserved
+ &ctxt_, // Receives the new context handle
&buffer_desc,
&out_flags,
&expiry);
if (status != SEC_I_CONTINUE_NEEDED) {
DLOG(ERROR) << "InitializeSecurityContext failed: " << status;
- return ERR_FAILED;
+ return ERR_FAILED; // TODO(wtc): map SEC_E_xxx error codes.
}
next_state_ = STATE_HANDSHAKE_WRITE;
@@ -322,9 +412,11 @@ int SSLClientSocket::DoHandshakeRead() {
int SSLClientSocket::DoHandshakeReadComplete(int result) {
if (result < 0)
return result;
- if (result == 0)
+ if (result == 0 && !ignore_ok_result_)
return ERR_FAILED; // Incomplete response :(
+ ignore_ok_result_ = false;
+
bytes_received_ += result;
// Process the contents of recv_buffer_.
@@ -362,7 +454,7 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) {
send_buffer_.BufferType = SECBUFFER_TOKEN;
send_buffer_.cbBuffer = 0;
- status = SChannel().InitializeSecurityContext(
+ status = SChannel()->InitializeSecurityContext(
&creds_,
&ctxt_,
NULL,
@@ -377,18 +469,25 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) {
&expiry);
if (status == SEC_E_INCOMPLETE_MESSAGE) {
+ DCHECK(FAILED(status));
+ DCHECK(send_buffer_.cbBuffer == 0 ||
+ !(out_flags & ISC_RET_EXTENDED_ERROR));
next_state_ = STATE_HANDSHAKE_READ;
return OK;
}
- // OK, all of the received data was consumed.
- bytes_received_ = 0;
-
if (send_buffer_.cbBuffer != 0 &&
(status == SEC_E_OK ||
status == SEC_I_CONTINUE_NEEDED ||
FAILED(status) && (out_flags & ISC_RET_EXTENDED_ERROR))) {
+ // TODO(wtc): if status is SEC_E_OK, we should finish the handshake
+ // successfully after sending send_buffer_.
+ // If FAILED(status) is true, we should terminate the connection after
+ // sending send_buffer_.
+ DCHECK(status == SEC_I_CONTINUE_NEEDED); // We only handle this case
+ // correctly.
next_state_ = STATE_HANDSHAKE_WRITE;
+ bytes_received_ = 0;
return OK;
}
@@ -396,13 +495,28 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) {
if (in_buffers[1].BufferType == SECBUFFER_EXTRA) {
// TODO(darin) need to save this data for later.
NOTREACHED() << "should not occur for HTTPS traffic";
+ return ERR_FAILED;
}
+ bytes_received_ = 0;
return DidCompleteHandshake();
}
if (FAILED(status))
- return ERR_FAILED;
+ return ERR_FAILED; // TODO(wtc): map error codes, in particular cert
+ // errors such as SEC_E_UNTRUSTED_ROOT.
+
+ DCHECK(status == SEC_I_CONTINUE_NEEDED);
+ if (in_buffers[1].BufferType == SECBUFFER_EXTRA) {
+ memmove(&recv_buffer_[0],
+ &recv_buffer_[0] + (bytes_received_ - in_buffers[1].cbBuffer),
+ in_buffers[1].cbBuffer);
+ bytes_received_ = in_buffers[1].cbBuffer;
+ next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
+ ignore_ok_result_ = true; // OK doesn't mean EOF.
+ return OK;
+ }
+ bytes_received_ = 0;
next_state_ = STATE_HANDSHAKE_READ;
return OK;
}
@@ -426,14 +540,18 @@ int SSLClientSocket::DoHandshakeWriteComplete(int result) {
DCHECK(result != 0);
- // TODO(darin): worry about overflow?
bytes_sent_ += result;
DCHECK(bytes_sent_ <= static_cast<int>(send_buffer_.cbBuffer));
- if (bytes_sent_ == static_cast<int>(send_buffer_.cbBuffer)) {
- SChannel().FreeContextBuffer(send_buffer_.pvBuffer);
+ if (bytes_sent_ >= static_cast<int>(send_buffer_.cbBuffer)) {
+ bool overflow = (bytes_sent_ > static_cast<int>(send_buffer_.cbBuffer));
+ SECURITY_STATUS status =
+ SChannel()->FreeContextBuffer(send_buffer_.pvBuffer);
+ DCHECK(status == SEC_E_OK);
memset(&send_buffer_, 0, sizeof(send_buffer_));
bytes_sent_ = 0;
+ if (overflow) // Bug!
+ return ERR_UNEXPECTED;
next_state_ = STATE_HANDSHAKE_READ;
} else {
// Send the remaining bytes.
@@ -446,46 +564,196 @@ int SSLClientSocket::DoHandshakeWriteComplete(int result) {
int SSLClientSocket::DoPayloadRead() {
next_state_ = STATE_PAYLOAD_READ_COMPLETE;
- return ERR_FAILED;
+ DCHECK(recv_buffer_.get());
+
+ char* buf = recv_buffer_.get() + bytes_received_;
+ int buf_len = kRecvBufferSize - bytes_received_;
+
+ if (buf_len <= 0) {
+ NOTREACHED() << "Receive buffer is too small!";
+ return ERR_FAILED;
+ }
+
+ return transport_->Read(buf, buf_len, &io_callback_);
}
int SSLClientSocket::DoPayloadReadComplete(int result) {
- return ERR_FAILED;
+ if (result < 0)
+ return result;
+ if (result == 0 && !ignore_ok_result_) {
+ // TODO(wtc): Unless we have received the close_notify alert, we need to
+ // return an error code indicating that the SSL connection ended
+ // uncleanly, a potential truncation attack.
+ if (bytes_received_ != 0)
+ return ERR_FAILED;
+ return OK;
+ }
+
+ ignore_ok_result_ = false;
+
+ bytes_received_ += result;
+
+ // Process the contents of recv_buffer_.
+ SecBuffer buffers[4];
+ buffers[0].pvBuffer = recv_buffer_.get();
+ buffers[0].cbBuffer = bytes_received_;
+ buffers[0].BufferType = SECBUFFER_DATA;
+
+ buffers[1].BufferType = SECBUFFER_EMPTY;
+ buffers[2].BufferType = SECBUFFER_EMPTY;
+ buffers[3].BufferType = SECBUFFER_EMPTY;
+
+ SecBufferDesc buffer_desc;
+ buffer_desc.cBuffers = 4;
+ buffer_desc.pBuffers = buffers;
+ buffer_desc.ulVersion = SECBUFFER_VERSION;
+
+ SECURITY_STATUS status;
+ status = SChannel()->DecryptMessage(&ctxt_, &buffer_desc, 0, NULL);
+
+ if (status == SEC_E_INCOMPLETE_MESSAGE) {
+ next_state_ = STATE_PAYLOAD_READ;
+ return OK;
+ }
+
+ if (status == SEC_I_CONTEXT_EXPIRED) {
+ // Received the close_notify alert.
+ bytes_received_ = 0;
+ return OK;
+ }
+
+ if (status != SEC_E_OK && status != SEC_I_RENEGOTIATE) {
+ DCHECK(status != SEC_E_MESSAGE_ALTERED);
+ return ERR_FAILED; // TODO(wtc): map error code
+ }
+
+ // The received ciphertext was decrypted in place in recv_buffer_. Remember
+ // the location and length of the decrypted plaintext and any unused
+ // ciphertext.
+ decrypted_ptr_ = NULL;
+ bytes_decrypted_ = 0;
+ received_ptr_ = NULL;
+ bytes_received_ = 0;
+ for (int i = 1; i < 4; i++) {
+ if (!decrypted_ptr_ && buffers[i].BufferType == SECBUFFER_DATA) {
+ decrypted_ptr_ = static_cast<char*>(buffers[i].pvBuffer);
+ bytes_decrypted_ = buffers[i].cbBuffer;
+ }
+ if (!received_ptr_ && buffers[i].BufferType == SECBUFFER_EXTRA) {
+ received_ptr_ = static_cast<char*>(buffers[i].pvBuffer);
+ bytes_received_ = buffers[i].cbBuffer;
+ }
+ }
+
+ int len = 0;
+ if (bytes_decrypted_ != 0) {
+ len = std::min(user_buf_len_, bytes_decrypted_);
+ memcpy(user_buf_, decrypted_ptr_, len);
+ decrypted_ptr_ += len;
+ bytes_decrypted_ -= len;
+ }
+ if (bytes_decrypted_ == 0) {
+ decrypted_ptr_ = NULL;
+ if (bytes_received_ != 0) {
+ memmove(recv_buffer_.get(), received_ptr_, bytes_received_);
+ received_ptr_ = recv_buffer_.get();
+ }
+ }
+ // TODO(wtc): need to handle SEC_I_RENEGOTIATE.
+ DCHECK(status == SEC_E_OK);
+ return len;
}
-int SSLClientSocket::DoPayloadWrite() {
+int SSLClientSocket::DoPayloadEncrypt() {
DCHECK(user_buf_);
DCHECK(user_buf_len_ > 0);
- next_state_ = STATE_PAYLOAD_WRITE_COMPLETE;
-
- size_t message_len = std::min(
+ ULONG message_len = std::min(
stream_sizes_.cbMaximumMessage, static_cast<ULONG>(user_buf_len_));
- size_t alloc_len =
+ ULONG alloc_len =
message_len + stream_sizes_.cbHeader + stream_sizes_.cbTrailer;
+ user_buf_len_ = message_len;
+
+ payload_send_buffer_.reset(new char[alloc_len]);
+ memcpy(&payload_send_buffer_[stream_sizes_.cbHeader],
+ user_buf_, message_len);
- /*
SecBuffer buffers[4];
- buffers[0].
+ buffers[0].pvBuffer = payload_send_buffer_.get();
+ buffers[0].cbBuffer = stream_sizes_.cbHeader;
+ buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
+
+ buffers[1].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader];
+ buffers[1].cbBuffer = message_len;
+ buffers[1].BufferType = SECBUFFER_DATA;
+
+ buffers[2].pvBuffer = &payload_send_buffer_[stream_sizes_.cbHeader +
+ message_len];
+ buffers[2].cbBuffer = stream_sizes_.cbTrailer;
+ buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
+
+ buffers[3].BufferType = SECBUFFER_EMPTY;
SecBufferDesc buffer_desc;
buffer_desc.cBuffers = 4;
- buffer_desc.pBuffers = //XXX
+ buffer_desc.pBuffers = buffers;
buffer_desc.ulVersion = SECBUFFER_VERSION;
- SECURITY_STATUS status = SChannel().EncryptMessage(
+ SECURITY_STATUS status = SChannel()->EncryptMessage(
&ctxt_, 0, &buffer_desc, 0);
- */
- return ERR_FAILED;
+ if (FAILED(status))
+ return ERR_FAILED;
+
+ payload_send_buffer_len_ = buffers[0].cbBuffer +
+ buffers[1].cbBuffer +
+ buffers[2].cbBuffer;
+ DCHECK(bytes_sent_ == 0);
+
+ next_state_ = STATE_PAYLOAD_WRITE;
+ return OK;
+}
+
+int SSLClientSocket::DoPayloadWrite() {
+ next_state_ = STATE_PAYLOAD_WRITE_COMPLETE;
+
+ // We should have something to send.
+ DCHECK(payload_send_buffer_.get());
+ DCHECK(payload_send_buffer_len_ > 0);
+
+ const char* buf = payload_send_buffer_.get() + bytes_sent_;
+ int buf_len = payload_send_buffer_len_ - bytes_sent_;
+
+ return transport_->Write(buf, buf_len, &io_callback_);
}
int SSLClientSocket::DoPayloadWriteComplete(int result) {
- return ERR_FAILED;
+ if (result < 0)
+ return result;
+
+ DCHECK(result != 0);
+
+ bytes_sent_ += result;
+ DCHECK(bytes_sent_ <= payload_send_buffer_len_);
+
+ if (bytes_sent_ >= payload_send_buffer_len_) {
+ bool overflow = (bytes_sent_ > payload_send_buffer_len_);
+ payload_send_buffer_.reset();
+ payload_send_buffer_len_ = 0;
+ bytes_sent_ = 0;
+ if (overflow) // Bug!
+ return ERR_UNEXPECTED;
+ // Done
+ return user_buf_len_;
+ }
+
+ // Send the remaining bytes.
+ next_state_ = STATE_PAYLOAD_WRITE;
+ return OK;
}
int SSLClientSocket::DidCompleteHandshake() {
- SECURITY_STATUS status = SChannel().QueryContextAttributes(
+ SECURITY_STATUS status = SChannel()->QueryContextAttributes(
&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
if (status != SEC_E_OK) {
DLOG(ERROR) << "QueryContextAttributes failed: " << status;
diff --git a/net/base/ssl_client_socket.h b/net/base/ssl_client_socket.h
index 3e1779f..8711e2a 100644
--- a/net/base/ssl_client_socket.h
+++ b/net/base/ssl_client_socket.h
@@ -43,6 +43,8 @@
namespace net {
+class SSLInfo;
+
// A client socket that uses SSL as the transport layer.
//
// NOTE: The SSL handshake occurs within the Connect method after a TCP
@@ -68,6 +70,9 @@ class SSLClientSocket : public ClientSocket {
virtual int Read(char* buf, int buf_len, CompletionCallback* callback);
virtual int Write(const char* buf, int buf_len, CompletionCallback* callback);
+ // Gets the SSL connection information of the socket.
+ void GetSSLInfo(SSLInfo* ssl_info);
+
private:
void DoCallback(int result);
void OnIOComplete(int result);
@@ -81,6 +86,7 @@ class SSLClientSocket : public ClientSocket {
int DoHandshakeWriteComplete(int result);
int DoPayloadRead();
int DoPayloadReadComplete(int result);
+ int DoPayloadEncrypt();
int DoPayloadWrite();
int DoPayloadWriteComplete(int result);
@@ -104,6 +110,7 @@ class SSLClientSocket : public ClientSocket {
STATE_HANDSHAKE_READ_COMPLETE,
STATE_HANDSHAKE_WRITE,
STATE_HANDSHAKE_WRITE_COMPLETE,
+ STATE_PAYLOAD_ENCRYPT,
STATE_PAYLOAD_WRITE,
STATE_PAYLOAD_WRITE_COMPLETE,
STATE_PAYLOAD_READ,
@@ -116,12 +123,35 @@ class SSLClientSocket : public ClientSocket {
CredHandle creds_;
CtxtHandle ctxt_;
SecBuffer send_buffer_;
+ scoped_array<char> payload_send_buffer_;
+ int payload_send_buffer_len_;
int bytes_sent_;
+ // recv_buffer_ holds the received ciphertext. Since Schannel decrypts
+ // data in place, sometimes recv_buffer_ may contain decrypted plaintext and
+ // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record
+ // at a time.)
+ //
+ // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of
+ // recv_buffer_, ready to be passed to DecryptMessage.
scoped_array<char> recv_buffer_;
- int bytes_received_;
+ char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_
+ int bytes_decrypted_; // The number of bytes of decrypted plaintext.
+ char* received_ptr_; // Points to the received ciphertext in recv_buffer_
+ int bytes_received_; // The number of bytes of received ciphertext.
bool completed_handshake_;
+
+ // Only used in the STATE_HANDSHAKE_READ_COMPLETE and
+ // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK
+ // should be ignored, to prevent it from being interpreted as EOF.
+ //
+ // The reason we need this flag is that OK means not only "0 bytes of data
+ // were read" but also EOF. We set ignore_ok_result_ to true when we need
+ // to continue processing previously read data without reading more data.
+ // We have to pass a 'result' of OK to the DoLoop method, and don't want it
+ // to be interpreted as EOF.
+ bool ignore_ok_result_;
};
} // namespace net
diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc
index ebcce65..7bf8ea3 100644
--- a/net/base/ssl_client_socket_unittest.cc
+++ b/net/base/ssl_client_socket_unittest.cc
@@ -51,7 +51,7 @@ TEST_F(SSLClientSocketTest, Connect) {
net::HostResolver resolver;
TestCompletionCallback callback;
- std::string hostname = "www.verisign.com";
+ std::string hostname = "bugs.webkit.org";
int rv = resolver.Resolve(hostname, 443, &addr, NULL);
EXPECT_EQ(net::OK, rv);
@@ -73,13 +73,12 @@ TEST_F(SSLClientSocketTest, Connect) {
EXPECT_FALSE(sock.IsConnected());
}
-#if 0
TEST_F(SSLClientSocketTest, Read) {
net::AddressList addr;
net::HostResolver resolver;
TestCompletionCallback callback;
- std::string hostname = "www.google.com";
+ std::string hostname = "bugs.webkit.org";
int rv = resolver.Resolve(hostname, 443, &addr, &callback);
EXPECT_EQ(rv, net::ERR_IO_PENDING);
@@ -124,10 +123,11 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) {
net::HostResolver resolver;
TestCompletionCallback callback;
- int rv = resolver.Resolve("www.google.com", 80, &addr, NULL);
+ std::string hostname = "bugs.webkit.org";
+ int rv = resolver.Resolve(hostname, 443, &addr, NULL);
EXPECT_EQ(rv, net::OK);
- net::TCPClientSocket sock(addr);
+ net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname);
rv = sock.Connect(&callback);
if (rv != net::OK) {
@@ -165,10 +165,11 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) {
net::HostResolver resolver;
TestCompletionCallback callback;
- int rv = resolver.Resolve("www.google.com", 80, &addr, NULL);
+ std::string hostname = "bugs.webkit.org";
+ int rv = resolver.Resolve(hostname, 443, &addr, NULL);
EXPECT_EQ(rv, net::OK);
- net::TCPClientSocket sock(addr);
+ net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname);
rv = sock.Connect(&callback);
if (rv != net::OK) {
@@ -197,4 +198,3 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) {
EXPECT_NE(rv, 0);
}
-#endif
diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc
index 56ec449..69af354 100644
--- a/net/http/http_network_transaction.cc
+++ b/net/http/http_network_transaction.cc
@@ -33,6 +33,7 @@
#include "net/base/client_socket_factory.h"
#include "net/base/host_resolver.h"
#include "net/base/load_flags.h"
+#include "net/base/ssl_client_socket.h"
#include "net/base/upload_data_stream.h"
#include "net/http/http_chunked_decoder.h"
#include "net/http/http_network_session.h"
@@ -123,7 +124,7 @@ int HttpNetworkTransaction::Read(char* buf, int buf_len,
}
const HttpResponseInfo* HttpNetworkTransaction::GetResponseInfo() const {
- return response_.headers ? &response_ : NULL;
+ return (response_.headers || response_.ssl_info.cert) ? &response_ : NULL;
}
LoadState HttpNetworkTransaction::GetLoadState() const {
@@ -663,6 +664,12 @@ int HttpNetworkTransaction::DidReadResponseHeaders() {
}
}
+ if (using_ssl_) {
+ SSLClientSocket* ssl_socket =
+ reinterpret_cast<SSLClientSocket*>(connection_.socket());
+ ssl_socket->GetSSLInfo(&response_.ssl_info);
+ }
+
return OK;
}