diff options
author | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-10-08 08:11:34 +0000 |
---|---|---|
committer | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-10-08 08:11:34 +0000 |
commit | 7459b5a58bce5d23fc79f6308ca5698b069466bf (patch) | |
tree | 7120bd82a74117af47bc865b9cd14d3598a9f405 /remoting | |
parent | a13bb0bdb4c12914232f67a6e039ee1ab4b97f86 (diff) | |
download | chromium_src-7459b5a58bce5d23fc79f6308ca5698b069466bf.zip chromium_src-7459b5a58bce5d23fc79f6308ca5698b069466bf.tar.gz chromium_src-7459b5a58bce5d23fc79f6308ca5698b069466bf.tar.bz2 |
Moved all channel-handling logic into separate NativeMessagingChannel class.
Changes in this CL:
- NativeMessagingHost is used only for actual processing the messages.
- NativeMessagingChannel implements receiving and sending messages.
- NativeMessagingChannel takes ownership of the passed inout and output handles.
- Both NativeMessagingChannel and NativeMessagingHost are explicitly marked as not thread-safe classes.
BUG=173509
Review URL: https://codereview.chromium.org/23903021
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@227492 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting')
-rw-r--r-- | remoting/host/setup/native_messaging_channel.cc | 129 | ||||
-rw-r--r-- | remoting/host/setup/native_messaging_channel.h | 94 | ||||
-rw-r--r-- | remoting/host/setup/native_messaging_host.cc | 293 | ||||
-rw-r--r-- | remoting/host/setup/native_messaging_host.h | 164 | ||||
-rw-r--r-- | remoting/host/setup/native_messaging_host_unittest.cc | 43 | ||||
-rw-r--r-- | remoting/remoting.gyp | 2 |
6 files changed, 469 insertions, 256 deletions
diff --git a/remoting/host/setup/native_messaging_channel.cc b/remoting/host/setup/native_messaging_channel.cc new file mode 100644 index 0000000..b93fd94 --- /dev/null +++ b/remoting/host/setup/native_messaging_channel.cc @@ -0,0 +1,129 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/host/setup/native_messaging_channel.h" + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/callback.h" +#include "base/location.h" +#include "base/values.h" + +#if defined(OS_POSIX) +#include <unistd.h> +#endif + +namespace { + +base::PlatformFile DuplicatePlatformFile(base::PlatformFile handle) { + base::PlatformFile result; +#if defined(OS_WIN) + if (!DuplicateHandle(GetCurrentProcess(), + handle, + GetCurrentProcess(), + &result, + 0, + FALSE, + DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS)) { + PLOG(ERROR) << "Failed to duplicate handle " << handle; + return base::kInvalidPlatformFileValue; + } + return result; +#elif defined(OS_POSIX) + result = dup(handle); + base::ClosePlatformFile(handle); + return result; +#else +#error Not implemented. +#endif +} + +} // namespace + +namespace remoting { + +NativeMessagingChannel::NativeMessagingChannel( + scoped_ptr<Delegate> delegate, + base::PlatformFile input, + base::PlatformFile output) + : native_messaging_reader_(DuplicatePlatformFile(input)), + native_messaging_writer_(new NativeMessagingWriter( + DuplicatePlatformFile(output))), + delegate_(delegate.Pass()), + pending_requests_(0), + shutdown_(false), + weak_factory_(this) { + weak_ptr_ = weak_factory_.GetWeakPtr(); +} + +NativeMessagingChannel::~NativeMessagingChannel() { +} + +void NativeMessagingChannel::Start(const base::Closure& quit_closure) { + DCHECK(CalledOnValidThread()); + DCHECK(quit_closure_.is_null()); + DCHECK(!quit_closure.is_null()); + + quit_closure_ = quit_closure; + native_messaging_reader_.Start( + base::Bind(&NativeMessagingChannel::ProcessMessage, weak_ptr_), + base::Bind(&NativeMessagingChannel::Shutdown, weak_ptr_)); +} + +void NativeMessagingChannel::ProcessMessage(scoped_ptr<base::Value> message) { + DCHECK(CalledOnValidThread()); + + // Don't process any more messages if Shutdown() has been called. + if (shutdown_) + return; + + if (message->GetType() != base::Value::TYPE_DICTIONARY) { + LOG(ERROR) << "Expected DictionaryValue"; + Shutdown(); + return; + } + + DCHECK_GE(pending_requests_, 0); + pending_requests_++; + + scoped_ptr<base::DictionaryValue> message_dict( + static_cast<base::DictionaryValue*>(message.release())); + delegate_->ProcessMessage( + message_dict.Pass(), + base::Bind(&NativeMessagingChannel::SendResponse, weak_ptr_)); +} + +void NativeMessagingChannel::SendResponse( + scoped_ptr<base::DictionaryValue> response) { + DCHECK(CalledOnValidThread()); + + bool success = response && native_messaging_writer_; + if (success) + success = native_messaging_writer_->WriteMessage(*response); + + if (!success) { + // Close the write pipe so no more responses will be sent. + native_messaging_writer_.reset(); + Shutdown(); + } + + pending_requests_--; + DCHECK_GE(pending_requests_, 0); + + if (shutdown_ && !pending_requests_) + quit_closure_.Run(); +} + +void NativeMessagingChannel::Shutdown() { + DCHECK(CalledOnValidThread()); + + if (shutdown_) + return; + + shutdown_ = true; + if (!pending_requests_) + quit_closure_.Run(); +} + +} // namespace remoting diff --git a/remoting/host/setup/native_messaging_channel.h b/remoting/host/setup/native_messaging_channel.h new file mode 100644 index 0000000..9b0c229 --- /dev/null +++ b/remoting/host/setup/native_messaging_channel.h @@ -0,0 +1,94 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef REMOTING_HOST_SETUP_NATIVE_MESSAGING_CHANNEL_H_ +#define REMOTING_HOST_SETUP_NATIVE_MESSAGING_CHANNEL_H_ + +#include "base/callback.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "base/platform_file.h" +#include "base/threading/non_thread_safe.h" +#include "remoting/host/setup/native_messaging_reader.h" +#include "remoting/host/setup/native_messaging_writer.h" + +namespace base { +class DictionaryValue; +class Value; +} // namespace base + +namespace remoting { + +// Implements reading messages and sending responses across the native messaging +// host pipe. Delegates processing of received messages to Delegate. +// +// TODO(alexeypa): Add ability to switch between different |delegate_| pointers +// on the fly. This is useful for implementing UAC-style elevation on Windows - +// an unprivileged delegate could be replaced with another delegate that +// forwards messages to the elevated instance of the native messaging host. +class NativeMessagingChannel : public base::NonThreadSafe { + public: + // Used to sends a response back to the client app. + typedef base::Callback<void(scoped_ptr<base::DictionaryValue> response)> + SendResponseCallback; + + class Delegate { + public: + virtual ~Delegate() {} + + // Processes a message received from the client app. Invokes |done| to send + // a response back to the client app. + virtual void ProcessMessage(scoped_ptr<base::DictionaryValue> message, + const SendResponseCallback& done) = 0; + }; + + // Constructs an object taking the ownership of |input| and |output|. Closes + // |input| and |output| to prevent the caller frmo using them. + NativeMessagingChannel( + scoped_ptr<Delegate> delegate, + base::PlatformFile input, + base::PlatformFile output); + ~NativeMessagingChannel(); + + // Starts reading and processing messages. + void Start(const base::Closure& quit_closure); + + private: + // Processes a message received from the client app. + void ProcessMessage(scoped_ptr<base::Value> message); + + // Sends a response back to the client app. + void SendResponse(scoped_ptr<base::DictionaryValue> response); + + // Initiates shutdown and runs |quit_closure| if there are no pending requests + // left. + void Shutdown(); + + base::Closure quit_closure_; + + NativeMessagingReader native_messaging_reader_; + scoped_ptr<NativeMessagingWriter> native_messaging_writer_; + + // |delegate_| may post tasks to this object during destruction (but not + // afterwards), so it needs to be destroyed before other members of this class + // (except for |weak_factory_|). + scoped_ptr<Delegate> delegate_; + + // Keeps track of pending requests. Used to delay shutdown until all responses + // have been sent. + int pending_requests_; + + // True if Shutdown() has been called. + bool shutdown_; + + base::WeakPtr<NativeMessagingChannel> weak_ptr_; + base::WeakPtrFactory<NativeMessagingChannel> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(NativeMessagingChannel); +}; + +} // namespace remoting + +#endif // REMOTING_HOST_SETUP_NATIVE_MESSAGING_CHANNEL_H_ diff --git a/remoting/host/setup/native_messaging_host.cc b/remoting/host/setup/native_messaging_host.cc index 275e9fe..fd2c0c6 100644 --- a/remoting/host/setup/native_messaging_host.cc +++ b/remoting/host/setup/native_messaging_host.cc @@ -9,10 +9,11 @@ #include "base/basictypes.h" #include "base/bind.h" #include "base/callback.h" -#include "base/json/json_string_value_serializer.h" -#include "base/location.h" +#include "base/command_line.h" +#include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" #include "base/strings/stringize_macros.h" #include "base/values.h" #include "google_apis/gaia/gaia_oauth_client.h" @@ -27,22 +28,20 @@ #include "remoting/host/setup/oauth_client.h" #include "remoting/protocol/pairing_registry.h" -#if defined(OS_POSIX) -#include <unistd.h> -#endif - namespace { +const char kParentWindowSwitchName[] = "parent-window"; + +// redirect_uri to use when authenticating service accounts (service account +// codes are obtained "out-of-band", i.e., not through an OAuth redirect). +const char* kServiceAccountRedirectUri = "oob"; + // Features supported in addition to the base protocol. const char* kSupportedFeatures[] = { "pairingRegistry", "oauthClient" }; -// redirect_uri to use when authenticating service accounts (service account -// codes are obtained "out-of-band", i.e., not through an OAuth redirect). -const char* kServiceAccountRedirectUri = "oob"; - // Helper to extract the "config" part of a message as a DictionaryValue. // Returns NULL on failure, and logs an error message. scoped_ptr<base::DictionaryValue> ConfigDictionaryFromMessage( @@ -64,151 +63,108 @@ namespace remoting { NativeMessagingHost::NativeMessagingHost( scoped_refptr<DaemonController> daemon_controller, scoped_refptr<protocol::PairingRegistry> pairing_registry, - scoped_ptr<OAuthClient> oauth_client, - base::PlatformFile input, - base::PlatformFile output, - scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner, - const base::Closure& quit_closure) - : caller_task_runner_(caller_task_runner), - quit_closure_(quit_closure), - native_messaging_reader_(input), - native_messaging_writer_(output), - daemon_controller_(daemon_controller), + scoped_ptr<OAuthClient> oauth_client) + : daemon_controller_(daemon_controller), pairing_registry_(pairing_registry), oauth_client_(oauth_client.Pass()), - pending_requests_(0), - shutdown_(false), weak_factory_(this) { weak_ptr_ = weak_factory_.GetWeakPtr(); } -NativeMessagingHost::~NativeMessagingHost() {} - -void NativeMessagingHost::Start() { - DCHECK(caller_task_runner_->BelongsToCurrentThread()); - - native_messaging_reader_.Start( - base::Bind(&NativeMessagingHost::ProcessMessage, weak_ptr_), - base::Bind(&NativeMessagingHost::Shutdown, weak_ptr_)); +NativeMessagingHost::~NativeMessagingHost() { } -void NativeMessagingHost::Shutdown() { - DCHECK(caller_task_runner_->BelongsToCurrentThread()); - - if (shutdown_) - return; - - shutdown_ = true; - if (!pending_requests_) - caller_task_runner_->PostTask(FROM_HERE, quit_closure_); -} - -void NativeMessagingHost::ProcessMessage(scoped_ptr<base::Value> message) { - DCHECK(caller_task_runner_->BelongsToCurrentThread()); - - // Don't process any more messages if Shutdown() has been called. - if (shutdown_) - return; - - const base::DictionaryValue* message_dict; - if (!message->GetAsDictionary(&message_dict)) { - LOG(ERROR) << "Expected DictionaryValue"; - Shutdown(); - return; - } - - scoped_ptr<base::DictionaryValue> response_dict(new base::DictionaryValue()); +void NativeMessagingHost::ProcessMessage( + scoped_ptr<base::DictionaryValue> message, + const SendResponseCallback& done) { + scoped_ptr<base::DictionaryValue> response(new base::DictionaryValue()); // If the client supplies an ID, it will expect it in the response. This // might be a string or a number, so cope with both. const base::Value* id; - if (message_dict->Get("id", &id)) - response_dict->Set("id", id->DeepCopy()); + if (message->Get("id", &id)) + response->Set("id", id->DeepCopy()); std::string type; - if (!message_dict->GetString("type", &type)) { + if (!message->GetString("type", &type)) { LOG(ERROR) << "'type' not found"; - Shutdown(); + done.Run(scoped_ptr<base::DictionaryValue>()); return; } - response_dict->SetString("type", type + "Response"); - - DCHECK_GE(pending_requests_, 0); - pending_requests_++; + response->SetString("type", type + "Response"); bool success = false; if (type == "hello") { - success = ProcessHello(*message_dict, response_dict.Pass()); + success = ProcessHello(*message, response.Pass(), done); } else if (type == "clearPairedClients") { - success = ProcessClearPairedClients(*message_dict, response_dict.Pass()); + success = ProcessClearPairedClients(*message, response.Pass(), done); } else if (type == "deletePairedClient") { - success = ProcessDeletePairedClient(*message_dict, response_dict.Pass()); + success = ProcessDeletePairedClient(*message, response.Pass(), done); } else if (type == "getHostName") { - success = ProcessGetHostName(*message_dict, response_dict.Pass()); + success = ProcessGetHostName(*message, response.Pass(), done); } else if (type == "getPinHash") { - success = ProcessGetPinHash(*message_dict, response_dict.Pass()); + success = ProcessGetPinHash(*message, response.Pass(), done); } else if (type == "generateKeyPair") { - success = ProcessGenerateKeyPair(*message_dict, response_dict.Pass()); + success = ProcessGenerateKeyPair(*message, response.Pass(), done); } else if (type == "updateDaemonConfig") { - success = ProcessUpdateDaemonConfig(*message_dict, response_dict.Pass()); + success = ProcessUpdateDaemonConfig(*message, response.Pass(), done); } else if (type == "getDaemonConfig") { - success = ProcessGetDaemonConfig(*message_dict, response_dict.Pass()); + success = ProcessGetDaemonConfig(*message, response.Pass(), done); } else if (type == "getPairedClients") { - success = ProcessGetPairedClients(*message_dict, response_dict.Pass()); + success = ProcessGetPairedClients(*message, response.Pass(), done); } else if (type == "getUsageStatsConsent") { - success = ProcessGetUsageStatsConsent(*message_dict, response_dict.Pass()); + success = ProcessGetUsageStatsConsent(*message, response.Pass(), done); } else if (type == "startDaemon") { - success = ProcessStartDaemon(*message_dict, response_dict.Pass()); + success = ProcessStartDaemon(*message, response.Pass(), done); } else if (type == "stopDaemon") { - success = ProcessStopDaemon(*message_dict, response_dict.Pass()); + success = ProcessStopDaemon(*message, response.Pass(), done); } else if (type == "getDaemonState") { - success = ProcessGetDaemonState(*message_dict, response_dict.Pass()); + success = ProcessGetDaemonState(*message, response.Pass(), done); } else if (type == "getHostClientId") { - success = ProcessGetHostClientId(*message_dict, response_dict.Pass()); + success = ProcessGetHostClientId(*message, response.Pass(), done); } else if (type == "getCredentialsFromAuthCode") { - success = ProcessGetCredentialsFromAuthCode( - *message_dict, response_dict.Pass()); + success = ProcessGetCredentialsFromAuthCode(*message, response.Pass(), + done); } else { LOG(ERROR) << "Unsupported request type: " << type; } - if (!success) { - pending_requests_--; - DCHECK_GE(pending_requests_, 0); - - Shutdown(); - } + if (!success) + done.Run(scoped_ptr<base::DictionaryValue>()); } bool NativeMessagingHost::ProcessHello( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { response->SetString("version", STRINGIZE(VERSION)); scoped_ptr<base::ListValue> supported_features_list(new base::ListValue()); supported_features_list->AppendStrings(std::vector<std::string>( kSupportedFeatures, kSupportedFeatures + arraysize(kSupportedFeatures))); response->Set("supportedFeatures", supported_features_list.release()); - SendResponse(response.Pass()); + done.Run(response.Pass()); return true; } bool NativeMessagingHost::ProcessClearPairedClients( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { if (pairing_registry_) { pairing_registry_->ClearAllPairings( base::Bind(&NativeMessagingHost::SendBooleanResult, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); } else { - SendBooleanResult(response.Pass(), false); + SendBooleanResult(done, response.Pass(), false); } return true; } bool NativeMessagingHost::ProcessDeletePairedClient( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { std::string client_id; if (!message.GetString(protocol::PairingRegistry::kClientIdKey, &client_id)) { LOG(ERROR) << "'" << protocol::PairingRegistry::kClientIdKey @@ -219,24 +175,26 @@ bool NativeMessagingHost::ProcessDeletePairedClient( if (pairing_registry_) { pairing_registry_->DeletePairing( client_id, base::Bind(&NativeMessagingHost::SendBooleanResult, - weak_ptr_, base::Passed(&response))); + weak_ptr_, done, base::Passed(&response))); } else { - SendBooleanResult(response.Pass(), false); + SendBooleanResult(done, response.Pass(), false); } return true; } bool NativeMessagingHost::ProcessGetHostName( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { response->SetString("hostname", net::GetHostName()); - SendResponse(response.Pass()); + done.Run(response.Pass()); return true; } bool NativeMessagingHost::ProcessGetPinHash( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { std::string host_id; if (!message.GetString("hostId", &host_id)) { LOG(ERROR) << "'hostId' not found: " << message; @@ -247,24 +205,26 @@ bool NativeMessagingHost::ProcessGetPinHash( LOG(ERROR) << "'pin' not found: " << message; return false; } - response->SetString("hash", remoting::MakeHostPinHash(host_id, pin)); - SendResponse(response.Pass()); + response->SetString("hash", MakeHostPinHash(host_id, pin)); + done.Run(response.Pass()); return true; } bool NativeMessagingHost::ProcessGenerateKeyPair( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { scoped_refptr<RsaKeyPair> key_pair = RsaKeyPair::Generate(); response->SetString("privateKey", key_pair->ToString()); response->SetString("publicKey", key_pair->GetPublicKey()); - SendResponse(response.Pass()); + done.Run(response.Pass()); return true; } bool NativeMessagingHost::ProcessUpdateDaemonConfig( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { scoped_ptr<base::DictionaryValue> config_dict = ConfigDictionaryFromMessage(message); if (!config_dict) @@ -273,45 +233,49 @@ bool NativeMessagingHost::ProcessUpdateDaemonConfig( daemon_controller_->UpdateConfig( config_dict.Pass(), base::Bind(&NativeMessagingHost::SendAsyncResult, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); return true; } bool NativeMessagingHost::ProcessGetDaemonConfig( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { daemon_controller_->GetConfig( base::Bind(&NativeMessagingHost::SendConfigResponse, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); return true; } bool NativeMessagingHost::ProcessGetPairedClients( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { if (pairing_registry_) { pairing_registry_->GetAllPairings( base::Bind(&NativeMessagingHost::SendPairedClientsResponse, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); } else { scoped_ptr<base::ListValue> no_paired_clients(new base::ListValue); - SendPairedClientsResponse(response.Pass(), no_paired_clients.Pass()); + SendPairedClientsResponse(done, response.Pass(), no_paired_clients.Pass()); } return true; } bool NativeMessagingHost::ProcessGetUsageStatsConsent( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { daemon_controller_->GetUsageStatsConsent( base::Bind(&NativeMessagingHost::SendUsageStatsConsentResponse, - weak_ptr_, base::Passed(&response))); + weak_ptr_, done, base::Passed(&response))); return true; } bool NativeMessagingHost::ProcessStartDaemon( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { bool consent; if (!message.GetBoolean("consent", &consent)) { LOG(ERROR) << "'consent' not found."; @@ -326,22 +290,24 @@ bool NativeMessagingHost::ProcessStartDaemon( daemon_controller_->SetConfigAndStart( config_dict.Pass(), consent, base::Bind(&NativeMessagingHost::SendAsyncResult, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); return true; } bool NativeMessagingHost::ProcessStopDaemon( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { daemon_controller_->Stop( base::Bind(&NativeMessagingHost::SendAsyncResult, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); return true; } bool NativeMessagingHost::ProcessGetDaemonState( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { DaemonController::State state = daemon_controller_->GetState(); switch (state) { case DaemonController::STATE_NOT_IMPLEMENTED: @@ -369,22 +335,24 @@ bool NativeMessagingHost::ProcessGetDaemonState( response->SetString("state", "UNKNOWN"); break; } - SendResponse(response.Pass()); + done.Run(response.Pass()); return true; } bool NativeMessagingHost::ProcessGetHostClientId( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { response->SetString("clientId", google_apis::GetOAuth2ClientID( google_apis::CLIENT_REMOTING_HOST)); - SendResponse(response.Pass()); + done.Run(response.Pass()); return true; } bool NativeMessagingHost::ProcessGetCredentialsFromAuthCode( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response) { + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done) { std::string auth_code; if (!message.GetString("authorizationCode", &auth_code)) { LOG(ERROR) << "'authorizationCode' string not found."; @@ -400,31 +368,13 @@ bool NativeMessagingHost::ProcessGetCredentialsFromAuthCode( oauth_client_->GetCredentialsFromAuthCode( oauth_client_info, auth_code, base::Bind( &NativeMessagingHost::SendCredentialsResponse, weak_ptr_, - base::Passed(&response))); + done, base::Passed(&response))); return true; } -void NativeMessagingHost::SendResponse( - scoped_ptr<base::DictionaryValue> response) { - if (!caller_task_runner_->BelongsToCurrentThread()) { - caller_task_runner_->PostTask( - FROM_HERE, base::Bind(&NativeMessagingHost::SendResponse, weak_ptr_, - base::Passed(&response))); - return; - } - - if (!native_messaging_writer_.WriteMessage(*response)) - Shutdown(); - - pending_requests_--; - DCHECK_GE(pending_requests_, 0); - - if (shutdown_ && !pending_requests_) - caller_task_runner_->PostTask(FROM_HERE, quit_closure_); -} - void NativeMessagingHost::SendConfigResponse( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, scoped_ptr<base::DictionaryValue> config) { if (config) { @@ -432,26 +382,29 @@ void NativeMessagingHost::SendConfigResponse( } else { response->Set("config", Value::CreateNullValue()); } - SendResponse(response.Pass()); + done.Run(response.Pass()); } void NativeMessagingHost::SendPairedClientsResponse( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, scoped_ptr<base::ListValue> pairings) { response->Set("pairedClients", pairings.release()); - SendResponse(response.Pass()); + done.Run(response.Pass()); } void NativeMessagingHost::SendUsageStatsConsentResponse( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, const DaemonController::UsageStatsConsent& consent) { response->SetBoolean("supported", consent.supported); response->SetBoolean("allowed", consent.allowed); response->SetBoolean("setByPolicy", consent.set_by_policy); - SendResponse(response.Pass()); + done.Run(response.Pass()); } void NativeMessagingHost::SendAsyncResult( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, DaemonController::AsyncResult result) { switch (result) { @@ -468,26 +421,32 @@ void NativeMessagingHost::SendAsyncResult( response->SetString("result", "FAILED_DIRECTORY"); break; } - SendResponse(response.Pass()); + done.Run(response.Pass()); } void NativeMessagingHost::SendBooleanResult( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, bool result) { response->SetBoolean("result", result); - SendResponse(response.Pass()); + done.Run(response.Pass()); } void NativeMessagingHost::SendCredentialsResponse( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, const std::string& user_email, const std::string& refresh_token) { response->SetString("userEmail", user_email); response->SetString("refreshToken", refresh_token); - SendResponse(response.Pass()); + done.Run(response.Pass()); } int NativeMessagingHostMain() { #if defined(OS_WIN) + // GetStdHandle() returns pseudo-handles for stdin and stdout even if + // the hosting executable specifies "Windows" subsystem. However the returned + // handles are invalid in that case unless standard input and output are + // redirected to a pipe or file. base::PlatformFile read_file = GetStdHandle(STD_INPUT_HANDLE); base::PlatformFile write_file = GetStdHandle(STD_OUTPUT_HANDLE); #elif defined(OS_POSIX) @@ -499,23 +458,47 @@ int NativeMessagingHostMain() { base::MessageLoop message_loop(base::MessageLoop::TYPE_IO); base::RunLoop run_loop; + + scoped_refptr<DaemonController> daemon_controller = + DaemonController::Create(); + + // Pass handle of the native view to the controller so that the UAC prompts + // are focused properly. + const CommandLine* command_line = CommandLine::ForCurrentProcess(); + if (command_line->HasSwitch(kParentWindowSwitchName)) { + std::string native_view = + command_line->GetSwitchValueASCII(kParentWindowSwitchName); + int64 native_view_handle = 0; + if (base::StringToInt64(native_view, &native_view_handle)) { + daemon_controller->SetWindow(reinterpret_cast<void*>(native_view_handle)); + } else { + LOG(WARNING) << "Invalid parameter value --" << kParentWindowSwitchName + << "=" << native_view; + } + } + // OAuth client (for credential requests). scoped_refptr<net::URLRequestContextGetter> url_request_context_getter( - new remoting::URLRequestContextGetter(message_loop.message_loop_proxy())); - scoped_ptr<remoting::OAuthClient> oauth_client( - new remoting::OAuthClient(url_request_context_getter)); + new URLRequestContextGetter(message_loop.message_loop_proxy())); + scoped_ptr<OAuthClient> oauth_client( + new OAuthClient(url_request_context_getter)); net::URLFetcher::SetIgnoreCertificateRequests(true); + // Create the pairing registry and native messaging host. scoped_refptr<protocol::PairingRegistry> pairing_registry = CreatePairingRegistry(message_loop.message_loop_proxy()); - remoting::NativeMessagingHost host(remoting::DaemonController::Create(), - pairing_registry, - oauth_client.Pass(), - read_file, write_file, - message_loop.message_loop_proxy(), - run_loop.QuitClosure()); - host.Start(); + scoped_ptr<NativeMessagingChannel::Delegate> host( + new NativeMessagingHost(daemon_controller, + pairing_registry, + oauth_client.Pass())); + + // Set up the native messaging channel. + scoped_ptr<NativeMessagingChannel> channel( + new NativeMessagingChannel(host.Pass(), read_file, write_file)); + channel->Start(run_loop.QuitClosure()); + + // Run the loop until channel is alive. run_loop.Run(); return kSuccessExitCode; } diff --git a/remoting/host/setup/native_messaging_host.h b/remoting/host/setup/native_messaging_host.h index eb7250a..4a359d7 100644 --- a/remoting/host/setup/native_messaging_host.h +++ b/remoting/host/setup/native_messaging_host.h @@ -5,21 +5,17 @@ #ifndef REMOTING_HOST_SETUP_NATIVE_MESSAGING_HOST_H_ #define REMOTING_HOST_SETUP_NATIVE_MESSAGING_HOST_H_ -#include "base/callback_forward.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" -#include "base/platform_file.h" +#include "base/threading/thread_checker.h" #include "remoting/host/setup/daemon_controller.h" -#include "remoting/host/setup/native_messaging_reader.h" -#include "remoting/host/setup/native_messaging_writer.h" +#include "remoting/host/setup/native_messaging_channel.h" #include "remoting/host/setup/oauth_client.h" namespace base { class DictionaryValue; class ListValue; -class SingleThreadTaskRunner; -class Value; } // namespace base namespace gaia { @@ -33,103 +29,111 @@ class PairingRegistry; } // namespace protocol // Implementation of the native messaging host process. -class NativeMessagingHost { +class NativeMessagingHost : public NativeMessagingChannel::Delegate { public: + typedef NativeMessagingChannel::SendResponseCallback SendResponseCallback; + NativeMessagingHost( scoped_refptr<DaemonController> daemon_controller, scoped_refptr<protocol::PairingRegistry> pairing_registry, - scoped_ptr<OAuthClient> oauth_client, - base::PlatformFile input, - base::PlatformFile output, - scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner, - const base::Closure& quit_closure); - ~NativeMessagingHost(); - - // Starts reading and processing messages. - void Start(); + scoped_ptr<OAuthClient> oauth_client); + virtual ~NativeMessagingHost(); - // Posts |quit_closure| to |caller_task_runner|. This gets called whenever an - // error is encountered during reading and processing messages. - void Shutdown(); + // NativeMessagingChannel::Delegate interface. + virtual void ProcessMessage(scoped_ptr<base::DictionaryValue> message, + const SendResponseCallback& done) OVERRIDE; private: - // Processes a message received from the client app. - void ProcessMessage(scoped_ptr<base::Value> message); - // These "Process.." methods handle specific request types. The |response| // dictionary is pre-filled by ProcessMessage() with the parts of the // response already known ("id" and "type" fields). - bool ProcessHello(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessClearPairedClients(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessDeletePairedClient(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetHostName(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetPinHash(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGenerateKeyPair(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessUpdateDaemonConfig(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetDaemonConfig(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetPairedClients(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetUsageStatsConsent(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessStartDaemon(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessStopDaemon(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetDaemonState(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - bool ProcessGetHostClientId(const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); + bool ProcessHello( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessClearPairedClients( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessDeletePairedClient( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetHostName( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetPinHash( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGenerateKeyPair( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessUpdateDaemonConfig( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetDaemonConfig( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetPairedClients( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetUsageStatsConsent( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessStartDaemon( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessStopDaemon( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetDaemonState( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); + bool ProcessGetHostClientId( + const base::DictionaryValue& message, + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); bool ProcessGetCredentialsFromAuthCode( const base::DictionaryValue& message, - scoped_ptr<base::DictionaryValue> response); - - // Sends a response back to the client app. This can be called on either the - // main message loop or the DaemonController's internal thread, so it - // PostTask()s to the main thread if necessary. - void SendResponse(scoped_ptr<base::DictionaryValue> response); + scoped_ptr<base::DictionaryValue> response, + const SendResponseCallback& done); // These Send... methods get called on the DaemonController's internal thread, // or on the calling thread if called by the PairingRegistry. // These methods fill in the |response| dictionary from the other parameters, // and pass it to SendResponse(). - void SendConfigResponse(scoped_ptr<base::DictionaryValue> response, + void SendConfigResponse(const SendResponseCallback& done, + scoped_ptr<base::DictionaryValue> response, scoped_ptr<base::DictionaryValue> config); - void SendPairedClientsResponse(scoped_ptr<base::DictionaryValue> response, + void SendPairedClientsResponse(const SendResponseCallback& done, + scoped_ptr<base::DictionaryValue> response, scoped_ptr<base::ListValue> pairings); void SendUsageStatsConsentResponse( + const SendResponseCallback& done, scoped_ptr<base::DictionaryValue> response, const DaemonController::UsageStatsConsent& consent); - void SendAsyncResult(scoped_ptr<base::DictionaryValue> response, + void SendAsyncResult(const SendResponseCallback& done, + scoped_ptr<base::DictionaryValue> response, DaemonController::AsyncResult result); - void SendBooleanResult(scoped_ptr<base::DictionaryValue> response, + void SendBooleanResult(const SendResponseCallback& done, + scoped_ptr<base::DictionaryValue> response, bool result); - void SendCredentialsResponse(scoped_ptr<base::DictionaryValue> response, + void SendCredentialsResponse(const SendResponseCallback& done, + scoped_ptr<base::DictionaryValue> response, const std::string& user_email, const std::string& refresh_token); - // Callbacks may be invoked by e.g. DaemonController during destruction, - // which use |weak_ptr_|, so it's important that it be the last member to be - // destroyed. - base::WeakPtr<NativeMessagingHost> weak_ptr_; - - scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_; - base::Closure quit_closure_; - - NativeMessagingReader native_messaging_reader_; - NativeMessagingWriter native_messaging_writer_; - - // The DaemonController may post tasks to this object during destruction (but - // not afterwards), so it needs to be destroyed before other members of this - // class (except for |weak_factory_|). - scoped_refptr<remoting::DaemonController> daemon_controller_; + scoped_refptr<DaemonController> daemon_controller_; // Used to load and update the paired clients for this host. scoped_refptr<protocol::PairingRegistry> pairing_registry_; @@ -137,13 +141,9 @@ class NativeMessagingHost { // Used to exchange the service account authorization code for credentials. scoped_ptr<OAuthClient> oauth_client_; - // Keeps track of pending requests. Used to delay shutdown until all responses - // have been sent. - int pending_requests_; - - // True if Shutdown() has been called. - bool shutdown_; + base::ThreadChecker thread_checker_; + base::WeakPtr<NativeMessagingHost> weak_ptr_; base::WeakPtrFactory<NativeMessagingHost> weak_factory_; DISALLOW_COPY_AND_ASSIGN(NativeMessagingHost); diff --git a/remoting/host/setup/native_messaging_host_unittest.cc b/remoting/host/setup/native_messaging_host_unittest.cc index 8ca2966..fbba357 100644 --- a/remoting/host/setup/native_messaging_host_unittest.cc +++ b/remoting/host/setup/native_messaging_host_unittest.cc @@ -18,6 +18,7 @@ #include "net/base/net_util.h" #include "remoting/base/auto_thread_task_runner.h" #include "remoting/host/pin_hash.h" +#include "remoting/host/setup/native_messaging_channel.h" #include "remoting/host/setup/test_util.h" #include "remoting/protocol/pairing_registry.h" #include "remoting/protocol/protocol_mock_objects.h" @@ -237,7 +238,8 @@ class NativeMessagingHostTest : public testing::Test { void TestBadRequest(const base::Value& message); protected: - // Reference to the MockDaemonControllerDelegate, which is owned by |host_|. + // Reference to the MockDaemonControllerDelegate, which is owned by + // |channel_|. MockDaemonControllerDelegate* daemon_controller_delegate_; private: @@ -247,14 +249,13 @@ class NativeMessagingHostTest : public testing::Test { // verifies output from output_read_handle. // // unittest -> [input] -> NativeMessagingHost -> [output] -> unittest - base::PlatformFile input_read_handle_; base::PlatformFile input_write_handle_; base::PlatformFile output_read_handle_; - base::PlatformFile output_write_handle_; base::MessageLoop message_loop_; base::RunLoop run_loop_; - scoped_ptr<remoting::NativeMessagingHost> host_; + scoped_refptr<AutoThreadTaskRunner> task_runner_; + scoped_ptr<remoting::NativeMessagingChannel> channel_; DISALLOW_COPY_AND_ASSIGN(NativeMessagingHostTest); }; @@ -265,11 +266,14 @@ NativeMessagingHostTest::NativeMessagingHostTest() NativeMessagingHostTest::~NativeMessagingHostTest() {} void NativeMessagingHostTest::SetUp() { - ASSERT_TRUE(MakePipe(&input_read_handle_, &input_write_handle_)); - ASSERT_TRUE(MakePipe(&output_read_handle_, &output_write_handle_)); + base::PlatformFile input_read_handle; + base::PlatformFile output_write_handle; + + ASSERT_TRUE(MakePipe(&input_read_handle, &input_write_handle_)); + ASSERT_TRUE(MakePipe(&output_read_handle_, &output_write_handle)); // Arrange to run |message_loop_| until no components depend on it. - scoped_refptr<AutoThreadTaskRunner> task_runner = new AutoThreadTaskRunner( + task_runner_ = new AutoThreadTaskRunner( message_loop_.message_loop_proxy(), run_loop_.QuitClosure()); daemon_controller_delegate_ = new MockDaemonControllerDelegate(); @@ -280,15 +284,14 @@ void NativeMessagingHostTest::SetUp() { scoped_refptr<PairingRegistry> pairing_registry = new SynchronousPairingRegistry(scoped_ptr<PairingRegistry::Delegate>( new MockPairingRegistryDelegate())); - - host_.reset(new NativeMessagingHost( - daemon_controller, - pairing_registry, - scoped_ptr<remoting::OAuthClient>(), - input_read_handle_, output_write_handle_, - task_runner, - base::Bind(&NativeMessagingHostTest::DeleteHost, - base::Unretained(this)))); + scoped_ptr<NativeMessagingChannel::Delegate> host( + new NativeMessagingHost(daemon_controller, + pairing_registry, + scoped_ptr<remoting::OAuthClient>())); + channel_.reset( + new NativeMessagingChannel(host.Pass(), + input_read_handle, + output_write_handle)); } void NativeMessagingHostTest::TearDown() { @@ -302,14 +305,16 @@ void NativeMessagingHostTest::Run() { // Close the write-end of input, so that the host sees EOF after reading // messages and won't block waiting for more input. base::ClosePlatformFile(input_write_handle_); - host_->Start(); + channel_->Start(base::Bind(&NativeMessagingHostTest::DeleteHost, + base::Unretained(this))); run_loop_.Run(); } void NativeMessagingHostTest::DeleteHost() { - // Destroy |host_| so that it closes its end of the output pipe, so that + // Destroy |channel_| so that it closes its end of the output pipe, so that // TestBadRequest() will see EOF and won't block waiting for more data. - host_.reset(); + channel_.reset(); + task_runner_ = NULL; } scoped_ptr<base::DictionaryValue> diff --git a/remoting/remoting.gyp b/remoting/remoting.gyp index ed450af..49abcf7 100644 --- a/remoting/remoting.gyp +++ b/remoting/remoting.gyp @@ -589,6 +589,8 @@ 'host/setup/daemon_installer_win.h', 'host/setup/host_starter.cc', 'host/setup/host_starter.h', + 'host/setup/native_messaging_channel.cc', + 'host/setup/native_messaging_channel.h', 'host/setup/native_messaging_host.cc', 'host/setup/native_messaging_host.h', 'host/setup/native_messaging_reader.cc', |