diff options
author | byungchul@chromium.org <byungchul@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-08-22 18:10:13 +0000 |
---|---|---|
committer | byungchul@chromium.org <byungchul@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-08-22 18:12:05 +0000 |
commit | 3626bc39d629245a782c61ba0bc668583778e392 (patch) | |
tree | 7156999fdfb9044dd14d5463568b8f6ce01543fe /net | |
parent | 1c2f071e2a14a39b215a897434d30834ae959560 (diff) | |
download | chromium_src-3626bc39d629245a782c61ba0bc668583778e392.zip chromium_src-3626bc39d629245a782c61ba0bc668583778e392.tar.gz chromium_src-3626bc39d629245a782c61ba0bc668583778e392.tar.bz2 |
Replace StreamListenSocket with StreamSocket in HttpServer.
1) HttpServer gets ServerSocket instead of StreamListenSocket.
2) HttpConnection is just a container for socket, websocket, and pending read/write buffers.
3) HttpServer handles data buffering and asynchronous read/write.
4) HttpConnection has limit in data buffering, up to 1Mbytes by default.
5) For devtools, send buffer limit is 100Mbytes.
6) Unittests for buffer handling in HttpConnection.
BUG=371906
Review URL: https://codereview.chromium.org/296053012
Cr-Commit-Position: refs/heads/master@{#291447}
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@291447 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/net.gypi | 1 | ||||
-rw-r--r-- | net/server/http_connection.cc | 161 | ||||
-rw-r--r-- | net/server/http_connection.h | 123 | ||||
-rw-r--r-- | net/server/http_connection_unittest.cc | 331 | ||||
-rw-r--r-- | net/server/http_server.cc | 257 | ||||
-rw-r--r-- | net/server/http_server.h | 70 | ||||
-rw-r--r-- | net/server/http_server_response_info.cc | 8 | ||||
-rw-r--r-- | net/server/http_server_response_info.h | 3 | ||||
-rw-r--r-- | net/server/http_server_unittest.cc | 120 | ||||
-rw-r--r-- | net/server/web_socket.cc | 163 | ||||
-rw-r--r-- | net/server/web_socket.h | 13 | ||||
-rw-r--r-- | net/socket/server_socket.h | 2 |
12 files changed, 1019 insertions, 233 deletions
diff --git a/net/net.gypi b/net/net.gypi index f240eba..b57d853 100644 --- a/net/net.gypi +++ b/net/net.gypi @@ -1555,6 +1555,7 @@ 'quic/quic_utils_test.cc', 'quic/quic_write_blocked_list_test.cc', 'quic/reliable_quic_stream_test.cc', + 'server/http_connection_unittest.cc', 'server/http_server_response_info_unittest.cc', 'server/http_server_unittest.cc', 'socket/client_socket_pool_base_unittest.cc', diff --git a/net/server/http_connection.cc b/net/server/http_connection.cc index d433012..3401f81 100644 --- a/net/server/http_connection.cc +++ b/net/server/http_connection.cc @@ -4,44 +4,163 @@ #include "net/server/http_connection.h" -#include "net/server/http_server.h" -#include "net/server/http_server_response_info.h" +#include "base/logging.h" #include "net/server/web_socket.h" -#include "net/socket/stream_listen_socket.h" +#include "net/socket/stream_socket.h" namespace net { -int HttpConnection::last_id_ = 0; +HttpConnection::ReadIOBuffer::ReadIOBuffer() + : base_(new GrowableIOBuffer()), + max_buffer_size_(kDefaultMaxBufferSize) { + SetCapacity(kInitialBufSize); +} -void HttpConnection::Send(const std::string& data) { - if (!socket_.get()) - return; - socket_->Send(data); +HttpConnection::ReadIOBuffer::~ReadIOBuffer() { + data_ = NULL; // base_ owns data_. +} + +int HttpConnection::ReadIOBuffer::GetCapacity() const { + return base_->capacity(); +} + +void HttpConnection::ReadIOBuffer::SetCapacity(int capacity) { + DCHECK_LE(GetSize(), capacity); + base_->SetCapacity(capacity); + data_ = base_->data(); +} + +bool HttpConnection::ReadIOBuffer::IncreaseCapacity() { + if (GetCapacity() >= max_buffer_size_) { + LOG(ERROR) << "Too large read data is pending: capacity=" << GetCapacity() + << ", max_buffer_size=" << max_buffer_size_ + << ", read=" << GetSize(); + return false; + } + + int new_capacity = GetCapacity() * kCapacityIncreaseFactor; + if (new_capacity > max_buffer_size_) + new_capacity = max_buffer_size_; + SetCapacity(new_capacity); + return true; +} + +char* HttpConnection::ReadIOBuffer::StartOfBuffer() const { + return base_->StartOfBuffer(); +} + +int HttpConnection::ReadIOBuffer::GetSize() const { + return base_->offset(); +} + +void HttpConnection::ReadIOBuffer::DidRead(int bytes) { + DCHECK_GE(RemainingCapacity(), bytes); + base_->set_offset(base_->offset() + bytes); + data_ = base_->data(); +} + +int HttpConnection::ReadIOBuffer::RemainingCapacity() const { + return base_->RemainingCapacity(); +} + +void HttpConnection::ReadIOBuffer::DidConsume(int bytes) { + int previous_size = GetSize(); + int unconsumed_size = previous_size - bytes; + DCHECK_LE(0, unconsumed_size); + if (unconsumed_size > 0) { + // Move unconsumed data to the start of buffer. + memmove(StartOfBuffer(), StartOfBuffer() + bytes, unconsumed_size); + } + base_->set_offset(unconsumed_size); + data_ = base_->data(); + + // If capacity is too big, reduce it. + if (GetCapacity() > kMinimumBufSize && + GetCapacity() > previous_size * kCapacityIncreaseFactor) { + int new_capacity = GetCapacity() / kCapacityIncreaseFactor; + if (new_capacity < kMinimumBufSize) + new_capacity = kMinimumBufSize; + // realloc() within GrowableIOBuffer::SetCapacity() could move data even + // when size is reduced. If unconsumed_size == 0, i.e. no data exists in + // the buffer, free internal buffer first to guarantee no data move. + if (!unconsumed_size) + base_->SetCapacity(0); + SetCapacity(new_capacity); + } +} + +HttpConnection::QueuedWriteIOBuffer::QueuedWriteIOBuffer() + : total_size_(0), + max_buffer_size_(kDefaultMaxBufferSize) { +} + +HttpConnection::QueuedWriteIOBuffer::~QueuedWriteIOBuffer() { + data_ = NULL; // pending_data_ owns data_. } -void HttpConnection::Send(const char* bytes, int len) { - if (!socket_.get()) +bool HttpConnection::QueuedWriteIOBuffer::IsEmpty() const { + return pending_data_.empty(); +} + +bool HttpConnection::QueuedWriteIOBuffer::Append(const std::string& data) { + if (data.empty()) + return true; + + if (total_size_ + static_cast<int>(data.size()) > max_buffer_size_) { + LOG(ERROR) << "Too large write data is pending: size=" + << total_size_ + data.size() + << ", max_buffer_size=" << max_buffer_size_; + return false; + } + + pending_data_.push(data); + total_size_ += data.size(); + + // If new data is the first pending data, updates data_. + if (pending_data_.size() == 1) + data_ = const_cast<char*>(pending_data_.front().data()); + return true; +} + +void HttpConnection::QueuedWriteIOBuffer::DidConsume(int size) { + DCHECK_GE(total_size_, size); + DCHECK_GE(GetSizeToWrite(), size); + if (size == 0) return; - socket_->Send(bytes, len); + + if (size < GetSizeToWrite()) { + data_ += size; + } else { // size == GetSizeToWrite(). Updates data_ to next pending data. + pending_data_.pop(); + data_ = IsEmpty() ? NULL : const_cast<char*>(pending_data_.front().data()); + } + total_size_ -= size; } -void HttpConnection::Send(const HttpServerResponseInfo& response) { - Send(response.Serialize()); +int HttpConnection::QueuedWriteIOBuffer::GetSizeToWrite() const { + if (IsEmpty()) { + DCHECK_EQ(0, total_size_); + return 0; + } + DCHECK_GE(data_, pending_data_.front().data()); + int consumed = static_cast<int>(data_ - pending_data_.front().data()); + DCHECK_GT(static_cast<int>(pending_data_.front().size()), consumed); + return pending_data_.front().size() - consumed; } -HttpConnection::HttpConnection(HttpServer* server, - scoped_ptr<StreamListenSocket> sock) - : server_(server), - socket_(sock.Pass()) { - id_ = last_id_++; +HttpConnection::HttpConnection(int id, scoped_ptr<StreamSocket> socket) + : id_(id), + socket_(socket.Pass()), + read_buf_(new ReadIOBuffer()), + write_buf_(new QueuedWriteIOBuffer()) { } HttpConnection::~HttpConnection() { - server_->delegate_->OnClose(id_); } -void HttpConnection::Shift(int num_bytes) { - recv_data_ = recv_data_.substr(num_bytes); +void HttpConnection::SetWebSocket(scoped_ptr<WebSocket> web_socket) { + DCHECK(!web_socket_); + web_socket_ = web_socket.Pass(); } } // namespace net diff --git a/net/server/http_connection.h b/net/server/http_connection.h index 17faa46..c7225e1 100644 --- a/net/server/http_connection.h +++ b/net/server/http_connection.h @@ -5,43 +5,130 @@ #ifndef NET_SERVER_HTTP_CONNECTION_H_ #define NET_SERVER_HTTP_CONNECTION_H_ +#include <queue> #include <string> #include "base/basictypes.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "net/http/http_status_code.h" +#include "net/base/io_buffer.h" namespace net { -class HttpServer; -class HttpServerResponseInfo; -class StreamListenSocket; +class StreamSocket; class WebSocket; +// A container which has all information of an http connection. It includes +// id, underlying socket, and pending read/write data. class HttpConnection { public: - ~HttpConnection(); + // IOBuffer for data read. It's a wrapper around GrowableIOBuffer, with more + // functions for buffer management. It moves unconsumed data to the start of + // buffer. + class ReadIOBuffer : public IOBuffer { + public: + static const int kInitialBufSize = 1024; + static const int kMinimumBufSize = 128; + static const int kCapacityIncreaseFactor = 2; + static const int kDefaultMaxBufferSize = 1 * 1024 * 1024; // 1 Mbytes. + + ReadIOBuffer(); + + // Capacity. + int GetCapacity() const; + void SetCapacity(int capacity); + // Increases capacity and returns true if capacity is not beyond the limit. + bool IncreaseCapacity(); + + // Start of read data. + char* StartOfBuffer() const; + // Returns the bytes of read data. + int GetSize() const; + // More read data was appended. + void DidRead(int bytes); + // Capacity for which more read data can be appended. + int RemainingCapacity() const; + + // Removes consumed data and moves unconsumed data to the start of buffer. + void DidConsume(int bytes); + + // Limit of how much internal capacity can increase. + int max_buffer_size() const { return max_buffer_size_; } + void set_max_buffer_size(int max_buffer_size) { + max_buffer_size_ = max_buffer_size; + } + + private: + virtual ~ReadIOBuffer(); + + scoped_refptr<GrowableIOBuffer> base_; + int max_buffer_size_; + + DISALLOW_COPY_AND_ASSIGN(ReadIOBuffer); + }; + + // IOBuffer of pending data to write which has a queue of pending data. Each + // pending data is stored in std::string. data() is the data of first + // std::string stored. + class QueuedWriteIOBuffer : public IOBuffer { + public: + static const int kDefaultMaxBufferSize = 1 * 1024 * 1024; // 1 Mbytes. + + QueuedWriteIOBuffer(); + + // Whether or not pending data exists. + bool IsEmpty() const; - void Send(const std::string& data); - void Send(const char* bytes, int len); - void Send(const HttpServerResponseInfo& response); + // Appends new pending data and returns true if total size doesn't exceed + // the limit, |total_size_limit_|. It would change data() if new data is + // the first pending data. + bool Append(const std::string& data); - void Shift(int num_bytes); + // Consumes data and changes data() accordingly. It cannot be more than + // GetSizeToWrite(). + void DidConsume(int size); + + // Gets size of data to write this time. It is NOT total data size. + int GetSizeToWrite() const; + + // Total size of all pending data. + int total_size() const { return total_size_; } + + // Limit of how much data can be pending. + int max_buffer_size() const { return max_buffer_size_; } + void set_max_buffer_size(int max_buffer_size) { + max_buffer_size_ = max_buffer_size; + } + + private: + virtual ~QueuedWriteIOBuffer(); + + std::queue<std::string> pending_data_; + int total_size_; + int max_buffer_size_; + + DISALLOW_COPY_AND_ASSIGN(QueuedWriteIOBuffer); + }; + + HttpConnection(int id, scoped_ptr<StreamSocket> socket); + ~HttpConnection(); - const std::string& recv_data() const { return recv_data_; } int id() const { return id_; } + StreamSocket* socket() const { return socket_.get(); } + ReadIOBuffer* read_buf() const { return read_buf_.get(); } + QueuedWriteIOBuffer* write_buf() const { return write_buf_.get(); } - private: - friend class HttpServer; - static int last_id_; + WebSocket* web_socket() const { return web_socket_.get(); } + void SetWebSocket(scoped_ptr<WebSocket> web_socket); - HttpConnection(HttpServer* server, scoped_ptr<StreamListenSocket> sock); + private: + const int id_; + const scoped_ptr<StreamSocket> socket_; + const scoped_refptr<ReadIOBuffer> read_buf_; + const scoped_refptr<QueuedWriteIOBuffer> write_buf_; - HttpServer* server_; - scoped_ptr<StreamListenSocket> socket_; scoped_ptr<WebSocket> web_socket_; - std::string recv_data_; - int id_; + DISALLOW_COPY_AND_ASSIGN(HttpConnection); }; diff --git a/net/server/http_connection_unittest.cc b/net/server/http_connection_unittest.cc new file mode 100644 index 0000000..488fd6f --- /dev/null +++ b/net/server/http_connection_unittest.cc @@ -0,0 +1,331 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/server/http_connection.h" + +#include <string> + +#include "base/memory/ref_counted.h" +#include "base/strings/string_piece.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +std::string GetTestString(int size) { + std::string test_string; + for (int i = 0; i < size; ++i) { + test_string.push_back('A' + (i % 26)); + } + return test_string; +} + +TEST(HttpConnectionTest, ReadIOBuffer_SetCapacity) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); + + const int kNewCapacity = HttpConnection::ReadIOBuffer::kInitialBufSize + 128; + buffer->SetCapacity(kNewCapacity); + EXPECT_EQ(kNewCapacity, buffer->GetCapacity()); + EXPECT_EQ(kNewCapacity, buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); +} + +TEST(HttpConnectionTest, ReadIOBuffer_SetCapacity_WithData) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->RemainingCapacity()); + + // Write arbitrary data up to kInitialBufSize. + const std::string kReadData( + GetTestString(HttpConnection::ReadIOBuffer::kInitialBufSize)); + memcpy(buffer->data(), kReadData.data(), kReadData.size()); + buffer->DidRead(kReadData.size()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize - + static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); + + // Check if read data in the buffer is same after SetCapacity(). + const int kNewCapacity = HttpConnection::ReadIOBuffer::kInitialBufSize + 128; + buffer->SetCapacity(kNewCapacity); + EXPECT_EQ(kNewCapacity, buffer->GetCapacity()); + EXPECT_EQ(kNewCapacity - static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); +} + +TEST(HttpConnectionTest, ReadIOBuffer_IncreaseCapacity) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_TRUE(buffer->IncreaseCapacity()); + const int kExpectedInitialBufSize = + HttpConnection::ReadIOBuffer::kInitialBufSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor; + EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity()); + EXPECT_EQ(kExpectedInitialBufSize, buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); + + // Increase capacity until it fails. + while (buffer->IncreaseCapacity()); + EXPECT_FALSE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->max_buffer_size()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->GetCapacity()); + + // Enlarge capacity limit. + buffer->set_max_buffer_size(buffer->max_buffer_size() * 2); + EXPECT_TRUE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + + // Shrink capacity limit. It doesn't change capacity itself. + buffer->set_max_buffer_size( + HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize / 2); + EXPECT_FALSE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); +} + +TEST(HttpConnectionTest, ReadIOBuffer_IncreaseCapacity_WithData) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_TRUE(buffer->IncreaseCapacity()); + const int kExpectedInitialBufSize = + HttpConnection::ReadIOBuffer::kInitialBufSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor; + EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity()); + EXPECT_EQ(kExpectedInitialBufSize, buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); + + // Write arbitrary data up to kExpectedInitialBufSize. + std::string kReadData(GetTestString(kExpectedInitialBufSize)); + memcpy(buffer->data(), kReadData.data(), kReadData.size()); + buffer->DidRead(kReadData.size()); + EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity()); + EXPECT_EQ(kExpectedInitialBufSize - static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); + + // Increase capacity until it fails and check if read data in the buffer is + // same. + while (buffer->IncreaseCapacity()); + EXPECT_FALSE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->max_buffer_size()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize - + static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); +} + +TEST(HttpConnectionTest, ReadIOBuffer_DidRead_DidConsume) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + const char* start_of_buffer = buffer->StartOfBuffer(); + EXPECT_EQ(start_of_buffer, buffer->data()); + + // Read data. + const int kReadLength = 128; + const std::string kReadData(GetTestString(kReadLength)); + memcpy(buffer->data(), kReadData.data(), kReadLength); + buffer->DidRead(kReadLength); + // No change in total capacity. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + // Change in unused capacity because of read data. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize - kReadLength, + buffer->RemainingCapacity()); + EXPECT_EQ(kReadLength, buffer->GetSize()); + // No change in start pointers of read data. + EXPECT_EQ(start_of_buffer, buffer->StartOfBuffer()); + // Change in start pointer of unused buffer. + EXPECT_EQ(start_of_buffer + kReadLength, buffer->data()); + // Test read data. + EXPECT_EQ(kReadData, std::string(buffer->StartOfBuffer(), buffer->GetSize())); + + // Consume data partially. + const int kConsumedLength = 32; + ASSERT_LT(kConsumedLength, kReadLength); + buffer->DidConsume(kConsumedLength); + // Capacity reduced because read data was too small comparing to capacity. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + // Change in unused capacity because of read data. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor - + kReadLength + kConsumedLength, + buffer->RemainingCapacity()); + // Change in read size. + EXPECT_EQ(kReadLength - kConsumedLength, buffer->GetSize()); + // Start data could be changed even when capacity is reduced. + start_of_buffer = buffer->StartOfBuffer(); + // Change in start pointer of unused buffer. + EXPECT_EQ(start_of_buffer + kReadLength - kConsumedLength, buffer->data()); + // Change in read data. + EXPECT_EQ(kReadData.substr(kConsumedLength), + std::string(buffer->StartOfBuffer(), buffer->GetSize())); + + // Read more data. + const int kReadLength2 = 64; + buffer->DidRead(kReadLength2); + // No change in total capacity. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + // Change in unused capacity because of read data. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor - + kReadLength + kConsumedLength - kReadLength2, + buffer->RemainingCapacity()); + // Change in read size + EXPECT_EQ(kReadLength - kConsumedLength + kReadLength2, buffer->GetSize()); + // No change in start pointer of read part. + EXPECT_EQ(start_of_buffer, buffer->StartOfBuffer()); + // Change in start pointer of unused buffer. + EXPECT_EQ(start_of_buffer + kReadLength - kConsumedLength + kReadLength2, + buffer->data()); + + // Consume data fully. + buffer->DidConsume(kReadLength - kConsumedLength + kReadLength2); + // Capacity reduced again because read data was too small. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->RemainingCapacity()); + // All reverts to initial because no data is left. + EXPECT_EQ(0, buffer->GetSize()); + // Start data could be changed even when capacity is reduced. + start_of_buffer = buffer->StartOfBuffer(); + EXPECT_EQ(start_of_buffer, buffer->data()); +} + +TEST(HttpConnectionTest, QueuedWriteIOBuffer_Append_DidConsume) { + scoped_refptr<HttpConnection::QueuedWriteIOBuffer> buffer( + new HttpConnection::QueuedWriteIOBuffer()); + EXPECT_TRUE(buffer->IsEmpty()); + EXPECT_EQ(0, buffer->GetSizeToWrite()); + EXPECT_EQ(0, buffer->total_size()); + + const std::string kData("data to write"); + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_FALSE(buffer->IsEmpty()); + EXPECT_EQ(static_cast<int>(kData.size()), buffer->GetSizeToWrite()); + EXPECT_EQ(static_cast<int>(kData.size()), buffer->total_size()); + // First data to write is same to kData. + EXPECT_EQ(kData, base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + const std::string kData2("more data to write"); + EXPECT_TRUE(buffer->Append(kData2)); + EXPECT_FALSE(buffer->IsEmpty()); + // No change in size to write. + EXPECT_EQ(static_cast<int>(kData.size()), buffer->GetSizeToWrite()); + // Change in total size. + EXPECT_EQ(static_cast<int>(kData.size() + kData2.size()), + buffer->total_size()); + // First data to write has not been changed. Same to kData. + EXPECT_EQ(kData, base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + // Consume data partially. + const int kConsumedLength = kData.length() - 1; + buffer->DidConsume(kConsumedLength); + EXPECT_FALSE(buffer->IsEmpty()); + // Change in size to write. + EXPECT_EQ(static_cast<int>(kData.size()) - kConsumedLength, + buffer->GetSizeToWrite()); + // Change in total size. + EXPECT_EQ(static_cast<int>(kData.size() + kData2.size()) - kConsumedLength, + buffer->total_size()); + // First data to write has shrinked. + EXPECT_EQ(kData.substr(kConsumedLength), + base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + // Consume first data fully. + buffer->DidConsume(kData.size() - kConsumedLength); + EXPECT_FALSE(buffer->IsEmpty()); + // Now, size to write is size of data added second. + EXPECT_EQ(static_cast<int>(kData2.size()), buffer->GetSizeToWrite()); + // Change in total size. + EXPECT_EQ(static_cast<int>(kData2.size()), buffer->total_size()); + // First data to write has changed to kData2. + EXPECT_EQ(kData2, + base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + // Consume second data fully. + buffer->DidConsume(kData2.size()); + EXPECT_TRUE(buffer->IsEmpty()); + EXPECT_EQ(0, buffer->GetSizeToWrite()); + EXPECT_EQ(0, buffer->total_size()); +} + +TEST(HttpConnectionTest, QueuedWriteIOBuffer_TotalSizeLimit) { + scoped_refptr<HttpConnection::QueuedWriteIOBuffer> buffer( + new HttpConnection::QueuedWriteIOBuffer()); + EXPECT_EQ(HttpConnection::QueuedWriteIOBuffer::kDefaultMaxBufferSize + 0, + buffer->max_buffer_size()); + + // Set total size limit very small. + buffer->set_max_buffer_size(10); + + const int kDataLength = 4; + const std::string kData(kDataLength, 'd'); + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength, buffer->total_size()); + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 2, buffer->total_size()); + + // Cannot append more data because it exceeds the limit. + EXPECT_FALSE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 2, buffer->total_size()); + + // Consume data partially. + const int kConsumedLength = 2; + buffer->DidConsume(kConsumedLength); + EXPECT_EQ(kDataLength * 2 - kConsumedLength, buffer->total_size()); + + // Can add more data. + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 3 - kConsumedLength, buffer->total_size()); + + // Cannot append more data because it exceeds the limit. + EXPECT_FALSE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 3 - kConsumedLength, buffer->total_size()); + + // Enlarge limit. + buffer->set_max_buffer_size(20); + // Can add more data. + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 4 - kConsumedLength, buffer->total_size()); +} + +} // namespace +} // namespace net diff --git a/net/server/http_server.cc b/net/server/http_server.cc index 043e625..fb0dab3 100644 --- a/net/server/http_server.cc +++ b/net/server/http_server.cc @@ -17,14 +17,25 @@ #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" #include "net/server/web_socket.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/server_socket.h" +#include "net/socket/stream_socket.h" +#include "net/socket/tcp_server_socket.h" namespace net { -HttpServer::HttpServer(const StreamListenSocketFactory& factory, +HttpServer::HttpServer(scoped_ptr<ServerSocket> server_socket, HttpServer::Delegate* delegate) - : delegate_(delegate), - server_(factory.CreateAndListen(this)) { + : server_socket_(server_socket.Pass()), + delegate_(delegate), + last_id_(0), + weak_ptr_factory_(this) { + DCHECK(server_socket_); + DoAcceptLoop(); +} + +HttpServer::~HttpServer() { + STLDeleteContainerPairSecondPointers( + id_to_connection_.begin(), id_to_connection_.end()); } void HttpServer::AcceptWebSocket( @@ -33,9 +44,8 @@ void HttpServer::AcceptWebSocket( HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - - DCHECK(connection->web_socket_.get()); - connection->web_socket_->Accept(request); + DCHECK(connection->web_socket()); + connection->web_socket()->Accept(request); } void HttpServer::SendOverWebSocket(int connection_id, @@ -43,23 +53,23 @@ void HttpServer::SendOverWebSocket(int connection_id, HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - DCHECK(connection->web_socket_.get()); - connection->web_socket_->Send(data); + DCHECK(connection->web_socket()); + connection->web_socket()->Send(data); } void HttpServer::SendRaw(int connection_id, const std::string& data) { HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - connection->Send(data); + + bool writing_in_progress = !connection->write_buf()->IsEmpty(); + if (connection->write_buf()->Append(data) && !writing_in_progress) + DoWriteLoop(connection); } void HttpServer::SendResponse(int connection_id, const HttpServerResponseInfo& response) { - HttpConnection* connection = FindConnection(connection_id); - if (connection == NULL) - return; - connection->Send(response); + SendRaw(connection_id, response.Serialize()); } void HttpServer::Send(int connection_id, @@ -67,8 +77,9 @@ void HttpServer::Send(int connection_id, const std::string& data, const std::string& content_type) { HttpServerResponseInfo response(status_code); - response.SetBody(data, content_type); + response.SetContentHeaders(data.size(), content_type); SendResponse(connection_id, response); + SendRaw(connection_id, data); } void HttpServer::Send200(int connection_id, @@ -90,108 +101,209 @@ void HttpServer::Close(int connection_id) { if (connection == NULL) return; - // Initiating close from server-side does not lead to the DidClose call. - // Do it manually here. - DidClose(connection->socket_.get()); + id_to_connection_.erase(connection_id); + delegate_->OnClose(connection_id); + + // The call stack might have callbacks which still have the pointer of + // connection. Instead of referencing connection with ID all the time, + // destroys the connection in next run loop to make sure any pending + // callbacks in the call stack return. + base::MessageLoopProxy::current()->DeleteSoon(FROM_HERE, connection); } int HttpServer::GetLocalAddress(IPEndPoint* address) { - if (!server_) - return ERR_SOCKET_NOT_CONNECTED; - return server_->GetLocalAddress(address); + return server_socket_->GetLocalAddress(address); +} + +void HttpServer::SetReceiveBufferSize(int connection_id, int32 size) { + HttpConnection* connection = FindConnection(connection_id); + DCHECK(connection); + connection->read_buf()->set_max_buffer_size(size); } -void HttpServer::DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> socket) { - HttpConnection* connection = new HttpConnection(this, socket.Pass()); +void HttpServer::SetSendBufferSize(int connection_id, int32 size) { + HttpConnection* connection = FindConnection(connection_id); + DCHECK(connection); + connection->write_buf()->set_max_buffer_size(size); +} + +void HttpServer::DoAcceptLoop() { + int rv; + do { + rv = server_socket_->Accept(&accepted_socket_, + base::Bind(&HttpServer::OnAcceptCompleted, + weak_ptr_factory_.GetWeakPtr())); + if (rv == ERR_IO_PENDING) + return; + rv = HandleAcceptResult(rv); + } while (rv == OK); +} + +void HttpServer::OnAcceptCompleted(int rv) { + if (HandleAcceptResult(rv) == OK) + DoAcceptLoop(); +} + +int HttpServer::HandleAcceptResult(int rv) { + if (rv < 0) { + LOG(ERROR) << "Accept error: rv=" << rv; + return rv; + } + + HttpConnection* connection = + new HttpConnection(++last_id_, accepted_socket_.Pass()); id_to_connection_[connection->id()] = connection; - // TODO(szym): Fix socket access. Make HttpConnection the Delegate. - socket_to_connection_[connection->socket_.get()] = connection; + DoReadLoop(connection); + return OK; } -void HttpServer::DidRead(StreamListenSocket* socket, - const char* data, - int len) { - HttpConnection* connection = FindConnection(socket); - DCHECK(connection != NULL); - if (connection == NULL) +void HttpServer::DoReadLoop(HttpConnection* connection) { + int rv; + do { + HttpConnection::ReadIOBuffer* read_buf = connection->read_buf(); + // Increases read buffer size if necessary. + if (read_buf->RemainingCapacity() == 0 && !read_buf->IncreaseCapacity()) { + Close(connection->id()); + return; + } + + rv = connection->socket()->Read( + read_buf, + read_buf->RemainingCapacity(), + base::Bind(&HttpServer::OnReadCompleted, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (rv == ERR_IO_PENDING) + return; + rv = HandleReadResult(connection, rv); + } while (rv == OK); +} + +void HttpServer::OnReadCompleted(int connection_id, int rv) { + HttpConnection* connection = FindConnection(connection_id); + if (!connection) // It might be closed right before by write error. return; - connection->recv_data_.append(data, len); - while (connection->recv_data_.length()) { - if (connection->web_socket_.get()) { + if (HandleReadResult(connection, rv) == OK) + DoReadLoop(connection); +} + +int HttpServer::HandleReadResult(HttpConnection* connection, int rv) { + if (rv <= 0) { + Close(connection->id()); + return rv == 0 ? ERR_CONNECTION_CLOSED : rv; + } + + HttpConnection::ReadIOBuffer* read_buf = connection->read_buf(); + read_buf->DidRead(rv); + + // Handles http requests or websocket messages. + while (read_buf->GetSize() > 0) { + if (connection->web_socket()) { std::string message; - WebSocket::ParseResult result = connection->web_socket_->Read(&message); + WebSocket::ParseResult result = connection->web_socket()->Read(&message); if (result == WebSocket::FRAME_INCOMPLETE) break; if (result == WebSocket::FRAME_CLOSE || result == WebSocket::FRAME_ERROR) { Close(connection->id()); - break; + return ERR_CONNECTION_CLOSED; } delegate_->OnWebSocketMessage(connection->id(), message); + if (HasClosedConnection(connection)) + return ERR_CONNECTION_CLOSED; continue; } HttpServerRequestInfo request; size_t pos = 0; - if (!ParseHeaders(connection, &request, &pos)) + if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(), + &request, &pos)) { break; + } // Sets peer address if exists. - socket->GetPeerAddress(&request.peer); + connection->socket()->GetPeerAddress(&request.peer); if (request.HasHeaderValue("connection", "upgrade")) { - connection->web_socket_.reset(WebSocket::CreateWebSocket(connection, - request, - &pos)); - - if (!connection->web_socket_.get()) // Not enough data was received. + scoped_ptr<WebSocket> websocket( + WebSocket::CreateWebSocket(this, connection, request, &pos)); + if (!websocket) // Not enough data was received. break; + connection->SetWebSocket(websocket.Pass()); + read_buf->DidConsume(pos); delegate_->OnWebSocketRequest(connection->id(), request); - connection->Shift(pos); + if (HasClosedConnection(connection)) + return ERR_CONNECTION_CLOSED; continue; } const char kContentLength[] = "content-length"; - if (request.headers.count(kContentLength)) { + if (request.headers.count(kContentLength) > 0) { size_t content_length = 0; const size_t kMaxBodySize = 100 << 20; if (!base::StringToSizeT(request.GetHeaderValue(kContentLength), &content_length) || content_length > kMaxBodySize) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "request content-length too big or unknown: " + - request.GetHeaderValue(kContentLength))); - DidClose(socket); - break; + SendResponse(connection->id(), + HttpServerResponseInfo::CreateFor500( + "request content-length too big or unknown: " + + request.GetHeaderValue(kContentLength))); + Close(connection->id()); + return ERR_CONNECTION_CLOSED; } - if (connection->recv_data_.length() - pos < content_length) + if (read_buf->GetSize() - pos < content_length) break; // Not enough data was received yet. - request.data = connection->recv_data_.substr(pos, content_length); + request.data.assign(read_buf->StartOfBuffer() + pos, content_length); pos += content_length; } + read_buf->DidConsume(pos); delegate_->OnHttpRequest(connection->id(), request); - connection->Shift(pos); + if (HasClosedConnection(connection)) + return ERR_CONNECTION_CLOSED; } + + return OK; } -void HttpServer::DidClose(StreamListenSocket* socket) { - HttpConnection* connection = FindConnection(socket); - DCHECK(connection != NULL); - id_to_connection_.erase(connection->id()); - socket_to_connection_.erase(connection->socket_.get()); - delete connection; +void HttpServer::DoWriteLoop(HttpConnection* connection) { + int rv = OK; + HttpConnection::QueuedWriteIOBuffer* write_buf = connection->write_buf(); + while (rv == OK && write_buf->GetSizeToWrite() > 0) { + rv = connection->socket()->Write( + write_buf, + write_buf->GetSizeToWrite(), + base::Bind(&HttpServer::OnWriteCompleted, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (rv == ERR_IO_PENDING || rv == OK) + return; + rv = HandleWriteResult(connection, rv); + } } -HttpServer::~HttpServer() { - STLDeleteContainerPairSecondPointers( - id_to_connection_.begin(), id_to_connection_.end()); +void HttpServer::OnWriteCompleted(int connection_id, int rv) { + HttpConnection* connection = FindConnection(connection_id); + if (!connection) // It might be closed right before by read error. + return; + + if (HandleWriteResult(connection, rv) == OK) + DoWriteLoop(connection); } +int HttpServer::HandleWriteResult(HttpConnection* connection, int rv) { + if (rv < 0) { + Close(connection->id()); + return rv; + } + + connection->write_buf()->DidConsume(rv); + return OK; +} + +namespace { + // // HTTP Request Parser // This HTTP request parser uses a simple state machine to quickly parse @@ -255,17 +367,19 @@ int charToInput(char ch) { return INPUT_DEFAULT; } -bool HttpServer::ParseHeaders(HttpConnection* connection, +} // namespace + +bool HttpServer::ParseHeaders(const char* data, + size_t data_len, HttpServerRequestInfo* info, size_t* ppos) { size_t& pos = *ppos; - size_t data_len = connection->recv_data_.length(); int state = ST_METHOD; std::string buffer; std::string header_name; std::string header_value; while (pos < data_len) { - char ch = connection->recv_data_[pos++]; + char ch = data[pos++]; int input = charToInput(ch); int next_state = parser_state[state][input]; @@ -337,11 +451,12 @@ HttpConnection* HttpServer::FindConnection(int connection_id) { return it->second; } -HttpConnection* HttpServer::FindConnection(StreamListenSocket* socket) { - SocketToConnectionMap::iterator it = socket_to_connection_.find(socket); - if (it == socket_to_connection_.end()) - return NULL; - return it->second; +// This is called after any delegate callbacks are called to check if Close() +// has been called during callback processing. Using the pointer of connection, +// |connection| is safe here because Close() deletes the connection in next run +// loop. +bool HttpServer::HasClosedConnection(HttpConnection* connection) { + return FindConnection(connection->id()) != connection; } } // namespace net diff --git a/net/server/http_server.h b/net/server/http_server.h index 4309d122..2ae698b 100644 --- a/net/server/http_server.h +++ b/net/server/http_server.h @@ -5,13 +5,14 @@ #ifndef NET_SERVER_HTTP_SERVER_H_ #define NET_SERVER_HTTP_SERVER_H_ -#include <list> #include <map> +#include <string> #include "base/basictypes.h" +#include "base/macros.h" #include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" #include "net/http/http_status_code.h" -#include "net/socket/stream_listen_socket.h" namespace net { @@ -19,30 +20,28 @@ class HttpConnection; class HttpServerRequestInfo; class HttpServerResponseInfo; class IPEndPoint; +class ServerSocket; +class StreamSocket; class WebSocket; -class HttpServer : public StreamListenSocket::Delegate, - public base::RefCountedThreadSafe<HttpServer> { +class HttpServer { public: + // Delegate to handle http/websocket events. Beware that it is not safe to + // destroy the HttpServer in any of these callbacks. class Delegate { public: virtual void OnHttpRequest(int connection_id, const HttpServerRequestInfo& info) = 0; - virtual void OnWebSocketRequest(int connection_id, const HttpServerRequestInfo& info) = 0; - virtual void OnWebSocketMessage(int connection_id, const std::string& data) = 0; - virtual void OnClose(int connection_id) = 0; - - protected: - virtual ~Delegate() {} }; - HttpServer(const StreamListenSocketFactory& socket_factory, + HttpServer(scoped_ptr<ServerSocket> server_socket, HttpServer::Delegate* delegate); + ~HttpServer(); void AcceptWebSocket(int connection_id, const HttpServerRequestInfo& request); @@ -51,6 +50,7 @@ class HttpServer : public StreamListenSocket::Delegate, // performed that data constitutes a valid HTTP response. A valid HTTP // response may be split across multiple calls to SendRaw. void SendRaw(int connection_id, const std::string& data); + // TODO(byungchul): Consider replacing function name with SendResponseInfo void SendResponse(int connection_id, const HttpServerResponseInfo& response); void Send(int connection_id, HttpStatusCode status_code, @@ -64,40 +64,50 @@ class HttpServer : public StreamListenSocket::Delegate, void Close(int connection_id); + void SetReceiveBufferSize(int connection_id, int32 size); + void SetSendBufferSize(int connection_id, int32 size); + // Copies the local address to |address|. Returns a network error code. int GetLocalAddress(IPEndPoint* address); - // ListenSocketDelegate - virtual void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> socket) OVERRIDE; - virtual void DidRead(StreamListenSocket* socket, - const char* data, - int len) OVERRIDE; - virtual void DidClose(StreamListenSocket* socket) OVERRIDE; + private: + friend class HttpServerTest; + + typedef std::map<int, HttpConnection*> IdToConnectionMap; - protected: - virtual ~HttpServer(); + void DoAcceptLoop(); + void OnAcceptCompleted(int rv); + int HandleAcceptResult(int rv); - private: - friend class base::RefCountedThreadSafe<HttpServer>; - friend class HttpConnection; + void DoReadLoop(HttpConnection* connection); + void OnReadCompleted(int connection_id, int rv); + int HandleReadResult(HttpConnection* connection, int rv); + + void DoWriteLoop(HttpConnection* connection); + void OnWriteCompleted(int connection_id, int rv); + int HandleWriteResult(HttpConnection* connection, int rv); // Expects the raw data to be stored in recv_data_. If parsing is successful, // will remove the data parsed from recv_data_, leaving only the unused // recv data. - bool ParseHeaders(HttpConnection* connection, + bool ParseHeaders(const char* data, + size_t data_len, HttpServerRequestInfo* info, size_t* pos); HttpConnection* FindConnection(int connection_id); - HttpConnection* FindConnection(StreamListenSocket* socket); - HttpServer::Delegate* delegate_; - scoped_ptr<StreamListenSocket> server_; - typedef std::map<int, HttpConnection*> IdToConnectionMap; + // Whether or not Close() has been called during delegate callback processing. + bool HasClosedConnection(HttpConnection* connection); + + const scoped_ptr<ServerSocket> server_socket_; + scoped_ptr<StreamSocket> accepted_socket_; + HttpServer::Delegate* const delegate_; + + int last_id_; IdToConnectionMap id_to_connection_; - typedef std::map<StreamListenSocket*, HttpConnection*> SocketToConnectionMap; - SocketToConnectionMap socket_to_connection_; + + base::WeakPtrFactory<HttpServer> weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(HttpServer); }; diff --git a/net/server/http_server_response_info.cc b/net/server/http_server_response_info.cc index e4c6043..2d0a32e 100644 --- a/net/server/http_server_response_info.cc +++ b/net/server/http_server_response_info.cc @@ -41,8 +41,14 @@ void HttpServerResponseInfo::SetBody(const std::string& body, const std::string& content_type) { DCHECK(body_.empty()); body_ = body; + SetContentHeaders(body.length(), content_type); +} + +void HttpServerResponseInfo::SetContentHeaders( + size_t content_length, + const std::string& content_type) { AddHeader(HttpRequestHeaders::kContentLength, - base::StringPrintf("%" PRIuS, body.length())); + base::StringPrintf("%" PRIuS, content_length)); AddHeader(HttpRequestHeaders::kContentType, content_type); } diff --git a/net/server/http_server_response_info.h b/net/server/http_server_response_info.h index d6cedaa..bbb76d8 100644 --- a/net/server/http_server_response_info.h +++ b/net/server/http_server_response_info.h @@ -27,6 +27,9 @@ class HttpServerResponseInfo { // This also adds an appropriate Content-Length header. void SetBody(const std::string& body, const std::string& content_type); + // Sets content-length and content-type. Body should be sent separately. + void SetContentHeaders(size_t content_length, + const std::string& content_type); std::string Serialize() const; diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc index 467bde4..4b67040 100644 --- a/net/server/http_server_unittest.cc +++ b/net/server/http_server_unittest.cc @@ -2,11 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <algorithm> #include <utility> #include <vector> #include "base/bind.h" #include "base/bind_helpers.h" +#include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/format_macros.h" #include "base/memory/ref_counted.h" @@ -24,11 +26,12 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_log.h" +#include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/socket/tcp_client_socket.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "net/url_request/url_fetcher.h" #include "net/url_request/url_fetcher_delegate.h" #include "net/url_request/url_request_context.h" @@ -155,8 +158,10 @@ class HttpServerTest : public testing::Test, HttpServerTest() : quit_after_request_count_(0) {} virtual void SetUp() OVERRIDE { - TCPListenSocketFactory socket_factory("127.0.0.1", 0); - server_ = new HttpServer(socket_factory, this); + scoped_ptr<ServerSocket> server_socket( + new TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1); + server_.reset(new HttpServer(server_socket.Pass(), this)); ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_)); } @@ -199,8 +204,13 @@ class HttpServerTest : public testing::Test, return requests_[request_index].second; } + void HandleAcceptResult(scoped_ptr<StreamSocket> socket) { + server_->accepted_socket_.reset(socket.release()); + server_->HandleAcceptResult(OK); + } + protected: - scoped_refptr<HttpServer> server_; + scoped_ptr<HttpServer> server_; IPEndPoint server_address_; base::Closure run_loop_quit_func_; std::vector<std::pair<HttpServerRequestInfo, int> > requests_; @@ -429,23 +439,105 @@ TEST_F(HttpServerTest, SendRaw) { namespace { -class MockStreamListenSocket : public StreamListenSocket { +class MockStreamSocket : public StreamSocket { public: - MockStreamListenSocket(StreamListenSocket::Delegate* delegate) - : StreamListenSocket(kInvalidSocket, delegate) {} + MockStreamSocket() + : connected_(true), + read_buf_(NULL), + read_buf_len_(0) {} + + // StreamSocket + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual void Disconnect() OVERRIDE { + connected_ = false; + if (!read_callback_.is_null()) { + read_buf_ = NULL; + read_buf_len_ = 0; + base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED); + } + } + virtual bool IsConnected() const OVERRIDE { return connected_; } + virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return true; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { return false; } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } - virtual void Accept() OVERRIDE { NOTREACHED(); } + // Socket + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + if (!connected_) { + return ERR_SOCKET_NOT_CONNECTED; + } + if (pending_read_data_.empty()) { + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + return ERR_IO_PENDING; + } + DCHECK_GT(buf_len, 0); + int read_len = std::min(static_cast<int>(pending_read_data_.size()), + buf_len); + memcpy(buf->data(), pending_read_data_.data(), read_len); + pending_read_data_.erase(0, read_len); + return read_len; + } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual int SetReceiveBufferSize(int32 size) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual int SetSendBufferSize(int32 size) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + + void DidRead(const char* data, int data_len) { + if (!read_buf_) { + pending_read_data_.append(data, data_len); + return; + } + int read_len = std::min(data_len, read_buf_len_); + memcpy(read_buf_->data(), data, read_len); + pending_read_data_.assign(data + read_len, data_len - read_len); + read_buf_ = NULL; + read_buf_len_ = 0; + base::ResetAndReturn(&read_callback_).Run(read_len); + } private: - virtual ~MockStreamListenSocket() {} + virtual ~MockStreamSocket() {} + + bool connected_; + scoped_refptr<IOBuffer> read_buf_; + int read_buf_len_; + CompletionCallback read_callback_; + std::string pending_read_data_; + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MockStreamSocket); }; } // namespace TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { - StreamListenSocket* socket = - new MockStreamListenSocket(server_.get()); - server_->DidAccept(NULL, make_scoped_ptr(socket)); + MockStreamSocket* socket = new MockStreamSocket(); + HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket)); std::string body("body"); std::string request_text = base::StringPrintf( "GET /test HTTP/1.1\r\n" @@ -453,9 +545,9 @@ TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { "Content-Length: %" PRIuS "\r\n\r\n%s", body.length(), body.c_str()); - server_->DidRead(socket, request_text.c_str(), request_text.length() - 2); + socket->DidRead(request_text.c_str(), request_text.length() - 2); ASSERT_EQ(0u, requests_.size()); - server_->DidRead(socket, request_text.c_str() + request_text.length() - 2, 2); + socket->DidRead(request_text.c_str() + request_text.length() - 2, 2); ASSERT_EQ(1u, requests_.size()); ASSERT_EQ(body, GetRequest(0).data); } diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc index f06b425..ec0fdac 100644 --- a/net/server/web_socket.cc +++ b/net/server/web_socket.cc @@ -15,6 +15,7 @@ #include "base/strings/stringprintf.h" #include "base/sys_byteorder.h" #include "net/server/http_connection.h" +#include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" @@ -43,12 +44,14 @@ static uint32 WebSocketKeyFingerprint(const std::string& str) { class WebSocketHixie76 : public net::WebSocket { public: - static net::WebSocket* Create(HttpConnection* connection, + static net::WebSocket* Create(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { - if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen) + if (connection->read_buf()->GetSize() < + static_cast<int>(*pos + kWebSocketHandshakeBodyLen)) return NULL; - return new WebSocketHixie76(connection, request, pos); + return new WebSocketHixie76(server, connection, request, pos); } virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { @@ -69,31 +72,33 @@ class WebSocketHixie76 : public net::WebSocket { std::string origin = request.GetHeaderValue("origin"); std::string host = request.GetHeaderValue("host"); std::string location = "ws://" + host + request.path; - connection_->Send(base::StringPrintf( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Origin: %s\r\n" - "Sec-WebSocket-Location: %s\r\n" - "\r\n", - origin.c_str(), - location.c_str())); - connection_->Send(reinterpret_cast<char*>(digest.a), 16); + server_->SendRaw( + connection_->id(), + base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Origin: %s\r\n" + "Sec-WebSocket-Location: %s\r\n" + "\r\n", + origin.c_str(), + location.c_str())); + server_->SendRaw(connection_->id(), + std::string(reinterpret_cast<char*>(digest.a), 16)); } virtual ParseResult Read(std::string* message) OVERRIDE { DCHECK(message); - const std::string& data = connection_->recv_data(); - if (data[0]) + HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); + if (read_buf->StartOfBuffer()[0]) return FRAME_ERROR; + base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); size_t pos = data.find('\377', 1); - if (pos == std::string::npos) + if (pos == base::StringPiece::npos) return FRAME_INCOMPLETE; - std::string buffer(data.begin() + 1, data.begin() + pos); - message->swap(buffer); - connection_->Shift(pos + 1); + message->assign(data.data() + 1, pos - 1); + read_buf->DidConsume(pos + 1); return FRAME_OK; } @@ -101,37 +106,42 @@ class WebSocketHixie76 : public net::WebSocket { virtual void Send(const std::string& message) OVERRIDE { char message_start = 0; char message_end = -1; - connection_->Send(&message_start, 1); - connection_->Send(message); - connection_->Send(&message_end, 1); + server_->SendRaw(connection_->id(), std::string(1, message_start)); + server_->SendRaw(connection_->id(), message); + server_->SendRaw(connection_->id(), std::string(1, message_end)); } private: static const int kWebSocketHandshakeBodyLen; - WebSocketHixie76(HttpConnection* connection, + WebSocketHixie76(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, - size_t* pos) : WebSocket(connection) { + size_t* pos) + : WebSocket(server, connection) { std::string key1 = request.GetHeaderValue("sec-websocket-key1"); std::string key2 = request.GetHeaderValue("sec-websocket-key2"); if (key1.empty()) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " - "specified.")); + server->SendResponse( + connection->id(), + HttpServerResponseInfo::CreateFor500( + "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " + "specified.")); return; } if (key2.empty()) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " - "specified.")); + server->SendResponse( + connection->id(), + HttpServerResponseInfo::CreateFor500( + "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " + "specified.")); return; } - key3_ = connection->recv_data().substr( - *pos, - *pos + kWebSocketHandshakeBodyLen); + key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, + kWebSocketHandshakeBodyLen); *pos += kWebSocketHandshakeBodyLen; } @@ -169,7 +179,8 @@ const size_t kMaskingKeyWidthInBytes = 4; class WebSocketHybi17 : public WebSocket { public: - static WebSocket* Create(HttpConnection* connection, + static WebSocket* Create(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { std::string version = request.GetHeaderValue("sec-websocket-version"); @@ -178,12 +189,14 @@ class WebSocketHybi17 : public WebSocket { std::string key = request.GetHeaderValue("sec-websocket-key"); if (key.empty()) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key is empty or isn't " - "specified.")); + server->SendResponse( + connection->id(), + HttpServerResponseInfo::CreateFor500( + "Invalid request format. Sec-WebSocket-Key is empty or isn't " + "specified.")); return NULL; } - return new WebSocketHybi17(connection, request, pos); + return new WebSocketHybi17(server, connection, request, pos); } virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { @@ -194,24 +207,24 @@ class WebSocketHybi17 : public WebSocket { std::string encoded_hash; base::Base64Encode(base::SHA1HashString(data), &encoded_hash); - std::string response = base::StringPrintf( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: %s\r\n" - "\r\n", - encoded_hash.c_str()); - connection_->Send(response); + server_->SendRaw( + connection_->id(), + base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n" + "\r\n", + encoded_hash.c_str())); } virtual ParseResult Read(std::string* message) OVERRIDE { - const std::string& frame = connection_->recv_data(); + HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); + base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); int bytes_consumed = 0; - ParseResult result = WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message); if (result == FRAME_OK) - connection_->Shift(bytes_consumed); + read_buf->DidConsume(bytes_consumed); if (result == FRAME_CLOSE) closed_ = true; return result; @@ -220,25 +233,26 @@ class WebSocketHybi17 : public WebSocket { virtual void Send(const std::string& message) OVERRIDE { if (closed_) return; - std::string data = WebSocket::EncodeFrameHybi17(message, 0); - connection_->Send(data); + server_->SendRaw(connection_->id(), + WebSocket::EncodeFrameHybi17(message, 0)); } private: - WebSocketHybi17(HttpConnection* connection, + WebSocketHybi17(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) - : WebSocket(connection), - op_code_(0), - final_(false), - reserved1_(false), - reserved2_(false), - reserved3_(false), - masked_(false), - payload_(0), - payload_length_(0), - frame_end_(0), - closed_(false) { + : WebSocket(server, connection), + op_code_(0), + final_(false), + reserved1_(false), + reserved2_(false), + reserved3_(false), + masked_(false), + payload_(0), + payload_length_(0), + frame_end_(0), + closed_(false) { } OpCode op_code_; @@ -257,21 +271,23 @@ class WebSocketHybi17 : public WebSocket { } // anonymous namespace -WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, +WebSocket* WebSocket::CreateWebSocket(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { - WebSocket* socket = WebSocketHybi17::Create(connection, request, pos); + WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); if (socket) return socket; - return WebSocketHixie76::Create(connection, request, pos); + return WebSocketHixie76::Create(server, connection, request, pos); } // static -WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame, - bool client_frame, - int* bytes_consumed, - std::string* output) { +WebSocket::ParseResult WebSocket::DecodeFrameHybi17( + const base::StringPiece& frame, + bool client_frame, + int* bytes_consumed, + std::string* output) { size_t data_length = frame.length(); if (data_length < 2) return FRAME_INCOMPLETE; @@ -349,8 +365,7 @@ WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame, for (size_t i = 0; i < payload_length; ++i) // Unmask the payload. (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; } else { - std::string buffer(p, p + payload_length); - output->swap(buffer); + output->assign(p, p + payload_length); } size_t pos = p + actual_masking_key_length + payload_length - buffer_begin; @@ -400,7 +415,9 @@ std::string WebSocket::EncodeFrameHybi17(const std::string& message, return std::string(&frame[0], frame.size()); } -WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { +WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) + : server_(server), + connection_(connection) { } } // namespace net diff --git a/net/server/web_socket.h b/net/server/web_socket.h index 49ced84..9b3a794 100644 --- a/net/server/web_socket.h +++ b/net/server/web_socket.h @@ -8,10 +8,12 @@ #include <string> #include "base/basictypes.h" +#include "base/strings/string_piece.h" namespace net { class HttpConnection; +class HttpServer; class HttpServerRequestInfo; class WebSocket { @@ -23,11 +25,12 @@ class WebSocket { FRAME_ERROR }; - static WebSocket* CreateWebSocket(HttpConnection* connection, + static WebSocket* CreateWebSocket(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos); - static ParseResult DecodeFrameHybi17(const std::string& frame, + static ParseResult DecodeFrameHybi17(const base::StringPiece& frame, bool client_frame, int* bytes_consumed, std::string* output); @@ -41,8 +44,10 @@ class WebSocket { virtual ~WebSocket() {} protected: - explicit WebSocket(HttpConnection* connection); - HttpConnection* connection_; + WebSocket(HttpServer* server, HttpConnection* connection); + + HttpServer* const server_; + HttpConnection* const connection_; private: DISALLOW_COPY_AND_ASSIGN(WebSocket); diff --git a/net/socket/server_socket.h b/net/socket/server_socket.h index 528955b..4b9ca8e 100644 --- a/net/socket/server_socket.h +++ b/net/socket/server_socket.h @@ -21,7 +21,7 @@ class NET_EXPORT ServerSocket { ServerSocket(); virtual ~ServerSocket(); - // Binds the socket and starts listening. Destroy the socket to stop + // Binds the socket and starts listening. Destroys the socket to stop // listening. virtual int Listen(const IPEndPoint& address, int backlog) = 0; |