diff options
Diffstat (limited to 'net/http')
-rw-r--r-- | net/http/http_network_transaction.cc | 84 | ||||
-rw-r--r-- | net/http/http_network_transaction.h | 16 | ||||
-rw-r--r-- | net/http/http_network_transaction_unittest.cc | 371 |
3 files changed, 373 insertions, 98 deletions
diff --git a/net/http/http_network_transaction.cc b/net/http/http_network_transaction.cc index c549cb3..7ab3021 100644 --- a/net/http/http_network_transaction.cc +++ b/net/http/http_network_transaction.cc @@ -80,9 +80,13 @@ int HttpNetworkTransaction::Start(const HttpRequestInfo* request_info, int HttpNetworkTransaction::RestartIgnoringLastError( CompletionCallback* callback) { - // TODO(wtc): If the connection is no longer alive, call - // connection_.socket()->ReconnectIgnoringLastError(). - next_state_ = STATE_WRITE_HEADERS; + if (connection_.socket()->IsConnected()) { + next_state_ = STATE_WRITE_HEADERS; + } else { + connection_.set_socket(NULL); + connection_.Reset(); + next_state_ = STATE_INIT_CONNECTION; + } int rv = DoLoop(OK); if (rv == ERR_IO_PENDING) user_callback_ = callback; @@ -249,7 +253,7 @@ LoadState HttpNetworkTransaction::GetLoadState() const { return LOAD_STATE_RESOLVING_PROXY_FOR_URL; case STATE_RESOLVE_HOST_COMPLETE: return LOAD_STATE_RESOLVING_HOST; - case STATE_CONNECT_COMPLETE: + case STATE_TCP_CONNECT_COMPLETE: return LOAD_STATE_CONNECTING; case STATE_WRITE_HEADERS_COMPLETE: case STATE_WRITE_BODY_COMPLETE: @@ -408,23 +412,23 @@ int HttpNetworkTransaction::DoLoop(int result) { rv = DoResolveHostComplete(rv); TRACE_EVENT_END("http.resolve_host", request_, request_->url.spec()); break; - case STATE_CONNECT: + case STATE_TCP_CONNECT: DCHECK_EQ(OK, rv); - TRACE_EVENT_BEGIN("http.connect", request_, request_->url.spec()); - rv = DoConnect(); + TRACE_EVENT_BEGIN("http.tcp_connect", request_, request_->url.spec()); + rv = DoTCPConnect(); break; - case STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); - TRACE_EVENT_END("http.connect", request_, request_->url.spec()); + case STATE_TCP_CONNECT_COMPLETE: + rv = DoTCPConnectComplete(rv); + TRACE_EVENT_END("http.tcp_connect", request_, request_->url.spec()); break; - case STATE_SSL_CONNECT_OVER_TUNNEL: + case STATE_SSL_CONNECT: DCHECK_EQ(OK, rv); - TRACE_EVENT_BEGIN("http.ssl_tunnel", request_, request_->url.spec()); - rv = DoSSLConnectOverTunnel(); + TRACE_EVENT_BEGIN("http.ssl_connect", request_, request_->url.spec()); + rv = DoSSLConnect(); break; - case STATE_SSL_CONNECT_OVER_TUNNEL_COMPLETE: - rv = DoSSLConnectOverTunnelComplete(rv); - TRACE_EVENT_END("http.ssl_tunnel", request_, request_->url.spec()); + case STATE_SSL_CONNECT_COMPLETE: + rv = DoSSLConnectComplete(rv); + TRACE_EVENT_END("http.ssl_connect", request_, request_->url.spec()); break; case STATE_WRITE_HEADERS: DCHECK_EQ(OK, rv); @@ -578,48 +582,42 @@ int HttpNetworkTransaction::DoResolveHostComplete(int result) { bool ok = (result == OK); DidFinishDnsResolutionWithStatus(ok, request_->referrer, this); if (ok) { - next_state_ = STATE_CONNECT; + next_state_ = STATE_TCP_CONNECT; } else { result = ReconsiderProxyAfterError(result); } return result; } -int HttpNetworkTransaction::DoConnect() { - next_state_ = STATE_CONNECT_COMPLETE; +int HttpNetworkTransaction::DoTCPConnect() { + next_state_ = STATE_TCP_CONNECT_COMPLETE; DCHECK(!connection_.socket()); ClientSocket* s = socket_factory_->CreateTCPClientSocket(addresses_); - - // If we are using a direct SSL connection, then go ahead and create the SSL - // wrapper socket now. Otherwise, we need to first issue a CONNECT request. - if (using_ssl_ && !using_tunnel_) - s = socket_factory_->CreateSSLClientSocket(s, request_->url.host(), - ssl_config_); - connection_.set_socket(s); return connection_.socket()->Connect(&io_callback_); } -int HttpNetworkTransaction::DoConnectComplete(int result) { - if (IsCertificateError(result)) - result = HandleCertificateError(result); - +int HttpNetworkTransaction::DoTCPConnectComplete(int result) { + // If we are using a direct SSL connection, then go ahead and establish the + // SSL connection now. Otherwise, we need to first issue a CONNECT request. if (result == OK) { - next_state_ = STATE_WRITE_HEADERS; - if (using_tunnel_) - establishing_tunnel_ = true; + if (using_ssl_ && !using_tunnel_) { + next_state_ = STATE_SSL_CONNECT; + } else { + next_state_ = STATE_WRITE_HEADERS; + if (using_tunnel_) + establishing_tunnel_ = true; + } } else { - result = HandleSSLHandshakeError(result); - if (result != OK) - result = ReconsiderProxyAfterError(result); + result = ReconsiderProxyAfterError(result); } return result; } -int HttpNetworkTransaction::DoSSLConnectOverTunnel() { - next_state_ = STATE_SSL_CONNECT_OVER_TUNNEL_COMPLETE; +int HttpNetworkTransaction::DoSSLConnect() { + next_state_ = STATE_SSL_CONNECT_COMPLETE; // Add a SSL socket on top of our existing transport socket. ClientSocket* s = connection_.release_socket(); @@ -629,7 +627,7 @@ int HttpNetworkTransaction::DoSSLConnectOverTunnel() { return connection_.socket()->Connect(&io_callback_); } -int HttpNetworkTransaction::DoSSLConnectOverTunnelComplete(int result) { +int HttpNetworkTransaction::DoSSLConnectComplete(int result) { if (IsCertificateError(result)) result = HandleCertificateError(result); @@ -1019,7 +1017,7 @@ int HttpNetworkTransaction::DidReadResponseHeaders() { // The proxy sent extraneous data after the headers. return ERR_TUNNEL_CONNECTION_FAILED; } - next_state_ = STATE_SSL_CONNECT_OVER_TUNNEL; + next_state_ = STATE_SSL_CONNECT; // Reset for the real request and response headers. request_headers_.clear(); request_headers_bytes_sent_ = 0; @@ -1150,6 +1148,12 @@ int HttpNetworkTransaction::HandleCertificateError(int error) { SSLClientSocket* ssl_socket = reinterpret_cast<SSLClientSocket*>(connection_.socket()); ssl_socket->GetSSLInfo(&response_.ssl_info); + + // Add the bad certificate to the set of allowed certificates in the + // SSL info object. This data structure will be consulted after calling + // RestartIgnoringLastError(). And the user will be asked interactively + // before RestartIgnoringLastError() is ever called. + ssl_config_.allowed_bad_certs_.insert(response_.ssl_info.cert); } return error; } diff --git a/net/http/http_network_transaction.h b/net/http/http_network_transaction.h index 939a900..a966086 100644 --- a/net/http/http_network_transaction.h +++ b/net/http/http_network_transaction.h @@ -72,10 +72,10 @@ class HttpNetworkTransaction : public HttpTransaction { int DoInitConnectionComplete(int result); int DoResolveHost(); int DoResolveHostComplete(int result); - int DoConnect(); - int DoConnectComplete(int result); - int DoSSLConnectOverTunnel(); - int DoSSLConnectOverTunnelComplete(int result); + int DoTCPConnect(); + int DoTCPConnectComplete(int result); + int DoSSLConnect(); + int DoSSLConnectComplete(int result); int DoWriteHeaders(); int DoWriteHeadersComplete(int result); int DoWriteBody(); @@ -296,10 +296,10 @@ class HttpNetworkTransaction : public HttpTransaction { STATE_INIT_CONNECTION_COMPLETE, STATE_RESOLVE_HOST, STATE_RESOLVE_HOST_COMPLETE, - STATE_CONNECT, - STATE_CONNECT_COMPLETE, - STATE_SSL_CONNECT_OVER_TUNNEL, - STATE_SSL_CONNECT_OVER_TUNNEL_COMPLETE, + STATE_TCP_CONNECT, + STATE_TCP_CONNECT_COMPLETE, + STATE_SSL_CONNECT, + STATE_SSL_CONNECT_COMPLETE, STATE_WRITE_HEADERS, STATE_WRITE_HEADERS_COMPLETE, STATE_WRITE_BODY, diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 4e8209c..de1174e 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -6,6 +6,9 @@ #include "base/compiler_specific.h" #include "net/base/client_socket_factory.h" +#include "net/base/completion_callback.h" +#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_info.h" #include "net/base/test_completion_callback.h" #include "net/base/upload_data.h" #include "net/http/http_auth_handler_ntlm.h" @@ -25,7 +28,8 @@ namespace net { struct MockConnect { // Asynchronous connection success. - MockConnect() : async(true), result(net::OK) { } + MockConnect() : async(true), result(OK) { } + MockConnect(bool a, int r) : async(a), result(r) { } bool async; int result; @@ -62,6 +66,7 @@ typedef MockRead MockWrite; struct MockSocket { MockSocket() : reads(NULL), writes(NULL) { } + MockSocket(MockRead* r, MockWrite* w) : reads(r), writes(w) { } MockConnect connect; MockRead* reads; @@ -76,37 +81,35 @@ struct MockSocket { // MockSocket* mock_sockets[10]; +// MockSSLSockets only need to keep track of the return code from calls to +// Connect(). +struct MockSSLSocket { + MockSSLSocket(bool async, int result) : connect(async, result) { } + + MockConnect connect; +}; +MockSSLSocket* mock_ssl_sockets[10]; + // Index of the next mock_sockets element to use. int mock_sockets_index; +int mock_ssl_sockets_index; -class MockTCPClientSocket : public net::ClientSocket { +class MockClientSocket : public SSLClientSocket { public: - explicit MockTCPClientSocket(const net::AddressList& addresses) - : data_(mock_sockets[mock_sockets_index++]), - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), + explicit MockClientSocket() + : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), callback_(NULL), - read_index_(0), - read_offset_(0), - write_index_(0), connected_(false) { - DCHECK(data_) << "overran mock_sockets array"; } + // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) { - DCHECK(!callback_); - if (connected_) - return net::OK; - connected_ = true; - if (data_->connect.async) { - RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; - } - return data_->connect.result; - } - virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { + virtual int Connect(CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(SSLInfo* ssl_info) { NOTREACHED(); - return net::ERR_FAILED; } + virtual void Disconnect() { connected_ = false; callback_ = NULL; @@ -118,7 +121,64 @@ class MockTCPClientSocket : public net::ClientSocket { return connected_; } // Socket methods: - virtual int Read(char* buf, int buf_len, net::CompletionCallback* callback) { + virtual int Read(char* buf, int buf_len, + CompletionCallback* callback) = 0; + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen) { + memset(reinterpret_cast<char *>(name), 0, *namelen); + return OK; + } +#endif + + + protected: + void RunCallbackAsync(CompletionCallback* callback, int result) { + callback_ = callback; + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &MockClientSocket::RunCallback, result)); + } + + void RunCallback(int result) { + CompletionCallback* c = callback_; + callback_ = NULL; + if (c) + c->Run(result); + } + + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + explicit MockTCPClientSocket(const AddressList& addresses) + : data_(mock_sockets[mock_sockets_index++]), + read_index_(0), + read_offset_(0), + write_index_(0) { + DCHECK(data_) << "overran mock_sockets array"; + } + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + if (connected_) + return OK; + connected_ = true; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + return data_->connect.result; + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { DCHECK(!callback_); MockRead& r = data_->reads[read_index_]; int result = r.result; @@ -137,12 +197,13 @@ class MockTCPClientSocket : public net::ClientSocket { } if (r.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + virtual int Write(const char* buf, int buf_len, - net::CompletionCallback* callback) { + CompletionCallback* callback) { DCHECK(buf); DCHECK(buf_len > 0); DCHECK(!callback_); @@ -159,49 +220,123 @@ class MockTCPClientSocket : public net::ClientSocket { std::string actual_data(buf, buf_len); EXPECT_EQ(expected_data, actual_data); if (expected_data != actual_data) - return net::ERR_UNEXPECTED; - if (result == net::OK) + return ERR_UNEXPECTED; + if (result == OK) result = w.data_len; } if (w.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + private: - void RunCallbackAsync(net::CompletionCallback* callback, int result) { - callback_ = callback; - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockTCPClientSocket::RunCallback, result)); - } - void RunCallback(int result) { - net::CompletionCallback* c = callback_; - callback_ = NULL; - if (c) - c->Run(result); - } MockSocket* data_; - ScopedRunnableMethodFactory<MockTCPClientSocket> method_factory_; - net::CompletionCallback* callback_; int read_index_; int read_offset_; int write_index_; - bool connected_; }; -class MockClientSocketFactory : public net::ClientSocketFactory { +class MockSSLClientSocket : public MockClientSocket { public: - virtual net::ClientSocket* CreateTCPClientSocket( - const net::AddressList& addresses) { + explicit MockSSLClientSocket( + ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) + : transport_(transport_socket), + data_(mock_ssl_sockets[mock_ssl_sockets_index++]) { + DCHECK(data_) << "overran mock_ssl_sockets array"; + } + + ~MockSSLClientSocket() { + Disconnect(); + } + + virtual void GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + } + + friend class ConnectCallback; + class ConnectCallback : + public CompletionCallbackImpl<ConnectCallback> { + public: + ConnectCallback(MockSSLClientSocket *ssl_client_socket, + CompletionCallback* user_callback, + int rv) + : ALLOW_THIS_IN_INITIALIZER_LIST( + CompletionCallbackImpl<ConnectCallback>( + this, &ConnectCallback::Wrapper)), + ssl_client_socket_(ssl_client_socket), + user_callback_(user_callback), + rv_(rv) { + } + + private: + void Wrapper(int rv) { + if (rv_ == OK) + ssl_client_socket_->connected_ = true; + user_callback_->Run(rv_); + delete this; + } + + MockSSLClientSocket* ssl_client_socket_; + CompletionCallback* user_callback_; + int rv_; + }; + + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + ConnectCallback* connect_callback = new ConnectCallback( + this, callback, data_->connect.result); + int rv = transport_->Connect(connect_callback); + if (rv == OK) { + delete connect_callback; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + if (data_->connect.result == OK) + connected_ = true; + return data_->connect.result; + } + return rv; + } + + virtual void Disconnect() { + MockClientSocket::Disconnect(); + if (transport_ != NULL) + transport_->Disconnect(); + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Read(buf, buf_len, callback); + } + + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Write(buf, buf_len, callback); + } + + private: + scoped_ptr<ClientSocket> transport_; + MockSSLSocket* data_; +}; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + virtual ClientSocket* CreateTCPClientSocket( + const AddressList& addresses) { return new MockTCPClientSocket(addresses); } - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocket* transport_socket, + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocket* transport_socket, const std::string& hostname, - const net::SSLConfig& ssl_config) { - return NULL; + const SSLConfig& ssl_config) { + return new MockSSLClientSocket(transport_socket, hostname, ssl_config); } }; @@ -229,6 +364,7 @@ class HttpNetworkTransactionTest : public PlatformTest { PlatformTest::SetUp(); mock_sockets[0] = NULL; mock_sockets_index = 0; + mock_ssl_sockets_index = 0; } virtual void TearDown() { @@ -2711,4 +2847,139 @@ TEST_F(HttpNetworkTransactionTest, ResetStateForRestart) { EXPECT_FALSE(trans->response_.vary_data.is_valid()); } +// Test HTTPS connections to a site with a bad certificate +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificate) { + scoped_ptr<ProxyService> proxy_service(CreateNullProxyService()); + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite data_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket bad_certificate_socket; + MockSocket data_socket(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &bad_certificate_socket; + mock_sockets[1] = &data_socket; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); +} + +// Test HTTPS connections to a site with a bad certificate, going through a +// proxy +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificateViaProxy) { + scoped_ptr<ProxyService> proxy_service( + CreateFixedProxyService("myproxy:70")); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite proxy_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + }; + + MockRead proxy_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead(false, net::OK) + }; + + MockWrite data_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket bad_certificate_socket(proxy_reads, proxy_writes); + MockSocket data_socket(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &bad_certificate_socket; + mock_sockets[1] = &data_socket; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + for (int i = 0; i < 2; i++) { + mock_sockets_index = 0; + mock_ssl_sockets_index = 0; + + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); + } +} + } // namespace net |