summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/socket/ssl_client_socket_nss.cc15
-rw-r--r--net/socket/ssl_client_socket_nss.h6
-rw-r--r--net/socket/ssl_client_socket_unittest.cc46
3 files changed, 64 insertions, 3 deletions
diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc
index d3c1f8d..16a7f5e 100644
--- a/net/socket/ssl_client_socket_nss.cc
+++ b/net/socket/ssl_client_socket_nss.cc
@@ -250,6 +250,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocket* transport_socket,
user_write_buf_len_(0),
server_cert_nss_(NULL),
client_auth_cert_needed_(false),
+ handshake_callback_called_(false),
completed_handshake_(false),
next_handshake_state_(STATE_NONE),
nss_fd_(NULL),
@@ -1267,6 +1268,8 @@ void SSLClientSocketNSS::HandshakeCallback(PRFileDesc* socket,
void* arg) {
SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg);
+ that->set_handshake_callback_called();
+
that->UpdateServerCert();
that->CheckSecureRenegotiation();
@@ -1288,9 +1291,15 @@ int SSLClientSocketNSS::DoHandshake() {
LOG(WARNING) << "Couldn't invalidate SSL session: " << PR_GetError();
}
} else if (rv == SECSuccess) {
- // SSL handshake is completed. Let's verify the certificate.
- GotoState(STATE_VERIFY_CERT);
- // Done!
+ if (handshake_callback_called_) {
+ // SSL handshake is completed. Let's verify the certificate.
+ GotoState(STATE_VERIFY_CERT);
+ // Done!
+ } else {
+ // SSL_ForceHandshake returned SECSuccess prematurely.
+ rv = SECFailure;
+ net_error = ERR_SSL_PROTOCOL_ERROR;
+ }
} else {
PRErrorCode prerr = PR_GetError();
net_error = MapHandshakeError(prerr);
diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h
index 3543df7..74f2003 100644
--- a/net/socket/ssl_client_socket_nss.h
+++ b/net/socket/ssl_client_socket_nss.h
@@ -58,6 +58,8 @@ class SSLClientSocketNSS : public SSLClientSocket {
virtual bool SetReceiveBufferSize(int32 size);
virtual bool SetSendBufferSize(int32 size);
+ void set_handshake_callback_called() { handshake_callback_called_ = true; }
+
private:
// Initializes NSS SSL options. Returns a net error code.
int InitializeSSLOptions();
@@ -140,6 +142,10 @@ class SSLClientSocketNSS : public SSLClientSocket {
scoped_ptr<CertVerifier> verifier_;
+ // True if NSS has called HandshakeCallback.
+ bool handshake_callback_called_;
+
+ // True if the SSL handshake has been completed.
bool completed_handshake_;
enum State {
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc
index 3606c45..7b297f8 100644
--- a/net/socket/ssl_client_socket_unittest.cc
+++ b/net/socket/ssl_client_socket_unittest.cc
@@ -13,6 +13,7 @@
#include "net/base/ssl_config_service.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/client_socket_factory.h"
+#include "net/socket/socket_test_util.h"
#include "net/socket/ssl_test_util.h"
#include "net/socket/tcp_client_socket.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -424,3 +425,48 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) {
EXPECT_GT(rv, 0);
}
+
+#if !defined(OS_WIN)
+// Regression test for http://crbug.com/42538
+TEST_F(SSLClientSocketTest, PrematureApplicationData) {
+ net::AddressList addr;
+ TestCompletionCallback callback;
+
+ static const unsigned char application_data[] = {
+ 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b,
+ 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46,
+ 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6,
+ 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36,
+ 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d,
+ 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f,
+ 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57,
+ 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82,
+ 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02,
+ 0x0a
+ };
+
+ // All reads and writes complete synchronously (async=false).
+ net::MockRead data_reads[] = {
+ net::MockRead(false, reinterpret_cast<const char*>(application_data),
+ arraysize(application_data)),
+ net::MockRead(false, net::OK),
+ };
+
+ net::StaticSocketDataProvider data(data_reads, arraysize(data_reads),
+ NULL, 0);
+
+ net::ClientSocket* transport =
+ new net::MockTCPClientSocket(addr, NULL, &data);
+ int rv = transport->Connect(&callback);
+ if (rv == net::ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(net::OK, rv);
+
+ scoped_ptr<net::SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(
+ transport, server_.kHostName, kDefaultSSLConfig));
+
+ rv = sock->Connect(&callback);
+ EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv);
+}
+#endif // !defined(OS_WIN)