summaryrefslogtreecommitdiffstats
path: root/blimp/net
diff options
context:
space:
mode:
authorkmarshall <kmarshall@chromium.org>2016-03-02 13:41:39 -0800
committerCommit bot <commit-bot@chromium.org>2016-03-02 21:43:41 +0000
commitc80f5095f045ad1712f1f1075a44547a561f774a (patch)
tree9b23343cb2ad36e6e63ecc7793cb2f93bfd0b23e /blimp/net
parentf0444bfd6b3c6b43f34da7709debdbc248395ef5 (diff)
downloadchromium_src-c80f5095f045ad1712f1f1075a44547a561f774a.zip
chromium_src-c80f5095f045ad1712f1f1075a44547a561f774a.tar.gz
chromium_src-c80f5095f045ad1712f1f1075a44547a561f774a.tar.bz2
Blimp: add support for SSL connections.
This CL allows the Blimp client to establish TLS-protected channels with the backend engine. The authenticity of the engine is validated by checking if its cert is an exact match of a certificate provided separately by the Assigner API. * Create new Blimp SSL transport class: SSLClientTransport. * Create custom CertValidator for checking an exact cert match against the SSL peer's cert * Integrate SSLClientTransport with BlimpClientSession. * Assignment: add certificate field. * AssignmentSource: add certificate file reading; PEM file parsing; X509 certificate parsing. * Created new DEPS entries as appropriate. R=wez@chromium.org CC=rsleevi@chromium.org BUG=585279,589202 Review URL: https://codereview.chromium.org/1696563002 Cr-Commit-Position: refs/heads/master@{#378839}
Diffstat (limited to 'blimp/net')
-rw-r--r--blimp/net/BUILD.gn5
-rw-r--r--blimp/net/DEPS3
-rw-r--r--blimp/net/blimp_transport.h2
-rw-r--r--blimp/net/exact_match_cert_verifier.cc56
-rw-r--r--blimp/net/exact_match_cert_verifier.h50
-rw-r--r--blimp/net/ssl_client_transport.cc91
-rw-r--r--blimp/net/ssl_client_transport.h64
-rw-r--r--blimp/net/ssl_client_transport_unittest.cc162
-rw-r--r--blimp/net/tcp_client_transport.cc46
-rw-r--r--blimp/net/tcp_client_transport.h32
-rw-r--r--blimp/net/tcp_engine_transport.cc2
-rw-r--r--blimp/net/tcp_engine_transport.h2
-rw-r--r--blimp/net/tcp_transport_unittest.cc14
-rw-r--r--blimp/net/test_common.cc2
-rw-r--r--blimp/net/test_common.h2
15 files changed, 500 insertions, 33 deletions
diff --git a/blimp/net/BUILD.gn b/blimp/net/BUILD.gn
index aeafbae..28c00cd 100644
--- a/blimp/net/BUILD.gn
+++ b/blimp/net/BUILD.gn
@@ -34,12 +34,16 @@ component("blimp_net") {
"engine_authentication_handler.h",
"engine_connection_manager.cc",
"engine_connection_manager.h",
+ "exact_match_cert_verifier.cc",
+ "exact_match_cert_verifier.h",
"input_message_converter.cc",
"input_message_converter.h",
"input_message_generator.cc",
"input_message_generator.h",
"null_blimp_message_processor.cc",
"null_blimp_message_processor.h",
+ "ssl_client_transport.cc",
+ "ssl_client_transport.h",
"stream_packet_reader.cc",
"stream_packet_reader.h",
"stream_packet_writer.cc",
@@ -94,6 +98,7 @@ source_set("unit_tests") {
"engine_authentication_handler_unittest.cc",
"engine_connection_manager_unittest.cc",
"input_message_unittest.cc",
+ "ssl_client_transport_unittest.cc",
"stream_packet_reader_unittest.cc",
"stream_packet_writer_unittest.cc",
"tcp_transport_unittest.cc",
diff --git a/blimp/net/DEPS b/blimp/net/DEPS
index 9891086..211c82c 100644
--- a/blimp/net/DEPS
+++ b/blimp/net/DEPS
@@ -1,6 +1,5 @@
include_rules = [
- "+net/base",
- "+net/socket",
+ "+net",
"+third_party/WebKit/public/platform/WebGestureDevice.h",
"+third_party/WebKit/public/web/WebInputEvent.h",
]
diff --git a/blimp/net/blimp_transport.h b/blimp/net/blimp_transport.h
index 8336912..8deabd5 100644
--- a/blimp/net/blimp_transport.h
+++ b/blimp/net/blimp_transport.h
@@ -35,7 +35,7 @@ class BlimpTransport {
virtual scoped_ptr<BlimpConnection> TakeConnection() = 0;
// Gets transport name, e.g. "TCP", "SSL", "mock", etc.
- virtual const std::string GetName() const = 0;
+ virtual const char* GetName() const = 0;
};
} // namespace blimp
diff --git a/blimp/net/exact_match_cert_verifier.cc b/blimp/net/exact_match_cert_verifier.cc
new file mode 100644
index 0000000..c140c36
--- /dev/null
+++ b/blimp/net/exact_match_cert_verifier.cc
@@ -0,0 +1,56 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "blimp/net/exact_match_cert_verifier.h"
+
+#include "base/callback.h"
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_errors.h"
+#include "net/cert/cert_verifier.h"
+#include "net/cert/cert_verify_result.h"
+#include "net/cert/x509_certificate.h"
+
+namespace blimp {
+
+ExactMatchCertVerifier::ExactMatchCertVerifier(
+ scoped_refptr<net::X509Certificate> engine_cert)
+ : engine_cert_(std::move(engine_cert)) {
+ DCHECK(engine_cert);
+
+ net::SHA1HashValue sha1_hash;
+ sha1_hash = net::X509Certificate::CalculateFingerprint(
+ engine_cert_->os_cert_handle());
+ engine_cert_hashes_.push_back(net::HashValue(sha1_hash));
+
+ net::SHA256HashValue sha256_hash;
+ sha256_hash = net::X509Certificate::CalculateFingerprint256(
+ engine_cert_->os_cert_handle());
+ engine_cert_hashes_.push_back(net::HashValue(sha256_hash));
+}
+
+ExactMatchCertVerifier::~ExactMatchCertVerifier() {}
+
+int ExactMatchCertVerifier::Verify(net::X509Certificate* cert,
+ const std::string& hostname,
+ const std::string& ocsp_response,
+ int flags,
+ net::CRLSet* crl_set,
+ net::CertVerifyResult* verify_result,
+ const net::CompletionCallback& callback,
+ scoped_ptr<Request>* out_req,
+ const net::BoundNetLog& net_log) {
+ verify_result->Reset();
+ verify_result->verified_cert = engine_cert_;
+
+ if (!cert->Equals(engine_cert_.get())) {
+ verify_result->cert_status = net::CERT_STATUS_INVALID;
+ return net::ERR_CERT_INVALID;
+ }
+
+ verify_result->public_key_hashes = engine_cert_hashes_;
+ return net::OK;
+}
+
+} // namespace blimp
diff --git a/blimp/net/exact_match_cert_verifier.h b/blimp/net/exact_match_cert_verifier.h
new file mode 100644
index 0000000..d26d57a
--- /dev/null
+++ b/blimp/net/exact_match_cert_verifier.h
@@ -0,0 +1,50 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BLIMP_NET_EXACT_MATCH_CERT_VERIFIER_H_
+#define BLIMP_NET_EXACT_MATCH_CERT_VERIFIER_H_
+
+#include <string>
+#include <vector>
+
+#include "blimp/net/blimp_net_export.h"
+#include "net/cert/cert_verifier.h"
+
+namespace net {
+class HashValue;
+} // namespace net
+
+namespace blimp {
+
+// Checks if the peer certificate is an exact match to the X.509 certificate
+// |engine_cert|, which is specified at class construction time.
+class BLIMP_NET_EXPORT ExactMatchCertVerifier : public net::CertVerifier {
+ public:
+ // |engine_cert|: The single allowable certificate.
+ explicit ExactMatchCertVerifier(
+ scoped_refptr<net::X509Certificate> engine_cert);
+
+ ~ExactMatchCertVerifier() override;
+
+ // net::CertVerifier implementation.
+ int Verify(net::X509Certificate* cert,
+ const std::string& hostname,
+ const std::string& ocsp_response,
+ int flags,
+ net::CRLSet* crl_set,
+ net::CertVerifyResult* verify_result,
+ const net::CompletionCallback& callback,
+ scoped_ptr<Request>* out_req,
+ const net::BoundNetLog& net_log) override;
+
+ private:
+ scoped_refptr<net::X509Certificate> engine_cert_;
+ std::vector<net::HashValue> engine_cert_hashes_;
+
+ DISALLOW_COPY_AND_ASSIGN(ExactMatchCertVerifier);
+};
+
+} // namespace blimp
+
+#endif // BLIMP_NET_EXACT_MATCH_CERT_VERIFIER_H_
diff --git a/blimp/net/ssl_client_transport.cc b/blimp/net/ssl_client_transport.cc
new file mode 100644
index 0000000..3364093
--- /dev/null
+++ b/blimp/net/ssl_client_transport.cc
@@ -0,0 +1,91 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "blimp/net/ssl_client_transport.h"
+
+#include "base/callback.h"
+#include "base/callback_helpers.h"
+#include "blimp/net/exact_match_cert_verifier.h"
+#include "blimp/net/stream_socket_connection.h"
+#include "net/base/host_port_pair.h"
+#include "net/cert/x509_certificate.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
+#include "net/socket/tcp_client_socket.h"
+#include "net/ssl/ssl_config.h"
+
+namespace blimp {
+
+SSLClientTransport::SSLClientTransport(const net::IPEndPoint& ip_endpoint,
+ scoped_refptr<net::X509Certificate> cert,
+ net::NetLog* net_log)
+ : TCPClientTransport(ip_endpoint, net_log), ip_endpoint_(ip_endpoint) {
+ // Test code may pass in a null value for |cert|. Only spin up a CertVerifier
+ // if there is a cert present.
+ if (cert) {
+ cert_verifier_.reset(new ExactMatchCertVerifier(std::move(cert)));
+ }
+}
+
+SSLClientTransport::~SSLClientTransport() {}
+
+const char* SSLClientTransport::GetName() const {
+ return "SSL";
+}
+
+void SSLClientTransport::OnTCPConnectComplete(int result) {
+ DCHECK_NE(net::ERR_IO_PENDING, result);
+
+ scoped_ptr<net::StreamSocket> tcp_socket = TCPClientTransport::TakeSocket();
+
+ DVLOG(1) << "TCP connection result=" << result;
+ if (result != net::OK) {
+ OnConnectComplete(result);
+ return;
+ }
+
+ // Construct arguments to use for the SSL socket factory.
+ scoped_ptr<net::ClientSocketHandle> socket_handle(
+ new net::ClientSocketHandle);
+ socket_handle->SetSocket(std::move(tcp_socket));
+
+ net::HostPortPair host_port_pair =
+ net::HostPortPair::FromIPEndPoint(ip_endpoint_);
+
+ net::SSLClientSocketContext create_context;
+ create_context.cert_verifier = cert_verifier_.get();
+ create_context.transport_security_state = &transport_security_state_;
+
+ scoped_ptr<net::StreamSocket> ssl_socket(
+ socket_factory()->CreateSSLClientSocket(std::move(socket_handle),
+ host_port_pair, net::SSLConfig(),
+ create_context));
+
+ if (!ssl_socket) {
+ OnConnectComplete(net::ERR_SSL_PROTOCOL_ERROR);
+ return;
+ }
+
+ result = ssl_socket->Connect(base::Bind(
+ &SSLClientTransport::OnSSLConnectComplete, base::Unretained(this)));
+ SetSocket(std::move(ssl_socket));
+
+ if (result == net::ERR_IO_PENDING) {
+ // SSL connection will complete asynchronously.
+ return;
+ }
+
+ OnSSLConnectComplete(result);
+}
+
+void SSLClientTransport::OnSSLConnectComplete(int result) {
+ DCHECK_NE(net::ERR_IO_PENDING, result);
+ DVLOG(1) << "SSL connection result=" << result;
+
+ OnConnectComplete(result);
+}
+
+} // namespace blimp
diff --git a/blimp/net/ssl_client_transport.h b/blimp/net/ssl_client_transport.h
new file mode 100644
index 0000000..0f0a1f4
--- /dev/null
+++ b/blimp/net/ssl_client_transport.h
@@ -0,0 +1,64 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BLIMP_NET_SSL_CLIENT_TRANSPORT_H_
+#define BLIMP_NET_SSL_CLIENT_TRANSPORT_H_
+
+#include <string>
+#include "base/callback_forward.h"
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "blimp/net/blimp_net_export.h"
+#include "blimp/net/blimp_transport.h"
+#include "blimp/net/exact_match_cert_verifier.h"
+#include "blimp/net/tcp_client_transport.h"
+#include "net/base/address_list.h"
+#include "net/base/net_errors.h"
+#include "net/http/transport_security_state.h"
+
+namespace net {
+class ClientSocketFactory;
+class NetLog;
+class StreamSocket;
+class TCPClientSocket;
+class TransportSecurityState;
+} // namespace net
+
+namespace blimp {
+
+class BlimpConnection;
+
+// Creates and connects SSL socket connections to an Engine.
+class BLIMP_NET_EXPORT SSLClientTransport : public TCPClientTransport {
+ public:
+ // |ip_endpoint|: the address to connect to.
+ // |cert|: the certificate required from the remote peer.
+ // SSL connections that use different certificates are rejected.
+ // |net_log|: the socket event log (optional).
+ SSLClientTransport(const net::IPEndPoint& ip_endpoint,
+ scoped_refptr<net::X509Certificate> cert,
+ net::NetLog* net_log);
+
+ ~SSLClientTransport() override;
+
+ // BlimpTransport implementation.
+ const char* GetName() const override;
+
+ private:
+ // Method called after TCPClientSocket::Connect finishes.
+ void OnTCPConnectComplete(int result) override;
+
+ // Method called after SSLClientSocket::Connect finishes.
+ void OnSSLConnectComplete(int result);
+
+ net::IPEndPoint ip_endpoint_;
+ scoped_ptr<ExactMatchCertVerifier> cert_verifier_;
+ net::TransportSecurityState transport_security_state_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLClientTransport);
+};
+
+} // namespace blimp
+
+#endif // BLIMP_NET_SSL_CLIENT_TRANSPORT_H_
diff --git a/blimp/net/ssl_client_transport_unittest.cc b/blimp/net/ssl_client_transport_unittest.cc
new file mode 100644
index 0000000..47724ce
--- /dev/null
+++ b/blimp/net/ssl_client_transport_unittest.cc
@@ -0,0 +1,162 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/bind.h"
+#include "base/message_loop/message_loop.h"
+#include "base/run_loop.h"
+#include "blimp/net/blimp_connection.h"
+#include "blimp/net/ssl_client_transport.h"
+#include "net/base/address_list.h"
+#include "net/base/ip_address.h"
+#include "net/base/net_errors.h"
+#include "net/socket/socket_test_util.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace blimp {
+namespace {
+
+const uint8_t kIPV4Address[] = {127, 0, 0, 1};
+const uint16_t kPort = 6667;
+
+} // namespace
+
+class SSLClientTransportTest : public testing::Test {
+ public:
+ SSLClientTransportTest() {}
+
+ ~SSLClientTransportTest() override {}
+
+ void TearDown() override { base::RunLoop().RunUntilIdle(); }
+
+ MOCK_METHOD1(ConnectComplete, void(int));
+
+ protected:
+ // Methods for provisioning simulated connection outcomes.
+ void SetupTCPSyncSocketConnect(net::IPEndPoint endpoint, int result) {
+ tcp_connect_.set_connect_data(
+ net::MockConnect(net::SYNCHRONOUS, result, endpoint));
+ socket_factory_.AddSocketDataProvider(&tcp_connect_);
+ }
+
+ void SetupTCPAsyncSocketConnect(net::IPEndPoint endpoint, int result) {
+ tcp_connect_.set_connect_data(
+ net::MockConnect(net::ASYNC, result, endpoint));
+ socket_factory_.AddSocketDataProvider(&tcp_connect_);
+ }
+
+ void SetupSSLSyncSocketConnect(int result) {
+ ssl_connect_.reset(
+ new net::SSLSocketDataProvider(net::SYNCHRONOUS, result));
+ socket_factory_.AddSSLSocketDataProvider(ssl_connect_.get());
+ }
+
+ void SetupSSLAsyncSocketConnect(int result) {
+ ssl_connect_.reset(new net::SSLSocketDataProvider(net::ASYNC, result));
+ socket_factory_.AddSSLSocketDataProvider(ssl_connect_.get());
+ }
+
+ void ConfigureTransport(const net::IPEndPoint& ip_endpoint) {
+ // The mock does not interact with the cert directly, so just leave it null.
+ scoped_refptr<net::X509Certificate> cert;
+ transport_.reset(new SSLClientTransport(ip_endpoint, cert, &net_log_));
+ transport_->SetClientSocketFactoryForTest(&socket_factory_);
+ }
+
+ base::MessageLoop message_loop;
+ net::NetLog net_log_;
+ net::StaticSocketDataProvider tcp_connect_;
+ scoped_ptr<net::SSLSocketDataProvider> ssl_connect_;
+ net::MockClientSocketFactory socket_factory_;
+ scoped_ptr<SSLClientTransport> transport_;
+};
+
+TEST_F(SSLClientTransportTest, ConnectSyncOK) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_CALL(*this, ConnectComplete(net::OK));
+ SetupTCPSyncSocketConnect(endpoint, net::OK);
+ SetupSSLSyncSocketConnect(net::OK);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+ EXPECT_NE(nullptr, transport_->TakeConnection().get());
+ base::RunLoop().RunUntilIdle();
+ }
+}
+
+TEST_F(SSLClientTransportTest, ConnectAsyncOK) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_CALL(*this, ConnectComplete(net::OK));
+ SetupTCPAsyncSocketConnect(endpoint, net::OK);
+ SetupSSLAsyncSocketConnect(net::OK);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+ base::RunLoop().RunUntilIdle();
+ EXPECT_NE(nullptr, transport_->TakeConnection().get());
+ }
+}
+
+TEST_F(SSLClientTransportTest, ConnectSyncTCPError) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+ EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED));
+ SetupTCPSyncSocketConnect(endpoint, net::ERR_FAILED);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+}
+
+TEST_F(SSLClientTransportTest, ConnectAsyncTCPError) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+ EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED));
+ SetupTCPAsyncSocketConnect(endpoint, net::ERR_FAILED);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+}
+
+TEST_F(SSLClientTransportTest, ConnectSyncSSLError) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+ EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED));
+ SetupTCPSyncSocketConnect(endpoint, net::OK);
+ SetupSSLSyncSocketConnect(net::ERR_FAILED);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+}
+
+TEST_F(SSLClientTransportTest, ConnectAsyncSSLError) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+ EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED));
+ SetupTCPAsyncSocketConnect(endpoint, net::OK);
+ SetupSSLAsyncSocketConnect(net::ERR_FAILED);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+}
+
+TEST_F(SSLClientTransportTest, ConnectAfterError) {
+ net::IPEndPoint endpoint(kIPV4Address, kPort);
+ ConfigureTransport(endpoint);
+
+ // TCP connection fails.
+ EXPECT_CALL(*this, ConnectComplete(net::ERR_FAILED));
+ SetupTCPSyncSocketConnect(endpoint, net::ERR_FAILED);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+ base::RunLoop().RunUntilIdle();
+
+ // Subsequent TCP+SSL connections succeed.
+ EXPECT_CALL(*this, ConnectComplete(net::OK));
+ SetupTCPSyncSocketConnect(endpoint, net::OK);
+ SetupSSLSyncSocketConnect(net::OK);
+ transport_->Connect(base::Bind(&SSLClientTransportTest::ConnectComplete,
+ base::Unretained(this)));
+ EXPECT_NE(nullptr, transport_->TakeConnection().get());
+ base::RunLoop().RunUntilIdle();
+}
+
+} // namespace blimp
diff --git a/blimp/net/tcp_client_transport.cc b/blimp/net/tcp_client_transport.cc
index 9dbb8803..6a6cc50 100644
--- a/blimp/net/tcp_client_transport.cc
+++ b/blimp/net/tcp_client_transport.cc
@@ -9,38 +9,42 @@
#include "base/memory/scoped_ptr.h"
#include "base/message_loop/message_loop.h"
#include "blimp/net/stream_socket_connection.h"
+#include "net/socket/client_socket_factory.h"
#include "net/socket/stream_socket.h"
#include "net/socket/tcp_client_socket.h"
namespace blimp {
-TCPClientTransport::TCPClientTransport(const net::AddressList& addresses,
+TCPClientTransport::TCPClientTransport(const net::IPEndPoint& ip_endpoint,
net::NetLog* net_log)
- : addresses_(addresses), net_log_(net_log) {}
+ : ip_endpoint_(ip_endpoint),
+ net_log_(net_log),
+ socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {}
TCPClientTransport::~TCPClientTransport() {}
+void TCPClientTransport::SetClientSocketFactoryForTest(
+ net::ClientSocketFactory* factory) {
+ DCHECK(factory);
+ socket_factory_ = factory;
+}
+
void TCPClientTransport::Connect(const net::CompletionCallback& callback) {
DCHECK(!socket_);
DCHECK(!callback.is_null());
- socket_.reset(
- new net::TCPClientSocket(addresses_, net_log_, net::NetLog::Source()));
+ connect_callback_ = callback;
+ socket_ = socket_factory_->CreateTransportClientSocket(
+ net::AddressList(ip_endpoint_), net_log_, net::NetLog::Source());
net::CompletionCallback completion_callback = base::Bind(
&TCPClientTransport::OnTCPConnectComplete, base::Unretained(this));
int result = socket_->Connect(completion_callback);
if (result == net::ERR_IO_PENDING) {
- connect_callback_ = callback;
return;
}
- if (result != net::OK) {
- socket_ = nullptr;
- }
-
- base::MessageLoop::current()->PostTask(FROM_HERE,
- base::Bind(callback, result));
+ OnTCPConnectComplete(result);
}
scoped_ptr<BlimpConnection> TCPClientTransport::TakeConnection() {
@@ -49,17 +53,33 @@ scoped_ptr<BlimpConnection> TCPClientTransport::TakeConnection() {
return make_scoped_ptr(new StreamSocketConnection(std::move(socket_)));
}
-const std::string TCPClientTransport::GetName() const {
+const char* TCPClientTransport::GetName() const {
return "TCP";
}
void TCPClientTransport::OnTCPConnectComplete(int result) {
DCHECK_NE(net::ERR_IO_PENDING, result);
- DCHECK(socket_);
+ OnConnectComplete(result);
+}
+
+void TCPClientTransport::OnConnectComplete(int result) {
if (result != net::OK) {
socket_ = nullptr;
}
base::ResetAndReturn(&connect_callback_).Run(result);
}
+scoped_ptr<net::StreamSocket> TCPClientTransport::TakeSocket() {
+ return std::move(socket_);
+}
+
+void TCPClientTransport::SetSocket(scoped_ptr<net::StreamSocket> socket) {
+ DCHECK(socket);
+ socket_ = std::move(socket);
+}
+
+net::ClientSocketFactory* TCPClientTransport::socket_factory() const {
+ return socket_factory_;
+}
+
} // namespace blimp
diff --git a/blimp/net/tcp_client_transport.h b/blimp/net/tcp_client_transport.h
index 3242883..6a879fe 100644
--- a/blimp/net/tcp_client_transport.h
+++ b/blimp/net/tcp_client_transport.h
@@ -5,6 +5,8 @@
#ifndef BLIMP_NET_TCP_CLIENT_TRANSPORT_H_
#define BLIMP_NET_TCP_CLIENT_TRANSPORT_H_
+#include <string>
+
#include "base/callback.h"
#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
@@ -14,6 +16,7 @@
#include "net/base/net_errors.h"
namespace net {
+class ClientSocketFactory;
class NetLog;
class StreamSocket;
} // namespace net
@@ -26,21 +29,38 @@ class BlimpConnection;
// |addresses| on each call to Connect().
class BLIMP_NET_EXPORT TCPClientTransport : public BlimpTransport {
public:
- TCPClientTransport(const net::AddressList& addresses, net::NetLog* net_log);
+ TCPClientTransport(const net::IPEndPoint& ip_endpoint, net::NetLog* net_log);
~TCPClientTransport() override;
+ void SetClientSocketFactoryForTest(net::ClientSocketFactory* factory);
+
// BlimpTransport implementation.
void Connect(const net::CompletionCallback& callback) override;
scoped_ptr<BlimpConnection> TakeConnection() override;
- const std::string GetName() const override;
+ const char* GetName() const override;
- private:
- void OnTCPConnectComplete(int result);
+ protected:
+ // Called when the TCP connection completed.
+ virtual void OnTCPConnectComplete(int result);
+
+ // Called when the connection attempt completed or failed.
+ // Resets |socket_| if |result| indicates a failure (!= net::OK).
+ void OnConnectComplete(int result);
+
+ // Methods for taking and setting |socket_|. Can be used by subclasses to
+ // swap out a socket for an upgraded one, e.g. adding SSL encryption.
+ scoped_ptr<net::StreamSocket> TakeSocket();
+ void SetSocket(scoped_ptr<net::StreamSocket> socket);
- net::AddressList addresses_;
+ // Gets the socket factory instance.
+ net::ClientSocketFactory* socket_factory() const;
+
+ private:
+ net::IPEndPoint ip_endpoint_;
net::NetLog* net_log_;
- scoped_ptr<net::StreamSocket> socket_;
net::CompletionCallback connect_callback_;
+ net::ClientSocketFactory* socket_factory_ = nullptr;
+ scoped_ptr<net::StreamSocket> socket_;
DISALLOW_COPY_AND_ASSIGN(TCPClientTransport);
};
diff --git a/blimp/net/tcp_engine_transport.cc b/blimp/net/tcp_engine_transport.cc
index 473a1bb..4606caf 100644
--- a/blimp/net/tcp_engine_transport.cc
+++ b/blimp/net/tcp_engine_transport.cc
@@ -61,7 +61,7 @@ scoped_ptr<BlimpConnection> TCPEngineTransport::TakeConnection() {
new StreamSocketConnection(std::move(accepted_socket_)));
}
-const std::string TCPEngineTransport::GetName() const {
+const char* TCPEngineTransport::GetName() const {
return "TCP";
}
diff --git a/blimp/net/tcp_engine_transport.h b/blimp/net/tcp_engine_transport.h
index bf8b7cf..ace223e 100644
--- a/blimp/net/tcp_engine_transport.h
+++ b/blimp/net/tcp_engine_transport.h
@@ -33,7 +33,7 @@ class BLIMP_NET_EXPORT TCPEngineTransport : public BlimpTransport {
// BlimpTransport implementation.
void Connect(const net::CompletionCallback& callback) override;
scoped_ptr<BlimpConnection> TakeConnection() override;
- const std::string GetName() const override;
+ const char* GetName() const override;
int GetLocalAddressForTesting(net::IPEndPoint* address) const;
diff --git a/blimp/net/tcp_transport_unittest.cc b/blimp/net/tcp_transport_unittest.cc
index cef15d3..8f1525a 100644
--- a/blimp/net/tcp_transport_unittest.cc
+++ b/blimp/net/tcp_transport_unittest.cc
@@ -35,10 +35,10 @@ class TCPTransportTest : public testing::Test {
engine_.reset(new TCPEngineTransport(local_address, nullptr));
}
- net::AddressList GetLocalAddressList() const {
+ net::IPEndPoint GetLocalEndpoint() const {
net::IPEndPoint local_address;
- engine_->GetLocalAddressForTesting(&local_address);
- return net::AddressList(local_address);
+ CHECK_EQ(net::OK, engine_->GetLocalAddressForTesting(&local_address));
+ return local_address;
}
base::MessageLoopForIO message_loop_;
@@ -50,7 +50,7 @@ TEST_F(TCPTransportTest, Connect) {
engine_->Connect(accept_callback.callback());
net::TestCompletionCallback connect_callback;
- TCPClientTransport client(GetLocalAddressList(), nullptr);
+ TCPClientTransport client(GetLocalEndpoint(), nullptr);
client.Connect(connect_callback.callback());
EXPECT_EQ(net::OK, connect_callback.WaitForResult());
@@ -63,11 +63,11 @@ TEST_F(TCPTransportTest, TwoClientConnections) {
engine_->Connect(accept_callback1.callback());
net::TestCompletionCallback connect_callback1;
- TCPClientTransport client1(GetLocalAddressList(), nullptr);
+ TCPClientTransport client1(GetLocalEndpoint(), nullptr);
client1.Connect(connect_callback1.callback());
net::TestCompletionCallback connect_callback2;
- TCPClientTransport client2(GetLocalAddressList(), nullptr);
+ TCPClientTransport client2(GetLocalEndpoint(), nullptr);
client2.Connect(connect_callback2.callback());
EXPECT_EQ(net::OK, connect_callback1.WaitForResult());
@@ -86,7 +86,7 @@ TEST_F(TCPTransportTest, ExchangeMessages) {
net::TestCompletionCallback accept_callback;
engine_->Connect(accept_callback.callback());
net::TestCompletionCallback client_connect_callback;
- TCPClientTransport client(GetLocalAddressList(), nullptr /* NetLog */);
+ TCPClientTransport client(GetLocalEndpoint(), nullptr /* NetLog */);
client.Connect(client_connect_callback.callback());
EXPECT_EQ(net::OK, client_connect_callback.WaitForResult());
EXPECT_EQ(net::OK, accept_callback.WaitForResult());
diff --git a/blimp/net/test_common.cc b/blimp/net/test_common.cc
index e6cd956..92a951d 100644
--- a/blimp/net/test_common.cc
+++ b/blimp/net/test_common.cc
@@ -26,7 +26,7 @@ scoped_ptr<BlimpConnection> MockTransport::TakeConnection() {
return make_scoped_ptr(TakeConnectionPtr());
}
-const std::string MockTransport::GetName() const {
+const char* MockTransport::GetName() const {
return "mock";
}
diff --git a/blimp/net/test_common.h b/blimp/net/test_common.h
index 302060a..380465e 100644
--- a/blimp/net/test_common.h
+++ b/blimp/net/test_common.h
@@ -137,7 +137,7 @@ class MockTransport : public BlimpTransport {
MOCK_METHOD0(TakeConnectionPtr, BlimpConnection*());
scoped_ptr<BlimpConnection> TakeConnection() override;
- const std::string GetName() const override;
+ const char* GetName() const override;
};
class MockConnectionHandler : public ConnectionHandler {