summaryrefslogtreecommitdiffstats
path: root/extensions/browser/api/cast_channel/cast_framer.cc
blob: 6e3075d73994a9b80911c3afb36ce7f627035013 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "extensions/browser/api/cast_channel/cast_framer.h"

#include <stdlib.h>

#include <limits>

#include "base/numerics/safe_conversions.h"
#include "base/strings/string_number_conversions.h"
#include "base/sys_byteorder.h"
#include "extensions/common/api/cast_channel/cast_channel.pb.h"

namespace extensions {
namespace api {
namespace cast_channel {
MessageFramer::MessageFramer(scoped_refptr<net::GrowableIOBuffer> input_buffer)
    : input_buffer_(input_buffer), error_(false) {
  Reset();
}

MessageFramer::~MessageFramer() {
}

MessageFramer::MessageHeader::MessageHeader() : message_size(0) {
}

void MessageFramer::MessageHeader::SetMessageSize(size_t size) {
  DCHECK_LT(size, static_cast<size_t>(std::numeric_limits<uint32_t>::max()));
  DCHECK_GT(size, 0U);
  message_size = size;
}

// TODO(mfoltz): Investigate replacing header serialization with base::Pickle,
// if bit-for-bit compatible.
void MessageFramer::MessageHeader::PrependToString(std::string* str) {
  MessageHeader output = *this;
  output.message_size = base::HostToNet32(message_size);
  size_t header_size = MessageHeader::header_size();
  scoped_ptr<char, base::FreeDeleter> char_array(
      static_cast<char*>(malloc(header_size)));
  memcpy(char_array.get(), &output, header_size);
  str->insert(0, char_array.get(), header_size);
}

// TODO(mfoltz): Investigate replacing header deserialization with base::Pickle,
// if bit-for-bit compatible.
void MessageFramer::MessageHeader::Deserialize(char* data,
                                               MessageHeader* header) {
  uint32_t message_size;
  memcpy(&message_size, data, header_size());
  header->message_size =
      base::checked_cast<size_t>(base::NetToHost32(message_size));
}

// static
size_t MessageFramer::MessageHeader::header_size() {
  return sizeof(uint32_t);
}

// static
size_t MessageFramer::MessageHeader::max_message_size() {
  return 65535;
}

std::string MessageFramer::MessageHeader::ToString() {
  return "{message_size: " +
         base::UintToString(static_cast<uint32_t>(message_size)) + "}";
}

// static
bool MessageFramer::Serialize(const CastMessage& message_proto,
                              std::string* message_data) {
  DCHECK(message_data);
  message_proto.SerializeToString(message_data);
  size_t message_size = message_data->size();
  if (message_size > MessageHeader::max_message_size()) {
    message_data->clear();
    return false;
  }
  MessageHeader header;
  header.SetMessageSize(message_size);
  header.PrependToString(message_data);
  return true;
}

size_t MessageFramer::BytesRequested() {
  size_t bytes_left;
  if (error_) {
    return 0;
  }

  switch (current_element_) {
    case HEADER:
      bytes_left = MessageHeader::header_size() - message_bytes_received_;
      DCHECK_LE(bytes_left, MessageHeader::header_size());
      VLOG(2) << "Bytes needed for header: " << bytes_left;
      return bytes_left;
    case BODY:
      bytes_left =
          (body_size_ + MessageHeader::header_size()) - message_bytes_received_;
      DCHECK_LE(
          bytes_left,
          MessageHeader::max_message_size() - MessageHeader::header_size());
      VLOG(2) << "Bytes needed for body: " << bytes_left;
      return bytes_left;
    default:
      NOTREACHED() << "Unhandled packet element type.";
      return 0;
  }
}

scoped_ptr<CastMessage> MessageFramer::Ingest(size_t num_bytes,
                                              size_t* message_length,
                                              ChannelError* error) {
  DCHECK(error);
  DCHECK(message_length);
  if (error_) {
    *error = CHANNEL_ERROR_INVALID_MESSAGE;
    return scoped_ptr<CastMessage>();
  }

  DCHECK_EQ(base::checked_cast<int32_t>(message_bytes_received_),
            input_buffer_->offset());
  CHECK_LE(num_bytes, BytesRequested());
  message_bytes_received_ += num_bytes;
  *error = CHANNEL_ERROR_NONE;
  *message_length = 0;
  switch (current_element_) {
    case HEADER:
      if (BytesRequested() == 0) {
        MessageHeader header;
        MessageHeader::Deserialize(input_buffer_.get()->StartOfBuffer(),
                                   &header);
        if (header.message_size > MessageHeader::max_message_size()) {
          VLOG(1) << "Error parsing header (message size too large).";
          *error = CHANNEL_ERROR_INVALID_MESSAGE;
          error_ = true;
          return scoped_ptr<CastMessage>();
        }
        current_element_ = BODY;
        body_size_ = header.message_size;
      }
      break;
    case BODY:
      if (BytesRequested() == 0) {
        scoped_ptr<CastMessage> parsed_message(new CastMessage);
        if (!parsed_message->ParseFromArray(
                input_buffer_->StartOfBuffer() + MessageHeader::header_size(),
                body_size_)) {
          VLOG(1) << "Error parsing packet body.";
          *error = CHANNEL_ERROR_INVALID_MESSAGE;
          error_ = true;
          return scoped_ptr<CastMessage>();
        }
        *message_length = body_size_;
        Reset();
        return parsed_message;
      }
      break;
    default:
      NOTREACHED() << "Unhandled packet element type.";
      return scoped_ptr<CastMessage>();
  }

  input_buffer_->set_offset(message_bytes_received_);
  return scoped_ptr<CastMessage>();
}

void MessageFramer::Reset() {
  current_element_ = HEADER;
  message_bytes_received_ = 0;
  body_size_ = 0;
  input_buffer_->set_offset(0);
}

}  // namespace cast_channel
}  // namespace api
}  // namespace extensions