summaryrefslogtreecommitdiffstats
path: root/remoting/host/client_session_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'remoting/host/client_session_unittest.cc')
-rw-r--r--remoting/host/client_session_unittest.cc201
1 files changed, 200 insertions, 1 deletions
diff --git a/remoting/host/client_session_unittest.cc b/remoting/host/client_session_unittest.cc
index f28050a..d648ba1 100644
--- a/remoting/host/client_session_unittest.cc
+++ b/remoting/host/client_session_unittest.cc
@@ -2,17 +2,24 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include <algorithm>
+#include <string>
+#include <vector>
+
#include "base/message_loop/message_loop.h"
+#include "base/strings/string_util.h"
#include "base/test/test_simple_task_runner.h"
#include "remoting/base/auto_thread_task_runner.h"
#include "remoting/base/constants.h"
#include "remoting/host/audio_capturer.h"
#include "remoting/host/client_session.h"
#include "remoting/host/desktop_environment.h"
+#include "remoting/host/host_extension.h"
#include "remoting/host/host_mock_objects.h"
#include "remoting/host/screen_capturer_fake.h"
#include "remoting/protocol/protocol_mock_objects.h"
#include "testing/gmock/include/gmock/gmock-matchers.h"
+#include "testing/gmock_mutant.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/webrtc/modules/desktop_capture/desktop_geometry.h"
#include "third_party/webrtc/modules/desktop_capture/desktop_region.h"
@@ -31,9 +38,11 @@ using protocol::SessionConfig;
using testing::_;
using testing::AnyNumber;
using testing::AtMost;
+using testing::CreateFunctor;
using testing::DeleteArg;
using testing::DoAll;
using testing::Expectation;
+using testing::Invoke;
using testing::Return;
using testing::ReturnRef;
using testing::Sequence;
@@ -42,6 +51,8 @@ using testing::StrictMock;
namespace {
+const char kDefaultTestCapability[] = "default";
+
ACTION_P2(InjectClipboardEvent, connection, event) {
connection->clipboard_stub()->InjectClipboardEvent(event);
}
@@ -67,8 +78,75 @@ ACTION_P2(DeliverClientMessage, client_session, message) {
client_session->DeliverClientMessage(message);
}
+ACTION_P2(AddHostCapabilities, client_session, capability) {
+ client_session->AddHostCapabilities(capability);
}
+// Matches a |protocol::Capabilities| argument against a list of capabilities
+// formatted as a space-separated string.
+MATCHER_P(EqCapabilities, expected_capabilities, "") {
+ if (!arg.has_capabilities())
+ return false;
+
+ std::vector<std::string> words_args;
+ std::vector<std::string> words_expected;
+ Tokenize(arg.capabilities(), " ", &words_args);
+ Tokenize(expected_capabilities, " ", &words_expected);
+ std::sort(words_args.begin(), words_args.end());
+ std::sort(words_expected.begin(), words_expected.end());
+ return words_args == words_expected;
+}
+
+// |HostExtension| implementation that can handle an extension message type and
+// provide capabilities.
+class FakeExtension : public HostExtension {
+ public:
+ FakeExtension(const std::string& message_type,
+ const std::string& capabilities);
+ virtual ~FakeExtension();
+
+ virtual std::string GetCapabilities() OVERRIDE;
+ virtual scoped_ptr<HostExtensionSession> CreateExtensionSession(
+ ClientSession* client_session) OVERRIDE;
+
+ bool message_handled() {
+ return message_handled_;
+ }
+
+ private:
+ class FakeExtensionSession : public HostExtensionSession {
+ public:
+ FakeExtensionSession(FakeExtension* extension);
+ virtual ~FakeExtensionSession();
+
+ virtual bool OnExtensionMessage(
+ ClientSession* client_session,
+ const protocol::ExtensionMessage& message) OVERRIDE;
+
+ private:
+ FakeExtension* extension_;
+ };
+
+ std::string message_type_;
+ std::string capabilities_;
+ bool message_handled_;
+};
+
+typedef std::vector<HostExtension*> HostExtensionList;
+
+void CreateExtensionSessions(const HostExtensionList& extensions,
+ ClientSession* client_session) {
+ for (HostExtensionList::const_iterator extension = extensions.begin();
+ extension != extensions.end(); ++extension) {
+ scoped_ptr<HostExtensionSession> extension_session =
+ (*extension)->CreateExtensionSession(client_session);
+ if (extension_session)
+ client_session->AddExtensionSession(extension_session.Pass());
+ }
+}
+
+} // namespace
+
class ClientSessionTest : public testing::Test {
public:
ClientSessionTest() : client_jid_("user@domain/rest-of-jid") {}
@@ -100,6 +178,10 @@ class ClientSessionTest : public testing::Test {
// the input pipe line and starts video capturing.
void ConnectClientSession();
+ // Creates expectations to send an extension message and to disconnect
+ // afterwards.
+ void SetSendMessageAndDisconnectExpectation(const std::string& message_type);
+
// Invoked when the last reference to the AutoThreadTaskRunner has been
// released and quits the message loop to finish the test.
void QuitMainMessageLoop();
@@ -131,6 +213,41 @@ class ClientSessionTest : public testing::Test {
scoped_ptr<MockDesktopEnvironmentFactory> desktop_environment_factory_;
};
+FakeExtension::FakeExtension(const std::string& message_type,
+ const std::string& capabilities)
+ : message_type_(message_type),
+ capabilities_(capabilities),
+ message_handled_(false) {
+}
+
+FakeExtension::~FakeExtension() {}
+
+std::string FakeExtension::GetCapabilities() {
+ return capabilities_;
+}
+
+scoped_ptr<HostExtensionSession> FakeExtension::CreateExtensionSession(
+ ClientSession* client_session) {
+ return scoped_ptr<HostExtensionSession>(new FakeExtensionSession(this));
+}
+
+FakeExtension::FakeExtensionSession::FakeExtensionSession(
+ FakeExtension* extension)
+ : extension_(extension) {
+}
+
+FakeExtension::FakeExtensionSession::~FakeExtensionSession() {}
+
+bool FakeExtension::FakeExtensionSession::OnExtensionMessage(
+ ClientSession* client_session,
+ const protocol::ExtensionMessage& message) {
+ if (message.type() == extension_->message_type_) {
+ extension_->message_handled_ = true;
+ return true;
+ }
+ return false;
+}
+
void ClientSessionTest::SetUp() {
// Arrange to run |message_loop_| until no components depend on it.
scoped_refptr<AutoThreadTaskRunner> ui_task_runner = new AutoThreadTaskRunner(
@@ -180,6 +297,11 @@ void ClientSessionTest::SetUp() {
desktop_environment_factory_.get(),
base::TimeDelta(),
NULL));
+
+ // By default, client will report the same capabilities as the host.
+ EXPECT_CALL(client_stub_, SetCapabilities(_))
+ .Times(AtMost(1))
+ .WillOnce(Invoke(client_session_.get(), &ClientSession::SetCapabilities));
}
void ClientSessionTest::TearDown() {
@@ -211,7 +333,8 @@ DesktopEnvironment* ClientSessionTest::CreateDesktopEnvironment() {
EXPECT_CALL(*desktop_environment, CreateVideoCapturerPtr())
.WillOnce(Invoke(this, &ClientSessionTest::CreateVideoCapturer));
EXPECT_CALL(*desktop_environment, GetCapabilities())
- .Times(AtMost(1));
+ .Times(AtMost(1))
+ .WillOnce(Return(kDefaultTestCapability));
EXPECT_CALL(*desktop_environment, SetCapabilities(_))
.Times(AtMost(1));
@@ -232,6 +355,23 @@ void ClientSessionTest::ConnectClientSession() {
client_session_->OnConnectionChannelsConnected(client_session_->connection());
}
+void ClientSessionTest::SetSendMessageAndDisconnectExpectation(
+ const std::string& message_type) {
+ protocol::ExtensionMessage message;
+ message.set_type(message_type);
+ message.set_data("data");
+
+ Expectation authenticated =
+ EXPECT_CALL(session_event_handler_, OnSessionAuthenticated(_))
+ .WillOnce(Return(true));
+ EXPECT_CALL(session_event_handler_, OnSessionChannelsConnected(_))
+ .After(authenticated)
+ .WillOnce(DoAll(
+ DeliverClientMessage(client_session_.get(), message),
+ InvokeWithoutArgs(this, &ClientSessionTest::DisconnectClientSession),
+ InvokeWithoutArgs(this, &ClientSessionTest::StopClientSession)));
+}
+
void ClientSessionTest::QuitMainMessageLoop() {
message_loop_.PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
}
@@ -599,4 +739,63 @@ TEST_F(ClientSessionTest, EnableGnubbyAuth) {
message_loop_.Run();
}
+// Verifies that messages can be handled by extensions.
+TEST_F(ClientSessionTest, ExtensionMessages_MessageHandled) {
+ FakeExtension extension1("ext1", "cap1");
+ FakeExtension extension2("ext2", "cap2");
+ FakeExtension extension3("ext3", "cap3");
+ HostExtensionList extensions;
+ extensions.push_back(&extension1);
+ extensions.push_back(&extension2);
+ extensions.push_back(&extension3);
+
+ EXPECT_CALL(session_event_handler_, OnSessionClientCapabilities(_))
+ .WillOnce(Invoke(CreateFunctor(&CreateExtensionSessions, extensions)));
+
+ SetSendMessageAndDisconnectExpectation("ext2");
+ ConnectClientSession();
+ message_loop_.Run();
+
+ EXPECT_FALSE(extension1.message_handled());
+ EXPECT_TRUE(extension2.message_handled());
+ EXPECT_FALSE(extension3.message_handled());
+}
+
+// Verifies that extension messages not handled by extensions don't result in a
+// crash.
+TEST_F(ClientSessionTest, ExtensionMessages_MessageNotHandled) {
+ FakeExtension extension1("ext1", "cap1");
+ HostExtensionList extensions;
+ extensions.push_back(&extension1);
+
+ EXPECT_CALL(session_event_handler_, OnSessionClientCapabilities(_))
+ .WillOnce(Invoke(CreateFunctor(&CreateExtensionSessions, extensions)));
+
+ SetSendMessageAndDisconnectExpectation("extX");
+ ConnectClientSession();
+ message_loop_.Run();
+
+ EXPECT_FALSE(extension1.message_handled());
+}
+
+TEST_F(ClientSessionTest, ReportCapabilities) {
+ Expectation authenticated =
+ EXPECT_CALL(session_event_handler_, OnSessionAuthenticated(_))
+ .WillOnce(DoAll(
+ AddHostCapabilities(client_session_.get(), "capX capZ"),
+ AddHostCapabilities(client_session_.get(), ""),
+ AddHostCapabilities(client_session_.get(), "capY"),
+ Return(true)));
+ EXPECT_CALL(client_stub_,
+ SetCapabilities(EqCapabilities("capX capY capZ default")));
+ EXPECT_CALL(session_event_handler_, OnSessionChannelsConnected(_))
+ .After(authenticated)
+ .WillOnce(DoAll(
+ InvokeWithoutArgs(this, &ClientSessionTest::DisconnectClientSession),
+ InvokeWithoutArgs(this, &ClientSessionTest::StopClientSession)));
+
+ ConnectClientSession();
+ message_loop_.Run();
+}
+
} // namespace remoting