diff options
Diffstat (limited to 'components/proximity_auth')
-rw-r--r-- | components/proximity_auth/wire_message.cc | 71 | ||||
-rw-r--r-- | components/proximity_auth/wire_message_unittest.cc | 57 |
2 files changed, 112 insertions, 16 deletions
diff --git a/components/proximity_auth/wire_message.cc b/components/proximity_auth/wire_message.cc index 0501359..7906615 100644 --- a/components/proximity_auth/wire_message.cc +++ b/components/proximity_auth/wire_message.cc @@ -4,11 +4,14 @@ #include "components/proximity_auth/wire_message.h" +#include <limits> + #include "base/json/json_reader.h" -#include "base/logging.h" +#include "base/json/json_writer.h" #include "base/macros.h" #include "base/values.h" #include "components/proximity_auth/cryptauth/base64url.h" +#include "components/proximity_auth/logging/logging.h" // The wire messages have a simple format: // [ message version ] [ body length ] [ JSON body ] @@ -23,7 +26,7 @@ namespace { const size_t kHeaderLength = 3; // The protocol version of the message format. -const int kExpectedMessageFormatVersion = 3; +const int kMessageFormatVersionThree = 3; const char kPayloadKey[] = "payload"; const char kPermitIdKey[] = "permit_id"; @@ -43,24 +46,24 @@ bool ParseHeader(const std::string& serialized_message, static_assert(kHeaderLength > 2, "kHeaderLength too small"); size_t version = serialized_message[0]; - if (version != kExpectedMessageFormatVersion) { - VLOG(1) << "Error: Invalid message version. Got " << version - << ", expected " << kExpectedMessageFormatVersion; + if (version != kMessageFormatVersionThree) { + PA_LOG(WARNING) << "Error: Invalid message version. Got " << version + << ", expected " << kMessageFormatVersionThree; return false; } - size_t expected_body_length = - (static_cast<size_t>(serialized_message[1]) << 8) | - (static_cast<size_t>(serialized_message[2]) << 0); + uint16_t expected_body_length = + (static_cast<uint8_t>(serialized_message[1]) << 8) | + (static_cast<uint8_t>(serialized_message[2]) << 0); size_t expected_message_length = kHeaderLength + expected_body_length; if (serialized_message.size() < expected_message_length) { *is_incomplete_message = true; return false; } if (serialized_message.size() != expected_message_length) { - VLOG(1) << "Error: Invalid message length. Got " - << serialized_message.size() << ", expected " - << expected_message_length; + PA_LOG(WARNING) << "Error: Invalid message length. Got " + << serialized_message.size() << ", expected " + << expected_message_length; return false; } @@ -82,7 +85,7 @@ scoped_ptr<WireMessage> WireMessage::Deserialize( scoped_ptr<base::Value> body_value(base::JSONReader::DeprecatedRead( serialized_message.substr(kHeaderLength))); if (!body_value || !body_value->IsType(base::Value::TYPE_DICTIONARY)) { - VLOG(1) << "Error: Unable to parse message as JSON."; + PA_LOG(WARNING) << "Error: Unable to parse message as JSON."; return scoped_ptr<WireMessage>(); } @@ -98,13 +101,13 @@ scoped_ptr<WireMessage> WireMessage::Deserialize( std::string payload_base64; if (!body->GetString(kPayloadKey, &payload_base64) || payload_base64.empty()) { - VLOG(1) << "Error: Missing payload."; + PA_LOG(WARNING) << "Error: Missing payload."; return scoped_ptr<WireMessage>(); } std::string payload; if (!Base64UrlDecode(payload_base64, &payload)) { - VLOG(1) << "Error: Invalid base64 encoding for payload."; + PA_LOG(WARNING) << "Error: Invalid base64 encoding for payload."; return scoped_ptr<WireMessage>(); } @@ -112,8 +115,44 @@ scoped_ptr<WireMessage> WireMessage::Deserialize( } std::string WireMessage::Serialize() const { - // TODO(isherman): Implement. - return "This method is not yet implemented."; + if (payload_.empty()) { + PA_LOG(ERROR) << "Failed to serialize empty wire message."; + return std::string(); + } + + // Create JSON body containing permit id and payload. + base::DictionaryValue body; + if (!permit_id_.empty()) + body.SetString(kPermitIdKey, permit_id_); + + std::string base64_payload; + Base64UrlEncode(payload_, &base64_payload); + body.SetString(kPayloadKey, base64_payload); + + std::string json_body; + if (!base::JSONWriter::Write(body, &json_body)) { + PA_LOG(ERROR) << "Failed to convert WireMessage body to JSON: " << body; + return std::string(); + } + + // Create header containing version and payload size. + size_t body_size = json_body.size(); + if (body_size > std::numeric_limits<uint16_t>::max()) { + PA_LOG(ERROR) << "Can not create WireMessage because body size exceeds " + << "16-bit unsigned integer: " << body_size; + return std::string(); + } + + uint8_t header[] = { + static_cast<uint8_t>(kMessageFormatVersionThree), + static_cast<uint8_t>((body_size >> 8) & 0xFF), + static_cast<uint8_t>(body_size & 0xFF), + }; + static_assert(sizeof(header) == kHeaderLength, "Malformed header."); + + std::string header_string(kHeaderLength, 0); + std::memcpy(&header_string[0], header, kHeaderLength); + return header_string + json_body; } WireMessage::WireMessage(const std::string& permit_id, diff --git a/components/proximity_auth/wire_message_unittest.cc b/components/proximity_auth/wire_message_unittest.cc index 68ce086..4411ff8 100644 --- a/components/proximity_auth/wire_message_unittest.cc +++ b/components/proximity_auth/wire_message_unittest.cc @@ -4,6 +4,7 @@ #include "components/proximity_auth/wire_message.h" +#include "base/strings/string_util.h" #include "testing/gtest/include/gtest/gtest.h" namespace proximity_auth { @@ -176,4 +177,60 @@ TEST(ProximityAuthWireMessage, Deserialize_ValidMessageWithExtraUnknownFields) { EXPECT_EQ("a", message->payload()); } +TEST(ProximityAuthWireMessage, Deserialize_SizeEquals0x01FF) { + // Create a message with a body of 0x01FF bytes to test the size contained in + // the header is parsed correctly. + std::string header("\3\x01\xff", 3); + char json_template[] = "{\"payload\":\"YQ==\", \"filler\":\"$1\"}"; + // Add 3 to the size to take into account the "$1" and NUL terminator ("\0") + // characters in |json_template|. + uint16_t filler_size = 0x01ff - sizeof(json_template) + 3; + std::string filler(filler_size, 'F'); + + std::string body = base::ReplaceStringPlaceholders( + json_template, std::vector<std::string>(1u, filler), nullptr); + std::string serialized_message = header + body; + + bool is_incomplete; + scoped_ptr<WireMessage> message = + WireMessage::Deserialize(serialized_message, &is_incomplete); + EXPECT_FALSE(is_incomplete); + ASSERT_TRUE(message); + EXPECT_EQ("a", message->payload()); +} + +TEST(ProximityAuthWireMessage, Serialize_WithPermitId) { + WireMessage message1("example id", "example payload"); + std::string bytes = message1.Serialize(); + ASSERT_FALSE(bytes.empty()); + + bool is_incomplete; + scoped_ptr<WireMessage> message2 = + WireMessage::Deserialize(bytes, &is_incomplete); + EXPECT_FALSE(is_incomplete); + ASSERT_TRUE(message2); + EXPECT_EQ("example id", message2->permit_id()); + EXPECT_EQ("example payload", message2->payload()); +} + +TEST(ProximityAuthWireMessage, Serialize_WithoutPermitId) { + WireMessage message1(std::string(), "example payload"); + std::string bytes = message1.Serialize(); + ASSERT_FALSE(bytes.empty()); + + bool is_incomplete; + scoped_ptr<WireMessage> message2 = + WireMessage::Deserialize(bytes, &is_incomplete); + EXPECT_FALSE(is_incomplete); + ASSERT_TRUE(message2); + EXPECT_EQ(std::string(), message2->permit_id()); + EXPECT_EQ("example payload", message2->payload()); +} + +TEST(ProximityAuthWireMessage, Serialize_FailsWithoutPayload) { + WireMessage message1("example id", std::string()); + std::string bytes = message1.Serialize(); + EXPECT_TRUE(bytes.empty()); +} + } // namespace proximity_auth |