diff options
author | markus@chromium.org <markus@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-03-30 21:09:30 +0000 |
---|---|---|
committer | markus@chromium.org <markus@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-03-30 21:09:30 +0000 |
commit | 3fd49f9bdbe56a9648cddc015bdb8bac02fe1a7b (patch) | |
tree | ee0a2b866a939e678bedf426e858871706c41bff /net/http/http_network_transaction_unittest.cc | |
parent | f463787972e54c126d23d263613634d5fd777789 (diff) | |
download | chromium_src-3fd49f9bdbe56a9648cddc015bdb8bac02fe1a7b.zip chromium_src-3fd49f9bdbe56a9648cddc015bdb8bac02fe1a7b.tar.gz chromium_src-3fd49f9bdbe56a9648cddc015bdb8bac02fe1a7b.tar.bz2 |
Change the bad-certificate handler for SSL (using NSS) to return an
error.
This requires a few additional changes in the rest of the code. In
particular, we now have to teach HttpNetworkTransaction about how to
restart connections with bad certificates. This was originally
intended to be done by ReconnectIgnoringLastError(), but that API
turns out be very difficult to implement in the SSLClientSocket. So,
instead, we just create a completely new SSLClientSocket.
We also have to be careful to store a copy of the certificate from
within the bad-certificate handler, as it won't be available by the
time GetSSLInfo() is called.
And we fix a bug that would cause us to erroneously talk SSL on
reconnected TCP sockets, even though we were still supposed to
negotiate a proxy tunnel first.
Review URL: http://codereview.chromium.org/43115
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@12809 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/http/http_network_transaction_unittest.cc')
-rw-r--r-- | net/http/http_network_transaction_unittest.cc | 371 |
1 files changed, 321 insertions, 50 deletions
diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 4e8209c..de1174e 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -6,6 +6,9 @@ #include "base/compiler_specific.h" #include "net/base/client_socket_factory.h" +#include "net/base/completion_callback.h" +#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_info.h" #include "net/base/test_completion_callback.h" #include "net/base/upload_data.h" #include "net/http/http_auth_handler_ntlm.h" @@ -25,7 +28,8 @@ namespace net { struct MockConnect { // Asynchronous connection success. - MockConnect() : async(true), result(net::OK) { } + MockConnect() : async(true), result(OK) { } + MockConnect(bool a, int r) : async(a), result(r) { } bool async; int result; @@ -62,6 +66,7 @@ typedef MockRead MockWrite; struct MockSocket { MockSocket() : reads(NULL), writes(NULL) { } + MockSocket(MockRead* r, MockWrite* w) : reads(r), writes(w) { } MockConnect connect; MockRead* reads; @@ -76,37 +81,35 @@ struct MockSocket { // MockSocket* mock_sockets[10]; +// MockSSLSockets only need to keep track of the return code from calls to +// Connect(). +struct MockSSLSocket { + MockSSLSocket(bool async, int result) : connect(async, result) { } + + MockConnect connect; +}; +MockSSLSocket* mock_ssl_sockets[10]; + // Index of the next mock_sockets element to use. int mock_sockets_index; +int mock_ssl_sockets_index; -class MockTCPClientSocket : public net::ClientSocket { +class MockClientSocket : public SSLClientSocket { public: - explicit MockTCPClientSocket(const net::AddressList& addresses) - : data_(mock_sockets[mock_sockets_index++]), - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), + explicit MockClientSocket() + : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), callback_(NULL), - read_index_(0), - read_offset_(0), - write_index_(0), connected_(false) { - DCHECK(data_) << "overran mock_sockets array"; } + // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) { - DCHECK(!callback_); - if (connected_) - return net::OK; - connected_ = true; - if (data_->connect.async) { - RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; - } - return data_->connect.result; - } - virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { + virtual int Connect(CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(SSLInfo* ssl_info) { NOTREACHED(); - return net::ERR_FAILED; } + virtual void Disconnect() { connected_ = false; callback_ = NULL; @@ -118,7 +121,64 @@ class MockTCPClientSocket : public net::ClientSocket { return connected_; } // Socket methods: - virtual int Read(char* buf, int buf_len, net::CompletionCallback* callback) { + virtual int Read(char* buf, int buf_len, + CompletionCallback* callback) = 0; + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen) { + memset(reinterpret_cast<char *>(name), 0, *namelen); + return OK; + } +#endif + + + protected: + void RunCallbackAsync(CompletionCallback* callback, int result) { + callback_ = callback; + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &MockClientSocket::RunCallback, result)); + } + + void RunCallback(int result) { + CompletionCallback* c = callback_; + callback_ = NULL; + if (c) + c->Run(result); + } + + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + explicit MockTCPClientSocket(const AddressList& addresses) + : data_(mock_sockets[mock_sockets_index++]), + read_index_(0), + read_offset_(0), + write_index_(0) { + DCHECK(data_) << "overran mock_sockets array"; + } + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + if (connected_) + return OK; + connected_ = true; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + return data_->connect.result; + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { DCHECK(!callback_); MockRead& r = data_->reads[read_index_]; int result = r.result; @@ -137,12 +197,13 @@ class MockTCPClientSocket : public net::ClientSocket { } if (r.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + virtual int Write(const char* buf, int buf_len, - net::CompletionCallback* callback) { + CompletionCallback* callback) { DCHECK(buf); DCHECK(buf_len > 0); DCHECK(!callback_); @@ -159,49 +220,123 @@ class MockTCPClientSocket : public net::ClientSocket { std::string actual_data(buf, buf_len); EXPECT_EQ(expected_data, actual_data); if (expected_data != actual_data) - return net::ERR_UNEXPECTED; - if (result == net::OK) + return ERR_UNEXPECTED; + if (result == OK) result = w.data_len; } if (w.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + private: - void RunCallbackAsync(net::CompletionCallback* callback, int result) { - callback_ = callback; - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockTCPClientSocket::RunCallback, result)); - } - void RunCallback(int result) { - net::CompletionCallback* c = callback_; - callback_ = NULL; - if (c) - c->Run(result); - } MockSocket* data_; - ScopedRunnableMethodFactory<MockTCPClientSocket> method_factory_; - net::CompletionCallback* callback_; int read_index_; int read_offset_; int write_index_; - bool connected_; }; -class MockClientSocketFactory : public net::ClientSocketFactory { +class MockSSLClientSocket : public MockClientSocket { public: - virtual net::ClientSocket* CreateTCPClientSocket( - const net::AddressList& addresses) { + explicit MockSSLClientSocket( + ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) + : transport_(transport_socket), + data_(mock_ssl_sockets[mock_ssl_sockets_index++]) { + DCHECK(data_) << "overran mock_ssl_sockets array"; + } + + ~MockSSLClientSocket() { + Disconnect(); + } + + virtual void GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + } + + friend class ConnectCallback; + class ConnectCallback : + public CompletionCallbackImpl<ConnectCallback> { + public: + ConnectCallback(MockSSLClientSocket *ssl_client_socket, + CompletionCallback* user_callback, + int rv) + : ALLOW_THIS_IN_INITIALIZER_LIST( + CompletionCallbackImpl<ConnectCallback>( + this, &ConnectCallback::Wrapper)), + ssl_client_socket_(ssl_client_socket), + user_callback_(user_callback), + rv_(rv) { + } + + private: + void Wrapper(int rv) { + if (rv_ == OK) + ssl_client_socket_->connected_ = true; + user_callback_->Run(rv_); + delete this; + } + + MockSSLClientSocket* ssl_client_socket_; + CompletionCallback* user_callback_; + int rv_; + }; + + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + ConnectCallback* connect_callback = new ConnectCallback( + this, callback, data_->connect.result); + int rv = transport_->Connect(connect_callback); + if (rv == OK) { + delete connect_callback; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + if (data_->connect.result == OK) + connected_ = true; + return data_->connect.result; + } + return rv; + } + + virtual void Disconnect() { + MockClientSocket::Disconnect(); + if (transport_ != NULL) + transport_->Disconnect(); + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Read(buf, buf_len, callback); + } + + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Write(buf, buf_len, callback); + } + + private: + scoped_ptr<ClientSocket> transport_; + MockSSLSocket* data_; +}; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + virtual ClientSocket* CreateTCPClientSocket( + const AddressList& addresses) { return new MockTCPClientSocket(addresses); } - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocket* transport_socket, + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocket* transport_socket, const std::string& hostname, - const net::SSLConfig& ssl_config) { - return NULL; + const SSLConfig& ssl_config) { + return new MockSSLClientSocket(transport_socket, hostname, ssl_config); } }; @@ -229,6 +364,7 @@ class HttpNetworkTransactionTest : public PlatformTest { PlatformTest::SetUp(); mock_sockets[0] = NULL; mock_sockets_index = 0; + mock_ssl_sockets_index = 0; } virtual void TearDown() { @@ -2711,4 +2847,139 @@ TEST_F(HttpNetworkTransactionTest, ResetStateForRestart) { EXPECT_FALSE(trans->response_.vary_data.is_valid()); } +// Test HTTPS connections to a site with a bad certificate +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificate) { + scoped_ptr<ProxyService> proxy_service(CreateNullProxyService()); + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite data_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket bad_certificate_socket; + MockSocket data_socket(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &bad_certificate_socket; + mock_sockets[1] = &data_socket; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); +} + +// Test HTTPS connections to a site with a bad certificate, going through a +// proxy +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificateViaProxy) { + scoped_ptr<ProxyService> proxy_service( + CreateFixedProxyService("myproxy:70")); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite proxy_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + }; + + MockRead proxy_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead(false, net::OK) + }; + + MockWrite data_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket bad_certificate_socket(proxy_reads, proxy_writes); + MockSocket data_socket(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &bad_certificate_socket; + mock_sockets[1] = &data_socket; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + for (int i = 0; i < 2; i++) { + mock_sockets_index = 0; + mock_ssl_sockets_index = 0; + + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); + } +} + } // namespace net |