diff options
author | sergeyu <sergeyu@chromium.org> | 2014-09-11 14:45:02 -0700 |
---|---|---|
committer | Commit bot <commit-bot@chromium.org> | 2014-09-11 22:01:29 +0000 |
commit | 28d886c967e016a5d5812be43cd5916f577c2e10 (patch) | |
tree | 4510350de11125ab89cfcf60ae8a624a8659037b /remoting | |
parent | 042e7e077ee2cb726804c27313093241b97bf09e (diff) | |
download | chromium_src-28d886c967e016a5d5812be43cd5916f577c2e10.zip chromium_src-28d886c967e016a5d5812be43cd5916f577c2e10.tar.gz chromium_src-28d886c967e016a5d5812be43cd5916f577c2e10.tar.bz2 |
Move PseudoTCP and channel auth out of LibjingleTransportFactory.
Previously TransportFactory interface was responsible for creation
and initialization of several protocol layers, including PseudoTCP and
authentication (TLS). Simplified it so now it only creates raw datagram
transport channel. PseudoTcpChannelFactory is now responsible for
setting up PseudoTcpAdapter and AuthenticatingChannelFactory takes care
of channel authentication. Also added DatagramChannelFactory for
Datagram channels.
This change will make it possible to replace PseudoTcpChannelFactory
with an object that creates SCTP-based channels.
Also fixed a bug in SslHmacChannelAuthenticator. It wasn't working
properly when deleted from the callback. (base::Callback objects
shouldn't be deleted while being called because when deleted they
also destroy reference parameters values they are holding).
BUG=402993
Review URL: https://codereview.chromium.org/551173004
Cr-Commit-Position: refs/heads/master@{#294474}
Diffstat (limited to 'remoting')
34 files changed, 555 insertions, 298 deletions
diff --git a/remoting/protocol/BUILD.gn b/remoting/protocol/BUILD.gn index 189e5884..4a7fa0e 100644 --- a/remoting/protocol/BUILD.gn +++ b/remoting/protocol/BUILD.gn @@ -44,6 +44,7 @@ static_library("protocol") { "connection_to_host.h", "content_description.cc", "content_description.h", + "datagram_channel_factory.h", "errors.h", "host_control_dispatcher.cc", "host_control_dispatcher.h", @@ -98,6 +99,10 @@ static_library("protocol") { "protobuf_video_reader.h", "protobuf_video_writer.cc", "protobuf_video_writer.h", + "pseudotcp_channel_factory.cc", + "pseudotcp_channel_factory.h", + "secure_channel_factory.cc", + "secure_channel_factory.h", "session.h", "session_config.cc", "session_config.h", @@ -106,6 +111,7 @@ static_library("protocol") { "socket_util.h", "ssl_hmac_channel_authenticator.cc", "ssl_hmac_channel_authenticator.h", + "stream_channel_factory.h", "third_party_authenticator_base.cc", "third_party_authenticator_base.h", "third_party_client_authenticator.cc", diff --git a/remoting/protocol/authenticator_test_base.cc b/remoting/protocol/authenticator_test_base.cc index 3f1d9f6..bd63cd7 100644 --- a/remoting/protocol/authenticator_test_base.cc +++ b/remoting/protocol/authenticator_test_base.cc @@ -10,6 +10,7 @@ #include "base/path_service.h" #include "base/test/test_timeouts.h" #include "base/timer/timer.h" +#include "net/base/net_errors.h" #include "net/base/test_data_directory.h" #include "remoting/base/rsa_key_pair.h" #include "remoting/protocol/authenticator.h" @@ -157,14 +158,14 @@ void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { } void AuthenticatorTestBase::OnHostConnected( - net::Error error, + int error, scoped_ptr<net::StreamSocket> socket) { host_callback_.OnDone(error); host_socket_ = socket.Pass(); } void AuthenticatorTestBase::OnClientConnected( - net::Error error, + int error, scoped_ptr<net::StreamSocket> socket) { client_callback_.OnDone(error); client_socket_ = socket.Pass(); diff --git a/remoting/protocol/authenticator_test_base.h b/remoting/protocol/authenticator_test_base.h index e20774a..c5c7e6b 100644 --- a/remoting/protocol/authenticator_test_base.h +++ b/remoting/protocol/authenticator_test_base.h @@ -9,7 +9,6 @@ #include "base/memory/ref_counted.h" #include "base/message_loop/message_loop.h" -#include "net/base/net_errors.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -37,7 +36,7 @@ class AuthenticatorTestBase : public testing::Test { public: MockChannelDoneCallback(); ~MockChannelDoneCallback(); - MOCK_METHOD1(OnDone, void(net::Error error)); + MOCK_METHOD1(OnDone, void(int error)); }; static void ContinueAuthExchangeWith(Authenticator* sender, @@ -49,9 +48,9 @@ class AuthenticatorTestBase : public testing::Test { void RunHostInitiatedAuthExchange(); void RunChannelAuth(bool expected_fail); - void OnHostConnected(net::Error error, + void OnHostConnected(int error, scoped_ptr<net::StreamSocket> socket); - void OnClientConnected(net::Error error, + void OnClientConnected(int error, scoped_ptr<net::StreamSocket> socket); base::MessageLoop message_loop_; diff --git a/remoting/protocol/channel_authenticator.h b/remoting/protocol/channel_authenticator.h index 7466b09..8bef908 100644 --- a/remoting/protocol/channel_authenticator.h +++ b/remoting/protocol/channel_authenticator.h @@ -8,7 +8,6 @@ #include <string> #include "base/callback_forward.h" -#include "net/base/net_errors.h" namespace net { class StreamSocket; @@ -23,14 +22,14 @@ namespace protocol { // should be used only once for one channel. class ChannelAuthenticator { public: - typedef base::Callback<void(net::Error error, scoped_ptr<net::StreamSocket>)> + typedef base::Callback<void(int error, scoped_ptr<net::StreamSocket>)> DoneCallback; virtual ~ChannelAuthenticator() {} - // Start authentication of the given |socket|. |done_callback| is - // called when authentication is finished. Callback may be invoked - // before this method returns. + // Start authentication of the given |socket|. |done_callback| is called when + // authentication is finished. Callback may be invoked before this method + // returns, and may delete the calling authenticator. virtual void SecureAndAuthenticate( scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) = 0; diff --git a/remoting/protocol/channel_dispatcher_base.cc b/remoting/protocol/channel_dispatcher_base.cc index 10d64006..c292b2a 100644 --- a/remoting/protocol/channel_dispatcher_base.cc +++ b/remoting/protocol/channel_dispatcher_base.cc @@ -6,9 +6,9 @@ #include "base/bind.h" #include "net/socket/stream_socket.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/session.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { diff --git a/remoting/protocol/channel_dispatcher_base.h b/remoting/protocol/channel_dispatcher_base.h index 906d0f9..f71d291 100644 --- a/remoting/protocol/channel_dispatcher_base.h +++ b/remoting/protocol/channel_dispatcher_base.h @@ -19,7 +19,7 @@ namespace remoting { namespace protocol { struct ChannelConfig; -class ChannelFactory; +class StreamChannelFactory; class Session; // Base class for channel message dispatchers. It's responsible for @@ -56,7 +56,7 @@ class ChannelDispatcherBase { void OnChannelReady(scoped_ptr<net::StreamSocket> socket); std::string channel_name_; - ChannelFactory* channel_factory_; + StreamChannelFactory* channel_factory_; InitializedCallback initialized_callback_; scoped_ptr<net::StreamSocket> channel_; diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc index a6a80b4..5751440 100644 --- a/remoting/protocol/channel_multiplexer.cc +++ b/remoting/protocol/channel_multiplexer.cc @@ -353,7 +353,7 @@ void ChannelMultiplexer::MuxSocket::OnPacketReceived() { } } -ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, +ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, const std::string& base_channel_name) : base_channel_factory_(factory), base_channel_name_(base_channel_name), diff --git a/remoting/protocol/channel_multiplexer.h b/remoting/protocol/channel_multiplexer.h index 924f132..506dc4b 100644 --- a/remoting/protocol/channel_multiplexer.h +++ b/remoting/protocol/channel_multiplexer.h @@ -8,22 +8,22 @@ #include "base/memory/weak_ptr.h" #include "remoting/proto/mux.pb.h" #include "remoting/protocol/buffered_socket_writer.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/message_reader.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { -class ChannelMultiplexer : public ChannelFactory { +class ChannelMultiplexer : public StreamChannelFactory { public: static const char kMuxChannelName[]; // |factory| is used to create the channel upon which to multiplex. - ChannelMultiplexer(ChannelFactory* factory, + ChannelMultiplexer(StreamChannelFactory* factory, const std::string& base_channel_name); virtual ~ChannelMultiplexer(); - // ChannelFactory interface. + // StreamChannelFactory interface. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) OVERRIDE; virtual void CancelChannelCreation(const std::string& name) OVERRIDE; @@ -59,7 +59,7 @@ class ChannelMultiplexer : public ChannelFactory { // Factory used to create |base_channel_|. Set to NULL once creation is // finished or failed. - ChannelFactory* base_channel_factory_; + StreamChannelFactory* base_channel_factory_; // Name of the underlying channel. std::string base_channel_name_; diff --git a/remoting/protocol/datagram_channel_factory.h b/remoting/protocol/datagram_channel_factory.h new file mode 100644 index 0000000..41ade7f --- /dev/null +++ b/remoting/protocol/datagram_channel_factory.h @@ -0,0 +1,45 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_DATAGRAM_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_DATAGRAM_CHANNEL_FACTORY_H_ + +namespace net { +class Socket; +} // namespace net + +namespace remoting { +namespace protocol { + +class DatagramChannelFactory { + public: + typedef base::Callback<void(scoped_ptr<net::Socket>)> + ChannelCreatedCallback; + + DatagramChannelFactory() {} + + // Creates new channels and calls the |callback| when then new channel is + // created and connected. The |callback| is called with NULL if channel setup + // failed for any reason. Callback may be called synchronously, before the + // call returns. All channels must be destroyed, and CancelChannelCreation() + // called for any pending channels, before the factory is destroyed. + virtual void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) = 0; + + // Cancels a pending CreateChannel() operation for the named channel. If the + // channel creation already completed then canceling it has no effect. When + // shutting down this method must be called for each channel pending creation. + virtual void CancelChannelCreation(const std::string& name) = 0; + + protected: + virtual ~DatagramChannelFactory() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DatagramChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_DATAGRAM_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/fake_authenticator.cc b/remoting/protocol/fake_authenticator.cc index bf06c56..9f2b4f6 100644 --- a/remoting/protocol/fake_authenticator.cc +++ b/remoting/protocol/fake_authenticator.cc @@ -7,6 +7,7 @@ #include "base/message_loop/message_loop.h" #include "base/strings/string_number_conversions.h" #include "net/base/io_buffer.h" +#include "net/base/net_errors.h" #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "testing/gtest/include/gtest/gtest.h" @@ -34,31 +35,33 @@ void FakeChannelAuthenticator::SecureAndAuthenticate( if (async_) { done_callback_ = done_callback; - scoped_refptr<net::IOBuffer> write_buf = new net::IOBuffer(1); - write_buf->data()[0] = 0; - int result = - socket_->Write(write_buf.get(), - 1, - base::Bind(&FakeChannelAuthenticator::OnAuthBytesWritten, - weak_factory_.GetWeakPtr())); - if (result != net::ERR_IO_PENDING) { - // This will not call the callback because |did_read_bytes_| is - // still set to false. - OnAuthBytesWritten(result); + if (result_ != net::OK) { + // Don't write anything if we are going to reject auth to make test + // ordering deterministic. + did_write_bytes_ = true; + } else { + scoped_refptr<net::IOBuffer> write_buf = new net::IOBuffer(1); + write_buf->data()[0] = 0; + int result = socket_->Write( + write_buf.get(), 1, + base::Bind(&FakeChannelAuthenticator::OnAuthBytesWritten, + weak_factory_.GetWeakPtr())); + if (result != net::ERR_IO_PENDING) { + // This will not call the callback because |did_read_bytes_| is + // still set to false. + OnAuthBytesWritten(result); + } } scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(1); - result = - socket_->Read(read_buf.get(), - 1, + int result = + socket_->Read(read_buf.get(), 1, base::Bind(&FakeChannelAuthenticator::OnAuthBytesRead, weak_factory_.GetWeakPtr())); if (result != net::ERR_IO_PENDING) OnAuthBytesRead(result); } else { - if (result_ != net::OK) - socket_.reset(); - done_callback.Run(result_, socket_.Pass()); + CallDoneCallback(); } } @@ -67,7 +70,7 @@ void FakeChannelAuthenticator::OnAuthBytesWritten(int result) { EXPECT_FALSE(did_write_bytes_); did_write_bytes_ = true; if (did_read_bytes_) - done_callback_.Run(result_, socket_.Pass()); + CallDoneCallback(); } void FakeChannelAuthenticator::OnAuthBytesRead(int result) { @@ -75,7 +78,15 @@ void FakeChannelAuthenticator::OnAuthBytesRead(int result) { EXPECT_FALSE(did_read_bytes_); did_read_bytes_ = true; if (did_write_bytes_) - done_callback_.Run(result_, socket_.Pass()); + CallDoneCallback(); +} + +void FakeChannelAuthenticator::CallDoneCallback() { + DoneCallback callback = done_callback_; + done_callback_.Reset(); + if (result_ != net::OK) + socket_.reset(); + callback.Run(result_, socket_.Pass()); } FakeAuthenticator::FakeAuthenticator( diff --git a/remoting/protocol/fake_authenticator.h b/remoting/protocol/fake_authenticator.h index b74654d0..a6ddc74 100644 --- a/remoting/protocol/fake_authenticator.h +++ b/remoting/protocol/fake_authenticator.h @@ -24,14 +24,12 @@ class FakeChannelAuthenticator : public ChannelAuthenticator { const DoneCallback& done_callback) OVERRIDE; private: - void CallCallback( - net::Error error, - scoped_ptr<net::StreamSocket> socket); - void OnAuthBytesWritten(int result); void OnAuthBytesRead(int result); - net::Error result_; + void CallDoneCallback(); + + int result_; bool async_; scoped_ptr<net::StreamSocket> socket_; diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index 7c62ed2..f02a47a 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -320,11 +320,11 @@ void FakeSession::set_config(const SessionConfig& config) { config_ = config; } -ChannelFactory* FakeSession::GetTransportChannelFactory() { +StreamChannelFactory* FakeSession::GetTransportChannelFactory() { return this; } -ChannelFactory* FakeSession::GetMultiplexedChannelFactory() { +StreamChannelFactory* FakeSession::GetMultiplexedChannelFactory() { return this; } diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index c7793f2..5240681 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -14,8 +14,8 @@ #include "net/base/completion_callback.h" #include "net/socket/socket.h" #include "net/socket/stream_socket.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/stream_channel_factory.h" namespace base { class MessageLoop; @@ -148,7 +148,7 @@ class FakeUdpSocket : public net::Socket { // FakeSession is a dummy protocol::Session that uses FakeSocket for all // channels. class FakeSession : public Session, - public ChannelFactory { + public StreamChannelFactory { public: FakeSession(); virtual ~FakeSession(); @@ -173,11 +173,11 @@ class FakeSession : public Session, virtual const CandidateSessionConfig* candidate_config() OVERRIDE; virtual const SessionConfig& config() OVERRIDE; virtual void set_config(const SessionConfig& config) OVERRIDE; - virtual ChannelFactory* GetTransportChannelFactory() OVERRIDE; - virtual ChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetTransportChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; virtual void Close() OVERRIDE; - // ChannelFactory interface. + // StreamChannelFactory interface. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) OVERRIDE; virtual void CancelChannelCreation(const std::string& name) OVERRIDE; diff --git a/remoting/protocol/jingle_session.cc b/remoting/protocol/jingle_session.cc index 727be4b..e6eb1e8 100644 --- a/remoting/protocol/jingle_session.cc +++ b/remoting/protocol/jingle_session.cc @@ -18,7 +18,10 @@ #include "remoting/protocol/content_description.h" #include "remoting/protocol/jingle_messages.h" #include "remoting/protocol/jingle_session_manager.h" +#include "remoting/protocol/pseudotcp_channel_factory.h" +#include "remoting/protocol/secure_channel_factory.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/stream_channel_factory.h" #include "remoting/signaling/iq_sender.h" #include "third_party/libjingle/source/talk/p2p/base/candidate.h" #include "third_party/webrtc/libjingle/xmllite/xmlelement.h" @@ -81,7 +84,7 @@ JingleSession::~JingleSession() { pending_requests_.end()); STLDeleteContainerPointers(transport_info_requests_.begin(), transport_info_requests_.end()); - STLDeleteContainerPairSecondPointers(channels_.begin(), channels_.end()); + DCHECK(channels_.empty()); session_manager_->SessionDestroyed(this); } @@ -187,7 +190,7 @@ void JingleSession::ContinueAcceptIncomingConnection() { SetState(CONNECTED); if (authenticator_->state() == Authenticator::ACCEPTED) { - SetState(AUTHENTICATED); + OnAuthenticated(); } else { DCHECK_EQ(authenticator_->state(), Authenticator::WAITING_MESSAGE); if (authenticator_->started()) { @@ -218,15 +221,17 @@ void JingleSession::set_config(const SessionConfig& config) { config_is_set_ = true; } -ChannelFactory* JingleSession::GetTransportChannelFactory() { +StreamChannelFactory* JingleSession::GetTransportChannelFactory() { DCHECK(CalledOnValidThread()); - return this; + return secure_channel_factory_.get(); } -ChannelFactory* JingleSession::GetMultiplexedChannelFactory() { +StreamChannelFactory* JingleSession::GetMultiplexedChannelFactory() { DCHECK(CalledOnValidThread()); - if (!channel_multiplexer_.get()) - channel_multiplexer_.reset(new ChannelMultiplexer(this, kMuxChannelName)); + if (!channel_multiplexer_.get()) { + channel_multiplexer_.reset( + new ChannelMultiplexer(GetTransportChannelFactory(), kMuxChannelName)); + } return channel_multiplexer_.get(); } @@ -254,19 +259,17 @@ void JingleSession::CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) { DCHECK(!channels_[name]); - scoped_ptr<ChannelAuthenticator> channel_authenticator = - authenticator_->CreateChannelAuthenticator(); - scoped_ptr<StreamTransport> channel = - session_manager_->transport_factory_->CreateStreamTransport(); - channel->Initialize(name, this, channel_authenticator.Pass()); - channel->Connect(callback); + scoped_ptr<Transport> channel = + session_manager_->transport_factory_->CreateTransport(); + channel->Connect(name, this, callback); AddPendingRemoteCandidates(channel.get(), name); channels_[name] = channel.release(); } void JingleSession::CancelChannelCreation(const std::string& name) { ChannelsMap::iterator it = channels_.find(name); - if (it != channels_.end() && !it->second->is_connected()) { + if (it != channels_.end()) { + DCHECK(!it->second->is_connected()); delete it->second; DCHECK(!channels_[name]); } @@ -598,13 +601,22 @@ void JingleSession::ProcessAuthenticationStep() { void JingleSession::ContinueAuthenticationStep() { if (authenticator_->state() == Authenticator::ACCEPTED) { - SetState(AUTHENTICATED); + OnAuthenticated(); } else if (authenticator_->state() == Authenticator::REJECTED) { CloseInternal(AuthRejectionReasonToErrorCode( authenticator_->rejection_reason())); } } +void JingleSession::OnAuthenticated() { + pseudotcp_channel_factory_.reset(new PseudoTcpChannelFactory(this)); + secure_channel_factory_.reset( + new SecureChannelFactory(pseudotcp_channel_factory_.get(), + authenticator_.get())); + + SetState(AUTHENTICATED); +} + void JingleSession::CloseInternal(ErrorCode error) { DCHECK(CalledOnValidThread()); diff --git a/remoting/protocol/jingle_session.h b/remoting/protocol/jingle_session.h index dfb96cb..9ec6206 100644 --- a/remoting/protocol/jingle_session.h +++ b/remoting/protocol/jingle_session.h @@ -15,7 +15,7 @@ #include "crypto/rsa_private_key.h" #include "net/base/completion_callback.h" #include "remoting/protocol/authenticator.h" -#include "remoting/protocol/channel_factory.h" +#include "remoting/protocol/datagram_channel_factory.h" #include "remoting/protocol/jingle_messages.h" #include "remoting/protocol/session.h" #include "remoting/protocol/session_config.h" @@ -30,14 +30,17 @@ class StreamSocket; namespace remoting { namespace protocol { +class SecureChannelFactory; class ChannelMultiplexer; class JingleSessionManager; +class PseudoTcpChannelFactory; // JingleSessionManager and JingleSession implement the subset of the // Jingle protocol used in Chromoting. Instances of this class are // created by the JingleSessionManager. -class JingleSession : public Session, - public ChannelFactory, +class JingleSession : public base::NonThreadSafe, + public Session, + public DatagramChannelFactory, public Transport::EventHandler { public: virtual ~JingleSession(); @@ -49,11 +52,11 @@ class JingleSession : public Session, virtual const CandidateSessionConfig* candidate_config() OVERRIDE; virtual const SessionConfig& config() OVERRIDE; virtual void set_config(const SessionConfig& config) OVERRIDE; - virtual ChannelFactory* GetTransportChannelFactory() OVERRIDE; - virtual ChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetTransportChannelFactory() OVERRIDE; + virtual StreamChannelFactory* GetMultiplexedChannelFactory() OVERRIDE; virtual void Close() OVERRIDE; - // ChannelFactory interface. + // DatagramChannelFactory interface. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) OVERRIDE; virtual void CancelChannelCreation(const std::string& name) OVERRIDE; @@ -133,6 +136,9 @@ class JingleSession : public Session, // Called after the authenticating step is finished. void ContinueAuthenticationStep(); + // Called when authentication is finished. + void OnAuthenticated(); + // Terminates the session and sends session-terminate if it is // necessary. |error| specifies the error code in case when the // session is being closed due to an error. @@ -165,6 +171,8 @@ class JingleSession : public Session, std::list<IqRequest*> transport_info_requests_; ChannelsMap channels_; + scoped_ptr<PseudoTcpChannelFactory> pseudotcp_channel_factory_; + scoped_ptr<SecureChannelFactory> secure_channel_factory_; scoped_ptr<ChannelMultiplexer> channel_multiplexer_; base::OneShotTimer<JingleSession> transport_infos_timer_; diff --git a/remoting/protocol/jingle_session_unittest.cc b/remoting/protocol/jingle_session_unittest.cc index 882e24d..d7ce228 100644 --- a/remoting/protocol/jingle_session_unittest.cc +++ b/remoting/protocol/jingle_session_unittest.cc @@ -22,6 +22,7 @@ #include "remoting/protocol/jingle_session_manager.h" #include "remoting/protocol/libjingle_transport_factory.h" #include "remoting/protocol/network_settings.h" +#include "remoting/protocol/stream_channel_factory.h" #include "remoting/signaling/fake_signal_strategy.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -139,7 +140,7 @@ class JingleSessionTest : public testing::Test { } void CreateSessionManagers(int auth_round_trips, int messages_till_start, - FakeAuthenticator::Action auth_action) { + FakeAuthenticator::Action auth_action) { host_signal_strategy_.reset(new FakeSignalStrategy(kHostJid)); client_signal_strategy_.reset(new FakeSignalStrategy(kClientJid)); FakeSignalStrategy::Connect(host_signal_strategy_.get(), @@ -510,12 +511,13 @@ TEST_F(JingleSessionTest, TestFailedChannelAuth) { // from the host. EXPECT_CALL(host_channel_callback_, OnDone(NULL)) .WillOnce(QuitThread()); - EXPECT_CALL(client_channel_callback_, OnDone(_)) - .Times(AtMost(1)); ExpectRouteChange(kChannelName); message_loop_->Run(); + client_session_->GetTransportChannelFactory()->CancelChannelCreation( + kChannelName); + EXPECT_TRUE(!host_socket_.get()); } diff --git a/remoting/protocol/libjingle_transport_factory.cc b/remoting/protocol/libjingle_transport_factory.cc index cc31440..61ed3ea 100644 --- a/remoting/protocol/libjingle_transport_factory.cc +++ b/remoting/protocol/libjingle_transport_factory.cc @@ -9,11 +9,8 @@ #include "base/thread_task_runner_handle.h" #include "base/timer/timer.h" #include "jingle/glue/channel_socket_adapter.h" -#include "jingle/glue/pseudotcp_adapter.h" #include "jingle/glue/utils.h" #include "net/base/net_errors.h" -#include "remoting/base/constants.h" -#include "remoting/protocol/channel_authenticator.h" #include "remoting/protocol/network_settings.h" #include "remoting/signaling/jingle_info_request.h" #include "third_party/libjingle/source/talk/p2p/base/constants.h" @@ -28,15 +25,6 @@ namespace protocol { namespace { -// Value is chosen to balance the extra latency against the reduced -// load due to ACK traffic. -const int kTcpAckDelayMilliseconds = 10; - -// Values for the TCP send and receive buffer size. This should be tuned to -// accommodate high latency network but not backlog the decoding pipeline. -const int kTcpReceiveBufferSize = 256 * 1024; -const int kTcpSendBufferSize = kTcpReceiveBufferSize + 30 * 1024; - // Try connecting ICE twice with timeout of 15 seconds for each attempt. const int kMaxReconnectAttempts = 2; const int kReconnectDelaySeconds = 15; @@ -44,25 +32,23 @@ const int kReconnectDelaySeconds = 15; // Get fresh STUN/Relay configuration every hour. const int kJingleInfoUpdatePeriodSeconds = 3600; -class LibjingleStreamTransport - : public StreamTransport, - public base::SupportsWeakPtr<LibjingleStreamTransport>, +class LibjingleTransport + : public Transport, + public base::SupportsWeakPtr<LibjingleTransport>, public sigslot::has_slots<> { public: - LibjingleStreamTransport(cricket::PortAllocator* port_allocator, + LibjingleTransport(cricket::PortAllocator* port_allocator, const NetworkSettings& network_settings); - virtual ~LibjingleStreamTransport(); + virtual ~LibjingleTransport(); // Called by JingleTransportFactory when it has fresh Jingle info. void OnCanStart(); - // StreamTransport interface. - virtual void Initialize( + // Transport interface. + virtual void Connect( const std::string& name, Transport::EventHandler* event_handler, - scoped_ptr<ChannelAuthenticator> authenticator) OVERRIDE; - virtual void Connect( - const StreamTransport::ConnectedCallback& callback) OVERRIDE; + const Transport::ConnectedCallback& callback) OVERRIDE; virtual void AddRemoteCandidate(const cricket::Candidate& candidate) OVERRIDE; virtual const std::string& name() const OVERRIDE; virtual bool is_connected() const OVERRIDE; @@ -78,13 +64,6 @@ class LibjingleStreamTransport const cricket::Candidate& candidate); void OnWritableState(cricket::TransportChannel* channel); - // Callback for PseudoTcpAdapter::Connect(). - void OnTcpConnected(int result); - - // Callback for Authenticator::SecureAndAuthenticate(); - void OnAuthenticationDone(net::Error error, - scoped_ptr<net::StreamSocket> socket); - // Callback for jingle_glue::TransportChannelSocketAdapter to notify when the // socket is destroyed. void OnChannelDestroyed(); @@ -92,17 +71,12 @@ class LibjingleStreamTransport // Tries to connect by restarting ICE. Called by |reconnect_timer_|. void TryReconnect(); - // Helper methods to call |callback_|. - void NotifyConnected(scoped_ptr<net::StreamSocket> socket); - void NotifyConnectFailed(); - cricket::PortAllocator* port_allocator_; NetworkSettings network_settings_; std::string name_; EventHandler* event_handler_; - StreamTransport::ConnectedCallback callback_; - scoped_ptr<ChannelAuthenticator> authenticator_; + Transport::ConnectedCallback callback_; std::string ice_username_fragment_; std::string ice_password_; @@ -112,15 +86,12 @@ class LibjingleStreamTransport scoped_ptr<cricket::P2PTransportChannel> channel_; bool channel_was_writable_; int connect_attempts_left_; - base::RepeatingTimer<LibjingleStreamTransport> reconnect_timer_; - - // We own |socket_| until it is connected. - scoped_ptr<jingle_glue::PseudoTcpAdapter> socket_; + base::RepeatingTimer<LibjingleTransport> reconnect_timer_; - DISALLOW_COPY_AND_ASSIGN(LibjingleStreamTransport); + DISALLOW_COPY_AND_ASSIGN(LibjingleTransport); }; -LibjingleStreamTransport::LibjingleStreamTransport( +LibjingleTransport::LibjingleTransport( cricket::PortAllocator* port_allocator, const NetworkSettings& network_settings) : port_allocator_(port_allocator), @@ -136,11 +107,10 @@ LibjingleStreamTransport::LibjingleStreamTransport( DCHECK(!ice_password_.empty()); } -LibjingleStreamTransport::~LibjingleStreamTransport() { +LibjingleTransport::~LibjingleTransport() { DCHECK(event_handler_); + event_handler_->OnTransportDeleted(this); - // Channel should be already destroyed if we were connected. - DCHECK(!is_connected() || socket_.get() == NULL); if (channel_.get()) { base::ThreadTaskRunnerHandle::Get()->DeleteSoon( @@ -148,7 +118,7 @@ LibjingleStreamTransport::~LibjingleStreamTransport() { } } -void LibjingleStreamTransport::OnCanStart() { +void LibjingleTransport::OnCanStart() { DCHECK(CalledOnValidThread()); DCHECK(!can_start_); @@ -164,33 +134,25 @@ void LibjingleStreamTransport::OnCanStart() { } } -void LibjingleStreamTransport::Initialize( +void LibjingleTransport::Connect( const std::string& name, Transport::EventHandler* event_handler, - scoped_ptr<ChannelAuthenticator> authenticator) { + const Transport::ConnectedCallback& callback) { DCHECK(CalledOnValidThread()); - DCHECK(!name.empty()); DCHECK(event_handler); + DCHECK(!callback.is_null()); - // Can be initialized only once. DCHECK(name_.empty()); - name_ = name; event_handler_ = event_handler; - authenticator_ = authenticator.Pass(); -} - -void LibjingleStreamTransport::Connect( - const StreamTransport::ConnectedCallback& callback) { - DCHECK(CalledOnValidThread()); callback_ = callback; if (can_start_) DoStart(); } -void LibjingleStreamTransport::DoStart() { +void LibjingleTransport::DoStart() { DCHECK(!channel_.get()); // Create P2PTransportChannel, attach signal handlers and connect it. @@ -200,13 +162,13 @@ void LibjingleStreamTransport::DoStart() { channel_->SetIceProtocolType(cricket::ICEPROTO_GOOGLE); channel_->SetIceCredentials(ice_username_fragment_, ice_password_); channel_->SignalRequestSignaling.connect( - this, &LibjingleStreamTransport::OnRequestSignaling); + this, &LibjingleTransport::OnRequestSignaling); channel_->SignalCandidateReady.connect( - this, &LibjingleStreamTransport::OnCandidateReady); + this, &LibjingleTransport::OnCandidateReady); channel_->SignalRouteChange.connect( - this, &LibjingleStreamTransport::OnRouteChange); + this, &LibjingleTransport::OnRouteChange); channel_->SignalWritableState.connect( - this, &LibjingleStreamTransport::OnWritableState); + this, &LibjingleTransport::OnWritableState); channel_->set_incoming_only( !(network_settings_.flags & NetworkSettings::NAT_TRAVERSAL_OUTGOING)); @@ -217,37 +179,20 @@ void LibjingleStreamTransport::DoStart() { // Start reconnection timer. reconnect_timer_.Start( FROM_HERE, base::TimeDelta::FromSeconds(kReconnectDelaySeconds), - this, &LibjingleStreamTransport::TryReconnect); + this, &LibjingleTransport::TryReconnect); // Create net::Socket adapter for the P2PTransportChannel. - scoped_ptr<jingle_glue::TransportChannelSocketAdapter> channel_adapter( + scoped_ptr<jingle_glue::TransportChannelSocketAdapter> socket( new jingle_glue::TransportChannelSocketAdapter(channel_.get())); + socket->SetOnDestroyedCallback(base::Bind( + &LibjingleTransport::OnChannelDestroyed, base::Unretained(this))); - channel_adapter->SetOnDestroyedCallback(base::Bind( - &LibjingleStreamTransport::OnChannelDestroyed, base::Unretained(this))); - - // Configure and connect PseudoTCP adapter. - socket_.reset( - new jingle_glue::PseudoTcpAdapter(channel_adapter.release())); - socket_->SetSendBufferSize(kTcpSendBufferSize); - socket_->SetReceiveBufferSize(kTcpReceiveBufferSize); - socket_->SetNoDelay(true); - socket_->SetAckDelay(kTcpAckDelayMilliseconds); - - // TODO(sergeyu): This is a hack to improve latency of the video - // channel. Consider removing it once we have better flow control - // implemented. - if (name_ == kVideoChannelName) - socket_->SetWriteWaitsForSend(true); - - int result = socket_->Connect( - base::Bind(&LibjingleStreamTransport::OnTcpConnected, - base::Unretained(this))); - if (result != net::ERR_IO_PENDING) - OnTcpConnected(result); + Transport::ConnectedCallback callback = callback_; + callback_.Reset(); + callback.Run(socket.PassAs<net::Socket>()); } -void LibjingleStreamTransport::AddRemoteCandidate( +void LibjingleTransport::AddRemoteCandidate( const cricket::Candidate& candidate) { DCHECK(CalledOnValidThread()); @@ -265,30 +210,30 @@ void LibjingleStreamTransport::AddRemoteCandidate( } } -const std::string& LibjingleStreamTransport::name() const { +const std::string& LibjingleTransport::name() const { DCHECK(CalledOnValidThread()); return name_; } -bool LibjingleStreamTransport::is_connected() const { +bool LibjingleTransport::is_connected() const { DCHECK(CalledOnValidThread()); return callback_.is_null(); } -void LibjingleStreamTransport::OnRequestSignaling( +void LibjingleTransport::OnRequestSignaling( cricket::TransportChannelImpl* channel) { DCHECK(CalledOnValidThread()); channel_->OnSignalingReady(); } -void LibjingleStreamTransport::OnCandidateReady( +void LibjingleTransport::OnCandidateReady( cricket::TransportChannelImpl* channel, const cricket::Candidate& candidate) { DCHECK(CalledOnValidThread()); event_handler_->OnTransportCandidate(this, candidate); } -void LibjingleStreamTransport::OnRouteChange( +void LibjingleTransport::OnRouteChange( cricket::TransportChannel* channel, const cricket::Candidate& candidate) { TransportRoute route; @@ -319,7 +264,7 @@ void LibjingleStreamTransport::OnRouteChange( event_handler_->OnTransportRouteChange(this, route); } -void LibjingleStreamTransport::OnWritableState( +void LibjingleTransport::OnWritableState( cricket::TransportChannel* channel) { DCHECK_EQ(channel, channel_.get()); @@ -333,39 +278,14 @@ void LibjingleStreamTransport::OnWritableState( } } -void LibjingleStreamTransport::OnTcpConnected(int result) { - DCHECK(CalledOnValidThread()); - - if (result != net::OK) { - NotifyConnectFailed(); - return; - } - - authenticator_->SecureAndAuthenticate( - socket_.PassAs<net::StreamSocket>(), - base::Bind(&LibjingleStreamTransport::OnAuthenticationDone, - base::Unretained(this))); -} - -void LibjingleStreamTransport::OnAuthenticationDone( - net::Error error, - scoped_ptr<net::StreamSocket> socket) { - if (error != net::OK) { - NotifyConnectFailed(); - return; - } - - NotifyConnected(socket.Pass()); -} - -void LibjingleStreamTransport::OnChannelDestroyed() { +void LibjingleTransport::OnChannelDestroyed() { if (is_connected()) { // The connection socket is being deleted, so delete the transport too. delete this; } } -void LibjingleStreamTransport::TryReconnect() { +void LibjingleTransport::TryReconnect() { DCHECK(!channel_->writable()); if (connect_attempts_left_ <= 0) { @@ -383,31 +303,6 @@ void LibjingleStreamTransport::TryReconnect() { channel_->SetIceCredentials(ice_username_fragment_, ice_password_); } -void LibjingleStreamTransport::NotifyConnected( - scoped_ptr<net::StreamSocket> socket) { - DCHECK(!is_connected()); - StreamTransport::ConnectedCallback callback = callback_; - callback_.Reset(); - callback.Run(socket.Pass()); -} - -void LibjingleStreamTransport::NotifyConnectFailed() { - DCHECK(!is_connected()); - - socket_.reset(); - - // This method may be called in response to a libjingle signal, so - // libjingle objects must be deleted asynchronously. - if (channel_.get()) { - base::ThreadTaskRunnerHandle::Get()->DeleteSoon( - FROM_HERE, channel_.release()); - } - - authenticator_.reset(); - - NotifyConnected(scoped_ptr<net::StreamSocket>()); -} - } // namespace LibjingleTransportFactory::LibjingleTransportFactory( @@ -431,9 +326,9 @@ void LibjingleTransportFactory::PrepareTokens() { EnsureFreshJingleInfo(); } -scoped_ptr<StreamTransport> LibjingleTransportFactory::CreateStreamTransport() { - scoped_ptr<LibjingleStreamTransport> result( - new LibjingleStreamTransport(port_allocator_.get(), network_settings_)); +scoped_ptr<Transport> LibjingleTransportFactory::CreateTransport() { + scoped_ptr<LibjingleTransport> result( + new LibjingleTransport(port_allocator_.get(), network_settings_)); EnsureFreshJingleInfo(); @@ -441,19 +336,13 @@ scoped_ptr<StreamTransport> LibjingleTransportFactory::CreateStreamTransport() { // transport until the request is finished. if (jingle_info_request_) { on_jingle_info_callbacks_.push_back( - base::Bind(&LibjingleStreamTransport::OnCanStart, + base::Bind(&LibjingleTransport::OnCanStart, result->AsWeakPtr())); } else { result->OnCanStart(); } - return result.PassAs<StreamTransport>(); -} - -scoped_ptr<DatagramTransport> -LibjingleTransportFactory::CreateDatagramTransport() { - NOTIMPLEMENTED(); - return scoped_ptr<DatagramTransport>(); + return result.PassAs<Transport>(); } void LibjingleTransportFactory::EnsureFreshJingleInfo() { diff --git a/remoting/protocol/libjingle_transport_factory.h b/remoting/protocol/libjingle_transport_factory.h index 08661df..0b20ff7 100644 --- a/remoting/protocol/libjingle_transport_factory.h +++ b/remoting/protocol/libjingle_transport_factory.h @@ -47,8 +47,7 @@ class LibjingleTransportFactory : public TransportFactory { // TransportFactory interface. virtual void PrepareTokens() OVERRIDE; - virtual scoped_ptr<StreamTransport> CreateStreamTransport() OVERRIDE; - virtual scoped_ptr<DatagramTransport> CreateDatagramTransport() OVERRIDE; + virtual scoped_ptr<Transport> CreateTransport() OVERRIDE; private: void EnsureFreshJingleInfo(); diff --git a/remoting/protocol/protobuf_video_reader.cc b/remoting/protocol/protobuf_video_reader.cc index 7fc9eed..f566156 100644 --- a/remoting/protocol/protobuf_video_reader.cc +++ b/remoting/protocol/protobuf_video_reader.cc @@ -8,8 +8,8 @@ #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "remoting/proto/video.pb.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { diff --git a/remoting/protocol/protobuf_video_reader.h b/remoting/protocol/protobuf_video_reader.h index f6bb55c..ad04273 100644 --- a/remoting/protocol/protobuf_video_reader.h +++ b/remoting/protocol/protobuf_video_reader.h @@ -17,7 +17,7 @@ class StreamSocket; namespace remoting { namespace protocol { -class ChannelFactory; +class StreamChannelFactory; class Session; class ProtobufVideoReader : public VideoReader { @@ -40,7 +40,7 @@ class ProtobufVideoReader : public VideoReader { VideoPacketFormat::Encoding encoding_; - ChannelFactory* channel_factory_; + StreamChannelFactory* channel_factory_; scoped_ptr<net::StreamSocket> channel_; ProtobufMessageReader<VideoPacket> reader_; diff --git a/remoting/protocol/protobuf_video_writer.cc b/remoting/protocol/protobuf_video_writer.cc index 851d1ec..81f2632 100644 --- a/remoting/protocol/protobuf_video_writer.cc +++ b/remoting/protocol/protobuf_video_writer.cc @@ -8,9 +8,9 @@ #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "remoting/proto/video.pb.h" -#include "remoting/protocol/channel_factory.h" #include "remoting/protocol/message_serialization.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/stream_channel_factory.h" namespace remoting { namespace protocol { diff --git a/remoting/protocol/protobuf_video_writer.h b/remoting/protocol/protobuf_video_writer.h index 961e821..b139e15 100644 --- a/remoting/protocol/protobuf_video_writer.h +++ b/remoting/protocol/protobuf_video_writer.h @@ -20,7 +20,7 @@ class StreamSocket; namespace remoting { namespace protocol { -class ChannelFactory; +class StreamChannelFactory; class Session; class ProtobufVideoWriter : public VideoWriter { @@ -43,7 +43,7 @@ class ProtobufVideoWriter : public VideoWriter { InitializedCallback initialized_callback_; - ChannelFactory* channel_factory_; + StreamChannelFactory* channel_factory_; scoped_ptr<net::StreamSocket> channel_; BufferedSocketWriter buffered_writer_; diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index 07e03306..2818604 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -167,8 +167,8 @@ class MockSession : public Session { MOCK_METHOD1(SetEventHandler, void(Session::EventHandler* event_handler)); MOCK_METHOD0(error, ErrorCode()); - MOCK_METHOD0(GetTransportChannelFactory, ChannelFactory*()); - MOCK_METHOD0(GetMultiplexedChannelFactory, ChannelFactory*()); + MOCK_METHOD0(GetTransportChannelFactory, StreamChannelFactory*()); + MOCK_METHOD0(GetMultiplexedChannelFactory, StreamChannelFactory*()); MOCK_METHOD0(jid, const std::string&()); MOCK_METHOD0(candidate_config, const CandidateSessionConfig*()); MOCK_METHOD0(config, const SessionConfig&()); diff --git a/remoting/protocol/pseudotcp_channel_factory.cc b/remoting/protocol/pseudotcp_channel_factory.cc new file mode 100644 index 0000000..689db92 --- /dev/null +++ b/remoting/protocol/pseudotcp_channel_factory.cc @@ -0,0 +1,100 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/pseudotcp_channel_factory.h" + +#include "base/bind.h" +#include "jingle/glue/pseudotcp_adapter.h" +#include "net/base/net_errors.h" +#include "net/socket/stream_socket.h" +#include "remoting/base/constants.h" +#include "remoting/protocol/datagram_channel_factory.h" + +namespace remoting { +namespace protocol { + +namespace { + +// Value is chosen to balance the extra latency against the reduced +// load due to ACK traffic. +const int kTcpAckDelayMilliseconds = 10; + +// Values for the TCP send and receive buffer size. This should be tuned to +// accommodate high latency network but not backlog the decoding pipeline. +const int kTcpReceiveBufferSize = 256 * 1024; +const int kTcpSendBufferSize = kTcpReceiveBufferSize + 30 * 1024; + +} // namespace + +PseudoTcpChannelFactory::PseudoTcpChannelFactory( + DatagramChannelFactory* datagram_channel_factory) + : datagram_channel_factory_(datagram_channel_factory) { +} + +PseudoTcpChannelFactory::~PseudoTcpChannelFactory() { + // CancelChannelCreation() is expected to be called before destruction. + DCHECK(pending_sockets_.empty()); +} + +void PseudoTcpChannelFactory::CreateChannel( + const std::string& name, + const ChannelCreatedCallback& callback) { + datagram_channel_factory_->CreateChannel( + name, + base::Bind(&PseudoTcpChannelFactory::OnDatagramChannelCreated, + base::Unretained(this), name, callback)); +} + +void PseudoTcpChannelFactory::CancelChannelCreation(const std::string& name) { + PendingSocketsMap::iterator it = pending_sockets_.find(name); + if (it == pending_sockets_.end()) { + datagram_channel_factory_->CancelChannelCreation(name); + } else { + delete it->second; + pending_sockets_.erase(it); + } +} + +void PseudoTcpChannelFactory::OnDatagramChannelCreated( + const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::Socket> datagram_socket) { + jingle_glue::PseudoTcpAdapter* adapter = + new jingle_glue::PseudoTcpAdapter(datagram_socket.release()); + pending_sockets_[name] = adapter; + + adapter->SetSendBufferSize(kTcpSendBufferSize); + adapter->SetReceiveBufferSize(kTcpReceiveBufferSize); + adapter->SetNoDelay(true); + adapter->SetAckDelay(kTcpAckDelayMilliseconds); + + // TODO(sergeyu): This is a hack to improve latency of the video channel. + // Consider removing it once we have better flow control implemented. + if (name == kVideoChannelName) + adapter->SetWriteWaitsForSend(true); + + int result = adapter->Connect( + base::Bind(&PseudoTcpChannelFactory::OnPseudoTcpConnected, + base::Unretained(this), name, callback)); + if (result != net::ERR_IO_PENDING) + OnPseudoTcpConnected(name, callback, result); +} + +void PseudoTcpChannelFactory::OnPseudoTcpConnected( + const std::string& name, + const ChannelCreatedCallback& callback, + int result) { + PendingSocketsMap::iterator it = pending_sockets_.find(name); + DCHECK(it != pending_sockets_.end()); + scoped_ptr<net::StreamSocket> socket(it->second); + pending_sockets_.erase(it); + + if (result != net::OK) + socket.reset(); + + callback.Run(socket.Pass()); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/pseudotcp_channel_factory.h b/remoting/protocol/pseudotcp_channel_factory.h new file mode 100644 index 0000000..701b5d7 --- /dev/null +++ b/remoting/protocol/pseudotcp_channel_factory.h @@ -0,0 +1,52 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_PSEUDOTCP_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_PSEUDOTCP_CHANNEL_FACTORY_H_ + +#include <map> + +#include "base/basictypes.h" +#include "remoting/protocol/stream_channel_factory.h" + +namespace remoting { +namespace protocol { + +class DatagramChannelFactory; + +// StreamChannelFactory that creates PseudoTCP-based stream channels that run on +// top of datagram channels created using specified |datagram_channel_factory|. +class PseudoTcpChannelFactory : public StreamChannelFactory { + public: + // |datagram_channel_factory| must outlive this object. + explicit PseudoTcpChannelFactory( + DatagramChannelFactory* datagram_channel_factory); + virtual ~PseudoTcpChannelFactory(); + + // StreamChannelFactory interface. + virtual void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) OVERRIDE; + virtual void CancelChannelCreation(const std::string& name) OVERRIDE; + + private: + typedef std::map<std::string, net::StreamSocket*> PendingSocketsMap; + + void OnDatagramChannelCreated(const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::Socket> socket); + void OnPseudoTcpConnected(const std::string& name, + const ChannelCreatedCallback& callback, + int result); + + DatagramChannelFactory* datagram_channel_factory_; + + PendingSocketsMap pending_sockets_; + + DISALLOW_COPY_AND_ASSIGN(PseudoTcpChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_PSEUDOTCP_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/secure_channel_factory.cc b/remoting/protocol/secure_channel_factory.cc new file mode 100644 index 0000000..df98378 --- /dev/null +++ b/remoting/protocol/secure_channel_factory.cc @@ -0,0 +1,83 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/protocol/secure_channel_factory.h" + +#include "base/bind.h" +#include "net/socket/stream_socket.h" +#include "remoting/protocol/authenticator.h" +#include "remoting/protocol/channel_authenticator.h" + +namespace remoting { +namespace protocol { + +SecureChannelFactory::SecureChannelFactory( + StreamChannelFactory* channel_factory, + Authenticator* authenticator) + : channel_factory_(channel_factory), + authenticator_(authenticator) { + DCHECK_EQ(authenticator_->state(), Authenticator::ACCEPTED); +} + +SecureChannelFactory::~SecureChannelFactory() { + // CancelChannelCreation() is expected to be called before destruction. + DCHECK(channel_authenticators_.empty()); +} + +void SecureChannelFactory::CreateChannel( + const std::string& name, + const ChannelCreatedCallback& callback) { + DCHECK(!callback.is_null()); + channel_factory_->CreateChannel( + name, + base::Bind(&SecureChannelFactory::OnBaseChannelCreated, + base::Unretained(this), name, callback)); +} + +void SecureChannelFactory::CancelChannelCreation( + const std::string& name) { + AuthenticatorMap::iterator it = channel_authenticators_.find(name); + if (it == channel_authenticators_.end()) { + channel_factory_->CancelChannelCreation(name); + } else { + delete it->second; + channel_authenticators_.erase(it); + } +} + +void SecureChannelFactory::OnBaseChannelCreated( + const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::StreamSocket> socket) { + if (!socket) { + callback.Run(scoped_ptr<net::StreamSocket>()); + return; + } + + ChannelAuthenticator* channel_authenticator = + authenticator_->CreateChannelAuthenticator().release(); + channel_authenticators_[name] = channel_authenticator; + channel_authenticator->SecureAndAuthenticate( + socket.Pass(), + base::Bind(&SecureChannelFactory::OnSecureChannelCreated, + base::Unretained(this), name, callback)); +} + +void SecureChannelFactory::OnSecureChannelCreated( + const std::string& name, + const ChannelCreatedCallback& callback, + int error, + scoped_ptr<net::StreamSocket> socket) { + DCHECK((socket && error == net::OK) || (!socket && error != net::OK)); + + AuthenticatorMap::iterator it = channel_authenticators_.find(name); + DCHECK(it != channel_authenticators_.end()); + delete it->second; + channel_authenticators_.erase(it); + + callback.Run(socket.Pass()); +} + +} // namespace protocol +} // namespace remoting diff --git a/remoting/protocol/secure_channel_factory.h b/remoting/protocol/secure_channel_factory.h new file mode 100644 index 0000000..8f8e12e --- /dev/null +++ b/remoting/protocol/secure_channel_factory.h @@ -0,0 +1,60 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_PROTOCOL_SECURE_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_SECURE_CHANNEL_FACTORY_H_ + +#include <map> + +#include "base/basictypes.h" +#include "net/base/net_errors.h" +#include "remoting/protocol/stream_channel_factory.h" + +namespace remoting { +namespace protocol { + +class Authenticator; +class ChannelAuthenticator; + +// StreamChannelFactory wrapper that authenticates every channel it creates. +// When CreateChannel() is called it first calls the wrapped +// StreamChannelFactory to create a channel and then uses the specified +// Authenticator to secure and authenticate the new channel before returning it +// to the caller. +class SecureChannelFactory : public StreamChannelFactory { + public: + // Both parameters must outlive the object. + SecureChannelFactory(StreamChannelFactory* channel_factory, + Authenticator* authenticator); + virtual ~SecureChannelFactory(); + + // StreamChannelFactory interface. + virtual void CreateChannel(const std::string& name, + const ChannelCreatedCallback& callback) OVERRIDE; + virtual void CancelChannelCreation(const std::string& name) OVERRIDE; + + private: + typedef std::map<std::string, ChannelAuthenticator*> AuthenticatorMap; + + void OnBaseChannelCreated(const std::string& name, + const ChannelCreatedCallback& callback, + scoped_ptr<net::StreamSocket> socket); + + void OnSecureChannelCreated(const std::string& name, + const ChannelCreatedCallback& callback, + int error, + scoped_ptr<net::StreamSocket> socket); + + StreamChannelFactory* channel_factory_; + Authenticator* authenticator_; + + AuthenticatorMap channel_authenticators_; + + DISALLOW_COPY_AND_ASSIGN(SecureChannelFactory); +}; + +} // namespace protocol +} // namespace remoting + +#endif // REMOTING_PROTOCOL_SECURE_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h index 35b3740..f806e21 100644 --- a/remoting/protocol/session.h +++ b/remoting/protocol/session.h @@ -17,7 +17,7 @@ class IPEndPoint; namespace remoting { namespace protocol { -class ChannelFactory; +class StreamChannelFactory; struct TransportRoute; // Generic interface for Chromotocol connection used by both client and host. @@ -99,8 +99,8 @@ class Session { // GetTransportChannelFactory() returns a factory that creates a new transport // channel for each logical channel. GetMultiplexedChannelFactory() channels // share a single underlying transport channel - virtual ChannelFactory* GetTransportChannelFactory() = 0; - virtual ChannelFactory* GetMultiplexedChannelFactory() = 0; + virtual StreamChannelFactory* GetTransportChannelFactory() = 0; + virtual StreamChannelFactory* GetMultiplexedChannelFactory() = 0; // Closes connection. Callbacks are guaranteed not to be called // after this method returns. Must be called before the object is diff --git a/remoting/protocol/ssl_hmac_channel_authenticator.cc b/remoting/protocol/ssl_hmac_channel_authenticator.cc index d85ad5f..f4bedea 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator.cc +++ b/remoting/protocol/ssl_hmac_channel_authenticator.cc @@ -279,13 +279,21 @@ void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) { DCHECK(socket_.get() != NULL); if (callback_called) *callback_called = true; - done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>()); + + CallDoneCallback(net::OK, socket_.PassAs<net::StreamSocket>()); } } void SslHmacChannelAuthenticator::NotifyError(int error) { - done_callback_.Run(static_cast<net::Error>(error), - scoped_ptr<net::StreamSocket>()); + CallDoneCallback(error, scoped_ptr<net::StreamSocket>()); +} + +void SslHmacChannelAuthenticator::CallDoneCallback( + int error, + scoped_ptr<net::StreamSocket> socket) { + DoneCallback callback = done_callback_; + done_callback_.Reset(); + callback.Run(error, socket.Pass()); } } // namespace protocol diff --git a/remoting/protocol/ssl_hmac_channel_authenticator.h b/remoting/protocol/ssl_hmac_channel_authenticator.h index f4223c4..849dab3 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator.h +++ b/remoting/protocol/ssl_hmac_channel_authenticator.h @@ -78,6 +78,7 @@ class SslHmacChannelAuthenticator : public ChannelAuthenticator, void CheckDone(bool* callback_called); void NotifyError(int error); + void CallDoneCallback(int error, scoped_ptr<net::StreamSocket> socket); // The mutual secret used for authentication. std::string auth_key_; diff --git a/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc b/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc index cb239fb..3b0818e 100644 --- a/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc +++ b/remoting/protocol/ssl_hmac_channel_authenticator_unittest.cc @@ -36,7 +36,7 @@ const char kTestSharedSecretBad[] = "0000-0000-0001"; class MockChannelDoneCallback { public: - MOCK_METHOD2(OnDone, void(net::Error error, net::StreamSocket* socket)); + MOCK_METHOD2(OnDone, void(int error, net::StreamSocket* socket)); }; ACTION_P(QuitThreadOnCounter, counter) { @@ -82,7 +82,7 @@ class SslHmacChannelAuthenticatorTest : public testing::Test { host_auth_->SecureAndAuthenticate( host_fake_socket_.PassAs<net::StreamSocket>(), base::Bind(&SslHmacChannelAuthenticatorTest::OnHostConnected, - base::Unretained(this))); + base::Unretained(this), std::string("ref argument value"))); // Expect two callbacks to be called - the client callback and the host // callback. @@ -109,14 +109,20 @@ class SslHmacChannelAuthenticatorTest : public testing::Test { message_loop_.Run(); } - void OnHostConnected(net::Error error, + void OnHostConnected(const std::string& ref_argument, + int error, scoped_ptr<net::StreamSocket> socket) { + // Try deleting the authenticator and verify that this doesn't destroy + // reference parameters. + host_auth_.reset(); + DCHECK_EQ(ref_argument, "ref argument value"); + host_callback_.OnDone(error, socket.get()); host_socket_ = socket.Pass(); } - void OnClientConnected(net::Error error, - scoped_ptr<net::StreamSocket> socket) { + void OnClientConnected(int error, scoped_ptr<net::StreamSocket> socket) { + client_auth_.reset(); client_callback_.OnDone(error, socket.get()); client_socket_ = socket.Pass(); } diff --git a/remoting/protocol/channel_factory.h b/remoting/protocol/stream_channel_factory.h index 31bf3ec..6c3ebd1 100644 --- a/remoting/protocol/channel_factory.h +++ b/remoting/protocol/stream_channel_factory.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ -#define REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ +#ifndef REMOTING_PROTOCOL_STREAM_CHANNEL_FACTORY_H_ +#define REMOTING_PROTOCOL_STREAM_CHANNEL_FACTORY_H_ #include "base/callback.h" #include "base/memory/scoped_ptr.h" @@ -17,22 +17,20 @@ class StreamSocket; namespace remoting { namespace protocol { -class ChannelFactory : public base::NonThreadSafe { +class StreamChannelFactory : public base::NonThreadSafe { public: // TODO(sergeyu): Specify connection error code when channel // connection fails. typedef base::Callback<void(scoped_ptr<net::StreamSocket>)> ChannelCreatedCallback; - ChannelFactory() {} + StreamChannelFactory() {} - // Creates new channels for this connection. The specified callback is called - // when then new channel is created and connected. The callback is called with - // NULL if connection failed for any reason. Callback may be called - // synchronously, before the call returns. All channels must be destroyed - // before the factory is destroyed and CancelChannelCreation() must be called - // to cancel creation of channels for which the |callback| hasn't been called - // yet. + // Creates new channels and calls the |callback| when then new channel is + // created and connected. The |callback| is called with NULL if connection + // failed for any reason. Callback may be called synchronously, before the + // call returns. All channels must be destroyed, and CancelChannelCreation() + // called for any pending channels, before the factory is destroyed. virtual void CreateChannel(const std::string& name, const ChannelCreatedCallback& callback) = 0; @@ -42,13 +40,13 @@ class ChannelFactory : public base::NonThreadSafe { virtual void CancelChannelCreation(const std::string& name) = 0; protected: - virtual ~ChannelFactory() {} + virtual ~StreamChannelFactory() {} private: - DISALLOW_COPY_AND_ASSIGN(ChannelFactory); + DISALLOW_COPY_AND_ASSIGN(StreamChannelFactory); }; } // namespace protocol } // namespace remoting -#endif // REMOTING_PROTOCOL_CHANNEL_FACTORY_H_ +#endif // REMOTING_PROTOCOL_STREAM_CHANNEL_FACTORY_H_ diff --git a/remoting/protocol/transport.h b/remoting/protocol/transport.h index eb20e12..d4c4b3f 100644 --- a/remoting/protocol/transport.h +++ b/remoting/protocol/transport.h @@ -87,14 +87,15 @@ class Transport : public base::NonThreadSafe { virtual void OnTransportDeleted(Transport* transport) = 0; }; + typedef base::Callback<void(scoped_ptr<net::Socket>)> ConnectedCallback; + Transport() {} virtual ~Transport() {} - // Intialize the transport with the specified parameters. - // |authenticator| is used to secure and authenticate the connection. - virtual void Initialize(const std::string& name, - Transport::EventHandler* event_handler, - scoped_ptr<ChannelAuthenticator> authenticator) = 0; + // Connects the transport and calls the |callback| after that. + virtual void Connect(const std::string& name, + Transport::EventHandler* event_handler, + const ConnectedCallback& callback) = 0; // Adds |candidate| received from the peer. virtual void AddRemoteCandidate(const cricket::Candidate& candidate) = 0; @@ -111,32 +112,6 @@ class Transport : public base::NonThreadSafe { DISALLOW_COPY_AND_ASSIGN(Transport); }; -class StreamTransport : public Transport { - public: - typedef base::Callback<void(scoped_ptr<net::StreamSocket>)> ConnectedCallback; - - StreamTransport() { } - virtual ~StreamTransport() { } - - virtual void Connect(const ConnectedCallback& callback) = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(StreamTransport); -}; - -class DatagramTransport : public Transport { - public: - typedef base::Callback<void(scoped_ptr<net::Socket>)> ConnectedCallback; - - DatagramTransport() { } - virtual ~DatagramTransport() { } - - virtual void Connect(const ConnectedCallback& callback) = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(DatagramTransport); -}; - class TransportFactory { public: TransportFactory() { } @@ -148,8 +123,7 @@ class TransportFactory { // necessary while the session is being authenticated. virtual void PrepareTokens() = 0; - virtual scoped_ptr<StreamTransport> CreateStreamTransport() = 0; - virtual scoped_ptr<DatagramTransport> CreateDatagramTransport() = 0; + virtual scoped_ptr<Transport> CreateTransport() = 0; private: DISALLOW_COPY_AND_ASSIGN(TransportFactory); diff --git a/remoting/remoting_srcs.gypi b/remoting/remoting_srcs.gypi index e9440ad..e8611d6 100644 --- a/remoting/remoting_srcs.gypi +++ b/remoting/remoting_srcs.gypi @@ -113,6 +113,7 @@ 'protocol/connection_to_host.h', 'protocol/content_description.cc', 'protocol/content_description.h', + 'protocol/datagram_channel_factory.h', 'protocol/errors.h', 'protocol/host_control_dispatcher.cc', 'protocol/host_control_dispatcher.h', @@ -167,6 +168,10 @@ 'protocol/protobuf_video_reader.h', 'protocol/protobuf_video_writer.cc', 'protocol/protobuf_video_writer.h', + 'protocol/pseudotcp_channel_factory.cc', + 'protocol/pseudotcp_channel_factory.h', + 'protocol/secure_channel_factory.cc', + 'protocol/secure_channel_factory.h', 'protocol/session.h', 'protocol/session_config.cc', 'protocol/session_config.h', @@ -175,6 +180,7 @@ 'protocol/socket_util.h', 'protocol/ssl_hmac_channel_authenticator.cc', 'protocol/ssl_hmac_channel_authenticator.h', + 'protocol/stream_channel_factory.h', 'protocol/third_party_authenticator_base.cc', 'protocol/third_party_authenticator_base.h', 'protocol/third_party_client_authenticator.cc', |