From 05e3f96385c7a9808dec06f13419d7bb1996ec45 Mon Sep 17 00:00:00 2001 From: kmarshall Date: Thu, 3 Mar 2016 12:50:08 -0800 Subject: Blimp: add support for SSL connections. This CL allows the Blimp client to establish TLS-protected channels with the backend engine. The authenticity of the engine is validated by checking if its cert is an exact match of a certificate provided separately by the Assigner API. * Create new Blimp SSL transport class: SSLClientTransport. * Create custom CertValidator for checking an exact cert match against the SSL peer's cert * Integrate SSLClientTransport with BlimpClientSession. * Assignment: add certificate field. * AssignmentSource: add certificate file reading; PEM file parsing; X509 certificate parsing. * Created new DEPS entries as appropriate. R=wez@chromium.org CC=rsleevi@chromium.org BUG=585279,589202 Committed: https://crrev.com/c80f5095f045ad1712f1f1075a44547a561f774a Cr-Commit-Position: refs/heads/master@{#378839} Review URL: https://codereview.chromium.org/1696563002 Cr-Commit-Position: refs/heads/master@{#379081} --- blimp/net/BUILD.gn | 5 + blimp/net/DEPS | 3 +- blimp/net/blimp_transport.h | 2 +- blimp/net/exact_match_cert_verifier.cc | 56 ++++++++++ blimp/net/exact_match_cert_verifier.h | 50 +++++++++ blimp/net/ssl_client_transport.cc | 91 ++++++++++++++++ blimp/net/ssl_client_transport.h | 64 ++++++++++++ blimp/net/ssl_client_transport_unittest.cc | 162 +++++++++++++++++++++++++++++ blimp/net/tcp_client_transport.cc | 46 +++++--- blimp/net/tcp_client_transport.h | 32 ++++-- blimp/net/tcp_engine_transport.cc | 2 +- blimp/net/tcp_engine_transport.h | 2 +- blimp/net/tcp_transport_unittest.cc | 14 +-- blimp/net/test_common.cc | 2 +- blimp/net/test_common.h | 2 +- 15 files changed, 500 insertions(+), 33 deletions(-) create mode 100644 blimp/net/exact_match_cert_verifier.cc create mode 100644 blimp/net/exact_match_cert_verifier.h create mode 100644 blimp/net/ssl_client_transport.cc create mode 100644 blimp/net/ssl_client_transport.h create mode 100644 blimp/net/ssl_client_transport_unittest.cc (limited to 'blimp/net') diff --git a/blimp/net/BUILD.gn b/blimp/net/BUILD.gn index aeafbae..28c00cd 100644 --- a/blimp/net/BUILD.gn +++ b/blimp/net/BUILD.gn @@ -34,12 +34,16 @@ component("blimp_net") { "engine_authentication_handler.h", "engine_connection_manager.cc", "engine_connection_manager.h", + "exact_match_cert_verifier.cc", + "exact_match_cert_verifier.h", "input_message_converter.cc", "input_message_converter.h", "input_message_generator.cc", "input_message_generator.h", "null_blimp_message_processor.cc", "null_blimp_message_processor.h", + "ssl_client_transport.cc", + "ssl_client_transport.h", "stream_packet_reader.cc", "stream_packet_reader.h", "stream_packet_writer.cc", @@ -94,6 +98,7 @@ source_set("unit_tests") { "engine_authentication_handler_unittest.cc", "engine_connection_manager_unittest.cc", "input_message_unittest.cc", + "ssl_client_transport_unittest.cc", "stream_packet_reader_unittest.cc", "stream_packet_writer_unittest.cc", "tcp_transport_unittest.cc", diff --git a/blimp/net/DEPS b/blimp/net/DEPS index 9891086..211c82c 100644 --- a/blimp/net/DEPS +++ b/blimp/net/DEPS @@ -1,6 +1,5 @@ include_rules = [ - "+net/base", - "+net/socket", + "+net", "+third_party/WebKit/public/platform/WebGestureDevice.h", "+third_party/WebKit/public/web/WebInputEvent.h", ] diff --git a/blimp/net/blimp_transport.h b/blimp/net/blimp_transport.h index 8336912..8deabd5 100644 --- a/blimp/net/blimp_transport.h +++ b/blimp/net/blimp_transport.h @@ -35,7 +35,7 @@ class BlimpTransport { virtual scoped_ptr TakeConnection() = 0; // Gets transport name, e.g. "TCP", "SSL", "mock", etc. - virtual const std::string GetName() const = 0; + virtual const char* GetName() const = 0; }; } // namespace blimp diff --git a/blimp/net/exact_match_cert_verifier.cc b/blimp/net/exact_match_cert_verifier.cc new file mode 100644 index 0000000..c140c36 --- /dev/null +++ b/blimp/net/exact_match_cert_verifier.cc @@ -0,0 +1,56 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "blimp/net/exact_match_cert_verifier.h" + +#include "base/callback.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_errors.h" +#include "net/cert/cert_verifier.h" +#include "net/cert/cert_verify_result.h" +#include "net/cert/x509_certificate.h" + +namespace blimp { + +ExactMatchCertVerifier::ExactMatchCertVerifier( + scoped_refptr engine_cert) + : engine_cert_(std::move(engine_cert)) { + DCHECK(engine_cert); + + net::SHA1HashValue sha1_hash; + sha1_hash = net::X509Certificate::CalculateFingerprint( + engine_cert_->os_cert_handle()); + engine_cert_hashes_.push_back(net::HashValue(sha1_hash)); + + net::SHA256HashValue sha256_hash; + sha256_hash = net::X509Certificate::CalculateFingerprint256( + engine_cert_->os_cert_handle()); + engine_cert_hashes_.push_back(net::HashValue(sha256_hash)); +} + +ExactMatchCertVerifier::~ExactMatchCertVerifier() {} + +int ExactMatchCertVerifier::Verify(net::X509Certificate* cert, + const std::string& hostname, + const std::string& ocsp_response, + int flags, + net::CRLSet* crl_set, + net::CertVerifyResult* verify_result, + const net::CompletionCallback& callback, + scoped_ptr* out_req, + const net::BoundNetLog& net_log) { + verify_result->Reset(); + verify_result->verified_cert = engine_cert_; + + if (!cert->Equals(engine_cert_.get())) { + verify_result->cert_status = net::CERT_STATUS_INVALID; + return net::ERR_CERT_INVALID; + } + + verify_result->public_key_hashes = engine_cert_hashes_; + return net::OK; +} + +} // namespace blimp diff --git a/blimp/net/exact_match_cert_verifier.h b/blimp/net/exact_match_cert_verifier.h new file mode 100644 index 0000000..d26d57a --- /dev/null +++ b/blimp/net/exact_match_cert_verifier.h @@ -0,0 +1,50 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef BLIMP_NET_EXACT_MATCH_CERT_VERIFIER_H_ +#define BLIMP_NET_EXACT_MATCH_CERT_VERIFIER_H_ + +#include +#include + +#include "blimp/net/blimp_net_export.h" +#include "net/cert/cert_verifier.h" + +namespace net { +class HashValue; +} // namespace net + +namespace blimp { + +// Checks if the peer certificate is an exact match to the X.509 certificate +// |engine_cert|, which is specified at class construction time. +class BLIMP_NET_EXPORT ExactMatchCertVerifier : public net::CertVerifier { + public: + // |engine_cert|: The single allowable certificate. + explicit ExactMatchCertVerifier( + scoped_refptr engine_cert); + + ~ExactMatchCertVerifier() override; + + // net::CertVerifier implementation. + int Verify(net::X509Certificate* cert, + const std::string& hostname, + const std::string& ocsp_response, + int flags, + net::CRLSet* crl_set, + net::CertVerifyResult* verify_result, + const net::CompletionCallback& callback, + scoped_ptr* out_req, + const net::BoundNetLog& net_log) override; + + private: + scoped_refptr engine_cert_; + std::vector engine_cert_hashes_; + + DISALLOW_COPY_AND_ASSIGN(ExactMatchCertVerifier); +}; + +} // namespace blimp + +#endif // BLIMP_NET_EXACT_MATCH_CERT_VERIFIER_H_ diff --git a/blimp/net/ssl_client_transport.cc b/blimp/net/ssl_client_transport.cc new file mode 100644 index 0000000..3364093 --- /dev/null +++ b/blimp/net/ssl_client_transport.cc @@ -0,0 +1,91 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "blimp/net/ssl_client_transport.h" + +#include "base/callback.h" +#include "base/callback_helpers.h" +#include "blimp/net/exact_match_cert_verifier.h" +#include "blimp/net/stream_socket_connection.h" +#include "net/base/host_port_pair.h" +#include "net/cert/x509_certificate.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" +#include "net/socket/tcp_client_socket.h" +#include "net/ssl/ssl_config.h" + +namespace blimp { + +SSLClientTransport::SSLClientTransport(const net::IPEndPoint& ip_endpoint, + scoped_refptr cert, + net::NetLog* net_log) + : TCPClientTransport(ip_endpoint, net_log), ip_endpoint_(ip_endpoint) { + // Test code may pass in a null value for |cert|. Only spin up a CertVerifier + // if there is a cert present. + if (cert) { + cert_verifier_.reset(new ExactMatchCertVerifier(std::move(cert))); + } +} + +SSLClientTransport::~SSLClientTransport() {} + +const char* SSLClientTransport::GetName() const { + return "SSL"; +} + +void SSLClientTransport::OnTCPConnectComplete(int result) { + DCHECK_NE(net::ERR_IO_PENDING, result); + + scoped_ptr tcp_socket = TCPClientTransport::TakeSocket(); + + DVLOG(1) << "TCP connection result=" << result; + if (result != net::OK) { + OnConnectComplete(result); + return; + } + + // Construct arguments to use for the SSL socket factory. + scoped_ptr socket_handle( + new net::ClientSocketHandle); + socket_handle->SetSocket(std::move(tcp_socket)); + + net::HostPortPair host_port_pair = + net::HostPortPair::FromIPEndPoint(ip_endpoint_); + + net::SSLClientSocketContext create_context; + create_context.cert_verifier = cert_verifier_.get(); + create_context.transport_security_state = &transport_security_state_; + + scoped_ptr ssl_socket( + socket_factory()->CreateSSLClientSocket(std::move(socket_handle), + host_port_pair, net::SSLConfig(), + create_context)); + + if (!ssl_socket) { + OnConnectComplete(net::ERR_SSL_PROTOCOL_ERROR); + return; + } + + result = ssl_socket->Connect(base::Bind( + &SSLClientTransport::OnSSLConnectComplete, base::Unretained(this))); + SetSocket(std::move(ssl_socket)); + + if (result == net::ERR_IO_PENDING) { + // SSL connection will complete asynchronously. + return; + } + + OnSSLConnectComplete(result); +} + +void SSLClientTransport::OnSSLConnectComplete(int result) { + DCHECK_NE(net::ERR_IO_PENDING, result); + DVLOG(1) << "SSL connection result=" << result; + + OnConnectComplete(result); +} + +} // namespace blimp diff --git a/blimp/net/ssl_client_transport.h b/blimp/net/ssl_client_transport.h new file mode 100644 index 0000000..0f0a1f4 --- /dev/null +++ b/blimp/net/ssl_client_transport.h @@ -0,0 +1,64 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef BLIMP_NET_SSL_CLIENT_TRANSPORT_H_ +#define BLIMP_NET_SSL_CLIENT_TRANSPORT_H_ + +#include +#include "base/callback_forward.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "blimp/net/blimp_net_export.h" +#include "blimp/net/blimp_transport.h" +#include "blimp/net/exact_match_cert_verifier.h" +#include "blimp/net/tcp_client_transport.h" +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/http/transport_security_state.h" + +namespace net { +class ClientSocketFactory; +class NetLog; +class StreamSocket; +class TCPClientSocket; +class TransportSecurityState; +} // namespace net + +namespace blimp { + +class BlimpConnection; + +// Creates and connects SSL socket connections to an Engine. +class BLIMP_NET_EXPORT SSLClientTransport : public TCPClientTransport { + public: + // |ip_endpoint|: the address to connect to. + // |cert|: the certificate required from the remote peer. + // SSL connections that use different certificates are rejected. + // |net_log|: the socket event log (optional). + SSLClientTransport(const net::IPEndPoint& ip_endpoint, + scoped_refptr cert, + net::NetLog* net_log); + + ~SSLClientTransport() override; + + // BlimpTransport implementation. + const char* GetName() const override; + + private: + // Method called after TCPClientSocket::Connect finishes. + void OnTCPConnectComplete(int result) override; + + // Method called after SSLClientSocket::Connect finishes. + void OnSSLConnectComplete(int result); + + net::IPEndPoint ip_endpoint_; + scoped_ptr cert_verifier_; + net::TransportSecurityState transport_security_state_; + + DISALLOW_COPY_AND_ASSIGN(SSLClientTransport); +}; + +} // namespace blimp + +#endif // BLIMP_NET_SSL_CLIENT_TRANSPORT_H_ diff --git a/blimp/net/ssl_client_transport_unittest.cc b/blimp/net/ssl_client_transport_unittest.cc new file mode 100644 index 0000000..47724ce --- /dev/null +++ b/blimp/net/ssl_client_transport_unittest.cc @@ -0,0 +1,162 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/bind.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "blimp/net/blimp_connection.h" +#include "blimp/net/ssl_client_transport.h" +#include "net/base/address_list.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/socket/socket_test_util.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace blimp { +namespace { + +const uint8_t kIPV4Address[] = {127, 0, 0, 1}; +const uint16_t kPort = 6667; + +} // namespace + +class SSLClientTransportTest : public testing::Test { + public: + SSLClientTransportTest() {} + + ~SSLClientTransportTest() override {} + + void TearDown() override { base::RunLoop().RunUntilIdle(); } + + MOCK_METHOD1(ConnectComplete, void(int)); + + protected: + // Methods for provisioning simulated connection outcomes. + void SetupTCPSyncSocketConnect(net::IPEndPoint endpoint, int result) { + tcp_connect_.set_connect_data( + net::MockConnect(net::SYNCHRONOUS, result, endpoint)); + socket_factory_.AddSocketDataProvider(&tcp_connect_); + } + + void SetupTCPAsyncSocketConnect(net::IPEndPoint endpoint, int result) { + tcp_connect_.set_connect_data( + net::MockConnect(net::ASYNC, result, endpoint)); + socket_factory_.AddSocketDataProvider(&tcp_connect_); + } + + void SetupSSLSyncSocketConnect(int result) { + ssl_connect_.reset( + new net::SSLSocketDataProvider(net::SYNCHRONOUS, result)); + socket_factory_.AddSSLSocketDataProvider(ssl_connect_.get()); + } + + void SetupSSLAsyncSocketConnect(int result) { + ssl_connect_.reset(new net::SSLSocketDataProvider(net::ASYNC, result)); + socket_factory_.AddSSLSocketDataProvider(ssl_connect_.get()); + } + + void ConfigureTransport(const net::IPEndPoint& ip_endpoint) { + // The mock does not interact with the cert directly, so just leave it null. + scoped_refptr cert; + transport_.reset(new SSLClientTransport(ip_endpoint, cert, &net_log_)); + transport_->SetClientSocketFactoryForTest(&socket_factory_); + } + + base::MessageLoop message_loop; + net::NetLog net_log_; + net::StaticSocketDataProvider tcp_connect_; + scoped_ptr ssl_connect_; + net::MockClientSocketFactory socket_factory_; + scoped_ptr transport_; +}; + +TEST_F(SSLClientTransportTest, ConnectSyncOK) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + for (int i = 0; i < 5; ++i) { + EXPECT_CALL(*this, ConnectComplete(net::OK)); + SetupTCPSyncSocketConnect(endpoint, net::OK); + SetupSSLSyncSocketConnect(net::OK); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); + EXPECT_NE(nullptr, transport_->TakeConnection().get()); + base::RunLoop().RunUntilIdle(); + } +} + +TEST_F(SSLClientTransportTest, ConnectAsyncOK) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + for (int i = 0; i < 5; ++i) { + EXPECT_CALL(*this, ConnectComplete(net::OK)); + SetupTCPAsyncSocketConnect(endpoint, net::OK); + SetupSSLAsyncSocketConnect(net::OK); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); + base::RunLoop().RunUntilIdle(); + EXPECT_NE(nullptr, transport_->TakeConnection().get()); + } +} + +TEST_F(SSLClientTransportTest, ConnectSyncTCPError) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED)); + SetupTCPSyncSocketConnect(endpoint, net::ERR_FAILED); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); +} + +TEST_F(SSLClientTransportTest, ConnectAsyncTCPError) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED)); + SetupTCPAsyncSocketConnect(endpoint, net::ERR_FAILED); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); +} + +TEST_F(SSLClientTransportTest, ConnectSyncSSLError) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED)); + SetupTCPSyncSocketConnect(endpoint, net::OK); + SetupSSLSyncSocketConnect(net::ERR_FAILED); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); +} + +TEST_F(SSLClientTransportTest, ConnectAsyncSSLError) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED)); + SetupTCPAsyncSocketConnect(endpoint, net::OK); + SetupSSLAsyncSocketConnect(net::ERR_FAILED); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); +} + +TEST_F(SSLClientTransportTest, ConnectAfterError) { + net::IPEndPoint endpoint(kIPV4Address, kPort); + ConfigureTransport(endpoint); + + // TCP connection fails. + EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED)); + SetupTCPSyncSocketConnect(endpoint, net::ERR_FAILED); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); + base::RunLoop().RunUntilIdle(); + + // Subsequent TCP+SSL connections succeed. + EXPECT_CALL(*this, ConnectComplete(net::OK)); + SetupTCPSyncSocketConnect(endpoint, net::OK); + SetupSSLSyncSocketConnect(net::OK); + transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete, + base::Unretained(this))); + EXPECT_NE(nullptr, transport_->TakeConnection().get()); + base::RunLoop().RunUntilIdle(); +} + +} // namespace blimp diff --git a/blimp/net/tcp_client_transport.cc b/blimp/net/tcp_client_transport.cc index 9dbb8803..6a6cc50 100644 --- a/blimp/net/tcp_client_transport.cc +++ b/blimp/net/tcp_client_transport.cc @@ -9,38 +9,42 @@ #include "base/memory/scoped_ptr.h" #include "base/message_loop/message_loop.h" #include "blimp/net/stream_socket_connection.h" +#include "net/socket/client_socket_factory.h" #include "net/socket/stream_socket.h" #include "net/socket/tcp_client_socket.h" namespace blimp { -TCPClientTransport::TCPClientTransport(const net::AddressList& addresses, +TCPClientTransport::TCPClientTransport(const net::IPEndPoint& ip_endpoint, net::NetLog* net_log) - : addresses_(addresses), net_log_(net_log) {} + : ip_endpoint_(ip_endpoint), + net_log_(net_log), + socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {} TCPClientTransport::~TCPClientTransport() {} +void TCPClientTransport::SetClientSocketFactoryForTest( + net::ClientSocketFactory* factory) { + DCHECK(factory); + socket_factory_ = factory; +} + void TCPClientTransport::Connect(const net::CompletionCallback& callback) { DCHECK(!socket_); DCHECK(!callback.is_null()); - socket_.reset( - new net::TCPClientSocket(addresses_, net_log_, net::NetLog::Source())); + connect_callback_ = callback; + socket_ = socket_factory_->CreateTransportClientSocket( + net::AddressList(ip_endpoint_), net_log_, net::NetLog::Source()); net::CompletionCallback completion_callback = base::Bind( &TCPClientTransport::OnTCPConnectComplete, base::Unretained(this)); int result = socket_->Connect(completion_callback); if (result == net::ERR_IO_PENDING) { - connect_callback_ = callback; return; } - if (result != net::OK) { - socket_ = nullptr; - } - - base::MessageLoop::current()->PostTask(FROM_HERE, - base::Bind(callback, result)); + OnTCPConnectComplete(result); } scoped_ptr TCPClientTransport::TakeConnection() { @@ -49,17 +53,33 @@ scoped_ptr TCPClientTransport::TakeConnection() { return make_scoped_ptr(new StreamSocketConnection(std::move(socket_))); } -const std::string TCPClientTransport::GetName() const { +const char* TCPClientTransport::GetName() const { return "TCP"; } void TCPClientTransport::OnTCPConnectComplete(int result) { DCHECK_NE(net::ERR_IO_PENDING, result); - DCHECK(socket_); + OnConnectComplete(result); +} + +void TCPClientTransport::OnConnectComplete(int result) { if (result != net::OK) { socket_ = nullptr; } base::ResetAndReturn(&connect_callback_).Run(result); } +scoped_ptr TCPClientTransport::TakeSocket() { + return std::move(socket_); +} + +void TCPClientTransport::SetSocket(scoped_ptr socket) { + DCHECK(socket); + socket_ = std::move(socket); +} + +net::ClientSocketFactory* TCPClientTransport::socket_factory() const { + return socket_factory_; +} + } // namespace blimp diff --git a/blimp/net/tcp_client_transport.h b/blimp/net/tcp_client_transport.h index 3242883..6a879fe 100644 --- a/blimp/net/tcp_client_transport.h +++ b/blimp/net/tcp_client_transport.h @@ -5,6 +5,8 @@ #ifndef BLIMP_NET_TCP_CLIENT_TRANSPORT_H_ #define BLIMP_NET_TCP_CLIENT_TRANSPORT_H_ +#include + #include "base/callback.h" #include "base/macros.h" #include "base/memory/scoped_ptr.h" @@ -14,6 +16,7 @@ #include "net/base/net_errors.h" namespace net { +class ClientSocketFactory; class NetLog; class StreamSocket; } // namespace net @@ -26,21 +29,38 @@ class BlimpConnection; // |addresses| on each call to Connect(). class BLIMP_NET_EXPORT TCPClientTransport : public BlimpTransport { public: - TCPClientTransport(const net::AddressList& addresses, net::NetLog* net_log); + TCPClientTransport(const net::IPEndPoint& ip_endpoint, net::NetLog* net_log); ~TCPClientTransport() override; + void SetClientSocketFactoryForTest(net::ClientSocketFactory* factory); + // BlimpTransport implementation. void Connect(const net::CompletionCallback& callback) override; scoped_ptr TakeConnection() override; - const std::string GetName() const override; + const char* GetName() const override; - private: - void OnTCPConnectComplete(int result); + protected: + // Called when the TCP connection completed. + virtual void OnTCPConnectComplete(int result); + + // Called when the connection attempt completed or failed. + // Resets |socket_| if |result| indicates a failure (!= net::OK). + void OnConnectComplete(int result); + + // Methods for taking and setting |socket_|. Can be used by subclasses to + // swap out a socket for an upgraded one, e.g. adding SSL encryption. + scoped_ptr TakeSocket(); + void SetSocket(scoped_ptr socket); - net::AddressList addresses_; + // Gets the socket factory instance. + net::ClientSocketFactory* socket_factory() const; + + private: + net::IPEndPoint ip_endpoint_; net::NetLog* net_log_; - scoped_ptr socket_; net::CompletionCallback connect_callback_; + net::ClientSocketFactory* socket_factory_ = nullptr; + scoped_ptr socket_; DISALLOW_COPY_AND_ASSIGN(TCPClientTransport); }; diff --git a/blimp/net/tcp_engine_transport.cc b/blimp/net/tcp_engine_transport.cc index 473a1bb..4606caf 100644 --- a/blimp/net/tcp_engine_transport.cc +++ b/blimp/net/tcp_engine_transport.cc @@ -61,7 +61,7 @@ scoped_ptr TCPEngineTransport::TakeConnection() { new StreamSocketConnection(std::move(accepted_socket_))); } -const std::string TCPEngineTransport::GetName() const { +const char* TCPEngineTransport::GetName() const { return "TCP"; } diff --git a/blimp/net/tcp_engine_transport.h b/blimp/net/tcp_engine_transport.h index bf8b7cf..ace223e 100644 --- a/blimp/net/tcp_engine_transport.h +++ b/blimp/net/tcp_engine_transport.h @@ -33,7 +33,7 @@ class BLIMP_NET_EXPORT TCPEngineTransport : public BlimpTransport { // BlimpTransport implementation. void Connect(const net::CompletionCallback& callback) override; scoped_ptr TakeConnection() override; - const std::string GetName() const override; + const char* GetName() const override; int GetLocalAddressForTesting(net::IPEndPoint* address) const; diff --git a/blimp/net/tcp_transport_unittest.cc b/blimp/net/tcp_transport_unittest.cc index cef15d3..8f1525a 100644 --- a/blimp/net/tcp_transport_unittest.cc +++ b/blimp/net/tcp_transport_unittest.cc @@ -35,10 +35,10 @@ class TCPTransportTest : public testing::Test { engine_.reset(new TCPEngineTransport(local_address, nullptr)); } - net::AddressList GetLocalAddressList() const { + net::IPEndPoint GetLocalEndpoint() const { net::IPEndPoint local_address; - engine_->GetLocalAddressForTesting(&local_address); - return net::AddressList(local_address); + CHECK_EQ(net::OK, engine_->GetLocalAddressForTesting(&local_address)); + return local_address; } base::MessageLoopForIO message_loop_; @@ -50,7 +50,7 @@ TEST_F(TCPTransportTest, Connect) { engine_->Connect(accept_callback.callback()); net::TestCompletionCallback connect_callback; - TCPClientTransport client(GetLocalAddressList(), nullptr); + TCPClientTransport client(GetLocalEndpoint(), nullptr); client.Connect(connect_callback.callback()); EXPECT_EQ(net::OK, connect_callback.WaitForResult()); @@ -63,11 +63,11 @@ TEST_F(TCPTransportTest, TwoClientConnections) { engine_->Connect(accept_callback1.callback()); net::TestCompletionCallback connect_callback1; - TCPClientTransport client1(GetLocalAddressList(), nullptr); + TCPClientTransport client1(GetLocalEndpoint(), nullptr); client1.Connect(connect_callback1.callback()); net::TestCompletionCallback connect_callback2; - TCPClientTransport client2(GetLocalAddressList(), nullptr); + TCPClientTransport client2(GetLocalEndpoint(), nullptr); client2.Connect(connect_callback2.callback()); EXPECT_EQ(net::OK, connect_callback1.WaitForResult()); @@ -86,7 +86,7 @@ TEST_F(TCPTransportTest, ExchangeMessages) { net::TestCompletionCallback accept_callback; engine_->Connect(accept_callback.callback()); net::TestCompletionCallback client_connect_callback; - TCPClientTransport client(GetLocalAddressList(), nullptr /* NetLog */); + TCPClientTransport client(GetLocalEndpoint(), nullptr /* NetLog */); client.Connect(client_connect_callback.callback()); EXPECT_EQ(net::OK, client_connect_callback.WaitForResult()); EXPECT_EQ(net::OK, accept_callback.WaitForResult()); diff --git a/blimp/net/test_common.cc b/blimp/net/test_common.cc index e6cd956..92a951d 100644 --- a/blimp/net/test_common.cc +++ b/blimp/net/test_common.cc @@ -26,7 +26,7 @@ scoped_ptr MockTransport::TakeConnection() { return make_scoped_ptr(TakeConnectionPtr()); } -const std::string MockTransport::GetName() const { +const char* MockTransport::GetName() const { return "mock"; } diff --git a/blimp/net/test_common.h b/blimp/net/test_common.h index 302060a..380465e 100644 --- a/blimp/net/test_common.h +++ b/blimp/net/test_common.h @@ -137,7 +137,7 @@ class MockTransport : public BlimpTransport { MOCK_METHOD0(TakeConnectionPtr, BlimpConnection*()); scoped_ptr TakeConnection() override; - const std::string GetName() const override; + const char* GetName() const override; }; class MockConnectionHandler : public ConnectionHandler { -- cgit v1.1