diff options
-rw-r--r-- | net/socket/ssl_client_socket_unittest.cc | 382 | ||||
-rw-r--r-- | net/test/spawned_test_server/base_test_server.cc | 22 | ||||
-rw-r--r-- | net/test/spawned_test_server/base_test_server.h | 20 | ||||
-rwxr-xr-x | net/tools/testserver/testserver.py | 15 |
4 files changed, 334 insertions, 105 deletions
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index 7dd9d64..9d7f2e9 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -6,6 +6,7 @@ #include "base/callback_helpers.h" #include "base/memory/ref_counted.h" +#include "base/run_loop.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -327,144 +328,195 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { // Socket implementation: virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return read_state_.RunWrappedFunction(buf, buf_len, callback); - } + const CompletionCallback& callback) OVERRIDE; virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return write_state_.RunWrappedFunction(buf, buf_len, callback); - } + const CompletionCallback& callback) OVERRIDE; // Causes the next call to Read() to return ERR_IO_PENDING, not completing // (invoking the callback) until UnblockRead() has been called and the // underlying transport has completed. - void SetNextReadShouldBlock() { read_state_.SetShouldBlock(); } - void UnblockRead() { read_state_.Unblock(); } + void SetNextReadShouldBlock(); + void UnblockRead(); - // Causes the next call to Write() to return ERR_IO_PENDING, not completing - // (invoking the callback) until UnblockWrite() has been called and the - // underlying transport has completed. - void SetNextWriteShouldBlock() { write_state_.SetShouldBlock(); } - void UnblockWrite() { write_state_.Unblock(); } + // Waits for the blocked Read() call to be complete at the underlying + // transport. + void WaitForRead(); + + // Causes the next call to Write() to return ERR_IO_PENDING, not beginning the + // underlying transport until UnblockWrite() has been called. + void SetNextWriteShouldBlock(); + void UnblockWrite(); + + // Waits for the blocked Write() call to be scheduled. + void WaitForWrite(); private: - // Tracks the state for simulating a blocking Read/Write operation. - class BlockingState { - public: - // Wrapper for the underlying Socket function to call (ie: Read/Write). - typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)> - WrappedSocketFunction; - - explicit BlockingState(const WrappedSocketFunction& function); - ~BlockingState() {} - - // Sets the next call to RunWrappedFunction() to block, returning - // ERR_IO_PENDING and not invoking the user callback until Unblock() is - // called. - void SetShouldBlock(); - - // Unblocks the currently blocked pending function, invoking the user - // callback if the results are immediately available. - // Note: It's not valid to call this unless SetShouldBlock() has been - // called beforehand. - void Unblock(); - - // Performs the wrapped socket function on the underlying transport. If - // configured to block via SetShouldBlock(), then |user_callback| will not - // be invoked until Unblock() has been called. - int RunWrappedFunction(IOBuffer* buf, - int len, - const CompletionCallback& user_callback); - - private: - // Handles completion from the underlying wrapped socket function. - void OnCompleted(int result); - - WrappedSocketFunction wrapped_function_; - bool should_block_; - bool have_result_; - int pending_result_; - CompletionCallback user_callback_; - }; + // Handles completion from the underlying transport read. + void OnReadCompleted(int result); + + // True if read callbacks are blocked. + bool should_block_read_; + + // The user callback for the pending read call. + CompletionCallback pending_read_callback_; + + // The result for the blocked read callback, or ERR_IO_PENDING if not + // completed. + int pending_read_result_; - BlockingState read_state_; - BlockingState write_state_; + // WaitForRead() wait loop. + scoped_ptr<base::RunLoop> read_loop_; - DISALLOW_COPY_AND_ASSIGN(FakeBlockingStreamSocket); + // True if write calls are blocked. + bool should_block_write_; + + // The buffer for the pending write, or NULL if not scheduled. + scoped_refptr<IOBuffer> pending_write_buf_; + + // The callback for the pending write call. + CompletionCallback pending_write_callback_; + + // The length for the pending write, or -1 if not scheduled. + int pending_write_len_; + + // WaitForWrite() wait loop. + scoped_ptr<base::RunLoop> write_loop_; }; FakeBlockingStreamSocket::FakeBlockingStreamSocket( scoped_ptr<StreamSocket> transport) : WrappedStreamSocket(transport.Pass()), - read_state_(base::Bind(&Socket::Read, - base::Unretained(transport_.get()))), - write_state_(base::Bind(&Socket::Write, - base::Unretained(transport_.get()))) {} - -FakeBlockingStreamSocket::BlockingState::BlockingState( - const WrappedSocketFunction& function) - : wrapped_function_(function), - should_block_(false), - have_result_(false), - pending_result_(OK) {} - -void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() { - DCHECK(!should_block_); - should_block_ = true; + should_block_read_(false), + pending_read_result_(ERR_IO_PENDING), + should_block_write_(false), + pending_write_len_(-1) {} + +int FakeBlockingStreamSocket::Read(IOBuffer* buf, + int len, + const CompletionCallback& callback) { + if (!should_block_read_) + return transport_->Read(buf, len, callback); + + DCHECK(pending_read_callback_.is_null()); + DCHECK_EQ(ERR_IO_PENDING, pending_read_result_); + + int rv = transport_->Read(buf, len, base::Bind( + &FakeBlockingStreamSocket::OnReadCompleted, base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + pending_read_callback_ = callback; + } else { + OnReadCompleted(rv); + } + return ERR_IO_PENDING; +} + +int FakeBlockingStreamSocket::Write(IOBuffer* buf, + int len, + const CompletionCallback& callback) { + DCHECK(buf); + DCHECK_LE(0, len); + + if (!should_block_write_) + return transport_->Write(buf, len, callback); + + // Schedule the write, but do nothing. + DCHECK(!pending_write_buf_); + DCHECK_EQ(-1, pending_write_len_); + DCHECK(pending_write_callback_.is_null()); + pending_write_buf_ = buf; + pending_write_len_ = len; + pending_write_callback_ = callback; + + // Stop the write loop, if any. + if (write_loop_) + write_loop_->Quit(); + return ERR_IO_PENDING; } -void FakeBlockingStreamSocket::BlockingState::Unblock() { - DCHECK(should_block_); - should_block_ = false; +void FakeBlockingStreamSocket::SetNextReadShouldBlock() { + DCHECK(!should_block_read_); + should_block_read_ = true; +} + +void FakeBlockingStreamSocket::UnblockRead() { + DCHECK(should_block_read_); + should_block_read_ = false; // If the operation is still pending in the underlying transport, immediately - // return - OnCompleted() will handle invoking the callback once the transport - // has completed. - if (!have_result_) + // return - OnReadCompleted() will handle invoking the callback once the + // transport has completed. + if (pending_read_result_ == ERR_IO_PENDING) return; + int result = pending_read_result_; + pending_read_result_ = ERR_IO_PENDING; + base::ResetAndReturn(&pending_read_callback_).Run(result); +} - have_result_ = false; +void FakeBlockingStreamSocket::WaitForRead() { + DCHECK(should_block_read_); + DCHECK(!read_loop_); - base::ResetAndReturn(&user_callback_).Run(pending_result_); + if (pending_read_result_ != ERR_IO_PENDING) + return; + read_loop_.reset(new base::RunLoop); + read_loop_->Run(); + read_loop_.reset(); + DCHECK_NE(ERR_IO_PENDING, pending_read_result_); } -int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction( - IOBuffer* buf, - int len, - const CompletionCallback& callback) { - - // The callback to be called by the underlying transport. Either forward - // directly to the user's callback if not set to block, or intercept it with - // OnCompleted so that the user's callback is not invoked until Unblock() is - // called. - CompletionCallback transport_callback = - !should_block_ ? callback : base::Bind(&BlockingState::OnCompleted, - base::Unretained(this)); - int rv = wrapped_function_.Run(buf, len, transport_callback); - if (should_block_) { - user_callback_ = callback; - // May have completed synchronously. - have_result_ = (rv != ERR_IO_PENDING); - pending_result_ = rv; - return ERR_IO_PENDING; +void FakeBlockingStreamSocket::SetNextWriteShouldBlock() { + DCHECK(!should_block_write_); + should_block_write_ = true; +} + +void FakeBlockingStreamSocket::UnblockWrite() { + DCHECK(should_block_write_); + should_block_write_ = false; + + // Do nothing if UnblockWrite() was called after SetNextWriteShouldBlock(), + // without a Write() in between. + if (!pending_write_buf_) + return; + + int rv = transport_->Write(pending_write_buf_, pending_write_len_, + pending_write_callback_); + pending_write_buf_ = NULL; + pending_write_len_ = -1; + if (rv == ERR_IO_PENDING) { + pending_write_callback_.Reset(); + } else { + base::ResetAndReturn(&pending_write_callback_).Run(rv); } +} - return rv; +void FakeBlockingStreamSocket::WaitForWrite() { + DCHECK(should_block_write_); + DCHECK(!write_loop_); + + if (pending_write_buf_) + return; + write_loop_.reset(new base::RunLoop); + write_loop_->Run(); + write_loop_.reset(); + DCHECK(pending_write_buf_); } -void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) { - if (should_block_) { +void FakeBlockingStreamSocket::OnReadCompleted(int result) { + if (should_block_read_) { // Store the result so that the callback can be invoked once Unblock() is // called. - have_result_ = true; - pending_result_ = result; + pending_read_result_ = result; + + // Stop the WaitForRead() call if any. + if (read_loop_) + read_loop_->Quit(); return; } // Otherwise, the Unblock() function was called before the underlying // transport completed, so run the user's callback immediately. - base::ResetAndReturn(&user_callback_).Run(result); + base::ResetAndReturn(&pending_read_callback_).Run(result); } // CompletionCallback that will delete the associated StreamSocket when @@ -565,6 +617,93 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { } }; +class SSLClientSocketFalseStartTest : public SSLClientSocketTest { + protected: + void TestFalseStart(const SpawnedTestServer::SSLOptions& server_options, + const SSLConfig& client_config, + bool expect_false_start) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + server_options, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(real_transport.Pass())); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + FakeBlockingStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + client_config)); + + // Connect. Stop before the client processes the first server leg + // (ServerHello, etc.) + raw_transport->SetNextReadShouldBlock(); + rv = sock->Connect(callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + raw_transport->WaitForRead(); + + // Release the ServerHello and wait for the client to write + // ClientKeyExchange, etc. (A proxy for waiting for the entirety of the + // server's leg to complete, since it may span multiple reads.) + EXPECT_FALSE(callback.have_result()); + raw_transport->SetNextWriteShouldBlock(); + raw_transport->UnblockRead(); + raw_transport->WaitForWrite(); + + // And, finally, release that and block the next server leg + // (ChangeCipherSpec, Finished). Note: callback.have_result() may or may not + // be true at this point depending on whether the SSL implementation waits + // for the client second leg to clear the internal write buffer and hit the + // network. + raw_transport->SetNextReadShouldBlock(); + raw_transport->UnblockWrite(); + + if (expect_false_start) { + // When False Starting, the handshake should complete before receiving the + // Change Cipher Spec and Finished messages. + rv = callback.GetResult(rv); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + static const int kRequestTextSize = + static_cast<int>(arraysize(request_text) - 1); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); + memcpy(request_buffer->data(), request_text, kRequestTextSize); + + // Write the request. + rv = callback.GetResult(sock->Write(request_buffer.get(), + kRequestTextSize, + callback.callback())); + EXPECT_EQ(kRequestTextSize, rv); + + // The read will hang; it's waiting for the peer to complete the + // handshake, and the handshake is still blocked. + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + rv = sock->Read(buf.get(), 4096, callback.callback()); + + // After releasing reads, the connection proceeds. + raw_transport->UnblockRead(); + rv = callback.GetResult(rv); + EXPECT_LT(0, rv); + } else { + // False Start is not enabled, so the handshake will not complete because + // the server second leg is blocked. + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(callback.have_result()); + } + } +}; + //----------------------------------------------------------------------------- // LogContainsSSLConnectEndEvent returns true if the given index in the given @@ -2183,4 +2322,43 @@ TEST_F(SSLClientSocketTest, ReuseStates) { // attempt to read one byte extra. } +// This test is only enabled on NSS until False Start support is added for +// OpenSSL. http://crbug.com/354132 +#if defined(USE_NSS) +#define MAYBE_FalseStartEnabled FalseStartEnabled +#else +#define MAYBE_FalseStartEnabled DISABLED_FalseStartEnabled +#endif // USE_NSS +TEST_F(SSLClientSocketFalseStartTest, MAYBE_FalseStartEnabled) { + // False Start requires NPN and a forward-secret cipher suite. + SpawnedTestServer::SSLOptions server_options; + server_options.key_exchanges = + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + server_options.enable_npn = true; + SSLConfig client_config; + client_config.next_protos.push_back("http/1.1"); + TestFalseStart(server_options, client_config, true); +} + +// Test that False Start is disabled without NPN. +TEST_F(SSLClientSocketFalseStartTest, NoNPN) { + SpawnedTestServer::SSLOptions server_options; + server_options.key_exchanges = + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + SSLConfig client_config; + client_config.next_protos.clear(); + TestFalseStart(server_options, client_config, false); +} + +// Test that False Start is disabled without a forward-secret cipher suite. +TEST_F(SSLClientSocketFalseStartTest, NoForwardSecrecy) { + SpawnedTestServer::SSLOptions server_options; + server_options.key_exchanges = + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_RSA; + server_options.enable_npn = true; + SSLConfig client_config; + client_config.next_protos.push_back("http/1.1"); + TestFalseStart(server_options, client_config, false); +} + } // namespace net diff --git a/net/test/spawned_test_server/base_test_server.cc b/net/test/spawned_test_server/base_test_server.cc index ac37c70..a781c6e 100644 --- a/net/test/spawned_test_server/base_test_server.cc +++ b/net/test/spawned_test_server/base_test_server.cc @@ -40,6 +40,13 @@ std::string GetHostname(BaseTestServer::Type type, return BaseTestServer::kLocalhost; } +void GetKeyExchangesList(int key_exchange, base::ListValue* values) { + if (key_exchange & BaseTestServer::SSLOptions::KEY_EXCHANGE_RSA) + values->Append(new base::StringValue("rsa")); + if (key_exchange & BaseTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA) + values->Append(new base::StringValue("dhe_rsa")); +} + void GetCiphersList(int cipher, base::ListValue* values) { if (cipher & BaseTestServer::SSLOptions::BULK_CIPHER_RC4) values->Append(new base::StringValue("rc4")); @@ -58,11 +65,13 @@ BaseTestServer::SSLOptions::SSLOptions() ocsp_status(OCSP_OK), cert_serial(0), request_client_certificate(false), + key_exchanges(SSLOptions::KEY_EXCHANGE_ANY), bulk_ciphers(SSLOptions::BULK_CIPHER_ANY), record_resume(false), tls_intolerant(TLS_INTOLERANT_NONE), fallback_scsv_enabled(false), - staple_ocsp_response(false) {} + staple_ocsp_response(false), + enable_npn(false) {} BaseTestServer::SSLOptions::SSLOptions( BaseTestServer::SSLOptions::ServerCertificate cert) @@ -70,11 +79,13 @@ BaseTestServer::SSLOptions::SSLOptions( ocsp_status(OCSP_OK), cert_serial(0), request_client_certificate(false), + key_exchanges(SSLOptions::KEY_EXCHANGE_ANY), bulk_ciphers(SSLOptions::BULK_CIPHER_ANY), record_resume(false), tls_intolerant(TLS_INTOLERANT_NONE), fallback_scsv_enabled(false), - staple_ocsp_response(false) {} + staple_ocsp_response(false), + enable_npn(false) {} BaseTestServer::SSLOptions::~SSLOptions() {} @@ -389,6 +400,11 @@ bool BaseTestServer::GenerateArguments(base::DictionaryValue* arguments) const { base::Value::CreateIntegerValue(ssl_options_.cert_serial)); } + // Check key exchange argument. + scoped_ptr<base::ListValue> key_exchange_values(new base::ListValue()); + GetKeyExchangesList(ssl_options_.key_exchanges, key_exchange_values.get()); + if (key_exchange_values->GetSize()) + arguments->Set("ssl-key-exchange", key_exchange_values.release()); // Check bulk cipher argument. scoped_ptr<base::ListValue> bulk_cipher_values(new base::ListValue()); GetCiphersList(ssl_options_.bulk_ciphers, bulk_cipher_values.get()); @@ -410,6 +426,8 @@ bool BaseTestServer::GenerateArguments(base::DictionaryValue* arguments) const { } if (ssl_options_.staple_ocsp_response) arguments->Set("staple-ocsp-response", base::Value::CreateNullValue()); + if (ssl_options_.enable_npn) + arguments->Set("enable-npn", base::Value::CreateNullValue()); } return GenerateAdditionalArguments(arguments); diff --git a/net/test/spawned_test_server/base_test_server.h b/net/test/spawned_test_server/base_test_server.h index 163808c..392a72b 100644 --- a/net/test/spawned_test_server/base_test_server.h +++ b/net/test/spawned_test_server/base_test_server.h @@ -72,6 +72,18 @@ class BaseTestServer { OCSP_UNKNOWN, }; + // Bitmask of key exchange algorithms that the test server supports and that + // can be selectively enabled or disabled. + enum KeyExchange { + // Special value used to indicate that any algorithm the server supports + // is acceptable. Preferred over explicitly OR-ing all key exchange + // algorithms. + KEY_EXCHANGE_ANY = 0, + + KEY_EXCHANGE_RSA = (1 << 0), + KEY_EXCHANGE_DHE_RSA = (1 << 1), + }; + // Bitmask of bulk encryption algorithms that the test server supports // and that can be selectively enabled or disabled. enum BulkCipher { @@ -134,6 +146,11 @@ class BaseTestServer { // field of the CertificateRequest. std::vector<base::FilePath> client_authorities; + // A bitwise-OR of KeyExchnage that should be used by the + // HTTPS server, or KEY_EXCHANGE_ANY to indicate that all implemented + // key exchange algorithms are acceptable. + int key_exchanges; + // A bitwise-OR of BulkCipher that should be used by the // HTTPS server, or BULK_CIPHER_ANY to indicate that all implemented // ciphers are acceptable. @@ -165,6 +182,9 @@ class BaseTestServer { // Whether to staple the OCSP response. bool staple_ocsp_response; + + // Whether to enable NPN support. + bool enable_npn; }; // Pass as the 'host' parameter during construction to server on 127.0.0.1 diff --git a/net/tools/testserver/testserver.py b/net/tools/testserver/testserver.py index 2b0c36c..9e86cf2 100755 --- a/net/tools/testserver/testserver.py +++ b/net/tools/testserver/testserver.py @@ -153,7 +153,7 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, def __init__(self, server_address, request_hander_class, pem_cert_and_key, ssl_client_auth, ssl_client_cas, - ssl_bulk_ciphers, ssl_key_exchanges, + ssl_bulk_ciphers, ssl_key_exchanges, enable_npn, record_resume_info, tls_intolerant, signed_cert_timestamps, fallback_scsv_enabled, ocsp_response): self.cert_chain = tlslite.api.X509CertChain() @@ -167,6 +167,10 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, implementations=['python']) self.ssl_client_auth = ssl_client_auth self.ssl_client_cas = [] + if enable_npn: + self.next_protos = ['http/1.1'] + else: + self.next_protos = None if tls_intolerant == 0: self.tls_intolerant = None else: @@ -207,6 +211,7 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, reqCert=self.ssl_client_auth, settings=self.ssl_handshake_settings, reqCAs=self.ssl_client_cas, + nextProtos=self.next_protos, tlsIntolerant=self.tls_intolerant, signedCertTimestamps= self.signed_cert_timestamps, @@ -1986,6 +1991,7 @@ class ServerRunner(testserver_base.TestServerRunner): self.options.ssl_client_ca, self.options.ssl_bulk_cipher, self.options.ssl_key_exchange, + self.options.enable_npn, self.options.record_resume, self.options.tls_intolerant, self.options.signed_cert_timestamps_tls_ext.decode( @@ -2182,6 +2188,13 @@ class ServerRunner(testserver_base.TestServerRunner): 'option may appear multiple times, ' 'indicating multiple algorithms should be ' 'enabled.'); + # TODO(davidben): Add ALPN support to tlslite. + self.option_parser.add_option('--enable-npn', dest='enable_npn', + default=False, const=True, + action='store_const', + help='Enable server support for the NPN ' + 'extension. The server will advertise ' + 'support for exactly one protocol, http/1.1') self.option_parser.add_option('--file-root-url', default='/files/', help='Specify a root URL for files served.') |