diff options
42 files changed, 546 insertions, 642 deletions
diff --git a/remoting/client/chromoting_client.cc b/remoting/client/chromoting_client.cc index 4e5b53e..9a9506f 100644 --- a/remoting/client/chromoting_client.cc +++ b/remoting/client/chromoting_client.cc @@ -70,9 +70,7 @@ void ChromotingClient::Start( connection_->set_video_stub(video_renderer_->GetVideoStub()); connection_->set_audio_stub(audio_decode_scheduler_.get()); - session_manager_.reset(new protocol::JingleSessionManager( - make_scoped_ptr(new protocol::IceTransportFactory(transport_context)), - signal_strategy)); + session_manager_.reset(new protocol::JingleSessionManager(signal_strategy)); if (!protocol_config_) protocol_config_ = protocol::CandidateSessionConfig::CreateDefault(); @@ -81,6 +79,7 @@ void ChromotingClient::Start( session_manager_->set_protocol_config(std::move(protocol_config_)); authenticator_ = std::move(authenticator); + transport_context_ = transport_context; signal_strategy_ = signal_strategy; signal_strategy_->AddListener(this); @@ -202,7 +201,8 @@ bool ChromotingClient::OnSignalStrategyIncomingStanza( void ChromotingClient::StartConnection() { DCHECK(thread_checker_.CalledOnValidThread()); connection_->Connect( - session_manager_->Connect(host_jid_, std::move(authenticator_)), this); + session_manager_->Connect(host_jid_, std::move(authenticator_)), + transport_context_, this); } void ChromotingClient::OnAuthenticated() { diff --git a/remoting/client/chromoting_client.h b/remoting/client/chromoting_client.h index 1fe71fa..dffbc5a 100644 --- a/remoting/client/chromoting_client.h +++ b/remoting/client/chromoting_client.h @@ -123,6 +123,7 @@ class ChromotingClient : public SignalStrategy::Listener, std::string host_jid_; scoped_ptr<protocol::Authenticator> authenticator_; + scoped_refptr<protocol::TransportContext> transport_context_; scoped_ptr<protocol::SessionManager> session_manager_; scoped_ptr<protocol::ConnectionToHost> connection_; diff --git a/remoting/host/chromoting_host.cc b/remoting/host/chromoting_host.cc index 3fdfaaa..885f2e9 100644 --- a/remoting/host/chromoting_host.cc +++ b/remoting/host/chromoting_host.cc @@ -25,6 +25,7 @@ #include "remoting/protocol/host_stub.h" #include "remoting/protocol/ice_connection_to_client.h" #include "remoting/protocol/input_stub.h" +#include "remoting/protocol/transport_context.h" #include "remoting/protocol/webrtc_connection_to_client.h" using remoting::protocol::ConnectionToClient; @@ -65,6 +66,7 @@ const net::BackoffEntry::Policy kDefaultBackoffPolicy = { ChromotingHost::ChromotingHost( DesktopEnvironmentFactory* desktop_environment_factory, scoped_ptr<protocol::SessionManager> session_manager, + scoped_refptr<protocol::TransportContext> transport_context, scoped_refptr<base::SingleThreadTaskRunner> audio_task_runner, scoped_refptr<base::SingleThreadTaskRunner> input_task_runner, scoped_refptr<base::SingleThreadTaskRunner> video_capture_task_runner, @@ -73,6 +75,7 @@ ChromotingHost::ChromotingHost( scoped_refptr<base::SingleThreadTaskRunner> ui_task_runner) : desktop_environment_factory_(desktop_environment_factory), session_manager_(std::move(session_manager)), + transport_context_(transport_context), audio_task_runner_(audio_task_runner), input_task_runner_(input_task_runner), video_capture_task_runner_(video_capture_task_runner), @@ -273,11 +276,12 @@ void ChromotingHost::OnIncomingSession( scoped_ptr<protocol::ConnectionToClient> connection; if (session->config().protocol() == protocol::SessionConfig::Protocol::WEBRTC) { - connection.reset( - new protocol::WebrtcConnectionToClient(make_scoped_ptr(session))); + connection.reset(new protocol::WebrtcConnectionToClient( + make_scoped_ptr(session), transport_context_)); } else { connection.reset(new protocol::IceConnectionToClient( - make_scoped_ptr(session), video_encode_task_runner_)); + make_scoped_ptr(session), transport_context_, + video_encode_task_runner_)); } // Create a ClientSession object. diff --git a/remoting/host/chromoting_host.h b/remoting/host/chromoting_host.h index 6b23d4d..3b1a1fa 100644 --- a/remoting/host/chromoting_host.h +++ b/remoting/host/chromoting_host.h @@ -34,6 +34,7 @@ namespace remoting { namespace protocol { class InputStub; +class TransportContext; } // namespace protocol class DesktopEnvironmentFactory; @@ -69,6 +70,7 @@ class ChromotingHost : public base::NonThreadSafe, ChromotingHost( DesktopEnvironmentFactory* desktop_environment_factory, scoped_ptr<protocol::SessionManager> session_manager, + scoped_refptr<protocol::TransportContext> transport_context, scoped_refptr<base::SingleThreadTaskRunner> audio_task_runner, scoped_refptr<base::SingleThreadTaskRunner> input_task_runner, scoped_refptr<base::SingleThreadTaskRunner> video_capture_task_runner, @@ -156,6 +158,7 @@ class ChromotingHost : public base::NonThreadSafe, // Parameters specified when the host was created. DesktopEnvironmentFactory* desktop_environment_factory_; scoped_ptr<protocol::SessionManager> session_manager_; + scoped_refptr<protocol::TransportContext> transport_context_; scoped_refptr<base::SingleThreadTaskRunner> audio_task_runner_; scoped_refptr<base::SingleThreadTaskRunner> input_task_runner_; scoped_refptr<base::SingleThreadTaskRunner> video_capture_task_runner_; diff --git a/remoting/host/chromoting_host_unittest.cc b/remoting/host/chromoting_host_unittest.cc index fc77198..f033eee 100644 --- a/remoting/host/chromoting_host_unittest.cc +++ b/remoting/host/chromoting_host_unittest.cc @@ -21,6 +21,7 @@ #include "remoting/protocol/fake_desktop_capturer.h" #include "remoting/protocol/protocol_mock_objects.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/transport_context.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock_mutant.h" #include "testing/gtest/include/gtest/gtest.h" @@ -53,8 +54,7 @@ namespace remoting { class ChromotingHostTest : public testing::Test { public: - ChromotingHostTest() { - } + ChromotingHostTest() {} void SetUp() override { task_runner_ = new AutoThreadTaskRunner(message_loop_.task_runner(), @@ -63,14 +63,15 @@ class ChromotingHostTest : public testing::Test { desktop_environment_factory_.reset(new FakeDesktopEnvironmentFactory()); session_manager_ = new protocol::MockSessionManager(); - host_.reset(new ChromotingHost(desktop_environment_factory_.get(), - make_scoped_ptr(session_manager_), - task_runner_, // Audio - task_runner_, // Input - task_runner_, // Video capture - task_runner_, // Video encode - task_runner_, // Network - task_runner_)); // UI + host_.reset(new ChromotingHost( + desktop_environment_factory_.get(), make_scoped_ptr(session_manager_), + protocol::TransportContext::ForTests(protocol::TransportRole::SERVER), + task_runner_, // Audio + task_runner_, // Input + task_runner_, // Video capture + task_runner_, // Video encode + task_runner_, // Network + task_runner_)); // UI host_->AddStatusObserver(&host_status_observer_); xmpp_login_ = "host@domain"; diff --git a/remoting/host/it2me/it2me_host.cc b/remoting/host/it2me/it2me_host.cc index 426c06b..c7733fe 100644 --- a/remoting/host/it2me/it2me_host.cc +++ b/remoting/host/it2me/it2me_host.cc @@ -235,16 +235,15 @@ void It2MeHost::FinishConnect() { protocol::NetworkSettings::kDefaultMaxPort; } - scoped_ptr<protocol::TransportFactory> transport_factory( - new protocol::IceTransportFactory(new protocol::TransportContext( + scoped_refptr<protocol::TransportContext> transport_context = + new protocol::TransportContext( signal_strategy_.get(), make_scoped_ptr(new protocol::ChromiumPortAllocatorFactory( host_context_->url_request_context_getter())), - network_settings, protocol::TransportRole::SERVER))); + network_settings, protocol::TransportRole::SERVER); scoped_ptr<protocol::SessionManager> session_manager( - new protocol::JingleSessionManager(std::move(transport_factory), - signal_strategy.get())); + new protocol::JingleSessionManager(signal_strategy.get())); scoped_ptr<protocol::CandidateSessionConfig> protocol_config = protocol::CandidateSessionConfig::CreateDefault(); @@ -256,7 +255,8 @@ void It2MeHost::FinishConnect() { // Create the host. host_.reset(new ChromotingHost( desktop_environment_factory_.get(), std::move(session_manager), - host_context_->audio_task_runner(), host_context_->input_task_runner(), + transport_context, host_context_->audio_task_runner(), + host_context_->input_task_runner(), host_context_->video_capture_task_runner(), host_context_->video_encode_task_runner(), host_context_->network_task_runner(), host_context_->ui_task_runner())); diff --git a/remoting/host/remoting_me2me_host.cc b/remoting/host/remoting_me2me_host.cc index 04be3d3..6628902 100644 --- a/remoting/host/remoting_me2me_host.cc +++ b/remoting/host/remoting_me2me_host.cc @@ -80,7 +80,6 @@ #include "remoting/protocol/authenticator.h" #include "remoting/protocol/channel_authenticator.h" #include "remoting/protocol/chromium_port_allocator.h" -#include "remoting/protocol/ice_transport.h" #include "remoting/protocol/jingle_session_manager.h" #include "remoting/protocol/me2me_host_authenticator_factory.h" #include "remoting/protocol/network_settings.h" @@ -88,7 +87,6 @@ #include "remoting/protocol/port_range.h" #include "remoting/protocol/token_validator.h" #include "remoting/protocol/transport_context.h" -#include "remoting/protocol/webrtc_transport.h" #include "remoting/signaling/push_notification_subscriber.h" #include "remoting/signaling/xmpp_signal_strategy.h" #include "third_party/webrtc/base/scoped_ref_ptr.h" @@ -885,6 +883,14 @@ void HostProcess::StartOnUiThread() { remoting::GnubbyAuthHandler::SetGnubbySocketName(gnubby_socket_name); #endif // defined(OS_LINUX) +#if defined(NDEBUG) + if (base::CommandLine::ForCurrentProcess()->HasSwitch(kEnableWebrtc)) { + LOG(ERROR) << "WebRTC is enabled only in debug builds."; + ShutdownHost(kUsageExitCode); + return; + } +#endif // defined(NDEBUG) + // Create a desktop environment factory appropriate to the build type & // platform. #if defined(OS_WIN) @@ -1507,29 +1513,9 @@ void HostProcess::StartHost() { make_scoped_ptr(new protocol::ChromiumPortAllocatorFactory( context_->url_request_context_getter())), network_settings, protocol::TransportRole::SERVER); - scoped_ptr<protocol::TransportFactory> transport_factory; - if (base::CommandLine::ForCurrentProcess()->HasSwitch(kEnableWebrtc)) { -#if !defined(NDEBUG) - jingle_glue::JingleThreadWrapper::EnsureForCurrentMessageLoop(); - // The network thread is also used as worker thread for webrtc. - // - // TODO(sergeyu): Figure out if we would benefit from using a separate - // thread as a worker thread. - transport_factory.reset(new protocol::WebrtcTransportFactory( - jingle_glue::JingleThreadWrapper::current(), transport_context)); -#else // !defined(NDEBUG) - LOG(ERROR) << "WebRTC is enabled only in debug builds."; - ShutdownHost(kUsageExitCode); - return; -#endif // defined(NDEBUG) - } else { - transport_factory.reset( - new protocol::IceTransportFactory(transport_context)); - } scoped_ptr<protocol::SessionManager> session_manager( - new protocol::JingleSessionManager(std::move(transport_factory), - signal_strategy_.get())); + new protocol::JingleSessionManager(signal_strategy_.get())); scoped_ptr<protocol::CandidateSessionConfig> protocol_config = protocol::CandidateSessionConfig::CreateDefault(); @@ -1546,8 +1532,8 @@ void HostProcess::StartHost() { host_.reset(new ChromotingHost( desktop_environment_factory_.get(), std::move(session_manager), - context_->audio_task_runner(), context_->input_task_runner(), - context_->video_capture_task_runner(), + transport_context, context_->audio_task_runner(), + context_->input_task_runner(), context_->video_capture_task_runner(), context_->video_encode_task_runner(), context_->network_task_runner(), context_->ui_task_runner())); diff --git a/remoting/protocol/connection_to_host.h b/remoting/protocol/connection_to_host.h index b41df20..d92905b 100644 --- a/remoting/protocol/connection_to_host.h +++ b/remoting/protocol/connection_to_host.h @@ -7,6 +7,7 @@ #include <string> +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "remoting/protocol/errors.h" @@ -21,6 +22,7 @@ class HostStub; class InputStub; class Session; class SessionConfig; +class TransportContext; struct TransportRoute; class VideoStub; @@ -76,6 +78,7 @@ class ConnectionToHost { // of changes in the state of the connection and must outlive the // ConnectionToHost. Caller must set stubs (see below) before calling Connect. virtual void Connect(scoped_ptr<Session> session, + scoped_refptr<TransportContext> transport_context, HostEventCallback* event_callback) = 0; // Returns the session configuration that was negotiated with the host. diff --git a/remoting/protocol/connection_unittest.cc b/remoting/protocol/connection_unittest.cc index 5a496c1..12e94f8 100644 --- a/remoting/protocol/connection_unittest.cc +++ b/remoting/protocol/connection_unittest.cc @@ -13,6 +13,7 @@ #include "remoting/protocol/ice_connection_to_client.h" #include "remoting/protocol/ice_connection_to_host.h" #include "remoting/protocol/protocol_mock_objects.h" +#include "remoting/protocol/transport_context.h" #include "remoting/protocol/webrtc_connection_to_client.h" #include "remoting/protocol/webrtc_connection_to_host.h" #include "testing/gmock/include/gmock/gmock.h" @@ -67,13 +68,16 @@ class ConnectionTest : public testing::Test, // Create Connection objects if (GetParam()) { - host_connection_.reset( - new WebrtcConnectionToClient(make_scoped_ptr(host_session_))); + host_connection_.reset(new WebrtcConnectionToClient( + make_scoped_ptr(host_session_), + TransportContext::ForTests(protocol::TransportRole::SERVER))); client_connection_.reset(new WebrtcConnectionToHost()); } else { host_connection_.reset(new IceConnectionToClient( - make_scoped_ptr(host_session_), message_loop_.task_runner())); + make_scoped_ptr(host_session_), + TransportContext::ForTests(protocol::TransportRole::SERVER), + message_loop_.task_runner())); client_connection_.reset(new IceConnectionToHost()); } @@ -98,7 +102,11 @@ class ConnectionTest : public testing::Test, OnConnectionAuthenticated(host_connection_.get())); } EXPECT_CALL(host_event_handler_, - OnConnectionChannelsConnected(host_connection_.get())); + OnConnectionChannelsConnected(host_connection_.get())) + .WillOnce( + InvokeWithoutArgs(this, &ConnectionTest::OnHostConnected)); + EXPECT_CALL(host_event_handler_, OnRouteChange(_, _, _)) + .Times(testing::AnyNumber()); { testing::InSequence sequence; @@ -107,13 +115,24 @@ class ConnectionTest : public testing::Test, EXPECT_CALL(client_event_handler_, OnConnectionState(ConnectionToHost::AUTHENTICATED, OK)); EXPECT_CALL(client_event_handler_, - OnConnectionState(ConnectionToHost::CONNECTED, OK)); + OnConnectionState(ConnectionToHost::CONNECTED, OK)) + .WillOnce(InvokeWithoutArgs( + this, &ConnectionTest::OnClientConnected)); } + EXPECT_CALL(client_event_handler_, OnRouteChanged(_, _)) + .Times(testing::AnyNumber()); - client_connection_->Connect(std::move(owned_client_session_), - &client_event_handler_); + client_connection_->Connect( + std::move(owned_client_session_), + TransportContext::ForTests(protocol::TransportRole::CLIENT), + &client_event_handler_); client_session_->SimulateConnection(host_session_); - base::RunLoop().RunUntilIdle(); + + run_loop_.reset(new base::RunLoop()); + run_loop_->Run(); + + EXPECT_TRUE(client_connected_); + EXPECT_TRUE(host_connected_); } void TearDown() override { @@ -122,7 +141,20 @@ class ConnectionTest : public testing::Test, base::RunLoop().RunUntilIdle(); } - base::MessageLoop message_loop_; + void OnHostConnected() { + host_connected_ = true; + if (client_connected_ && run_loop_) + run_loop_->Quit(); + } + + void OnClientConnected() { + client_connected_ = true; + if (host_connected_ && run_loop_) + run_loop_->Quit(); + } + + base::MessageLoopForIO message_loop_; + scoped_ptr<base::RunLoop> run_loop_; MockConnectionToClientEventHandler host_event_handler_; MockClipboardStub host_clipboard_stub_; @@ -130,6 +162,7 @@ class ConnectionTest : public testing::Test, MockInputStub host_input_stub_; scoped_ptr<ConnectionToClient> host_connection_; FakeSession* host_session_; // Owned by |host_connection_|. + bool host_connected_ = false; MockConnectionToHostEventCallback client_event_handler_; MockClientStub client_stub_; @@ -138,6 +171,7 @@ class ConnectionTest : public testing::Test, scoped_ptr<ConnectionToHost> client_connection_; FakeSession* client_session_; // Owned by |client_connection_|. scoped_ptr<FakeSession> owned_client_session_; + bool client_connected_ = false; private: DISALLOW_COPY_AND_ASSIGN(ConnectionTest); @@ -152,8 +186,10 @@ TEST_P(ConnectionTest, RejectConnection) { EXPECT_CALL(client_event_handler_, OnConnectionState(ConnectionToHost::CLOSED, OK)); - client_connection_->Connect(std::move(owned_client_session_), - &client_event_handler_); + client_connection_->Connect( + std::move(owned_client_session_), + TransportContext::ForTests(protocol::TransportRole::CLIENT), + &client_event_handler_); client_session_->event_handler()->OnSessionStateChange(Session::CLOSED); } diff --git a/remoting/protocol/fake_authenticator.cc b/remoting/protocol/fake_authenticator.cc index 2d25898..416d391 100644 --- a/remoting/protocol/fake_authenticator.cc +++ b/remoting/protocol/fake_authenticator.cc @@ -37,9 +37,9 @@ void FakeChannelAuthenticator::SecureAndAuthenticate( const DoneCallback& done_callback) { socket_ = std::move(socket); - if (async_) { - done_callback_ = done_callback; + done_callback_ = done_callback; + if (async_) { if (result_ != net::OK) { // Don't write anything if we are going to reject auth to make test // ordering deterministic. diff --git a/remoting/protocol/fake_connection_to_host.cc b/remoting/protocol/fake_connection_to_host.cc index f818730..cbec769 100644 --- a/remoting/protocol/fake_connection_to_host.cc +++ b/remoting/protocol/fake_connection_to_host.cc @@ -23,8 +23,10 @@ void FakeConnectionToHost::set_video_stub(protocol::VideoStub* video_stub) {} void FakeConnectionToHost::set_audio_stub(protocol::AudioStub* audio_stub) {} -void FakeConnectionToHost::Connect(scoped_ptr<protocol::Session> session, - HostEventCallback* event_callback) { +void FakeConnectionToHost::Connect( + scoped_ptr<protocol::Session> session, + scoped_refptr<protocol::TransportContext> transport_context, + HostEventCallback* event_callback) { DCHECK(event_callback); event_callback_ = event_callback; @@ -41,7 +43,6 @@ void FakeConnectionToHost::SignalStateChange(protocol::Session::State state, case protocol::Session::CONNECTING: case protocol::Session::ACCEPTING: case protocol::Session::AUTHENTICATING: - case protocol::Session::CONNECTED: // No updates for these events. break; diff --git a/remoting/protocol/fake_connection_to_host.h b/remoting/protocol/fake_connection_to_host.h index 668196e..62f93b1 100644 --- a/remoting/protocol/fake_connection_to_host.h +++ b/remoting/protocol/fake_connection_to_host.h @@ -25,6 +25,7 @@ class FakeConnectionToHost : public protocol::ConnectionToHost { void set_video_stub(protocol::VideoStub* video_stub) override; void set_audio_stub(protocol::AudioStub* audio_stub) override; void Connect(scoped_ptr<protocol::Session> session, + scoped_refptr<protocol::TransportContext> transport_context, HostEventCallback* event_callback) override; const protocol::SessionConfig& config() override; protocol::ClipboardStub* clipboard_forwarder() override; diff --git a/remoting/protocol/fake_session.cc b/remoting/protocol/fake_session.cc index 46edfe9..616d12d 100644 --- a/remoting/protocol/fake_session.cc +++ b/remoting/protocol/fake_session.cc @@ -4,57 +4,46 @@ #include "remoting/protocol/fake_session.h" +#include "base/location.h" +#include "base/thread_task_runner_handle.h" +#include "remoting/protocol/fake_authenticator.h" +#include "third_party/webrtc/libjingle/xmllite/xmlelement.h" + namespace remoting { namespace protocol { const char kTestJid[] = "host1@gmail.com/chromoting123"; -FakeTransport::FakeTransport() {} -FakeTransport::~FakeTransport() {} - -void FakeTransport::Start(EventHandler* event_handler, - Authenticator* authenticator) { - NOTREACHED(); -} - -bool FakeTransport::ProcessTransportInfo( - buzz::XmlElement* transport_info) { - NOTREACHED(); - return true; -} - -FakeStreamChannelFactory* FakeTransport::GetStreamChannelFactory() { - return &channel_factory_; -} - -FakeStreamChannelFactory* FakeTransport::GetMultiplexedChannelFactory() { - return &channel_factory_; -} - FakeSession::FakeSession() : config_(SessionConfig::ForTest()), jid_(kTestJid), weak_factory_(this) {} - FakeSession::~FakeSession() {} void FakeSession::SimulateConnection(FakeSession* peer) { peer_ = peer->weak_factory_.GetWeakPtr(); peer->peer_ = weak_factory_.GetWeakPtr(); - transport_.GetStreamChannelFactory()->PairWith( - peer->transport_.GetStreamChannelFactory()); - transport_.GetMultiplexedChannelFactory()->PairWith( - peer->transport_.GetMultiplexedChannelFactory()); - event_handler_->OnSessionStateChange(CONNECTING); peer->event_handler_->OnSessionStateChange(ACCEPTING); peer->event_handler_->OnSessionStateChange(ACCEPTED); event_handler_->OnSessionStateChange(ACCEPTED); event_handler_->OnSessionStateChange(AUTHENTICATING); peer->event_handler_->OnSessionStateChange(AUTHENTICATING); - event_handler_->OnSessionStateChange(AUTHENTICATED); + + // Initialize transport and authenticator on the client. + authenticator_.reset(new FakeAuthenticator(FakeAuthenticator::CLIENT, 0, + FakeAuthenticator::ACCEPT, false)); + transport_->Start(authenticator_.get(), + base::Bind(&FakeSession::SendTransportInfo, + weak_factory_.GetWeakPtr())); + + // Initialize transport and authenticator on the host. + peer->authenticator_.reset(new FakeAuthenticator( + FakeAuthenticator::HOST, 0, FakeAuthenticator::ACCEPT, false)); + peer->transport_->Start(peer->authenticator_.get(), + base::Bind(&FakeSession::SendTransportInfo, peer_)); + peer->event_handler_->OnSessionStateChange(AUTHENTICATED); - event_handler_->OnSessionStateChange(CONNECTED); - peer->event_handler_->OnSessionStateChange(CONNECTED); + event_handler_->OnSessionStateChange(AUTHENTICATED); } void FakeSession::SetEventHandler(EventHandler* event_handler) { @@ -73,8 +62,8 @@ const SessionConfig& FakeSession::config() { return *config_; } -FakeTransport* FakeSession::GetTransport() { - return &transport_; +void FakeSession::SetTransport(Transport* transport) { + transport_ = transport; } void FakeSession::Close(ErrorCode error) { @@ -90,5 +79,12 @@ void FakeSession::Close(ErrorCode error) { } } +void FakeSession::SendTransportInfo( + scoped_ptr<buzz::XmlElement> transport_info) { + if (!peer_) + return; + peer_->transport_->ProcessTransportInfo(transport_info.get()); +} + } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/fake_session.h b/remoting/protocol/fake_session.h index ef0aea4..5717a97 100644 --- a/remoting/protocol/fake_session.h +++ b/remoting/protocol/fake_session.h @@ -21,24 +21,8 @@ namespace protocol { extern const char kTestJid[]; -class FakeTransport : public Transport { - public: - FakeTransport(); - ~FakeTransport() override; - - // Transport interface. - void Start(EventHandler* event_handler, - Authenticator* authenticator) override; - bool ProcessTransportInfo(buzz::XmlElement* transport_info) override; - FakeStreamChannelFactory* GetStreamChannelFactory() override; - FakeStreamChannelFactory* GetMultiplexedChannelFactory() override; - - private: - FakeStreamChannelFactory channel_factory_; -}; +class FakeAuthenticator; -// FakeSession is a dummy protocol::Session that uses FakeStreamSocket for all -// channels. class FakeSession : public Session { public: FakeSession(); @@ -55,16 +39,20 @@ class FakeSession : public Session { ErrorCode error() override; const std::string& jid() override; const SessionConfig& config() override; - FakeTransport* GetTransport() override; + void SetTransport(Transport* transport) override; void Close(ErrorCode error) override; private: + // Callback provided to the |transport_|. + void SendTransportInfo(scoped_ptr<buzz::XmlElement> transport_info); + EventHandler* event_handler_ = nullptr; scoped_ptr<SessionConfig> config_; std::string jid_; - FakeTransport transport_; + scoped_ptr<FakeAuthenticator> authenticator_; + Transport* transport_; ErrorCode error_ = OK; bool closed_ = false; diff --git a/remoting/protocol/ice_connection_to_client.cc b/remoting/protocol/ice_connection_to_client.cc index 41130a6..a5bae2d 100644 --- a/remoting/protocol/ice_connection_to_client.cc +++ b/remoting/protocol/ice_connection_to_client.cc @@ -19,6 +19,7 @@ #include "remoting/protocol/host_stub.h" #include "remoting/protocol/host_video_dispatcher.h" #include "remoting/protocol/input_stub.h" +#include "remoting/protocol/transport_context.h" #include "remoting/protocol/video_frame_pump.h" namespace remoting { @@ -46,14 +47,17 @@ scoped_ptr<VideoEncoder> CreateVideoEncoder( IceConnectionToClient::IceConnectionToClient( scoped_ptr<protocol::Session> session, + scoped_refptr<TransportContext> transport_context, scoped_refptr<base::SingleThreadTaskRunner> video_encode_task_runner) : event_handler_(nullptr), session_(std::move(session)), video_encode_task_runner_(video_encode_task_runner), + transport_(transport_context, this), control_dispatcher_(new HostControlDispatcher()), event_dispatcher_(new HostEventDispatcher()), video_dispatcher_(new HostVideoDispatcher()) { session_->SetEventHandler(this); + session_->SetTransport(&transport_); } IceConnectionToClient::~IceConnectionToClient() {} @@ -72,8 +76,6 @@ protocol::Session* IceConnectionToClient::session() { void IceConnectionToClient::Disconnect(ErrorCode error) { DCHECK(thread_checker_.CalledOnValidThread()); - CloseChannels(); - // This should trigger OnConnectionClosed() event and this object // may be destroyed as the result. session_->Close(error); @@ -136,7 +138,6 @@ void IceConnectionToClient::OnSessionStateChange(Session::State state) { case Session::CONNECTING: case Session::ACCEPTING: case Session::ACCEPTED: - case Session::CONNECTED: // Don't care about these events. break; case Session::AUTHENTICATING: @@ -144,23 +145,19 @@ void IceConnectionToClient::OnSessionStateChange(Session::State state) { break; case Session::AUTHENTICATED: // Initialize channels. - control_dispatcher_->Init( - session_->GetTransport()->GetMultiplexedChannelFactory(), this); + control_dispatcher_->Init(transport_.GetMultiplexedChannelFactory(), + this); - event_dispatcher_->Init( - session_->GetTransport()->GetMultiplexedChannelFactory(), this); + event_dispatcher_->Init(transport_.GetMultiplexedChannelFactory(), this); event_dispatcher_->set_on_input_event_callback( base::Bind(&IceConnectionToClient::OnInputEventReceived, base::Unretained(this))); - video_dispatcher_->Init( - session_->GetTransport()->GetStreamChannelFactory(), this); + video_dispatcher_->Init(transport_.GetStreamChannelFactory(), this); audio_writer_ = AudioWriter::Create(session_->config()); - if (audio_writer_.get()) { - audio_writer_->Init( - session_->GetTransport()->GetMultiplexedChannelFactory(), this); - } + if (audio_writer_) + audio_writer_->Init(transport_.GetMultiplexedChannelFactory(), this); // Notify the handler after initializing the channels, so that // ClientSession can get a client clipboard stub. @@ -168,21 +165,27 @@ void IceConnectionToClient::OnSessionStateChange(Session::State state) { break; case Session::CLOSED: - Close(OK); - break; - case Session::FAILED: - Close(session_->error()); + CloseChannels(); + event_handler_->OnConnectionClosed( + this, state == Session::FAILED ? session_->error() : OK); break; } } -void IceConnectionToClient::OnSessionRouteChange( + +void IceConnectionToClient::OnIceTransportRouteChange( const std::string& channel_name, const TransportRoute& route) { event_handler_->OnRouteChange(this, channel_name, route); } +void IceConnectionToClient::OnIceTransportError(ErrorCode error) { + DCHECK(thread_checker_.CalledOnValidThread()); + + Disconnect(error); +} + void IceConnectionToClient::OnChannelInitialized( ChannelDispatcherBase* channel_dispatcher) { DCHECK(thread_checker_.CalledOnValidThread()); @@ -197,7 +200,7 @@ void IceConnectionToClient::OnChannelError( LOG(ERROR) << "Failed to connect channel " << channel_dispatcher->channel_name(); - Close(CHANNEL_CONNECTION_ERROR); + Disconnect(error); } void IceConnectionToClient::NotifyIfChannelsReady() { @@ -216,11 +219,6 @@ void IceConnectionToClient::NotifyIfChannelsReady() { event_handler_->OnConnectionChannelsConnected(this); } -void IceConnectionToClient::Close(ErrorCode error) { - CloseChannels(); - event_handler_->OnConnectionClosed(this, error); -} - void IceConnectionToClient::CloseChannels() { control_dispatcher_.reset(); event_dispatcher_.reset(); diff --git a/remoting/protocol/ice_connection_to_client.h b/remoting/protocol/ice_connection_to_client.h index cf03496..1ea8677 100644 --- a/remoting/protocol/ice_connection_to_client.h +++ b/remoting/protocol/ice_connection_to_client.h @@ -14,6 +14,7 @@ #include "base/threading/thread_checker.h" #include "remoting/protocol/channel_dispatcher_base.h" #include "remoting/protocol/connection_to_client.h" +#include "remoting/protocol/ice_transport.h" #include "remoting/protocol/session.h" namespace remoting { @@ -31,10 +32,12 @@ class VideoFramePump; // stubs. class IceConnectionToClient : public ConnectionToClient, public Session::EventHandler, + public IceTransport::EventHandler, public ChannelDispatcherBase::EventHandler { public: IceConnectionToClient( scoped_ptr<Session> session, + scoped_refptr<TransportContext> transport_context, scoped_refptr<base::SingleThreadTaskRunner> video_encode_task_runner); ~IceConnectionToClient() override; @@ -54,8 +57,11 @@ class IceConnectionToClient : public ConnectionToClient, // Session::EventHandler interface. void OnSessionStateChange(Session::State state) override; - void OnSessionRouteChange(const std::string& channel_name, - const TransportRoute& route) override; + + // IceTransport::EventHandler interface. + void OnIceTransportRouteChange(const std::string& channel_name, + const TransportRoute& route) override; + void OnIceTransportError(ErrorCode error) override; // ChannelDispatcherBase::EventHandler interface. void OnChannelInitialized(ChannelDispatcherBase* channel_dispatcher) override; @@ -65,9 +71,6 @@ class IceConnectionToClient : public ConnectionToClient, private: void NotifyIfChannelsReady(); - void Close(ErrorCode error); - - // Stops writing in the channels. void CloseChannels(); base::ThreadChecker thread_checker_; @@ -79,6 +82,8 @@ class IceConnectionToClient : public ConnectionToClient, scoped_refptr<base::SingleThreadTaskRunner> video_encode_task_runner_; + IceTransport transport_; + scoped_ptr<HostControlDispatcher> control_dispatcher_; scoped_ptr<HostEventDispatcher> event_dispatcher_; scoped_ptr<HostVideoDispatcher> video_dispatcher_; diff --git a/remoting/protocol/ice_connection_to_host.cc b/remoting/protocol/ice_connection_to_host.cc index 39694af..ddb77c7 100644 --- a/remoting/protocol/ice_connection_to_host.cc +++ b/remoting/protocol/ice_connection_to_host.cc @@ -20,7 +20,7 @@ #include "remoting/protocol/clipboard_stub.h" #include "remoting/protocol/errors.h" #include "remoting/protocol/ice_transport.h" -#include "remoting/protocol/transport.h" +#include "remoting/protocol/transport_context.h" #include "remoting/protocol/video_stub.h" namespace remoting { @@ -29,14 +29,19 @@ namespace protocol { IceConnectionToHost::IceConnectionToHost() {} IceConnectionToHost::~IceConnectionToHost() {} -void IceConnectionToHost::Connect(scoped_ptr<Session> session, - HostEventCallback* event_callback) { +void IceConnectionToHost::Connect( + scoped_ptr<Session> session, + scoped_refptr<TransportContext> transport_context, + HostEventCallback* event_callback) { DCHECK(client_stub_); DCHECK(clipboard_stub_); DCHECK(monitored_video_stub_); + transport_.reset(new IceTransport(transport_context, this)); + session_ = std::move(session); session_->SetEventHandler(this); + session_->SetTransport(transport_.get()); event_callback_ = event_callback; @@ -91,32 +96,27 @@ void IceConnectionToHost::OnSessionStateChange(Session::State state) { case Session::ACCEPTING: case Session::ACCEPTED: case Session::AUTHENTICATING: - case Session::CONNECTED: // Don't care about these events. break; case Session::AUTHENTICATED: SetState(AUTHENTICATED, OK); - control_dispatcher_.reset(new ClientControlDispatcher()); - control_dispatcher_->Init( - session_->GetTransport()->GetMultiplexedChannelFactory(), this); + control_dispatcher_->Init(transport_->GetMultiplexedChannelFactory(), + this); control_dispatcher_->set_client_stub(client_stub_); control_dispatcher_->set_clipboard_stub(clipboard_stub_); event_dispatcher_.reset(new ClientEventDispatcher()); - event_dispatcher_->Init( - session_->GetTransport()->GetMultiplexedChannelFactory(), this); + event_dispatcher_->Init(transport_->GetMultiplexedChannelFactory(), this); video_dispatcher_.reset( new ClientVideoDispatcher(monitored_video_stub_.get())); - video_dispatcher_->Init( - session_->GetTransport()->GetStreamChannelFactory(), this); + video_dispatcher_->Init(transport_->GetStreamChannelFactory(), this); if (session_->config().is_audio_enabled()) { audio_reader_.reset(new AudioReader(audio_stub_)); - audio_reader_->Init( - session_->GetTransport()->GetMultiplexedChannelFactory(), this); + audio_reader_->Init(transport_->GetMultiplexedChannelFactory(), this); } break; @@ -144,11 +144,16 @@ void IceConnectionToHost::OnSessionStateChange(Session::State state) { } } -void IceConnectionToHost::OnSessionRouteChange(const std::string& channel_name, - const TransportRoute& route) { +void IceConnectionToHost::OnIceTransportRouteChange( + const std::string& channel_name, + const TransportRoute& route) { event_callback_->OnRouteChanged(channel_name, route); } +void IceConnectionToHost::OnIceTransportError(ErrorCode error) { + session_->Close(error); +} + void IceConnectionToHost::OnChannelInitialized( ChannelDispatcherBase* channel_dispatcher) { NotifyIfChannelsReady(); @@ -157,9 +162,9 @@ void IceConnectionToHost::OnChannelInitialized( void IceConnectionToHost::OnChannelError( ChannelDispatcherBase* channel_dispatcher, ErrorCode error) { - LOG(ERROR) << "Failed to connect channel " << channel_dispatcher; + LOG(ERROR) << "Failed to connect channel " + << channel_dispatcher->channel_name(); CloseOnError(CHANNEL_CONNECTION_ERROR); - return; } void IceConnectionToHost::OnVideoChannelStatus(bool active) { diff --git a/remoting/protocol/ice_connection_to_host.h b/remoting/protocol/ice_connection_to_host.h index 2ebc584..6cc8e15 100644 --- a/remoting/protocol/ice_connection_to_host.h +++ b/remoting/protocol/ice_connection_to_host.h @@ -18,6 +18,7 @@ #include "remoting/protocol/clipboard_filter.h" #include "remoting/protocol/connection_to_host.h" #include "remoting/protocol/errors.h" +#include "remoting/protocol/ice_transport.h" #include "remoting/protocol/input_filter.h" #include "remoting/protocol/message_reader.h" #include "remoting/protocol/monitored_video_stub.h" @@ -34,6 +35,7 @@ class ClientVideoDispatcher; class IceConnectionToHost : public ConnectionToHost, public Session::EventHandler, + public IceTransport::EventHandler, public ChannelDispatcherBase::EventHandler, public base::NonThreadSafe { public: @@ -46,6 +48,7 @@ class IceConnectionToHost : public ConnectionToHost, void set_video_stub(VideoStub* video_stub) override; void set_audio_stub(AudioStub* audio_stub) override; void Connect(scoped_ptr<Session> session, + scoped_refptr<TransportContext> transport_context, HostEventCallback* event_callback) override; const SessionConfig& config() override; ClipboardStub* clipboard_forwarder() override; @@ -56,8 +59,11 @@ class IceConnectionToHost : public ConnectionToHost, private: // Session::EventHandler interface. void OnSessionStateChange(Session::State state) override; - void OnSessionRouteChange(const std::string& channel_name, - const TransportRoute& route) override; + + // IceTransport::EventHandler interface. + void OnIceTransportRouteChange(const std::string& channel_name, + const TransportRoute& route) override; + void OnIceTransportError(ErrorCode error) override; // ChannelDispatcherBase::EventHandler interface. void OnChannelInitialized(ChannelDispatcherBase* channel_dispatcher) override; @@ -71,7 +77,7 @@ class IceConnectionToHost : public ConnectionToHost, void CloseOnError(ErrorCode error); - // Stops writing in the channels. + // Closes the P2P connection. void CloseChannels(); void SetState(State state, ErrorCode error); @@ -84,8 +90,9 @@ class IceConnectionToHost : public ConnectionToHost, AudioStub* audio_stub_ = nullptr; scoped_ptr<Session> session_; - scoped_ptr<MonitoredVideoStub> monitored_video_stub_; + scoped_ptr<IceTransport> transport_; + scoped_ptr<MonitoredVideoStub> monitored_video_stub_; scoped_ptr<ClientVideoDispatcher> video_dispatcher_; scoped_ptr<AudioReader> audio_reader_; scoped_ptr<ClientControlDispatcher> control_dispatcher_; diff --git a/remoting/protocol/ice_transport.cc b/remoting/protocol/ice_transport.cc index 08e95d3..858500e 100644 --- a/remoting/protocol/ice_transport.cc +++ b/remoting/protocol/ice_transport.cc @@ -23,8 +23,11 @@ const int kTransportInfoSendDelayMs = 20; // Name of the multiplexed channel. static const char kMuxChannelName[] = "mux"; -IceTransport::IceTransport(scoped_refptr<TransportContext> transport_context) - : transport_context_(transport_context), weak_factory_(this) { +IceTransport::IceTransport(scoped_refptr<TransportContext> transport_context, + EventHandler* event_handler) + : transport_context_(transport_context), + event_handler_(event_handler), + weak_factory_(this) { transport_context->Prepare(); } @@ -33,12 +36,12 @@ IceTransport::~IceTransport() { DCHECK(channels_.empty()); } -void IceTransport::Start(Transport::EventHandler* event_handler, - Authenticator* authenticator) { - DCHECK(event_handler); - DCHECK(!event_handler_); +void IceTransport::Start( + Authenticator* authenticator, + SendTransportInfoCallback send_transport_info_callback) { + DCHECK(!pseudotcp_channel_factory_); - event_handler_ = event_handler; + send_transport_info_callback_ = std::move(send_transport_info_callback); pseudotcp_channel_factory_.reset(new PseudoTcpChannelFactory(this)); secure_channel_factory_.reset(new SecureChannelFactory( pseudotcp_channel_factory_.get(), authenticator)); @@ -132,32 +135,32 @@ void IceTransport::AddPendingRemoteTransportInfo(IceTransportChannel* channel) { } } -void IceTransport::OnTransportIceCredentials(IceTransportChannel* channel, - const std::string& ufrag, - const std::string& password) { +void IceTransport::OnChannelIceCredentials(IceTransportChannel* channel, + const std::string& ufrag, + const std::string& password) { EnsurePendingTransportInfoMessage(); pending_transport_info_message_->ice_credentials.push_back( IceTransportInfo::IceCredentials(channel->name(), ufrag, password)); } -void IceTransport::OnTransportCandidate(IceTransportChannel* channel, - const cricket::Candidate& candidate) { +void IceTransport::OnChannelCandidate(IceTransportChannel* channel, + const cricket::Candidate& candidate) { EnsurePendingTransportInfoMessage(); pending_transport_info_message_->candidates.push_back( IceTransportInfo::NamedCandidate(channel->name(), candidate)); } -void IceTransport::OnTransportRouteChange(IceTransportChannel* channel, - const TransportRoute& route) { +void IceTransport::OnChannelRouteChange(IceTransportChannel* channel, + const TransportRoute& route) { if (event_handler_) - event_handler_->OnTransportRouteChange(channel->name(), route); + event_handler_->OnIceTransportRouteChange(channel->name(), route); } -void IceTransport::OnTransportFailed(IceTransportChannel* channel) { - event_handler_->OnTransportError(CHANNEL_CONNECTION_ERROR); +void IceTransport::OnChannelFailed(IceTransportChannel* channel) { + event_handler_->OnIceTransportError(CHANNEL_CONNECTION_ERROR); } -void IceTransport::OnTransportDeleted(IceTransportChannel* channel) { +void IceTransport::OnChannelDeleted(IceTransportChannel* channel) { ChannelsMap::iterator it = channels_.find(channel->name()); DCHECK_EQ(it->second, channel); channels_.erase(it); @@ -181,19 +184,11 @@ void IceTransport::EnsurePendingTransportInfoMessage() { void IceTransport::SendTransportInfo() { DCHECK(pending_transport_info_message_); - event_handler_->OnOutgoingTransportInfo( - pending_transport_info_message_->ToXml()); - pending_transport_info_message_.reset(); -} - -IceTransportFactory::IceTransportFactory( - scoped_refptr<TransportContext> transport_context) - : transport_context_(transport_context) {} -IceTransportFactory::~IceTransportFactory() {} - -scoped_ptr<Transport> IceTransportFactory::CreateTransport() { - return make_scoped_ptr(new IceTransport(transport_context_.get())); + scoped_ptr<buzz::XmlElement> transport_info_xml = + pending_transport_info_message_->ToXml(); + pending_transport_info_message_.reset(); + send_transport_info_callback_.Run(std::move(transport_info_xml)); } } // namespace protocol diff --git a/remoting/protocol/ice_transport.h b/remoting/protocol/ice_transport.h index fc72b29..60e4211 100644 --- a/remoting/protocol/ice_transport.h +++ b/remoting/protocol/ice_transport.h @@ -27,16 +27,28 @@ class IceTransport : public Transport, public IceTransportChannel::Delegate, public DatagramChannelFactory { public: + class EventHandler { + public: + // Called when transport route changes. + virtual void OnIceTransportRouteChange(const std::string& channel_name, + const TransportRoute& route) = 0; + + // Called when there is an error connecting the session. + virtual void OnIceTransportError(ErrorCode error) = 0; + }; + // |transport_context| must outlive the session. - explicit IceTransport(scoped_refptr<TransportContext> transport_context); + IceTransport(scoped_refptr<TransportContext> transport_context, + EventHandler* event_handler); ~IceTransport() override; + StreamChannelFactory* GetStreamChannelFactory(); + StreamChannelFactory* GetMultiplexedChannelFactory(); + // Transport interface. - void Start(EventHandler* event_handler, - Authenticator* authenticator) override; + void Start(Authenticator* authenticator, + SendTransportInfoCallback send_transport_info_callback) override; bool ProcessTransportInfo(buzz::XmlElement* transport_info) override; - StreamChannelFactory* GetStreamChannelFactory() override; - StreamChannelFactory* GetMultiplexedChannelFactory() override; private: typedef std::map<std::string, IceTransportChannel*> ChannelsMap; @@ -51,15 +63,15 @@ class IceTransport : public Transport, void AddPendingRemoteTransportInfo(IceTransportChannel* channel); // IceTransportChannel::Delegate interface. - void OnTransportIceCredentials(IceTransportChannel* transport, - const std::string& ufrag, - const std::string& password) override; - void OnTransportCandidate(IceTransportChannel* transport, - const cricket::Candidate& candidate) override; - void OnTransportRouteChange(IceTransportChannel* transport, - const TransportRoute& route) override; - void OnTransportFailed(IceTransportChannel* transport) override; - void OnTransportDeleted(IceTransportChannel* transport) override; + void OnChannelIceCredentials(IceTransportChannel* transport, + const std::string& ufrag, + const std::string& password) override; + void OnChannelCandidate(IceTransportChannel* transport, + const cricket::Candidate& candidate) override; + void OnChannelRouteChange(IceTransportChannel* transport, + const TransportRoute& route) override; + void OnChannelFailed(IceTransportChannel* transport) override; + void OnChannelDeleted(IceTransportChannel* transport) override; // Creates empty |pending_transport_info_message_| and schedules timer for // SentTransportInfo() to sent the message later. @@ -69,8 +81,9 @@ class IceTransport : public Transport, void SendTransportInfo(); scoped_refptr<TransportContext> transport_context_; + EventHandler* event_handler_; - Transport::EventHandler* event_handler_ = nullptr; + SendTransportInfoCallback send_transport_info_callback_; ChannelsMap channels_; scoped_ptr<PseudoTcpChannelFactory> pseudotcp_channel_factory_; @@ -90,20 +103,6 @@ class IceTransport : public Transport, DISALLOW_COPY_AND_ASSIGN(IceTransport); }; -class IceTransportFactory : public TransportFactory { - public: - IceTransportFactory(scoped_refptr<TransportContext> transport_context); - ~IceTransportFactory() override; - - // TransportFactory interface. - scoped_ptr<Transport> CreateTransport() override; - - private: - scoped_refptr<TransportContext> transport_context_; - - DISALLOW_COPY_AND_ASSIGN(IceTransportFactory); -}; - } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/ice_transport_channel.cc b/remoting/protocol/ice_transport_channel.cc index 9774aee..45919ab 100644 --- a/remoting/protocol/ice_transport_channel.cc +++ b/remoting/protocol/ice_transport_channel.cc @@ -61,7 +61,7 @@ IceTransportChannel::IceTransportChannel( IceTransportChannel::~IceTransportChannel() { DCHECK(delegate_); - delegate_->OnTransportDeleted(this); + delegate_->OnChannelDeleted(this); auto task_runner = base::ThreadTaskRunnerHandle::Get(); if (channel_) @@ -103,8 +103,8 @@ void IceTransportChannel::OnPortAllocatorCreated( channel_->SetIceRole((transport_context_->role() == TransportRole::CLIENT) ? cricket::ICEROLE_CONTROLLING : cricket::ICEROLE_CONTROLLED); - delegate_->OnTransportIceCredentials(this, ice_username_fragment_, - ice_password); + delegate_->OnChannelIceCredentials(this, ice_username_fragment_, + ice_password); channel_->SetIceCredentials(ice_username_fragment_, ice_password); channel_->SignalCandidateGathered.connect( this, &IceTransportChannel::OnCandidateGathered); @@ -191,7 +191,7 @@ void IceTransportChannel::OnCandidateGathered( cricket::TransportChannelImpl* channel, const cricket::Candidate& candidate) { DCHECK(thread_checker_.CalledOnValidThread()); - delegate_->OnTransportCandidate(this, candidate); + delegate_->OnChannelCandidate(this, candidate); } void IceTransportChannel::OnRouteChange( @@ -261,7 +261,7 @@ void IceTransportChannel::NotifyRouteChanged() { LOG(FATAL) << "Failed to convert local IP address."; } - delegate_->OnTransportRouteChange(this, route); + delegate_->OnChannelRouteChange(this, route); } void IceTransportChannel::TryReconnect() { @@ -272,15 +272,15 @@ void IceTransportChannel::TryReconnect() { // Notify the caller that ICE connection has failed - normally that will // terminate Jingle connection (i.e. the transport will be destroyed). - delegate_->OnTransportFailed(this); + delegate_->OnChannelFailed(this); return; } --connect_attempts_left_; // Restart ICE by resetting ICE password. std::string ice_password = rtc::CreateRandomString(cricket::ICE_PWD_LENGTH); - delegate_->OnTransportIceCredentials(this, ice_username_fragment_, - ice_password); + delegate_->OnChannelIceCredentials(this, ice_username_fragment_, + ice_password); channel_->SetIceCredentials(ice_username_fragment_, ice_password); } diff --git a/remoting/protocol/ice_transport_channel.h b/remoting/protocol/ice_transport_channel.h index 7bde3e5..08acad9 100644 --- a/remoting/protocol/ice_transport_channel.h +++ b/remoting/protocol/ice_transport_channel.h @@ -37,26 +37,26 @@ class IceTransportChannel : public sigslot::has_slots<> { // Called to pass ICE credentials to the session. Used only for STANDARD // version of ICE, see SetIceVersion(). - virtual void OnTransportIceCredentials(IceTransportChannel* transport, + virtual void OnChannelIceCredentials(IceTransportChannel* transport, const std::string& ufrag, const std::string& password) = 0; // Called when the transport generates a new candidate that needs // to be passed to the AddRemoteCandidate() method on the remote // end of the connection. - virtual void OnTransportCandidate(IceTransportChannel* transport, + virtual void OnChannelCandidate(IceTransportChannel* transport, const cricket::Candidate& candidate) = 0; // Called when transport route changes. Can be called even before // the transport is connected. - virtual void OnTransportRouteChange(IceTransportChannel* transport, + virtual void OnChannelRouteChange(IceTransportChannel* transport, const TransportRoute& route) = 0; - // Called when when the transport has failed to connect or reconnect. - virtual void OnTransportFailed(IceTransportChannel* transport) = 0; + // Called when when the channel has failed to connect or reconnect. + virtual void OnChannelFailed(IceTransportChannel* transport) = 0; - // Called when the transport is about to be deleted. - virtual void OnTransportDeleted(IceTransportChannel* transport) = 0; + // Called when the channel is about to be deleted. + virtual void OnChannelDeleted(IceTransportChannel* transport) = 0; }; typedef base::Callback<void(scoped_ptr<P2PDatagramSocket>)> ConnectedCallback; diff --git a/remoting/protocol/ice_transport_unittest.cc b/remoting/protocol/ice_transport_unittest.cc index c98a1dc..212c1f3 100644 --- a/remoting/protocol/ice_transport_unittest.cc +++ b/remoting/protocol/ice_transport_unittest.cc @@ -21,7 +21,6 @@ #include "remoting/protocol/p2p_stream_socket.h" #include "remoting/protocol/stream_channel_factory.h" #include "remoting/protocol/transport_context.h" -#include "remoting/signaling/fake_signal_strategy.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/webrtc/libjingle/xmllite/xmlelement.h" @@ -55,43 +54,25 @@ class MockChannelCreatedCallback { MOCK_METHOD1(OnDone, void(P2PStreamSocket* socket)); }; -class TestTransportEventHandler : public Transport::EventHandler { +class TestTransportEventHandler : public IceTransport::EventHandler { public: - typedef base::Callback<void(scoped_ptr<buzz::XmlElement> message)> - TransportInfoCallback; typedef base::Callback<void(ErrorCode error)> ErrorCallback; TestTransportEventHandler() {} ~TestTransportEventHandler() {} - // Both callback must be set before the test handler is passed to a Transport - // object. - void set_transport_info_callback(const TransportInfoCallback& callback) { - transport_info_callback_ = callback; - } - void set_connected_callback(const base::Closure& callback) { - connected_callback_ = callback; - } void set_error_callback(const ErrorCallback& callback) { error_callback_ = callback; } - // Transport::EventHandler interface. - void OnOutgoingTransportInfo(scoped_ptr<buzz::XmlElement> message) override { - transport_info_callback_.Run(std::move(message)); - } - void OnTransportRouteChange(const std::string& channel_name, + // IceTransport::EventHandler interface. + void OnIceTransportRouteChange(const std::string& channel_name, const TransportRoute& route) override {} - void OnTransportConnected() override { - connected_callback_.Run(); - } - void OnTransportError(ErrorCode error) override { + void OnIceTransportError(ErrorCode error) override { error_callback_.Run(error); } private: - TransportInfoCallback transport_info_callback_; - base::Closure connected_callback_; ErrorCallback error_callback_; DISALLOW_COPY_AND_ASSIGN(TestTransportEventHandler); @@ -132,43 +113,36 @@ class IceTransportTest : public testing::Test { } void InitializeConnection() { - host_transport_.reset(new IceTransport(new TransportContext( - signal_strategy_.get(), - make_scoped_ptr(new ChromiumPortAllocatorFactory(nullptr)), - network_settings_, TransportRole::SERVER))); + host_transport_.reset( + new IceTransport(TransportContext::ForTests(TransportRole::SERVER), + &host_event_handler_)); if (!host_authenticator_) { host_authenticator_.reset(new FakeAuthenticator( FakeAuthenticator::HOST, 0, FakeAuthenticator::ACCEPT, true)); } - client_transport_.reset(new IceTransport(new TransportContext( - signal_strategy_.get(), - make_scoped_ptr(new ChromiumPortAllocatorFactory(nullptr)), - network_settings_, TransportRole::CLIENT))); + client_transport_.reset( + new IceTransport(TransportContext::ForTests(TransportRole::CLIENT), + &client_event_handler_)); if (!client_authenticator_) { client_authenticator_.reset(new FakeAuthenticator( FakeAuthenticator::CLIENT, 0, FakeAuthenticator::ACCEPT, true)); } - // Connect signaling between the two IceTransport objects. - host_event_handler_.set_transport_info_callback( - base::Bind(&IceTransportTest::ProcessTransportInfo, - base::Unretained(this), &client_transport_)); - client_event_handler_.set_transport_info_callback( - base::Bind(&IceTransportTest::ProcessTransportInfo, - base::Unretained(this), &host_transport_)); - - host_event_handler_.set_connected_callback(base::Bind(&base::DoNothing)); host_event_handler_.set_error_callback(base::Bind( &IceTransportTest::OnTransportError, base::Unretained(this))); - - client_event_handler_.set_connected_callback(base::Bind(&base::DoNothing)); client_event_handler_.set_error_callback(base::Bind( &IceTransportTest::OnTransportError, base::Unretained(this))); - host_transport_->Start(&host_event_handler_, host_authenticator_.get()); - client_transport_->Start(&client_event_handler_, - client_authenticator_.get()); + // Start both transports. + host_transport_->Start( + host_authenticator_.get(), + base::Bind(&IceTransportTest::ProcessTransportInfo, + base::Unretained(this), &client_transport_)); + client_transport_->Start( + client_authenticator_.get(), + base::Bind(&IceTransportTest::ProcessTransportInfo, + base::Unretained(this), &host_transport_)); } void WaitUntilConnected() { @@ -207,8 +181,6 @@ class IceTransportTest : public testing::Test { NetworkSettings network_settings_; - scoped_ptr<FakeSignalStrategy> signal_strategy_; - base::TimeDelta transport_info_delay_; scoped_ptr<IceTransport> host_transport_; diff --git a/remoting/protocol/jingle_session.cc b/remoting/protocol/jingle_session.cc index 319f09a..0576489 100644 --- a/remoting/protocol/jingle_session.cc +++ b/remoting/protocol/jingle_session.cc @@ -22,6 +22,7 @@ #include "remoting/protocol/jingle_messages.h" #include "remoting/protocol/jingle_session_manager.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/transport.h" #include "remoting/signaling/iq_sender.h" #include "third_party/webrtc/libjingle/xmllite/xmlelement.h" #include "third_party/webrtc/p2p/base/candidate.h" @@ -69,8 +70,6 @@ JingleSession::JingleSession(JingleSessionManager* session_manager) } JingleSession::~JingleSession() { - transport_.reset(); - STLDeleteContainerPointers(pending_requests_.begin(), pending_requests_.end()); STLDeleteContainerPointers(transport_info_requests_.begin(), @@ -80,19 +79,19 @@ JingleSession::~JingleSession() { } void JingleSession::SetEventHandler(Session::EventHandler* event_handler) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(event_handler); event_handler_ = event_handler; } ErrorCode JingleSession::error() { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); return error_; } void JingleSession::StartConnection(const std::string& peer_jid, scoped_ptr<Authenticator> authenticator) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(authenticator.get()); DCHECK_EQ(authenticator->state(), Authenticator::MESSAGE_READY); @@ -106,8 +105,6 @@ void JingleSession::StartConnection(const std::string& peer_jid, session_id_ = base::Uint64ToString( base::RandGenerator(std::numeric_limits<uint64_t>::max())); - transport_ = session_manager_->transport_factory_->CreateTransport(); - // Send session-initiate message. JingleMessage message(peer_jid_, JingleMessage::SESSION_INITIATE, session_id_); @@ -123,7 +120,7 @@ void JingleSession::StartConnection(const std::string& peer_jid, void JingleSession::InitializeIncomingConnection( const JingleMessage& initiate_message, scoped_ptr<Authenticator> authenticator) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(initiate_message.description.get()); DCHECK(authenticator.get()); DCHECK_EQ(authenticator->state(), Authenticator::WAITING_MESSAGE); @@ -143,8 +140,6 @@ void JingleSession::InitializeIncomingConnection( Close(INCOMPATIBLE_PROTOCOL); return; } - - transport_ = session_manager_->transport_factory_->CreateTransport(); } void JingleSession::AcceptIncomingConnection( @@ -200,24 +195,43 @@ void JingleSession::ContinueAcceptIncomingConnection() { } const std::string& JingleSession::jid() { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); return peer_jid_; } const SessionConfig& JingleSession::config() { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); return *config_; } -Transport* JingleSession::GetTransport() { - DCHECK(CalledOnValidThread()); - return transport_.get(); +void JingleSession::SetTransport(Transport* transport) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(!transport_); + DCHECK(transport); + transport_ = transport; } -void JingleSession::Close(protocol::ErrorCode error) { - DCHECK(CalledOnValidThread()); +void JingleSession::SendTransportInfo( + scoped_ptr<buzz::XmlElement> transport_info) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_EQ(state_, AUTHENTICATED); + + JingleMessage message(peer_jid_, JingleMessage::TRANSPORT_INFO, session_id_); + message.transport_info = std::move(transport_info); + + scoped_ptr<IqRequest> request = session_manager_->iq_sender()->SendIq( + message.ToXml(), base::Bind(&JingleSession::OnTransportInfoResponse, + base::Unretained(this))); + if (request) { + request->SetTimeout(base::TimeDelta::FromSeconds(kTransportInfoTimeout)); + transport_info_requests_.push_back(request.release()); + } else { + LOG(ERROR) << "Failed to send a transport-info message"; + } +} - transport_.reset(); +void JingleSession::Close(protocol::ErrorCode error) { + DCHECK(thread_checker_.CalledOnValidThread()); if (is_session_active()) { // Send session-terminate message with the appropriate error code. @@ -264,7 +278,7 @@ void JingleSession::Close(protocol::ErrorCode error) { } void JingleSession::SendMessage(const JingleMessage& message) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); scoped_ptr<IqRequest> request = session_manager_->iq_sender()->SendIq( message.ToXml(), @@ -290,7 +304,7 @@ void JingleSession::OnMessageResponse( JingleMessage::ActionType request_type, IqRequest* request, const buzz::XmlElement* response) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); // Delete the request from the list of pending requests. pending_requests_.erase(request); @@ -322,46 +336,9 @@ void JingleSession::OnMessageResponse( } } -void JingleSession::OnOutgoingTransportInfo( - scoped_ptr<XmlElement> transport_info) { - DCHECK(CalledOnValidThread()); - - JingleMessage message(peer_jid_, JingleMessage::TRANSPORT_INFO, session_id_); - message.transport_info = std::move(transport_info); - - scoped_ptr<IqRequest> request = session_manager_->iq_sender()->SendIq( - message.ToXml(), base::Bind(&JingleSession::OnTransportInfoResponse, - base::Unretained(this))); - if (request) { - request->SetTimeout(base::TimeDelta::FromSeconds(kTransportInfoTimeout)); - transport_info_requests_.push_back(request.release()); - } else { - LOG(ERROR) << "Failed to send a transport-info message"; - } -} - -void JingleSession::OnTransportRouteChange(const std::string& channel_name, - const TransportRoute& route) { - DCHECK(CalledOnValidThread()); - - event_handler_->OnSessionRouteChange(channel_name, route); -} - -void JingleSession::OnTransportConnected() { - DCHECK(CalledOnValidThread()); - DCHECK_EQ(state_, AUTHENTICATED); - SetState(CONNECTED); -} - -void JingleSession::OnTransportError(ErrorCode error) { - DCHECK(CalledOnValidThread()); - - Close(error); -} - void JingleSession::OnTransportInfoResponse(IqRequest* request, const buzz::XmlElement* response) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(!transport_info_requests_.empty()); // Consider transport-info requests sent before this one lost and delete @@ -392,7 +369,7 @@ void JingleSession::OnTransportInfoResponse(IqRequest* request, void JingleSession::OnIncomingMessage(const JingleMessage& message, const ReplyCallback& reply_callback) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); if (message.from != peer_jid_) { // Ignore messages received from a different Jid. @@ -410,12 +387,20 @@ void JingleSession::OnIncomingMessage(const JingleMessage& message, break; case JingleMessage::TRANSPORT_INFO: - if (message.transport_info && - transport_->ProcessTransportInfo(message.transport_info.get())) { + if (!transport_) { + LOG(ERROR) << "Received unexpected transport-info message."; reply_callback.Run(JingleMessageReply::NONE); - } else { + return; + } + + if (!message.transport_info || + !transport_->ProcessTransportInfo( + message.transport_info.get())) { reply_callback.Run(JingleMessageReply::BAD_REQUEST); + return; } + + reply_callback.Run(JingleMessageReply::NONE); break; case JingleMessage::SESSION_TERMINATE: @@ -543,7 +528,7 @@ bool JingleSession::InitializeConfigFromDescription( } void JingleSession::ProcessAuthenticationStep() { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); DCHECK_NE(authenticator_->state(), Authenticator::PROCESSING_MESSAGE); if (state_ != ACCEPTED && state_ != AUTHENTICATING) { @@ -585,13 +570,15 @@ void JingleSession::ContinueAuthenticationStep() { } void JingleSession::OnAuthenticated() { - transport_->Start(this, authenticator_.get()); + transport_->Start(authenticator_.get(), + base::Bind(&JingleSession::SendTransportInfo, + weak_factory_.GetWeakPtr())); SetState(AUTHENTICATED); } void JingleSession::SetState(State new_state) { - DCHECK(CalledOnValidThread()); + DCHECK(thread_checker_.CalledOnValidThread()); if (new_state != state_) { DCHECK_NE(state_, CLOSED); diff --git a/remoting/protocol/jingle_session.h b/remoting/protocol/jingle_session.h index ba1172d..87a98c7 100644 --- a/remoting/protocol/jingle_session.h +++ b/remoting/protocol/jingle_session.h @@ -11,7 +11,7 @@ #include "base/macros.h" #include "base/memory/ref_counted.h" -#include "base/threading/non_thread_safe.h" +#include "base/threading/thread_checker.h" #include "base/timer/timer.h" #include "crypto/rsa_private_key.h" #include "net/base/completion_callback.h" @@ -20,20 +20,18 @@ #include "remoting/protocol/jingle_messages.h" #include "remoting/protocol/session.h" #include "remoting/protocol/session_config.h" -#include "remoting/protocol/transport.h" #include "remoting/signaling/iq_sender.h" namespace remoting { namespace protocol { class JingleSessionManager; +class Transport; // JingleSessionManager and JingleSession implement the subset of the // Jingle protocol used in Chromoting. Instances of this class are // created by the JingleSessionManager. -class JingleSession : public base::NonThreadSafe, - public Session, - public Transport::EventHandler { +class JingleSession : public Session { public: ~JingleSession() override; @@ -42,7 +40,7 @@ class JingleSession : public base::NonThreadSafe, ErrorCode error() override; const std::string& jid() override; const SessionConfig& config() override; - Transport* GetTransport() override; + void SetTransport(Transport* transport) override; void Close(protocol::ErrorCode error) override; private: @@ -61,6 +59,9 @@ class JingleSession : public base::NonThreadSafe, scoped_ptr<Authenticator> authenticator); void AcceptIncomingConnection(const JingleMessage& initiate_message); + // Callback for Transport interface to send transport-info messages. + void SendTransportInfo(scoped_ptr<buzz::XmlElement> transport_info); + // Sends |message| to the peer. The session is closed if the send fails or no // response is received within a reasonable time. All other responses are // ignored. @@ -71,14 +72,6 @@ class JingleSession : public base::NonThreadSafe, IqRequest* request, const buzz::XmlElement* response); - // Transport::EventHandler interface. - void OnOutgoingTransportInfo( - scoped_ptr<buzz::XmlElement> transport_info) override; - void OnTransportRouteChange(const std::string& component, - const TransportRoute& route) override; - void OnTransportConnected() override; - void OnTransportError(ErrorCode error) override; - // Response handler for transport-info responses. Transport-info timeouts are // ignored and don't terminate connection. void OnTransportInfoResponse(IqRequest* request, @@ -119,6 +112,8 @@ class JingleSession : public base::NonThreadSafe, // Returns true if the state of the session is not CLOSED or FAILED bool is_session_active(); + base::ThreadChecker thread_checker_; + JingleSessionManager* session_manager_; std::string peer_jid_; Session::EventHandler* event_handler_; @@ -131,7 +126,7 @@ class JingleSession : public base::NonThreadSafe, scoped_ptr<Authenticator> authenticator_; - scoped_ptr<Transport> transport_; + Transport* transport_ = nullptr; // Pending Iq requests. Used for all messages except transport-info. std::set<IqRequest*> pending_requests_; diff --git a/remoting/protocol/jingle_session_manager.cc b/remoting/protocol/jingle_session_manager.cc index 7319235..f3c53b9 100644 --- a/remoting/protocol/jingle_session_manager.cc +++ b/remoting/protocol/jingle_session_manager.cc @@ -22,11 +22,8 @@ using buzz::QName; namespace remoting { namespace protocol { -JingleSessionManager::JingleSessionManager( - scoped_ptr<TransportFactory> transport_factory, - SignalStrategy* signal_strategy) - : transport_factory_(std::move(transport_factory)), - signal_strategy_(signal_strategy), +JingleSessionManager::JingleSessionManager(SignalStrategy* signal_strategy) + : signal_strategy_(signal_strategy), protocol_config_(CandidateSessionConfig::CreateDefault()), iq_sender_(new IqSender(signal_strategy_)) { signal_strategy_->AddListener(this); @@ -40,7 +37,6 @@ JingleSessionManager::~JingleSessionManager() { void JingleSessionManager::AcceptIncoming( const IncomingSessionCallback& incoming_session_callback) { incoming_session_callback_ = incoming_session_callback; - } void JingleSessionManager::set_protocol_config( diff --git a/remoting/protocol/jingle_session_manager.h b/remoting/protocol/jingle_session_manager.h index f52f7c5..3ac1906 100644 --- a/remoting/protocol/jingle_session_manager.h +++ b/remoting/protocol/jingle_session_manager.h @@ -33,8 +33,7 @@ class TransportFactory; class JingleSessionManager : public SessionManager, public SignalStrategy::Listener { public: - JingleSessionManager(scoped_ptr<TransportFactory> transport_factory, - SignalStrategy* signal_strategy); + explicit JingleSessionManager(SignalStrategy* signal_strategy); ~JingleSessionManager() override; // SessionManager interface. @@ -63,8 +62,7 @@ class JingleSessionManager : public SessionManager, // Called by JingleSession when it is being destroyed. void SessionDestroyed(JingleSession* session); - scoped_ptr<TransportFactory> transport_factory_; - SignalStrategy* signal_strategy_; + SignalStrategy* signal_strategy_ = nullptr; IncomingSessionCallback incoming_session_callback_; scoped_ptr<CandidateSessionConfig> protocol_config_; diff --git a/remoting/protocol/jingle_session_unittest.cc b/remoting/protocol/jingle_session_unittest.cc index 171a1a1..adb1fc4 100644 --- a/remoting/protocol/jingle_session_unittest.cc +++ b/remoting/protocol/jingle_session_unittest.cc @@ -11,7 +11,6 @@ #include "base/run_loop.h" #include "base/test/test_timeouts.h" #include "base/time/time.h" -#include "jingle/glue/thread_wrapper.h" #include "net/socket/socket.h" #include "net/socket/stream_socket.h" #include "net/url_request/url_request_context_getter.h" @@ -21,9 +20,9 @@ #include "remoting/protocol/chromium_port_allocator.h" #include "remoting/protocol/connection_tester.h" #include "remoting/protocol/fake_authenticator.h" -#include "remoting/protocol/ice_transport.h" #include "remoting/protocol/jingle_session_manager.h" #include "remoting/protocol/network_settings.h" +#include "remoting/protocol/transport.h" #include "remoting/protocol/transport_context.h" #include "remoting/signaling/fake_signal_strategy.h" #include "testing/gmock/include/gmock/gmock.h" @@ -64,13 +63,20 @@ class MockSessionEventHandler : public Session::EventHandler { const TransportRoute& route)); }; +class MockTransport : public Transport { + public: + MOCK_METHOD2(Start, + void(Authenticator* authenticator, + SendTransportInfoCallback send_transport_info_callback)); + MOCK_METHOD1(ProcessTransportInfo, bool(buzz::XmlElement* transport_info)); +}; + } // namespace class JingleSessionTest : public testing::Test { public: JingleSessionTest() { message_loop_.reset(new base::MessageLoopForIO()); - jingle_glue::JingleThreadWrapper::EnsureForCurrentMessageLoop(); network_settings_ = NetworkSettings(NetworkSettings::NAT_TRAVERSAL_OUTGOING); } @@ -80,6 +86,7 @@ class JingleSessionTest : public testing::Test { DCHECK(session); host_session_.reset(session); host_session_->SetEventHandler(&host_session_event_handler_); + host_session_->SetTransport(&host_transport_); } void DeleteSession() { @@ -105,11 +112,7 @@ class JingleSessionTest : public testing::Test { FakeSignalStrategy::Connect(host_signal_strategy_.get(), client_signal_strategy_.get()); - host_server_.reset(new JingleSessionManager( - make_scoped_ptr(new IceTransportFactory(new TransportContext( - nullptr, make_scoped_ptr(new ChromiumPortAllocatorFactory(nullptr)), - network_settings_, TransportRole::SERVER))), - host_signal_strategy_.get())); + host_server_.reset(new JingleSessionManager(host_signal_strategy_.get())); host_server_->AcceptIncoming( base::Bind(&MockSessionManagerListener::OnIncomingSession, base::Unretained(&host_server_listener_))); @@ -119,11 +122,8 @@ class JingleSessionTest : public testing::Test { messages_till_start, auth_action, true)); host_server_->set_authenticator_factory(std::move(factory)); - client_server_.reset(new JingleSessionManager( - make_scoped_ptr(new IceTransportFactory(new TransportContext( - nullptr, make_scoped_ptr(new ChromiumPortAllocatorFactory(nullptr)), - network_settings_, TransportRole::CLIENT))), - client_signal_strategy_.get())); + client_server_.reset( + new JingleSessionManager(client_signal_strategy_.get())); } void CreateSessionManagers(int auth_round_trips, @@ -160,9 +160,11 @@ class JingleSessionTest : public testing::Test { OnSessionStateChange(Session::FAILED)) .Times(1); } else { + EXPECT_CALL(host_transport_, Start(_, _)).Times(1); EXPECT_CALL(host_session_event_handler_, OnSessionStateChange(Session::AUTHENTICATED)) .Times(1); + // Expect that the connection will be closed eventually. EXPECT_CALL(host_session_event_handler_, OnSessionStateChange(Session::CLOSED)) @@ -184,9 +186,11 @@ class JingleSessionTest : public testing::Test { OnSessionStateChange(Session::FAILED)) .Times(1); } else { + EXPECT_CALL(client_transport_, Start(_, _)).Times(1); EXPECT_CALL(client_session_event_handler_, OnSessionStateChange(Session::AUTHENTICATED)) .Times(1); + // Expect that the connection will be closed eventually. EXPECT_CALL(client_session_event_handler_, OnSessionStateChange(Session::CLOSED)) @@ -200,6 +204,7 @@ class JingleSessionTest : public testing::Test { client_session_ = client_server_->Connect(kHostJid, std::move(authenticator)); client_session_->SetEventHandler(&client_session_event_handler_); + client_session_->SetTransport(&client_transport_); base::RunLoop().RunUntilIdle(); } @@ -226,8 +231,10 @@ class JingleSessionTest : public testing::Test { scoped_ptr<Session> host_session_; MockSessionEventHandler host_session_event_handler_; + MockTransport host_transport_; scoped_ptr<Session> client_session_; MockSessionEventHandler client_session_event_handler_; + MockTransport client_transport_; }; diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index 0351ed6..0958396 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -168,17 +168,9 @@ class MockSession : public Session { MOCK_METHOD1(SetEventHandler, void(Session::EventHandler* event_handler)); MOCK_METHOD0(error, ErrorCode()); - MOCK_METHOD0(GetTransport, Transport*()); - MOCK_METHOD0(GetQuicChannelFactory, StreamChannelFactory*()); + MOCK_METHOD1(SetTransport, void(Transport*)); MOCK_METHOD0(jid, const std::string&()); - MOCK_METHOD0(candidate_config, const CandidateSessionConfig*()); MOCK_METHOD0(config, const SessionConfig&()); - MOCK_METHOD0(initiator_token, const std::string&()); - MOCK_METHOD1(set_initiator_token, void(const std::string& initiator_token)); - MOCK_METHOD0(receiver_token, const std::string&()); - MOCK_METHOD1(set_receiver_token, void(const std::string& receiver_token)); - MOCK_METHOD1(set_shared_secret, void(const std::string& secret)); - MOCK_METHOD0(shared_secret, const std::string&()); MOCK_METHOD1(Close, void(ErrorCode error)); private: diff --git a/remoting/protocol/session.h b/remoting/protocol/session.h index a87375f..917dd80 100644 --- a/remoting/protocol/session.h +++ b/remoting/protocol/session.h @@ -10,17 +10,23 @@ #include "base/macros.h" #include "remoting/protocol/errors.h" #include "remoting/protocol/session_config.h" +#include "remoting/protocol/transport.h" + +namespace buzz { +class XmlElement; +} // namespace buzz namespace remoting { namespace protocol { +class Authenicator; class StreamChannelFactory; class Transport; struct TransportRoute; -// Generic interface for Chromotocol connection used by both client and host. -// Provides access to the connection channels, but doesn't depend on the -// protocol used for each channel. +// Session is responsible for initializing and authenticating both incoming and +// outgoing connections. It uses TransportInfoSink interface to pass +// transport-info messages to the transport. class Session { public: enum State { @@ -42,9 +48,6 @@ class Session { // Session has been connected and authenticated. AUTHENTICATED, - // Session has been connected. - CONNECTED, - // Session has been closed. CLOSED, @@ -61,12 +64,6 @@ class Session { // the session from within the handler if |state| is AUTHENTICATING // or CLOSED or FAILED. virtual void OnSessionStateChange(State state) = 0; - - // Called whenever route for the channel specified with - // |channel_name| changes. Session must not be destroyed by the - // handler of this event. - virtual void OnSessionRouteChange(const std::string& channel_name, - const TransportRoute& route) = 0; }; Session() {} @@ -86,10 +83,11 @@ class Session { // Returned pointer is valid until connection is closed. virtual const SessionConfig& config() = 0; - // Returns Transport that can be used to create transport channels. - virtual Transport* GetTransport() = 0; + // Sets Transport to be used by the session. Must be called before the + // session becomes AUTHENTICATED. The transport must outlive the session. + virtual void SetTransport(Transport* transport) = 0; - // Closes connection. Callbacks are guaranteed not to be called after this + // Closes connection. EventHandler is guaranteed not to be called after this // method returns. |error| specifies the error code in case when the session // is being closed due to an error. virtual void Close(ErrorCode error) = 0; diff --git a/remoting/protocol/transport.cc b/remoting/protocol/transport.cc index 41eef83..e122ca0 100644 --- a/remoting/protocol/transport.cc +++ b/remoting/protocol/transport.cc @@ -26,9 +26,5 @@ std::string TransportRoute::GetTypeString(RouteType type) { TransportRoute::TransportRoute() : type(DIRECT) {} TransportRoute::~TransportRoute() {} -WebrtcTransport* Transport::AsWebrtcTransport() { - return nullptr; -} - } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/transport.h b/remoting/protocol/transport.h index 99aab7b..c9f24d3 100644 --- a/remoting/protocol/transport.h +++ b/remoting/protocol/transport.h @@ -14,10 +14,6 @@ #include "net/base/ip_endpoint.h" #include "remoting/protocol/errors.h" -namespace cricket { -class Candidate; -} // namespace cricket - namespace buzz { class XmlElement; } // namespace buzz @@ -33,7 +29,6 @@ class Authenticator; class DatagramChannelFactory; class P2PDatagramSocket; class StreamChannelFactory; -class WebrtcTransport; enum class TransportRole { SERVER, @@ -58,68 +53,22 @@ struct TransportRoute { net::IPEndPoint local_address; }; -// Transport represents a P2P connection that consists of one or more -// channels. +// Transport represents a P2P connection that consists of one or more channels. +// This interface is used just to send and receive transport-info messages. +// Implementations should provide other methods to send and receive data. class Transport { public: - class EventHandler { - public: - // Called to send a transport-info message. - virtual void OnOutgoingTransportInfo( - scoped_ptr<buzz::XmlElement> message) = 0; - - // Called when transport route changes. - virtual void OnTransportRouteChange(const std::string& channel_name, - const TransportRoute& route) = 0; - - // Called when the transport is connected. - virtual void OnTransportConnected() = 0; - - // Called when there is an error connecting the session. - virtual void OnTransportError(ErrorCode error) = 0; - }; + typedef base::Callback<void(scoped_ptr<buzz::XmlElement> transport_info)> + SendTransportInfoCallback; - Transport() {} virtual ~Transport() {} - // Starts transport session. Both parameters must outlive Transport. - virtual void Start(EventHandler* event_handler, - Authenticator* authenticator) = 0; - - // Called to process incoming transport message. Returns false if - // |transport_info| is in invalid format. + // Sets the object responsible for delivering outgoing transport-info messages + // to the peer. + virtual void Start( + Authenticator* authenticator, + SendTransportInfoCallback send_transport_info_callback) = 0; virtual bool ProcessTransportInfo(buzz::XmlElement* transport_info) = 0; - - // Channel factory for the session that creates stream channels. - virtual StreamChannelFactory* GetStreamChannelFactory() = 0; - - // Returns a factory that creates multiplexed channels over a single stream - // channel. - virtual StreamChannelFactory* GetMultiplexedChannelFactory() = 0; - - // Returns the transport as WebrtcTransport or nullptr if this is not a - // WebrtcTransport. - // - // TODO(sergeyu): Move creation and ownership of Transport objects to the - // Connection classes. That way the Connection classes will be able to ensure - // that correct transport implementation is used for the connection and this - // method will not be necessary. - virtual WebrtcTransport* AsWebrtcTransport(); - - private: - DISALLOW_COPY_AND_ASSIGN(Transport); -}; - -class TransportFactory { - public: - TransportFactory() { } - virtual ~TransportFactory() { } - - // Creates a new Transport. The factory must outlive the session. - virtual scoped_ptr<Transport> CreateTransport() = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(TransportFactory); }; } // namespace protocol diff --git a/remoting/protocol/transport_context.cc b/remoting/protocol/transport_context.cc index 205df73..8fa25ad 100644 --- a/remoting/protocol/transport_context.cc +++ b/remoting/protocol/transport_context.cc @@ -13,12 +13,31 @@ #include "remoting/protocol/port_allocator_factory.h" #include "third_party/webrtc/p2p/client/httpportallocator.h" +#if !defined(OS_NACL) +#include "jingle/glue/thread_wrapper.h" +#include "net/url_request/url_request_context_getter.h" +#include "remoting/protocol/chromium_port_allocator.h" +#endif // !defined(OS_NACL) + namespace remoting { namespace protocol { // Get fresh STUN/Relay configuration every hour. static const int kJingleInfoUpdatePeriodSeconds = 3600; +#if !defined(OS_NACL) +// static +scoped_refptr<TransportContext> TransportContext::ForTests(TransportRole role) { + jingle_glue::JingleThreadWrapper::EnsureForCurrentMessageLoop(); + return new protocol::TransportContext( + nullptr, make_scoped_ptr( + new protocol::ChromiumPortAllocatorFactory(nullptr)), + protocol::NetworkSettings( + protocol::NetworkSettings::NAT_TRAVERSAL_OUTGOING), + role); +} +#endif // !defined(OS_NACL) + TransportContext::TransportContext( SignalStrategy* signal_strategy, scoped_ptr<PortAllocatorFactory> port_allocator_factory, diff --git a/remoting/protocol/transport_context.h b/remoting/protocol/transport_context.h index faac2d2..12a6c3d 100644 --- a/remoting/protocol/transport_context.h +++ b/remoting/protocol/transport_context.h @@ -37,6 +37,8 @@ class TransportContext : public base::RefCountedThreadSafe<TransportContext> { typedef base::Callback<void(scoped_ptr<cricket::PortAllocator> port_allocator)> CreatePortAllocatorCallback; + static scoped_refptr<TransportContext> ForTests(TransportRole role); + TransportContext( SignalStrategy* signal_strategy, scoped_ptr<PortAllocatorFactory> port_allocator_factory, diff --git a/remoting/protocol/webrtc_connection_to_client.cc b/remoting/protocol/webrtc_connection_to_client.cc index 95360a0..9d732e0 100644 --- a/remoting/protocol/webrtc_connection_to_client.cc +++ b/remoting/protocol/webrtc_connection_to_client.cc @@ -8,6 +8,7 @@ #include "base/bind.h" #include "base/location.h" +#include "jingle/glue/thread_wrapper.h" #include "net/base/io_buffer.h" #include "remoting/codec/video_encoder.h" #include "remoting/codec/video_encoder_verbatim.h" @@ -18,6 +19,7 @@ #include "remoting/protocol/host_event_dispatcher.h" #include "remoting/protocol/host_stub.h" #include "remoting/protocol/input_stub.h" +#include "remoting/protocol/transport_context.h" #include "remoting/protocol/webrtc_transport.h" #include "remoting/protocol/webrtc_video_capturer_adapter.h" #include "remoting/protocol/webrtc_video_stream.h" @@ -32,12 +34,21 @@ namespace protocol { const char kStreamLabel[] = "screen_stream"; const char kVideoLabel[] = "screen_video"; +// Currently the network thread is also used as worker thread for webrtc. +// +// TODO(sergeyu): Figure out if we would benefit from using a separate +// thread as a worker thread. WebrtcConnectionToClient::WebrtcConnectionToClient( - scoped_ptr<protocol::Session> session) - : session_(std::move(session)), + scoped_ptr<protocol::Session> session, + scoped_refptr<protocol::TransportContext> transport_context) + : transport_(jingle_glue::JingleThreadWrapper::current(), + transport_context, + this), + session_(std::move(session)), control_dispatcher_(new HostControlDispatcher()), event_dispatcher_(new HostEventDispatcher()) { session_->SetEventHandler(this); + session_->SetTransport(&transport_); } WebrtcConnectionToClient::~WebrtcConnectionToClient() {} @@ -56,9 +67,6 @@ protocol::Session* WebrtcConnectionToClient::session() { void WebrtcConnectionToClient::Disconnect(ErrorCode error) { DCHECK(thread_checker_.CalledOnValidThread()); - control_dispatcher_.reset(); - event_dispatcher_.reset(); - // This should trigger OnConnectionClosed() event and this object // may be destroyed as the result. session_->Close(error); @@ -71,10 +79,6 @@ void WebrtcConnectionToClient::OnInputEventReceived(int64_t timestamp) { scoped_ptr<VideoStream> WebrtcConnectionToClient::StartVideoStream( scoped_ptr<webrtc::DesktopCapturer> desktop_capturer) { - // TODO(sergeyu): Reconsider Transport interface and how it's used here. - WebrtcTransport* transport = session_->GetTransport()->AsWebrtcTransport(); - CHECK(transport); - scoped_ptr<WebrtcVideoCapturerAdapter> video_capturer_adapter( new WebrtcVideoCapturerAdapter(std::move(desktop_capturer))); @@ -84,22 +88,22 @@ scoped_ptr<VideoStream> WebrtcConnectionToClient::StartVideoStream( webrtc::MediaConstraintsInterface::kMinFrameRate, 5); rtc::scoped_refptr<webrtc::VideoTrackInterface> video_track = - transport->peer_connection_factory()->CreateVideoTrack( + transport_.peer_connection_factory()->CreateVideoTrack( kVideoLabel, - transport->peer_connection_factory()->CreateVideoSource( + transport_.peer_connection_factory()->CreateVideoSource( video_capturer_adapter.release(), &video_constraints)); rtc::scoped_refptr<webrtc::MediaStreamInterface> video_stream = - transport->peer_connection_factory()->CreateLocalMediaStream( + transport_.peer_connection_factory()->CreateLocalMediaStream( kStreamLabel); if (!video_stream->AddTrack(video_track) || - !transport->peer_connection()->AddStream(video_stream)) { + !transport_.peer_connection()->AddStream(video_stream)) { return nullptr; } return make_scoped_ptr( - new WebrtcVideoStream(transport->peer_connection(), video_stream)); + new WebrtcVideoStream(transport_.peer_connection(), video_stream)); } AudioStub* WebrtcConnectionToClient::audio_stub() { @@ -145,12 +149,9 @@ void WebrtcConnectionToClient::OnSessionStateChange(Session::State state) { break; case Session::AUTHENTICATED: { // Initialize channels. - control_dispatcher_->Init( - session_->GetTransport()->GetStreamChannelFactory(), - this); + control_dispatcher_->Init(transport_.GetStreamChannelFactory(), this); - event_dispatcher_->Init( - session_->GetTransport()->GetStreamChannelFactory(), this); + event_dispatcher_->Init(transport_.GetStreamChannelFactory(), this); event_dispatcher_->set_on_input_event_callback(base::Bind( &ConnectionToClient::OnInputEventReceived, base::Unretained(this))); @@ -160,10 +161,6 @@ void WebrtcConnectionToClient::OnSessionStateChange(Session::State state) { break; } - case Session::CONNECTED: - event_handler_->OnConnectionChannelsConnected(this); - break; - case Session::CLOSED: case Session::FAILED: control_dispatcher_.reset(); @@ -174,10 +171,13 @@ void WebrtcConnectionToClient::OnSessionStateChange(Session::State state) { } } -void WebrtcConnectionToClient::OnSessionRouteChange( - const std::string& channel_name, - const TransportRoute& route) { - event_handler_->OnRouteChange(this, channel_name, route); +void WebrtcConnectionToClient::OnWebrtcTransportConnected() { + event_handler_->OnConnectionChannelsConnected(this); +} + +void WebrtcConnectionToClient::OnWebrtcTransportError(ErrorCode error) { + DCHECK(thread_checker_.CalledOnValidThread()); + Disconnect(error); } void WebrtcConnectionToClient::OnChannelInitialized( @@ -192,7 +192,7 @@ void WebrtcConnectionToClient::OnChannelError( LOG(ERROR) << "Failed to connect channel " << channel_dispatcher->channel_name(); - session_->Close(CHANNEL_CONNECTION_ERROR); + Disconnect(error); } } // namespace protocol diff --git a/remoting/protocol/webrtc_connection_to_client.h b/remoting/protocol/webrtc_connection_to_client.h index 2902db9..39e2433 100644 --- a/remoting/protocol/webrtc_connection_to_client.h +++ b/remoting/protocol/webrtc_connection_to_client.h @@ -15,6 +15,7 @@ #include "remoting/protocol/channel_dispatcher_base.h" #include "remoting/protocol/connection_to_client.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/webrtc_transport.h" namespace remoting { namespace protocol { @@ -24,9 +25,12 @@ class HostEventDispatcher; class WebrtcConnectionToClient : public ConnectionToClient, public Session::EventHandler, + public WebrtcTransport::EventHandler, public ChannelDispatcherBase::EventHandler { public: - explicit WebrtcConnectionToClient(scoped_ptr<Session> session); + WebrtcConnectionToClient( + scoped_ptr<Session> session, + scoped_refptr<protocol::TransportContext> transport_context); ~WebrtcConnectionToClient() override; // ConnectionToClient interface. @@ -45,8 +49,10 @@ class WebrtcConnectionToClient : public ConnectionToClient, // Session::EventHandler interface. void OnSessionStateChange(Session::State state) override; - void OnSessionRouteChange(const std::string& channel_name, - const TransportRoute& route) override; + + // WebrtcTransport::EventHandler interface + void OnWebrtcTransportConnected() override; + void OnWebrtcTransportError(ErrorCode error) override; // ChannelDispatcherBase::EventHandler interface. void OnChannelInitialized(ChannelDispatcherBase* channel_dispatcher) override; @@ -59,6 +65,8 @@ class WebrtcConnectionToClient : public ConnectionToClient, // Event handler for handling events sent from this object. ConnectionToClient::EventHandler* event_handler_ = nullptr; + WebrtcTransport transport_; + scoped_ptr<Session> session_; scoped_ptr<HostControlDispatcher> control_dispatcher_; diff --git a/remoting/protocol/webrtc_connection_to_host.cc b/remoting/protocol/webrtc_connection_to_host.cc index 113da4d..43b7a4f 100644 --- a/remoting/protocol/webrtc_connection_to_host.cc +++ b/remoting/protocol/webrtc_connection_to_host.cc @@ -6,10 +6,12 @@ #include <utility> +#include "jingle/glue/thread_wrapper.h" #include "remoting/protocol/client_control_dispatcher.h" #include "remoting/protocol/client_event_dispatcher.h" #include "remoting/protocol/client_stub.h" #include "remoting/protocol/clipboard_stub.h" +#include "remoting/protocol/transport_context.h" #include "remoting/protocol/webrtc_transport.h" namespace remoting { @@ -18,13 +20,19 @@ namespace protocol { WebrtcConnectionToHost::WebrtcConnectionToHost() {} WebrtcConnectionToHost::~WebrtcConnectionToHost() {} -void WebrtcConnectionToHost::Connect(scoped_ptr<Session> session, - HostEventCallback* event_callback) { +void WebrtcConnectionToHost::Connect( + scoped_ptr<Session> session, + scoped_refptr<TransportContext> transport_context, + HostEventCallback* event_callback) { DCHECK(client_stub_); DCHECK(clipboard_stub_); + transport_.reset(new WebrtcTransport( + jingle_glue::JingleThreadWrapper::current(), transport_context, this)); + session_ = std::move(session); session_->SetEventHandler(this); + session_->SetTransport(transport_.get()); event_callback_ = event_callback; @@ -72,7 +80,6 @@ void WebrtcConnectionToHost::OnSessionStateChange(Session::State state) { case Session::ACCEPTING: case Session::ACCEPTED: case Session::AUTHENTICATING: - case Session::CONNECTED: // Don't care about these events. break; @@ -80,14 +87,12 @@ void WebrtcConnectionToHost::OnSessionStateChange(Session::State state) { SetState(AUTHENTICATED, OK); control_dispatcher_.reset(new ClientControlDispatcher()); - control_dispatcher_->Init( - session_->GetTransport()->GetStreamChannelFactory(), this); + control_dispatcher_->Init(transport_->GetStreamChannelFactory(), this); control_dispatcher_->set_client_stub(client_stub_); control_dispatcher_->set_clipboard_stub(clipboard_stub_); event_dispatcher_.reset(new ClientEventDispatcher()); - event_dispatcher_->Init( - session_->GetTransport()->GetStreamChannelFactory(), this); + event_dispatcher_->Init(transport_->GetStreamChannelFactory(), this); break; case Session::CLOSED: @@ -98,10 +103,11 @@ void WebrtcConnectionToHost::OnSessionStateChange(Session::State state) { } } -void WebrtcConnectionToHost::OnSessionRouteChange( - const std::string& channel_name, - const TransportRoute& route) { - event_callback_->OnRouteChanged(channel_name, route); +void WebrtcConnectionToHost::OnWebrtcTransportConnected() {} + +void WebrtcConnectionToHost::OnWebrtcTransportError(ErrorCode error) { + CloseChannels(); + SetState(FAILED, error); } void WebrtcConnectionToHost::OnChannelInitialized( @@ -115,7 +121,6 @@ void WebrtcConnectionToHost::OnChannelError( LOG(ERROR) << "Failed to connect channel " << channel_dispatcher; CloseChannels(); SetState(FAILED, CHANNEL_CONNECTION_ERROR); - return; } ConnectionToHost::State WebrtcConnectionToHost::state() const { diff --git a/remoting/protocol/webrtc_connection_to_host.h b/remoting/protocol/webrtc_connection_to_host.h index 3e3e4bc..79b97e0 100644 --- a/remoting/protocol/webrtc_connection_to_host.h +++ b/remoting/protocol/webrtc_connection_to_host.h @@ -15,6 +15,7 @@ #include "remoting/protocol/errors.h" #include "remoting/protocol/input_filter.h" #include "remoting/protocol/session.h" +#include "remoting/protocol/webrtc_transport.h" namespace remoting { namespace protocol { @@ -25,6 +26,7 @@ class SessionConfig; class WebrtcConnectionToHost : public ConnectionToHost, public Session::EventHandler, + public WebrtcTransport::EventHandler, public ChannelDispatcherBase::EventHandler { public: WebrtcConnectionToHost(); @@ -36,6 +38,7 @@ class WebrtcConnectionToHost : public ConnectionToHost, void set_video_stub(VideoStub* video_stub) override; void set_audio_stub(AudioStub* audio_stub) override; void Connect(scoped_ptr<Session> session, + scoped_refptr<TransportContext> transport_context, HostEventCallback* event_callback) override; const SessionConfig& config() override; ClipboardStub* clipboard_forwarder() override; @@ -46,8 +49,10 @@ class WebrtcConnectionToHost : public ConnectionToHost, private: // Session::EventHandler interface. void OnSessionStateChange(Session::State state) override; - void OnSessionRouteChange(const std::string& channel_name, - const TransportRoute& route) override; + + // WebrtcTransport::EventHandler interface. + void OnWebrtcTransportConnected() override; + void OnWebrtcTransportError(ErrorCode error) override; // ChannelDispatcherBase::EventHandler interface. void OnChannelInitialized(ChannelDispatcherBase* channel_dispatcher) override; @@ -67,6 +72,7 @@ class WebrtcConnectionToHost : public ConnectionToHost, ClipboardStub* clipboard_stub_ = nullptr; scoped_ptr<Session> session_; + scoped_ptr<WebrtcTransport> transport_; scoped_ptr<ClientControlDispatcher> control_dispatcher_; scoped_ptr<ClientEventDispatcher> event_dispatcher_; diff --git a/remoting/protocol/webrtc_transport.cc b/remoting/protocol/webrtc_transport.cc index 537557e..59711e0 100644 --- a/remoting/protocol/webrtc_transport.cc +++ b/remoting/protocol/webrtc_transport.cc @@ -105,18 +105,23 @@ class SetSessionDescriptionObserver WebrtcTransport::WebrtcTransport( rtc::Thread* worker_thread, - scoped_refptr<TransportContext> transport_context) - : transport_context_(transport_context), - worker_thread_(worker_thread), + scoped_refptr<TransportContext> transport_context, + EventHandler* event_handler) + : worker_thread_(worker_thread), + transport_context_(transport_context), + event_handler_(event_handler), weak_factory_(this) {} WebrtcTransport::~WebrtcTransport() {} -void WebrtcTransport::Start(EventHandler* event_handler, - Authenticator* authenticator) { +void WebrtcTransport::Start( + Authenticator* authenticator, + SendTransportInfoCallback send_transport_info_callback) { DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(send_transport_info_callback_.is_null()); + + send_transport_info_callback_ = std::move(send_transport_info_callback); - event_handler_ = event_handler; // TODO(sergeyu): Use the |authenticator| to authenticate PeerConnection. transport_context_->CreatePortAllocator(base::Bind( @@ -248,15 +253,6 @@ StreamChannelFactory* WebrtcTransport::GetStreamChannelFactory() { return &data_stream_adapter_; } -StreamChannelFactory* WebrtcTransport::GetMultiplexedChannelFactory() { - DCHECK(thread_checker_.CalledOnValidThread()); - return GetStreamChannelFactory(); -} - -WebrtcTransport* WebrtcTransport::AsWebrtcTransport() { - return this; -} - void WebrtcTransport::OnLocalSessionDescriptionCreated( scoped_ptr<webrtc::SessionDescriptionInterface> description, const std::string& error) { @@ -287,7 +283,7 @@ void WebrtcTransport::OnLocalSessionDescriptionCreated( offer_tag->SetAttr(QName(std::string(), "type"), description->type()); offer_tag->SetBodyText(description_sdp); - event_handler_->OnOutgoingTransportInfo(std::move(transport_info)); + send_transport_info_callback_.Run(std::move(transport_info)); peer_connection_->SetLocalDescription( SetSessionDescriptionObserver::Create(base::Bind( @@ -346,7 +342,7 @@ void WebrtcTransport::Close(ErrorCode error) { peer_connection_factory_ = nullptr; if (error != OK) - event_handler_->OnTransportError(error); + event_handler_->OnWebrtcTransportError(error); } void WebrtcTransport::OnSignalingChange( @@ -398,7 +394,7 @@ void WebrtcTransport::OnIceConnectionChange( DCHECK(thread_checker_.CalledOnValidThread()); if (new_state == webrtc::PeerConnectionInterface::kIceConnectionConnected) - event_handler_->OnTransportConnected(); + event_handler_->OnWebrtcTransportConnected(); } void WebrtcTransport::OnIceGatheringChange( @@ -473,9 +469,7 @@ void WebrtcTransport::SendTransportInfo() { DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(pending_transport_info_message_); - event_handler_->OnOutgoingTransportInfo( - std::move(pending_transport_info_message_)); - pending_transport_info_message_.reset(); + send_transport_info_callback_.Run(std::move(pending_transport_info_message_)); } void WebrtcTransport::AddPendingCandidatesIfPossible() { @@ -494,18 +488,5 @@ void WebrtcTransport::AddPendingCandidatesIfPossible() { } } -WebrtcTransportFactory::WebrtcTransportFactory( - rtc::Thread* worker_thread, - scoped_refptr<TransportContext> transport_context) - : worker_thread_(worker_thread), - transport_context_(transport_context) {} - -WebrtcTransportFactory::~WebrtcTransportFactory() {} - -scoped_ptr<Transport> WebrtcTransportFactory::CreateTransport() { - return make_scoped_ptr( - new WebrtcTransport(worker_thread_, transport_context_.get())); -} - } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/webrtc_transport.h b/remoting/protocol/webrtc_transport.h index 827d2ef..a000029 100644 --- a/remoting/protocol/webrtc_transport.h +++ b/remoting/protocol/webrtc_transport.h @@ -30,8 +30,18 @@ class TransportContext; class WebrtcTransport : public Transport, public webrtc::PeerConnectionObserver { public: + class EventHandler { + public: + // Called when the transport is connected. + virtual void OnWebrtcTransportConnected() = 0; + + // Called when there is an error connecting the session. + virtual void OnWebrtcTransportError(ErrorCode error) = 0; + }; + WebrtcTransport(rtc::Thread* worker_thread, - scoped_refptr<TransportContext> transport_context); + scoped_refptr<TransportContext> transport_context, + EventHandler* event_handler); ~WebrtcTransport() override; webrtc::PeerConnectionInterface* peer_connection() { @@ -41,13 +51,12 @@ class WebrtcTransport : public Transport, return peer_connection_factory_; } + StreamChannelFactory* GetStreamChannelFactory(); + // Transport interface. - void Start(EventHandler* event_handler, - Authenticator* authenticator) override; + void Start(Authenticator* authenticator, + SendTransportInfoCallback send_transport_info_callback) override; bool ProcessTransportInfo(buzz::XmlElement* transport_info) override; - StreamChannelFactory* GetStreamChannelFactory() override; - StreamChannelFactory* GetMultiplexedChannelFactory() override; - WebrtcTransport* AsWebrtcTransport() override; private: void OnPortAllocatorCreated( @@ -84,9 +93,10 @@ class WebrtcTransport : public Transport, base::ThreadChecker thread_checker_; + rtc::Thread* worker_thread_; scoped_refptr<TransportContext> transport_context_; EventHandler* event_handler_ = nullptr; - rtc::Thread* worker_thread_; + SendTransportInfoCallback send_transport_info_callback_; scoped_ptr<webrtc::FakeAudioDeviceModule> fake_audio_device_module_; @@ -111,22 +121,6 @@ class WebrtcTransport : public Transport, DISALLOW_COPY_AND_ASSIGN(WebrtcTransport); }; -class WebrtcTransportFactory : public TransportFactory { - public: - WebrtcTransportFactory(rtc::Thread* worker_thread, - scoped_refptr<TransportContext> transport_context); - ~WebrtcTransportFactory() override; - - // TransportFactory interface. - scoped_ptr<Transport> CreateTransport() override; - - private: - rtc::Thread* worker_thread_; - scoped_refptr<TransportContext> transport_context_; - - DISALLOW_COPY_AND_ASSIGN(WebrtcTransportFactory); -}; - } // namespace protocol } // namespace remoting diff --git a/remoting/protocol/webrtc_transport_unittest.cc b/remoting/protocol/webrtc_transport_unittest.cc index 6b73087..4c70e1f 100644 --- a/remoting/protocol/webrtc_transport_unittest.cc +++ b/remoting/protocol/webrtc_transport_unittest.cc @@ -27,23 +27,17 @@ namespace protocol { namespace { -const char kTestJid[] = "client@gmail.com/321"; const char kChannelName[] = "test_channel"; -class TestTransportEventHandler : public Transport::EventHandler { +class TestTransportEventHandler : public WebrtcTransport::EventHandler { public: - typedef base::Callback<void(scoped_ptr<buzz::XmlElement> message)> - TransportInfoCallback; typedef base::Callback<void(ErrorCode error)> ErrorCallback; TestTransportEventHandler() {} ~TestTransportEventHandler() {} - // Both callback must be set before the test handler is passed to a Transport + // Both callbacks must be set before the test handler is passed to a Transport // object. - void set_transport_info_callback(const TransportInfoCallback& callback) { - transport_info_callback_ = callback; - } void set_connected_callback(const base::Closure& callback) { connected_callback_ = callback; } @@ -51,21 +45,15 @@ class TestTransportEventHandler : public Transport::EventHandler { error_callback_ = callback; } - // Transport::EventHandler interface. - void OnOutgoingTransportInfo(scoped_ptr<buzz::XmlElement> message) override { - transport_info_callback_.Run(std::move(message)); - } - void OnTransportRouteChange(const std::string& channel_name, - const TransportRoute& route) override {} - void OnTransportConnected() override { + // WebrtcTransport::EventHandler interface. + void OnWebrtcTransportConnected() override { connected_callback_.Run(); } - void OnTransportError(ErrorCode error) override { + void OnWebrtcTransportError(ErrorCode error) override { error_callback_.Run(error); } private: - TransportInfoCallback transport_info_callback_; base::Closure connected_callback_; ErrorCallback error_callback_; @@ -82,7 +70,7 @@ class WebrtcTransportTest : public testing::Test { NetworkSettings(NetworkSettings::NAT_TRAVERSAL_OUTGOING); } - void ProcessTransportInfo(scoped_ptr<Transport>* target_transport, + void ProcessTransportInfo(scoped_ptr<WebrtcTransport>* target_transport, scoped_ptr<buzz::XmlElement> transport_info) { ASSERT_TRUE(target_transport); EXPECT_TRUE((*target_transport) @@ -91,35 +79,19 @@ class WebrtcTransportTest : public testing::Test { protected: void InitializeConnection() { - signal_strategy_.reset(new FakeSignalStrategy(kTestJid)); - - host_transport_factory_.reset(new WebrtcTransportFactory( - jingle_glue::JingleThreadWrapper::current(), - new TransportContext( - signal_strategy_.get(), - make_scoped_ptr(new ChromiumPortAllocatorFactory(nullptr)), - network_settings_, TransportRole::SERVER))); - host_transport_ = host_transport_factory_->CreateTransport(); + host_transport_.reset( + new WebrtcTransport(jingle_glue::JingleThreadWrapper::current(), + TransportContext::ForTests(TransportRole::SERVER), + &host_event_handler_)); host_authenticator_.reset(new FakeAuthenticator( FakeAuthenticator::HOST, 0, FakeAuthenticator::ACCEPT, false)); - client_transport_factory_.reset(new WebrtcTransportFactory( - jingle_glue::JingleThreadWrapper::current(), - new TransportContext( - signal_strategy_.get(), - make_scoped_ptr(new ChromiumPortAllocatorFactory(nullptr)), - network_settings_, TransportRole::CLIENT))); - client_transport_ = client_transport_factory_->CreateTransport(); - host_authenticator_.reset(new FakeAuthenticator( + client_transport_.reset( + new WebrtcTransport(jingle_glue::JingleThreadWrapper::current(), + TransportContext::ForTests(TransportRole::CLIENT), + &client_event_handler_)); + client_authenticator_.reset(new FakeAuthenticator( FakeAuthenticator::CLIENT, 0, FakeAuthenticator::ACCEPT, false)); - - // Connect signaling between the two WebrtcTransport objects. - host_event_handler_.set_transport_info_callback( - base::Bind(&WebrtcTransportTest::ProcessTransportInfo, - base::Unretained(this), &client_transport_)); - client_event_handler_.set_transport_info_callback( - base::Bind(&WebrtcTransportTest::ProcessTransportInfo, - base::Unretained(this), &host_transport_)); } void StartConnection() { @@ -131,9 +103,15 @@ class WebrtcTransportTest : public testing::Test { client_event_handler_.set_error_callback(base::Bind( &WebrtcTransportTest::OnSessionError, base::Unretained(this))); - host_transport_->Start(&host_event_handler_, host_authenticator_.get()); - client_transport_->Start(&client_event_handler_, - client_authenticator_.get()); + // Start both transports. + host_transport_->Start( + host_authenticator_.get(), + base::Bind(&WebrtcTransportTest::ProcessTransportInfo, + base::Unretained(this), &client_transport_)); + client_transport_->Start( + client_authenticator_.get(), + base::Bind(&WebrtcTransportTest::ProcessTransportInfo, + base::Unretained(this), &host_transport_)); } void WaitUntilConnected() { @@ -189,15 +167,11 @@ class WebrtcTransportTest : public testing::Test { NetworkSettings network_settings_; - scoped_ptr< FakeSignalStrategy> signal_strategy_; - - scoped_ptr<WebrtcTransportFactory> host_transport_factory_; - scoped_ptr<Transport> host_transport_; + scoped_ptr<WebrtcTransport> host_transport_; TestTransportEventHandler host_event_handler_; scoped_ptr<FakeAuthenticator> host_authenticator_; - scoped_ptr<WebrtcTransportFactory> client_transport_factory_; - scoped_ptr<Transport> client_transport_; + scoped_ptr<WebrtcTransport> client_transport_; TestTransportEventHandler client_event_handler_; scoped_ptr<FakeAuthenticator> client_authenticator_; diff --git a/remoting/test/protocol_perftest.cc b/remoting/test/protocol_perftest.cc index 04703db..91b3eaa 100644 --- a/remoting/test/protocol_perftest.cc +++ b/remoting/test/protocol_perftest.cc @@ -25,7 +25,6 @@ #include "remoting/host/chromoting_host.h" #include "remoting/host/chromoting_host_context.h" #include "remoting/host/fake_desktop_environment.h" -#include "remoting/protocol/ice_transport.h" #include "remoting/protocol/jingle_session_manager.h" #include "remoting/protocol/me2me_host_authenticator_factory.h" #include "remoting/protocol/negotiating_client_authenticator.h" @@ -240,21 +239,18 @@ class ProtocolPerfTest new protocol::TransportContext( host_signaling_.get(), std::move(port_allocator_factory), network_settings, protocol::TransportRole::SERVER)); - scoped_ptr<protocol::SessionManager> session_manager( - new protocol::JingleSessionManager( - make_scoped_ptr( - new protocol::IceTransportFactory(transport_context)), - host_signaling_.get())); + new protocol::JingleSessionManager(host_signaling_.get())); session_manager->set_protocol_config(protocol_config_->Clone()); // Encoder runs on a separate thread, main thread is used for everything // else. host_.reset(new ChromotingHost( &desktop_environment_factory_, std::move(session_manager), - host_thread_.task_runner(), host_thread_.task_runner(), - capture_thread_.task_runner(), encode_thread_.task_runner(), - host_thread_.task_runner(), host_thread_.task_runner())); + transport_context, host_thread_.task_runner(), + host_thread_.task_runner(), capture_thread_.task_runner(), + encode_thread_.task_runner(), host_thread_.task_runner(), + host_thread_.task_runner())); base::FilePath certs_dir(net::GetTestCertsDirectory()); |