diff options
Diffstat (limited to 'remoting/jingle_glue/ssl_socket_adapter.cc')
-rw-r--r-- | remoting/jingle_glue/ssl_socket_adapter.cc | 467 |
1 files changed, 467 insertions, 0 deletions
diff --git a/remoting/jingle_glue/ssl_socket_adapter.cc b/remoting/jingle_glue/ssl_socket_adapter.cc new file mode 100644 index 0000000..08ba785 --- /dev/null +++ b/remoting/jingle_glue/ssl_socket_adapter.cc @@ -0,0 +1,467 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/jingle_glue/ssl_socket_adapter.h" + +#include "base/base64.h" +#include "base/compiler_specific.h" +#include "base/message_loop.h" +#include "jingle/glue/utils.h" +#include "net/base/address_list.h" +#include "net/base/cert_verifier.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_errors.h" +#include "net/base/ssl_config_service.h" +#include "net/base/transport_security_state.h" +#include "net/socket/client_socket_factory.h" +#include "net/url_request/url_request_context.h" + +namespace remoting { + +SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { + return new SSLSocketAdapter(socket); +} + +SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) + : SSLAdapter(socket), + ignore_bad_cert_(false), + cert_verifier_(net::CertVerifier::CreateDefault()), + transport_security_state_(new net::TransportSecurityState()), + ssl_state_(SSLSTATE_NONE), + read_pending_(false), + write_pending_(false) { + transport_socket_ = new TransportSocket(socket, this); +} + +SSLSocketAdapter::~SSLSocketAdapter() { +} + +int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { + DCHECK(!restartable); + hostname_ = hostname; + + if (socket_->GetState() != Socket::CS_CONNECTED) { + ssl_state_ = SSLSTATE_WAIT; + return 0; + } else { + return BeginSSL(); + } +} + +int SSLSocketAdapter::BeginSSL() { + if (!MessageLoop::current()) { + // Certificate verification is done via the Chrome message loop. + // Without this check, if we don't have a chrome message loop the + // SSL connection just hangs silently. + LOG(DFATAL) << "Chrome message loop (needed by SSL certificate " + << "verification) does not exist"; + return net::ERR_UNEXPECTED; + } + + // SSLConfigService is not thread-safe, and the default values for SSLConfig + // are correct for us, so we don't use the config service to initialize this + // object. + net::SSLConfig ssl_config; + net::SSLClientSocketContext context( + cert_verifier_.get(), NULL, transport_security_state_.get(), ""); + + transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); + ssl_socket_.reset( + net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( + transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, + context)); + + int result = ssl_socket_->Connect( + base::Bind(&SSLSocketAdapter::OnConnected, base::Unretained(this))); + + if (result == net::ERR_IO_PENDING || result == net::OK) { + return 0; + } else { + LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); + return result; + } +} + +int SSLSocketAdapter::Send(const void* buf, size_t len) { + if (ssl_state_ == SSLSTATE_ERROR) { + SetError(EINVAL); + return -1; + } + + if (ssl_state_ == SSLSTATE_NONE) { + // Propagate the call to underlying socket if SSL is not connected + // yet (connection is not encrypted until StartSSL() is called). + return AsyncSocketAdapter::Send(buf, len); + } + + if (write_pending_) { + SetError(EWOULDBLOCK); + return -1; + } + + write_buffer_ = new net::DrainableIOBuffer(new net::IOBuffer(len), len); + memcpy(write_buffer_->data(), buf, len); + + DoWrite(); + + return len; +} + +int SSLSocketAdapter::Recv(void* buf, size_t len) { + switch (ssl_state_) { + case SSLSTATE_NONE: { + return AsyncSocketAdapter::Recv(buf, len); + } + + case SSLSTATE_WAIT: { + SetError(EWOULDBLOCK); + return -1; + } + + case SSLSTATE_CONNECTED: { + if (read_pending_) { + SetError(EWOULDBLOCK); + return -1; + } + + int bytes_read = 0; + + // Process any data we have left from the previous read. + if (read_buffer_) { + int size = std::min(read_buffer_->RemainingCapacity(), + static_cast<int>(len)); + memcpy(buf, read_buffer_->data(), size); + read_buffer_->set_offset(read_buffer_->offset() + size); + if (!read_buffer_->RemainingCapacity()) + read_buffer_ = NULL; + + if (size == static_cast<int>(len)) + return size; + + // If we didn't fill the caller's buffer then dispatch a new + // Read() in case there's more data ready. + buf = reinterpret_cast<char*>(buf) + size; + len -= size; + bytes_read = size; + DCHECK(!read_buffer_); + } + + // Dispatch a Read() request to the SSL layer. + read_buffer_ = new net::GrowableIOBuffer(); + read_buffer_->SetCapacity(len); + int result = ssl_socket_->Read( + read_buffer_, len, + base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); + if (result >= 0) + memcpy(buf, read_buffer_->data(), len); + + if (result == net::ERR_IO_PENDING) { + read_pending_ = true; + if (bytes_read) { + return bytes_read; + } else { + SetError(EWOULDBLOCK); + return -1; + } + } + + if (result < 0) { + SetError(EINVAL); + ssl_state_ = SSLSTATE_ERROR; + LOG(ERROR) << "Error reading from SSL socket " << result; + return -1; + } + read_buffer_ = NULL; + return result + bytes_read; + } + + case SSLSTATE_ERROR: { + SetError(EINVAL); + return -1; + } + } + + NOTREACHED(); + return -1; +} + +void SSLSocketAdapter::OnConnected(int result) { + if (result == net::OK) { + ssl_state_ = SSLSTATE_CONNECTED; + OnConnectEvent(this); + } else { + LOG(WARNING) << "OnConnected failed with error " << result; + } +} + +void SSLSocketAdapter::OnRead(int result) { + DCHECK(read_pending_); + read_pending_ = false; + if (result > 0) { + DCHECK_GE(read_buffer_->capacity(), result); + read_buffer_->SetCapacity(result); + } else { + if (result < 0) + ssl_state_ = SSLSTATE_ERROR; + } + AsyncSocketAdapter::OnReadEvent(this); +} + +void SSLSocketAdapter::OnWritten(int result) { + DCHECK(write_pending_); + write_pending_ = false; + if (result >= 0) { + write_buffer_->DidConsume(result); + if (!write_buffer_->BytesRemaining()) { + write_buffer_ = NULL; + } else { + DoWrite(); + } + } else { + ssl_state_ = SSLSTATE_ERROR; + } + AsyncSocketAdapter::OnWriteEvent(this); +} + +void SSLSocketAdapter::DoWrite() { + DCHECK_GT(write_buffer_->BytesRemaining(), 0); + DCHECK(!write_pending_); + + while (true) { + int result = ssl_socket_->Write( + write_buffer_, write_buffer_->BytesRemaining(), + base::Bind(&SSLSocketAdapter::OnWritten, base::Unretained(this))); + + if (result > 0) { + write_buffer_->DidConsume(result); + if (!write_buffer_->BytesRemaining()) { + write_buffer_ = NULL; + return; + } + continue; + } + + if (result == net::ERR_IO_PENDING) { + write_pending_ = true; + } else { + SetError(EINVAL); + ssl_state_ = SSLSTATE_ERROR; + } + return; + } +} + +void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { + if (ssl_state_ != SSLSTATE_WAIT) { + AsyncSocketAdapter::OnConnectEvent(socket); + } else { + ssl_state_ = SSLSTATE_NONE; + int result = BeginSSL(); + if (0 != result) { + // TODO(zork): Handle this case gracefully. + LOG(WARNING) << "BeginSSL() failed with " << result; + } + } +} + +TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, + SSLSocketAdapter *ssl_adapter) + : read_buffer_len_(0), + write_buffer_len_(0), + socket_(socket), + was_used_to_convey_data_(false) { + socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); + socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); +} + +TransportSocket::~TransportSocket() { +} + +int TransportSocket::Connect(const net::CompletionCallback& callback) { + // Connect is never called by SSLClientSocket, instead SSLSocketAdapter + // calls Connect() on socket_ directly. + NOTREACHED(); + return false; +} + +void TransportSocket::Disconnect() { + socket_->Close(); +} + +bool TransportSocket::IsConnected() const { + return (socket_->GetState() == talk_base::Socket::CS_CONNECTED); +} + +bool TransportSocket::IsConnectedAndIdle() const { + // Not implemented. + NOTREACHED(); + return false; +} + +int TransportSocket::GetPeerAddress(net::IPEndPoint* address) const { + talk_base::SocketAddress socket_address = socket_->GetRemoteAddress(); + if (jingle_glue::SocketAddressToIPEndPoint(socket_address, address)) { + return net::OK; + } else { + return net::ERR_FAILED; + } +} + +int TransportSocket::GetLocalAddress(net::IPEndPoint* address) const { + talk_base::SocketAddress socket_address = socket_->GetLocalAddress(); + if (jingle_glue::SocketAddressToIPEndPoint(socket_address, address)) { + return net::OK; + } else { + return net::ERR_FAILED; + } +} + +const net::BoundNetLog& TransportSocket::NetLog() const { + return net_log_; +} + +void TransportSocket::SetSubresourceSpeculation() { + NOTREACHED(); +} + +void TransportSocket::SetOmniboxSpeculation() { + NOTREACHED(); +} + +bool TransportSocket::WasEverUsed() const { + // We don't use this in ClientSocketPools, so this should never be used. + NOTREACHED(); + return was_used_to_convey_data_; +} + +bool TransportSocket::UsingTCPFastOpen() const { + return false; +} + +int64 TransportSocket::NumBytesRead() const { + NOTREACHED(); + return -1; +} + +base::TimeDelta TransportSocket::GetConnectTimeMicros() const { + NOTREACHED(); + return base::TimeDelta::FromMicroseconds(-1); +} + +bool TransportSocket::WasNpnNegotiated() const { + NOTREACHED(); + return false; +} + +net::NextProto TransportSocket::GetNegotiatedProtocol() const { + NOTREACHED(); + return net::kProtoUnknown; +} + +bool TransportSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + NOTREACHED(); + return false; +} + +int TransportSocket::Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK(read_callback_.is_null()); + DCHECK(!read_buffer_.get()); + int result = socket_->Recv(buf->data(), buf_len); + if (result < 0) { + result = net::MapSystemError(socket_->GetError()); + if (result == net::ERR_IO_PENDING) { + read_callback_ = callback; + read_buffer_ = buf; + read_buffer_len_ = buf_len; + } + } + if (result != net::ERR_IO_PENDING) + was_used_to_convey_data_ = true; + return result; +} + +int TransportSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK(buf); + DCHECK(write_callback_.is_null()); + DCHECK(!write_buffer_.get()); + int result = socket_->Send(buf->data(), buf_len); + if (result < 0) { + result = net::MapSystemError(socket_->GetError()); + if (result == net::ERR_IO_PENDING) { + write_callback_ = callback; + write_buffer_ = buf; + write_buffer_len_ = buf_len; + } + } + if (result != net::ERR_IO_PENDING) + was_used_to_convey_data_ = true; + return result; +} + +bool TransportSocket::SetReceiveBufferSize(int32 size) { + // Not implemented. + return false; +} + +bool TransportSocket::SetSendBufferSize(int32 size) { + // Not implemented. + return false; +} + +void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { + if (!read_callback_.is_null()) { + DCHECK(read_buffer_.get()); + net::CompletionCallback callback = read_callback_; + scoped_refptr<net::IOBuffer> buffer = read_buffer_; + int buffer_len = read_buffer_len_; + + read_callback_.Reset(); + read_buffer_ = NULL; + read_buffer_len_ = 0; + + int result = socket_->Recv(buffer->data(), buffer_len); + if (result < 0) { + result = net::MapSystemError(socket_->GetError()); + if (result == net::ERR_IO_PENDING) { + read_callback_ = callback; + read_buffer_ = buffer; + read_buffer_len_ = buffer_len; + return; + } + } + was_used_to_convey_data_ = true; + callback.Run(result); + } +} + +void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { + if (!write_callback_.is_null()) { + DCHECK(write_buffer_.get()); + net::CompletionCallback callback = write_callback_; + scoped_refptr<net::IOBuffer> buffer = write_buffer_; + int buffer_len = write_buffer_len_; + + write_callback_.Reset(); + write_buffer_ = NULL; + write_buffer_len_ = 0; + + int result = socket_->Send(buffer->data(), buffer_len); + if (result < 0) { + result = net::MapSystemError(socket_->GetError()); + if (result == net::ERR_IO_PENDING) { + write_callback_ = callback; + write_buffer_ = buffer; + write_buffer_len_ = buffer_len; + return; + } + } + was_used_to_convey_data_ = true; + callback.Run(result); + } +} + +} // namespace remoting |