diff options
Diffstat (limited to 'remoting/host/client_session_unittest.cc')
-rw-r--r-- | remoting/host/client_session_unittest.cc | 201 |
1 files changed, 200 insertions, 1 deletions
diff --git a/remoting/host/client_session_unittest.cc b/remoting/host/client_session_unittest.cc index f28050a..d648ba1 100644 --- a/remoting/host/client_session_unittest.cc +++ b/remoting/host/client_session_unittest.cc @@ -2,17 +2,24 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <algorithm> +#include <string> +#include <vector> + #include "base/message_loop/message_loop.h" +#include "base/strings/string_util.h" #include "base/test/test_simple_task_runner.h" #include "remoting/base/auto_thread_task_runner.h" #include "remoting/base/constants.h" #include "remoting/host/audio_capturer.h" #include "remoting/host/client_session.h" #include "remoting/host/desktop_environment.h" +#include "remoting/host/host_extension.h" #include "remoting/host/host_mock_objects.h" #include "remoting/host/screen_capturer_fake.h" #include "remoting/protocol/protocol_mock_objects.h" #include "testing/gmock/include/gmock/gmock-matchers.h" +#include "testing/gmock_mutant.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/webrtc/modules/desktop_capture/desktop_geometry.h" #include "third_party/webrtc/modules/desktop_capture/desktop_region.h" @@ -31,9 +38,11 @@ using protocol::SessionConfig; using testing::_; using testing::AnyNumber; using testing::AtMost; +using testing::CreateFunctor; using testing::DeleteArg; using testing::DoAll; using testing::Expectation; +using testing::Invoke; using testing::Return; using testing::ReturnRef; using testing::Sequence; @@ -42,6 +51,8 @@ using testing::StrictMock; namespace { +const char kDefaultTestCapability[] = "default"; + ACTION_P2(InjectClipboardEvent, connection, event) { connection->clipboard_stub()->InjectClipboardEvent(event); } @@ -67,8 +78,75 @@ ACTION_P2(DeliverClientMessage, client_session, message) { client_session->DeliverClientMessage(message); } +ACTION_P2(AddHostCapabilities, client_session, capability) { + client_session->AddHostCapabilities(capability); } +// Matches a |protocol::Capabilities| argument against a list of capabilities +// formatted as a space-separated string. +MATCHER_P(EqCapabilities, expected_capabilities, "") { + if (!arg.has_capabilities()) + return false; + + std::vector<std::string> words_args; + std::vector<std::string> words_expected; + Tokenize(arg.capabilities(), " ", &words_args); + Tokenize(expected_capabilities, " ", &words_expected); + std::sort(words_args.begin(), words_args.end()); + std::sort(words_expected.begin(), words_expected.end()); + return words_args == words_expected; +} + +// |HostExtension| implementation that can handle an extension message type and +// provide capabilities. +class FakeExtension : public HostExtension { + public: + FakeExtension(const std::string& message_type, + const std::string& capabilities); + virtual ~FakeExtension(); + + virtual std::string GetCapabilities() OVERRIDE; + virtual scoped_ptr<HostExtensionSession> CreateExtensionSession( + ClientSession* client_session) OVERRIDE; + + bool message_handled() { + return message_handled_; + } + + private: + class FakeExtensionSession : public HostExtensionSession { + public: + FakeExtensionSession(FakeExtension* extension); + virtual ~FakeExtensionSession(); + + virtual bool OnExtensionMessage( + ClientSession* client_session, + const protocol::ExtensionMessage& message) OVERRIDE; + + private: + FakeExtension* extension_; + }; + + std::string message_type_; + std::string capabilities_; + bool message_handled_; +}; + +typedef std::vector<HostExtension*> HostExtensionList; + +void CreateExtensionSessions(const HostExtensionList& extensions, + ClientSession* client_session) { + for (HostExtensionList::const_iterator extension = extensions.begin(); + extension != extensions.end(); ++extension) { + scoped_ptr<HostExtensionSession> extension_session = + (*extension)->CreateExtensionSession(client_session); + if (extension_session) + client_session->AddExtensionSession(extension_session.Pass()); + } +} + +} // namespace + class ClientSessionTest : public testing::Test { public: ClientSessionTest() : client_jid_("user@domain/rest-of-jid") {} @@ -100,6 +178,10 @@ class ClientSessionTest : public testing::Test { // the input pipe line and starts video capturing. void ConnectClientSession(); + // Creates expectations to send an extension message and to disconnect + // afterwards. + void SetSendMessageAndDisconnectExpectation(const std::string& message_type); + // Invoked when the last reference to the AutoThreadTaskRunner has been // released and quits the message loop to finish the test. void QuitMainMessageLoop(); @@ -131,6 +213,41 @@ class ClientSessionTest : public testing::Test { scoped_ptr<MockDesktopEnvironmentFactory> desktop_environment_factory_; }; +FakeExtension::FakeExtension(const std::string& message_type, + const std::string& capabilities) + : message_type_(message_type), + capabilities_(capabilities), + message_handled_(false) { +} + +FakeExtension::~FakeExtension() {} + +std::string FakeExtension::GetCapabilities() { + return capabilities_; +} + +scoped_ptr<HostExtensionSession> FakeExtension::CreateExtensionSession( + ClientSession* client_session) { + return scoped_ptr<HostExtensionSession>(new FakeExtensionSession(this)); +} + +FakeExtension::FakeExtensionSession::FakeExtensionSession( + FakeExtension* extension) + : extension_(extension) { +} + +FakeExtension::FakeExtensionSession::~FakeExtensionSession() {} + +bool FakeExtension::FakeExtensionSession::OnExtensionMessage( + ClientSession* client_session, + const protocol::ExtensionMessage& message) { + if (message.type() == extension_->message_type_) { + extension_->message_handled_ = true; + return true; + } + return false; +} + void ClientSessionTest::SetUp() { // Arrange to run |message_loop_| until no components depend on it. scoped_refptr<AutoThreadTaskRunner> ui_task_runner = new AutoThreadTaskRunner( @@ -180,6 +297,11 @@ void ClientSessionTest::SetUp() { desktop_environment_factory_.get(), base::TimeDelta(), NULL)); + + // By default, client will report the same capabilities as the host. + EXPECT_CALL(client_stub_, SetCapabilities(_)) + .Times(AtMost(1)) + .WillOnce(Invoke(client_session_.get(), &ClientSession::SetCapabilities)); } void ClientSessionTest::TearDown() { @@ -211,7 +333,8 @@ DesktopEnvironment* ClientSessionTest::CreateDesktopEnvironment() { EXPECT_CALL(*desktop_environment, CreateVideoCapturerPtr()) .WillOnce(Invoke(this, &ClientSessionTest::CreateVideoCapturer)); EXPECT_CALL(*desktop_environment, GetCapabilities()) - .Times(AtMost(1)); + .Times(AtMost(1)) + .WillOnce(Return(kDefaultTestCapability)); EXPECT_CALL(*desktop_environment, SetCapabilities(_)) .Times(AtMost(1)); @@ -232,6 +355,23 @@ void ClientSessionTest::ConnectClientSession() { client_session_->OnConnectionChannelsConnected(client_session_->connection()); } +void ClientSessionTest::SetSendMessageAndDisconnectExpectation( + const std::string& message_type) { + protocol::ExtensionMessage message; + message.set_type(message_type); + message.set_data("data"); + + Expectation authenticated = + EXPECT_CALL(session_event_handler_, OnSessionAuthenticated(_)) + .WillOnce(Return(true)); + EXPECT_CALL(session_event_handler_, OnSessionChannelsConnected(_)) + .After(authenticated) + .WillOnce(DoAll( + DeliverClientMessage(client_session_.get(), message), + InvokeWithoutArgs(this, &ClientSessionTest::DisconnectClientSession), + InvokeWithoutArgs(this, &ClientSessionTest::StopClientSession))); +} + void ClientSessionTest::QuitMainMessageLoop() { message_loop_.PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); } @@ -599,4 +739,63 @@ TEST_F(ClientSessionTest, EnableGnubbyAuth) { message_loop_.Run(); } +// Verifies that messages can be handled by extensions. +TEST_F(ClientSessionTest, ExtensionMessages_MessageHandled) { + FakeExtension extension1("ext1", "cap1"); + FakeExtension extension2("ext2", "cap2"); + FakeExtension extension3("ext3", "cap3"); + HostExtensionList extensions; + extensions.push_back(&extension1); + extensions.push_back(&extension2); + extensions.push_back(&extension3); + + EXPECT_CALL(session_event_handler_, OnSessionClientCapabilities(_)) + .WillOnce(Invoke(CreateFunctor(&CreateExtensionSessions, extensions))); + + SetSendMessageAndDisconnectExpectation("ext2"); + ConnectClientSession(); + message_loop_.Run(); + + EXPECT_FALSE(extension1.message_handled()); + EXPECT_TRUE(extension2.message_handled()); + EXPECT_FALSE(extension3.message_handled()); +} + +// Verifies that extension messages not handled by extensions don't result in a +// crash. +TEST_F(ClientSessionTest, ExtensionMessages_MessageNotHandled) { + FakeExtension extension1("ext1", "cap1"); + HostExtensionList extensions; + extensions.push_back(&extension1); + + EXPECT_CALL(session_event_handler_, OnSessionClientCapabilities(_)) + .WillOnce(Invoke(CreateFunctor(&CreateExtensionSessions, extensions))); + + SetSendMessageAndDisconnectExpectation("extX"); + ConnectClientSession(); + message_loop_.Run(); + + EXPECT_FALSE(extension1.message_handled()); +} + +TEST_F(ClientSessionTest, ReportCapabilities) { + Expectation authenticated = + EXPECT_CALL(session_event_handler_, OnSessionAuthenticated(_)) + .WillOnce(DoAll( + AddHostCapabilities(client_session_.get(), "capX capZ"), + AddHostCapabilities(client_session_.get(), ""), + AddHostCapabilities(client_session_.get(), "capY"), + Return(true))); + EXPECT_CALL(client_stub_, + SetCapabilities(EqCapabilities("capX capY capZ default"))); + EXPECT_CALL(session_event_handler_, OnSessionChannelsConnected(_)) + .After(authenticated) + .WillOnce(DoAll( + InvokeWithoutArgs(this, &ClientSessionTest::DisconnectClientSession), + InvokeWithoutArgs(this, &ClientSessionTest::StopClientSession))); + + ConnectClientSession(); + message_loop_.Run(); +} + } // namespace remoting |