summaryrefslogtreecommitdiffstats
path: root/remoting
diff options
context:
space:
mode:
authoralexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-10-08 08:11:34 +0000
committeralexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-10-08 08:11:34 +0000
commit7459b5a58bce5d23fc79f6308ca5698b069466bf (patch)
tree7120bd82a74117af47bc865b9cd14d3598a9f405 /remoting
parenta13bb0bdb4c12914232f67a6e039ee1ab4b97f86 (diff)
downloadchromium_src-7459b5a58bce5d23fc79f6308ca5698b069466bf.zip
chromium_src-7459b5a58bce5d23fc79f6308ca5698b069466bf.tar.gz
chromium_src-7459b5a58bce5d23fc79f6308ca5698b069466bf.tar.bz2
Moved all channel-handling logic into separate NativeMessagingChannel class.
Changes in this CL: - NativeMessagingHost is used only for actual processing the messages. - NativeMessagingChannel implements receiving and sending messages. - NativeMessagingChannel takes ownership of the passed inout and output handles. - Both NativeMessagingChannel and NativeMessagingHost are explicitly marked as not thread-safe classes. BUG=173509 Review URL: https://codereview.chromium.org/23903021 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@227492 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting')
-rw-r--r--remoting/host/setup/native_messaging_channel.cc129
-rw-r--r--remoting/host/setup/native_messaging_channel.h94
-rw-r--r--remoting/host/setup/native_messaging_host.cc293
-rw-r--r--remoting/host/setup/native_messaging_host.h164
-rw-r--r--remoting/host/setup/native_messaging_host_unittest.cc43
-rw-r--r--remoting/remoting.gyp2
6 files changed, 469 insertions, 256 deletions
diff --git a/remoting/host/setup/native_messaging_channel.cc b/remoting/host/setup/native_messaging_channel.cc
new file mode 100644
index 0000000..b93fd94
--- /dev/null
+++ b/remoting/host/setup/native_messaging_channel.cc
@@ -0,0 +1,129 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "remoting/host/setup/native_messaging_channel.h"
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/location.h"
+#include "base/values.h"
+
+#if defined(OS_POSIX)
+#include <unistd.h>
+#endif
+
+namespace {
+
+base::PlatformFile DuplicatePlatformFile(base::PlatformFile handle) {
+ base::PlatformFile result;
+#if defined(OS_WIN)
+ if (!DuplicateHandle(GetCurrentProcess(),
+ handle,
+ GetCurrentProcess(),
+ &result,
+ 0,
+ FALSE,
+ DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS)) {
+ PLOG(ERROR) << "Failed to duplicate handle " << handle;
+ return base::kInvalidPlatformFileValue;
+ }
+ return result;
+#elif defined(OS_POSIX)
+ result = dup(handle);
+ base::ClosePlatformFile(handle);
+ return result;
+#else
+#error Not implemented.
+#endif
+}
+
+} // namespace
+
+namespace remoting {
+
+NativeMessagingChannel::NativeMessagingChannel(
+ scoped_ptr<Delegate> delegate,
+ base::PlatformFile input,
+ base::PlatformFile output)
+ : native_messaging_reader_(DuplicatePlatformFile(input)),
+ native_messaging_writer_(new NativeMessagingWriter(
+ DuplicatePlatformFile(output))),
+ delegate_(delegate.Pass()),
+ pending_requests_(0),
+ shutdown_(false),
+ weak_factory_(this) {
+ weak_ptr_ = weak_factory_.GetWeakPtr();
+}
+
+NativeMessagingChannel::~NativeMessagingChannel() {
+}
+
+void NativeMessagingChannel::Start(const base::Closure& quit_closure) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(quit_closure_.is_null());
+ DCHECK(!quit_closure.is_null());
+
+ quit_closure_ = quit_closure;
+ native_messaging_reader_.Start(
+ base::Bind(&NativeMessagingChannel::ProcessMessage, weak_ptr_),
+ base::Bind(&NativeMessagingChannel::Shutdown, weak_ptr_));
+}
+
+void NativeMessagingChannel::ProcessMessage(scoped_ptr<base::Value> message) {
+ DCHECK(CalledOnValidThread());
+
+ // Don't process any more messages if Shutdown() has been called.
+ if (shutdown_)
+ return;
+
+ if (message->GetType() != base::Value::TYPE_DICTIONARY) {
+ LOG(ERROR) << "Expected DictionaryValue";
+ Shutdown();
+ return;
+ }
+
+ DCHECK_GE(pending_requests_, 0);
+ pending_requests_++;
+
+ scoped_ptr<base::DictionaryValue> message_dict(
+ static_cast<base::DictionaryValue*>(message.release()));
+ delegate_->ProcessMessage(
+ message_dict.Pass(),
+ base::Bind(&NativeMessagingChannel::SendResponse, weak_ptr_));
+}
+
+void NativeMessagingChannel::SendResponse(
+ scoped_ptr<base::DictionaryValue> response) {
+ DCHECK(CalledOnValidThread());
+
+ bool success = response && native_messaging_writer_;
+ if (success)
+ success = native_messaging_writer_->WriteMessage(*response);
+
+ if (!success) {
+ // Close the write pipe so no more responses will be sent.
+ native_messaging_writer_.reset();
+ Shutdown();
+ }
+
+ pending_requests_--;
+ DCHECK_GE(pending_requests_, 0);
+
+ if (shutdown_ && !pending_requests_)
+ quit_closure_.Run();
+}
+
+void NativeMessagingChannel::Shutdown() {
+ DCHECK(CalledOnValidThread());
+
+ if (shutdown_)
+ return;
+
+ shutdown_ = true;
+ if (!pending_requests_)
+ quit_closure_.Run();
+}
+
+} // namespace remoting
diff --git a/remoting/host/setup/native_messaging_channel.h b/remoting/host/setup/native_messaging_channel.h
new file mode 100644
index 0000000..9b0c229
--- /dev/null
+++ b/remoting/host/setup/native_messaging_channel.h
@@ -0,0 +1,94 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef REMOTING_HOST_SETUP_NATIVE_MESSAGING_CHANNEL_H_
+#define REMOTING_HOST_SETUP_NATIVE_MESSAGING_CHANNEL_H_
+
+#include "base/callback.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "base/platform_file.h"
+#include "base/threading/non_thread_safe.h"
+#include "remoting/host/setup/native_messaging_reader.h"
+#include "remoting/host/setup/native_messaging_writer.h"
+
+namespace base {
+class DictionaryValue;
+class Value;
+} // namespace base
+
+namespace remoting {
+
+// Implements reading messages and sending responses across the native messaging
+// host pipe. Delegates processing of received messages to Delegate.
+//
+// TODO(alexeypa): Add ability to switch between different |delegate_| pointers
+// on the fly. This is useful for implementing UAC-style elevation on Windows -
+// an unprivileged delegate could be replaced with another delegate that
+// forwards messages to the elevated instance of the native messaging host.
+class NativeMessagingChannel : public base::NonThreadSafe {
+ public:
+ // Used to sends a response back to the client app.
+ typedef base::Callback<void(scoped_ptr<base::DictionaryValue> response)>
+ SendResponseCallback;
+
+ class Delegate {
+ public:
+ virtual ~Delegate() {}
+
+ // Processes a message received from the client app. Invokes |done| to send
+ // a response back to the client app.
+ virtual void ProcessMessage(scoped_ptr<base::DictionaryValue> message,
+ const SendResponseCallback& done) = 0;
+ };
+
+ // Constructs an object taking the ownership of |input| and |output|. Closes
+ // |input| and |output| to prevent the caller frmo using them.
+ NativeMessagingChannel(
+ scoped_ptr<Delegate> delegate,
+ base::PlatformFile input,
+ base::PlatformFile output);
+ ~NativeMessagingChannel();
+
+ // Starts reading and processing messages.
+ void Start(const base::Closure& quit_closure);
+
+ private:
+ // Processes a message received from the client app.
+ void ProcessMessage(scoped_ptr<base::Value> message);
+
+ // Sends a response back to the client app.
+ void SendResponse(scoped_ptr<base::DictionaryValue> response);
+
+ // Initiates shutdown and runs |quit_closure| if there are no pending requests
+ // left.
+ void Shutdown();
+
+ base::Closure quit_closure_;
+
+ NativeMessagingReader native_messaging_reader_;
+ scoped_ptr<NativeMessagingWriter> native_messaging_writer_;
+
+ // |delegate_| may post tasks to this object during destruction (but not
+ // afterwards), so it needs to be destroyed before other members of this class
+ // (except for |weak_factory_|).
+ scoped_ptr<Delegate> delegate_;
+
+ // Keeps track of pending requests. Used to delay shutdown until all responses
+ // have been sent.
+ int pending_requests_;
+
+ // True if Shutdown() has been called.
+ bool shutdown_;
+
+ base::WeakPtr<NativeMessagingChannel> weak_ptr_;
+ base::WeakPtrFactory<NativeMessagingChannel> weak_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(NativeMessagingChannel);
+};
+
+} // namespace remoting
+
+#endif // REMOTING_HOST_SETUP_NATIVE_MESSAGING_CHANNEL_H_
diff --git a/remoting/host/setup/native_messaging_host.cc b/remoting/host/setup/native_messaging_host.cc
index 275e9fe..fd2c0c6 100644
--- a/remoting/host/setup/native_messaging_host.cc
+++ b/remoting/host/setup/native_messaging_host.cc
@@ -9,10 +9,11 @@
#include "base/basictypes.h"
#include "base/bind.h"
#include "base/callback.h"
-#include "base/json/json_string_value_serializer.h"
-#include "base/location.h"
+#include "base/command_line.h"
+#include "base/logging.h"
#include "base/message_loop/message_loop.h"
#include "base/run_loop.h"
+#include "base/strings/string_number_conversions.h"
#include "base/strings/stringize_macros.h"
#include "base/values.h"
#include "google_apis/gaia/gaia_oauth_client.h"
@@ -27,22 +28,20 @@
#include "remoting/host/setup/oauth_client.h"
#include "remoting/protocol/pairing_registry.h"
-#if defined(OS_POSIX)
-#include <unistd.h>
-#endif
-
namespace {
+const char kParentWindowSwitchName[] = "parent-window";
+
+// redirect_uri to use when authenticating service accounts (service account
+// codes are obtained "out-of-band", i.e., not through an OAuth redirect).
+const char* kServiceAccountRedirectUri = "oob";
+
// Features supported in addition to the base protocol.
const char* kSupportedFeatures[] = {
"pairingRegistry",
"oauthClient"
};
-// redirect_uri to use when authenticating service accounts (service account
-// codes are obtained "out-of-band", i.e., not through an OAuth redirect).
-const char* kServiceAccountRedirectUri = "oob";
-
// Helper to extract the "config" part of a message as a DictionaryValue.
// Returns NULL on failure, and logs an error message.
scoped_ptr<base::DictionaryValue> ConfigDictionaryFromMessage(
@@ -64,151 +63,108 @@ namespace remoting {
NativeMessagingHost::NativeMessagingHost(
scoped_refptr<DaemonController> daemon_controller,
scoped_refptr<protocol::PairingRegistry> pairing_registry,
- scoped_ptr<OAuthClient> oauth_client,
- base::PlatformFile input,
- base::PlatformFile output,
- scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
- const base::Closure& quit_closure)
- : caller_task_runner_(caller_task_runner),
- quit_closure_(quit_closure),
- native_messaging_reader_(input),
- native_messaging_writer_(output),
- daemon_controller_(daemon_controller),
+ scoped_ptr<OAuthClient> oauth_client)
+ : daemon_controller_(daemon_controller),
pairing_registry_(pairing_registry),
oauth_client_(oauth_client.Pass()),
- pending_requests_(0),
- shutdown_(false),
weak_factory_(this) {
weak_ptr_ = weak_factory_.GetWeakPtr();
}
-NativeMessagingHost::~NativeMessagingHost() {}
-
-void NativeMessagingHost::Start() {
- DCHECK(caller_task_runner_->BelongsToCurrentThread());
-
- native_messaging_reader_.Start(
- base::Bind(&NativeMessagingHost::ProcessMessage, weak_ptr_),
- base::Bind(&NativeMessagingHost::Shutdown, weak_ptr_));
+NativeMessagingHost::~NativeMessagingHost() {
}
-void NativeMessagingHost::Shutdown() {
- DCHECK(caller_task_runner_->BelongsToCurrentThread());
-
- if (shutdown_)
- return;
-
- shutdown_ = true;
- if (!pending_requests_)
- caller_task_runner_->PostTask(FROM_HERE, quit_closure_);
-}
-
-void NativeMessagingHost::ProcessMessage(scoped_ptr<base::Value> message) {
- DCHECK(caller_task_runner_->BelongsToCurrentThread());
-
- // Don't process any more messages if Shutdown() has been called.
- if (shutdown_)
- return;
-
- const base::DictionaryValue* message_dict;
- if (!message->GetAsDictionary(&message_dict)) {
- LOG(ERROR) << "Expected DictionaryValue";
- Shutdown();
- return;
- }
-
- scoped_ptr<base::DictionaryValue> response_dict(new base::DictionaryValue());
+void NativeMessagingHost::ProcessMessage(
+ scoped_ptr<base::DictionaryValue> message,
+ const SendResponseCallback& done) {
+ scoped_ptr<base::DictionaryValue> response(new base::DictionaryValue());
// If the client supplies an ID, it will expect it in the response. This
// might be a string or a number, so cope with both.
const base::Value* id;
- if (message_dict->Get("id", &id))
- response_dict->Set("id", id->DeepCopy());
+ if (message->Get("id", &id))
+ response->Set("id", id->DeepCopy());
std::string type;
- if (!message_dict->GetString("type", &type)) {
+ if (!message->GetString("type", &type)) {
LOG(ERROR) << "'type' not found";
- Shutdown();
+ done.Run(scoped_ptr<base::DictionaryValue>());
return;
}
- response_dict->SetString("type", type + "Response");
-
- DCHECK_GE(pending_requests_, 0);
- pending_requests_++;
+ response->SetString("type", type + "Response");
bool success = false;
if (type == "hello") {
- success = ProcessHello(*message_dict, response_dict.Pass());
+ success = ProcessHello(*message, response.Pass(), done);
} else if (type == "clearPairedClients") {
- success = ProcessClearPairedClients(*message_dict, response_dict.Pass());
+ success = ProcessClearPairedClients(*message, response.Pass(), done);
} else if (type == "deletePairedClient") {
- success = ProcessDeletePairedClient(*message_dict, response_dict.Pass());
+ success = ProcessDeletePairedClient(*message, response.Pass(), done);
} else if (type == "getHostName") {
- success = ProcessGetHostName(*message_dict, response_dict.Pass());
+ success = ProcessGetHostName(*message, response.Pass(), done);
} else if (type == "getPinHash") {
- success = ProcessGetPinHash(*message_dict, response_dict.Pass());
+ success = ProcessGetPinHash(*message, response.Pass(), done);
} else if (type == "generateKeyPair") {
- success = ProcessGenerateKeyPair(*message_dict, response_dict.Pass());
+ success = ProcessGenerateKeyPair(*message, response.Pass(), done);
} else if (type == "updateDaemonConfig") {
- success = ProcessUpdateDaemonConfig(*message_dict, response_dict.Pass());
+ success = ProcessUpdateDaemonConfig(*message, response.Pass(), done);
} else if (type == "getDaemonConfig") {
- success = ProcessGetDaemonConfig(*message_dict, response_dict.Pass());
+ success = ProcessGetDaemonConfig(*message, response.Pass(), done);
} else if (type == "getPairedClients") {
- success = ProcessGetPairedClients(*message_dict, response_dict.Pass());
+ success = ProcessGetPairedClients(*message, response.Pass(), done);
} else if (type == "getUsageStatsConsent") {
- success = ProcessGetUsageStatsConsent(*message_dict, response_dict.Pass());
+ success = ProcessGetUsageStatsConsent(*message, response.Pass(), done);
} else if (type == "startDaemon") {
- success = ProcessStartDaemon(*message_dict, response_dict.Pass());
+ success = ProcessStartDaemon(*message, response.Pass(), done);
} else if (type == "stopDaemon") {
- success = ProcessStopDaemon(*message_dict, response_dict.Pass());
+ success = ProcessStopDaemon(*message, response.Pass(), done);
} else if (type == "getDaemonState") {
- success = ProcessGetDaemonState(*message_dict, response_dict.Pass());
+ success = ProcessGetDaemonState(*message, response.Pass(), done);
} else if (type == "getHostClientId") {
- success = ProcessGetHostClientId(*message_dict, response_dict.Pass());
+ success = ProcessGetHostClientId(*message, response.Pass(), done);
} else if (type == "getCredentialsFromAuthCode") {
- success = ProcessGetCredentialsFromAuthCode(
- *message_dict, response_dict.Pass());
+ success = ProcessGetCredentialsFromAuthCode(*message, response.Pass(),
+ done);
} else {
LOG(ERROR) << "Unsupported request type: " << type;
}
- if (!success) {
- pending_requests_--;
- DCHECK_GE(pending_requests_, 0);
-
- Shutdown();
- }
+ if (!success)
+ done.Run(scoped_ptr<base::DictionaryValue>());
}
bool NativeMessagingHost::ProcessHello(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
response->SetString("version", STRINGIZE(VERSION));
scoped_ptr<base::ListValue> supported_features_list(new base::ListValue());
supported_features_list->AppendStrings(std::vector<std::string>(
kSupportedFeatures, kSupportedFeatures + arraysize(kSupportedFeatures)));
response->Set("supportedFeatures", supported_features_list.release());
- SendResponse(response.Pass());
+ done.Run(response.Pass());
return true;
}
bool NativeMessagingHost::ProcessClearPairedClients(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
if (pairing_registry_) {
pairing_registry_->ClearAllPairings(
base::Bind(&NativeMessagingHost::SendBooleanResult, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
} else {
- SendBooleanResult(response.Pass(), false);
+ SendBooleanResult(done, response.Pass(), false);
}
return true;
}
bool NativeMessagingHost::ProcessDeletePairedClient(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
std::string client_id;
if (!message.GetString(protocol::PairingRegistry::kClientIdKey, &client_id)) {
LOG(ERROR) << "'" << protocol::PairingRegistry::kClientIdKey
@@ -219,24 +175,26 @@ bool NativeMessagingHost::ProcessDeletePairedClient(
if (pairing_registry_) {
pairing_registry_->DeletePairing(
client_id, base::Bind(&NativeMessagingHost::SendBooleanResult,
- weak_ptr_, base::Passed(&response)));
+ weak_ptr_, done, base::Passed(&response)));
} else {
- SendBooleanResult(response.Pass(), false);
+ SendBooleanResult(done, response.Pass(), false);
}
return true;
}
bool NativeMessagingHost::ProcessGetHostName(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
response->SetString("hostname", net::GetHostName());
- SendResponse(response.Pass());
+ done.Run(response.Pass());
return true;
}
bool NativeMessagingHost::ProcessGetPinHash(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
std::string host_id;
if (!message.GetString("hostId", &host_id)) {
LOG(ERROR) << "'hostId' not found: " << message;
@@ -247,24 +205,26 @@ bool NativeMessagingHost::ProcessGetPinHash(
LOG(ERROR) << "'pin' not found: " << message;
return false;
}
- response->SetString("hash", remoting::MakeHostPinHash(host_id, pin));
- SendResponse(response.Pass());
+ response->SetString("hash", MakeHostPinHash(host_id, pin));
+ done.Run(response.Pass());
return true;
}
bool NativeMessagingHost::ProcessGenerateKeyPair(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
scoped_refptr<RsaKeyPair> key_pair = RsaKeyPair::Generate();
response->SetString("privateKey", key_pair->ToString());
response->SetString("publicKey", key_pair->GetPublicKey());
- SendResponse(response.Pass());
+ done.Run(response.Pass());
return true;
}
bool NativeMessagingHost::ProcessUpdateDaemonConfig(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
scoped_ptr<base::DictionaryValue> config_dict =
ConfigDictionaryFromMessage(message);
if (!config_dict)
@@ -273,45 +233,49 @@ bool NativeMessagingHost::ProcessUpdateDaemonConfig(
daemon_controller_->UpdateConfig(
config_dict.Pass(),
base::Bind(&NativeMessagingHost::SendAsyncResult, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
return true;
}
bool NativeMessagingHost::ProcessGetDaemonConfig(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
daemon_controller_->GetConfig(
base::Bind(&NativeMessagingHost::SendConfigResponse, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
return true;
}
bool NativeMessagingHost::ProcessGetPairedClients(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
if (pairing_registry_) {
pairing_registry_->GetAllPairings(
base::Bind(&NativeMessagingHost::SendPairedClientsResponse, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
} else {
scoped_ptr<base::ListValue> no_paired_clients(new base::ListValue);
- SendPairedClientsResponse(response.Pass(), no_paired_clients.Pass());
+ SendPairedClientsResponse(done, response.Pass(), no_paired_clients.Pass());
}
return true;
}
bool NativeMessagingHost::ProcessGetUsageStatsConsent(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
daemon_controller_->GetUsageStatsConsent(
base::Bind(&NativeMessagingHost::SendUsageStatsConsentResponse,
- weak_ptr_, base::Passed(&response)));
+ weak_ptr_, done, base::Passed(&response)));
return true;
}
bool NativeMessagingHost::ProcessStartDaemon(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
bool consent;
if (!message.GetBoolean("consent", &consent)) {
LOG(ERROR) << "'consent' not found.";
@@ -326,22 +290,24 @@ bool NativeMessagingHost::ProcessStartDaemon(
daemon_controller_->SetConfigAndStart(
config_dict.Pass(), consent,
base::Bind(&NativeMessagingHost::SendAsyncResult, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
return true;
}
bool NativeMessagingHost::ProcessStopDaemon(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
daemon_controller_->Stop(
base::Bind(&NativeMessagingHost::SendAsyncResult, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
return true;
}
bool NativeMessagingHost::ProcessGetDaemonState(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
DaemonController::State state = daemon_controller_->GetState();
switch (state) {
case DaemonController::STATE_NOT_IMPLEMENTED:
@@ -369,22 +335,24 @@ bool NativeMessagingHost::ProcessGetDaemonState(
response->SetString("state", "UNKNOWN");
break;
}
- SendResponse(response.Pass());
+ done.Run(response.Pass());
return true;
}
bool NativeMessagingHost::ProcessGetHostClientId(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
response->SetString("clientId", google_apis::GetOAuth2ClientID(
google_apis::CLIENT_REMOTING_HOST));
- SendResponse(response.Pass());
+ done.Run(response.Pass());
return true;
}
bool NativeMessagingHost::ProcessGetCredentialsFromAuthCode(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response) {
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done) {
std::string auth_code;
if (!message.GetString("authorizationCode", &auth_code)) {
LOG(ERROR) << "'authorizationCode' string not found.";
@@ -400,31 +368,13 @@ bool NativeMessagingHost::ProcessGetCredentialsFromAuthCode(
oauth_client_->GetCredentialsFromAuthCode(
oauth_client_info, auth_code, base::Bind(
&NativeMessagingHost::SendCredentialsResponse, weak_ptr_,
- base::Passed(&response)));
+ done, base::Passed(&response)));
return true;
}
-void NativeMessagingHost::SendResponse(
- scoped_ptr<base::DictionaryValue> response) {
- if (!caller_task_runner_->BelongsToCurrentThread()) {
- caller_task_runner_->PostTask(
- FROM_HERE, base::Bind(&NativeMessagingHost::SendResponse, weak_ptr_,
- base::Passed(&response)));
- return;
- }
-
- if (!native_messaging_writer_.WriteMessage(*response))
- Shutdown();
-
- pending_requests_--;
- DCHECK_GE(pending_requests_, 0);
-
- if (shutdown_ && !pending_requests_)
- caller_task_runner_->PostTask(FROM_HERE, quit_closure_);
-}
-
void NativeMessagingHost::SendConfigResponse(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response,
scoped_ptr<base::DictionaryValue> config) {
if (config) {
@@ -432,26 +382,29 @@ void NativeMessagingHost::SendConfigResponse(
} else {
response->Set("config", Value::CreateNullValue());
}
- SendResponse(response.Pass());
+ done.Run(response.Pass());
}
void NativeMessagingHost::SendPairedClientsResponse(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response,
scoped_ptr<base::ListValue> pairings) {
response->Set("pairedClients", pairings.release());
- SendResponse(response.Pass());
+ done.Run(response.Pass());
}
void NativeMessagingHost::SendUsageStatsConsentResponse(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response,
const DaemonController::UsageStatsConsent& consent) {
response->SetBoolean("supported", consent.supported);
response->SetBoolean("allowed", consent.allowed);
response->SetBoolean("setByPolicy", consent.set_by_policy);
- SendResponse(response.Pass());
+ done.Run(response.Pass());
}
void NativeMessagingHost::SendAsyncResult(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response,
DaemonController::AsyncResult result) {
switch (result) {
@@ -468,26 +421,32 @@ void NativeMessagingHost::SendAsyncResult(
response->SetString("result", "FAILED_DIRECTORY");
break;
}
- SendResponse(response.Pass());
+ done.Run(response.Pass());
}
void NativeMessagingHost::SendBooleanResult(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response, bool result) {
response->SetBoolean("result", result);
- SendResponse(response.Pass());
+ done.Run(response.Pass());
}
void NativeMessagingHost::SendCredentialsResponse(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response,
const std::string& user_email,
const std::string& refresh_token) {
response->SetString("userEmail", user_email);
response->SetString("refreshToken", refresh_token);
- SendResponse(response.Pass());
+ done.Run(response.Pass());
}
int NativeMessagingHostMain() {
#if defined(OS_WIN)
+ // GetStdHandle() returns pseudo-handles for stdin and stdout even if
+ // the hosting executable specifies "Windows" subsystem. However the returned
+ // handles are invalid in that case unless standard input and output are
+ // redirected to a pipe or file.
base::PlatformFile read_file = GetStdHandle(STD_INPUT_HANDLE);
base::PlatformFile write_file = GetStdHandle(STD_OUTPUT_HANDLE);
#elif defined(OS_POSIX)
@@ -499,23 +458,47 @@ int NativeMessagingHostMain() {
base::MessageLoop message_loop(base::MessageLoop::TYPE_IO);
base::RunLoop run_loop;
+
+ scoped_refptr<DaemonController> daemon_controller =
+ DaemonController::Create();
+
+ // Pass handle of the native view to the controller so that the UAC prompts
+ // are focused properly.
+ const CommandLine* command_line = CommandLine::ForCurrentProcess();
+ if (command_line->HasSwitch(kParentWindowSwitchName)) {
+ std::string native_view =
+ command_line->GetSwitchValueASCII(kParentWindowSwitchName);
+ int64 native_view_handle = 0;
+ if (base::StringToInt64(native_view, &native_view_handle)) {
+ daemon_controller->SetWindow(reinterpret_cast<void*>(native_view_handle));
+ } else {
+ LOG(WARNING) << "Invalid parameter value --" << kParentWindowSwitchName
+ << "=" << native_view;
+ }
+ }
+
// OAuth client (for credential requests).
scoped_refptr<net::URLRequestContextGetter> url_request_context_getter(
- new remoting::URLRequestContextGetter(message_loop.message_loop_proxy()));
- scoped_ptr<remoting::OAuthClient> oauth_client(
- new remoting::OAuthClient(url_request_context_getter));
+ new URLRequestContextGetter(message_loop.message_loop_proxy()));
+ scoped_ptr<OAuthClient> oauth_client(
+ new OAuthClient(url_request_context_getter));
net::URLFetcher::SetIgnoreCertificateRequests(true);
+ // Create the pairing registry and native messaging host.
scoped_refptr<protocol::PairingRegistry> pairing_registry =
CreatePairingRegistry(message_loop.message_loop_proxy());
- remoting::NativeMessagingHost host(remoting::DaemonController::Create(),
- pairing_registry,
- oauth_client.Pass(),
- read_file, write_file,
- message_loop.message_loop_proxy(),
- run_loop.QuitClosure());
- host.Start();
+ scoped_ptr<NativeMessagingChannel::Delegate> host(
+ new NativeMessagingHost(daemon_controller,
+ pairing_registry,
+ oauth_client.Pass()));
+
+ // Set up the native messaging channel.
+ scoped_ptr<NativeMessagingChannel> channel(
+ new NativeMessagingChannel(host.Pass(), read_file, write_file));
+ channel->Start(run_loop.QuitClosure());
+
+ // Run the loop until channel is alive.
run_loop.Run();
return kSuccessExitCode;
}
diff --git a/remoting/host/setup/native_messaging_host.h b/remoting/host/setup/native_messaging_host.h
index eb7250a..4a359d7 100644
--- a/remoting/host/setup/native_messaging_host.h
+++ b/remoting/host/setup/native_messaging_host.h
@@ -5,21 +5,17 @@
#ifndef REMOTING_HOST_SETUP_NATIVE_MESSAGING_HOST_H_
#define REMOTING_HOST_SETUP_NATIVE_MESSAGING_HOST_H_
-#include "base/callback_forward.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/memory/weak_ptr.h"
-#include "base/platform_file.h"
+#include "base/threading/thread_checker.h"
#include "remoting/host/setup/daemon_controller.h"
-#include "remoting/host/setup/native_messaging_reader.h"
-#include "remoting/host/setup/native_messaging_writer.h"
+#include "remoting/host/setup/native_messaging_channel.h"
#include "remoting/host/setup/oauth_client.h"
namespace base {
class DictionaryValue;
class ListValue;
-class SingleThreadTaskRunner;
-class Value;
} // namespace base
namespace gaia {
@@ -33,103 +29,111 @@ class PairingRegistry;
} // namespace protocol
// Implementation of the native messaging host process.
-class NativeMessagingHost {
+class NativeMessagingHost : public NativeMessagingChannel::Delegate {
public:
+ typedef NativeMessagingChannel::SendResponseCallback SendResponseCallback;
+
NativeMessagingHost(
scoped_refptr<DaemonController> daemon_controller,
scoped_refptr<protocol::PairingRegistry> pairing_registry,
- scoped_ptr<OAuthClient> oauth_client,
- base::PlatformFile input,
- base::PlatformFile output,
- scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
- const base::Closure& quit_closure);
- ~NativeMessagingHost();
-
- // Starts reading and processing messages.
- void Start();
+ scoped_ptr<OAuthClient> oauth_client);
+ virtual ~NativeMessagingHost();
- // Posts |quit_closure| to |caller_task_runner|. This gets called whenever an
- // error is encountered during reading and processing messages.
- void Shutdown();
+ // NativeMessagingChannel::Delegate interface.
+ virtual void ProcessMessage(scoped_ptr<base::DictionaryValue> message,
+ const SendResponseCallback& done) OVERRIDE;
private:
- // Processes a message received from the client app.
- void ProcessMessage(scoped_ptr<base::Value> message);
-
// These "Process.." methods handle specific request types. The |response|
// dictionary is pre-filled by ProcessMessage() with the parts of the
// response already known ("id" and "type" fields).
- bool ProcessHello(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessClearPairedClients(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessDeletePairedClient(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetHostName(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetPinHash(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGenerateKeyPair(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessUpdateDaemonConfig(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetDaemonConfig(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetPairedClients(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetUsageStatsConsent(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessStartDaemon(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessStopDaemon(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetDaemonState(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
- bool ProcessGetHostClientId(const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
+ bool ProcessHello(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessClearPairedClients(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessDeletePairedClient(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetHostName(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetPinHash(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGenerateKeyPair(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessUpdateDaemonConfig(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetDaemonConfig(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetPairedClients(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetUsageStatsConsent(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessStartDaemon(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessStopDaemon(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetDaemonState(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
+ bool ProcessGetHostClientId(
+ const base::DictionaryValue& message,
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
bool ProcessGetCredentialsFromAuthCode(
const base::DictionaryValue& message,
- scoped_ptr<base::DictionaryValue> response);
-
- // Sends a response back to the client app. This can be called on either the
- // main message loop or the DaemonController's internal thread, so it
- // PostTask()s to the main thread if necessary.
- void SendResponse(scoped_ptr<base::DictionaryValue> response);
+ scoped_ptr<base::DictionaryValue> response,
+ const SendResponseCallback& done);
// These Send... methods get called on the DaemonController's internal thread,
// or on the calling thread if called by the PairingRegistry.
// These methods fill in the |response| dictionary from the other parameters,
// and pass it to SendResponse().
- void SendConfigResponse(scoped_ptr<base::DictionaryValue> response,
+ void SendConfigResponse(const SendResponseCallback& done,
+ scoped_ptr<base::DictionaryValue> response,
scoped_ptr<base::DictionaryValue> config);
- void SendPairedClientsResponse(scoped_ptr<base::DictionaryValue> response,
+ void SendPairedClientsResponse(const SendResponseCallback& done,
+ scoped_ptr<base::DictionaryValue> response,
scoped_ptr<base::ListValue> pairings);
void SendUsageStatsConsentResponse(
+ const SendResponseCallback& done,
scoped_ptr<base::DictionaryValue> response,
const DaemonController::UsageStatsConsent& consent);
- void SendAsyncResult(scoped_ptr<base::DictionaryValue> response,
+ void SendAsyncResult(const SendResponseCallback& done,
+ scoped_ptr<base::DictionaryValue> response,
DaemonController::AsyncResult result);
- void SendBooleanResult(scoped_ptr<base::DictionaryValue> response,
+ void SendBooleanResult(const SendResponseCallback& done,
+ scoped_ptr<base::DictionaryValue> response,
bool result);
- void SendCredentialsResponse(scoped_ptr<base::DictionaryValue> response,
+ void SendCredentialsResponse(const SendResponseCallback& done,
+ scoped_ptr<base::DictionaryValue> response,
const std::string& user_email,
const std::string& refresh_token);
- // Callbacks may be invoked by e.g. DaemonController during destruction,
- // which use |weak_ptr_|, so it's important that it be the last member to be
- // destroyed.
- base::WeakPtr<NativeMessagingHost> weak_ptr_;
-
- scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_;
- base::Closure quit_closure_;
-
- NativeMessagingReader native_messaging_reader_;
- NativeMessagingWriter native_messaging_writer_;
-
- // The DaemonController may post tasks to this object during destruction (but
- // not afterwards), so it needs to be destroyed before other members of this
- // class (except for |weak_factory_|).
- scoped_refptr<remoting::DaemonController> daemon_controller_;
+ scoped_refptr<DaemonController> daemon_controller_;
// Used to load and update the paired clients for this host.
scoped_refptr<protocol::PairingRegistry> pairing_registry_;
@@ -137,13 +141,9 @@ class NativeMessagingHost {
// Used to exchange the service account authorization code for credentials.
scoped_ptr<OAuthClient> oauth_client_;
- // Keeps track of pending requests. Used to delay shutdown until all responses
- // have been sent.
- int pending_requests_;
-
- // True if Shutdown() has been called.
- bool shutdown_;
+ base::ThreadChecker thread_checker_;
+ base::WeakPtr<NativeMessagingHost> weak_ptr_;
base::WeakPtrFactory<NativeMessagingHost> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(NativeMessagingHost);
diff --git a/remoting/host/setup/native_messaging_host_unittest.cc b/remoting/host/setup/native_messaging_host_unittest.cc
index 8ca2966..fbba357 100644
--- a/remoting/host/setup/native_messaging_host_unittest.cc
+++ b/remoting/host/setup/native_messaging_host_unittest.cc
@@ -18,6 +18,7 @@
#include "net/base/net_util.h"
#include "remoting/base/auto_thread_task_runner.h"
#include "remoting/host/pin_hash.h"
+#include "remoting/host/setup/native_messaging_channel.h"
#include "remoting/host/setup/test_util.h"
#include "remoting/protocol/pairing_registry.h"
#include "remoting/protocol/protocol_mock_objects.h"
@@ -237,7 +238,8 @@ class NativeMessagingHostTest : public testing::Test {
void TestBadRequest(const base::Value& message);
protected:
- // Reference to the MockDaemonControllerDelegate, which is owned by |host_|.
+ // Reference to the MockDaemonControllerDelegate, which is owned by
+ // |channel_|.
MockDaemonControllerDelegate* daemon_controller_delegate_;
private:
@@ -247,14 +249,13 @@ class NativeMessagingHostTest : public testing::Test {
// verifies output from output_read_handle.
//
// unittest -> [input] -> NativeMessagingHost -> [output] -> unittest
- base::PlatformFile input_read_handle_;
base::PlatformFile input_write_handle_;
base::PlatformFile output_read_handle_;
- base::PlatformFile output_write_handle_;
base::MessageLoop message_loop_;
base::RunLoop run_loop_;
- scoped_ptr<remoting::NativeMessagingHost> host_;
+ scoped_refptr<AutoThreadTaskRunner> task_runner_;
+ scoped_ptr<remoting::NativeMessagingChannel> channel_;
DISALLOW_COPY_AND_ASSIGN(NativeMessagingHostTest);
};
@@ -265,11 +266,14 @@ NativeMessagingHostTest::NativeMessagingHostTest()
NativeMessagingHostTest::~NativeMessagingHostTest() {}
void NativeMessagingHostTest::SetUp() {
- ASSERT_TRUE(MakePipe(&input_read_handle_, &input_write_handle_));
- ASSERT_TRUE(MakePipe(&output_read_handle_, &output_write_handle_));
+ base::PlatformFile input_read_handle;
+ base::PlatformFile output_write_handle;
+
+ ASSERT_TRUE(MakePipe(&input_read_handle, &input_write_handle_));
+ ASSERT_TRUE(MakePipe(&output_read_handle_, &output_write_handle));
// Arrange to run |message_loop_| until no components depend on it.
- scoped_refptr<AutoThreadTaskRunner> task_runner = new AutoThreadTaskRunner(
+ task_runner_ = new AutoThreadTaskRunner(
message_loop_.message_loop_proxy(), run_loop_.QuitClosure());
daemon_controller_delegate_ = new MockDaemonControllerDelegate();
@@ -280,15 +284,14 @@ void NativeMessagingHostTest::SetUp() {
scoped_refptr<PairingRegistry> pairing_registry =
new SynchronousPairingRegistry(scoped_ptr<PairingRegistry::Delegate>(
new MockPairingRegistryDelegate()));
-
- host_.reset(new NativeMessagingHost(
- daemon_controller,
- pairing_registry,
- scoped_ptr<remoting::OAuthClient>(),
- input_read_handle_, output_write_handle_,
- task_runner,
- base::Bind(&NativeMessagingHostTest::DeleteHost,
- base::Unretained(this))));
+ scoped_ptr<NativeMessagingChannel::Delegate> host(
+ new NativeMessagingHost(daemon_controller,
+ pairing_registry,
+ scoped_ptr<remoting::OAuthClient>()));
+ channel_.reset(
+ new NativeMessagingChannel(host.Pass(),
+ input_read_handle,
+ output_write_handle));
}
void NativeMessagingHostTest::TearDown() {
@@ -302,14 +305,16 @@ void NativeMessagingHostTest::Run() {
// Close the write-end of input, so that the host sees EOF after reading
// messages and won't block waiting for more input.
base::ClosePlatformFile(input_write_handle_);
- host_->Start();
+ channel_->Start(base::Bind(&NativeMessagingHostTest::DeleteHost,
+ base::Unretained(this)));
run_loop_.Run();
}
void NativeMessagingHostTest::DeleteHost() {
- // Destroy |host_| so that it closes its end of the output pipe, so that
+ // Destroy |channel_| so that it closes its end of the output pipe, so that
// TestBadRequest() will see EOF and won't block waiting for more data.
- host_.reset();
+ channel_.reset();
+ task_runner_ = NULL;
}
scoped_ptr<base::DictionaryValue>
diff --git a/remoting/remoting.gyp b/remoting/remoting.gyp
index ed450af..49abcf7 100644
--- a/remoting/remoting.gyp
+++ b/remoting/remoting.gyp
@@ -589,6 +589,8 @@
'host/setup/daemon_installer_win.h',
'host/setup/host_starter.cc',
'host/setup/host_starter.h',
+ 'host/setup/native_messaging_channel.cc',
+ 'host/setup/native_messaging_channel.h',
'host/setup/native_messaging_host.cc',
'host/setup/native_messaging_host.h',
'host/setup/native_messaging_reader.cc',