// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "extensions/browser/api/socket/tls_socket.h" #include "base/callback_helpers.h" #include "base/logging.h" #include "extensions/browser/api/api_resource.h" #include "net/base/address_list.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/rand_callback.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/tcp_client_socket.h" #include "url/url_canon.h" namespace { // Returns the SSL protocol version (as a uint16) represented by a string. // Returns 0 if the string is invalid. uint16 SSLProtocolVersionFromString(const std::string& version_str) { uint16 version = 0; // Invalid. if (version_str == "tls1") { version = net::SSL_PROTOCOL_VERSION_TLS1; } else if (version_str == "tls1.1") { version = net::SSL_PROTOCOL_VERSION_TLS1_1; } else if (version_str == "tls1.2") { version = net::SSL_PROTOCOL_VERSION_TLS1_2; } return version; } void TlsConnectDone(scoped_ptr ssl_socket, const std::string& extension_id, const extensions::TLSSocket::SecureCallback& callback, int result) { DVLOG(1) << "Got back result " << result << " " << net::ErrorToString(result); // No matter how the TLS connection attempt went, the underlying socket's // no longer bound to the original TCPSocket. It belongs to |ssl_socket|, // which is promoted here to a new API-accessible socket (via a TLSSocket // wrapper), or deleted. if (result != net::OK) { callback.Run(scoped_ptr(), result); return; }; // Wrap the StreamSocket in a TLSSocket, which matches the extension socket // API. Set the handle of the socket to the new value, so that it can be // used for read/write/close/etc. scoped_ptr wrapper( new extensions::TLSSocket(ssl_socket.Pass(), extension_id)); // Caller will end up deleting the prior TCPSocket, once it calls // SetSocket(..,wrapper). callback.Run(wrapper.Pass(), result); } } // namespace namespace extensions { const char kTLSSocketTypeInvalidError[] = "Cannot listen on a socket that is already connected."; TLSSocket::TLSSocket(scoped_ptr tls_socket, const std::string& owner_extension_id) : ResumableTCPSocket(owner_extension_id), tls_socket_(tls_socket.Pass()) { } TLSSocket::~TLSSocket() { Disconnect(); } void TLSSocket::Connect(const std::string& address, uint16 port, const CompletionCallback& callback) { callback.Run(net::ERR_CONNECTION_FAILED); } void TLSSocket::Disconnect() { if (tls_socket_) { tls_socket_->Disconnect(); tls_socket_.reset(); } } void TLSSocket::Read(int count, const ReadCompletionCallback& callback) { DCHECK(!callback.is_null()); if (!read_callback_.is_null()) { callback.Run(net::ERR_IO_PENDING, NULL); return; } if (count <= 0) { callback.Run(net::ERR_INVALID_ARGUMENT, NULL); return; } if (!tls_socket_.get() || !IsConnected()) { callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL); return; } read_callback_ = callback; scoped_refptr io_buffer(new net::IOBuffer(count)); // |tls_socket_| is owned by this class and the callback won't be run once // |tls_socket_| is gone (as in an a call to Disconnect()). Therefore, it is // safe to use base::Unretained() here. int result = tls_socket_->Read( io_buffer.get(), count, base::Bind( &TLSSocket::OnReadComplete, base::Unretained(this), io_buffer)); if (result != net::ERR_IO_PENDING) { OnReadComplete(io_buffer, result); } } void TLSSocket::OnReadComplete(const scoped_refptr& io_buffer, int result) { DCHECK(!read_callback_.is_null()); base::ResetAndReturn(&read_callback_).Run(result, io_buffer); } int TLSSocket::WriteImpl(net::IOBuffer* io_buffer, int io_buffer_size, const net::CompletionCallback& callback) { if (!IsConnected()) { return net::ERR_SOCKET_NOT_CONNECTED; } return tls_socket_->Write(io_buffer, io_buffer_size, callback); } bool TLSSocket::SetKeepAlive(bool enable, int delay) { return false; } bool TLSSocket::SetNoDelay(bool no_delay) { return false; } int TLSSocket::Listen(const std::string& address, uint16 port, int backlog, std::string* error_msg) { *error_msg = kTLSSocketTypeInvalidError; return net::ERR_NOT_IMPLEMENTED; } void TLSSocket::Accept(const AcceptCompletionCallback& callback) { callback.Run(net::ERR_FAILED, NULL); } bool TLSSocket::IsConnected() { return tls_socket_.get() && tls_socket_->IsConnected(); } bool TLSSocket::GetPeerAddress(net::IPEndPoint* address) { return IsConnected() && tls_socket_->GetPeerAddress(address); } bool TLSSocket::GetLocalAddress(net::IPEndPoint* address) { return IsConnected() && tls_socket_->GetLocalAddress(address); } Socket::SocketType TLSSocket::GetSocketType() const { return Socket::TYPE_TLS; } // static void TLSSocket::UpgradeSocketToTLS( Socket* socket, scoped_refptr ssl_config_service, net::CertVerifier* cert_verifier, net::TransportSecurityState* transport_security_state, const std::string& extension_id, core_api::socket::SecureOptions* options, const TLSSocket::SecureCallback& callback) { DCHECK_CURRENTLY_ON(content::BrowserThread::IO); TCPSocket* tcp_socket = static_cast(socket); scoped_ptr null_sock; if (!tcp_socket || tcp_socket->GetSocketType() != Socket::TYPE_TCP || !tcp_socket->ClientStream() || !tcp_socket->IsConnected() || tcp_socket->HasPendingRead()) { DVLOG(1) << "Failing before trying. socket is " << tcp_socket; if (tcp_socket) { DVLOG(1) << "type: " << tcp_socket->GetSocketType() << ", ClientStream is " << tcp_socket->ClientStream() << ", IsConnected: " << tcp_socket->IsConnected() << ", HasPendingRead: " << tcp_socket->HasPendingRead(); } TlsConnectDone( null_sock.Pass(), extension_id, callback, net::ERR_INVALID_ARGUMENT); return; } net::IPEndPoint dest_host_port_pair; if (!tcp_socket->GetPeerAddress(&dest_host_port_pair)) { DVLOG(1) << "Could not get peer address."; TlsConnectDone( null_sock.Pass(), extension_id, callback, net::ERR_INVALID_ARGUMENT); return; } // Convert any U-LABELs to A-LABELs. url::CanonHostInfo host_info; std::string canon_host = net::CanonicalizeHost(tcp_socket->hostname(), &host_info); // Canonicalization shouldn't fail: the socket is already connected with a // host, using this hostname. if (host_info.family == url::CanonHostInfo::BROKEN) { DVLOG(1) << "Could not canonicalize hostname"; TlsConnectDone( null_sock.Pass(), extension_id, callback, net::ERR_INVALID_ARGUMENT); return; } net::HostPortPair host_and_port(canon_host, dest_host_port_pair.port()); scoped_ptr socket_handle( new net::ClientSocketHandle()); // Set the socket handle to the socket's client stream (that should be the // only one active here). Then have the old socket release ownership on // that client stream. socket_handle->SetSocket( scoped_ptr(tcp_socket->ClientStream())); tcp_socket->Release(); DCHECK(transport_security_state); net::SSLClientSocketContext context; context.cert_verifier = cert_verifier; context.transport_security_state = transport_security_state; // Fill in the SSL socket params. net::SSLConfig ssl_config; ssl_config_service->GetSSLConfig(&ssl_config); if (options && options->tls_version.get()) { uint16 version_min = 0, version_max = 0; core_api::socket::TLSVersionConstraints* versions = options->tls_version.get(); if (versions->min.get()) { version_min = SSLProtocolVersionFromString(*versions->min.get()); } if (versions->max.get()) { version_max = SSLProtocolVersionFromString(*versions->max.get()); } if (version_min) { ssl_config.version_min = version_min; } if (version_max) { ssl_config.version_max = version_max; } } net::ClientSocketFactory* socket_factory = net::ClientSocketFactory::GetDefaultFactory(); // Create the socket. scoped_ptr ssl_socket( socket_factory->CreateSSLClientSocket( socket_handle.Pass(), host_and_port, ssl_config, context)); DVLOG(1) << "Attempting to secure a connection to " << tcp_socket->hostname() << ":" << dest_host_port_pair.port(); // We need the contents of |ssl_socket| in order to invoke its Connect() // method. It belongs to |ssl_socket|, and we own that until our internal // callback (|connect_cb|, below) is invoked. net::SSLClientSocket* saved_ssl_socket = ssl_socket.get(); // Try establish a TLS connection. Pass ownership of |ssl_socket| to // TlsConnectDone, which will pass it on to |callback|. |connect_cb| below // is only for UpgradeSocketToTLS use, and not be confused with the // argument |callback|, which gets invoked by TlsConnectDone() after // Connect() below returns. base::Callback connect_cb(base::Bind( &TlsConnectDone, base::Passed(&ssl_socket), extension_id, callback)); int status = saved_ssl_socket->Connect(connect_cb); saved_ssl_socket = NULL; // Connect completed synchronously, or failed. if (status != net::ERR_IO_PENDING) { // Note: this can't recurse -- if |socket| is already a connected // TLSSocket, it will return TYPE_TLS instead of TYPE_TCP, causing // UpgradeSocketToTLS() to fail with an error above. If // UpgradeSocketToTLS() is called on |socket| twice, the call to // Release() on |socket| above causes the additional call to // fail with an error above. if (status != net::OK) { DVLOG(1) << "Status is not OK or IO-pending: " << net::ErrorToString(status); } connect_cb.Run(status); } } } // namespace extensions