summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorwtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98>2008-10-15 00:20:11 +0000
committerwtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98>2008-10-15 00:20:11 +0000
commitaaead5019627818c93693fdb6ec04d47b47c17f2 (patch)
tree6ff15880c597bb59a8c3def51d13ce492a4bb405
parent1ad083f293cd321fa7d7c8f14e71816571c6c54f (diff)
downloadchromium_src-aaead5019627818c93693fdb6ec04d47b47c17f2.zip
chromium_src-aaead5019627818c93693fdb6ec04d47b47c17f2.tar.gz
chromium_src-aaead5019627818c93693fdb6ec04d47b47c17f2.tar.bz2
Turn SSLClientSocket into an interface.
The original ssl_client_socket.{h,cc} are renamed ssl_client_socket_win.{h,cc}. The new ssl_client_socket.h defines the SSLClientSocket interface, which simply extends the ClientSocket interface with a new GetSSLInfo method. ClientSocketFactory::CreateSSLClientSocket returns SSLClientSocket* instead of ClientSocket*. Replace the SSL protocol version mask parameter to the constructor and factory method by a SSLConfig parameter. R=darin Review URL: http://codereview.chromium.org/7304 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@3387 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--net/SConscript2
-rw-r--r--net/base/client_socket_factory.cc9
-rw-r--r--net/base/client_socket_factory.h9
-rw-r--r--net/base/ssl_client_socket.h125
-rw-r--r--net/base/ssl_client_socket_unittest.cc60
-rw-r--r--net/base/ssl_client_socket_win.cc (renamed from net/base/ssl_client_socket.cc)119
-rw-r--r--net/base/ssl_client_socket_win.h139
-rw-r--r--net/build/net.vcproj8
-rw-r--r--net/http/http_network_transaction.cc27
-rw-r--r--net/http/http_network_transaction.h3
-rw-r--r--net/http/http_network_transaction_unittest.cc12
-rw-r--r--net/net.xcodeproj/project.pbxproj2
12 files changed, 266 insertions, 249 deletions
diff --git a/net/SConscript b/net/SConscript
index b3dd629..37d6c18 100644
--- a/net/SConscript
+++ b/net/SConscript
@@ -94,7 +94,7 @@ if env['PLATFORM'] == 'win32':
'base/directory_lister.cc',
'base/dns_resolution_observer.cc',
'base/listen_socket.cc',
- 'base/ssl_client_socket.cc',
+ 'base/ssl_client_socket_win.cc',
'base/ssl_config_service.cc',
'base/tcp_client_socket.cc',
'base/telnet_server.cc',
diff --git a/net/base/client_socket_factory.cc b/net/base/client_socket_factory.cc
index 43df0e6..10f24df 100644
--- a/net/base/client_socket_factory.cc
+++ b/net/base/client_socket_factory.cc
@@ -7,7 +7,7 @@
#include "base/singleton.h"
#include "build/build_config.h"
#if defined(OS_WIN)
-#include "net/base/ssl_client_socket.h"
+#include "net/base/ssl_client_socket_win.h"
#endif
#include "net/base/tcp_client_socket.h"
@@ -20,13 +20,12 @@ class DefaultClientSocketFactory : public ClientSocketFactory {
return new TCPClientSocket(addresses);
}
- virtual ClientSocket* CreateSSLClientSocket(
+ virtual SSLClientSocket* CreateSSLClientSocket(
ClientSocket* transport_socket,
const std::string& hostname,
- int protocol_version_mask) {
+ const SSLConfig& ssl_config) {
#if defined(OS_WIN)
- return new SSLClientSocket(transport_socket, hostname,
- protocol_version_mask);
+ return new SSLClientSocketWin(transport_socket, hostname, ssl_config);
#else
// TODO(pinkerton): turn on when we port SSL socket from win32
NOTIMPLEMENTED();
diff --git a/net/base/client_socket_factory.h b/net/base/client_socket_factory.h
index ed371a3..3a3bb84 100644
--- a/net/base/client_socket_factory.h
+++ b/net/base/client_socket_factory.h
@@ -11,6 +11,8 @@ namespace net {
class AddressList;
class ClientSocket;
+class SSLClientSocket;
+struct SSLConfig;
// An interface used to instantiate ClientSocket objects. Used to facilitate
// testing code with mock socket implementations.
@@ -21,13 +23,10 @@ class ClientSocketFactory {
virtual ClientSocket* CreateTCPClientSocket(
const AddressList& addresses) = 0;
- // protocol_version_mask is a bitmask that specifies which versions of the
- // SSL protocol (SSL 2.0, SSL 3.0, and TLS 1.0) should be enabled. The bit
- // flags are defined in net/base/ssl_client_socket.h.
- virtual ClientSocket* CreateSSLClientSocket(
+ virtual SSLClientSocket* CreateSSLClientSocket(
ClientSocket* transport_socket,
const std::string& hostname,
- int protocol_version_mask) = 0;
+ const SSLConfig& ssl_config) = 0;
// Returns the default ClientSocketFactory.
static ClientSocketFactory* GetDefaultFactory();
diff --git a/net/base/ssl_client_socket.h b/net/base/ssl_client_socket.h
index 100e514..dca5ef3 100644
--- a/net/base/ssl_client_socket.h
+++ b/net/base/ssl_client_socket.h
@@ -5,17 +5,7 @@
#ifndef NET_BASE_SSL_CLIENT_SOCKET_H_
#define NET_BASE_SSL_CLIENT_SOCKET_H_
-#define SECURITY_WIN32 // Needs to be defined before including security.h
-
-#include <windows.h>
-#include <wincrypt.h>
-#include <security.h>
-
-#include <string>
-
-#include "base/scoped_ptr.h"
#include "net/base/client_socket.h"
-#include "net/base/completion_callback.h"
namespace net {
@@ -30,121 +20,8 @@ class SSLInfo;
//
class SSLClientSocket : public ClientSocket {
public:
- enum {
- SSL2 = 1 << 0,
- SSL3 = 1 << 1,
- TLS1 = 1 << 2
- };
-
- // Takes ownership of the transport_socket, which may already be connected.
- // The given hostname will be compared with the name(s) in the server's
- // certificate during the SSL handshake. protocol_version_mask is a bitwise
- // OR of SSL2, SSL3, and TLS1 that specifies which versions of the SSL
- // protocol should be enabled.
- SSLClientSocket(ClientSocket* transport_socket,
- const std::string& hostname,
- int protocol_version_mask);
- ~SSLClientSocket();
-
- // ClientSocket methods:
- virtual int Connect(CompletionCallback* callback);
- virtual int ReconnectIgnoringLastError(CompletionCallback* callback);
- virtual void Disconnect();
- virtual bool IsConnected() const;
-
- // Socket methods:
- virtual int Read(char* buf, int buf_len, CompletionCallback* callback);
- virtual int Write(const char* buf, int buf_len, CompletionCallback* callback);
-
// Gets the SSL connection information of the socket.
- void GetSSLInfo(SSLInfo* ssl_info);
-
- private:
- void DoCallback(int result);
- void OnIOComplete(int result);
-
- int DoLoop(int last_io_result);
- int DoConnect();
- int DoConnectComplete(int result);
- int DoHandshakeRead();
- int DoHandshakeReadComplete(int result);
- int DoHandshakeWrite();
- int DoHandshakeWriteComplete(int result);
- int DoPayloadRead();
- int DoPayloadReadComplete(int result);
- int DoPayloadEncrypt();
- int DoPayloadWrite();
- int DoPayloadWriteComplete(int result);
-
- int DidCompleteHandshake();
- int VerifyServerCert();
-
- CompletionCallbackImpl<SSLClientSocket> io_callback_;
- scoped_ptr<ClientSocket> transport_;
- std::string hostname_;
- int protocol_version_mask_;
-
- CompletionCallback* user_callback_;
-
- // Used by both Read and Write functions.
- char* user_buf_;
- int user_buf_len_;
-
- enum State {
- STATE_NONE,
- STATE_CONNECT,
- STATE_CONNECT_COMPLETE,
- STATE_HANDSHAKE_READ,
- STATE_HANDSHAKE_READ_COMPLETE,
- STATE_HANDSHAKE_WRITE,
- STATE_HANDSHAKE_WRITE_COMPLETE,
- STATE_PAYLOAD_ENCRYPT,
- STATE_PAYLOAD_WRITE,
- STATE_PAYLOAD_WRITE_COMPLETE,
- STATE_PAYLOAD_READ,
- STATE_PAYLOAD_READ_COMPLETE,
- };
- State next_state_;
-
- SecPkgContext_StreamSizes stream_sizes_;
- PCCERT_CONTEXT server_cert_;
- int server_cert_status_;
-
- CredHandle creds_;
- CtxtHandle ctxt_;
- SecBuffer send_buffer_;
- scoped_array<char> payload_send_buffer_;
- int payload_send_buffer_len_;
- int bytes_sent_;
-
- // recv_buffer_ holds the received ciphertext. Since Schannel decrypts
- // data in place, sometimes recv_buffer_ may contain decrypted plaintext and
- // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record
- // at a time.)
- //
- // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of
- // recv_buffer_, ready to be passed to DecryptMessage.
- scoped_array<char> recv_buffer_;
- char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_
- int bytes_decrypted_; // The number of bytes of decrypted plaintext.
- char* received_ptr_; // Points to the received ciphertext in recv_buffer_
- int bytes_received_; // The number of bytes of received ciphertext.
-
- bool completed_handshake_;
-
- // Only used in the STATE_HANDSHAKE_READ_COMPLETE and
- // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK
- // should be ignored, to prevent it from being interpreted as EOF.
- //
- // The reason we need this flag is that OK means not only "0 bytes of data
- // were read" but also EOF. We set ignore_ok_result_ to true when we need
- // to continue processing previously read data without reading more data.
- // We have to pass a 'result' of OK to the DoLoop method, and don't want it
- // to be interpreted as EOF.
- bool ignore_ok_result_;
-
- // True if the user has no client certificate.
- bool no_client_cert_;
+ virtual void GetSSLInfo(SSLInfo* ssl_info) = 0;
};
} // namespace net
diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc
index 2aba7ab..d1f1f82 100644
--- a/net/base/ssl_client_socket_unittest.cc
+++ b/net/base/ssl_client_socket_unittest.cc
@@ -3,9 +3,11 @@
// found in the LICENSE file.
#include "net/base/address_list.h"
+#include "net/base/client_socket_factory.h"
#include "net/base/host_resolver.h"
#include "net/base/net_errors.h"
#include "net/base/ssl_client_socket.h"
+#include "net/base/ssl_config_service.h"
#include "net/base/tcp_client_socket.h"
#include "net/base/test_completion_callback.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -14,10 +16,16 @@
namespace {
-const unsigned int kDefaultSSLVersionMask = net::SSLClientSocket::SSL3 |
- net::SSLClientSocket::TLS1;
+const net::SSLConfig kDefaultSSLConfig;
class SSLClientSocketTest : public testing::Test {
+ public:
+ SSLClientSocketTest()
+ : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {
+ }
+
+ protected:
+ net::ClientSocketFactory* socket_factory_;
};
} // namespace
@@ -34,12 +42,13 @@ TEST_F(SSLClientSocketTest, DISABLED_Connect) {
int rv = resolver.Resolve(hostname, 443, &addr, NULL);
EXPECT_EQ(net::OK, rv);
- net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname,
- kDefaultSSLVersionMask);
+ scoped_ptr<net::SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr),
+ hostname, kDefaultSSLConfig));
- EXPECT_FALSE(sock.IsConnected());
+ EXPECT_FALSE(sock->IsConnected());
- rv = sock.Connect(&callback);
+ rv = sock->Connect(&callback);
if (rv != net::OK) {
ASSERT_EQ(net::ERR_IO_PENDING, rv);
@@ -47,10 +56,10 @@ TEST_F(SSLClientSocketTest, DISABLED_Connect) {
EXPECT_EQ(net::OK, rv);
}
- EXPECT_TRUE(sock.IsConnected());
+ EXPECT_TRUE(sock->IsConnected());
- sock.Disconnect();
- EXPECT_FALSE(sock.IsConnected());
+ sock->Disconnect();
+ EXPECT_FALSE(sock->IsConnected());
}
// bug 1354783
@@ -66,10 +75,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) {
rv = callback.WaitForResult();
EXPECT_EQ(rv, net::OK);
- net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname,
- kDefaultSSLVersionMask);
+ scoped_ptr<net::SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr),
+ hostname, kDefaultSSLConfig));
- rv = sock.Connect(&callback);
+ rv = sock->Connect(&callback);
if (rv != net::OK) {
ASSERT_EQ(rv, net::ERR_IO_PENDING);
@@ -78,7 +88,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) {
}
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- rv = sock.Write(request_text, arraysize(request_text) - 1, &callback);
+ rv = sock->Write(request_text, arraysize(request_text) - 1, &callback);
EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
if (rv == net::ERR_IO_PENDING) {
@@ -88,7 +98,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) {
char buf[4096];
for (;;) {
- rv = sock.Read(buf, sizeof(buf), &callback);
+ rv = sock->Read(buf, sizeof(buf), &callback);
EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
if (rv == net::ERR_IO_PENDING)
@@ -110,10 +120,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) {
int rv = resolver.Resolve(hostname, 443, &addr, NULL);
EXPECT_EQ(rv, net::OK);
- net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname,
- kDefaultSSLVersionMask);
+ scoped_ptr<net::SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr),
+ hostname, kDefaultSSLConfig));
- rv = sock.Connect(&callback);
+ rv = sock->Connect(&callback);
if (rv != net::OK) {
ASSERT_EQ(rv, net::ERR_IO_PENDING);
@@ -122,7 +133,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) {
}
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- rv = sock.Write(request_text, arraysize(request_text) - 1, &callback);
+ rv = sock->Write(request_text, arraysize(request_text) - 1, &callback);
EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
if (rv == net::ERR_IO_PENDING) {
@@ -132,7 +143,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) {
char buf[1];
for (;;) {
- rv = sock.Read(buf, sizeof(buf), &callback);
+ rv = sock->Read(buf, sizeof(buf), &callback);
EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
if (rv == net::ERR_IO_PENDING)
@@ -154,10 +165,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) {
int rv = resolver.Resolve(hostname, 443, &addr, NULL);
EXPECT_EQ(rv, net::OK);
- net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname,
- kDefaultSSLVersionMask);
+ scoped_ptr<net::SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr),
+ hostname, kDefaultSSLConfig));
- rv = sock.Connect(&callback);
+ rv = sock->Connect(&callback);
if (rv != net::OK) {
ASSERT_EQ(rv, net::ERR_IO_PENDING);
@@ -166,7 +178,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) {
}
const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
- rv = sock.Write(request_text, arraysize(request_text) - 1, &callback);
+ rv = sock->Write(request_text, arraysize(request_text) - 1, &callback);
EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
if (rv == net::ERR_IO_PENDING) {
@@ -176,7 +188,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) {
// Do a partial read and then exit. This test should not crash!
char buf[512];
- rv = sock.Read(buf, sizeof(buf), &callback);
+ rv = sock->Read(buf, sizeof(buf), &callback);
EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
if (rv == net::ERR_IO_PENDING)
diff --git a/net/base/ssl_client_socket.cc b/net/base/ssl_client_socket_win.cc
index d155009..1eeb090 100644
--- a/net/base/ssl_client_socket.cc
+++ b/net/base/ssl_client_socket_win.cc
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "net/base/ssl_client_socket.h"
+#include "net/base/ssl_client_socket_win.h"
#include <schnlsp.h>
@@ -90,14 +90,14 @@ static int MapNetErrorToCertStatus(int error) {
// 64: >= SSL record trailer (16 or 20 have been observed)
static const int kRecvBufferSize = (5 + 16*1024 + 64);
-SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket,
- const std::string& hostname,
- int protocol_version_mask)
+SSLClientSocketWin::SSLClientSocketWin(ClientSocket* transport_socket,
+ const std::string& hostname,
+ const SSLConfig& ssl_config)
#pragma warning(suppress: 4355)
- : io_callback_(this, &SSLClientSocket::OnIOComplete),
+ : io_callback_(this, &SSLClientSocketWin::OnIOComplete),
transport_(transport_socket),
hostname_(hostname),
- protocol_version_mask_(protocol_version_mask),
+ ssl_config_(ssl_config),
user_callback_(NULL),
user_buf_(NULL),
user_buf_len_(0),
@@ -119,11 +119,36 @@ SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket,
memset(&ctxt_, 0, sizeof(ctxt_));
}
-SSLClientSocket::~SSLClientSocket() {
+SSLClientSocketWin::~SSLClientSocketWin() {
Disconnect();
}
-int SSLClientSocket::Connect(CompletionCallback* callback) {
+void SSLClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) {
+ SECURITY_STATUS status = SEC_E_OK;
+ if (server_cert_ == NULL) {
+ status = QueryContextAttributes(&ctxt_,
+ SECPKG_ATTR_REMOTE_CERT_CONTEXT,
+ &server_cert_);
+ }
+ if (status == SEC_E_OK) {
+ DCHECK(server_cert_);
+ PCCERT_CONTEXT dup_cert = CertDuplicateCertificateContext(server_cert_);
+ ssl_info->cert = X509Certificate::CreateFromHandle(dup_cert);
+ }
+ SecPkgContext_ConnectionInfo connection_info;
+ status = QueryContextAttributes(&ctxt_,
+ SECPKG_ATTR_CONNECTION_INFO,
+ &connection_info);
+ if (status == SEC_E_OK) {
+ // TODO(wtc): compute the overall security strength, taking into account
+ // dwExchStrength and dwHashStrength. dwExchStrength needs to be
+ // normalized.
+ ssl_info->security_bits = connection_info.dwCipherStrength;
+ }
+ ssl_info->cert_status = server_cert_status_;
+}
+
+int SSLClientSocketWin::Connect(CompletionCallback* callback) {
DCHECK(transport_.get());
DCHECK(next_state_ == STATE_NONE);
DCHECK(!user_callback_);
@@ -135,12 +160,13 @@ int SSLClientSocket::Connect(CompletionCallback* callback) {
return rv;
}
-int SSLClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) {
+int SSLClientSocketWin::ReconnectIgnoringLastError(
+ CompletionCallback* callback) {
// TODO(darin): implement me!
return ERR_FAILED;
}
-void SSLClientSocket::Disconnect() {
+void SSLClientSocketWin::Disconnect() {
// TODO(wtc): Send SSL close_notify alert.
completed_handshake_ = false;
transport_->Disconnect();
@@ -165,7 +191,7 @@ void SSLClientSocket::Disconnect() {
bytes_received_ = 0;
}
-bool SSLClientSocket::IsConnected() const {
+bool SSLClientSocketWin::IsConnected() const {
// Ideally, we should also check if we have received the close_notify alert
// message from the server, and return false in that case. We're not doing
// that, so this function may return a false positive. Since the upper
@@ -175,8 +201,8 @@ bool SSLClientSocket::IsConnected() const {
return completed_handshake_ && transport_->IsConnected();
}
-int SSLClientSocket::Read(char* buf, int buf_len,
- CompletionCallback* callback) {
+int SSLClientSocketWin::Read(char* buf, int buf_len,
+ CompletionCallback* callback) {
DCHECK(completed_handshake_);
DCHECK(next_state_ == STATE_NONE);
DCHECK(!user_callback_);
@@ -213,8 +239,8 @@ int SSLClientSocket::Read(char* buf, int buf_len,
return rv;
}
-int SSLClientSocket::Write(const char* buf, int buf_len,
- CompletionCallback* callback) {
+int SSLClientSocketWin::Write(const char* buf, int buf_len,
+ CompletionCallback* callback) {
DCHECK(completed_handshake_);
DCHECK(next_state_ == STATE_NONE);
DCHECK(!user_callback_);
@@ -229,32 +255,7 @@ int SSLClientSocket::Write(const char* buf, int buf_len,
return rv;
}
-void SSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
- SECURITY_STATUS status = SEC_E_OK;
- if (server_cert_ == NULL) {
- status = QueryContextAttributes(&ctxt_,
- SECPKG_ATTR_REMOTE_CERT_CONTEXT,
- &server_cert_);
- }
- if (status == SEC_E_OK) {
- DCHECK(server_cert_);
- PCCERT_CONTEXT dup_cert = CertDuplicateCertificateContext(server_cert_);
- ssl_info->cert = X509Certificate::CreateFromHandle(dup_cert);
- }
- SecPkgContext_ConnectionInfo connection_info;
- status = QueryContextAttributes(&ctxt_,
- SECPKG_ATTR_CONNECTION_INFO,
- &connection_info);
- if (status == SEC_E_OK) {
- // TODO(wtc): compute the overall security strength, taking into account
- // dwExchStrength and dwHashStrength. dwExchStrength needs to be
- // normalized.
- ssl_info->security_bits = connection_info.dwCipherStrength;
- }
- ssl_info->cert_status = server_cert_status_;
-}
-
-void SSLClientSocket::DoCallback(int rv) {
+void SSLClientSocketWin::DoCallback(int rv) {
DCHECK(rv != ERR_IO_PENDING);
DCHECK(user_callback_);
@@ -264,13 +265,13 @@ void SSLClientSocket::DoCallback(int rv) {
c->Run(rv);
}
-void SSLClientSocket::OnIOComplete(int result) {
+void SSLClientSocketWin::OnIOComplete(int result) {
int rv = DoLoop(result);
if (rv != ERR_IO_PENDING)
DoCallback(rv);
}
-int SSLClientSocket::DoLoop(int last_io_result) {
+int SSLClientSocketWin::DoLoop(int last_io_result) {
DCHECK(next_state_ != STATE_NONE);
int rv = last_io_result;
do {
@@ -319,12 +320,12 @@ int SSLClientSocket::DoLoop(int last_io_result) {
return rv;
}
-int SSLClientSocket::DoConnect() {
+int SSLClientSocketWin::DoConnect() {
next_state_ = STATE_CONNECT_COMPLETE;
return transport_->Connect(&io_callback_);
}
-int SSLClientSocket::DoConnectComplete(int result) {
+int SSLClientSocketWin::DoConnectComplete(int result) {
if (result < 0)
return result;
@@ -337,11 +338,11 @@ int SSLClientSocket::DoConnectComplete(int result) {
// The global system registry settings take precedence over the value of
// schannel_cred.grbitEnabledProtocols.
schannel_cred.grbitEnabledProtocols = 0;
- if (protocol_version_mask_ & SSL2)
+ if (ssl_config_.ssl2_enabled)
schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL2;
- if (protocol_version_mask_ & SSL3)
+ if (ssl_config_.ssl3_enabled)
schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL3;
- if (protocol_version_mask_ & TLS1)
+ if (ssl_config_.tls1_enabled)
schannel_cred.grbitEnabledProtocols |= SP_PROT_TLS1;
// The default (0) means Schannel selects the protocol, rather than no
// protocols are selected. So we have to fail here.
@@ -429,7 +430,7 @@ int SSLClientSocket::DoConnectComplete(int result) {
return OK;
}
-int SSLClientSocket::DoHandshakeRead() {
+int SSLClientSocketWin::DoHandshakeRead() {
next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
if (!recv_buffer_.get())
@@ -446,7 +447,7 @@ int SSLClientSocket::DoHandshakeRead() {
return transport_->Read(buf, buf_len, &io_callback_);
}
-int SSLClientSocket::DoHandshakeReadComplete(int result) {
+int SSLClientSocketWin::DoHandshakeReadComplete(int result) {
if (result < 0)
return result;
if (result == 0 && !ignore_ok_result_)
@@ -576,7 +577,7 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) {
return OK;
}
-int SSLClientSocket::DoHandshakeWrite() {
+int SSLClientSocketWin::DoHandshakeWrite() {
next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
// We should have something to send.
@@ -589,7 +590,7 @@ int SSLClientSocket::DoHandshakeWrite() {
return transport_->Write(buf, buf_len, &io_callback_);
}
-int SSLClientSocket::DoHandshakeWriteComplete(int result) {
+int SSLClientSocketWin::DoHandshakeWriteComplete(int result) {
if (result < 0)
return result;
@@ -615,7 +616,7 @@ int SSLClientSocket::DoHandshakeWriteComplete(int result) {
return OK;
}
-int SSLClientSocket::DoPayloadRead() {
+int SSLClientSocketWin::DoPayloadRead() {
next_state_ = STATE_PAYLOAD_READ_COMPLETE;
DCHECK(recv_buffer_.get());
@@ -631,7 +632,7 @@ int SSLClientSocket::DoPayloadRead() {
return transport_->Read(buf, buf_len, &io_callback_);
}
-int SSLClientSocket::DoPayloadReadComplete(int result) {
+int SSLClientSocketWin::DoPayloadReadComplete(int result) {
if (result < 0)
return result;
if (result == 0 && !ignore_ok_result_) {
@@ -728,7 +729,7 @@ int SSLClientSocket::DoPayloadReadComplete(int result) {
return len;
}
-int SSLClientSocket::DoPayloadEncrypt() {
+int SSLClientSocketWin::DoPayloadEncrypt() {
DCHECK(user_buf_);
DCHECK(user_buf_len_ > 0);
@@ -777,7 +778,7 @@ int SSLClientSocket::DoPayloadEncrypt() {
return OK;
}
-int SSLClientSocket::DoPayloadWrite() {
+int SSLClientSocketWin::DoPayloadWrite() {
next_state_ = STATE_PAYLOAD_WRITE_COMPLETE;
// We should have something to send.
@@ -790,7 +791,7 @@ int SSLClientSocket::DoPayloadWrite() {
return transport_->Write(buf, buf_len, &io_callback_);
}
-int SSLClientSocket::DoPayloadWriteComplete(int result) {
+int SSLClientSocketWin::DoPayloadWriteComplete(int result) {
if (result < 0)
return result;
@@ -815,7 +816,7 @@ int SSLClientSocket::DoPayloadWriteComplete(int result) {
return OK;
}
-int SSLClientSocket::DidCompleteHandshake() {
+int SSLClientSocketWin::DidCompleteHandshake() {
SECURITY_STATUS status = QueryContextAttributes(
&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
if (status != SEC_E_OK) {
@@ -839,7 +840,7 @@ int SSLClientSocket::DidCompleteHandshake() {
return rv;
}
-int SSLClientSocket::VerifyServerCert() {
+int SSLClientSocketWin::VerifyServerCert() {
DCHECK(server_cert_);
// Build and validate certificate chain.
diff --git a/net/base/ssl_client_socket_win.h b/net/base/ssl_client_socket_win.h
new file mode 100644
index 0000000..403e7f3
--- /dev/null
+++ b/net/base/ssl_client_socket_win.h
@@ -0,0 +1,139 @@
+// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_BASE_SSL_CLIENT_SOCKET_WIN_H_
+#define NET_BASE_SSL_CLIENT_SOCKET_WIN_H_
+
+#define SECURITY_WIN32 // Needs to be defined before including security.h
+
+#include <windows.h>
+#include <wincrypt.h>
+#include <security.h>
+
+#include <string>
+
+#include "base/scoped_ptr.h"
+#include "net/base/completion_callback.h"
+#include "net/base/ssl_client_socket.h"
+#include "net/base/ssl_config_service.h"
+
+namespace net {
+
+// An SSL client socket implemented with the Windows Schannel.
+class SSLClientSocketWin : public SSLClientSocket {
+ public:
+ // Takes ownership of the transport_socket, which may already be connected.
+ // The given hostname will be compared with the name(s) in the server's
+ // certificate during the SSL handshake. ssl_config specifies the SSL
+ // settings.
+ SSLClientSocketWin(ClientSocket* transport_socket,
+ const std::string& hostname,
+ const SSLConfig& ssl_config);
+ ~SSLClientSocketWin();
+
+ // SSLClientSocket methods:
+ virtual void GetSSLInfo(SSLInfo* ssl_info);
+
+ // ClientSocket methods:
+ virtual int Connect(CompletionCallback* callback);
+ virtual int ReconnectIgnoringLastError(CompletionCallback* callback);
+ virtual void Disconnect();
+ virtual bool IsConnected() const;
+
+ // Socket methods:
+ virtual int Read(char* buf, int buf_len, CompletionCallback* callback);
+ virtual int Write(const char* buf, int buf_len, CompletionCallback* callback);
+
+ private:
+ void DoCallback(int result);
+ void OnIOComplete(int result);
+
+ int DoLoop(int last_io_result);
+ int DoConnect();
+ int DoConnectComplete(int result);
+ int DoHandshakeRead();
+ int DoHandshakeReadComplete(int result);
+ int DoHandshakeWrite();
+ int DoHandshakeWriteComplete(int result);
+ int DoPayloadRead();
+ int DoPayloadReadComplete(int result);
+ int DoPayloadEncrypt();
+ int DoPayloadWrite();
+ int DoPayloadWriteComplete(int result);
+
+ int DidCompleteHandshake();
+ int VerifyServerCert();
+
+ CompletionCallbackImpl<SSLClientSocketWin> io_callback_;
+ scoped_ptr<ClientSocket> transport_;
+ std::string hostname_;
+ SSLConfig ssl_config_;
+
+ CompletionCallback* user_callback_;
+
+ // Used by both Read and Write functions.
+ char* user_buf_;
+ int user_buf_len_;
+
+ enum State {
+ STATE_NONE,
+ STATE_CONNECT,
+ STATE_CONNECT_COMPLETE,
+ STATE_HANDSHAKE_READ,
+ STATE_HANDSHAKE_READ_COMPLETE,
+ STATE_HANDSHAKE_WRITE,
+ STATE_HANDSHAKE_WRITE_COMPLETE,
+ STATE_PAYLOAD_ENCRYPT,
+ STATE_PAYLOAD_WRITE,
+ STATE_PAYLOAD_WRITE_COMPLETE,
+ STATE_PAYLOAD_READ,
+ STATE_PAYLOAD_READ_COMPLETE,
+ };
+ State next_state_;
+
+ SecPkgContext_StreamSizes stream_sizes_;
+ PCCERT_CONTEXT server_cert_;
+ int server_cert_status_;
+
+ CredHandle creds_;
+ CtxtHandle ctxt_;
+ SecBuffer send_buffer_;
+ scoped_array<char> payload_send_buffer_;
+ int payload_send_buffer_len_;
+ int bytes_sent_;
+
+ // recv_buffer_ holds the received ciphertext. Since Schannel decrypts
+ // data in place, sometimes recv_buffer_ may contain decrypted plaintext and
+ // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record
+ // at a time.)
+ //
+ // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of
+ // recv_buffer_, ready to be passed to DecryptMessage.
+ scoped_array<char> recv_buffer_;
+ char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_
+ int bytes_decrypted_; // The number of bytes of decrypted plaintext.
+ char* received_ptr_; // Points to the received ciphertext in recv_buffer_
+ int bytes_received_; // The number of bytes of received ciphertext.
+
+ bool completed_handshake_;
+
+ // Only used in the STATE_HANDSHAKE_READ_COMPLETE and
+ // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK
+ // should be ignored, to prevent it from being interpreted as EOF.
+ //
+ // The reason we need this flag is that OK means not only "0 bytes of data
+ // were read" but also EOF. We set ignore_ok_result_ to true when we need
+ // to continue processing previously read data without reading more data.
+ // We have to pass a 'result' of OK to the DoLoop method, and don't want it
+ // to be interpreted as EOF.
+ bool ignore_ok_result_;
+
+ // True if the user has no client certificate.
+ bool no_client_cert_;
+};
+
+} // namespace net
+
+#endif // NET_BASE_SSL_CLIENT_SOCKET_WIN_H_
+
diff --git a/net/build/net.vcproj b/net/build/net.vcproj
index b56b05f..b711b49 100644
--- a/net/build/net.vcproj
+++ b/net/build/net.vcproj
@@ -433,11 +433,15 @@
>
</File>
<File
- RelativePath="..\base\ssl_client_socket.cc"
+ RelativePath="..\base\ssl_client_socket.h"
>
</File>
<File
- RelativePath="..\base\ssl_client_socket.h"
+ RelativePath="..\base\ssl_client_socket_win.cc"
+ >
+ </File>
+ <File
+ RelativePath="..\base\ssl_client_socket_win.h"
>
</File>
<File
diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc
index 9d2c398..168b99d 100644
--- a/net/http/http_network_transaction.cc
+++ b/net/http/http_network_transaction.cc
@@ -13,9 +13,7 @@
#include "net/base/host_resolver.h"
#include "net/base/load_flags.h"
#include "net/base/net_util.h"
-#if defined(OS_WIN)
#include "net/base/ssl_client_socket.h"
-#endif
#include "net/base/upload_data_stream.h"
#include "net/http/http_auth.h"
#include "net/http/http_auth_handler.h"
@@ -58,12 +56,7 @@ HttpNetworkTransaction::HttpNetworkTransaction(HttpNetworkSession* session,
read_buf_(NULL),
read_buf_len_(0),
next_state_(STATE_NONE) {
-#if defined(OS_WIN)
- // TODO(wtc): Use SSL settings (bug 3003).
- ssl_version_mask_ = SSLClientSocket::SSL3 | SSLClientSocket::TLS1;
-#else
- ssl_version_mask_ = 0; // A dummy value so that the code compiles.
-#endif
+ // TODO(wtc): Initialize ssl_config_with SSL settings (bug 3003).
}
void HttpNetworkTransaction::Destroy() {
@@ -89,7 +82,7 @@ int HttpNetworkTransaction::RestartIgnoringLastError(
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING)
user_callback_ = callback;
- return rv;
+ return rv;
}
int HttpNetworkTransaction::RestartWithAuth(
@@ -482,7 +475,7 @@ int HttpNetworkTransaction::DoConnect() {
// wrapper socket now. Otherwise, we need to first issue a CONNECT request.
if (using_ssl_ && !using_tunnel_)
s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(),
- ssl_version_mask_);
+ ssl_config_);
connection_.set_socket(s);
return connection_.socket()->Connect(&io_callback_);
@@ -510,7 +503,7 @@ int HttpNetworkTransaction::DoSSLConnectOverTunnel() {
// Add a SSL socket on top of our existing transport socket.
ClientSocket* s = connection_.release_socket();
s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(),
- ssl_version_mask_);
+ ssl_config_);
connection_.set_socket(s);
return connection_.socket()->Connect(&io_callback_);
}
@@ -834,13 +827,11 @@ int HttpNetworkTransaction::DidReadResponseHeaders() {
}
}
-#if defined(OS_WIN)
if (using_ssl_ && !establishing_tunnel_) {
SSLClientSocket* ssl_socket =
reinterpret_cast<SSLClientSocket*>(connection_.socket());
ssl_socket->GetSSLInfo(&response_.ssl_info);
}
-#endif
return OK;
}
@@ -869,25 +860,22 @@ int HttpNetworkTransaction::HandleCertificateError(int error) {
}
}
-#if defined(OS_WIN)
if (error != OK) {
SSLClientSocket* ssl_socket =
reinterpret_cast<SSLClientSocket*>(connection_.socket());
ssl_socket->GetSSLInfo(&response_.ssl_info);
}
-#endif
return error;
}
int HttpNetworkTransaction::HandleSSLHandshakeError(int error) {
-#if defined(OS_WIN)
switch (error) {
case ERR_SSL_PROTOCOL_ERROR:
case ERR_SSL_VERSION_OR_CIPHER_MISMATCH:
- if (ssl_version_mask_ & SSLClientSocket::TLS1) {
+ if (ssl_config_.tls1_enabled) {
// This could be a TLS-intolerant server or an SSL 3.0 server that
// chose a TLS-only cipher suite. Turn off TLS 1.0 and retry.
- ssl_version_mask_ &= ~SSLClientSocket::TLS1;
+ ssl_config_.tls1_enabled = false;
connection_.set_socket(NULL);
connection_.Reset();
next_state_ = STATE_INIT_CONNECTION;
@@ -895,7 +883,6 @@ int HttpNetworkTransaction::HandleSSLHandshakeError(int error) {
}
break;
}
-#endif
return error;
}
@@ -1001,7 +988,7 @@ void HttpNetworkTransaction::AddAuthorizationHeader(HttpAuth::Target target) {
// Add auth data to cache
session_->auth_cache()->Add(auth_cache_key_[target], auth_data_[target]);
-
+
// Add a Authorization/Proxy-Authorization header line.
std::string credentials = auth_handler_[target]->GenerateCredentials(
auth_data_[target]->username,
diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h
index bbbfb74..475056e 100644
--- a/net/http/http_network_transaction.h
+++ b/net/http/http_network_transaction.h
@@ -11,6 +11,7 @@
#include "net/base/address_list.h"
#include "net/base/client_socket_handle.h"
#include "net/base/host_resolver.h"
+#include "net/base/ssl_config_service.h"
#include "net/http/http_auth.h"
#include "net/http/http_auth_handler.h"
#include "net/http/http_response_info.h"
@@ -186,7 +187,7 @@ class HttpNetworkTransaction : public HttpTransaction {
// the real request/response of the transaction.
bool establishing_tunnel_;
- int ssl_version_mask_;
+ SSLConfig ssl_config_;
std::string request_headers_;
size_t request_headers_bytes_sent_;
diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc
index 56fb5cd..9c2d6a7 100644
--- a/net/http/http_network_transaction_unittest.cc
+++ b/net/http/http_network_transaction_unittest.cc
@@ -137,7 +137,7 @@ class MockTCPClientSocket : public net::ClientSocket {
// Not using mock writes; succeed synchronously.
if (!data_->writes)
return buf_len;
-
+
// Check that what we are writing matches the expectation.
// Then give the mocked return value.
MockWrite& w = data_->writes[write_index_];
@@ -185,10 +185,10 @@ class MockClientSocketFactory : public net::ClientSocketFactory {
const net::AddressList& addresses) {
return new MockTCPClientSocket(addresses);
}
- virtual net::ClientSocket* CreateSSLClientSocket(
+ virtual net::SSLClientSocket* CreateSSLClientSocket(
net::ClientSocket* transport_socket,
const std::string& hostname,
- int protocol_version_mask) {
+ const net::SSLConfig& ssl_config) {
return NULL;
}
};
@@ -623,7 +623,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuth) {
MockRead("HTTP/1.0 401 Unauthorized\r\n"),
// Give a couple authenticate options (only the middle one is actually
// supported).
- MockRead("WWW-Authenticate: Basic\r\n"), // Malformed
+ MockRead("WWW-Authenticate: Basic\r\n"), // Malformed
MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
MockRead("WWW-Authenticate: UNSUPPORTED realm=\"FOO\"\r\n"),
MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"),
@@ -717,7 +717,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyThenServer) {
MockRead("HTTP/1.0 407 Unauthorized\r\n"),
// Give a couple authenticate options (only the middle one is actually
// supported).
- MockRead("Proxy-Authenticate: Basic\r\n"), // Malformed
+ MockRead("Proxy-Authenticate: Basic\r\n"), // Malformed
MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
MockRead("Proxy-Authenticate: UNSUPPORTED realm=\"FOO\"\r\n"),
MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"),
@@ -745,7 +745,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyThenServer) {
MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"),
MockRead("Content-Length: 2000\r\n\r\n"),
- MockRead(false, net::ERR_FAILED), // Won't be reached.
+ MockRead(false, net::ERR_FAILED), // Won't be reached.
};
// After calling trans->RestartWithAuth() the second time, we should send
diff --git a/net/net.xcodeproj/project.pbxproj b/net/net.xcodeproj/project.pbxproj
index 15f2211..a594020 100644
--- a/net/net.xcodeproj/project.pbxproj
+++ b/net/net.xcodeproj/project.pbxproj
@@ -491,7 +491,6 @@
7BED32940E5A181C00A747DB /* ssl_config_service.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_config_service.cc; sourceTree = "<group>"; };
7BED32950E5A181C00A747DB /* ssl_client_socket_unittest.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_client_socket_unittest.cc; sourceTree = "<group>"; };
7BED32960E5A181C00A747DB /* ssl_client_socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ssl_client_socket.h; sourceTree = "<group>"; };
- 7BED32970E5A181C00A747DB /* ssl_client_socket.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_client_socket.cc; sourceTree = "<group>"; };
7BED32980E5A181C00A747DB /* socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = socket.h; sourceTree = "<group>"; };
7BED32990E5A181C00A747DB /* registry_controlled_domain_unittest.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = registry_controlled_domain_unittest.cc; sourceTree = "<group>"; };
7BED329A0E5A181C00A747DB /* registry_controlled_domain.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = registry_controlled_domain.h; sourceTree = "<group>"; };
@@ -946,7 +945,6 @@
E49DD2E80E892F8C003C7A87 /* sdch_manager.cc */,
E49DD2E90E892F8C003C7A87 /* sdch_manager.h */,
7BED32980E5A181C00A747DB /* socket.h */,
- 7BED32970E5A181C00A747DB /* ssl_client_socket.cc */,
7BED32960E5A181C00A747DB /* ssl_client_socket.h */,
7BED32950E5A181C00A747DB /* ssl_client_socket_unittest.cc */,
7BED32940E5A181C00A747DB /* ssl_config_service.cc */,