summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorbyungchul <byungchul@chromium.org>2014-08-25 16:27:46 -0700
committerCommit bot <commit-bot@chromium.org>2014-08-25 23:39:36 +0000
commit38c3ae72c9743dbe172779477917bf24bc25ab97 (patch)
treedb9494d10bbe65ca83e265f63d82fbef970bc0ac /net
parentb71e30c90d27b6af56e80db25236019d007b8055 (diff)
downloadchromium_src-38c3ae72c9743dbe172779477917bf24bc25ab97.zip
chromium_src-38c3ae72c9743dbe172779477917bf24bc25ab97.tar.gz
chromium_src-38c3ae72c9743dbe172779477917bf24bc25ab97.tar.bz2
Revert "Revert of Replace StreamListenSocket with StreamSocket in HttpServer. (patchset #29 of https://codereview.chromium.org/296053012/)"
This reverts commit 0b2f33f4a88efbd203b0623324ad4114e3bb9d23. This is relanding CL of https://codereview.chromium.org/296053012/, which broke http server unittests because http server doesn't send response synchronously any more. This CL fixes unittests by reading responses completely. Patch set #1 is same to the original CL. Patch set #2 is the diff. BUG=371906 TBR=pfeldman@chromium.org,darin@chromium.org,gunsch@chromium.org,mnaganov@chromium.org Review URL: https://codereview.chromium.org/487013003 Cr-Commit-Position: refs/heads/master@{#291784}
Diffstat (limited to 'net')
-rw-r--r--net/net.gypi1
-rw-r--r--net/server/http_connection.cc161
-rw-r--r--net/server/http_connection.h123
-rw-r--r--net/server/http_connection_unittest.cc331
-rw-r--r--net/server/http_server.cc257
-rw-r--r--net/server/http_server.h70
-rw-r--r--net/server/http_server_response_info.cc8
-rw-r--r--net/server/http_server_response_info.h3
-rw-r--r--net/server/http_server_unittest.cc164
-rw-r--r--net/server/web_socket.cc163
-rw-r--r--net/server/web_socket.h13
-rw-r--r--net/socket/server_socket.h2
12 files changed, 1052 insertions, 244 deletions
diff --git a/net/net.gypi b/net/net.gypi
index f240eba..b57d853 100644
--- a/net/net.gypi
+++ b/net/net.gypi
@@ -1555,6 +1555,7 @@
'quic/quic_utils_test.cc',
'quic/quic_write_blocked_list_test.cc',
'quic/reliable_quic_stream_test.cc',
+ 'server/http_connection_unittest.cc',
'server/http_server_response_info_unittest.cc',
'server/http_server_unittest.cc',
'socket/client_socket_pool_base_unittest.cc',
diff --git a/net/server/http_connection.cc b/net/server/http_connection.cc
index d433012..3401f81 100644
--- a/net/server/http_connection.cc
+++ b/net/server/http_connection.cc
@@ -4,44 +4,163 @@
#include "net/server/http_connection.h"
-#include "net/server/http_server.h"
-#include "net/server/http_server_response_info.h"
+#include "base/logging.h"
#include "net/server/web_socket.h"
-#include "net/socket/stream_listen_socket.h"
+#include "net/socket/stream_socket.h"
namespace net {
-int HttpConnection::last_id_ = 0;
+HttpConnection::ReadIOBuffer::ReadIOBuffer()
+ : base_(new GrowableIOBuffer()),
+ max_buffer_size_(kDefaultMaxBufferSize) {
+ SetCapacity(kInitialBufSize);
+}
-void HttpConnection::Send(const std::string& data) {
- if (!socket_.get())
- return;
- socket_->Send(data);
+HttpConnection::ReadIOBuffer::~ReadIOBuffer() {
+ data_ = NULL; // base_ owns data_.
+}
+
+int HttpConnection::ReadIOBuffer::GetCapacity() const {
+ return base_->capacity();
+}
+
+void HttpConnection::ReadIOBuffer::SetCapacity(int capacity) {
+ DCHECK_LE(GetSize(), capacity);
+ base_->SetCapacity(capacity);
+ data_ = base_->data();
+}
+
+bool HttpConnection::ReadIOBuffer::IncreaseCapacity() {
+ if (GetCapacity() >= max_buffer_size_) {
+ LOG(ERROR) << "Too large read data is pending: capacity=" << GetCapacity()
+ << ", max_buffer_size=" << max_buffer_size_
+ << ", read=" << GetSize();
+ return false;
+ }
+
+ int new_capacity = GetCapacity() * kCapacityIncreaseFactor;
+ if (new_capacity > max_buffer_size_)
+ new_capacity = max_buffer_size_;
+ SetCapacity(new_capacity);
+ return true;
+}
+
+char* HttpConnection::ReadIOBuffer::StartOfBuffer() const {
+ return base_->StartOfBuffer();
+}
+
+int HttpConnection::ReadIOBuffer::GetSize() const {
+ return base_->offset();
+}
+
+void HttpConnection::ReadIOBuffer::DidRead(int bytes) {
+ DCHECK_GE(RemainingCapacity(), bytes);
+ base_->set_offset(base_->offset() + bytes);
+ data_ = base_->data();
+}
+
+int HttpConnection::ReadIOBuffer::RemainingCapacity() const {
+ return base_->RemainingCapacity();
+}
+
+void HttpConnection::ReadIOBuffer::DidConsume(int bytes) {
+ int previous_size = GetSize();
+ int unconsumed_size = previous_size - bytes;
+ DCHECK_LE(0, unconsumed_size);
+ if (unconsumed_size > 0) {
+ // Move unconsumed data to the start of buffer.
+ memmove(StartOfBuffer(), StartOfBuffer() + bytes, unconsumed_size);
+ }
+ base_->set_offset(unconsumed_size);
+ data_ = base_->data();
+
+ // If capacity is too big, reduce it.
+ if (GetCapacity() > kMinimumBufSize &&
+ GetCapacity() > previous_size * kCapacityIncreaseFactor) {
+ int new_capacity = GetCapacity() / kCapacityIncreaseFactor;
+ if (new_capacity < kMinimumBufSize)
+ new_capacity = kMinimumBufSize;
+ // realloc() within GrowableIOBuffer::SetCapacity() could move data even
+ // when size is reduced. If unconsumed_size == 0, i.e. no data exists in
+ // the buffer, free internal buffer first to guarantee no data move.
+ if (!unconsumed_size)
+ base_->SetCapacity(0);
+ SetCapacity(new_capacity);
+ }
+}
+
+HttpConnection::QueuedWriteIOBuffer::QueuedWriteIOBuffer()
+ : total_size_(0),
+ max_buffer_size_(kDefaultMaxBufferSize) {
+}
+
+HttpConnection::QueuedWriteIOBuffer::~QueuedWriteIOBuffer() {
+ data_ = NULL; // pending_data_ owns data_.
}
-void HttpConnection::Send(const char* bytes, int len) {
- if (!socket_.get())
+bool HttpConnection::QueuedWriteIOBuffer::IsEmpty() const {
+ return pending_data_.empty();
+}
+
+bool HttpConnection::QueuedWriteIOBuffer::Append(const std::string& data) {
+ if (data.empty())
+ return true;
+
+ if (total_size_ + static_cast<int>(data.size()) > max_buffer_size_) {
+ LOG(ERROR) << "Too large write data is pending: size="
+ << total_size_ + data.size()
+ << ", max_buffer_size=" << max_buffer_size_;
+ return false;
+ }
+
+ pending_data_.push(data);
+ total_size_ += data.size();
+
+ // If new data is the first pending data, updates data_.
+ if (pending_data_.size() == 1)
+ data_ = const_cast<char*>(pending_data_.front().data());
+ return true;
+}
+
+void HttpConnection::QueuedWriteIOBuffer::DidConsume(int size) {
+ DCHECK_GE(total_size_, size);
+ DCHECK_GE(GetSizeToWrite(), size);
+ if (size == 0)
return;
- socket_->Send(bytes, len);
+
+ if (size < GetSizeToWrite()) {
+ data_ += size;
+ } else { // size == GetSizeToWrite(). Updates data_ to next pending data.
+ pending_data_.pop();
+ data_ = IsEmpty() ? NULL : const_cast<char*>(pending_data_.front().data());
+ }
+ total_size_ -= size;
}
-void HttpConnection::Send(const HttpServerResponseInfo& response) {
- Send(response.Serialize());
+int HttpConnection::QueuedWriteIOBuffer::GetSizeToWrite() const {
+ if (IsEmpty()) {
+ DCHECK_EQ(0, total_size_);
+ return 0;
+ }
+ DCHECK_GE(data_, pending_data_.front().data());
+ int consumed = static_cast<int>(data_ - pending_data_.front().data());
+ DCHECK_GT(static_cast<int>(pending_data_.front().size()), consumed);
+ return pending_data_.front().size() - consumed;
}
-HttpConnection::HttpConnection(HttpServer* server,
- scoped_ptr<StreamListenSocket> sock)
- : server_(server),
- socket_(sock.Pass()) {
- id_ = last_id_++;
+HttpConnection::HttpConnection(int id, scoped_ptr<StreamSocket> socket)
+ : id_(id),
+ socket_(socket.Pass()),
+ read_buf_(new ReadIOBuffer()),
+ write_buf_(new QueuedWriteIOBuffer()) {
}
HttpConnection::~HttpConnection() {
- server_->delegate_->OnClose(id_);
}
-void HttpConnection::Shift(int num_bytes) {
- recv_data_ = recv_data_.substr(num_bytes);
+void HttpConnection::SetWebSocket(scoped_ptr<WebSocket> web_socket) {
+ DCHECK(!web_socket_);
+ web_socket_ = web_socket.Pass();
}
} // namespace net
diff --git a/net/server/http_connection.h b/net/server/http_connection.h
index 17faa46..c7225e1 100644
--- a/net/server/http_connection.h
+++ b/net/server/http_connection.h
@@ -5,43 +5,130 @@
#ifndef NET_SERVER_HTTP_CONNECTION_H_
#define NET_SERVER_HTTP_CONNECTION_H_
+#include <queue>
#include <string>
#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
-#include "net/http/http_status_code.h"
+#include "net/base/io_buffer.h"
namespace net {
-class HttpServer;
-class HttpServerResponseInfo;
-class StreamListenSocket;
+class StreamSocket;
class WebSocket;
+// A container which has all information of an http connection. It includes
+// id, underlying socket, and pending read/write data.
class HttpConnection {
public:
- ~HttpConnection();
+ // IOBuffer for data read. It's a wrapper around GrowableIOBuffer, with more
+ // functions for buffer management. It moves unconsumed data to the start of
+ // buffer.
+ class ReadIOBuffer : public IOBuffer {
+ public:
+ static const int kInitialBufSize = 1024;
+ static const int kMinimumBufSize = 128;
+ static const int kCapacityIncreaseFactor = 2;
+ static const int kDefaultMaxBufferSize = 1 * 1024 * 1024; // 1 Mbytes.
+
+ ReadIOBuffer();
+
+ // Capacity.
+ int GetCapacity() const;
+ void SetCapacity(int capacity);
+ // Increases capacity and returns true if capacity is not beyond the limit.
+ bool IncreaseCapacity();
+
+ // Start of read data.
+ char* StartOfBuffer() const;
+ // Returns the bytes of read data.
+ int GetSize() const;
+ // More read data was appended.
+ void DidRead(int bytes);
+ // Capacity for which more read data can be appended.
+ int RemainingCapacity() const;
+
+ // Removes consumed data and moves unconsumed data to the start of buffer.
+ void DidConsume(int bytes);
+
+ // Limit of how much internal capacity can increase.
+ int max_buffer_size() const { return max_buffer_size_; }
+ void set_max_buffer_size(int max_buffer_size) {
+ max_buffer_size_ = max_buffer_size;
+ }
+
+ private:
+ virtual ~ReadIOBuffer();
+
+ scoped_refptr<GrowableIOBuffer> base_;
+ int max_buffer_size_;
+
+ DISALLOW_COPY_AND_ASSIGN(ReadIOBuffer);
+ };
+
+ // IOBuffer of pending data to write which has a queue of pending data. Each
+ // pending data is stored in std::string. data() is the data of first
+ // std::string stored.
+ class QueuedWriteIOBuffer : public IOBuffer {
+ public:
+ static const int kDefaultMaxBufferSize = 1 * 1024 * 1024; // 1 Mbytes.
+
+ QueuedWriteIOBuffer();
+
+ // Whether or not pending data exists.
+ bool IsEmpty() const;
- void Send(const std::string& data);
- void Send(const char* bytes, int len);
- void Send(const HttpServerResponseInfo& response);
+ // Appends new pending data and returns true if total size doesn't exceed
+ // the limit, |total_size_limit_|. It would change data() if new data is
+ // the first pending data.
+ bool Append(const std::string& data);
- void Shift(int num_bytes);
+ // Consumes data and changes data() accordingly. It cannot be more than
+ // GetSizeToWrite().
+ void DidConsume(int size);
+
+ // Gets size of data to write this time. It is NOT total data size.
+ int GetSizeToWrite() const;
+
+ // Total size of all pending data.
+ int total_size() const { return total_size_; }
+
+ // Limit of how much data can be pending.
+ int max_buffer_size() const { return max_buffer_size_; }
+ void set_max_buffer_size(int max_buffer_size) {
+ max_buffer_size_ = max_buffer_size;
+ }
+
+ private:
+ virtual ~QueuedWriteIOBuffer();
+
+ std::queue<std::string> pending_data_;
+ int total_size_;
+ int max_buffer_size_;
+
+ DISALLOW_COPY_AND_ASSIGN(QueuedWriteIOBuffer);
+ };
+
+ HttpConnection(int id, scoped_ptr<StreamSocket> socket);
+ ~HttpConnection();
- const std::string& recv_data() const { return recv_data_; }
int id() const { return id_; }
+ StreamSocket* socket() const { return socket_.get(); }
+ ReadIOBuffer* read_buf() const { return read_buf_.get(); }
+ QueuedWriteIOBuffer* write_buf() const { return write_buf_.get(); }
- private:
- friend class HttpServer;
- static int last_id_;
+ WebSocket* web_socket() const { return web_socket_.get(); }
+ void SetWebSocket(scoped_ptr<WebSocket> web_socket);
- HttpConnection(HttpServer* server, scoped_ptr<StreamListenSocket> sock);
+ private:
+ const int id_;
+ const scoped_ptr<StreamSocket> socket_;
+ const scoped_refptr<ReadIOBuffer> read_buf_;
+ const scoped_refptr<QueuedWriteIOBuffer> write_buf_;
- HttpServer* server_;
- scoped_ptr<StreamListenSocket> socket_;
scoped_ptr<WebSocket> web_socket_;
- std::string recv_data_;
- int id_;
+
DISALLOW_COPY_AND_ASSIGN(HttpConnection);
};
diff --git a/net/server/http_connection_unittest.cc b/net/server/http_connection_unittest.cc
new file mode 100644
index 0000000..488fd6f
--- /dev/null
+++ b/net/server/http_connection_unittest.cc
@@ -0,0 +1,331 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/server/http_connection.h"
+
+#include <string>
+
+#include "base/memory/ref_counted.h"
+#include "base/strings/string_piece.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+namespace {
+
+std::string GetTestString(int size) {
+ std::string test_string;
+ for (int i = 0; i < size; ++i) {
+ test_string.push_back('A' + (i % 26));
+ }
+ return test_string;
+}
+
+TEST(HttpConnectionTest, ReadIOBuffer_SetCapacity) {
+ scoped_refptr<HttpConnection::ReadIOBuffer> buffer(
+ new HttpConnection::ReadIOBuffer);
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0,
+ buffer->GetCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0,
+ buffer->RemainingCapacity());
+ EXPECT_EQ(0, buffer->GetSize());
+
+ const int kNewCapacity = HttpConnection::ReadIOBuffer::kInitialBufSize + 128;
+ buffer->SetCapacity(kNewCapacity);
+ EXPECT_EQ(kNewCapacity, buffer->GetCapacity());
+ EXPECT_EQ(kNewCapacity, buffer->RemainingCapacity());
+ EXPECT_EQ(0, buffer->GetSize());
+}
+
+TEST(HttpConnectionTest, ReadIOBuffer_SetCapacity_WithData) {
+ scoped_refptr<HttpConnection::ReadIOBuffer> buffer(
+ new HttpConnection::ReadIOBuffer);
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0,
+ buffer->GetCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0,
+ buffer->RemainingCapacity());
+
+ // Write arbitrary data up to kInitialBufSize.
+ const std::string kReadData(
+ GetTestString(HttpConnection::ReadIOBuffer::kInitialBufSize));
+ memcpy(buffer->data(), kReadData.data(), kReadData.size());
+ buffer->DidRead(kReadData.size());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0,
+ buffer->GetCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize -
+ static_cast<int>(kReadData.size()),
+ buffer->RemainingCapacity());
+ EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize());
+ EXPECT_EQ(kReadData,
+ base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize()));
+
+ // Check if read data in the buffer is same after SetCapacity().
+ const int kNewCapacity = HttpConnection::ReadIOBuffer::kInitialBufSize + 128;
+ buffer->SetCapacity(kNewCapacity);
+ EXPECT_EQ(kNewCapacity, buffer->GetCapacity());
+ EXPECT_EQ(kNewCapacity - static_cast<int>(kReadData.size()),
+ buffer->RemainingCapacity());
+ EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize());
+ EXPECT_EQ(kReadData,
+ base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize()));
+}
+
+TEST(HttpConnectionTest, ReadIOBuffer_IncreaseCapacity) {
+ scoped_refptr<HttpConnection::ReadIOBuffer> buffer(
+ new HttpConnection::ReadIOBuffer);
+ EXPECT_TRUE(buffer->IncreaseCapacity());
+ const int kExpectedInitialBufSize =
+ HttpConnection::ReadIOBuffer::kInitialBufSize *
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor;
+ EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity());
+ EXPECT_EQ(kExpectedInitialBufSize, buffer->RemainingCapacity());
+ EXPECT_EQ(0, buffer->GetSize());
+
+ // Increase capacity until it fails.
+ while (buffer->IncreaseCapacity());
+ EXPECT_FALSE(buffer->IncreaseCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0,
+ buffer->max_buffer_size());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0,
+ buffer->GetCapacity());
+
+ // Enlarge capacity limit.
+ buffer->set_max_buffer_size(buffer->max_buffer_size() * 2);
+ EXPECT_TRUE(buffer->IncreaseCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize *
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor,
+ buffer->GetCapacity());
+
+ // Shrink capacity limit. It doesn't change capacity itself.
+ buffer->set_max_buffer_size(
+ HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize / 2);
+ EXPECT_FALSE(buffer->IncreaseCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize *
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor,
+ buffer->GetCapacity());
+}
+
+TEST(HttpConnectionTest, ReadIOBuffer_IncreaseCapacity_WithData) {
+ scoped_refptr<HttpConnection::ReadIOBuffer> buffer(
+ new HttpConnection::ReadIOBuffer);
+ EXPECT_TRUE(buffer->IncreaseCapacity());
+ const int kExpectedInitialBufSize =
+ HttpConnection::ReadIOBuffer::kInitialBufSize *
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor;
+ EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity());
+ EXPECT_EQ(kExpectedInitialBufSize, buffer->RemainingCapacity());
+ EXPECT_EQ(0, buffer->GetSize());
+
+ // Write arbitrary data up to kExpectedInitialBufSize.
+ std::string kReadData(GetTestString(kExpectedInitialBufSize));
+ memcpy(buffer->data(), kReadData.data(), kReadData.size());
+ buffer->DidRead(kReadData.size());
+ EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity());
+ EXPECT_EQ(kExpectedInitialBufSize - static_cast<int>(kReadData.size()),
+ buffer->RemainingCapacity());
+ EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize());
+ EXPECT_EQ(kReadData,
+ base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize()));
+
+ // Increase capacity until it fails and check if read data in the buffer is
+ // same.
+ while (buffer->IncreaseCapacity());
+ EXPECT_FALSE(buffer->IncreaseCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0,
+ buffer->max_buffer_size());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0,
+ buffer->GetCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize -
+ static_cast<int>(kReadData.size()),
+ buffer->RemainingCapacity());
+ EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize());
+ EXPECT_EQ(kReadData,
+ base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize()));
+}
+
+TEST(HttpConnectionTest, ReadIOBuffer_DidRead_DidConsume) {
+ scoped_refptr<HttpConnection::ReadIOBuffer> buffer(
+ new HttpConnection::ReadIOBuffer);
+ const char* start_of_buffer = buffer->StartOfBuffer();
+ EXPECT_EQ(start_of_buffer, buffer->data());
+
+ // Read data.
+ const int kReadLength = 128;
+ const std::string kReadData(GetTestString(kReadLength));
+ memcpy(buffer->data(), kReadData.data(), kReadLength);
+ buffer->DidRead(kReadLength);
+ // No change in total capacity.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0,
+ buffer->GetCapacity());
+ // Change in unused capacity because of read data.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize - kReadLength,
+ buffer->RemainingCapacity());
+ EXPECT_EQ(kReadLength, buffer->GetSize());
+ // No change in start pointers of read data.
+ EXPECT_EQ(start_of_buffer, buffer->StartOfBuffer());
+ // Change in start pointer of unused buffer.
+ EXPECT_EQ(start_of_buffer + kReadLength, buffer->data());
+ // Test read data.
+ EXPECT_EQ(kReadData, std::string(buffer->StartOfBuffer(), buffer->GetSize()));
+
+ // Consume data partially.
+ const int kConsumedLength = 32;
+ ASSERT_LT(kConsumedLength, kReadLength);
+ buffer->DidConsume(kConsumedLength);
+ // Capacity reduced because read data was too small comparing to capacity.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor,
+ buffer->GetCapacity());
+ // Change in unused capacity because of read data.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor -
+ kReadLength + kConsumedLength,
+ buffer->RemainingCapacity());
+ // Change in read size.
+ EXPECT_EQ(kReadLength - kConsumedLength, buffer->GetSize());
+ // Start data could be changed even when capacity is reduced.
+ start_of_buffer = buffer->StartOfBuffer();
+ // Change in start pointer of unused buffer.
+ EXPECT_EQ(start_of_buffer + kReadLength - kConsumedLength, buffer->data());
+ // Change in read data.
+ EXPECT_EQ(kReadData.substr(kConsumedLength),
+ std::string(buffer->StartOfBuffer(), buffer->GetSize()));
+
+ // Read more data.
+ const int kReadLength2 = 64;
+ buffer->DidRead(kReadLength2);
+ // No change in total capacity.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor,
+ buffer->GetCapacity());
+ // Change in unused capacity because of read data.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor -
+ kReadLength + kConsumedLength - kReadLength2,
+ buffer->RemainingCapacity());
+ // Change in read size
+ EXPECT_EQ(kReadLength - kConsumedLength + kReadLength2, buffer->GetSize());
+ // No change in start pointer of read part.
+ EXPECT_EQ(start_of_buffer, buffer->StartOfBuffer());
+ // Change in start pointer of unused buffer.
+ EXPECT_EQ(start_of_buffer + kReadLength - kConsumedLength + kReadLength2,
+ buffer->data());
+
+ // Consume data fully.
+ buffer->DidConsume(kReadLength - kConsumedLength + kReadLength2);
+ // Capacity reduced again because read data was too small.
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor,
+ buffer->GetCapacity());
+ EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor /
+ HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor,
+ buffer->RemainingCapacity());
+ // All reverts to initial because no data is left.
+ EXPECT_EQ(0, buffer->GetSize());
+ // Start data could be changed even when capacity is reduced.
+ start_of_buffer = buffer->StartOfBuffer();
+ EXPECT_EQ(start_of_buffer, buffer->data());
+}
+
+TEST(HttpConnectionTest, QueuedWriteIOBuffer_Append_DidConsume) {
+ scoped_refptr<HttpConnection::QueuedWriteIOBuffer> buffer(
+ new HttpConnection::QueuedWriteIOBuffer());
+ EXPECT_TRUE(buffer->IsEmpty());
+ EXPECT_EQ(0, buffer->GetSizeToWrite());
+ EXPECT_EQ(0, buffer->total_size());
+
+ const std::string kData("data to write");
+ EXPECT_TRUE(buffer->Append(kData));
+ EXPECT_FALSE(buffer->IsEmpty());
+ EXPECT_EQ(static_cast<int>(kData.size()), buffer->GetSizeToWrite());
+ EXPECT_EQ(static_cast<int>(kData.size()), buffer->total_size());
+ // First data to write is same to kData.
+ EXPECT_EQ(kData, base::StringPiece(buffer->data(), buffer->GetSizeToWrite()));
+
+ const std::string kData2("more data to write");
+ EXPECT_TRUE(buffer->Append(kData2));
+ EXPECT_FALSE(buffer->IsEmpty());
+ // No change in size to write.
+ EXPECT_EQ(static_cast<int>(kData.size()), buffer->GetSizeToWrite());
+ // Change in total size.
+ EXPECT_EQ(static_cast<int>(kData.size() + kData2.size()),
+ buffer->total_size());
+ // First data to write has not been changed. Same to kData.
+ EXPECT_EQ(kData, base::StringPiece(buffer->data(), buffer->GetSizeToWrite()));
+
+ // Consume data partially.
+ const int kConsumedLength = kData.length() - 1;
+ buffer->DidConsume(kConsumedLength);
+ EXPECT_FALSE(buffer->IsEmpty());
+ // Change in size to write.
+ EXPECT_EQ(static_cast<int>(kData.size()) - kConsumedLength,
+ buffer->GetSizeToWrite());
+ // Change in total size.
+ EXPECT_EQ(static_cast<int>(kData.size() + kData2.size()) - kConsumedLength,
+ buffer->total_size());
+ // First data to write has shrinked.
+ EXPECT_EQ(kData.substr(kConsumedLength),
+ base::StringPiece(buffer->data(), buffer->GetSizeToWrite()));
+
+ // Consume first data fully.
+ buffer->DidConsume(kData.size() - kConsumedLength);
+ EXPECT_FALSE(buffer->IsEmpty());
+ // Now, size to write is size of data added second.
+ EXPECT_EQ(static_cast<int>(kData2.size()), buffer->GetSizeToWrite());
+ // Change in total size.
+ EXPECT_EQ(static_cast<int>(kData2.size()), buffer->total_size());
+ // First data to write has changed to kData2.
+ EXPECT_EQ(kData2,
+ base::StringPiece(buffer->data(), buffer->GetSizeToWrite()));
+
+ // Consume second data fully.
+ buffer->DidConsume(kData2.size());
+ EXPECT_TRUE(buffer->IsEmpty());
+ EXPECT_EQ(0, buffer->GetSizeToWrite());
+ EXPECT_EQ(0, buffer->total_size());
+}
+
+TEST(HttpConnectionTest, QueuedWriteIOBuffer_TotalSizeLimit) {
+ scoped_refptr<HttpConnection::QueuedWriteIOBuffer> buffer(
+ new HttpConnection::QueuedWriteIOBuffer());
+ EXPECT_EQ(HttpConnection::QueuedWriteIOBuffer::kDefaultMaxBufferSize + 0,
+ buffer->max_buffer_size());
+
+ // Set total size limit very small.
+ buffer->set_max_buffer_size(10);
+
+ const int kDataLength = 4;
+ const std::string kData(kDataLength, 'd');
+ EXPECT_TRUE(buffer->Append(kData));
+ EXPECT_EQ(kDataLength, buffer->total_size());
+ EXPECT_TRUE(buffer->Append(kData));
+ EXPECT_EQ(kDataLength * 2, buffer->total_size());
+
+ // Cannot append more data because it exceeds the limit.
+ EXPECT_FALSE(buffer->Append(kData));
+ EXPECT_EQ(kDataLength * 2, buffer->total_size());
+
+ // Consume data partially.
+ const int kConsumedLength = 2;
+ buffer->DidConsume(kConsumedLength);
+ EXPECT_EQ(kDataLength * 2 - kConsumedLength, buffer->total_size());
+
+ // Can add more data.
+ EXPECT_TRUE(buffer->Append(kData));
+ EXPECT_EQ(kDataLength * 3 - kConsumedLength, buffer->total_size());
+
+ // Cannot append more data because it exceeds the limit.
+ EXPECT_FALSE(buffer->Append(kData));
+ EXPECT_EQ(kDataLength * 3 - kConsumedLength, buffer->total_size());
+
+ // Enlarge limit.
+ buffer->set_max_buffer_size(20);
+ // Can add more data.
+ EXPECT_TRUE(buffer->Append(kData));
+ EXPECT_EQ(kDataLength * 4 - kConsumedLength, buffer->total_size());
+}
+
+} // namespace
+} // namespace net
diff --git a/net/server/http_server.cc b/net/server/http_server.cc
index 043e625..fb0dab3 100644
--- a/net/server/http_server.cc
+++ b/net/server/http_server.cc
@@ -17,14 +17,25 @@
#include "net/server/http_server_request_info.h"
#include "net/server/http_server_response_info.h"
#include "net/server/web_socket.h"
-#include "net/socket/tcp_listen_socket.h"
+#include "net/socket/server_socket.h"
+#include "net/socket/stream_socket.h"
+#include "net/socket/tcp_server_socket.h"
namespace net {
-HttpServer::HttpServer(const StreamListenSocketFactory& factory,
+HttpServer::HttpServer(scoped_ptr<ServerSocket> server_socket,
HttpServer::Delegate* delegate)
- : delegate_(delegate),
- server_(factory.CreateAndListen(this)) {
+ : server_socket_(server_socket.Pass()),
+ delegate_(delegate),
+ last_id_(0),
+ weak_ptr_factory_(this) {
+ DCHECK(server_socket_);
+ DoAcceptLoop();
+}
+
+HttpServer::~HttpServer() {
+ STLDeleteContainerPairSecondPointers(
+ id_to_connection_.begin(), id_to_connection_.end());
}
void HttpServer::AcceptWebSocket(
@@ -33,9 +44,8 @@ void HttpServer::AcceptWebSocket(
HttpConnection* connection = FindConnection(connection_id);
if (connection == NULL)
return;
-
- DCHECK(connection->web_socket_.get());
- connection->web_socket_->Accept(request);
+ DCHECK(connection->web_socket());
+ connection->web_socket()->Accept(request);
}
void HttpServer::SendOverWebSocket(int connection_id,
@@ -43,23 +53,23 @@ void HttpServer::SendOverWebSocket(int connection_id,
HttpConnection* connection = FindConnection(connection_id);
if (connection == NULL)
return;
- DCHECK(connection->web_socket_.get());
- connection->web_socket_->Send(data);
+ DCHECK(connection->web_socket());
+ connection->web_socket()->Send(data);
}
void HttpServer::SendRaw(int connection_id, const std::string& data) {
HttpConnection* connection = FindConnection(connection_id);
if (connection == NULL)
return;
- connection->Send(data);
+
+ bool writing_in_progress = !connection->write_buf()->IsEmpty();
+ if (connection->write_buf()->Append(data) && !writing_in_progress)
+ DoWriteLoop(connection);
}
void HttpServer::SendResponse(int connection_id,
const HttpServerResponseInfo& response) {
- HttpConnection* connection = FindConnection(connection_id);
- if (connection == NULL)
- return;
- connection->Send(response);
+ SendRaw(connection_id, response.Serialize());
}
void HttpServer::Send(int connection_id,
@@ -67,8 +77,9 @@ void HttpServer::Send(int connection_id,
const std::string& data,
const std::string& content_type) {
HttpServerResponseInfo response(status_code);
- response.SetBody(data, content_type);
+ response.SetContentHeaders(data.size(), content_type);
SendResponse(connection_id, response);
+ SendRaw(connection_id, data);
}
void HttpServer::Send200(int connection_id,
@@ -90,108 +101,209 @@ void HttpServer::Close(int connection_id) {
if (connection == NULL)
return;
- // Initiating close from server-side does not lead to the DidClose call.
- // Do it manually here.
- DidClose(connection->socket_.get());
+ id_to_connection_.erase(connection_id);
+ delegate_->OnClose(connection_id);
+
+ // The call stack might have callbacks which still have the pointer of
+ // connection. Instead of referencing connection with ID all the time,
+ // destroys the connection in next run loop to make sure any pending
+ // callbacks in the call stack return.
+ base::MessageLoopProxy::current()->DeleteSoon(FROM_HERE, connection);
}
int HttpServer::GetLocalAddress(IPEndPoint* address) {
- if (!server_)
- return ERR_SOCKET_NOT_CONNECTED;
- return server_->GetLocalAddress(address);
+ return server_socket_->GetLocalAddress(address);
+}
+
+void HttpServer::SetReceiveBufferSize(int connection_id, int32 size) {
+ HttpConnection* connection = FindConnection(connection_id);
+ DCHECK(connection);
+ connection->read_buf()->set_max_buffer_size(size);
}
-void HttpServer::DidAccept(StreamListenSocket* server,
- scoped_ptr<StreamListenSocket> socket) {
- HttpConnection* connection = new HttpConnection(this, socket.Pass());
+void HttpServer::SetSendBufferSize(int connection_id, int32 size) {
+ HttpConnection* connection = FindConnection(connection_id);
+ DCHECK(connection);
+ connection->write_buf()->set_max_buffer_size(size);
+}
+
+void HttpServer::DoAcceptLoop() {
+ int rv;
+ do {
+ rv = server_socket_->Accept(&accepted_socket_,
+ base::Bind(&HttpServer::OnAcceptCompleted,
+ weak_ptr_factory_.GetWeakPtr()));
+ if (rv == ERR_IO_PENDING)
+ return;
+ rv = HandleAcceptResult(rv);
+ } while (rv == OK);
+}
+
+void HttpServer::OnAcceptCompleted(int rv) {
+ if (HandleAcceptResult(rv) == OK)
+ DoAcceptLoop();
+}
+
+int HttpServer::HandleAcceptResult(int rv) {
+ if (rv < 0) {
+ LOG(ERROR) << "Accept error: rv=" << rv;
+ return rv;
+ }
+
+ HttpConnection* connection =
+ new HttpConnection(++last_id_, accepted_socket_.Pass());
id_to_connection_[connection->id()] = connection;
- // TODO(szym): Fix socket access. Make HttpConnection the Delegate.
- socket_to_connection_[connection->socket_.get()] = connection;
+ DoReadLoop(connection);
+ return OK;
}
-void HttpServer::DidRead(StreamListenSocket* socket,
- const char* data,
- int len) {
- HttpConnection* connection = FindConnection(socket);
- DCHECK(connection != NULL);
- if (connection == NULL)
+void HttpServer::DoReadLoop(HttpConnection* connection) {
+ int rv;
+ do {
+ HttpConnection::ReadIOBuffer* read_buf = connection->read_buf();
+ // Increases read buffer size if necessary.
+ if (read_buf->RemainingCapacity() == 0 && !read_buf->IncreaseCapacity()) {
+ Close(connection->id());
+ return;
+ }
+
+ rv = connection->socket()->Read(
+ read_buf,
+ read_buf->RemainingCapacity(),
+ base::Bind(&HttpServer::OnReadCompleted,
+ weak_ptr_factory_.GetWeakPtr(), connection->id()));
+ if (rv == ERR_IO_PENDING)
+ return;
+ rv = HandleReadResult(connection, rv);
+ } while (rv == OK);
+}
+
+void HttpServer::OnReadCompleted(int connection_id, int rv) {
+ HttpConnection* connection = FindConnection(connection_id);
+ if (!connection) // It might be closed right before by write error.
return;
- connection->recv_data_.append(data, len);
- while (connection->recv_data_.length()) {
- if (connection->web_socket_.get()) {
+ if (HandleReadResult(connection, rv) == OK)
+ DoReadLoop(connection);
+}
+
+int HttpServer::HandleReadResult(HttpConnection* connection, int rv) {
+ if (rv <= 0) {
+ Close(connection->id());
+ return rv == 0 ? ERR_CONNECTION_CLOSED : rv;
+ }
+
+ HttpConnection::ReadIOBuffer* read_buf = connection->read_buf();
+ read_buf->DidRead(rv);
+
+ // Handles http requests or websocket messages.
+ while (read_buf->GetSize() > 0) {
+ if (connection->web_socket()) {
std::string message;
- WebSocket::ParseResult result = connection->web_socket_->Read(&message);
+ WebSocket::ParseResult result = connection->web_socket()->Read(&message);
if (result == WebSocket::FRAME_INCOMPLETE)
break;
if (result == WebSocket::FRAME_CLOSE ||
result == WebSocket::FRAME_ERROR) {
Close(connection->id());
- break;
+ return ERR_CONNECTION_CLOSED;
}
delegate_->OnWebSocketMessage(connection->id(), message);
+ if (HasClosedConnection(connection))
+ return ERR_CONNECTION_CLOSED;
continue;
}
HttpServerRequestInfo request;
size_t pos = 0;
- if (!ParseHeaders(connection, &request, &pos))
+ if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(),
+ &request, &pos)) {
break;
+ }
// Sets peer address if exists.
- socket->GetPeerAddress(&request.peer);
+ connection->socket()->GetPeerAddress(&request.peer);
if (request.HasHeaderValue("connection", "upgrade")) {
- connection->web_socket_.reset(WebSocket::CreateWebSocket(connection,
- request,
- &pos));
-
- if (!connection->web_socket_.get()) // Not enough data was received.
+ scoped_ptr<WebSocket> websocket(
+ WebSocket::CreateWebSocket(this, connection, request, &pos));
+ if (!websocket) // Not enough data was received.
break;
+ connection->SetWebSocket(websocket.Pass());
+ read_buf->DidConsume(pos);
delegate_->OnWebSocketRequest(connection->id(), request);
- connection->Shift(pos);
+ if (HasClosedConnection(connection))
+ return ERR_CONNECTION_CLOSED;
continue;
}
const char kContentLength[] = "content-length";
- if (request.headers.count(kContentLength)) {
+ if (request.headers.count(kContentLength) > 0) {
size_t content_length = 0;
const size_t kMaxBodySize = 100 << 20;
if (!base::StringToSizeT(request.GetHeaderValue(kContentLength),
&content_length) ||
content_length > kMaxBodySize) {
- connection->Send(HttpServerResponseInfo::CreateFor500(
- "request content-length too big or unknown: " +
- request.GetHeaderValue(kContentLength)));
- DidClose(socket);
- break;
+ SendResponse(connection->id(),
+ HttpServerResponseInfo::CreateFor500(
+ "request content-length too big or unknown: " +
+ request.GetHeaderValue(kContentLength)));
+ Close(connection->id());
+ return ERR_CONNECTION_CLOSED;
}
- if (connection->recv_data_.length() - pos < content_length)
+ if (read_buf->GetSize() - pos < content_length)
break; // Not enough data was received yet.
- request.data = connection->recv_data_.substr(pos, content_length);
+ request.data.assign(read_buf->StartOfBuffer() + pos, content_length);
pos += content_length;
}
+ read_buf->DidConsume(pos);
delegate_->OnHttpRequest(connection->id(), request);
- connection->Shift(pos);
+ if (HasClosedConnection(connection))
+ return ERR_CONNECTION_CLOSED;
}
+
+ return OK;
}
-void HttpServer::DidClose(StreamListenSocket* socket) {
- HttpConnection* connection = FindConnection(socket);
- DCHECK(connection != NULL);
- id_to_connection_.erase(connection->id());
- socket_to_connection_.erase(connection->socket_.get());
- delete connection;
+void HttpServer::DoWriteLoop(HttpConnection* connection) {
+ int rv = OK;
+ HttpConnection::QueuedWriteIOBuffer* write_buf = connection->write_buf();
+ while (rv == OK && write_buf->GetSizeToWrite() > 0) {
+ rv = connection->socket()->Write(
+ write_buf,
+ write_buf->GetSizeToWrite(),
+ base::Bind(&HttpServer::OnWriteCompleted,
+ weak_ptr_factory_.GetWeakPtr(), connection->id()));
+ if (rv == ERR_IO_PENDING || rv == OK)
+ return;
+ rv = HandleWriteResult(connection, rv);
+ }
}
-HttpServer::~HttpServer() {
- STLDeleteContainerPairSecondPointers(
- id_to_connection_.begin(), id_to_connection_.end());
+void HttpServer::OnWriteCompleted(int connection_id, int rv) {
+ HttpConnection* connection = FindConnection(connection_id);
+ if (!connection) // It might be closed right before by read error.
+ return;
+
+ if (HandleWriteResult(connection, rv) == OK)
+ DoWriteLoop(connection);
}
+int HttpServer::HandleWriteResult(HttpConnection* connection, int rv) {
+ if (rv < 0) {
+ Close(connection->id());
+ return rv;
+ }
+
+ connection->write_buf()->DidConsume(rv);
+ return OK;
+}
+
+namespace {
+
//
// HTTP Request Parser
// This HTTP request parser uses a simple state machine to quickly parse
@@ -255,17 +367,19 @@ int charToInput(char ch) {
return INPUT_DEFAULT;
}
-bool HttpServer::ParseHeaders(HttpConnection* connection,
+} // namespace
+
+bool HttpServer::ParseHeaders(const char* data,
+ size_t data_len,
HttpServerRequestInfo* info,
size_t* ppos) {
size_t& pos = *ppos;
- size_t data_len = connection->recv_data_.length();
int state = ST_METHOD;
std::string buffer;
std::string header_name;
std::string header_value;
while (pos < data_len) {
- char ch = connection->recv_data_[pos++];
+ char ch = data[pos++];
int input = charToInput(ch);
int next_state = parser_state[state][input];
@@ -337,11 +451,12 @@ HttpConnection* HttpServer::FindConnection(int connection_id) {
return it->second;
}
-HttpConnection* HttpServer::FindConnection(StreamListenSocket* socket) {
- SocketToConnectionMap::iterator it = socket_to_connection_.find(socket);
- if (it == socket_to_connection_.end())
- return NULL;
- return it->second;
+// This is called after any delegate callbacks are called to check if Close()
+// has been called during callback processing. Using the pointer of connection,
+// |connection| is safe here because Close() deletes the connection in next run
+// loop.
+bool HttpServer::HasClosedConnection(HttpConnection* connection) {
+ return FindConnection(connection->id()) != connection;
}
} // namespace net
diff --git a/net/server/http_server.h b/net/server/http_server.h
index 4309d122..2ae698b 100644
--- a/net/server/http_server.h
+++ b/net/server/http_server.h
@@ -5,13 +5,14 @@
#ifndef NET_SERVER_HTTP_SERVER_H_
#define NET_SERVER_HTTP_SERVER_H_
-#include <list>
#include <map>
+#include <string>
#include "base/basictypes.h"
+#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
#include "net/http/http_status_code.h"
-#include "net/socket/stream_listen_socket.h"
namespace net {
@@ -19,30 +20,28 @@ class HttpConnection;
class HttpServerRequestInfo;
class HttpServerResponseInfo;
class IPEndPoint;
+class ServerSocket;
+class StreamSocket;
class WebSocket;
-class HttpServer : public StreamListenSocket::Delegate,
- public base::RefCountedThreadSafe<HttpServer> {
+class HttpServer {
public:
+ // Delegate to handle http/websocket events. Beware that it is not safe to
+ // destroy the HttpServer in any of these callbacks.
class Delegate {
public:
virtual void OnHttpRequest(int connection_id,
const HttpServerRequestInfo& info) = 0;
-
virtual void OnWebSocketRequest(int connection_id,
const HttpServerRequestInfo& info) = 0;
-
virtual void OnWebSocketMessage(int connection_id,
const std::string& data) = 0;
-
virtual void OnClose(int connection_id) = 0;
-
- protected:
- virtual ~Delegate() {}
};
- HttpServer(const StreamListenSocketFactory& socket_factory,
+ HttpServer(scoped_ptr<ServerSocket> server_socket,
HttpServer::Delegate* delegate);
+ ~HttpServer();
void AcceptWebSocket(int connection_id,
const HttpServerRequestInfo& request);
@@ -51,6 +50,7 @@ class HttpServer : public StreamListenSocket::Delegate,
// performed that data constitutes a valid HTTP response. A valid HTTP
// response may be split across multiple calls to SendRaw.
void SendRaw(int connection_id, const std::string& data);
+ // TODO(byungchul): Consider replacing function name with SendResponseInfo
void SendResponse(int connection_id, const HttpServerResponseInfo& response);
void Send(int connection_id,
HttpStatusCode status_code,
@@ -64,40 +64,50 @@ class HttpServer : public StreamListenSocket::Delegate,
void Close(int connection_id);
+ void SetReceiveBufferSize(int connection_id, int32 size);
+ void SetSendBufferSize(int connection_id, int32 size);
+
// Copies the local address to |address|. Returns a network error code.
int GetLocalAddress(IPEndPoint* address);
- // ListenSocketDelegate
- virtual void DidAccept(StreamListenSocket* server,
- scoped_ptr<StreamListenSocket> socket) OVERRIDE;
- virtual void DidRead(StreamListenSocket* socket,
- const char* data,
- int len) OVERRIDE;
- virtual void DidClose(StreamListenSocket* socket) OVERRIDE;
+ private:
+ friend class HttpServerTest;
+
+ typedef std::map<int, HttpConnection*> IdToConnectionMap;
- protected:
- virtual ~HttpServer();
+ void DoAcceptLoop();
+ void OnAcceptCompleted(int rv);
+ int HandleAcceptResult(int rv);
- private:
- friend class base::RefCountedThreadSafe<HttpServer>;
- friend class HttpConnection;
+ void DoReadLoop(HttpConnection* connection);
+ void OnReadCompleted(int connection_id, int rv);
+ int HandleReadResult(HttpConnection* connection, int rv);
+
+ void DoWriteLoop(HttpConnection* connection);
+ void OnWriteCompleted(int connection_id, int rv);
+ int HandleWriteResult(HttpConnection* connection, int rv);
// Expects the raw data to be stored in recv_data_. If parsing is successful,
// will remove the data parsed from recv_data_, leaving only the unused
// recv data.
- bool ParseHeaders(HttpConnection* connection,
+ bool ParseHeaders(const char* data,
+ size_t data_len,
HttpServerRequestInfo* info,
size_t* pos);
HttpConnection* FindConnection(int connection_id);
- HttpConnection* FindConnection(StreamListenSocket* socket);
- HttpServer::Delegate* delegate_;
- scoped_ptr<StreamListenSocket> server_;
- typedef std::map<int, HttpConnection*> IdToConnectionMap;
+ // Whether or not Close() has been called during delegate callback processing.
+ bool HasClosedConnection(HttpConnection* connection);
+
+ const scoped_ptr<ServerSocket> server_socket_;
+ scoped_ptr<StreamSocket> accepted_socket_;
+ HttpServer::Delegate* const delegate_;
+
+ int last_id_;
IdToConnectionMap id_to_connection_;
- typedef std::map<StreamListenSocket*, HttpConnection*> SocketToConnectionMap;
- SocketToConnectionMap socket_to_connection_;
+
+ base::WeakPtrFactory<HttpServer> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(HttpServer);
};
diff --git a/net/server/http_server_response_info.cc b/net/server/http_server_response_info.cc
index e4c6043..2d0a32e 100644
--- a/net/server/http_server_response_info.cc
+++ b/net/server/http_server_response_info.cc
@@ -41,8 +41,14 @@ void HttpServerResponseInfo::SetBody(const std::string& body,
const std::string& content_type) {
DCHECK(body_.empty());
body_ = body;
+ SetContentHeaders(body.length(), content_type);
+}
+
+void HttpServerResponseInfo::SetContentHeaders(
+ size_t content_length,
+ const std::string& content_type) {
AddHeader(HttpRequestHeaders::kContentLength,
- base::StringPrintf("%" PRIuS, body.length()));
+ base::StringPrintf("%" PRIuS, content_length));
AddHeader(HttpRequestHeaders::kContentType, content_type);
}
diff --git a/net/server/http_server_response_info.h b/net/server/http_server_response_info.h
index d6cedaa..bbb76d8 100644
--- a/net/server/http_server_response_info.h
+++ b/net/server/http_server_response_info.h
@@ -27,6 +27,9 @@ class HttpServerResponseInfo {
// This also adds an appropriate Content-Length header.
void SetBody(const std::string& body, const std::string& content_type);
+ // Sets content-length and content-type. Body should be sent separately.
+ void SetContentHeaders(size_t content_length,
+ const std::string& content_type);
std::string Serialize() const;
diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc
index a492cf1..5783ca2 100644
--- a/net/server/http_server_unittest.cc
+++ b/net/server/http_server_unittest.cc
@@ -2,11 +2,13 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include <algorithm>
#include <utility>
#include <vector>
#include "base/bind.h"
#include "base/bind_helpers.h"
+#include "base/callback_helpers.h"
#include "base/compiler_specific.h"
#include "base/format_macros.h"
#include "base/memory/ref_counted.h"
@@ -24,11 +26,14 @@
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/net_log.h"
+#include "net/base/net_util.h"
#include "net/base/test_completion_callback.h"
+#include "net/http/http_response_headers.h"
+#include "net/http/http_util.h"
#include "net/server/http_server.h"
#include "net/server/http_server_request_info.h"
#include "net/socket/tcp_client_socket.h"
-#include "net/socket/tcp_listen_socket.h"
+#include "net/socket/tcp_server_socket.h"
#include "net/url_request/url_fetcher.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "net/url_request/url_request_context.h"
@@ -90,10 +95,6 @@ class TestHttpClient {
Write();
}
- bool Read(std::string* message) {
- return Read(message, 1);
- }
-
bool Read(std::string* message, int expected_bytes) {
int total_bytes_received = 0;
message->clear();
@@ -110,6 +111,18 @@ class TestHttpClient {
return true;
}
+ bool ReadResponse(std::string* message) {
+ if (!Read(message, 1))
+ return false;
+ while (!IsCompleteResponse(*message)) {
+ std::string chunk;
+ if (!Read(&chunk, 1))
+ return false;
+ message->append(chunk);
+ }
+ return true;
+ }
+
private:
void OnConnect(const base::Closure& quit_loop, int result) {
connect_result_ = result;
@@ -141,6 +154,21 @@ class TestHttpClient {
callback.Run(result);
}
+ bool IsCompleteResponse(const std::string& response) {
+ // Check end of headers first.
+ int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(),
+ response.size());
+ if (end_of_headers < 0)
+ return false;
+
+ // Return true if response has data equal to or more than content length.
+ int64 body_size = static_cast<int64>(response.size()) - end_of_headers;
+ DCHECK_LE(0, body_size);
+ scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders(
+ HttpUtil::AssembleRawHeaders(response.data(), end_of_headers)));
+ return body_size >= headers->GetContentLength();
+ }
+
scoped_refptr<IOBufferWithSize> read_buffer_;
scoped_refptr<DrainableIOBuffer> write_buffer_;
scoped_ptr<TCPClientSocket> socket_;
@@ -155,8 +183,10 @@ class HttpServerTest : public testing::Test,
HttpServerTest() : quit_after_request_count_(0) {}
virtual void SetUp() OVERRIDE {
- TCPListenSocketFactory socket_factory("127.0.0.1", 0);
- server_ = new HttpServer(socket_factory, this);
+ scoped_ptr<ServerSocket> server_socket(
+ new TCPServerSocket(NULL, net::NetLog::Source()));
+ server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
+ server_.reset(new HttpServer(server_socket.Pass(), this));
ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
}
@@ -199,8 +229,13 @@ class HttpServerTest : public testing::Test,
return requests_[request_index].second;
}
+ void HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
+ server_->accepted_socket_.reset(socket.release());
+ server_->HandleAcceptResult(OK);
+ }
+
protected:
- scoped_refptr<HttpServer> server_;
+ scoped_ptr<HttpServer> server_;
IPEndPoint server_address_;
base::Closure run_loop_quit_func_;
std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
@@ -407,7 +442,7 @@ TEST_F(HttpServerTest, Send200) {
server_->Send200(GetConnectionId(0), "Response!", "text/plain");
std::string response;
- ASSERT_TRUE(client.Read(&response));
+ ASSERT_TRUE(client.ReadResponse(&response));
ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
ASSERT_TRUE(EndsWith(response, "Response!", true));
}
@@ -429,23 +464,105 @@ TEST_F(HttpServerTest, SendRaw) {
namespace {
-class MockStreamListenSocket : public StreamListenSocket {
+class MockStreamSocket : public StreamSocket {
public:
- MockStreamListenSocket(StreamListenSocket::Delegate* delegate)
- : StreamListenSocket(kInvalidSocket, delegate) {}
+ MockStreamSocket()
+ : connected_(true),
+ read_buf_(NULL),
+ read_buf_len_(0) {}
+
+ // StreamSocket
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ return ERR_NOT_IMPLEMENTED;
+ }
+ virtual void Disconnect() OVERRIDE {
+ connected_ = false;
+ if (!read_callback_.is_null()) {
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
+ }
+ }
+ virtual bool IsConnected() const OVERRIDE { return connected_; }
+ virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); }
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ return ERR_NOT_IMPLEMENTED;
+ }
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ return ERR_NOT_IMPLEMENTED;
+ }
+ virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+ virtual bool WasEverUsed() const OVERRIDE { return true; }
+ virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
+ virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
+
+ // Socket
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ if (!connected_) {
+ return ERR_SOCKET_NOT_CONNECTED;
+ }
+ if (pending_read_data_.empty()) {
+ read_buf_ = buf;
+ read_buf_len_ = buf_len;
+ read_callback_ = callback;
+ return ERR_IO_PENDING;
+ }
+ DCHECK_GT(buf_len, 0);
+ int read_len = std::min(static_cast<int>(pending_read_data_.size()),
+ buf_len);
+ memcpy(buf->data(), pending_read_data_.data(), read_len);
+ pending_read_data_.erase(0, read_len);
+ return read_len;
+ }
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_NOT_IMPLEMENTED;
+ }
+ virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
+ return ERR_NOT_IMPLEMENTED;
+ }
+ virtual int SetSendBufferSize(int32 size) OVERRIDE {
+ return ERR_NOT_IMPLEMENTED;
+ }
- virtual void Accept() OVERRIDE { NOTREACHED(); }
+ void DidRead(const char* data, int data_len) {
+ if (!read_buf_) {
+ pending_read_data_.append(data, data_len);
+ return;
+ }
+ int read_len = std::min(data_len, read_buf_len_);
+ memcpy(read_buf_->data(), data, read_len);
+ pending_read_data_.assign(data + read_len, data_len - read_len);
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ base::ResetAndReturn(&read_callback_).Run(read_len);
+ }
private:
- virtual ~MockStreamListenSocket() {}
+ virtual ~MockStreamSocket() {}
+
+ bool connected_;
+ scoped_refptr<IOBuffer> read_buf_;
+ int read_buf_len_;
+ CompletionCallback read_callback_;
+ std::string pending_read_data_;
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockStreamSocket);
};
} // namespace
TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
- StreamListenSocket* socket =
- new MockStreamListenSocket(server_.get());
- server_->DidAccept(NULL, make_scoped_ptr(socket));
+ MockStreamSocket* socket = new MockStreamSocket();
+ HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket));
std::string body("body");
std::string request_text = base::StringPrintf(
"GET /test HTTP/1.1\r\n"
@@ -453,9 +570,9 @@ TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
"Content-Length: %" PRIuS "\r\n\r\n%s",
body.length(),
body.c_str());
- server_->DidRead(socket, request_text.c_str(), request_text.length() - 2);
+ socket->DidRead(request_text.c_str(), request_text.length() - 2);
ASSERT_EQ(0u, requests_.size());
- server_->DidRead(socket, request_text.c_str() + request_text.length() - 2, 2);
+ socket->DidRead(request_text.c_str() + request_text.length() - 2, 2);
ASSERT_EQ(1u, requests_.size());
ASSERT_EQ(body, GetRequest(0).data);
}
@@ -477,7 +594,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
int client_connection_id = GetConnectionId(0);
server_->Send200(client_connection_id, "Content for /test", "text/plain");
std::string response1;
- ASSERT_TRUE(client.Read(&response1));
+ ASSERT_TRUE(client.ReadResponse(&response1));
ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
@@ -488,7 +605,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
ASSERT_EQ(client_connection_id, GetConnectionId(1));
server_->Send404(client_connection_id);
std::string response2;
- ASSERT_TRUE(client.Read(&response2));
+ ASSERT_TRUE(client.ReadResponse(&response2));
ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
client.Send("GET /test3 HTTP/1.1\r\n\r\n");
@@ -498,12 +615,9 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
ASSERT_EQ(client_connection_id, GetConnectionId(2));
server_->Send200(client_connection_id, "Content for /test3", "text/plain");
std::string response3;
- ASSERT_TRUE(client.Read(&response3));
+ ASSERT_TRUE(client.ReadResponse(&response3));
ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
-#if 0
- // TODO(byungchul): Figure out why it fails in windows build bot.
ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
-#endif
}
} // namespace net
diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc
index f06b425..ec0fdac 100644
--- a/net/server/web_socket.cc
+++ b/net/server/web_socket.cc
@@ -15,6 +15,7 @@
#include "base/strings/stringprintf.h"
#include "base/sys_byteorder.h"
#include "net/server/http_connection.h"
+#include "net/server/http_server.h"
#include "net/server/http_server_request_info.h"
#include "net/server/http_server_response_info.h"
@@ -43,12 +44,14 @@ static uint32 WebSocketKeyFingerprint(const std::string& str) {
class WebSocketHixie76 : public net::WebSocket {
public:
- static net::WebSocket* Create(HttpConnection* connection,
+ static net::WebSocket* Create(HttpServer* server,
+ HttpConnection* connection,
const HttpServerRequestInfo& request,
size_t* pos) {
- if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen)
+ if (connection->read_buf()->GetSize() <
+ static_cast<int>(*pos + kWebSocketHandshakeBodyLen))
return NULL;
- return new WebSocketHixie76(connection, request, pos);
+ return new WebSocketHixie76(server, connection, request, pos);
}
virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE {
@@ -69,31 +72,33 @@ class WebSocketHixie76 : public net::WebSocket {
std::string origin = request.GetHeaderValue("origin");
std::string host = request.GetHeaderValue("host");
std::string location = "ws://" + host + request.path;
- connection_->Send(base::StringPrintf(
- "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Sec-WebSocket-Origin: %s\r\n"
- "Sec-WebSocket-Location: %s\r\n"
- "\r\n",
- origin.c_str(),
- location.c_str()));
- connection_->Send(reinterpret_cast<char*>(digest.a), 16);
+ server_->SendRaw(
+ connection_->id(),
+ base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Origin: %s\r\n"
+ "Sec-WebSocket-Location: %s\r\n"
+ "\r\n",
+ origin.c_str(),
+ location.c_str()));
+ server_->SendRaw(connection_->id(),
+ std::string(reinterpret_cast<char*>(digest.a), 16));
}
virtual ParseResult Read(std::string* message) OVERRIDE {
DCHECK(message);
- const std::string& data = connection_->recv_data();
- if (data[0])
+ HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
+ if (read_buf->StartOfBuffer()[0])
return FRAME_ERROR;
+ base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize());
size_t pos = data.find('\377', 1);
- if (pos == std::string::npos)
+ if (pos == base::StringPiece::npos)
return FRAME_INCOMPLETE;
- std::string buffer(data.begin() + 1, data.begin() + pos);
- message->swap(buffer);
- connection_->Shift(pos + 1);
+ message->assign(data.data() + 1, pos - 1);
+ read_buf->DidConsume(pos + 1);
return FRAME_OK;
}
@@ -101,37 +106,42 @@ class WebSocketHixie76 : public net::WebSocket {
virtual void Send(const std::string& message) OVERRIDE {
char message_start = 0;
char message_end = -1;
- connection_->Send(&message_start, 1);
- connection_->Send(message);
- connection_->Send(&message_end, 1);
+ server_->SendRaw(connection_->id(), std::string(1, message_start));
+ server_->SendRaw(connection_->id(), message);
+ server_->SendRaw(connection_->id(), std::string(1, message_end));
}
private:
static const int kWebSocketHandshakeBodyLen;
- WebSocketHixie76(HttpConnection* connection,
+ WebSocketHixie76(HttpServer* server,
+ HttpConnection* connection,
const HttpServerRequestInfo& request,
- size_t* pos) : WebSocket(connection) {
+ size_t* pos)
+ : WebSocket(server, connection) {
std::string key1 = request.GetHeaderValue("sec-websocket-key1");
std::string key2 = request.GetHeaderValue("sec-websocket-key2");
if (key1.empty()) {
- connection->Send(HttpServerResponseInfo::CreateFor500(
- "Invalid request format. Sec-WebSocket-Key1 is empty or isn't "
- "specified."));
+ server->SendResponse(
+ connection->id(),
+ HttpServerResponseInfo::CreateFor500(
+ "Invalid request format. Sec-WebSocket-Key1 is empty or isn't "
+ "specified."));
return;
}
if (key2.empty()) {
- connection->Send(HttpServerResponseInfo::CreateFor500(
- "Invalid request format. Sec-WebSocket-Key2 is empty or isn't "
- "specified."));
+ server->SendResponse(
+ connection->id(),
+ HttpServerResponseInfo::CreateFor500(
+ "Invalid request format. Sec-WebSocket-Key2 is empty or isn't "
+ "specified."));
return;
}
- key3_ = connection->recv_data().substr(
- *pos,
- *pos + kWebSocketHandshakeBodyLen);
+ key3_.assign(connection->read_buf()->StartOfBuffer() + *pos,
+ kWebSocketHandshakeBodyLen);
*pos += kWebSocketHandshakeBodyLen;
}
@@ -169,7 +179,8 @@ const size_t kMaskingKeyWidthInBytes = 4;
class WebSocketHybi17 : public WebSocket {
public:
- static WebSocket* Create(HttpConnection* connection,
+ static WebSocket* Create(HttpServer* server,
+ HttpConnection* connection,
const HttpServerRequestInfo& request,
size_t* pos) {
std::string version = request.GetHeaderValue("sec-websocket-version");
@@ -178,12 +189,14 @@ class WebSocketHybi17 : public WebSocket {
std::string key = request.GetHeaderValue("sec-websocket-key");
if (key.empty()) {
- connection->Send(HttpServerResponseInfo::CreateFor500(
- "Invalid request format. Sec-WebSocket-Key is empty or isn't "
- "specified."));
+ server->SendResponse(
+ connection->id(),
+ HttpServerResponseInfo::CreateFor500(
+ "Invalid request format. Sec-WebSocket-Key is empty or isn't "
+ "specified."));
return NULL;
}
- return new WebSocketHybi17(connection, request, pos);
+ return new WebSocketHybi17(server, connection, request, pos);
}
virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE {
@@ -194,24 +207,24 @@ class WebSocketHybi17 : public WebSocket {
std::string encoded_hash;
base::Base64Encode(base::SHA1HashString(data), &encoded_hash);
- std::string response = base::StringPrintf(
- "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Sec-WebSocket-Accept: %s\r\n"
- "\r\n",
- encoded_hash.c_str());
- connection_->Send(response);
+ server_->SendRaw(
+ connection_->id(),
+ base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: %s\r\n"
+ "\r\n",
+ encoded_hash.c_str()));
}
virtual ParseResult Read(std::string* message) OVERRIDE {
- const std::string& frame = connection_->recv_data();
+ HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
+ base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize());
int bytes_consumed = 0;
-
ParseResult result =
WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message);
if (result == FRAME_OK)
- connection_->Shift(bytes_consumed);
+ read_buf->DidConsume(bytes_consumed);
if (result == FRAME_CLOSE)
closed_ = true;
return result;
@@ -220,25 +233,26 @@ class WebSocketHybi17 : public WebSocket {
virtual void Send(const std::string& message) OVERRIDE {
if (closed_)
return;
- std::string data = WebSocket::EncodeFrameHybi17(message, 0);
- connection_->Send(data);
+ server_->SendRaw(connection_->id(),
+ WebSocket::EncodeFrameHybi17(message, 0));
}
private:
- WebSocketHybi17(HttpConnection* connection,
+ WebSocketHybi17(HttpServer* server,
+ HttpConnection* connection,
const HttpServerRequestInfo& request,
size_t* pos)
- : WebSocket(connection),
- op_code_(0),
- final_(false),
- reserved1_(false),
- reserved2_(false),
- reserved3_(false),
- masked_(false),
- payload_(0),
- payload_length_(0),
- frame_end_(0),
- closed_(false) {
+ : WebSocket(server, connection),
+ op_code_(0),
+ final_(false),
+ reserved1_(false),
+ reserved2_(false),
+ reserved3_(false),
+ masked_(false),
+ payload_(0),
+ payload_length_(0),
+ frame_end_(0),
+ closed_(false) {
}
OpCode op_code_;
@@ -257,21 +271,23 @@ class WebSocketHybi17 : public WebSocket {
} // anonymous namespace
-WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection,
+WebSocket* WebSocket::CreateWebSocket(HttpServer* server,
+ HttpConnection* connection,
const HttpServerRequestInfo& request,
size_t* pos) {
- WebSocket* socket = WebSocketHybi17::Create(connection, request, pos);
+ WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos);
if (socket)
return socket;
- return WebSocketHixie76::Create(connection, request, pos);
+ return WebSocketHixie76::Create(server, connection, request, pos);
}
// static
-WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame,
- bool client_frame,
- int* bytes_consumed,
- std::string* output) {
+WebSocket::ParseResult WebSocket::DecodeFrameHybi17(
+ const base::StringPiece& frame,
+ bool client_frame,
+ int* bytes_consumed,
+ std::string* output) {
size_t data_length = frame.length();
if (data_length < 2)
return FRAME_INCOMPLETE;
@@ -349,8 +365,7 @@ WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame,
for (size_t i = 0; i < payload_length; ++i) // Unmask the payload.
(*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
} else {
- std::string buffer(p, p + payload_length);
- output->swap(buffer);
+ output->assign(p, p + payload_length);
}
size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
@@ -400,7 +415,9 @@ std::string WebSocket::EncodeFrameHybi17(const std::string& message,
return std::string(&frame[0], frame.size());
}
-WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) {
+WebSocket::WebSocket(HttpServer* server, HttpConnection* connection)
+ : server_(server),
+ connection_(connection) {
}
} // namespace net
diff --git a/net/server/web_socket.h b/net/server/web_socket.h
index 49ced84..9b3a794 100644
--- a/net/server/web_socket.h
+++ b/net/server/web_socket.h
@@ -8,10 +8,12 @@
#include <string>
#include "base/basictypes.h"
+#include "base/strings/string_piece.h"
namespace net {
class HttpConnection;
+class HttpServer;
class HttpServerRequestInfo;
class WebSocket {
@@ -23,11 +25,12 @@ class WebSocket {
FRAME_ERROR
};
- static WebSocket* CreateWebSocket(HttpConnection* connection,
+ static WebSocket* CreateWebSocket(HttpServer* server,
+ HttpConnection* connection,
const HttpServerRequestInfo& request,
size_t* pos);
- static ParseResult DecodeFrameHybi17(const std::string& frame,
+ static ParseResult DecodeFrameHybi17(const base::StringPiece& frame,
bool client_frame,
int* bytes_consumed,
std::string* output);
@@ -41,8 +44,10 @@ class WebSocket {
virtual ~WebSocket() {}
protected:
- explicit WebSocket(HttpConnection* connection);
- HttpConnection* connection_;
+ WebSocket(HttpServer* server, HttpConnection* connection);
+
+ HttpServer* const server_;
+ HttpConnection* const connection_;
private:
DISALLOW_COPY_AND_ASSIGN(WebSocket);
diff --git a/net/socket/server_socket.h b/net/socket/server_socket.h
index 528955b..4b9ca8e 100644
--- a/net/socket/server_socket.h
+++ b/net/socket/server_socket.h
@@ -21,7 +21,7 @@ class NET_EXPORT ServerSocket {
ServerSocket();
virtual ~ServerSocket();
- // Binds the socket and starts listening. Destroy the socket to stop
+ // Binds the socket and starts listening. Destroys the socket to stop
// listening.
virtual int Listen(const IPEndPoint& address, int backlog) = 0;