diff options
-rw-r--r-- | net/socket/ssl_client_socket_snapstart_unittest.cc | 45 | ||||
-rw-r--r-- | net/test/openssl_helper.cc | 36 |
2 files changed, 75 insertions, 6 deletions
diff --git a/net/socket/ssl_client_socket_snapstart_unittest.cc b/net/socket/ssl_client_socket_snapstart_unittest.cc index b3fb07e..13b2636 100644 --- a/net/socket/ssl_client_socket_snapstart_unittest.cc +++ b/net/socket/ssl_client_socket_snapstart_unittest.cc @@ -226,6 +226,8 @@ class SSLClientSocketSnapStartTest : public PlatformTest { EXPECT_EQ(8, rv); EXPECT_TRUE(memcmp(reply_buffer->data(), "goodbye!", 8) == 0); + next_proto_status_ = sock->GetNextProto(&next_proto_); + sock->Disconnect(); } @@ -266,6 +268,8 @@ class SSLClientSocketSnapStartTest : public PlatformTest { int client_; SSLConfig ssl_config_; CapturingNetLog log_; + SSLClientSocket::NextProtoStatus next_proto_status_; + std::string next_proto_; }; TEST_F(SSLClientSocketSnapStartTest, Basic) { @@ -318,4 +322,45 @@ TEST_F(SSLClientSocketSnapStartTest, SnapStartResumeRecovery) { EXPECT_TRUE(DidMerge()); } +TEST_F(SSLClientSocketSnapStartTest, SnapStartWithNPN) { + ssl_config_.next_protos.assign("\003foo\003bar"); + StartSnapStartServer("snap-start", "npn", NULL); + PerformConnection(); + EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_); + EXPECT_EQ("foo", next_proto_); + EXPECT_EQ(SSL_SNAP_START_NONE, SnapStartEventType()); + EXPECT_FALSE(DidMerge()); + SSLClientSocketNSS::ClearSessionCache(); + PerformConnection(); + EXPECT_EQ(SSL_SNAP_START_FULL, SnapStartEventType()); + EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_); + EXPECT_EQ("foo", next_proto_); + EXPECT_TRUE(DidMerge()); +} + +TEST_F(SSLClientSocketSnapStartTest, SnapStartWithNPNMispredict) { + // This tests that we recover in the event of a misprediction. + ssl_config_.next_protos.assign("\003foo\003baz"); + StartSnapStartServer("snap-start", "npn-mispredict", NULL); + PerformConnection(); + EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_); + EXPECT_EQ("foo", next_proto_); + EXPECT_EQ(SSL_SNAP_START_NONE, SnapStartEventType()); + EXPECT_FALSE(DidMerge()); + + SSLClientSocketNSS::ClearSessionCache(); + PerformConnection(); + EXPECT_EQ(SSL_SNAP_START_RECOVERY, SnapStartEventType()); + EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_); + EXPECT_EQ("baz", next_proto_); + EXPECT_TRUE(DidMerge()); + + SSLClientSocketNSS::ClearSessionCache(); + PerformConnection(); + EXPECT_EQ(SSL_SNAP_START_FULL, SnapStartEventType()); + EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_); + EXPECT_EQ("baz", next_proto_); + EXPECT_TRUE(DidMerge()); +} + } // namespace net diff --git a/net/test/openssl_helper.cc b/net/test/openssl_helper.cc index b3eb20f..25989cb 100644 --- a/net/test/openssl_helper.cc +++ b/net/test/openssl_helper.cc @@ -30,9 +30,19 @@ static int verify_cb(int preverify_ok, X509_STORE_CTX *ctx) { // Next Protocol Negotiation callback from OpenSSL static int next_proto_cb(SSL *ssl, const unsigned char **out, unsigned int *outlen, void *arg) { + bool* npn_mispredict = reinterpret_cast<bool*>(arg); static char kProtos[] = "\003foo\003bar"; - *out = (const unsigned char*) kProtos; - *outlen = sizeof(kProtos) - 1; + static char kProtos2[] = "\003baz\003boo"; + static unsigned count = 0; + + if (!*npn_mispredict || count == 0) { + *out = (const unsigned char*) kProtos; + *outlen = sizeof(kProtos) - 1; + } else { + *out = (const unsigned char*) kProtos2; + *outlen = sizeof(kProtos2) - 1; + } + count++; return SSL_TLSEXT_ERR_OK; } @@ -46,6 +56,7 @@ main(int argc, char **argv) { bool sni = false, sni_good = false, snap_start = false; bool snap_start_recovery = false, sslv3 = false, session_tickets = false; bool fail_resume = false, client_cert = false, npn = false; + bool npn_mispredict = false; const char* key_file = kDefaultPEMFile; const char* cert_file = kDefaultPEMFile; @@ -76,6 +87,10 @@ main(int argc, char **argv) { } else if (strcmp(argv[i], "npn") == 0) { // Advertise NPN npn = true; + } else if (strcmp(argv[i], "npn-mispredict") == 0) { + // Advertise NPN + npn = true; + npn_mispredict = true; } else if (strcmp(argv[i], "--key-file") == 0) { // Use alternative key file i++; @@ -165,11 +180,13 @@ main(int argc, char **argv) { } if (npn) - SSL_CTX_set_next_protos_advertised_cb(ctx, next_proto_cb, NULL); + SSL_CTX_set_next_protos_advertised_cb(ctx, next_proto_cb, &npn_mispredict); unsigned connection_limit = 1; if (snap_start || session_tickets) connection_limit = 2; + if (npn_mispredict) + connection_limit = 3; for (unsigned connections = 0; connections < connection_limit; connections++) { @@ -209,10 +226,17 @@ main(int argc, char **argv) { } if (npn) { - const unsigned char *data; - unsigned len; + const unsigned char *data, *expected_data; + unsigned len, expected_len; SSL_get0_next_proto_negotiated(server, &data, &len); - if (len != 3 || memcmp(data, "bar", 3) != 0) { + if (!npn_mispredict || connections == 0) { + expected_data = (unsigned char*) "foo"; + expected_len = 3; + } else { + expected_data = (unsigned char*) "baz"; + expected_len = 3; + } + if (len != expected_len || memcmp(data, expected_data, len) != 0) { fprintf(stderr, "Bad NPN: %d\n", len); return 1; } |