diff options
author | wtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2008-10-15 00:20:11 +0000 |
---|---|---|
committer | wtc@google.com <wtc@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2008-10-15 00:20:11 +0000 |
commit | aaead5019627818c93693fdb6ec04d47b47c17f2 (patch) | |
tree | 6ff15880c597bb59a8c3def51d13ce492a4bb405 | |
parent | 1ad083f293cd321fa7d7c8f14e71816571c6c54f (diff) | |
download | chromium_src-aaead5019627818c93693fdb6ec04d47b47c17f2.zip chromium_src-aaead5019627818c93693fdb6ec04d47b47c17f2.tar.gz chromium_src-aaead5019627818c93693fdb6ec04d47b47c17f2.tar.bz2 |
Turn SSLClientSocket into an interface.
The original ssl_client_socket.{h,cc} are renamed
ssl_client_socket_win.{h,cc}.
The new ssl_client_socket.h defines the SSLClientSocket
interface, which simply extends the ClientSocket interface
with a new GetSSLInfo method.
ClientSocketFactory::CreateSSLClientSocket returns
SSLClientSocket* instead of ClientSocket*.
Replace the SSL protocol version mask parameter to the
constructor and factory method by a SSLConfig parameter.
R=darin
Review URL: http://codereview.chromium.org/7304
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@3387 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | net/SConscript | 2 | ||||
-rw-r--r-- | net/base/client_socket_factory.cc | 9 | ||||
-rw-r--r-- | net/base/client_socket_factory.h | 9 | ||||
-rw-r--r-- | net/base/ssl_client_socket.h | 125 | ||||
-rw-r--r-- | net/base/ssl_client_socket_unittest.cc | 60 | ||||
-rw-r--r-- | net/base/ssl_client_socket_win.cc (renamed from net/base/ssl_client_socket.cc) | 119 | ||||
-rw-r--r-- | net/base/ssl_client_socket_win.h | 139 | ||||
-rw-r--r-- | net/build/net.vcproj | 8 | ||||
-rw-r--r-- | net/http/http_network_transaction.cc | 27 | ||||
-rw-r--r-- | net/http/http_network_transaction.h | 3 | ||||
-rw-r--r-- | net/http/http_network_transaction_unittest.cc | 12 | ||||
-rw-r--r-- | net/net.xcodeproj/project.pbxproj | 2 |
12 files changed, 266 insertions, 249 deletions
diff --git a/net/SConscript b/net/SConscript index b3dd629..37d6c18 100644 --- a/net/SConscript +++ b/net/SConscript @@ -94,7 +94,7 @@ if env['PLATFORM'] == 'win32': 'base/directory_lister.cc', 'base/dns_resolution_observer.cc', 'base/listen_socket.cc', - 'base/ssl_client_socket.cc', + 'base/ssl_client_socket_win.cc', 'base/ssl_config_service.cc', 'base/tcp_client_socket.cc', 'base/telnet_server.cc', diff --git a/net/base/client_socket_factory.cc b/net/base/client_socket_factory.cc index 43df0e6..10f24df 100644 --- a/net/base/client_socket_factory.cc +++ b/net/base/client_socket_factory.cc @@ -7,7 +7,7 @@ #include "base/singleton.h" #include "build/build_config.h" #if defined(OS_WIN) -#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_client_socket_win.h" #endif #include "net/base/tcp_client_socket.h" @@ -20,13 +20,12 @@ class DefaultClientSocketFactory : public ClientSocketFactory { return new TCPClientSocket(addresses); } - virtual ClientSocket* CreateSSLClientSocket( + virtual SSLClientSocket* CreateSSLClientSocket( ClientSocket* transport_socket, const std::string& hostname, - int protocol_version_mask) { + const SSLConfig& ssl_config) { #if defined(OS_WIN) - return new SSLClientSocket(transport_socket, hostname, - protocol_version_mask); + return new SSLClientSocketWin(transport_socket, hostname, ssl_config); #else // TODO(pinkerton): turn on when we port SSL socket from win32 NOTIMPLEMENTED(); diff --git a/net/base/client_socket_factory.h b/net/base/client_socket_factory.h index ed371a3..3a3bb84 100644 --- a/net/base/client_socket_factory.h +++ b/net/base/client_socket_factory.h @@ -11,6 +11,8 @@ namespace net { class AddressList; class ClientSocket; +class SSLClientSocket; +struct SSLConfig; // An interface used to instantiate ClientSocket objects. Used to facilitate // testing code with mock socket implementations. @@ -21,13 +23,10 @@ class ClientSocketFactory { virtual ClientSocket* CreateTCPClientSocket( const AddressList& addresses) = 0; - // protocol_version_mask is a bitmask that specifies which versions of the - // SSL protocol (SSL 2.0, SSL 3.0, and TLS 1.0) should be enabled. The bit - // flags are defined in net/base/ssl_client_socket.h. - virtual ClientSocket* CreateSSLClientSocket( + virtual SSLClientSocket* CreateSSLClientSocket( ClientSocket* transport_socket, const std::string& hostname, - int protocol_version_mask) = 0; + const SSLConfig& ssl_config) = 0; // Returns the default ClientSocketFactory. static ClientSocketFactory* GetDefaultFactory(); diff --git a/net/base/ssl_client_socket.h b/net/base/ssl_client_socket.h index 100e514..dca5ef3 100644 --- a/net/base/ssl_client_socket.h +++ b/net/base/ssl_client_socket.h @@ -5,17 +5,7 @@ #ifndef NET_BASE_SSL_CLIENT_SOCKET_H_ #define NET_BASE_SSL_CLIENT_SOCKET_H_ -#define SECURITY_WIN32 // Needs to be defined before including security.h - -#include <windows.h> -#include <wincrypt.h> -#include <security.h> - -#include <string> - -#include "base/scoped_ptr.h" #include "net/base/client_socket.h" -#include "net/base/completion_callback.h" namespace net { @@ -30,121 +20,8 @@ class SSLInfo; // class SSLClientSocket : public ClientSocket { public: - enum { - SSL2 = 1 << 0, - SSL3 = 1 << 1, - TLS1 = 1 << 2 - }; - - // Takes ownership of the transport_socket, which may already be connected. - // The given hostname will be compared with the name(s) in the server's - // certificate during the SSL handshake. protocol_version_mask is a bitwise - // OR of SSL2, SSL3, and TLS1 that specifies which versions of the SSL - // protocol should be enabled. - SSLClientSocket(ClientSocket* transport_socket, - const std::string& hostname, - int protocol_version_mask); - ~SSLClientSocket(); - - // ClientSocket methods: - virtual int Connect(CompletionCallback* callback); - virtual int ReconnectIgnoringLastError(CompletionCallback* callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - - // Socket methods: - virtual int Read(char* buf, int buf_len, CompletionCallback* callback); - virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); - // Gets the SSL connection information of the socket. - void GetSSLInfo(SSLInfo* ssl_info); - - private: - void DoCallback(int result); - void OnIOComplete(int result); - - int DoLoop(int last_io_result); - int DoConnect(); - int DoConnectComplete(int result); - int DoHandshakeRead(); - int DoHandshakeReadComplete(int result); - int DoHandshakeWrite(); - int DoHandshakeWriteComplete(int result); - int DoPayloadRead(); - int DoPayloadReadComplete(int result); - int DoPayloadEncrypt(); - int DoPayloadWrite(); - int DoPayloadWriteComplete(int result); - - int DidCompleteHandshake(); - int VerifyServerCert(); - - CompletionCallbackImpl<SSLClientSocket> io_callback_; - scoped_ptr<ClientSocket> transport_; - std::string hostname_; - int protocol_version_mask_; - - CompletionCallback* user_callback_; - - // Used by both Read and Write functions. - char* user_buf_; - int user_buf_len_; - - enum State { - STATE_NONE, - STATE_CONNECT, - STATE_CONNECT_COMPLETE, - STATE_HANDSHAKE_READ, - STATE_HANDSHAKE_READ_COMPLETE, - STATE_HANDSHAKE_WRITE, - STATE_HANDSHAKE_WRITE_COMPLETE, - STATE_PAYLOAD_ENCRYPT, - STATE_PAYLOAD_WRITE, - STATE_PAYLOAD_WRITE_COMPLETE, - STATE_PAYLOAD_READ, - STATE_PAYLOAD_READ_COMPLETE, - }; - State next_state_; - - SecPkgContext_StreamSizes stream_sizes_; - PCCERT_CONTEXT server_cert_; - int server_cert_status_; - - CredHandle creds_; - CtxtHandle ctxt_; - SecBuffer send_buffer_; - scoped_array<char> payload_send_buffer_; - int payload_send_buffer_len_; - int bytes_sent_; - - // recv_buffer_ holds the received ciphertext. Since Schannel decrypts - // data in place, sometimes recv_buffer_ may contain decrypted plaintext and - // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record - // at a time.) - // - // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of - // recv_buffer_, ready to be passed to DecryptMessage. - scoped_array<char> recv_buffer_; - char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_ - int bytes_decrypted_; // The number of bytes of decrypted plaintext. - char* received_ptr_; // Points to the received ciphertext in recv_buffer_ - int bytes_received_; // The number of bytes of received ciphertext. - - bool completed_handshake_; - - // Only used in the STATE_HANDSHAKE_READ_COMPLETE and - // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK - // should be ignored, to prevent it from being interpreted as EOF. - // - // The reason we need this flag is that OK means not only "0 bytes of data - // were read" but also EOF. We set ignore_ok_result_ to true when we need - // to continue processing previously read data without reading more data. - // We have to pass a 'result' of OK to the DoLoop method, and don't want it - // to be interpreted as EOF. - bool ignore_ok_result_; - - // True if the user has no client certificate. - bool no_client_cert_; + virtual void GetSSLInfo(SSLInfo* ssl_info) = 0; }; } // namespace net diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc index 2aba7ab..d1f1f82 100644 --- a/net/base/ssl_client_socket_unittest.cc +++ b/net/base/ssl_client_socket_unittest.cc @@ -3,9 +3,11 @@ // found in the LICENSE file. #include "net/base/address_list.h" +#include "net/base/client_socket_factory.h" #include "net/base/host_resolver.h" #include "net/base/net_errors.h" #include "net/base/ssl_client_socket.h" +#include "net/base/ssl_config_service.h" #include "net/base/tcp_client_socket.h" #include "net/base/test_completion_callback.h" #include "testing/gtest/include/gtest/gtest.h" @@ -14,10 +16,16 @@ namespace { -const unsigned int kDefaultSSLVersionMask = net::SSLClientSocket::SSL3 | - net::SSLClientSocket::TLS1; +const net::SSLConfig kDefaultSSLConfig; class SSLClientSocketTest : public testing::Test { + public: + SSLClientSocketTest() + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { + } + + protected: + net::ClientSocketFactory* socket_factory_; }; } // namespace @@ -34,12 +42,13 @@ TEST_F(SSLClientSocketTest, DISABLED_Connect) { int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(net::OK, rv); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr<net::SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - EXPECT_FALSE(sock.IsConnected()); + EXPECT_FALSE(sock->IsConnected()); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(net::ERR_IO_PENDING, rv); @@ -47,10 +56,10 @@ TEST_F(SSLClientSocketTest, DISABLED_Connect) { EXPECT_EQ(net::OK, rv); } - EXPECT_TRUE(sock.IsConnected()); + EXPECT_TRUE(sock->IsConnected()); - sock.Disconnect(); - EXPECT_FALSE(sock.IsConnected()); + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); } // bug 1354783 @@ -66,10 +75,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) { rv = callback.WaitForResult(); EXPECT_EQ(rv, net::OK); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr<net::SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -78,7 +88,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -88,7 +98,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read) { char buf[4096]; for (;;) { - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -110,10 +120,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) { int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(rv, net::OK); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr<net::SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -122,7 +133,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -132,7 +143,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_SmallChunks) { char buf[1]; for (;;) { - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -154,10 +165,11 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) { int rv = resolver.Resolve(hostname, 443, &addr, NULL); EXPECT_EQ(rv, net::OK); - net::SSLClientSocket sock(new net::TCPClientSocket(addr), hostname, - kDefaultSSLVersionMask); + scoped_ptr<net::SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + hostname, kDefaultSSLConfig)); - rv = sock.Connect(&callback); + rv = sock->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -166,7 +178,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -176,7 +188,7 @@ TEST_F(SSLClientSocketTest, DISABLED_Read_Interrupted) { // Do a partial read and then exit. This test should not crash! char buf[512]; - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) diff --git a/net/base/ssl_client_socket.cc b/net/base/ssl_client_socket_win.cc index d155009..1eeb090 100644 --- a/net/base/ssl_client_socket.cc +++ b/net/base/ssl_client_socket_win.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_client_socket_win.h" #include <schnlsp.h> @@ -90,14 +90,14 @@ static int MapNetErrorToCertStatus(int error) { // 64: >= SSL record trailer (16 or 20 have been observed) static const int kRecvBufferSize = (5 + 16*1024 + 64); -SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket, - const std::string& hostname, - int protocol_version_mask) +SSLClientSocketWin::SSLClientSocketWin(ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) #pragma warning(suppress: 4355) - : io_callback_(this, &SSLClientSocket::OnIOComplete), + : io_callback_(this, &SSLClientSocketWin::OnIOComplete), transport_(transport_socket), hostname_(hostname), - protocol_version_mask_(protocol_version_mask), + ssl_config_(ssl_config), user_callback_(NULL), user_buf_(NULL), user_buf_len_(0), @@ -119,11 +119,36 @@ SSLClientSocket::SSLClientSocket(ClientSocket* transport_socket, memset(&ctxt_, 0, sizeof(ctxt_)); } -SSLClientSocket::~SSLClientSocket() { +SSLClientSocketWin::~SSLClientSocketWin() { Disconnect(); } -int SSLClientSocket::Connect(CompletionCallback* callback) { +void SSLClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { + SECURITY_STATUS status = SEC_E_OK; + if (server_cert_ == NULL) { + status = QueryContextAttributes(&ctxt_, + SECPKG_ATTR_REMOTE_CERT_CONTEXT, + &server_cert_); + } + if (status == SEC_E_OK) { + DCHECK(server_cert_); + PCCERT_CONTEXT dup_cert = CertDuplicateCertificateContext(server_cert_); + ssl_info->cert = X509Certificate::CreateFromHandle(dup_cert); + } + SecPkgContext_ConnectionInfo connection_info; + status = QueryContextAttributes(&ctxt_, + SECPKG_ATTR_CONNECTION_INFO, + &connection_info); + if (status == SEC_E_OK) { + // TODO(wtc): compute the overall security strength, taking into account + // dwExchStrength and dwHashStrength. dwExchStrength needs to be + // normalized. + ssl_info->security_bits = connection_info.dwCipherStrength; + } + ssl_info->cert_status = server_cert_status_; +} + +int SSLClientSocketWin::Connect(CompletionCallback* callback) { DCHECK(transport_.get()); DCHECK(next_state_ == STATE_NONE); DCHECK(!user_callback_); @@ -135,12 +160,13 @@ int SSLClientSocket::Connect(CompletionCallback* callback) { return rv; } -int SSLClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) { +int SSLClientSocketWin::ReconnectIgnoringLastError( + CompletionCallback* callback) { // TODO(darin): implement me! return ERR_FAILED; } -void SSLClientSocket::Disconnect() { +void SSLClientSocketWin::Disconnect() { // TODO(wtc): Send SSL close_notify alert. completed_handshake_ = false; transport_->Disconnect(); @@ -165,7 +191,7 @@ void SSLClientSocket::Disconnect() { bytes_received_ = 0; } -bool SSLClientSocket::IsConnected() const { +bool SSLClientSocketWin::IsConnected() const { // Ideally, we should also check if we have received the close_notify alert // message from the server, and return false in that case. We're not doing // that, so this function may return a false positive. Since the upper @@ -175,8 +201,8 @@ bool SSLClientSocket::IsConnected() const { return completed_handshake_ && transport_->IsConnected(); } -int SSLClientSocket::Read(char* buf, int buf_len, - CompletionCallback* callback) { +int SSLClientSocketWin::Read(char* buf, int buf_len, + CompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK(next_state_ == STATE_NONE); DCHECK(!user_callback_); @@ -213,8 +239,8 @@ int SSLClientSocket::Read(char* buf, int buf_len, return rv; } -int SSLClientSocket::Write(const char* buf, int buf_len, - CompletionCallback* callback) { +int SSLClientSocketWin::Write(const char* buf, int buf_len, + CompletionCallback* callback) { DCHECK(completed_handshake_); DCHECK(next_state_ == STATE_NONE); DCHECK(!user_callback_); @@ -229,32 +255,7 @@ int SSLClientSocket::Write(const char* buf, int buf_len, return rv; } -void SSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { - SECURITY_STATUS status = SEC_E_OK; - if (server_cert_ == NULL) { - status = QueryContextAttributes(&ctxt_, - SECPKG_ATTR_REMOTE_CERT_CONTEXT, - &server_cert_); - } - if (status == SEC_E_OK) { - DCHECK(server_cert_); - PCCERT_CONTEXT dup_cert = CertDuplicateCertificateContext(server_cert_); - ssl_info->cert = X509Certificate::CreateFromHandle(dup_cert); - } - SecPkgContext_ConnectionInfo connection_info; - status = QueryContextAttributes(&ctxt_, - SECPKG_ATTR_CONNECTION_INFO, - &connection_info); - if (status == SEC_E_OK) { - // TODO(wtc): compute the overall security strength, taking into account - // dwExchStrength and dwHashStrength. dwExchStrength needs to be - // normalized. - ssl_info->security_bits = connection_info.dwCipherStrength; - } - ssl_info->cert_status = server_cert_status_; -} - -void SSLClientSocket::DoCallback(int rv) { +void SSLClientSocketWin::DoCallback(int rv) { DCHECK(rv != ERR_IO_PENDING); DCHECK(user_callback_); @@ -264,13 +265,13 @@ void SSLClientSocket::DoCallback(int rv) { c->Run(rv); } -void SSLClientSocket::OnIOComplete(int result) { +void SSLClientSocketWin::OnIOComplete(int result) { int rv = DoLoop(result); if (rv != ERR_IO_PENDING) DoCallback(rv); } -int SSLClientSocket::DoLoop(int last_io_result) { +int SSLClientSocketWin::DoLoop(int last_io_result) { DCHECK(next_state_ != STATE_NONE); int rv = last_io_result; do { @@ -319,12 +320,12 @@ int SSLClientSocket::DoLoop(int last_io_result) { return rv; } -int SSLClientSocket::DoConnect() { +int SSLClientSocketWin::DoConnect() { next_state_ = STATE_CONNECT_COMPLETE; return transport_->Connect(&io_callback_); } -int SSLClientSocket::DoConnectComplete(int result) { +int SSLClientSocketWin::DoConnectComplete(int result) { if (result < 0) return result; @@ -337,11 +338,11 @@ int SSLClientSocket::DoConnectComplete(int result) { // The global system registry settings take precedence over the value of // schannel_cred.grbitEnabledProtocols. schannel_cred.grbitEnabledProtocols = 0; - if (protocol_version_mask_ & SSL2) + if (ssl_config_.ssl2_enabled) schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL2; - if (protocol_version_mask_ & SSL3) + if (ssl_config_.ssl3_enabled) schannel_cred.grbitEnabledProtocols |= SP_PROT_SSL3; - if (protocol_version_mask_ & TLS1) + if (ssl_config_.tls1_enabled) schannel_cred.grbitEnabledProtocols |= SP_PROT_TLS1; // The default (0) means Schannel selects the protocol, rather than no // protocols are selected. So we have to fail here. @@ -429,7 +430,7 @@ int SSLClientSocket::DoConnectComplete(int result) { return OK; } -int SSLClientSocket::DoHandshakeRead() { +int SSLClientSocketWin::DoHandshakeRead() { next_state_ = STATE_HANDSHAKE_READ_COMPLETE; if (!recv_buffer_.get()) @@ -446,7 +447,7 @@ int SSLClientSocket::DoHandshakeRead() { return transport_->Read(buf, buf_len, &io_callback_); } -int SSLClientSocket::DoHandshakeReadComplete(int result) { +int SSLClientSocketWin::DoHandshakeReadComplete(int result) { if (result < 0) return result; if (result == 0 && !ignore_ok_result_) @@ -576,7 +577,7 @@ int SSLClientSocket::DoHandshakeReadComplete(int result) { return OK; } -int SSLClientSocket::DoHandshakeWrite() { +int SSLClientSocketWin::DoHandshakeWrite() { next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; // We should have something to send. @@ -589,7 +590,7 @@ int SSLClientSocket::DoHandshakeWrite() { return transport_->Write(buf, buf_len, &io_callback_); } -int SSLClientSocket::DoHandshakeWriteComplete(int result) { +int SSLClientSocketWin::DoHandshakeWriteComplete(int result) { if (result < 0) return result; @@ -615,7 +616,7 @@ int SSLClientSocket::DoHandshakeWriteComplete(int result) { return OK; } -int SSLClientSocket::DoPayloadRead() { +int SSLClientSocketWin::DoPayloadRead() { next_state_ = STATE_PAYLOAD_READ_COMPLETE; DCHECK(recv_buffer_.get()); @@ -631,7 +632,7 @@ int SSLClientSocket::DoPayloadRead() { return transport_->Read(buf, buf_len, &io_callback_); } -int SSLClientSocket::DoPayloadReadComplete(int result) { +int SSLClientSocketWin::DoPayloadReadComplete(int result) { if (result < 0) return result; if (result == 0 && !ignore_ok_result_) { @@ -728,7 +729,7 @@ int SSLClientSocket::DoPayloadReadComplete(int result) { return len; } -int SSLClientSocket::DoPayloadEncrypt() { +int SSLClientSocketWin::DoPayloadEncrypt() { DCHECK(user_buf_); DCHECK(user_buf_len_ > 0); @@ -777,7 +778,7 @@ int SSLClientSocket::DoPayloadEncrypt() { return OK; } -int SSLClientSocket::DoPayloadWrite() { +int SSLClientSocketWin::DoPayloadWrite() { next_state_ = STATE_PAYLOAD_WRITE_COMPLETE; // We should have something to send. @@ -790,7 +791,7 @@ int SSLClientSocket::DoPayloadWrite() { return transport_->Write(buf, buf_len, &io_callback_); } -int SSLClientSocket::DoPayloadWriteComplete(int result) { +int SSLClientSocketWin::DoPayloadWriteComplete(int result) { if (result < 0) return result; @@ -815,7 +816,7 @@ int SSLClientSocket::DoPayloadWriteComplete(int result) { return OK; } -int SSLClientSocket::DidCompleteHandshake() { +int SSLClientSocketWin::DidCompleteHandshake() { SECURITY_STATUS status = QueryContextAttributes( &ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); if (status != SEC_E_OK) { @@ -839,7 +840,7 @@ int SSLClientSocket::DidCompleteHandshake() { return rv; } -int SSLClientSocket::VerifyServerCert() { +int SSLClientSocketWin::VerifyServerCert() { DCHECK(server_cert_); // Build and validate certificate chain. diff --git a/net/base/ssl_client_socket_win.h b/net/base/ssl_client_socket_win.h new file mode 100644 index 0000000..403e7f3 --- /dev/null +++ b/net/base/ssl_client_socket_win.h @@ -0,0 +1,139 @@ +// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_SSL_CLIENT_SOCKET_WIN_H_ +#define NET_BASE_SSL_CLIENT_SOCKET_WIN_H_ + +#define SECURITY_WIN32 // Needs to be defined before including security.h + +#include <windows.h> +#include <wincrypt.h> +#include <security.h> + +#include <string> + +#include "base/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_config_service.h" + +namespace net { + +// An SSL client socket implemented with the Windows Schannel. +class SSLClientSocketWin : public SSLClientSocket { + public: + // Takes ownership of the transport_socket, which may already be connected. + // The given hostname will be compared with the name(s) in the server's + // certificate during the SSL handshake. ssl_config specifies the SSL + // settings. + SSLClientSocketWin(ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config); + ~SSLClientSocketWin(); + + // SSLClientSocket methods: + virtual void GetSSLInfo(SSLInfo* ssl_info); + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback); + virtual int ReconnectIgnoringLastError(CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback); + virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); + + private: + void DoCallback(int result); + void OnIOComplete(int result); + + int DoLoop(int last_io_result); + int DoConnect(); + int DoConnectComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + int DoPayloadRead(); + int DoPayloadReadComplete(int result); + int DoPayloadEncrypt(); + int DoPayloadWrite(); + int DoPayloadWriteComplete(int result); + + int DidCompleteHandshake(); + int VerifyServerCert(); + + CompletionCallbackImpl<SSLClientSocketWin> io_callback_; + scoped_ptr<ClientSocket> transport_; + std::string hostname_; + SSLConfig ssl_config_; + + CompletionCallback* user_callback_; + + // Used by both Read and Write functions. + char* user_buf_; + int user_buf_len_; + + enum State { + STATE_NONE, + STATE_CONNECT, + STATE_CONNECT_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_PAYLOAD_ENCRYPT, + STATE_PAYLOAD_WRITE, + STATE_PAYLOAD_WRITE_COMPLETE, + STATE_PAYLOAD_READ, + STATE_PAYLOAD_READ_COMPLETE, + }; + State next_state_; + + SecPkgContext_StreamSizes stream_sizes_; + PCCERT_CONTEXT server_cert_; + int server_cert_status_; + + CredHandle creds_; + CtxtHandle ctxt_; + SecBuffer send_buffer_; + scoped_array<char> payload_send_buffer_; + int payload_send_buffer_len_; + int bytes_sent_; + + // recv_buffer_ holds the received ciphertext. Since Schannel decrypts + // data in place, sometimes recv_buffer_ may contain decrypted plaintext and + // any undecrypted ciphertext. (Ciphertext is decrypted one full SSL record + // at a time.) + // + // If bytes_decrypted_ is 0, the received ciphertext is at the beginning of + // recv_buffer_, ready to be passed to DecryptMessage. + scoped_array<char> recv_buffer_; + char* decrypted_ptr_; // Points to the decrypted plaintext in recv_buffer_ + int bytes_decrypted_; // The number of bytes of decrypted plaintext. + char* received_ptr_; // Points to the received ciphertext in recv_buffer_ + int bytes_received_; // The number of bytes of received ciphertext. + + bool completed_handshake_; + + // Only used in the STATE_HANDSHAKE_READ_COMPLETE and + // STATE_PAYLOAD_READ_COMPLETE states. True if a 'result' argument of OK + // should be ignored, to prevent it from being interpreted as EOF. + // + // The reason we need this flag is that OK means not only "0 bytes of data + // were read" but also EOF. We set ignore_ok_result_ to true when we need + // to continue processing previously read data without reading more data. + // We have to pass a 'result' of OK to the DoLoop method, and don't want it + // to be interpreted as EOF. + bool ignore_ok_result_; + + // True if the user has no client certificate. + bool no_client_cert_; +}; + +} // namespace net + +#endif // NET_BASE_SSL_CLIENT_SOCKET_WIN_H_ + diff --git a/net/build/net.vcproj b/net/build/net.vcproj index b56b05f..b711b49 100644 --- a/net/build/net.vcproj +++ b/net/build/net.vcproj @@ -433,11 +433,15 @@ > </File> <File - RelativePath="..\base\ssl_client_socket.cc" + RelativePath="..\base\ssl_client_socket.h" > </File> <File - RelativePath="..\base\ssl_client_socket.h" + RelativePath="..\base\ssl_client_socket_win.cc" + > + </File> + <File + RelativePath="..\base\ssl_client_socket_win.h" > </File> <File diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index 9d2c398..168b99d 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -13,9 +13,7 @@ #include "net/base/host_resolver.h" #include "net/base/load_flags.h" #include "net/base/net_util.h" -#if defined(OS_WIN) #include "net/base/ssl_client_socket.h" -#endif #include "net/base/upload_data_stream.h" #include "net/http/http_auth.h" #include "net/http/http_auth_handler.h" @@ -58,12 +56,7 @@ HttpNetworkTransaction::HttpNetworkTransaction(HttpNetworkSession* session, read_buf_(NULL), read_buf_len_(0), next_state_(STATE_NONE) { -#if defined(OS_WIN) - // TODO(wtc): Use SSL settings (bug 3003). - ssl_version_mask_ = SSLClientSocket::SSL3 | SSLClientSocket::TLS1; -#else - ssl_version_mask_ = 0; // A dummy value so that the code compiles. -#endif + // TODO(wtc): Initialize ssl_config_with SSL settings (bug 3003). } void HttpNetworkTransaction::Destroy() { @@ -89,7 +82,7 @@ int HttpNetworkTransaction::RestartIgnoringLastError( int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) user_callback_ = callback; - return rv; + return rv; } int HttpNetworkTransaction::RestartWithAuth( @@ -482,7 +475,7 @@ int HttpNetworkTransaction::DoConnect() { // wrapper socket now. Otherwise, we need to first issue a CONNECT request. if (using_ssl_ && !using_tunnel_) s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(), - ssl_version_mask_); + ssl_config_); connection_.set_socket(s); return connection_.socket()->Connect(&io_callback_); @@ -510,7 +503,7 @@ int HttpNetworkTransaction::DoSSLConnectOverTunnel() { // Add a SSL socket on top of our existing transport socket. ClientSocket* s = connection_.release_socket(); s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(), - ssl_version_mask_); + ssl_config_); connection_.set_socket(s); return connection_.socket()->Connect(&io_callback_); } @@ -834,13 +827,11 @@ int HttpNetworkTransaction::DidReadResponseHeaders() { } } -#if defined(OS_WIN) if (using_ssl_ && !establishing_tunnel_) { SSLClientSocket* ssl_socket = reinterpret_cast<SSLClientSocket*>(connection_.socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); } -#endif return OK; } @@ -869,25 +860,22 @@ int HttpNetworkTransaction::HandleCertificateError(int error) { } } -#if defined(OS_WIN) if (error != OK) { SSLClientSocket* ssl_socket = reinterpret_cast<SSLClientSocket*>(connection_.socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); } -#endif return error; } int HttpNetworkTransaction::HandleSSLHandshakeError(int error) { -#if defined(OS_WIN) switch (error) { case ERR_SSL_PROTOCOL_ERROR: case ERR_SSL_VERSION_OR_CIPHER_MISMATCH: - if (ssl_version_mask_ & SSLClientSocket::TLS1) { + if (ssl_config_.tls1_enabled) { // This could be a TLS-intolerant server or an SSL 3.0 server that // chose a TLS-only cipher suite. Turn off TLS 1.0 and retry. - ssl_version_mask_ &= ~SSLClientSocket::TLS1; + ssl_config_.tls1_enabled = false; connection_.set_socket(NULL); connection_.Reset(); next_state_ = STATE_INIT_CONNECTION; @@ -895,7 +883,6 @@ int HttpNetworkTransaction::HandleSSLHandshakeError(int error) { } break; } -#endif return error; } @@ -1001,7 +988,7 @@ void HttpNetworkTransaction::AddAuthorizationHeader(HttpAuth::Target target) { // Add auth data to cache session_->auth_cache()->Add(auth_cache_key_[target], auth_data_[target]); - + // Add a Authorization/Proxy-Authorization header line. std::string credentials = auth_handler_[target]->GenerateCredentials( auth_data_[target]->username, diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index bbbfb74..475056e 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -11,6 +11,7 @@ #include "net/base/address_list.h" #include "net/base/client_socket_handle.h" #include "net/base/host_resolver.h" +#include "net/base/ssl_config_service.h" #include "net/http/http_auth.h" #include "net/http/http_auth_handler.h" #include "net/http/http_response_info.h" @@ -186,7 +187,7 @@ class HttpNetworkTransaction : public HttpTransaction { // the real request/response of the transaction. bool establishing_tunnel_; - int ssl_version_mask_; + SSLConfig ssl_config_; std::string request_headers_; size_t request_headers_bytes_sent_; diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 56fb5cd..9c2d6a7 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -137,7 +137,7 @@ class MockTCPClientSocket : public net::ClientSocket { // Not using mock writes; succeed synchronously. if (!data_->writes) return buf_len; - + // Check that what we are writing matches the expectation. // Then give the mocked return value. MockWrite& w = data_->writes[write_index_]; @@ -185,10 +185,10 @@ class MockClientSocketFactory : public net::ClientSocketFactory { const net::AddressList& addresses) { return new MockTCPClientSocket(addresses); } - virtual net::ClientSocket* CreateSSLClientSocket( + virtual net::SSLClientSocket* CreateSSLClientSocket( net::ClientSocket* transport_socket, const std::string& hostname, - int protocol_version_mask) { + const net::SSLConfig& ssl_config) { return NULL; } }; @@ -623,7 +623,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuth) { MockRead("HTTP/1.0 401 Unauthorized\r\n"), // Give a couple authenticate options (only the middle one is actually // supported). - MockRead("WWW-Authenticate: Basic\r\n"), // Malformed + MockRead("WWW-Authenticate: Basic\r\n"), // Malformed MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("WWW-Authenticate: UNSUPPORTED realm=\"FOO\"\r\n"), MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), @@ -717,7 +717,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyThenServer) { MockRead("HTTP/1.0 407 Unauthorized\r\n"), // Give a couple authenticate options (only the middle one is actually // supported). - MockRead("Proxy-Authenticate: Basic\r\n"), // Malformed + MockRead("Proxy-Authenticate: Basic\r\n"), // Malformed MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("Proxy-Authenticate: UNSUPPORTED realm=\"FOO\"\r\n"), MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), @@ -745,7 +745,7 @@ TEST_F(HttpNetworkTransactionTest, BasicAuthProxyThenServer) { MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), MockRead("Content-Length: 2000\r\n\r\n"), - MockRead(false, net::ERR_FAILED), // Won't be reached. + MockRead(false, net::ERR_FAILED), // Won't be reached. }; // After calling trans->RestartWithAuth() the second time, we should send diff --git a/net/net.xcodeproj/project.pbxproj b/net/net.xcodeproj/project.pbxproj index 15f2211..a594020 100644 --- a/net/net.xcodeproj/project.pbxproj +++ b/net/net.xcodeproj/project.pbxproj @@ -491,7 +491,6 @@ 7BED32940E5A181C00A747DB /* ssl_config_service.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_config_service.cc; sourceTree = "<group>"; }; 7BED32950E5A181C00A747DB /* ssl_client_socket_unittest.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_client_socket_unittest.cc; sourceTree = "<group>"; }; 7BED32960E5A181C00A747DB /* ssl_client_socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ssl_client_socket.h; sourceTree = "<group>"; }; - 7BED32970E5A181C00A747DB /* ssl_client_socket.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ssl_client_socket.cc; sourceTree = "<group>"; }; 7BED32980E5A181C00A747DB /* socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = socket.h; sourceTree = "<group>"; }; 7BED32990E5A181C00A747DB /* registry_controlled_domain_unittest.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = registry_controlled_domain_unittest.cc; sourceTree = "<group>"; }; 7BED329A0E5A181C00A747DB /* registry_controlled_domain.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = registry_controlled_domain.h; sourceTree = "<group>"; }; @@ -946,7 +945,6 @@ E49DD2E80E892F8C003C7A87 /* sdch_manager.cc */, E49DD2E90E892F8C003C7A87 /* sdch_manager.h */, 7BED32980E5A181C00A747DB /* socket.h */, - 7BED32970E5A181C00A747DB /* ssl_client_socket.cc */, 7BED32960E5A181C00A747DB /* ssl_client_socket.h */, 7BED32950E5A181C00A747DB /* ssl_client_socket_unittest.cc */, 7BED32940E5A181C00A747DB /* ssl_config_service.cc */, |