diff options
Diffstat (limited to 'remoting/protocol')
33 files changed, 549 insertions, 298 deletions
diff --git a/remoting/protocol/BUILD.gn b/remoting/protocol/BUILD.gn index 189e5884..4a7fa0e 100644 --- a/remoting/protocol/BUILD.gn +++ b/remoting/protocol/BUILD.gn @@ -44,6 +44,7 @@ static_library("protocol") { "connection_to_host.h", "content_description.cc", "content_description.h", + "datagram_channel_factory.h", "errors.h", "host_control_dispatcher.cc", "host_control_dispatcher.h", @@ -98,6 +99,10 @@ static_library("protocol") { "protobuf_video_reader.h", "protobuf_video_writer.cc", "protobuf_video_writer.h", + "pseudotcp_channel_factory.cc", + "pseudotcp_channel_factory.h", + "secure_channel_factory.cc", + "secure_channel_factory.h", "session.h", "session_config.cc", "session_config.h", @@ -106,6 +111,7 @@ static_library("protocol") { "socket_util.h", "ssl_hmac_channel_authenticator.cc", "ssl_hmac_channel_authenticator.h", + "stream_channel_factory.h", "third_party_authenticator_base.cc", "third_party_authenticator_base.h", "third_party_client_authenticator.cc", diff --git a/remoting/protocol/authenticator_test_base.cc b/remoting/protocol/authenticator_test_base.cc index 3f1d9f6..bd63cd7 100644 --- a/remoting/protocol/authenticator_test_base.cc +++ b/remoting/protocol/authenticator_test_base.cc @@ -10,6 +10,7 @@ #include "base/path_service.h" #include "base/test/test_timeouts.h" #include "base/timer/timer.h" +#include "net/base/net_errors.h" #include "net/base/test_data_directory.h" #include "remoting/base/rsa_key_pair.h" #include "remoting/protocol/authenticator.h" @@ -157,14 +158,14 @@ void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { } void AuthenticatorTestBase::OnHostConnected( - net::Error error, + int error, scoped_ptr<net::StreamSocket> socket) { host_callback_.OnDone(error); host_socket_ = socket.Pass(); } void AuthenticatorTestBase::OnClientConnected( - net::Error error, + int error, scoped_ptr<net::StreamSocket> socket) { client_callback_.OnDone(error); client_socket_ = socket.Pass(); diff --git a/remoting/protocol/authenticator_test_base.h b/remoting/protocol/authenticator_test_base.h index e20774a..c5c7e6b 100644 --- a/remoting/protocol/authenticator_test_base.h +++ b/remoting/protocol/authenticator_test_base.h @@ -9,7 +9,6 @@ #include "base/memory/ref_counted.h" #include "base/message_loop/message_loop.h" -#include "net/base/net_errors.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -37,7 +36,7 @@ class AuthenticatorTestBase : public testing::Test { public: MockChannelDoneCallback(); ~MockChannelDoneCallback(); - MOCK_METHOD1(OnDone, void(net::Error error)); + MOCK_METHOD1(OnDone, void(int error)); }; static void ContinueAuthExchangeWith(Authenticator* sender, @@ -49,9 +48,9 @@ class AuthenticatorTestBase : public testing::Test { void RunHostInitiatedAuthExchange(); void RunChannelAuth(bool expected_fail); - void OnHostConnected(net::Error error, + void OnHostConnected(int error, scoped_ptr<net::StreamSocket> socket); - void OnClientConnected(net::Error error, + void OnClientConnected(int error, scoped_ptr<net::StreamSocket> socket); base::MessageLoop message_loop_; diff --git a/remoting/protocol/channel_authenticator.h b/remoting/protocol/channel_authenticator.h index 7466b09..8bef908 100644 --- a/remoting/protocol/channel_authenticator.h +++ b/remoting/protocol/channel_authenticator.h @@ -8,7 +8,6 @@ #include <string> #include "base/callback_forward.h" -#include "net/base/net_errors.h" namespace net { class StreamSocket; @@ -23,14 +22,14 @@ namespace protocol { // should be used only once for one channel. class ChannelAuthenticator { public: - typedef base::Callback<void(net::Error error, scoped_ptr<net::StreamSocket>)> + typedef base::Callback<void(int error, scoped_ptr<net::StreamSocket>)> DoneCallback; virtual ~ChannelAuthenticator() {} - // Start authentication of the given |socket|. |done_callback| is - // called when authentication is finished. Callback may be invoked - // before this method returns. + // Start authentication of the given |socket|. |done_callback| is called when + // authentication is finished. Callback may be invoked before this method + // returns, and may delete the calling authenticator. virtual void SecureAndAuthenticate( scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) = 0; diff --git a/remoting/protocol/channel_dispatcher_base.cc b/remoting/protocol/channel_dispatcher_base.cc index 10d64006..c292b2a 100644 --- a/remoting/protocol/channel_dispatcher_base.cc +++ b/remoting/protocol/channel_dispatcher_base.cc @@ -6,9 +6,9 @@ #include "base/bind.h" #include "net/socket/stream_socket.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/session.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { diff --git a/remoting/protocol/channel_dispatcher_base.h b/remoting/protocol/channel_dispatcher_base.h index 906d0f9..f71d291 100644 --- a/remoting/protocol/channel_dispatcher_base.h +++ b/remoting/protocol/channel_dispatcher_base.h @@ -19,7 +19,7 @@ namespace remoting { namespace protocol { struct ChannelConfig; -class ChannelFactory; +class StreamChannelFactory; class Session; // Base class for channel message dispatchers. It's responsible for @@ -56,7 +56,7 @@ class ChannelDispatcherBase { void OnChannelReady(scoped_ptr<net::StreamSocket> socket); std::string channel_name_; - ChannelFactory* channel_factory_; + StreamChannelFactory* channel_factory_; InitializedCallback initialized_callback_; scoped_ptr<net::StreamSocket> channel_; diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc index a6a80b4..5751440 100644 --- a/remoting/protocol/channel_multiplexer.cc +++ b/remoting/protocol/channel_multiplexer.cc @@ -353,7 +353,7 @@ void ChannelMultiplexer::MuxSocket::OnPacketReceived() { } } -ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, +ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, const std::string& base_channel_name) : base_channel_factory_(factory), base_channel_name_(base_channel_name), diff --git a/remoting/protocol/channel_multiplexer.h b/remoting/protocol/channel_multiplexer.h index 924f132..506dc4b 100644 --- a/remoting/protocol/channel_multiplexer.h +++ b/remoting/protocol/channel_multiplexer.h @@ -8,22 +8,22 @@ #include "base/memory/weak_ptr.h" #include "remoting/proto/mux.pb.h" #include "remoting/protocol/buffered_socket_writer.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/message_reader.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { -class ChannelMultiplexer : public ChannelFactory { +class ChannelMultiplexer : public StreamChannelFactory { public: static const char kMuxChannelName[]; // |factory| is used to create the channel upon which to multiplex. - ChannelMultiplexer(ChannelFactory* factory, + ChannelMultiplexer(StreamChannelFactory* factory, const std::string& base_channel_name); virtual ~ChannelMultiplexer(); - // ChannelFactory interface. + // StreamChannelFactory interface. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) OVERRIDE; virtual void CancelChannelCreation(const std::string& name) OVERRIDE; @@ -59,7 +59,7 @@ class ChannelMultiplexer : public ChannelFactory { // Factory used to create |base_channel_|. Set to NULL once creation is // finished or failed. - ChannelFactory* base_channel_factory_; + StreamChannelFactory* base_channel_factory_; // Name of the underlying channel. std::string base_channel_name_; diff --git a/remoting/protocol/datagram_channel_factory.h b/remoting/protocol/datagram_channel_factory.h new file mode 100644 index 0000000..41ade7f --- /dev/null +++ b/remoting/protocol/datagram_channel_factory.h @@ -0,0 +1,45 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_DATAGRAM_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_DATAGRAM_CHANNEL_FACTORY_H_ + +namespace net { +class Socket; +} // namespace net + +namespace remoting { +namespace protocol { + +class DatagramChannelFactory { + public: + typedef base::Callback<void(scoped_ptr<net::Socket>)> + ChannelCreatedCallback; + + DatagramChannelFactory() {} + + // Creates new channels and calls the |callback| when then new channel is + // created and connected. The |callback| is called with NULL if channel setup + // failed for any reason. Callback may be called synchronously, before the + // call returns. All channels must be destroyed, and CancelChannelCreation() + // called for any pending channels, before the factory is destroyed. + virtual void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) = 0; + + // Cancels a pending CreateChannel() operation for the named channel. If the + // channel creation already completed then canceling it has no effect. When + // shutting down this method must be called for each channel pending creation. + virtual void CancelChannelCreation(const std::string& name) = 0; + + protected: + virtual ~DatagramChannelFactory() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DatagramChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_DATAGRAM_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/fake_authenticator.cc b/remoting/protocol/fake_authenticator.cc index bf06c56..9f2b4f6 100644 --- a/remoting/protocol/fake_authenticator.cc +++ b/remoting/protocol/fake_authenticator.cc @@ -7,6 +7,7 @@ #include "base/message_loop/message_loop.h" #include "base/strings/string_number_conversions.h" #include "net/base/io_buffer.h" +#include "net/base/net_errors.h" #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "testing/gtest/include/gtest/gtest.h" @@ -34,31 +35,33 @@ void FakeChannelAuthenticator::SecureAndAuthenticate( if (async_) { done_callback_ = done_callback; - scoped_refptr<net::IOBuffer> write_buf = new net::IOBuffer(1); - write_buf->data()[0] = 0; - int result = - socket_->Write(write_buf.get(), - 1, - base::Bind(&FakeChannelAuthenticator::OnAuthBytesWritten, - weak_factory_.GetWeakPtr())); - if (result != net::ERR_IO_PENDING) { - // This will not call the callback because |did_read_bytes_| is - // still set to false. - OnAuthBytesWritten(result); + if (result_ != net::OK) { + // Don't write anything if we are going to reject auth to make test + // ordering deterministic. + did_write_bytes_ = true; + } else { + scoped_refptr<net::IOBuffer> write_buf = new net::IOBuffer(1); + write_buf->data()[0] = 0; + int result = socket_->Write( + write_buf.get(), 1, + base::Bind(&FakeChannelAuthenticator::OnAuthBytesWritten, + weak_factory_.GetWeakPtr())); + if (result != net::ERR_IO_PENDING) { + // This will not call the callback because |did_read_bytes_| is + // still set to false. + OnAuthBytesWritten(result); + } } scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(1); - result = - socket_->Read(read_buf.get(), - 1, + int result = + socket_->Read(read_buf.get(), 1, base::Bind(&FakeChannelAuthenticator::OnAuthBytesRead, weak_factory_.GetWeakPtr())); if (result != net::ERR_IO_PENDING) OnAuthBytesRead(result); } else { - if (result_ != net::OK) - socket_.reset(); - done_callback.Run(result_, socket_.Pass()); + CallDoneCallback(); } } @@ -67,7 +70,7 @@ void FakeChannelAuthenticator::OnAuthBytesWritten(int result) { EXPECT_FALSE(did_write_bytes_); did_write_bytes_ = true; if (did_read_bytes_) - done_callback_.Run(result_, socket_.Pass()); + CallDoneCallback(); } void FakeChannelAuthenticator::OnAuthBytesRead(int result) { @@ -75,7 +78,15 @@ void FakeChannelAuthenticator::OnAuthBytesRead(int result) { EXPECT_FALSE(did_read_bytes_); did_read_bytes_ = true; if (did_write_bytes_) - done_callback_.Run(result_, socket_.Pass()); + CallDoneCallback(); +} + +void FakeChannelAuthenticator::CallDoneCallback() { + DoneCallback callback = done_callback_; + done_callback_.Reset(); + if (result_ != net::OK) + socket_.reset(); + callback.Run(result_, socket_.Pass()); } FakeAuthenticator::FakeAuthenticator( diff --git a/remoting/protocol/fake_authenticator.h b/remoting/protocol/fake_authenticator.h index b74654d0..a6ddc74 100644 --- a/remoting/protocol/fake_authenticator.h +++ b/remoting/protocol/fake_authenticator.h @@ -24,14 +24,12 @@ class FakeChannelAuthenticator : public ChannelAuthenticator { const DoneCallback& done_callback) OVERRIDE; private: - void CallCallback( - net::Error error, - scoped_ptr<net::StreamSocket> socket); - void OnAuthBytesWritten(int result); void OnAuthBytesRead(int result); - net::Error result_; + void CallDoneCallback(); + + int result_; bool async_; scoped_ptr<net::StreamSocket> socket_; diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index 7c62ed2..f02a47a 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -320,11 +320,11 @@ void FakeSession::set_config(const SessionConfig& config) { config_ = config; } -ChannelFactory* FakeSession::GetTransportChannelFactory() { +StreamChannelFactory* FakeSession::GetTransportChannelFactory() { return this; } -ChannelFactory* FakeSession::GetMultiplexedChannelFactory() { +StreamChannelFactory* FakeSession::GetMultiplexedChannelFactory() { return this; } diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index c7793f2..5240681 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -14,8 +14,8 @@ #include "net/base/completion_callback.h" #include "net/socket/socket.h" #include "net/socket/stream_socket.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/stream_channel_factory.h" namespace base { class MessageLoop; @@ -148,7 +148,7 @@ class FakeUdpSocket : public net::Socket { // FakeSession is a dummy protocol::Session that uses FakeSocket for all // channels. class FakeSession : public Session, - public ChannelFactory { + public StreamChannelFactory { public: FakeSession(); virtual ~FakeSession(); @@ -173,11 +173,11 @@ class FakeSession : public Session, virtual const CandidateSessionConfig* candidate_config() OVERRIDE; virtual const SessionConfig& config() OVERRIDE; virtual void set_config(const SessionConfig& config) OVERRIDE; - virtual ChannelFactory* GetTransportChannelFactory() OVERRIDE; - virtual ChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetTransportChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; virtual void Close() OVERRIDE; - // ChannelFactory interface. + // StreamChannelFactory interface. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) OVERRIDE; virtual void CancelChannelCreation(const std::string& name) OVERRIDE; diff --git a/remoting/protocol/jingle_session.cc b/remoting/protocol/jingle_session.cc index 727be4b..e6eb1e8 100644 --- a/remoting/protocol/jingle_session.cc +++ b/remoting/protocol/jingle_session.cc @@ -18,7 +18,10 @@ #include "remoting/protocol/content_description.h" #include "remoting/protocol/jingle_messages.h" #include "remoting/protocol/jingle_session_manager.h" +#include "remoting/protocol/pseudotcp_channel_factory.h" +#include "remoting/protocol/secure_channel_factory.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/stream_channel_factory.h" #include "remoting/signaling/iq_sender.h" #include "third_party/libjingle/source/talk/p2p/base/candidate.h" #include "third_party/webrtc/libjingle/xmllite/xmlelement.h" @@ -81,7 +84,7 @@ JingleSession::~JingleSession() { pending_requests_.end()); STLDeleteContainerPointers(transport_info_requests_.begin(), transport_info_requests_.end()); - STLDeleteContainerPairSecondPointers(channels_.begin(), channels_.end()); + DCHECK(channels_.empty()); session_manager_->SessionDestroyed(this); } @@ -187,7 +190,7 @@ void JingleSession::ContinueAcceptIncomingConnection() { SetState(CONNECTED); if (authenticator_->state() == Authenticator::ACCEPTED) { - SetState(AUTHENTICATED); + OnAuthenticated(); } else { DCHECK_EQ(authenticator_->state(), Authenticator::WAITING_MESSAGE); if (authenticator_->started()) { @@ -218,15 +221,17 @@ void JingleSession::set_config(const SessionConfig& config) { config_is_set_ = true; } -ChannelFactory* JingleSession::GetTransportChannelFactory() { +StreamChannelFactory* JingleSession::GetTransportChannelFactory() { DCHECK(CalledOnValidThread()); - return this; + return secure_channel_factory_.get(); } -ChannelFactory* JingleSession::GetMultiplexedChannelFactory() { +StreamChannelFactory* JingleSession::GetMultiplexedChannelFactory() { DCHECK(CalledOnValidThread()); - if (!channel_multiplexer_.get()) - channel_multiplexer_.reset(new ChannelMultiplexer(this, kMuxChannelName)); + if (!channel_multiplexer_.get()) { + channel_multiplexer_.reset( + new ChannelMultiplexer(GetTransportChannelFactory(), kMuxChannelName)); + } return channel_multiplexer_.get(); } @@ -254,19 +259,17 @@ void JingleSession::CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) { DCHECK(!channels_[name]); - scoped_ptr<ChannelAuthenticator> channel_authenticator = - authenticator_->CreateChannelAuthenticator(); - scoped_ptr<StreamTransport> channel = - session_manager_->transport_factory_->CreateStreamTransport(); - channel->Initialize(name, this, channel_authenticator.Pass()); - channel->Connect(callback); + scoped_ptr<Transport> channel = + session_manager_->transport_factory_->CreateTransport(); + channel->Connect(name, this, callback); AddPendingRemoteCandidates(channel.get(), name); channels_[name] = channel.release(); } void JingleSession::CancelChannelCreation(const std::string& name) { ChannelsMap::iterator it = channels_.find(name); - if (it != channels_.end() && !it->second->is_connected()) { + if (it != channels_.end()) { + DCHECK(!it->second->is_connected()); delete it->second; DCHECK(!channels_[name]); } @@ -598,13 +601,22 @@ void JingleSession::ProcessAuthenticationStep() { void JingleSession::ContinueAuthenticationStep() { if (authenticator_->state() == Authenticator::ACCEPTED) { - SetState(AUTHENTICATED); + OnAuthenticated(); } else if (authenticator_->state() == Authenticator::REJECTED) { CloseInternal(AuthRejectionReasonToErrorCode( authenticator_->rejection_reason())); } } +void JingleSession::OnAuthenticated() { + pseudotcp_channel_factory_.reset(new PseudoTcpChannelFactory(this)); + secure_channel_factory_.reset( + new SecureChannelFactory(pseudotcp_channel_factory_.get(), + authenticator_.get())); + + SetState(AUTHENTICATED); +} + void JingleSession::CloseInternal(ErrorCode error) { DCHECK(CalledOnValidThread()); diff --git a/remoting/protocol/jingle_session.h b/remoting/protocol/jingle_session.h index dfb96cb..9ec6206 100644 --- a/remoting/protocol/jingle_session.h +++ b/remoting/protocol/jingle_session.h @@ -15,7 +15,7 @@ #include "crypto/rsa_private_key.h" #include "net/base/completion_callback.h" #include "remoting/protocol/authenticator.h" -#include "remoting/protocol/channel_factory.h" +#include "remoting/protocol/datagram_channel_factory.h" #include "remoting/protocol/jingle_messages.h" #include "remoting/protocol/session.h" #include "remoting/protocol/session_config.h" @@ -30,14 +30,17 @@ class StreamSocket; namespace remoting { namespace protocol { +class SecureChannelFactory; class ChannelMultiplexer; class JingleSessionManager; +class PseudoTcpChannelFactory; // JingleSessionManager and JingleSession implement the subset of the // Jingle protocol used in Chromoting. Instances of this class are // created by the JingleSessionManager. -class JingleSession : public Session, - public ChannelFactory, +class JingleSession : public base::NonThreadSafe, + public Session, + public DatagramChannelFactory, public Transport::EventHandler { public: virtual ~JingleSession(); @@ -49,11 +52,11 @@ class JingleSession : public Session, virtual const CandidateSessionConfig* candidate_config() OVERRIDE; virtual const SessionConfig& config() OVERRIDE; virtual void set_config(const SessionConfig& config) OVERRIDE; - virtual ChannelFactory* GetTransportChannelFactory() OVERRIDE; - virtual ChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetTransportChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; virtual void Close() OVERRIDE; - // ChannelFactory interface. + // DatagramChannelFactory interface. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) OVERRIDE; virtual void CancelChannelCreation(const std::string& name) OVERRIDE; @@ -133,6 +136,9 @@ class JingleSession : public Session, // Called after the authenticating step is finished. void ContinueAuthenticationStep(); + // Called when authentication is finished. + void OnAuthenticated(); + // Terminates the session and sends session-terminate if it is // necessary. |error| specifies the error code in case when the // session is being closed due to an error. @@ -165,6 +171,8 @@ class JingleSession : public Session, std::list<IqRequest*> transport_info_requests_; ChannelsMap channels_; + scoped_ptr<PseudoTcpChannelFactory> pseudotcp_channel_factory_; + scoped_ptr<SecureChannelFactory> secure_channel_factory_; scoped_ptr<ChannelMultiplexer> channel_multiplexer_; base::OneShotTimer<JingleSession> transport_infos_timer_; diff --git a/remoting/protocol/jingle_session_unittest.cc b/remoting/protocol/jingle_session_unittest.cc index 882e24d..d7ce228 100644 --- a/remoting/protocol/jingle_session_unittest.cc +++ b/remoting/protocol/jingle_session_unittest.cc @@ -22,6 +22,7 @@ #include "remoting/protocol/jingle_session_manager.h" #include "remoting/protocol/libjingle_transport_factory.h" #include "remoting/protocol/network_settings.h" +#include "remoting/protocol/stream_channel_factory.h" #include "remoting/signaling/fake_signal_strategy.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -139,7 +140,7 @@ class JingleSessionTest : public testing::Test { } void CreateSessionManagers(int auth_round_trips, int messages_till_start, - FakeAuthenticator::Action auth_action) { + FakeAuthenticator::Action auth_action) { host_signal_strategy_.reset(new FakeSignalStrategy(kHostJid)); client_signal_strategy_.reset(new FakeSignalStrategy(kClientJid)); FakeSignalStrategy::Connect(host_signal_strategy_.get(), @@ -510,12 +511,13 @@ TEST_F(JingleSessionTest, TestFailedChannelAuth) { // from the host. EXPECT_CALL(host_channel_callback_, OnDone(NULL)) .WillOnce(QuitThread()); - EXPECT_CALL(client_channel_callback_, OnDone(_)) - .Times(AtMost(1)); ExpectRouteChange(kChannelName); message_loop_->Run(); + client_session_->GetTransportChannelFactory()->CancelChannelCreation( + kChannelName); + EXPECT_TRUE(!host_socket_.get()); } diff --git a/remoting/protocol/libjingle_transport_factory.cc b/remoting/protocol/libjingle_transport_factory.cc index cc31440..61ed3ea 100644 --- a/remoting/protocol/libjingle_transport_factory.cc +++ b/remoting/protocol/libjingle_transport_factory.cc @@ -9,11 +9,8 @@ #include "base/thread_task_runner_handle.h" #include "base/timer/timer.h" #include "jingle/glue/channel_socket_adapter.h" -#include "jingle/glue/pseudotcp_adapter.h" #include "jingle/glue/utils.h" #include "net/base/net_errors.h" -#include "remoting/base/constants.h" -#include "remoting/protocol/channel_authenticator.h" #include "remoting/protocol/network_settings.h" #include "remoting/signaling/jingle_info_request.h" #include "third_party/libjingle/source/talk/p2p/base/constants.h" @@ -28,15 +25,6 @@ namespace protocol { namespace { -// Value is chosen to balance the extra latency against the reduced -// load due to ACK traffic. -const int kTcpAckDelayMilliseconds = 10; - -// Values for the TCP send and receive buffer size. This should be tuned to -// accommodate high latency network but not backlog the decoding pipeline. -const int kTcpReceiveBufferSize = 256 * 1024; -const int kTcpSendBufferSize = kTcpReceiveBufferSize + 30 * 1024; - // Try connecting ICE twice with timeout of 15 seconds for each attempt. const int kMaxReconnectAttempts = 2; const int kReconnectDelaySeconds = 15; @@ -44,25 +32,23 @@ const int kReconnectDelaySeconds = 15; // Get fresh STUN/Relay configuration every hour. const int kJingleInfoUpdatePeriodSeconds = 3600; -class LibjingleStreamTransport - : public StreamTransport, - public base::SupportsWeakPtr<LibjingleStreamTransport>, +class LibjingleTransport + : public Transport, + public base::SupportsWeakPtr<LibjingleTransport>, public sigslot::has_slots<> { public: - LibjingleStreamTransport(cricket::PortAllocator* port_allocator, + LibjingleTransport(cricket::PortAllocator* port_allocator, const NetworkSettings& network_settings); - virtual ~LibjingleStreamTransport(); + virtual ~LibjingleTransport(); // Called by JingleTransportFactory when it has fresh Jingle info. void OnCanStart(); - // StreamTransport interface. - virtual void Initialize( + // Transport interface. + virtual void Connect( const std::string& name, Transport::EventHandler* event_handler, - scoped_ptr<ChannelAuthenticator> authenticator) OVERRIDE; - virtual void Connect( - const StreamTransport::ConnectedCallback& callback) OVERRIDE; + const Transport::ConnectedCallback& callback) OVERRIDE; virtual void AddRemoteCandidate(const cricket::Candidate& candidate) OVERRIDE; virtual const std::string& name() const OVERRIDE; virtual bool is_connected() const OVERRIDE; @@ -78,13 +64,6 @@ class LibjingleStreamTransport const cricket::Candidate& candidate); void OnWritableState(cricket::TransportChannel* channel); - // Callback for PseudoTcpAdapter::Connect(). - void OnTcpConnected(int result); - - // Callback for Authenticator::SecureAndAuthenticate(); - void OnAuthenticationDone(net::Error error, - scoped_ptr<net::StreamSocket> socket); - // Callback for jingle_glue::TransportChannelSocketAdapter to notify when the // socket is destroyed. void OnChannelDestroyed(); @@ -92,17 +71,12 @@ class LibjingleStreamTransport // Tries to connect by restarting ICE. Called by |reconnect_timer_|. void TryReconnect(); - // Helper methods to call |callback_|. - void NotifyConnected(scoped_ptr<net::StreamSocket> socket); - void NotifyConnectFailed(); - cricket::PortAllocator* port_allocator_; NetworkSettings network_settings_; std::string name_; EventHandler* event_handler_; - StreamTransport::ConnectedCallback callback_; - scoped_ptr<ChannelAuthenticator> authenticator_; + Transport::ConnectedCallback callback_; std::string ice_username_fragment_; std::string ice_password_; @@ -112,15 +86,12 @@ class LibjingleStreamTransport scoped_ptr<cricket::P2PTransportChannel> channel_; bool channel_was_writable_; int connect_attempts_left_; - base::RepeatingTimer<LibjingleStreamTransport> reconnect_timer_; - - // We own |socket_| until it is connected. - scoped_ptr<jingle_glue::PseudoTcpAdapter> socket_; + base::RepeatingTimer<LibjingleTransport> reconnect_timer_; - DISALLOW_COPY_AND_ASSIGN(LibjingleStreamTransport); + DISALLOW_COPY_AND_ASSIGN(LibjingleTransport); }; -LibjingleStreamTransport::LibjingleStreamTransport( +LibjingleTransport::LibjingleTransport( cricket::PortAllocator* port_allocator, const NetworkSettings& network_settings) : port_allocator_(port_allocator), @@ -136,11 +107,10 @@ LibjingleStreamTransport::LibjingleStreamTransport( DCHECK(!ice_password_.empty()); } -LibjingleStreamTransport::~LibjingleStreamTransport() { +LibjingleTransport::~LibjingleTransport() { DCHECK(event_handler_); + event_handler_->OnTransportDeleted(this); - // Channel should be already destroyed if we were connected. - DCHECK(!is_connected() || socket_.get() == NULL); if (channel_.get()) { base::ThreadTaskRunnerHandle::Get()->DeleteSoon( @@ -148,7 +118,7 @@ LibjingleStreamTransport::~LibjingleStreamTransport() { } } -void LibjingleStreamTransport::OnCanStart() { +void LibjingleTransport::OnCanStart() { DCHECK(CalledOnValidThread()); DCHECK(!can_start_); @@ -164,33 +134,25 @@ void LibjingleStreamTransport::OnCanStart() { } } -void LibjingleStreamTransport::Initialize( +void LibjingleTransport::Connect( const std::string& name, Transport::EventHandler* event_handler, - scoped_ptr<ChannelAuthenticator> authenticator) { + const Transport::ConnectedCallback& callback) { DCHECK(CalledOnValidThread()); - DCHECK(!name.empty()); DCHECK(event_handler); + DCHECK(!callback.is_null()); - // Can be initialized only once. DCHECK(name_.empty()); - name_ = name; event_handler_ = event_handler; - authenticator_ = authenticator.Pass(); -} - -void LibjingleStreamTransport::Connect( - const StreamTransport::ConnectedCallback& callback) { - DCHECK(CalledOnValidThread()); callback_ = callback; if (can_start_) DoStart(); } -void LibjingleStreamTransport::DoStart() { +void LibjingleTransport::DoStart() { DCHECK(!channel_.get()); // Create P2PTransportChannel, attach signal handlers and connect it. @@ -200,13 +162,13 @@ void LibjingleStreamTransport::DoStart() { channel_->SetIceProtocolType(cricket::ICEPROTO_GOOGLE); channel_->SetIceCredentials(ice_username_fragment_, ice_password_); channel_->SignalRequestSignaling.connect( - this, &LibjingleStreamTransport::OnRequestSignaling); + this, &LibjingleTransport::OnRequestSignaling); channel_->SignalCandidateReady.connect( - this, &LibjingleStreamTransport::OnCandidateReady); + this, &LibjingleTransport::OnCandidateReady); channel_->SignalRouteChange.connect( - this, &LibjingleStreamTransport::OnRouteChange); + this, &LibjingleTransport::OnRouteChange); channel_->SignalWritableState.connect( - this, &LibjingleStreamTransport::OnWritableState); + this, &LibjingleTransport::OnWritableState); channel_->set_incoming_only( !(network_settings_.flags & NetworkSettings::NAT_TRAVERSAL_OUTGOING)); @@ -217,37 +179,20 @@ void LibjingleStreamTransport::DoStart() { // Start reconnection timer. reconnect_timer_.Start( FROM_HERE, base::TimeDelta::FromSeconds(kReconnectDelaySeconds), - this, &LibjingleStreamTransport::TryReconnect); + this, &LibjingleTransport::TryReconnect); // Create net::Socket adapter for the P2PTransportChannel. - scoped_ptr<jingle_glue::TransportChannelSocketAdapter> channel_adapter( + scoped_ptr<jingle_glue::TransportChannelSocketAdapter> socket( new jingle_glue::TransportChannelSocketAdapter(channel_.get())); + socket->SetOnDestroyedCallback(base::Bind( + &LibjingleTransport::OnChannelDestroyed, base::Unretained(this))); - channel_adapter->SetOnDestroyedCallback(base::Bind( - &LibjingleStreamTransport::OnChannelDestroyed, base::Unretained(this))); - - // Configure and connect PseudoTCP adapter. - socket_.reset( - new jingle_glue::PseudoTcpAdapter(channel_adapter.release())); - socket_->SetSendBufferSize(kTcpSendBufferSize); - socket_->SetReceiveBufferSize(kTcpReceiveBufferSize); - socket_->SetNoDelay(true); - socket_->SetAckDelay(kTcpAckDelayMilliseconds); - - // TODO(sergeyu): This is a hack to improve latency of the video - // channel. Consider removing it once we have better flow control - // implemented. - if (name_ == kVideoChannelName) - socket_->SetWriteWaitsForSend(true); - - int result = socket_->Connect( - base::Bind(&LibjingleStreamTransport::OnTcpConnected, - base::Unretained(this))); - if (result != net::ERR_IO_PENDING) - OnTcpConnected(result); + Transport::ConnectedCallback callback = callback_; + callback_.Reset(); + callback.Run(socket.PassAs<net::Socket>()); } -void LibjingleStreamTransport::AddRemoteCandidate( +void LibjingleTransport::AddRemoteCandidate( const cricket::Candidate& candidate) { DCHECK(CalledOnValidThread()); @@ -265,30 +210,30 @@ void LibjingleStreamTransport::AddRemoteCandidate( } } -const std::string& LibjingleStreamTransport::name() const { +const std::string& LibjingleTransport::name() const { DCHECK(CalledOnValidThread()); return name_; } -bool LibjingleStreamTransport::is_connected() const { +bool LibjingleTransport::is_connected() const { DCHECK(CalledOnValidThread()); return callback_.is_null(); } -void LibjingleStreamTransport::OnRequestSignaling( +void LibjingleTransport::OnRequestSignaling( cricket::TransportChannelImpl* channel) { DCHECK(CalledOnValidThread()); channel_->OnSignalingReady(); } -void LibjingleStreamTransport::OnCandidateReady( +void LibjingleTransport::OnCandidateReady( cricket::TransportChannelImpl* channel, const cricket::Candidate& candidate) { DCHECK(CalledOnValidThread()); event_handler_->OnTransportCandidate(this, candidate); } -void LibjingleStreamTransport::OnRouteChange( +void LibjingleTransport::OnRouteChange( cricket::TransportChannel* channel, const cricket::Candidate& candidate) { TransportRoute route; @@ -319,7 +264,7 @@ void LibjingleStreamTransport::OnRouteChange( event_handler_->OnTransportRouteChange(this, route); } -void LibjingleStreamTransport::OnWritableState( +void LibjingleTransport::OnWritableState( cricket::TransportChannel* channel) { DCHECK_EQ(channel, channel_.get()); @@ -333,39 +278,14 @@ void LibjingleStreamTransport::OnWritableState( } } -void LibjingleStreamTransport::OnTcpConnected(int result) { - DCHECK(CalledOnValidThread()); - - if (result != net::OK) { - NotifyConnectFailed(); - return; - } - - authenticator_->SecureAndAuthenticate( - socket_.PassAs<net::StreamSocket>(), - base::Bind(&LibjingleStreamTransport::OnAuthenticationDone, - base::Unretained(this))); -} - -void LibjingleStreamTransport::OnAuthenticationDone( - net::Error error, - scoped_ptr<net::StreamSocket> socket) { - if (error != net::OK) { - NotifyConnectFailed(); - return; - } - - NotifyConnected(socket.Pass()); -} - -void LibjingleStreamTransport::OnChannelDestroyed() { +void LibjingleTransport::OnChannelDestroyed() { if (is_connected()) { // The connection socket is being deleted, so delete the transport too. delete this; } } -void LibjingleStreamTransport::TryReconnect() { +void LibjingleTransport::TryReconnect() { DCHECK(!channel_->writable()); if (connect_attempts_left_ <= 0) { @@ -383,31 +303,6 @@ void LibjingleStreamTransport::TryReconnect() { channel_->SetIceCredentials(ice_username_fragment_, ice_password_); } -void LibjingleStreamTransport::NotifyConnected( - scoped_ptr<net::StreamSocket> socket) { - DCHECK(!is_connected()); - StreamTransport::ConnectedCallback callback = callback_; - callback_.Reset(); - callback.Run(socket.Pass()); -} - -void LibjingleStreamTransport::NotifyConnectFailed() { - DCHECK(!is_connected()); - - socket_.reset(); - - // This method may be called in response to a libjingle signal, so - // libjingle objects must be deleted asynchronously. - if (channel_.get()) { - base::ThreadTaskRunnerHandle::Get()->DeleteSoon( - FROM_HERE, channel_.release()); - } - - authenticator_.reset(); - - NotifyConnected(scoped_ptr<net::StreamSocket>()); -} - } // namespace LibjingleTransportFactory::LibjingleTransportFactory( @@ -431,9 +326,9 @@ void LibjingleTransportFactory::PrepareTokens() { EnsureFreshJingleInfo(); } -scoped_ptr<StreamTransport> LibjingleTransportFactory::CreateStreamTransport() { - scoped_ptr<LibjingleStreamTransport> result( - new LibjingleStreamTransport(port_allocator_.get(), network_settings_)); +scoped_ptr<Transport> LibjingleTransportFactory::CreateTransport() { + scoped_ptr<LibjingleTransport> result( + new LibjingleTransport(port_allocator_.get(), network_settings_)); EnsureFreshJingleInfo(); @@ -441,19 +336,13 @@ scoped_ptr<StreamTransport> LibjingleTransportFactory::CreateStreamTransport() { // transport until the request is finished. if (jingle_info_request_) { on_jingle_info_callbacks_.push_back( - base::Bind(&LibjingleStreamTransport::OnCanStart, + base::Bind(&LibjingleTransport::OnCanStart, result->AsWeakPtr())); } else { result->OnCanStart(); } - return result.PassAs<StreamTransport>(); -} - -scoped_ptr<DatagramTransport> -LibjingleTransportFactory::CreateDatagramTransport() { - NOTIMPLEMENTED(); - return scoped_ptr<DatagramTransport>(); + return result.PassAs<Transport>(); } void LibjingleTransportFactory::EnsureFreshJingleInfo() { diff --git a/remoting/protocol/libjingle_transport_factory.h b/remoting/protocol/libjingle_transport_factory.h index 08661df..0b20ff7 100644 --- a/remoting/protocol/libjingle_transport_factory.h +++ b/remoting/protocol/libjingle_transport_factory.h @@ -47,8 +47,7 @@ class LibjingleTransportFactory : public TransportFactory { // TransportFactory interface. virtual void PrepareTokens() OVERRIDE; - virtual scoped_ptr<StreamTransport> CreateStreamTransport() OVERRIDE; - virtual scoped_ptr<DatagramTransport> CreateDatagramTransport() OVERRIDE; + virtual scoped_ptr<Transport> CreateTransport() OVERRIDE; private: void EnsureFreshJingleInfo(); diff --git a/remoting/protocol/protobuf_video_reader.cc b/remoting/protocol/protobuf_video_reader.cc index 7fc9eed..f566156 100644 --- a/remoting/protocol/protobuf_video_reader.cc +++ b/remoting/protocol/protobuf_video_reader.cc @@ -8,8 +8,8 @@ #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "remoting/proto/video.pb.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { diff --git a/remoting/protocol/protobuf_video_reader.h b/remoting/protocol/protobuf_video_reader.h index f6bb55c..ad04273 100644 --- a/remoting/protocol/protobuf_video_reader.h +++ b/remoting/protocol/protobuf_video_reader.h @@ -17,7 +17,7 @@ class StreamSocket; namespace remoting { namespace protocol { -class ChannelFactory; +class StreamChannelFactory; class Session; class ProtobufVideoReader : public VideoReader { @@ -40,7 +40,7 @@ class ProtobufVideoReader : public VideoReader { VideoPacketFormat::Encoding encoding_; - ChannelFactory* channel_factory_; + StreamChannelFactory* channel_factory_; scoped_ptr<net::StreamSocket> channel_; ProtobufMessageReader<VideoPacket> reader_; diff --git a/remoting/protocol/protobuf_video_writer.cc b/remoting/protocol/protobuf_video_writer.cc index 851d1ec..81f2632 100644 --- a/remoting/protocol/protobuf_video_writer.cc +++ b/remoting/protocol/protobuf_video_writer.cc @@ -8,9 +8,9 @@ #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "remoting/proto/video.pb.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/message_serialization.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { diff --git a/remoting/protocol/protobuf_video_writer.h b/remoting/protocol/protobuf_video_writer.h index 961e821..b139e15 100644 --- a/remoting/protocol/protobuf_video_writer.h +++ b/remoting/protocol/protobuf_video_writer.h @@ -20,7 +20,7 @@ class StreamSocket; namespace remoting { namespace protocol { -class ChannelFactory; +class StreamChannelFactory; class Session; class ProtobufVideoWriter : public VideoWriter { @@ -43,7 +43,7 @@ class ProtobufVideoWriter : public VideoWriter { InitializedCallback initialized_callback_; - ChannelFactory* channel_factory_; + StreamChannelFactory* channel_factory_; scoped_ptr<net::StreamSocket> channel_; BufferedSocketWriter buffered_writer_; diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index 07e03306..2818604 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -167,8 +167,8 @@ class MockSession : public Session { MOCK_METHOD1(SetEventHandler, void(Session::EventHandler* event_handler)); MOCK_METHOD0(error, ErrorCode()); - MOCK_METHOD0(GetTransportChannelFactory, ChannelFactory*()); - MOCK_METHOD0(GetMultiplexedChannelFactory, ChannelFactory*()); + MOCK_METHOD0(GetTransportChannelFactory, StreamChannelFactory*()); + MOCK_METHOD0(GetMultiplexedChannelFactory, StreamChannelFactory*()); MOCK_METHOD0(jid, const std::string&()); MOCK_METHOD0(candidate_config, const CandidateSessionConfig*()); MOCK_METHOD0(config, const SessionConfig&()); diff --git a/remoting/protocol/pseudotcp_channel_factory.cc b/remoting/protocol/pseudotcp_channel_factory.cc new file mode 100644 index 0000000..689db92 --- /dev/null +++ b/remoting/protocol/pseudotcp_channel_factory.cc @@ -0,0 +1,100 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/pseudotcp_channel_factory.h" + +#include "base/bind.h" +#include "jingle/glue/pseudotcp_adapter.h" +#include "net/base/net_errors.h" +#include "net/socket/stream_socket.h" +#include "remoting/base/constants.h" +#include "remoting/protocol/datagram_channel_factory.h" + +namespace remoting { +namespace protocol { + +namespace { + +// Value is chosen to balance the extra latency against the reduced +// load due to ACK traffic. +const int kTcpAckDelayMilliseconds = 10; + +// Values for the TCP send and receive buffer size. This should be tuned to +// accommodate high latency network but not backlog the decoding pipeline. +const int kTcpReceiveBufferSize = 256 * 1024; +const int kTcpSendBufferSize = kTcpReceiveBufferSize + 30 * 1024; + +} // namespace + +PseudoTcpChannelFactory::PseudoTcpChannelFactory( + DatagramChannelFactory* datagram_channel_factory) + : datagram_channel_factory_(datagram_channel_factory) { +} + +PseudoTcpChannelFactory::~PseudoTcpChannelFactory() { + // CancelChannelCreation() is expected to be called before destruction. + DCHECK(pending_sockets_.empty()); +} + +void PseudoTcpChannelFactory::CreateChannel( + const std::string& name, + const ChannelCreatedCallback& callback) { + datagram_channel_factory_->CreateChannel( + name, + base::Bind(&PseudoTcpChannelFactory::OnDatagramChannelCreated, + base::Unretained(this), name, callback)); +} + +void PseudoTcpChannelFactory::CancelChannelCreation(const std::string& name) { + PendingSocketsMap::iterator it = pending_sockets_.find(name); + if (it == pending_sockets_.end()) { + datagram_channel_factory_->CancelChannelCreation(name); + } else { + delete it->second; + pending_sockets_.erase(it); + } +} + +void PseudoTcpChannelFactory::OnDatagramChannelCreated( + const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::Socket> datagram_socket) { + jingle_glue::PseudoTcpAdapter* adapter = + new jingle_glue::PseudoTcpAdapter(datagram_socket.release()); + pending_sockets_[name] = adapter; + + adapter->SetSendBufferSize(kTcpSendBufferSize); + adapter->SetReceiveBufferSize(kTcpReceiveBufferSize); + adapter->SetNoDelay(true); + adapter->SetAckDelay(kTcpAckDelayMilliseconds); + + // TODO(sergeyu): This is a hack to improve latency of the video channel. + // Consider removing it once we have better flow control implemented. + if (name == kVideoChannelName) + adapter->SetWriteWaitsForSend(true); + + int result = adapter->Connect( + base::Bind(&PseudoTcpChannelFactory::OnPseudoTcpConnected, + base::Unretained(this), name, callback)); + if (result != net::ERR_IO_PENDING) + OnPseudoTcpConnected(name, callback, result); +} + +void PseudoTcpChannelFactory::OnPseudoTcpConnected( + const std::string& name, + const ChannelCreatedCallback& callback, + int result) { + PendingSocketsMap::iterator it = pending_sockets_.find(name); + DCHECK(it != pending_sockets_.end()); + scoped_ptr<net::StreamSocket> socket(it->second); + pending_sockets_.erase(it); + + if (result != net::OK) + socket.reset(); + + callback.Run(socket.Pass()); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/pseudotcp_channel_factory.h b/remoting/protocol/pseudotcp_channel_factory.h new file mode 100644 index 0000000..701b5d7 --- /dev/null +++ b/remoting/protocol/pseudotcp_channel_factory.h @@ -0,0 +1,52 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_PSEUDOTCP_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_PSEUDOTCP_CHANNEL_FACTORY_H_ + +#include <map> + +#include "base/basictypes.h" +#include "remoting/protocol/stream_channel_factory.h" + +namespace remoting { +namespace protocol { + +class DatagramChannelFactory; + +// StreamChannelFactory that creates PseudoTCP-based stream channels that run on +// top of datagram channels created using specified |datagram_channel_factory|. +class PseudoTcpChannelFactory : public StreamChannelFactory { + public: + // |datagram_channel_factory| must outlive this object. + explicit PseudoTcpChannelFactory( + DatagramChannelFactory* datagram_channel_factory); + virtual ~PseudoTcpChannelFactory(); + + // StreamChannelFactory interface. + virtual void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) OVERRIDE; + virtual void CancelChannelCreation(const std::string& name) OVERRIDE; + + private: + typedef std::map<std::string, net::StreamSocket*> PendingSocketsMap; + + void OnDatagramChannelCreated(const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::Socket> socket); + void OnPseudoTcpConnected(const std::string& name, + const ChannelCreatedCallback& callback, + int result); + + DatagramChannelFactory* datagram_channel_factory_; + + PendingSocketsMap pending_sockets_; + + DISALLOW_COPY_AND_ASSIGN(PseudoTcpChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_PSEUDOTCP_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/secure_channel_factory.cc b/remoting/protocol/secure_channel_factory.cc new file mode 100644 index 0000000..df98378 --- /dev/null +++ b/remoting/protocol/secure_channel_factory.cc @@ -0,0 +1,83 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/secure_channel_factory.h" + +#include "base/bind.h" +#include "net/socket/stream_socket.h" +#include "remoting/protocol/authenticator.h" +#include "remoting/protocol/channel_authenticator.h" + +namespace remoting { +namespace protocol { + +SecureChannelFactory::SecureChannelFactory( + StreamChannelFactory* channel_factory, + Authenticator* authenticator) + : channel_factory_(channel_factory), + authenticator_(authenticator) { + DCHECK_EQ(authenticator_->state(), Authenticator::ACCEPTED); +} + +SecureChannelFactory::~SecureChannelFactory() { + // CancelChannelCreation() is expected to be called before destruction. + DCHECK(channel_authenticators_.empty()); +} + +void SecureChannelFactory::CreateChannel( + const std::string& name, + const ChannelCreatedCallback& callback) { + DCHECK(!callback.is_null()); + channel_factory_->CreateChannel( + name, + base::Bind(&SecureChannelFactory::OnBaseChannelCreated, + base::Unretained(this), name, callback)); +} + +void SecureChannelFactory::CancelChannelCreation( + const std::string& name) { + AuthenticatorMap::iterator it = channel_authenticators_.find(name); + if (it == channel_authenticators_.end()) { + channel_factory_->CancelChannelCreation(name); + } else { + delete it->second; + channel_authenticators_.erase(it); + } +} + +void SecureChannelFactory::OnBaseChannelCreated( + const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::StreamSocket> socket) { + if (!socket) { + callback.Run(scoped_ptr<net::StreamSocket>()); + return; + } + + ChannelAuthenticator* channel_authenticator = + authenticator_->CreateChannelAuthenticator().release(); + channel_authenticators_[name] = channel_authenticator; + channel_authenticator->SecureAndAuthenticate( + socket.Pass(), + base::Bind(&SecureChannelFactory::OnSecureChannelCreated, + base::Unretained(this), name, callback)); +} + +void SecureChannelFactory::OnSecureChannelCreated( + const std::string& name, + const ChannelCreatedCallback& callback, + int error, + scoped_ptr<net::StreamSocket> socket) { + DCHECK((socket && error == net::OK) || (!socket && error != net::OK)); + + AuthenticatorMap::iterator it = channel_authenticators_.find(name); + DCHECK(it != channel_authenticators_.end()); + delete it->second; + channel_authenticators_.erase(it); + + callback.Run(socket.Pass()); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/secure_channel_factory.h b/remoting/protocol/secure_channel_factory.h new file mode 100644 index 0000000..8f8e12e --- /dev/null +++ b/remoting/protocol/secure_channel_factory.h @@ -0,0 +1,60 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_SECURE_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_SECURE_CHANNEL_FACTORY_H_ + +#include <map> + +#include "base/basictypes.h" +#include "net/base/net_errors.h" +#include "remoting/protocol/stream_channel_factory.h" + +namespace remoting { +namespace protocol { + +class Authenticator; +class ChannelAuthenticator; + +// StreamChannelFactory wrapper that authenticates every channel it creates. +// When CreateChannel() is called it first calls the wrapped +// StreamChannelFactory to create a channel and then uses the specified +// Authenticator to secure and authenticate the new channel before returning it +// to the caller. +class SecureChannelFactory : public StreamChannelFactory { + public: + // Both parameters must outlive the object. + SecureChannelFactory(StreamChannelFactory* channel_factory, + Authenticator* authenticator); + virtual ~SecureChannelFactory(); + + // StreamChannelFactory interface. + virtual void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) OVERRIDE; + virtual void CancelChannelCreation(const std::string& name) OVERRIDE; + + private: + typedef std::map<std::string, ChannelAuthenticator*> AuthenticatorMap; + + void OnBaseChannelCreated(const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::StreamSocket> socket); + + void OnSecureChannelCreated(const std::string& name, + const ChannelCreatedCallback& callback, + int error, + scoped_ptr<net::StreamSocket> socket); + + StreamChannelFactory* channel_factory_; + Authenticator* authenticator_; + + AuthenticatorMap channel_authenticators_; + + DISALLOW_COPY_AND_ASSIGN(SecureChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_SECURE_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h index 35b3740..f806e21 100644 --- a/remoting/protocol/session.h +++ b/remoting/protocol/session.h @@ -17,7 +17,7 @@ class IPEndPoint; namespace remoting { namespace protocol { -class ChannelFactory; +class StreamChannelFactory; struct TransportRoute; // Generic interface for Chromotocol connection used by both client and host. @@ -99,8 +99,8 @@ class Session { // GetTransportChannelFactory() returns a factory that creates a new transport // channel for each logical channel. GetMultiplexedChannelFactory() channels // share a single underlying transport channel - virtual ChannelFactory* GetTransportChannelFactory() = 0; - virtual ChannelFactory* GetMultiplexedChannelFactory() = 0; + virtual StreamChannelFactory* GetTransportChannelFactory() = 0; + virtual StreamChannelFactory* GetMultiplexedChannelFactory() = 0; // Closes connection. Callbacks are guaranteed not to be called // after this method returns. Must be called before the object is diff --git a/remoting/protocol/ssl_hmac_channel_authenticator.cc b/remoting/protocol/ssl_hmac_channel_authenticator.cc index d85ad5f..f4bedea 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator.cc +++ b/remoting/protocol/ssl_hmac_channel_authenticator.cc @@ -279,13 +279,21 @@ void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) { DCHECK(socket_.get() != NULL); if (callback_called) *callback_called = true; - done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>()); + + CallDoneCallback(net::OK, socket_.PassAs<net::StreamSocket>()); } } void SslHmacChannelAuthenticator::NotifyError(int error) { - done_callback_.Run(static_cast<net::Error>(error), - scoped_ptr<net::StreamSocket>()); + CallDoneCallback(error, scoped_ptr<net::StreamSocket>()); +} + +void SslHmacChannelAuthenticator::CallDoneCallback( + int error, + scoped_ptr<net::StreamSocket> socket) { + DoneCallback callback = done_callback_; + done_callback_.Reset(); + callback.Run(error, socket.Pass()); } } // namespace protocol diff --git a/remoting/protocol/ssl_hmac_channel_authenticator.h b/remoting/protocol/ssl_hmac_channel_authenticator.h index f4223c4..849dab3 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator.h +++ b/remoting/protocol/ssl_hmac_channel_authenticator.h @@ -78,6 +78,7 @@ class SslHmacChannelAuthenticator : public ChannelAuthenticator, void CheckDone(bool* callback_called); void NotifyError(int error); + void CallDoneCallback(int error, scoped_ptr<net::StreamSocket> socket); // The mutual secret used for authentication. std::string auth_key_; diff --git a/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc b/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc index cb239fb..3b0818e 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc +++ b/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc @@ -36,7 +36,7 @@ const char kTestSharedSecretBad[] = "0000-0000-0001"; class MockChannelDoneCallback { public: - MOCK_METHOD2(OnDone, void(net::Error error, net::StreamSocket* socket)); + MOCK_METHOD2(OnDone, void(int error, net::StreamSocket* socket)); }; ACTION_P(QuitThreadOnCounter, counter) { @@ -82,7 +82,7 @@ class SslHmacChannelAuthenticatorTest : public testing::Test { host_auth_->SecureAndAuthenticate( host_fake_socket_.PassAs<net::StreamSocket>(), base::Bind(&SslHmacChannelAuthenticatorTest::OnHostConnected, - base::Unretained(this))); + base::Unretained(this), std::string("ref argument value"))); // Expect two callbacks to be called - the client callback and the host // callback. @@ -109,14 +109,20 @@ class SslHmacChannelAuthenticatorTest : public testing::Test { message_loop_.Run(); } - void OnHostConnected(net::Error error, + void OnHostConnected(const std::string& ref_argument, + int error, scoped_ptr<net::StreamSocket> socket) { + // Try deleting the authenticator and verify that this doesn't destroy + // reference parameters. + host_auth_.reset(); + DCHECK_EQ(ref_argument, "ref argument value"); + host_callback_.OnDone(error, socket.get()); host_socket_ = socket.Pass(); } - void OnClientConnected(net::Error error, - scoped_ptr<net::StreamSocket> socket) { + void OnClientConnected(int error, scoped_ptr<net::StreamSocket> socket) { + client_auth_.reset(); client_callback_.OnDone(error, socket.get()); client_socket_ = socket.Pass(); } diff --git a/remoting/protocol/channel_factory.h b/remoting/protocol/stream_channel_factory.h index 31bf3ec..6c3ebd1 100644 --- a/remoting/protocol/channel_factory.h +++ b/remoting/protocol/stream_channel_factory.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ -#define REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ +#ifndef REMOTING_PROTOCOL_STREAM_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_STREAM_CHANNEL_FACTORY_H_ #include "base/callback.h" #include "base/memory/scoped_ptr.h" @@ -17,22 +17,20 @@ class StreamSocket; namespace remoting { namespace protocol { -class ChannelFactory : public base::NonThreadSafe { +class StreamChannelFactory : public base::NonThreadSafe { public: // TODO(sergeyu): Specify connection error code when channel // connection fails. typedef base::Callback<void(scoped_ptr<net::StreamSocket>)> ChannelCreatedCallback; - ChannelFactory() {} + StreamChannelFactory() {} - // Creates new channels for this connection. The specified callback is called - // when then new channel is created and connected. The callback is called with - // NULL if connection failed for any reason. Callback may be called - // synchronously, before the call returns. All channels must be destroyed - // before the factory is destroyed and CancelChannelCreation() must be called - // to cancel creation of channels for which the |callback| hasn't been called - // yet. + // Creates new channels and calls the |callback| when then new channel is + // created and connected. The |callback| is called with NULL if connection + // failed for any reason. Callback may be called synchronously, before the + // call returns. All channels must be destroyed, and CancelChannelCreation() + // called for any pending channels, before the factory is destroyed. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) = 0; @@ -42,13 +40,13 @@ class ChannelFactory : public base::NonThreadSafe { virtual void CancelChannelCreation(const std::string& name) = 0; protected: - virtual ~ChannelFactory() {} + virtual ~StreamChannelFactory() {} private: - DISALLOW_COPY_AND_ASSIGN(ChannelFactory); + DISALLOW_COPY_AND_ASSIGN(StreamChannelFactory); }; } // namespace protocol } // namespace remoting -#endif // REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ +#endif // REMOTING_PROTOCOL_STREAM_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/transport.h b/remoting/protocol/transport.h index eb20e12..d4c4b3f 100644 --- a/remoting/protocol/transport.h +++ b/remoting/protocol/transport.h @@ -87,14 +87,15 @@ class Transport : public base::NonThreadSafe { virtual void OnTransportDeleted(Transport* transport) = 0; }; + typedef base::Callback<void(scoped_ptr<net::Socket>)> ConnectedCallback; + Transport() {} virtual ~Transport() {} - // Intialize the transport with the specified parameters. - // |authenticator| is used to secure and authenticate the connection. - virtual void Initialize(const std::string& name, - Transport::EventHandler* event_handler, - scoped_ptr<ChannelAuthenticator> authenticator) = 0; + // Connects the transport and calls the |callback| after that. + virtual void Connect(const std::string& name, + Transport::EventHandler* event_handler, + const ConnectedCallback& callback) = 0; // Adds |candidate| received from the peer. virtual void AddRemoteCandidate(const cricket::Candidate& candidate) = 0; @@ -111,32 +112,6 @@ class Transport : public base::NonThreadSafe { DISALLOW_COPY_AND_ASSIGN(Transport); }; -class StreamTransport : public Transport { - public: - typedef base::Callback<void(scoped_ptr<net::StreamSocket>)> ConnectedCallback; - - StreamTransport() { } - virtual ~StreamTransport() { } - - virtual void Connect(const ConnectedCallback& callback) = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(StreamTransport); -}; - -class DatagramTransport : public Transport { - public: - typedef base::Callback<void(scoped_ptr<net::Socket>)> ConnectedCallback; - - DatagramTransport() { } - virtual ~DatagramTransport() { } - - virtual void Connect(const ConnectedCallback& callback) = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(DatagramTransport); -}; - class TransportFactory { public: TransportFactory() { } @@ -148,8 +123,7 @@ class TransportFactory { // necessary while the session is being authenticated. virtual void PrepareTokens() = 0; - virtual scoped_ptr<StreamTransport> CreateStreamTransport() = 0; - virtual scoped_ptr<DatagramTransport> CreateDatagramTransport() = 0; + virtual scoped_ptr<Transport> CreateTransport() = 0; private: DISALLOW_COPY_AND_ASSIGN(TransportFactory); |