diff options
-rw-r--r-- | components/components_tests.gyp | 1 | ||||
-rw-r--r-- | components/proximity_auth.gypi | 2 | ||||
-rw-r--r-- | components/proximity_auth/BUILD.gn | 3 | ||||
-rw-r--r-- | components/proximity_auth/ble/fake_wire_message.cc | 3 | ||||
-rw-r--r-- | components/proximity_auth/client_impl.cc | 151 | ||||
-rw-r--r-- | components/proximity_auth/client_impl.h | 9 | ||||
-rw-r--r-- | components/proximity_auth/client_impl_unittest.cc | 13 | ||||
-rw-r--r-- | components/proximity_auth/connection_unittest.cc | 2 | ||||
-rw-r--r-- | components/proximity_auth/device_to_device_secure_context.cc | 119 | ||||
-rw-r--r-- | components/proximity_auth/device_to_device_secure_context.h | 77 | ||||
-rw-r--r-- | components/proximity_auth/device_to_device_secure_context_unittest.cc | 113 | ||||
-rw-r--r-- | components/proximity_auth/secure_context.h | 19 | ||||
-rw-r--r-- | components/proximity_auth/wire_message.cc | 12 | ||||
-rw-r--r-- | components/proximity_auth/wire_message.h | 20 | ||||
-rw-r--r-- | components/proximity_auth/wire_message_unittest.cc | 6 |
15 files changed, 453 insertions, 97 deletions
diff --git a/components/components_tests.gyp b/components/components_tests.gyp index 09d599f..f8bc6fb 100644 --- a/components/components_tests.gyp +++ b/components/components_tests.gyp @@ -472,6 +472,7 @@ 'proximity_auth/cryptauth/fake_secure_message_delegate_unittest.cc', 'proximity_auth/cryptauth/sync_scheduler_impl_unittest.cc', 'proximity_auth/device_to_device_operations_unittest.cc', + 'proximity_auth/device_to_device_secure_context_unittest.cc', 'proximity_auth/logging/logging_unittest.cc', 'proximity_auth/proximity_auth_system_unittest.cc', 'proximity_auth/proximity_monitor_impl_unittest.cc', diff --git a/components/proximity_auth.gypi b/components/proximity_auth.gypi index 2f21b75..74b07d8 100644 --- a/components/proximity_auth.gypi +++ b/components/proximity_auth.gypi @@ -56,6 +56,8 @@ "proximity_auth/connection_observer.h", "proximity_auth/device_to_device_initiator_operations.cc", "proximity_auth/device_to_device_initiator_operations.h", + "proximity_auth/device_to_device_secure_context.cc", + "proximity_auth/device_to_device_secure_context.h", "proximity_auth/metrics.cc", "proximity_auth/metrics.h", "proximity_auth/proximity_auth_client.h", diff --git a/components/proximity_auth/BUILD.gn b/components/proximity_auth/BUILD.gn index a85fb9a..d064756 100644 --- a/components/proximity_auth/BUILD.gn +++ b/components/proximity_auth/BUILD.gn @@ -26,6 +26,8 @@ source_set("proximity_auth") { "connection_observer.h", "device_to_device_initiator_operations.cc", "device_to_device_initiator_operations.h", + "device_to_device_secure_context.cc", + "device_to_device_secure_context.h", "metrics.cc", "metrics.h", "proximity_auth_client.h", @@ -83,6 +85,7 @@ source_set("unit_tests") { "client_impl_unittest.cc", "connection_unittest.cc", "device_to_device_operations_unittest.cc", + "device_to_device_secure_context_unittest.cc", "proximity_auth_system_unittest.cc", "proximity_monitor_impl_unittest.cc", "remote_status_update_unittest.cc", diff --git a/components/proximity_auth/ble/fake_wire_message.cc b/components/proximity_auth/ble/fake_wire_message.cc index 193ac04..0997b3f 100644 --- a/components/proximity_auth/ble/fake_wire_message.cc +++ b/components/proximity_auth/ble/fake_wire_message.cc @@ -12,8 +12,7 @@ namespace proximity_auth { FakeWireMessage::FakeWireMessage(const std::string& payload) - : WireMessage("", payload) { -} + : WireMessage(payload) {} scoped_ptr<FakeWireMessage> FakeWireMessage::Deserialize( const std::string& serialized_message, diff --git a/components/proximity_auth/client_impl.cc b/components/proximity_auth/client_impl.cc index fc7d6e4b..7a96f50 100644 --- a/components/proximity_auth/client_impl.cc +++ b/components/proximity_auth/client_impl.cc @@ -4,12 +4,14 @@ #include "components/proximity_auth/client_impl.h" +#include "base/bind.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" #include "base/values.h" #include "components/proximity_auth/client_observer.h" #include "components/proximity_auth/connection.h" #include "components/proximity_auth/cryptauth/base64url.h" +#include "components/proximity_auth/logging/logging.h" #include "components/proximity_auth/remote_status_update.h" #include "components/proximity_auth/secure_context.h" #include "components/proximity_auth/wire_message.h" @@ -54,7 +56,9 @@ std::string GetMessageType(const base::DictionaryValue& message) { ClientImpl::ClientImpl(scoped_ptr<Connection> connection, scoped_ptr<SecureContext> secure_context) - : connection_(connection.Pass()), secure_context_(secure_context.Pass()) { + : connection_(connection.Pass()), + secure_context_(secure_context.Pass()), + weak_ptr_factory_(this) { DCHECK(connection_->IsConnected()); connection_->AddObserver(this); } @@ -87,8 +91,8 @@ void ClientImpl::DispatchUnlockEvent() { void ClientImpl::RequestDecryption(const std::string& challenge) { if (!SupportsSignIn()) { - VLOG(1) << "[Client] Dropping decryption request, as remote device " - << "does not support protocol v3.1."; + PA_LOG(WARNING) << "Dropping decryption request, as remote device " + << "does not support protocol v3.1."; FOR_EACH_OBSERVER(ClientObserver, observers_, OnDecryptResponse(scoped_ptr<std::string>())); return; @@ -108,8 +112,8 @@ void ClientImpl::RequestDecryption(const std::string& challenge) { void ClientImpl::RequestUnlock() { if (!SupportsSignIn()) { - VLOG(1) << "[Client] Dropping unlock request, as remote device does not " - << "support protocol v3.1."; + PA_LOG(WARNING) << "Dropping unlock request, as remote device does not " + << "support protocol v3.1."; FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockResponse(false)); return; } @@ -139,65 +143,21 @@ void ClientImpl::ProcessMessageQueue() { pending_message_.reset(new PendingMessage(queued_messages_.front())); queued_messages_.pop_front(); - connection_->SendMessage(make_scoped_ptr(new WireMessage( - std::string(), secure_context_->Encode(pending_message_->json_message)))); + secure_context_->Encode(pending_message_->json_message, + base::Bind(&ClientImpl::OnMessageEncoded, + weak_ptr_factory_.GetWeakPtr())); } -void ClientImpl::HandleRemoteStatusUpdateMessage( - const base::DictionaryValue& message) { - scoped_ptr<RemoteStatusUpdate> status_update = - RemoteStatusUpdate::Deserialize(message); - if (!status_update) { - VLOG(1) << "[Client] Unexpected remote status update: " << message; - return; - } - - FOR_EACH_OBSERVER(ClientObserver, observers_, - OnRemoteStatusUpdate(*status_update)); +void ClientImpl::OnMessageEncoded(const std::string& encoded_message) { + connection_->SendMessage(make_scoped_ptr(new WireMessage(encoded_message))); } -void ClientImpl::HandleDecryptResponseMessage( - const base::DictionaryValue& message) { - std::string base64_data; - std::string decrypted_data; - scoped_ptr<std::string> response; - if (!message.GetString(kDataKey, &base64_data) || base64_data.empty()) { - VLOG(1) << "[Client] Decrypt response missing '" << kDataKey << "' value."; - } else if (!Base64UrlDecode(base64_data, &decrypted_data)) { - VLOG(1) << "[Client] Unable to base64-decode decrypt response."; - } else { - response.reset(new std::string(decrypted_data)); - } - FOR_EACH_OBSERVER(ClientObserver, observers_, - OnDecryptResponse(response.Pass())); -} - -void ClientImpl::HandleUnlockResponseMessage( - const base::DictionaryValue& message) { - FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockResponse(true)); -} - -void ClientImpl::OnConnectionStatusChanged(Connection* connection, - Connection::Status old_status, - Connection::Status new_status) { - DCHECK_EQ(connection, connection_.get()); - if (new_status != Connection::CONNECTED) { - VLOG(1) << "[Client] Secure channel disconnected..."; - connection_->RemoveObserver(this); - connection_.reset(); - FOR_EACH_OBSERVER(ClientObserver, observers_, OnDisconnected()); - // TODO(isherman): Determine whether it's also necessary/appropriate to fire - // this notification from the destructor. - } -} - -void ClientImpl::OnMessageReceived(const Connection& connection, - const WireMessage& wire_message) { - std::string json_message = secure_context_->Decode(wire_message.payload()); - scoped_ptr<base::Value> message_value = base::JSONReader::Read(json_message); +void ClientImpl::OnMessageDecoded(const std::string& decoded_message) { + // The decoded message should be a JSON string. + scoped_ptr<base::Value> message_value = + base::JSONReader::Read(decoded_message); if (!message_value || !message_value->IsType(base::Value::TYPE_DICTIONARY)) { - VLOG(1) << "[Client] Unable to parse message as JSON: " << json_message - << "."; + PA_LOG(ERROR) << "Unable to parse message as JSON:\n" << decoded_message; return; } @@ -207,8 +167,8 @@ void ClientImpl::OnMessageReceived(const Connection& connection, std::string type; if (!message->GetString(kTypeKey, &type)) { - VLOG(1) << "[Client] Missing '" << kTypeKey - << "' key in message: " << json_message << "."; + PA_LOG(ERROR) << "Missing '" << kTypeKey << "' key in message:\n " + << decoded_message; return; } @@ -221,7 +181,7 @@ void ClientImpl::OnMessageReceived(const Connection& connection, // All other messages should only be received in response to a message that // the client sent. if (!pending_message_) { - VLOG(1) << "[Client] Unexpected message received: " << json_message; + PA_LOG(WARNING) << "Unexpected message received:\n" << decoded_message; return; } @@ -234,9 +194,9 @@ void ClientImpl::OnMessageReceived(const Connection& connection, NOTREACHED(); // There are no other message types that expect a response. if (type != expected_type) { - VLOG(1) << "[Client] Unexpected '" << kTypeKey << "' value in message. " - << "Expected '" << expected_type << "' but received '" << type - << "'."; + PA_LOG(ERROR) << "Unexpected '" << kTypeKey << "' value in message. " + << "Expected '" << expected_type << "' but received '" << type + << "'."; return; } @@ -251,11 +211,66 @@ void ClientImpl::OnMessageReceived(const Connection& connection, ProcessMessageQueue(); } +void ClientImpl::HandleRemoteStatusUpdateMessage( + const base::DictionaryValue& message) { + scoped_ptr<RemoteStatusUpdate> status_update = + RemoteStatusUpdate::Deserialize(message); + if (!status_update) { + PA_LOG(ERROR) << "Unexpected remote status update: " << message; + return; + } + + FOR_EACH_OBSERVER(ClientObserver, observers_, + OnRemoteStatusUpdate(*status_update)); +} + +void ClientImpl::HandleDecryptResponseMessage( + const base::DictionaryValue& message) { + std::string base64_data; + std::string decrypted_data; + scoped_ptr<std::string> response; + if (!message.GetString(kDataKey, &base64_data) || base64_data.empty()) { + PA_LOG(ERROR) << "Decrypt response missing '" << kDataKey << "' value."; + } else if (!Base64UrlDecode(base64_data, &decrypted_data)) { + PA_LOG(ERROR) << "Unable to base64-decode decrypt response."; + } else { + response.reset(new std::string(decrypted_data)); + } + FOR_EACH_OBSERVER(ClientObserver, observers_, + OnDecryptResponse(response.Pass())); +} + +void ClientImpl::HandleUnlockResponseMessage( + const base::DictionaryValue& message) { + FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockResponse(true)); +} + +void ClientImpl::OnConnectionStatusChanged(Connection* connection, + Connection::Status old_status, + Connection::Status new_status) { + DCHECK_EQ(connection, connection_.get()); + if (new_status == Connection::DISCONNECTED) { + PA_LOG(INFO) << "Secure channel disconnected..."; + connection_->RemoveObserver(this); + connection_.reset(); + FOR_EACH_OBSERVER(ClientObserver, observers_, OnDisconnected()); + // TODO(isherman): Determine whether it's also necessary/appropriate to fire + // this notification from the destructor. + } +} + +void ClientImpl::OnMessageReceived(const Connection& connection, + const WireMessage& wire_message) { + secure_context_->Decode(wire_message.payload(), + base::Bind(&ClientImpl::OnMessageDecoded, + weak_ptr_factory_.GetWeakPtr())); +} + void ClientImpl::OnSendCompleted(const Connection& connection, const WireMessage& wire_message, bool success) { if (!pending_message_) { - VLOG(1) << "[Client] Unexpected message sent."; + PA_LOG(ERROR) << "Unexpected message sent."; return; } @@ -277,8 +292,8 @@ void ClientImpl::OnSendCompleted(const Connection& connection, } else if (pending_message_->type == kMessageTypeLocalEvent) { FOR_EACH_OBSERVER(ClientObserver, observers_, OnUnlockEventSent(success)); } else { - VLOG(1) << "[Client] Message of unknown type '" << pending_message_->type - << "sent."; + PA_LOG(ERROR) << "Message of unknown type '" << pending_message_->type + << "' sent."; } pending_message_.reset(); diff --git a/components/proximity_auth/client_impl.h b/components/proximity_auth/client_impl.h index b4d8dcb..d046092 100644 --- a/components/proximity_auth/client_impl.h +++ b/components/proximity_auth/client_impl.h @@ -9,6 +9,7 @@ #include "base/macros.h" #include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" #include "base/observer_list.h" #include "components/proximity_auth/client.h" #include "components/proximity_auth/connection_observer.h" @@ -65,6 +66,12 @@ class ClientImpl : public Client, public ConnectionObserver { // Pops the first of the |queued_messages_| and sends it to the remote device. void ProcessMessageQueue(); + // Called when the message is encoded so it can be sent over the connection. + void OnMessageEncoded(const std::string& encoded_message); + + // Called when the message is decoded so it can be parsed. + void OnMessageDecoded(const std::string& decoded_message); + // Handles an incoming "status_update" |message|, parsing and notifying // observers of the content. void HandleRemoteStatusUpdateMessage(const base::DictionaryValue& message); @@ -104,6 +111,8 @@ class ClientImpl : public Client, public ConnectionObserver { // response. Null if there is no message currently in this state. scoped_ptr<PendingMessage> pending_message_; + base::WeakPtrFactory<ClientImpl> weak_ptr_factory_; + DISALLOW_COPY_AND_ASSIGN(ClientImpl); }; diff --git a/components/proximity_auth/client_impl_unittest.cc b/components/proximity_auth/client_impl_unittest.cc index f519e01..fc75805 100644 --- a/components/proximity_auth/client_impl_unittest.cc +++ b/components/proximity_auth/client_impl_unittest.cc @@ -4,6 +4,7 @@ #include "components/proximity_auth/client_impl.h" +#include "base/callback.h" #include "base/macros.h" #include "base/memory/scoped_ptr.h" #include "components/proximity_auth/client_observer.h" @@ -44,15 +45,17 @@ class MockSecureContext : public SecureContext { MOCK_CONST_METHOD0(GetReceivedAuthMessage, std::string()); MOCK_CONST_METHOD0(GetProtocolVersion, ProtocolVersion()); - std::string Encode(const std::string& message) override { - return message + kFakeEncodingSuffix; + void Encode(const std::string& message, + const MessageCallback& callback) override { + callback.Run(message + kFakeEncodingSuffix); } - std::string Decode(const std::string& encoded_message) override { + void Decode(const std::string& encoded_message, + const MessageCallback& callback) override { EXPECT_THAT(encoded_message, EndsWith(kFakeEncodingSuffix)); std::string decoded_message = encoded_message; decoded_message.erase(decoded_message.rfind(kFakeEncodingSuffix)); - return decoded_message; + callback.Run(decoded_message); } private: @@ -94,7 +97,7 @@ class FakeConnection : public Connection { scoped_ptr<WireMessage> DeserializeWireMessage( bool* is_incomplete_message) override { *is_incomplete_message = false; - return make_scoped_ptr(new WireMessage(std::string(), pending_payload_)); + return make_scoped_ptr(new WireMessage(pending_payload_)); } WireMessage* current_message() { return current_message_.get(); } diff --git a/components/proximity_auth/connection_unittest.cc b/components/proximity_auth/connection_unittest.cc index 1e298cf..76cfe54 100644 --- a/components/proximity_auth/connection_unittest.cc +++ b/components/proximity_auth/connection_unittest.cc @@ -76,7 +76,7 @@ class MockConnectionObserver : public ConnectionObserver { // Unlike WireMessage, offers a public constructor. class TestWireMessage : public WireMessage { public: - TestWireMessage() : WireMessage(std::string(), std::string()) {} + TestWireMessage() : WireMessage(std::string()) {} ~TestWireMessage() override {} private: diff --git a/components/proximity_auth/device_to_device_secure_context.cc b/components/proximity_auth/device_to_device_secure_context.cc new file mode 100644 index 0000000..eabc451 --- /dev/null +++ b/components/proximity_auth/device_to_device_secure_context.cc @@ -0,0 +1,119 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/proximity_auth/device_to_device_secure_context.h" + +#include "base/bind.h" +#include "base/callback.h" +#include "components/proximity_auth/cryptauth/proto/cryptauth_api.pb.h" +#include "components/proximity_auth/cryptauth/proto/securemessage.pb.h" +#include "components/proximity_auth/cryptauth/secure_message_delegate.h" +#include "components/proximity_auth/logging/logging.h" + +namespace proximity_auth { + +namespace { + +// The version to put in the GcmMetadata field. +const int kGcmMetadataVersion = 1; + +// The sequence number of the last message used during authentication. These +// messages are sent and received before the SecureContext is created. +const int kAuthenticationSequenceNumber = 2; + +} // namespace + +DeviceToDeviceSecureContext::DeviceToDeviceSecureContext( + scoped_ptr<SecureMessageDelegate> secure_message_delegate, + const std::string& symmetric_key, + const std::string& responder_auth_message, + ProtocolVersion protocol_version) + : secure_message_delegate_(secure_message_delegate.Pass()), + symmetric_key_(symmetric_key), + responder_auth_message_(responder_auth_message), + protocol_version_(protocol_version), + last_sequence_number_(kAuthenticationSequenceNumber), + weak_ptr_factory_(this) {} + +DeviceToDeviceSecureContext::~DeviceToDeviceSecureContext() {} + +void DeviceToDeviceSecureContext::Decode(const std::string& encoded_message, + const MessageCallback& callback) { + SecureMessageDelegate::UnwrapOptions unwrap_options; + unwrap_options.encryption_scheme = securemessage::AES_256_CBC; + unwrap_options.signature_scheme = securemessage::HMAC_SHA256; + + secure_message_delegate_->UnwrapSecureMessage( + encoded_message, symmetric_key_, unwrap_options, + base::Bind(&DeviceToDeviceSecureContext::HandleUnwrapResult, + weak_ptr_factory_.GetWeakPtr(), callback)); +} + +void DeviceToDeviceSecureContext::Encode(const std::string& message, + const MessageCallback& callback) { + // Create a GcmMetadata field to put in the header. + cryptauth::GcmMetadata gcm_metadata; + gcm_metadata.set_type(cryptauth::DEVICE_TO_DEVICE_MESSAGE); + gcm_metadata.set_version(kGcmMetadataVersion); + + // Wrap |message| inside a DeviceToDeviceMessage proto. + securemessage::DeviceToDeviceMessage device_to_device_message; + device_to_device_message.set_sequence_number(++last_sequence_number_); + device_to_device_message.set_message(message); + + SecureMessageDelegate::CreateOptions create_options; + create_options.encryption_scheme = securemessage::AES_256_CBC; + create_options.signature_scheme = securemessage::HMAC_SHA256; + gcm_metadata.SerializeToString(&create_options.public_metadata); + + secure_message_delegate_->CreateSecureMessage( + device_to_device_message.SerializeAsString(), symmetric_key_, + create_options, callback); +} + +std::string DeviceToDeviceSecureContext::GetReceivedAuthMessage() const { + return responder_auth_message_; +} + +SecureContext::ProtocolVersion DeviceToDeviceSecureContext::GetProtocolVersion() + const { + return protocol_version_; +} + +void DeviceToDeviceSecureContext::HandleUnwrapResult( + const DeviceToDeviceSecureContext::MessageCallback& callback, + bool verified, + const std::string& payload, + const securemessage::Header& header) { + // The payload should contain a DeviceToDeviceMessage proto. + securemessage::DeviceToDeviceMessage device_to_device_message; + if (!verified || !device_to_device_message.ParseFromString(payload)) { + PA_LOG(ERROR) << "Failed to unwrap secure message."; + callback.Run(std::string()); + return; + } + + // Check that the sequence number matches the expected sequence number. + if (device_to_device_message.sequence_number() != last_sequence_number_ + 1) { + PA_LOG(ERROR) << "Expected sequence_number=" << last_sequence_number_ + 1 + << ", but got " << device_to_device_message.sequence_number(); + callback.Run(std::string()); + return; + } + + // Validate the GcmMetadata proto in the header. + cryptauth::GcmMetadata gcm_metadata; + if (!gcm_metadata.ParseFromString(header.public_metadata()) || + gcm_metadata.type() != cryptauth::DEVICE_TO_DEVICE_MESSAGE || + gcm_metadata.version() != kGcmMetadataVersion) { + PA_LOG(ERROR) << "Failed to validate GcmMetadata."; + callback.Run(std::string()); + return; + } + + last_sequence_number_++; + callback.Run(device_to_device_message.message()); +} + +} // proximity_auth diff --git a/components/proximity_auth/device_to_device_secure_context.h b/components/proximity_auth/device_to_device_secure_context.h new file mode 100644 index 0000000..62e3296 --- /dev/null +++ b/components/proximity_auth/device_to_device_secure_context.h @@ -0,0 +1,77 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_PROXIMITY_AUTH_DEVICE_TO_DEVICE_SECURE_CONTEXT_H +#define COMPONENTS_PROXIMITY_AUTH_DEVICE_TO_DEVICE_SECURE_CONTEXT_H + +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "components/proximity_auth/secure_context.h" + +namespace securemessage { +class Header; +} + +namespace proximity_auth { + +class SecureMessageDelegate; + +// SecureContext implementation for the DeviceToDevice protocol. +class DeviceToDeviceSecureContext : public SecureContext { + public: + DeviceToDeviceSecureContext( + scoped_ptr<SecureMessageDelegate> secure_message_delegate, + const std::string& symmetric_key, + const std::string& responder_auth_message_, + ProtocolVersion protocol_version); + + ~DeviceToDeviceSecureContext() override; + + // SecureContext: + void Decode(const std::string& encoded_message, + const MessageCallback& callback) override; + void Encode(const std::string& message, + const MessageCallback& callback) override; + ProtocolVersion GetProtocolVersion() const override; + + // Returns the message received from the remote device that authenticates it. + // This message should have been received during the handshake that + // establishes the secure channel. + std::string GetReceivedAuthMessage() const; + + private: + // Callback for unwrapping a secure message. |callback| will be invoked with + // the decrypted payload if the message is unwrapped successfully; otherwise + // it will be invoked with an empty string. + void HandleUnwrapResult( + const DeviceToDeviceSecureContext::MessageCallback& callback, + bool verified, + const std::string& payload, + const securemessage::Header& header); + + // Delegate for handling the creation and unwrapping of SecureMessages. + scoped_ptr<SecureMessageDelegate> secure_message_delegate_; + + // The symmetric key used to create and unwrap messages. + const std::string symmetric_key_; + + // The [Responder Auth] message received from the remote device during + // authentication. + const std::string responder_auth_message_; + + // The protocol version supported by the remote device. + const ProtocolVersion protocol_version_; + + // The last sequence number of the message sent or received. + int last_sequence_number_; + + base::WeakPtrFactory<DeviceToDeviceSecureContext> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(DeviceToDeviceSecureContext); +}; + +} // namespace proximity_auth + +#endif // COMPONENTS_PROXIMITY_AUTH_DEVICE_TO_DEVICE_SECURE_CONTEXT_H diff --git a/components/proximity_auth/device_to_device_secure_context_unittest.cc b/components/proximity_auth/device_to_device_secure_context_unittest.cc new file mode 100644 index 0000000..cd00fa2 --- /dev/null +++ b/components/proximity_auth/device_to_device_secure_context_unittest.cc @@ -0,0 +1,113 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/proximity_auth/device_to_device_secure_context.h" + +#include "base/bind.h" +#include "base/memory/scoped_ptr.h" +#include "components/proximity_auth/cryptauth/fake_secure_message_delegate.h" +#include "components/proximity_auth/cryptauth/proto/cryptauth_api.pb.h" +#include "components/proximity_auth/cryptauth/proto/securemessage.pb.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace proximity_auth { + +namespace { + +const char kSymmetricKey[] = "symmetric key"; +const char kResponderAuthMessage[] = "responder_auth_message"; +const SecureContext::ProtocolVersion kProtocolVersion = + SecureContext::PROTOCOL_VERSION_THREE_ONE; + +// Callback saving |result| to |result_out|. +void SaveResult(std::string* result_out, const std::string& result) { + *result_out = result; +} + +} // namespace + +class ProximityAuthDeviceToDeviceSecureContextTest : public testing::Test { + protected: + ProximityAuthDeviceToDeviceSecureContextTest() + : secure_context_(make_scoped_ptr(new FakeSecureMessageDelegate()), + kSymmetricKey, + kResponderAuthMessage, + kProtocolVersion) {} + + DeviceToDeviceSecureContext secure_context_; +}; + +TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, GetProperties) { + EXPECT_EQ(kResponderAuthMessage, secure_context_.GetReceivedAuthMessage()); + EXPECT_EQ(kProtocolVersion, secure_context_.GetProtocolVersion()); +} + +TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, CheckEncodedHeader) { + std::string message = "encrypt this message"; + std::string encoded_message; + secure_context_.Encode(message, base::Bind(&SaveResult, &encoded_message)); + + securemessage::SecureMessage secure_message; + ASSERT_TRUE(secure_message.ParseFromString(encoded_message)); + securemessage::HeaderAndBody header_and_body; + ASSERT_TRUE( + header_and_body.ParseFromString(secure_message.header_and_body())); + + cryptauth::GcmMetadata gcm_metadata; + ASSERT_TRUE( + gcm_metadata.ParseFromString(header_and_body.header().public_metadata())); + EXPECT_EQ(1, gcm_metadata.version()); + EXPECT_EQ(cryptauth::DEVICE_TO_DEVICE_MESSAGE, gcm_metadata.type()); +} + +TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, DecodeInvalidMessage) { + std::string encoded_message = "invalidly encoded message"; + std::string decoded_message = "not empty"; + secure_context_.Decode(encoded_message, + base::Bind(&SaveResult, &decoded_message)); + EXPECT_TRUE(decoded_message.empty()); +} + +TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, EncodeAndDecode) { + // Initialize second secure channel with the same parameters as the first. + DeviceToDeviceSecureContext secure_context2( + make_scoped_ptr(new FakeSecureMessageDelegate()), kSymmetricKey, + kResponderAuthMessage, kProtocolVersion); + std::string message = "encrypt this message"; + + // Pass some messages between the two secure contexts. + for (int i = 0; i < 3; ++i) { + std::string encoded_message; + secure_context_.Encode(message, base::Bind(&SaveResult, &encoded_message)); + EXPECT_NE(message, encoded_message); + + std::string decoded_message; + secure_context2.Decode(encoded_message, + base::Bind(&SaveResult, &decoded_message)); + EXPECT_EQ(message, decoded_message); + } +} + +TEST_F(ProximityAuthDeviceToDeviceSecureContextTest, + DecodeInvalidSequenceNumber) { + // Initialize second secure channel with the same parameters as the first. + DeviceToDeviceSecureContext secure_context2( + make_scoped_ptr(new FakeSecureMessageDelegate()), kSymmetricKey, + kResponderAuthMessage, kProtocolVersion); + + // Send a few messages over the first secure context. + std::string message = "encrypt this message"; + std::string encoded1; + for (int i = 0; i < 3; ++i) { + secure_context_.Encode(message, base::Bind(&SaveResult, &encoded1)); + } + + // Second secure channel should not decode the message with an invalid + // sequence number. + std::string decoded_message = "not empty"; + secure_context_.Decode(encoded1, base::Bind(&SaveResult, &decoded_message)); + EXPECT_TRUE(decoded_message.empty()); +} + +} // proximity_auth diff --git a/components/proximity_auth/secure_context.h b/components/proximity_auth/secure_context.h index 3b9f34a..b249720 100644 --- a/components/proximity_auth/secure_context.h +++ b/components/proximity_auth/secure_context.h @@ -5,11 +5,15 @@ #ifndef COMPONENTS_PROXIMITY_AUTH_SECURE_CONTEXT_H #define COMPONENTS_PROXIMITY_AUTH_SECURE_CONTEXT_H +#include "base/callback_forward.h" + namespace proximity_auth { // An interface used to decode and encode messages. class SecureContext { public: + typedef base::Callback<void(const std::string& message)> MessageCallback; + // The protocol version used during authentication. enum ProtocolVersion { PROTOCOL_VERSION_THREE_ZERO, // 3.0 @@ -19,15 +23,16 @@ class SecureContext { virtual ~SecureContext() {} // Decodes the |encoded_message| and returns the result. - virtual std::string Decode(const std::string& encoded_message) = 0; + // This function is asynchronous because the ChromeOS implementation requires + // a DBus call. + virtual void Decode(const std::string& encoded_message, + const MessageCallback& callback) = 0; // Encodes the |message| and returns the result. - virtual std::string Encode(const std::string& message) = 0; - - // Returns the message received from the remote device that authenticates it. - // This message should have been received during the handshake that - // establishes the secure channel. - virtual std::string GetReceivedAuthMessage() const = 0; + // This function is asynchronous because the ChromeOS implementation requires + // a DBus call. + virtual void Encode(const std::string& message, + const MessageCallback& callback) = 0; // Returns the protocol version that was used during authentication. virtual ProtocolVersion GetProtocolVersion() const = 0; diff --git a/components/proximity_auth/wire_message.cc b/components/proximity_auth/wire_message.cc index 7906615..6d21116 100644 --- a/components/proximity_auth/wire_message.cc +++ b/components/proximity_auth/wire_message.cc @@ -111,7 +111,7 @@ scoped_ptr<WireMessage> WireMessage::Deserialize( return scoped_ptr<WireMessage>(); } - return scoped_ptr<WireMessage>(new WireMessage(permit_id, payload)); + return make_scoped_ptr(new WireMessage(payload, permit_id)); } std::string WireMessage::Serialize() const { @@ -155,9 +155,11 @@ std::string WireMessage::Serialize() const { return header_string + json_body; } -WireMessage::WireMessage(const std::string& permit_id, - const std::string& payload) - : permit_id_(permit_id), payload_(payload) { -} +WireMessage::WireMessage(const std::string& payload) + : WireMessage(payload, std::string()) {} + +WireMessage::WireMessage(const std::string& payload, + const std::string& permit_id) + : payload_(payload), permit_id_(permit_id) {} } // namespace proximity_auth diff --git a/components/proximity_auth/wire_message.h b/components/proximity_auth/wire_message.h index 73b643f..c023c21 100644 --- a/components/proximity_auth/wire_message.h +++ b/components/proximity_auth/wire_message.h @@ -14,7 +14,12 @@ namespace proximity_auth { class WireMessage { public: - WireMessage(const std::string& permit_id, const std::string& payload); + // Creates a WireMessage containing |payload|. + explicit WireMessage(const std::string& payload); + + // Creates a WireMessage containing |payload| and |permit_id| in the metadata. + WireMessage(const std::string& payload, const std::string& permit_id); + virtual ~WireMessage(); // Returns the deserialized message from |serialized_message|, or NULL if the @@ -28,17 +33,20 @@ class WireMessage { // Returns a serialized representation of |this| message. virtual std::string Serialize() const; - const std::string& permit_id() const { return permit_id_; } const std::string& payload() const { return payload_; } + const std::string& permit_id() const { return permit_id_; } private: - // Identifier of the permit being used. - // TODO(isherman): Describe in a bit more detail. - const std::string permit_id_; - // The message payload. const std::string payload_; + // Identifier of the permit being used. A permit contains the credentials used + // to authenticate a device. For example, when sending a WireMessage to the + // remote device the |permit_id_| indexes a permit possibly containing the + // public key + // of the local device or a symmetric key shared between the devices. + const std::string permit_id_; + DISALLOW_COPY_AND_ASSIGN(WireMessage); }; diff --git a/components/proximity_auth/wire_message_unittest.cc b/components/proximity_auth/wire_message_unittest.cc index 4411ff8..512f480 100644 --- a/components/proximity_auth/wire_message_unittest.cc +++ b/components/proximity_auth/wire_message_unittest.cc @@ -200,7 +200,7 @@ TEST(ProximityAuthWireMessage, Deserialize_SizeEquals0x01FF) { } TEST(ProximityAuthWireMessage, Serialize_WithPermitId) { - WireMessage message1("example id", "example payload"); + WireMessage message1("example payload", "example id"); std::string bytes = message1.Serialize(); ASSERT_FALSE(bytes.empty()); @@ -214,7 +214,7 @@ TEST(ProximityAuthWireMessage, Serialize_WithPermitId) { } TEST(ProximityAuthWireMessage, Serialize_WithoutPermitId) { - WireMessage message1(std::string(), "example payload"); + WireMessage message1("example payload"); std::string bytes = message1.Serialize(); ASSERT_FALSE(bytes.empty()); @@ -228,7 +228,7 @@ TEST(ProximityAuthWireMessage, Serialize_WithoutPermitId) { } TEST(ProximityAuthWireMessage, Serialize_FailsWithoutPayload) { - WireMessage message1("example id", std::string()); + WireMessage message1(std::string(), "example id"); std::string bytes = message1.Serialize(); EXPECT_TRUE(bytes.empty()); } |