summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorsergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2012-08-08 02:03:10 +0000
committersergeyu@chromium.org <sergeyu@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2012-08-08 02:03:10 +0000
commitdfcc8927737e14edd5890c509ae113794435a470 (patch)
tree8705cd5af8968a9fa995246df6f23275aa226a5c
parentd7f7f753ace38ad83299ba7b53aa85f373849c91 (diff)
downloadchromium_src-dfcc8927737e14edd5890c509ae113794435a470.zip
chromium_src-dfcc8927737e14edd5890c509ae113794435a470.tar.gz
chromium_src-dfcc8927737e14edd5890c509ae113794435a470.tar.bz2
Implement ChannelMultiplexer.
ChannelMultiplexer allows multiple logical channels to share a single underlying transport channel. BUG=137135 Review URL: https://chromiumcodereview.appspot.com/10830046 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@150484 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--remoting/host/server_log_entry.cc1
-rw-r--r--remoting/proto/chromotocol.gyp1
-rw-r--r--remoting/proto/mux.proto27
-rw-r--r--remoting/protocol/buffered_socket_writer.cc4
-rw-r--r--remoting/protocol/channel_factory.h59
-rw-r--r--remoting/protocol/channel_multiplexer.cc513
-rw-r--r--remoting/protocol/channel_multiplexer.h88
-rw-r--r--remoting/protocol/channel_multiplexer_unittest.cc301
-rw-r--r--remoting/protocol/connection_tester.h1
-rw-r--r--remoting/protocol/session.h31
-rw-r--r--remoting/remoting.gyp7
11 files changed, 1001 insertions, 32 deletions
diff --git a/remoting/host/server_log_entry.cc b/remoting/host/server_log_entry.cc
index 98f39ee..dabcacf 100644
--- a/remoting/host/server_log_entry.cc
+++ b/remoting/host/server_log_entry.cc
@@ -4,6 +4,7 @@
#include "remoting/host/server_log_entry.h"
+#include "base/logging.h"
#include "base/sys_info.h"
#include "remoting/base/constants.h"
#include "remoting/protocol/session.h"
diff --git a/remoting/proto/chromotocol.gyp b/remoting/proto/chromotocol.gyp
index 98077ce..baf8c40 100644
--- a/remoting/proto/chromotocol.gyp
+++ b/remoting/proto/chromotocol.gyp
@@ -15,6 +15,7 @@
'control.proto',
'event.proto',
'internal.proto',
+ 'mux.proto',
'video.proto',
],
'variables': {
diff --git a/remoting/proto/mux.proto b/remoting/proto/mux.proto
new file mode 100644
index 0000000..ff0a8f6
--- /dev/null
+++ b/remoting/proto/mux.proto
@@ -0,0 +1,27 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Protocol for the mux channel that multiplexes multiple channels.
+
+syntax = "proto2";
+
+option optimize_for = LITE_RUNTIME;
+
+package remoting.protocol;
+
+message MultiplexPacket {
+ // Channel ID. Each peer choses this value when it sends first packet to
+ // the other peer. It unique identified channel this packet belongs to.
+ // Channel ID is direction-specific, i.e. each channel has two IDs
+ // assigned to it: one for receiving and one for sending.
+ optional int32 channel_id = 1;
+
+ // Channel name. The name is used to identify channels before channel ID
+ // is assigned in the first message. This value must be included only
+ // in the first packet for a given channel. All other packets must be
+ // identified using channel ID.
+ optional string channel_name = 2;
+
+ optional bytes data = 3;
+}
diff --git a/remoting/protocol/buffered_socket_writer.cc b/remoting/protocol/buffered_socket_writer.cc
index 178f14a..de11356 100644
--- a/remoting/protocol/buffered_socket_writer.cc
+++ b/remoting/protocol/buffered_socket_writer.cc
@@ -55,7 +55,9 @@ bool BufferedSocketWriterBase::Write(
buffer_size_ += data->size();
DoWrite();
- return true;
+
+ // DoWrite() may trigger OnWriteError() to be called.
+ return !closed_;
}
void BufferedSocketWriterBase::DoWrite() {
diff --git a/remoting/protocol/channel_factory.h b/remoting/protocol/channel_factory.h
new file mode 100644
index 0000000..7741e5d
--- /dev/null
+++ b/remoting/protocol/channel_factory.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef REMOTING_PROTOCOL_CHANNEL_FACTORY_H_
+#define REMOTING_PROTOCOL_CHANNEL_FACTORY_H_
+
+#include "base/callback.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/threading/non_thread_safe.h"
+
+namespace net {
+class Socket;
+class StreamSocket;
+} // namespace net
+
+namespace remoting {
+namespace protocol {
+
+class ChannelFactory : public base::NonThreadSafe {
+ public:
+ // TODO(sergeyu): Specify connection error code when channel
+ // connection fails.
+ typedef base::Callback<void(scoped_ptr<net::StreamSocket>)>
+ StreamChannelCallback;
+ typedef base::Callback<void(scoped_ptr<net::Socket>)>
+ DatagramChannelCallback;
+
+ ChannelFactory() {}
+
+ // Creates new channels for this connection. The specified callback is called
+ // when then new channel is created and connected. The callback is called with
+ // NULL if connection failed for any reason. Callback may be called
+ // synchronously, before the call returns. All channels must be destroyed
+ // before the factory is destroyed and CancelChannelCreation() must be called
+ // to cancel creation of channels for which the |callback| hasn't been called
+ // yet.
+ virtual void CreateStreamChannel(
+ const std::string& name, const StreamChannelCallback& callback) = 0;
+ virtual void CreateDatagramChannel(
+ const std::string& name, const DatagramChannelCallback& callback) = 0;
+
+ // Cancels a pending CreateStreamChannel() or CreateDatagramChannel()
+ // operation for the named channel. If the channel creation already
+ // completed then canceling it has no effect. When shutting down
+ // this method must be called for each channel pending creation.
+ virtual void CancelChannelCreation(const std::string& name) = 0;
+
+ protected:
+ virtual ~ChannelFactory() {}
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ChannelFactory);
+};
+
+} // namespace protocol
+} // namespace remoting
+
+#endif // REMOTING_PROTOCOL_CHANNEL_FACTORY_H_
diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc
new file mode 100644
index 0000000..71647bf
--- /dev/null
+++ b/remoting/protocol/channel_multiplexer.cc
@@ -0,0 +1,513 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "remoting/protocol/channel_multiplexer.h"
+
+#include <string.h>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/location.h"
+#include "base/stl_util.h"
+#include "net/base/net_errors.h"
+#include "net/socket/stream_socket.h"
+#include "remoting/protocol/util.h"
+
+namespace remoting {
+namespace protocol {
+
+namespace {
+const int kChannelIdUnknown = -1;
+const int kMaxPacketSize = 1024;
+
+class PendingPacket {
+ public:
+ PendingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task)
+ : packet(packet.Pass()),
+ done_task(done_task),
+ pos(0U) {
+ }
+ ~PendingPacket() {
+ done_task.Run();
+ }
+
+ bool is_empty() { return pos >= packet->data().size(); }
+
+ int Read(char* buffer, size_t size) {
+ size = std::min(size, packet->data().size() - pos);
+ memcpy(buffer, packet->data().data() + pos, size);
+ pos += size;
+ return size;
+ }
+
+ private:
+ scoped_ptr<MultiplexPacket> packet;
+ base::Closure done_task;
+ size_t pos;
+
+ DISALLOW_COPY_AND_ASSIGN(PendingPacket);
+};
+
+} // namespace
+
+const char ChannelMultiplexer::kMuxChannelName[] = "mux";
+
+struct ChannelMultiplexer::PendingChannel {
+ PendingChannel(const std::string& name,
+ const StreamChannelCallback& callback)
+ : name(name), callback(callback) {
+ }
+ std::string name;
+ StreamChannelCallback callback;
+};
+
+class ChannelMultiplexer::MuxChannel {
+ public:
+ MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
+ int send_id);
+ ~MuxChannel();
+
+ const std::string& name() { return name_; }
+ int receive_id() { return receive_id_; }
+ void set_receive_id(int id) { receive_id_ = id; }
+
+ // Called by ChannelMultiplexer.
+ scoped_ptr<net::StreamSocket> CreateSocket();
+ void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task);
+ void OnWriteFailed();
+
+ // Called by MuxSocket.
+ void OnSocketDestroyed();
+ bool DoWrite(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task);
+ int DoRead(net::IOBuffer* buffer, int buffer_len);
+
+ private:
+ ChannelMultiplexer* multiplexer_;
+ std::string name_;
+ int send_id_;
+ bool id_sent_;
+ int receive_id_;
+ MuxSocket* socket_;
+ std::list<PendingPacket*> pending_packets_;
+
+ DISALLOW_COPY_AND_ASSIGN(MuxChannel);
+};
+
+class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
+ public base::NonThreadSafe,
+ public base::SupportsWeakPtr<MuxSocket> {
+ public:
+ MuxSocket(MuxChannel* channel);
+ ~MuxSocket();
+
+ void OnWriteComplete();
+ void OnWriteFailed();
+ void OnPacketReceived();
+
+ // net::StreamSocket interface.
+ virtual int Read(net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) OVERRIDE;
+ virtual int Write(net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) OVERRIDE;
+
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+
+ virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
+ NOTIMPLEMENTED();
+ return net::ERR_FAILED;
+ }
+ virtual void Disconnect() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+ virtual bool IsConnected() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return true;
+ }
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+ virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
+ NOTIMPLEMENTED();
+ return net::ERR_FAILED;
+ }
+ virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
+ NOTIMPLEMENTED();
+ return net::ERR_FAILED;
+ }
+ virtual const net::BoundNetLog& NetLog() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return net_log_;
+ }
+ virtual void SetSubresourceSpeculation() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+ virtual void SetOmniboxSpeculation() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+ virtual bool WasEverUsed() const OVERRIDE {
+ return true;
+ }
+ virtual bool UsingTCPFastOpen() const OVERRIDE {
+ return false;
+ }
+ virtual int64 NumBytesRead() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return 0;
+ }
+ virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return base::TimeDelta();
+ }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+ virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return net::kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+
+ private:
+ MuxChannel* channel_;
+
+ net::CompletionCallback read_callback_;
+ scoped_refptr<net::IOBuffer> read_buffer_;
+ int read_buffer_size_;
+
+ bool write_pending_;
+ int write_result_;
+ net::CompletionCallback write_callback_;
+
+ net::BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(MuxSocket);
+};
+
+
+ChannelMultiplexer::MuxChannel::MuxChannel(
+ ChannelMultiplexer* multiplexer,
+ const std::string& name,
+ int send_id)
+ : multiplexer_(multiplexer),
+ name_(name),
+ send_id_(send_id),
+ id_sent_(false),
+ receive_id_(kChannelIdUnknown),
+ socket_(NULL) {
+}
+
+ChannelMultiplexer::MuxChannel::~MuxChannel() {
+ // Socket must be destroyed before the channel.
+ DCHECK(!socket_);
+ STLDeleteElements(&pending_packets_);
+}
+
+scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
+ DCHECK(!socket_); // Can't create more than one socket per channel.
+ scoped_ptr<MuxSocket> result(new MuxSocket(this));
+ socket_ = result.get();
+ return result.PassAs<net::StreamSocket>();
+}
+
+void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
+ scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ DCHECK_EQ(packet->channel_id(), receive_id_);
+ if (packet->data().size() > 0) {
+ pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
+ if (socket_) {
+ // Notify the socket that we have more data.
+ socket_->OnPacketReceived();
+ }
+ }
+}
+
+void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
+ if (socket_)
+ socket_->OnWriteFailed();
+}
+
+void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
+ DCHECK(socket_);
+ socket_ = NULL;
+}
+
+bool ChannelMultiplexer::MuxChannel::DoWrite(
+ scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ packet->set_channel_id(send_id_);
+ if (!id_sent_) {
+ packet->set_channel_name(name_);
+ id_sent_ = true;
+ }
+ return multiplexer_->DoWrite(packet.Pass(), done_task);
+}
+
+int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
+ int buffer_len) {
+ int pos = 0;
+ while (buffer_len > 0 && !pending_packets_.empty()) {
+ DCHECK(!pending_packets_.front()->is_empty());
+ int result = pending_packets_.front()->Read(
+ buffer->data() + pos, buffer_len);
+ DCHECK_LE(result, buffer_len);
+ pos += result;
+ buffer_len -= pos;
+ if (pending_packets_.front()->is_empty()) {
+ delete pending_packets_.front();
+ pending_packets_.erase(pending_packets_.begin());
+ }
+ }
+ return pos;
+}
+
+ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
+ : channel_(channel),
+ read_buffer_size_(0),
+ write_pending_(false),
+ write_result_(0) {
+}
+
+ChannelMultiplexer::MuxSocket::~MuxSocket() {
+ channel_->OnSocketDestroyed();
+}
+
+int ChannelMultiplexer::MuxSocket::Read(
+ net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(read_callback_.is_null());
+
+ int result = channel_->DoRead(buffer, buffer_len);
+ if (result == 0) {
+ read_buffer_ = buffer;
+ read_buffer_size_ = buffer_len;
+ read_callback_ = callback;
+ return net::ERR_IO_PENDING;
+ }
+ return result;
+}
+
+int ChannelMultiplexer::MuxSocket::Write(
+ net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+
+ scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
+ size_t size = std::min(kMaxPacketSize, buffer_len);
+ packet->mutable_data()->assign(buffer->data(), size);
+
+ write_pending_ = true;
+ bool result = channel_->DoWrite(packet.Pass(), base::Bind(
+ &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
+
+ if (!result) {
+ // Cannot complete the write, e.g. if the connection has been terminated.
+ return net::ERR_FAILED;
+ }
+
+ // OnWriteComplete() might be called above synchronously.
+ if (write_pending_) {
+ DCHECK(write_callback_.is_null());
+ write_callback_ = callback;
+ write_result_ = size;
+ return net::ERR_IO_PENDING;
+ }
+
+ return size;
+}
+
+void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
+ write_pending_ = false;
+ if (!write_callback_.is_null()) {
+ net::CompletionCallback cb;
+ std::swap(cb, write_callback_);
+ cb.Run(write_result_);
+ }
+}
+
+void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
+ if (!write_callback_.is_null()) {
+ net::CompletionCallback cb;
+ std::swap(cb, write_callback_);
+ cb.Run(net::ERR_FAILED);
+ }
+}
+
+void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
+ if (!read_callback_.is_null()) {
+ int result = channel_->DoRead(read_buffer_, read_buffer_size_);
+ read_buffer_ = NULL;
+ DCHECK_GT(result, 0);
+ net::CompletionCallback cb;
+ std::swap(cb, read_callback_);
+ cb.Run(result);
+ }
+}
+
+ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
+ const std::string& base_channel_name)
+ : base_channel_factory_(factory),
+ base_channel_name_(base_channel_name),
+ next_channel_id_(0),
+ destroyed_flag_(NULL) {
+ factory->CreateStreamChannel(
+ base_channel_name,
+ base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
+ base::Unretained(this)));
+}
+
+ChannelMultiplexer::~ChannelMultiplexer() {
+ DCHECK(pending_channels_.empty());
+ STLDeleteValues(&channels_);
+
+ // Cancel creation of the base channel if it hasn't finished.
+ if (base_channel_factory_)
+ base_channel_factory_->CancelChannelCreation(base_channel_name_);
+
+ if (destroyed_flag_)
+ *destroyed_flag_ = true;
+}
+
+void ChannelMultiplexer::CreateStreamChannel(
+ const std::string& name,
+ const StreamChannelCallback& callback) {
+ if (base_channel_.get()) {
+ // Already have |base_channel_|. Create new multiplexed channel
+ // synchronously.
+ callback.Run(GetOrCreateChannel(name)->CreateSocket());
+ } else if (!base_channel_.get() && !base_channel_factory_) {
+ // Fail synchronously if we failed to create |base_channel_|.
+ callback.Run(scoped_ptr<net::StreamSocket>());
+ } else {
+ // Still waiting for the |base_channel_|.
+ pending_channels_.push_back(PendingChannel(name, callback));
+ }
+}
+
+void ChannelMultiplexer::CreateDatagramChannel(
+ const std::string& name,
+ const DatagramChannelCallback& callback) {
+ NOTIMPLEMENTED();
+ callback.Run(scoped_ptr<net::Socket>());
+}
+
+void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
+ it != pending_channels_.end(); ++it) {
+ if (it->name == name) {
+ pending_channels_.erase(it);
+ return;
+ }
+ }
+}
+
+void ChannelMultiplexer::OnBaseChannelReady(
+ scoped_ptr<net::StreamSocket> socket) {
+ base_channel_factory_ = NULL;
+ base_channel_ = socket.Pass();
+
+ if (!base_channel_.get()) {
+ // Notify all callers that we can't create any channels.
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
+ it != pending_channels_.end(); ++it) {
+ it->callback.Run(scoped_ptr<net::StreamSocket>());
+ }
+ pending_channels_.clear();
+ return;
+ }
+
+ // Initialize reader and writer.
+ reader_.Init(base_channel_.get(),
+ base::Bind(&ChannelMultiplexer::OnIncomingPacket,
+ base::Unretained(this)));
+ writer_.Init(base_channel_.get(),
+ base::Bind(&ChannelMultiplexer::OnWriteFailed,
+ base::Unretained(this)));
+
+ // Now create all pending channels.
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
+ it != pending_channels_.end(); ++it) {
+ it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket());
+ }
+ pending_channels_.clear();
+}
+
+ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
+ const std::string& name) {
+ // Check if we already have a channel with the requested name.
+ std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
+ if (it != channels_.end())
+ return it->second;
+
+ // Create a new channel if we haven't found existing one.
+ MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
+ ++next_channel_id_;
+ channels_[channel->name()] = channel;
+ return channel;
+}
+
+
+void ChannelMultiplexer::OnWriteFailed(int error) {
+ bool destroyed = false;
+ destroyed_flag_ = &destroyed;
+ for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
+ it != channels_.end(); ++it) {
+ it->second->OnWriteFailed();
+ if (destroyed)
+ return;
+ }
+ destroyed_flag_ = NULL;
+}
+
+void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ if (!packet->has_channel_id()) {
+ LOG(ERROR) << "Received packet without channel_id.";
+ done_task.Run();
+ return;
+ }
+
+ int receive_id = packet->channel_id();
+ MuxChannel* channel = NULL;
+ std::map<int, MuxChannel*>::iterator it =
+ channels_by_receive_id_.find(receive_id);
+ if (it != channels_by_receive_id_.end()) {
+ channel = it->second;
+ } else {
+ // This is a new |channel_id| we haven't seen before. Look it up by name.
+ if (!packet->has_channel_name()) {
+ LOG(ERROR) << "Received packet with unknown channel_id and "
+ "without channel_name.";
+ done_task.Run();
+ return;
+ }
+ channel = GetOrCreateChannel(packet->channel_name());
+ channel->set_receive_id(receive_id);
+ channels_by_receive_id_[receive_id] = channel;
+ }
+
+ channel->OnIncomingPacket(packet.Pass(), done_task);
+}
+
+bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
+}
+
+} // namespace protocol
+} // namespace remoting
diff --git a/remoting/protocol/channel_multiplexer.h b/remoting/protocol/channel_multiplexer.h
new file mode 100644
index 0000000..0f16fb1
--- /dev/null
+++ b/remoting/protocol/channel_multiplexer.h
@@ -0,0 +1,88 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
+#define REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
+
+#include "remoting/proto/mux.pb.h"
+#include "remoting/protocol/buffered_socket_writer.h"
+#include "remoting/protocol/channel_factory.h"
+#include "remoting/protocol/message_reader.h"
+
+namespace remoting {
+namespace protocol {
+
+class ChannelMultiplexer : public ChannelFactory {
+ public:
+ static const char kMuxChannelName[];
+
+ // |factory| is used to create the channel upon which to multiplex.
+ ChannelMultiplexer(ChannelFactory* factory,
+ const std::string& base_channel_name);
+ virtual ~ChannelMultiplexer();
+
+ // ChannelFactory interface.
+ virtual void CreateStreamChannel(
+ const std::string& name,
+ const StreamChannelCallback& callback) OVERRIDE;
+ virtual void CreateDatagramChannel(
+ const std::string& name,
+ const DatagramChannelCallback& callback) OVERRIDE;
+ virtual void CancelChannelCreation(const std::string& name) OVERRIDE;
+
+ private:
+ struct PendingChannel;
+ class MuxChannel;
+ class MuxSocket;
+ friend class MuxChannel;
+
+ // Callback for |base_channel_| creation.
+ void OnBaseChannelReady(scoped_ptr<net::StreamSocket> socket);
+
+ // Helper method used to create channels.
+ MuxChannel* GetOrCreateChannel(const std::string& name);
+
+ // Callbacks for |writer_| and |reader_|.
+ void OnWriteFailed(int error);
+ void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task);
+
+ // Called by MuxChannel.
+ bool DoWrite(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task);
+
+ // Factory used to create |base_channel_|. Set to NULL once creation is
+ // finished or failed.
+ ChannelFactory* base_channel_factory_;
+
+ // Name of the underlying channel.
+ std::string base_channel_name_;
+
+ // The channel over which to multiplex.
+ scoped_ptr<net::StreamSocket> base_channel_;
+
+ // List of requested channels while we are waiting for |base_channel_|.
+ std::list<PendingChannel> pending_channels_;
+
+ int next_channel_id_;
+ std::map<std::string, MuxChannel*> channels_;
+
+ // Channels are added to |channels_by_receive_id_| only after we receive
+ // receive_id from the remote peer.
+ std::map<int, MuxChannel*> channels_by_receive_id_;
+
+ BufferedSocketWriter writer_;
+ ProtobufMessageReader<MultiplexPacket> reader_;
+
+ // Flag used by OnWriteFailed() to detect when the multiplexer is destroyed.
+ bool* destroyed_flag_;
+
+ DISALLOW_COPY_AND_ASSIGN(ChannelMultiplexer);
+};
+
+} // namespace protocol
+} // namespace remoting
+
+
+#endif // REMOTING_PROTOCOL_CHANNEL_MULTIPLEXER_H_
diff --git a/remoting/protocol/channel_multiplexer_unittest.cc b/remoting/protocol/channel_multiplexer_unittest.cc
new file mode 100644
index 0000000..11a459b
--- /dev/null
+++ b/remoting/protocol/channel_multiplexer_unittest.cc
@@ -0,0 +1,301 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "remoting/protocol/channel_multiplexer.h"
+
+#include "base/bind.h"
+#include "base/message_loop.h"
+#include "net/base/net_errors.h"
+#include "net/socket/socket.h"
+#include "net/socket/stream_socket.h"
+#include "remoting/base/constants.h"
+#include "remoting/protocol/connection_tester.h"
+#include "remoting/protocol/fake_session.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using testing::_;
+using testing::AtMost;
+using testing::InvokeWithoutArgs;
+
+namespace remoting {
+namespace protocol {
+
+namespace {
+
+const int kMessageSize = 1024;
+const int kMessages = 100;
+const char kMuxChannelName[] = "mux";
+
+void QuitCurrentThread() {
+ MessageLoop::current()->PostTask(FROM_HERE, MessageLoop::QuitClosure());
+}
+
+class MockSocketCallback {
+ public:
+ MOCK_METHOD1(OnDone, void(int result));
+};
+
+} // namespace
+
+class ChannelMultiplexerTest : public testing::Test {
+ public:
+ void DeleteAll() {
+ host_socket1_.reset();
+ host_socket2_.reset();
+ client_socket1_.reset();
+ client_socket2_.reset();
+ host_mux_.reset();
+ client_mux_.reset();
+ }
+
+ protected:
+ virtual void SetUp() OVERRIDE {
+ // Create pair of multiplexers and connect them to each other.
+ host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
+ client_mux_.reset(new ChannelMultiplexer(&client_session_,
+ kMuxChannelName));
+ FakeSocket* host_socket =
+ host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
+ FakeSocket* client_socket =
+ client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
+ host_socket->PairWith(client_socket);
+
+ // Make writes asynchronous in one direction.
+ host_socket->set_async_write(true);
+ }
+
+ void CreateChannel(const std::string& name,
+ scoped_ptr<net::StreamSocket>* host_socket,
+ scoped_ptr<net::StreamSocket>* client_socket) {
+ int counter = 2;
+ host_mux_->CreateStreamChannel(name, base::Bind(
+ &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
+ host_socket, &counter));
+ client_mux_->CreateStreamChannel(name, base::Bind(
+ &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
+ client_socket, &counter));
+
+ message_loop_.Run();
+
+ EXPECT_TRUE(host_socket->get());
+ EXPECT_TRUE(client_socket->get());
+ }
+
+ void OnChannelConnected(
+ scoped_ptr<net::StreamSocket>* storage,
+ int* counter,
+ scoped_ptr<net::StreamSocket> socket) {
+ *storage = socket.Pass();
+ --(*counter);
+ EXPECT_GE(*counter, 0);
+ if (*counter == 0)
+ QuitCurrentThread();
+ }
+
+ scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
+ scoped_refptr<net::IOBufferWithSize> result =
+ new net::IOBufferWithSize(size);
+ for (int i = 0; i< size; ++i) {
+ result->data()[i] = rand() % 256;
+ }
+ return result;
+ }
+
+ MessageLoop message_loop_;
+
+ FakeSession host_session_;
+ FakeSession client_session_;
+
+ scoped_ptr<ChannelMultiplexer> host_mux_;
+ scoped_ptr<ChannelMultiplexer> client_mux_;
+
+ scoped_ptr<net::StreamSocket> host_socket1_;
+ scoped_ptr<net::StreamSocket> client_socket1_;
+ scoped_ptr<net::StreamSocket> host_socket2_;
+ scoped_ptr<net::StreamSocket> client_socket2_;
+};
+
+
+TEST_F(ChannelMultiplexerTest, OneChannel) {
+ scoped_ptr<net::StreamSocket> host_socket;
+ scoped_ptr<net::StreamSocket> client_socket;
+ ASSERT_NO_FATAL_FAILURE(CreateChannel("test", &host_socket, &client_socket));
+
+ StreamConnectionTester tester(host_socket.get(), client_socket.get(),
+ kMessageSize, kMessages);
+ tester.Start();
+ message_loop_.Run();
+ tester.CheckResults();
+}
+
+TEST_F(ChannelMultiplexerTest, TwoChannels) {
+ scoped_ptr<net::StreamSocket> host_socket1_;
+ scoped_ptr<net::StreamSocket> client_socket1_;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("test", &host_socket1_, &client_socket1_));
+
+ scoped_ptr<net::StreamSocket> host_socket2_;
+ scoped_ptr<net::StreamSocket> client_socket2_;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("ch2", &host_socket2_, &client_socket2_));
+
+ StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
+ kMessageSize, kMessages);
+ StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
+ kMessageSize, kMessages);
+ tester1.Start();
+ tester2.Start();
+ while (!tester1.done() || !tester2.done()) {
+ message_loop_.Run();
+ }
+ tester1.CheckResults();
+ tester2.CheckResults();
+}
+
+// Four channels, two in each direction
+TEST_F(ChannelMultiplexerTest, FourChannels) {
+ scoped_ptr<net::StreamSocket> host_socket1_;
+ scoped_ptr<net::StreamSocket> client_socket1_;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("test", &host_socket1_, &client_socket1_));
+
+ scoped_ptr<net::StreamSocket> host_socket2_;
+ scoped_ptr<net::StreamSocket> client_socket2_;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("ch2", &host_socket2_, &client_socket2_));
+
+ scoped_ptr<net::StreamSocket> host_socket3;
+ scoped_ptr<net::StreamSocket> client_socket3;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("test3", &host_socket3, &client_socket3));
+
+ scoped_ptr<net::StreamSocket> host_socket4;
+ scoped_ptr<net::StreamSocket> client_socket4;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("ch4", &host_socket4, &client_socket4));
+
+ StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
+ kMessageSize, kMessages);
+ StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
+ kMessageSize, kMessages);
+ StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
+ kMessageSize, kMessages);
+ StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
+ kMessageSize, kMessages);
+ tester1.Start();
+ tester2.Start();
+ tester3.Start();
+ tester4.Start();
+ while (!tester1.done() || !tester2.done() ||
+ !tester3.done() || !tester4.done()) {
+ message_loop_.Run();
+ }
+ tester1.CheckResults();
+ tester2.CheckResults();
+ tester3.CheckResults();
+ tester4.CheckResults();
+}
+
+TEST_F(ChannelMultiplexerTest, SyncFail) {
+ scoped_ptr<net::StreamSocket> host_socket1_;
+ scoped_ptr<net::StreamSocket> client_socket1_;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("test", &host_socket1_, &client_socket1_));
+
+ scoped_ptr<net::StreamSocket> host_socket2_;
+ scoped_ptr<net::StreamSocket> client_socket2_;
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("ch2", &host_socket2_, &client_socket2_));
+
+ host_session_.GetStreamChannel(kMuxChannelName)->
+ set_next_write_error(net::ERR_FAILED);
+ host_session_.GetStreamChannel(kMuxChannelName)->
+ set_async_write(false);
+
+ scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
+
+ MockSocketCallback cb1;
+ MockSocketCallback cb2;
+
+ EXPECT_CALL(cb1, OnDone(_))
+ .Times(0);
+ EXPECT_CALL(cb2, OnDone(_))
+ .Times(0);
+
+ EXPECT_EQ(net::ERR_FAILED, host_socket1_->Write(buf, buf->size(), base::Bind(
+ &MockSocketCallback::OnDone, base::Unretained(&cb1))));
+ EXPECT_EQ(net::ERR_FAILED, host_socket2_->Write(buf, buf->size(), base::Bind(
+ &MockSocketCallback::OnDone, base::Unretained(&cb2))));
+
+ message_loop_.RunAllPending();
+}
+
+TEST_F(ChannelMultiplexerTest, AsyncFail) {
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("test", &host_socket1_, &client_socket1_));
+
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("ch2", &host_socket2_, &client_socket2_));
+
+ host_session_.GetStreamChannel(kMuxChannelName)->
+ set_next_write_error(net::ERR_FAILED);
+ host_session_.GetStreamChannel(kMuxChannelName)->
+ set_async_write(true);
+
+ scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
+
+ MockSocketCallback cb1;
+ MockSocketCallback cb2;
+ EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
+ EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
+
+ EXPECT_EQ(net::ERR_IO_PENDING,
+ host_socket1_->Write(buf, buf->size(), base::Bind(
+ &MockSocketCallback::OnDone, base::Unretained(&cb1))));
+ EXPECT_EQ(net::ERR_IO_PENDING,
+ host_socket2_->Write(buf, buf->size(), base::Bind(
+ &MockSocketCallback::OnDone, base::Unretained(&cb2))));
+
+ message_loop_.RunAllPending();
+}
+
+TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("test", &host_socket1_, &client_socket1_));
+ ASSERT_NO_FATAL_FAILURE(
+ CreateChannel("ch2", &host_socket2_, &client_socket2_));
+
+ host_session_.GetStreamChannel(kMuxChannelName)->
+ set_next_write_error(net::ERR_FAILED);
+ host_session_.GetStreamChannel(kMuxChannelName)->
+ set_async_write(true);
+
+ scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
+
+ MockSocketCallback cb1;
+ MockSocketCallback cb2;
+
+ EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
+ .Times(AtMost(1))
+ .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
+ EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
+ .Times(AtMost(1))
+ .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
+
+ EXPECT_EQ(net::ERR_IO_PENDING,
+ host_socket1_->Write(buf, buf->size(), base::Bind(
+ &MockSocketCallback::OnDone, base::Unretained(&cb1))));
+ EXPECT_EQ(net::ERR_IO_PENDING,
+ host_socket2_->Write(buf, buf->size(), base::Bind(
+ &MockSocketCallback::OnDone, base::Unretained(&cb2))));
+
+ message_loop_.RunAllPending();
+
+ // Check that the sockets were destroyed.
+ EXPECT_FALSE(host_mux_.get());
+}
+
+} // namespace protocol
+} // namespace remoting
diff --git a/remoting/protocol/connection_tester.h b/remoting/protocol/connection_tester.h
index b32cf62..f89cd7a 100644
--- a/remoting/protocol/connection_tester.h
+++ b/remoting/protocol/connection_tester.h
@@ -34,6 +34,7 @@ class StreamConnectionTester {
~StreamConnectionTester();
void Start();
+ bool done() { return done_; }
void CheckResults();
protected:
diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h
index 9148318..7041486 100644
--- a/remoting/protocol/session.h
+++ b/remoting/protocol/session.h
@@ -7,16 +7,12 @@
#include <string>
-#include "base/callback.h"
-#include "base/threading/non_thread_safe.h"
-#include "remoting/protocol/buffered_socket_writer.h"
+#include "remoting/protocol/channel_factory.h"
#include "remoting/protocol/errors.h"
#include "remoting/protocol/session_config.h"
namespace net {
class IPEndPoint;
-class Socket;
-class StreamSocket;
} // namespace net
namespace remoting {
@@ -27,7 +23,7 @@ struct TransportRoute;
// Generic interface for Chromotocol connection used by both client and host.
// Provides access to the connection channels, but doesn't depend on the
// protocol used for each channel.
-class Session : public base::NonThreadSafe {
+class Session : public ChannelFactory {
public:
enum State {
// Created, but not connecting yet.
@@ -74,12 +70,6 @@ class Session : public base::NonThreadSafe {
bool ready) {}
};
- // TODO(sergeyu): Specify connection error code when channel
- // connection fails.
- typedef base::Callback<void(scoped_ptr<net::StreamSocket>)>
- StreamChannelCallback;
- typedef base::Callback<void(scoped_ptr<net::Socket>)>
- DatagramChannelCallback;
Session() {}
virtual ~Session() {}
@@ -91,23 +81,6 @@ class Session : public base::NonThreadSafe {
// Returns error code for a failed session.
virtual ErrorCode error() = 0;
- // Creates new channels for this connection. The specified callback
- // is called when then new channel is created and connected. The
- // callback is called with NULL if connection failed for any reason.
- // All channels must be destroyed before the session is
- // destroyed. Can be called only when in CONNECTING, CONNECTED or
- // AUTHENTICATED states.
- virtual void CreateStreamChannel(
- const std::string& name, const StreamChannelCallback& callback) = 0;
- virtual void CreateDatagramChannel(
- const std::string& name, const DatagramChannelCallback& callback) = 0;
-
- // Cancels a pending CreateStreamChannel() or CreateDatagramChannel()
- // operation for the named channel. If the channel creation already
- // completed then cancelling it has no effect. When shutting down
- // this method must be called for each channel pending creation.
- virtual void CancelChannelCreation(const std::string& name) = 0;
-
// JID of the other side.
virtual const std::string& jid() = 0;
diff --git a/remoting/remoting.gyp b/remoting/remoting.gyp
index fc8f683..48725f9 100644
--- a/remoting/remoting.gyp
+++ b/remoting/remoting.gyp
@@ -1616,6 +1616,8 @@
'protocol/channel_authenticator.h',
'protocol/channel_dispatcher_base.cc',
'protocol/channel_dispatcher_base.h',
+ 'protocol/channel_multiplexer.cc',
+ 'protocol/channel_multiplexer.h',
'protocol/client_control_dispatcher.cc',
'protocol/client_control_dispatcher.h',
'protocol/client_event_dispatcher.cc',
@@ -1623,11 +1625,11 @@
'protocol/client_stub.h',
'protocol/clipboard_echo_filter.cc',
'protocol/clipboard_echo_filter.h',
- 'protocol/clipboard_filter.h',
'protocol/clipboard_filter.cc',
+ 'protocol/clipboard_filter.h',
+ 'protocol/clipboard_stub.h',
'protocol/clipboard_thread_proxy.cc',
'protocol/clipboard_thread_proxy.h',
- 'protocol/clipboard_stub.h',
'protocol/connection_to_client.cc',
'protocol/connection_to_client.h',
'protocol/connection_to_host.cc',
@@ -1802,6 +1804,7 @@
'protocol/authenticator_test_base.cc',
'protocol/authenticator_test_base.h',
'protocol/buffered_socket_writer_unittest.cc',
+ 'protocol/channel_multiplexer_unittest.cc',
'protocol/clipboard_echo_filter_unittest.cc',
'protocol/connection_tester.cc',
'protocol/connection_tester.h',