From bacff6543fbb0df029aae780589c4a7274b5ce34 Mon Sep 17 00:00:00 2001 From: "markus@chromium.org" Date: Tue, 31 Mar 2009 17:50:33 +0000 Subject: Resubmitted code from revision 12809. The bug in the Windows SSL stack that this code originally uncovered has been fixed in a separate changelist. git-svn-id: svn://svn.chromium.org/chrome/trunk/src@12876 0039d316-1c4b-4281-b951-d872f2087c98 --- net/base/client_socket.h | 5 - net/base/client_socket_pool_unittest.cc | 3 - net/base/ssl_client_socket.h | 3 +- net/base/ssl_client_socket_mac.cc | 14 +- net/base/ssl_client_socket_mac.h | 1 - net/base/ssl_client_socket_nss.cc | 129 +++++++-- net/base/ssl_client_socket_nss.h | 15 +- net/base/ssl_client_socket_unittest.cc | 62 ++++- net/base/ssl_client_socket_win.cc | 14 +- net/base/ssl_client_socket_win.h | 1 - net/base/ssl_config_service.h | 9 + net/base/tcp_client_socket.h | 1 - net/base/tcp_client_socket_libevent.cc | 5 - net/base/tcp_client_socket_win.cc | 5 - net/http/http_network_transaction.cc | 80 +++--- net/http/http_network_transaction.h | 16 +- net/http/http_network_transaction_unittest.cc | 371 ++++++++++++++++++++++---- net/url_request/url_request_unittest.cc | 66 ++++- net/url_request/url_request_unittest.h | 20 +- 19 files changed, 643 insertions(+), 177 deletions(-) (limited to 'net') diff --git a/net/base/client_socket.h b/net/base/client_socket.h index 4d25b38..3553d52 100644 --- a/net/base/client_socket.h +++ b/net/base/client_socket.h @@ -33,11 +33,6 @@ class ClientSocket : public Socket { // virtual int Connect(CompletionCallback* callback) = 0; - // If a non-fatal error occurs during Connect, the consumer can call this - // method to re-Connect ignoring the error that occured. This call is only - // valid for certain errors. - virtual int ReconnectIgnoringLastError(CompletionCallback* callback) = 0; - // Called to disconnect a connected socket. Does nothing if the socket is // already disconnected. After calling Disconnect it is possible to call // Connect again to establish a new connection. diff --git a/net/base/client_socket_pool_unittest.cc b/net/base/client_socket_pool_unittest.cc index 8849781..dd51819 100644 --- a/net/base/client_socket_pool_unittest.cc +++ b/net/base/client_socket_pool_unittest.cc @@ -32,9 +32,6 @@ class MockClientSocket : public net::ClientSocket { connected_ = true; return net::OK; } - virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { - return net::ERR_FAILED; - } virtual void Disconnect() { connected_ = false; } diff --git a/net/base/ssl_client_socket.h b/net/base/ssl_client_socket.h index 21b120b..8c9f05b 100644 --- a/net/base/ssl_client_socket.h +++ b/net/base/ssl_client_socket.h @@ -15,8 +15,7 @@ class SSLInfo; // // NOTE: The SSL handshake occurs within the Connect method after a TCP // connection is established. If a SSL error occurs during the handshake, -// Connect will fail. The consumer may choose to ignore certain SSL errors, -// such as a name mismatch, by calling ReconnectIgnoringLastError. +// Connect will fail. // class SSLClientSocket : public ClientSocket { public: diff --git a/net/base/ssl_client_socket_mac.cc b/net/base/ssl_client_socket_mac.cc index 5ed6c67..c43a4d1 100644 --- a/net/base/ssl_client_socket_mac.cc +++ b/net/base/ssl_client_socket_mac.cc @@ -279,12 +279,6 @@ int SSLClientSocketMac::Connect(CompletionCallback* callback) { return rv; } -int SSLClientSocketMac::ReconnectIgnoringLastError( - CompletionCallback* callback) { - // TODO(darin): implement me! - return ERR_FAILED; -} - void SSLClientSocketMac::Disconnect() { completed_handshake_ = false; @@ -451,7 +445,13 @@ int SSLClientSocketMac::DoLoop(int last_io_result) { int SSLClientSocketMac::DoConnect() { next_state_ = STATE_CONNECT_COMPLETE; - return transport_->Connect(&io_callback_); + + // The caller has to make sure that the transport socket is connected. If + // it isn't, we will eventually fail when trying to negotiate an SSL session. + // But we cannot call transport_->Connect(), as we do not know if there is + // any proxy negotiation that needs to be performed prior to establishing + // the SSL session. + return OK; } int SSLClientSocketMac::DoConnectComplete(int result) { diff --git a/net/base/ssl_client_socket_mac.h b/net/base/ssl_client_socket_mac.h index f195940..2efa031c 100644 --- a/net/base/ssl_client_socket_mac.h +++ b/net/base/ssl_client_socket_mac.h @@ -34,7 +34,6 @@ class SSLClientSocketMac : public SSLClientSocket { // ClientSocket methods: virtual int Connect(CompletionCallback* callback); - virtual int ReconnectIgnoringLastError(CompletionCallback* callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; diff --git a/net/base/ssl_client_socket_nss.cc b/net/base/ssl_client_socket_nss.cc index 4777ddc..aeccee3 100644 --- a/net/base/ssl_client_socket_nss.cc +++ b/net/base/ssl_client_socket_nss.cc @@ -15,6 +15,7 @@ #include #undef Lock +#include "base/compiler_specific.h" #include "base/logging.h" #include "base/nss_init.h" #include "base/string_util.h" @@ -129,6 +130,34 @@ int SSLClientSocketNSS::Init() { return OK; } +// As part of Connect(), the SSLClientSocketNSS object performs an SSL +// handshake. This requires network IO, which in turn calls +// BufferRecvComplete() with a non-zero byte count. This byte count eventually +// winds its way through the state machine and ends up being passed to the +// callback. For Read() and Write(), that's what we want. But for Connect(), +// the caller expects OK (i.e. 0) for success. +// +// The ConnectCallbackWrapper object changes the argument that gets passed +// to the callback function. Any positive value gets turned into OK. +class ConnectCallbackWrapper : + public CompletionCallbackImpl { + public: + ConnectCallbackWrapper(CompletionCallback* user_callback) + : ALLOW_THIS_IN_INITIALIZER_LIST( + CompletionCallbackImpl(this, + &ConnectCallbackWrapper::ReturnValueWrapper)), + user_callback_(user_callback) { + } + + private: + void ReturnValueWrapper(int rv) { + user_callback_->Run(rv > OK ? OK : rv); + delete this; + } + + CompletionCallback* user_callback_; +}; + int SSLClientSocketNSS::Connect(CompletionCallback* callback) { EnterFunction(""); DCHECK(transport_.get()); @@ -138,28 +167,38 @@ int SSLClientSocketNSS::Connect(CompletionCallback* callback) { GotoState(STATE_CONNECT); int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) - user_callback_ = callback; + user_callback_ = new ConnectCallbackWrapper(callback); LeaveFunction(""); - return rv; + return rv > OK ? OK : rv; } -int SSLClientSocketNSS::ReconnectIgnoringLastError( - CompletionCallback* callback) { - EnterFunction(""); - // TODO(darin): implement me! - LeaveFunction(""); - return ERR_FAILED; +void SSLClientSocketNSS::InvalidateSessionIfBadCertificate() { + if (UpdateServerCert() != NULL && + ssl_config_.allowed_bad_certs_.count(server_cert_)) { + SSL_InvalidateSession(nss_fd_); + } } void SSLClientSocketNSS::Disconnect() { EnterFunction(""); + + // Reset object state + transport_send_busy_ = false; + transport_recv_busy_ = false; + user_buf_ = NULL; + user_buf_len_ = 0; + server_cert_error_ = OK; + completed_handshake_ = false; + nss_bufs_ = NULL; + // TODO(wtc): Send SSL close_notify alert. if (nss_fd_ != NULL) { + InvalidateSessionIfBadCertificate(); PR_Close(nss_fd_); nss_fd_ = NULL; } - completed_handshake_ = false; + transport_->Disconnect(); LeaveFunction(""); } @@ -229,6 +268,20 @@ int SSLClientSocketNSS::Write(const char* buf, int buf_len, return rv; } +X509Certificate *SSLClientSocketNSS::UpdateServerCert() { + // We set the server_cert_ from OwnAuthCertHandler(), but this handler + // does not necessarily get called if we are continuing a cached SSL + // session. + if (server_cert_ == NULL) { + X509Certificate::OSCertHandle nss_cert = SSL_PeerCertificate(nss_fd_); + if (nss_cert) { + server_cert_ = X509Certificate::CreateFromHandle( + nss_cert, X509Certificate::SOURCE_FROM_NETWORK); + } + } + return server_cert_; +} + void SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { EnterFunction(""); ssl_info->Reset(); @@ -248,13 +301,12 @@ void SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { LOG(DFATAL) << "SSL_GetCipherSuiteInfo returned " << PR_GetError() << " for cipherSuite " << channel_info.cipherSuite; } + UpdateServerCert(); } if (server_cert_error_ != net::OK) ssl_info->SetCertError(server_cert_error_); - X509Certificate::OSCertHandle nss_cert = SSL_PeerCertificate(nss_fd_); - if (nss_cert) - ssl_info->cert = X509Certificate::CreateFromHandle(nss_cert, - X509Certificate::SOURCE_FROM_NETWORK); + DCHECK(server_cert_ != NULL); + ssl_info->cert = server_cert_; LeaveFunction(""); } @@ -355,7 +407,6 @@ void SSLClientSocketNSS::BufferRecvComplete(int result) { LeaveFunction(""); } - int SSLClientSocketNSS::DoLoop(int last_io_result) { EnterFunction(last_io_result); bool network_moved; @@ -409,20 +460,49 @@ int SSLClientSocketNSS::DoLoop(int last_io_result) { int SSLClientSocketNSS::DoConnect() { EnterFunction(""); GotoState(STATE_CONNECT_COMPLETE); - return transport_->Connect(&io_callback_); + + // The caller has to make sure that the transport socket is connected. If + // it isn't, we will eventually fail when trying to negotiate an SSL session. + // But we cannot call transport_->Connect(), as we do not know if there is + // any proxy negotiation that needs to be performed prior to establishing + // the SSL session. + return OK; +} + +// static +// NSS calls this if an incoming certificate needs to be verified. +SECStatus SSLClientSocketNSS::OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { + SSLClientSocketNSS* that = reinterpret_cast(arg); + + // Remember the certificate as it will no longer be accessible if the + // handshake fails. + that->UpdateServerCert(); + + return SSL_AuthCertificate(CERT_GetDefaultCertDB(), socket, checksig, + is_server); } // static // NSS calls this if an incoming certificate is invalid. -SECStatus SSLClientSocketNSS::OwnBadCertHandler(void* arg, PRFileDesc* socket) { +SECStatus SSLClientSocketNSS::OwnBadCertHandler(void* arg, + PRFileDesc* socket) { SSLClientSocketNSS* that = reinterpret_cast(arg); + + if (that->server_cert_ && + that->ssl_config_.allowed_bad_certs_.count(that->server_cert_)) { + LOG(INFO) << "accepting bad SSL certificate, as user told us to"; + + return SECSuccess; + } PRErrorCode prerr = PR_GetError(); that->server_cert_error_ = NetErrorFromNSPRError(prerr); LOG(INFO) << "server certificate is invalid; NSS error code " << prerr << ", net error " << that->server_cert_error_; - // Return SECSuccess to override the problem. - // Chromium wants it to succeed here, and may abort the connection later. - return SECSuccess; + + return SECFailure; } int SSLClientSocketNSS::DoConnectComplete(int result) { @@ -503,6 +583,10 @@ int SSLClientSocketNSS::DoConnectComplete(int result) { if (rv != SECSuccess) return ERR_UNEXPECTED; + rv = SSL_AuthCertificateHook(nss_fd_, OwnAuthCertHandler, this); + if (rv != SECSuccess) + return ERR_UNEXPECTED; + rv = SSL_BadCertHook(nss_fd_, OwnBadCertHandler, this); if (rv != SECSuccess) return ERR_UNEXPECTED; @@ -520,11 +604,14 @@ int SSLClientSocketNSS::DoConnectComplete(int result) { int SSLClientSocketNSS::DoHandshakeRead() { EnterFunction(""); - int net_error; + int net_error = net::OK; int rv = SSL_ForceHandshake(nss_fd_); if (rv == SECSuccess) { - net_error = server_cert_error_; + DCHECK(server_cert_error_ == net::OK); + + InvalidateSessionIfBadCertificate(); + // there's a callback for this, too completed_handshake_ = true; // Done! diff --git a/net/base/ssl_client_socket_nss.h b/net/base/ssl_client_socket_nss.h index 9b77f94..89fe99e 100644 --- a/net/base/ssl_client_socket_nss.h +++ b/net/base/ssl_client_socket_nss.h @@ -17,6 +17,8 @@ namespace net { +class X509Certificate; + // An SSL client socket implemented with Mozilla NSS. class SSLClientSocketNSS : public SSLClientSocket { public: @@ -34,7 +36,6 @@ class SSLClientSocketNSS : public SSLClientSocket { // ClientSocket methods: virtual int Connect(CompletionCallback* callback); - virtual int ReconnectIgnoringLastError(CompletionCallback* callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; @@ -44,6 +45,8 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); private: + void InvalidateSessionIfBadCertificate(); + X509Certificate* UpdateServerCert(); void DoCallback(int result); void OnIOComplete(int result); @@ -59,7 +62,12 @@ class SSLClientSocketNSS : public SSLClientSocket { void BufferSendComplete(int result); void BufferRecvComplete(int result); - // nss calls this on error. We pass 'this' as the first argument. + // NSS calls this when checking certificates. We pass 'this' as the first + // argument. + static SECStatus OwnAuthCertHandler(void* arg, PRFileDesc* socket, + PRBool checksig, PRBool is_server); + + // NSS calls this on error. We pass 'this' as the first argument. static SECStatus OwnBadCertHandler(void* arg, PRFileDesc* socket); CompletionCallbackImpl buffer_send_callback_; @@ -81,6 +89,9 @@ class SSLClientSocketNSS : public SSLClientSocket { // Set when handshake finishes. Value is net error code, see net_errors.h int server_cert_error_; + // Set during handshake. + scoped_refptr server_cert_; + bool completed_handshake_; enum State { diff --git a/net/base/ssl_client_socket_unittest.cc b/net/base/ssl_client_socket_unittest.cc index 7b551fe..7c7b170 100644 --- a/net/base/ssl_client_socket_unittest.cc +++ b/net/base/ssl_client_socket_unittest.cc @@ -79,8 +79,14 @@ TEST_F(SSLClientSocketTest, MAYBE_Connect) { &addr, NULL); EXPECT_EQ(net::OK, rv); + net::ClientSocket *transport = new net::TCPClientSocket(addr); + rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + scoped_ptr sock( - socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + socket_factory_->CreateSSLClientSocket(transport, server_.kHostName, kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -111,8 +117,14 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectExpired) { &addr, NULL); EXPECT_EQ(net::OK, rv); + net::ClientSocket *transport = new net::TCPClientSocket(addr); + rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + scoped_ptr sock( - socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + socket_factory_->CreateSSLClientSocket(transport, server_.kHostName, kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -126,7 +138,9 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectExpired) { EXPECT_EQ(net::ERR_CERT_DATE_INVALID, rv); } - EXPECT_TRUE(sock->IsConnected()); + // We cannot test sock->IsConnected(), as the NSS implementation disconnects + // the socket when it encounters an error, whereas other implementations + // leave it connected. } TEST_F(SSLClientSocketTest, MAYBE_ConnectMismatched) { @@ -140,8 +154,14 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectMismatched) { &addr, NULL); EXPECT_EQ(net::OK, rv); + net::ClientSocket *transport = new net::TCPClientSocket(addr); + rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + scoped_ptr sock( - socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + socket_factory_->CreateSSLClientSocket(transport, server_.kMismatchedHostName, kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); @@ -155,13 +175,9 @@ TEST_F(SSLClientSocketTest, MAYBE_ConnectMismatched) { EXPECT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, rv); } - // The Windows code happens to keep the connection - // open now in spite of an error. The designers of - // this API intended to also allow the connection - // to be closed on error, in which case the caller - // should call ReconnectIgnoringLastError, but - // that's currently unimplemented. - EXPECT_TRUE(sock->IsConnected()); + // We cannot test sock->IsConnected(), as the NSS implementation disconnects + // the socket when it encounters an error, whereas other implementations + // leave it connected. } // TODO(wtc): Add unit tests for IsConnectedAndIdle: @@ -183,8 +199,14 @@ TEST_F(SSLClientSocketTest, MAYBE_Read) { rv = callback.WaitForResult(); EXPECT_EQ(net::OK, rv); + net::ClientSocket *transport = new net::TCPClientSocket(addr); + rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + scoped_ptr sock( - socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + socket_factory_->CreateSSLClientSocket(transport, server_.kHostName, kDefaultSSLConfig)); @@ -231,8 +253,14 @@ TEST_F(SSLClientSocketTest, MAYBE_Read_SmallChunks) { &addr, NULL); EXPECT_EQ(net::OK, rv); + net::ClientSocket *transport = new net::TCPClientSocket(addr); + rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + scoped_ptr sock( - socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + socket_factory_->CreateSSLClientSocket(transport, server_.kHostName, kDefaultSSLConfig)); rv = sock->Connect(&callback); @@ -277,8 +305,14 @@ TEST_F(SSLClientSocketTest, MAYBE_Read_Interrupted) { &addr, NULL); EXPECT_EQ(net::OK, rv); + net::ClientSocket *transport = new net::TCPClientSocket(addr); + rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + scoped_ptr sock( - socket_factory_->CreateSSLClientSocket(new net::TCPClientSocket(addr), + socket_factory_->CreateSSLClientSocket(transport, server_.kHostName, kDefaultSSLConfig)); rv = sock->Connect(&callback); diff --git a/net/base/ssl_client_socket_win.cc b/net/base/ssl_client_socket_win.cc index 96484fa..70ef6a5 100644 --- a/net/base/ssl_client_socket_win.cc +++ b/net/base/ssl_client_socket_win.cc @@ -280,12 +280,6 @@ int SSLClientSocketWin::Connect(CompletionCallback* callback) { return rv; } -int SSLClientSocketWin::ReconnectIgnoringLastError( - CompletionCallback* callback) { - // TODO(darin): implement me! - return ERR_FAILED; -} - void SSLClientSocketWin::Disconnect() { // TODO(wtc): Send SSL close_notify alert. completed_handshake_ = false; @@ -450,7 +444,13 @@ int SSLClientSocketWin::DoLoop(int last_io_result) { int SSLClientSocketWin::DoConnect() { next_state_ = STATE_CONNECT_COMPLETE; - return transport_->Connect(&io_callback_); + + // The caller has to make sure that the transport socket is connected. If + // it isn't, we will eventually fail when trying to negotiate an SSL session. + // But we cannot call transport_->Connect(), as we do not know if there is + // any proxy negotiation that needs to be performed prior to establishing + // the SSL session. + return OK; } int SSLClientSocketWin::DoConnectComplete(int result) { diff --git a/net/base/ssl_client_socket_win.h b/net/base/ssl_client_socket_win.h index 68eb24d..a3935f6 100644 --- a/net/base/ssl_client_socket_win.h +++ b/net/base/ssl_client_socket_win.h @@ -39,7 +39,6 @@ class SSLClientSocketWin : public SSLClientSocket { // ClientSocket methods: virtual int Connect(CompletionCallback* callback); - virtual int ReconnectIgnoringLastError(CompletionCallback* callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; diff --git a/net/base/ssl_config_service.h b/net/base/ssl_config_service.h index 96d42a2..dec6fdb 100644 --- a/net/base/ssl_config_service.h +++ b/net/base/ssl_config_service.h @@ -5,7 +5,10 @@ #ifndef NET_BASE_SSL_CONFIG_SERVICE_H__ #define NET_BASE_SSL_CONFIG_SERVICE_H__ +#include + #include "base/time.h" +#include "net/base/x509_certificate.h" namespace net { @@ -23,6 +26,12 @@ struct SSLConfig { bool ssl2_enabled; // True if SSL 2.0 is enabled. bool ssl3_enabled; // True if SSL 3.0 is enabled. bool tls1_enabled; // True if TLS 1.0 is enabled. + + // Add any known-bad SSL certificates to allowed_bad_certs_ that should not + // trigger an ERR_CERT_*_INVALID error when calling SSLClientSocket::Connect. + // This would normally be done in response to the user explicitly accepting + // the bad certificate. + std::set > allowed_bad_certs_; }; // This class is responsible for getting and setting the SSL configuration. diff --git a/net/base/tcp_client_socket.h b/net/base/tcp_client_socket.h index b846836..729bbbc 100644 --- a/net/base/tcp_client_socket.h +++ b/net/base/tcp_client_socket.h @@ -46,7 +46,6 @@ class TCPClientSocket : public ClientSocket, // ClientSocket methods: virtual int Connect(CompletionCallback* callback); - virtual int ReconnectIgnoringLastError(CompletionCallback* callback); virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; diff --git a/net/base/tcp_client_socket_libevent.cc b/net/base/tcp_client_socket_libevent.cc index 204b700..8d51dbf 100644 --- a/net/base/tcp_client_socket_libevent.cc +++ b/net/base/tcp_client_socket_libevent.cc @@ -126,11 +126,6 @@ int TCPClientSocket::Connect(CompletionCallback* callback) { return ERR_IO_PENDING; } -int TCPClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) { - // No ignorable errors! - return ERR_UNEXPECTED; -} - void TCPClientSocket::Disconnect() { if (socket_ == kInvalidSocket) return; diff --git a/net/base/tcp_client_socket_win.cc b/net/base/tcp_client_socket_win.cc index 7c10d2d..9d69505 100644 --- a/net/base/tcp_client_socket_win.cc +++ b/net/base/tcp_client_socket_win.cc @@ -106,11 +106,6 @@ int TCPClientSocket::Connect(CompletionCallback* callback) { return ERR_IO_PENDING; } -int TCPClientSocket::ReconnectIgnoringLastError(CompletionCallback* callback) { - // No ignorable errors! - return ERR_UNEXPECTED; -} - void TCPClientSocket::Disconnect() { if (socket_ == INVALID_SOCKET) return; diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index c549cb3..36a752f 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -80,9 +80,13 @@ int HttpNetworkTransaction::Start(const HttpRequestInfo* request_info, int HttpNetworkTransaction::RestartIgnoringLastError( CompletionCallback* callback) { - // TODO(wtc): If the connection is no longer alive, call - // connection_.socket()->ReconnectIgnoringLastError(). - next_state_ = STATE_WRITE_HEADERS; + if (connection_.socket()->IsConnected()) { + next_state_ = STATE_WRITE_HEADERS; + } else { + connection_.set_socket(NULL); + connection_.Reset(); + next_state_ = STATE_INIT_CONNECTION; + } int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) user_callback_ = callback; @@ -249,7 +253,7 @@ LoadState HttpNetworkTransaction::GetLoadState() const { return LOAD_STATE_RESOLVING_PROXY_FOR_URL; case STATE_RESOLVE_HOST_COMPLETE: return LOAD_STATE_RESOLVING_HOST; - case STATE_CONNECT_COMPLETE: + case STATE_TCP_CONNECT_COMPLETE: return LOAD_STATE_CONNECTING; case STATE_WRITE_HEADERS_COMPLETE: case STATE_WRITE_BODY_COMPLETE: @@ -408,23 +412,23 @@ int HttpNetworkTransaction::DoLoop(int result) { rv = DoResolveHostComplete(rv); TRACE_EVENT_END("http.resolve_host", request_, request_->url.spec()); break; - case STATE_CONNECT: + case STATE_TCP_CONNECT: DCHECK_EQ(OK, rv); TRACE_EVENT_BEGIN("http.connect", request_, request_->url.spec()); - rv = DoConnect(); + rv = DoTCPConnect(); break; - case STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); + case STATE_TCP_CONNECT_COMPLETE: + rv = DoTCPConnectComplete(rv); TRACE_EVENT_END("http.connect", request_, request_->url.spec()); break; - case STATE_SSL_CONNECT_OVER_TUNNEL: + case STATE_SSL_CONNECT: DCHECK_EQ(OK, rv); - TRACE_EVENT_BEGIN("http.ssl_tunnel", request_, request_->url.spec()); - rv = DoSSLConnectOverTunnel(); + TRACE_EVENT_BEGIN("http.ssl_connect", request_, request_->url.spec()); + rv = DoSSLConnect(); break; - case STATE_SSL_CONNECT_OVER_TUNNEL_COMPLETE: - rv = DoSSLConnectOverTunnelComplete(rv); - TRACE_EVENT_END("http.ssl_tunnel", request_, request_->url.spec()); + case STATE_SSL_CONNECT_COMPLETE: + rv = DoSSLConnectComplete(rv); + TRACE_EVENT_END("http.ssl_connect", request_, request_->url.spec()); break; case STATE_WRITE_HEADERS: DCHECK_EQ(OK, rv); @@ -578,48 +582,42 @@ int HttpNetworkTransaction::DoResolveHostComplete(int result) { bool ok = (result == OK); DidFinishDnsResolutionWithStatus(ok, request_->referrer, this); if (ok) { - next_state_ = STATE_CONNECT; + next_state_ = STATE_TCP_CONNECT; } else { result = ReconsiderProxyAfterError(result); } return result; } -int HttpNetworkTransaction::DoConnect() { - next_state_ = STATE_CONNECT_COMPLETE; +int HttpNetworkTransaction::DoTCPConnect() { + next_state_ = STATE_TCP_CONNECT_COMPLETE; DCHECK(!connection_.socket()); ClientSocket* s = socket_factory_->CreateTCPClientSocket(addresses_); - - // If we are using a direct SSL connection, then go ahead and create the SSL - // wrapper socket now. Otherwise, we need to first issue a CONNECT request. - if (using_ssl_ && !using_tunnel_) - s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(), - ssl_config_); - connection_.set_socket(s); return connection_.socket()->Connect(&io_callback_); } -int HttpNetworkTransaction::DoConnectComplete(int result) { - if (IsCertificateError(result)) - result = HandleCertificateError(result); - +int HttpNetworkTransaction::DoTCPConnectComplete(int result) { + // If we are using a direct SSL connection, then go ahead and establish the + // SSL connection, now. Otherwise, we need to first issue a CONNECT request. if (result == OK) { - next_state_ = STATE_WRITE_HEADERS; - if (using_tunnel_) - establishing_tunnel_ = true; + if (using_ssl_ && !using_tunnel_) { + next_state_ = STATE_SSL_CONNECT; + } else { + next_state_ = STATE_WRITE_HEADERS; + if (using_tunnel_) + establishing_tunnel_ = true; + } } else { - result = HandleSSLHandshakeError(result); - if (result != OK) - result = ReconsiderProxyAfterError(result); + result = ReconsiderProxyAfterError(result); } return result; } -int HttpNetworkTransaction::DoSSLConnectOverTunnel() { - next_state_ = STATE_SSL_CONNECT_OVER_TUNNEL_COMPLETE; +int HttpNetworkTransaction::DoSSLConnect() { + next_state_ = STATE_SSL_CONNECT_COMPLETE; // Add a SSL socket on top of our existing transport socket. ClientSocket* s = connection_.release_socket(); @@ -629,7 +627,7 @@ int HttpNetworkTransaction::DoSSLConnectOverTunnel() { return connection_.socket()->Connect(&io_callback_); } -int HttpNetworkTransaction::DoSSLConnectOverTunnelComplete(int result) { +int HttpNetworkTransaction::DoSSLConnectComplete(int result) { if (IsCertificateError(result)) result = HandleCertificateError(result); @@ -1019,7 +1017,7 @@ int HttpNetworkTransaction::DidReadResponseHeaders() { // The proxy sent extraneous data after the headers. return ERR_TUNNEL_CONNECTION_FAILED; } - next_state_ = STATE_SSL_CONNECT_OVER_TUNNEL; + next_state_ = STATE_SSL_CONNECT; // Reset for the real request and response headers. request_headers_.clear(); request_headers_bytes_sent_ = 0; @@ -1150,6 +1148,12 @@ int HttpNetworkTransaction::HandleCertificateError(int error) { SSLClientSocket* ssl_socket = reinterpret_cast(connection_.socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); + + // Add the bad certificate to the set of allowed certificates in the + // SSL info object. This data structure will be consulted after calling + // RestartIgnoringLastError(). And the user will be asked interactively + // before RestartIgnoringLastError() is ever called. + ssl_config_.allowed_bad_certs_.insert(response_.ssl_info.cert); } return error; } diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index 939a900..a966086 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -72,10 +72,10 @@ class HttpNetworkTransaction : public HttpTransaction { int DoInitConnectionComplete(int result); int DoResolveHost(); int DoResolveHostComplete(int result); - int DoConnect(); - int DoConnectComplete(int result); - int DoSSLConnectOverTunnel(); - int DoSSLConnectOverTunnelComplete(int result); + int DoTCPConnect(); + int DoTCPConnectComplete(int result); + int DoSSLConnect(); + int DoSSLConnectComplete(int result); int DoWriteHeaders(); int DoWriteHeadersComplete(int result); int DoWriteBody(); @@ -296,10 +296,10 @@ class HttpNetworkTransaction : public HttpTransaction { STATE_INIT_CONNECTION_COMPLETE, STATE_RESOLVE_HOST, STATE_RESOLVE_HOST_COMPLETE, - STATE_CONNECT, - STATE_CONNECT_COMPLETE, - STATE_SSL_CONNECT_OVER_TUNNEL, - STATE_SSL_CONNECT_OVER_TUNNEL_COMPLETE, + STATE_TCP_CONNECT, + STATE_TCP_CONNECT_COMPLETE, + STATE_SSL_CONNECT, + STATE_SSL_CONNECT_COMPLETE, STATE_WRITE_HEADERS, STATE_WRITE_HEADERS_COMPLETE, STATE_WRITE_BODY, diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 4e8209c..05d1ac7 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -6,6 +6,9 @@ #include "base/compiler_specific.h" #include "net/base/client_socket_factory.h" +#include "net/base/completion_callback.h" +#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_info.h" #include "net/base/test_completion_callback.h" #include "net/base/upload_data.h" #include "net/http/http_auth_handler_ntlm.h" @@ -25,7 +28,8 @@ namespace net { struct MockConnect { // Asynchronous connection success. - MockConnect() : async(true), result(net::OK) { } + MockConnect() : async(true), result(OK) { } + MockConnect(bool a, int r) : async(a), result(r) { } bool async; int result; @@ -62,6 +66,7 @@ typedef MockRead MockWrite; struct MockSocket { MockSocket() : reads(NULL), writes(NULL) { } + MockSocket(MockRead* r, MockWrite* w) : reads(r), writes(w) { } MockConnect connect; MockRead* reads; @@ -76,37 +81,35 @@ struct MockSocket { // MockSocket* mock_sockets[10]; +// MockSSLSockets only need to keep track of the return code from calls to +// Connect(). +struct MockSSLSocket { + MockSSLSocket(bool async, int result) : connect(async, result) { } + + MockConnect connect; +}; +MockSSLSocket* mock_ssl_sockets[10]; + // Index of the next mock_sockets element to use. int mock_sockets_index; +int mock_ssl_sockets_index; -class MockTCPClientSocket : public net::ClientSocket { +class MockClientSocket : public SSLClientSocket { public: - explicit MockTCPClientSocket(const net::AddressList& addresses) - : data_(mock_sockets[mock_sockets_index++]), - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), + explicit MockClientSocket() + : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), callback_(NULL), - read_index_(0), - read_offset_(0), - write_index_(0), connected_(false) { - DCHECK(data_) << "overran mock_sockets array"; } + // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) { - DCHECK(!callback_); - if (connected_) - return net::OK; - connected_ = true; - if (data_->connect.async) { - RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; - } - return data_->connect.result; - } - virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { + virtual int Connect(CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(SSLInfo* ssl_info) { NOTREACHED(); - return net::ERR_FAILED; } + virtual void Disconnect() { connected_ = false; callback_ = NULL; @@ -118,7 +121,64 @@ class MockTCPClientSocket : public net::ClientSocket { return connected_; } // Socket methods: - virtual int Read(char* buf, int buf_len, net::CompletionCallback* callback) { + virtual int Read(char* buf, int buf_len, + CompletionCallback* callback) = 0; + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen) { + memset(reinterpret_cast(name), 0, *namelen); + return OK; + } +#endif + + + protected: + void RunCallbackAsync(CompletionCallback* callback, int result) { + callback_ = callback; + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &MockClientSocket::RunCallback, result)); + } + + void RunCallback(int result) { + CompletionCallback* c = callback_; + callback_ = NULL; + if (c) + c->Run(result); + } + + ScopedRunnableMethodFactory method_factory_; + CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + explicit MockTCPClientSocket(const AddressList& addresses) + : data_(mock_sockets[mock_sockets_index++]), + read_index_(0), + read_offset_(0), + write_index_(0) { + DCHECK(data_) << "overran mock_sockets array"; + } + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + if (connected_) + return OK; + connected_ = true; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + return data_->connect.result; + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { DCHECK(!callback_); MockRead& r = data_->reads[read_index_]; int result = r.result; @@ -137,12 +197,13 @@ class MockTCPClientSocket : public net::ClientSocket { } if (r.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + virtual int Write(const char* buf, int buf_len, - net::CompletionCallback* callback) { + CompletionCallback* callback) { DCHECK(buf); DCHECK(buf_len > 0); DCHECK(!callback_); @@ -159,49 +220,123 @@ class MockTCPClientSocket : public net::ClientSocket { std::string actual_data(buf, buf_len); EXPECT_EQ(expected_data, actual_data); if (expected_data != actual_data) - return net::ERR_UNEXPECTED; - if (result == net::OK) + return ERR_UNEXPECTED; + if (result == OK) result = w.data_len; } if (w.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + private: - void RunCallbackAsync(net::CompletionCallback* callback, int result) { - callback_ = callback; - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockTCPClientSocket::RunCallback, result)); - } - void RunCallback(int result) { - net::CompletionCallback* c = callback_; - callback_ = NULL; - if (c) - c->Run(result); - } MockSocket* data_; - ScopedRunnableMethodFactory method_factory_; - net::CompletionCallback* callback_; int read_index_; int read_offset_; int write_index_; - bool connected_; }; -class MockClientSocketFactory : public net::ClientSocketFactory { +class MockSSLClientSocket : public MockClientSocket { public: - virtual net::ClientSocket* CreateTCPClientSocket( - const net::AddressList& addresses) { + explicit MockSSLClientSocket( + ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) + : transport_(transport_socket), + data_(mock_ssl_sockets[mock_ssl_sockets_index++]) { + DCHECK(data_) << "overran mock_ssl_sockets array"; + } + + ~MockSSLClientSocket() { + Disconnect(); + } + + virtual void GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + } + + friend class ConnectCallback; + class ConnectCallback : + public CompletionCallbackImpl { + public: + ConnectCallback(MockSSLClientSocket *ssl_client_socket, + CompletionCallback* user_callback, + int rv) + : ALLOW_THIS_IN_INITIALIZER_LIST( + CompletionCallbackImpl( + this, &ConnectCallback::Wrapper)), + ssl_client_socket_(ssl_client_socket), + user_callback_(user_callback), + rv_(rv) { + } + + private: + void Wrapper(int rv) { + if (rv_ == OK) + ssl_client_socket_->connected_ = true; + user_callback_->Run(rv_); + delete this; + } + + MockSSLClientSocket* ssl_client_socket_; + CompletionCallback* user_callback_; + int rv_; + }; + + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + ConnectCallback* connect_callback = new ConnectCallback( + this, callback, data_->connect.result); + int rv = transport_->Connect(connect_callback); + if (rv == OK) { + delete connect_callback; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + if (data_->connect.result == OK) + connected_ = true; + return data_->connect.result; + } + return rv; + } + + virtual void Disconnect() { + MockClientSocket::Disconnect(); + if (transport_ != NULL) + transport_->Disconnect(); + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Read(buf, buf_len, callback); + } + + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Write(buf, buf_len, callback); + } + + private: + scoped_ptr transport_; + MockSSLSocket* data_; +}; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + virtual ClientSocket* CreateTCPClientSocket( + const AddressList& addresses) { return new MockTCPClientSocket(addresses); } - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocket* transport_socket, + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocket* transport_socket, const std::string& hostname, - const net::SSLConfig& ssl_config) { - return NULL; + const SSLConfig& ssl_config) { + return new MockSSLClientSocket(transport_socket, hostname, ssl_config); } }; @@ -229,6 +364,7 @@ class HttpNetworkTransactionTest : public PlatformTest { PlatformTest::SetUp(); mock_sockets[0] = NULL; mock_sockets_index = 0; + mock_ssl_sockets_index = 0; } virtual void TearDown() { @@ -2711,4 +2847,139 @@ TEST_F(HttpNetworkTransactionTest, ResetStateForRestart) { EXPECT_FALSE(trans->response_.vary_data.is_valid()); } +// Test HTTPS connections to a site with a bad certificate +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificate) { + scoped_ptr proxy_service(CreateNullProxyService()); + scoped_ptr trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite data_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket ssl_bad_certificate; + MockSocket data(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &ssl_bad_certificate; + mock_sockets[1] = &data; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); +} + +// Test HTTPS connections to a site with a bad certificate, going through a +// proxy +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificateViaProxy) { + scoped_ptr proxy_service( + CreateFixedProxyService("myproxy:70")); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite proxy_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + }; + + MockRead proxy_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead(false, net::OK) + }; + + MockWrite data_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket ssl_bad_certificate(proxy_reads, proxy_writes); + MockSocket data(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &ssl_bad_certificate; + mock_sockets[1] = &data; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + for (int i = 0; i < 2; i++) { + mock_sockets_index = 0; + mock_ssl_sockets_index = 0; + + scoped_ptr trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); + } +} + } // namespace net diff --git a/net/url_request/url_request_unittest.cc b/net/url_request/url_request_unittest.cc index fb4a650..fbfaafe 100644 --- a/net/url_request/url_request_unittest.cc +++ b/net/url_request/url_request_unittest.cc @@ -201,9 +201,13 @@ class HTTPSRequestTest : public testing::Test { #if defined(OS_MACOSX) // ssl_client_socket_mac.cc crashes currently in GetSSLInfo // when called on a connection with an unrecognized certificate -#define MAYBE_HTTPSGetTest DISABLED_HTTPSGetTest +#define MAYBE_HTTPSGetTest DISABLED_HTTPSGetTest +#define MAYBE_HTTPSMismatchedTest DISABLED_HTTPSMismatchedTest +#define MAYBE_HTTPSExpiredTest DISABLED_HTTPSExpiredTest #else -#define MAYBE_HTTPSGetTest HTTPSGetTest +#define MAYBE_HTTPSGetTest HTTPSGetTest +#define MAYBE_HTTPSMismatchedTest HTTPSMismatchedTest +#define MAYBE_HTTPSExpiredTest HTTPSExpiredTest #endif TEST_F(HTTPSRequestTest, MAYBE_HTTPSGetTest) { @@ -233,7 +237,63 @@ TEST_F(HTTPSRequestTest, MAYBE_HTTPSGetTest) { #endif } -// TODO(dkegel): add test for expired and mismatched certificates here +TEST_F(HTTPSRequestTest, MAYBE_HTTPSMismatchedTest) { + scoped_refptr server = + HTTPSTestServer::CreateMismatchedServer(L"net/data/ssl"); + ASSERT_TRUE(NULL != server.get()); + + bool err_allowed = true; + for (int i = 0; i < 2 ; i++, err_allowed = !err_allowed) { + TestDelegate d; + { + d.set_allow_certificate_errors(err_allowed); + TestURLRequest r(server->TestServerPage(""), &d); + + r.Start(); + EXPECT_TRUE(r.is_pending()); + + MessageLoop::current()->Run(); + + EXPECT_EQ(1, d.response_started_count()); + EXPECT_FALSE(d.received_data_before_response()); + EXPECT_TRUE(d.have_certificate_errors()); + if (err_allowed) + EXPECT_NE(0, d.bytes_received()); + else + EXPECT_EQ(0, d.bytes_received()); + } + } +} + +TEST_F(HTTPSRequestTest, MAYBE_HTTPSExpiredTest) { + scoped_refptr server = + HTTPSTestServer::CreateExpiredServer(L"net/data/ssl"); + ASSERT_TRUE(NULL != server.get()); + + // Iterate from false to true, just so that we do the opposite of the + // previous test in order to increase test coverage. + bool err_allowed = false; + for (int i = 0; i < 2 ; i++, err_allowed = !err_allowed) { + TestDelegate d; + { + d.set_allow_certificate_errors(err_allowed); + TestURLRequest r(server->TestServerPage(""), &d); + + r.Start(); + EXPECT_TRUE(r.is_pending()); + + MessageLoop::current()->Run(); + + EXPECT_EQ(1, d.response_started_count()); + EXPECT_FALSE(d.received_data_before_response()); + EXPECT_TRUE(d.have_certificate_errors()); + if (err_allowed) + EXPECT_NE(0, d.bytes_received()); + else + EXPECT_EQ(0, d.bytes_received()); + } + } +} TEST_F(URLRequestTest, CancelTest) { TestDelegate d; diff --git a/net/url_request/url_request_unittest.h b/net/url_request/url_request_unittest.h index a3627ad..88b83a6 100644 --- a/net/url_request/url_request_unittest.h +++ b/net/url_request/url_request_unittest.h @@ -69,11 +69,13 @@ class TestDelegate : public URLRequest::Delegate { cancel_in_rd_(false), cancel_in_rd_pending_(false), quit_on_complete_(true), + allow_certificate_errors_(false), response_started_count_(0), received_bytes_count_(0), received_redirect_count_(0), received_data_before_response_(false), request_failed_(false), + have_certificate_errors_(false), buf_(new net::IOBuffer(kBufferSize)) { } @@ -158,10 +160,14 @@ class TestDelegate : public URLRequest::Delegate { virtual void OnSSLCertificateError(URLRequest* request, int cert_error, net::X509Certificate* cert) { - // Ignore SSL errors, we test the server is started and shut it down by - // performing GETs, no security restrictions should apply as we always want - // these GETs to go through. - request->ContinueDespiteLastError(); + // The caller can control whether it needs all SSL requests to go through, + // independent of any possible errors, or whether it wants SSL errors to + // cancel the request. + have_certificate_errors_ = true; + if (allow_certificate_errors_) + request->ContinueDespiteLastError(); + else + request->Cancel(); } void set_cancel_in_received_redirect(bool val) { cancel_in_rr_ = val; } @@ -171,6 +177,9 @@ class TestDelegate : public URLRequest::Delegate { cancel_in_rd_pending_ = val; } void set_quit_on_complete(bool val) { quit_on_complete_ = val; } + void set_allow_certificate_errors(bool val) { + allow_certificate_errors_ = val; + } void set_username(const std::wstring& u) { username_ = u; } void set_password(const std::wstring& p) { password_ = p; } @@ -183,6 +192,7 @@ class TestDelegate : public URLRequest::Delegate { return received_data_before_response_; } bool request_failed() const { return request_failed_; } + bool have_certificate_errors() const { return have_certificate_errors_; } private: static const int kBufferSize = 4096; @@ -192,6 +202,7 @@ class TestDelegate : public URLRequest::Delegate { bool cancel_in_rd_; bool cancel_in_rd_pending_; bool quit_on_complete_; + bool allow_certificate_errors_; std::wstring username_; std::wstring password_; @@ -202,6 +213,7 @@ class TestDelegate : public URLRequest::Delegate { int received_redirect_count_; bool received_data_before_response_; bool request_failed_; + bool have_certificate_errors_; std::string data_received_; // our read buffer -- cgit v1.1