diff options
Diffstat (limited to 'net/socket')
-rw-r--r-- | net/socket/socket_test_util.cc | 6 | ||||
-rw-r--r-- | net/socket/socket_test_util.h | 4 | ||||
-rw-r--r-- | net/socket/ssl_client_socket.h | 10 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_nss.cc | 5 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_nss.h | 5 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_openssl.cc | 165 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_openssl.h | 7 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_unittest.cc | 155 |
8 files changed, 292 insertions, 65 deletions
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index bed8a0c..f1456f8 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -782,6 +782,12 @@ MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) { return SSLClientSocket::kNextProtoUnsupported; } +scoped_refptr<X509Certificate> +MockClientSocket::GetUnverifiedServerCertificateChain() const { + NOTREACHED(); + return NULL; +} + MockClientSocket::~MockClientSocket() {} void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback, diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index e4f6553..17c98aa 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -707,6 +707,10 @@ class MockClientSocket : public SSLClientSocket { void RunCallbackAsync(const CompletionCallback& callback, int result); void RunCallback(const CompletionCallback& callback, int result); + // SSLClientSocket implementation. + virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const OVERRIDE; + // True if Connect completed successfully and Disconnect hasn't been called. bool connected_; diff --git a/net/socket/ssl_client_socket.h b/net/socket/ssl_client_socket.h index 410062d..a43e58c 100644 --- a/net/socket/ssl_client_socket.h +++ b/net/socket/ssl_client_socket.h @@ -23,6 +23,7 @@ class SSLCertRequestInfo; struct SSLConfig; class SSLInfo; class TransportSecurityState; +class X509Certificate; // This struct groups together several fields which are used by various // classes related to SSLClientSocket. @@ -154,6 +155,13 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { const SSLConfig& ssl_config, ServerBoundCertService* server_bound_cert_service); + // For unit testing only. + // Returns the unverified certificate chain as presented by server. + // Note that chain may be different than the verified chain returned by + // StreamSocket::GetSSLInfo(). + virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const = 0; + private: // For signed_cert_timestamps_received_ and stapled_ocsp_response_received_. FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, @@ -162,6 +170,8 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { ConnectSignedCertTimestampsEnabledOCSP); FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled); + FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, + VerifyServerChainProperlyOrdered); // True if NPN was responded to, independent of selecting SPDY or HTTP. bool was_npn_negotiated_; diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 598bb7c..ca5a689 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -3589,6 +3589,11 @@ void SSLClientSocketNSS::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { } } +scoped_refptr<X509Certificate> +SSLClientSocketNSS::GetUnverifiedServerCertificateChain() const { + return core_->state().server_cert.get(); +} + ServerBoundCertService* SSLClientSocketNSS::GetServerBoundCertService() const { return server_bound_cert_service_; } diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index 47fed74..acd5d37 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -106,6 +106,11 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual bool SetSendBufferSize(int32 size) OVERRIDE; virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + protected: + // SSLClientSocket implementation. + virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const OVERRIDE; + private: // Helper class to handle marshalling any NSS interaction to and from the // NSS and network task runners. Not every call needs to happen on the Core diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc index ee07f19..d04670f 100644 --- a/net/socket/ssl_client_socket_openssl.cc +++ b/net/socket/ssl_client_socket_openssl.cc @@ -326,6 +326,140 @@ class SSLClientSocketOpenSSL::SSLContext { SSLSessionCacheOpenSSL session_cache_; }; +// PeerCertificateChain is a helper object which extracts the certificate +// chain, as given by the server, from an OpenSSL socket and performs the needed +// resource management. The first element of the chain is the leaf certificate +// and the other elements are in the order given by the server. +class SSLClientSocketOpenSSL::PeerCertificateChain { + public: + explicit PeerCertificateChain(SSL* ssl) { Reset(ssl); } + PeerCertificateChain(const PeerCertificateChain& other) { *this = other; } + ~PeerCertificateChain() {} + PeerCertificateChain& operator=(const PeerCertificateChain& other); + + // Resets the PeerCertificateChain to the set of certificates supplied by the + // peer of |ssl|, which may be NULL, indicating to empty the store + // certificates. Note: If an error occurs, such as being unable to parse the + // certificates, this will behave as if Reset(NULL) was called. + void Reset(SSL* ssl); + // Note that when USE_OPENSSL is defined, OSCertHandle is X509* + const scoped_refptr<X509Certificate>& AsOSChain() const { return os_chain_; } + + size_t size() const { + if (!openssl_chain_.get()) + return 0; + return sk_X509_num(openssl_chain_.get()); + } + + X509* operator[](size_t index) const { + DCHECK_LT(index, size()); + return sk_X509_value(openssl_chain_.get(), index); + } + + private: + static void FreeX509Stack(STACK_OF(X509)* cert_chain) { + sk_X509_pop_free(cert_chain, X509_free); + } + + friend class crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>; + + crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack> openssl_chain_; + + scoped_refptr<X509Certificate> os_chain_; +}; + +SSLClientSocketOpenSSL::PeerCertificateChain& +SSLClientSocketOpenSSL::PeerCertificateChain::operator=( + const PeerCertificateChain& other) { + if (this == &other) + return *this; + + // os_chain_ is reference counted by scoped_refptr; + os_chain_ = other.os_chain_; + + // Must increase the reference count manually for sk_X509_dup + openssl_chain_.reset(sk_X509_dup(other.openssl_chain_.get())); + for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { + X509* x = sk_X509_value(openssl_chain_.get(), i); + CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509); + } + return *this; +} + +#if defined(USE_OPENSSL) +// When OSCertHandle is typedef'ed to X509, this implementation does a short cut +// to avoid converting back and forth between der and X509 struct. +void SSLClientSocketOpenSSL::PeerCertificateChain::Reset(SSL* ssl) { + openssl_chain_.reset(NULL); + os_chain_ = NULL; + + if (ssl == NULL) + return; + + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); + if (!chain) + return; + + X509Certificate::OSCertHandles intermediates; + for (int i = 1; i < sk_X509_num(chain); ++i) + intermediates.push_back(sk_X509_value(chain, i)); + + os_chain_ = + X509Certificate::CreateFromHandle(sk_X509_value(chain, 0), intermediates); + + // sk_X509_dup does not increase reference count on the certs in the stack. + openssl_chain_.reset(sk_X509_dup(chain)); + + std::vector<base::StringPiece> der_chain; + for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { + X509* x = sk_X509_value(openssl_chain_.get(), i); + // Increase the reference count for the certs in openssl_chain_. + CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509); + } +} +#else // !defined(USE_OPENSSL) +void SSLClientSocketOpenSSL::PeerCertificateChain::Reset(SSL* ssl) { + openssl_chain_.reset(NULL); + os_chain_ = NULL; + + if (ssl == NULL) + return; + + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); + if (!chain) + return; + + // sk_X509_dup does not increase reference count on the certs in the stack. + openssl_chain_.reset(sk_X509_dup(chain)); + + std::vector<base::StringPiece> der_chain; + for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { + X509* x = sk_X509_value(openssl_chain_.get(), i); + + // Increase the reference count for the certs in openssl_chain_. + CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509); + + unsigned char* cert_data = NULL; + int cert_data_length = i2d_X509(x, &cert_data); + if (cert_data_length && cert_data) + der_chain.push_back(base::StringPiece(reinterpret_cast<char*>(cert_data), + cert_data_length)); + } + + os_chain_ = X509Certificate::CreateFromDERCertChain(der_chain); + + for (size_t i = 0; i < der_chain.size(); ++i) { + OPENSSL_free(const_cast<char*>(der_chain[i].data())); + } + + if (der_chain.size() != + static_cast<size_t>(sk_X509_num(openssl_chain_.get()))) { + openssl_chain_.reset(NULL); + os_chain_ = NULL; + } +} +#endif // USE_OPENSSL + // static SSLSessionCacheOpenSSL::Config SSLClientSocketOpenSSL::SSLContext::kDefaultSessionCacheConfig = { @@ -354,6 +488,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( weak_factory_(this), pending_read_error_(kNoPendingReadResult), transport_write_error_(OK), + server_cert_chain_(new PeerCertificateChain(NULL)), completed_handshake_(false), client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), @@ -369,8 +504,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( npn_status_(kNextProtoUnsupported), channel_id_request_return_value_(ERR_UNEXPECTED), channel_id_xtn_negotiated_(false), - net_log_(transport_->socket()->NetLog()) { -} + net_log_(transport_->socket()->NetLog()) {} SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { Disconnect(); @@ -924,26 +1058,8 @@ void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { } X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() { - if (server_cert_.get()) - return server_cert_.get(); - - crypto::ScopedOpenSSL<X509, X509_free> cert(SSL_get_peer_certificate(ssl_)); - if (!cert.get()) { - LOG(WARNING) << "SSL_get_peer_certificate returned NULL"; - return NULL; - } - - // Unlike SSL_get_peer_certificate, SSL_get_peer_cert_chain does not - // increment the reference so sk_X509_free does not need to be called. - STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); - X509Certificate::OSCertHandles intermediates; - if (chain) { - for (int i = 0; i < sk_X509_num(chain); ++i) - intermediates.push_back(sk_X509_value(chain, i)); - } - server_cert_ = X509Certificate::CreateFromHandle(cert.get(), intermediates); - DCHECK(server_cert_.get()); - + server_cert_chain_->Reset(ssl_); + server_cert_ = server_cert_chain_->AsOSChain(); return server_cert_.get(); } @@ -1447,4 +1563,9 @@ int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, return SSL_TLSEXT_ERR_OK; } +scoped_refptr<X509Certificate> +SSLClientSocketOpenSSL::GetUnverifiedServerCertificateChain() const { + return server_cert_; +} + } // namespace net diff --git a/net/socket/ssl_client_socket_openssl.h b/net/socket/ssl_client_socket_openssl.h index 0fc9cbe..8952fef 100644 --- a/net/socket/ssl_client_socket_openssl.h +++ b/net/socket/ssl_client_socket_openssl.h @@ -92,7 +92,13 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; virtual bool SetSendBufferSize(int32 size) OVERRIDE; + protected: + // SSLClientSocket implementation. + virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const OVERRIDE; + private: + class PeerCertificateChain; class SSLContext; friend class SSLClientSocket; friend class SSLContext; @@ -176,6 +182,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { int transport_write_error_; // Set when handshake finishes. + scoped_ptr<PeerCertificateChain> server_cert_chain_; scoped_refptr<X509Certificate> server_cert_; CertVerifyResult server_cert_verify_result_; bool completed_handshake_; diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc index af05a39..20ba896 100644 --- a/net/socket/ssl_client_socket_unittest.cc +++ b/net/socket/ssl_client_socket_unittest.cc @@ -524,6 +524,47 @@ class SSLClientSocketTest : public PlatformTest { SSLClientSocketContext context_; }; +// Verifies the correctness of GetSSLCertRequestInfo. +class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { + protected: + // Creates a test server with the given SSLOptions, connects to it and returns + // the SSLCertRequestInfo reported by the socket. + scoped_refptr<SSLCertRequestInfo> GetCertRequest( + SpawnedTestServer::SSLOptions ssl_options) { + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + if (!test_server.Start()) + return NULL; + + AddressList addr; + if (!test_server.GetAddressList(&addr)) + return NULL; + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); + sock->GetSSLCertRequestInfo(request_info.get()); + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); + + return request_info; + } +}; + //----------------------------------------------------------------------------- // LogContainsSSLConnectEndEvent returns true if the given index in the given @@ -541,6 +582,8 @@ static bool LogContainsSSLConnectEndEvent( log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); } +} // namespace + TEST_F(SSLClientSocketTest, Connect) { SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, SpawnedTestServer::kLocalhost, @@ -1708,6 +1751,75 @@ TEST(SSLClientSocket, ClearSessionCache) { SSLClientSocket::ClearSessionCache(); } +// Test that the server certificates are properly retrieved from the underlying +// SSL stack. +TEST_F(SSLClientSocketTest, VerifyServerChainProperlyOrdered) { + // The connection does not have to be successful. + cert_verifier_->set_default_result(ERR_CERT_INVALID); + + // Set up a test server with CERT_CHAIN_WRONG_ROOT. + // This makes the server present redundant-server-chain.pem, which contains + // intermediate certificates. + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + rv = callback.GetResult(rv); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + EXPECT_FALSE(sock->IsConnected()); + rv = sock->Connect(callback.callback()); + rv = callback.GetResult(rv); + + EXPECT_EQ(ERR_CERT_INVALID, rv); + EXPECT_TRUE(sock->IsConnected()); + + // When given option CERT_CHAIN_WRONG_ROOT, SpawnedTestServer will present + // certs from redundant-server-chain.pem. + CertificateList server_certs = + CreateCertificateListFromFile(GetTestCertsDirectory(), + "redundant-server-chain.pem", + X509Certificate::FORMAT_AUTO); + + // Get the server certificate as received client side. + scoped_refptr<X509Certificate> server_certificate = + sock->GetUnverifiedServerCertificateChain(); + + // Get the intermediates as received client side. + const X509Certificate::OSCertHandles& server_intermediates = + server_certificate->GetIntermediateCertificates(); + + // Check that the unverified server certificate chain is properly retrieved + // from the underlying ssl stack. + ASSERT_EQ(4U, server_certs.size()); + + EXPECT_TRUE(X509Certificate::IsSameOSCert( + server_certificate->os_cert_handle(), server_certs[0]->os_cert_handle())); + + ASSERT_EQ(3U, server_intermediates.size()); + + EXPECT_TRUE(X509Certificate::IsSameOSCert(server_intermediates[0], + server_certs[1]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(server_intermediates[1], + server_certs[2]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(server_intermediates[2], + server_certs[3]->os_cert_handle())); + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + // This tests that SSLInfo contains a properly re-constructed certificate // chain. That, in turn, verifies that GetSSLInfo is giving us the chain as // verified, not the chain as served by the server. (They may be different.) @@ -1806,47 +1918,6 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { EXPECT_FALSE(sock->IsConnected()); } -// Verifies the correctness of GetSSLCertRequestInfo. -class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { - protected: - // Creates a test server with the given SSLOptions, connects to it and returns - // the SSLCertRequestInfo reported by the socket. - scoped_refptr<SSLCertRequestInfo> GetCertRequest( - SpawnedTestServer::SSLOptions ssl_options) { - SpawnedTestServer test_server( - SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); - if (!test_server.Start()) - return NULL; - - AddressList addr; - if (!test_server.GetAddressList(&addr)) - return NULL; - - TestCompletionCallback callback; - CapturingNetLog log; - scoped_ptr<StreamSocket> transport( - new TCPClientSocket(addr, &log, NetLog::Source())); - int rv = transport->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - EXPECT_EQ(OK, rv); - - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); - EXPECT_FALSE(sock->IsConnected()); - - rv = sock->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); - sock->GetSSLCertRequestInfo(request_info.get()); - sock->Disconnect(); - EXPECT_FALSE(sock->IsConnected()); - - return request_info; - } -}; - TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) { SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; @@ -1898,8 +1969,6 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { request_info->cert_authorities[1]); } -} // namespace - TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension) { SpawnedTestServer::SSLOptions ssl_options; ssl_options.signed_cert_timestamps_tls_ext = "test"; |