diff options
-rw-r--r-- | net/socket/ssl_client_socket_nss.cc | 15 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_nss.h | 6 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_unittest.cc | 46 |
3 files changed, 64 insertions, 3 deletions
diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index d3c1f8d..16a7f5e 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -250,6 +250,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocket* transport_socket, user_write_buf_len_(0), server_cert_nss_(NULL), client_auth_cert_needed_(false), + handshake_callback_called_(false), completed_handshake_(false), next_handshake_state_(STATE_NONE), nss_fd_(NULL), @@ -1267,6 +1268,8 @@ void SSLClientSocketNSS::HandshakeCallback(PRFileDesc* socket, void* arg) { SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); + that->set_handshake_callback_called(); + that->UpdateServerCert(); that->CheckSecureRenegotiation(); @@ -1288,9 +1291,15 @@ int SSLClientSocketNSS::DoHandshake() { LOG(WARNING) << "Couldn't invalidate SSL session: " << PR_GetError(); } } else if (rv == SECSuccess) { - // SSL handshake is completed. Let's verify the certificate. - GotoState(STATE_VERIFY_CERT); - // Done! + if (handshake_callback_called_) { + // SSL handshake is completed. Let's verify the certificate. + GotoState(STATE_VERIFY_CERT); + // Done! + } else { + // SSL_ForceHandshake returned SECSuccess prematurely. + rv = SECFailure; + net_error = ERR_SSL_PROTOCOL_ERROR; + } } else { PRErrorCode prerr = PR_GetError(); net_error = MapHandshakeError(prerr); diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 3543df7..74f2003 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -58,6 +58,8 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); + void set_handshake_callback_called() { handshake_callback_called_ = true; } + private: // Initializes NSS SSL options. Returns a net error code. int InitializeSSLOptions(); @@ -140,6 +142,10 @@ class SSLClientSocketNSS : public SSLClientSocket { scoped_ptr<CertVerifier> verifier_; + // True if NSS has called HandshakeCallback. + bool handshake_callback_called_; + + // True if the SSL handshake has been completed. bool completed_handshake_; enum State { diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index 3606c45..7b297f8 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -13,6 +13,7 @@ #include "net/base/ssl_config_service.h" #include "net/base/test_completion_callback.h" #include "net/socket/client_socket_factory.h" +#include "net/socket/socket_test_util.h" #include "net/socket/ssl_test_util.h" #include "net/socket/tcp_client_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -424,3 +425,48 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { EXPECT_GT(rv, 0); } + +#if !defined(OS_WIN) +// Regression test for http://crbug.com/42538 +TEST_F(SSLClientSocketTest, PrematureApplicationData) { + net::AddressList addr; + TestCompletionCallback callback; + + static const unsigned char application_data[] = { + 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, + 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46, + 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6, + 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36, + 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d, + 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f, + 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57, + 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82, + 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, + 0x0a + }; + + // All reads and writes complete synchronously (async=false). + net::MockRead data_reads[] = { + net::MockRead(false, reinterpret_cast<const char*>(application_data), + arraysize(application_data)), + net::MockRead(false, net::OK), + }; + + net::StaticSocketDataProvider data(data_reads, arraysize(data_reads), + NULL, 0); + + net::ClientSocket* transport = + new net::MockTCPClientSocket(addr, NULL, &data); + int rv = transport->Connect(&callback); + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(net::OK, rv); + + scoped_ptr<net::SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket( + transport, server_.kHostName, kDefaultSSLConfig)); + + rv = sock->Connect(&callback); + EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); +} +#endif // !defined(OS_WIN) |