diff options
Diffstat (limited to 'jingle/glue')
-rw-r--r-- | jingle/glue/chrome_async_socket.cc | 450 | ||||
-rw-r--r-- | jingle/glue/chrome_async_socket.h | 213 | ||||
-rw-r--r-- | jingle/glue/chrome_async_socket_unittest.cc | 1079 | ||||
-rw-r--r-- | jingle/glue/fake_ssl_client_socket.cc | 353 | ||||
-rw-r--r-- | jingle/glue/fake_ssl_client_socket.h | 114 | ||||
-rw-r--r-- | jingle/glue/fake_ssl_client_socket_unittest.cc | 348 | ||||
-rw-r--r-- | jingle/glue/proxy_resolving_client_socket.cc | 392 | ||||
-rw-r--r-- | jingle/glue/proxy_resolving_client_socket.h | 106 | ||||
-rw-r--r-- | jingle/glue/proxy_resolving_client_socket_unittest.cc | 117 | ||||
-rw-r--r-- | jingle/glue/resolving_client_socket_factory.h | 36 | ||||
-rw-r--r-- | jingle/glue/xmpp_client_socket_factory.cc | 56 | ||||
-rw-r--r-- | jingle/glue/xmpp_client_socket_factory.h | 56 |
12 files changed, 3320 insertions, 0 deletions
diff --git a/jingle/glue/chrome_async_socket.cc b/jingle/glue/chrome_async_socket.cc new file mode 100644 index 0000000..667ec7d --- /dev/null +++ b/jingle/glue/chrome_async_socket.cc @@ -0,0 +1,450 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/chrome_async_socket.h" + +#include <algorithm> +#include <cstring> +#include <cstdlib> + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/message_loop.h" +#include "jingle/glue/resolving_client_socket_factory.h" +#include "net/base/address_list.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/net_util.h" +#include "net/base/ssl_config_service.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_client_socket.h" +#include "third_party/libjingle/source/talk/base/socketaddress.h" + +namespace jingle_glue { + +ChromeAsyncSocket::ChromeAsyncSocket( + ResolvingClientSocketFactory* resolving_client_socket_factory, + size_t read_buf_size, + size_t write_buf_size) + : ALLOW_THIS_IN_INITIALIZER_LIST(weak_ptr_factory_(this)), + resolving_client_socket_factory_(resolving_client_socket_factory), + state_(STATE_CLOSED), + error_(ERROR_NONE), + net_error_(net::OK), + read_state_(IDLE), + read_buf_(new net::IOBufferWithSize(read_buf_size)), + read_start_(0U), + read_end_(0U), + write_state_(IDLE), + write_buf_(new net::IOBufferWithSize(write_buf_size)), + write_end_(0U) { + DCHECK(resolving_client_socket_factory_.get()); + DCHECK_GT(read_buf_size, 0U); + DCHECK_GT(write_buf_size, 0U); +} + +ChromeAsyncSocket::~ChromeAsyncSocket() {} + +ChromeAsyncSocket::State ChromeAsyncSocket::state() { + return state_; +} + +ChromeAsyncSocket::Error ChromeAsyncSocket::error() { + return error_; +} + +int ChromeAsyncSocket::GetError() { + return net_error_; +} + +bool ChromeAsyncSocket::IsOpen() const { + return (state_ == STATE_OPEN) || (state_ == STATE_TLS_OPEN); +} + +void ChromeAsyncSocket::DoNonNetError(Error error) { + DCHECK_NE(error, ERROR_NONE); + DCHECK_NE(error, ERROR_WINSOCK); + error_ = error; + net_error_ = net::OK; +} + +void ChromeAsyncSocket::DoNetError(net::Error net_error) { + error_ = ERROR_WINSOCK; + net_error_ = net_error; +} + +void ChromeAsyncSocket::DoNetErrorFromStatus(int status) { + DCHECK_LT(status, net::OK); + DoNetError(static_cast<net::Error>(status)); +} + +// STATE_CLOSED -> STATE_CONNECTING + +bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) { + if (state_ != STATE_CLOSED) { + LOG(DFATAL) << "Connect() called on non-closed socket"; + DoNonNetError(ERROR_WRONGSTATE); + return false; + } + if (address.hostname().empty() || address.port() == 0) { + DoNonNetError(ERROR_DNS); + return false; + } + + DCHECK_EQ(state_, buzz::AsyncSocket::STATE_CLOSED); + DCHECK_EQ(read_state_, IDLE); + DCHECK_EQ(write_state_, IDLE); + + state_ = STATE_CONNECTING; + + DCHECK(!weak_ptr_factory_.HasWeakPtrs()); + weak_ptr_factory_.InvalidateWeakPtrs(); + + net::HostPortPair dest_host_port_pair(address.hostname(), address.port()); + + transport_socket_.reset( + resolving_client_socket_factory_->CreateTransportClientSocket( + dest_host_port_pair)); + int status = transport_socket_->Connect( + base::Bind(&ChromeAsyncSocket::ProcessConnectDone, + weak_ptr_factory_.GetWeakPtr())); + if (status != net::ERR_IO_PENDING) { + // We defer execution of ProcessConnectDone instead of calling it + // directly here as the caller may not expect an error/close to + // happen here. This is okay, as from the caller's point of view, + // the connect always happens asynchronously. + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ChromeAsyncSocket::ProcessConnectDone, + weak_ptr_factory_.GetWeakPtr(), status)); + } + return true; +} + +// STATE_CONNECTING -> STATE_OPEN +// read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead()) + +void ChromeAsyncSocket::ProcessConnectDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK_EQ(read_state_, IDLE); + DCHECK_EQ(write_state_, IDLE); + DCHECK_EQ(state_, STATE_CONNECTING); + if (status != net::OK) { + DoNetErrorFromStatus(status); + DoClose(); + return; + } + state_ = STATE_OPEN; + PostDoRead(); + // Write buffer should be empty. + DCHECK_EQ(write_end_, 0U); + SignalConnected(); +} + +// read_state_ == IDLE -> read_state_ == POSTED + +void ChromeAsyncSocket::PostDoRead() { + DCHECK(IsOpen()); + DCHECK_EQ(read_state_, IDLE); + DCHECK_EQ(read_start_, 0U); + DCHECK_EQ(read_end_, 0U); + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ChromeAsyncSocket::DoRead, + weak_ptr_factory_.GetWeakPtr())); + read_state_ = POSTED; +} + +// read_state_ == POSTED -> read_state_ == PENDING + +void ChromeAsyncSocket::DoRead() { + DCHECK(IsOpen()); + DCHECK_EQ(read_state_, POSTED); + DCHECK_EQ(read_start_, 0U); + DCHECK_EQ(read_end_, 0U); + // Once we call Read(), we cannot call StartTls() until the read + // finishes. This is okay, as StartTls() is called only from a read + // handler (i.e., after a read finishes and before another read is + // done). + int status = + transport_socket_->Read( + read_buf_.get(), read_buf_->size(), + base::Bind(&ChromeAsyncSocket::ProcessReadDone, + weak_ptr_factory_.GetWeakPtr())); + read_state_ = PENDING; + if (status != net::ERR_IO_PENDING) { + ProcessReadDone(status); + } +} + +// read_state_ == PENDING -> read_state_ == IDLE + +void ChromeAsyncSocket::ProcessReadDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK(IsOpen()); + DCHECK_EQ(read_state_, PENDING); + DCHECK_EQ(read_start_, 0U); + DCHECK_EQ(read_end_, 0U); + read_state_ = IDLE; + if (status > 0) { + read_end_ = static_cast<size_t>(status); + SignalRead(); + } else if (status == 0) { + // Other side closed the connection. + error_ = ERROR_NONE; + net_error_ = net::OK; + DoClose(); + } else { // status < 0 + DoNetErrorFromStatus(status); + DoClose(); + } +} + +// (maybe) read_state_ == IDLE -> read_state_ == POSTED (via +// PostDoRead()) + +bool ChromeAsyncSocket::Read(char* data, size_t len, size_t* len_read) { + if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) { + // Read() may be called on a closed socket if a previous read + // causes a socket close (e.g., client sends wrong password and + // server terminates connection). + // + // TODO(akalin): Fix handling of this on the libjingle side. + if (state_ != STATE_CLOSED) { + LOG(DFATAL) << "Read() called on non-open non-tls-connecting socket"; + } + DoNonNetError(ERROR_WRONGSTATE); + return false; + } + DCHECK_LE(read_start_, read_end_); + if ((state_ == STATE_TLS_CONNECTING) || read_end_ == 0U) { + if (state_ == STATE_TLS_CONNECTING) { + DCHECK_EQ(read_state_, IDLE); + DCHECK_EQ(read_end_, 0U); + } else { + DCHECK_NE(read_state_, IDLE); + } + *len_read = 0; + return true; + } + DCHECK_EQ(read_state_, IDLE); + *len_read = std::min(len, read_end_ - read_start_); + DCHECK_GT(*len_read, 0U); + std::memcpy(data, read_buf_->data() + read_start_, *len_read); + read_start_ += *len_read; + if (read_start_ == read_end_) { + read_start_ = 0U; + read_end_ = 0U; + // We defer execution of DoRead() here for similar reasons as + // ProcessConnectDone(). + PostDoRead(); + } + return true; +} + +// (maybe) write_state_ == IDLE -> write_state_ == POSTED (via +// PostDoWrite()) + +bool ChromeAsyncSocket::Write(const char* data, size_t len) { + if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) { + LOG(DFATAL) << "Write() called on non-open non-tls-connecting socket"; + DoNonNetError(ERROR_WRONGSTATE); + return false; + } + // TODO(akalin): Avoid this check by modifying the interface to have + // a "ready for writing" signal. + if ((static_cast<size_t>(write_buf_->size()) - write_end_) < len) { + LOG(DFATAL) << "queueing " << len << " bytes would exceed the " + << "max write buffer size = " << write_buf_->size() + << " by " << (len - write_buf_->size()) << " bytes"; + DoNetError(net::ERR_INSUFFICIENT_RESOURCES); + return false; + } + std::memcpy(write_buf_->data() + write_end_, data, len); + write_end_ += len; + // If we're TLS-connecting, the write buffer will get flushed once + // the TLS-connect finishes. Otherwise, start writing if we're not + // already writing and we have something to write. + if ((state_ != STATE_TLS_CONNECTING) && + (write_state_ == IDLE) && (write_end_ > 0U)) { + // We defer execution of DoWrite() here for similar reasons as + // ProcessConnectDone(). + PostDoWrite(); + } + return true; +} + +// write_state_ == IDLE -> write_state_ == POSTED + +void ChromeAsyncSocket::PostDoWrite() { + DCHECK(IsOpen()); + DCHECK_EQ(write_state_, IDLE); + DCHECK_GT(write_end_, 0U); + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ChromeAsyncSocket::DoWrite, + weak_ptr_factory_.GetWeakPtr())); + write_state_ = POSTED; +} + +// write_state_ == POSTED -> write_state_ == PENDING + +void ChromeAsyncSocket::DoWrite() { + DCHECK(IsOpen()); + DCHECK_EQ(write_state_, POSTED); + DCHECK_GT(write_end_, 0U); + // Once we call Write(), we cannot call StartTls() until the write + // finishes. This is okay, as StartTls() is called only after we + // have received a reply to a message we sent to the server and + // before we send the next message. + int status = + transport_socket_->Write( + write_buf_.get(), write_end_, + base::Bind(&ChromeAsyncSocket::ProcessWriteDone, + weak_ptr_factory_.GetWeakPtr())); + write_state_ = PENDING; + if (status != net::ERR_IO_PENDING) { + ProcessWriteDone(status); + } +} + +// write_state_ == PENDING -> write_state_ == IDLE or POSTED (the +// latter via PostDoWrite()) + +void ChromeAsyncSocket::ProcessWriteDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK(IsOpen()); + DCHECK_EQ(write_state_, PENDING); + DCHECK_GT(write_end_, 0U); + write_state_ = IDLE; + if (status < net::OK) { + DoNetErrorFromStatus(status); + DoClose(); + return; + } + size_t written = static_cast<size_t>(status); + if (written > write_end_) { + LOG(DFATAL) << "bytes written = " << written + << " exceeds bytes requested = " << write_end_; + DoNetError(net::ERR_UNEXPECTED); + DoClose(); + return; + } + // TODO(akalin): Figure out a better way to do this; perhaps a queue + // of DrainableIOBuffers. This'll also allow us to not have an + // artificial buffer size limit. + std::memmove(write_buf_->data(), + write_buf_->data() + written, + write_end_ - written); + write_end_ -= written; + if (write_end_ > 0U) { + PostDoWrite(); + } +} + +// * -> STATE_CLOSED + +bool ChromeAsyncSocket::Close() { + DoClose(); + return true; +} + +// (not STATE_CLOSED) -> STATE_CLOSED + +void ChromeAsyncSocket::DoClose() { + weak_ptr_factory_.InvalidateWeakPtrs(); + if (transport_socket_.get()) { + transport_socket_->Disconnect(); + } + transport_socket_.reset(); + read_state_ = IDLE; + read_start_ = 0U; + read_end_ = 0U; + write_state_ = IDLE; + write_end_ = 0U; + if (state_ != STATE_CLOSED) { + state_ = STATE_CLOSED; + SignalClosed(); + } + // Reset error variables after SignalClosed() so slots connected + // to it can read it. + error_ = ERROR_NONE; + net_error_ = net::OK; +} + +// STATE_OPEN -> STATE_TLS_CONNECTING + +bool ChromeAsyncSocket::StartTls(const std::string& domain_name) { + if ((state_ != STATE_OPEN) || (read_state_ == PENDING) || + (write_state_ != IDLE)) { + LOG(DFATAL) << "StartTls() called in wrong state"; + DoNonNetError(ERROR_WRONGSTATE); + return false; + } + + state_ = STATE_TLS_CONNECTING; + read_state_ = IDLE; + read_start_ = 0U; + read_end_ = 0U; + DCHECK_EQ(write_end_, 0U); + + // Clear out any posted DoRead() tasks. + weak_ptr_factory_.InvalidateWeakPtrs(); + + DCHECK(transport_socket_.get()); + scoped_ptr<net::ClientSocketHandle> socket_handle( + new net::ClientSocketHandle()); + socket_handle->set_socket(transport_socket_.release()); + transport_socket_.reset( + resolving_client_socket_factory_->CreateSSLClientSocket( + socket_handle.release(), net::HostPortPair(domain_name, 443))); + int status = transport_socket_->Connect( + base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone, + weak_ptr_factory_.GetWeakPtr())); + if (status != net::ERR_IO_PENDING) { + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone, + weak_ptr_factory_.GetWeakPtr(), status)); + } + return true; +} + +// STATE_TLS_CONNECTING -> STATE_TLS_OPEN +// read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead()) +// (maybe) write_state_ == IDLE -> write_state_ == POSTED (via +// PostDoWrite()) + +void ChromeAsyncSocket::ProcessSSLConnectDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK_EQ(state_, STATE_TLS_CONNECTING); + DCHECK_EQ(read_state_, IDLE); + DCHECK_EQ(read_start_, 0U); + DCHECK_EQ(read_end_, 0U); + DCHECK_EQ(write_state_, IDLE); + if (status != net::OK) { + DoNetErrorFromStatus(status); + DoClose(); + return; + } + state_ = STATE_TLS_OPEN; + PostDoRead(); + if (write_end_ > 0U) { + PostDoWrite(); + } + SignalSSLConnected(); +} + +} // namespace jingle_glue diff --git a/jingle/glue/chrome_async_socket.h b/jingle/glue/chrome_async_socket.h new file mode 100644 index 0000000..1037d24 --- /dev/null +++ b/jingle/glue/chrome_async_socket.h @@ -0,0 +1,213 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// An implementation of buzz::AsyncSocket that uses Chrome sockets. + +#ifndef JINGLE_GLUE_CHROME_ASYNC_SOCKET_H_ +#define JINGLE_GLUE_CHROME_ASYNC_SOCKET_H_ + +#if !defined(FEATURE_ENABLE_SSL) +#error ChromeAsyncSocket expects FEATURE_ENABLE_SSL to be defined +#endif + +#include <string> +#include <vector> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/net_errors.h" +#include "third_party/libjingle/source/talk/xmpp/asyncsocket.h" + +namespace net { +class IOBufferWithSize; +class StreamSocket; +} // namespace net + +namespace jingle_glue { + +class ResolvingClientSocketFactory; + +class ChromeAsyncSocket : public buzz::AsyncSocket { + public: + // Takes ownership of |resolving_client_socket_factory|. + ChromeAsyncSocket( + ResolvingClientSocketFactory* resolving_client_socket_factory, + size_t read_buf_size, + size_t write_buf_size); + + // Does not raise any signals. + virtual ~ChromeAsyncSocket(); + + // buzz::AsyncSocket implementation. + + // The current state (see buzz::AsyncSocket::State; all but + // STATE_CLOSING is used). + virtual State state() OVERRIDE; + + // The last generated error. Errors are generated when the main + // functions below return false or when SignalClosed is raised due + // to an asynchronous error. + virtual Error error() OVERRIDE; + + // GetError() (which is of type net::Error) != net::OK only when + // error() == ERROR_WINSOCK. + virtual int GetError() OVERRIDE; + + // Tries to connect to the given address. + // + // If state() is not STATE_CLOSED, sets error to ERROR_WRONGSTATE + // and returns false. + // + // If |address| has an empty hostname or a zero port, sets error to + // ERROR_DNS and returns false. (We don't use the IP address even + // if it's present, as DNS resolution is done by + // |resolving_client_socket_factory_|. But it's perfectly fine if + // the hostname is a stringified IP address.) + // + // Otherwise, starts the connection process and returns true. + // SignalConnected will be raised when the connection is successful; + // otherwise, SignalClosed will be raised with a net error set. + virtual bool Connect(const talk_base::SocketAddress& address) OVERRIDE; + + // Tries to read at most |len| bytes into |data|. + // + // If state() is not STATE_TLS_CONNECTING, STATE_OPEN, or + // STATE_TLS_OPEN, sets error to ERROR_WRONGSTATE and returns false. + // + // Otherwise, fills in |len_read| with the number of bytes read and + // returns true. If this is called when state() is + // STATE_TLS_CONNECTING, reads 0 bytes. (We have to handle this + // case because StartTls() is called during a slot connected to + // SignalRead after parsing the final non-TLS reply from the server + // [see XmppClient::Private::OnSocketRead()].) + virtual bool Read(char* data, size_t len, size_t* len_read) OVERRIDE; + + // Queues up |len| bytes of |data| for writing. + // + // If state() is not STATE_TLS_CONNECTING, STATE_OPEN, or + // STATE_TLS_OPEN, sets error to ERROR_WRONGSTATE and returns false. + // + // If the given data is too big for the internal write buffer, sets + // error to ERROR_WINSOCK/net::ERR_INSUFFICIENT_RESOURCES and + // returns false. + // + // Otherwise, queues up the data and returns true. If this is + // called when state() == STATE_TLS_CONNECTING, the data is will be + // sent only after the TLS connection succeeds. (See StartTls() + // below for why this happens.) + // + // Note that there's no guarantee that the data will actually be + // sent; however, it is guaranteed that the any data sent will be + // sent in FIFO order. + virtual bool Write(const char* data, size_t len) OVERRIDE; + + // If the socket is not already closed, closes the socket and raises + // SignalClosed. Always returns true. + virtual bool Close() OVERRIDE; + + // Tries to change to a TLS connection with the given domain name. + // + // If state() is not STATE_OPEN or there are pending reads or + // writes, sets error to ERROR_WRONGSTATE and returns false. (In + // practice, this means that StartTls() can only be called from a + // slot connected to SignalRead.) + // + // Otherwise, starts the TLS connection process and returns true. + // SignalSSLConnected will be raised when the connection is + // successful; otherwise, SignalClosed will be raised with a net + // error set. + virtual bool StartTls(const std::string& domain_name) OVERRIDE; + + // Signal behavior: + // + // SignalConnected: raised whenever the connect initiated by a call + // to Connect() is complete. + // + // SignalSSLConnected: raised whenever the connect initiated by a + // call to StartTls() is complete. Not actually used by + // XmppClient. (It just assumes that if SignalRead is raised after a + // call to StartTls(), the connection has been successfully + // upgraded.) + // + // SignalClosed: raised whenever the socket is closed, either due to + // an asynchronous error, the other side closing the connection, or + // when Close() is called. + // + // SignalRead: raised whenever the next call to Read() will succeed + // with a non-zero |len_read| (assuming nothing else happens in the + // meantime). + // + // SignalError: not used. + + private: + enum AsyncIOState { + // An I/O op is not in progress. + IDLE, + // A function has been posted to do the I/O. + POSTED, + // An async I/O operation is pending. + PENDING, + }; + + bool IsOpen() const; + + // Error functions. + void DoNonNetError(Error error); + void DoNetError(net::Error net_error); + void DoNetErrorFromStatus(int status); + + // Connection functions. + void ProcessConnectDone(int status); + + // Read loop functions. + void PostDoRead(); + void DoRead(); + void ProcessReadDone(int status); + + // Write loop functions. + void PostDoWrite(); + void DoWrite(); + void ProcessWriteDone(int status); + + // SSL/TLS connection functions. + void ProcessSSLConnectDone(int status); + + // Close functions. + void DoClose(); + + base::WeakPtrFactory<ChromeAsyncSocket> weak_ptr_factory_; + scoped_ptr<ResolvingClientSocketFactory> resolving_client_socket_factory_; + + // buzz::AsyncSocket state. + buzz::AsyncSocket::State state_; + buzz::AsyncSocket::Error error_; + net::Error net_error_; + + // NULL iff state() == STATE_CLOSED. + scoped_ptr<net::StreamSocket> transport_socket_; + + // State for the read loop. |read_start_| <= |read_end_| <= + // |read_buf_->size()|. There's a read in flight (i.e., + // |read_state_| != IDLE) iff |read_end_| == 0. + AsyncIOState read_state_; + scoped_refptr<net::IOBufferWithSize> read_buf_; + size_t read_start_, read_end_; + + // State for the write loop. |write_end_| <= |write_buf_->size()|. + // There's a write in flight (i.e., |write_state_| != IDLE) iff + // |write_end_| > 0. + AsyncIOState write_state_; + scoped_refptr<net::IOBufferWithSize> write_buf_; + size_t write_end_; + + DISALLOW_COPY_AND_ASSIGN(ChromeAsyncSocket); +}; + +} // namespace jingle_glue + +#endif // JINGLE_GLUE_CHROME_ASYNC_SOCKET_H_ diff --git a/jingle/glue/chrome_async_socket_unittest.cc b/jingle/glue/chrome_async_socket_unittest.cc new file mode 100644 index 0000000..0e3274b --- /dev/null +++ b/jingle/glue/chrome_async_socket_unittest.cc @@ -0,0 +1,1079 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/chrome_async_socket.h" + +#include <deque> +#include <string> + +#include "base/basictypes.h" +#include "base/logging.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" +#include "jingle/glue/resolving_client_socket_factory.h" +#include "net/base/address_list.h" +#include "net/base/mock_cert_verifier.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/ssl_config_service.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/url_request/url_request_context_getter.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/libjingle/source/talk/base/ipaddress.h" +#include "third_party/libjingle/source/talk/base/sigslot.h" +#include "third_party/libjingle/source/talk/base/socketaddress.h" + +namespace jingle_glue { + +namespace { + +// Data provider that handles reads/writes for ChromeAsyncSocket. +class AsyncSocketDataProvider : public net::SocketDataProvider { + public: + AsyncSocketDataProvider() : has_pending_read_(false) {} + + virtual ~AsyncSocketDataProvider() { + EXPECT_TRUE(writes_.empty()); + EXPECT_TRUE(reads_.empty()); + } + + // If there's no read, sets the "has pending read" flag. Otherwise, + // pops the next read. + virtual net::MockRead GetNextRead() { + if (reads_.empty()) { + DCHECK(!has_pending_read_); + has_pending_read_ = true; + const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING); + return pending_read; + } + net::MockRead mock_read = reads_.front(); + reads_.pop_front(); + return mock_read; + } + + // Simply pops the next write and, if applicable, compares it to + // |data|. + virtual net::MockWriteResult OnWrite(const std::string& data) { + DCHECK(!writes_.empty()); + net::MockWrite mock_write = writes_.front(); + writes_.pop_front(); + if (mock_write.result != net::OK) { + return net::MockWriteResult(mock_write.mode, mock_write.result); + } + std::string expected_data(mock_write.data, mock_write.data_len); + EXPECT_EQ(expected_data, data); + if (expected_data != data) { + return net::MockWriteResult(net::SYNCHRONOUS, net::ERR_UNEXPECTED); + } + return net::MockWriteResult(mock_write.mode, data.size()); + } + + // We ignore resets so we can pre-load the socket data provider with + // read/write events. + virtual void Reset() {} + + // If there is a pending read, completes it with the given read. + // Otherwise, queues up the given read. + void AddRead(const net::MockRead& mock_read) { + DCHECK_NE(mock_read.result, net::ERR_IO_PENDING); + if (has_pending_read_) { + socket()->OnReadComplete(mock_read); + has_pending_read_ = false; + return; + } + reads_.push_back(mock_read); + } + + // Simply queues up the given write. + void AddWrite(const net::MockWrite& mock_write) { + writes_.push_back(mock_write); + } + + private: + std::deque<net::MockRead> reads_; + bool has_pending_read_; + + std::deque<net::MockWrite> writes_; + + DISALLOW_COPY_AND_ASSIGN(AsyncSocketDataProvider); +}; + +class MockXmppClientSocketFactory : public ResolvingClientSocketFactory { + public: + MockXmppClientSocketFactory( + net::ClientSocketFactory* mock_client_socket_factory, + const net::AddressList& address_list) + : mock_client_socket_factory_(mock_client_socket_factory), + address_list_(address_list), + cert_verifier_(new net::MockCertVerifier) { + } + + // ResolvingClientSocketFactory implementation. + virtual net::StreamSocket* CreateTransportClientSocket( + const net::HostPortPair& host_and_port) { + return mock_client_socket_factory_->CreateTransportClientSocket( + address_list_, NULL, net::NetLog::Source()); + } + + virtual net::SSLClientSocket* CreateSSLClientSocket( + net::ClientSocketHandle* transport_socket, + const net::HostPortPair& host_and_port) { + net::SSLClientSocketContext context; + context.cert_verifier = cert_verifier_.get(); + return mock_client_socket_factory_->CreateSSLClientSocket( + transport_socket, host_and_port, ssl_config_, context); + } + + private: + scoped_ptr<net::ClientSocketFactory> mock_client_socket_factory_; + net::AddressList address_list_; + net::SSLConfig ssl_config_; + scoped_ptr<net::CertVerifier> cert_verifier_; +}; + +class ChromeAsyncSocketTest + : public testing::Test, + public sigslot::has_slots<> { + protected: + ChromeAsyncSocketTest() + : ssl_socket_data_provider_(net::ASYNC, net::OK), + addr_("localhost", 35) {} + + virtual ~ChromeAsyncSocketTest() {} + + virtual void SetUp() { + scoped_ptr<net::MockClientSocketFactory> mock_client_socket_factory( + new net::MockClientSocketFactory()); + mock_client_socket_factory->AddSocketDataProvider( + &async_socket_data_provider_); + mock_client_socket_factory->AddSSLSocketDataProvider( + &ssl_socket_data_provider_); + + // Fake DNS resolution for |addr_| and pass it to the factory. + net::IPAddressNumber resolved_addr; + EXPECT_TRUE(net::ParseIPLiteralToNumber("127.0.0.1", &resolved_addr)); + const net::AddressList address_list = + net::AddressList::CreateFromIPAddress(resolved_addr, addr_.port()); + scoped_ptr<MockXmppClientSocketFactory> mock_xmpp_client_socket_factory( + new MockXmppClientSocketFactory( + mock_client_socket_factory.release(), + address_list)); + chrome_async_socket_.reset( + new ChromeAsyncSocket(mock_xmpp_client_socket_factory.release(), + 14, 20)), + + chrome_async_socket_->SignalConnected.connect( + this, &ChromeAsyncSocketTest::OnConnect); + chrome_async_socket_->SignalSSLConnected.connect( + this, &ChromeAsyncSocketTest::OnSSLConnect); + chrome_async_socket_->SignalClosed.connect( + this, &ChromeAsyncSocketTest::OnClose); + chrome_async_socket_->SignalRead.connect( + this, &ChromeAsyncSocketTest::OnRead); + chrome_async_socket_->SignalError.connect( + this, &ChromeAsyncSocketTest::OnError); + } + + virtual void TearDown() { + // Run any tasks that we forgot to pump. + message_loop_.RunAllPending(); + ExpectClosed(); + ExpectNoSignal(); + chrome_async_socket_.reset(); + } + + enum Signal { + SIGNAL_CONNECT, + SIGNAL_SSL_CONNECT, + SIGNAL_CLOSE, + SIGNAL_READ, + SIGNAL_ERROR, + }; + + // Helper struct that records the state at the time of a signal. + + struct SignalSocketState { + SignalSocketState() + : signal(SIGNAL_ERROR), + state(ChromeAsyncSocket::STATE_CLOSED), + error(ChromeAsyncSocket::ERROR_NONE), + net_error(net::OK) {} + + SignalSocketState( + Signal signal, + ChromeAsyncSocket::State state, + ChromeAsyncSocket::Error error, + net::Error net_error) + : signal(signal), + state(state), + error(error), + net_error(net_error) {} + + bool IsEqual(const SignalSocketState& other) const { + return + (signal == other.signal) && + (state == other.state) && + (error == other.error) && + (net_error == other.net_error); + } + + static SignalSocketState FromAsyncSocket( + Signal signal, + buzz::AsyncSocket* async_socket) { + return SignalSocketState(signal, + async_socket->state(), + async_socket->error(), + static_cast<net::Error>( + async_socket->GetError())); + } + + static SignalSocketState NoError( + Signal signal, buzz::AsyncSocket::State state) { + return SignalSocketState(signal, state, + buzz::AsyncSocket::ERROR_NONE, + net::OK); + } + + Signal signal; + ChromeAsyncSocket::State state; + ChromeAsyncSocket::Error error; + net::Error net_error; + }; + + void CaptureSocketState(Signal signal) { + signal_socket_states_.push_back( + SignalSocketState::FromAsyncSocket( + signal, chrome_async_socket_.get())); + } + + void OnConnect() { + CaptureSocketState(SIGNAL_CONNECT); + } + + void OnSSLConnect() { + CaptureSocketState(SIGNAL_SSL_CONNECT); + } + + void OnClose() { + CaptureSocketState(SIGNAL_CLOSE); + } + + void OnRead() { + CaptureSocketState(SIGNAL_READ); + } + + void OnError() { + ADD_FAILURE(); + } + + // State expect functions. + + void ExpectState(ChromeAsyncSocket::State state, + ChromeAsyncSocket::Error error, + net::Error net_error) { + EXPECT_EQ(state, chrome_async_socket_->state()); + EXPECT_EQ(error, chrome_async_socket_->error()); + EXPECT_EQ(net_error, chrome_async_socket_->GetError()); + } + + void ExpectNonErrorState(ChromeAsyncSocket::State state) { + ExpectState(state, ChromeAsyncSocket::ERROR_NONE, net::OK); + } + + void ExpectErrorState(ChromeAsyncSocket::State state, + ChromeAsyncSocket::Error error) { + ExpectState(state, error, net::OK); + } + + void ExpectClosed() { + ExpectNonErrorState(ChromeAsyncSocket::STATE_CLOSED); + } + + // Signal expect functions. + + void ExpectNoSignal() { + if (!signal_socket_states_.empty()) { + ADD_FAILURE() << signal_socket_states_.front().signal; + } + } + + void ExpectSignalSocketState( + SignalSocketState expected_signal_socket_state) { + if (signal_socket_states_.empty()) { + ADD_FAILURE() << expected_signal_socket_state.signal; + return; + } + EXPECT_TRUE(expected_signal_socket_state.IsEqual( + signal_socket_states_.front())) + << signal_socket_states_.front().signal; + signal_socket_states_.pop_front(); + } + + void ExpectReadSignal() { + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_READ, ChromeAsyncSocket::STATE_OPEN)); + } + + void ExpectSSLConnectSignal() { + ExpectSignalSocketState( + SignalSocketState::NoError(SIGNAL_SSL_CONNECT, + ChromeAsyncSocket::STATE_TLS_OPEN)); + } + + void ExpectSSLReadSignal() { + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_READ, ChromeAsyncSocket::STATE_TLS_OPEN)); + } + + // Open/close utility functions. + + void DoOpenClosed() { + ExpectClosed(); + async_socket_data_provider_.set_connect_data( + net::MockConnect(net::SYNCHRONOUS, net::OK)); + EXPECT_TRUE(chrome_async_socket_->Connect(addr_)); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CONNECTING); + + message_loop_.RunAllPending(); + // We may not necessarily be open; may have been other events + // queued up. + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CONNECT, ChromeAsyncSocket::STATE_OPEN)); + } + + void DoCloseOpened(SignalSocketState expected_signal_socket_state) { + // We may be in an error state, so just compare state(). + EXPECT_EQ(ChromeAsyncSocket::STATE_OPEN, chrome_async_socket_->state()); + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectSignalSocketState(expected_signal_socket_state); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CLOSED); + } + + void DoCloseOpenedNoError() { + DoCloseOpened( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); + } + + void DoSSLOpenClosed() { + const char kDummyData[] = "dummy_data"; + async_socket_data_provider_.AddRead(net::MockRead(kDummyData)); + DoOpenClosed(); + ExpectReadSignal(); + EXPECT_EQ(kDummyData, DrainRead(1)); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + message_loop_.RunAllPending(); + ExpectSSLConnectSignal(); + ExpectNoSignal(); + ExpectNonErrorState(ChromeAsyncSocket::STATE_TLS_OPEN); + } + + void DoSSLCloseOpened(SignalSocketState expected_signal_socket_state) { + // We may be in an error state, so just compare state(). + EXPECT_EQ(ChromeAsyncSocket::STATE_TLS_OPEN, + chrome_async_socket_->state()); + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectSignalSocketState(expected_signal_socket_state); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CLOSED); + } + + void DoSSLCloseOpenedNoError() { + DoSSLCloseOpened( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); + } + + // Read utility fucntions. + + std::string DrainRead(size_t buf_size) { + std::string read; + scoped_array<char> buf(new char[buf_size]); + size_t len_read; + while (true) { + bool success = + chrome_async_socket_->Read(buf.get(), buf_size, &len_read); + if (!success) { + ADD_FAILURE(); + break; + } + if (len_read == 0U) { + break; + } + read.append(buf.get(), len_read); + } + return read; + } + + // ChromeAsyncSocket expects a message loop. + MessageLoop message_loop_; + + AsyncSocketDataProvider async_socket_data_provider_; + net::SSLSocketDataProvider ssl_socket_data_provider_; + + scoped_ptr<ChromeAsyncSocket> chrome_async_socket_; + std::deque<SignalSocketState> signal_socket_states_; + const talk_base::SocketAddress addr_; + + private: + DISALLOW_COPY_AND_ASSIGN(ChromeAsyncSocketTest); +}; + +TEST_F(ChromeAsyncSocketTest, InitialState) { + ExpectClosed(); + ExpectNoSignal(); +} + +TEST_F(ChromeAsyncSocketTest, EmptyClose) { + ExpectClosed(); + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); +} + +TEST_F(ChromeAsyncSocketTest, ImmediateConnectAndClose) { + DoOpenClosed(); + + ExpectNonErrorState(ChromeAsyncSocket::STATE_OPEN); + + DoCloseOpenedNoError(); +} + +// After this, no need to test immediate successful connect and +// Close() so thoroughly. + +TEST_F(ChromeAsyncSocketTest, DoubleClose) { + DoOpenClosed(); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); +} + +TEST_F(ChromeAsyncSocketTest, NoHostnameConnect) { + talk_base::IPAddress ip_address; + EXPECT_TRUE(talk_base::IPFromString("127.0.0.1", &ip_address)); + const talk_base::SocketAddress no_hostname_addr(ip_address, addr_.port()); + EXPECT_FALSE(chrome_async_socket_->Connect(no_hostname_addr)); + ExpectErrorState(ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_DNS); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); +} + +TEST_F(ChromeAsyncSocketTest, ZeroPortConnect) { + const talk_base::SocketAddress zero_port_addr(addr_.hostname(), 0); + EXPECT_FALSE(chrome_async_socket_->Connect(zero_port_addr)); + ExpectErrorState(ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_DNS); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); +} + +TEST_F(ChromeAsyncSocketTest, DoubleConnect) { + EXPECT_DEBUG_DEATH({ + DoOpenClosed(); + + EXPECT_FALSE(chrome_async_socket_->Connect(addr_)); + ExpectErrorState(ChromeAsyncSocket::STATE_OPEN, + ChromeAsyncSocket::ERROR_WRONGSTATE); + + DoCloseOpened( + SignalSocketState(SIGNAL_CLOSE, + ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE, + net::OK)); + }, "non-closed socket"); +} + +TEST_F(ChromeAsyncSocketTest, ImmediateConnectCloseBeforeRead) { + DoOpenClosed(); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); + + message_loop_.RunAllPending(); +} + +TEST_F(ChromeAsyncSocketTest, HangingConnect) { + EXPECT_TRUE(chrome_async_socket_->Connect(addr_)); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CONNECTING); + ExpectNoSignal(); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); +} + +TEST_F(ChromeAsyncSocketTest, PendingConnect) { + async_socket_data_provider_.set_connect_data( + net::MockConnect(net::ASYNC, net::OK)); + EXPECT_TRUE(chrome_async_socket_->Connect(addr_)); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CONNECTING); + ExpectNoSignal(); + + message_loop_.RunAllPending(); + ExpectNonErrorState(ChromeAsyncSocket::STATE_OPEN); + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CONNECT, ChromeAsyncSocket::STATE_OPEN)); + ExpectNoSignal(); + + message_loop_.RunAllPending(); + + DoCloseOpenedNoError(); +} + +// After this no need to test successful pending connect so +// thoroughly. + +TEST_F(ChromeAsyncSocketTest, PendingConnectCloseBeforeRead) { + async_socket_data_provider_.set_connect_data( + net::MockConnect(net::ASYNC, net::OK)); + EXPECT_TRUE(chrome_async_socket_->Connect(addr_)); + + message_loop_.RunAllPending(); + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CONNECT, ChromeAsyncSocket::STATE_OPEN)); + + DoCloseOpenedNoError(); + + message_loop_.RunAllPending(); +} + +TEST_F(ChromeAsyncSocketTest, PendingConnectError) { + async_socket_data_provider_.set_connect_data( + net::MockConnect(net::ASYNC, net::ERR_TIMED_OUT)); + EXPECT_TRUE(chrome_async_socket_->Connect(addr_)); + + message_loop_.RunAllPending(); + + ExpectSignalSocketState( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, net::ERR_TIMED_OUT)); +} + +// After this we can assume Connect() and Close() work as expected. + +TEST_F(ChromeAsyncSocketTest, EmptyRead) { + DoOpenClosed(); + + char buf[4096]; + size_t len_read = 10000U; + EXPECT_TRUE(chrome_async_socket_->Read(buf, sizeof(buf), &len_read)); + EXPECT_EQ(0U, len_read); + + DoCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, WrongRead) { + EXPECT_DEBUG_DEATH({ + async_socket_data_provider_.set_connect_data( + net::MockConnect(net::ASYNC, net::OK)); + EXPECT_TRUE(chrome_async_socket_->Connect(addr_)); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CONNECTING); + ExpectNoSignal(); + + char buf[4096]; + size_t len_read; + EXPECT_FALSE(chrome_async_socket_->Read(buf, sizeof(buf), &len_read)); + ExpectErrorState(ChromeAsyncSocket::STATE_CONNECTING, + ChromeAsyncSocket::ERROR_WRONGSTATE); + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectSignalSocketState( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE, net::OK)); + }, "non-open"); +} + +TEST_F(ChromeAsyncSocketTest, WrongReadClosed) { + char buf[4096]; + size_t len_read; + EXPECT_FALSE(chrome_async_socket_->Read(buf, sizeof(buf), &len_read)); + ExpectErrorState(ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE); + EXPECT_TRUE(chrome_async_socket_->Close()); +} + +const char kReadData[] = "mydatatoread"; + +TEST_F(ChromeAsyncSocketTest, Read) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + + ExpectReadSignal(); + ExpectNoSignal(); + + EXPECT_EQ(kReadData, DrainRead(1)); + + message_loop_.RunAllPending(); + + DoCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, ReadTwice) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + + ExpectReadSignal(); + ExpectNoSignal(); + + EXPECT_EQ(kReadData, DrainRead(1)); + + message_loop_.RunAllPending(); + + const char kReadData2[] = "mydatatoread2"; + async_socket_data_provider_.AddRead(net::MockRead(kReadData2)); + + ExpectReadSignal(); + ExpectNoSignal(); + + EXPECT_EQ(kReadData2, DrainRead(1)); + + DoCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, ReadError) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + + ExpectReadSignal(); + ExpectNoSignal(); + + EXPECT_EQ(kReadData, DrainRead(1)); + + message_loop_.RunAllPending(); + + async_socket_data_provider_.AddRead( + net::MockRead(net::SYNCHRONOUS, net::ERR_TIMED_OUT)); + + ExpectSignalSocketState( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, net::ERR_TIMED_OUT)); +} + +TEST_F(ChromeAsyncSocketTest, ReadEmpty) { + async_socket_data_provider_.AddRead(net::MockRead("")); + DoOpenClosed(); + + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); +} + +TEST_F(ChromeAsyncSocketTest, PendingRead) { + DoOpenClosed(); + + ExpectNoSignal(); + + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_READ, ChromeAsyncSocket::STATE_OPEN)); + ExpectNoSignal(); + + EXPECT_EQ(kReadData, DrainRead(1)); + + message_loop_.RunAllPending(); + + DoCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, PendingEmptyRead) { + DoOpenClosed(); + + ExpectNoSignal(); + + async_socket_data_provider_.AddRead(net::MockRead("")); + + ExpectSignalSocketState( + SignalSocketState::NoError( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED)); +} + +TEST_F(ChromeAsyncSocketTest, PendingReadError) { + DoOpenClosed(); + + ExpectNoSignal(); + + async_socket_data_provider_.AddRead( + net::MockRead(net::ASYNC, net::ERR_TIMED_OUT)); + + ExpectSignalSocketState( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, net::ERR_TIMED_OUT)); +} + +// After this we can assume non-SSL Read() works as expected. + +TEST_F(ChromeAsyncSocketTest, WrongWrite) { + EXPECT_DEBUG_DEATH({ + std::string data("foo"); + EXPECT_FALSE(chrome_async_socket_->Write(data.data(), data.size())); + ExpectErrorState(ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE); + EXPECT_TRUE(chrome_async_socket_->Close()); + }, "non-open"); +} + +const char kWriteData[] = "mydatatowrite"; + +TEST_F(ChromeAsyncSocketTest, SyncWrite) { + async_socket_data_provider_.AddWrite( + net::MockWrite(net::SYNCHRONOUS, kWriteData, 3)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::SYNCHRONOUS, kWriteData + 3, 5)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::SYNCHRONOUS, + kWriteData + 8, arraysize(kWriteData) - 8)); + DoOpenClosed(); + + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 3, 5)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 8, + arraysize(kWriteData) - 8)); + message_loop_.RunAllPending(); + + ExpectNoSignal(); + + DoCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, AsyncWrite) { + DoOpenClosed(); + + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData, 3)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData + 3, 5)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData + 8, arraysize(kWriteData) - 8)); + + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 3, 5)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 8, + arraysize(kWriteData) - 8)); + message_loop_.RunAllPending(); + + ExpectNoSignal(); + + DoCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, AsyncWriteError) { + DoOpenClosed(); + + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData, 3)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData + 3, 5)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, net::ERR_TIMED_OUT)); + + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 3, 5)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 8, + arraysize(kWriteData) - 8)); + message_loop_.RunAllPending(); + + ExpectSignalSocketState( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, net::ERR_TIMED_OUT)); +} + +TEST_F(ChromeAsyncSocketTest, LargeWrite) { + EXPECT_DEBUG_DEATH({ + DoOpenClosed(); + + std::string large_data(100, 'x'); + EXPECT_FALSE(chrome_async_socket_->Write(large_data.data(), + large_data.size())); + ExpectState(ChromeAsyncSocket::STATE_OPEN, + ChromeAsyncSocket::ERROR_WINSOCK, + net::ERR_INSUFFICIENT_RESOURCES); + DoCloseOpened( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, + net::ERR_INSUFFICIENT_RESOURCES)); + }, "exceed the max write buffer"); +} + +TEST_F(ChromeAsyncSocketTest, LargeAccumulatedWrite) { + EXPECT_DEBUG_DEATH({ + DoOpenClosed(); + + std::string data(15, 'x'); + EXPECT_TRUE(chrome_async_socket_->Write(data.data(), data.size())); + EXPECT_FALSE(chrome_async_socket_->Write(data.data(), data.size())); + ExpectState(ChromeAsyncSocket::STATE_OPEN, + ChromeAsyncSocket::ERROR_WINSOCK, + net::ERR_INSUFFICIENT_RESOURCES); + DoCloseOpened( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, + net::ERR_INSUFFICIENT_RESOURCES)); + }, "exceed the max write buffer"); +} + +// After this we can assume non-SSL I/O works as expected. + +TEST_F(ChromeAsyncSocketTest, HangingSSLConnect) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + ExpectReadSignal(); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + ExpectNoSignal(); + + ExpectNonErrorState(ChromeAsyncSocket::STATE_TLS_CONNECTING); + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectSignalSocketState( + SignalSocketState::NoError(SIGNAL_CLOSE, + ChromeAsyncSocket::STATE_CLOSED)); + ExpectNonErrorState(ChromeAsyncSocket::STATE_CLOSED); +} + +TEST_F(ChromeAsyncSocketTest, ImmediateSSLConnect) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + ExpectReadSignal(); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + message_loop_.RunAllPending(); + ExpectSSLConnectSignal(); + ExpectNoSignal(); + ExpectNonErrorState(ChromeAsyncSocket::STATE_TLS_OPEN); + + DoSSLCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, DoubleSSLConnect) { + EXPECT_DEBUG_DEATH({ + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + ExpectReadSignal(); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + message_loop_.RunAllPending(); + ExpectSSLConnectSignal(); + ExpectNoSignal(); + ExpectNonErrorState(ChromeAsyncSocket::STATE_TLS_OPEN); + + EXPECT_FALSE(chrome_async_socket_->StartTls("fakedomain.com")); + + DoSSLCloseOpened( + SignalSocketState(SIGNAL_CLOSE, + ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE, + net::OK)); + }, "wrong state"); +} + +TEST_F(ChromeAsyncSocketTest, FailedSSLConnect) { + ssl_socket_data_provider_.connect = + net::MockConnect(net::ASYNC, net::ERR_CERT_COMMON_NAME_INVALID), + + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + ExpectReadSignal(); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + message_loop_.RunAllPending(); + ExpectSignalSocketState( + SignalSocketState( + SIGNAL_CLOSE, ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WINSOCK, + net::ERR_CERT_COMMON_NAME_INVALID)); + + EXPECT_TRUE(chrome_async_socket_->Close()); + ExpectClosed(); +} + +TEST_F(ChromeAsyncSocketTest, ReadDuringSSLConnecting) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + ExpectReadSignal(); + EXPECT_EQ(kReadData, DrainRead(1)); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + ExpectNoSignal(); + + // Shouldn't do anything. + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + + char buf[4096]; + size_t len_read = 10000U; + EXPECT_TRUE(chrome_async_socket_->Read(buf, sizeof(buf), &len_read)); + EXPECT_EQ(0U, len_read); + + message_loop_.RunAllPending(); + ExpectSSLConnectSignal(); + ExpectSSLReadSignal(); + ExpectNoSignal(); + ExpectNonErrorState(ChromeAsyncSocket::STATE_TLS_OPEN); + + len_read = 10000U; + EXPECT_TRUE(chrome_async_socket_->Read(buf, sizeof(buf), &len_read)); + EXPECT_EQ(kReadData, std::string(buf, len_read)); + + DoSSLCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, WriteDuringSSLConnecting) { + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + DoOpenClosed(); + ExpectReadSignal(); + + EXPECT_TRUE(chrome_async_socket_->StartTls("fakedomain.com")); + ExpectNoSignal(); + ExpectNonErrorState(ChromeAsyncSocket::STATE_TLS_CONNECTING); + + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData, 3)); + + // Shouldn't do anything. + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + + // TODO(akalin): Figure out how to test that the write happens + // *after* the SSL connect. + + message_loop_.RunAllPending(); + ExpectSSLConnectSignal(); + ExpectNoSignal(); + + message_loop_.RunAllPending(); + + DoSSLCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, SSLConnectDuringPendingRead) { + EXPECT_DEBUG_DEATH({ + DoOpenClosed(); + + EXPECT_FALSE(chrome_async_socket_->StartTls("fakedomain.com")); + + DoCloseOpened( + SignalSocketState(SIGNAL_CLOSE, + ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE, + net::OK)); + }, "wrong state"); +} + +TEST_F(ChromeAsyncSocketTest, SSLConnectDuringPostedWrite) { + EXPECT_DEBUG_DEATH({ + DoOpenClosed(); + + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData, 3)); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + + EXPECT_FALSE(chrome_async_socket_->StartTls("fakedomain.com")); + + message_loop_.RunAllPending(); + + DoCloseOpened( + SignalSocketState(SIGNAL_CLOSE, + ChromeAsyncSocket::STATE_CLOSED, + ChromeAsyncSocket::ERROR_WRONGSTATE, + net::OK)); + }, "wrong state"); +} + +// After this we can assume SSL connect works as expected. + +TEST_F(ChromeAsyncSocketTest, SSLRead) { + DoSSLOpenClosed(); + async_socket_data_provider_.AddRead(net::MockRead(kReadData)); + message_loop_.RunAllPending(); + + ExpectSSLReadSignal(); + ExpectNoSignal(); + + EXPECT_EQ(kReadData, DrainRead(1)); + + message_loop_.RunAllPending(); + + DoSSLCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, SSLSyncWrite) { + async_socket_data_provider_.AddWrite( + net::MockWrite(net::SYNCHRONOUS, kWriteData, 3)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::SYNCHRONOUS, kWriteData + 3, 5)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::SYNCHRONOUS, + kWriteData + 8, arraysize(kWriteData) - 8)); + DoSSLOpenClosed(); + + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 3, 5)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 8, + arraysize(kWriteData) - 8)); + message_loop_.RunAllPending(); + + ExpectNoSignal(); + + DoSSLCloseOpenedNoError(); +} + +TEST_F(ChromeAsyncSocketTest, SSLAsyncWrite) { + DoSSLOpenClosed(); + + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData, 3)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData + 3, 5)); + async_socket_data_provider_.AddWrite( + net::MockWrite(net::ASYNC, kWriteData + 8, arraysize(kWriteData) - 8)); + + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData, 3)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 3, 5)); + message_loop_.RunAllPending(); + EXPECT_TRUE(chrome_async_socket_->Write(kWriteData + 8, + arraysize(kWriteData) - 8)); + message_loop_.RunAllPending(); + + ExpectNoSignal(); + + DoSSLCloseOpenedNoError(); +} + +} // namespace + +} // namespace jingle_glue diff --git a/jingle/glue/fake_ssl_client_socket.cc b/jingle/glue/fake_ssl_client_socket.cc new file mode 100644 index 0000000..19beb1f --- /dev/null +++ b/jingle/glue/fake_ssl_client_socket.cc @@ -0,0 +1,353 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/fake_ssl_client_socket.h" + +#include <cstdlib> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" + +namespace jingle_glue { + +namespace { + +// The constants below were taken from libjingle's socketadapters.cc. +// Basically, we do a "fake" SSL handshake to fool proxies into +// thinking this is a real SSL connection. + +// This is a SSL v2 CLIENT_HELLO message. +// TODO(juberti): Should this have a session id? The response doesn't have a +// certificate, so the hello should have a session id. +static const uint8 kSslClientHello[] = { + 0x80, 0x46, // msg len + 0x01, // CLIENT_HELLO + 0x03, 0x01, // SSL 3.1 + 0x00, 0x2d, // ciphersuite len + 0x00, 0x00, // session id len + 0x00, 0x10, // challenge len + 0x01, 0x00, 0x80, 0x03, 0x00, 0x80, 0x07, 0x00, 0xc0, // ciphersuites + 0x06, 0x00, 0x40, 0x02, 0x00, 0x80, 0x04, 0x00, 0x80, // + 0x00, 0x00, 0x04, 0x00, 0xfe, 0xff, 0x00, 0x00, 0x0a, // + 0x00, 0xfe, 0xfe, 0x00, 0x00, 0x09, 0x00, 0x00, 0x64, // + 0x00, 0x00, 0x62, 0x00, 0x00, 0x03, 0x00, 0x00, 0x06, // + 0x1f, 0x17, 0x0c, 0xa6, 0x2f, 0x00, 0x78, 0xfc, // challenge + 0x46, 0x55, 0x2e, 0xb1, 0x83, 0x39, 0xf1, 0xea // +}; + +// This is a TLSv1 SERVER_HELLO message. +static const uint8 kSslServerHello[] = { + 0x16, // handshake message + 0x03, 0x01, // SSL 3.1 + 0x00, 0x4a, // message len + 0x02, // SERVER_HELLO + 0x00, 0x00, 0x46, // handshake len + 0x03, 0x01, // SSL 3.1 + 0x42, 0x85, 0x45, 0xa7, 0x27, 0xa9, 0x5d, 0xa0, // server random + 0xb3, 0xc5, 0xe7, 0x53, 0xda, 0x48, 0x2b, 0x3f, // + 0xc6, 0x5a, 0xca, 0x89, 0xc1, 0x58, 0x52, 0xa1, // + 0x78, 0x3c, 0x5b, 0x17, 0x46, 0x00, 0x85, 0x3f, // + 0x20, // session id len + 0x0e, 0xd3, 0x06, 0x72, 0x5b, 0x5b, 0x1b, 0x5f, // session id + 0x15, 0xac, 0x13, 0xf9, 0x88, 0x53, 0x9d, 0x9b, // + 0xe8, 0x3d, 0x7b, 0x0c, 0x30, 0x32, 0x6e, 0x38, // + 0x4d, 0xa2, 0x75, 0x57, 0x41, 0x6c, 0x34, 0x5c, // + 0x00, 0x04, // RSA/RC4-128/MD5 + 0x00 // null compression +}; + +net::DrainableIOBuffer* NewDrainableIOBufferWithSize(int size) { + return new net::DrainableIOBuffer(new net::IOBuffer(size), size); +} + +} // namespace + +base::StringPiece FakeSSLClientSocket::GetSslClientHello() { + return base::StringPiece(reinterpret_cast<const char*>(kSslClientHello), + arraysize(kSslClientHello)); +} + +base::StringPiece FakeSSLClientSocket::GetSslServerHello() { + return base::StringPiece(reinterpret_cast<const char*>(kSslServerHello), + arraysize(kSslServerHello)); +} + +FakeSSLClientSocket::FakeSSLClientSocket( + net::StreamSocket* transport_socket) + : transport_socket_(transport_socket), + next_handshake_state_(STATE_NONE), + handshake_completed_(false), + write_buf_(NewDrainableIOBufferWithSize(arraysize(kSslClientHello))), + read_buf_(NewDrainableIOBufferWithSize(arraysize(kSslServerHello))) { + CHECK(transport_socket_.get()); + std::memcpy(write_buf_->data(), kSslClientHello, arraysize(kSslClientHello)); +} + +FakeSSLClientSocket::~FakeSSLClientSocket() {} + +int FakeSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK_EQ(next_handshake_state_, STATE_NONE); + DCHECK(handshake_completed_); + return transport_socket_->Read(buf, buf_len, callback); +} + +int FakeSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + DCHECK_EQ(next_handshake_state_, STATE_NONE); + DCHECK(handshake_completed_); + return transport_socket_->Write(buf, buf_len, callback); +} + +bool FakeSSLClientSocket::SetReceiveBufferSize(int32 size) { + return transport_socket_->SetReceiveBufferSize(size); +} + +bool FakeSSLClientSocket::SetSendBufferSize(int32 size) { + return transport_socket_->SetSendBufferSize(size); +} + +int FakeSSLClientSocket::Connect(const net::CompletionCallback& callback) { + // We don't support synchronous operation, even if + // |transport_socket_| does. + DCHECK(!callback.is_null()); + DCHECK_EQ(next_handshake_state_, STATE_NONE); + DCHECK(!handshake_completed_); + DCHECK(user_connect_callback_.is_null()); + DCHECK_EQ(write_buf_->BytesConsumed(), 0); + DCHECK_EQ(read_buf_->BytesConsumed(), 0); + + next_handshake_state_ = STATE_CONNECT; + int status = DoHandshakeLoop(); + if (status == net::ERR_IO_PENDING) + user_connect_callback_ = callback; + + return status; +} + +int FakeSSLClientSocket::DoHandshakeLoop() { + DCHECK_NE(next_handshake_state_, STATE_NONE); + int status = net::OK; + do { + HandshakeState state = next_handshake_state_; + next_handshake_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT: + status = DoConnect(); + break; + case STATE_SEND_CLIENT_HELLO: + status = DoSendClientHello(); + break; + case STATE_VERIFY_SERVER_HELLO: + status = DoVerifyServerHello(); + break; + default: + status = net::ERR_UNEXPECTED; + LOG(DFATAL) << "unexpected state: " << state; + break; + } + } while ((status != net::ERR_IO_PENDING) && + (next_handshake_state_ != STATE_NONE)); + return status; +} + +void FakeSSLClientSocket::RunUserConnectCallback(int status) { + DCHECK_LE(status, net::OK); + next_handshake_state_ = STATE_NONE; + net::CompletionCallback user_connect_callback = user_connect_callback_; + user_connect_callback_.Reset(); + user_connect_callback.Run(status); +} + +void FakeSSLClientSocket::DoHandshakeLoopWithUserConnectCallback() { + int status = DoHandshakeLoop(); + if (status != net::ERR_IO_PENDING) { + RunUserConnectCallback(status); + } +} + +int FakeSSLClientSocket::DoConnect() { + int status = transport_socket_->Connect( + base::Bind(&FakeSSLClientSocket::OnConnectDone, base::Unretained(this))); + if (status != net::OK) { + return status; + } + ProcessConnectDone(); + return net::OK; +} + +void FakeSSLClientSocket::OnConnectDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK_LE(status, net::OK); + DCHECK(!user_connect_callback_.is_null()); + if (status != net::OK) { + RunUserConnectCallback(status); + return; + } + ProcessConnectDone(); + DoHandshakeLoopWithUserConnectCallback(); +} + +void FakeSSLClientSocket::ProcessConnectDone() { + DCHECK_EQ(write_buf_->BytesConsumed(), 0); + DCHECK_EQ(read_buf_->BytesConsumed(), 0); + next_handshake_state_ = STATE_SEND_CLIENT_HELLO; +} + +int FakeSSLClientSocket::DoSendClientHello() { + int status = transport_socket_->Write( + write_buf_, write_buf_->BytesRemaining(), + base::Bind(&FakeSSLClientSocket::OnSendClientHelloDone, + base::Unretained(this))); + if (status < net::OK) { + return status; + } + ProcessSendClientHelloDone(static_cast<size_t>(status)); + return net::OK; +} + +void FakeSSLClientSocket::OnSendClientHelloDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK(!user_connect_callback_.is_null()); + if (status < net::OK) { + RunUserConnectCallback(status); + return; + } + ProcessSendClientHelloDone(static_cast<size_t>(status)); + DoHandshakeLoopWithUserConnectCallback(); +} + +void FakeSSLClientSocket::ProcessSendClientHelloDone(size_t written) { + DCHECK_LE(written, static_cast<size_t>(write_buf_->BytesRemaining())); + DCHECK_EQ(read_buf_->BytesConsumed(), 0); + if (written < static_cast<size_t>(write_buf_->BytesRemaining())) { + next_handshake_state_ = STATE_SEND_CLIENT_HELLO; + write_buf_->DidConsume(written); + } else { + next_handshake_state_ = STATE_VERIFY_SERVER_HELLO; + } +} + +int FakeSSLClientSocket::DoVerifyServerHello() { + int status = transport_socket_->Read( + read_buf_, read_buf_->BytesRemaining(), + base::Bind(&FakeSSLClientSocket::OnVerifyServerHelloDone, + base::Unretained(this))); + if (status < net::OK) { + return status; + } + size_t read = static_cast<size_t>(status); + return ProcessVerifyServerHelloDone(read); +} + +void FakeSSLClientSocket::OnVerifyServerHelloDone(int status) { + DCHECK_NE(status, net::ERR_IO_PENDING); + DCHECK(!user_connect_callback_.is_null()); + if (status < net::OK) { + RunUserConnectCallback(status); + return; + } + size_t read = static_cast<size_t>(status); + status = ProcessVerifyServerHelloDone(read); + if (status < net::OK) { + RunUserConnectCallback(status); + return; + } + if (handshake_completed_) { + RunUserConnectCallback(net::OK); + } else { + DoHandshakeLoopWithUserConnectCallback(); + } +} + +net::Error FakeSSLClientSocket::ProcessVerifyServerHelloDone(size_t read) { + DCHECK_LE(read, static_cast<size_t>(read_buf_->BytesRemaining())); + if (read == 0U) { + return net::ERR_UNEXPECTED; + } + const uint8* expected_data_start = + &kSslServerHello[arraysize(kSslServerHello) - + read_buf_->BytesRemaining()]; + if (std::memcmp(expected_data_start, read_buf_->data(), read) != 0) { + return net::ERR_UNEXPECTED; + } + if (read < static_cast<size_t>(read_buf_->BytesRemaining())) { + next_handshake_state_ = STATE_VERIFY_SERVER_HELLO; + read_buf_->DidConsume(read); + } else { + next_handshake_state_ = STATE_NONE; + handshake_completed_ = true; + } + return net::OK; +} + +void FakeSSLClientSocket::Disconnect() { + transport_socket_->Disconnect(); + next_handshake_state_ = STATE_NONE; + handshake_completed_ = false; + user_connect_callback_.Reset(); + write_buf_->SetOffset(0); + read_buf_->SetOffset(0); +} + +bool FakeSSLClientSocket::IsConnected() const { + return handshake_completed_ && transport_socket_->IsConnected(); +} + +bool FakeSSLClientSocket::IsConnectedAndIdle() const { + return handshake_completed_ && transport_socket_->IsConnectedAndIdle(); +} + +int FakeSSLClientSocket::GetPeerAddress(net::IPEndPoint* address) const { + return transport_socket_->GetPeerAddress(address); +} + +int FakeSSLClientSocket::GetLocalAddress(net::IPEndPoint* address) const { + return transport_socket_->GetLocalAddress(address); +} + +const net::BoundNetLog& FakeSSLClientSocket::NetLog() const { + return transport_socket_->NetLog(); +} + +void FakeSSLClientSocket::SetSubresourceSpeculation() { + transport_socket_->SetSubresourceSpeculation(); +} + +void FakeSSLClientSocket::SetOmniboxSpeculation() { + transport_socket_->SetOmniboxSpeculation(); +} + +bool FakeSSLClientSocket::WasEverUsed() const { + return transport_socket_->WasEverUsed(); +} + +bool FakeSSLClientSocket::UsingTCPFastOpen() const { + return transport_socket_->UsingTCPFastOpen(); +} + +int64 FakeSSLClientSocket::NumBytesRead() const { + return transport_socket_->NumBytesRead(); +} + +base::TimeDelta FakeSSLClientSocket::GetConnectTimeMicros() const { + return transport_socket_->GetConnectTimeMicros(); +} + +bool FakeSSLClientSocket::WasNpnNegotiated() const { + return transport_socket_->WasNpnNegotiated(); +} + +net::NextProto FakeSSLClientSocket::GetNegotiatedProtocol() const { + return transport_socket_->GetNegotiatedProtocol(); +} + +bool FakeSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + return transport_socket_->GetSSLInfo(ssl_info); +} + +} // namespace jingle_glue diff --git a/jingle/glue/fake_ssl_client_socket.h b/jingle/glue/fake_ssl_client_socket.h new file mode 100644 index 0000000..edd3267 --- /dev/null +++ b/jingle/glue/fake_ssl_client_socket.h @@ -0,0 +1,114 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// This StreamSocket implementation is to be used with servers that +// accept connections on port 443 but don't really use SSL. For +// example, the Google Talk servers do this to bypass proxies. (The +// connection is upgraded to TLS as part of the XMPP negotiation, so +// security is preserved.) A "fake" SSL handshake is done immediately +// after connection to fool proxies into thinking that this is a real +// SSL connection. +// +// NOTE: This StreamSocket implementation does *not* do a real SSL +// handshake nor does it do any encryption! + +#ifndef JINGLE_GLUE_FAKE_SSL_CLIENT_SOCKET_H_ +#define JINGLE_GLUE_FAKE_SSL_CLIENT_SOCKET_H_ + +#include <cstddef> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/string_piece.h" +#include "net/base/completion_callback.h" +#include "net/base/net_errors.h" +#include "net/socket/stream_socket.h" + +namespace net { +class DrainableIOBuffer; +class SSLInfo; +} // namespace net + +namespace jingle_glue { + +class FakeSSLClientSocket : public net::StreamSocket { + public: + // Takes ownership of |transport_socket|. + explicit FakeSSLClientSocket(net::StreamSocket* transport_socket); + + virtual ~FakeSSLClientSocket(); + + // Exposed for testing. + static base::StringPiece GetSslClientHello(); + static base::StringPiece GetSslServerHello(); + + // net::StreamSocket implementation. + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE; + virtual const net::BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual int64 NumBytesRead() const OVERRIDE; + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; + + private: + enum HandshakeState { + STATE_NONE, + STATE_CONNECT, + STATE_SEND_CLIENT_HELLO, + STATE_VERIFY_SERVER_HELLO, + }; + + int DoHandshakeLoop(); + void RunUserConnectCallback(int status); + void DoHandshakeLoopWithUserConnectCallback(); + + int DoConnect(); + void OnConnectDone(int status); + void ProcessConnectDone(); + + int DoSendClientHello(); + void OnSendClientHelloDone(int status); + void ProcessSendClientHelloDone(size_t written); + + int DoVerifyServerHello(); + void OnVerifyServerHelloDone(int status); + net::Error ProcessVerifyServerHelloDone(size_t read); + + scoped_ptr<net::StreamSocket> transport_socket_; + + // During the handshake process, holds a value from HandshakeState. + // STATE_NONE otherwise. + HandshakeState next_handshake_state_; + + // True iff we're connected and we've finished the handshake. + bool handshake_completed_; + + // The callback passed to Connect(). + net::CompletionCallback user_connect_callback_; + + scoped_refptr<net::DrainableIOBuffer> write_buf_; + scoped_refptr<net::DrainableIOBuffer> read_buf_; +}; + +} // namespace jingle_glue + +#endif // JINGLE_GLUE_FAKE_SSL_CLIENT_SOCKET_H_ diff --git a/jingle/glue/fake_ssl_client_socket_unittest.cc b/jingle/glue/fake_ssl_client_socket_unittest.cc new file mode 100644 index 0000000..a7ae4fa --- /dev/null +++ b/jingle/glue/fake_ssl_client_socket_unittest.cc @@ -0,0 +1,348 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/fake_ssl_client_socket.h" + +#include <algorithm> +#include <vector> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" +#include "net/base/io_buffer.h" +#include "net/base/net_log.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/stream_socket.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace jingle_glue { + +namespace { + +using ::testing::Return; +using ::testing::ReturnRef; + +// Used by RunUnsuccessfulHandshakeTestHelper. Represents where in +// the handshake step an error should be inserted. +enum HandshakeErrorLocation { + CONNECT_ERROR, + SEND_CLIENT_HELLO_ERROR, + VERIFY_SERVER_HELLO_ERROR, +}; + +// Private error codes appended to the net::Error set. +enum { + // An error representing a server hello that has been corrupted in + // transit. + ERR_MALFORMED_SERVER_HELLO = -15000, +}; + +// Used by PassThroughMethods test. +class MockClientSocket : public net::StreamSocket { + public: + virtual ~MockClientSocket() {} + + MOCK_METHOD3(Read, int(net::IOBuffer*, int, + const net::CompletionCallback&)); + MOCK_METHOD3(Write, int(net::IOBuffer*, int, + const net::CompletionCallback&)); + MOCK_METHOD1(SetReceiveBufferSize, bool(int32)); + MOCK_METHOD1(SetSendBufferSize, bool(int32)); + MOCK_METHOD1(Connect, int(const net::CompletionCallback&)); + MOCK_METHOD0(Disconnect, void()); + MOCK_CONST_METHOD0(IsConnected, bool()); + MOCK_CONST_METHOD0(IsConnectedAndIdle, bool()); + MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*)); + MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*)); + MOCK_CONST_METHOD0(NetLog, const net::BoundNetLog&()); + MOCK_METHOD0(SetSubresourceSpeculation, void()); + MOCK_METHOD0(SetOmniboxSpeculation, void()); + MOCK_CONST_METHOD0(WasEverUsed, bool()); + MOCK_CONST_METHOD0(UsingTCPFastOpen, bool()); + MOCK_CONST_METHOD0(NumBytesRead, int64()); + MOCK_CONST_METHOD0(GetConnectTimeMicros, base::TimeDelta()); + MOCK_CONST_METHOD0(WasNpnNegotiated, bool()); + MOCK_CONST_METHOD0(GetNegotiatedProtocol, net::NextProto()); + MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*)); +}; + +// Break up |data| into a bunch of chunked MockReads/Writes and push +// them onto |ops|. +void AddChunkedOps(base::StringPiece data, size_t chunk_size, net::IoMode mode, + std::vector<net::MockRead>* ops) { + DCHECK_GT(chunk_size, 0U); + size_t offset = 0; + while (offset < data.size()) { + size_t bounded_chunk_size = std::min(data.size() - offset, chunk_size); + // We take advantage of the fact that MockWrite is typedefed to + // MockRead. + ops->push_back(net::MockRead(mode, data.data() + offset, + bounded_chunk_size)); + offset += bounded_chunk_size; + } +} + +class FakeSSLClientSocketTest : public testing::Test { + protected: + FakeSSLClientSocketTest() {} + + virtual ~FakeSSLClientSocketTest() {} + + net::StreamSocket* MakeClientSocket() { + return mock_client_socket_factory_.CreateTransportClientSocket( + net::AddressList(), NULL, net::NetLog::Source()); + } + + void SetData(const net::MockConnect& mock_connect, + std::vector<net::MockRead>* reads, + std::vector<net::MockWrite>* writes) { + static_socket_data_provider_.reset( + new net::StaticSocketDataProvider( + reads->empty() ? NULL : &*reads->begin(), reads->size(), + writes->empty() ? NULL : &*writes->begin(), writes->size())); + static_socket_data_provider_->set_connect_data(mock_connect); + mock_client_socket_factory_.AddSocketDataProvider( + static_socket_data_provider_.get()); + } + + void ExpectStatus( + net::IoMode mode, int expected_status, int immediate_status, + net::TestCompletionCallback* test_completion_callback) { + if (mode == net::ASYNC) { + EXPECT_EQ(net::ERR_IO_PENDING, immediate_status); + int status = test_completion_callback->WaitForResult(); + EXPECT_EQ(expected_status, status); + } else { + EXPECT_EQ(expected_status, immediate_status); + } + } + + // Sets up the mock socket to generate a successful handshake + // (sliced up according to the parameters) and makes sure the + // FakeSSLClientSocket behaves as expected. + void RunSuccessfulHandshakeTest( + net::IoMode mode, size_t read_chunk_size, size_t write_chunk_size, + int num_resets) { + base::StringPiece ssl_client_hello = + FakeSSLClientSocket::GetSslClientHello(); + base::StringPiece ssl_server_hello = + FakeSSLClientSocket::GetSslServerHello(); + + net::MockConnect mock_connect(mode, net::OK); + std::vector<net::MockRead> reads; + std::vector<net::MockWrite> writes; + static const char kReadTestData[] = "read test data"; + static const char kWriteTestData[] = "write test data"; + for (int i = 0; i < num_resets + 1; ++i) { + SCOPED_TRACE(i); + AddChunkedOps(ssl_server_hello, read_chunk_size, mode, &reads); + AddChunkedOps(ssl_client_hello, write_chunk_size, mode, &writes); + reads.push_back( + net::MockRead(mode, kReadTestData, arraysize(kReadTestData))); + writes.push_back( + net::MockWrite(mode, kWriteTestData, arraysize(kWriteTestData))); + } + SetData(mock_connect, &reads, &writes); + + FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket()); + + for (int i = 0; i < num_resets + 1; ++i) { + SCOPED_TRACE(i); + net::TestCompletionCallback test_completion_callback; + int status = fake_ssl_client_socket.Connect( + test_completion_callback.callback()); + if (mode == net::ASYNC) { + EXPECT_FALSE(fake_ssl_client_socket.IsConnected()); + } + ExpectStatus(mode, net::OK, status, &test_completion_callback); + if (fake_ssl_client_socket.IsConnected()) { + int read_len = arraysize(kReadTestData); + int read_buf_len = 2 * read_len; + scoped_refptr<net::IOBuffer> read_buf( + new net::IOBuffer(read_buf_len)); + int read_status = fake_ssl_client_socket.Read( + read_buf, read_buf_len, test_completion_callback.callback()); + ExpectStatus(mode, read_len, read_status, &test_completion_callback); + + scoped_refptr<net::IOBuffer> write_buf( + new net::StringIOBuffer(kWriteTestData)); + int write_status = fake_ssl_client_socket.Write( + write_buf, arraysize(kWriteTestData), + test_completion_callback.callback()); + ExpectStatus(mode, arraysize(kWriteTestData), write_status, + &test_completion_callback); + } else { + ADD_FAILURE(); + } + fake_ssl_client_socket.Disconnect(); + EXPECT_FALSE(fake_ssl_client_socket.IsConnected()); + } + } + + // Sets up the mock socket to generate an unsuccessful handshake + // FakeSSLClientSocket fails as expected. + void RunUnsuccessfulHandshakeTestHelper( + net::IoMode mode, int error, HandshakeErrorLocation location) { + DCHECK_NE(error, net::OK); + base::StringPiece ssl_client_hello = + FakeSSLClientSocket::GetSslClientHello(); + base::StringPiece ssl_server_hello = + FakeSSLClientSocket::GetSslServerHello(); + + net::MockConnect mock_connect(mode, net::OK); + std::vector<net::MockRead> reads; + std::vector<net::MockWrite> writes; + const size_t kChunkSize = 1; + AddChunkedOps(ssl_server_hello, kChunkSize, mode, &reads); + AddChunkedOps(ssl_client_hello, kChunkSize, mode, &writes); + switch (location) { + case CONNECT_ERROR: + mock_connect.result = error; + writes.clear(); + reads.clear(); + break; + case SEND_CLIENT_HELLO_ERROR: { + // Use a fixed index for repeatability. + size_t index = 100 % writes.size(); + writes[index].result = error; + writes[index].data = NULL; + writes[index].data_len = 0; + writes.resize(index + 1); + reads.clear(); + break; + } + case VERIFY_SERVER_HELLO_ERROR: { + // Use a fixed index for repeatability. + size_t index = 50 % reads.size(); + if (error == ERR_MALFORMED_SERVER_HELLO) { + static const char kBadData[] = "BAD_DATA"; + reads[index].data = kBadData; + reads[index].data_len = arraysize(kBadData); + } else { + reads[index].result = error; + reads[index].data = NULL; + reads[index].data_len = 0; + } + reads.resize(index + 1); + if (error == + net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { + static const char kDummyData[] = "DUMMY"; + reads.push_back(net::MockRead(mode, kDummyData)); + } + break; + } + } + SetData(mock_connect, &reads, &writes); + + FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket()); + + // The two errors below are interpreted by FakeSSLClientSocket as + // an unexpected event. + int expected_status = + ((error == net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) || + (error == ERR_MALFORMED_SERVER_HELLO)) ? + net::ERR_UNEXPECTED : error; + + net::TestCompletionCallback test_completion_callback; + int status = fake_ssl_client_socket.Connect( + test_completion_callback.callback()); + EXPECT_FALSE(fake_ssl_client_socket.IsConnected()); + ExpectStatus(mode, expected_status, status, &test_completion_callback); + EXPECT_FALSE(fake_ssl_client_socket.IsConnected()); + } + + void RunUnsuccessfulHandshakeTest( + int error, HandshakeErrorLocation location) { + RunUnsuccessfulHandshakeTestHelper(net::SYNCHRONOUS, error, location); + RunUnsuccessfulHandshakeTestHelper(net::ASYNC, error, location); + } + + // MockTCPClientSocket needs a message loop. + MessageLoop message_loop_; + + net::MockClientSocketFactory mock_client_socket_factory_; + scoped_ptr<net::StaticSocketDataProvider> static_socket_data_provider_; +}; + +TEST_F(FakeSSLClientSocketTest, PassThroughMethods) { + MockClientSocket* mock_client_socket = new MockClientSocket(); + const int kReceiveBufferSize = 10; + const int kSendBufferSize = 20; + net::IPEndPoint ip_endpoint(net::IPAddressNumber(net::kIPv4AddressSize), 80); + const int kPeerAddress = 30; + net::BoundNetLog net_log; + EXPECT_CALL(*mock_client_socket, SetReceiveBufferSize(kReceiveBufferSize)); + EXPECT_CALL(*mock_client_socket, SetSendBufferSize(kSendBufferSize)); + EXPECT_CALL(*mock_client_socket, GetPeerAddress(&ip_endpoint)). + WillOnce(Return(kPeerAddress)); + EXPECT_CALL(*mock_client_socket, NetLog()).WillOnce(ReturnRef(net_log)); + EXPECT_CALL(*mock_client_socket, SetSubresourceSpeculation()); + EXPECT_CALL(*mock_client_socket, SetOmniboxSpeculation()); + + // Takes ownership of |mock_client_socket|. + FakeSSLClientSocket fake_ssl_client_socket(mock_client_socket); + fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize); + fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize); + EXPECT_EQ(kPeerAddress, + fake_ssl_client_socket.GetPeerAddress(&ip_endpoint)); + EXPECT_EQ(&net_log, &fake_ssl_client_socket.NetLog()); + fake_ssl_client_socket.SetSubresourceSpeculation(); + fake_ssl_client_socket.SetOmniboxSpeculation(); +} + +TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeSync) { + for (size_t i = 1; i < 100; i += 3) { + SCOPED_TRACE(i); + for (size_t j = 1; j < 100; j += 5) { + SCOPED_TRACE(j); + RunSuccessfulHandshakeTest(net::SYNCHRONOUS, i, j, 0); + } + } +} + +TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeAsync) { + for (size_t i = 1; i < 100; i += 7) { + SCOPED_TRACE(i); + for (size_t j = 1; j < 100; j += 9) { + SCOPED_TRACE(j); + RunSuccessfulHandshakeTest(net::ASYNC, i, j, 0); + } + } +} + +TEST_F(FakeSSLClientSocketTest, ResetSocket) { + RunSuccessfulHandshakeTest(net::ASYNC, 1, 2, 3); +} + +TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeConnectError) { + RunUnsuccessfulHandshakeTest(net::ERR_ACCESS_DENIED, CONNECT_ERROR); +} + +TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeWriteError) { + RunUnsuccessfulHandshakeTest(net::ERR_OUT_OF_MEMORY, + SEND_CLIENT_HELLO_ERROR); +} + +TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeReadError) { + RunUnsuccessfulHandshakeTest(net::ERR_CONNECTION_CLOSED, + VERIFY_SERVER_HELLO_ERROR); +} + +TEST_F(FakeSSLClientSocketTest, PeerClosedDuringHandshake) { + RunUnsuccessfulHandshakeTest( + net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, + VERIFY_SERVER_HELLO_ERROR); +} + +TEST_F(FakeSSLClientSocketTest, MalformedServerHello) { + RunUnsuccessfulHandshakeTest(ERR_MALFORMED_SERVER_HELLO, + VERIFY_SERVER_HELLO_ERROR); +} + +} // namespace + +} // namespace jingle_glue diff --git a/jingle/glue/proxy_resolving_client_socket.cc b/jingle/glue/proxy_resolving_client_socket.cc new file mode 100644 index 0000000..479a781 --- /dev/null +++ b/jingle/glue/proxy_resolving_client_socket.cc @@ -0,0 +1,392 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/proxy_resolving_client_socket.h" + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "googleurl/src/gurl.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/http/http_network_session.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_getter.h" + +namespace jingle_glue { + +ProxyResolvingClientSocket::ProxyResolvingClientSocket( + net::ClientSocketFactory* socket_factory, + const scoped_refptr<net::URLRequestContextGetter>& request_context_getter, + const net::SSLConfig& ssl_config, + const net::HostPortPair& dest_host_port_pair) + : ALLOW_THIS_IN_INITIALIZER_LIST(proxy_resolve_callback_( + base::Bind(&ProxyResolvingClientSocket::ProcessProxyResolveDone, + base::Unretained(this)))), + ALLOW_THIS_IN_INITIALIZER_LIST(connect_callback_( + base::Bind(&ProxyResolvingClientSocket::ProcessConnectDone, + base::Unretained(this)))), + ssl_config_(ssl_config), + pac_request_(NULL), + dest_host_port_pair_(dest_host_port_pair), + tried_direct_connect_fallback_(false), + bound_net_log_( + net::BoundNetLog::Make( + request_context_getter->GetURLRequestContext()->net_log(), + net::NetLog::SOURCE_SOCKET)), + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { + DCHECK(request_context_getter); + net::URLRequestContext* request_context = + request_context_getter->GetURLRequestContext(); + DCHECK(request_context); + DCHECK(!dest_host_port_pair_.host().empty()); + DCHECK_GT(dest_host_port_pair_.port(), 0); + net::HttpNetworkSession::Params session_params; + session_params.client_socket_factory = socket_factory; + session_params.host_resolver = request_context->host_resolver(); + session_params.cert_verifier = request_context->cert_verifier(); + // TODO(rkn): This is NULL because ServerBoundCertService is not thread safe. + session_params.server_bound_cert_service = NULL; + // transport_security_state is NULL because it's not thread safe. + session_params.transport_security_state = NULL; + session_params.proxy_service = request_context->proxy_service(); + session_params.ssl_config_service = request_context->ssl_config_service(); + session_params.http_auth_handler_factory = + request_context->http_auth_handler_factory(); + session_params.network_delegate = request_context->network_delegate(); + session_params.http_server_properties = + request_context->http_server_properties(); + session_params.net_log = request_context->net_log(); + network_session_ = new net::HttpNetworkSession(session_params); +} + +ProxyResolvingClientSocket::~ProxyResolvingClientSocket() { + Disconnect(); +} + +int ProxyResolvingClientSocket::Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) { + if (transport_.get() && transport_->socket()) + return transport_->socket()->Read(buf, buf_len, callback); + NOTREACHED(); + return net::ERR_SOCKET_NOT_CONNECTED; +} + +int ProxyResolvingClientSocket::Write( + net::IOBuffer* buf, + int buf_len, + const net::CompletionCallback& callback) { + if (transport_.get() && transport_->socket()) + return transport_->socket()->Write(buf, buf_len, callback); + NOTREACHED(); + return net::ERR_SOCKET_NOT_CONNECTED; +} + +bool ProxyResolvingClientSocket::SetReceiveBufferSize(int32 size) { + if (transport_.get() && transport_->socket()) + return transport_->socket()->SetReceiveBufferSize(size); + NOTREACHED(); + return false; +} + +bool ProxyResolvingClientSocket::SetSendBufferSize(int32 size) { + if (transport_.get() && transport_->socket()) + return transport_->socket()->SetSendBufferSize(size); + NOTREACHED(); + return false; +} + +int ProxyResolvingClientSocket::Connect( + const net::CompletionCallback& callback) { + DCHECK(user_connect_callback_.is_null()); + + tried_direct_connect_fallback_ = false; + + // First we try and resolve the proxy. + GURL url("http://" + dest_host_port_pair_.ToString()); + DCHECK(url.is_valid()); + int status = network_session_->proxy_service()->ResolveProxy( + url, + &proxy_info_, + proxy_resolve_callback_, + &pac_request_, + bound_net_log_); + if (status != net::ERR_IO_PENDING) { + // We defer execution of ProcessProxyResolveDone instead of calling it + // directly here for simplicity. From the caller's point of view, + // the connect always happens asynchronously. + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ProxyResolvingClientSocket::ProcessProxyResolveDone, + weak_factory_.GetWeakPtr(), status)); + } + user_connect_callback_ = callback; + return net::ERR_IO_PENDING; +} + +void ProxyResolvingClientSocket::RunUserConnectCallback(int status) { + DCHECK_LE(status, net::OK); + net::CompletionCallback user_connect_callback = user_connect_callback_; + user_connect_callback_.Reset(); + user_connect_callback.Run(status); +} + +// Always runs asynchronously. +void ProxyResolvingClientSocket::ProcessProxyResolveDone(int status) { + pac_request_ = NULL; + + DCHECK_NE(status, net::ERR_IO_PENDING); + if (status == net::OK) { + // Remove unsupported proxies from the list. + proxy_info_.RemoveProxiesWithoutScheme( + net::ProxyServer::SCHEME_DIRECT | + net::ProxyServer::SCHEME_HTTP | net::ProxyServer::SCHEME_HTTPS | + net::ProxyServer::SCHEME_SOCKS4 | net::ProxyServer::SCHEME_SOCKS5); + + if (proxy_info_.is_empty()) { + // No proxies/direct to choose from. This happens when we don't support + // any of the proxies in the returned list. + status = net::ERR_NO_SUPPORTED_PROXIES; + } + } + + // Since we are faking the URL, it is possible that no proxies match our URL. + // Try falling back to a direct connection if we have not tried that before. + if (status != net::OK) { + if (!tried_direct_connect_fallback_) { + tried_direct_connect_fallback_ = true; + proxy_info_.UseDirect(); + } else { + CloseTransportSocket(); + RunUserConnectCallback(status); + return; + } + } + + transport_.reset(new net::ClientSocketHandle); + // Now that we have resolved the proxy, we need to connect. + status = net::InitSocketHandleForRawConnect( + dest_host_port_pair_, network_session_.get(), proxy_info_, ssl_config_, + ssl_config_, bound_net_log_, transport_.get(), connect_callback_); + if (status != net::ERR_IO_PENDING) { + // Since this method is always called asynchronously. it is OK to call + // ProcessConnectDone synchronously. + ProcessConnectDone(status); + } +} + +void ProxyResolvingClientSocket::ProcessConnectDone(int status) { + if (status != net::OK) { + // If the connection fails, try another proxy. + status = ReconsiderProxyAfterError(status); + // ReconsiderProxyAfterError either returns an error (in which case it is + // not reconsidering a proxy) or returns ERR_IO_PENDING if it is considering + // another proxy. + DCHECK_NE(status, net::OK); + if (status == net::ERR_IO_PENDING) + // Proxy reconsideration pending. Return. + return; + CloseTransportSocket(); + } else { + ReportSuccessfulProxyConnection(); + } + RunUserConnectCallback(status); +} + +// TODO(sanjeevr): This has largely been copied from +// HttpStreamFactoryImpl::Job::ReconsiderProxyAfterError. This should be +// refactored into some common place. +// This method reconsiders the proxy on certain errors. If it does reconsider +// a proxy it always returns ERR_IO_PENDING and posts a call to +// ProcessProxyResolveDone with the result of the reconsideration. +int ProxyResolvingClientSocket::ReconsiderProxyAfterError(int error) { + DCHECK(!pac_request_); + DCHECK_NE(error, net::OK); + DCHECK_NE(error, net::ERR_IO_PENDING); + // A failure to resolve the hostname or any error related to establishing a + // TCP connection could be grounds for trying a new proxy configuration. + // + // Why do this when a hostname cannot be resolved? Some URLs only make sense + // to proxy servers. The hostname in those URLs might fail to resolve if we + // are still using a non-proxy config. We need to check if a proxy config + // now exists that corresponds to a proxy server that could load the URL. + // + switch (error) { + case net::ERR_PROXY_CONNECTION_FAILED: + case net::ERR_NAME_NOT_RESOLVED: + case net::ERR_INTERNET_DISCONNECTED: + case net::ERR_ADDRESS_UNREACHABLE: + case net::ERR_CONNECTION_CLOSED: + case net::ERR_CONNECTION_RESET: + case net::ERR_CONNECTION_REFUSED: + case net::ERR_CONNECTION_ABORTED: + case net::ERR_TIMED_OUT: + case net::ERR_TUNNEL_CONNECTION_FAILED: + case net::ERR_SOCKS_CONNECTION_FAILED: + break; + case net::ERR_SOCKS_CONNECTION_HOST_UNREACHABLE: + // Remap the SOCKS-specific "host unreachable" error to a more + // generic error code (this way consumers like the link doctor + // know to substitute their error page). + // + // Note that if the host resolving was done by the SOCSK5 proxy, we can't + // differentiate between a proxy-side "host not found" versus a proxy-side + // "address unreachable" error, and will report both of these failures as + // ERR_ADDRESS_UNREACHABLE. + return net::ERR_ADDRESS_UNREACHABLE; + default: + return error; + } + + if (proxy_info_.is_https() && ssl_config_.send_client_cert) { + network_session_->ssl_client_auth_cache()->Remove( + proxy_info_.proxy_server().host_port_pair().ToString()); + } + + GURL url("http://" + dest_host_port_pair_.ToString()); + int rv = network_session_->proxy_service()->ReconsiderProxyAfterError( + url, &proxy_info_, proxy_resolve_callback_, &pac_request_, + bound_net_log_); + if (rv == net::OK || rv == net::ERR_IO_PENDING) { + CloseTransportSocket(); + } else { + // If ReconsiderProxyAfterError() failed synchronously, it means + // there was nothing left to fall-back to, so fail the transaction + // with the last connection error we got. + rv = error; + } + + // We either have new proxy info or there was an error in falling back. + // In both cases we want to post ProcessProxyResolveDone (in the error case + // we might still want to fall back a direct connection). + if (rv != net::ERR_IO_PENDING) { + MessageLoop* message_loop = MessageLoop::current(); + CHECK(message_loop); + message_loop->PostTask( + FROM_HERE, + base::Bind(&ProxyResolvingClientSocket::ProcessProxyResolveDone, + weak_factory_.GetWeakPtr(), rv)); + // Since we potentially have another try to go (trying the direct connect) + // set the return code code to ERR_IO_PENDING. + rv = net::ERR_IO_PENDING; + } + return rv; +} + +void ProxyResolvingClientSocket::ReportSuccessfulProxyConnection() { + network_session_->proxy_service()->ReportSuccess(proxy_info_); +} + +void ProxyResolvingClientSocket::Disconnect() { + CloseTransportSocket(); + if (pac_request_) + network_session_->proxy_service()->CancelPacRequest(pac_request_); + user_connect_callback_.Reset(); +} + +bool ProxyResolvingClientSocket::IsConnected() const { + if (!transport_.get() || !transport_->socket()) + return false; + return transport_->socket()->IsConnected(); +} + +bool ProxyResolvingClientSocket::IsConnectedAndIdle() const { + if (!transport_.get() || !transport_->socket()) + return false; + return transport_->socket()->IsConnectedAndIdle(); +} + +int ProxyResolvingClientSocket::GetPeerAddress( + net::IPEndPoint* address) const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->GetPeerAddress(address); + NOTREACHED(); + return net::ERR_SOCKET_NOT_CONNECTED; +} + +int ProxyResolvingClientSocket::GetLocalAddress( + net::IPEndPoint* address) const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->GetLocalAddress(address); + NOTREACHED(); + return net::ERR_SOCKET_NOT_CONNECTED; +} + +const net::BoundNetLog& ProxyResolvingClientSocket::NetLog() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->NetLog(); + NOTREACHED(); + return bound_net_log_; +} + +void ProxyResolvingClientSocket::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) + transport_->socket()->SetSubresourceSpeculation(); + else + NOTREACHED(); +} + +void ProxyResolvingClientSocket::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) + transport_->socket()->SetOmniboxSpeculation(); + else + NOTREACHED(); +} + +bool ProxyResolvingClientSocket::WasEverUsed() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->WasEverUsed(); + NOTREACHED(); + return false; +} + +bool ProxyResolvingClientSocket::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->UsingTCPFastOpen(); + NOTREACHED(); + return false; +} + +int64 ProxyResolvingClientSocket::NumBytesRead() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->NumBytesRead(); + NOTREACHED(); + return -1; +} + +base::TimeDelta ProxyResolvingClientSocket::GetConnectTimeMicros() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->GetConnectTimeMicros(); + NOTREACHED(); + return base::TimeDelta::FromMicroseconds(-1); +} + +bool ProxyResolvingClientSocket::WasNpnNegotiated() const { + return false; +} + +net::NextProto ProxyResolvingClientSocket::GetNegotiatedProtocol() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->GetNegotiatedProtocol(); + NOTREACHED(); + return net::kProtoUnknown; +} + +bool ProxyResolvingClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + return false; +} + +void ProxyResolvingClientSocket::CloseTransportSocket() { + if (transport_.get() && transport_->socket()) + transport_->socket()->Disconnect(); + transport_.reset(); +} + +} // namespace jingle_glue diff --git a/jingle/glue/proxy_resolving_client_socket.h b/jingle/glue/proxy_resolving_client_socket.h new file mode 100644 index 0000000..99f342c --- /dev/null +++ b/jingle/glue/proxy_resolving_client_socket.h @@ -0,0 +1,106 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// This StreamSocket implementation wraps a ClientSocketHandle that is created +// from the client socket pool after resolving proxies. + +#ifndef JINGLE_GLUE_PROXY_RESOLVING_CLIENT_SOCKET_H_ +#define JINGLE_GLUE_PROXY_RESOLVING_CLIENT_SOCKET_H_ + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/ssl_config_service.h" +#include "net/proxy/proxy_info.h" +#include "net/proxy/proxy_service.h" +#include "net/socket/stream_socket.h" + +namespace net { +class ClientSocketFactory; +class ClientSocketHandle; +class HttpNetworkSession; +class URLRequestContextGetter; +} // namespace net + +// TODO(sanjeevr): Move this to net/ +namespace jingle_glue { + +class ProxyResolvingClientSocket : public net::StreamSocket { + public: + // Constructs a new ProxyResolvingClientSocket. |socket_factory| is + // the ClientSocketFactory that will be used by the underlying + // HttpNetworkSession. If |socket_factory| is NULL, the default + // socket factory (net::ClientSocketFactory::GetDefaultFactory()) + // will be used. |dest_host_port_pair| is the destination for this + // socket. The hostname must be non-empty and the port must be > 0. + ProxyResolvingClientSocket( + net::ClientSocketFactory* socket_factory, + const scoped_refptr<net::URLRequestContextGetter>& request_context_getter, + const net::SSLConfig& ssl_config, + const net::HostPortPair& dest_host_port_pair); + virtual ~ProxyResolvingClientSocket(); + + // net::StreamSocket implementation. + virtual int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; + virtual int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE; + virtual const net::BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual int64 NumBytesRead() const OVERRIDE; + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE; + + private: + // Proxy resolution and connection functions. + void ProcessProxyResolveDone(int status); + void ProcessConnectDone(int status); + + void CloseTransportSocket(); + void RunUserConnectCallback(int status); + int ReconsiderProxyAfterError(int error); + void ReportSuccessfulProxyConnection(); + + // Callbacks passed to net APIs. + net::CompletionCallback proxy_resolve_callback_; + net::CompletionCallback connect_callback_; + + scoped_refptr<net::HttpNetworkSession> network_session_; + + // The transport socket. + scoped_ptr<net::ClientSocketHandle> transport_; + + const net::SSLConfig ssl_config_; + net::ProxyService::PacRequest* pac_request_; + net::ProxyInfo proxy_info_; + net::HostPortPair dest_host_port_pair_; + bool tried_direct_connect_fallback_; + net::BoundNetLog bound_net_log_; + base::WeakPtrFactory<ProxyResolvingClientSocket> weak_factory_; + + // The callback passed to Connect(). + net::CompletionCallback user_connect_callback_; +}; + +} // namespace jingle_glue + +#endif // JINGLE_GLUE_PROXY_RESOLVING_CLIENT_SOCKET_H_ diff --git a/jingle/glue/proxy_resolving_client_socket_unittest.cc b/jingle/glue/proxy_resolving_client_socket_unittest.cc new file mode 100644 index 0000000..49d95f5 --- /dev/null +++ b/jingle/glue/proxy_resolving_client_socket_unittest.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/proxy_resolving_client_socket.h" + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/message_loop.h" +#include "net/base/mock_host_resolver.h" +#include "net/base/test_completion_callback.h" +#include "net/proxy/proxy_service.h" +#include "net/socket/socket_test_util.h" +#include "net/url_request/url_request_context_getter.h" +#include "net/url_request/url_request_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace { + +class MyTestURLRequestContext : public TestURLRequestContext { + public: + MyTestURLRequestContext() : TestURLRequestContext(true) { + context_storage_.set_proxy_service( + net::ProxyService::CreateFixedFromPacResult( + "PROXY bad:99; PROXY maybe:80; DIRECT")); + Init(); + } + virtual ~MyTestURLRequestContext() {} +}; + +} // namespace + +namespace jingle_glue { + +class ProxyResolvingClientSocketTest : public testing::Test { + protected: + ProxyResolvingClientSocketTest() + : url_request_context_getter_(new TestURLRequestContextGetter( + base::MessageLoopProxy::current(), + scoped_ptr<TestURLRequestContext>(new MyTestURLRequestContext))) {} + + virtual ~ProxyResolvingClientSocketTest() {} + + virtual void TearDown() { + // Clear out any messages posted by ProxyResolvingClientSocket's + // destructor. + message_loop_.RunAllPending(); + } + + MessageLoop message_loop_; + scoped_refptr<TestURLRequestContextGetter> url_request_context_getter_; +}; + +// TODO(sanjeevr): Fix this test on Linux. +TEST_F(ProxyResolvingClientSocketTest, DISABLED_ConnectError) { + net::HostPortPair dest("0.0.0.0", 0); + ProxyResolvingClientSocket proxy_resolving_socket( + NULL, + url_request_context_getter_, + net::SSLConfig(), + dest); + net::TestCompletionCallback callback; + int status = proxy_resolving_socket.Connect(callback.callback()); + // Connect always returns ERR_IO_PENDING because it is always asynchronous. + EXPECT_EQ(net::ERR_IO_PENDING, status); + status = callback.WaitForResult(); + // ProxyResolvingClientSocket::Connect() will always return an error of + // ERR_ADDRESS_INVALID for a 0 IP address. + EXPECT_EQ(net::ERR_ADDRESS_INVALID, status); +} + +TEST_F(ProxyResolvingClientSocketTest, ReportsBadProxies) { + net::HostPortPair dest("example.com", 443); + net::MockClientSocketFactory socket_factory; + + net::StaticSocketDataProvider socket_data1; + socket_data1.set_connect_data( + net::MockConnect(net::ASYNC, net::ERR_ADDRESS_UNREACHABLE)); + socket_factory.AddSocketDataProvider(&socket_data1); + + net::MockRead reads[] = { + net::MockRead("HTTP/1.1 200 Success\r\n\r\n") + }; + net::MockWrite writes[] = { + net::MockWrite("CONNECT example.com:443 HTTP/1.1\r\n" + "Host: example.com:443\r\n" + "Proxy-Connection: keep-alive\r\n\r\n") + }; + net::StaticSocketDataProvider socket_data2(reads, arraysize(reads), + writes, arraysize(writes)); + socket_data2.set_connect_data(net::MockConnect(net::ASYNC, net::OK)); + socket_factory.AddSocketDataProvider(&socket_data2); + + ProxyResolvingClientSocket proxy_resolving_socket( + &socket_factory, + url_request_context_getter_, + net::SSLConfig(), + dest); + + net::TestCompletionCallback callback; + int status = proxy_resolving_socket.Connect(callback.callback()); + EXPECT_EQ(net::ERR_IO_PENDING, status); + status = callback.WaitForResult(); + EXPECT_EQ(net::OK, status); + + net::URLRequestContext* context = + url_request_context_getter_->GetURLRequestContext(); + const net::ProxyRetryInfoMap& retry_info = + context->proxy_service()->proxy_retry_info(); + + EXPECT_EQ(1u, retry_info.size()); + net::ProxyRetryInfoMap::const_iterator iter = retry_info.find("bad:99"); + EXPECT_TRUE(iter != retry_info.end()); +} + +// TODO(sanjeevr): Add more unit-tests. +} // namespace jingle_glue diff --git a/jingle/glue/resolving_client_socket_factory.h b/jingle/glue/resolving_client_socket_factory.h new file mode 100644 index 0000000..5be8bc8 --- /dev/null +++ b/jingle/glue/resolving_client_socket_factory.h @@ -0,0 +1,36 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_ +#define JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_ + + +namespace net { +class ClientSocketHandle; +class HostPortPair; +class SSLClientSocket; +class StreamSocket; +} // namespace net + +// TODO(sanjeevr): Move this to net/ + +namespace jingle_glue { + +// Interface for a ClientSocketFactory that creates ClientSockets that can +// resolve host names and tunnel through proxies. +class ResolvingClientSocketFactory { + public: + virtual ~ResolvingClientSocketFactory() { } + // Method to create a transport socket using a HostPortPair. + virtual net::StreamSocket* CreateTransportClientSocket( + const net::HostPortPair& host_and_port) = 0; + + virtual net::SSLClientSocket* CreateSSLClientSocket( + net::ClientSocketHandle* transport_socket, + const net::HostPortPair& host_and_port) = 0; +}; + +} // namespace jingle_glue + +#endif // JINGLE_GLUE_RESOLVING_CLIENT_SOCKET_FACTORY_H_ diff --git a/jingle/glue/xmpp_client_socket_factory.cc b/jingle/glue/xmpp_client_socket_factory.cc new file mode 100644 index 0000000..13749e1 --- /dev/null +++ b/jingle/glue/xmpp_client_socket_factory.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "jingle/glue/xmpp_client_socket_factory.h" + +#include "base/logging.h" +#include "jingle/glue/fake_ssl_client_socket.h" +#include "jingle/glue/proxy_resolving_client_socket.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/ssl_client_socket.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_getter.h" + +namespace jingle_glue { + +XmppClientSocketFactory::XmppClientSocketFactory( + net::ClientSocketFactory* client_socket_factory, + const net::SSLConfig& ssl_config, + const scoped_refptr<net::URLRequestContextGetter>& request_context_getter, + bool use_fake_ssl_client_socket) + : client_socket_factory_(client_socket_factory), + request_context_getter_(request_context_getter), + ssl_config_(ssl_config), + use_fake_ssl_client_socket_(use_fake_ssl_client_socket) { + CHECK(client_socket_factory_); +} + +XmppClientSocketFactory::~XmppClientSocketFactory() {} + +net::StreamSocket* XmppClientSocketFactory::CreateTransportClientSocket( + const net::HostPortPair& host_and_port) { + // TODO(akalin): Use socket pools. + net::StreamSocket* transport_socket = new ProxyResolvingClientSocket( + NULL, + request_context_getter_, + ssl_config_, + host_and_port); + return (use_fake_ssl_client_socket_ ? + new FakeSSLClientSocket(transport_socket) : transport_socket); +} + +net::SSLClientSocket* XmppClientSocketFactory::CreateSSLClientSocket( + net::ClientSocketHandle* transport_socket, + const net::HostPortPair& host_and_port) { + net::SSLClientSocketContext context; + context.cert_verifier = + request_context_getter_->GetURLRequestContext()->cert_verifier(); + // TODO(rkn): context.server_bound_cert_service is NULL because the + // ServerBoundCertService class is not thread safe. + return client_socket_factory_->CreateSSLClientSocket( + transport_socket, host_and_port, ssl_config_, context); +} + + +} // namespace jingle_glue diff --git a/jingle/glue/xmpp_client_socket_factory.h b/jingle/glue/xmpp_client_socket_factory.h new file mode 100644 index 0000000..f03a04e --- /dev/null +++ b/jingle/glue/xmpp_client_socket_factory.h @@ -0,0 +1,56 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef JINGLE_GLUE_XMPP_CLIENT_SOCKET_FACTORY_H_ +#define JINGLE_GLUE_XMPP_CLIENT_SOCKET_FACTORY_H_ + +#include <string> + +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "jingle/glue/resolving_client_socket_factory.h" +#include "net/base/ssl_config_service.h" + +namespace net { +class ClientSocketFactory; +class ClientSocketHandle; +class HostPortPair; +class SSLClientSocket; +class StreamSocket; +class URLRequestContextGetter; +} // namespace net + +namespace jingle_glue { + +class XmppClientSocketFactory : public ResolvingClientSocketFactory { + public: + // Does not take ownership of |client_socket_factory|. + XmppClientSocketFactory( + net::ClientSocketFactory* client_socket_factory, + const net::SSLConfig& ssl_config, + const scoped_refptr<net::URLRequestContextGetter>& request_context_getter, + bool use_fake_ssl_client_socket); + + virtual ~XmppClientSocketFactory(); + + // ResolvingClientSocketFactory implementation. + virtual net::StreamSocket* CreateTransportClientSocket( + const net::HostPortPair& host_and_port) OVERRIDE; + + virtual net::SSLClientSocket* CreateSSLClientSocket( + net::ClientSocketHandle* transport_socket, + const net::HostPortPair& host_and_port) OVERRIDE; + + private: + net::ClientSocketFactory* const client_socket_factory_; + scoped_refptr<net::URLRequestContextGetter> request_context_getter_; + const net::SSLConfig ssl_config_; + const bool use_fake_ssl_client_socket_; + + DISALLOW_COPY_AND_ASSIGN(XmppClientSocketFactory); +}; + +} // namespace jingle_glue + +#endif // JINGLE_GLUE_XMPP_CLIENT_SOCKET_FACTORY_H_ |