diff options
Diffstat (limited to 'net/http/http_network_transaction_unittest.cc')
-rw-r--r-- | net/http/http_network_transaction_unittest.cc | 371 |
1 files changed, 321 insertions, 50 deletions
diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 4e8209c..05d1ac7 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -6,6 +6,9 @@ #include "base/compiler_specific.h" #include "net/base/client_socket_factory.h" +#include "net/base/completion_callback.h" +#include "net/base/ssl_client_socket.h" +#include "net/base/ssl_info.h" #include "net/base/test_completion_callback.h" #include "net/base/upload_data.h" #include "net/http/http_auth_handler_ntlm.h" @@ -25,7 +28,8 @@ namespace net { struct MockConnect { // Asynchronous connection success. - MockConnect() : async(true), result(net::OK) { } + MockConnect() : async(true), result(OK) { } + MockConnect(bool a, int r) : async(a), result(r) { } bool async; int result; @@ -62,6 +66,7 @@ typedef MockRead MockWrite; struct MockSocket { MockSocket() : reads(NULL), writes(NULL) { } + MockSocket(MockRead* r, MockWrite* w) : reads(r), writes(w) { } MockConnect connect; MockRead* reads; @@ -76,37 +81,35 @@ struct MockSocket { // MockSocket* mock_sockets[10]; +// MockSSLSockets only need to keep track of the return code from calls to +// Connect(). +struct MockSSLSocket { + MockSSLSocket(bool async, int result) : connect(async, result) { } + + MockConnect connect; +}; +MockSSLSocket* mock_ssl_sockets[10]; + // Index of the next mock_sockets element to use. int mock_sockets_index; +int mock_ssl_sockets_index; -class MockTCPClientSocket : public net::ClientSocket { +class MockClientSocket : public SSLClientSocket { public: - explicit MockTCPClientSocket(const net::AddressList& addresses) - : data_(mock_sockets[mock_sockets_index++]), - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), + explicit MockClientSocket() + : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), callback_(NULL), - read_index_(0), - read_offset_(0), - write_index_(0), connected_(false) { - DCHECK(data_) << "overran mock_sockets array"; } + // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback) { - DCHECK(!callback_); - if (connected_) - return net::OK; - connected_ = true; - if (data_->connect.async) { - RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; - } - return data_->connect.result; - } - virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { + virtual int Connect(CompletionCallback* callback) = 0; + + // SSLClientSocket methods: + virtual void GetSSLInfo(SSLInfo* ssl_info) { NOTREACHED(); - return net::ERR_FAILED; } + virtual void Disconnect() { connected_ = false; callback_ = NULL; @@ -118,7 +121,64 @@ class MockTCPClientSocket : public net::ClientSocket { return connected_; } // Socket methods: - virtual int Read(char* buf, int buf_len, net::CompletionCallback* callback) { + virtual int Read(char* buf, int buf_len, + CompletionCallback* callback) = 0; + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) = 0; + +#if defined(OS_LINUX) + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen) { + memset(reinterpret_cast<char *>(name), 0, *namelen); + return OK; + } +#endif + + + protected: + void RunCallbackAsync(CompletionCallback* callback, int result) { + callback_ = callback; + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &MockClientSocket::RunCallback, result)); + } + + void RunCallback(int result) { + CompletionCallback* c = callback_; + callback_ = NULL; + if (c) + c->Run(result); + } + + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; + CompletionCallback* callback_; + bool connected_; +}; + +class MockTCPClientSocket : public MockClientSocket { + public: + explicit MockTCPClientSocket(const AddressList& addresses) + : data_(mock_sockets[mock_sockets_index++]), + read_index_(0), + read_offset_(0), + write_index_(0) { + DCHECK(data_) << "overran mock_sockets array"; + } + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + if (connected_) + return OK; + connected_ = true; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + return data_->connect.result; + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { DCHECK(!callback_); MockRead& r = data_->reads[read_index_]; int result = r.result; @@ -137,12 +197,13 @@ class MockTCPClientSocket : public net::ClientSocket { } if (r.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + virtual int Write(const char* buf, int buf_len, - net::CompletionCallback* callback) { + CompletionCallback* callback) { DCHECK(buf); DCHECK(buf_len > 0); DCHECK(!callback_); @@ -159,49 +220,123 @@ class MockTCPClientSocket : public net::ClientSocket { std::string actual_data(buf, buf_len); EXPECT_EQ(expected_data, actual_data); if (expected_data != actual_data) - return net::ERR_UNEXPECTED; - if (result == net::OK) + return ERR_UNEXPECTED; + if (result == OK) result = w.data_len; } if (w.async) { RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return result; } + private: - void RunCallbackAsync(net::CompletionCallback* callback, int result) { - callback_ = callback; - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockTCPClientSocket::RunCallback, result)); - } - void RunCallback(int result) { - net::CompletionCallback* c = callback_; - callback_ = NULL; - if (c) - c->Run(result); - } MockSocket* data_; - ScopedRunnableMethodFactory<MockTCPClientSocket> method_factory_; - net::CompletionCallback* callback_; int read_index_; int read_offset_; int write_index_; - bool connected_; }; -class MockClientSocketFactory : public net::ClientSocketFactory { +class MockSSLClientSocket : public MockClientSocket { public: - virtual net::ClientSocket* CreateTCPClientSocket( - const net::AddressList& addresses) { + explicit MockSSLClientSocket( + ClientSocket* transport_socket, + const std::string& hostname, + const SSLConfig& ssl_config) + : transport_(transport_socket), + data_(mock_ssl_sockets[mock_ssl_sockets_index++]) { + DCHECK(data_) << "overran mock_ssl_sockets array"; + } + + ~MockSSLClientSocket() { + Disconnect(); + } + + virtual void GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + } + + friend class ConnectCallback; + class ConnectCallback : + public CompletionCallbackImpl<ConnectCallback> { + public: + ConnectCallback(MockSSLClientSocket *ssl_client_socket, + CompletionCallback* user_callback, + int rv) + : ALLOW_THIS_IN_INITIALIZER_LIST( + CompletionCallbackImpl<ConnectCallback>( + this, &ConnectCallback::Wrapper)), + ssl_client_socket_(ssl_client_socket), + user_callback_(user_callback), + rv_(rv) { + } + + private: + void Wrapper(int rv) { + if (rv_ == OK) + ssl_client_socket_->connected_ = true; + user_callback_->Run(rv_); + delete this; + } + + MockSSLClientSocket* ssl_client_socket_; + CompletionCallback* user_callback_; + int rv_; + }; + + virtual int Connect(CompletionCallback* callback) { + DCHECK(!callback_); + ConnectCallback* connect_callback = new ConnectCallback( + this, callback, data_->connect.result); + int rv = transport_->Connect(connect_callback); + if (rv == OK) { + delete connect_callback; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + if (data_->connect.result == OK) + connected_ = true; + return data_->connect.result; + } + return rv; + } + + virtual void Disconnect() { + MockClientSocket::Disconnect(); + if (transport_ != NULL) + transport_->Disconnect(); + } + + // Socket methods: + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Read(buf, buf_len, callback); + } + + virtual int Write(const char* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!callback_); + return transport_->Write(buf, buf_len, callback); + } + + private: + scoped_ptr<ClientSocket> transport_; + MockSSLSocket* data_; +}; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + virtual ClientSocket* CreateTCPClientSocket( + const AddressList& addresses) { return new MockTCPClientSocket(addresses); } - virtual net::SSLClientSocket* CreateSSLClientSocket( - net::ClientSocket* transport_socket, + virtual SSLClientSocket* CreateSSLClientSocket( + ClientSocket* transport_socket, const std::string& hostname, - const net::SSLConfig& ssl_config) { - return NULL; + const SSLConfig& ssl_config) { + return new MockSSLClientSocket(transport_socket, hostname, ssl_config); } }; @@ -229,6 +364,7 @@ class HttpNetworkTransactionTest : public PlatformTest { PlatformTest::SetUp(); mock_sockets[0] = NULL; mock_sockets_index = 0; + mock_ssl_sockets_index = 0; } virtual void TearDown() { @@ -2711,4 +2847,139 @@ TEST_F(HttpNetworkTransactionTest, ResetStateForRestart) { EXPECT_FALSE(trans->response_.vary_data.is_valid()); } +// Test HTTPS connections to a site with a bad certificate +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificate) { + scoped_ptr<ProxyService> proxy_service(CreateNullProxyService()); + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite data_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket ssl_bad_certificate; + MockSocket data(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &ssl_bad_certificate; + mock_sockets[1] = &data; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); +} + +// Test HTTPS connections to a site with a bad certificate, going through a +// proxy +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificateViaProxy) { + scoped_ptr<ProxyService> proxy_service( + CreateFixedProxyService("myproxy:70")); + + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("https://www.google.com/"); + request.load_flags = 0; + + MockWrite proxy_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + }; + + MockRead proxy_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead(false, net::OK) + }; + + MockWrite data_writes[] = { + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n\r\n"), + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads[] = { + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(false, OK), + }; + + MockSocket ssl_bad_certificate(proxy_reads, proxy_writes); + MockSocket data(data_reads, data_writes); + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); + MockSSLSocket ssl(true, OK); + + mock_sockets[0] = &ssl_bad_certificate; + mock_sockets[1] = &data; + mock_sockets[2] = NULL; + + mock_ssl_sockets[0] = &ssl_bad; + mock_ssl_sockets[1] = &ssl; + mock_ssl_sockets[2] = NULL; + + TestCompletionCallback callback; + + for (int i = 0; i < 2; i++) { + mock_sockets_index = 0; + mock_ssl_sockets_index = 0; + + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( + CreateSession(proxy_service.get()), &mock_socket_factory)); + + int rv = trans->Start(&request, &callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); + + rv = trans->RestartIgnoringLastError(&callback); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + + EXPECT_FALSE(response == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); + } +} + } // namespace net |