From f61c397ae7c8d07762b02d6578928163e2a8eca0 Mon Sep 17 00:00:00 2001 From: "hclam@chromium.org" Date: Thu, 23 Dec 2010 09:54:15 +0000 Subject: Defines SSLServerSocket and implements SSLServerSocketNSS Defines a SSLServerSocket interface. Implement this interface using NSS as SSLServerSocketNSS. This is the first version of the code. It disables several functions of NSS like caching, session ticket, reneogotiation, etc. This is implemented to suit the needs of Chromoting. Additional features of this socket will be added when necessary. BUG=None TEST=None Review URL: http://codereview.chromium.org/5746003 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@70041 0039d316-1c4b-4281-b951-d872f2087c98 --- net/base/net_log_event_type_list.h | 3 + net/base/ssl_config_service.cc | 5 +- net/base/ssl_config_service.h | 5 + net/base/x509_certificate.h | 5 + net/base/x509_certificate_mac.cc | 11 + net/base/x509_certificate_nss.cc | 9 + net/base/x509_certificate_openssl.cc | 5 + net/base/x509_certificate_unittest.cc | 12 + net/base/x509_certificate_win.cc | 9 + net/net.gyp | 8 + net/socket/nss_ssl_util.cc | 240 +++++++++++ net/socket/nss_ssl_util.h | 36 ++ net/socket/ssl_client_socket_nss.cc | 232 +---------- net/socket/ssl_server_socket.h | 53 +++ net/socket/ssl_server_socket_nss.cc | 677 +++++++++++++++++++++++++++++++ net/socket/ssl_server_socket_nss.h | 133 ++++++ net/socket/ssl_server_socket_unittest.cc | 369 +++++++++++++++++ 17 files changed, 1589 insertions(+), 223 deletions(-) create mode 100644 net/socket/nss_ssl_util.cc create mode 100644 net/socket/nss_ssl_util.h create mode 100644 net/socket/ssl_server_socket.h create mode 100644 net/socket/ssl_server_socket_nss.cc create mode 100644 net/socket/ssl_server_socket_nss.h create mode 100644 net/socket/ssl_server_socket_unittest.cc (limited to 'net') diff --git a/net/base/net_log_event_type_list.h b/net/base/net_log_event_type_list.h index e571685..f1bc4f8 100644 --- a/net/base/net_log_event_type_list.h +++ b/net/base/net_log_event_type_list.h @@ -325,6 +325,9 @@ EVENT_TYPE(SOCKS_UNKNOWN_ADDRESS_TYPE) // The start/end of a SSL connect(). EVENT_TYPE(SSL_CONNECT) +// The start/end of a SSL accept(). +EVENT_TYPE(SSL_ACCEPT) + // An SSL error occurred while trying to do the indicated activity. // The following parameters are attached to the event: // { diff --git a/net/base/ssl_config_service.cc b/net/base/ssl_config_service.cc index 9b0a903..d02df38 100644 --- a/net/base/ssl_config_service.cc +++ b/net/base/ssl_config_service.cc @@ -23,8 +23,9 @@ SSLConfig::SSLConfig() : rev_checking_enabled(true), ssl3_enabled(true), tls1_enabled(true), dnssec_enabled(false), snap_start_enabled(false), dns_cert_provenance_checking_enabled(false), - mitm_proxies_allowed(false), false_start_enabled(true), - send_client_cert(false), verify_ev_cert(false), ssl3_fallback(false) { + session_resume_disabled(false), mitm_proxies_allowed(false), + false_start_enabled(true), send_client_cert(false), + verify_ev_cert(false), ssl3_fallback(false) { } SSLConfig::~SSLConfig() { diff --git a/net/base/ssl_config_service.h b/net/base/ssl_config_service.h index c1ae553..de2ebef 100644 --- a/net/base/ssl_config_service.h +++ b/net/base/ssl_config_service.h @@ -32,6 +32,11 @@ struct SSLConfig { // True if we'll do async checks for certificate provenance using DNS. bool dns_cert_provenance_checking_enabled; + // TODO(hclam): This option is used to simplify the SSLServerSocketNSS + // implementation and should be removed when session caching is implemented. + // See http://crbug.com/67236 for more details. + bool session_resume_disabled; // Don't allow session resume. + // Cipher suites which should be explicitly prevented from being used in // addition to those disabled by the net built-in policy -- by default, all // cipher suites supported by the underlying SSL implementation will be diff --git a/net/base/x509_certificate.h b/net/base/x509_certificate.h index 3ee7304..c59c33c 100644 --- a/net/base/x509_certificate.h +++ b/net/base/x509_certificate.h @@ -287,6 +287,11 @@ class X509Certificate : public base::RefCountedThreadSafe { int flags, CertVerifyResult* verify_result) const; + // This method returns the DER encoded certificate. + // If the return value is true then the DER encoded certificate is available. + // The content of the DER encoded certificate is written to |encoded|. + bool GetDEREncoded(std::string* encoded); + OSCertHandle os_cert_handle() const { return cert_handle_; } // Returns true if two OSCertHandles refer to identical certificates. diff --git a/net/base/x509_certificate_mac.cc b/net/base/x509_certificate_mac.cc index f7c89e4..fd965cb3 100644 --- a/net/base/x509_certificate_mac.cc +++ b/net/base/x509_certificate_mac.cc @@ -651,6 +651,17 @@ int X509Certificate::Verify(const std::string& hostname, int flags, return OK; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + encoded->clear(); + CSSM_DATA der_data; + if(SecCertificateGetData(cert_handle_, &der_data) == noErr) { + encoded->append(reinterpret_cast(der_data.Data), + der_data.Length); + return true; + } + return false; +} + bool X509Certificate::VerifyEV() const { // We don't call this private method, but we do need to implement it because // it's defined in x509_certificate.h. We perform EV checking in the diff --git a/net/base/x509_certificate_nss.cc b/net/base/x509_certificate_nss.cc index 2962cb5..05e736c 100644 --- a/net/base/x509_certificate_nss.cc +++ b/net/base/x509_certificate_nss.cc @@ -829,6 +829,15 @@ bool X509Certificate::VerifyEV() const { return true; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + if (!cert_handle_->derCert.len) + return false; + encoded->clear(); + encoded->append(reinterpret_cast(cert_handle_->derCert.data), + cert_handle_->derCert.len); + return true; +} + // static bool X509Certificate::IsSameOSCert(X509Certificate::OSCertHandle a, X509Certificate::OSCertHandle b) { diff --git a/net/base/x509_certificate_openssl.cc b/net/base/x509_certificate_openssl.cc index c6ffb2c..cf43610 100644 --- a/net/base/x509_certificate_openssl.cc +++ b/net/base/x509_certificate_openssl.cc @@ -462,6 +462,11 @@ int X509Certificate::Verify(const std::string& hostname, return OK; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + // TODO(port): Implement. + return false; +} + // static bool X509Certificate::IsSameOSCert(X509Certificate::OSCertHandle a, X509Certificate::OSCertHandle b) { diff --git a/net/base/x509_certificate_unittest.cc b/net/base/x509_certificate_unittest.cc index 83c11fa..dba5ef3 100644 --- a/net/base/x509_certificate_unittest.cc +++ b/net/base/x509_certificate_unittest.cc @@ -672,6 +672,18 @@ TEST(X509CertificateTest, CreateSelfSigned) { EXPECT_EQ("subject", cert->subject().GetDisplayName()); EXPECT_FALSE(cert->HasExpired()); } + +TEST(X509CertificateTest, GetDEREncoded) { + scoped_ptr private_key( + base::RSAPrivateKey::Create(1024)); + scoped_refptr cert = + net::X509Certificate::CreateSelfSigned( + private_key.get(), "CN=subject", 0, base::TimeDelta::FromDays(1)); + + std::string der_cert; + EXPECT_TRUE(cert->GetDEREncoded(&der_cert)); + EXPECT_FALSE(der_cert.empty()); +} #endif class X509CertificateParseTest diff --git a/net/base/x509_certificate_win.cc b/net/base/x509_certificate_win.cc index 568c1fd..663563d 100644 --- a/net/base/x509_certificate_win.cc +++ b/net/base/x509_certificate_win.cc @@ -843,6 +843,15 @@ int X509Certificate::Verify(const std::string& hostname, return OK; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + if (!cert_handle_->pbCertEncoded || !cert_handle_->cbCertEncoded) + return false; + encoded->clear(); + encoded->append(reinterpret_cast(cert_handle_->pbCertEncoded), + cert_handle_->cbCertEncoded); + return true; +} + // Returns true if the certificate is an extended-validation certificate. // // This function checks the certificatePolicies extensions of the diff --git a/net/net.gyp b/net/net.gyp index 5e58283..2c24152 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -594,6 +594,8 @@ 'socket/client_socket_pool_manager.h', 'socket/dns_cert_provenance_checker.cc', 'socket/dns_cert_provenance_checker.h', + 'socket/nss_ssl_util.cc', + 'socket/nss_ssl_util.h', 'socket/socket.h', 'socket/socks5_client_socket.cc', 'socket/socks5_client_socket.h', @@ -619,6 +621,9 @@ 'socket/ssl_client_socket_win.h', 'socket/ssl_error_params.cc', 'socket/ssl_error_params.h', + 'socket/ssl_server_socket.h', + 'socket/ssl_server_socket_nss.cc', + 'socket/ssl_server_socket_nss.h', 'socket/ssl_host_info.cc', 'socket/ssl_host_info.h', 'socket/tcp_client_socket.cc', @@ -750,6 +755,8 @@ 'socket/ssl_client_socket_nss.h', 'socket/ssl_client_socket_nss_factory.cc', 'socket/ssl_client_socket_nss_factory.h', + 'socket/ssl_server_socket_nss.cc', + 'socket/ssl_server_socket_nss.h', ], }, { # else !use_openssl: remove the unneeded files @@ -960,6 +967,7 @@ 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', 'socket/ssl_client_socket_pool_unittest.cc', + 'socket/ssl_server_socket_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', 'socket/tcp_client_socket_unittest.cc', 'socket_stream/socket_stream_metrics_unittest.cc', diff --git a/net/socket/nss_ssl_util.cc b/net/socket/nss_ssl_util.cc new file mode 100644 index 0000000..eb8bafb --- /dev/null +++ b/net/socket/nss_ssl_util.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/nss_ssl_util.h" + +#include +#include +#include +#include + +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/nss_util.h" +#include "base/singleton.h" +#include "base/thread_restrictions.h" +#include "base/values.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" + +namespace net { + +class NSSSSLInitSingleton { + public: + NSSSSLInitSingleton() { + base::EnsureNSSInit(); + + NSS_SetDomesticPolicy(); + +#if defined(USE_SYSTEM_SSL) + // Use late binding to avoid scary but benign warning + // "Symbol `SSL_ImplementedCiphers' has different size in shared object, + // consider re-linking" + // TODO(wtc): Use the new SSL_GetImplementedCiphers and + // SSL_GetNumImplementedCiphers functions when we require NSS 3.12.6. + // See https://bugzilla.mozilla.org/show_bug.cgi?id=496993. + const PRUint16* pSSL_ImplementedCiphers = static_cast( + dlsym(RTLD_DEFAULT, "SSL_ImplementedCiphers")); + if (pSSL_ImplementedCiphers == NULL) { + NOTREACHED() << "Can't get list of supported ciphers"; + return; + } +#else +#define pSSL_ImplementedCiphers SSL_ImplementedCiphers +#endif + + // Explicitly enable exactly those ciphers with keys of at least 80 bits + for (int i = 0; i < SSL_NumImplementedCiphers; i++) { + SSLCipherSuiteInfo info; + if (SSL_GetCipherSuiteInfo(pSSL_ImplementedCiphers[i], &info, + sizeof(info)) == SECSuccess) { + SSL_CipherPrefSetDefault(pSSL_ImplementedCiphers[i], + (info.effectiveKeyBits >= 80)); + } + } + + // Enable SSL. + SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE); + + // All other SSL options are set per-session by SSLClientSocket and + // SSLServerSocket. + } + + ~NSSSSLInitSingleton() { + // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY. + SSL_ClearSessionCache(); + } +}; + +static base::LazyInstance g_nss_ssl_init_singleton( + base::LINKER_INITIALIZED); + +// Initialize the NSS SSL library if it isn't already initialized. This must +// be called before any other NSS SSL functions. This function is +// thread-safe, and the NSS SSL library will only ever be initialized once. +// The NSS SSL library will be properly shut down on program exit. +void EnsureNSSSSLInit() { + // Initializing SSL causes us to do blocking IO. + // Temporarily allow it until we fix + // http://code.google.com/p/chromium/issues/detail?id=59847 + base::ThreadRestrictions::ScopedAllowIO allow_io; + + g_nss_ssl_init_singleton.Get(); +} + +// Map a Chromium net error code to an NSS error code. +// See _MD_unix_map_default_error in the NSS source +// tree for inspiration. +PRErrorCode MapErrorToNSS(int result) { + if (result >=0) + return result; + + switch (result) { + case ERR_IO_PENDING: + return PR_WOULD_BLOCK_ERROR; + case ERR_ACCESS_DENIED: + case ERR_NETWORK_ACCESS_DENIED: + // For connect, this could be mapped to PR_ADDRESS_NOT_SUPPORTED_ERROR. + return PR_NO_ACCESS_RIGHTS_ERROR; + case ERR_NOT_IMPLEMENTED: + return PR_NOT_IMPLEMENTED_ERROR; + case ERR_INTERNET_DISCONNECTED: // Equivalent to ENETDOWN. + return PR_NETWORK_UNREACHABLE_ERROR; // Best approximation. + case ERR_CONNECTION_TIMED_OUT: + case ERR_TIMED_OUT: + return PR_IO_TIMEOUT_ERROR; + case ERR_CONNECTION_RESET: + return PR_CONNECT_RESET_ERROR; + case ERR_CONNECTION_ABORTED: + return PR_CONNECT_ABORTED_ERROR; + case ERR_CONNECTION_REFUSED: + return PR_CONNECT_REFUSED_ERROR; + case ERR_ADDRESS_UNREACHABLE: + return PR_HOST_UNREACHABLE_ERROR; // Also PR_NETWORK_UNREACHABLE_ERROR. + case ERR_ADDRESS_INVALID: + return PR_ADDRESS_NOT_AVAILABLE_ERROR; + case ERR_NAME_NOT_RESOLVED: + return PR_DIRECTORY_LOOKUP_ERROR; + default: + LOG(WARNING) << "MapErrorToNSS " << result + << " mapped to PR_UNKNOWN_ERROR"; + return PR_UNKNOWN_ERROR; + } +} + +// The default error mapping function. +// Maps an NSS error code to a network error code. +int MapNSSError(PRErrorCode err) { + // TODO(port): fill this out as we learn what's important + switch (err) { + case PR_WOULD_BLOCK_ERROR: + return ERR_IO_PENDING; + case PR_ADDRESS_NOT_SUPPORTED_ERROR: // For connect. + case PR_NO_ACCESS_RIGHTS_ERROR: + return ERR_ACCESS_DENIED; + case PR_IO_TIMEOUT_ERROR: + return ERR_TIMED_OUT; + case PR_CONNECT_RESET_ERROR: + return ERR_CONNECTION_RESET; + case PR_CONNECT_ABORTED_ERROR: + return ERR_CONNECTION_ABORTED; + case PR_CONNECT_REFUSED_ERROR: + return ERR_CONNECTION_REFUSED; + case PR_HOST_UNREACHABLE_ERROR: + case PR_NETWORK_UNREACHABLE_ERROR: + return ERR_ADDRESS_UNREACHABLE; + case PR_ADDRESS_NOT_AVAILABLE_ERROR: + return ERR_ADDRESS_INVALID; + case PR_INVALID_ARGUMENT_ERROR: + return ERR_INVALID_ARGUMENT; + case PR_END_OF_FILE_ERROR: + return ERR_CONNECTION_CLOSED; + case PR_NOT_IMPLEMENTED_ERROR: + return ERR_NOT_IMPLEMENTED; + + case SEC_ERROR_INVALID_ARGS: + return ERR_INVALID_ARGUMENT; + + case SSL_ERROR_SSL_DISABLED: + return ERR_NO_SSL_VERSIONS_ENABLED; + case SSL_ERROR_NO_CYPHER_OVERLAP: + case SSL_ERROR_UNSUPPORTED_VERSION: + return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; + case SSL_ERROR_HANDSHAKE_FAILURE_ALERT: + case SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT: + case SSL_ERROR_ILLEGAL_PARAMETER_ALERT: + return ERR_SSL_PROTOCOL_ERROR; + case SSL_ERROR_DECOMPRESSION_FAILURE_ALERT: + return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; + case SSL_ERROR_BAD_MAC_ALERT: + return ERR_SSL_BAD_RECORD_MAC_ALERT; + case SSL_ERROR_UNSAFE_NEGOTIATION: + return ERR_SSL_UNSAFE_NEGOTIATION; + case SSL_ERROR_WEAK_SERVER_KEY: + return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; + + default: { + if (IS_SSL_ERROR(err)) { + LOG(WARNING) << "Unknown SSL error " << err << + " mapped to net::ERR_SSL_PROTOCOL_ERROR"; + return ERR_SSL_PROTOCOL_ERROR; + } + LOG(WARNING) << "Unknown error " << err << + " mapped to net::ERR_FAILED"; + return ERR_FAILED; + } + } +} + +// Context-sensitive error mapping functions. +int MapNSSHandshakeError(PRErrorCode err) { + switch (err) { + // If the server closed on us, it is a protocol error. + // Some TLS-intolerant servers do this when we request TLS. + case PR_END_OF_FILE_ERROR: + // The handshake may fail because some signature (for example, the + // signature in the ServerKeyExchange message for an ephemeral + // Diffie-Hellman cipher suite) is invalid. + case SEC_ERROR_BAD_SIGNATURE: + return ERR_SSL_PROTOCOL_ERROR; + default: + return MapNSSError(err); + } +} + +// Extra parameters to attach to the NetLog when we receive an error in response +// to a call to an NSS function. Used instead of SSLErrorParams with +// events of type TYPE_SSL_NSS_ERROR. Automatically looks up last PR error. +class SSLFailedNSSFunctionParams : public NetLog::EventParameters { + public: + // |param| is ignored if it has a length of 0. + SSLFailedNSSFunctionParams(const std::string& function, + const std::string& param) + : function_(function), param_(param), ssl_lib_error_(PR_GetError()) { + } + + virtual Value* ToValue() const { + DictionaryValue* dict = new DictionaryValue(); + dict->SetString("function", function_); + if (!param_.empty()) + dict->SetString("param", param_); + dict->SetInteger("ssl_lib_error", ssl_lib_error_); + return dict; + } + + private: + const std::string function_; + const std::string param_; + const PRErrorCode ssl_lib_error_; +}; + +void LogFailedNSSFunction(const BoundNetLog& net_log, + const char* function, + const char* param) { + net_log.AddEvent( + NetLog::TYPE_SSL_NSS_ERROR, + make_scoped_refptr(new SSLFailedNSSFunctionParams(function, param))); +} + +} // namespace net diff --git a/net/socket/nss_ssl_util.h b/net/socket/nss_ssl_util.h new file mode 100644 index 0000000..64bf3cf --- /dev/null +++ b/net/socket/nss_ssl_util.h @@ -0,0 +1,36 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file is only inclued in ssl_client_socket_nss.cc and +// ssl_server_socket_nss.cc to share common functions of NSS. + +#ifndef NET_SOCKET_NSS_SSL_UTIL_H_ +#define NEt_SOCKET_NSS_SSL_UTIL_H_ + +#include + +namespace net { + +class BoundNetLog; + +// Initalize NSS SSL library. +void EnsureNSSSSLInit(); + +// Log a failed NSS funcion call. +void LogFailedNSSFunction(const BoundNetLog& net_log, + const char* function, + const char* param); + +// Map network error code to NSS error code. +PRErrorCode MapErrorToNSS(int result); + +// Map NSS error code to network error code. +int MapNSSError(PRErrorCode err); + +// Map NSS handshake error to network error code. +int MapNSSHandshakeError(PRErrorCode err); + +} // namespace net + +#endif // NET_SOCKET_NSS_SSL_UTIL_H_ diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 3f56b59..49e065f 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -63,7 +63,6 @@ #include "base/compiler_specific.h" #include "base/metrics/histogram.h" -#include "base/lazy_instance.h" #include "base/logging.h" #include "base/nss_util.h" #include "base/string_number_conversions.h" @@ -88,6 +87,7 @@ #include "net/ocsp/nss_ocsp.h" #include "net/socket/client_socket_handle.h" #include "net/socket/dns_cert_provenance_checker.h" +#include "net/socket/nss_ssl_util.h" #include "net/socket/ssl_error_params.h" #include "net/socket/ssl_host_info.h" @@ -139,183 +139,6 @@ namespace net { namespace { -class NSSSSLInitSingleton { - public: - NSSSSLInitSingleton() { - base::EnsureNSSInit(); - - NSS_SetDomesticPolicy(); - -#if defined(USE_SYSTEM_SSL) - // Use late binding to avoid scary but benign warning - // "Symbol `SSL_ImplementedCiphers' has different size in shared object, - // consider re-linking" - // TODO(wtc): Use the new SSL_GetImplementedCiphers and - // SSL_GetNumImplementedCiphers functions when we require NSS 3.12.6. - // See https://bugzilla.mozilla.org/show_bug.cgi?id=496993. - const PRUint16* pSSL_ImplementedCiphers = static_cast( - dlsym(RTLD_DEFAULT, "SSL_ImplementedCiphers")); - if (pSSL_ImplementedCiphers == NULL) { - NOTREACHED() << "Can't get list of supported ciphers"; - return; - } -#else -#define pSSL_ImplementedCiphers SSL_ImplementedCiphers -#endif - - // Explicitly enable exactly those ciphers with keys of at least 80 bits - for (int i = 0; i < SSL_NumImplementedCiphers; i++) { - SSLCipherSuiteInfo info; - if (SSL_GetCipherSuiteInfo(pSSL_ImplementedCiphers[i], &info, - sizeof(info)) == SECSuccess) { - SSL_CipherPrefSetDefault(pSSL_ImplementedCiphers[i], - (info.effectiveKeyBits >= 80)); - } - } - - // Enable SSL. - SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE); - - // All other SSL options are set per-session by SSLClientSocket. - } - - ~NSSSSLInitSingleton() { - // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY. - SSL_ClearSessionCache(); - } -}; - -static base::LazyInstance g_nss_ssl_init_singleton( - base::LINKER_INITIALIZED); - -// Initialize the NSS SSL library if it isn't already initialized. This must -// be called before any other NSS SSL functions. This function is -// thread-safe, and the NSS SSL library will only ever be initialized once. -// The NSS SSL library will be properly shut down on program exit. -void EnsureNSSSSLInit() { - // Initializing SSL causes us to do blocking IO. - // Temporarily allow it until we fix - // http://code.google.com/p/chromium/issues/detail?id=59847 - base::ThreadRestrictions::ScopedAllowIO allow_io; - - g_nss_ssl_init_singleton.Get(); -} - -// The default error mapping function. -// Maps an NSPR error code to a network error code. -int MapNSPRError(PRErrorCode err) { - // TODO(port): fill this out as we learn what's important - switch (err) { - case PR_WOULD_BLOCK_ERROR: - return ERR_IO_PENDING; - case PR_ADDRESS_NOT_SUPPORTED_ERROR: // For connect. - case PR_NO_ACCESS_RIGHTS_ERROR: - return ERR_ACCESS_DENIED; - case PR_IO_TIMEOUT_ERROR: - return ERR_TIMED_OUT; - case PR_CONNECT_RESET_ERROR: - return ERR_CONNECTION_RESET; - case PR_CONNECT_ABORTED_ERROR: - return ERR_CONNECTION_ABORTED; - case PR_CONNECT_REFUSED_ERROR: - return ERR_CONNECTION_REFUSED; - case PR_HOST_UNREACHABLE_ERROR: - case PR_NETWORK_UNREACHABLE_ERROR: - return ERR_ADDRESS_UNREACHABLE; - case PR_ADDRESS_NOT_AVAILABLE_ERROR: - return ERR_ADDRESS_INVALID; - case PR_INVALID_ARGUMENT_ERROR: - return ERR_INVALID_ARGUMENT; - case PR_END_OF_FILE_ERROR: - return ERR_CONNECTION_CLOSED; - case PR_NOT_IMPLEMENTED_ERROR: - return ERR_NOT_IMPLEMENTED; - - case SEC_ERROR_INVALID_ARGS: - return ERR_INVALID_ARGUMENT; - - case SSL_ERROR_SSL_DISABLED: - return ERR_NO_SSL_VERSIONS_ENABLED; - case SSL_ERROR_NO_CYPHER_OVERLAP: - case SSL_ERROR_UNSUPPORTED_VERSION: - return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; - case SSL_ERROR_HANDSHAKE_FAILURE_ALERT: - case SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT: - case SSL_ERROR_ILLEGAL_PARAMETER_ALERT: - return ERR_SSL_PROTOCOL_ERROR; - case SSL_ERROR_DECOMPRESSION_FAILURE_ALERT: - return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; - case SSL_ERROR_BAD_MAC_ALERT: - return ERR_SSL_BAD_RECORD_MAC_ALERT; - case SSL_ERROR_UNSAFE_NEGOTIATION: - return ERR_SSL_UNSAFE_NEGOTIATION; - case SSL_ERROR_WEAK_SERVER_KEY: - return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; - - default: { - if (IS_SSL_ERROR(err)) { - LOG(WARNING) << "Unknown SSL error " << err << - " mapped to net::ERR_SSL_PROTOCOL_ERROR"; - return ERR_SSL_PROTOCOL_ERROR; - } - LOG(WARNING) << "Unknown error " << err << - " mapped to net::ERR_FAILED"; - return ERR_FAILED; - } - } -} - -// Context-sensitive error mapping functions. - -int MapHandshakeError(PRErrorCode err) { - switch (err) { - // If the server closed on us, it is a protocol error. - // Some TLS-intolerant servers do this when we request TLS. - case PR_END_OF_FILE_ERROR: - // The handshake may fail because some signature (for example, the - // signature in the ServerKeyExchange message for an ephemeral - // Diffie-Hellman cipher suite) is invalid. - case SEC_ERROR_BAD_SIGNATURE: - return ERR_SSL_PROTOCOL_ERROR; - default: - return MapNSPRError(err); - } -} - -// Extra parameters to attach to the NetLog when we receive an error in response -// to a call to an NSS function. Used instead of SSLErrorParams with -// events of type TYPE_SSL_NSS_ERROR. Automatically looks up last PR error. -class SSLFailedNSSFunctionParams : public NetLog::EventParameters { - public: - // |param| is ignored if it has a length of 0. - SSLFailedNSSFunctionParams(const std::string& function, - const std::string& param) - : function_(function), param_(param), ssl_lib_error_(PR_GetError()) { - } - - virtual Value* ToValue() const { - DictionaryValue* dict = new DictionaryValue(); - dict->SetString("function", function_); - if (!param_.empty()) - dict->SetString("param", param_); - dict->SetInteger("ssl_lib_error", ssl_lib_error_); - return dict; - } - - private: - const std::string function_; - const std::string param_; - const PRErrorCode ssl_lib_error_; -}; - -void LogFailedNSSFunction(const BoundNetLog& net_log, - const char* function, - const char* param) { - net_log.AddEvent( - NetLog::TYPE_SSL_NSS_ERROR, - make_scoped_refptr(new SSLFailedNSSFunctionParams(function, param))); -} - #if defined(OS_WIN) // This callback is intended to be used with CertFindChainInStore. In addition @@ -736,6 +559,13 @@ int SSLClientSocketNSS::InitializeSSLOptions() { #error "You need to install NSS-3.12 or later to build chromium" #endif + rv = SSL_OptionSet(nss_fd_, SSL_NO_CACHE, + ssl_config_.session_resume_disabled); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_NO_CACHE"); + return ERR_UNEXPECTED; + } + #ifdef SSL_ENABLE_DEFLATE // Some web servers have been found to break if TLS is used *or* if DEFLATE // is advertised. Thus, if TLS is disabled (probably because we are doing @@ -1363,46 +1193,6 @@ void SSLClientSocketNSS::OnRecvComplete(int result) { LeaveFunction(""); } -// Map a Chromium net error code to an NSS error code. -// See _MD_unix_map_default_error in the NSS source -// tree for inspiration. -static PRErrorCode MapErrorToNSS(int result) { - if (result >=0) - return result; - - switch (result) { - case ERR_IO_PENDING: - return PR_WOULD_BLOCK_ERROR; - case ERR_ACCESS_DENIED: - case ERR_NETWORK_ACCESS_DENIED: - // For connect, this could be mapped to PR_ADDRESS_NOT_SUPPORTED_ERROR. - return PR_NO_ACCESS_RIGHTS_ERROR; - case ERR_NOT_IMPLEMENTED: - return PR_NOT_IMPLEMENTED_ERROR; - case ERR_INTERNET_DISCONNECTED: // Equivalent to ENETDOWN. - return PR_NETWORK_UNREACHABLE_ERROR; // Best approximation. - case ERR_CONNECTION_TIMED_OUT: - case ERR_TIMED_OUT: - return PR_IO_TIMEOUT_ERROR; - case ERR_CONNECTION_RESET: - return PR_CONNECT_RESET_ERROR; - case ERR_CONNECTION_ABORTED: - return PR_CONNECT_ABORTED_ERROR; - case ERR_CONNECTION_REFUSED: - return PR_CONNECT_REFUSED_ERROR; - case ERR_ADDRESS_UNREACHABLE: - return PR_HOST_UNREACHABLE_ERROR; // Also PR_NETWORK_UNREACHABLE_ERROR. - case ERR_ADDRESS_INVALID: - return PR_ADDRESS_NOT_AVAILABLE_ERROR; - case ERR_NAME_NOT_RESOLVED: - return PR_DIRECTORY_LOOKUP_ERROR; - default: - LOG(WARNING) << "MapErrorToNSS " << result - << " mapped to PR_UNKNOWN_ERROR"; - return PR_UNKNOWN_ERROR; - } -} - // Do network I/O between the given buffer and the given socket. // Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) bool SSLClientSocketNSS::DoTransportIO() { @@ -2191,7 +1981,7 @@ int SSLClientSocketNSS::DoHandshake() { } } else { PRErrorCode prerr = PR_GetError(); - net_error = MapHandshakeError(prerr); + net_error = MapNSSHandshakeError(prerr); // If not done, stay in this state if (net_error == ERR_IO_PENDING) { @@ -2580,7 +2370,7 @@ int SSLClientSocketNSS::DoPayloadRead() { return ERR_IO_PENDING; } LeaveFunction(""); - rv = MapNSPRError(prerr); + rv = MapNSSError(prerr); net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, make_scoped_refptr(new SSLErrorParams(rv, prerr))); return rv; @@ -2601,7 +2391,7 @@ int SSLClientSocketNSS::DoPayloadWrite() { return ERR_IO_PENDING; } LeaveFunction(""); - rv = MapNSPRError(prerr); + rv = MapNSSError(prerr); net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, make_scoped_refptr(new SSLErrorParams(rv, prerr))); return rv; diff --git a/net/socket/ssl_server_socket.h b/net/socket/ssl_server_socket.h new file mode 100644 index 0000000..b689c71 --- /dev/null +++ b/net/socket/ssl_server_socket.h @@ -0,0 +1,53 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SERVER_SOCKET_H_ +#define NET_SOCKET_SSL_SERVER_SOCKET_H_ + +#include "base/basictypes.h" +#include "net/base/completion_callback.h" +#include "net/socket/socket.h" + +namespace base { +class RSAPrivateKey; +} // namespace base + +namespace net { + +class IOBuffer; +struct SSLConfig; +class X509Certificate; + +// SSLServerSocket takes an already connected socket and performs SSL on top of +// it. +// +// This class is designed to work in a peer-to-peer connection and is not +// intended to be used as a standalone SSL server. +class SSLServerSocket : public Socket { + public: + virtual ~SSLServerSocket() {} + + // Performs an SSL server handshake on the existing socket. The given socket + // must have already been connected. + // + // Accept either returns ERR_IO_PENDING, in which case the given callback + // will be called in the future with the real result, or it completes + // synchronously, returning the result immediately. + virtual int Accept(CompletionCallback* callback) = 0; +}; + +// Creates an SSL server socket using an already connected socket. A certificate +// and private key needs to be provided. +// +// This created server socket will take ownership of |socket|. However |key| +// is copied. +// TODO(hclam): Defines ServerSocketFactory to create SSLServerSocket. This will +// make mocking easier. +SSLServerSocket* CreateSSLServerSocket( + Socket* socket, X509Certificate* certificate, base::RSAPrivateKey* key, + const SSLConfig& ssl_config); + +} // namespace net + +#endif // NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc new file mode 100644 index 0000000..2e47fb8 --- /dev/null +++ b/net/socket/ssl_server_socket_nss.cc @@ -0,0 +1,677 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_server_socket_nss.h" + +#if defined(OS_WIN) +#include +#endif + +#if defined(USE_SYSTEM_SSL) +#include +#endif +#if defined(OS_MACOSX) +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "base/crypto/rsa_private_key.h" +#include "base/nss_util_internal.h" +#include "base/ref_counted.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/ocsp/nss_ocsp.h" +#include "net/socket/nss_ssl_util.h" +#include "net/socket/ssl_error_params.h" + +static const int kRecvBufferSize = 4096; + +#define GotoState(s) next_handshake_state_ = s + +namespace net { + +SSLServerSocket* CreateSSLServerSocket( + Socket* socket, X509Certificate* cert, base::RSAPrivateKey* key, + const SSLConfig& ssl_config) { + return new SSLServerSocketNSS(socket, cert, key, ssl_config); +} + +SSLServerSocketNSS::SSLServerSocketNSS( + Socket* transport_socket, + scoped_refptr cert, + base::RSAPrivateKey* key, + const SSLConfig& ssl_config) + : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( + this, &SSLServerSocketNSS::BufferSendComplete)), + ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( + this, &SSLServerSocketNSS::BufferRecvComplete)), + transport_send_busy_(false), + transport_recv_busy_(false), + user_accept_callback_(NULL), + user_read_callback_(NULL), + user_write_callback_(NULL), + nss_fd_(NULL), + nss_bufs_(NULL), + transport_socket_(transport_socket), + ssl_config_(ssl_config), + cert_(cert), + next_handshake_state_(STATE_NONE), + completed_handshake_(false) { + ssl_config_.false_start_enabled = false; + ssl_config_.ssl3_enabled = true; + ssl_config_.tls1_enabled = true; + + // TODO(hclam): Need a better way to clone a key. + std::vector key_bytes; + CHECK(key->ExportPrivateKey(&key_bytes)); + key_.reset(base::RSAPrivateKey::CreateFromPrivateKeyInfo(key_bytes)); + CHECK(key_.get()); +} + +SSLServerSocketNSS::~SSLServerSocketNSS() { + if (nss_fd_ != NULL) { + PR_Close(nss_fd_); + nss_fd_ = NULL; + } +} + +int SSLServerSocketNSS::Init() { + // Initialize the NSS SSL library in a threadsafe way. This also + // initializes the NSS base library. + EnsureNSSSSLInit(); + if (!NSS_IsInitialized()) + return ERR_UNEXPECTED; +#if !defined(OS_MACOSX) && !defined(OS_WIN) + // We must call EnsureOCSPInit() here, on the IO thread, to get the IO loop + // by MessageLoopForIO::current(). + // X509Certificate::Verify() runs on a worker thread of CertVerifier. + EnsureOCSPInit(); +#endif + + return OK; +} + +int SSLServerSocketNSS::Accept(CompletionCallback* callback) { + net_log_.BeginEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + + int rv = Init(); + if (rv != OK) { + LOG(ERROR) << "Failed to initialize NSS"; + net_log_.EndEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + return rv; + } + + rv = InitializeSSLOptions(); + if (rv != OK) { + LOG(ERROR) << "Failed to initialize SSL options"; + net_log_.EndEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + return rv; + } + + // Set peer address. TODO(hclam): This should be in a separate method. + PRNetAddr peername; + memset(&peername, 0, sizeof(peername)); + peername.raw.family = AF_INET; + memio_SetPeerName(nss_fd_, &peername); + + GotoState(STATE_HANDSHAKE); + rv = DoHandshakeLoop(net::OK); + if (rv == ERR_IO_PENDING) { + user_accept_callback_ = callback; + } else { + net_log_.EndEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + } + + return rv > OK ? OK : rv; +} + +int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!user_read_callback_); + DCHECK(!user_accept_callback_); + DCHECK(!user_read_buf_); + DCHECK(nss_bufs_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + DCHECK(completed_handshake_); + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + return rv; +} + +int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!user_write_callback_); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} + +// static +// NSS calls this if an incoming certificate needs to be verified. +// Do nothing but return SECSuccess. +// This is called only in full handshake mode. +// Peer certificate is retrieved in HandshakeCallback() later, which is called +// in full handshake mode or in resumption handshake mode. +SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { + // TODO(hclam): Implement. + // Tell NSS to not verify the certificate. + return SECSuccess; +} + +// static +// NSS calls this when handshake is completed. +// After the SSL handshake is finished we need to verify the certificate. +void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, + void* arg) { + // TODO(hclam): Implement. +} + +int SSLServerSocketNSS::InitializeSSLOptions() { + // Transport connected, now hook it up to nss + // TODO(port): specify rx and tx buffer sizes separately + nss_fd_ = memio_CreateIOLayer(kRecvBufferSize); + if (nss_fd_ == NULL) { + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR error code. + } + + // Grab pointer to buffers + nss_bufs_ = memio_GetSecret(nss_fd_); + + /* Create SSL state machine */ + /* Push SSL onto our fake I/O socket */ + nss_fd_ = SSL_ImportFD(NULL, nss_fd_); + if (nss_fd_ == NULL) { + LogFailedNSSFunction(net_log_, "SSL_ImportFD", ""); + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code. + } + // TODO(port): set more ssl options! Check errors! + + int rv; + + rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL3, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL3"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_TLS, ssl_config_.tls1_enabled); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_TLS"); + return ERR_UNEXPECTED; + } + + for (std::vector::const_iterator it = + ssl_config_.disabled_cipher_suites.begin(); + it != ssl_config_.disabled_cipher_suites.end(); ++it) { + // This will fail if the specified cipher is not implemented by NSS, but + // the failure is harmless. + SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE); + } + + // Server socket doesn't need session tickets. + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction( + net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS"); + } + + // Doing this will force PR_Accept perform handshake as server. + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_SERVER, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_SERVER"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_REQUEST_CERTIFICATE, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUEST_CERTIFICATE"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_REQUIRE_CERTIFICATE, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUIRE_CERTIFICATE"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_NO_CACHE, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_NO_CACHE"); + return ERR_UNEXPECTED; + } + + rv = SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ConfigureServerSessionIDCache", ""); + return ERR_UNEXPECTED; + } + + rv = SSL_AuthCertificateHook(nss_fd_, OwnAuthCertHandler, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_AuthCertificateHook", ""); + return ERR_UNEXPECTED; + } + + rv = SSL_HandshakeCallback(nss_fd_, HandshakeCallback, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_HandshakeCallback", ""); + return ERR_UNEXPECTED; + } + + // Get a certificate of CERTCertificate structure. + std::string der_string; + if (!cert_->GetDEREncoded(&der_string)) + return ERR_UNEXPECTED; + + SECItem der_cert; + der_cert.data = reinterpret_cast(const_cast( + der_string.data())); + der_cert.len = der_string.length(); + der_cert.type = siDERCertBuffer; + + // Parse into a CERTCertificate structure. + CERTCertificate* cert = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); + + // Get a key of SECKEYPrivateKey* structure. + std::vector key_vector; + if (!key_->ExportPrivateKey(&key_vector)) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + SECKEYPrivateKeyStr* private_key = NULL; + PK11SlotInfo *slot = base::GetDefaultNSSKeySlot(); + if (!slot) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + SECItem der_private_key_info; + der_private_key_info.data = + const_cast(&key_vector.front()); + der_private_key_info.len = key_vector.size(); + rv = PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot, &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE, + KU_DIGITAL_SIGNATURE, &private_key, NULL); + PK11_FreeSlot(slot); + if (rv != SECSuccess) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + // Assign server certificate and private key. + SSLKEAType cert_kea = NSS_FindCertKEAType(cert); + rv = SSL_ConfigSecureServer(nss_fd_, cert, private_key, cert_kea); + CERT_DestroyCertificate(cert); + SECKEY_DestroyPrivateKey(private_key); + + if (rv != SECSuccess) { + PRErrorCode prerr = PR_GetError(); + LOG(ERROR) << "Failed to config SSL server: " << prerr; + LogFailedNSSFunction(net_log_, "SSL_ConfigureSecureServer", ""); + return ERR_UNEXPECTED; + } + + // Tell SSL we're a server; needed if not letting NSPR do socket I/O + rv = SSL_ResetHandshake(nss_fd_, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ResetHandshake", ""); + return ERR_UNEXPECTED; + } + + return OK; +} + +// Return 0 for EOF, +// > 0 for bytes transferred immediately, +// < 0 for error (or the non-error ERR_IO_PENDING). +int SSLServerSocketNSS::BufferSend(void) { + if (transport_send_busy_) + return ERR_IO_PENDING; + + const char* buf1; + const char* buf2; + unsigned int len1, len2; + memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2); + const unsigned int len = len1 + len2; + + int rv = 0; + if (len) { + scoped_refptr send_buffer(new IOBuffer(len)); + memcpy(send_buffer->data(), buf1, len1); + memcpy(send_buffer->data() + len1, buf2, len2); + rv = transport_socket_->Write(send_buffer, len, + &buffer_send_callback_); + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv)); + } + } + + return rv; +} + +void SSLServerSocketNSS::BufferSendComplete(int result) { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result)); + transport_send_busy_ = false; + OnSendComplete(result); +} + +int SSLServerSocketNSS::BufferRecv(void) { + if (transport_recv_busy_) return ERR_IO_PENDING; + + char *buf; + int nb = memio_GetReadParams(nss_bufs_, &buf); + int rv; + if (!nb) { + // buffer too full to read into, so no I/O possible at moment + rv = ERR_IO_PENDING; + } else { + recv_buffer_ = new IOBuffer(nb); + rv = transport_socket_->Read(recv_buffer_, nb, &buffer_recv_callback_); + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + if (rv > 0) + memcpy(buf, recv_buffer_->data(), rv); + memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv)); + recv_buffer_ = NULL; + } + } + return rv; +} + +void SSLServerSocketNSS::BufferRecvComplete(int result) { + if (result > 0) { + char *buf; + memio_GetReadParams(nss_bufs_, &buf); + memcpy(buf, recv_buffer_->data(), result); + } + recv_buffer_ = NULL; + memio_PutReadResult(nss_bufs_, MapErrorToNSS(result)); + transport_recv_busy_ = false; + OnRecvComplete(result); +} + +void SSLServerSocketNSS::OnSendComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + if (!user_write_buf_ || !completed_handshake_) + return; + + int rv = DoWriteLoop(result); + if (rv != ERR_IO_PENDING) + DoWriteCallback(rv); +} + +void SSLServerSocketNSS::OnRecvComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + // Network layer received some data, check if client requested to read + // decrypted data. + if (!user_read_buf_ || !completed_handshake_) + return; + + int rv = DoReadLoop(result); + if (rv != ERR_IO_PENDING) + DoReadCallback(rv); +} + +void SSLServerSocketNSS::OnHandshakeIOComplete(int result) { + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(net::NetLog::TYPE_SSL_ACCEPT, NULL); + if (user_accept_callback_) + DoAcceptCallback(rv); + } +} + +void SSLServerSocketNSS::DoAcceptCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + + CompletionCallback* c = user_accept_callback_; + user_accept_callback_ = NULL; + c->Run(rv > OK ? OK : rv); +} + +void SSLServerSocketNSS::DoReadCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_read_callback_); + + // Since Run may result in Read being called, clear |user_read_callback_| + // up front. + CompletionCallback* c = user_read_callback_; + user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); +} + +void SSLServerSocketNSS::DoWriteCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_write_callback_); + + // Since Run may result in Write being called, clear |user_write_callback_| + // up front. + CompletionCallback* c = user_write_callback_; + user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); +} + +// Do network I/O between the given buffer and the given socket. +// Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) +bool SSLServerSocketNSS::DoTransportIO() { + bool network_moved = false; + if (nss_bufs_ != NULL) { + int nsent = BufferSend(); + int nreceived = BufferRecv(); + network_moved = (nsent > 0 || nreceived >= 0); + } + return network_moved; +} + +int SSLServerSocketNSS::DoPayloadRead() { + DCHECK(user_read_buf_); + DCHECK_GT(user_read_buf_len_, 0); + int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_); + if (rv >= 0) + return rv; + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) { + return ERR_IO_PENDING; + } + rv = MapNSSError(prerr); + net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, prerr))); + return rv; +} + +int SSLServerSocketNSS::DoPayloadWrite() { + DCHECK(user_write_buf_); + int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_); + if (rv >= 0) + return rv; + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) { + return ERR_IO_PENDING; + } + rv = MapNSSError(prerr); + net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, prerr))); + return rv; +} + +int SSLServerSocketNSS::DoHandshakeLoop(int last_io_result) { + bool network_moved; + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + // (This is a quirk carried over from the windows + // implementation. It makes reading the logs a bit harder.) + // State handlers can and often do call GotoState just + // to stay in the current state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + switch (state) { + case STATE_NONE: + // we're just pumping data between the buffer and the network + break; + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + default: + rv = ERR_UNEXPECTED; + LOG(DFATAL) << "unexpected state " << state; + break; + } + + // Do the actual network I/O + network_moved = DoTransportIO(); + } while ((rv != ERR_IO_PENDING || network_moved) && + next_handshake_state_ != STATE_NONE); + return rv; +} + +int SSLServerSocketNSS::DoReadLoop(int result) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, 0))); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadRead(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + return rv; +} + +int SSLServerSocketNSS::DoWriteLoop(int result) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, 0))); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + return rv; +} + +int SSLServerSocketNSS::DoHandshake() { + int net_error = net::OK; + SECStatus rv = SSL_ForceHandshake(nss_fd_); + + if (rv == SECSuccess) { + completed_handshake_ = true; + } else { + PRErrorCode prerr = PR_GetError(); + net_error = MapNSSHandshakeError(prerr); + + // If not done, stay in this state + if (net_error == ERR_IO_PENDING) { + GotoState(STATE_HANDSHAKE); + } else { + LOG(ERROR) << "handshake failed; NSS error code " << prerr + << ", net_error " << net_error; + net_log_.AddEvent( + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + make_scoped_refptr(new SSLErrorParams(net_error, prerr))); + } + } + return net_error; +} + +} // namespace net diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h new file mode 100644 index 0000000..3883c9b --- /dev/null +++ b/net/socket/ssl_server_socket_nss.h @@ -0,0 +1,133 @@ +// Copyright (c) 2006-2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ +#define NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ +#pragma once + +#include +#include +#include +#include + +#include "base/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_log.h" +#include "net/base/nss_memio.h" +#include "net/base/ssl_config_service.h" +#include "net/socket/ssl_server_socket.h" + +namespace net { + +class SSLServerSocketNSS : public SSLServerSocket { + public: + // This object takes ownership of the following parameters: + // |socket| - A socket that is already connected. + // |cert| - The certificate to be used by the server. + // + // The following parameters are copied in the constructor. + // |ssl_config| - Options for SSL socket. + // |key| - The private key used by the server. + SSLServerSocketNSS(Socket* transport_socket, + scoped_refptr cert, + base::RSAPrivateKey* key, + const SSLConfig& ssl_config); + virtual ~SSLServerSocketNSS(); + + // SSLServerSocket implementation. + virtual int Accept(CompletionCallback* callback); + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback); + virtual bool SetReceiveBufferSize(int32 size) { return false; } + virtual bool SetSendBufferSize(int32 size) { return false; } + + private: + virtual int Init(); + + int InitializeSSLOptions(); + + void OnSendComplete(int result); + void OnRecvComplete(int result); + void OnHandshakeIOComplete(int result); + + int BufferSend(); + void BufferSendComplete(int result); + int BufferRecv(); + void BufferRecvComplete(int result); + bool DoTransportIO(); + int DoPayloadWrite(); + int DoPayloadRead(); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + int DoHandshake(); + void DoAcceptCallback(int result); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + static SECStatus OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server); + static void HandshakeCallback(PRFileDesc* socket, void* arg); + + // Members used to send and receive buffer. + CompletionCallbackImpl buffer_send_callback_; + CompletionCallbackImpl buffer_recv_callback_; + bool transport_send_busy_; + bool transport_recv_busy_; + + scoped_refptr recv_buffer_; + + BoundNetLog net_log_; + + CompletionCallback* user_accept_callback_; + CompletionCallback* user_read_callback_; + CompletionCallback* user_write_callback_; + + // Used by Read function. + scoped_refptr user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr user_write_buf_; + int user_write_buf_len_; + + // The NSS SSL state machine + PRFileDesc* nss_fd_; + + // Buffers for the network end of the SSL state machine + memio_Private* nss_bufs_; + + // Socket for sending and receiving data. + scoped_ptr transport_socket_; + + // Options for the SSL socket. + // TODO(hclam): This memeber is currently not used. Should make use of this + // member to configure the socket. + SSLConfig ssl_config_; + + // Certificate for the server. + scoped_refptr cert_; + + // Private key used by the server. + scoped_ptr key_; + + enum State { + STATE_NONE, + STATE_HANDSHAKE, + }; + State next_handshake_state_; + bool completed_handshake_; + + DISALLOW_COPY_AND_ASSIGN(SSLServerSocketNSS); +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc new file mode 100644 index 0000000..781a3f4 --- /dev/null +++ b/net/socket/ssl_server_socket_unittest.cc @@ -0,0 +1,369 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This test suite uses SSLClientSocket to test the implementation of +// SSLServerSocket. In order to establish connections between the sockets +// we need two additional classes: +// 1. FakeSocket +// Connects SSL socket to FakeDataChannel. This class is just a stub. +// +// 2. FakeDataChannel +// Implements the actual exchange of data between two FakeSockets. +// +// Implementations of these two classes are included in this file. + +#include "net/socket/ssl_server_socket.h" + +#include + +#include "base/crypto/rsa_private_key.h" +#include "base/file_path.h" +#include "base/file_util.h" +#include "base/nss_util.h" +#include "base/path_service.h" +#include "net/base/address_list.h" +#include "net/base/cert_verifier.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/ssl_config_service.h" +#include "net/base/x509_certificate.h" +#include "net/socket/client_socket.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { + +class FakeDataChannel { + public: + FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) { + } + + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + if (data_.empty()) { + read_callback_ = callback; + read_buf_ = buf; + read_buf_len_ = buf_len; + return net::ERR_IO_PENDING; + } + return PropogateData(buf, buf_len); + } + + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + data_.push(new net::DrainableIOBuffer(buf, buf_len)); + DoReadCallback(); + return buf_len; + } + + private: + void DoReadCallback() { + if (!read_callback_) + return; + + int copied = PropogateData(read_buf_, read_buf_len_); + net::CompletionCallback* callback = read_callback_; + read_callback_ = NULL; + read_buf_ = NULL; + read_buf_len_ = 0; + callback->Run(copied); + } + + int PropogateData(scoped_refptr read_buf, int read_buf_len) { + scoped_refptr buf = data_.front(); + int copied = std::min(buf->BytesRemaining(), read_buf_len); + memcpy(read_buf->data(), buf->data(), copied); + buf->DidConsume(copied); + + if (!buf->BytesRemaining()) + data_.pop(); + return copied; + } + + net::CompletionCallback* read_callback_; + scoped_refptr read_buf_; + int read_buf_len_; + + std::queue > data_; + + DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); +}; + +class FakeSocket : public ClientSocket { + public: + FakeSocket(FakeDataChannel* incoming_channel, + FakeDataChannel* outgoing_channel) + : incoming_(incoming_channel), + outgoing_(outgoing_channel) { + } + + virtual ~FakeSocket() { + + } + + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return incoming_->Read(buf, buf_len, callback); + } + + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return outgoing_->Write(buf, buf_len, callback); + } + + virtual bool SetReceiveBufferSize(int32 size) { + return true; + } + + virtual bool SetSendBufferSize(int32 size) { + return true; + } + + virtual int Connect(CompletionCallback* callback) { + return net::OK; + } + + virtual void Disconnect() {} + + virtual bool IsConnected() const { + return true; + } + + virtual bool IsConnectedAndIdle() const { + return true; + } + + virtual int GetPeerAddress(AddressList* address) const { + net::IPAddressNumber ip_address(4); + *address = net::AddressList(ip_address, 0, false); + return net::OK; + } + + virtual const BoundNetLog& NetLog() const { + return net_log_; + } + + virtual void SetSubresourceSpeculation() {} + virtual void SetOmniboxSpeculation() {} + + virtual bool WasEverUsed() const { + return true; + } + + virtual bool UsingTCPFastOpen() const { + return false; + } + + private: + net::BoundNetLog net_log_; + FakeDataChannel* incoming_; + FakeDataChannel* outgoing_; + + DISALLOW_COPY_AND_ASSIGN(FakeSocket); +}; + +} // namespace + +// Verify the correctness of the test helper classes first. +TEST(FakeSocketTest, DataTransfer) { + // Establish channels between two sockets. + FakeDataChannel channel_1; + FakeDataChannel channel_2; + FakeSocket client(&channel_1, &channel_2); + FakeSocket server(&channel_2, &channel_1); + + const char kTestData[] = "testing123"; + const int kTestDataSize = strlen(kTestData); + const int kReadBufSize = 1024; + scoped_refptr write_buf = new net::StringIOBuffer(kTestData); + scoped_refptr read_buf = new net::IOBuffer(kReadBufSize); + + // Write then read. + EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL)); + EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL)); + EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); + + // Read then write. + TestCompletionCallback callback; + EXPECT_EQ(net::ERR_IO_PENDING, + server.Read(read_buf, kReadBufSize, &callback)); + EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL)); + EXPECT_EQ(kTestDataSize, callback.WaitForResult()); + EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); +} + +class SSLServerSocketTest : public PlatformTest { + public: + SSLServerSocketTest() + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { + } + + protected: + void Initialize() { + FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); + FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); + + FilePath certs_dir; + PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir); + certs_dir = certs_dir.AppendASCII("net"); + certs_dir = certs_dir.AppendASCII("data"); + certs_dir = certs_dir.AppendASCII("ssl"); + certs_dir = certs_dir.AppendASCII("certificates"); + + FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); + std::string cert_der; + ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); + + scoped_refptr cert = + X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); + + FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); + std::string key_string; + ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); + std::vector key_vector( + reinterpret_cast(key_string.data()), + reinterpret_cast(key_string.data() + + key_string.length())); + + scoped_ptr private_key( + base::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); + + net::SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + ssl_config.snap_start_enabled = false; + ssl_config.ssl3_enabled = true; + ssl_config.tls1_enabled = true; + ssl_config.session_resume_disabled = true; + + // Certificate provided by the host doesn't need authority. + net::SSLConfig::CertAndStatus cert_and_status; + cert_and_status.cert_status = net::ERR_CERT_AUTHORITY_INVALID; + cert_and_status.cert = cert; + ssl_config.allowed_bad_certs.push_back(cert_and_status); + + net::HostPortPair host_and_pair("unittest", 0); + client_socket_.reset( + socket_factory_->CreateSSLClientSocket( + fake_client_socket, host_and_pair, ssl_config, NULL, + &cert_verifier_)); + server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket, + cert, private_key.get(), + net::SSLConfig())); + } + + FakeDataChannel channel_1_; + FakeDataChannel channel_2_; + scoped_ptr client_socket_; + scoped_ptr server_socket_; + net::ClientSocketFactory* socket_factory_; + net::CertVerifier cert_verifier_; +}; + +// SSLServerSocket is only implemented using NSS. +#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) + +// This test only executes creation of client and server sockets. This is to +// test that creation of sockets doesn't crash and have minimal code to run +// under valgrind in order to help debugging memory problems. +TEST_F(SSLServerSocketTest, Initialize) { + Initialize(); +} + +// This test executes Connect() of SSLClientSocket and Accept() of +// SSLServerSocket to make sure handshaking between the two sockets are +// completed successfully. +TEST_F(SSLServerSocketTest, Handshake) { + Initialize(); + + if (!base::CheckNSSVersion("3.12.8")) + return; + + TestCompletionCallback connect_callback; + TestCompletionCallback accept_callback; + + int server_ret = server_socket_->Accept(&accept_callback); + EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + int client_ret = client_socket_->Connect(&connect_callback); + EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, connect_callback.WaitForResult()); + } + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, accept_callback.WaitForResult()); + } +} + +TEST_F(SSLServerSocketTest, DataTransfer) { + Initialize(); + + if (!base::CheckNSSVersion("3.12.8")) + return; + + TestCompletionCallback connect_callback; + TestCompletionCallback accept_callback; + + // Establish connection. + int client_ret = client_socket_->Connect(&connect_callback); + EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + int server_ret = server_socket_->Accept(&accept_callback); + EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, connect_callback.WaitForResult()); + } + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, accept_callback.WaitForResult()); + } + + const int kReadBufSize = 1024; + scoped_refptr write_buf = + new net::StringIOBuffer("testing123"); + scoped_refptr read_buf = new net::IOBuffer(kReadBufSize); + + // Write then read. + TestCompletionCallback write_callback; + TestCompletionCallback read_callback; + server_ret = server_socket_->Write(write_buf, write_buf->size(), + &write_callback); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_GT(write_callback.WaitForResult(), 0); + } + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_GT(read_callback.WaitForResult(), 0); + } + EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); + + // Read then write. + write_buf = new net::StringIOBuffer("hello123"); + server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + client_ret = client_socket_->Write(write_buf, write_buf->size(), + &write_callback); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_GT(read_callback.WaitForResult(), 0); + } + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_GT(write_callback.WaitForResult(), 0); + } + EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); +} +#endif + +} // namespace net -- cgit v1.1