summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/socket/ssl_client_socket_snapstart_unittest.cc45
-rw-r--r--net/test/openssl_helper.cc36
2 files changed, 75 insertions, 6 deletions
diff --git a/net/socket/ssl_client_socket_snapstart_unittest.cc b/net/socket/ssl_client_socket_snapstart_unittest.cc
index b3fb07e..13b2636 100644
--- a/net/socket/ssl_client_socket_snapstart_unittest.cc
+++ b/net/socket/ssl_client_socket_snapstart_unittest.cc
@@ -226,6 +226,8 @@ class SSLClientSocketSnapStartTest : public PlatformTest {
EXPECT_EQ(8, rv);
EXPECT_TRUE(memcmp(reply_buffer->data(), "goodbye!", 8) == 0);
+ next_proto_status_ = sock->GetNextProto(&next_proto_);
+
sock->Disconnect();
}
@@ -266,6 +268,8 @@ class SSLClientSocketSnapStartTest : public PlatformTest {
int client_;
SSLConfig ssl_config_;
CapturingNetLog log_;
+ SSLClientSocket::NextProtoStatus next_proto_status_;
+ std::string next_proto_;
};
TEST_F(SSLClientSocketSnapStartTest, Basic) {
@@ -318,4 +322,45 @@ TEST_F(SSLClientSocketSnapStartTest, SnapStartResumeRecovery) {
EXPECT_TRUE(DidMerge());
}
+TEST_F(SSLClientSocketSnapStartTest, SnapStartWithNPN) {
+ ssl_config_.next_protos.assign("\003foo\003bar");
+ StartSnapStartServer("snap-start", "npn", NULL);
+ PerformConnection();
+ EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_);
+ EXPECT_EQ("foo", next_proto_);
+ EXPECT_EQ(SSL_SNAP_START_NONE, SnapStartEventType());
+ EXPECT_FALSE(DidMerge());
+ SSLClientSocketNSS::ClearSessionCache();
+ PerformConnection();
+ EXPECT_EQ(SSL_SNAP_START_FULL, SnapStartEventType());
+ EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_);
+ EXPECT_EQ("foo", next_proto_);
+ EXPECT_TRUE(DidMerge());
+}
+
+TEST_F(SSLClientSocketSnapStartTest, SnapStartWithNPNMispredict) {
+ // This tests that we recover in the event of a misprediction.
+ ssl_config_.next_protos.assign("\003foo\003baz");
+ StartSnapStartServer("snap-start", "npn-mispredict", NULL);
+ PerformConnection();
+ EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_);
+ EXPECT_EQ("foo", next_proto_);
+ EXPECT_EQ(SSL_SNAP_START_NONE, SnapStartEventType());
+ EXPECT_FALSE(DidMerge());
+
+ SSLClientSocketNSS::ClearSessionCache();
+ PerformConnection();
+ EXPECT_EQ(SSL_SNAP_START_RECOVERY, SnapStartEventType());
+ EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_);
+ EXPECT_EQ("baz", next_proto_);
+ EXPECT_TRUE(DidMerge());
+
+ SSLClientSocketNSS::ClearSessionCache();
+ PerformConnection();
+ EXPECT_EQ(SSL_SNAP_START_FULL, SnapStartEventType());
+ EXPECT_EQ(SSLClientSocket::kNextProtoNegotiated, next_proto_status_);
+ EXPECT_EQ("baz", next_proto_);
+ EXPECT_TRUE(DidMerge());
+}
+
} // namespace net
diff --git a/net/test/openssl_helper.cc b/net/test/openssl_helper.cc
index b3eb20f..25989cb 100644
--- a/net/test/openssl_helper.cc
+++ b/net/test/openssl_helper.cc
@@ -30,9 +30,19 @@ static int verify_cb(int preverify_ok, X509_STORE_CTX *ctx) {
// Next Protocol Negotiation callback from OpenSSL
static int next_proto_cb(SSL *ssl, const unsigned char **out,
unsigned int *outlen, void *arg) {
+ bool* npn_mispredict = reinterpret_cast<bool*>(arg);
static char kProtos[] = "\003foo\003bar";
- *out = (const unsigned char*) kProtos;
- *outlen = sizeof(kProtos) - 1;
+ static char kProtos2[] = "\003baz\003boo";
+ static unsigned count = 0;
+
+ if (!*npn_mispredict || count == 0) {
+ *out = (const unsigned char*) kProtos;
+ *outlen = sizeof(kProtos) - 1;
+ } else {
+ *out = (const unsigned char*) kProtos2;
+ *outlen = sizeof(kProtos2) - 1;
+ }
+ count++;
return SSL_TLSEXT_ERR_OK;
}
@@ -46,6 +56,7 @@ main(int argc, char **argv) {
bool sni = false, sni_good = false, snap_start = false;
bool snap_start_recovery = false, sslv3 = false, session_tickets = false;
bool fail_resume = false, client_cert = false, npn = false;
+ bool npn_mispredict = false;
const char* key_file = kDefaultPEMFile;
const char* cert_file = kDefaultPEMFile;
@@ -76,6 +87,10 @@ main(int argc, char **argv) {
} else if (strcmp(argv[i], "npn") == 0) {
// Advertise NPN
npn = true;
+ } else if (strcmp(argv[i], "npn-mispredict") == 0) {
+ // Advertise NPN
+ npn = true;
+ npn_mispredict = true;
} else if (strcmp(argv[i], "--key-file") == 0) {
// Use alternative key file
i++;
@@ -165,11 +180,13 @@ main(int argc, char **argv) {
}
if (npn)
- SSL_CTX_set_next_protos_advertised_cb(ctx, next_proto_cb, NULL);
+ SSL_CTX_set_next_protos_advertised_cb(ctx, next_proto_cb, &npn_mispredict);
unsigned connection_limit = 1;
if (snap_start || session_tickets)
connection_limit = 2;
+ if (npn_mispredict)
+ connection_limit = 3;
for (unsigned connections = 0; connections < connection_limit;
connections++) {
@@ -209,10 +226,17 @@ main(int argc, char **argv) {
}
if (npn) {
- const unsigned char *data;
- unsigned len;
+ const unsigned char *data, *expected_data;
+ unsigned len, expected_len;
SSL_get0_next_proto_negotiated(server, &data, &len);
- if (len != 3 || memcmp(data, "bar", 3) != 0) {
+ if (!npn_mispredict || connections == 0) {
+ expected_data = (unsigned char*) "foo";
+ expected_len = 3;
+ } else {
+ expected_data = (unsigned char*) "baz";
+ expected_len = 3;
+ }
+ if (len != expected_len || memcmp(data, expected_data, len) != 0) {
fprintf(stderr, "Bad NPN: %d\n", len);
return 1;
}