summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--components/components_tests.gyp1
-rw-r--r--components/proximity_auth.gypi2
-rw-r--r--components/proximity_auth/BUILD.gn3
-rw-r--r--components/proximity_auth/ble/fake_wire_message.cc3
-rw-r--r--components/proximity_auth/client_impl.cc151
-rw-r--r--components/proximity_auth/client_impl.h9
-rw-r--r--components/proximity_auth/client_impl_unittest.cc13
-rw-r--r--components/proximity_auth/connection_unittest.cc2
-rw-r--r--components/proximity_auth/device_to_device_secure_context.cc119
-rw-r--r--components/proximity_auth/device_to_device_secure_context.h77
-rw-r--r--components/proximity_auth/device_to_device_secure_context_unittest.cc113
-rw-r--r--components/proximity_auth/secure_context.h19
-rw-r--r--components/proximity_auth/wire_message.cc12
-rw-r--r--components/proximity_auth/wire_message.h20
-rw-r--r--components/proximity_auth/wire_message_unittest.cc6
15 files changed, 453 insertions, 97 deletions
diff --git a/components/components_tests.gyp b/components/components_tests.gyp
index 09d599f..f8bc6fb 100644
--- a/components/components_tests.gyp
+++ b/components/components_tests.gyp
@@ -472,6 +472,7 @@
'proximity_auth/cryptauth/fake_secure_message_delegate_unittest.cc',
'proximity_auth/cryptauth/sync_scheduler_impl_unittest.cc',
'proximity_auth/device_to_device_operations_unittest.cc',
+ 'proximity_auth/device_to_device_secure_context_unittest.cc',
'proximity_auth/logging/logging_unittest.cc',
'proximity_auth/proximity_auth_system_unittest.cc',
'proximity_auth/proximity_monitor_impl_unittest.cc',
diff --git a/components/proximity_auth.gypi b/components/proximity_auth.gypi
index 2f21b75..74b07d8 100644
--- a/components/proximity_auth.gypi
+++ b/components/proximity_auth.gypi
@@ -56,6 +56,8 @@
"proximity_auth/connection_observer.h",
"proximity_auth/device_to_device_initiator_operations.cc",
"proximity_auth/device_to_device_initiator_operations.h",
+ "proximity_auth/device_to_device_secure_context.cc",
+ "proximity_auth/device_to_device_secure_context.h",
"proximity_auth/metrics.cc",
"proximity_auth/metrics.h",
"proximity_auth/proximity_auth_client.h",
diff --git a/components/proximity_auth/BUILD.gn b/components/proximity_auth/BUILD.gn
index a85fb9a..d064756 100644
--- a/components/proximity_auth/BUILD.gn
+++ b/components/proximity_auth/BUILD.gn
@@ -26,6 +26,8 @@ source_set("proximity_auth") {
"connection_observer.h",
"device_to_device_initiator_operations.cc",
"device_to_device_initiator_operations.h",
+ "device_to_device_secure_context.cc",
+ "device_to_device_secure_context.h",
"metrics.cc",
"metrics.h",
"proximity_auth_client.h",
@@ -83,6 +85,7 @@ source_set("unit_tests") {
"client_impl_unittest.cc",
"connection_unittest.cc",
"device_to_device_operations_unittest.cc",
+ "device_to_device_secure_context_unittest.cc",
"proximity_auth_system_unittest.cc",
"proximity_monitor_impl_unittest.cc",
"remote_status_update_unittest.cc",
diff --git a/components/proximity_auth/ble/fake_wire_message.cc b/components/proximity_auth/ble/fake_wire_message.cc
index 193ac04..0997b3f 100644
--- a/components/proximity_auth/ble/fake_wire_message.cc
+++ b/components/proximity_auth/ble/fake_wire_message.cc
@@ -12,8 +12,7 @@
namespace proximity_auth {
FakeWireMessage::FakeWireMessage(const std::string& payload)
- : WireMessage("", payload) {
-}
+ : WireMessage(payload) {}
scoped_ptr<FakeWireMessage> FakeWireMessage::Deserialize(
const std::string& serialized_message,
diff --git a/components/proximity_auth/client_impl.cc b/components/proximity_auth/client_impl.cc
index fc7d6e4b..7a96f50 100644
--- a/components/proximity_auth/client_impl.cc
+++ b/components/proximity_auth/client_impl.cc
@@ -4,12 +4,14 @@
#include "components/proximity_auth/client_impl.h"
+#include "base/bind.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/values.h"
#include "components/proximity_auth/client_observer.h"
#include "components/proximity_auth/connection.h"
#include "components/proximity_auth/cryptauth/base64url.h"
+#include "components/proximity_auth/logging/logging.h"
#include "components/proximity_auth/remote_status_update.h"
#include "components/proximity_auth/secure_context.h"
#include "components/proximity_auth/wire_message.h"
@@ -54,7 +56,9 @@ std::string GetMessageType(const base::DictionaryValue& message) {
ClientImpl::ClientImpl(scoped_ptr<Connection> connection,
scoped_ptr<SecureContext> secure_context)
- : connection_(connection.Pass()), secure_context_(secure_context.Pass()) {
+ : connection_(connection.Pass()),
+ secure_context_(secure_context.Pass()),
+ weak_ptr_factory_(this) {
DCHECK(connection_->IsConnected());
connection_->AddObserver(this);
}
@@ -87,8 +91,8 @@ void ClientImpl::DispatchUnlockEvent() {
void ClientImpl::RequestDecryption(const std::string& challenge) {
if (!SupportsSignIn()) {
- VLOG(1) << "[Client] Dropping decryption request, as remote device "
- << "does not support protocol v3.1.";
+ PA_LOG(WARNING) << "Dropping decryption request, as remote device "
+ << "does not support protocol v3.1.";
FOR_EACH_OBSERVER(ClientObserver, observers_,
OnDecryptResponse(scoped_ptr<std::string>()));
return;
@@ -108,8 +112,8 @@ void ClientImpl::RequestDecryption(const std::string& challenge) {
void ClientImpl::RequestUnlock() {
if (!SupportsSignIn()) {
- VLOG(1) << "[Client] Dropping unlock request, as remote device does not "
- << "support protocol v3.1.";
+ PA_LOG(WARNING) << "Dropping unlock request, as remote device does not "
+ << "support protocol v3.1.";
FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockResponse(false));
return;
}
@@ -139,65 +143,21 @@ void ClientImpl::ProcessMessageQueue() {
pending_message_.reset(new PendingMessage(queued_messages_.front()));
queued_messages_.pop_front();
- connection_->SendMessage(make_scoped_ptr(new WireMessage(
- std::string(), secure_context_->Encode(pending_message_->json_message))));
+ secure_context_->Encode(pending_message_->json_message,
+ base::Bind(&ClientImpl::OnMessageEncoded,
+ weak_ptr_factory_.GetWeakPtr()));
}
-void ClientImpl::HandleRemoteStatusUpdateMessage(
- const base::DictionaryValue& message) {
- scoped_ptr<RemoteStatusUpdate> status_update =
- RemoteStatusUpdate::Deserialize(message);
- if (!status_update) {
- VLOG(1) << "[Client] Unexpected remote status update: " << message;
- return;
- }
-
- FOR_EACH_OBSERVER(ClientObserver, observers_,
- OnRemoteStatusUpdate(*status_update));
+void ClientImpl::OnMessageEncoded(const std::string& encoded_message) {
+ connection_->SendMessage(make_scoped_ptr(new WireMessage(encoded_message)));
}
-void ClientImpl::HandleDecryptResponseMessage(
- const base::DictionaryValue& message) {
- std::string base64_data;
- std::string decrypted_data;
- scoped_ptr<std::string> response;
- if (!message.GetString(kDataKey, &base64_data) || base64_data.empty()) {
- VLOG(1) << "[Client] Decrypt response missing '" << kDataKey << "' value.";
- } else if (!Base64UrlDecode(base64_data, &decrypted_data)) {
- VLOG(1) << "[Client] Unable to base64-decode decrypt response.";
- } else {
- response.reset(new std::string(decrypted_data));
- }
- FOR_EACH_OBSERVER(ClientObserver, observers_,
- OnDecryptResponse(response.Pass()));
-}
-
-void ClientImpl::HandleUnlockResponseMessage(
- const base::DictionaryValue& message) {
- FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockResponse(true));
-}
-
-void ClientImpl::OnConnectionStatusChanged(Connection* connection,
- Connection::Status old_status,
- Connection::Status new_status) {
- DCHECK_EQ(connection, connection_.get());
- if (new_status != Connection::CONNECTED) {
- VLOG(1) << "[Client] Secure channel disconnected...";
- connection_->RemoveObserver(this);
- connection_.reset();
- FOR_EACH_OBSERVER(ClientObserver, observers_, OnDisconnected());
- // TODO(isherman): Determine whether it's also necessary/appropriate to fire
- // this notification from the destructor.
- }
-}
-
-void ClientImpl::OnMessageReceived(const Connection& connection,
- const WireMessage& wire_message) {
- std::string json_message = secure_context_->Decode(wire_message.payload());
- scoped_ptr<base::Value> message_value = base::JSONReader::Read(json_message);
+void ClientImpl::OnMessageDecoded(const std::string& decoded_message) {
+ // The decoded message should be a JSON string.
+ scoped_ptr<base::Value> message_value =
+ base::JSONReader::Read(decoded_message);
if (!message_value || !message_value->IsType(base::Value::TYPE_DICTIONARY)) {
- VLOG(1) << "[Client] Unable to parse message as JSON: " << json_message
- << ".";
+ PA_LOG(ERROR) << "Unable to parse message as JSON:\n" << decoded_message;
return;
}
@@ -207,8 +167,8 @@ void ClientImpl::OnMessageReceived(const Connection& connection,
std::string type;
if (!message->GetString(kTypeKey, &type)) {
- VLOG(1) << "[Client] Missing '" << kTypeKey
- << "' key in message: " << json_message << ".";
+ PA_LOG(ERROR) << "Missing '" << kTypeKey << "' key in message:\n "
+ << decoded_message;
return;
}
@@ -221,7 +181,7 @@ void ClientImpl::OnMessageReceived(const Connection& connection,
// All other messages should only be received in response to a message that
// the client sent.
if (!pending_message_) {
- VLOG(1) << "[Client] Unexpected message received: " << json_message;
+ PA_LOG(WARNING) << "Unexpected message received:\n" << decoded_message;
return;
}
@@ -234,9 +194,9 @@ void ClientImpl::OnMessageReceived(const Connection& connection,
NOTREACHED(); // There are no other message types that expect a response.
if (type != expected_type) {
- VLOG(1) << "[Client] Unexpected '" << kTypeKey << "' value in message. "
- << "Expected '" << expected_type << "' but received '" << type
- << "'.";
+ PA_LOG(ERROR) << "Unexpected '" << kTypeKey << "' value in message. "
+ << "Expected '" << expected_type << "' but received '" << type
+ << "'.";
return;
}
@@ -251,11 +211,66 @@ void ClientImpl::OnMessageReceived(const Connection& connection,
ProcessMessageQueue();
}
+void ClientImpl::HandleRemoteStatusUpdateMessage(
+ const base::DictionaryValue& message) {
+ scoped_ptr<RemoteStatusUpdate> status_update =
+ RemoteStatusUpdate::Deserialize(message);
+ if (!status_update) {
+ PA_LOG(ERROR) << "Unexpected remote status update: " << message;
+ return;
+ }
+
+ FOR_EACH_OBSERVER(ClientObserver, observers_,
+ OnRemoteStatusUpdate(*status_update));
+}
+
+void ClientImpl::HandleDecryptResponseMessage(
+ const base::DictionaryValue& message) {
+ std::string base64_data;
+ std::string decrypted_data;
+ scoped_ptr<std::string> response;
+ if (!message.GetString(kDataKey, &base64_data) || base64_data.empty()) {
+ PA_LOG(ERROR) << "Decrypt response missing '" << kDataKey << "' value.";
+ } else if (!Base64UrlDecode(base64_data, &decrypted_data)) {
+ PA_LOG(ERROR) << "Unable to base64-decode decrypt response.";
+ } else {
+ response.reset(new std::string(decrypted_data));
+ }
+ FOR_EACH_OBSERVER(ClientObserver, observers_,
+ OnDecryptResponse(response.Pass()));
+}
+
+void ClientImpl::HandleUnlockResponseMessage(
+ const base::DictionaryValue& message) {
+ FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockResponse(true));
+}
+
+void ClientImpl::OnConnectionStatusChanged(Connection* connection,
+ Connection::Status old_status,
+ Connection::Status new_status) {
+ DCHECK_EQ(connection, connection_.get());
+ if (new_status == Connection::DISCONNECTED) {
+ PA_LOG(INFO) << "Secure channel disconnected...";
+ connection_->RemoveObserver(this);
+ connection_.reset();
+ FOR_EACH_OBSERVER(ClientObserver, observers_, OnDisconnected());
+ // TODO(isherman): Determine whether it's also necessary/appropriate to fire
+ // this notification from the destructor.
+ }
+}
+
+void ClientImpl::OnMessageReceived(const Connection& connection,
+ const WireMessage& wire_message) {
+ secure_context_->Decode(wire_message.payload(),
+ base::Bind(&ClientImpl::OnMessageDecoded,
+ weak_ptr_factory_.GetWeakPtr()));
+}
+
void ClientImpl::OnSendCompleted(const Connection& connection,
const WireMessage& wire_message,
bool success) {
if (!pending_message_) {
- VLOG(1) << "[Client] Unexpected message sent.";
+ PA_LOG(ERROR) << "Unexpected message sent.";
return;
}
@@ -277,8 +292,8 @@ void ClientImpl::OnSendCompleted(const Connection& connection,
} else if (pending_message_->type == kMessageTypeLocalEvent) {
FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockEventSent(success));
} else {
- VLOG(1) << "[Client] Message of unknown type '" << pending_message_->type
- << "sent.";
+ PA_LOG(ERROR) << "Message of unknown type '" << pending_message_->type
+ << "' sent.";
}
pending_message_.reset();
diff --git a/components/proximity_auth/client_impl.h b/components/proximity_auth/client_impl.h
index b4d8dcb..d046092 100644
--- a/components/proximity_auth/client_impl.h
+++ b/components/proximity_auth/client_impl.h
@@ -9,6 +9,7 @@
#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
#include "base/observer_list.h"
#include "components/proximity_auth/client.h"
#include "components/proximity_auth/connection_observer.h"
@@ -65,6 +66,12 @@ class ClientImpl : public Client, public ConnectionObserver {
// Pops the first of the |queued_messages_| and sends it to the remote device.
void ProcessMessageQueue();
+ // Called when the message is encoded so it can be sent over the connection.
+ void OnMessageEncoded(const std::string& encoded_message);
+
+ // Called when the message is decoded so it can be parsed.
+ void OnMessageDecoded(const std::string& decoded_message);
+
// Handles an incoming "status_update" |message|, parsing and notifying
// observers of the content.
void HandleRemoteStatusUpdateMessage(const base::DictionaryValue& message);
@@ -104,6 +111,8 @@ class ClientImpl : public Client, public ConnectionObserver {
// response. Null if there is no message currently in this state.
scoped_ptr<PendingMessage> pending_message_;
+ base::WeakPtrFactory<ClientImpl> weak_ptr_factory_;
+
DISALLOW_COPY_AND_ASSIGN(ClientImpl);
};
diff --git a/components/proximity_auth/client_impl_unittest.cc b/components/proximity_auth/client_impl_unittest.cc
index f519e01..fc75805 100644
--- a/components/proximity_auth/client_impl_unittest.cc
+++ b/components/proximity_auth/client_impl_unittest.cc
@@ -4,6 +4,7 @@
#include "components/proximity_auth/client_impl.h"
+#include "base/callback.h"
#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
#include "components/proximity_auth/client_observer.h"
@@ -44,15 +45,17 @@ class MockSecureContext : public SecureContext {
MOCK_CONST_METHOD0(GetReceivedAuthMessage, std::string());
MOCK_CONST_METHOD0(GetProtocolVersion, ProtocolVersion());
- std::string Encode(const std::string& message) override {
- return message + kFakeEncodingSuffix;
+ void Encode(const std::string& message,
+ const MessageCallback& callback) override {
+ callback.Run(message + kFakeEncodingSuffix);
}
- std::string Decode(const std::string& encoded_message) override {
+ void Decode(const std::string& encoded_message,
+ const MessageCallback& callback) override {
EXPECT_THAT(encoded_message, EndsWith(kFakeEncodingSuffix));
std::string decoded_message = encoded_message;
decoded_message.erase(decoded_message.rfind(kFakeEncodingSuffix));
- return decoded_message;
+ callback.Run(decoded_message);
}
private:
@@ -94,7 +97,7 @@ class FakeConnection : public Connection {
scoped_ptr<WireMessage> DeserializeWireMessage(
bool* is_incomplete_message) override {
*is_incomplete_message = false;
- return make_scoped_ptr(new WireMessage(std::string(), pending_payload_));
+ return make_scoped_ptr(new WireMessage(pending_payload_));
}
WireMessage* current_message() { return current_message_.get(); }
diff --git a/components/proximity_auth/connection_unittest.cc b/components/proximity_auth/connection_unittest.cc
index 1e298cf..76cfe54 100644
--- a/components/proximity_auth/connection_unittest.cc
+++ b/components/proximity_auth/connection_unittest.cc
@@ -76,7 +76,7 @@ class MockConnectionObserver : public ConnectionObserver {
// Unlike WireMessage, offers a public constructor.
class TestWireMessage : public WireMessage {
public:
- TestWireMessage() : WireMessage(std::string(), std::string()) {}
+ TestWireMessage() : WireMessage(std::string()) {}
~TestWireMessage() override {}
private:
diff --git a/components/proximity_auth/device_to_device_secure_context.cc b/components/proximity_auth/device_to_device_secure_context.cc
new file mode 100644
index 0000000..eabc451
--- /dev/null
+++ b/components/proximity_auth/device_to_device_secure_context.cc
@@ -0,0 +1,119 @@
+// Copyright 2015 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/proximity_auth/device_to_device_secure_context.h"
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "components/proximity_auth/cryptauth/proto/cryptauth_api.pb.h"
+#include "components/proximity_auth/cryptauth/proto/securemessage.pb.h"
+#include "components/proximity_auth/cryptauth/secure_message_delegate.h"
+#include "components/proximity_auth/logging/logging.h"
+
+namespace proximity_auth {
+
+namespace {
+
+// The version to put in the GcmMetadata field.
+const int kGcmMetadataVersion = 1;
+
+// The sequence number of the last message used during authentication. These
+// messages are sent and received before the SecureContext is created.
+const int kAuthenticationSequenceNumber = 2;
+
+} // namespace
+
+DeviceToDeviceSecureContext::DeviceToDeviceSecureContext(
+ scoped_ptr<SecureMessageDelegate> secure_message_delegate,
+ const std::string& symmetric_key,
+ const std::string& responder_auth_message,
+ ProtocolVersion protocol_version)
+ : secure_message_delegate_(secure_message_delegate.Pass()),
+ symmetric_key_(symmetric_key),
+ responder_auth_message_(responder_auth_message),
+ protocol_version_(protocol_version),
+ last_sequence_number_(kAuthenticationSequenceNumber),
+ weak_ptr_factory_(this) {}
+
+DeviceToDeviceSecureContext::~DeviceToDeviceSecureContext() {}
+
+void DeviceToDeviceSecureContext::Decode(const std::string& encoded_message,
+ const MessageCallback& callback) {
+ SecureMessageDelegate::UnwrapOptions unwrap_options;
+ unwrap_options.encryption_scheme = securemessage::AES_256_CBC;
+ unwrap_options.signature_scheme = securemessage::HMAC_SHA256;
+
+ secure_message_delegate_->UnwrapSecureMessage(
+ encoded_message, symmetric_key_, unwrap_options,
+ base::Bind(&DeviceToDeviceSecureContext::HandleUnwrapResult,
+ weak_ptr_factory_.GetWeakPtr(), callback));
+}
+
+void DeviceToDeviceSecureContext::Encode(const std::string& message,
+ const MessageCallback& callback) {
+ // Create a GcmMetadata field to put in the header.
+ cryptauth::GcmMetadata gcm_metadata;
+ gcm_metadata.set_type(cryptauth::DEVICE_TO_DEVICE_MESSAGE);
+ gcm_metadata.set_version(kGcmMetadataVersion);
+
+ // Wrap |message| inside a DeviceToDeviceMessage proto.
+ securemessage::DeviceToDeviceMessage device_to_device_message;
+ device_to_device_message.set_sequence_number(++last_sequence_number_);
+ device_to_device_message.set_message(message);
+
+ SecureMessageDelegate::CreateOptions create_options;
+ create_options.encryption_scheme = securemessage::AES_256_CBC;
+ create_options.signature_scheme = securemessage::HMAC_SHA256;
+ gcm_metadata.SerializeToString(&create_options.public_metadata);
+
+ secure_message_delegate_->CreateSecureMessage(
+ device_to_device_message.SerializeAsString(), symmetric_key_,
+ create_options, callback);
+}
+
+std::string DeviceToDeviceSecureContext::GetReceivedAuthMessage() const {
+ return responder_auth_message_;
+}
+
+SecureContext::ProtocolVersion DeviceToDeviceSecureContext::GetProtocolVersion()
+ const {
+ return protocol_version_;
+}
+
+void DeviceToDeviceSecureContext::HandleUnwrapResult(
+ const DeviceToDeviceSecureContext::MessageCallback& callback,
+ bool verified,
+ const std::string& payload,
+ const securemessage::Header& header) {
+ // The payload should contain a DeviceToDeviceMessage proto.
+ securemessage::DeviceToDeviceMessage device_to_device_message;
+ if (!verified || !device_to_device_message.ParseFromString(payload)) {
+ PA_LOG(ERROR) << "Failed to unwrap secure message.";
+ callback.Run(std::string());
+ return;
+ }
+
+ // Check that the sequence number matches the expected sequence number.
+ if (device_to_device_message.sequence_number() != last_sequence_number_ + 1) {
+ PA_LOG(ERROR) << "Expected sequence_number=" << last_sequence_number_ + 1
+ << ", but got " << device_to_device_message.sequence_number();
+ callback.Run(std::string());
+ return;
+ }
+
+ // Validate the GcmMetadata proto in the header.
+ cryptauth::GcmMetadata gcm_metadata;
+ if (!gcm_metadata.ParseFromString(header.public_metadata()) ||
+ gcm_metadata.type() != cryptauth::DEVICE_TO_DEVICE_MESSAGE ||
+ gcm_metadata.version() != kGcmMetadataVersion) {
+ PA_LOG(ERROR) << "Failed to validate GcmMetadata.";
+ callback.Run(std::string());
+ return;
+ }
+
+ last_sequence_number_++;
+ callback.Run(device_to_device_message.message());
+}
+
+} // proximity_auth
diff --git a/components/proximity_auth/device_to_device_secure_context.h b/components/proximity_auth/device_to_device_secure_context.h
new file mode 100644
index 0000000..62e3296
--- /dev/null
+++ b/components/proximity_auth/device_to_device_secure_context.h
@@ -0,0 +1,77 @@
+// Copyright 2015 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_PROXIMITY_AUTH_DEVICE_TO_DEVICE_SECURE_CONTEXT_H
+#define COMPONENTS_PROXIMITY_AUTH_DEVICE_TO_DEVICE_SECURE_CONTEXT_H
+
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "components/proximity_auth/secure_context.h"
+
+namespace securemessage {
+class Header;
+}
+
+namespace proximity_auth {
+
+class SecureMessageDelegate;
+
+// SecureContext implementation for the DeviceToDevice protocol.
+class DeviceToDeviceSecureContext : public SecureContext {
+ public:
+ DeviceToDeviceSecureContext(
+ scoped_ptr<SecureMessageDelegate> secure_message_delegate,
+ const std::string& symmetric_key,
+ const std::string& responder_auth_message_,
+ ProtocolVersion protocol_version);
+
+ ~DeviceToDeviceSecureContext() override;
+
+ // SecureContext:
+ void Decode(const std::string& encoded_message,
+ const MessageCallback& callback) override;
+ void Encode(const std::string& message,
+ const MessageCallback& callback) override;
+ ProtocolVersion GetProtocolVersion() const override;
+
+ // Returns the message received from the remote device that authenticates it.
+ // This message should have been received during the handshake that
+ // establishes the secure channel.
+ std::string GetReceivedAuthMessage() const;
+
+ private:
+ // Callback for unwrapping a secure message. |callback| will be invoked with
+ // the decrypted payload if the message is unwrapped successfully; otherwise
+ // it will be invoked with an empty string.
+ void HandleUnwrapResult(
+ const DeviceToDeviceSecureContext::MessageCallback& callback,
+ bool verified,
+ const std::string& payload,
+ const securemessage::Header& header);
+
+ // Delegate for handling the creation and unwrapping of SecureMessages.
+ scoped_ptr<SecureMessageDelegate> secure_message_delegate_;
+
+ // The symmetric key used to create and unwrap messages.
+ const std::string symmetric_key_;
+
+ // The [Responder Auth] message received from the remote device during
+ // authentication.
+ const std::string responder_auth_message_;
+
+ // The protocol version supported by the remote device.
+ const ProtocolVersion protocol_version_;
+
+ // The last sequence number of the message sent or received.
+ int last_sequence_number_;
+
+ base::WeakPtrFactory<DeviceToDeviceSecureContext> weak_ptr_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(DeviceToDeviceSecureContext);
+};
+
+} // namespace proximity_auth
+
+#endif // COMPONENTS_PROXIMITY_AUTH_DEVICE_TO_DEVICE_SECURE_CONTEXT_H
diff --git a/components/proximity_auth/device_to_device_secure_context_unittest.cc b/components/proximity_auth/device_to_device_secure_context_unittest.cc
new file mode 100644
index 0000000..cd00fa2
--- /dev/null
+++ b/components/proximity_auth/device_to_device_secure_context_unittest.cc
@@ -0,0 +1,113 @@
+// Copyright 2015 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/proximity_auth/device_to_device_secure_context.h"
+
+#include "base/bind.h"
+#include "base/memory/scoped_ptr.h"
+#include "components/proximity_auth/cryptauth/fake_secure_message_delegate.h"
+#include "components/proximity_auth/cryptauth/proto/cryptauth_api.pb.h"
+#include "components/proximity_auth/cryptauth/proto/securemessage.pb.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace proximity_auth {
+
+namespace {
+
+const char kSymmetricKey[] = "symmetric key";
+const char kResponderAuthMessage[] = "responder_auth_message";
+const SecureContext::ProtocolVersion kProtocolVersion =
+ SecureContext::PROTOCOL_VERSION_THREE_ONE;
+
+// Callback saving |result| to |result_out|.
+void SaveResult(std::string* result_out, const std::string& result) {
+ *result_out = result;
+}
+
+} // namespace
+
+class ProximityAuthDeviceToDeviceSecureContextTest : public testing::Test {
+ protected:
+ ProximityAuthDeviceToDeviceSecureContextTest()
+ : secure_context_(make_scoped_ptr(new FakeSecureMessageDelegate()),
+ kSymmetricKey,
+ kResponderAuthMessage,
+ kProtocolVersion) {}
+
+ DeviceToDeviceSecureContext secure_context_;
+};
+
+TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, GetProperties) {
+ EXPECT_EQ(kResponderAuthMessage, secure_context_.GetReceivedAuthMessage());
+ EXPECT_EQ(kProtocolVersion, secure_context_.GetProtocolVersion());
+}
+
+TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, CheckEncodedHeader) {
+ std::string message = "encrypt this message";
+ std::string encoded_message;
+ secure_context_.Encode(message, base::Bind(&SaveResult, &encoded_message));
+
+ securemessage::SecureMessage secure_message;
+ ASSERT_TRUE(secure_message.ParseFromString(encoded_message));
+ securemessage::HeaderAndBody header_and_body;
+ ASSERT_TRUE(
+ header_and_body.ParseFromString(secure_message.header_and_body()));
+
+ cryptauth::GcmMetadata gcm_metadata;
+ ASSERT_TRUE(
+ gcm_metadata.ParseFromString(header_and_body.header().public_metadata()));
+ EXPECT_EQ(1, gcm_metadata.version());
+ EXPECT_EQ(cryptauth::DEVICE_TO_DEVICE_MESSAGE, gcm_metadata.type());
+}
+
+TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, DecodeInvalidMessage) {
+ std::string encoded_message = "invalidly encoded message";
+ std::string decoded_message = "not empty";
+ secure_context_.Decode(encoded_message,
+ base::Bind(&SaveResult, &decoded_message));
+ EXPECT_TRUE(decoded_message.empty());
+}
+
+TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, EncodeAndDecode) {
+ // Initialize second secure channel with the same parameters as the first.
+ DeviceToDeviceSecureContext secure_context2(
+ make_scoped_ptr(new FakeSecureMessageDelegate()), kSymmetricKey,
+ kResponderAuthMessage, kProtocolVersion);
+ std::string message = "encrypt this message";
+
+ // Pass some messages between the two secure contexts.
+ for (int i = 0; i < 3; ++i) {
+ std::string encoded_message;
+ secure_context_.Encode(message, base::Bind(&SaveResult, &encoded_message));
+ EXPECT_NE(message, encoded_message);
+
+ std::string decoded_message;
+ secure_context2.Decode(encoded_message,
+ base::Bind(&SaveResult, &decoded_message));
+ EXPECT_EQ(message, decoded_message);
+ }
+}
+
+TEST_F(ProximityAuthDeviceToDeviceSecureContextTest,
+ DecodeInvalidSequenceNumber) {
+ // Initialize second secure channel with the same parameters as the first.
+ DeviceToDeviceSecureContext secure_context2(
+ make_scoped_ptr(new FakeSecureMessageDelegate()), kSymmetricKey,
+ kResponderAuthMessage, kProtocolVersion);
+
+ // Send a few messages over the first secure context.
+ std::string message = "encrypt this message";
+ std::string encoded1;
+ for (int i = 0; i < 3; ++i) {
+ secure_context_.Encode(message, base::Bind(&SaveResult, &encoded1));
+ }
+
+ // Second secure channel should not decode the message with an invalid
+ // sequence number.
+ std::string decoded_message = "not empty";
+ secure_context_.Decode(encoded1, base::Bind(&SaveResult, &decoded_message));
+ EXPECT_TRUE(decoded_message.empty());
+}
+
+} // proximity_auth
diff --git a/components/proximity_auth/secure_context.h b/components/proximity_auth/secure_context.h
index 3b9f34a..b249720 100644
--- a/components/proximity_auth/secure_context.h
+++ b/components/proximity_auth/secure_context.h
@@ -5,11 +5,15 @@
#ifndef COMPONENTS_PROXIMITY_AUTH_SECURE_CONTEXT_H
#define COMPONENTS_PROXIMITY_AUTH_SECURE_CONTEXT_H
+#include "base/callback_forward.h"
+
namespace proximity_auth {
// An interface used to decode and encode messages.
class SecureContext {
public:
+ typedef base::Callback<void(const std::string& message)> MessageCallback;
+
// The protocol version used during authentication.
enum ProtocolVersion {
PROTOCOL_VERSION_THREE_ZERO, // 3.0
@@ -19,15 +23,16 @@ class SecureContext {
virtual ~SecureContext() {}
// Decodes the |encoded_message| and returns the result.
- virtual std::string Decode(const std::string& encoded_message) = 0;
+ // This function is asynchronous because the ChromeOS implementation requires
+ // a DBus call.
+ virtual void Decode(const std::string& encoded_message,
+ const MessageCallback& callback) = 0;
// Encodes the |message| and returns the result.
- virtual std::string Encode(const std::string& message) = 0;
-
- // Returns the message received from the remote device that authenticates it.
- // This message should have been received during the handshake that
- // establishes the secure channel.
- virtual std::string GetReceivedAuthMessage() const = 0;
+ // This function is asynchronous because the ChromeOS implementation requires
+ // a DBus call.
+ virtual void Encode(const std::string& message,
+ const MessageCallback& callback) = 0;
// Returns the protocol version that was used during authentication.
virtual ProtocolVersion GetProtocolVersion() const = 0;
diff --git a/components/proximity_auth/wire_message.cc b/components/proximity_auth/wire_message.cc
index 7906615..6d21116 100644
--- a/components/proximity_auth/wire_message.cc
+++ b/components/proximity_auth/wire_message.cc
@@ -111,7 +111,7 @@ scoped_ptr<WireMessage> WireMessage::Deserialize(
return scoped_ptr<WireMessage>();
}
- return scoped_ptr<WireMessage>(new WireMessage(permit_id, payload));
+ return make_scoped_ptr(new WireMessage(payload, permit_id));
}
std::string WireMessage::Serialize() const {
@@ -155,9 +155,11 @@ std::string WireMessage::Serialize() const {
return header_string + json_body;
}
-WireMessage::WireMessage(const std::string& permit_id,
- const std::string& payload)
- : permit_id_(permit_id), payload_(payload) {
-}
+WireMessage::WireMessage(const std::string& payload)
+ : WireMessage(payload, std::string()) {}
+
+WireMessage::WireMessage(const std::string& payload,
+ const std::string& permit_id)
+ : payload_(payload), permit_id_(permit_id) {}
} // namespace proximity_auth
diff --git a/components/proximity_auth/wire_message.h b/components/proximity_auth/wire_message.h
index 73b643f..c023c21 100644
--- a/components/proximity_auth/wire_message.h
+++ b/components/proximity_auth/wire_message.h
@@ -14,7 +14,12 @@ namespace proximity_auth {
class WireMessage {
public:
- WireMessage(const std::string& permit_id, const std::string& payload);
+ // Creates a WireMessage containing |payload|.
+ explicit WireMessage(const std::string& payload);
+
+ // Creates a WireMessage containing |payload| and |permit_id| in the metadata.
+ WireMessage(const std::string& payload, const std::string& permit_id);
+
virtual ~WireMessage();
// Returns the deserialized message from |serialized_message|, or NULL if the
@@ -28,17 +33,20 @@ class WireMessage {
// Returns a serialized representation of |this| message.
virtual std::string Serialize() const;
- const std::string& permit_id() const { return permit_id_; }
const std::string& payload() const { return payload_; }
+ const std::string& permit_id() const { return permit_id_; }
private:
- // Identifier of the permit being used.
- // TODO(isherman): Describe in a bit more detail.
- const std::string permit_id_;
-
// The message payload.
const std::string payload_;
+ // Identifier of the permit being used. A permit contains the credentials used
+ // to authenticate a device. For example, when sending a WireMessage to the
+ // remote device the |permit_id_| indexes a permit possibly containing the
+ // public key
+ // of the local device or a symmetric key shared between the devices.
+ const std::string permit_id_;
+
DISALLOW_COPY_AND_ASSIGN(WireMessage);
};
diff --git a/components/proximity_auth/wire_message_unittest.cc b/components/proximity_auth/wire_message_unittest.cc
index 4411ff8..512f480 100644
--- a/components/proximity_auth/wire_message_unittest.cc
+++ b/components/proximity_auth/wire_message_unittest.cc
@@ -200,7 +200,7 @@ TEST(ProximityAuthWireMessage, Deserialize_SizeEquals0x01FF) {
}
TEST(ProximityAuthWireMessage, Serialize_WithPermitId) {
- WireMessage message1("example id", "example payload");
+ WireMessage message1("example payload", "example id");
std::string bytes = message1.Serialize();
ASSERT_FALSE(bytes.empty());
@@ -214,7 +214,7 @@ TEST(ProximityAuthWireMessage, Serialize_WithPermitId) {
}
TEST(ProximityAuthWireMessage, Serialize_WithoutPermitId) {
- WireMessage message1(std::string(), "example payload");
+ WireMessage message1("example payload");
std::string bytes = message1.Serialize();
ASSERT_FALSE(bytes.empty());
@@ -228,7 +228,7 @@ TEST(ProximityAuthWireMessage, Serialize_WithoutPermitId) {
}
TEST(ProximityAuthWireMessage, Serialize_FailsWithoutPayload) {
- WireMessage message1("example id", std::string());
+ WireMessage message1(std::string(), "example id");
std::string bytes = message1.Serialize();
EXPECT_TRUE(bytes.empty());
}