diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/base/net_log_event_type_list.h | 3 | ||||
-rw-r--r-- | net/base/ssl_config_service.cc | 5 | ||||
-rw-r--r-- | net/base/ssl_config_service.h | 5 | ||||
-rw-r--r-- | net/base/x509_certificate.h | 5 | ||||
-rw-r--r-- | net/base/x509_certificate_mac.cc | 11 | ||||
-rw-r--r-- | net/base/x509_certificate_nss.cc | 9 | ||||
-rw-r--r-- | net/base/x509_certificate_openssl.cc | 5 | ||||
-rw-r--r-- | net/base/x509_certificate_unittest.cc | 12 | ||||
-rw-r--r-- | net/base/x509_certificate_win.cc | 9 | ||||
-rw-r--r-- | net/net.gyp | 8 | ||||
-rw-r--r-- | net/socket/nss_ssl_util.cc | 240 | ||||
-rw-r--r-- | net/socket/nss_ssl_util.h | 36 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_nss.cc | 232 | ||||
-rw-r--r-- | net/socket/ssl_server_socket.h | 53 | ||||
-rw-r--r-- | net/socket/ssl_server_socket_nss.cc | 677 | ||||
-rw-r--r-- | net/socket/ssl_server_socket_nss.h | 133 | ||||
-rw-r--r-- | net/socket/ssl_server_socket_unittest.cc | 369 |
17 files changed, 1589 insertions, 223 deletions
diff --git a/net/base/net_log_event_type_list.h b/net/base/net_log_event_type_list.h index e571685..f1bc4f8 100644 --- a/net/base/net_log_event_type_list.h +++ b/net/base/net_log_event_type_list.h @@ -325,6 +325,9 @@ EVENT_TYPE(SOCKS_UNKNOWN_ADDRESS_TYPE) // The start/end of a SSL connect(). EVENT_TYPE(SSL_CONNECT) +// The start/end of a SSL accept(). +EVENT_TYPE(SSL_ACCEPT) + // An SSL error occurred while trying to do the indicated activity. // The following parameters are attached to the event: // { diff --git a/net/base/ssl_config_service.cc b/net/base/ssl_config_service.cc index 9b0a903..d02df38 100644 --- a/net/base/ssl_config_service.cc +++ b/net/base/ssl_config_service.cc @@ -23,8 +23,9 @@ SSLConfig::SSLConfig() : rev_checking_enabled(true), ssl3_enabled(true), tls1_enabled(true), dnssec_enabled(false), snap_start_enabled(false), dns_cert_provenance_checking_enabled(false), - mitm_proxies_allowed(false), false_start_enabled(true), - send_client_cert(false), verify_ev_cert(false), ssl3_fallback(false) { + session_resume_disabled(false), mitm_proxies_allowed(false), + false_start_enabled(true), send_client_cert(false), + verify_ev_cert(false), ssl3_fallback(false) { } SSLConfig::~SSLConfig() { diff --git a/net/base/ssl_config_service.h b/net/base/ssl_config_service.h index c1ae553..de2ebef 100644 --- a/net/base/ssl_config_service.h +++ b/net/base/ssl_config_service.h @@ -32,6 +32,11 @@ struct SSLConfig { // True if we'll do async checks for certificate provenance using DNS. bool dns_cert_provenance_checking_enabled; + // TODO(hclam): This option is used to simplify the SSLServerSocketNSS + // implementation and should be removed when session caching is implemented. + // See http://crbug.com/67236 for more details. + bool session_resume_disabled; // Don't allow session resume. + // Cipher suites which should be explicitly prevented from being used in // addition to those disabled by the net built-in policy -- by default, all // cipher suites supported by the underlying SSL implementation will be diff --git a/net/base/x509_certificate.h b/net/base/x509_certificate.h index 3ee7304..c59c33c 100644 --- a/net/base/x509_certificate.h +++ b/net/base/x509_certificate.h @@ -287,6 +287,11 @@ class X509Certificate : public base::RefCountedThreadSafe<X509Certificate> { int flags, CertVerifyResult* verify_result) const; + // This method returns the DER encoded certificate. + // If the return value is true then the DER encoded certificate is available. + // The content of the DER encoded certificate is written to |encoded|. + bool GetDEREncoded(std::string* encoded); + OSCertHandle os_cert_handle() const { return cert_handle_; } // Returns true if two OSCertHandles refer to identical certificates. diff --git a/net/base/x509_certificate_mac.cc b/net/base/x509_certificate_mac.cc index f7c89e4..fd965cb3 100644 --- a/net/base/x509_certificate_mac.cc +++ b/net/base/x509_certificate_mac.cc @@ -651,6 +651,17 @@ int X509Certificate::Verify(const std::string& hostname, int flags, return OK; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + encoded->clear(); + CSSM_DATA der_data; + if(SecCertificateGetData(cert_handle_, &der_data) == noErr) { + encoded->append(reinterpret_cast<char*>(der_data.Data), + der_data.Length); + return true; + } + return false; +} + bool X509Certificate::VerifyEV() const { // We don't call this private method, but we do need to implement it because // it's defined in x509_certificate.h. We perform EV checking in the diff --git a/net/base/x509_certificate_nss.cc b/net/base/x509_certificate_nss.cc index 2962cb5..05e736c 100644 --- a/net/base/x509_certificate_nss.cc +++ b/net/base/x509_certificate_nss.cc @@ -829,6 +829,15 @@ bool X509Certificate::VerifyEV() const { return true; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + if (!cert_handle_->derCert.len) + return false; + encoded->clear(); + encoded->append(reinterpret_cast<char*>(cert_handle_->derCert.data), + cert_handle_->derCert.len); + return true; +} + // static bool X509Certificate::IsSameOSCert(X509Certificate::OSCertHandle a, X509Certificate::OSCertHandle b) { diff --git a/net/base/x509_certificate_openssl.cc b/net/base/x509_certificate_openssl.cc index c6ffb2c..cf43610 100644 --- a/net/base/x509_certificate_openssl.cc +++ b/net/base/x509_certificate_openssl.cc @@ -462,6 +462,11 @@ int X509Certificate::Verify(const std::string& hostname, return OK; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + // TODO(port): Implement. + return false; +} + // static bool X509Certificate::IsSameOSCert(X509Certificate::OSCertHandle a, X509Certificate::OSCertHandle b) { diff --git a/net/base/x509_certificate_unittest.cc b/net/base/x509_certificate_unittest.cc index 83c11fa..dba5ef3 100644 --- a/net/base/x509_certificate_unittest.cc +++ b/net/base/x509_certificate_unittest.cc @@ -672,6 +672,18 @@ TEST(X509CertificateTest, CreateSelfSigned) { EXPECT_EQ("subject", cert->subject().GetDisplayName()); EXPECT_FALSE(cert->HasExpired()); } + +TEST(X509CertificateTest, GetDEREncoded) { + scoped_ptr<base::RSAPrivateKey> private_key( + base::RSAPrivateKey::Create(1024)); + scoped_refptr<net::X509Certificate> cert = + net::X509Certificate::CreateSelfSigned( + private_key.get(), "CN=subject", 0, base::TimeDelta::FromDays(1)); + + std::string der_cert; + EXPECT_TRUE(cert->GetDEREncoded(&der_cert)); + EXPECT_FALSE(der_cert.empty()); +} #endif class X509CertificateParseTest diff --git a/net/base/x509_certificate_win.cc b/net/base/x509_certificate_win.cc index 568c1fd..663563d 100644 --- a/net/base/x509_certificate_win.cc +++ b/net/base/x509_certificate_win.cc @@ -843,6 +843,15 @@ int X509Certificate::Verify(const std::string& hostname, return OK; } +bool X509Certificate::GetDEREncoded(std::string* encoded) { + if (!cert_handle_->pbCertEncoded || !cert_handle_->cbCertEncoded) + return false; + encoded->clear(); + encoded->append(reinterpret_cast<char*>(cert_handle_->pbCertEncoded), + cert_handle_->cbCertEncoded); + return true; +} + // Returns true if the certificate is an extended-validation certificate. // // This function checks the certificatePolicies extensions of the diff --git a/net/net.gyp b/net/net.gyp index 5e58283..2c24152 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -594,6 +594,8 @@ 'socket/client_socket_pool_manager.h', 'socket/dns_cert_provenance_checker.cc', 'socket/dns_cert_provenance_checker.h', + 'socket/nss_ssl_util.cc', + 'socket/nss_ssl_util.h', 'socket/socket.h', 'socket/socks5_client_socket.cc', 'socket/socks5_client_socket.h', @@ -619,6 +621,9 @@ 'socket/ssl_client_socket_win.h', 'socket/ssl_error_params.cc', 'socket/ssl_error_params.h', + 'socket/ssl_server_socket.h', + 'socket/ssl_server_socket_nss.cc', + 'socket/ssl_server_socket_nss.h', 'socket/ssl_host_info.cc', 'socket/ssl_host_info.h', 'socket/tcp_client_socket.cc', @@ -750,6 +755,8 @@ 'socket/ssl_client_socket_nss.h', 'socket/ssl_client_socket_nss_factory.cc', 'socket/ssl_client_socket_nss_factory.h', + 'socket/ssl_server_socket_nss.cc', + 'socket/ssl_server_socket_nss.h', ], }, { # else !use_openssl: remove the unneeded files @@ -960,6 +967,7 @@ 'socket/socks_client_socket_unittest.cc', 'socket/ssl_client_socket_unittest.cc', 'socket/ssl_client_socket_pool_unittest.cc', + 'socket/ssl_server_socket_unittest.cc', 'socket/tcp_client_socket_pool_unittest.cc', 'socket/tcp_client_socket_unittest.cc', 'socket_stream/socket_stream_metrics_unittest.cc', diff --git a/net/socket/nss_ssl_util.cc b/net/socket/nss_ssl_util.cc new file mode 100644 index 0000000..eb8bafb --- /dev/null +++ b/net/socket/nss_ssl_util.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/nss_ssl_util.h" + +#include <nss.h> +#include <secerr.h> +#include <ssl.h> +#include <sslerr.h> + +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/nss_util.h" +#include "base/singleton.h" +#include "base/thread_restrictions.h" +#include "base/values.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" + +namespace net { + +class NSSSSLInitSingleton { + public: + NSSSSLInitSingleton() { + base::EnsureNSSInit(); + + NSS_SetDomesticPolicy(); + +#if defined(USE_SYSTEM_SSL) + // Use late binding to avoid scary but benign warning + // "Symbol `SSL_ImplementedCiphers' has different size in shared object, + // consider re-linking" + // TODO(wtc): Use the new SSL_GetImplementedCiphers and + // SSL_GetNumImplementedCiphers functions when we require NSS 3.12.6. + // See https://bugzilla.mozilla.org/show_bug.cgi?id=496993. + const PRUint16* pSSL_ImplementedCiphers = static_cast<const PRUint16*>( + dlsym(RTLD_DEFAULT, "SSL_ImplementedCiphers")); + if (pSSL_ImplementedCiphers == NULL) { + NOTREACHED() << "Can't get list of supported ciphers"; + return; + } +#else +#define pSSL_ImplementedCiphers SSL_ImplementedCiphers +#endif + + // Explicitly enable exactly those ciphers with keys of at least 80 bits + for (int i = 0; i < SSL_NumImplementedCiphers; i++) { + SSLCipherSuiteInfo info; + if (SSL_GetCipherSuiteInfo(pSSL_ImplementedCiphers[i], &info, + sizeof(info)) == SECSuccess) { + SSL_CipherPrefSetDefault(pSSL_ImplementedCiphers[i], + (info.effectiveKeyBits >= 80)); + } + } + + // Enable SSL. + SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE); + + // All other SSL options are set per-session by SSLClientSocket and + // SSLServerSocket. + } + + ~NSSSSLInitSingleton() { + // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY. + SSL_ClearSessionCache(); + } +}; + +static base::LazyInstance<NSSSSLInitSingleton> g_nss_ssl_init_singleton( + base::LINKER_INITIALIZED); + +// Initialize the NSS SSL library if it isn't already initialized. This must +// be called before any other NSS SSL functions. This function is +// thread-safe, and the NSS SSL library will only ever be initialized once. +// The NSS SSL library will be properly shut down on program exit. +void EnsureNSSSSLInit() { + // Initializing SSL causes us to do blocking IO. + // Temporarily allow it until we fix + // http://code.google.com/p/chromium/issues/detail?id=59847 + base::ThreadRestrictions::ScopedAllowIO allow_io; + + g_nss_ssl_init_singleton.Get(); +} + +// Map a Chromium net error code to an NSS error code. +// See _MD_unix_map_default_error in the NSS source +// tree for inspiration. +PRErrorCode MapErrorToNSS(int result) { + if (result >=0) + return result; + + switch (result) { + case ERR_IO_PENDING: + return PR_WOULD_BLOCK_ERROR; + case ERR_ACCESS_DENIED: + case ERR_NETWORK_ACCESS_DENIED: + // For connect, this could be mapped to PR_ADDRESS_NOT_SUPPORTED_ERROR. + return PR_NO_ACCESS_RIGHTS_ERROR; + case ERR_NOT_IMPLEMENTED: + return PR_NOT_IMPLEMENTED_ERROR; + case ERR_INTERNET_DISCONNECTED: // Equivalent to ENETDOWN. + return PR_NETWORK_UNREACHABLE_ERROR; // Best approximation. + case ERR_CONNECTION_TIMED_OUT: + case ERR_TIMED_OUT: + return PR_IO_TIMEOUT_ERROR; + case ERR_CONNECTION_RESET: + return PR_CONNECT_RESET_ERROR; + case ERR_CONNECTION_ABORTED: + return PR_CONNECT_ABORTED_ERROR; + case ERR_CONNECTION_REFUSED: + return PR_CONNECT_REFUSED_ERROR; + case ERR_ADDRESS_UNREACHABLE: + return PR_HOST_UNREACHABLE_ERROR; // Also PR_NETWORK_UNREACHABLE_ERROR. + case ERR_ADDRESS_INVALID: + return PR_ADDRESS_NOT_AVAILABLE_ERROR; + case ERR_NAME_NOT_RESOLVED: + return PR_DIRECTORY_LOOKUP_ERROR; + default: + LOG(WARNING) << "MapErrorToNSS " << result + << " mapped to PR_UNKNOWN_ERROR"; + return PR_UNKNOWN_ERROR; + } +} + +// The default error mapping function. +// Maps an NSS error code to a network error code. +int MapNSSError(PRErrorCode err) { + // TODO(port): fill this out as we learn what's important + switch (err) { + case PR_WOULD_BLOCK_ERROR: + return ERR_IO_PENDING; + case PR_ADDRESS_NOT_SUPPORTED_ERROR: // For connect. + case PR_NO_ACCESS_RIGHTS_ERROR: + return ERR_ACCESS_DENIED; + case PR_IO_TIMEOUT_ERROR: + return ERR_TIMED_OUT; + case PR_CONNECT_RESET_ERROR: + return ERR_CONNECTION_RESET; + case PR_CONNECT_ABORTED_ERROR: + return ERR_CONNECTION_ABORTED; + case PR_CONNECT_REFUSED_ERROR: + return ERR_CONNECTION_REFUSED; + case PR_HOST_UNREACHABLE_ERROR: + case PR_NETWORK_UNREACHABLE_ERROR: + return ERR_ADDRESS_UNREACHABLE; + case PR_ADDRESS_NOT_AVAILABLE_ERROR: + return ERR_ADDRESS_INVALID; + case PR_INVALID_ARGUMENT_ERROR: + return ERR_INVALID_ARGUMENT; + case PR_END_OF_FILE_ERROR: + return ERR_CONNECTION_CLOSED; + case PR_NOT_IMPLEMENTED_ERROR: + return ERR_NOT_IMPLEMENTED; + + case SEC_ERROR_INVALID_ARGS: + return ERR_INVALID_ARGUMENT; + + case SSL_ERROR_SSL_DISABLED: + return ERR_NO_SSL_VERSIONS_ENABLED; + case SSL_ERROR_NO_CYPHER_OVERLAP: + case SSL_ERROR_UNSUPPORTED_VERSION: + return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; + case SSL_ERROR_HANDSHAKE_FAILURE_ALERT: + case SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT: + case SSL_ERROR_ILLEGAL_PARAMETER_ALERT: + return ERR_SSL_PROTOCOL_ERROR; + case SSL_ERROR_DECOMPRESSION_FAILURE_ALERT: + return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; + case SSL_ERROR_BAD_MAC_ALERT: + return ERR_SSL_BAD_RECORD_MAC_ALERT; + case SSL_ERROR_UNSAFE_NEGOTIATION: + return ERR_SSL_UNSAFE_NEGOTIATION; + case SSL_ERROR_WEAK_SERVER_KEY: + return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; + + default: { + if (IS_SSL_ERROR(err)) { + LOG(WARNING) << "Unknown SSL error " << err << + " mapped to net::ERR_SSL_PROTOCOL_ERROR"; + return ERR_SSL_PROTOCOL_ERROR; + } + LOG(WARNING) << "Unknown error " << err << + " mapped to net::ERR_FAILED"; + return ERR_FAILED; + } + } +} + +// Context-sensitive error mapping functions. +int MapNSSHandshakeError(PRErrorCode err) { + switch (err) { + // If the server closed on us, it is a protocol error. + // Some TLS-intolerant servers do this when we request TLS. + case PR_END_OF_FILE_ERROR: + // The handshake may fail because some signature (for example, the + // signature in the ServerKeyExchange message for an ephemeral + // Diffie-Hellman cipher suite) is invalid. + case SEC_ERROR_BAD_SIGNATURE: + return ERR_SSL_PROTOCOL_ERROR; + default: + return MapNSSError(err); + } +} + +// Extra parameters to attach to the NetLog when we receive an error in response +// to a call to an NSS function. Used instead of SSLErrorParams with +// events of type TYPE_SSL_NSS_ERROR. Automatically looks up last PR error. +class SSLFailedNSSFunctionParams : public NetLog::EventParameters { + public: + // |param| is ignored if it has a length of 0. + SSLFailedNSSFunctionParams(const std::string& function, + const std::string& param) + : function_(function), param_(param), ssl_lib_error_(PR_GetError()) { + } + + virtual Value* ToValue() const { + DictionaryValue* dict = new DictionaryValue(); + dict->SetString("function", function_); + if (!param_.empty()) + dict->SetString("param", param_); + dict->SetInteger("ssl_lib_error", ssl_lib_error_); + return dict; + } + + private: + const std::string function_; + const std::string param_; + const PRErrorCode ssl_lib_error_; +}; + +void LogFailedNSSFunction(const BoundNetLog& net_log, + const char* function, + const char* param) { + net_log.AddEvent( + NetLog::TYPE_SSL_NSS_ERROR, + make_scoped_refptr(new SSLFailedNSSFunctionParams(function, param))); +} + +} // namespace net diff --git a/net/socket/nss_ssl_util.h b/net/socket/nss_ssl_util.h new file mode 100644 index 0000000..64bf3cf --- /dev/null +++ b/net/socket/nss_ssl_util.h @@ -0,0 +1,36 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file is only inclued in ssl_client_socket_nss.cc and +// ssl_server_socket_nss.cc to share common functions of NSS. + +#ifndef NET_SOCKET_NSS_SSL_UTIL_H_ +#define NEt_SOCKET_NSS_SSL_UTIL_H_ + +#include <prerror.h> + +namespace net { + +class BoundNetLog; + +// Initalize NSS SSL library. +void EnsureNSSSSLInit(); + +// Log a failed NSS funcion call. +void LogFailedNSSFunction(const BoundNetLog& net_log, + const char* function, + const char* param); + +// Map network error code to NSS error code. +PRErrorCode MapErrorToNSS(int result); + +// Map NSS error code to network error code. +int MapNSSError(PRErrorCode err); + +// Map NSS handshake error to network error code. +int MapNSSHandshakeError(PRErrorCode err); + +} // namespace net + +#endif // NET_SOCKET_NSS_SSL_UTIL_H_ diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index 3f56b59..49e065f 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -63,7 +63,6 @@ #include "base/compiler_specific.h" #include "base/metrics/histogram.h" -#include "base/lazy_instance.h" #include "base/logging.h" #include "base/nss_util.h" #include "base/string_number_conversions.h" @@ -88,6 +87,7 @@ #include "net/ocsp/nss_ocsp.h" #include "net/socket/client_socket_handle.h" #include "net/socket/dns_cert_provenance_checker.h" +#include "net/socket/nss_ssl_util.h" #include "net/socket/ssl_error_params.h" #include "net/socket/ssl_host_info.h" @@ -139,183 +139,6 @@ namespace net { namespace { -class NSSSSLInitSingleton { - public: - NSSSSLInitSingleton() { - base::EnsureNSSInit(); - - NSS_SetDomesticPolicy(); - -#if defined(USE_SYSTEM_SSL) - // Use late binding to avoid scary but benign warning - // "Symbol `SSL_ImplementedCiphers' has different size in shared object, - // consider re-linking" - // TODO(wtc): Use the new SSL_GetImplementedCiphers and - // SSL_GetNumImplementedCiphers functions when we require NSS 3.12.6. - // See https://bugzilla.mozilla.org/show_bug.cgi?id=496993. - const PRUint16* pSSL_ImplementedCiphers = static_cast<const PRUint16*>( - dlsym(RTLD_DEFAULT, "SSL_ImplementedCiphers")); - if (pSSL_ImplementedCiphers == NULL) { - NOTREACHED() << "Can't get list of supported ciphers"; - return; - } -#else -#define pSSL_ImplementedCiphers SSL_ImplementedCiphers -#endif - - // Explicitly enable exactly those ciphers with keys of at least 80 bits - for (int i = 0; i < SSL_NumImplementedCiphers; i++) { - SSLCipherSuiteInfo info; - if (SSL_GetCipherSuiteInfo(pSSL_ImplementedCiphers[i], &info, - sizeof(info)) == SECSuccess) { - SSL_CipherPrefSetDefault(pSSL_ImplementedCiphers[i], - (info.effectiveKeyBits >= 80)); - } - } - - // Enable SSL. - SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE); - - // All other SSL options are set per-session by SSLClientSocket. - } - - ~NSSSSLInitSingleton() { - // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY. - SSL_ClearSessionCache(); - } -}; - -static base::LazyInstance<NSSSSLInitSingleton> g_nss_ssl_init_singleton( - base::LINKER_INITIALIZED); - -// Initialize the NSS SSL library if it isn't already initialized. This must -// be called before any other NSS SSL functions. This function is -// thread-safe, and the NSS SSL library will only ever be initialized once. -// The NSS SSL library will be properly shut down on program exit. -void EnsureNSSSSLInit() { - // Initializing SSL causes us to do blocking IO. - // Temporarily allow it until we fix - // http://code.google.com/p/chromium/issues/detail?id=59847 - base::ThreadRestrictions::ScopedAllowIO allow_io; - - g_nss_ssl_init_singleton.Get(); -} - -// The default error mapping function. -// Maps an NSPR error code to a network error code. -int MapNSPRError(PRErrorCode err) { - // TODO(port): fill this out as we learn what's important - switch (err) { - case PR_WOULD_BLOCK_ERROR: - return ERR_IO_PENDING; - case PR_ADDRESS_NOT_SUPPORTED_ERROR: // For connect. - case PR_NO_ACCESS_RIGHTS_ERROR: - return ERR_ACCESS_DENIED; - case PR_IO_TIMEOUT_ERROR: - return ERR_TIMED_OUT; - case PR_CONNECT_RESET_ERROR: - return ERR_CONNECTION_RESET; - case PR_CONNECT_ABORTED_ERROR: - return ERR_CONNECTION_ABORTED; - case PR_CONNECT_REFUSED_ERROR: - return ERR_CONNECTION_REFUSED; - case PR_HOST_UNREACHABLE_ERROR: - case PR_NETWORK_UNREACHABLE_ERROR: - return ERR_ADDRESS_UNREACHABLE; - case PR_ADDRESS_NOT_AVAILABLE_ERROR: - return ERR_ADDRESS_INVALID; - case PR_INVALID_ARGUMENT_ERROR: - return ERR_INVALID_ARGUMENT; - case PR_END_OF_FILE_ERROR: - return ERR_CONNECTION_CLOSED; - case PR_NOT_IMPLEMENTED_ERROR: - return ERR_NOT_IMPLEMENTED; - - case SEC_ERROR_INVALID_ARGS: - return ERR_INVALID_ARGUMENT; - - case SSL_ERROR_SSL_DISABLED: - return ERR_NO_SSL_VERSIONS_ENABLED; - case SSL_ERROR_NO_CYPHER_OVERLAP: - case SSL_ERROR_UNSUPPORTED_VERSION: - return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; - case SSL_ERROR_HANDSHAKE_FAILURE_ALERT: - case SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT: - case SSL_ERROR_ILLEGAL_PARAMETER_ALERT: - return ERR_SSL_PROTOCOL_ERROR; - case SSL_ERROR_DECOMPRESSION_FAILURE_ALERT: - return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; - case SSL_ERROR_BAD_MAC_ALERT: - return ERR_SSL_BAD_RECORD_MAC_ALERT; - case SSL_ERROR_UNSAFE_NEGOTIATION: - return ERR_SSL_UNSAFE_NEGOTIATION; - case SSL_ERROR_WEAK_SERVER_KEY: - return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; - - default: { - if (IS_SSL_ERROR(err)) { - LOG(WARNING) << "Unknown SSL error " << err << - " mapped to net::ERR_SSL_PROTOCOL_ERROR"; - return ERR_SSL_PROTOCOL_ERROR; - } - LOG(WARNING) << "Unknown error " << err << - " mapped to net::ERR_FAILED"; - return ERR_FAILED; - } - } -} - -// Context-sensitive error mapping functions. - -int MapHandshakeError(PRErrorCode err) { - switch (err) { - // If the server closed on us, it is a protocol error. - // Some TLS-intolerant servers do this when we request TLS. - case PR_END_OF_FILE_ERROR: - // The handshake may fail because some signature (for example, the - // signature in the ServerKeyExchange message for an ephemeral - // Diffie-Hellman cipher suite) is invalid. - case SEC_ERROR_BAD_SIGNATURE: - return ERR_SSL_PROTOCOL_ERROR; - default: - return MapNSPRError(err); - } -} - -// Extra parameters to attach to the NetLog when we receive an error in response -// to a call to an NSS function. Used instead of SSLErrorParams with -// events of type TYPE_SSL_NSS_ERROR. Automatically looks up last PR error. -class SSLFailedNSSFunctionParams : public NetLog::EventParameters { - public: - // |param| is ignored if it has a length of 0. - SSLFailedNSSFunctionParams(const std::string& function, - const std::string& param) - : function_(function), param_(param), ssl_lib_error_(PR_GetError()) { - } - - virtual Value* ToValue() const { - DictionaryValue* dict = new DictionaryValue(); - dict->SetString("function", function_); - if (!param_.empty()) - dict->SetString("param", param_); - dict->SetInteger("ssl_lib_error", ssl_lib_error_); - return dict; - } - - private: - const std::string function_; - const std::string param_; - const PRErrorCode ssl_lib_error_; -}; - -void LogFailedNSSFunction(const BoundNetLog& net_log, - const char* function, - const char* param) { - net_log.AddEvent( - NetLog::TYPE_SSL_NSS_ERROR, - make_scoped_refptr(new SSLFailedNSSFunctionParams(function, param))); -} - #if defined(OS_WIN) // This callback is intended to be used with CertFindChainInStore. In addition @@ -736,6 +559,13 @@ int SSLClientSocketNSS::InitializeSSLOptions() { #error "You need to install NSS-3.12 or later to build chromium" #endif + rv = SSL_OptionSet(nss_fd_, SSL_NO_CACHE, + ssl_config_.session_resume_disabled); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_NO_CACHE"); + return ERR_UNEXPECTED; + } + #ifdef SSL_ENABLE_DEFLATE // Some web servers have been found to break if TLS is used *or* if DEFLATE // is advertised. Thus, if TLS is disabled (probably because we are doing @@ -1363,46 +1193,6 @@ void SSLClientSocketNSS::OnRecvComplete(int result) { LeaveFunction(""); } -// Map a Chromium net error code to an NSS error code. -// See _MD_unix_map_default_error in the NSS source -// tree for inspiration. -static PRErrorCode MapErrorToNSS(int result) { - if (result >=0) - return result; - - switch (result) { - case ERR_IO_PENDING: - return PR_WOULD_BLOCK_ERROR; - case ERR_ACCESS_DENIED: - case ERR_NETWORK_ACCESS_DENIED: - // For connect, this could be mapped to PR_ADDRESS_NOT_SUPPORTED_ERROR. - return PR_NO_ACCESS_RIGHTS_ERROR; - case ERR_NOT_IMPLEMENTED: - return PR_NOT_IMPLEMENTED_ERROR; - case ERR_INTERNET_DISCONNECTED: // Equivalent to ENETDOWN. - return PR_NETWORK_UNREACHABLE_ERROR; // Best approximation. - case ERR_CONNECTION_TIMED_OUT: - case ERR_TIMED_OUT: - return PR_IO_TIMEOUT_ERROR; - case ERR_CONNECTION_RESET: - return PR_CONNECT_RESET_ERROR; - case ERR_CONNECTION_ABORTED: - return PR_CONNECT_ABORTED_ERROR; - case ERR_CONNECTION_REFUSED: - return PR_CONNECT_REFUSED_ERROR; - case ERR_ADDRESS_UNREACHABLE: - return PR_HOST_UNREACHABLE_ERROR; // Also PR_NETWORK_UNREACHABLE_ERROR. - case ERR_ADDRESS_INVALID: - return PR_ADDRESS_NOT_AVAILABLE_ERROR; - case ERR_NAME_NOT_RESOLVED: - return PR_DIRECTORY_LOOKUP_ERROR; - default: - LOG(WARNING) << "MapErrorToNSS " << result - << " mapped to PR_UNKNOWN_ERROR"; - return PR_UNKNOWN_ERROR; - } -} - // Do network I/O between the given buffer and the given socket. // Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) bool SSLClientSocketNSS::DoTransportIO() { @@ -2191,7 +1981,7 @@ int SSLClientSocketNSS::DoHandshake() { } } else { PRErrorCode prerr = PR_GetError(); - net_error = MapHandshakeError(prerr); + net_error = MapNSSHandshakeError(prerr); // If not done, stay in this state if (net_error == ERR_IO_PENDING) { @@ -2580,7 +2370,7 @@ int SSLClientSocketNSS::DoPayloadRead() { return ERR_IO_PENDING; } LeaveFunction(""); - rv = MapNSPRError(prerr); + rv = MapNSSError(prerr); net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, make_scoped_refptr(new SSLErrorParams(rv, prerr))); return rv; @@ -2601,7 +2391,7 @@ int SSLClientSocketNSS::DoPayloadWrite() { return ERR_IO_PENDING; } LeaveFunction(""); - rv = MapNSPRError(prerr); + rv = MapNSSError(prerr); net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, make_scoped_refptr(new SSLErrorParams(rv, prerr))); return rv; diff --git a/net/socket/ssl_server_socket.h b/net/socket/ssl_server_socket.h new file mode 100644 index 0000000..b689c71 --- /dev/null +++ b/net/socket/ssl_server_socket.h @@ -0,0 +1,53 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SERVER_SOCKET_H_ +#define NET_SOCKET_SSL_SERVER_SOCKET_H_ + +#include "base/basictypes.h" +#include "net/base/completion_callback.h" +#include "net/socket/socket.h" + +namespace base { +class RSAPrivateKey; +} // namespace base + +namespace net { + +class IOBuffer; +struct SSLConfig; +class X509Certificate; + +// SSLServerSocket takes an already connected socket and performs SSL on top of +// it. +// +// This class is designed to work in a peer-to-peer connection and is not +// intended to be used as a standalone SSL server. +class SSLServerSocket : public Socket { + public: + virtual ~SSLServerSocket() {} + + // Performs an SSL server handshake on the existing socket. The given socket + // must have already been connected. + // + // Accept either returns ERR_IO_PENDING, in which case the given callback + // will be called in the future with the real result, or it completes + // synchronously, returning the result immediately. + virtual int Accept(CompletionCallback* callback) = 0; +}; + +// Creates an SSL server socket using an already connected socket. A certificate +// and private key needs to be provided. +// +// This created server socket will take ownership of |socket|. However |key| +// is copied. +// TODO(hclam): Defines ServerSocketFactory to create SSLServerSocket. This will +// make mocking easier. +SSLServerSocket* CreateSSLServerSocket( + Socket* socket, X509Certificate* certificate, base::RSAPrivateKey* key, + const SSLConfig& ssl_config); + +} // namespace net + +#endif // NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc new file mode 100644 index 0000000..2e47fb8 --- /dev/null +++ b/net/socket/ssl_server_socket_nss.cc @@ -0,0 +1,677 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_server_socket_nss.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif + +#if defined(USE_SYSTEM_SSL) +#include <dlfcn.h> +#endif +#if defined(OS_MACOSX) +#include <Security/Security.h> +#endif +#include <certdb.h> +#include <cryptohi.h> +#include <hasht.h> +#include <keyhi.h> +#include <nspr.h> +#include <nss.h> +#include <pk11pub.h> +#include <secerr.h> +#include <sechash.h> +#include <ssl.h> +#include <sslerr.h> +#include <sslproto.h> + +#include <limits> + +#include "base/crypto/rsa_private_key.h" +#include "base/nss_util_internal.h" +#include "base/ref_counted.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/ocsp/nss_ocsp.h" +#include "net/socket/nss_ssl_util.h" +#include "net/socket/ssl_error_params.h" + +static const int kRecvBufferSize = 4096; + +#define GotoState(s) next_handshake_state_ = s + +namespace net { + +SSLServerSocket* CreateSSLServerSocket( + Socket* socket, X509Certificate* cert, base::RSAPrivateKey* key, + const SSLConfig& ssl_config) { + return new SSLServerSocketNSS(socket, cert, key, ssl_config); +} + +SSLServerSocketNSS::SSLServerSocketNSS( + Socket* transport_socket, + scoped_refptr<X509Certificate> cert, + base::RSAPrivateKey* key, + const SSLConfig& ssl_config) + : ALLOW_THIS_IN_INITIALIZER_LIST(buffer_send_callback_( + this, &SSLServerSocketNSS::BufferSendComplete)), + ALLOW_THIS_IN_INITIALIZER_LIST(buffer_recv_callback_( + this, &SSLServerSocketNSS::BufferRecvComplete)), + transport_send_busy_(false), + transport_recv_busy_(false), + user_accept_callback_(NULL), + user_read_callback_(NULL), + user_write_callback_(NULL), + nss_fd_(NULL), + nss_bufs_(NULL), + transport_socket_(transport_socket), + ssl_config_(ssl_config), + cert_(cert), + next_handshake_state_(STATE_NONE), + completed_handshake_(false) { + ssl_config_.false_start_enabled = false; + ssl_config_.ssl3_enabled = true; + ssl_config_.tls1_enabled = true; + + // TODO(hclam): Need a better way to clone a key. + std::vector<uint8> key_bytes; + CHECK(key->ExportPrivateKey(&key_bytes)); + key_.reset(base::RSAPrivateKey::CreateFromPrivateKeyInfo(key_bytes)); + CHECK(key_.get()); +} + +SSLServerSocketNSS::~SSLServerSocketNSS() { + if (nss_fd_ != NULL) { + PR_Close(nss_fd_); + nss_fd_ = NULL; + } +} + +int SSLServerSocketNSS::Init() { + // Initialize the NSS SSL library in a threadsafe way. This also + // initializes the NSS base library. + EnsureNSSSSLInit(); + if (!NSS_IsInitialized()) + return ERR_UNEXPECTED; +#if !defined(OS_MACOSX) && !defined(OS_WIN) + // We must call EnsureOCSPInit() here, on the IO thread, to get the IO loop + // by MessageLoopForIO::current(). + // X509Certificate::Verify() runs on a worker thread of CertVerifier. + EnsureOCSPInit(); +#endif + + return OK; +} + +int SSLServerSocketNSS::Accept(CompletionCallback* callback) { + net_log_.BeginEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + + int rv = Init(); + if (rv != OK) { + LOG(ERROR) << "Failed to initialize NSS"; + net_log_.EndEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + return rv; + } + + rv = InitializeSSLOptions(); + if (rv != OK) { + LOG(ERROR) << "Failed to initialize SSL options"; + net_log_.EndEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + return rv; + } + + // Set peer address. TODO(hclam): This should be in a separate method. + PRNetAddr peername; + memset(&peername, 0, sizeof(peername)); + peername.raw.family = AF_INET; + memio_SetPeerName(nss_fd_, &peername); + + GotoState(STATE_HANDSHAKE); + rv = DoHandshakeLoop(net::OK); + if (rv == ERR_IO_PENDING) { + user_accept_callback_ = callback; + } else { + net_log_.EndEvent(NetLog::TYPE_SSL_ACCEPT, NULL); + } + + return rv > OK ? OK : rv; +} + +int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!user_read_callback_); + DCHECK(!user_accept_callback_); + DCHECK(!user_read_buf_); + DCHECK(nss_bufs_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + DCHECK(completed_handshake_); + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + return rv; +} + +int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(!user_write_callback_); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} + +// static +// NSS calls this if an incoming certificate needs to be verified. +// Do nothing but return SECSuccess. +// This is called only in full handshake mode. +// Peer certificate is retrieved in HandshakeCallback() later, which is called +// in full handshake mode or in resumption handshake mode. +SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { + // TODO(hclam): Implement. + // Tell NSS to not verify the certificate. + return SECSuccess; +} + +// static +// NSS calls this when handshake is completed. +// After the SSL handshake is finished we need to verify the certificate. +void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, + void* arg) { + // TODO(hclam): Implement. +} + +int SSLServerSocketNSS::InitializeSSLOptions() { + // Transport connected, now hook it up to nss + // TODO(port): specify rx and tx buffer sizes separately + nss_fd_ = memio_CreateIOLayer(kRecvBufferSize); + if (nss_fd_ == NULL) { + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR error code. + } + + // Grab pointer to buffers + nss_bufs_ = memio_GetSecret(nss_fd_); + + /* Create SSL state machine */ + /* Push SSL onto our fake I/O socket */ + nss_fd_ = SSL_ImportFD(NULL, nss_fd_); + if (nss_fd_ == NULL) { + LogFailedNSSFunction(net_log_, "SSL_ImportFD", ""); + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code. + } + // TODO(port): set more ssl options! Check errors! + + int rv; + + rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL3, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL3"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_TLS, ssl_config_.tls1_enabled); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_TLS"); + return ERR_UNEXPECTED; + } + + for (std::vector<uint16>::const_iterator it = + ssl_config_.disabled_cipher_suites.begin(); + it != ssl_config_.disabled_cipher_suites.end(); ++it) { + // This will fail if the specified cipher is not implemented by NSS, but + // the failure is harmless. + SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE); + } + + // Server socket doesn't need session tickets. + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction( + net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS"); + } + + // Doing this will force PR_Accept perform handshake as server. + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_SERVER, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_SERVER"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_REQUEST_CERTIFICATE, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUEST_CERTIFICATE"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_REQUIRE_CERTIFICATE, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUIRE_CERTIFICATE"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_NO_CACHE, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_NO_CACHE"); + return ERR_UNEXPECTED; + } + + rv = SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ConfigureServerSessionIDCache", ""); + return ERR_UNEXPECTED; + } + + rv = SSL_AuthCertificateHook(nss_fd_, OwnAuthCertHandler, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_AuthCertificateHook", ""); + return ERR_UNEXPECTED; + } + + rv = SSL_HandshakeCallback(nss_fd_, HandshakeCallback, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_HandshakeCallback", ""); + return ERR_UNEXPECTED; + } + + // Get a certificate of CERTCertificate structure. + std::string der_string; + if (!cert_->GetDEREncoded(&der_string)) + return ERR_UNEXPECTED; + + SECItem der_cert; + der_cert.data = reinterpret_cast<unsigned char*>(const_cast<char*>( + der_string.data())); + der_cert.len = der_string.length(); + der_cert.type = siDERCertBuffer; + + // Parse into a CERTCertificate structure. + CERTCertificate* cert = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); + + // Get a key of SECKEYPrivateKey* structure. + std::vector<uint8> key_vector; + if (!key_->ExportPrivateKey(&key_vector)) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + SECKEYPrivateKeyStr* private_key = NULL; + PK11SlotInfo *slot = base::GetDefaultNSSKeySlot(); + if (!slot) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + SECItem der_private_key_info; + der_private_key_info.data = + const_cast<unsigned char*>(&key_vector.front()); + der_private_key_info.len = key_vector.size(); + rv = PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot, &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE, + KU_DIGITAL_SIGNATURE, &private_key, NULL); + PK11_FreeSlot(slot); + if (rv != SECSuccess) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + // Assign server certificate and private key. + SSLKEAType cert_kea = NSS_FindCertKEAType(cert); + rv = SSL_ConfigSecureServer(nss_fd_, cert, private_key, cert_kea); + CERT_DestroyCertificate(cert); + SECKEY_DestroyPrivateKey(private_key); + + if (rv != SECSuccess) { + PRErrorCode prerr = PR_GetError(); + LOG(ERROR) << "Failed to config SSL server: " << prerr; + LogFailedNSSFunction(net_log_, "SSL_ConfigureSecureServer", ""); + return ERR_UNEXPECTED; + } + + // Tell SSL we're a server; needed if not letting NSPR do socket I/O + rv = SSL_ResetHandshake(nss_fd_, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ResetHandshake", ""); + return ERR_UNEXPECTED; + } + + return OK; +} + +// Return 0 for EOF, +// > 0 for bytes transferred immediately, +// < 0 for error (or the non-error ERR_IO_PENDING). +int SSLServerSocketNSS::BufferSend(void) { + if (transport_send_busy_) + return ERR_IO_PENDING; + + const char* buf1; + const char* buf2; + unsigned int len1, len2; + memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2); + const unsigned int len = len1 + len2; + + int rv = 0; + if (len) { + scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); + memcpy(send_buffer->data(), buf1, len1); + memcpy(send_buffer->data() + len1, buf2, len2); + rv = transport_socket_->Write(send_buffer, len, + &buffer_send_callback_); + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv)); + } + } + + return rv; +} + +void SSLServerSocketNSS::BufferSendComplete(int result) { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result)); + transport_send_busy_ = false; + OnSendComplete(result); +} + +int SSLServerSocketNSS::BufferRecv(void) { + if (transport_recv_busy_) return ERR_IO_PENDING; + + char *buf; + int nb = memio_GetReadParams(nss_bufs_, &buf); + int rv; + if (!nb) { + // buffer too full to read into, so no I/O possible at moment + rv = ERR_IO_PENDING; + } else { + recv_buffer_ = new IOBuffer(nb); + rv = transport_socket_->Read(recv_buffer_, nb, &buffer_recv_callback_); + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + if (rv > 0) + memcpy(buf, recv_buffer_->data(), rv); + memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv)); + recv_buffer_ = NULL; + } + } + return rv; +} + +void SSLServerSocketNSS::BufferRecvComplete(int result) { + if (result > 0) { + char *buf; + memio_GetReadParams(nss_bufs_, &buf); + memcpy(buf, recv_buffer_->data(), result); + } + recv_buffer_ = NULL; + memio_PutReadResult(nss_bufs_, MapErrorToNSS(result)); + transport_recv_busy_ = false; + OnRecvComplete(result); +} + +void SSLServerSocketNSS::OnSendComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + if (!user_write_buf_ || !completed_handshake_) + return; + + int rv = DoWriteLoop(result); + if (rv != ERR_IO_PENDING) + DoWriteCallback(rv); +} + +void SSLServerSocketNSS::OnRecvComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + // Network layer received some data, check if client requested to read + // decrypted data. + if (!user_read_buf_ || !completed_handshake_) + return; + + int rv = DoReadLoop(result); + if (rv != ERR_IO_PENDING) + DoReadCallback(rv); +} + +void SSLServerSocketNSS::OnHandshakeIOComplete(int result) { + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(net::NetLog::TYPE_SSL_ACCEPT, NULL); + if (user_accept_callback_) + DoAcceptCallback(rv); + } +} + +void SSLServerSocketNSS::DoAcceptCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + + CompletionCallback* c = user_accept_callback_; + user_accept_callback_ = NULL; + c->Run(rv > OK ? OK : rv); +} + +void SSLServerSocketNSS::DoReadCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_read_callback_); + + // Since Run may result in Read being called, clear |user_read_callback_| + // up front. + CompletionCallback* c = user_read_callback_; + user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); +} + +void SSLServerSocketNSS::DoWriteCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_write_callback_); + + // Since Run may result in Write being called, clear |user_write_callback_| + // up front. + CompletionCallback* c = user_write_callback_; + user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); +} + +// Do network I/O between the given buffer and the given socket. +// Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) +bool SSLServerSocketNSS::DoTransportIO() { + bool network_moved = false; + if (nss_bufs_ != NULL) { + int nsent = BufferSend(); + int nreceived = BufferRecv(); + network_moved = (nsent > 0 || nreceived >= 0); + } + return network_moved; +} + +int SSLServerSocketNSS::DoPayloadRead() { + DCHECK(user_read_buf_); + DCHECK_GT(user_read_buf_len_, 0); + int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_); + if (rv >= 0) + return rv; + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) { + return ERR_IO_PENDING; + } + rv = MapNSSError(prerr); + net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, prerr))); + return rv; +} + +int SSLServerSocketNSS::DoPayloadWrite() { + DCHECK(user_write_buf_); + int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_); + if (rv >= 0) + return rv; + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) { + return ERR_IO_PENDING; + } + rv = MapNSSError(prerr); + net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, prerr))); + return rv; +} + +int SSLServerSocketNSS::DoHandshakeLoop(int last_io_result) { + bool network_moved; + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + // (This is a quirk carried over from the windows + // implementation. It makes reading the logs a bit harder.) + // State handlers can and often do call GotoState just + // to stay in the current state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + switch (state) { + case STATE_NONE: + // we're just pumping data between the buffer and the network + break; + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + default: + rv = ERR_UNEXPECTED; + LOG(DFATAL) << "unexpected state " << state; + break; + } + + // Do the actual network I/O + network_moved = DoTransportIO(); + } while ((rv != ERR_IO_PENDING || network_moved) && + next_handshake_state_ != STATE_NONE); + return rv; +} + +int SSLServerSocketNSS::DoReadLoop(int result) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, 0))); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadRead(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + return rv; +} + +int SSLServerSocketNSS::DoWriteLoop(int result) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, + make_scoped_refptr(new SSLErrorParams(rv, 0))); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + return rv; +} + +int SSLServerSocketNSS::DoHandshake() { + int net_error = net::OK; + SECStatus rv = SSL_ForceHandshake(nss_fd_); + + if (rv == SECSuccess) { + completed_handshake_ = true; + } else { + PRErrorCode prerr = PR_GetError(); + net_error = MapNSSHandshakeError(prerr); + + // If not done, stay in this state + if (net_error == ERR_IO_PENDING) { + GotoState(STATE_HANDSHAKE); + } else { + LOG(ERROR) << "handshake failed; NSS error code " << prerr + << ", net_error " << net_error; + net_log_.AddEvent( + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + make_scoped_refptr(new SSLErrorParams(net_error, prerr))); + } + } + return net_error; +} + +} // namespace net diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h new file mode 100644 index 0000000..3883c9b --- /dev/null +++ b/net/socket/ssl_server_socket_nss.h @@ -0,0 +1,133 @@ +// Copyright (c) 2006-2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ +#define NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ +#pragma once + +#include <certt.h> +#include <keyt.h> +#include <nspr.h> +#include <nss.h> + +#include "base/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_log.h" +#include "net/base/nss_memio.h" +#include "net/base/ssl_config_service.h" +#include "net/socket/ssl_server_socket.h" + +namespace net { + +class SSLServerSocketNSS : public SSLServerSocket { + public: + // This object takes ownership of the following parameters: + // |socket| - A socket that is already connected. + // |cert| - The certificate to be used by the server. + // + // The following parameters are copied in the constructor. + // |ssl_config| - Options for SSL socket. + // |key| - The private key used by the server. + SSLServerSocketNSS(Socket* transport_socket, + scoped_refptr<X509Certificate> cert, + base::RSAPrivateKey* key, + const SSLConfig& ssl_config); + virtual ~SSLServerSocketNSS(); + + // SSLServerSocket implementation. + virtual int Accept(CompletionCallback* callback); + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback); + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback); + virtual bool SetReceiveBufferSize(int32 size) { return false; } + virtual bool SetSendBufferSize(int32 size) { return false; } + + private: + virtual int Init(); + + int InitializeSSLOptions(); + + void OnSendComplete(int result); + void OnRecvComplete(int result); + void OnHandshakeIOComplete(int result); + + int BufferSend(); + void BufferSendComplete(int result); + int BufferRecv(); + void BufferRecvComplete(int result); + bool DoTransportIO(); + int DoPayloadWrite(); + int DoPayloadRead(); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + int DoHandshake(); + void DoAcceptCallback(int result); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + static SECStatus OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server); + static void HandshakeCallback(PRFileDesc* socket, void* arg); + + // Members used to send and receive buffer. + CompletionCallbackImpl<SSLServerSocketNSS> buffer_send_callback_; + CompletionCallbackImpl<SSLServerSocketNSS> buffer_recv_callback_; + bool transport_send_busy_; + bool transport_recv_busy_; + + scoped_refptr<IOBuffer> recv_buffer_; + + BoundNetLog net_log_; + + CompletionCallback* user_accept_callback_; + CompletionCallback* user_read_callback_; + CompletionCallback* user_write_callback_; + + // Used by Read function. + scoped_refptr<IOBuffer> user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr<IOBuffer> user_write_buf_; + int user_write_buf_len_; + + // The NSS SSL state machine + PRFileDesc* nss_fd_; + + // Buffers for the network end of the SSL state machine + memio_Private* nss_bufs_; + + // Socket for sending and receiving data. + scoped_ptr<Socket> transport_socket_; + + // Options for the SSL socket. + // TODO(hclam): This memeber is currently not used. Should make use of this + // member to configure the socket. + SSLConfig ssl_config_; + + // Certificate for the server. + scoped_refptr<X509Certificate> cert_; + + // Private key used by the server. + scoped_ptr<base::RSAPrivateKey> key_; + + enum State { + STATE_NONE, + STATE_HANDSHAKE, + }; + State next_handshake_state_; + bool completed_handshake_; + + DISALLOW_COPY_AND_ASSIGN(SSLServerSocketNSS); +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc new file mode 100644 index 0000000..781a3f4 --- /dev/null +++ b/net/socket/ssl_server_socket_unittest.cc @@ -0,0 +1,369 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This test suite uses SSLClientSocket to test the implementation of +// SSLServerSocket. In order to establish connections between the sockets +// we need two additional classes: +// 1. FakeSocket +// Connects SSL socket to FakeDataChannel. This class is just a stub. +// +// 2. FakeDataChannel +// Implements the actual exchange of data between two FakeSockets. +// +// Implementations of these two classes are included in this file. + +#include "net/socket/ssl_server_socket.h" + +#include <queue> + +#include "base/crypto/rsa_private_key.h" +#include "base/file_path.h" +#include "base/file_util.h" +#include "base/nss_util.h" +#include "base/path_service.h" +#include "net/base/address_list.h" +#include "net/base/cert_verifier.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/ssl_config_service.h" +#include "net/base/x509_certificate.h" +#include "net/socket/client_socket.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { + +class FakeDataChannel { + public: + FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) { + } + + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + if (data_.empty()) { + read_callback_ = callback; + read_buf_ = buf; + read_buf_len_ = buf_len; + return net::ERR_IO_PENDING; + } + return PropogateData(buf, buf_len); + } + + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + data_.push(new net::DrainableIOBuffer(buf, buf_len)); + DoReadCallback(); + return buf_len; + } + + private: + void DoReadCallback() { + if (!read_callback_) + return; + + int copied = PropogateData(read_buf_, read_buf_len_); + net::CompletionCallback* callback = read_callback_; + read_callback_ = NULL; + read_buf_ = NULL; + read_buf_len_ = 0; + callback->Run(copied); + } + + int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { + scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); + int copied = std::min(buf->BytesRemaining(), read_buf_len); + memcpy(read_buf->data(), buf->data(), copied); + buf->DidConsume(copied); + + if (!buf->BytesRemaining()) + data_.pop(); + return copied; + } + + net::CompletionCallback* read_callback_; + scoped_refptr<net::IOBuffer> read_buf_; + int read_buf_len_; + + std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; + + DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); +}; + +class FakeSocket : public ClientSocket { + public: + FakeSocket(FakeDataChannel* incoming_channel, + FakeDataChannel* outgoing_channel) + : incoming_(incoming_channel), + outgoing_(outgoing_channel) { + } + + virtual ~FakeSocket() { + + } + + virtual int Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return incoming_->Read(buf, buf_len, callback); + } + + virtual int Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + return outgoing_->Write(buf, buf_len, callback); + } + + virtual bool SetReceiveBufferSize(int32 size) { + return true; + } + + virtual bool SetSendBufferSize(int32 size) { + return true; + } + + virtual int Connect(CompletionCallback* callback) { + return net::OK; + } + + virtual void Disconnect() {} + + virtual bool IsConnected() const { + return true; + } + + virtual bool IsConnectedAndIdle() const { + return true; + } + + virtual int GetPeerAddress(AddressList* address) const { + net::IPAddressNumber ip_address(4); + *address = net::AddressList(ip_address, 0, false); + return net::OK; + } + + virtual const BoundNetLog& NetLog() const { + return net_log_; + } + + virtual void SetSubresourceSpeculation() {} + virtual void SetOmniboxSpeculation() {} + + virtual bool WasEverUsed() const { + return true; + } + + virtual bool UsingTCPFastOpen() const { + return false; + } + + private: + net::BoundNetLog net_log_; + FakeDataChannel* incoming_; + FakeDataChannel* outgoing_; + + DISALLOW_COPY_AND_ASSIGN(FakeSocket); +}; + +} // namespace + +// Verify the correctness of the test helper classes first. +TEST(FakeSocketTest, DataTransfer) { + // Establish channels between two sockets. + FakeDataChannel channel_1; + FakeDataChannel channel_2; + FakeSocket client(&channel_1, &channel_2); + FakeSocket server(&channel_2, &channel_1); + + const char kTestData[] = "testing123"; + const int kTestDataSize = strlen(kTestData); + const int kReadBufSize = 1024; + scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); + scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); + + // Write then read. + EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL)); + EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL)); + EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); + + // Read then write. + TestCompletionCallback callback; + EXPECT_EQ(net::ERR_IO_PENDING, + server.Read(read_buf, kReadBufSize, &callback)); + EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL)); + EXPECT_EQ(kTestDataSize, callback.WaitForResult()); + EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); +} + +class SSLServerSocketTest : public PlatformTest { + public: + SSLServerSocketTest() + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { + } + + protected: + void Initialize() { + FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); + FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); + + FilePath certs_dir; + PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir); + certs_dir = certs_dir.AppendASCII("net"); + certs_dir = certs_dir.AppendASCII("data"); + certs_dir = certs_dir.AppendASCII("ssl"); + certs_dir = certs_dir.AppendASCII("certificates"); + + FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); + std::string cert_der; + ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); + + scoped_refptr<net::X509Certificate> cert = + X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); + + FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); + std::string key_string; + ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); + std::vector<uint8> key_vector( + reinterpret_cast<const uint8*>(key_string.data()), + reinterpret_cast<const uint8*>(key_string.data() + + key_string.length())); + + scoped_ptr<base::RSAPrivateKey> private_key( + base::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); + + net::SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + ssl_config.snap_start_enabled = false; + ssl_config.ssl3_enabled = true; + ssl_config.tls1_enabled = true; + ssl_config.session_resume_disabled = true; + + // Certificate provided by the host doesn't need authority. + net::SSLConfig::CertAndStatus cert_and_status; + cert_and_status.cert_status = net::ERR_CERT_AUTHORITY_INVALID; + cert_and_status.cert = cert; + ssl_config.allowed_bad_certs.push_back(cert_and_status); + + net::HostPortPair host_and_pair("unittest", 0); + client_socket_.reset( + socket_factory_->CreateSSLClientSocket( + fake_client_socket, host_and_pair, ssl_config, NULL, + &cert_verifier_)); + server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket, + cert, private_key.get(), + net::SSLConfig())); + } + + FakeDataChannel channel_1_; + FakeDataChannel channel_2_; + scoped_ptr<net::SSLClientSocket> client_socket_; + scoped_ptr<net::SSLServerSocket> server_socket_; + net::ClientSocketFactory* socket_factory_; + net::CertVerifier cert_verifier_; +}; + +// SSLServerSocket is only implemented using NSS. +#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) + +// This test only executes creation of client and server sockets. This is to +// test that creation of sockets doesn't crash and have minimal code to run +// under valgrind in order to help debugging memory problems. +TEST_F(SSLServerSocketTest, Initialize) { + Initialize(); +} + +// This test executes Connect() of SSLClientSocket and Accept() of +// SSLServerSocket to make sure handshaking between the two sockets are +// completed successfully. +TEST_F(SSLServerSocketTest, Handshake) { + Initialize(); + + if (!base::CheckNSSVersion("3.12.8")) + return; + + TestCompletionCallback connect_callback; + TestCompletionCallback accept_callback; + + int server_ret = server_socket_->Accept(&accept_callback); + EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + int client_ret = client_socket_->Connect(&connect_callback); + EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, connect_callback.WaitForResult()); + } + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, accept_callback.WaitForResult()); + } +} + +TEST_F(SSLServerSocketTest, DataTransfer) { + Initialize(); + + if (!base::CheckNSSVersion("3.12.8")) + return; + + TestCompletionCallback connect_callback; + TestCompletionCallback accept_callback; + + // Establish connection. + int client_ret = client_socket_->Connect(&connect_callback); + EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + int server_ret = server_socket_->Accept(&accept_callback); + EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, connect_callback.WaitForResult()); + } + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, accept_callback.WaitForResult()); + } + + const int kReadBufSize = 1024; + scoped_refptr<net::StringIOBuffer> write_buf = + new net::StringIOBuffer("testing123"); + scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); + + // Write then read. + TestCompletionCallback write_callback; + TestCompletionCallback read_callback; + server_ret = server_socket_->Write(write_buf, write_buf->size(), + &write_callback); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_GT(write_callback.WaitForResult(), 0); + } + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_GT(read_callback.WaitForResult(), 0); + } + EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); + + // Read then write. + write_buf = new net::StringIOBuffer("hello123"); + server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + client_ret = client_socket_->Write(write_buf, write_buf->size(), + &write_callback); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_GT(read_callback.WaitForResult(), 0); + } + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_GT(write_callback.WaitForResult(), 0); + } + EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); +} +#endif + +} // namespace net |