diff options
author | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-08-02 01:35:16 +0000 |
---|---|---|
committer | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-08-02 01:35:16 +0000 |
commit | 3784660319748aed36a0d243969feed846aad380 (patch) | |
tree | d6280de3ff37d1c5b3dd40a86a135f94c49785af /remoting/protocol | |
parent | da12d4cbd28bb214212db5bac896699f77efae9c (diff) | |
download | chromium_src-3784660319748aed36a0d243969feed846aad380.zip chromium_src-3784660319748aed36a0d243969feed846aad380.tar.gz chromium_src-3784660319748aed36a0d243969feed846aad380.tar.bz2 |
Refactored PairingRegistry::Delegate such that it can retrieve/modify for a single client.
The delegate implementation for Linux now keeps a separate JSON file per each pairing under "paired-clients" in the config directory.
This CL also makes the delegate completely synchronous moving jumping between task runners to the parent PairingRegistry.
BUG=156182
Review URL: https://chromiumcodereview.appspot.com/21128006
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@215187 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting/protocol')
-rw-r--r-- | remoting/protocol/negotiating_authenticator_unittest.cc | 15 | ||||
-rw-r--r-- | remoting/protocol/pairing_registry.cc | 295 | ||||
-rw-r--r-- | remoting/protocol/pairing_registry.h | 111 | ||||
-rw-r--r-- | remoting/protocol/pairing_registry_unittest.cc | 157 | ||||
-rw-r--r-- | remoting/protocol/protocol_mock_objects.cc | 83 | ||||
-rw-r--r-- | remoting/protocol/protocol_mock_objects.h | 59 |
6 files changed, 371 insertions, 349 deletions
diff --git a/remoting/protocol/negotiating_authenticator_unittest.cc b/remoting/protocol/negotiating_authenticator_unittest.cc index b746432..bd68b70 100644 --- a/remoting/protocol/negotiating_authenticator_unittest.cc +++ b/remoting/protocol/negotiating_authenticator_unittest.cc @@ -84,14 +84,13 @@ class NegotiatingAuthenticatorTest : public AuthenticatorTestBase { } void CreatePairingRegistry(bool with_paired_client) { - mock_delegate_ = new MockPairingRegistryDelegate; - pairing_registry_ = new PairingRegistry( - scoped_ptr<PairingRegistry::Delegate>(mock_delegate_)); + pairing_registry_ = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>( + new MockPairingRegistryDelegate())); if (with_paired_client) { PairingRegistry::Pairing pairing( base::Time(), kTestClientName, kTestClientId, kTestPairedSecret); pairing_registry_->AddPairing(pairing); - mock_delegate_->RunCallback(); } } @@ -142,9 +141,6 @@ class NegotiatingAuthenticatorTest : public AuthenticatorTestBase { // Use a bare pointer because the storage is managed by the base class. NegotiatingClientAuthenticator* client_as_negotiating_authenticator_; - protected: - MockPairingRegistryDelegate* mock_delegate_; - private: scoped_refptr<PairingRegistry> pairing_registry_; @@ -217,7 +213,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingRevokedPinOkay) { kTestClientId, kTestPairedSecret, kTestPin, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyAccepted(AuthenticationMethod::Spake2Pair()); } @@ -227,7 +222,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingRevokedPinBad) { kTestClientId, kTestPairedSecret, kTestPinBad, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyRejected(Authenticator::INVALID_CREDENTIALS); } @@ -237,7 +231,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingSucceeded) { kTestClientId, kTestPairedSecret, kTestPinBad, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyAccepted(AuthenticationMethod::Spake2Pair()); } @@ -247,7 +240,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingSucceededInvalidSecretButPinOkay) { kTestClientId, kTestPairedSecretBad, kTestPin, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyAccepted(AuthenticationMethod::Spake2Pair()); } @@ -257,7 +249,6 @@ TEST_F(NegotiatingAuthenticatorTest, PairingFailedInvalidSecretAndPin) { kTestClientId, kTestPairedSecretBad, kTestPinBad, kTestPin, AuthenticationMethod::HMAC_SHA256, false)); ASSERT_NO_FATAL_FAILURE(RunAuthExchange()); - mock_delegate_->RunCallback(); VerifyRejected(Authenticator::INVALID_CREDENTIALS); } diff --git a/remoting/protocol/pairing_registry.cc b/remoting/protocol/pairing_registry.cc index 6c6b1a2..a6e7327 100644 --- a/remoting/protocol/pairing_registry.cc +++ b/remoting/protocol/pairing_registry.cc @@ -8,7 +8,10 @@ #include "base/bind.h" #include "base/guid.h" #include "base/json/json_string_value_serializer.h" +#include "base/location.h" +#include "base/single_thread_task_runner.h" #include "base/strings/string_number_conversions.h" +#include "base/thread_task_runner_handle.h" #include "base/values.h" #include "crypto/random.h" @@ -36,6 +39,9 @@ PairingRegistry::Pairing::Pairing(const base::Time& created_time, shared_secret_(shared_secret) { } +PairingRegistry::Pairing::~Pairing() { +} + PairingRegistry::Pairing PairingRegistry::Pairing::Create( const std::string& client_name) { base::Time created_time = base::Time::Now(); @@ -50,7 +56,38 @@ PairingRegistry::Pairing PairingRegistry::Pairing::Create( return Pairing(created_time, client_name, client_id, shared_secret); } -PairingRegistry::Pairing::~Pairing() { +PairingRegistry::Pairing PairingRegistry::Pairing::CreateFromValue( + const base::Value& pairing_json) { + const base::DictionaryValue* pairing = NULL; + if (!pairing_json.GetAsDictionary(&pairing)) { + LOG(ERROR) << "Failed to load pairing information: not a dictionary."; + return Pairing(); + } + + std::string client_name, client_id; + double created_time_value; + if (pairing->GetDouble(kCreatedTimeKey, &created_time_value) && + pairing->GetString(kClientNameKey, &client_name) && + pairing->GetString(kClientIdKey, &client_id)) { + // The shared secret is optional. + std::string shared_secret; + pairing->GetString(kSharedSecretKey, &shared_secret); + base::Time created_time = base::Time::FromJsTime(created_time_value); + return Pairing(created_time, client_name, client_id, shared_secret); + } + + LOG(ERROR) << "Failed to load pairing information: unexpected format."; + return Pairing(); +} + +scoped_ptr<base::Value> PairingRegistry::Pairing::ToValue() const { + scoped_ptr<base::DictionaryValue> pairing(new base::DictionaryValue()); + pairing->SetDouble(kCreatedTimeKey, created_time().ToJsTime()); + pairing->SetString(kClientNameKey, client_name()); + pairing->SetString(kClientIdKey, client_id()); + if (!shared_secret().empty()) + pairing->SetString(kSharedSecretKey, shared_secret()); + return pairing.PassAs<base::Value>(); } bool PairingRegistry::Pairing::operator==(const Pairing& other) const { @@ -64,17 +101,19 @@ bool PairingRegistry::Pairing::is_valid() const { return !client_id_.empty() && !shared_secret_.empty(); } -PairingRegistry::PairingRegistry(scoped_ptr<Delegate> delegate) - : delegate_(delegate.Pass()) { +PairingRegistry::PairingRegistry( + scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner, + scoped_ptr<Delegate> delegate) + : caller_task_runner_(base::ThreadTaskRunnerHandle::Get()), + delegate_task_runner_(delegate_task_runner), + delegate_(delegate.Pass()) { DCHECK(delegate_); } -PairingRegistry::~PairingRegistry() { -} - PairingRegistry::Pairing PairingRegistry::CreatePairing( const std::string& client_name) { - DCHECK(CalledOnValidThread()); + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + Pairing result = Pairing::Create(client_name); AddPairing(result); return result; @@ -82,121 +121,124 @@ PairingRegistry::Pairing PairingRegistry::CreatePairing( void PairingRegistry::GetPairing(const std::string& client_id, const GetPairingCallback& callback) { - DCHECK(CalledOnValidThread()); + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + GetPairingCallback wrapped_callback = base::Bind( &PairingRegistry::InvokeGetPairingCallbackAndScheduleNext, this, callback); - LoadCallback load_callback = base::Bind( - &PairingRegistry::DoGetPairing, this, client_id, wrapped_callback); - // |Unretained| and |get| are both safe here because the delegate is owned - // by the pairing registry and so is guaranteed to exist when the request - // is serviced. base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoLoad, this, client_id, wrapped_callback); ServiceOrQueueRequest(request); } void PairingRegistry::GetAllPairings( const GetAllPairingsCallback& callback) { - DCHECK(CalledOnValidThread()); + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + GetAllPairingsCallback wrapped_callback = base::Bind( &PairingRegistry::InvokeGetAllPairingsCallbackAndScheduleNext, this, callback); - LoadCallback load_callback = base::Bind( - &PairingRegistry::SanitizePairings, this, wrapped_callback); + GetAllPairingsCallback sanitize_callback = base::Bind( + &PairingRegistry::SanitizePairings, + this, wrapped_callback); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoLoadAll, this, sanitize_callback); ServiceOrQueueRequest(request); } void PairingRegistry::DeletePairing( - const std::string& client_id, const SaveCallback& callback) { - DCHECK(CalledOnValidThread()); - SaveCallback wrapped_callback = base::Bind( - &PairingRegistry::InvokeSaveCallbackAndScheduleNext, + const std::string& client_id, const DoneCallback& callback) { + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + + DoneCallback wrapped_callback = base::Bind( + &PairingRegistry::InvokeDoneCallbackAndScheduleNext, this, callback); - LoadCallback load_callback = base::Bind( - &PairingRegistry::DoDeletePairing, this, client_id, wrapped_callback); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoDelete, this, client_id, wrapped_callback); ServiceOrQueueRequest(request); } void PairingRegistry::ClearAllPairings( - const SaveCallback& callback) { - DCHECK(CalledOnValidThread()); - SaveCallback wrapped_callback = base::Bind( - &PairingRegistry::InvokeSaveCallbackAndScheduleNext, + const DoneCallback& callback) { + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + + DoneCallback wrapped_callback = base::Bind( + &PairingRegistry::InvokeDoneCallbackAndScheduleNext, this, callback); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Save, - base::Unretained(delegate_.get()), - EncodeJson(PairedClients()), - wrapped_callback); + &PairingRegistry::DoDeleteAll, this, wrapped_callback); ServiceOrQueueRequest(request); } +PairingRegistry::~PairingRegistry() { +} + +void PairingRegistry::PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task) { + task_runner->PostTask(from_here, task); +} + void PairingRegistry::AddPairing(const Pairing& pairing) { - SaveCallback callback = base::Bind( - &PairingRegistry::InvokeSaveCallbackAndScheduleNext, - this, SaveCallback()); - LoadCallback load_callback = base::Bind( - &PairingRegistry::MergePairingAndSave, this, pairing, callback); + DoneCallback wrapped_callback = base::Bind( + &PairingRegistry::InvokeDoneCallbackAndScheduleNext, + this, DoneCallback()); base::Closure request = base::Bind( - &PairingRegistry::Delegate::Load, - base::Unretained(delegate_.get()), load_callback); + &PairingRegistry::DoSave, this, pairing, wrapped_callback); ServiceOrQueueRequest(request); } -void PairingRegistry::MergePairingAndSave(const Pairing& pairing, - const SaveCallback& callback, - const std::string& pairings_json) { - DCHECK(CalledOnValidThread()); - PairedClients clients = DecodeJson(pairings_json); - clients[pairing.client_id()] = pairing; - std::string new_pairings_json = EncodeJson(clients); - delegate_->Save(new_pairings_json, callback); +void PairingRegistry::DoLoadAll( + const protocol::PairingRegistry::GetAllPairingsCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + scoped_ptr<base::ListValue> pairings = delegate_->LoadAll(); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, + base::Passed(&pairings))); } -void PairingRegistry::DoGetPairing(const std::string& client_id, - const GetPairingCallback& callback, - const std::string& pairings_json) { - PairedClients clients = DecodeJson(pairings_json); - Pairing result = clients[client_id]; - callback.Run(result); +void PairingRegistry::DoDeleteAll( + const protocol::PairingRegistry::DoneCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + bool success = delegate_->DeleteAll(); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success)); } -void PairingRegistry::SanitizePairings(const GetAllPairingsCallback& callback, - const std::string& pairings_json) { - PairedClients clients = DecodeJson(pairings_json); - callback.Run(ConvertToListValue(clients, false)); +void PairingRegistry::DoLoad( + const std::string& client_id, + const protocol::PairingRegistry::GetPairingCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + Pairing pairing = delegate_->Load(client_id); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, pairing)); } -void PairingRegistry::DoDeletePairing(const std::string& client_id, - const SaveCallback& callback, - const std::string& pairings_json) { - PairedClients clients = DecodeJson(pairings_json); - clients.erase(client_id); - std::string new_pairings_json = EncodeJson(clients); - delegate_->Save(new_pairings_json, callback); +void PairingRegistry::DoSave( + const protocol::PairingRegistry::Pairing& pairing, + const protocol::PairingRegistry::DoneCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + bool success = delegate_->Save(pairing); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success)); } -void PairingRegistry::InvokeLoadCallbackAndScheduleNext( - const LoadCallback& callback, const std::string& pairings_json) { - callback.Run(pairings_json); - pending_requests_.pop(); - ServiceNextRequest(); +void PairingRegistry::DoDelete( + const std::string& client_id, + const protocol::PairingRegistry::DoneCallback& callback) { + DCHECK(delegate_task_runner_->BelongsToCurrentThread()); + + bool success = delegate_->Delete(client_id); + PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success)); } -void PairingRegistry::InvokeSaveCallbackAndScheduleNext( - const SaveCallback& callback, bool success) { +void PairingRegistry::InvokeDoneCallbackAndScheduleNext( + const DoneCallback& callback, bool success) { // CreatePairing doesn't have a callback, so the callback can be null. - if (!callback.is_null()) { + if (!callback.is_null()) callback.Run(success); - } + pending_requests_.pop(); ServiceNextRequest(); } @@ -216,50 +258,35 @@ void PairingRegistry::InvokeGetAllPairingsCallbackAndScheduleNext( ServiceNextRequest(); } -// static -PairingRegistry::PairedClients PairingRegistry::DecodeJson( - const std::string& pairings_json) { - PairedClients result; - - if (pairings_json.empty()) { - return result; - } - - JSONStringValueSerializer registry(pairings_json); - int error_code; - std::string error_message; - scoped_ptr<base::Value> root( - registry.Deserialize(&error_code, &error_message)); - if (!root) { - LOG(ERROR) << "Failed to load paired clients: " << error_message - << " (" << error_code << ")."; - return result; - } - - base::ListValue* root_list = NULL; - if (!root->GetAsList(&root_list)) { - LOG(ERROR) << "Failed to load paired clients: root node is not a list."; - return result; - } +void PairingRegistry::SanitizePairings(const GetAllPairingsCallback& callback, + scoped_ptr<base::ListValue> pairings) { + DCHECK(caller_task_runner_->BelongsToCurrentThread()); + + scoped_ptr<base::ListValue> sanitized_pairings(new base::ListValue()); + for (size_t i = 0; i < pairings->GetSize(); ++i) { + DictionaryValue* pairing_json; + if (!pairings->GetDictionary(i, &pairing_json)) { + LOG(WARNING) << "A pairing entry is not a dictionary."; + continue; + } - for (size_t i = 0; i < root_list->GetSize(); ++i) { - base::DictionaryValue* pairing = NULL; - std::string client_name, client_id, shared_secret; - double created_time_value; - if (root_list->GetDictionary(i, &pairing) && - pairing->GetDouble(kCreatedTimeKey, &created_time_value) && - pairing->GetString(kClientNameKey, &client_name) && - pairing->GetString(kClientIdKey, &client_id) && - pairing->GetString(kSharedSecretKey, &shared_secret)) { - base::Time created_time = base::Time::FromJsTime(created_time_value); - result[client_id] = Pairing( - created_time, client_name, client_id, shared_secret); - } else { - LOG(ERROR) << "Paired client " << i << " has unexpected format."; + // Parse the pairing data. + Pairing pairing = Pairing::CreateFromValue(*pairing_json); + if (!pairing.is_valid()) { + LOG(WARNING) << "Could not parse a pairing entry."; + continue; } + + // Clear the shared secrect and append the pairing data to the list. + Pairing sanitized_pairing( + pairing.created_time(), + pairing.client_name(), + pairing.client_id(), + ""); + sanitized_pairings->Append(sanitized_pairing.ToValue().release()); } - return result; + callback.Run(sanitized_pairings.Pass()); } void PairingRegistry::ServiceOrQueueRequest(const base::Closure& request) { @@ -271,40 +298,10 @@ void PairingRegistry::ServiceOrQueueRequest(const base::Closure& request) { } void PairingRegistry::ServiceNextRequest() { - if (pending_requests_.empty()) { + if (pending_requests_.empty()) return; - } - base::Closure request = pending_requests_.front(); - request.Run(); -} -// static -std::string PairingRegistry::EncodeJson(const PairedClients& clients) { - scoped_ptr<base::ListValue> root = ConvertToListValue(clients, true); - std::string result; - JSONStringValueSerializer serializer(&result); - serializer.Serialize(*root); - - return result; -} - -// static -scoped_ptr<base::ListValue> PairingRegistry::ConvertToListValue( - const PairedClients& clients, - bool include_shared_secrets) { - scoped_ptr<base::ListValue> root(new base::ListValue()); - for (PairedClients::const_iterator i = clients.begin(); - i != clients.end(); ++i) { - base::DictionaryValue* pairing = new base::DictionaryValue(); - pairing->SetDouble(kCreatedTimeKey, i->second.created_time().ToJsTime()); - pairing->SetString(kClientNameKey, i->second.client_name()); - pairing->SetString(kClientIdKey, i->second.client_id()); - if (include_shared_secrets) { - pairing->SetString(kSharedSecretKey, i->second.shared_secret()); - } - root->Append(pairing); - } - return root.Pass(); + PostTask(delegate_task_runner_, FROM_HERE, pending_requests_.front()); } } // namespace protocol diff --git a/remoting/protocol/pairing_registry.h b/remoting/protocol/pairing_registry.h index fc63e84..ddcd736 100644 --- a/remoting/protocol/pairing_registry.h +++ b/remoting/protocol/pairing_registry.h @@ -14,13 +14,18 @@ #include "base/gtest_prod_util.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "base/threading/non_thread_safe.h" #include "base/time/time.h" namespace base { class ListValue; +class Value; +class SingleThreadTaskRunner; } // namespace base +namespace tracked_objects { +class Location; +} // namespace tracked_objects + namespace remoting { namespace protocol { @@ -33,8 +38,7 @@ namespace protocol { // class and sent in plain-text by the client during authentication. // * The shared secret for the client. This is generated on-demand by this // class and used in the SPAKE2 exchange to mutually verify identity. -class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, - public base::NonThreadSafe { +class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry> { public: struct Pairing { Pairing(); @@ -45,6 +49,9 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, ~Pairing(); static Pairing Create(const std::string& client_name); + static Pairing CreateFromValue(const base::Value& pairing_json); + + scoped_ptr<base::Value> ToValue() const; bool operator==(const Pairing& other) const; @@ -66,11 +73,10 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, typedef std::map<std::string, Pairing> PairedClients; // Delegate callbacks. - typedef base::Callback<void(const std::string& pairings_json)> LoadCallback; - typedef base::Callback<void(bool success)> SaveCallback; - typedef base::Callback<void(Pairing pairing)> GetPairingCallback; + typedef base::Callback<void(bool success)> DoneCallback; typedef base::Callback<void(scoped_ptr<base::ListValue> pairings)> GetAllPairingsCallback; + typedef base::Callback<void(Pairing pairing)> GetPairingCallback; static const char kCreatedTimeKey[]; static const char kClientIdKey[]; @@ -82,18 +88,25 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, public: virtual ~Delegate() {} - // Save JSON-encoded pairing information to persistent storage. If - // a non-NULL callback is provided, invoke it on completion to - // indicate success or failure. Must not block. - virtual void Save(const std::string& pairings_json, - const SaveCallback& callback) = 0; + // Retrieves all JSON-encoded pairings from persistent storage. + virtual scoped_ptr<base::ListValue> LoadAll() = 0; + + // Deletes all pairings in persistent storage. + virtual bool DeleteAll() = 0; - // Retrieve the JSON-encoded pairing information from persistent - // storage. Must not block. - virtual void Load(const LoadCallback& callback) = 0; + // Retrieves the pairing identified by |client_id|. + virtual Pairing Load(const std::string& client_id) = 0; + + // Saves |pairing| to persistent storage. + virtual bool Save(const Pairing& pairing) = 0; + + // Deletes the pairing identified by |client_id|. + virtual bool Delete(const std::string& client_id) = 0; }; - explicit PairingRegistry(scoped_ptr<Delegate> delegate); + PairingRegistry( + scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner, + scoped_ptr<Delegate> delegate); // Creates a pairing for a new client and saves it to disk. // @@ -115,58 +128,68 @@ class PairingRegistry : public base::RefCountedThreadSafe<PairingRegistry>, // the result of saving the new config, which occurs even if the client ID // did not match any pairing. void DeletePairing(const std::string& client_id, - const SaveCallback& callback); + const DoneCallback& callback); // Clear all pairings from the registry. - void ClearAllPairings(const SaveCallback& callback); + void ClearAllPairings(const DoneCallback& callback); + + protected: + friend class base::RefCountedThreadSafe<PairingRegistry>; + virtual ~PairingRegistry(); + + // Lets the tests override task posting to make all callbacks synchronous. + virtual void PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task); private: FRIEND_TEST_ALL_PREFIXES(PairingRegistryTest, AddPairing); - FRIEND_TEST_ALL_PREFIXES(PairingRegistryTest, GetAllPairingsJSON); friend class NegotiatingAuthenticatorTest; - friend class base::RefCountedThreadSafe<PairingRegistry>; - - virtual ~PairingRegistry(); // Helper method for unit tests. void AddPairing(const Pairing& pairing); - // Worker functions for each of the public methods, passed as a callback to - // the delegate. - void MergePairingAndSave(const Pairing& pairing, - const SaveCallback& callback, - const std::string& pairings_json); - void DoGetPairing(const std::string& client_id, - const GetPairingCallback& callback, - const std::string& pairings_json); - void SanitizePairings(const GetAllPairingsCallback& callback, - const std::string& pairings_json); - void DoDeletePairing(const std::string& client_id, - const SaveCallback& callback, - const std::string& pairings_json); + // Blocking helper methods used to call the delegate. + void DoLoadAll( + const protocol::PairingRegistry::GetAllPairingsCallback& callback); + void DoDeleteAll( + const protocol::PairingRegistry::DoneCallback& callback); + void DoLoad( + const std::string& client_id, + const protocol::PairingRegistry::GetPairingCallback& callback); + void DoSave( + const protocol::PairingRegistry::Pairing& pairing, + const protocol::PairingRegistry::DoneCallback& callback); + void DoDelete( + const std::string& client_id, + const protocol::PairingRegistry::DoneCallback& callback); // "Trampoline" callbacks that schedule the next pending request and then // invoke the original caller-supplied callback. - void InvokeLoadCallbackAndScheduleNext( - const LoadCallback& callback, const std::string& pairings_json); - void InvokeSaveCallbackAndScheduleNext( - const SaveCallback& callback, bool success); + void InvokeDoneCallbackAndScheduleNext( + const DoneCallback& callback, bool success); void InvokeGetPairingCallbackAndScheduleNext( const GetPairingCallback& callback, Pairing pairing); void InvokeGetAllPairingsCallbackAndScheduleNext( const GetAllPairingsCallback& callback, scoped_ptr<base::ListValue> pairings); + // Sanitize |pairings| by parsing each entry and removing the secret from it. + void SanitizePairings(const GetAllPairingsCallback& callback, + scoped_ptr<base::ListValue> pairings); + // Queue management methods. void ServiceOrQueueRequest(const base::Closure& request); void ServiceNextRequest(); - // Translate between the structured and serialized forms of the pairing data. - static PairedClients DecodeJson(const std::string& pairings_json); - static std::string EncodeJson(const PairedClients& clients); - static scoped_ptr<base::ListValue> ConvertToListValue( - const PairedClients& clients, - bool include_shared_secrets); + // Task runner on which all public methods of this class should be called. + scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_; + + // Task runner used to run blocking calls to the delegate. A single thread + // task runner is used to guarantee that one one method of the delegate is + // called at a time. + scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner_; scoped_ptr<Delegate> delegate_; diff --git a/remoting/protocol/pairing_registry_unittest.cc b/remoting/protocol/pairing_registry_unittest.cc index f564256..eefea2a 100644 --- a/remoting/protocol/pairing_registry_unittest.cc +++ b/remoting/protocol/pairing_registry_unittest.cc @@ -11,15 +11,37 @@ #include "base/bind.h" #include "base/compiler_specific.h" #include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/thread_task_runner_handle.h" #include "base/values.h" #include "remoting/protocol/protocol_mock_objects.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +using testing::Sequence; + namespace { using remoting::protocol::PairingRegistry; +class MockPairingRegistryCallbacks { + public: + MockPairingRegistryCallbacks() {} + virtual ~MockPairingRegistryCallbacks() {} + + MOCK_METHOD1(DoneCallback, void(bool)); + MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*)); + MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing)); + + void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) { + GetAllPairingsCallbackPtr(pairings.get()); + } + + private: + DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks); +}; + // Verify that a pairing Dictionary has correct entries, but doesn't include // any shared secret. void VerifyPairing(PairingRegistry::Pairing expected, @@ -59,32 +81,19 @@ class PairingRegistryTest : public testing::Test { ++callback_count_; } - void ExpectClientName(const std::string& expected, - PairingRegistry::Pairing actual) { - EXPECT_EQ(expected, actual.client_name()); - ++callback_count_; - } - - void ExpectNoPairings(scoped_ptr<base::ListValue> pairings) { - EXPECT_TRUE(pairings->empty()); - ++callback_count_; - } - protected: + base::MessageLoop message_loop_; + base::RunLoop run_loop_; + int callback_count_; scoped_ptr<base::ListValue> pairings_; }; TEST_F(PairingRegistryTest, CreateAndGetPairings) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client"); - mock_delegate->RunCallback(); EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret()); @@ -92,7 +101,6 @@ TEST_F(PairingRegistryTest, CreateAndGetPairings) { base::Bind(&PairingRegistryTest::ExpectSecret, base::Unretained(this), pairing_1.shared_secret())); - mock_delegate->RunCallback(); EXPECT_EQ(1, callback_count_); // Check that the second client is paired with a different shared secret. @@ -100,25 +108,18 @@ TEST_F(PairingRegistryTest, CreateAndGetPairings) { base::Bind(&PairingRegistryTest::ExpectSecret, base::Unretained(this), pairing_2.shared_secret())); - mock_delegate->RunCallback(); EXPECT_EQ(2, callback_count_); } TEST_F(PairingRegistryTest, GetAllPairings) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); - mock_delegate->RunCallback(); registry->GetAllPairings( base::Bind(&PairingRegistryTest::set_pairings, base::Unretained(this))); - mock_delegate->RunCallback(); ASSERT_EQ(2u, pairings_->GetSize()); const base::DictionaryValue* actual_pairing_1; @@ -139,27 +140,20 @@ TEST_F(PairingRegistryTest, GetAllPairings) { } TEST_F(PairingRegistryTest, DeletePairing) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); - mock_delegate->RunCallback(); registry->DeletePairing( pairing_1.client_id(), base::Bind(&PairingRegistryTest::ExpectSaveSuccess, base::Unretained(this))); - mock_delegate->RunCallback(); // Re-read the list, and verify it only has the pairing_2 client. registry->GetAllPairings( base::Bind(&PairingRegistryTest::set_pairings, base::Unretained(this))); - mock_delegate->RunCallback(); ASSERT_EQ(1u, pairings_->GetSize()); const base::DictionaryValue* actual_pairing_2; @@ -171,15 +165,10 @@ TEST_F(PairingRegistryTest, DeletePairing) { } TEST_F(PairingRegistryTest, ClearAllPairings) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); + scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); - mock_delegate->RunCallback(); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); - mock_delegate->RunCallback(); registry->ClearAllPairings( base::Bind(&PairingRegistryTest::ExpectSaveSuccess, @@ -189,57 +178,81 @@ TEST_F(PairingRegistryTest, ClearAllPairings) { registry->GetAllPairings( base::Bind(&PairingRegistryTest::set_pairings, base::Unretained(this))); - mock_delegate->RunCallback(); EXPECT_TRUE(pairings_->empty()); } -TEST_F(PairingRegistryTest, SerializedRequests) { - MockPairingRegistryDelegate* mock_delegate = - new MockPairingRegistryDelegate(); - scoped_ptr<PairingRegistry::Delegate> delegate(mock_delegate); - mock_delegate->set_run_save_callback_automatically(false); +ACTION_P(QuitMessageLoop, callback) { + callback.Run(); +} + +MATCHER_P(EqualsClientName, client_name, "") { + return arg.client_name() == client_name; +} + +MATCHER(NoPairings, "") { + return arg->empty(); +} - scoped_refptr<PairingRegistry> registry(new PairingRegistry(delegate.Pass())); +TEST_F(PairingRegistryTest, SerializedRequests) { + MockPairingRegistryCallbacks callbacks; + Sequence s; + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1"))) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2"))) + .InSequence(s); + EXPECT_CALL(callbacks, DoneCallback(true)) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1"))) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName(""))) + .InSequence(s); + EXPECT_CALL(callbacks, DoneCallback(true)) + .InSequence(s); + EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings())) + .InSequence(s); + EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3"))) + .InSequence(s) + .WillOnce(QuitMessageLoop(run_loop_.QuitClosure())); + + scoped_refptr<PairingRegistry> registry = new PairingRegistry( + base::ThreadTaskRunnerHandle::Get(), + scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); registry->GetPairing( pairing_1.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client1")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->GetPairing( pairing_2.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client2")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->DeletePairing( pairing_2.client_id(), - base::Bind(&PairingRegistryTest::ExpectSaveSuccess, - base::Unretained(this))); + base::Bind(&MockPairingRegistryCallbacks::DoneCallback, + base::Unretained(&callbacks))); registry->GetPairing( pairing_1.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client1")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->GetPairing( pairing_2.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "")); + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); registry->ClearAllPairings( - base::Bind(&PairingRegistryTest::ExpectSaveSuccess, - base::Unretained(this))); + base::Bind(&MockPairingRegistryCallbacks::DoneCallback, + base::Unretained(&callbacks))); registry->GetAllPairings( - base::Bind(&PairingRegistryTest::ExpectNoPairings, - base::Unretained(this))); + base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback, + base::Unretained(&callbacks))); PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3"); registry->GetPairing( pairing_3.client_id(), - base::Bind(&PairingRegistryTest::ExpectClientName, - base::Unretained(this), "client3")); - - while (mock_delegate->HasCallback()) { - mock_delegate->RunCallback(); - } + base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, + base::Unretained(&callbacks))); - EXPECT_EQ(8, callback_count_); + run_loop_.Run(); } } // namespace protocol diff --git a/remoting/protocol/protocol_mock_objects.cc b/remoting/protocol/protocol_mock_objects.cc index efde603d..93f6e34 100644 --- a/remoting/protocol/protocol_mock_objects.cc +++ b/remoting/protocol/protocol_mock_objects.cc @@ -4,6 +4,9 @@ #include "remoting/protocol/protocol_mock_objects.h" +#include "base/logging.h" +#include "base/thread_task_runner_handle.h" + namespace remoting { namespace protocol { @@ -48,57 +51,61 @@ MockSessionManager::MockSessionManager() {} MockSessionManager::~MockSessionManager() {} -MockPairingRegistryDelegate::MockPairingRegistryDelegate() - : run_save_callback_automatically_(true) { +MockPairingRegistryDelegate::MockPairingRegistryDelegate() { } MockPairingRegistryDelegate::~MockPairingRegistryDelegate() { } -void MockPairingRegistryDelegate::Save( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) { - EXPECT_TRUE(load_callback_.is_null()); - EXPECT_TRUE(save_callback_.is_null()); - if (run_save_callback_automatically_) { - SetPairingsJsonAndRunCallback(pairings_json, callback); - } else { - save_callback_ = base::Bind( - &MockPairingRegistryDelegate::SetPairingsJsonAndRunCallback, - base::Unretained(this), pairings_json, callback); +scoped_ptr<base::ListValue> MockPairingRegistryDelegate::LoadAll() { + scoped_ptr<base::ListValue> result(new base::ListValue()); + for (Pairings::const_iterator i = pairings_.begin(); i != pairings_.end(); + ++i) { + result->Append(i->second.ToValue().release()); } + return result.Pass(); +} + +bool MockPairingRegistryDelegate::DeleteAll() { + pairings_.clear(); + return true; } -void MockPairingRegistryDelegate::SetPairingsJsonAndRunCallback( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) { - pairings_json_ = pairings_json; - if (!callback.is_null()) { - callback.Run(true); +protocol::PairingRegistry::Pairing MockPairingRegistryDelegate::Load( + const std::string& client_id) { + Pairings::const_iterator i = pairings_.find(client_id); + if (i != pairings_.end()) { + return i->second; + } else { + return protocol::PairingRegistry::Pairing(); } } -void MockPairingRegistryDelegate::Load( - const PairingRegistry::LoadCallback& callback) { - EXPECT_TRUE(load_callback_.is_null()); - EXPECT_TRUE(save_callback_.is_null()); - load_callback_ = base::Bind(callback, pairings_json_); +bool MockPairingRegistryDelegate::Save( + const protocol::PairingRegistry::Pairing& pairing) { + pairings_[pairing.client_id()] = pairing; + return true; } -void MockPairingRegistryDelegate::RunCallback() { - if (!load_callback_.is_null()) { - EXPECT_TRUE(save_callback_.is_null()); - base::Closure load_callback = load_callback_; - load_callback_.Reset(); - load_callback.Run(); - } else if (!save_callback_.is_null()) { - EXPECT_TRUE(load_callback_.is_null()); - base::Closure save_callback = save_callback_; - save_callback_.Reset(); - save_callback.Run(); - } else { - ADD_FAILURE() << "RunCallback called without any callbacks set."; - } +bool MockPairingRegistryDelegate::Delete(const std::string& client_id) { + pairings_.erase(client_id); + return true; +} + +SynchronousPairingRegistry::SynchronousPairingRegistry( + scoped_ptr<Delegate> delegate) + : PairingRegistry(base::ThreadTaskRunnerHandle::Get(), delegate.Pass()) { +} + +SynchronousPairingRegistry::~SynchronousPairingRegistry() { +} + +void SynchronousPairingRegistry::PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task) { + DCHECK(task_runner->BelongsToCurrentThread()); + task.Run(); } } // namespace protocol diff --git a/remoting/protocol/protocol_mock_objects.h b/remoting/protocol/protocol_mock_objects.h index cfb912b..74435db 100644 --- a/remoting/protocol/protocol_mock_objects.h +++ b/remoting/protocol/protocol_mock_objects.h @@ -5,8 +5,12 @@ #ifndef REMOTING_PROTOCOL_PROTOCOL_MOCK_OBJECTS_H_ #define REMOTING_PROTOCOL_PROTOCOL_MOCK_OBJECTS_H_ +#include <map> #include <string> +#include "base/location.h" +#include "base/single_thread_task_runner.h" +#include "base/values.h" #include "net/base/ip_endpoint.h" #include "remoting/proto/internal.pb.h" #include "remoting/proto/video.pb.h" @@ -211,44 +215,31 @@ class MockPairingRegistryDelegate : public PairingRegistry::Delegate { MockPairingRegistryDelegate(); virtual ~MockPairingRegistryDelegate(); - const std::string& pairings_json() const { - return pairings_json_; - } - // PairingRegistry::Delegate implementation. - virtual void Save( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback) OVERRIDE; - virtual void Load( - const PairingRegistry::LoadCallback& callback) OVERRIDE; - - // By default, the Save method runs its callback automatically because the - // negotiating authenticator unit test does not provide any hooks to do it - // manually. For unit tests that need to verify correct behaviour under - // asynchronous conditions, use this method to disable this feature and call - // RunCallback as appropriate. - void set_run_save_callback_automatically( - bool run_save_callback_automatically) { - run_save_callback_automatically_ = run_save_callback_automatically; - } + virtual scoped_ptr<base::ListValue> LoadAll() OVERRIDE; + virtual bool DeleteAll() OVERRIDE; + virtual protocol::PairingRegistry::Pairing Load( + const std::string& client_id) OVERRIDE; + virtual bool Save(const protocol::PairingRegistry::Pairing& pairing) OVERRIDE; + virtual bool Delete(const std::string& client_id) OVERRIDE; - bool HasCallback() const { - return !load_callback_.is_null() || !save_callback_.is_null(); - } + private: + typedef std::map<std::string, protocol::PairingRegistry::Pairing> Pairings; + Pairings pairings_; +}; + +class SynchronousPairingRegistry : public PairingRegistry { + public: + explicit SynchronousPairingRegistry(scoped_ptr<Delegate> delegate); - // Run either the save or the load callback (whichever was set most recently; - // it is an error for both of these to be set at the same time). - void RunCallback(); + protected: + virtual ~SynchronousPairingRegistry(); - private: - void SetPairingsJsonAndRunCallback( - const std::string& pairings_json, - const PairingRegistry::SaveCallback& callback); - - base::Closure load_callback_; - base::Closure save_callback_; - std::string pairings_json_; - bool run_save_callback_automatically_; + // Runs tasks synchronously instead of posting them to |task_runner|. + virtual void PostTask( + const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const tracked_objects::Location& from_here, + const base::Closure& task) OVERRIDE; }; } // namespace protocol |