diff options
author | sergeyu <sergeyu@chromium.org> | 2015-08-14 11:38:17 -0700 |
---|---|---|
committer | Commit bot <commit-bot@chromium.org> | 2015-08-14 18:38:58 +0000 |
commit | 1c9834963f5c48eaf50d9676a61dec0ed1edae49 (patch) | |
tree | 55436989fb8f85ec17bd797d99b54d0718206477 /remoting/protocol | |
parent | e6642ef4bd6cc11fa85dab1bb727ebdd31053348 (diff) | |
download | chromium_src-1c9834963f5c48eaf50d9676a61dec0ed1edae49.zip chromium_src-1c9834963f5c48eaf50d9676a61dec0ed1edae49.tar.gz chromium_src-1c9834963f5c48eaf50d9676a61dec0ed1edae49.tar.bz2 |
Implement QuicChannel and QuicChannelFactory
QuicChannelFactory implements StreamChannelFactory to create channels
over a QUIC connection.
The new code will be hooked up in a separate CL.
BUG=448838
Review URL: https://codereview.chromium.org/1273233002
Cr-Commit-Position: refs/heads/master@{#343447}
Diffstat (limited to 'remoting/protocol')
-rw-r--r-- | remoting/protocol/BUILD.gn | 1 | ||||
-rw-r--r-- | remoting/protocol/fake_datagram_socket.cc | 34 | ||||
-rw-r--r-- | remoting/protocol/fake_datagram_socket.h | 15 | ||||
-rw-r--r-- | remoting/protocol/quic_channel.cc | 179 | ||||
-rw-r--r-- | remoting/protocol/quic_channel.h | 104 | ||||
-rw-r--r-- | remoting/protocol/quic_channel_factory.cc | 536 | ||||
-rw-r--r-- | remoting/protocol/quic_channel_factory.h | 58 | ||||
-rw-r--r-- | remoting/protocol/quic_channel_factory_unittest.cc | 353 |
8 files changed, 1280 insertions, 0 deletions
diff --git a/remoting/protocol/BUILD.gn b/remoting/protocol/BUILD.gn index 819ff5d..bb1942b 100644 --- a/remoting/protocol/BUILD.gn +++ b/remoting/protocol/BUILD.gn @@ -83,6 +83,7 @@ source_set("unit_tests") { "port_range_unittest.cc", "ppapi_module_stub.cc", "pseudotcp_adapter_unittest.cc", + "quic_channel_factory_unittest.cc", "ssl_hmac_channel_authenticator_unittest.cc", "third_party_authenticator_unittest.cc", "v2_authenticator_unittest.cc", diff --git a/remoting/protocol/fake_datagram_socket.cc b/remoting/protocol/fake_datagram_socket.cc index 53d31d0..010cb36 100644 --- a/remoting/protocol/fake_datagram_socket.cc +++ b/remoting/protocol/fake_datagram_socket.cc @@ -69,6 +69,40 @@ int FakeDatagramSocket::Send(const scoped_refptr<net::IOBuffer>& buf, int buf_len, const net::CompletionCallback& callback) { EXPECT_TRUE(task_runner_->BelongsToCurrentThread()); + EXPECT_FALSE(send_pending_); + + if (async_send_) { + send_pending_ = true; + task_runner_->PostTask( + FROM_HERE, + base::Bind(&FakeDatagramSocket::DoAsyncSend, weak_factory_.GetWeakPtr(), + buf, buf_len, callback)); + return net::ERR_IO_PENDING; + } else { + return DoSend(buf, buf_len); + } +} + +void FakeDatagramSocket::DoAsyncSend(const scoped_refptr<net::IOBuffer>& buf, + int buf_len, + const net::CompletionCallback& callback) { + EXPECT_TRUE(task_runner_->BelongsToCurrentThread()); + + EXPECT_TRUE(send_pending_); + send_pending_ = false; + callback.Run(DoSend(buf, buf_len)); +} + +int FakeDatagramSocket::DoSend(const scoped_refptr<net::IOBuffer>& buf, + int buf_len) { + EXPECT_TRUE(task_runner_->BelongsToCurrentThread()); + + if (next_send_error_ != net::OK) { + int r = next_send_error_; + next_send_error_ = net::OK; + return r; + } + written_packets_.push_back(std::string()); written_packets_.back().assign(buf->data(), buf->data() + buf_len); diff --git a/remoting/protocol/fake_datagram_socket.h b/remoting/protocol/fake_datagram_socket.h index 7270e1d..a22751d 100644 --- a/remoting/protocol/fake_datagram_socket.h +++ b/remoting/protocol/fake_datagram_socket.h @@ -41,6 +41,13 @@ class FakeDatagramSocket : public P2PDatagramSocket { return written_packets_; } + // Enables asynchronous Write(). + void set_async_send(bool async_send) { async_send_ = async_send; } + + // Set error codes for the next Write() call. Once returned the + // value is automatically reset to net::OK . + void set_next_send_error(int error) { next_send_error_ = error; } + void AppendInputPacket(const std::string& data); // Current position in the input in number of packets, i.e. number of finished @@ -62,6 +69,14 @@ class FakeDatagramSocket : public P2PDatagramSocket { private: int CopyReadData(const scoped_refptr<net::IOBuffer>& buf, int buf_len); + void DoAsyncSend(const scoped_refptr<net::IOBuffer>& buf, int buf_len, + const net::CompletionCallback& callback); + int DoSend(const scoped_refptr<net::IOBuffer>& buf, int buf_len); + + bool async_send_ = false; + bool send_pending_ = false; + int next_send_error_ = 0; + base::WeakPtr<FakeDatagramSocket> peer_socket_; scoped_refptr<net::IOBuffer> read_buffer_; diff --git a/remoting/protocol/quic_channel.cc b/remoting/protocol/quic_channel.cc new file mode 100644 index 0000000..66fa39b --- /dev/null +++ b/remoting/protocol/quic_channel.cc @@ -0,0 +1,179 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/quic_channel.h" + +#include "base/callback_helpers.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" + +namespace remoting { +namespace protocol { + +static const size_t kNamePrefixLength = 1; +static const size_t kMaxNameLength = 255; + +QuicChannel::QuicChannel(net::QuicP2PStream* stream, + const base::Closure& on_destroyed_callback) + : stream_(stream), on_destroyed_callback_(on_destroyed_callback) { + DCHECK(stream_); + stream_->SetDelegate(this); +} + +QuicChannel::~QuicChannel() { + // Don't call the read callback when destroying the stream. + read_callback_.Reset(); + + on_destroyed_callback_.Run(); + + // The callback must destroy the stream which must result in OnClose(). + DCHECK(!stream_); +} + +int QuicChannel::Read(const scoped_refptr<net::IOBuffer>& buffer, + int buffer_len, + const net::CompletionCallback& callback) { + DCHECK(read_callback_.is_null()); + + if (error_ != net::OK) + return error_; + + if (data_received_.total_bytes() == 0) { + read_buffer_ = buffer; + read_buffer_size_ = buffer_len; + read_callback_ = callback; + return net::ERR_IO_PENDING; + } + + int result = std::min(buffer_len, data_received_.total_bytes()); + data_received_.CopyTo(buffer->data(), result); + data_received_.CropFront(result); + return result; +} + +int QuicChannel::Write(const scoped_refptr<net::IOBuffer>& buffer, + int buffer_len, + const net::CompletionCallback& callback) { + if (error_ != net::OK) + return error_; + + return stream_->Write(base::StringPiece(buffer->data(), buffer_len), + callback); +} + +void QuicChannel::SetName(const std::string& name) { + DCHECK(name_.empty()); + + name_ = name; +} + +void QuicChannel::OnDataReceived(const char* data, int length) { + if (read_callback_.is_null()) { + data_received_.AppendCopyOf(data, length); + return; + } + + DCHECK_EQ(data_received_.total_bytes(), 0); + int bytes_to_read = std::min(length, read_buffer_size_); + memcpy(read_buffer_->data(), data, bytes_to_read); + read_buffer_ = nullptr; + + // Copy leftover data to |data_received_|. + if (length > bytes_to_read) + data_received_.AppendCopyOf(data + bytes_to_read, length - bytes_to_read); + + base::ResetAndReturn(&read_callback_).Run(bytes_to_read); +} + +void QuicChannel::OnClose(net::QuicErrorCode error) { + error_ = (error == net::QUIC_NO_ERROR) ? net::ERR_CONNECTION_CLOSED + : net::ERR_QUIC_PROTOCOL_ERROR; + stream_ = nullptr; + if (!read_callback_.is_null()) { + base::ResetAndReturn(&read_callback_).Run(error_); + } +} + +QuicClientChannel::QuicClientChannel(net::QuicP2PStream* stream, + const base::Closure& on_destroyed_callback, + const std::string& name) + : QuicChannel(stream, on_destroyed_callback) { + CHECK_LE(name.size(), kMaxNameLength); + + SetName(name); + + // Send the name to the host. + stream_->WriteHeader( + std::string(kNamePrefixLength, static_cast<char>(name.size())) + name); +} + +QuicClientChannel::~QuicClientChannel() {} + +QuicServerChannel::QuicServerChannel( + net::QuicP2PStream* stream, + const base::Closure& on_destroyed_callback) + : QuicChannel(stream, on_destroyed_callback) {} + +void QuicServerChannel::ReceiveName( + const base::Closure& name_received_callback) { + name_received_callback_ = name_received_callback; + + // First read 1 byte containing name length. + name_read_buffer_ = new net::DrainableIOBuffer( + new net::IOBuffer(kNamePrefixLength), kNamePrefixLength); + int result = Read(name_read_buffer_, kNamePrefixLength, + base::Bind(&QuicServerChannel::OnNameSizeReadResult, + base::Unretained(this))); + if (result != net::ERR_IO_PENDING) + OnNameSizeReadResult(result); +} + +QuicServerChannel::~QuicServerChannel() {} + +void QuicServerChannel::OnNameSizeReadResult(int result) { + if (result < 0) { + base::ResetAndReturn(&name_received_callback_).Run(); + return; + } + + DCHECK_EQ(result, static_cast<int>(kNamePrefixLength)); + name_length_ = *reinterpret_cast<uint8_t*>(name_read_buffer_->data()); + name_read_buffer_ = + new net::DrainableIOBuffer(new net::IOBuffer(name_length_), name_length_); + ReadNameLoop(0); +} + +void QuicServerChannel::ReadNameLoop(int result) { + while (result >= 0 && name_read_buffer_->BytesRemaining() > 0) { + result = Read(name_read_buffer_, name_read_buffer_->BytesRemaining(), + base::Bind(&QuicServerChannel::OnNameReadResult, + base::Unretained(this))); + if (result >= 0) { + name_read_buffer_->DidConsume(result); + } + } + + if (result < 0 && result != net::ERR_IO_PENDING) { + // Failed to read name for the stream. + base::ResetAndReturn(&name_received_callback_).Run(); + return; + } + + if (name_read_buffer_->BytesRemaining() == 0) { + name_read_buffer_->SetOffset(0); + SetName(std::string(name_read_buffer_->data(), + name_read_buffer_->data() + name_length_)); + base::ResetAndReturn(&name_received_callback_).Run(); + } +} + +void QuicServerChannel::OnNameReadResult(int result) { + if (result > 0) + name_read_buffer_->DidConsume(result); + + ReadNameLoop(result); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/quic_channel.h b/remoting/protocol/quic_channel.h new file mode 100644 index 0000000..64c346a --- /dev/null +++ b/remoting/protocol/quic_channel.h @@ -0,0 +1,104 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_QUIC_CHANNEL_H_ +#define REMOTING_PROTOCOL_QUIC_CHANNEL_H_ + +#include "net/quic/p2p/quic_p2p_stream.h" +#include "remoting/base/compound_buffer.h" +#include "remoting/protocol/p2p_stream_socket.h" + +namespace net { +class DrainableIOBuffer; +} // namespace net + +namespace remoting { +namespace protocol { + +// QuicChannel implements P2PStreamSocket interface for a QuicP2PStream. +class QuicChannel : public net::QuicP2PStream::Delegate, + public P2PStreamSocket { + public: + QuicChannel(net::QuicP2PStream* stream, + const base::Closure& on_destroyed_callback); + ~QuicChannel() override; + + const std::string& name() { return name_; } + + // P2PStreamSocket interface. + int Read(const scoped_refptr<net::IOBuffer>& buffer, + int buffer_len, + const net::CompletionCallback& callback) override; + int Write(const scoped_refptr<net::IOBuffer>& buffer, + int buffer_len, + const net::CompletionCallback& callback) override; + + protected: + void SetName(const std::string& name); + + // Owned by QuicSession. + net::QuicP2PStream* stream_; + + private: + // net::QuicP2PStream::Delegate interface. + void OnDataReceived(const char* data, int length) override; + void OnClose(net::QuicErrorCode error) override; + + base::Closure on_destroyed_callback_; + + std::string name_; + + CompoundBuffer data_received_; + + net::CompletionCallback read_callback_; + scoped_refptr<net::IOBuffer> read_buffer_; + int read_buffer_size_ = 0; + + int error_ = 0; + + DISALLOW_COPY_AND_ASSIGN(QuicChannel); +}; + +// Client side of a channel. Sends the |name| specified in the constructor to +// the peer. +class QuicClientChannel : public QuicChannel { + public: + QuicClientChannel(net::QuicP2PStream* stream, + const base::Closure& on_destroyed_callback, + const std::string& name); + ~QuicClientChannel() override; + + private: + DISALLOW_COPY_AND_ASSIGN(QuicClientChannel); +}; + +// Host side of a channel. Receives name from the peer after ReceiveName is +// called. Read() can be called only after the name is received. +class QuicServerChannel : public QuicChannel { + public: + QuicServerChannel(net::QuicP2PStream* stream, + const base::Closure& on_destroyed_callback); + ~QuicServerChannel() override; + + // Must be called after the constructor to receive channel name. + // |name_received_callback| must use QuicChannel::name() to get the name. + // Empty name() indicates failure to receive it. + void ReceiveName(const base::Closure& name_received_callback); + + private: + void OnNameSizeReadResult(int result); + void ReadNameLoop(int result); + void OnNameReadResult(int result); + + base::Closure name_received_callback_; + uint8_t name_length_ = 0; + scoped_refptr<net::DrainableIOBuffer> name_read_buffer_; + + DISALLOW_COPY_AND_ASSIGN(QuicServerChannel); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_QUIC_CHANNEL_H_ diff --git a/remoting/protocol/quic_channel_factory.cc b/remoting/protocol/quic_channel_factory.cc new file mode 100644 index 0000000..1046e87 --- /dev/null +++ b/remoting/protocol/quic_channel_factory.cc @@ -0,0 +1,536 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/quic_channel_factory.h" + +#include <vector> + +#include "base/bind.h" +#include "base/location.h" +#include "base/single_thread_task_runner.h" +#include "base/stl_util.h" +#include "base/thread_task_runner_handle.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/quic/crypto/crypto_framer.h" +#include "net/quic/crypto/crypto_handshake_message.h" +#include "net/quic/crypto/crypto_protocol.h" +#include "net/quic/crypto/quic_random.h" +#include "net/quic/p2p/quic_p2p_crypto_config.h" +#include "net/quic/p2p/quic_p2p_session.h" +#include "net/quic/p2p/quic_p2p_stream.h" +#include "net/quic/quic_clock.h" +#include "net/quic/quic_connection_helper.h" +#include "net/quic/quic_default_packet_writer.h" +#include "net/socket/stream_socket.h" +#include "remoting/base/constants.h" +#include "remoting/protocol/datagram_channel_factory.h" +#include "remoting/protocol/p2p_datagram_socket.h" +#include "remoting/protocol/quic_channel.h" + +namespace remoting { +namespace protocol { + +namespace { + +// The maximum receive window sizes for QUIC sessions and streams. These are +// the same values that are used in chrome. +const int kQuicSessionMaxRecvWindowSize = 15 * 1024 * 1024; // 15 MB +const int kQuicStreamMaxRecvWindowSize = 6 * 1024 * 1024; // 6 MB + +class P2PQuicPacketWriter : public net::QuicPacketWriter { + public: + P2PQuicPacketWriter(net::QuicConnection* connection, + P2PDatagramSocket* socket) + : connection_(connection), socket_(socket), weak_factory_(this) {} + ~P2PQuicPacketWriter() override {} + + // QuicPacketWriter interface. + net::WriteResult WritePacket(const char* buffer, + size_t buf_len, + const net::IPAddressNumber& self_address, + const net::IPEndPoint& peer_address) override { + DCHECK(!write_blocked_); + + scoped_refptr<net::StringIOBuffer> buf( + new net::StringIOBuffer(std::string(buffer, buf_len))); + int result = socket_->Send(buf, buf_len, + base::Bind(&P2PQuicPacketWriter::OnSendComplete, + weak_factory_.GetWeakPtr())); + net::WriteStatus status = net::WRITE_STATUS_OK; + if (result < 0) { + if (result == net::ERR_IO_PENDING) { + status = net::WRITE_STATUS_BLOCKED; + write_blocked_ = true; + } else { + status = net::WRITE_STATUS_ERROR; + } + } + + return net::WriteResult(status, result); + } + bool IsWriteBlockedDataBuffered() const override { + // P2PDatagramSocket::Send() method buffer the data until the Send is + // unblocked. + return true; + } + bool IsWriteBlocked() const override { return write_blocked_; } + void SetWritable() override { write_blocked_ = false; } + + private: + void OnSendComplete(int result){ + DCHECK_NE(result, net::ERR_IO_PENDING); + write_blocked_ = false; + if (result < 0) { + connection_->OnWriteError(result); + } + connection_->OnCanWrite(); + } + + net::QuicConnection* connection_; + P2PDatagramSocket* socket_; + + // Whether a write is currently in flight. + bool write_blocked_ = false; + + base::WeakPtrFactory<P2PQuicPacketWriter> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(P2PQuicPacketWriter); +}; + +class QuicPacketWriterFactory + : public net::QuicConnection::PacketWriterFactory { + public: + explicit QuicPacketWriterFactory(P2PDatagramSocket* socket) + : socket_(socket) {} + ~QuicPacketWriterFactory() override {} + + net::QuicPacketWriter* Create( + net::QuicConnection* connection) const override { + return new P2PQuicPacketWriter(connection, socket_); + } + + private: + P2PDatagramSocket* socket_; +}; + +class P2PDatagramSocketAdapter : public net::Socket { + public: + explicit P2PDatagramSocketAdapter(scoped_ptr<P2PDatagramSocket> socket) + : socket_(socket.Pass()) {} + ~P2PDatagramSocketAdapter() override {} + + int Read(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) override { + return socket_->Recv(buf, buf_len, callback); + } + int Write(net::IOBuffer* buf, int buf_len, + const net::CompletionCallback& callback) override { + return socket_->Send(buf, buf_len, callback); + } + + int SetReceiveBufferSize(int32_t size) override { + NOTREACHED(); + return net::ERR_FAILED; + } + + int SetSendBufferSize(int32_t size) override { + NOTREACHED(); + return net::ERR_FAILED; + } + + private: + scoped_ptr<P2PDatagramSocket> socket_; +}; + +} // namespace + +class QuicChannelFactory::Core : public net::QuicP2PSession::Delegate { + public: + Core(const std::string& session_id, bool is_server); + virtual ~Core(); + + // Called from ~QuicChannelFactory() to synchronously release underlying + // socket. Core is destroyed later asynchronously. + void Close(); + + // Implementation of all all methods for QuicChannelFactory. + const std::string& CreateSessionInitiateConfigMessage(); + bool ProcessSessionAcceptConfigMessage(const std::string& message); + + bool ProcessSessionInitiateConfigMessage(const std::string& message); + const std::string& CreateSessionAcceptConfigMessage(); + + void Start(DatagramChannelFactory* factory, const std::string& shared_secret); + + void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback); + void CancelChannelCreation(const std::string& name); + + private: + friend class QuicChannelFactory; + + struct PendingChannel { + PendingChannel(const std::string& name, + const ChannelCreatedCallback& callback) + : name(name), callback(callback) {} + + std::string name; + ChannelCreatedCallback callback; + }; + + // QuicP2PSession::Delegate interface. + void OnIncomingStream(net::QuicP2PStream* stream) override; + void OnConnectionClosed(net::QuicErrorCode error) override; + + void OnBaseChannelReady(scoped_ptr<P2PDatagramSocket> socket); + + void OnNameReceived(QuicChannel* channel); + + void OnChannelDestroyed(int stream_id); + + std::string session_id_; + bool is_server_; + DatagramChannelFactory* base_channel_factory_ = nullptr; + + net::QuicConfig quic_config_; + std::string shared_secret_; + std::string session_initiate_quic_config_message_; + std::string session_accept_quic_config_message_; + + net::QuicClock quic_clock_; + net::QuicConnectionHelper quic_helper_; + scoped_ptr<net::QuicP2PSession> quic_session_; + bool connected_ = false; + + std::vector<PendingChannel*> pending_channels_; + std::vector<QuicChannel*> unnamed_incoming_channels_; + + base::WeakPtrFactory<Core> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(Core); +}; + +QuicChannelFactory::Core::Core(const std::string& session_id, bool is_server) + : session_id_(session_id), + is_server_(is_server), + quic_helper_(base::ThreadTaskRunnerHandle::Get().get(), + &quic_clock_, + net::QuicRandom::GetInstance()), + weak_factory_(this) { + quic_config_.SetInitialSessionFlowControlWindowToSend( + kQuicSessionMaxRecvWindowSize); + quic_config_.SetInitialStreamFlowControlWindowToSend( + kQuicStreamMaxRecvWindowSize); +} + +QuicChannelFactory::Core::~Core() {} + +void QuicChannelFactory::Core::Close() { + DCHECK(pending_channels_.empty()); + + // Cancel creation of the base channel if it hasn't finished. + if (base_channel_factory_) + base_channel_factory_->CancelChannelCreation(kQuicChannelName); + + if (quic_session_ && quic_session_->connection()->connected()) + quic_session_->connection()->CloseConnection(net::QUIC_NO_ERROR, false); + + DCHECK(unnamed_incoming_channels_.empty()); +} + +void QuicChannelFactory::Core::Start(DatagramChannelFactory* factory, + const std::string& shared_secret) { + base_channel_factory_ = factory; + shared_secret_ = shared_secret; + + base_channel_factory_->CreateChannel( + kQuicChannelName, + base::Bind(&Core::OnBaseChannelReady, weak_factory_.GetWeakPtr())); +} + +const std::string& +QuicChannelFactory::Core::CreateSessionInitiateConfigMessage() { + DCHECK(!is_server_); + + net::CryptoHandshakeMessage handshake_message; + handshake_message.set_tag(net::kCHLO); + quic_config_.ToHandshakeMessage(&handshake_message); + + session_initiate_quic_config_message_ = + handshake_message.GetSerialized().AsStringPiece().as_string(); + return session_initiate_quic_config_message_; +} + +bool QuicChannelFactory::Core::ProcessSessionAcceptConfigMessage( + const std::string& message) { + DCHECK(!is_server_); + + session_accept_quic_config_message_ = message; + + scoped_ptr<net::CryptoHandshakeMessage> parsed_message( + net::CryptoFramer::ParseMessage(message)); + if (!parsed_message) { + LOG(ERROR) << "Received invalid QUIC config."; + return false; + } + + if (parsed_message->tag() != net::kSHLO) { + LOG(ERROR) << "Received QUIC handshake message with unexpected tag " + << parsed_message->tag(); + return false; + } + + std::string error_message; + net::QuicErrorCode error = quic_config_.ProcessPeerHello( + *parsed_message, net::SERVER, &error_message); + if (error != net::QUIC_NO_ERROR) { + LOG(ERROR) << "Failed to process QUIC handshake message: " + << error_message; + return false; + } + + return true; +} + +bool QuicChannelFactory::Core::ProcessSessionInitiateConfigMessage( + const std::string& message) { + DCHECK(is_server_); + + session_initiate_quic_config_message_ = message; + + scoped_ptr<net::CryptoHandshakeMessage> parsed_message( + net::CryptoFramer::ParseMessage(message)); + if (!parsed_message) { + LOG(ERROR) << "Received invalid QUIC config."; + return false; + } + + if (parsed_message->tag() != net::kCHLO) { + LOG(ERROR) << "Received QUIC handshake message with unexpected tag " + << parsed_message->tag(); + return false; + } + + std::string error_message; + net::QuicErrorCode error = quic_config_.ProcessPeerHello( + *parsed_message, net::CLIENT, &error_message); + if (error != net::QUIC_NO_ERROR) { + LOG(ERROR) << "Failed to process QUIC handshake message: " + << error_message; + return false; + } + + return true; +} + +const std::string& +QuicChannelFactory::Core::CreateSessionAcceptConfigMessage() { + DCHECK(is_server_); + + if (session_initiate_quic_config_message_.empty()) { + // Don't send quic-config to the client if the client didn't include the + // config in the session-initiate message. + DCHECK(session_accept_quic_config_message_.empty()); + return session_accept_quic_config_message_; + } + + net::CryptoHandshakeMessage handshake_message; + handshake_message.set_tag(net::kSHLO); + quic_config_.ToHandshakeMessage(&handshake_message); + + session_accept_quic_config_message_ = + handshake_message.GetSerialized().AsStringPiece().as_string(); + return session_accept_quic_config_message_; +} + +// StreamChannelFactory interface. +void QuicChannelFactory::Core::CreateChannel( + const std::string& name, + const ChannelCreatedCallback& callback) { + if (quic_session_ && quic_session_->connection()->connected()) { + if (!is_server_) { + net::QuicP2PStream* stream = quic_session_->CreateOutgoingDynamicStream(); + scoped_ptr<QuicChannel> channel(new QuicClientChannel( + stream, base::Bind(&Core::OnChannelDestroyed, base::Unretained(this), + stream->id()), + name)); + callback.Run(channel.Pass()); + } else { + // On the server side wait for the client to create a QUIC stream and + // send the name. The channel will be connected in OnNameReceived(). + pending_channels_.push_back(new PendingChannel(name, callback)); + } + } else if (!base_channel_factory_) { + // Fail synchronously if we failed to connect transport. + callback.Run(nullptr); + } else { + // Still waiting for the transport. + pending_channels_.push_back(new PendingChannel(name, callback)); + } +} + +void QuicChannelFactory::Core::CancelChannelCreation(const std::string& name) { + for (auto it = pending_channels_.begin(); it != pending_channels_.end(); + ++it) { + if ((*it)->name == name) { + delete *it; + pending_channels_.erase(it); + return; + } + } +} + +void QuicChannelFactory::Core::OnBaseChannelReady( + scoped_ptr<P2PDatagramSocket> socket) { + base_channel_factory_ = nullptr; + + // Failed to connect underlying transport connection. Fail all pending + // channel. + if (!socket) { + while (!pending_channels_.empty()) { + scoped_ptr<PendingChannel> pending_channel(pending_channels_.front()); + pending_channels_.erase(pending_channels_.begin()); + pending_channel->callback.Run(nullptr); + } + return; + } + + QuicPacketWriterFactory writer_factory(socket.get()); + net::IPAddressNumber ip(net::kIPv4AddressSize, 0); + scoped_ptr<net::QuicConnection> quic_connection(new net::QuicConnection( + 0, net::IPEndPoint(ip, 0), &quic_helper_, writer_factory, + true /* owns_writer */, + is_server_ ? net::Perspective::IS_SERVER : net::Perspective::IS_CLIENT, + true /* is_secure */, net::QuicSupportedVersions())); + + net::QuicP2PCryptoConfig quic_crypto_config(shared_secret_); + quic_crypto_config.set_hkdf_input_suffix( + session_id_ + "\0" + kQuicChannelName + + session_initiate_quic_config_message_ + + session_accept_quic_config_message_); + + quic_session_.reset(new net::QuicP2PSession( + quic_config_, quic_crypto_config, quic_connection.Pass(), + make_scoped_ptr(new P2PDatagramSocketAdapter(socket.Pass())))); + quic_session_->SetDelegate(this); + quic_session_->Initialize(); + + if (!is_server_) { + // On the client create streams for all pending channels and send a name for + // each channel. + while (!pending_channels_.empty()) { + scoped_ptr<PendingChannel> pending_channel(pending_channels_.front()); + pending_channels_.erase(pending_channels_.begin()); + + net::QuicP2PStream* stream = quic_session_->CreateOutgoingDynamicStream(); + scoped_ptr<QuicChannel> channel(new QuicClientChannel( + stream, base::Bind(&Core::OnChannelDestroyed, base::Unretained(this), + stream->id()), + pending_channel->name)); + pending_channel->callback.Run(channel.Pass()); + } + } +} + +void QuicChannelFactory::Core::OnIncomingStream(net::QuicP2PStream* stream) { + QuicServerChannel* channel = new QuicServerChannel( + stream, base::Bind(&Core::OnChannelDestroyed, base::Unretained(this), + stream->id())); + unnamed_incoming_channels_.push_back(channel); + channel->ReceiveName( + base::Bind(&Core::OnNameReceived, base::Unretained(this), channel)); +} + +void QuicChannelFactory::Core::OnConnectionClosed(net::QuicErrorCode error) { + if (error != net::QUIC_NO_ERROR) + LOG(ERROR) << "QUIC connection was closed, error_code=" << error; + + while (!pending_channels_.empty()) { + scoped_ptr<PendingChannel> pending_channel(pending_channels_.front()); + pending_channels_.erase(pending_channels_.begin()); + pending_channel->callback.Run(nullptr); + } +} + +void QuicChannelFactory::Core::OnNameReceived(QuicChannel* channel) { + DCHECK(is_server_); + + scoped_ptr<QuicChannel> owned_channel(channel); + + auto it = std::find(unnamed_incoming_channels_.begin(), + unnamed_incoming_channels_.end(), channel); + DCHECK(it != unnamed_incoming_channels_.end()); + unnamed_incoming_channels_.erase(it); + + if (channel->name().empty()) { + // Failed to read a name for incoming channel. + return; + } + + for (auto it = pending_channels_.begin(); + it != pending_channels_.end(); ++it) { + if ((*it)->name == channel->name()) { + scoped_ptr<PendingChannel> pending_channel(*it); + pending_channels_.erase(it); + pending_channel->callback.Run(owned_channel.Pass()); + return; + } + } + + LOG(ERROR) << "Unexpected incoming channel: " << channel->name(); +} + +void QuicChannelFactory::Core::OnChannelDestroyed(int stream_id) { + if (quic_session_) + quic_session_->CloseStream(stream_id); +} + +QuicChannelFactory::QuicChannelFactory(const std::string& session_id, + bool is_server) + : core_(new Core(session_id, is_server)) {} + +QuicChannelFactory::~QuicChannelFactory() { + core_->Close(); + base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, core_.release()); +} + +const std::string& QuicChannelFactory::CreateSessionInitiateConfigMessage() { + return core_->CreateSessionInitiateConfigMessage(); +} + +bool QuicChannelFactory::ProcessSessionAcceptConfigMessage( + const std::string& message) { + return core_->ProcessSessionAcceptConfigMessage(message); +} + +bool QuicChannelFactory::ProcessSessionInitiateConfigMessage( + const std::string& message) { + return core_->ProcessSessionInitiateConfigMessage(message); +} + +const std::string& QuicChannelFactory::CreateSessionAcceptConfigMessage() { + return core_->CreateSessionAcceptConfigMessage(); +} + +void QuicChannelFactory::Start(DatagramChannelFactory* factory, + const std::string& shared_secret) { + core_->Start(factory, shared_secret); +} + +void QuicChannelFactory::CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) { + core_->CreateChannel(name, callback); +} + +void QuicChannelFactory::CancelChannelCreation(const std::string& name) { + core_->CancelChannelCreation(name); +} + +net::QuicP2PSession* QuicChannelFactory::GetP2PSessionForTests() { + return core_->quic_session_.get(); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/quic_channel_factory.h b/remoting/protocol/quic_channel_factory.h new file mode 100644 index 0000000..fd12f4c --- /dev/null +++ b/remoting/protocol/quic_channel_factory.h @@ -0,0 +1,58 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_QUIC_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_QUIC_CHANNEL_FACTORY_H_ + +#include "base/memory/scoped_ptr.h" +#include "remoting/protocol/stream_channel_factory.h" + +namespace net { +class QuicP2PSession; +} // namespace net + +namespace remoting { +namespace protocol { + +class DatagramChannelFactory; + +// QuicChannelFactory is responsible for QUIC connection between client and +// host. +class QuicChannelFactory : public StreamChannelFactory { + public: + QuicChannelFactory(const std::string& session_id, bool is_server); + ~QuicChannelFactory() override; + + // QuicConfig handshake handlers for the client side. + const std::string& CreateSessionInitiateConfigMessage(); + bool ProcessSessionAcceptConfigMessage(const std::string& message); + + // QuicConfig handshake handlers for the server side. + bool ProcessSessionInitiateConfigMessage(const std::string& message); + const std::string& CreateSessionAcceptConfigMessage(); + + // Creates a QUIC connection using a datagram channel created using |factory|. + // Must be called after successful handshake using the methods above. + // |shared_secret| must contain the shared key generated by the authentication + // handshake. + void Start(DatagramChannelFactory* factory, const std::string& shared_secret); + + // StreamChannelFactory interface. + void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) override; + void CancelChannelCreation(const std::string& name) override; + + net::QuicP2PSession* GetP2PSessionForTests(); + + private: + class Core; + scoped_ptr<Core> core_; + + DISALLOW_COPY_AND_ASSIGN(QuicChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_QUIC_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/quic_channel_factory_unittest.cc b/remoting/protocol/quic_channel_factory_unittest.cc new file mode 100644 index 0000000..e860664 --- /dev/null +++ b/remoting/protocol/quic_channel_factory_unittest.cc @@ -0,0 +1,353 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/quic_channel_factory.h" + +#include "base/bind.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/quic/p2p/quic_p2p_session.h" +#include "net/quic/p2p/quic_p2p_stream.h" +#include "net/socket/socket.h" +#include "remoting/base/constants.h" +#include "remoting/protocol/connection_tester.h" +#include "remoting/protocol/fake_datagram_socket.h" +#include "remoting/protocol/p2p_stream_socket.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using testing::_; +using testing::AtMost; +using testing::InvokeWithoutArgs; + +namespace remoting { +namespace protocol { + +namespace { + +const int kMessageSize = 1024; +const int kMessages = 100; + +const char kTestChannelName[] = "test"; +const char kTestChannelName2[] = "test2"; + +} // namespace + +class QuicChannelFactoryTest : public testing::Test, + public testing::WithParamInterface<bool> { + public: + void DeleteAll() { + host_channel1_.reset(); + host_channel2_.reset(); + client_channel1_.reset(); + client_channel2_.reset(); + host_quic_.reset(); + client_quic_.reset(); + } + + void FailedReadDeleteAll(int result) { + EXPECT_NE(net::OK, result); + DeleteAll(); + } + + void OnChannelConnected(scoped_ptr<P2PStreamSocket>* storage, + int* counter, + base::RunLoop* run_loop, + scoped_ptr<P2PStreamSocket> socket) { + *storage = socket.Pass(); + if (counter) { + --(*counter); + EXPECT_GE(*counter, 0); + if (*counter == 0) + run_loop->Quit(); + } + } + + void OnChannelConnectedExpectFail(scoped_ptr<P2PStreamSocket> socket) { + EXPECT_FALSE(socket); + host_quic_->CancelChannelCreation(kTestChannelName2); + DeleteAll(); + } + + void OnChannelConnectedNotReached(scoped_ptr<P2PStreamSocket> socket) { + NOTREACHED(); + } + + protected: + void TearDown() override { + DeleteAll(); + // QuicChannelFactory destroys the internals asynchronously. Run all pending + // tasks to avoid leaking memory. + base::RunLoop().RunUntilIdle(); + } + + void Initialize() { + host_base_channel_factory_.PairWith(&client_base_channel_factory_); + host_base_channel_factory_.set_asynchronous_create(GetParam()); + client_base_channel_factory_.set_asynchronous_create(GetParam()); + + const char kTestSessionId[] = "123123"; + host_quic_.reset(new QuicChannelFactory(kTestSessionId, true)); + client_quic_.reset(new QuicChannelFactory(kTestSessionId, false)); + + std::string message = client_quic_->CreateSessionInitiateConfigMessage(); + EXPECT_TRUE(host_quic_->ProcessSessionInitiateConfigMessage(message)); + message = host_quic_->CreateSessionAcceptConfigMessage(); + EXPECT_TRUE(client_quic_->ProcessSessionAcceptConfigMessage(message)); + + const char kTestSharedSecret[] = "Shared Secret"; + host_quic_->Start(&host_base_channel_factory_, kTestSharedSecret); + client_quic_->Start(&client_base_channel_factory_, kTestSharedSecret); + + FakeDatagramSocket* host_base_channel = + host_base_channel_factory_.GetFakeChannel(kQuicChannelName); + if (host_base_channel) + host_base_channel->set_async_send(GetParam()); + + FakeDatagramSocket* client_base_channel = + client_base_channel_factory_.GetFakeChannel(kQuicChannelName); + if (client_base_channel) + client_base_channel->set_async_send(GetParam()); + } + + void CreateChannel(const std::string& name, + scoped_ptr<P2PStreamSocket>* host_channel, + scoped_ptr<P2PStreamSocket>* client_channel) { + int counter = 2; + base::RunLoop run_loop; + host_quic_->CreateChannel( + name, + base::Bind(&QuicChannelFactoryTest::OnChannelConnected, + base::Unretained(this), host_channel, &counter, &run_loop)); + client_quic_->CreateChannel( + name, base::Bind(&QuicChannelFactoryTest::OnChannelConnected, + base::Unretained(this), client_channel, &counter, + &run_loop)); + + run_loop.Run(); + + EXPECT_TRUE(host_channel->get()); + EXPECT_TRUE(client_channel->get()); + } + + scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) { + scoped_refptr<net::IOBufferWithSize> result = + new net::IOBufferWithSize(size); + for (int i = 0; i < size; ++i) { + result->data()[i] = rand() % 256; + } + return result; + } + + base::MessageLoop message_loop_; + + FakeDatagramChannelFactory host_base_channel_factory_; + FakeDatagramChannelFactory client_base_channel_factory_; + + scoped_ptr<QuicChannelFactory> host_quic_; + scoped_ptr<QuicChannelFactory> client_quic_; + + scoped_ptr<P2PStreamSocket> host_channel1_; + scoped_ptr<P2PStreamSocket> client_channel1_; + scoped_ptr<P2PStreamSocket> host_channel2_; + scoped_ptr<P2PStreamSocket> client_channel2_; +}; + +INSTANTIATE_TEST_CASE_P(SyncWrite, + QuicChannelFactoryTest, + ::testing::Values(false)); +INSTANTIATE_TEST_CASE_P(AsyncWrite, + QuicChannelFactoryTest, + ::testing::Values(true)); + +TEST_P(QuicChannelFactoryTest, OneChannel) { + Initialize(); + + scoped_ptr<P2PStreamSocket> host_channel; + scoped_ptr<P2PStreamSocket> client_channel; + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName, &host_channel, &client_channel)); + + StreamConnectionTester tester(host_channel.get(), client_channel.get(), + kMessageSize, kMessages); + tester.Start(); + message_loop_.Run(); + tester.CheckResults(); +} + +TEST_P(QuicChannelFactoryTest, TwoChannels) { + Initialize(); + + scoped_ptr<P2PStreamSocket> host_channel1_; + scoped_ptr<P2PStreamSocket> client_channel1_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName, &host_channel1_, &client_channel1_)); + + scoped_ptr<P2PStreamSocket> host_channel2_; + scoped_ptr<P2PStreamSocket> client_channel2_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName2, &host_channel2_, &client_channel2_)); + + StreamConnectionTester tester1(host_channel1_.get(), client_channel1_.get(), + kMessageSize, kMessages); + StreamConnectionTester tester2(host_channel2_.get(), client_channel2_.get(), + kMessageSize, kMessages); + tester1.Start(); + tester2.Start(); + while (!tester1.done() || !tester2.done()) { + message_loop_.Run(); + } + tester1.CheckResults(); + tester2.CheckResults(); +} + +TEST_P(QuicChannelFactoryTest, SendFail) { + Initialize(); + + scoped_ptr<P2PStreamSocket> host_channel1_; + scoped_ptr<P2PStreamSocket> client_channel1_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName, &host_channel1_, &client_channel1_)); + + scoped_ptr<P2PStreamSocket> host_channel2_; + scoped_ptr<P2PStreamSocket> client_channel2_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName2, &host_channel2_, &client_channel2_)); + + host_base_channel_factory_.GetFakeChannel(kQuicChannelName) + ->set_next_send_error(net::ERR_FAILED); + + scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); + + + // Try writing to a channel. This should result in all stream being closed due + // to an error. + { + net::TestCompletionCallback write_cb_1; + host_channel1_->Write(buf.get(), buf->size(), write_cb_1.callback()); + base::RunLoop().RunUntilIdle(); + } + + // Repeated attempt to write should result in an error. + { + net::TestCompletionCallback write_cb_1; + net::TestCompletionCallback write_cb_2; + EXPECT_NE(net::OK, host_channel1_->Write(buf.get(), buf->size(), + write_cb_1.callback())); + EXPECT_FALSE(write_cb_1.have_result()); + EXPECT_NE(net::OK, host_channel1_->Write(buf.get(), buf->size(), + write_cb_2.callback())); + EXPECT_FALSE(write_cb_2.have_result()); + } +} + +TEST_P(QuicChannelFactoryTest, DeleteWhenFailed) { + Initialize(); + + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName, &host_channel1_, &client_channel1_)); + ASSERT_NO_FATAL_FAILURE( + CreateChannel(kTestChannelName2, &host_channel2_, &client_channel2_)); + + host_base_channel_factory_.GetFakeChannel(kQuicChannelName) + ->set_next_send_error(net::ERR_FAILED); + + scoped_refptr<net::IOBufferWithSize> read_buf = + new net::IOBufferWithSize(100); + + EXPECT_EQ(net::ERR_IO_PENDING, + host_channel1_->Read( + read_buf.get(), read_buf->size(), + base::Bind(&QuicChannelFactoryTest::FailedReadDeleteAll, + base::Unretained(this)))); + + // Try writing to a channel. This should result it DeleteAll() called and the + // connection torn down. + scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); + net::TestCompletionCallback write_cb_1; + host_channel1_->Write(buf.get(), buf->size(), write_cb_1.callback()); + + base::RunLoop().RunUntilIdle(); + + // Check that the connection was torn down. + EXPECT_FALSE(host_quic_); +} + +TEST_P(QuicChannelFactoryTest, SessionFail) { + host_base_channel_factory_.set_fail_create(true); + Initialize(); + + host_quic_->CreateChannel( + kTestChannelName, + base::Bind(&QuicChannelFactoryTest::OnChannelConnectedExpectFail, + base::Unretained(this))); + + // host_quic_ may be destroyed at this point in sync mode. + if (host_quic_) { + host_quic_->CreateChannel( + kTestChannelName2, + base::Bind(&QuicChannelFactoryTest::OnChannelConnectedNotReached, + base::Unretained(this))); + } + + base::RunLoop().RunUntilIdle(); + + // Check that DeleteAll() was called and the connection was torn down. + EXPECT_FALSE(host_quic_); +} + +// Verify that the host just ignores incoming stream with unexpected name. +TEST_P(QuicChannelFactoryTest, UnknownName) { + Initialize(); + + // Create a new channel from the client side. + client_quic_->CreateChannel( + kTestChannelName, base::Bind(&QuicChannelFactoryTest::OnChannelConnected, + base::Unretained(this), &client_channel1_, + nullptr, nullptr)); + base::RunLoop().RunUntilIdle(); + + EXPECT_EQ(0U, host_quic_->GetP2PSessionForTests()->GetNumOpenStreams()); +} + +// Verify that incoming streams that have received only partial name are +// destroyed correctly. +TEST_P(QuicChannelFactoryTest, SendPartialName) { + Initialize(); + + base::RunLoop().RunUntilIdle(); + + net::QuicP2PSession* session = client_quic_->GetP2PSessionForTests(); + net::QuicP2PStream* stream = session->CreateOutgoingDynamicStream(); + + std::string name = kTestChannelName; + // Send only half of the name to the host. + stream->WriteHeader(std::string(1, static_cast<char>(name.size())) + + name.substr(0, name.size() / 2)); + + base::RunLoop().RunUntilIdle(); + + // Host should have received the new stream and is still waiting for the name. + EXPECT_EQ(1U, host_quic_->GetP2PSessionForTests()->GetNumOpenStreams()); + + session->CloseStream(stream->id()); + base::RunLoop().RunUntilIdle(); + + // Verify that the stream was closed on the host side. + EXPECT_EQ(0U, host_quic_->GetP2PSessionForTests()->GetNumOpenStreams()); + + // Create another stream with only partial name and tear down connection while + // it's still pending. + stream = session->CreateOutgoingDynamicStream(); + stream->WriteHeader(std::string(1, static_cast<char>(name.size())) + + name.substr(0, name.size() / 2)); + base::RunLoop().RunUntilIdle(); + EXPECT_EQ(1U, host_quic_->GetP2PSessionForTests()->GetNumOpenStreams()); +} + +} // namespace protocol +} // namespace remoting |