diff options
author | sergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-08 02:03:10 +0000 |
---|---|---|
committer | sergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-08-08 02:03:10 +0000 |
commit | dfcc8927737e14edd5890c509ae113794435a470 (patch) | |
tree | 8705cd5af8968a9fa995246df6f23275aa226a5c | |
parent | d7f7f753ace38ad83299ba7b53aa85f373849c91 (diff) | |
download | chromium_src-dfcc8927737e14edd5890c509ae113794435a470.zip chromium_src-dfcc8927737e14edd5890c509ae113794435a470.tar.gz chromium_src-dfcc8927737e14edd5890c509ae113794435a470.tar.bz2 |
Implement ChannelMultiplexer.
ChannelMultiplexer allows multiple logical channels to share a
single underlying transport channel.
BUG=137135
Review URL: https://chromiumcodereview.appspot.com/10830046
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@150484 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | remoting/host/server_log_entry.cc | 1 | ||||
-rw-r--r-- | remoting/proto/chromotocol.gyp | 1 | ||||
-rw-r--r-- | remoting/proto/mux.proto | 27 | ||||
-rw-r--r-- | remoting/protocol/buffered_socket_writer.cc | 4 | ||||
-rw-r--r-- | remoting/protocol/channel_factory.h | 59 | ||||
-rw-r--r-- | remoting/protocol/channel_multiplexer.cc | 513 | ||||
-rw-r--r-- | remoting/protocol/channel_multiplexer.h | 88 | ||||
-rw-r--r-- | remoting/protocol/channel_multiplexer_unittest.cc | 301 | ||||
-rw-r--r-- | remoting/protocol/connection_tester.h | 1 | ||||
-rw-r--r-- | remoting/protocol/session.h | 31 | ||||
-rw-r--r-- | remoting/remoting.gyp | 7 |
11 files changed, 1001 insertions, 32 deletions
diff --git a/remoting/host/server_log_entry.cc b/remoting/host/server_log_entry.cc index 98f39ee..dabcacf 100644 --- a/remoting/host/server_log_entry.cc +++ b/remoting/host/server_log_entry.cc @@ -4,6 +4,7 @@ #include "remoting/host/server_log_entry.h" +#include "base/logging.h" #include "base/sys_info.h" #include "remoting/base/constants.h" #include "remoting/protocol/session.h" diff --git a/remoting/proto/chromotocol.gyp b/remoting/proto/chromotocol.gyp index 98077ce..baf8c40 100644 --- a/remoting/proto/chromotocol.gyp +++ b/remoting/proto/chromotocol.gyp @@ -15,6 +15,7 @@ 'control.proto', 'event.proto', 'internal.proto', + 'mux.proto', 'video.proto', ], 'variables': { diff --git a/remoting/proto/mux.proto b/remoting/proto/mux.proto new file mode 100644 index 0000000..ff0a8f6 --- /dev/null +++ b/remoting/proto/mux.proto @@ -0,0 +1,27 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Protocol for the mux channel that multiplexes multiple channels. + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package remoting.protocol; + +message MultiplexPacket { + // Channel ID. Each peer choses this value when it sends first packet to + // the other peer. It unique identified channel this packet belongs to. + // Channel ID is direction-specific, i.e. each channel has two IDs + // assigned to it: one for receiving and one for sending. + optional int32 channel_id = 1; + + // Channel name. The name is used to identify channels before channel ID + // is assigned in the first message. This value must be included only + // in the first packet for a given channel. All other packets must be + // identified using channel ID. + optional string channel_name = 2; + + optional bytes data = 3; +} diff --git a/remoting/protocol/buffered_socket_writer.cc b/remoting/protocol/buffered_socket_writer.cc index 178f14a..de11356 100644 --- a/remoting/protocol/buffered_socket_writer.cc +++ b/remoting/protocol/buffered_socket_writer.cc @@ -55,7 +55,9 @@ bool BufferedSocketWriterBase::Write( buffer_size_ += data->size(); DoWrite(); - return true; + + // DoWrite() may trigger OnWriteError() to be called. + return !closed_; } void BufferedSocketWriterBase::DoWrite() { diff --git a/remoting/protocol/channel_factory.h b/remoting/protocol/channel_factory.h new file mode 100644 index 0000000..7741e5d --- /dev/null +++ b/remoting/protocol/channel_factory.h @@ -0,0 +1,59 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ + +#include "base/callback.h" +#include "base/memory/scoped_ptr.h" +#include "base/threading/non_thread_safe.h" + +namespace net { +class Socket; +class StreamSocket; +} // namespace net + +namespace remoting { +namespace protocol { + +class ChannelFactory : public base::NonThreadSafe { + public: + // TODO(sergeyu): Specify connection error code when channel + // connection fails. + typedef base::Callback<void(scoped_ptr<net::StreamSocket>)> + StreamChannelCallback; + typedef base::Callback<void(scoped_ptr<net::Socket>)> + DatagramChannelCallback; + + ChannelFactory() {} + + // Creates new channels for this connection. The specified callback is called + // when then new channel is created and connected. The callback is called with + // NULL if connection failed for any reason. Callback may be called + // synchronously, before the call returns. All channels must be destroyed + // before the factory is destroyed and CancelChannelCreation() must be called + // to cancel creation of channels for which the |callback| hasn't been called + // yet. + virtual void CreateStreamChannel( + const std::string& name, const StreamChannelCallback& callback) = 0; + virtual void CreateDatagramChannel( + const std::string& name, const DatagramChannelCallback& callback) = 0; + + // Cancels a pending CreateStreamChannel() or CreateDatagramChannel() + // operation for the named channel. If the channel creation already + // completed then canceling it has no effect. When shutting down + // this method must be called for each channel pending creation. + virtual void CancelChannelCreation(const std::string& name) = 0; + + protected: + virtual ~ChannelFactory() {} + + private: + DISALLOW_COPY_AND_ASSIGN(ChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc new file mode 100644 index 0000000..71647bf --- /dev/null +++ b/remoting/protocol/channel_multiplexer.cc @@ -0,0 +1,513 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/channel_multiplexer.h" + +#include <string.h> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/location.h" +#include "base/stl_util.h" +#include "net/base/net_errors.h" +#include "net/socket/stream_socket.h" +#include "remoting/protocol/util.h" + +namespace remoting { +namespace protocol { + +namespace { +const int kChannelIdUnknown = -1; +const int kMaxPacketSize = 1024; + +class PendingPacket { + public: + PendingPacket(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task) + : packet(packet.Pass()), + done_task(done_task), + pos(0U) { + } + ~PendingPacket() { + done_task.Run(); + } + + bool is_empty() { return pos >= packet->data().size(); } + + int Read(char* buffer, size_t size) { + size = std::min(size, packet->data().size() - pos); + memcpy(buffer, packet->data().data() + pos, size); + pos += size; + return size; + } + + private: + scoped_ptr<MultiplexPacket> packet; + base::Closure done_task; + size_t pos; + + DISALLOW_COPY_AND_ASSIGN(PendingPacket); +}; + +} // namespace + +const char ChannelMultiplexer::kMuxChannelName[] = "mux"; + +struct ChannelMultiplexer::PendingChannel { + PendingChannel(const std::string& name, + const StreamChannelCallback& callback) + : name(name), callback(callback) { + } + std::string name; + StreamChannelCallback callback; +}; + +class ChannelMultiplexer::MuxChannel { + public: + MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, + int send_id); + ~MuxChannel(); + + const std::string& name() { return name_; } + int receive_id() { return receive_id_; } + void set_receive_id(int id) { receive_id_ = id; } + + // Called by ChannelMultiplexer. + scoped_ptr<net::StreamSocket> CreateSocket(); + void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task); + void OnWriteFailed(); + + // Called by MuxSocket. + void OnSocketDestroyed(); + bool DoWrite(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task); + int DoRead(net::IOBuffer* buffer, int buffer_len); + + private: + ChannelMultiplexer* multiplexer_; + std::string name_; + int send_id_; + bool id_sent_; + int receive_id_; + MuxSocket* socket_; + std::list<PendingPacket*> pending_packets_; + + DISALLOW_COPY_AND_ASSIGN(MuxChannel); +}; + +class ChannelMultiplexer::MuxSocket : public net::StreamSocket, + public base::NonThreadSafe, + public base::SupportsWeakPtr<MuxSocket> { + public: + MuxSocket(MuxChannel* channel); + ~MuxSocket(); + + void OnWriteComplete(); + void OnWriteFailed(); + void OnPacketReceived(); + + // net::StreamSocket interface. + virtual int Read(net::IOBuffer* buffer, int buffer_len, + const net::CompletionCallback& callback) OVERRIDE; + virtual int Write(net::IOBuffer* buffer, int buffer_len, + const net::CompletionCallback& callback) OVERRIDE; + + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { + NOTIMPLEMENTED(); + return false; + } + virtual bool SetSendBufferSize(int32 size) OVERRIDE { + NOTIMPLEMENTED(); + return false; + } + + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { + NOTIMPLEMENTED(); + return net::ERR_FAILED; + } + virtual void Disconnect() OVERRIDE { + NOTIMPLEMENTED(); + } + virtual bool IsConnected() const OVERRIDE { + NOTIMPLEMENTED(); + return true; + } + virtual bool IsConnectedAndIdle() const OVERRIDE { + NOTIMPLEMENTED(); + return false; + } + virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { + NOTIMPLEMENTED(); + return net::ERR_FAILED; + } + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { + NOTIMPLEMENTED(); + return net::ERR_FAILED; + } + virtual const net::BoundNetLog& NetLog() const OVERRIDE { + NOTIMPLEMENTED(); + return net_log_; + } + virtual void SetSubresourceSpeculation() OVERRIDE { + NOTIMPLEMENTED(); + } + virtual void SetOmniboxSpeculation() OVERRIDE { + NOTIMPLEMENTED(); + } + virtual bool WasEverUsed() const OVERRIDE { + return true; + } + virtual bool UsingTCPFastOpen() const OVERRIDE { + return false; + } + virtual int64 NumBytesRead() const OVERRIDE { + NOTIMPLEMENTED(); + return 0; + } + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { + NOTIMPLEMENTED(); + return base::TimeDelta(); + } + virtual bool WasNpnNegotiated() const OVERRIDE { + return false; + } + virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { + return net::kProtoUnknown; + } + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { + NOTIMPLEMENTED(); + return false; + } + + private: + MuxChannel* channel_; + + net::CompletionCallback read_callback_; + scoped_refptr<net::IOBuffer> read_buffer_; + int read_buffer_size_; + + bool write_pending_; + int write_result_; + net::CompletionCallback write_callback_; + + net::BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MuxSocket); +}; + + +ChannelMultiplexer::MuxChannel::MuxChannel( + ChannelMultiplexer* multiplexer, + const std::string& name, + int send_id) + : multiplexer_(multiplexer), + name_(name), + send_id_(send_id), + id_sent_(false), + receive_id_(kChannelIdUnknown), + socket_(NULL) { +} + +ChannelMultiplexer::MuxChannel::~MuxChannel() { + // Socket must be destroyed before the channel. + DCHECK(!socket_); + STLDeleteElements(&pending_packets_); +} + +scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { + DCHECK(!socket_); // Can't create more than one socket per channel. + scoped_ptr<MuxSocket> result(new MuxSocket(this)); + socket_ = result.get(); + return result.PassAs<net::StreamSocket>(); +} + +void ChannelMultiplexer::MuxChannel::OnIncomingPacket( + scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task) { + DCHECK_EQ(packet->channel_id(), receive_id_); + if (packet->data().size() > 0) { + pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); + if (socket_) { + // Notify the socket that we have more data. + socket_->OnPacketReceived(); + } + } +} + +void ChannelMultiplexer::MuxChannel::OnWriteFailed() { + if (socket_) + socket_->OnWriteFailed(); +} + +void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { + DCHECK(socket_); + socket_ = NULL; +} + +bool ChannelMultiplexer::MuxChannel::DoWrite( + scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task) { + packet->set_channel_id(send_id_); + if (!id_sent_) { + packet->set_channel_name(name_); + id_sent_ = true; + } + return multiplexer_->DoWrite(packet.Pass(), done_task); +} + +int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, + int buffer_len) { + int pos = 0; + while (buffer_len > 0 && !pending_packets_.empty()) { + DCHECK(!pending_packets_.front()->is_empty()); + int result = pending_packets_.front()->Read( + buffer->data() + pos, buffer_len); + DCHECK_LE(result, buffer_len); + pos += result; + buffer_len -= pos; + if (pending_packets_.front()->is_empty()) { + delete pending_packets_.front(); + pending_packets_.erase(pending_packets_.begin()); + } + } + return pos; +} + +ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) + : channel_(channel), + read_buffer_size_(0), + write_pending_(false), + write_result_(0) { +} + +ChannelMultiplexer::MuxSocket::~MuxSocket() { + channel_->OnSocketDestroyed(); +} + +int ChannelMultiplexer::MuxSocket::Read( + net::IOBuffer* buffer, int buffer_len, + const net::CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(read_callback_.is_null()); + + int result = channel_->DoRead(buffer, buffer_len); + if (result == 0) { + read_buffer_ = buffer; + read_buffer_size_ = buffer_len; + read_callback_ = callback; + return net::ERR_IO_PENDING; + } + return result; +} + +int ChannelMultiplexer::MuxSocket::Write( + net::IOBuffer* buffer, int buffer_len, + const net::CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + + scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); + size_t size = std::min(kMaxPacketSize, buffer_len); + packet->mutable_data()->assign(buffer->data(), size); + + write_pending_ = true; + bool result = channel_->DoWrite(packet.Pass(), base::Bind( + &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); + + if (!result) { + // Cannot complete the write, e.g. if the connection has been terminated. + return net::ERR_FAILED; + } + + // OnWriteComplete() might be called above synchronously. + if (write_pending_) { + DCHECK(write_callback_.is_null()); + write_callback_ = callback; + write_result_ = size; + return net::ERR_IO_PENDING; + } + + return size; +} + +void ChannelMultiplexer::MuxSocket::OnWriteComplete() { + write_pending_ = false; + if (!write_callback_.is_null()) { + net::CompletionCallback cb; + std::swap(cb, write_callback_); + cb.Run(write_result_); + } +} + +void ChannelMultiplexer::MuxSocket::OnWriteFailed() { + if (!write_callback_.is_null()) { + net::CompletionCallback cb; + std::swap(cb, write_callback_); + cb.Run(net::ERR_FAILED); + } +} + +void ChannelMultiplexer::MuxSocket::OnPacketReceived() { + if (!read_callback_.is_null()) { + int result = channel_->DoRead(read_buffer_, read_buffer_size_); + read_buffer_ = NULL; + DCHECK_GT(result, 0); + net::CompletionCallback cb; + std::swap(cb, read_callback_); + cb.Run(result); + } +} + +ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, + const std::string& base_channel_name) + : base_channel_factory_(factory), + base_channel_name_(base_channel_name), + next_channel_id_(0), + destroyed_flag_(NULL) { + factory->CreateStreamChannel( + base_channel_name, + base::Bind(&ChannelMultiplexer::OnBaseChannelReady, + base::Unretained(this))); +} + +ChannelMultiplexer::~ChannelMultiplexer() { + DCHECK(pending_channels_.empty()); + STLDeleteValues(&channels_); + + // Cancel creation of the base channel if it hasn't finished. + if (base_channel_factory_) + base_channel_factory_->CancelChannelCreation(base_channel_name_); + + if (destroyed_flag_) + *destroyed_flag_ = true; +} + +void ChannelMultiplexer::CreateStreamChannel( + const std::string& name, + const StreamChannelCallback& callback) { + if (base_channel_.get()) { + // Already have |base_channel_|. Create new multiplexed channel + // synchronously. + callback.Run(GetOrCreateChannel(name)->CreateSocket()); + } else if (!base_channel_.get() && !base_channel_factory_) { + // Fail synchronously if we failed to create |base_channel_|. + callback.Run(scoped_ptr<net::StreamSocket>()); + } else { + // Still waiting for the |base_channel_|. + pending_channels_.push_back(PendingChannel(name, callback)); + } +} + +void ChannelMultiplexer::CreateDatagramChannel( + const std::string& name, + const DatagramChannelCallback& callback) { + NOTIMPLEMENTED(); + callback.Run(scoped_ptr<net::Socket>()); +} + +void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { + for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); + it != pending_channels_.end(); ++it) { + if (it->name == name) { + pending_channels_.erase(it); + return; + } + } +} + +void ChannelMultiplexer::OnBaseChannelReady( + scoped_ptr<net::StreamSocket> socket) { + base_channel_factory_ = NULL; + base_channel_ = socket.Pass(); + + if (!base_channel_.get()) { + // Notify all callers that we can't create any channels. + for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); + it != pending_channels_.end(); ++it) { + it->callback.Run(scoped_ptr<net::StreamSocket>()); + } + pending_channels_.clear(); + return; + } + + // Initialize reader and writer. + reader_.Init(base_channel_.get(), + base::Bind(&ChannelMultiplexer::OnIncomingPacket, + base::Unretained(this))); + writer_.Init(base_channel_.get(), + base::Bind(&ChannelMultiplexer::OnWriteFailed, + base::Unretained(this))); + + // Now create all pending channels. + for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); + it != pending_channels_.end(); ++it) { + it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); + } + pending_channels_.clear(); +} + +ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( + const std::string& name) { + // Check if we already have a channel with the requested name. + std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); + if (it != channels_.end()) + return it->second; + + // Create a new channel if we haven't found existing one. + MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); + ++next_channel_id_; + channels_[channel->name()] = channel; + return channel; +} + + +void ChannelMultiplexer::OnWriteFailed(int error) { + bool destroyed = false; + destroyed_flag_ = &destroyed; + for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); + it != channels_.end(); ++it) { + it->second->OnWriteFailed(); + if (destroyed) + return; + } + destroyed_flag_ = NULL; +} + +void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task) { + if (!packet->has_channel_id()) { + LOG(ERROR) << "Received packet without channel_id."; + done_task.Run(); + return; + } + + int receive_id = packet->channel_id(); + MuxChannel* channel = NULL; + std::map<int, MuxChannel*>::iterator it = + channels_by_receive_id_.find(receive_id); + if (it != channels_by_receive_id_.end()) { + channel = it->second; + } else { + // This is a new |channel_id| we haven't seen before. Look it up by name. + if (!packet->has_channel_name()) { + LOG(ERROR) << "Received packet with unknown channel_id and " + "without channel_name."; + done_task.Run(); + return; + } + channel = GetOrCreateChannel(packet->channel_name()); + channel->set_receive_id(receive_id); + channels_by_receive_id_[receive_id] = channel; + } + + channel->OnIncomingPacket(packet.Pass(), done_task); +} + +bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task) { + return writer_.Write(SerializeAndFrameMessage(*packet), done_task); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/channel_multiplexer.h b/remoting/protocol/channel_multiplexer.h new file mode 100644 index 0000000..0f16fb1 --- /dev/null +++ b/remoting/protocol/channel_multiplexer.h @@ -0,0 +1,88 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_ +#define REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_ + +#include "remoting/proto/mux.pb.h" +#include "remoting/protocol/buffered_socket_writer.h" +#include "remoting/protocol/channel_factory.h" +#include "remoting/protocol/message_reader.h" + +namespace remoting { +namespace protocol { + +class ChannelMultiplexer : public ChannelFactory { + public: + static const char kMuxChannelName[]; + + // |factory| is used to create the channel upon which to multiplex. + ChannelMultiplexer(ChannelFactory* factory, + const std::string& base_channel_name); + virtual ~ChannelMultiplexer(); + + // ChannelFactory interface. + virtual void CreateStreamChannel( + const std::string& name, + const StreamChannelCallback& callback) OVERRIDE; + virtual void CreateDatagramChannel( + const std::string& name, + const DatagramChannelCallback& callback) OVERRIDE; + virtual void CancelChannelCreation(const std::string& name) OVERRIDE; + + private: + struct PendingChannel; + class MuxChannel; + class MuxSocket; + friend class MuxChannel; + + // Callback for |base_channel_| creation. + void OnBaseChannelReady(scoped_ptr<net::StreamSocket> socket); + + // Helper method used to create channels. + MuxChannel* GetOrCreateChannel(const std::string& name); + + // Callbacks for |writer_| and |reader_|. + void OnWriteFailed(int error); + void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task); + + // Called by MuxChannel. + bool DoWrite(scoped_ptr<MultiplexPacket> packet, + const base::Closure& done_task); + + // Factory used to create |base_channel_|. Set to NULL once creation is + // finished or failed. + ChannelFactory* base_channel_factory_; + + // Name of the underlying channel. + std::string base_channel_name_; + + // The channel over which to multiplex. + scoped_ptr<net::StreamSocket> base_channel_; + + // List of requested channels while we are waiting for |base_channel_|. + std::list<PendingChannel> pending_channels_; + + int next_channel_id_; + std::map<std::string, MuxChannel*> channels_; + + // Channels are added to |channels_by_receive_id_| only after we receive + // receive_id from the remote peer. + std::map<int, MuxChannel*> channels_by_receive_id_; + + BufferedSocketWriter writer_; + ProtobufMessageReader<MultiplexPacket> reader_; + + // Flag used by OnWriteFailed() to detect when the multiplexer is destroyed. + bool* destroyed_flag_; + + DISALLOW_COPY_AND_ASSIGN(ChannelMultiplexer); +}; + +} // namespace protocol +} // namespace remoting + + +#endif // REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_ diff --git a/remoting/protocol/channel_multiplexer_unittest.cc b/remoting/protocol/channel_multiplexer_unittest.cc new file mode 100644 index 0000000..11a459b --- /dev/null +++ b/remoting/protocol/channel_multiplexer_unittest.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/channel_multiplexer.h" + +#include "base/bind.h" +#include "base/message_loop.h" +#include "net/base/net_errors.h" +#include "net/socket/socket.h" +#include "net/socket/stream_socket.h" +#include "remoting/base/constants.h" +#include "remoting/protocol/connection_tester.h" +#include "remoting/protocol/fake_session.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using testing::_; +using testing::AtMost; +using testing::InvokeWithoutArgs; + +namespace remoting { +namespace protocol { + +namespace { + +const int kMessageSize = 1024; +const int kMessages = 100; +const char kMuxChannelName[] = "mux"; + +void QuitCurrentThread() { + MessageLoop::current()->PostTask(FROM_HERE, MessageLoop::QuitClosure()); +} + +class MockSocketCallback { + public: + MOCK_METHOD1(OnDone, void(int result)); +}; + +} // namespace + +class ChannelMultiplexerTest : public testing::Test { + public: + void DeleteAll() { + host_socket1_.reset(); + host_socket2_.reset(); + client_socket1_.reset(); + client_socket2_.reset(); + host_mux_.reset(); + client_mux_.reset(); + } + + protected: + virtual void SetUp() OVERRIDE { + // Create pair of multiplexers and connect them to each other. + host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName)); + client_mux_.reset(new ChannelMultiplexer(&client_session_, + kMuxChannelName)); + FakeSocket* host_socket = + host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName); + FakeSocket* client_socket = + client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName); + host_socket->PairWith(client_socket); + + // Make writes asynchronous in one direction. + host_socket->set_async_write(true); + } + + void CreateChannel(const std::string& name, + scoped_ptr<net::StreamSocket>* host_socket, + scoped_ptr<net::StreamSocket>* client_socket) { + int counter = 2; + host_mux_->CreateStreamChannel(name, base::Bind( + &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), + host_socket, &counter)); + client_mux_->CreateStreamChannel(name, base::Bind( + &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), + client_socket, &counter)); + + message_loop_.Run(); + + EXPECT_TRUE(host_socket->get()); + EXPECT_TRUE(client_socket->get()); + } + + void OnChannelConnected( + scoped_ptr<net::StreamSocket>* storage, + int* counter, + scoped_ptr<net::StreamSocket> socket) { + *storage = socket.Pass(); + --(*counter); + EXPECT_GE(*counter, 0); + if (*counter == 0) + QuitCurrentThread(); + } + + scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) { + scoped_refptr<net::IOBufferWithSize> result = + new net::IOBufferWithSize(size); + for (int i = 0; i< size; ++i) { + result->data()[i] = rand() % 256; + } + return result; + } + + MessageLoop message_loop_; + + FakeSession host_session_; + FakeSession client_session_; + + scoped_ptr<ChannelMultiplexer> host_mux_; + scoped_ptr<ChannelMultiplexer> client_mux_; + + scoped_ptr<net::StreamSocket> host_socket1_; + scoped_ptr<net::StreamSocket> client_socket1_; + scoped_ptr<net::StreamSocket> host_socket2_; + scoped_ptr<net::StreamSocket> client_socket2_; +}; + + +TEST_F(ChannelMultiplexerTest, OneChannel) { + scoped_ptr<net::StreamSocket> host_socket; + scoped_ptr<net::StreamSocket> client_socket; + ASSERT_NO_FATAL_FAILURE(CreateChannel("test", &host_socket, &client_socket)); + + StreamConnectionTester tester(host_socket.get(), client_socket.get(), + kMessageSize, kMessages); + tester.Start(); + message_loop_.Run(); + tester.CheckResults(); +} + +TEST_F(ChannelMultiplexerTest, TwoChannels) { + scoped_ptr<net::StreamSocket> host_socket1_; + scoped_ptr<net::StreamSocket> client_socket1_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("test", &host_socket1_, &client_socket1_)); + + scoped_ptr<net::StreamSocket> host_socket2_; + scoped_ptr<net::StreamSocket> client_socket2_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("ch2", &host_socket2_, &client_socket2_)); + + StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), + kMessageSize, kMessages); + StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), + kMessageSize, kMessages); + tester1.Start(); + tester2.Start(); + while (!tester1.done() || !tester2.done()) { + message_loop_.Run(); + } + tester1.CheckResults(); + tester2.CheckResults(); +} + +// Four channels, two in each direction +TEST_F(ChannelMultiplexerTest, FourChannels) { + scoped_ptr<net::StreamSocket> host_socket1_; + scoped_ptr<net::StreamSocket> client_socket1_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("test", &host_socket1_, &client_socket1_)); + + scoped_ptr<net::StreamSocket> host_socket2_; + scoped_ptr<net::StreamSocket> client_socket2_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("ch2", &host_socket2_, &client_socket2_)); + + scoped_ptr<net::StreamSocket> host_socket3; + scoped_ptr<net::StreamSocket> client_socket3; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("test3", &host_socket3, &client_socket3)); + + scoped_ptr<net::StreamSocket> host_socket4; + scoped_ptr<net::StreamSocket> client_socket4; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("ch4", &host_socket4, &client_socket4)); + + StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), + kMessageSize, kMessages); + StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), + kMessageSize, kMessages); + StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(), + kMessageSize, kMessages); + StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(), + kMessageSize, kMessages); + tester1.Start(); + tester2.Start(); + tester3.Start(); + tester4.Start(); + while (!tester1.done() || !tester2.done() || + !tester3.done() || !tester4.done()) { + message_loop_.Run(); + } + tester1.CheckResults(); + tester2.CheckResults(); + tester3.CheckResults(); + tester4.CheckResults(); +} + +TEST_F(ChannelMultiplexerTest, SyncFail) { + scoped_ptr<net::StreamSocket> host_socket1_; + scoped_ptr<net::StreamSocket> client_socket1_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("test", &host_socket1_, &client_socket1_)); + + scoped_ptr<net::StreamSocket> host_socket2_; + scoped_ptr<net::StreamSocket> client_socket2_; + ASSERT_NO_FATAL_FAILURE( + CreateChannel("ch2", &host_socket2_, &client_socket2_)); + + host_session_.GetStreamChannel(kMuxChannelName)-> + set_next_write_error(net::ERR_FAILED); + host_session_.GetStreamChannel(kMuxChannelName)-> + set_async_write(false); + + scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); + + MockSocketCallback cb1; + MockSocketCallback cb2; + + EXPECT_CALL(cb1, OnDone(_)) + .Times(0); + EXPECT_CALL(cb2, OnDone(_)) + .Times(0); + + EXPECT_EQ(net::ERR_FAILED, host_socket1_->Write(buf, buf->size(), base::Bind( + &MockSocketCallback::OnDone, base::Unretained(&cb1)))); + EXPECT_EQ(net::ERR_FAILED, host_socket2_->Write(buf, buf->size(), base::Bind( + &MockSocketCallback::OnDone, base::Unretained(&cb2)))); + + message_loop_.RunAllPending(); +} + +TEST_F(ChannelMultiplexerTest, AsyncFail) { + ASSERT_NO_FATAL_FAILURE( + CreateChannel("test", &host_socket1_, &client_socket1_)); + + ASSERT_NO_FATAL_FAILURE( + CreateChannel("ch2", &host_socket2_, &client_socket2_)); + + host_session_.GetStreamChannel(kMuxChannelName)-> + set_next_write_error(net::ERR_FAILED); + host_session_.GetStreamChannel(kMuxChannelName)-> + set_async_write(true); + + scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); + + MockSocketCallback cb1; + MockSocketCallback cb2; + EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)); + EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)); + + EXPECT_EQ(net::ERR_IO_PENDING, + host_socket1_->Write(buf, buf->size(), base::Bind( + &MockSocketCallback::OnDone, base::Unretained(&cb1)))); + EXPECT_EQ(net::ERR_IO_PENDING, + host_socket2_->Write(buf, buf->size(), base::Bind( + &MockSocketCallback::OnDone, base::Unretained(&cb2)))); + + message_loop_.RunAllPending(); +} + +TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) { + ASSERT_NO_FATAL_FAILURE( + CreateChannel("test", &host_socket1_, &client_socket1_)); + ASSERT_NO_FATAL_FAILURE( + CreateChannel("ch2", &host_socket2_, &client_socket2_)); + + host_session_.GetStreamChannel(kMuxChannelName)-> + set_next_write_error(net::ERR_FAILED); + host_session_.GetStreamChannel(kMuxChannelName)-> + set_async_write(true); + + scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); + + MockSocketCallback cb1; + MockSocketCallback cb2; + + EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)) + .Times(AtMost(1)) + .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); + EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)) + .Times(AtMost(1)) + .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); + + EXPECT_EQ(net::ERR_IO_PENDING, + host_socket1_->Write(buf, buf->size(), base::Bind( + &MockSocketCallback::OnDone, base::Unretained(&cb1)))); + EXPECT_EQ(net::ERR_IO_PENDING, + host_socket2_->Write(buf, buf->size(), base::Bind( + &MockSocketCallback::OnDone, base::Unretained(&cb2)))); + + message_loop_.RunAllPending(); + + // Check that the sockets were destroyed. + EXPECT_FALSE(host_mux_.get()); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/connection_tester.h b/remoting/protocol/connection_tester.h index b32cf62..f89cd7a 100644 --- a/remoting/protocol/connection_tester.h +++ b/remoting/protocol/connection_tester.h @@ -34,6 +34,7 @@ class StreamConnectionTester { ~StreamConnectionTester(); void Start(); + bool done() { return done_; } void CheckResults(); protected: diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h index 9148318..7041486 100644 --- a/remoting/protocol/session.h +++ b/remoting/protocol/session.h @@ -7,16 +7,12 @@ #include <string> -#include "base/callback.h" -#include "base/threading/non_thread_safe.h" -#include "remoting/protocol/buffered_socket_writer.h" +#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/errors.h" #include "remoting/protocol/session_config.h" namespace net { class IPEndPoint; -class Socket; -class StreamSocket; } // namespace net namespace remoting { @@ -27,7 +23,7 @@ struct TransportRoute; // Generic interface for Chromotocol connection used by both client and host. // Provides access to the connection channels, but doesn't depend on the // protocol used for each channel. -class Session : public base::NonThreadSafe { +class Session : public ChannelFactory { public: enum State { // Created, but not connecting yet. @@ -74,12 +70,6 @@ class Session : public base::NonThreadSafe { bool ready) {} }; - // TODO(sergeyu): Specify connection error code when channel - // connection fails. - typedef base::Callback<void(scoped_ptr<net::StreamSocket>)> - StreamChannelCallback; - typedef base::Callback<void(scoped_ptr<net::Socket>)> - DatagramChannelCallback; Session() {} virtual ~Session() {} @@ -91,23 +81,6 @@ class Session : public base::NonThreadSafe { // Returns error code for a failed session. virtual ErrorCode error() = 0; - // Creates new channels for this connection. The specified callback - // is called when then new channel is created and connected. The - // callback is called with NULL if connection failed for any reason. - // All channels must be destroyed before the session is - // destroyed. Can be called only when in CONNECTING, CONNECTED or - // AUTHENTICATED states. - virtual void CreateStreamChannel( - const std::string& name, const StreamChannelCallback& callback) = 0; - virtual void CreateDatagramChannel( - const std::string& name, const DatagramChannelCallback& callback) = 0; - - // Cancels a pending CreateStreamChannel() or CreateDatagramChannel() - // operation for the named channel. If the channel creation already - // completed then cancelling it has no effect. When shutting down - // this method must be called for each channel pending creation. - virtual void CancelChannelCreation(const std::string& name) = 0; - // JID of the other side. virtual const std::string& jid() = 0; diff --git a/remoting/remoting.gyp b/remoting/remoting.gyp index fc8f683..48725f9 100644 --- a/remoting/remoting.gyp +++ b/remoting/remoting.gyp @@ -1616,6 +1616,8 @@ 'protocol/channel_authenticator.h', 'protocol/channel_dispatcher_base.cc', 'protocol/channel_dispatcher_base.h', + 'protocol/channel_multiplexer.cc', + 'protocol/channel_multiplexer.h', 'protocol/client_control_dispatcher.cc', 'protocol/client_control_dispatcher.h', 'protocol/client_event_dispatcher.cc', @@ -1623,11 +1625,11 @@ 'protocol/client_stub.h', 'protocol/clipboard_echo_filter.cc', 'protocol/clipboard_echo_filter.h', - 'protocol/clipboard_filter.h', 'protocol/clipboard_filter.cc', + 'protocol/clipboard_filter.h', + 'protocol/clipboard_stub.h', 'protocol/clipboard_thread_proxy.cc', 'protocol/clipboard_thread_proxy.h', - 'protocol/clipboard_stub.h', 'protocol/connection_to_client.cc', 'protocol/connection_to_client.h', 'protocol/connection_to_host.cc', @@ -1802,6 +1804,7 @@ 'protocol/authenticator_test_base.cc', 'protocol/authenticator_test_base.h', 'protocol/buffered_socket_writer_unittest.cc', + 'protocol/channel_multiplexer_unittest.cc', 'protocol/clipboard_echo_filter_unittest.cc', 'protocol/connection_tester.cc', 'protocol/connection_tester.h', |