diff options
author | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-08-02 01:35:16 +0000 |
---|---|---|
committer | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-08-02 01:35:16 +0000 |
commit | 3784660319748aed36a0d243969feed846aad380 (patch) | |
tree | d6280de3ff37d1c5b3dd40a86a135f94c49785af /remoting | |
parent | da12d4cbd28bb214212db5bac896699f77efae9c (diff) | |
download | chromium_src-3784660319748aed36a0d243969feed846aad380.zip chromium_src-3784660319748aed36a0d243969feed846aad380.tar.gz chromium_src-3784660319748aed36a0d243969feed846aad380.tar.bz2 |
Refactored PairingRegistry::Delegate such that it can retrieve/modify for a single client.
The delegate implementation for Linux now keeps a separate JSON file per each pairing under "paired-clients" in the config directory.
This CL also makes the delegate completely synchronous moving jumping between task runners to the parent PairingRegistry.
BUG=156182
Review URL: https://chromiumcodereview.appspot.com/21128006
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@215187 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting')
-rw-r--r-- | remoting/host/pairing_registry_delegate.cc | 8 | ||||
-rw-r--r-- | remoting/host/pairing_registry_delegate.h | 12 | ||||
-rw-r--r-- | remoting/host/pairing_registry_delegate_linux.cc | 182 | ||||
-rw-r--r-- | remoting/host/pairing_registry_delegate_linux.h | 51 | ||||
-rw-r--r-- | remoting/host/pairing_registry_delegate_linux_unittest.cc | 113 | ||||
-rw-r--r-- | remoting/host/pairing_registry_delegate_mac.cc | 3 | ||||
-rw-r--r-- | remoting/host/pairing_registry_delegate_win.cc | 3 | ||||
-rw-r--r-- | remoting/host/remoting_me2me_host.cc | 8 | ||||
-rw-r--r-- | remoting/host/setup/native_messaging_host_unittest.cc | 8 | ||||
-rw-r--r-- | remoting/protocol/negotiating_authenticator_unittest.cc | 15 | ||||
-rw-r--r-- | remoting/protocol/pairing_registry.cc | 295 | ||||
-rw-r--r-- | remoting/protocol/pairing_registry.h | 111 | ||||
-rw-r--r-- | remoting/protocol/pairing_registry_unittest.cc | 157 | ||||
-rw-r--r-- | remoting/protocol/protocol_mock_objects.cc | 83 | ||||
-rw-r--r-- | remoting/protocol/protocol_mock_objects.h | 59 |
15 files changed, 571 insertions, 537 deletions
diff --git a/remoting/host/pairing_registry_delegate.cc b/remoting/host/pairing_registry_delegate.cc index ada8cfc..c7e49a0 100644 --- a/remoting/host/pairing_registry_delegate.cc +++ b/remoting/host/pairing_registry_delegate.cc @@ -4,19 +4,19 @@ #include "remoting/host/pairing_registry_delegate.h" -#include "base/task_runner.h" +#include "base/single_thread_task_runner.h" namespace remoting { using protocol::PairingRegistry; scoped_refptr<PairingRegistry> CreatePairingRegistry( - scoped_refptr<base::TaskRunner> task_runner) { + scoped_refptr<base::SingleThreadTaskRunner> task_runner) { scoped_refptr<PairingRegistry> pairing_registry; scoped_ptr<PairingRegistry::Delegate> delegate( - CreatePairingRegistryDelegate(task_runner)); + CreatePairingRegistryDelegate()); if (delegate) { - pairing_registry = new PairingRegistry(delegate.Pass()); + pairing_registry = new PairingRegistry(task_runner, delegate.Pass()); } return pairing_registry; } diff --git a/remoting/host/pairing_registry_delegate.h b/remoting/host/pairing_registry_delegate.h index 82ecbde..b25836e 100644 --- a/remoting/host/pairing_registry_delegate.h +++ b/remoting/host/pairing_registry_delegate.h @@ -10,20 +10,20 @@ #include "remoting/protocol/pairing_registry.h" namespace base { -class TaskRunner; +class SingleThreadTaskRunner; } // namespace base namespace remoting { // Returns a platform-specific pairing registry delegate that will save to -// permanent storage using the specified TaskRunner. Returns NULL on platforms -// that don't support pairing. +// permanent storage. Returns NULL on platforms that don't support pairing. scoped_ptr<protocol::PairingRegistry::Delegate> -CreatePairingRegistryDelegate(scoped_refptr<base::TaskRunner> task_runner); +CreatePairingRegistryDelegate(); // Convenience function which returns a new PairingRegistry, using the delegate -// returned by CreatePairingRegistryDelegate(). +// returned by CreatePairingRegistryDelegate(). The passed |task_runner| is used +// to run the delegate's methods asynchronously. scoped_refptr<protocol::PairingRegistry> CreatePairingRegistry( - scoped_refptr<base::TaskRunner> task_runner); + scoped_refptr<base::SingleThreadTaskRunner> task_runner); } // namespace remoting diff --git a/remoting/host/pairing_registry_delegate_linux.cc b/remoting/host/pairing_registry_delegate_linux.cc index bca9b4f..0f1517a 100644 --- a/remoting/host/pairing_registry_delegate_linux.cc +++ b/remoting/host/pairing_registry_delegate_linux.cc @@ -6,127 +6,153 @@ #include "base/bind.h" #include "base/file_util.h" +#include "base/files/file_enumerator.h" #include "base/files/important_file_writer.h" +#include "base/json/json_file_value_serializer.h" +#include "base/json/json_string_value_serializer.h" #include "base/location.h" -#include "base/single_thread_task_runner.h" -#include "base/thread_task_runner_handle.h" +#include "base/strings/stringprintf.h" +#include "base/values.h" #include "remoting/host/branding.h" namespace { -const char kRegistryFilename[] = "paired-clients.json"; + +// The pairing registry path relative to the configuration directory. +const char kRegistryDirectory[] = "paired-clients"; + +const char kPairingFilenameFormat[] = "%s.json"; +const char kPairingFilenamePattern[] = "*.json"; + } // namespace namespace remoting { using protocol::PairingRegistry; -PairingRegistryDelegateLinux::PairingRegistryDelegateLinux( - scoped_refptr<base::TaskRunner> task_runner) - : task_runner_(task_runner), - weak_factory_(this) { +PairingRegistryDelegateLinux::PairingRegistryDelegateLinux() { } PairingRegistryDelegateLinux::~PairingRegistryDelegateLinux() { } -void PairingRegistryDelegateLinux::Save( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) { - // If a callback was supplied, wrap it in a helper function that will - // run it on this thread. - PairingRegistry::SaveCallback run_callback_on_this_thread; - if (!callback.is_null()) { - run_callback_on_this_thread = - base::Bind(&PairingRegistryDelegateLinux::RunSaveCallbackOnThread, - base::ThreadTaskRunnerHandle::Get(), - callback); +scoped_ptr<base::ListValue> PairingRegistryDelegateLinux::LoadAll() { + scoped_ptr<base::ListValue> pairings(new base::ListValue()); + + // Enumerate all pairing files in the pairing registry. + base::FilePath registry_path = GetRegistryPath(); + base::FileEnumerator enumerator(registry_path, false, + base::FileEnumerator::FILES, + kPairingFilenamePattern); + for (base::FilePath pairing_file = enumerator.Next(); !pairing_file.empty(); + pairing_file = enumerator.Next()) { + // Read the JSON containing pairing data. + JSONFileValueSerializer serializer(pairing_file); + int error_code; + std::string error_message; + scoped_ptr<base::Value> pairing_json( + serializer.Deserialize(&error_code, &error_message)); + if (!pairing_json) { + LOG(WARNING) << "Failed to load '" << pairing_file.value() << "' (" + << error_code << ")."; + continue; + } + + pairings->Append(pairing_json.release()); } - task_runner_->PostTask( - FROM_HERE, - base::Bind(&PairingRegistryDelegateLinux::DoSave, - weak_factory_.GetWeakPtr(), - pairings_json, - run_callback_on_this_thread)); -} -void PairingRegistryDelegateLinux::Load( - const PairingRegistry::LoadCallback& callback) { - // Wrap the callback in a helper function that will run it on this thread. - // Note that, unlike AddPairing, the GetPairing callback is mandatory. - PairingRegistry::LoadCallback run_callback_on_this_thread = - base::Bind(&PairingRegistryDelegateLinux::RunLoadCallbackOnThread, - base::ThreadTaskRunnerHandle::Get(), - callback); - task_runner_->PostTask( - FROM_HERE, - base::Bind(&PairingRegistryDelegateLinux::DoLoad, - weak_factory_.GetWeakPtr(), - run_callback_on_this_thread)); + return pairings.Pass(); } -void PairingRegistryDelegateLinux::RunSaveCallbackOnThread( - scoped_refptr<base::TaskRunner> task_runner, - const PairingRegistry::SaveCallback& callback, - bool success) { - task_runner->PostTask(FROM_HERE, base::Bind(callback, success)); +bool PairingRegistryDelegateLinux::DeleteAll() { + // Delete all pairing files in the pairing registry. + base::FilePath registry_path = GetRegistryPath(); + base::FileEnumerator enumerator(registry_path, false, + base::FileEnumerator::FILES, + kPairingFilenamePattern); + + bool success = true; + for (base::FilePath pairing_file = enumerator.Next(); !pairing_file.empty(); + pairing_file = enumerator.Next()) { + success = success && base::DeleteFile(pairing_file, false); + } + + return success; } -void PairingRegistryDelegateLinux::RunLoadCallbackOnThread( - scoped_refptr<base::TaskRunner> task_runner, - const PairingRegistry::LoadCallback& callback, - const std::string& pairings_json) { - task_runner->PostTask(FROM_HERE, base::Bind(callback, pairings_json)); +PairingRegistry::Pairing PairingRegistryDelegateLinux::Load( + const std::string& client_id) { + base::FilePath registry_path = GetRegistryPath(); + base::FilePath pairing_file = registry_path.Append( + base::StringPrintf(kPairingFilenameFormat, client_id.c_str())); + + JSONFileValueSerializer serializer(pairing_file); + int error_code; + std::string error_message; + scoped_ptr<base::Value> pairing( + serializer.Deserialize(&error_code, &error_message)); + if (!pairing) { + LOG(WARNING) << "Failed to load pairing information: " << error_message + << " (" << error_code << ")."; + return PairingRegistry::Pairing(); + } + + return PairingRegistry::Pairing::CreateFromValue(*pairing); } -void PairingRegistryDelegateLinux::DoSave( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) { - base::FilePath registry_path = GetRegistryFilePath(); - base::FilePath parent_directory = registry_path.DirName(); +bool PairingRegistryDelegateLinux::Save( + const PairingRegistry::Pairing& pairing) { + base::FilePath registry_path = GetRegistryPath(); base::PlatformFileError error; - if (!file_util::CreateDirectoryAndGetError(parent_directory, &error)) { + if (!file_util::CreateDirectoryAndGetError(registry_path, &error)) { LOG(ERROR) << "Could not create pairing registry directory: " << error; - return; + return false; } - if (!base::ImportantFileWriter::WriteFileAtomically(registry_path, - pairings_json)) { - LOG(ERROR) << "Could not save pairing registry."; + + std::string pairing_json; + JSONStringValueSerializer serializer(&pairing_json); + if (!serializer.Serialize(*pairing.ToValue())) { + LOG(ERROR) << "Failed to serialize pairing data for " + << pairing.client_id(); + return false; } - if (!callback.is_null()) { - callback.Run(true); + base::FilePath pairing_file = registry_path.Append( + base::StringPrintf(kPairingFilenameFormat, pairing.client_id().c_str())); + if (!base::ImportantFileWriter::WriteFileAtomically(pairing_file, + pairing_json)) { + LOG(ERROR) << "Could not save pairing data for " << pairing.client_id(); + return false; } + + return true; } -void PairingRegistryDelegateLinux::DoLoad( - const PairingRegistry::LoadCallback& callback) { - base::FilePath registry_path = GetRegistryFilePath(); - std::string result; - if (!file_util::ReadFileToString(registry_path, &result)) { - LOG(ERROR) << "Load failed."; - } - callback.Run(result); +bool PairingRegistryDelegateLinux::Delete(const std::string& client_id) { + base::FilePath registry_path = GetRegistryPath(); + base::FilePath pairing_file = registry_path.Append( + base::StringPrintf(kPairingFilenameFormat, client_id.c_str())); + + return base::DeleteFile(pairing_file, false); } -base::FilePath PairingRegistryDelegateLinux::GetRegistryFilePath() { - if (!filename_for_testing_.empty()) { - return filename_for_testing_; +base::FilePath PairingRegistryDelegateLinux::GetRegistryPath() { + if (!registry_path_for_testing_.empty()) { + return registry_path_for_testing_; } base::FilePath config_dir = remoting::GetConfigDir(); - return config_dir.Append(kRegistryFilename); + return config_dir.Append(kRegistryDirectory); } -void PairingRegistryDelegateLinux::SetFilenameForTesting( - const base::FilePath &filename) { - filename_for_testing_ = filename; +void PairingRegistryDelegateLinux::SetRegistryPathForTesting( + const base::FilePath& registry_path) { + registry_path_for_testing_ = registry_path; } -scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate( - scoped_refptr<base::TaskRunner> task_runner) { +scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate() { return scoped_ptr<PairingRegistry::Delegate>( - new PairingRegistryDelegateLinux(task_runner)); + new PairingRegistryDelegateLinux()); } } // namespace remoting diff --git a/remoting/host/pairing_registry_delegate_linux.h b/remoting/host/pairing_registry_delegate_linux.h index 08c073b..491f310 100644 --- a/remoting/host/pairing_registry_delegate_linux.h +++ b/remoting/host/pairing_registry_delegate_linux.h @@ -8,11 +8,9 @@ #include "remoting/protocol/pairing_registry.h" #include "base/files/file_path.h" -#include "base/memory/weak_ptr.h" namespace base { class ListValue; -class TaskRunner; } // namespace base namespace remoting { @@ -20,50 +18,29 @@ namespace remoting { class PairingRegistryDelegateLinux : public protocol::PairingRegistry::Delegate { public: - explicit PairingRegistryDelegateLinux( - scoped_refptr<base::TaskRunner> task_runner); + PairingRegistryDelegateLinux(); virtual ~PairingRegistryDelegateLinux(); // PairingRegistry::Delegate interface - virtual void Save( - const std::string& pairings_json, - const protocol::PairingRegistry::SaveCallback& callback) OVERRIDE; - virtual void Load( - const protocol::PairingRegistry::LoadCallback& callback) OVERRIDE; + virtual scoped_ptr<base::ListValue> LoadAll() OVERRIDE; + virtual bool DeleteAll() OVERRIDE; + virtual protocol::PairingRegistry::Pairing Load( + const std::string& client_id) OVERRIDE; + virtual bool Save(const protocol::PairingRegistry::Pairing& pairing) OVERRIDE; + virtual bool Delete(const std::string& client_id) OVERRIDE; private: FRIEND_TEST_ALL_PREFIXES(PairingRegistryDelegateLinuxTest, SaveAndLoad); + FRIEND_TEST_ALL_PREFIXES(PairingRegistryDelegateLinuxTest, Stateless); - // Blocking helper methods run using the TaskRunner passed to the ctor. - void DoSave(const std::string& pairings_json, - const protocol::PairingRegistry::SaveCallback& callback); - void DoLoad(const protocol::PairingRegistry::LoadCallback& callback); + // Return the path to the directory to use for loading and saving paired + // clients. + base::FilePath GetRegistryPath(); - // Run the delegate callbacks on their original thread. - static void RunSaveCallbackOnThread( - scoped_refptr<base::TaskRunner> task_runner, - const protocol::PairingRegistry::SaveCallback& callback, - bool success); - static void RunLoadCallbackOnThread( - scoped_refptr<base::TaskRunner> task_runner, - const protocol::PairingRegistry::LoadCallback& callback, - const std::string& pairings_json); + // For testing purposes, set the path returned by |GetRegistryPath()|. + void SetRegistryPathForTesting(const base::FilePath& registry_path); - // Helper methods to load and save the pairing registry. - protocol::PairingRegistry::PairedClients LoadPairings(); - void SavePairings( - const protocol::PairingRegistry::PairedClients& paired_clients); - - // Return the path to the file to use for loading and saving paired clients. - base::FilePath GetRegistryFilePath(); - - // For testing purposes, set the path returned by |GetRegistryFilePath|. - void SetFilenameForTesting(const base::FilePath &filename); - - scoped_refptr<base::TaskRunner> task_runner_; - base::FilePath filename_for_testing_; - - base::WeakPtrFactory<PairingRegistryDelegateLinux> weak_factory_; + base::FilePath registry_path_for_testing_; DISALLOW_COPY_AND_ASSIGN(PairingRegistryDelegateLinux); }; diff --git a/remoting/host/pairing_registry_delegate_linux_unittest.cc b/remoting/host/pairing_registry_delegate_linux_unittest.cc index 0f679eb..d044a04 100644 --- a/remoting/host/pairing_registry_delegate_linux_unittest.cc +++ b/remoting/host/pairing_registry_delegate_linux_unittest.cc @@ -5,11 +5,8 @@ #include "remoting/host/pairing_registry_delegate_linux.h" #include "base/file_util.h" -#include "base/message_loop/message_loop.h" -#include "base/run_loop.h" -#include "base/task_runner.h" -#include "base/thread_task_runner_handle.h" #include "base/timer/timer.h" +#include "base/values.h" #include "testing/gtest/include/gtest/gtest.h" namespace remoting { @@ -18,59 +15,75 @@ using protocol::PairingRegistry; class PairingRegistryDelegateLinuxTest : public testing::Test { public: - void SaveComplete(PairingRegistry::Delegate* delegate, - const std::string& expected_json, - bool success) { - EXPECT_TRUE(success); - // Load the pairings again to make sure we get what we've just written. - delegate->Load( - base::Bind(&PairingRegistryDelegateLinuxTest::VerifyLoad, - base::Unretained(this), - expected_json)); + virtual void SetUp() OVERRIDE { + // Create a temporary directory in order to get a unique name and use a + // subdirectory to ensure that PairingRegistryDelegateLinux::Save() creates + // the parent directory if it doesn't exist. + file_util::CreateNewTempDirectory("chromoting-test", &temp_dir_); + temp_registry_ = temp_dir_.Append("paired-clients"); } - void VerifyLoad(const std::string& expected, - const std::string& actual) { - EXPECT_EQ(actual, expected); - base::MessageLoop::current()->Quit(); + virtual void TearDown() OVERRIDE { + base::DeleteFile(temp_dir_, true); } + + protected: + base::FilePath temp_dir_; + base::FilePath temp_registry_; }; TEST_F(PairingRegistryDelegateLinuxTest, SaveAndLoad) { - base::MessageLoop message_loop; - base::RunLoop run_loop; - - // Create a temporary directory in order to get a unique name and use a - // subdirectory to ensure that the AddPairing method creates the parent - // directory if it doesn't exist. - base::FilePath temp_dir; - file_util::CreateNewTempDirectory("chromoting-test", &temp_dir); - base::FilePath temp_file = temp_dir.Append("dir").Append("registry.json"); - - scoped_refptr<base::TaskRunner> task_runner = - base::ThreadTaskRunnerHandle::Get(); + scoped_ptr<PairingRegistryDelegateLinux> delegate( + new PairingRegistryDelegateLinux()); + delegate->SetRegistryPathForTesting(temp_registry_); + + // Check that registry is initially empty. + EXPECT_TRUE(delegate->LoadAll()->empty()); + + // Add a couple of pairings. + PairingRegistry::Pairing pairing1(base::Time::Now(), "xxx", "xxx", "xxx"); + PairingRegistry::Pairing pairing2(base::Time::Now(), "yyy", "yyy", "yyy"); + EXPECT_TRUE(delegate->Save(pairing1)); + EXPECT_TRUE(delegate->Save(pairing2)); + + // Verify that there are two pairings in the store now. +EXPECT_EQ(delegate->LoadAll()->GetSize(), 2u); + + // Verify that they can be retrieved. + EXPECT_EQ(delegate->Load(pairing1.client_id()), pairing1); + EXPECT_EQ(delegate->Load(pairing2.client_id()), pairing2); + + // Delete the first pairing. + EXPECT_TRUE(delegate->Delete(pairing1.client_id())); + + // Verify that there is only one pairing left. + EXPECT_EQ(delegate->Load(pairing1.client_id()), PairingRegistry::Pairing()); + EXPECT_EQ(delegate->Load(pairing2.client_id()), pairing2); + + // Verify that the only value that left is |pairing2|. + EXPECT_EQ(delegate->LoadAll()->GetSize(), 1u); + scoped_ptr<base::ListValue> pairings = delegate->LoadAll(); + base::Value* json; + EXPECT_TRUE(pairings->Get(0, &json)); + EXPECT_EQ(PairingRegistry::Pairing::CreateFromValue(*json), pairing2); + + // Delete the rest and verify. + EXPECT_TRUE(delegate->DeleteAll()); + EXPECT_TRUE(delegate->LoadAll()->empty()); +} + +// Verifies that the delegate is stateless by using two different instances. +TEST_F(PairingRegistryDelegateLinuxTest, Stateless) { scoped_ptr<PairingRegistryDelegateLinux> save_delegate( - new PairingRegistryDelegateLinux(task_runner)); + new PairingRegistryDelegateLinux()); scoped_ptr<PairingRegistryDelegateLinux> load_delegate( - new PairingRegistryDelegateLinux(task_runner)); - save_delegate->SetFilenameForTesting(temp_file); - load_delegate->SetFilenameForTesting(temp_file); - - // Save the pairings, then load them using a different delegate to ensure - // that the test isn't passing due to cached values. Note that the delegate - // doesn't require that the strings it loads and saves are valid JSON, so - // we can simplify the test a bit. - std::string test_data = "test data"; - save_delegate->Save( - test_data, - base::Bind(&PairingRegistryDelegateLinuxTest::SaveComplete, - base::Unretained(this), - load_delegate.get(), - test_data)); - - run_loop.Run(); - - base::DeleteFile(temp_dir, true); -}; + new PairingRegistryDelegateLinux()); + save_delegate->SetRegistryPathForTesting(temp_registry_); + load_delegate->SetRegistryPathForTesting(temp_registry_); + + PairingRegistry::Pairing pairing(base::Time::Now(), "xxx", "xxx", "xxx"); + EXPECT_TRUE(save_delegate->Save(pairing)); + EXPECT_EQ(load_delegate->Load(pairing.client_id()), pairing); +} } // namespace remoting diff --git a/remoting/host/pairing_registry_delegate_mac.cc b/remoting/host/pairing_registry_delegate_mac.cc index 602a068..dba3ec3 100644 --- a/remoting/host/pairing_registry_delegate_mac.cc +++ b/remoting/host/pairing_registry_delegate_mac.cc @@ -10,8 +10,7 @@ namespace remoting { using protocol::PairingRegistry; -scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate( - scoped_refptr<base::TaskRunner> task_runner) { +scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate() { return scoped_ptr<PairingRegistry::Delegate>(); } diff --git a/remoting/host/pairing_registry_delegate_win.cc b/remoting/host/pairing_registry_delegate_win.cc index 602a068..dba3ec3 100644 --- a/remoting/host/pairing_registry_delegate_win.cc +++ b/remoting/host/pairing_registry_delegate_win.cc @@ -10,8 +10,7 @@ namespace remoting { using protocol::PairingRegistry; -scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate( - scoped_refptr<base::TaskRunner> task_runner) { +scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate() { return scoped_ptr<PairingRegistry::Delegate>(); } diff --git a/remoting/host/remoting_me2me_host.cc b/remoting/host/remoting_me2me_host.cc index 3af3123..af8861b 100644 --- a/remoting/host/remoting_me2me_host.cc +++ b/remoting/host/remoting_me2me_host.cc @@ -482,12 +482,8 @@ void HostProcess::CreateAuthenticatorFactory() { return; } - scoped_refptr<protocol::PairingRegistry> pairing_registry = NULL; - scoped_ptr<protocol::PairingRegistry::Delegate> delegate( - CreatePairingRegistryDelegate(context_->file_task_runner())); - if (delegate) { - pairing_registry = new protocol::PairingRegistry(delegate.Pass()); - } + scoped_refptr<protocol::PairingRegistry> pairing_registry = + CreatePairingRegistry(context_->file_task_runner()); scoped_ptr<protocol::AuthenticatorFactory> factory; diff --git a/remoting/host/setup/native_messaging_host_unittest.cc b/remoting/host/setup/native_messaging_host_unittest.cc index 1003b5e..09cf26e 100644 --- a/remoting/host/setup/native_messaging_host_unittest.cc +++ b/remoting/host/setup/native_messaging_host_unittest.cc @@ -20,8 +20,9 @@ #include "remoting/protocol/protocol_mock_objects.h" #include "testing/gtest/include/gtest/gtest.h" -using remoting::protocol::PairingRegistry; using remoting::protocol::MockPairingRegistryDelegate; +using remoting::protocol::PairingRegistry; +using remoting::protocol::SynchronousPairingRegistry; namespace { @@ -266,8 +267,9 @@ void NativeMessagingHostTest::SetUp() { daemon_controller_ = new MockDaemonController(); scoped_ptr<DaemonController> daemon_controller(daemon_controller_); - scoped_refptr<PairingRegistry> pairing_registry = new PairingRegistry( - scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate)); + scoped_refptr<PairingRegistry> pairing_registry = + new SynchronousPairingRegistry(scoped_ptr<PairingRegistry::Delegate>( + new MockPairingRegistryDelegate())); host_.reset(new NativeMessagingHost(daemon_controller.Pass(), pairing_registry, diff --git a/remoting/protocol/negotiating_authenticator_unittest.cc b/remoting/protocol/negotiating_authenticator_unittest.cc index b746432..bd68b70 100644 --- a/remoting/protocol/negotiating_authenticator_unittest.cc +++ b/remoting/protocol/negotiating_authenticator_unittest.cc @@ -84,14 +84,13 @@ class NegotiatingAuthenticatorTest : public AuthenticatorTestBase { } void CreatePairingRegistry(bool with_paired_client) { - mock_delegate_ = new MockPairingRegistryDelegate; - pairing_registry_ = new PairingRegistry( - scoped_ptr<PairingRegistry::Delegate>(mock_delegate_)); + pairing_registry_ = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>( + new MockPairingRegistryDelegate())); if (with_paired_client) { PairingRegistry::Pairing pairing( base::Time(), kTestClientName, kTestClientId, kTestPairedSecret); pairing_registry_->AddPairing(pairing); - mock_delegate_->RunCallback(); } } @@ -142,9 +141,6 @@ class NegotiatingAuthenticatorTest : public AuthenticatorTestBase { // Use a bare pointer because the storage is managed by the base class. NegotiatingClientAuthenticator* client_as_negotiating_authenticator_; - protected: - MockPairingRegistryDelegate* mock_delegate_; - private: scoped_refptr<PairingRegistry> pairing_registry_; @@ -217,7 +213,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingRevokedPinOkay) { kTestClientId, kTestPairedSecret, kTestPin, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyAccepted(AuthenticationMethod::Spake2Pair()); } @@ -227,7 +222,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingRevokedPinBad) { kTestClientId, kTestPairedSecret, kTestPinBad, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyRejected(Authenticator::INVALID_CREDENTIALS); } @@ -237,7 +231,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingSucceeded) { kTestClientId, kTestPairedSecret, kTestPinBad, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyAccepted(AuthenticationMethod::Spake2Pair()); } @@ -247,7 +240,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingSucceededInvalidSecretButPinOkay) { kTestClientId, kTestPairedSecretBad, kTestPin, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyAccepted(AuthenticationMethod::Spake2Pair()); } @@ -257,7 +249,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingFailedInvalidSecretAndPin) { kTestClientId, kTestPairedSecretBad, kTestPinBad, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyRejected(Authenticator::INVALID_CREDENTIALS); } diff --git a/remoting/protocol/pairing_registry.cc b/remoting/protocol/pairing_registry.cc index 6c6b1a2..a6e7327 100644 --- a/remoting/protocol/pairing_registry.cc +++ b/remoting/protocol/pairing_registry.cc @@ -8,7 +8,10 @@ #include "base/bind.h" #include "base/guid.h" #include "base/json/json_string_value_serializer.h" +#include "base/location.h" +#include "base/single_thread_task_runner.h" #include "base/strings/string_number_conversions.h" +#include "base/thread_task_runner_handle.h" #include "base/values.h" #include "crypto/random.h" @@ -36,6 +39,9 @@ PairingRegistry::Pairing::Pairing(const base::Time& created_time, shared_secret_(shared_secret) { } +PairingRegistry::Pairing::~Pairing() { +} + PairingRegistry::Pairing PairingRegistry::Pairing::Create( const std::string& client_name) { base::Time created_time = base::Time::Now(); @@ -50,7 +56,38 @@ PairingRegistry::Pairing PairingRegistry::Pairing::Create( return Pairing(created_time, client_name, client_id, shared_secret); } -PairingRegistry::Pairing::~Pairing() { +PairingRegistry::Pairing PairingRegistry::Pairing::CreateFromValue( + const base::Value& pairing_json) { + const base::DictionaryValue* pairing = NULL; + if (!pairing_json.GetAsDictionary(&pairing)) { + LOG(ERROR) << "Failed to load pairing information: not a dictionary."; + return Pairing(); + } + + std::string client_name, client_id; + double created_time_value; + if (pairing->GetDouble(kCreatedTimeKey, &created_time_value) && + pairing->GetString(kClientNameKey, &client_name) && + pairing->GetString(kClientIdKey, &client_id)) { + // The shared secret is optional. + std::string shared_secret; + pairing->GetString(kSharedSecretKey, &shared_secret); + base::Time created_time = base::Time::FromJsTime(created_time_value); + return Pairing(created_time, client_name, client_id, shared_secret); + } + + LOG(ERROR) << "Failed to load pairing information: unexpected format."; + return Pairing(); +} + +scoped_ptr<base::Value> PairingRegistry::Pairing::ToValue() const { + scoped_ptr<base::DictionaryValue> pairing(new base::DictionaryValue()); + pairing->SetDouble(kCreatedTimeKey, created_time().ToJsTime()); + pairing->SetString(kClientNameKey, client_name()); + pairing->SetString(kClientIdKey, client_id()); + if (!shared_secret().empty()) + pairing->SetString(kSharedSecretKey, shared_secret()); + return pairing.PassAs<base::Value>(); } bool PairingRegistry::Pairing::operator==(const Pairing& other) const { @@ -64,17 +101,19 @@ bool PairingRegistry::Pairing::is_valid() const { return !client_id_.empty() && !shared_secret_.empty(); } -PairingRegistry::PairingRegistry(scoped_ptr<Delegate> delegate) - : delegate_(delegate.Pass()) { +PairingRegistry::PairingRegistry( + scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner, + scoped_ptr<Delegate> delegate) + : caller_task_runner_(base::ThreadTaskRunnerHandle::Get()), + delegate_task_runner_(delegate_task_runner), + delegate_(delegate.Pass()) { DCHECK(delegate_); } -PairingRegistry::~PairingRegistry() { -} - PairingRegistry::Pairing PairingRegistry::CreatePairing( const std::string& client_name) { - DCHECK(CalledOnValidThread()); + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + Pairing result = Pairing::Create(client_name); AddPairing(result); return result; @@ -82,121 +121,124 @@ PairingRegistry::Pairing PairingRegistry::CreatePairing( void PairingRegistry::GetPairing(const std::string& client_id, const GetPairingCallback& callback) { - DCHECK(CalledOnValidThread()); + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + GetPairingCallback wrapped_callback = base::Bind( &PairingRegistry::InvokeGetPairingCallbackAndScheduleNext, this, callback); - LoadCallback load_callback = base::Bind( - &PairingRegistry::DoGetPairing, this, client_id, wrapped_callback); - // |Unretained| and |get| are both safe here because the delegate is owned - // by the pairing registry and so is guaranteed to exist when the request - // is serviced. base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoLoad, this, client_id, wrapped_callback); ServiceOrQueueRequest(request); } void PairingRegistry::GetAllPairings( const GetAllPairingsCallback& callback) { - DCHECK(CalledOnValidThread()); + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + GetAllPairingsCallback wrapped_callback = base::Bind( &PairingRegistry::InvokeGetAllPairingsCallbackAndScheduleNext, this, callback); - LoadCallback load_callback = base::Bind( - &PairingRegistry::SanitizePairings, this, wrapped_callback); + GetAllPairingsCallback sanitize_callback = base::Bind( + &PairingRegistry::SanitizePairings, + this, wrapped_callback); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoLoadAll, this, sanitize_callback); ServiceOrQueueRequest(request); } void PairingRegistry::DeletePairing( - const std::string& client_id, const SaveCallback& callback) { - DCHECK(CalledOnValidThread()); - SaveCallback wrapped_callback = base::Bind( - &PairingRegistry::InvokeSaveCallbackAndScheduleNext, + const std::string& client_id, const DoneCallback& callback) { + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + + DoneCallback wrapped_callback = base::Bind( + &PairingRegistry::InvokeDoneCallbackAndScheduleNext, this, callback); - LoadCallback load_callback = base::Bind( - &PairingRegistry::DoDeletePairing, this, client_id, wrapped_callback); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoDelete, this, client_id, wrapped_callback); ServiceOrQueueRequest(request); } void PairingRegistry::ClearAllPairings( - const SaveCallback& callback) { - DCHECK(CalledOnValidThread()); - SaveCallback wrapped_callback = base::Bind( - &PairingRegistry::InvokeSaveCallbackAndScheduleNext, + const DoneCallback& callback) { + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + + DoneCallback wrapped_callback = base::Bind( + &PairingRegistry::InvokeDoneCallbackAndScheduleNext, this, callback); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Save, - base::Unretained(delegate_.get()), - EncodeJson(PairedClients()), - wrapped_callback); + &PairingRegistry::DoDeleteAll, this, wrapped_callback); ServiceOrQueueRequest(request); } +PairingRegistry::~PairingRegistry() { +} + +void PairingRegistry::PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task) { + task_runner->PostTask(from_here, task); +} + void PairingRegistry::AddPairing(const Pairing& pairing) { - SaveCallback callback = base::Bind( - &PairingRegistry::InvokeSaveCallbackAndScheduleNext, - this, SaveCallback()); - LoadCallback load_callback = base::Bind( - &PairingRegistry::MergePairingAndSave, this, pairing, callback); + DoneCallback wrapped_callback = base::Bind( + &PairingRegistry::InvokeDoneCallbackAndScheduleNext, + this, DoneCallback()); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoSave, this, pairing, wrapped_callback); ServiceOrQueueRequest(request); } -void PairingRegistry::MergePairingAndSave(const Pairing& pairing, - const SaveCallback& callback, - const std::string& pairings_json) { - DCHECK(CalledOnValidThread()); - PairedClients clients = DecodeJson(pairings_json); - clients[pairing.client_id()] = pairing; - std::string new_pairings_json = EncodeJson(clients); - delegate_->Save(new_pairings_json, callback); +void PairingRegistry::DoLoadAll( + const protocol::PairingRegistry::GetAllPairingsCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + scoped_ptr<base::ListValue> pairings = delegate_->LoadAll(); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, + base::Passed(&pairings))); } -void PairingRegistry::DoGetPairing(const std::string& client_id, - const GetPairingCallback& callback, - const std::string& pairings_json) { - PairedClients clients = DecodeJson(pairings_json); - Pairing result = clients[client_id]; - callback.Run(result); +void PairingRegistry::DoDeleteAll( + const protocol::PairingRegistry::DoneCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + bool success = delegate_->DeleteAll(); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success)); } -void PairingRegistry::SanitizePairings(const GetAllPairingsCallback& callback, - const std::string& pairings_json) { - PairedClients clients = DecodeJson(pairings_json); - callback.Run(ConvertToListValue(clients, false)); +void PairingRegistry::DoLoad( + const std::string& client_id, + const protocol::PairingRegistry::GetPairingCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + Pairing pairing = delegate_->Load(client_id); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, pairing)); } -void PairingRegistry::DoDeletePairing(const std::string& client_id, - const SaveCallback& callback, - const std::string& pairings_json) { - PairedClients clients = DecodeJson(pairings_json); - clients.erase(client_id); - std::string new_pairings_json = EncodeJson(clients); - delegate_->Save(new_pairings_json, callback); +void PairingRegistry::DoSave( + const protocol::PairingRegistry::Pairing& pairing, + const protocol::PairingRegistry::DoneCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + bool success = delegate_->Save(pairing); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success)); } -void PairingRegistry::InvokeLoadCallbackAndScheduleNext( - const LoadCallback& callback, const std::string& pairings_json) { - callback.Run(pairings_json); - pending_requests_.pop(); - ServiceNextRequest(); +void PairingRegistry::DoDelete( + const std::string& client_id, + const protocol::PairingRegistry::DoneCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + bool success = delegate_->Delete(client_id); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success)); } -void PairingRegistry::InvokeSaveCallbackAndScheduleNext( - const SaveCallback& callback, bool success) { +void PairingRegistry::InvokeDoneCallbackAndScheduleNext( + const DoneCallback& callback, bool success) { // CreatePairing doesn't have a callback, so the callback can be null. - if (!callback.is_null()) { + if (!callback.is_null()) callback.Run(success); - } + pending_requests_.pop(); ServiceNextRequest(); } @@ -216,50 +258,35 @@ void PairingRegistry::InvokeGetAllPairingsCallbackAndScheduleNext( ServiceNextRequest(); } -// static -PairingRegistry::PairedClients PairingRegistry::DecodeJson( - const std::string& pairings_json) { - PairedClients result; - - if (pairings_json.empty()) { - return result; - } - - JSONStringValueSerializer registry(pairings_json); - int error_code; - std::string error_message; - scoped_ptr<base::Value> root( - registry.Deserialize(&error_code, &error_message)); - if (!root) { - LOG(ERROR) << "Failed to load paired clients: " << error_message - << " (" << error_code << ")."; - return result; - } - - base::ListValue* root_list = NULL; - if (!root->GetAsList(&root_list)) { - LOG(ERROR) << "Failed to load paired clients: root node is not a list."; - return result; - } +void PairingRegistry::SanitizePairings(const GetAllPairingsCallback& callback, + scoped_ptr<base::ListValue> pairings) { + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + + scoped_ptr<base::ListValue> sanitized_pairings(new base::ListValue()); + for (size_t i = 0; i < pairings->GetSize(); ++i) { + DictionaryValue* pairing_json; + if (!pairings->GetDictionary(i, &pairing_json)) { + LOG(WARNING) << "A pairing entry is not a dictionary."; + continue; + } - for (size_t i = 0; i < root_list->GetSize(); ++i) { - base::DictionaryValue* pairing = NULL; - std::string client_name, client_id, shared_secret; - double created_time_value; - if (root_list->GetDictionary(i, &pairing) && - pairing->GetDouble(kCreatedTimeKey, &created_time_value) && - pairing->GetString(kClientNameKey, &client_name) && - pairing->GetString(kClientIdKey, &client_id) && - pairing->GetString(kSharedSecretKey, &shared_secret)) { - base::Time created_time = base::Time::FromJsTime(created_time_value); - result[client_id] = Pairing( - created_time, client_name, client_id, shared_secret); - } else { - LOG(ERROR) << "Paired client " << i << " has unexpected format."; + // Parse the pairing data. + Pairing pairing = Pairing::CreateFromValue(*pairing_json); + if (!pairing.is_valid()) { + LOG(WARNING) << "Could not parse a pairing entry."; + continue; } + + // Clear the shared secrect and append the pairing data to the list. + Pairing sanitized_pairing( + pairing.created_time(), + pairing.client_name(), + pairing.client_id(), + ""); + sanitized_pairings->Append(sanitized_pairing.ToValue().release()); } - return result; + callback.Run(sanitized_pairings.Pass()); } void PairingRegistry::ServiceOrQueueRequest(const base::Closure& request) { @@ -271,40 +298,10 @@ void PairingRegistry::ServiceOrQueueRequest(const base::Closure& request) { } void PairingRegistry::ServiceNextRequest() { - if (pending_requests_.empty()) { + if (pending_requests_.empty()) return; - } - base::Closure request = pending_requests_.front(); - request.Run(); -} -// static -std::string PairingRegistry::EncodeJson(const PairedClients& clients) { - scoped_ptr<base::ListValue> root = ConvertToListValue(clients, true); - std::string result; - JSONStringValueSerializer serializer(&result); - serializer.Serialize(*root); - - return result; -} - -// static -scoped_ptr<base::ListValue> PairingRegistry::ConvertToListValue( - const PairedClients& clients, - bool include_shared_secrets) { - scoped_ptr<base::ListValue> root(new base::ListValue()); - for (PairedClients::const_iterator i = clients.begin(); - i != clients.end(); ++i) { - base::DictionaryValue* pairing = new base::DictionaryValue(); - pairing->SetDouble(kCreatedTimeKey, i->second.created_time().ToJsTime()); - pairing->SetString(kClientNameKey, i->second.client_name()); - pairing->SetString(kClientIdKey, i->second.client_id()); - if (include_shared_secrets) { - pairing->SetString(kSharedSecretKey, i->second.shared_secret()); - } - root->Append(pairing); - } - return root.Pass(); + PostTask(delegate_task_runner_, FROM_HERE, pending_requests_.front()); } } // namespace protocol diff --git a/remoting/protocol/pairing_registry.h b/remoting/protocol/pairing_registry.h index fc63e84..ddcd736 100644 --- a/remoting/protocol/pairing_registry.h +++ b/remoting/protocol/pairing_registry.h @@ -14,13 +14,18 @@ #include "base/gtest_prod_util.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "base/threading/non_thread_safe.h" #include "base/time/time.h" namespace base { class ListValue; +class Value; +class SingleThreadTaskRunner; } // namespace base +namespace tracked_objects { +class Location; +} // namespace tracked_objects + namespace remoting { namespace protocol { @@ -33,8 +38,7 @@ namespace protocol { // class and sent in plain-text by the client during authentication. // * The shared secret for the client. This is generated on-demand by this // class and used in the SPAKE2 exchange to mutually verify identity. -class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, - public base::NonThreadSafe { +class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry> { public: struct Pairing { Pairing(); @@ -45,6 +49,9 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, ~Pairing(); static Pairing Create(const std::string& client_name); + static Pairing CreateFromValue(const base::Value& pairing_json); + + scoped_ptr<base::Value> ToValue() const; bool operator==(const Pairing& other) const; @@ -66,11 +73,10 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, typedef std::map<std::string, Pairing> PairedClients; // Delegate callbacks. - typedef base::Callback<void(const std::string& pairings_json)> LoadCallback; - typedef base::Callback<void(bool success)> SaveCallback; - typedef base::Callback<void(Pairing pairing)> GetPairingCallback; + typedef base::Callback<void(bool success)> DoneCallback; typedef base::Callback<void(scoped_ptr<base::ListValue> pairings)> GetAllPairingsCallback; + typedef base::Callback<void(Pairing pairing)> GetPairingCallback; static const char kCreatedTimeKey[]; static const char kClientIdKey[]; @@ -82,18 +88,25 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, public: virtual ~Delegate() {} - // Save JSON-encoded pairing information to persistent storage. If - // a non-NULL callback is provided, invoke it on completion to - // indicate success or failure. Must not block. - virtual void Save(const std::string& pairings_json, - const SaveCallback& callback) = 0; + // Retrieves all JSON-encoded pairings from persistent storage. + virtual scoped_ptr<base::ListValue> LoadAll() = 0; + + // Deletes all pairings in persistent storage. + virtual bool DeleteAll() = 0; - // Retrieve the JSON-encoded pairing information from persistent - // storage. Must not block. - virtual void Load(const LoadCallback& callback) = 0; + // Retrieves the pairing identified by |client_id|. + virtual Pairing Load(const std::string& client_id) = 0; + + // Saves |pairing| to persistent storage. + virtual bool Save(const Pairing& pairing) = 0; + + // Deletes the pairing identified by |client_id|. + virtual bool Delete(const std::string& client_id) = 0; }; - explicit PairingRegistry(scoped_ptr<Delegate> delegate); + PairingRegistry( + scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner, + scoped_ptr<Delegate> delegate); // Creates a pairing for a new client and saves it to disk. // @@ -115,58 +128,68 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, // the result of saving the new config, which occurs even if the client ID // did not match any pairing. void DeletePairing(const std::string& client_id, - const SaveCallback& callback); + const DoneCallback& callback); // Clear all pairings from the registry. - void ClearAllPairings(const SaveCallback& callback); + void ClearAllPairings(const DoneCallback& callback); + + protected: + friend class base::RefCountedThreadSafe<PairingRegistry>; + virtual ~PairingRegistry(); + + // Lets the tests override task posting to make all callbacks synchronous. + virtual void PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task); private: FRIEND_TEST_ALL_PREFIXES(PairingRegistryTest, AddPairing); - FRIEND_TEST_ALL_PREFIXES(PairingRegistryTest, GetAllPairingsJSON); friend class NegotiatingAuthenticatorTest; - friend class base::RefCountedThreadSafe<PairingRegistry>; - - virtual ~PairingRegistry(); // Helper method for unit tests. void AddPairing(const Pairing& pairing); - // Worker functions for each of the public methods, passed as a callback to - // the delegate. - void MergePairingAndSave(const Pairing& pairing, - const SaveCallback& callback, - const std::string& pairings_json); - void DoGetPairing(const std::string& client_id, - const GetPairingCallback& callback, - const std::string& pairings_json); - void SanitizePairings(const GetAllPairingsCallback& callback, - const std::string& pairings_json); - void DoDeletePairing(const std::string& client_id, - const SaveCallback& callback, - const std::string& pairings_json); + // Blocking helper methods used to call the delegate. + void DoLoadAll( + const protocol::PairingRegistry::GetAllPairingsCallback& callback); + void DoDeleteAll( + const protocol::PairingRegistry::DoneCallback& callback); + void DoLoad( + const std::string& client_id, + const protocol::PairingRegistry::GetPairingCallback& callback); + void DoSave( + const protocol::PairingRegistry::Pairing& pairing, + const protocol::PairingRegistry::DoneCallback& callback); + void DoDelete( + const std::string& client_id, + const protocol::PairingRegistry::DoneCallback& callback); // "Trampoline" callbacks that schedule the next pending request and then // invoke the original caller-supplied callback. - void InvokeLoadCallbackAndScheduleNext( - const LoadCallback& callback, const std::string& pairings_json); - void InvokeSaveCallbackAndScheduleNext( - const SaveCallback& callback, bool success); + void InvokeDoneCallbackAndScheduleNext( + const DoneCallback& callback, bool success); void InvokeGetPairingCallbackAndScheduleNext( const GetPairingCallback& callback, Pairing pairing); void InvokeGetAllPairingsCallbackAndScheduleNext( const GetAllPairingsCallback& callback, scoped_ptr<base::ListValue> pairings); + // Sanitize |pairings| by parsing each entry and removing the secret from it. + void SanitizePairings(const GetAllPairingsCallback& callback, + scoped_ptr<base::ListValue> pairings); + // Queue management methods. void ServiceOrQueueRequest(const base::Closure& request); void ServiceNextRequest(); - // Translate between the structured and serialized forms of the pairing data. - static PairedClients DecodeJson(const std::string& pairings_json); - static std::string EncodeJson(const PairedClients& clients); - static scoped_ptr<base::ListValue> ConvertToListValue( - const PairedClients& clients, - bool include_shared_secrets); + // Task runner on which all public methods of this class should be called. + scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_; + + // Task runner used to run blocking calls to the delegate. A single thread + // task runner is used to guarantee that one one method of the delegate is + // called at a time. + scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner_; scoped_ptr<Delegate> delegate_; diff --git a/remoting/protocol/pairing_registry_unittest.cc b/remoting/protocol/pairing_registry_unittest.cc index f564256..eefea2a 100644 --- a/remoting/protocol/pairing_registry_unittest.cc +++ b/remoting/protocol/pairing_registry_unittest.cc @@ -11,15 +11,37 @@ #include "base/bind.h" #include "base/compiler_specific.h" #include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/thread_task_runner_handle.h" #include "base/values.h" #include "remoting/protocol/protocol_mock_objects.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +using testing::Sequence; + namespace { using remoting::protocol::PairingRegistry; +class MockPairingRegistryCallbacks { + public: + MockPairingRegistryCallbacks() {} + virtual ~MockPairingRegistryCallbacks() {} + + MOCK_METHOD1(DoneCallback, void(bool)); + MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*)); + MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing)); + + void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) { + GetAllPairingsCallbackPtr(pairings.get()); + } + + private: + DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks); +}; + // Verify that a pairing Dictionary has correct entries, but doesn't include // any shared secret. void VerifyPairing(PairingRegistry::Pairing expected, @@ -59,32 +81,19 @@ class PairingRegistryTest : public testing::Test { ++callback_count_; } - void ExpectClientName(const std::string& expected, - PairingRegistry::Pairing actual) { - EXPECT_EQ(expected, actual.client_name()); - ++callback_count_; - } - - void ExpectNoPairings(scoped_ptr<base::ListValue> pairings) { - EXPECT_TRUE(pairings->empty()); - ++callback_count_; - } - protected: + base::MessageLoop message_loop_; + base::RunLoop run_loop_; + int callback_count_; scoped_ptr<base::ListValue> pairings_; }; TEST_F(PairingRegistryTest, CreateAndGetPairings) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client"); - mock_delegate->RunCallback(); EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret()); @@ -92,7 +101,6 @@ TEST_F(PairingRegistryTest, CreateAndGetPairings) { base::Bind(&PairingRegistryTest::ExpectSecret, base::Unretained(this), pairing_1.shared_secret())); - mock_delegate->RunCallback(); EXPECT_EQ(1, callback_count_); // Check that the second client is paired with a different shared secret. @@ -100,25 +108,18 @@ TEST_F(PairingRegistryTest, CreateAndGetPairings) { base::Bind(&PairingRegistryTest::ExpectSecret, base::Unretained(this), pairing_2.shared_secret())); - mock_delegate->RunCallback(); EXPECT_EQ(2, callback_count_); } TEST_F(PairingRegistryTest, GetAllPairings) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); - mock_delegate->RunCallback(); registry->GetAllPairings( base::Bind(&PairingRegistryTest::set_pairings, base::Unretained(this))); - mock_delegate->RunCallback(); ASSERT_EQ(2u, pairings_->GetSize()); const base::DictionaryValue* actual_pairing_1; @@ -139,27 +140,20 @@ TEST_F(PairingRegistryTest, GetAllPairings) { } TEST_F(PairingRegistryTest, DeletePairing) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); - mock_delegate->RunCallback(); registry->DeletePairing( pairing_1.client_id(), base::Bind(&PairingRegistryTest::ExpectSaveSuccess, base::Unretained(this))); - mock_delegate->RunCallback(); // Re-read the list, and verify it only has the pairing_2 client. registry->GetAllPairings( base::Bind(&PairingRegistryTest::set_pairings, base::Unretained(this))); - mock_delegate->RunCallback(); ASSERT_EQ(1u, pairings_->GetSize()); const base::DictionaryValue* actual_pairing_2; @@ -171,15 +165,10 @@ TEST_F(PairingRegistryTest, DeletePairing) { } TEST_F(PairingRegistryTest, ClearAllPairings) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); - mock_delegate->RunCallback(); registry->ClearAllPairings( base::Bind(&PairingRegistryTest::ExpectSaveSuccess, @@ -189,57 +178,81 @@ TEST_F(PairingRegistryTest, ClearAllPairings) { registry->GetAllPairings( base::Bind(&PairingRegistryTest::set_pairings, base::Unretained(this))); - mock_delegate->RunCallback(); EXPECT_TRUE(pairings_->empty()); } -TEST_F(PairingRegistryTest, SerializedRequests) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - mock_delegate->set_run_save_callback_automatically(false); +ACTION_P(QuitMessageLoop, callback) { + callback.Run(); +} + +MATCHER_P(EqualsClientName, client_name, "") { + return arg.client_name() == client_name; +} + +MATCHER(NoPairings, "") { + return arg->empty(); +} - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); +TEST_F(PairingRegistryTest, SerializedRequests) { + MockPairingRegistryCallbacks callbacks; + Sequence s; + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1"))) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2"))) + .InSequence(s); + EXPECT_CALL(callbacks, DoneCallback(true)) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1"))) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName(""))) + .InSequence(s); + EXPECT_CALL(callbacks, DoneCallback(true)) + .InSequence(s); + EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings())) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3"))) + .InSequence(s) + .WillOnce(QuitMessageLoop(run_loop_.QuitClosure())); + + scoped_refptr<PairingRegistry> registry = new PairingRegistry( + base::ThreadTaskRunnerHandle::Get(), + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); registry->GetPairing( pairing_1.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client1")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->GetPairing( pairing_2.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client2")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->DeletePairing( pairing_2.client_id(), - base::Bind(&PairingRegistryTest::ExpectSaveSuccess, - base::Unretained(this))); + base::Bind(&MockPairingRegistryCallbacks::DoneCallback, + base::Unretained(&callbacks))); registry->GetPairing( pairing_1.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client1")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->GetPairing( pairing_2.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->ClearAllPairings( - base::Bind(&PairingRegistryTest::ExpectSaveSuccess, - base::Unretained(this))); + base::Bind(&MockPairingRegistryCallbacks::DoneCallback, + base::Unretained(&callbacks))); registry->GetAllPairings( - base::Bind(&PairingRegistryTest::ExpectNoPairings, - base::Unretained(this))); + base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback, + base::Unretained(&callbacks))); PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3"); registry->GetPairing( pairing_3.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client3")); - - while (mock_delegate->HasCallback()) { - mock_delegate->RunCallback(); - } + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); - EXPECT_EQ(8, callback_count_); + run_loop_.Run(); } } // namespace protocol diff --git a/remoting/protocol/protocol_mock_objects.cc b/remoting/protocol/protocol_mock_objects.cc index efde603d..93f6e34 100644 --- a/remoting/protocol/protocol_mock_objects.cc +++ b/remoting/protocol/protocol_mock_objects.cc @@ -4,6 +4,9 @@ #include "remoting/protocol/protocol_mock_objects.h" +#include "base/logging.h" +#include "base/thread_task_runner_handle.h" + namespace remoting { namespace protocol { @@ -48,57 +51,61 @@ MockSessionManager::MockSessionManager() {} MockSessionManager::~MockSessionManager() {} -MockPairingRegistryDelegate::MockPairingRegistryDelegate() - : run_save_callback_automatically_(true) { +MockPairingRegistryDelegate::MockPairingRegistryDelegate() { } MockPairingRegistryDelegate::~MockPairingRegistryDelegate() { } -void MockPairingRegistryDelegate::Save( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) { - EXPECT_TRUE(load_callback_.is_null()); - EXPECT_TRUE(save_callback_.is_null()); - if (run_save_callback_automatically_) { - SetPairingsJsonAndRunCallback(pairings_json, callback); - } else { - save_callback_ = base::Bind( - &MockPairingRegistryDelegate::SetPairingsJsonAndRunCallback, - base::Unretained(this), pairings_json, callback); +scoped_ptr<base::ListValue> MockPairingRegistryDelegate::LoadAll() { + scoped_ptr<base::ListValue> result(new base::ListValue()); + for (Pairings::const_iterator i = pairings_.begin(); i != pairings_.end(); + ++i) { + result->Append(i->second.ToValue().release()); } + return result.Pass(); +} + +bool MockPairingRegistryDelegate::DeleteAll() { + pairings_.clear(); + return true; } -void MockPairingRegistryDelegate::SetPairingsJsonAndRunCallback( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) { - pairings_json_ = pairings_json; - if (!callback.is_null()) { - callback.Run(true); +protocol::PairingRegistry::Pairing MockPairingRegistryDelegate::Load( + const std::string& client_id) { + Pairings::const_iterator i = pairings_.find(client_id); + if (i != pairings_.end()) { + return i->second; + } else { + return protocol::PairingRegistry::Pairing(); } } -void MockPairingRegistryDelegate::Load( - const PairingRegistry::LoadCallback& callback) { - EXPECT_TRUE(load_callback_.is_null()); - EXPECT_TRUE(save_callback_.is_null()); - load_callback_ = base::Bind(callback, pairings_json_); +bool MockPairingRegistryDelegate::Save( + const protocol::PairingRegistry::Pairing& pairing) { + pairings_[pairing.client_id()] = pairing; + return true; } -void MockPairingRegistryDelegate::RunCallback() { - if (!load_callback_.is_null()) { - EXPECT_TRUE(save_callback_.is_null()); - base::Closure load_callback = load_callback_; - load_callback_.Reset(); - load_callback.Run(); - } else if (!save_callback_.is_null()) { - EXPECT_TRUE(load_callback_.is_null()); - base::Closure save_callback = save_callback_; - save_callback_.Reset(); - save_callback.Run(); - } else { - ADD_FAILURE() << "RunCallback called without any callbacks set."; - } +bool MockPairingRegistryDelegate::Delete(const std::string& client_id) { + pairings_.erase(client_id); + return true; +} + +SynchronousPairingRegistry::SynchronousPairingRegistry( + scoped_ptr<Delegate> delegate) + : PairingRegistry(base::ThreadTaskRunnerHandle::Get(), delegate.Pass()) { +} + +SynchronousPairingRegistry::~SynchronousPairingRegistry() { +} + +void SynchronousPairingRegistry::PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task) { + DCHECK(task_runner->BelongsToCurrentThread()); + task.Run(); } } // namespace protocol diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index cfb912b..74435db 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -5,8 +5,12 @@ #ifndef REMOTING_PROTOCOL_PROTOCOL_MOCK_OBJECTS_H_ #define REMOTING_PROTOCOL_PROTOCOL_MOCK_OBJECTS_H_ +#include <map> #include <string> +#include "base/location.h" +#include "base/single_thread_task_runner.h" +#include "base/values.h" #include "net/base/ip_endpoint.h" #include "remoting/proto/internal.pb.h" #include "remoting/proto/video.pb.h" @@ -211,44 +215,31 @@ class MockPairingRegistryDelegate : public PairingRegistry::Delegate { MockPairingRegistryDelegate(); virtual ~MockPairingRegistryDelegate(); - const std::string& pairings_json() const { - return pairings_json_; - } - // PairingRegistry::Delegate implementation. - virtual void Save( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) OVERRIDE; - virtual void Load( - const PairingRegistry::LoadCallback& callback) OVERRIDE; - - // By default, the Save method runs its callback automatically because the - // negotiating authenticator unit test does not provide any hooks to do it - // manually. For unit tests that need to verify correct behaviour under - // asynchronous conditions, use this method to disable this feature and call - // RunCallback as appropriate. - void set_run_save_callback_automatically( - bool run_save_callback_automatically) { - run_save_callback_automatically_ = run_save_callback_automatically; - } + virtual scoped_ptr<base::ListValue> LoadAll() OVERRIDE; + virtual bool DeleteAll() OVERRIDE; + virtual protocol::PairingRegistry::Pairing Load( + const std::string& client_id) OVERRIDE; + virtual bool Save(const protocol::PairingRegistry::Pairing& pairing) OVERRIDE; + virtual bool Delete(const std::string& client_id) OVERRIDE; - bool HasCallback() const { - return !load_callback_.is_null() || !save_callback_.is_null(); - } + private: + typedef std::map<std::string, protocol::PairingRegistry::Pairing> Pairings; + Pairings pairings_; +}; + +class SynchronousPairingRegistry : public PairingRegistry { + public: + explicit SynchronousPairingRegistry(scoped_ptr<Delegate> delegate); - // Run either the save or the load callback (whichever was set most recently; - // it is an error for both of these to be set at the same time). - void RunCallback(); + protected: + virtual ~SynchronousPairingRegistry(); - private: - void SetPairingsJsonAndRunCallback( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback); - - base::Closure load_callback_; - base::Closure save_callback_; - std::string pairings_json_; - bool run_save_callback_automatically_; + // Runs tasks synchronously instead of posting them to |task_runner|. + virtual void PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task) OVERRIDE; }; } // namespace protocol |