summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authordilmah@chromium.org <dilmah@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-07-19 20:33:43 +0000
committerdilmah@chromium.org <dilmah@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2011-07-19 20:33:43 +0000
commit38dc684dae9a03cd0a1642581d618e39f813013c (patch)
treeb00197013d79cfebe690d77d5dfb46a55c123a3f /net
parent3b9448e6acb5fd762afc0c3b3957a65e9d3ef092 (diff)
downloadchromium_src-38dc684dae9a03cd0a1642581d618e39f813013c.zip
chromium_src-38dc684dae9a03cd0a1642581d618e39f813013c.tar.gz
chromium_src-38dc684dae9a03cd0a1642581d618e39f813013c.tar.bz2
Implementation of server socket for websockets protocol.
It is prerequisite to refactoring websocket-to-tcp proxy (existing on ChromeOS). BUG=chromium-os:15533 TEST=unittest Committed: http://src.chromium.org/viewvc/chrome?view=rev&revision=93010 Review URL: http://codereview.chromium.org/7131009 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@93089 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r--net/base/io_buffer.h2
-rw-r--r--net/base/net_error_list.h9
-rw-r--r--net/net.gyp3
-rw-r--r--net/socket/nss_ssl_util.h2
-rw-r--r--net/socket/socket.h12
-rw-r--r--net/socket/web_socket_server_socket.cc905
-rw-r--r--net/socket/web_socket_server_socket.h63
-rw-r--r--net/socket/web_socket_server_socket_unittest.cc596
8 files changed, 1581 insertions, 11 deletions
diff --git a/net/base/io_buffer.h b/net/base/io_buffer.h
index 1ec7869..e117cce 100644
--- a/net/base/io_buffer.h
+++ b/net/base/io_buffer.h
@@ -128,7 +128,7 @@ class NET_API PickledIOBuffer : public IOBuffer {
Pickle* pickle() { return &pickle_; }
- // Signals that we are done writing to the picke and we can use it for a
+ // Signals that we are done writing to the pickle and we can use it for a
// write-style IO operation.
void Done();
diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h
index cff5a79..9596a9a 100644
--- a/net/base/net_error_list.h
+++ b/net/base/net_error_list.h
@@ -239,6 +239,12 @@ NET_ERROR(MSG_TOO_BIG, -142)
// See also: ESET_ANTI_VIRUS_SSL_INTERCEPTION
NET_ERROR(KASPERSKY_ANTI_VIRUS_SSL_INTERCEPTION, -143)
+// Violation of limits (e.g. imposed to prevent DoS).
+NET_ERROR(LIMIT_VIOLATION, -144)
+
+// WebSocket protocol error occurred.
+NET_ERROR(WS_PROTOCOL_ERROR, -145)
+
// Connection was aborted for switching to another ptotocol.
// WebSocket abort SocketStream connection when alternate protocol is found.
NET_ERROR(PROTOCOL_SWITCHED, -146)
@@ -246,9 +252,6 @@ NET_ERROR(PROTOCOL_SWITCHED, -146)
// Returned when attempting to bind an address that is already in use.
NET_ERROR(ADDRESS_IN_USE, -147)
-// NOTE: error codes 144-145 are available, please use those before adding
-// 148.
-
// Certificate error codes
//
// The values of certificate error codes must be consecutive.
diff --git a/net/net.gyp b/net/net.gyp
index f9e62fa..724e24b 100644
--- a/net/net.gyp
+++ b/net/net.gyp
@@ -544,6 +544,8 @@
'socket/tcp_server_socket_win.h',
'socket/transport_client_socket_pool.cc',
'socket/transport_client_socket_pool.h',
+ 'socket/web_socket_server_socket.cc',
+ 'socket/web_socket_server_socket.h',
'socket_stream/socket_stream.cc',
'socket_stream/socket_stream.h',
'socket_stream/socket_stream_job.cc',
@@ -1005,6 +1007,7 @@
'socket/tcp_server_socket_unittest.cc',
'socket/transport_client_socket_pool_unittest.cc',
'socket/transport_client_socket_unittest.cc',
+ 'socket/web_socket_server_socket_unittest.cc',
'socket_stream/socket_stream_metrics_unittest.cc',
'socket_stream/socket_stream_unittest.cc',
'spdy/spdy_framer_test.cc',
diff --git a/net/socket/nss_ssl_util.h b/net/socket/nss_ssl_util.h
index 2a53fc7..614ab5f 100644
--- a/net/socket/nss_ssl_util.h
+++ b/net/socket/nss_ssl_util.h
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-// This file is only inclued in ssl_client_socket_nss.cc and
+// This file is only included in ssl_client_socket_nss.cc and
// ssl_server_socket_nss.cc to share common functions of NSS.
#ifndef NET_SOCKET_NSS_SSL_UTIL_H_
diff --git a/net/socket/socket.h b/net/socket/socket.h
index 17d9f12..450f42d 100644
--- a/net/socket/socket.h
+++ b/net/socket/socket.h
@@ -18,8 +18,8 @@ class NET_API Socket {
public:
virtual ~Socket() {}
- // Reads data, up to buf_len bytes, from the socket. The number of bytes read
- // is returned, or an error is returned upon failure.
+ // Reads data, up to |buf_len| bytes, from the socket. The number of bytes
+ // read is returned, or an error is returned upon failure.
// ERR_SOCKET_NOT_CONNECTED should be returned if the socket is not currently
// connected. Zero is returned once to indicate end-of-file; the return value
// of subsequent calls is undefined, and may be OS dependent. ERR_IO_PENDING
@@ -32,8 +32,8 @@ class NET_API Socket {
virtual int Read(IOBuffer* buf, int buf_len,
CompletionCallback* callback) = 0;
- // Writes data, up to buf_len bytes, to the socket. Note: only part of the
- // data may be written! The number of bytes written is returned, or an error
+ // Writes data, up to |buf_len| bytes, to the socket. Note: data may be
+ // written partially. The number of bytes written is returned, or an error
// is returned upon failure. ERR_SOCKET_NOT_CONNECTED should be returned if
// the socket is not currently connected. The return value when the
// connection is closed is undefined, and may be OS dependent. ERR_IO_PENDING
@@ -48,12 +48,12 @@ class NET_API Socket {
CompletionCallback* callback) = 0;
// Set the receive buffer size (in bytes) for the socket.
- // Note: changing this value can effect the TCP window size on some platforms.
+ // Note: changing this value can affect the TCP window size on some platforms.
// Returns true on success, or false on failure.
virtual bool SetReceiveBufferSize(int32 size) = 0;
// Set the send buffer size (in bytes) for the socket.
- // Note: changing this value can effect the TCP window size on some platforms.
+ // Note: changing this value can affect the TCP window size on some platforms.
// Returns true on success, or false on failure.
virtual bool SetSendBufferSize(int32 size) = 0;
};
diff --git a/net/socket/web_socket_server_socket.cc b/net/socket/web_socket_server_socket.cc
new file mode 100644
index 0000000..a793097
--- /dev/null
+++ b/net/socket/web_socket_server_socket.cc
@@ -0,0 +1,905 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/web_socket_server_socket.h"
+
+#include <algorithm>
+#include <deque>
+#include <limits>
+#include <map>
+#include <vector>
+
+#if defined(OS_WIN)
+#include <winsock2.h> // for htonl
+#else
+#include <arpa/inet.h>
+#endif
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+#include "base/md5.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop.h"
+#include "base/string_util.h"
+#include "base/task.h"
+#include "googleurl/src/gurl.h"
+#include "net/base/completion_callback.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+
+namespace {
+
+const size_t kHandshakeLimitBytes = 1 << 14;
+
+const char kCrOctet = '\r';
+COMPILE_ASSERT(kCrOctet == '\x0d', ASCII);
+const char kLfOctet = '\n';
+COMPILE_ASSERT(kLfOctet == '\x0a', ASCII);
+const char kSpaceOctet = ' ';
+COMPILE_ASSERT(kSpaceOctet == '\x20', ASCII);
+const char kCommaOctet = ',';
+COMPILE_ASSERT(kCommaOctet == '\x2c', ASCII);
+
+const char kCRLF[] = { kCrOctet, kLfOctet, 0 };
+const char kCRLFCRLF[] = { kCrOctet, kLfOctet, kCrOctet, kLfOctet, 0 };
+
+const char kPlainHostFieldName[] = "Host";
+const char kPlainOriginFieldName[] = "Origin";
+const char kOriginFieldName[] = "Sec-WebSocket-Origin";
+const char kProtocolFieldName[] = "Sec-WebSocket-Protocol";
+const char kVersionFieldName[] = "Sec-WebSocket-Version";
+const char kLocationFieldName[] = "Sec-WebSocket-Location";
+const char kKey1FieldName[] = "Sec-WebSocket-Key1";
+const char kKey2FieldName[] = "Sec-WebSocket-Key2";
+
+int CountSpaces(const std::string& s) {
+ return std::count(s.begin(), s.end(), kSpaceOctet);
+}
+
+// Returns true on success.
+bool FetchDecimalDigits(const std::string& s, uint32* result) {
+ *result = 0;
+ bool got_something = false;
+ for (size_t i = 0; i < s.size(); ++i) {
+ if (IsAsciiDigit(s[i])) {
+ got_something = true;
+ if (*result > std::numeric_limits<uint32>::max() / 10)
+ return false;
+ *result *= 10;
+ int digit = s[i] - '0';
+ if (*result > std::numeric_limits<uint32>::max() - digit)
+ return false;
+ *result += digit;
+ }
+ }
+ return got_something;
+}
+
+// Returns number of fetched subprotocols or negative error code.
+int FetchSubprotocolList(
+ const std::string& s, std::vector<std::string>* subprotocol_list) {
+ subprotocol_list->clear();
+ subprotocol_list->push_back(std::string());
+ for (size_t i = 0; i < s.size(); ++i) {
+ if (s[i] > '\x20' && s[i] < '\x7f' && s[i] != kCommaOctet)
+ subprotocol_list->back() += s[i];
+ else if (!subprotocol_list->back().empty()) {
+ if (subprotocol_list->size() < 16)
+ subprotocol_list->push_back(std::string());
+ else
+ return net::ERR_LIMIT_VIOLATION;
+ }
+ }
+ if (subprotocol_list->back().empty())
+ subprotocol_list->pop_back();
+ if (subprotocol_list->empty())
+ return net::ERR_WS_PROTOCOL_ERROR;
+
+ {
+ std::vector<std::string> tmp(*subprotocol_list);
+ std::sort(tmp.begin(), tmp.end());
+ if (tmp.end() != std::unique(tmp.begin(), tmp.end()))
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ return subprotocol_list->size();
+}
+
+class WebSocketServerSocketImpl : public net::WebSocketServerSocket {
+ public:
+ WebSocketServerSocketImpl(net::Socket* transport_socket, Delegate* delegate)
+ : phase_(PHASE_NYMPH),
+ frame_bytes_remaining_(0),
+ transport_socket_(transport_socket),
+ delegate_(delegate),
+ handshake_buf_(new net::IOBuffer(kHandshakeLimitBytes)),
+ fill_handshake_buf_(new net::DrainableIOBuffer(
+ handshake_buf_, kHandshakeLimitBytes)),
+ process_handshake_buf_(new net::DrainableIOBuffer(
+ handshake_buf_, kHandshakeLimitBytes)),
+ transport_read_callback_(NewCallback(
+ this, &WebSocketServerSocketImpl::OnRead)),
+ transport_write_callback_(NewCallback(
+ this, &WebSocketServerSocketImpl::OnWrite)),
+ is_transport_read_pending_(false),
+ is_transport_write_pending_(false),
+ method_factory_(this) {
+ DCHECK(transport_socket);
+ DCHECK(delegate);
+ }
+
+ virtual ~WebSocketServerSocketImpl() {
+ std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ);
+ if (it != pending_reqs_.end() &&
+ it->type == PendingReq::TYPE_READ &&
+ it->io_buf != NULL &&
+ it->io_buf->data() != NULL &&
+ it->callback != 0) {
+ it->callback->Run(0); // Report EOF.
+ }
+ }
+
+ private:
+ enum Phase {
+ // Before Accept() is called.
+ PHASE_NYMPH,
+
+ // After Accept() is called and until handshake success/fail.
+ PHASE_HANDSHAKE,
+
+ // Processing data stream.
+ PHASE_FRAME_OUTSIDE, // Outside data frame.
+ PHASE_FRAME_INSIDE, // Inside text frame.
+ PHASE_FRAME_LENGTH, // Reading length of binary frame.
+ PHASE_FRAME_SKIP, // Skipping binary frame.
+
+ // After termination.
+ PHASE_SHUT
+ };
+
+ struct PendingReq {
+ enum Type {
+ // Frame delimiters or handshake (as opposed to user data).
+ TYPE_METADATA = 1 << 0,
+ // Read request.
+ TYPE_READ = 1 << 1,
+ // Write request.
+ TYPE_WRITE = 1 << 2,
+
+ TYPE_READ_METADATA = TYPE_READ | TYPE_METADATA,
+ TYPE_WRITE_METADATA = TYPE_WRITE | TYPE_METADATA
+ };
+
+ PendingReq(Type type, net::DrainableIOBuffer* io_buf,
+ net::CompletionCallback* callback)
+ : type(type),
+ io_buf(io_buf),
+ callback(callback) {
+ switch (type) {
+ case PendingReq::TYPE_READ:
+ case PendingReq::TYPE_WRITE:
+ case PendingReq::TYPE_READ_METADATA:
+ case PendingReq::TYPE_WRITE_METADATA: {
+ DCHECK(io_buf);
+ break;
+ }
+ default: {
+ NOTREACHED();
+ break;
+ }
+ }
+ }
+
+ Type type;
+ scoped_refptr<net::DrainableIOBuffer> io_buf;
+ net::CompletionCallback* callback;
+ };
+
+ // Socket implementation.
+ virtual int Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) OVERRIDE {
+ if (buf_len == 0)
+ return 0;
+ if (buf == NULL || buf_len < 0) {
+ NOTREACHED();
+ return net::ERR_INVALID_ARGUMENT;
+ }
+ while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() -
+ process_handshake_buf_->BytesConsumed()) {
+ DCHECK(!is_transport_read_pending_);
+ DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end());
+ switch (phase_) {
+ case PHASE_FRAME_OUTSIDE:
+ case PHASE_FRAME_INSIDE:
+ case PHASE_FRAME_LENGTH:
+ case PHASE_FRAME_SKIP: {
+ int n = std::min(bytes_remaining, buf_len);
+ int rv = ProcessDataFrames(
+ process_handshake_buf_->data(), n, buf->data(), buf_len);
+ process_handshake_buf_->DidConsume(n);
+ if (rv == 0) {
+ // ProcessDataFrames may return zero for non-empty buffer if it
+ // contains only frame delimiters without real data. In this case:
+ // try again and do not just return zero (zero stands for EOF).
+ continue;
+ }
+ return rv;
+ }
+ case PHASE_SHUT: {
+ return 0;
+ }
+ case PHASE_NYMPH:
+ case PHASE_HANDSHAKE:
+ default: {
+ NOTREACHED();
+ return net::ERR_UNEXPECTED;
+ }
+ }
+ }
+ switch (phase_) {
+ case PHASE_FRAME_OUTSIDE:
+ case PHASE_FRAME_INSIDE:
+ case PHASE_FRAME_LENGTH:
+ case PHASE_FRAME_SKIP: {
+ pending_reqs_.push_back(PendingReq(
+ PendingReq::TYPE_READ,
+ new net::DrainableIOBuffer(buf, buf_len),
+ callback));
+ ConsiderTransportRead();
+ break;
+ }
+ case PHASE_SHUT: {
+ return 0;
+ }
+ case PHASE_NYMPH:
+ case PHASE_HANDSHAKE:
+ default: {
+ NOTREACHED();
+ return net::ERR_UNEXPECTED;
+ }
+ }
+ return net::ERR_IO_PENDING;
+ }
+
+ virtual int Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) OVERRIDE {
+ if (buf_len == 0)
+ return 0;
+ if (buf == NULL || buf_len < 0) {
+ NOTREACHED();
+ return net::ERR_INVALID_ARGUMENT;
+ }
+ DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'),
+ buf->data() + buf_len);
+ switch (phase_) {
+ case PHASE_FRAME_OUTSIDE:
+ case PHASE_FRAME_INSIDE:
+ case PHASE_FRAME_LENGTH:
+ case PHASE_FRAME_SKIP: {
+ break;
+ }
+ case PHASE_SHUT: {
+ return net::ERR_SOCKET_NOT_CONNECTED;
+ }
+ case PHASE_NYMPH:
+ case PHASE_HANDSHAKE:
+ default: {
+ NOTREACHED();
+ return net::ERR_UNEXPECTED;
+ }
+ }
+
+ net::IOBuffer* frame_start = new net::IOBuffer(1);
+ frame_start->data()[0] = '\x00';
+ pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA,
+ new net::DrainableIOBuffer(frame_start, 1),
+ NULL));
+
+ pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE,
+ new net::DrainableIOBuffer(buf, buf_len),
+ callback));
+
+ net::IOBuffer* frame_end = new net::IOBuffer(1);
+ frame_end->data()[0] = '\xff';
+ pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA,
+ new net::DrainableIOBuffer(frame_end, 1),
+ NULL));
+
+ ConsiderTransportWrite();
+ return net::ERR_IO_PENDING;
+ }
+
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
+ return transport_socket_->SetReceiveBufferSize(size);
+ }
+
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE {
+ return transport_socket_->SetSendBufferSize(size);
+ }
+
+ // WebSocketServerSocket implementation.
+ virtual int Accept(net::CompletionCallback* callback) {
+ if (phase_ != PHASE_NYMPH)
+ return net::ERR_UNEXPECTED;
+ phase_ = PHASE_HANDSHAKE;
+ pending_reqs_.push_front(PendingReq(
+ PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback));
+ ConsiderTransportRead();
+ return net::ERR_IO_PENDING;
+ }
+
+ std::deque<PendingReq>::iterator GetPendingReq(PendingReq::Type type) {
+ for (std::deque<PendingReq>::iterator it = pending_reqs_.begin();
+ it != pending_reqs_.end(); ++it) {
+ if (it->type & type)
+ return it;
+ }
+ return pending_reqs_.end();
+ }
+
+ void ConsiderTransportRead() {
+ if (pending_reqs_.empty())
+ return;
+ if (is_transport_read_pending_)
+ return;
+ std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ);
+ if (it == pending_reqs_.end())
+ return;
+ if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) {
+ NOTREACHED();
+ return;
+ }
+ is_transport_read_pending_ = true;
+ int rv = transport_socket_->Read(
+ it->io_buf.get(), it->io_buf->BytesRemaining(),
+ transport_read_callback_.get());
+ if (rv != net::ERR_IO_PENDING) {
+ // PostTask rather than direct call in order to:
+ // (1) guarantee calling callback after returning from Read();
+ // (2) avoid potential stack overflow;
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &WebSocketServerSocketImpl::OnRead, rv));
+ }
+ }
+
+ void ConsiderTransportWrite() {
+ if (is_transport_write_pending_)
+ return;
+ if (pending_reqs_.empty())
+ return;
+ std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE);
+ if (it == pending_reqs_.end())
+ return;
+ if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+ is_transport_write_pending_ = true;
+ int rv = transport_socket_->Write(
+ it->io_buf.get(), it->io_buf->BytesRemaining(),
+ transport_write_callback_.get());
+ if (rv != net::ERR_IO_PENDING) {
+ // PostTask rather than direct call in order to:
+ // (1) guarantee calling callback after returning from Read();
+ // (2) avoid potential stack overflow;
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &WebSocketServerSocketImpl::OnWrite, rv));
+ }
+ }
+
+ void Shut(int result) {
+ if (result > 0 || result == net::ERR_IO_PENDING)
+ result = net::ERR_UNEXPECTED;
+ if (result != 0) {
+ while (!pending_reqs_.empty()) {
+ PendingReq& req = pending_reqs_.front();
+ if (req.callback)
+ req.callback->Run(result);
+ pending_reqs_.pop_front();
+ }
+ transport_socket_.reset(); // terminate underlying connection.
+ }
+ phase_ = PHASE_SHUT;
+ }
+
+ // Callbacks for transport socket.
+ void OnRead(int result) {
+ if (!is_transport_read_pending_) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+ is_transport_read_pending_ = false;
+
+ if (result <= 0) {
+ Shut(result);
+ return;
+ }
+
+ std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ);
+ if (it == pending_reqs_.end() ||
+ it->io_buf == NULL ||
+ it->io_buf->data() == NULL) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+ if ((phase_ == PHASE_HANDSHAKE) == (it->type == PendingReq::TYPE_READ)) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+
+ switch (phase_) {
+ case PHASE_HANDSHAKE: {
+ if (it != pending_reqs_.begin() || it->io_buf != fill_handshake_buf_) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+ fill_handshake_buf_->DidConsume(result);
+ // ProcessHandshake invalidates iterators for |pending_reqs_|
+ int rv = ProcessHandshake();
+ if (rv > 0) {
+ process_handshake_buf_->DidConsume(rv);
+ phase_ = PHASE_FRAME_OUTSIDE;
+ net::CompletionCallback* cb = pending_reqs_.front().callback;
+ pending_reqs_.pop_front();
+ ConsiderTransportWrite(); // Schedule answer handshake.
+ if (cb)
+ cb->Run(0);
+ } else if (rv == net::ERR_IO_PENDING) {
+ if (fill_handshake_buf_->BytesRemaining() < 1)
+ Shut(net::ERR_LIMIT_VIOLATION);
+ } else if (rv < 0) {
+ Shut(rv);
+ } else {
+ Shut(net::ERR_UNEXPECTED);
+ }
+ break;
+ }
+ case PHASE_FRAME_OUTSIDE:
+ case PHASE_FRAME_INSIDE:
+ case PHASE_FRAME_LENGTH:
+ case PHASE_FRAME_SKIP: {
+ int rv = ProcessDataFrames(
+ it->io_buf->data(), result,
+ it->io_buf->data(), it->io_buf->BytesRemaining());
+ if (rv < 0) {
+ Shut(rv);
+ return;
+ }
+ if (rv > 0 || phase_ == PHASE_SHUT) {
+ net::CompletionCallback* cb = it->callback;
+ pending_reqs_.erase(it);
+ if (cb)
+ cb->Run(rv);
+ }
+ break;
+ }
+ case PHASE_NYMPH:
+ default: {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ break;
+ }
+ }
+ ConsiderTransportRead();
+ }
+
+ void OnWrite(int result) {
+ if (!is_transport_write_pending_) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+ is_transport_write_pending_ = false;
+
+ if (result < 0) {
+ Shut(result);
+ return;
+ }
+
+ std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE);
+ if (it == pending_reqs_.end() ||
+ it->io_buf == NULL ||
+ it->io_buf->data() == NULL) {
+ NOTREACHED();
+ Shut(net::ERR_UNEXPECTED);
+ return;
+ }
+ DCHECK_LE(result, it->io_buf->BytesRemaining());
+ it->io_buf->DidConsume(result);
+ if (it->io_buf->BytesRemaining() == 0) {
+ net::CompletionCallback* cb = it->callback;
+ int bytes_written = it->io_buf->BytesConsumed();
+ DCHECK_GT(bytes_written, 0);
+ pending_reqs_.erase(it);
+ if (cb)
+ cb->Run(bytes_written);
+ }
+ ConsiderTransportWrite();
+ }
+
+ // Returns (positive) number of consumed bytes on success.
+ // Returns ERR_IO_PENDING in case of incomplete input.
+ // Returns ERR_WS_PROTOCOL_ERROR or ERR_LIMIT_VIOLATION in case of failure to
+ // reasonably parse input.
+ int ProcessHandshake() {
+ static const char kGetPrefix[] = "GET ";
+ static const char kKeyValueDelimiter[] = ": ";
+
+ class Fields {
+ public:
+ bool Has(const std::string& name) {
+ return map_.find(StringToLowerASCII(name)) != map_.end();
+ }
+
+ std::string Get(const std::string& name) {
+ return Has(name) ? map_[StringToLowerASCII(name)] : std::string();
+ }
+
+ void Set(const std::string& name, const std::string& value) {
+ map_[StringToLowerASCII(name)] = StringToLowerASCII(value);
+ }
+
+ private:
+ std::map<std::string, std::string> map_;
+ } fields;
+
+ char* buf = process_handshake_buf_->data();
+ size_t buf_size = fill_handshake_buf_->BytesConsumed();
+
+ if (buf_size < 1)
+ return net::ERR_IO_PENDING;
+ if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)),
+ kGetPrefix)) {
+ // Data head does not match what is expected.
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ if (buf_size >= kHandshakeLimitBytes)
+ return net::ERR_LIMIT_VIOLATION;
+ char* buf_end = buf + buf_size;
+
+ if (buf_size < strlen(kGetPrefix))
+ return net::ERR_IO_PENDING;
+ char* resource_begin = buf + strlen(kGetPrefix);
+ char* resource_end = std::find(resource_begin, buf_end, kSpaceOctet);
+ if (resource_end == buf_end)
+ return net::ERR_IO_PENDING;
+ std::string resource(resource_begin, resource_end);
+ if (!IsStringUTF8(resource) ||
+ resource.find_first_of(kCRLF) != std::string::npos) {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ char* term_pos = std::search(
+ buf, buf_end, kCRLFCRLF, kCRLFCRLF + strlen(kCRLFCRLF));
+ char key3[8]; // Notation (key3) matches websocket RFC.
+ size_t message_len = buf_end - term_pos;
+ if (message_len < sizeof(key3) + strlen(kCRLFCRLF))
+ return net::ERR_IO_PENDING;
+ term_pos += strlen(kCRLFCRLF);
+ memcpy(key3, term_pos, sizeof(key3));
+ term_pos += sizeof(key3);
+ // First line is "GET resource" line, so skip it.
+ char* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF));
+ if (pos == term_pos)
+ return net::ERR_WS_PROTOCOL_ERROR;
+ for (;;) {
+ pos += strlen(kCRLF);
+ if (term_pos - pos <
+ static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ if (term_pos - pos ==
+ static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) {
+ break;
+ }
+ char* next_pos = std::search(
+ pos, term_pos, kKeyValueDelimiter,
+ kKeyValueDelimiter + strlen(kKeyValueDelimiter));
+ if (next_pos == term_pos)
+ return net::ERR_WS_PROTOCOL_ERROR;
+ std::string key(pos, next_pos);
+ if (!IsStringASCII(key) ||
+ key.find_first_of(kCRLF) != std::string::npos) {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ pos = std::search(next_pos += strlen(kKeyValueDelimiter), term_pos,
+ kCRLF, kCRLF + strlen(kCRLF));
+ if (pos == term_pos)
+ return net::ERR_WS_PROTOCOL_ERROR;
+ if (!key.empty()) {
+ std::string value(next_pos, pos);
+ if (!IsStringASCII(value) ||
+ value.find_first_of(kCRLF) != std::string::npos) {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ fields.Set(key, value);
+ }
+ }
+
+ // Values of Upgrade and Connection fields are hardcoded in the protocol.
+ if (fields.Get("Upgrade") != "websocket" ||
+ fields.Get("Connection") != "upgrade") {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ if (fields.Has(kVersionFieldName)) {
+ NOTIMPLEMENTED(); // new protocol.
+ return net::ERR_NOT_IMPLEMENTED;
+ }
+
+ if (!fields.Has(kPlainOriginFieldName))
+ return net::ERR_CONNECTION_REFUSED;
+ // Normalize (e.g. w.r.t. leading slashes) origin.
+ GURL origin = GURL(fields.Get(kPlainOriginFieldName)).GetOrigin();
+ if (!origin.is_valid())
+ return net::ERR_WS_PROTOCOL_ERROR;
+ std::string normalized_origin = origin.spec();
+
+ if (!fields.Has(kPlainHostFieldName))
+ return net::ERR_CONNECTION_REFUSED;
+
+ std::vector<std::string> subprotocol_list;
+ if (fields.Has(kProtocolFieldName)) {
+ int rv = FetchSubprotocolList(
+ fields.Get(kProtocolFieldName), &subprotocol_list);
+ if (rv < 0)
+ return rv;
+ DCHECK(subprotocol_list.end() == std::find(
+ subprotocol_list.begin(), subprotocol_list.end(), ""));
+ }
+
+ std::string location;
+ std::string subprotocol;
+ if (!delegate_->ValidateWebSocket(resource,
+ normalized_origin,
+ fields.Get(kPlainHostFieldName),
+ subprotocol_list,
+ &location,
+ &subprotocol)) {
+ return net::ERR_CONNECTION_REFUSED;
+ }
+ if (subprotocol_list.empty()) {
+ DCHECK(subprotocol.empty());
+ } else {
+ if (!subprotocol.empty()) {
+ if (subprotocol_list.end() == std::find(
+ subprotocol_list.begin(), subprotocol_list.end(), subprotocol)) {
+ NOTREACHED() << "delegate must pick subprotocol from given list";
+ return net::ERR_UNEXPECTED;
+ }
+ }
+ }
+
+ uint32 key_number1 = 0;
+ uint32 key_number2 = 0;
+ if (!FetchDecimalDigits(fields.Get(kKey1FieldName), &key_number1) ||
+ !FetchDecimalDigits(fields.Get(kKey2FieldName), &key_number2)) {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+
+ // We limit incoming header size so following numbers shall not be too high.
+ int spaces1 = CountSpaces(fields.Get(kKey1FieldName));
+ int spaces2 = CountSpaces(fields.Get(kKey2FieldName));
+ if (spaces1 == 0 ||
+ spaces2 == 0 ||
+ key_number1 % spaces1 != 0 ||
+ key_number2 % spaces2 != 0) {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+
+ char challenge[4 + 4 + sizeof(key3)];
+ int32 part1 = htonl(key_number1 / spaces1);
+ int32 part2 = htonl(key_number2 / spaces2);
+ memcpy(challenge, &part1, 4);
+ memcpy(challenge + 4, &part2, 4);
+ memcpy(challenge + 4 + 4, key3, sizeof(key3));
+ MD5Digest challenge_response;
+ MD5Sum(challenge, sizeof(challenge), &challenge_response);
+
+ // Concocting response handshake.
+ class Buffer {
+ public:
+ Buffer()
+ : io_buf_(new net::IOBuffer(kHandshakeLimitBytes)),
+ bytes_written_(0),
+ is_ok_(true) {
+ }
+
+ bool Write(const void* p, int len) {
+ DCHECK(p);
+ DCHECK_GE(len, 0);
+ if (!is_ok_)
+ return false;
+ if (bytes_written_ + len > kHandshakeLimitBytes) {
+ NOTREACHED();
+ is_ok_ = false;
+ return false;
+ }
+ memcpy(io_buf_->data() + bytes_written_, p, len);
+ bytes_written_ += len;
+ return true;
+ }
+
+ bool WriteLine(const char* p) {
+ return Write(p, strlen(p)) && Write(kCRLF, strlen(kCRLF));
+ }
+
+ operator net::DrainableIOBuffer*() {
+ return new net::DrainableIOBuffer(io_buf_, bytes_written_);
+ }
+
+ bool is_ok() { return is_ok_; }
+
+ private:
+ net::IOBuffer* io_buf_;
+ size_t bytes_written_;
+ bool is_ok_;
+ } buffer;
+
+ buffer.WriteLine("HTTP/1.1 101 WebSocket Protocol Handshake");
+ buffer.WriteLine("Upgrade: WebSocket");
+ buffer.WriteLine("Connection: Upgrade");
+
+ {
+ // Take care of Location field.
+ char tmp[2048];
+ int rv = base::snprintf(tmp, sizeof(tmp),
+ "%s: %s",
+ kLocationFieldName,
+ location.c_str());
+ if (rv <= 0 || rv + 0u >= sizeof(tmp))
+ return net::ERR_LIMIT_VIOLATION;
+ buffer.WriteLine(tmp);
+ }
+ {
+ // Take care of Origin field.
+ char tmp[2048];
+ int rv = base::snprintf(tmp, sizeof(tmp),
+ "%s: %s",
+ kOriginFieldName,
+ fields.Get(kPlainOriginFieldName).c_str());
+ if (rv <= 0 || rv + 0u >= sizeof(tmp))
+ return net::ERR_LIMIT_VIOLATION;
+ buffer.WriteLine(tmp);
+ }
+ if (!subprotocol.empty()) {
+ char tmp[2048];
+ int rv = base::snprintf(tmp, sizeof(tmp),
+ "%s: %s",
+ kProtocolFieldName,
+ subprotocol.c_str());
+ if (rv <= 0 || rv + 0u >= sizeof(tmp))
+ return net::ERR_LIMIT_VIOLATION;
+ buffer.WriteLine(tmp);
+ }
+ buffer.WriteLine("");
+ buffer.Write(&challenge_response, sizeof(challenge_response));
+
+ if (!buffer.is_ok())
+ return net::ERR_LIMIT_VIOLATION;
+
+ pending_reqs_.push_back(PendingReq(
+ PendingReq::TYPE_WRITE_METADATA, buffer, NULL));
+ DCHECK_GT(term_pos - buf, 0);
+ return term_pos - buf;
+ }
+
+ // Removes frame delimiters and returns net number of data bytes (or error).
+ // |out| may be equal to |buf|, in that case it is in-place operation.
+ int ProcessDataFrames(char* buf, int buf_len, char* out, int out_len) {
+ if (out_len < buf_len) {
+ NOTREACHED();
+ return net::ERR_UNEXPECTED;
+ }
+ int out_pos = 0;
+ for (char* p = buf; p < buf + buf_len; ++p) {
+ switch (phase_) {
+ case PHASE_FRAME_INSIDE: {
+ if (*p == '\x00')
+ return net::ERR_WS_PROTOCOL_ERROR;
+ if (*p == '\xff')
+ phase_ = PHASE_FRAME_OUTSIDE;
+ else
+ out[out_pos++] = *p;
+ break;
+ }
+ case PHASE_FRAME_OUTSIDE: {
+ if (*p == '\x00') {
+ phase_ = PHASE_FRAME_INSIDE;
+ } else if (*p == '\xff') {
+ phase_ = PHASE_FRAME_LENGTH;
+ frame_bytes_remaining_ = 0;
+ }
+ else {
+ return net::ERR_WS_PROTOCOL_ERROR;
+ }
+ break;
+ }
+ case PHASE_FRAME_LENGTH: {
+ static const int kValueBits = 7;
+ static const char kValueMask = (1 << kValueBits) - 1;
+ frame_bytes_remaining_ <<= kValueBits;
+ frame_bytes_remaining_ += (*p & kValueMask);
+ if (*p & ~kValueMask) {
+ // Check that next byte would not overflow.
+ if (frame_bytes_remaining_ >
+ (std::numeric_limits<int>::max() - ((1 << 7) - 1)) >> 7) {
+ return net::ERR_LIMIT_VIOLATION;
+ }
+ } else {
+ if (frame_bytes_remaining_ == 0) {
+ phase_ = PHASE_SHUT;
+ return out_pos;
+ } else {
+ phase_ = PHASE_FRAME_SKIP;
+ }
+ }
+ break;
+ }
+ case PHASE_FRAME_SKIP: {
+ DCHECK_GE(frame_bytes_remaining_, 1);
+ frame_bytes_remaining_ -= 1;
+ if (frame_bytes_remaining_ < 1)
+ phase_ = PHASE_FRAME_OUTSIDE;
+ break;
+ }
+ default: {
+ NOTREACHED();
+ }
+ }
+ }
+ return out_pos;
+ }
+
+ // State machinery.
+ Phase phase_;
+
+ // Counts frame length for PHASE_FRAME_LENGTH and PHASE_FRAME_SKIP.
+ int frame_bytes_remaining_;
+
+ // Underlying socket.
+ scoped_ptr<net::Socket> transport_socket_;
+
+ // Validation is performed via delegate.
+ Delegate* delegate_;
+
+ // IOBuffer used to communicate with transport at initial stage.
+ scoped_refptr<net::IOBuffer> handshake_buf_;
+ scoped_refptr<net::DrainableIOBuffer> fill_handshake_buf_;
+ scoped_refptr<net::DrainableIOBuffer> process_handshake_buf_;
+
+ // Pending io requests we need to complete.
+ std::deque<PendingReq> pending_reqs_;
+
+ // Callbacks from transport to us.
+ scoped_ptr<net::CompletionCallback> transport_read_callback_;
+ scoped_ptr<net::CompletionCallback> transport_write_callback_;
+
+ // Whether transport requests are pending.
+ bool is_transport_read_pending_;
+ bool is_transport_write_pending_;
+
+ ScopedRunnableMethodFactory<WebSocketServerSocketImpl> method_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketServerSocketImpl);
+};
+
+} // namespace
+
+namespace net {
+
+WebSocketServerSocket* CreateWebSocketServerSocket(
+ Socket* transport_socket, WebSocketServerSocket::Delegate* delegate) {
+ return new WebSocketServerSocketImpl(transport_socket, delegate);
+}
+
+WebSocketServerSocket::~WebSocketServerSocket() {
+}
+
+} // namespace net;
diff --git a/net/socket/web_socket_server_socket.h b/net/socket/web_socket_server_socket.h
new file mode 100644
index 0000000..c28c58d
--- /dev/null
+++ b/net/socket/web_socket_server_socket.h
@@ -0,0 +1,63 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_WEB_SOCKET_SERVER_SOCKET_H_
+#define NET_SOCKET_WEB_SOCKET_SERVER_SOCKET_H_
+
+#include <string>
+#include <vector>
+
+#include "net/base/net_api.h"
+#include "net/socket/socket.h"
+
+namespace net {
+
+// WebSocketServerSocket takes an (already connected) underlying transport
+// socket and speaks server-side websocket protocol atop of it.
+// This class implements Socket interface: notice that Read() returns (or calls
+// back) as soon as some amount of data is available, even if message/frame is
+// incomplete.
+class WebSocketServerSocket : public Socket {
+ public:
+ class Delegate {
+ public:
+ // Validates websocket handshake: return false to reject handshake.
+ // |resource| is name of resource requested in GET stanza of handshake;
+ // |origin| is origin as reported in handshake;
+ // |host| is Host field from handshake;
+ // |subprotocol_list| is derived from Sec-WebSocket-Protocol field.
+ // Output parameters are:
+ // |location_out| is location of websocket server;
+ // |subprotocol_out| is selected subprotocol (or empty string if subprotocol
+ // list is empty.
+ virtual bool ValidateWebSocket(
+ const std::string& resource,
+ const std::string& origin,
+ const std::string& host,
+ const std::vector<std::string>& subprotocol_list,
+ std::string* location_out,
+ std::string* subprotocol_out) = 0;
+
+ virtual ~Delegate() {}
+ };
+
+ virtual ~WebSocketServerSocket();
+
+ // Performs websocket server handshake on transport socket. Underlying socket
+ // must have already been connected/accepted.
+ //
+ // Returns either ERR_IO_PENDING, in which case the given callback will be
+ // called in the future with the real result, or it completes synchronously,
+ // returning the result immediately.
+ virtual int Accept(CompletionCallback* callback) = 0;
+};
+
+// Creates websocket server socket atop of already connected socket. This
+// created server socket will take ownership of |transport_socket|.
+NET_API WebSocketServerSocket* CreateWebSocketServerSocket(
+ Socket* transport_socket, WebSocketServerSocket::Delegate* delegate);
+
+} // namespace net
+
+#endif // NET_SOCKET_WEB_SOCKET_SERVER_SOCKET_H_
diff --git a/net/socket/web_socket_server_socket_unittest.cc b/net/socket/web_socket_server_socket_unittest.cc
new file mode 100644
index 0000000..20e3dcf
--- /dev/null
+++ b/net/socket/web_socket_server_socket_unittest.cc
@@ -0,0 +1,596 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/web_socket_server_socket.h"
+
+#include <stdlib.h>
+#include <algorithm>
+
+#include "base/callback_old.h"
+#include "base/memory/ref_counted.h"
+#include "base/message_loop.h"
+#include "base/string_util.h"
+#include "base/task.h"
+#include "base/time.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace {
+
+const char* kSampleHandshakeRequest[] = {
+ "GET /demo HTTP/1.1",
+ "Upgrade: WebSocket",
+ "Connection: Upgrade",
+ "Host: example.com",
+ "Origin: http://example.com",
+ "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5",
+ "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00",
+ "",
+ "^n:ds[4U"
+};
+
+const char kSampleHandshakeAnswer[] =
+ "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Location: ws://example.com/demo\r\n"
+ "Sec-WebSocket-Origin: http://example.com\r\n"
+ "\r\n"
+ "8jKS'y:G*Co,Wxa-";
+
+const int kHandshakeBufBytes = 1 << 12;
+
+const char kCRLF[] = "\r\n";
+const char kCRLFCRLF[] = "\r\n\r\n";
+const char kSpaceOctet = '\x20';
+
+const int kReadSalt = 7;
+const int kWriteSalt = 5;
+
+int GetRand(int min, int max) {
+ CHECK(max >= min);
+ CHECK(max - min < RAND_MAX);
+ return rand() % (max - min + 1) + min;
+}
+
+class RandIntClass {
+ public:
+ int operator() (int range) {
+ return GetRand(0, range - 1);
+ }
+} g_rand;
+
+net::DrainableIOBuffer* ResizeIOBuffer(net::DrainableIOBuffer* buf, int len) {
+ net::DrainableIOBuffer* rv = new net::DrainableIOBuffer(
+ new net::IOBuffer(len), len);
+ std::copy(buf->data(), buf->data() + std::min(len, buf->BytesRemaining()),
+ rv->data());
+ return rv;
+}
+
+// TODO(dilmah): consider switching to socket_test_util.h
+// Simulates reading from |sample| stream; data supplied in Write() calls are
+// stored in |answer| buffer.
+class TestingTransportSocket : public net::Socket {
+ public:
+ TestingTransportSocket(
+ net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer)
+ : sample_(sample),
+ answer_(answer),
+ final_read_callback_(NULL),
+ method_factory_(this) {
+ }
+
+ ~TestingTransportSocket() {
+ if (final_read_callback_) {
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &TestingTransportSocket::DoReadCallback,
+ final_read_callback_, 0));
+ }
+ }
+
+ // Socket implementation.
+ virtual int Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) {
+ CHECK_GT(buf_len, 0);
+ int remaining = sample_->BytesRemaining();
+ if (remaining < 1) {
+ if (final_read_callback_)
+ return 0;
+ final_read_callback_ = callback;
+ return net::ERR_IO_PENDING;
+ }
+ int lot = GetRand(1, std::min(remaining, buf_len));
+ std::copy(sample_->data(), sample_->data() + lot, buf->data());
+ sample_->DidConsume(lot);
+ if (GetRand(0, 1)) {
+ return lot;
+ }
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &TestingTransportSocket::DoReadCallback, callback, lot));
+ return net::ERR_IO_PENDING;
+ }
+
+ virtual int Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) {
+ CHECK_GT(buf_len, 0);
+ int remaining = answer_->BytesRemaining();
+ CHECK_GE(remaining, buf_len);
+ int lot = std::min(remaining, buf_len);
+ if (GetRand(0, 1))
+ lot = GetRand(1, lot);
+ std::copy(buf->data(), buf->data() + lot, answer_->data());
+ answer_->DidConsume(lot);
+ if (GetRand(0, 1)) {
+ return lot;
+ }
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &TestingTransportSocket::DoWriteCallback, callback, lot));
+ return net::ERR_IO_PENDING;
+ }
+
+ virtual bool SetReceiveBufferSize(int32 size) {
+ return true;
+ }
+
+ virtual bool SetSendBufferSize(int32 size) {
+ return true;
+ }
+
+ net::DrainableIOBuffer* answer() { return answer_.get(); }
+
+ void DoReadCallback(net::CompletionCallback* callback, int result) {
+ if (result == 0 && !is_closed_) {
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &TestingTransportSocket::DoReadCallback, callback, 0));
+ } else {
+ if (callback)
+ callback->Run(result);
+ }
+ }
+
+ void DoWriteCallback(net::CompletionCallback* callback, int result) {
+ if (callback)
+ callback->Run(result);
+ }
+
+ bool is_closed_;
+
+ // Data to return for Read requests.
+ scoped_refptr<net::DrainableIOBuffer> sample_;
+
+ // Data pushed to us by server socket (using Write calls).
+ scoped_refptr<net::DrainableIOBuffer> answer_;
+
+ // Final read callback to report zero (zero stands for EOF).
+ net::CompletionCallback* final_read_callback_;
+
+ ScopedRunnableMethodFactory<TestingTransportSocket> method_factory_;
+};
+
+class Validator : public net::WebSocketServerSocket::Delegate {
+ public:
+ Validator(const std::string& resource,
+ const std::string& origin,
+ const std::string& host)
+ : resource_(resource), origin_(origin), host_(host) {
+ }
+
+ // WebSocketServerSocket::Delegate implementation.
+ virtual bool ValidateWebSocket(
+ const std::string& resource,
+ const std::string& origin,
+ const std::string& host,
+ const std::vector<std::string>& subprotocol_list,
+ std::string* location_out,
+ std::string* subprotocol_out) {
+ if (resource != resource_ || origin != origin_ || host != host_)
+ return false;
+ if (!subprotocol_list.empty())
+ *subprotocol_out = subprotocol_list.front();
+
+ char tmp[2048];
+ base::snprintf(
+ tmp, sizeof(tmp), "ws://%s%s", host.c_str(), resource.c_str());
+ location_out->assign(tmp);
+ return true;
+ }
+
+ private:
+ std::string resource_;
+ std::string origin_;
+ std::string host_;
+};
+
+char ReferenceSeq(unsigned n, unsigned salt) {
+ return (salt * 2 + n * 3) % ('z' - 'a') + 'a';
+}
+
+class ReadWriteTracker {
+ public:
+ ReadWriteTracker(
+ net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write)
+ : ws_(ws),
+ buf_size_(1 << 14),
+ accept_callback_(NewCallback(this, &ReadWriteTracker::OnAccept)),
+ read_callback_(NewCallback(this, &ReadWriteTracker::OnRead)),
+ write_callback_(NewCallback(this, &ReadWriteTracker::OnWrite)),
+ read_buf_(new net::IOBuffer(buf_size_)),
+ write_buf_(new net::IOBuffer(buf_size_)),
+ bytes_remaining_to_read_(bytes_to_read),
+ bytes_remaining_to_write_(bytes_to_write),
+ read_initiated_(false),
+ write_initiated_(false),
+ got_final_zero_(false) {
+ int rv = ws_->Accept(accept_callback_.get());
+ if (rv != net::ERR_IO_PENDING)
+ OnAccept(rv);
+ }
+
+ ~ReadWriteTracker() {
+ CHECK_EQ(bytes_remaining_to_write_, 0);
+ CHECK_EQ(bytes_remaining_to_read_, 0);
+ }
+
+ void OnAccept(int result) {
+ ASSERT_EQ(result, 0);
+ if (GetRand(0, 1)) {
+ DoRead();
+ DoWrite();
+ } else {
+ DoWrite();
+ DoRead();
+ }
+ }
+
+ void DoWrite() {
+ if (bytes_remaining_to_write_ < 1)
+ return;
+ int lot = GetRand(1, bytes_remaining_to_write_);
+ lot = std::min(lot, buf_size_);
+ for (int i = 0; i < lot; ++i)
+ write_buf_->data()[i] = ReferenceSeq(
+ bytes_remaining_to_write_ - i - 1, kWriteSalt);
+ int rv = ws_->Write(write_buf_, lot, write_callback_.get());
+ if (rv != net::ERR_IO_PENDING)
+ OnWrite(rv);
+ }
+
+ void DoRead() {
+ int lot = GetRand(1, buf_size_);
+ if (bytes_remaining_to_read_ < 1) {
+ if (got_final_zero_)
+ return;
+ } else {
+ lot = GetRand(1, bytes_remaining_to_read_);
+ lot = std::min(lot, buf_size_);
+ }
+ int rv = ws_->Read(read_buf_, lot, read_callback_.get());
+ if (rv != net::ERR_IO_PENDING)
+ OnRead(rv);
+ }
+
+ void OnWrite(int result) {
+ ASSERT_GT(result, 0);
+ ASSERT_LE(result, bytes_remaining_to_write_);
+ bytes_remaining_to_write_ -= result;
+ DoWrite();
+ }
+
+ void OnRead(int result) {
+ ASSERT_LE(result, bytes_remaining_to_read_);
+ if (bytes_remaining_to_read_ < 1) {
+ ASSERT_FALSE(got_final_zero_);
+ ASSERT_EQ(result, 0);
+ got_final_zero_ = true;
+ return;
+ }
+ for (int i = 0; i < result; ++i) {
+ ASSERT_EQ(read_buf_->data()[i], ReferenceSeq(
+ bytes_remaining_to_read_ - i - 1, kReadSalt));
+ }
+ bytes_remaining_to_read_ -= result;
+ DoRead();
+ }
+
+ private:
+ net::WebSocketServerSocket* const ws_;
+ int const buf_size_;
+ scoped_ptr<net::CompletionCallback> accept_callback_;
+ scoped_ptr<net::CompletionCallback> read_callback_;
+ scoped_ptr<net::CompletionCallback> write_callback_;
+ scoped_refptr<net::IOBuffer> read_buf_;
+ scoped_refptr<net::IOBuffer> write_buf_;
+ int bytes_remaining_to_read_;
+ int bytes_remaining_to_write_;
+ bool read_initiated_;
+ bool write_initiated_;
+ bool got_final_zero_;
+};
+
+} // namespace
+
+namespace net {
+
+class WebSocketServerSocketTest : public testing::Test {
+ public:
+ virtual ~WebSocketServerSocketTest() {
+ }
+
+ virtual void SetUp() {
+ count_ = 0;
+ accept_callback_[0].reset(NewCallback<WebSocketServerSocketTest, int>(
+ this, &WebSocketServerSocketTest::OnAccept0));
+ accept_callback_[1].reset(NewCallback<WebSocketServerSocketTest, int>(
+ this, &WebSocketServerSocketTest::OnAccept1));
+ }
+
+ virtual void TearDown() {
+ }
+
+ void OnAccept0(int result) {
+ ASSERT_EQ(result, 0);
+ ASSERT_LT(count_, 99999);
+ count_ += 1;
+ }
+
+ void OnAccept1(int result) {
+ ASSERT_TRUE(result == ERR_CONNECTION_REFUSED ||
+ result == ERR_ACCESS_DENIED);
+ ASSERT_LT(count_, 99999);
+ count_ += 1;
+ }
+
+ int count_;
+ scoped_ptr<net::CompletionCallback> accept_callback_[2];
+};
+
+TEST_F(WebSocketServerSocketTest, Handshake) {
+ srand(2523456);
+ std::vector<Socket*> kill_list;
+ std::vector< scoped_refptr<DrainableIOBuffer> > answer_list;
+ Validator validator("/demo", "http://example.com/", "example.com");
+ count_ = 0;
+ const int kNumTests = 300;
+ for (int run = kNumTests; run--;) {
+ scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
+ new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
+ for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
+ std::copy(kSampleHandshakeRequest[i],
+ kSampleHandshakeRequest[i] + strlen(kSampleHandshakeRequest[i]),
+ sample->data());
+ sample->DidConsume(strlen(kSampleHandshakeRequest[i]));
+ if (i != arraysize(kSampleHandshakeRequest) - 1) {
+ std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
+ sample->DidConsume(strlen(kCRLF));
+ }
+ }
+ int sample_len = sample->BytesConsumed();
+ sample->SetOffset(0);
+ DrainableIOBuffer* answer = new DrainableIOBuffer(
+ new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
+ answer_list.push_back(answer);
+ TestingTransportSocket* transport = new TestingTransportSocket(
+ ResizeIOBuffer(sample.get(), sample_len), answer);
+ WebSocketServerSocket* ws = CreateWebSocketServerSocket(
+ transport, &validator);
+ ASSERT_TRUE(ws != NULL);
+ kill_list.push_back(ws);
+
+ int rv = ws->Accept(accept_callback_[0].get());
+ if (rv != ERR_IO_PENDING)
+ OnAccept0(rv);
+ }
+ MessageLoop::current()->RunAllPending();
+ ASSERT_EQ(count_, kNumTests);
+ for (size_t i = answer_list.size(); i--;) {
+ ASSERT_EQ(answer_list[i]->BytesConsumed() + 0u,
+ strlen(kSampleHandshakeAnswer));
+ ASSERT_TRUE(std::equal(
+ answer_list[i]->data() - answer_list[i]->BytesConsumed(),
+ answer_list[i]->data(), kSampleHandshakeAnswer));
+ }
+ for (size_t i = kill_list.size(); i--;)
+ delete kill_list[i];
+ MessageLoop::current()->RunAllPending();
+}
+
+TEST_F(WebSocketServerSocketTest, BadCred) {
+ srand(9034958);
+ std::vector<Socket*> kill_list;
+ std::vector< scoped_refptr<DrainableIOBuffer> > answer_list;
+ Validator *validator[] = {
+ new Validator("/demo", "http://gooogle.com/", "example.com"),
+ new Validator("/tcpproxy", "http://example.com/", "example.com"),
+ new Validator("/tcpproxy", "http://gooogle.com/", "example.com"),
+ new Validator("/demo", "http://example.com/", "exmple.com"),
+ new Validator("/demo", "http://gooogle.com/", "gooogle.com")
+ };
+ count_ = 0;
+ for (int run = arraysize(validator); run--;) {
+ scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
+ new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
+ for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
+ std::copy(kSampleHandshakeRequest[i],
+ kSampleHandshakeRequest[i] + strlen(kSampleHandshakeRequest[i]),
+ sample->data());
+ sample->DidConsume(strlen(kSampleHandshakeRequest[i]));
+ if (i != arraysize(kSampleHandshakeRequest) - 1) {
+ std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
+ sample->DidConsume(strlen(kCRLF));
+ }
+ }
+ int sample_len = sample->BytesConsumed();
+ sample->SetOffset(0);
+ DrainableIOBuffer* answer = new DrainableIOBuffer(
+ new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
+ answer_list.push_back(answer);
+ TestingTransportSocket* transport = new TestingTransportSocket(
+ ResizeIOBuffer(sample.get(), sample_len), answer);
+ WebSocketServerSocket* ws = CreateWebSocketServerSocket(
+ transport, validator[run]);
+ ASSERT_TRUE(ws != NULL);
+ kill_list.push_back(ws);
+
+ int rv = ws->Accept(accept_callback_[1].get());
+ if (rv != ERR_IO_PENDING)
+ OnAccept1(rv);
+ }
+ MessageLoop::current()->RunAllPending();
+ ASSERT_EQ(count_ + 0u, arraysize(validator));
+ for (size_t i = answer_list.size(); i--;)
+ ASSERT_EQ(answer_list[i]->BytesConsumed(), 0);
+ for (size_t i = kill_list.size(); i--;)
+ delete kill_list[i];
+ for (size_t i = arraysize(validator); i--;)
+ delete validator[i];
+ MessageLoop::current()->RunAllPending();
+}
+
+TEST_F(WebSocketServerSocketTest, ReorderedHandshake) {
+ srand(205643459);
+ std::vector<Socket*> kill_list;
+ std::vector< scoped_refptr<DrainableIOBuffer> > answer_list;
+ Validator validator("/demo", "http://example.com/", "example.com");
+ count_ = 0;
+ const int kNumTests = 200;
+ for (int run = kNumTests; run--;) {
+ scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
+ new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
+
+ std::vector<size_t> fields_order;
+ for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i)
+ fields_order.push_back(i);
+ // One leading and two trailing lines of request are special, leave them.
+ std::random_shuffle(fields_order.begin() + 1,
+ fields_order.begin() + fields_order.size() - 3,
+ g_rand);
+
+ for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
+ size_t j = fields_order[i];
+ std::copy(kSampleHandshakeRequest[j],
+ kSampleHandshakeRequest[j] + strlen(kSampleHandshakeRequest[j]),
+ sample->data());
+ sample->DidConsume(strlen(kSampleHandshakeRequest[j]));
+ if (i != arraysize(kSampleHandshakeRequest) - 1) {
+ std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
+ sample->DidConsume(strlen(kCRLF));
+ }
+ }
+ int sample_len = sample->BytesConsumed();
+ sample->SetOffset(0);
+ DrainableIOBuffer* answer = new DrainableIOBuffer(
+ new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
+ answer_list.push_back(answer);
+ TestingTransportSocket* transport = new TestingTransportSocket(
+ ResizeIOBuffer(sample.get(), sample_len), answer);
+ WebSocketServerSocket* ws = CreateWebSocketServerSocket(
+ transport, &validator);
+ ASSERT_TRUE(ws != NULL);
+ kill_list.push_back(ws);
+
+ int rv = ws->Accept(accept_callback_[0].get());
+ if (rv != ERR_IO_PENDING)
+ OnAccept0(rv);
+ }
+ MessageLoop::current()->RunAllPending();
+ ASSERT_EQ(count_, kNumTests);
+ for (size_t i = answer_list.size(); i--;) {
+ ASSERT_EQ(answer_list[i]->BytesConsumed() + 0u,
+ strlen(kSampleHandshakeAnswer));
+ ASSERT_TRUE(std::equal(
+ answer_list[i]->data() - answer_list[i]->BytesConsumed(),
+ answer_list[i]->data(), kSampleHandshakeAnswer));
+ }
+ for (size_t i = kill_list.size(); i--;)
+ delete kill_list[i];
+ MessageLoop::current()->RunAllPending();
+}
+
+TEST_F(WebSocketServerSocketTest, ConveyData) {
+ srand(8234523);
+ std::vector<Socket*> kill_list;
+ std::vector<ReadWriteTracker*> tracker_list;
+ Validator validator("/demo", "http://example.com/", "example.com");
+ count_ = 0;
+ const int kNumTests = 150;
+ for (int run = kNumTests; run--;) {
+ int bytes_to_read = GetRand(1, 1 << 14);
+ int bytes_to_write = GetRand(1, 1 << 14);
+ int frames_limit = GetRand(1, 1 << 10);
+ int sample_limit = kHandshakeBufBytes + bytes_to_write + frames_limit * 2;
+ scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
+ new IOBuffer(sample_limit), sample_limit);
+
+ std::vector<size_t> fields_order;
+ for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i)
+ fields_order.push_back(i);
+ // One leading and two trailing lines of request are special, leave them.
+ std::random_shuffle(fields_order.begin() + 1,
+ fields_order.begin() + fields_order.size() - 3,
+ g_rand);
+
+ for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
+ size_t j = fields_order[i];
+ std::copy(kSampleHandshakeRequest[j],
+ kSampleHandshakeRequest[j] + strlen(kSampleHandshakeRequest[j]),
+ sample->data());
+ sample->DidConsume(strlen(kSampleHandshakeRequest[j]));
+ if (i != arraysize(kSampleHandshakeRequest) - 1) {
+ std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
+ sample->DidConsume(strlen(kCRLF));
+ }
+ }
+ {
+ bool outside_frame = true;
+ int pos = 0;
+ for (int i = 0; i < bytes_to_write; ++i) {
+ if (outside_frame) {
+ sample->data()[pos++] = '\x00';
+ outside_frame = false;
+ CHECK_GE(frames_limit, 1);
+ frames_limit -= 1;
+ }
+ sample->data()[pos++] = ReferenceSeq(bytes_to_write - i - 1, kReadSalt);
+ if ((frames_limit > 1 &&
+ GetRand(0, 1 + (bytes_to_write - i) / frames_limit) == 0) ||
+ i == bytes_to_write - 1) {
+ sample->data()[pos++] = '\xff';
+ outside_frame = true;
+ }
+ }
+ sample->DidConsume(pos);
+ }
+
+ int sample_len = sample->BytesConsumed();
+ sample->SetOffset(0);
+ int answer_limit = kHandshakeBufBytes + bytes_to_read * 3;
+ DrainableIOBuffer* answer = new DrainableIOBuffer(
+ new IOBuffer(answer_limit), answer_limit);
+ TestingTransportSocket* transport = new TestingTransportSocket(
+ ResizeIOBuffer(sample.get(), sample_len), answer);
+ WebSocketServerSocket* ws = CreateWebSocketServerSocket(
+ transport, &validator);
+ ASSERT_TRUE(ws != NULL);
+ kill_list.push_back(ws);
+
+ ReadWriteTracker* tracker = new ReadWriteTracker(
+ ws, bytes_to_write, bytes_to_read);
+ tracker_list.push_back(tracker);
+ }
+ MessageLoop::current()->RunAllPending();
+
+ for (size_t i = kill_list.size(); i--;)
+ delete kill_list[i];
+ for (size_t i = tracker_list.size(); i--;)
+ delete tracker_list[i];
+ MessageLoop::current()->RunAllPending();
+}
+
+} // namespace net