diff options
51 files changed, 559 insertions, 56 deletions
diff --git a/content/browser/renderer_host/p2p/socket_host_test_utils.h b/content/browser/renderer_host/p2p/socket_host_test_utils.h index 3ab5de3..c3deea8 100644 --- a/content/browser/renderer_host/p2p/socket_host_test_utils.h +++ b/content/browser/renderer_host/p2p/socket_host_test_utils.h @@ -75,7 +75,9 @@ class FakeSocket : public net::StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; private: bool read_pending_; @@ -219,10 +221,18 @@ base::TimeDelta FakeSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool FakeSocket::WasNpnNegotiated() const { + return false; +} + net::NextProto FakeSocket::GetNegotiatedProtocol() const { return net::kProtoUnknown; } +bool FakeSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + return false; +} + void CreateRandomPacket(std::vector<char>* packet) { size_t size = kStunHeaderSize + rand() % 1000; packet->resize(size); diff --git a/jingle/glue/pseudotcp_adapter.cc b/jingle/glue/pseudotcp_adapter.cc index 65fd156..34b697e 100644 --- a/jingle/glue/pseudotcp_adapter.cc +++ b/jingle/glue/pseudotcp_adapter.cc @@ -573,11 +573,21 @@ base::TimeDelta PseudoTcpAdapter::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool PseudoTcpAdapter::WasNpnNegotiated() const { + DCHECK(CalledOnValidThread()); + return false; +} + net::NextProto PseudoTcpAdapter::GetNegotiatedProtocol() const { DCHECK(CalledOnValidThread()); return net::kProtoUnknown; } +bool PseudoTcpAdapter::GetSSLInfo(net::SSLInfo* ssl_info) { + DCHECK(CalledOnValidThread()); + return false; +} + void PseudoTcpAdapter::SetAckDelay(int delay_ms) { DCHECK(CalledOnValidThread()); core_->SetAckDelay(delay_ms); diff --git a/jingle/glue/pseudotcp_adapter.h b/jingle/glue/pseudotcp_adapter.h index c0114a8..e4e1a54 100644 --- a/jingle/glue/pseudotcp_adapter.h +++ b/jingle/glue/pseudotcp_adapter.h @@ -51,7 +51,9 @@ class PseudoTcpAdapter : public net::StreamSocket, base::NonThreadSafe { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; // Set the delay for sending ACK. void SetAckDelay(int delay_ms); diff --git a/jingle/notifier/base/fake_ssl_client_socket.cc b/jingle/notifier/base/fake_ssl_client_socket.cc index ec8fd26..3329e31 100644 --- a/jingle/notifier/base/fake_ssl_client_socket.cc +++ b/jingle/notifier/base/fake_ssl_client_socket.cc @@ -338,8 +338,16 @@ base::TimeDelta FakeSSLClientSocket::GetConnectTimeMicros() const { return transport_socket_->GetConnectTimeMicros(); } +bool FakeSSLClientSocket::WasNpnNegotiated() const { + return transport_socket_->WasNpnNegotiated(); +} + net::NextProto FakeSSLClientSocket::GetNegotiatedProtocol() const { return transport_socket_->GetNegotiatedProtocol(); } +bool FakeSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + return transport_socket_->GetSSLInfo(ssl_info); +} + } // namespace notifier diff --git a/jingle/notifier/base/fake_ssl_client_socket.h b/jingle/notifier/base/fake_ssl_client_socket.h index f7dc215..b9df8a6 100644 --- a/jingle/notifier/base/fake_ssl_client_socket.h +++ b/jingle/notifier/base/fake_ssl_client_socket.h @@ -29,6 +29,7 @@ namespace net { class DrainableIOBuffer; +class SSLInfo; } // namespace net namespace notifier { @@ -64,7 +65,9 @@ class FakeSSLClientSocket : public net::StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; private: enum HandshakeState { diff --git a/jingle/notifier/base/fake_ssl_client_socket_unittest.cc b/jingle/notifier/base/fake_ssl_client_socket_unittest.cc index fddb693..617e4d1 100644 --- a/jingle/notifier/base/fake_ssl_client_socket_unittest.cc +++ b/jingle/notifier/base/fake_ssl_client_socket_unittest.cc @@ -65,7 +65,9 @@ class MockClientSocket : public net::StreamSocket { MOCK_CONST_METHOD0(UsingTCPFastOpen, bool()); MOCK_CONST_METHOD0(NumBytesRead, int64()); MOCK_CONST_METHOD0(GetConnectTimeMicros, base::TimeDelta()); + MOCK_CONST_METHOD0(WasNpnNegotiated, bool()); MOCK_CONST_METHOD0(GetNegotiatedProtocol, net::NextProto()); + MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*)); }; // Break up |data| into a bunch of chunked MockReads/Writes and push diff --git a/jingle/notifier/base/proxy_resolving_client_socket.cc b/jingle/notifier/base/proxy_resolving_client_socket.cc index 9abe6e3..eb14d56 100644 --- a/jingle/notifier/base/proxy_resolving_client_socket.cc +++ b/jingle/notifier/base/proxy_resolving_client_socket.cc @@ -368,6 +368,10 @@ base::TimeDelta ProxyResolvingClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool ProxyResolvingClientSocket::WasNpnNegotiated() const { + return false; +} + net::NextProto ProxyResolvingClientSocket::GetNegotiatedProtocol() const { if (transport_.get() && transport_->socket()) return transport_->socket()->GetNegotiatedProtocol(); @@ -375,6 +379,10 @@ net::NextProto ProxyResolvingClientSocket::GetNegotiatedProtocol() const { return net::kProtoUnknown; } +bool ProxyResolvingClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + return false; +} + void ProxyResolvingClientSocket::CloseTransportSocket() { if (transport_.get() && transport_->socket()) transport_->socket()->Disconnect(); diff --git a/jingle/notifier/base/proxy_resolving_client_socket.h b/jingle/notifier/base/proxy_resolving_client_socket.h index 0944a0b..e2426ec 100644 --- a/jingle/notifier/base/proxy_resolving_client_socket.h +++ b/jingle/notifier/base/proxy_resolving_client_socket.h @@ -66,7 +66,9 @@ class ProxyResolvingClientSocket : public net::StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; private: // Proxy resolution and connection functions. diff --git a/net/curvecp/curvecp_client_socket.cc b/net/curvecp/curvecp_client_socket.cc index 8da55b68..f042189 100644 --- a/net/curvecp/curvecp_client_socket.cc +++ b/net/curvecp/curvecp_client_socket.cc @@ -85,10 +85,18 @@ base::TimeDelta CurveCPClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool CurveCPClientSocket::WasNpnNegotiated() const { + return false; +} + NextProto CurveCPClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool CurveCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + int CurveCPClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { diff --git a/net/curvecp/curvecp_client_socket.h b/net/curvecp/curvecp_client_socket.h index de3e86a..80734b7 100644 --- a/net/curvecp/curvecp_client_socket.h +++ b/net/curvecp/curvecp_client_socket.h @@ -38,7 +38,9 @@ class CurveCPClientSocket : public StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; - virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket methods: virtual int Read(IOBuffer* buf, diff --git a/net/http/http_network_transaction_spdy2_unittest.cc b/net/http/http_network_transaction_spdy2_unittest.cc index 94cec09..d3148c0 100644 --- a/net/http/http_network_transaction_spdy2_unittest.cc +++ b/net/http/http_network_transaction_spdy2_unittest.cc @@ -21,6 +21,7 @@ #include "base/utf_string_conversions.h" #include "net/base/auth.h" #include "net/base/capturing_net_log.h" +#include "net/base/cert_test_util.h" #include "net/base/completion_callback.h" #include "net/base/host_cache.h" #include "net/base/mock_cert_verifier.h" @@ -10075,4 +10076,141 @@ TEST_F(HttpNetworkTransactionSpdy2Test, UseSpdySessionForHttpWhenForced) { EXPECT_TRUE(trans2.GetResponseInfo()->was_fetched_via_spdy); } +// Test that in the case where we have a SPDY session to a SPDY proxy +// that we do not pool other origins that resolve to the same IP when +// the certificate does not match the new origin. +// http://crbug.com/134690 +TEST_F(HttpNetworkTransactionSpdy2Test, DoNotUseSpdySessionIfCertDoesNotMatch) { + const std::string url1 = "http://www.google.com/"; + const std::string url2 = "https://mail.google.com/"; + const std::string ip_addr = "1.2.3.4"; + + // SPDY GET for HTTP URL (through SPDY proxy) + const char* const headers[] = { + "method", "GET", + "url", url1.c_str(), + "host", "www.google.com", + "scheme", "http", + "version", "HTTP/1.1" + }; + scoped_ptr<SpdyFrame> req1(ConstructSpdyControlFrame(NULL, 0, false, 1, + LOWEST, SYN_STREAM, + CONTROL_FLAG_FIN, + headers, + arraysize(headers))); + + MockWrite writes1[] = { + CreateMockWrite(*req1, 0), + }; + + scoped_ptr<SpdyFrame> resp1(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<SpdyFrame> body1(ConstructSpdyBodyFrame(1, true)); + MockRead reads1[] = { + CreateMockRead(*resp1, 1), + CreateMockRead(*body1, 2), + MockRead(ASYNC, OK, 3) // EOF + }; + + scoped_ptr<DeterministicSocketData> data1( + new DeterministicSocketData(reads1, arraysize(reads1), + writes1, arraysize(writes1))); + IPAddressNumber ip; + ASSERT_TRUE(ParseIPLiteralToNumber(ip_addr, &ip)); + IPEndPoint peer_addr = IPEndPoint(ip, 443); + MockConnect connect_data1(ASYNC, OK, peer_addr); + data1->set_connect_data(connect_data1); + + // SPDY GET for HTTPS URL (direct) + scoped_ptr<SpdyFrame> req2(ConstructSpdyGet(url2.c_str(), + false, 1, MEDIUM)); + + MockWrite writes2[] = { + CreateMockWrite(*req2, 0), + }; + + scoped_ptr<SpdyFrame> resp2(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<SpdyFrame> body2(ConstructSpdyBodyFrame(1, true)); + MockRead reads2[] = { + CreateMockRead(*resp2, 1), + CreateMockRead(*body2, 2), + MockRead(ASYNC, OK, 3) // EOF + }; + + scoped_ptr<DeterministicSocketData> data2( + new DeterministicSocketData(reads2, arraysize(reads2), + writes2, arraysize(writes2))); + MockConnect connect_data2(ASYNC, OK); + data2->set_connect_data(connect_data2); + + // Set up a proxy config that sends HTTP requests to a proxy, and + // all others direct. + ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString("http=https://proxy:443"); + CapturingProxyResolver* capturing_proxy_resolver = + new CapturingProxyResolver(); + SpdySessionDependencies session_deps(new ProxyService( + new ProxyConfigServiceFixed(proxy_config), capturing_proxy_resolver, + NULL)); + + // Load a valid cert. Note, that this does not need to + // be valid for proxy because the MockSSLClientSocket does + // not actually verify it. But SpdySession will use this + // to see if it is valid for the new origin + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> server_cert( + ImportCertFromFile(certs_dir, "ok_cert.pem")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), server_cert); + + SSLSocketDataProvider ssl1(ASYNC, OK); // to the proxy + ssl1.SetNextProto(kProtoSPDY2); + ssl1.cert = server_cert; + session_deps.deterministic_socket_factory->AddSSLSocketDataProvider(&ssl1); + session_deps.deterministic_socket_factory->AddSocketDataProvider(data1.get()); + + SSLSocketDataProvider ssl2(ASYNC, OK); // to the server + ssl2.SetNextProto(kProtoSPDY2); + session_deps.deterministic_socket_factory->AddSSLSocketDataProvider(&ssl2); + session_deps.deterministic_socket_factory->AddSocketDataProvider(data2.get()); + + session_deps.host_resolver.reset(new MockCachingHostResolver()); + session_deps.host_resolver->rules()->AddRule("mail.google.com", ip_addr); + session_deps.host_resolver->rules()->AddRule("proxy", ip_addr); + + scoped_refptr<HttpNetworkSession> session( + SpdySessionDependencies::SpdyCreateSessionDeterministic(&session_deps)); + + // Start the first transaction to set up the SpdySession + HttpRequestInfo request1; + request1.method = "GET"; + request1.url = GURL(url1); + request1.priority = LOWEST; + request1.load_flags = 0; + HttpNetworkTransaction trans1(session); + TestCompletionCallback callback1; + ASSERT_EQ(ERR_IO_PENDING, + trans1.Start(&request1, callback1.callback(), BoundNetLog())); + data1->RunFor(3); + + ASSERT_TRUE(callback1.have_result()); + EXPECT_EQ(OK, callback1.WaitForResult()); + EXPECT_TRUE(trans1.GetResponseInfo()->was_fetched_via_spdy); + + // Now, start the HTTP request + HttpRequestInfo request2; + request2.method = "GET"; + request2.url = GURL(url2); + request2.priority = MEDIUM; + request2.load_flags = 0; + HttpNetworkTransaction trans2(session); + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, + trans2.Start(&request2, callback2.callback(), BoundNetLog())); + MessageLoop::current()->RunAllPending(); + data2->RunFor(3); + + ASSERT_TRUE(callback2.have_result()); + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_TRUE(trans2.GetResponseInfo()->was_fetched_via_spdy); +} + } // namespace net diff --git a/net/http/http_network_transaction_spdy3_unittest.cc b/net/http/http_network_transaction_spdy3_unittest.cc index cfafd7e..6818cd5 100644 --- a/net/http/http_network_transaction_spdy3_unittest.cc +++ b/net/http/http_network_transaction_spdy3_unittest.cc @@ -21,6 +21,7 @@ #include "base/utf_string_conversions.h" #include "net/base/auth.h" #include "net/base/capturing_net_log.h" +#include "net/base/cert_test_util.h" #include "net/base/completion_callback.h" #include "net/base/host_cache.h" #include "net/base/mock_cert_verifier.h" @@ -10049,4 +10050,130 @@ TEST_F(HttpNetworkTransactionSpdy3Test, UseSpdySessionForHttpWhenForced) { EXPECT_TRUE(trans2.GetResponseInfo()->was_fetched_via_spdy); } +// Test that in the case where we have a SPDY session to a SPDY proxy +// that we do not pool other origins that resolve to the same IP when +// the certificate does not match the new origin. +// http://crbug.com/134690 +TEST_F(HttpNetworkTransactionSpdy3Test, DoNotUseSpdySessionIfCertDoesNotMatch) { + const std::string url1 = "http://www.google.com/"; + const std::string url2 = "https://mail.google.com/"; + const std::string ip_addr = "1.2.3.4"; + + scoped_ptr<SpdyFrame> req1(ConstructSpdyGet(url1.c_str(), + false, 1, LOWEST)); + + MockWrite writes1[] = { + CreateMockWrite(*req1, 0), + }; + + scoped_ptr<SpdyFrame> resp1(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<SpdyFrame> body1(ConstructSpdyBodyFrame(1, true)); + MockRead reads1[] = { + CreateMockRead(*resp1, 1), + CreateMockRead(*body1, 2), + MockRead(ASYNC, OK, 3) // EOF + }; + + scoped_ptr<DeterministicSocketData> data1( + new DeterministicSocketData(reads1, arraysize(reads1), + writes1, arraysize(writes1))); + IPAddressNumber ip; + ASSERT_TRUE(ParseIPLiteralToNumber(ip_addr, &ip)); + IPEndPoint peer_addr = IPEndPoint(ip, 443); + MockConnect connect_data1(ASYNC, OK, peer_addr); + data1->set_connect_data(connect_data1); + + // SPDY GET for HTTPS URL (direct) + scoped_ptr<SpdyFrame> req2(ConstructSpdyGet(url2.c_str(), + false, 1, MEDIUM)); + + MockWrite writes2[] = { + CreateMockWrite(*req2, 0), + }; + + scoped_ptr<SpdyFrame> resp2(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<SpdyFrame> body2(ConstructSpdyBodyFrame(1, true)); + MockRead reads2[] = { + CreateMockRead(*resp2, 1), + CreateMockRead(*body2, 2), + MockRead(ASYNC, OK, 3) // EOF + }; + + scoped_ptr<DeterministicSocketData> data2( + new DeterministicSocketData(reads2, arraysize(reads2), + writes2, arraysize(writes2))); + MockConnect connect_data2(ASYNC, OK); + data2->set_connect_data(connect_data2); + + // Set up a proxy config that sends HTTP requests to a proxy, and + // all others direct. + ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString("http=https://proxy:443"); + CapturingProxyResolver* capturing_proxy_resolver = + new CapturingProxyResolver(); + SpdySessionDependencies session_deps(new ProxyService( + new ProxyConfigServiceFixed(proxy_config), capturing_proxy_resolver, + NULL)); + + // Load a valid cert. Note, that this does not need to + // be valid for proxy because the MockSSLClientSocket does + // not actually verify it. But SpdySession will use this + // to see if it is valid for the new origin + FilePath certs_dir = GetTestCertsDirectory(); + scoped_refptr<X509Certificate> server_cert( + ImportCertFromFile(certs_dir, "ok_cert.pem")); + ASSERT_NE(static_cast<X509Certificate*>(NULL), server_cert); + + SSLSocketDataProvider ssl1(ASYNC, OK); // to the proxy + ssl1.SetNextProto(kProtoSPDY3); + ssl1.cert = server_cert; + session_deps.deterministic_socket_factory->AddSSLSocketDataProvider(&ssl1); + session_deps.deterministic_socket_factory->AddSocketDataProvider(data1.get()); + + SSLSocketDataProvider ssl2(ASYNC, OK); // to the server + ssl2.SetNextProto(kProtoSPDY3); + session_deps.deterministic_socket_factory->AddSSLSocketDataProvider(&ssl2); + session_deps.deterministic_socket_factory->AddSocketDataProvider(data2.get()); + + session_deps.host_resolver.reset(new MockCachingHostResolver()); + session_deps.host_resolver->rules()->AddRule("mail.google.com", ip_addr); + session_deps.host_resolver->rules()->AddRule("proxy", ip_addr); + + scoped_refptr<HttpNetworkSession> session( + SpdySessionDependencies::SpdyCreateSessionDeterministic(&session_deps)); + + // Start the first transaction to set up the SpdySession + HttpRequestInfo request1; + request1.method = "GET"; + request1.url = GURL(url1); + request1.priority = LOWEST; + request1.load_flags = 0; + HttpNetworkTransaction trans1(session); + TestCompletionCallback callback1; + ASSERT_EQ(ERR_IO_PENDING, + trans1.Start(&request1, callback1.callback(), BoundNetLog())); + data1->RunFor(3); + + ASSERT_TRUE(callback1.have_result()); + EXPECT_EQ(OK, callback1.WaitForResult()); + EXPECT_TRUE(trans1.GetResponseInfo()->was_fetched_via_spdy); + + // Now, start the HTTP request + HttpRequestInfo request2; + request2.method = "GET"; + request2.url = GURL(url2); + request2.priority = MEDIUM; + request2.load_flags = 0; + HttpNetworkTransaction trans2(session); + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, + trans2.Start(&request2, callback2.callback(), BoundNetLog())); + MessageLoop::current()->RunAllPending(); + data2->RunFor(3); + + ASSERT_TRUE(callback2.have_result()); + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_TRUE(trans2.GetResponseInfo()->was_fetched_via_spdy); +} + } // namespace net diff --git a/net/http/http_proxy_client_socket.cc b/net/http/http_proxy_client_socket.cc index 826bf09..e3dfa06 100644 --- a/net/http/http_proxy_client_socket.cc +++ b/net/http/http_proxy_client_socket.cc @@ -200,6 +200,14 @@ base::TimeDelta HttpProxyClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool HttpProxyClientSocket::WasNpnNegotiated() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasNpnNegotiated(); + } + NOTREACHED(); + return false; +} + NextProto HttpProxyClientSocket::GetNegotiatedProtocol() const { if (transport_.get() && transport_->socket()) { return transport_->socket()->GetNegotiatedProtocol(); @@ -208,6 +216,14 @@ NextProto HttpProxyClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool HttpProxyClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + int HttpProxyClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(user_callback_.is_null()); diff --git a/net/http/http_proxy_client_socket.h b/net/http/http_proxy_client_socket.h index d02e9ab..3740aed 100644 --- a/net/http/http_proxy_client_socket.h +++ b/net/http/http_proxy_client_socket.h @@ -72,7 +72,9 @@ class HttpProxyClientSocket : public ProxyClientSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, diff --git a/net/http/http_stream_factory_impl_job.cc b/net/http/http_stream_factory_impl_job.cc index 5ef3379..cdab39b 100644 --- a/net/http/http_stream_factory_impl_job.cc +++ b/net/http/http_stream_factory_impl_job.cc @@ -800,7 +800,7 @@ int HttpStreamFactoryImpl::Job::DoInitConnectionComplete(int result) { if (ssl_started && (result == OK || IsCertificateError(result))) { SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(connection_->socket()); - if (ssl_socket->was_npn_negotiated()) { + if (ssl_socket->WasNpnNegotiated()) { was_npn_negotiated_ = true; std::string proto; std::string server_protos; diff --git a/net/socket/buffered_write_stream_socket.cc b/net/socket/buffered_write_stream_socket.cc index ed7ce2e..3119985 100644 --- a/net/socket/buffered_write_stream_socket.cc +++ b/net/socket/buffered_write_stream_socket.cc @@ -119,10 +119,18 @@ base::TimeDelta BufferedWriteStreamSocket::GetConnectTimeMicros() const { return wrapped_socket_->GetConnectTimeMicros(); } +bool BufferedWriteStreamSocket::WasNpnNegotiated() const { + return wrapped_socket_->WasNpnNegotiated(); +} + NextProto BufferedWriteStreamSocket::GetNegotiatedProtocol() const { return wrapped_socket_->GetNegotiatedProtocol(); } +bool BufferedWriteStreamSocket::GetSSLInfo(SSLInfo* ssl_info) { + return wrapped_socket_->GetSSLInfo(ssl_info); +} + void BufferedWriteStreamSocket::DoDelayedWrite() { int result = wrapped_socket_->Write( io_buffer_, io_buffer_->RemainingCapacity(), diff --git a/net/socket/buffered_write_stream_socket.h b/net/socket/buffered_write_stream_socket.h index 6d41c07..5651d2a 100644 --- a/net/socket/buffered_write_stream_socket.h +++ b/net/socket/buffered_write_stream_socket.h @@ -58,7 +58,9 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; private: void DoDelayedWrite(); diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index b53c775..363010a 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -116,9 +116,15 @@ class MockClientSocket : public StreamSocket { base::TimeDelta::FromMicroseconds(10); return kDummyConnectTimeMicros; // Dummy value. } + virtual bool WasNpnNegotiated() const { + return false; + } virtual NextProto GetNegotiatedProtocol() const { return kProtoUnknown; } + virtual bool GetSSLInfo(SSLInfo* ssl_info) { + return false; + } private: bool connected_; diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index ef6d38c..ff66830 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -704,10 +704,6 @@ const BoundNetLog& MockClientSocket::NetLog() const { return net_log_; } -void MockClientSocket::GetSSLInfo(SSLInfo* ssl_info) { - NOTREACHED(); -} - void MockClientSocket::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { } @@ -872,6 +868,14 @@ base::TimeDelta MockTCPClientSocket::GetConnectTimeMicros() const { return kTestingConnectTimeMicros; } +bool MockTCPClientSocket::WasNpnNegotiated() const { + return false; +} + +bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // There must be a read pending. DCHECK(pending_buf_); @@ -1071,6 +1075,14 @@ base::TimeDelta DeterministicMockTCPClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const { + return false; +} + +bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} // static @@ -1158,11 +1170,12 @@ base::TimeDelta MockSSLClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } -void MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { +bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); ssl_info->cert = data_->cert; ssl_info->client_cert_sent = data_->client_cert_sent; ssl_info->channel_id_sent = data_->channel_id_sent; + return true; } void MockSSLClientSocket::GetSSLCertRequestInfo( @@ -1184,17 +1197,17 @@ SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( return data_->next_proto_status; } -bool MockSSLClientSocket::was_npn_negotiated() const { - if (is_npn_state_set_) - return new_npn_value_; - return data_->was_npn_negotiated; -} - bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { is_npn_state_set_ = true; return new_npn_value_ = negotiated; } +bool MockSSLClientSocket::WasNpnNegotiated() const { + if (is_npn_state_set_) + return new_npn_value_; + return data_->was_npn_negotiated; +} + NextProto MockSSLClientSocket::GetNegotiatedProtocol() const { if (is_protocol_negotiated_set_) return protocol_negotiated_; diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 384452c..0c7e4cf 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -595,7 +595,6 @@ class MockClientSocket : public SSLClientSocket { virtual void SetOmniboxSpeculation() OVERRIDE {} // SSLClientSocket implementation. - virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual int ExportKeyingMaterial(const base::StringPiece& label, @@ -647,6 +646,8 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // AsyncSocket: virtual void OnReadComplete(const MockRead& data) OVERRIDE; @@ -705,6 +706,8 @@ class DeterministicMockTCPClientSocket virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // AsyncSocket: virtual void OnReadComplete(const MockRead& data) OVERRIDE; @@ -748,14 +751,15 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { virtual int64 NumBytesRead() const OVERRIDE; virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // SSLClientSocket implementation. - virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual NextProtoStatus GetNextProto(std::string* proto, std::string* server_protos) OVERRIDE; - virtual bool was_npn_negotiated() const OVERRIDE; + //virtual bool was_npn_negotiated() const OVERRIDE; virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE; virtual void set_protocol_negotiated( NextProto protocol_negotiated) OVERRIDE; diff --git a/net/socket/socks5_client_socket.cc b/net/socket/socks5_client_socket.cc index 1b020d7..409968c 100644 --- a/net/socket/socks5_client_socket.cc +++ b/net/socket/socks5_client_socket.cc @@ -158,6 +158,14 @@ base::TimeDelta SOCKS5ClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool SOCKS5ClientSocket::WasNpnNegotiated() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasNpnNegotiated(); + } + NOTREACHED(); + return false; +} + NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const { if (transport_.get() && transport_->socket()) { return transport_->socket()->GetNegotiatedProtocol(); @@ -166,6 +174,15 @@ NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; + +} + // Read is called by the transport layer above to read. This can only be done // if the SOCKS handshake is complete. int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, diff --git a/net/socket/socks5_client_socket.h b/net/socket/socks5_client_socket.h index fa76be2..38810df 100644 --- a/net/socket/socks5_client_socket.h +++ b/net/socket/socks5_client_socket.h @@ -61,7 +61,9 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index 6776b71..2842fd1 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -180,6 +180,14 @@ base::TimeDelta SOCKSClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool SOCKSClientSocket::WasNpnNegotiated() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasNpnNegotiated(); + } + NOTREACHED(); + return false; +} + NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { if (transport_.get() && transport_->socket()) { return transport_->socket()->GetNegotiatedProtocol(); @@ -188,6 +196,15 @@ NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; + +} + // Read is called by the transport layer above to read. This can only be done // if the SOCKS handshake is complete. int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, diff --git a/net/socket/socks_client_socket.h b/net/socket/socks_client_socket.h index 3f0a086..6e74409 100644 --- a/net/socket/socks_client_socket.h +++ b/net/socket/socks_client_socket.h @@ -58,7 +58,9 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, diff --git a/net/socket/ssl_client_socket.cc b/net/socket/ssl_client_socket.cc index 6bcc96e..7f23258 100644 --- a/net/socket/ssl_client_socket.cc +++ b/net/socket/ssl_client_socket.cc @@ -77,6 +77,10 @@ std::string SSLClientSocket::ServerProtosToString( return JoinString(server_protos_with_commas, ','); } +bool SSLClientSocket::WasNpnNegotiated() const { + return was_npn_negotiated_; +} + NextProto SSLClientSocket::GetNegotiatedProtocol() const { return protocol_negotiated_; } @@ -100,10 +104,6 @@ bool SSLClientSocket::IgnoreCertError(int error, int load_flags) { return false; } -bool SSLClientSocket::was_npn_negotiated() const { - return was_npn_negotiated_; -} - bool SSLClientSocket::set_was_npn_negotiated(bool negotiated) { return was_npn_negotiated_ = negotiated; } diff --git a/net/socket/ssl_client_socket.h b/net/socket/ssl_client_socket.h index 6748e6e..41ee087 100644 --- a/net/socket/ssl_client_socket.h +++ b/net/socket/ssl_client_socket.h @@ -69,20 +69,15 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // the first protocol in our list. }; - // Gets the SSL connection information of the socket. - // - // TODO(sergeyu): Move this method to the SSLSocket interface and - // implemented in SSLServerSocket too. - virtual void GetSSLInfo(SSLInfo* ssl_info) = 0; + // StreamSocket: + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; // Gets the SSL CertificateRequest info of the socket after Connect failed // with ERR_SSL_CLIENT_AUTH_CERT_NEEDED. virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) = 0; - // StreamSocket: - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - // Get the application level protocol that we negotiated with the server. // *proto is set to the resulting protocol (n.b. that the string may have // embedded NULs). @@ -110,8 +105,6 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // sessions. static void ClearSessionCache(); - virtual bool was_npn_negotiated() const; - virtual bool set_was_npn_negotiated(bool negotiated); virtual bool was_spdy_negotiated() const; diff --git a/net/socket/ssl_client_socket_mac.cc b/net/socket/ssl_client_socket_mac.cc index bdca223..ff43849 100644 --- a/net/socket/ssl_client_socket_mac.cc +++ b/net/socket/ssl_client_socket_mac.cc @@ -715,10 +715,10 @@ bool SSLClientSocketMac::SetSendBufferSize(int32 size) { return transport_->socket()->SetSendBufferSize(size); } -void SSLClientSocketMac::GetSSLInfo(SSLInfo* ssl_info) { +bool SSLClientSocketMac::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); if (!server_cert_) - return; + return false; ssl_info->cert = server_cert_verify_result_.verified_cert; ssl_info->cert_status = server_cert_verify_result_.cert_status; @@ -741,6 +741,8 @@ void SSLClientSocketMac::GetSSLInfo(SSLInfo* ssl_info) { if (ssl_config_.version_fallback) ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK; + + return true; } void SSLClientSocketMac::GetSSLCertRequestInfo( diff --git a/net/socket/ssl_client_socket_mac.h b/net/socket/ssl_client_socket_mac.h index f923f3a..d7ced45 100644 --- a/net/socket/ssl_client_socket_mac.h +++ b/net/socket/ssl_client_socket_mac.h @@ -40,7 +40,6 @@ class SSLClientSocketMac : public SSLClientSocket { virtual ~SSLClientSocketMac(); // SSLClientSocket implementation. - virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual int ExportKeyingMaterial(const base::StringPiece& label, @@ -66,6 +65,7 @@ class SSLClientSocketMac : public SSLClientSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 0c3ea4e..9d0eea2 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -2750,12 +2750,12 @@ void SSLClientSocket::ClearSessionCache() { SSL_ClearSessionCache(); } -void SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { +bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { EnterFunction(""); ssl_info->Reset(); if (core_->state().server_cert_chain.empty() || !core_->state().server_cert_chain[0]) { - return; + return false; } ssl_info->cert_status = server_cert_verify_result_.cert_status; @@ -2791,6 +2791,7 @@ void SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL; LeaveFunction(""); + return true; } void SSLClientSocketNSS::GetSSLCertRequestInfo( diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 79a72fe..434b7c6 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -66,7 +66,6 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual ~SSLClientSocketNSS(); // SSLClientSocket implementation. - virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual int ExportKeyingMaterial(const base::StringPiece& label, @@ -91,6 +90,7 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index e350ded..85d0d65 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -591,10 +591,10 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, // SSLClientSocket methods -void SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { +bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); if (!server_cert_) - return; + return false; ssl_info->cert = server_cert_verify_result_.verified_cert; ssl_info->cert_status = server_cert_verify_result_.cert_status; @@ -631,6 +631,7 @@ void SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { << SSLConnectionStatusToCompression(ssl_info->connection_status) << " version = " << SSLConnectionStatusToVersion(ssl_info->connection_status); + return true; } void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 129d30e..d113f82 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -55,7 +55,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { const unsigned char* in, unsigned int inlen); // SSLClientSocket implementation. - virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; virtual void GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual int ExportKeyingMaterial(const base::StringPiece& label, @@ -81,6 +80,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, diff --git a/net/socket/ssl_client_socket_pool_unittest.cc b/net/socket/ssl_client_socket_pool_unittest.cc index 34b9951..1dd9e55 100644 --- a/net/socket/ssl_client_socket_pool_unittest.cc +++ b/net/socket/ssl_client_socket_pool_unittest.cc @@ -330,7 +330,7 @@ TEST_F(SSLClientSocketPoolTest, DirectWithNPN) { EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); - EXPECT_TRUE(ssl_socket->was_npn_negotiated()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); } TEST_F(SSLClientSocketPoolTest, DirectNoSPDY) { @@ -382,7 +382,7 @@ TEST_F(SSLClientSocketPoolTest, DirectGotSPDY) { EXPECT_TRUE(handle.socket()); SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); - EXPECT_TRUE(ssl_socket->was_npn_negotiated()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); std::string proto; std::string server_protos; ssl_socket->GetNextProto(&proto, &server_protos); @@ -414,7 +414,7 @@ TEST_F(SSLClientSocketPoolTest, DirectGotBonusSPDY) { EXPECT_TRUE(handle.socket()); SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); - EXPECT_TRUE(ssl_socket->was_npn_negotiated()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); std::string proto; std::string server_protos; ssl_socket->GetNextProto(&proto, &server_protos); @@ -714,7 +714,7 @@ TEST_F(SSLClientSocketPoolTest, IPPooling) { EXPECT_TRUE(handle->socket()); SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle->socket()); - EXPECT_TRUE(ssl_socket->was_npn_negotiated()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); std::string proto; std::string server_protos; ssl_socket->GetNextProto(&proto, &server_protos); @@ -793,7 +793,7 @@ void SSLClientSocketPoolTest::TestIPPoolingDisabled( EXPECT_TRUE(handle->socket()); SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle->socket()); - EXPECT_TRUE(ssl_socket->was_npn_negotiated()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); std::string proto; std::string server_protos; ssl_socket->GetNextProto(&proto, &server_protos); diff --git a/net/socket/ssl_client_socket_win.cc b/net/socket/ssl_client_socket_win.cc index 3edad6b..d997380 100644 --- a/net/socket/ssl_client_socket_win.cc +++ b/net/socket/ssl_client_socket_win.cc @@ -404,10 +404,10 @@ SSLClientSocketWin::~SSLClientSocketWin() { Disconnect(); } -void SSLClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { +bool SSLClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); if (!server_cert_) - return; + return false; ssl_info->cert = server_cert_verify_result_.verified_cert; ssl_info->cert_status = server_cert_verify_result_.cert_status; @@ -448,6 +448,8 @@ void SSLClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { if (ssl_config_.version_fallback) ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK; + + return true; } void SSLClientSocketWin::GetSSLCertRequestInfo( diff --git a/net/socket/ssl_client_socket_win.h b/net/socket/ssl_client_socket_win.h index 9013a9e..f5c0a4d 100644 --- a/net/socket/ssl_client_socket_win.h +++ b/net/socket/ssl_client_socket_win.h @@ -45,7 +45,6 @@ class SSLClientSocketWin : public SSLClientSocket { ~SSLClientSocketWin(); // SSLClientSocket implementation. - virtual void GetSSLInfo(SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info); virtual int ExportKeyingMaterial(const base::StringPiece& label, bool has_context, @@ -70,6 +69,7 @@ class SSLClientSocketWin : public SSLClientSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 35ccdd6..a43dbfe 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -279,11 +279,20 @@ base::TimeDelta SSLServerSocketNSS::GetConnectTimeMicros() const { return transport_socket_->GetConnectTimeMicros(); } +bool SSLServerSocketNSS::WasNpnNegotiated() const { + return false; +} + NextProto SSLServerSocketNSS::GetNegotiatedProtocol() const { // NPN is not supported by this class. return kProtoUnknown; } +bool SSLServerSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { + NOTIMPLEMENTED(); + return false; +} + int SSLServerSocketNSS::InitializeSSLOptions() { // Transport connected, now hook it up to nss // TODO(port): specify rx and tx buffer sizes separately diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index e7da5ac..ba55649 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -60,7 +60,9 @@ class SSLServerSocketNSS : public SSLServerSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; private: enum State { diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc index 03a6db0..f93d7c6 100644 --- a/net/socket/ssl_server_socket_unittest.cc +++ b/net/socket/ssl_server_socket_unittest.cc @@ -236,10 +236,18 @@ class FakeSocket : public StreamSocket { return base::TimeDelta::FromMicroseconds(-1); } + virtual bool WasNpnNegotiated() const { + return false; + } + virtual NextProto GetNegotiatedProtocol() const { return kProtoUnknown; } + virtual bool GetSSLInfo(SSLInfo* ssl_info) { + return false; + } + private: net::BoundNetLog net_log_; FakeDataChannel* incoming_; diff --git a/net/socket/stream_socket.h b/net/socket/stream_socket.h index a513099..9a7d153 100644 --- a/net/socket/stream_socket.h +++ b/net/socket/stream_socket.h @@ -14,6 +14,7 @@ namespace net { class AddressList; class IPEndPoint; +class SSLInfo; class NET_EXPORT_PRIVATE StreamSocket : public Socket { public: @@ -86,10 +87,17 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { // Returns the connection setup time of this socket. virtual base::TimeDelta GetConnectTimeMicros() const = 0; + // Returns true if NPN was negotiated during the connection of this socket. + virtual bool WasNpnNegotiated() const = 0; + // Returns the protocol negotiated via NPN for this socket, or // kProtoUnknown will be returned if NPN is not applicable. virtual NextProto GetNegotiatedProtocol() const = 0; + // Gets the SSL connection information of the socket. Returns false if + // SSL was not used by this socket. + virtual bool GetSSLInfo(SSLInfo* ssl_info) = 0; + protected: // The following class is only used to gather statistics about the history of // a socket. It is only instantiated and used in basic sockets, such as diff --git a/net/socket/tcp_client_socket_libevent.cc b/net/socket/tcp_client_socket_libevent.cc index d139a54..e54eb1e 100644 --- a/net/socket/tcp_client_socket_libevent.cc +++ b/net/socket/tcp_client_socket_libevent.cc @@ -758,8 +758,16 @@ base::TimeDelta TCPClientSocketLibevent::GetConnectTimeMicros() const { return connect_time_micros_; } +bool TCPClientSocketLibevent::WasNpnNegotiated() const { + return false; +} + NextProto TCPClientSocketLibevent::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool TCPClientSocketLibevent::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + } // namespace net diff --git a/net/socket/tcp_client_socket_libevent.h b/net/socket/tcp_client_socket_libevent.h index 653ffe3..8ddb61b 100644 --- a/net/socket/tcp_client_socket_libevent.h +++ b/net/socket/tcp_client_socket_libevent.h @@ -55,7 +55,9 @@ class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. // Multiple outstanding requests are not supported. diff --git a/net/socket/tcp_client_socket_win.cc b/net/socket/tcp_client_socket_win.cc index aed2adb..74fa99e 100644 --- a/net/socket/tcp_client_socket_win.cc +++ b/net/socket/tcp_client_socket_win.cc @@ -684,10 +684,18 @@ base::TimeDelta TCPClientSocketWin::GetConnectTimeMicros() const { return connect_time_micros_; } +bool TCPClientSocketWin::WasNpnNegotiated() const { + return false; +} + NextProto TCPClientSocketWin::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool TCPClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + int TCPClientSocketWin::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { diff --git a/net/socket/tcp_client_socket_win.h b/net/socket/tcp_client_socket_win.h index 1ca957f..9e95aae 100644 --- a/net/socket/tcp_client_socket_win.h +++ b/net/socket/tcp_client_socket_win.h @@ -54,7 +54,9 @@ class NET_EXPORT TCPClientSocketWin : public StreamSocket, virtual bool UsingTCPFastOpen() const; virtual int64 NumBytesRead() const; virtual base::TimeDelta GetConnectTimeMicros() const; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. // Multiple outstanding requests are not supported. diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 93e7d11..d7ff4e6 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -89,9 +89,15 @@ class MockClientSocket : public StreamSocket { virtual base::TimeDelta GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } + virtual bool WasNpnNegotiated() const { + return false; + } virtual NextProto GetNegotiatedProtocol() const { return kProtoUnknown; } + virtual bool GetSSLInfo(SSLInfo* ssl_info) { + return false; + } // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, @@ -146,9 +152,15 @@ class MockFailingClientSocket : public StreamSocket { virtual base::TimeDelta GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } + virtual bool WasNpnNegotiated() const { + return false; + } virtual NextProto GetNegotiatedProtocol() const { return kProtoUnknown; } + virtual bool GetSSLInfo(SSLInfo* ssl_info) { + return false; + } // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, @@ -228,9 +240,15 @@ class MockPendingClientSocket : public StreamSocket { virtual base::TimeDelta GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } + virtual bool WasNpnNegotiated() const { + return false; + } virtual NextProto GetNegotiatedProtocol() const { return kProtoUnknown; } + virtual bool GetSSLInfo(SSLInfo* ssl_info) { + return false; + } // Socket implementation. virtual int Read(IOBuffer* buf, int buf_len, diff --git a/net/spdy/spdy_proxy_client_socket.cc b/net/spdy/spdy_proxy_client_socket.cc index 759c84cf..251ada2 100644 --- a/net/spdy/spdy_proxy_client_socket.cc +++ b/net/spdy/spdy_proxy_client_socket.cc @@ -168,10 +168,21 @@ base::TimeDelta SpdyProxyClientSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool SpdyProxyClientSocket::WasNpnNegotiated() const { + return false; +} + NextProto SpdyProxyClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } +bool SpdyProxyClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + bool was_npn_negotiated; + NextProto protocol_negotiated; + return spdy_stream_->GetSSLInfo(ssl_info, &was_npn_negotiated, + &protocol_negotiated); +} + int SpdyProxyClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(read_callback_.is_null()); diff --git a/net/spdy/spdy_proxy_client_socket.h b/net/spdy/spdy_proxy_client_socket.h index 70f6ae6..3859c61 100644 --- a/net/spdy/spdy_proxy_client_socket.h +++ b/net/spdy/spdy_proxy_client_socket.h @@ -74,7 +74,9 @@ class NET_EXPORT_PRIVATE SpdyProxyClientSocket : public ProxyClientSocket, virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // Socket implementation. virtual int Read(IOBuffer* buf, diff --git a/net/spdy/spdy_session.cc b/net/spdy/spdy_session.cc index 5528617..d0190df 100644 --- a/net/spdy/spdy_session.cc +++ b/net/spdy/spdy_session.cc @@ -1191,15 +1191,10 @@ scoped_refptr<SpdyStream> SpdySession::GetActivePushStream( bool SpdySession::GetSSLInfo(SSLInfo* ssl_info, bool* was_npn_negotiated, NextProto* protocol_negotiated) { - if (!is_secure_) { - *protocol_negotiated = kProtoUnknown; - return false; - } - SSLClientSocket* ssl_socket = GetSSLClientSocket(); - ssl_socket->GetSSLInfo(ssl_info); - *was_npn_negotiated = ssl_socket->was_npn_negotiated(); - *protocol_negotiated = ssl_socket->GetNegotiatedProtocol(); - return true; + + *was_npn_negotiated = connection_->socket()->WasNpnNegotiated(); + *protocol_negotiated = connection_->socket()->GetNegotiatedProtocol(); + return connection_->socket()->GetSSLInfo(ssl_info); } bool SpdySession::GetSSLCertRequestInfo( diff --git a/remoting/jingle_glue/ssl_socket_adapter.cc b/remoting/jingle_glue/ssl_socket_adapter.cc index 4ff09a5c9..08ba785 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.cc +++ b/remoting/jingle_glue/ssl_socket_adapter.cc @@ -349,11 +349,21 @@ base::TimeDelta TransportSocket::GetConnectTimeMicros() const { return base::TimeDelta::FromMicroseconds(-1); } +bool TransportSocket::WasNpnNegotiated() const { + NOTREACHED(); + return false; +} + net::NextProto TransportSocket::GetNegotiatedProtocol() const { NOTREACHED(); return net::kProtoUnknown; } +bool TransportSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + NOTREACHED(); + return false; +} + int TransportSocket::Read(net::IOBuffer* buf, int buf_len, const net::CompletionCallback& callback) { DCHECK(buf); diff --git a/remoting/jingle_glue/ssl_socket_adapter.h b/remoting/jingle_glue/ssl_socket_adapter.h index c95ac3a..e62c048 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.h +++ b/remoting/jingle_glue/ssl_socket_adapter.h @@ -54,7 +54,9 @@ class TransportSocket : public net::StreamSocket, public sigslot::has_slots<> { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; // net::Socket implementation. virtual int Read(net::IOBuffer* buf, int buf_len, diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index de6f7cf..69f87db 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -165,11 +165,19 @@ base::TimeDelta FakeSocket::GetConnectTimeMicros() const { return base::TimeDelta(); } +bool FakeSocket::WasNpnNegotiated() const { + return false; +} + net::NextProto FakeSocket::GetNegotiatedProtocol() const { NOTIMPLEMENTED(); return net::kProtoUnknown; } +bool FakeSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + return false; +} + FakeUdpSocket::FakeUdpSocket() : read_pending_(false), input_pos_(0), diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index 4000cbc..5b59d90 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -69,7 +69,9 @@ class FakeSocket : public net::StreamSocket { virtual bool UsingTCPFastOpen() const OVERRIDE; virtual int64 NumBytesRead() const OVERRIDE; virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; private: int next_read_error_; |