diff options
author | dilmah@chromium.org <dilmah@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-07-19 20:33:43 +0000 |
---|---|---|
committer | dilmah@chromium.org <dilmah@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-07-19 20:33:43 +0000 |
commit | 38dc684dae9a03cd0a1642581d618e39f813013c (patch) | |
tree | b00197013d79cfebe690d77d5dfb46a55c123a3f /net | |
parent | 3b9448e6acb5fd762afc0c3b3957a65e9d3ef092 (diff) | |
download | chromium_src-38dc684dae9a03cd0a1642581d618e39f813013c.zip chromium_src-38dc684dae9a03cd0a1642581d618e39f813013c.tar.gz chromium_src-38dc684dae9a03cd0a1642581d618e39f813013c.tar.bz2 |
Implementation of server socket for websockets protocol.
It is prerequisite to refactoring websocket-to-tcp proxy (existing on ChromeOS).
BUG=chromium-os:15533
TEST=unittest
Committed: http://src.chromium.org/viewvc/chrome?view=rev&revision=93010
Review URL: http://codereview.chromium.org/7131009
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@93089 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net')
-rw-r--r-- | net/base/io_buffer.h | 2 | ||||
-rw-r--r-- | net/base/net_error_list.h | 9 | ||||
-rw-r--r-- | net/net.gyp | 3 | ||||
-rw-r--r-- | net/socket/nss_ssl_util.h | 2 | ||||
-rw-r--r-- | net/socket/socket.h | 12 | ||||
-rw-r--r-- | net/socket/web_socket_server_socket.cc | 905 | ||||
-rw-r--r-- | net/socket/web_socket_server_socket.h | 63 | ||||
-rw-r--r-- | net/socket/web_socket_server_socket_unittest.cc | 596 |
8 files changed, 1581 insertions, 11 deletions
diff --git a/net/base/io_buffer.h b/net/base/io_buffer.h index 1ec7869..e117cce 100644 --- a/net/base/io_buffer.h +++ b/net/base/io_buffer.h @@ -128,7 +128,7 @@ class NET_API PickledIOBuffer : public IOBuffer { Pickle* pickle() { return &pickle_; } - // Signals that we are done writing to the picke and we can use it for a + // Signals that we are done writing to the pickle and we can use it for a // write-style IO operation. void Done(); diff --git a/net/base/net_error_list.h b/net/base/net_error_list.h index cff5a79..9596a9a 100644 --- a/net/base/net_error_list.h +++ b/net/base/net_error_list.h @@ -239,6 +239,12 @@ NET_ERROR(MSG_TOO_BIG, -142) // See also: ESET_ANTI_VIRUS_SSL_INTERCEPTION NET_ERROR(KASPERSKY_ANTI_VIRUS_SSL_INTERCEPTION, -143) +// Violation of limits (e.g. imposed to prevent DoS). +NET_ERROR(LIMIT_VIOLATION, -144) + +// WebSocket protocol error occurred. +NET_ERROR(WS_PROTOCOL_ERROR, -145) + // Connection was aborted for switching to another ptotocol. // WebSocket abort SocketStream connection when alternate protocol is found. NET_ERROR(PROTOCOL_SWITCHED, -146) @@ -246,9 +252,6 @@ NET_ERROR(PROTOCOL_SWITCHED, -146) // Returned when attempting to bind an address that is already in use. NET_ERROR(ADDRESS_IN_USE, -147) -// NOTE: error codes 144-145 are available, please use those before adding -// 148. - // Certificate error codes // // The values of certificate error codes must be consecutive. diff --git a/net/net.gyp b/net/net.gyp index f9e62fa..724e24b 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -544,6 +544,8 @@ 'socket/tcp_server_socket_win.h', 'socket/transport_client_socket_pool.cc', 'socket/transport_client_socket_pool.h', + 'socket/web_socket_server_socket.cc', + 'socket/web_socket_server_socket.h', 'socket_stream/socket_stream.cc', 'socket_stream/socket_stream.h', 'socket_stream/socket_stream_job.cc', @@ -1005,6 +1007,7 @@ 'socket/tcp_server_socket_unittest.cc', 'socket/transport_client_socket_pool_unittest.cc', 'socket/transport_client_socket_unittest.cc', + 'socket/web_socket_server_socket_unittest.cc', 'socket_stream/socket_stream_metrics_unittest.cc', 'socket_stream/socket_stream_unittest.cc', 'spdy/spdy_framer_test.cc', diff --git a/net/socket/nss_ssl_util.h b/net/socket/nss_ssl_util.h index 2a53fc7..614ab5f 100644 --- a/net/socket/nss_ssl_util.h +++ b/net/socket/nss_ssl_util.h @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -// This file is only inclued in ssl_client_socket_nss.cc and +// This file is only included in ssl_client_socket_nss.cc and // ssl_server_socket_nss.cc to share common functions of NSS. #ifndef NET_SOCKET_NSS_SSL_UTIL_H_ diff --git a/net/socket/socket.h b/net/socket/socket.h index 17d9f12..450f42d 100644 --- a/net/socket/socket.h +++ b/net/socket/socket.h @@ -18,8 +18,8 @@ class NET_API Socket { public: virtual ~Socket() {} - // Reads data, up to buf_len bytes, from the socket. The number of bytes read - // is returned, or an error is returned upon failure. + // Reads data, up to |buf_len| bytes, from the socket. The number of bytes + // read is returned, or an error is returned upon failure. // ERR_SOCKET_NOT_CONNECTED should be returned if the socket is not currently // connected. Zero is returned once to indicate end-of-file; the return value // of subsequent calls is undefined, and may be OS dependent. ERR_IO_PENDING @@ -32,8 +32,8 @@ class NET_API Socket { virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback) = 0; - // Writes data, up to buf_len bytes, to the socket. Note: only part of the - // data may be written! The number of bytes written is returned, or an error + // Writes data, up to |buf_len| bytes, to the socket. Note: data may be + // written partially. The number of bytes written is returned, or an error // is returned upon failure. ERR_SOCKET_NOT_CONNECTED should be returned if // the socket is not currently connected. The return value when the // connection is closed is undefined, and may be OS dependent. ERR_IO_PENDING @@ -48,12 +48,12 @@ class NET_API Socket { CompletionCallback* callback) = 0; // Set the receive buffer size (in bytes) for the socket. - // Note: changing this value can effect the TCP window size on some platforms. + // Note: changing this value can affect the TCP window size on some platforms. // Returns true on success, or false on failure. virtual bool SetReceiveBufferSize(int32 size) = 0; // Set the send buffer size (in bytes) for the socket. - // Note: changing this value can effect the TCP window size on some platforms. + // Note: changing this value can affect the TCP window size on some platforms. // Returns true on success, or false on failure. virtual bool SetSendBufferSize(int32 size) = 0; }; diff --git a/net/socket/web_socket_server_socket.cc b/net/socket/web_socket_server_socket.cc new file mode 100644 index 0000000..a793097 --- /dev/null +++ b/net/socket/web_socket_server_socket.cc @@ -0,0 +1,905 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/web_socket_server_socket.h" + +#include <algorithm> +#include <deque> +#include <limits> +#include <map> +#include <vector> + +#if defined(OS_WIN) +#include <winsock2.h> // for htonl +#else +#include <arpa/inet.h> +#endif + +#include "base/basictypes.h" +#include "base/logging.h" +#include "base/md5.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" +#include "base/string_util.h" +#include "base/task.h" +#include "googleurl/src/gurl.h" +#include "net/base/completion_callback.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" + +namespace { + +const size_t kHandshakeLimitBytes = 1 << 14; + +const char kCrOctet = '\r'; +COMPILE_ASSERT(kCrOctet == '\x0d', ASCII); +const char kLfOctet = '\n'; +COMPILE_ASSERT(kLfOctet == '\x0a', ASCII); +const char kSpaceOctet = ' '; +COMPILE_ASSERT(kSpaceOctet == '\x20', ASCII); +const char kCommaOctet = ','; +COMPILE_ASSERT(kCommaOctet == '\x2c', ASCII); + +const char kCRLF[] = { kCrOctet, kLfOctet, 0 }; +const char kCRLFCRLF[] = { kCrOctet, kLfOctet, kCrOctet, kLfOctet, 0 }; + +const char kPlainHostFieldName[] = "Host"; +const char kPlainOriginFieldName[] = "Origin"; +const char kOriginFieldName[] = "Sec-WebSocket-Origin"; +const char kProtocolFieldName[] = "Sec-WebSocket-Protocol"; +const char kVersionFieldName[] = "Sec-WebSocket-Version"; +const char kLocationFieldName[] = "Sec-WebSocket-Location"; +const char kKey1FieldName[] = "Sec-WebSocket-Key1"; +const char kKey2FieldName[] = "Sec-WebSocket-Key2"; + +int CountSpaces(const std::string& s) { + return std::count(s.begin(), s.end(), kSpaceOctet); +} + +// Returns true on success. +bool FetchDecimalDigits(const std::string& s, uint32* result) { + *result = 0; + bool got_something = false; + for (size_t i = 0; i < s.size(); ++i) { + if (IsAsciiDigit(s[i])) { + got_something = true; + if (*result > std::numeric_limits<uint32>::max() / 10) + return false; + *result *= 10; + int digit = s[i] - '0'; + if (*result > std::numeric_limits<uint32>::max() - digit) + return false; + *result += digit; + } + } + return got_something; +} + +// Returns number of fetched subprotocols or negative error code. +int FetchSubprotocolList( + const std::string& s, std::vector<std::string>* subprotocol_list) { + subprotocol_list->clear(); + subprotocol_list->push_back(std::string()); + for (size_t i = 0; i < s.size(); ++i) { + if (s[i] > '\x20' && s[i] < '\x7f' && s[i] != kCommaOctet) + subprotocol_list->back() += s[i]; + else if (!subprotocol_list->back().empty()) { + if (subprotocol_list->size() < 16) + subprotocol_list->push_back(std::string()); + else + return net::ERR_LIMIT_VIOLATION; + } + } + if (subprotocol_list->back().empty()) + subprotocol_list->pop_back(); + if (subprotocol_list->empty()) + return net::ERR_WS_PROTOCOL_ERROR; + + { + std::vector<std::string> tmp(*subprotocol_list); + std::sort(tmp.begin(), tmp.end()); + if (tmp.end() != std::unique(tmp.begin(), tmp.end())) + return net::ERR_WS_PROTOCOL_ERROR; + } + return subprotocol_list->size(); +} + +class WebSocketServerSocketImpl : public net::WebSocketServerSocket { + public: + WebSocketServerSocketImpl(net::Socket* transport_socket, Delegate* delegate) + : phase_(PHASE_NYMPH), + frame_bytes_remaining_(0), + transport_socket_(transport_socket), + delegate_(delegate), + handshake_buf_(new net::IOBuffer(kHandshakeLimitBytes)), + fill_handshake_buf_(new net::DrainableIOBuffer( + handshake_buf_, kHandshakeLimitBytes)), + process_handshake_buf_(new net::DrainableIOBuffer( + handshake_buf_, kHandshakeLimitBytes)), + transport_read_callback_(NewCallback( + this, &WebSocketServerSocketImpl::OnRead)), + transport_write_callback_(NewCallback( + this, &WebSocketServerSocketImpl::OnWrite)), + is_transport_read_pending_(false), + is_transport_write_pending_(false), + method_factory_(this) { + DCHECK(transport_socket); + DCHECK(delegate); + } + + virtual ~WebSocketServerSocketImpl() { + std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); + if (it != pending_reqs_.end() && + it->type == PendingReq::TYPE_READ && + it->io_buf != NULL && + it->io_buf->data() != NULL && + it->callback != 0) { + it->callback->Run(0); // Report EOF. + } + } + + private: + enum Phase { + // Before Accept() is called. + PHASE_NYMPH, + + // After Accept() is called and until handshake success/fail. + PHASE_HANDSHAKE, + + // Processing data stream. + PHASE_FRAME_OUTSIDE, // Outside data frame. + PHASE_FRAME_INSIDE, // Inside text frame. + PHASE_FRAME_LENGTH, // Reading length of binary frame. + PHASE_FRAME_SKIP, // Skipping binary frame. + + // After termination. + PHASE_SHUT + }; + + struct PendingReq { + enum Type { + // Frame delimiters or handshake (as opposed to user data). + TYPE_METADATA = 1 << 0, + // Read request. + TYPE_READ = 1 << 1, + // Write request. + TYPE_WRITE = 1 << 2, + + TYPE_READ_METADATA = TYPE_READ | TYPE_METADATA, + TYPE_WRITE_METADATA = TYPE_WRITE | TYPE_METADATA + }; + + PendingReq(Type type, net::DrainableIOBuffer* io_buf, + net::CompletionCallback* callback) + : type(type), + io_buf(io_buf), + callback(callback) { + switch (type) { + case PendingReq::TYPE_READ: + case PendingReq::TYPE_WRITE: + case PendingReq::TYPE_READ_METADATA: + case PendingReq::TYPE_WRITE_METADATA: { + DCHECK(io_buf); + break; + } + default: { + NOTREACHED(); + break; + } + } + } + + Type type; + scoped_refptr<net::DrainableIOBuffer> io_buf; + net::CompletionCallback* callback; + }; + + // Socket implementation. + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) OVERRIDE { + if (buf_len == 0) + return 0; + if (buf == NULL || buf_len < 0) { + NOTREACHED(); + return net::ERR_INVALID_ARGUMENT; + } + while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() - + process_handshake_buf_->BytesConsumed()) { + DCHECK(!is_transport_read_pending_); + DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end()); + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + int n = std::min(bytes_remaining, buf_len); + int rv = ProcessDataFrames( + process_handshake_buf_->data(), n, buf->data(), buf_len); + process_handshake_buf_->DidConsume(n); + if (rv == 0) { + // ProcessDataFrames may return zero for non-empty buffer if it + // contains only frame delimiters without real data. In this case: + // try again and do not just return zero (zero stands for EOF). + continue; + } + return rv; + } + case PHASE_SHUT: { + return 0; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + } + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + pending_reqs_.push_back(PendingReq( + PendingReq::TYPE_READ, + new net::DrainableIOBuffer(buf, buf_len), + callback)); + ConsiderTransportRead(); + break; + } + case PHASE_SHUT: { + return 0; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + return net::ERR_IO_PENDING; + } + + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) OVERRIDE { + if (buf_len == 0) + return 0; + if (buf == NULL || buf_len < 0) { + NOTREACHED(); + return net::ERR_INVALID_ARGUMENT; + } + DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'), + buf->data() + buf_len); + switch (phase_) { + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + break; + } + case PHASE_SHUT: { + return net::ERR_SOCKET_NOT_CONNECTED; + } + case PHASE_NYMPH: + case PHASE_HANDSHAKE: + default: { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + } + + net::IOBuffer* frame_start = new net::IOBuffer(1); + frame_start->data()[0] = '\x00'; + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, + new net::DrainableIOBuffer(frame_start, 1), + NULL)); + + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, + new net::DrainableIOBuffer(buf, buf_len), + callback)); + + net::IOBuffer* frame_end = new net::IOBuffer(1); + frame_end->data()[0] = '\xff'; + pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, + new net::DrainableIOBuffer(frame_end, 1), + NULL)); + + ConsiderTransportWrite(); + return net::ERR_IO_PENDING; + } + + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { + return transport_socket_->SetReceiveBufferSize(size); + } + + virtual bool SetSendBufferSize(int32 size) OVERRIDE { + return transport_socket_->SetSendBufferSize(size); + } + + // WebSocketServerSocket implementation. + virtual int Accept(net::CompletionCallback* callback) { + if (phase_ != PHASE_NYMPH) + return net::ERR_UNEXPECTED; + phase_ = PHASE_HANDSHAKE; + pending_reqs_.push_front(PendingReq( + PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback)); + ConsiderTransportRead(); + return net::ERR_IO_PENDING; + } + + std::deque<PendingReq>::iterator GetPendingReq(PendingReq::Type type) { + for (std::deque<PendingReq>::iterator it = pending_reqs_.begin(); + it != pending_reqs_.end(); ++it) { + if (it->type & type) + return it; + } + return pending_reqs_.end(); + } + + void ConsiderTransportRead() { + if (pending_reqs_.empty()) + return; + if (is_transport_read_pending_) + return; + std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); + if (it == pending_reqs_.end()) + return; + if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { + NOTREACHED(); + return; + } + is_transport_read_pending_ = true; + int rv = transport_socket_->Read( + it->io_buf.get(), it->io_buf->BytesRemaining(), + transport_read_callback_.get()); + if (rv != net::ERR_IO_PENDING) { + // PostTask rather than direct call in order to: + // (1) guarantee calling callback after returning from Read(); + // (2) avoid potential stack overflow; + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &WebSocketServerSocketImpl::OnRead, rv)); + } + } + + void ConsiderTransportWrite() { + if (is_transport_write_pending_) + return; + if (pending_reqs_.empty()) + return; + std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); + if (it == pending_reqs_.end()) + return; + if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + is_transport_write_pending_ = true; + int rv = transport_socket_->Write( + it->io_buf.get(), it->io_buf->BytesRemaining(), + transport_write_callback_.get()); + if (rv != net::ERR_IO_PENDING) { + // PostTask rather than direct call in order to: + // (1) guarantee calling callback after returning from Read(); + // (2) avoid potential stack overflow; + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &WebSocketServerSocketImpl::OnWrite, rv)); + } + } + + void Shut(int result) { + if (result > 0 || result == net::ERR_IO_PENDING) + result = net::ERR_UNEXPECTED; + if (result != 0) { + while (!pending_reqs_.empty()) { + PendingReq& req = pending_reqs_.front(); + if (req.callback) + req.callback->Run(result); + pending_reqs_.pop_front(); + } + transport_socket_.reset(); // terminate underlying connection. + } + phase_ = PHASE_SHUT; + } + + // Callbacks for transport socket. + void OnRead(int result) { + if (!is_transport_read_pending_) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + is_transport_read_pending_ = false; + + if (result <= 0) { + Shut(result); + return; + } + + std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); + if (it == pending_reqs_.end() || + it->io_buf == NULL || + it->io_buf->data() == NULL) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + if ((phase_ == PHASE_HANDSHAKE) == (it->type == PendingReq::TYPE_READ)) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + + switch (phase_) { + case PHASE_HANDSHAKE: { + if (it != pending_reqs_.begin() || it->io_buf != fill_handshake_buf_) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + fill_handshake_buf_->DidConsume(result); + // ProcessHandshake invalidates iterators for |pending_reqs_| + int rv = ProcessHandshake(); + if (rv > 0) { + process_handshake_buf_->DidConsume(rv); + phase_ = PHASE_FRAME_OUTSIDE; + net::CompletionCallback* cb = pending_reqs_.front().callback; + pending_reqs_.pop_front(); + ConsiderTransportWrite(); // Schedule answer handshake. + if (cb) + cb->Run(0); + } else if (rv == net::ERR_IO_PENDING) { + if (fill_handshake_buf_->BytesRemaining() < 1) + Shut(net::ERR_LIMIT_VIOLATION); + } else if (rv < 0) { + Shut(rv); + } else { + Shut(net::ERR_UNEXPECTED); + } + break; + } + case PHASE_FRAME_OUTSIDE: + case PHASE_FRAME_INSIDE: + case PHASE_FRAME_LENGTH: + case PHASE_FRAME_SKIP: { + int rv = ProcessDataFrames( + it->io_buf->data(), result, + it->io_buf->data(), it->io_buf->BytesRemaining()); + if (rv < 0) { + Shut(rv); + return; + } + if (rv > 0 || phase_ == PHASE_SHUT) { + net::CompletionCallback* cb = it->callback; + pending_reqs_.erase(it); + if (cb) + cb->Run(rv); + } + break; + } + case PHASE_NYMPH: + default: { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + break; + } + } + ConsiderTransportRead(); + } + + void OnWrite(int result) { + if (!is_transport_write_pending_) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + is_transport_write_pending_ = false; + + if (result < 0) { + Shut(result); + return; + } + + std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); + if (it == pending_reqs_.end() || + it->io_buf == NULL || + it->io_buf->data() == NULL) { + NOTREACHED(); + Shut(net::ERR_UNEXPECTED); + return; + } + DCHECK_LE(result, it->io_buf->BytesRemaining()); + it->io_buf->DidConsume(result); + if (it->io_buf->BytesRemaining() == 0) { + net::CompletionCallback* cb = it->callback; + int bytes_written = it->io_buf->BytesConsumed(); + DCHECK_GT(bytes_written, 0); + pending_reqs_.erase(it); + if (cb) + cb->Run(bytes_written); + } + ConsiderTransportWrite(); + } + + // Returns (positive) number of consumed bytes on success. + // Returns ERR_IO_PENDING in case of incomplete input. + // Returns ERR_WS_PROTOCOL_ERROR or ERR_LIMIT_VIOLATION in case of failure to + // reasonably parse input. + int ProcessHandshake() { + static const char kGetPrefix[] = "GET "; + static const char kKeyValueDelimiter[] = ": "; + + class Fields { + public: + bool Has(const std::string& name) { + return map_.find(StringToLowerASCII(name)) != map_.end(); + } + + std::string Get(const std::string& name) { + return Has(name) ? map_[StringToLowerASCII(name)] : std::string(); + } + + void Set(const std::string& name, const std::string& value) { + map_[StringToLowerASCII(name)] = StringToLowerASCII(value); + } + + private: + std::map<std::string, std::string> map_; + } fields; + + char* buf = process_handshake_buf_->data(); + size_t buf_size = fill_handshake_buf_->BytesConsumed(); + + if (buf_size < 1) + return net::ERR_IO_PENDING; + if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)), + kGetPrefix)) { + // Data head does not match what is expected. + return net::ERR_WS_PROTOCOL_ERROR; + } + if (buf_size >= kHandshakeLimitBytes) + return net::ERR_LIMIT_VIOLATION; + char* buf_end = buf + buf_size; + + if (buf_size < strlen(kGetPrefix)) + return net::ERR_IO_PENDING; + char* resource_begin = buf + strlen(kGetPrefix); + char* resource_end = std::find(resource_begin, buf_end, kSpaceOctet); + if (resource_end == buf_end) + return net::ERR_IO_PENDING; + std::string resource(resource_begin, resource_end); + if (!IsStringUTF8(resource) || + resource.find_first_of(kCRLF) != std::string::npos) { + return net::ERR_WS_PROTOCOL_ERROR; + } + char* term_pos = std::search( + buf, buf_end, kCRLFCRLF, kCRLFCRLF + strlen(kCRLFCRLF)); + char key3[8]; // Notation (key3) matches websocket RFC. + size_t message_len = buf_end - term_pos; + if (message_len < sizeof(key3) + strlen(kCRLFCRLF)) + return net::ERR_IO_PENDING; + term_pos += strlen(kCRLFCRLF); + memcpy(key3, term_pos, sizeof(key3)); + term_pos += sizeof(key3); + // First line is "GET resource" line, so skip it. + char* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF)); + if (pos == term_pos) + return net::ERR_WS_PROTOCOL_ERROR; + for (;;) { + pos += strlen(kCRLF); + if (term_pos - pos < + static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { + return net::ERR_WS_PROTOCOL_ERROR; + } + if (term_pos - pos == + static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { + break; + } + char* next_pos = std::search( + pos, term_pos, kKeyValueDelimiter, + kKeyValueDelimiter + strlen(kKeyValueDelimiter)); + if (next_pos == term_pos) + return net::ERR_WS_PROTOCOL_ERROR; + std::string key(pos, next_pos); + if (!IsStringASCII(key) || + key.find_first_of(kCRLF) != std::string::npos) { + return net::ERR_WS_PROTOCOL_ERROR; + } + pos = std::search(next_pos += strlen(kKeyValueDelimiter), term_pos, + kCRLF, kCRLF + strlen(kCRLF)); + if (pos == term_pos) + return net::ERR_WS_PROTOCOL_ERROR; + if (!key.empty()) { + std::string value(next_pos, pos); + if (!IsStringASCII(value) || + value.find_first_of(kCRLF) != std::string::npos) { + return net::ERR_WS_PROTOCOL_ERROR; + } + fields.Set(key, value); + } + } + + // Values of Upgrade and Connection fields are hardcoded in the protocol. + if (fields.Get("Upgrade") != "websocket" || + fields.Get("Connection") != "upgrade") { + return net::ERR_WS_PROTOCOL_ERROR; + } + if (fields.Has(kVersionFieldName)) { + NOTIMPLEMENTED(); // new protocol. + return net::ERR_NOT_IMPLEMENTED; + } + + if (!fields.Has(kPlainOriginFieldName)) + return net::ERR_CONNECTION_REFUSED; + // Normalize (e.g. w.r.t. leading slashes) origin. + GURL origin = GURL(fields.Get(kPlainOriginFieldName)).GetOrigin(); + if (!origin.is_valid()) + return net::ERR_WS_PROTOCOL_ERROR; + std::string normalized_origin = origin.spec(); + + if (!fields.Has(kPlainHostFieldName)) + return net::ERR_CONNECTION_REFUSED; + + std::vector<std::string> subprotocol_list; + if (fields.Has(kProtocolFieldName)) { + int rv = FetchSubprotocolList( + fields.Get(kProtocolFieldName), &subprotocol_list); + if (rv < 0) + return rv; + DCHECK(subprotocol_list.end() == std::find( + subprotocol_list.begin(), subprotocol_list.end(), "")); + } + + std::string location; + std::string subprotocol; + if (!delegate_->ValidateWebSocket(resource, + normalized_origin, + fields.Get(kPlainHostFieldName), + subprotocol_list, + &location, + &subprotocol)) { + return net::ERR_CONNECTION_REFUSED; + } + if (subprotocol_list.empty()) { + DCHECK(subprotocol.empty()); + } else { + if (!subprotocol.empty()) { + if (subprotocol_list.end() == std::find( + subprotocol_list.begin(), subprotocol_list.end(), subprotocol)) { + NOTREACHED() << "delegate must pick subprotocol from given list"; + return net::ERR_UNEXPECTED; + } + } + } + + uint32 key_number1 = 0; + uint32 key_number2 = 0; + if (!FetchDecimalDigits(fields.Get(kKey1FieldName), &key_number1) || + !FetchDecimalDigits(fields.Get(kKey2FieldName), &key_number2)) { + return net::ERR_WS_PROTOCOL_ERROR; + } + + // We limit incoming header size so following numbers shall not be too high. + int spaces1 = CountSpaces(fields.Get(kKey1FieldName)); + int spaces2 = CountSpaces(fields.Get(kKey2FieldName)); + if (spaces1 == 0 || + spaces2 == 0 || + key_number1 % spaces1 != 0 || + key_number2 % spaces2 != 0) { + return net::ERR_WS_PROTOCOL_ERROR; + } + + char challenge[4 + 4 + sizeof(key3)]; + int32 part1 = htonl(key_number1 / spaces1); + int32 part2 = htonl(key_number2 / spaces2); + memcpy(challenge, &part1, 4); + memcpy(challenge + 4, &part2, 4); + memcpy(challenge + 4 + 4, key3, sizeof(key3)); + MD5Digest challenge_response; + MD5Sum(challenge, sizeof(challenge), &challenge_response); + + // Concocting response handshake. + class Buffer { + public: + Buffer() + : io_buf_(new net::IOBuffer(kHandshakeLimitBytes)), + bytes_written_(0), + is_ok_(true) { + } + + bool Write(const void* p, int len) { + DCHECK(p); + DCHECK_GE(len, 0); + if (!is_ok_) + return false; + if (bytes_written_ + len > kHandshakeLimitBytes) { + NOTREACHED(); + is_ok_ = false; + return false; + } + memcpy(io_buf_->data() + bytes_written_, p, len); + bytes_written_ += len; + return true; + } + + bool WriteLine(const char* p) { + return Write(p, strlen(p)) && Write(kCRLF, strlen(kCRLF)); + } + + operator net::DrainableIOBuffer*() { + return new net::DrainableIOBuffer(io_buf_, bytes_written_); + } + + bool is_ok() { return is_ok_; } + + private: + net::IOBuffer* io_buf_; + size_t bytes_written_; + bool is_ok_; + } buffer; + + buffer.WriteLine("HTTP/1.1 101 WebSocket Protocol Handshake"); + buffer.WriteLine("Upgrade: WebSocket"); + buffer.WriteLine("Connection: Upgrade"); + + { + // Take care of Location field. + char tmp[2048]; + int rv = base::snprintf(tmp, sizeof(tmp), + "%s: %s", + kLocationFieldName, + location.c_str()); + if (rv <= 0 || rv + 0u >= sizeof(tmp)) + return net::ERR_LIMIT_VIOLATION; + buffer.WriteLine(tmp); + } + { + // Take care of Origin field. + char tmp[2048]; + int rv = base::snprintf(tmp, sizeof(tmp), + "%s: %s", + kOriginFieldName, + fields.Get(kPlainOriginFieldName).c_str()); + if (rv <= 0 || rv + 0u >= sizeof(tmp)) + return net::ERR_LIMIT_VIOLATION; + buffer.WriteLine(tmp); + } + if (!subprotocol.empty()) { + char tmp[2048]; + int rv = base::snprintf(tmp, sizeof(tmp), + "%s: %s", + kProtocolFieldName, + subprotocol.c_str()); + if (rv <= 0 || rv + 0u >= sizeof(tmp)) + return net::ERR_LIMIT_VIOLATION; + buffer.WriteLine(tmp); + } + buffer.WriteLine(""); + buffer.Write(&challenge_response, sizeof(challenge_response)); + + if (!buffer.is_ok()) + return net::ERR_LIMIT_VIOLATION; + + pending_reqs_.push_back(PendingReq( + PendingReq::TYPE_WRITE_METADATA, buffer, NULL)); + DCHECK_GT(term_pos - buf, 0); + return term_pos - buf; + } + + // Removes frame delimiters and returns net number of data bytes (or error). + // |out| may be equal to |buf|, in that case it is in-place operation. + int ProcessDataFrames(char* buf, int buf_len, char* out, int out_len) { + if (out_len < buf_len) { + NOTREACHED(); + return net::ERR_UNEXPECTED; + } + int out_pos = 0; + for (char* p = buf; p < buf + buf_len; ++p) { + switch (phase_) { + case PHASE_FRAME_INSIDE: { + if (*p == '\x00') + return net::ERR_WS_PROTOCOL_ERROR; + if (*p == '\xff') + phase_ = PHASE_FRAME_OUTSIDE; + else + out[out_pos++] = *p; + break; + } + case PHASE_FRAME_OUTSIDE: { + if (*p == '\x00') { + phase_ = PHASE_FRAME_INSIDE; + } else if (*p == '\xff') { + phase_ = PHASE_FRAME_LENGTH; + frame_bytes_remaining_ = 0; + } + else { + return net::ERR_WS_PROTOCOL_ERROR; + } + break; + } + case PHASE_FRAME_LENGTH: { + static const int kValueBits = 7; + static const char kValueMask = (1 << kValueBits) - 1; + frame_bytes_remaining_ <<= kValueBits; + frame_bytes_remaining_ += (*p & kValueMask); + if (*p & ~kValueMask) { + // Check that next byte would not overflow. + if (frame_bytes_remaining_ > + (std::numeric_limits<int>::max() - ((1 << 7) - 1)) >> 7) { + return net::ERR_LIMIT_VIOLATION; + } + } else { + if (frame_bytes_remaining_ == 0) { + phase_ = PHASE_SHUT; + return out_pos; + } else { + phase_ = PHASE_FRAME_SKIP; + } + } + break; + } + case PHASE_FRAME_SKIP: { + DCHECK_GE(frame_bytes_remaining_, 1); + frame_bytes_remaining_ -= 1; + if (frame_bytes_remaining_ < 1) + phase_ = PHASE_FRAME_OUTSIDE; + break; + } + default: { + NOTREACHED(); + } + } + } + return out_pos; + } + + // State machinery. + Phase phase_; + + // Counts frame length for PHASE_FRAME_LENGTH and PHASE_FRAME_SKIP. + int frame_bytes_remaining_; + + // Underlying socket. + scoped_ptr<net::Socket> transport_socket_; + + // Validation is performed via delegate. + Delegate* delegate_; + + // IOBuffer used to communicate with transport at initial stage. + scoped_refptr<net::IOBuffer> handshake_buf_; + scoped_refptr<net::DrainableIOBuffer> fill_handshake_buf_; + scoped_refptr<net::DrainableIOBuffer> process_handshake_buf_; + + // Pending io requests we need to complete. + std::deque<PendingReq> pending_reqs_; + + // Callbacks from transport to us. + scoped_ptr<net::CompletionCallback> transport_read_callback_; + scoped_ptr<net::CompletionCallback> transport_write_callback_; + + // Whether transport requests are pending. + bool is_transport_read_pending_; + bool is_transport_write_pending_; + + ScopedRunnableMethodFactory<WebSocketServerSocketImpl> method_factory_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketServerSocketImpl); +}; + +} // namespace + +namespace net { + +WebSocketServerSocket* CreateWebSocketServerSocket( + Socket* transport_socket, WebSocketServerSocket::Delegate* delegate) { + return new WebSocketServerSocketImpl(transport_socket, delegate); +} + +WebSocketServerSocket::~WebSocketServerSocket() { +} + +} // namespace net; diff --git a/net/socket/web_socket_server_socket.h b/net/socket/web_socket_server_socket.h new file mode 100644 index 0000000..c28c58d --- /dev/null +++ b/net/socket/web_socket_server_socket.h @@ -0,0 +1,63 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEB_SOCKET_SERVER_SOCKET_H_ +#define NET_SOCKET_WEB_SOCKET_SERVER_SOCKET_H_ + +#include <string> +#include <vector> + +#include "net/base/net_api.h" +#include "net/socket/socket.h" + +namespace net { + +// WebSocketServerSocket takes an (already connected) underlying transport +// socket and speaks server-side websocket protocol atop of it. +// This class implements Socket interface: notice that Read() returns (or calls +// back) as soon as some amount of data is available, even if message/frame is +// incomplete. +class WebSocketServerSocket : public Socket { + public: + class Delegate { + public: + // Validates websocket handshake: return false to reject handshake. + // |resource| is name of resource requested in GET stanza of handshake; + // |origin| is origin as reported in handshake; + // |host| is Host field from handshake; + // |subprotocol_list| is derived from Sec-WebSocket-Protocol field. + // Output parameters are: + // |location_out| is location of websocket server; + // |subprotocol_out| is selected subprotocol (or empty string if subprotocol + // list is empty. + virtual bool ValidateWebSocket( + const std::string& resource, + const std::string& origin, + const std::string& host, + const std::vector<std::string>& subprotocol_list, + std::string* location_out, + std::string* subprotocol_out) = 0; + + virtual ~Delegate() {} + }; + + virtual ~WebSocketServerSocket(); + + // Performs websocket server handshake on transport socket. Underlying socket + // must have already been connected/accepted. + // + // Returns either ERR_IO_PENDING, in which case the given callback will be + // called in the future with the real result, or it completes synchronously, + // returning the result immediately. + virtual int Accept(CompletionCallback* callback) = 0; +}; + +// Creates websocket server socket atop of already connected socket. This +// created server socket will take ownership of |transport_socket|. +NET_API WebSocketServerSocket* CreateWebSocketServerSocket( + Socket* transport_socket, WebSocketServerSocket::Delegate* delegate); + +} // namespace net + +#endif // NET_SOCKET_WEB_SOCKET_SERVER_SOCKET_H_ diff --git a/net/socket/web_socket_server_socket_unittest.cc b/net/socket/web_socket_server_socket_unittest.cc new file mode 100644 index 0000000..20e3dcf --- /dev/null +++ b/net/socket/web_socket_server_socket_unittest.cc @@ -0,0 +1,596 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/web_socket_server_socket.h" + +#include <stdlib.h> +#include <algorithm> + +#include "base/callback_old.h" +#include "base/memory/ref_counted.h" +#include "base/message_loop.h" +#include "base/string_util.h" +#include "base/task.h" +#include "base/time.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace { + +const char* kSampleHandshakeRequest[] = { + "GET /demo HTTP/1.1", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Host: example.com", + "Origin: http://example.com", + "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5", + "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00", + "", + "^n:ds[4U" +}; + +const char kSampleHandshakeAnswer[] = + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Location: ws://example.com/demo\r\n" + "Sec-WebSocket-Origin: http://example.com\r\n" + "\r\n" + "8jKS'y:G*Co,Wxa-"; + +const int kHandshakeBufBytes = 1 << 12; + +const char kCRLF[] = "\r\n"; +const char kCRLFCRLF[] = "\r\n\r\n"; +const char kSpaceOctet = '\x20'; + +const int kReadSalt = 7; +const int kWriteSalt = 5; + +int GetRand(int min, int max) { + CHECK(max >= min); + CHECK(max - min < RAND_MAX); + return rand() % (max - min + 1) + min; +} + +class RandIntClass { + public: + int operator() (int range) { + return GetRand(0, range - 1); + } +} g_rand; + +net::DrainableIOBuffer* ResizeIOBuffer(net::DrainableIOBuffer* buf, int len) { + net::DrainableIOBuffer* rv = new net::DrainableIOBuffer( + new net::IOBuffer(len), len); + std::copy(buf->data(), buf->data() + std::min(len, buf->BytesRemaining()), + rv->data()); + return rv; +} + +// TODO(dilmah): consider switching to socket_test_util.h +// Simulates reading from |sample| stream; data supplied in Write() calls are +// stored in |answer| buffer. +class TestingTransportSocket : public net::Socket { + public: + TestingTransportSocket( + net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) + : sample_(sample), + answer_(answer), + final_read_callback_(NULL), + method_factory_(this) { + } + + ~TestingTransportSocket() { + if (final_read_callback_) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoReadCallback, + final_read_callback_, 0)); + } + } + + // Socket implementation. + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + CHECK_GT(buf_len, 0); + int remaining = sample_->BytesRemaining(); + if (remaining < 1) { + if (final_read_callback_) + return 0; + final_read_callback_ = callback; + return net::ERR_IO_PENDING; + } + int lot = GetRand(1, std::min(remaining, buf_len)); + std::copy(sample_->data(), sample_->data() + lot, buf->data()); + sample_->DidConsume(lot); + if (GetRand(0, 1)) { + return lot; + } + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoReadCallback, callback, lot)); + return net::ERR_IO_PENDING; + } + + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + CHECK_GT(buf_len, 0); + int remaining = answer_->BytesRemaining(); + CHECK_GE(remaining, buf_len); + int lot = std::min(remaining, buf_len); + if (GetRand(0, 1)) + lot = GetRand(1, lot); + std::copy(buf->data(), buf->data() + lot, answer_->data()); + answer_->DidConsume(lot); + if (GetRand(0, 1)) { + return lot; + } + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoWriteCallback, callback, lot)); + return net::ERR_IO_PENDING; + } + + virtual bool SetReceiveBufferSize(int32 size) { + return true; + } + + virtual bool SetSendBufferSize(int32 size) { + return true; + } + + net::DrainableIOBuffer* answer() { return answer_.get(); } + + void DoReadCallback(net::CompletionCallback* callback, int result) { + if (result == 0 && !is_closed_) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &TestingTransportSocket::DoReadCallback, callback, 0)); + } else { + if (callback) + callback->Run(result); + } + } + + void DoWriteCallback(net::CompletionCallback* callback, int result) { + if (callback) + callback->Run(result); + } + + bool is_closed_; + + // Data to return for Read requests. + scoped_refptr<net::DrainableIOBuffer> sample_; + + // Data pushed to us by server socket (using Write calls). + scoped_refptr<net::DrainableIOBuffer> answer_; + + // Final read callback to report zero (zero stands for EOF). + net::CompletionCallback* final_read_callback_; + + ScopedRunnableMethodFactory<TestingTransportSocket> method_factory_; +}; + +class Validator : public net::WebSocketServerSocket::Delegate { + public: + Validator(const std::string& resource, + const std::string& origin, + const std::string& host) + : resource_(resource), origin_(origin), host_(host) { + } + + // WebSocketServerSocket::Delegate implementation. + virtual bool ValidateWebSocket( + const std::string& resource, + const std::string& origin, + const std::string& host, + const std::vector<std::string>& subprotocol_list, + std::string* location_out, + std::string* subprotocol_out) { + if (resource != resource_ || origin != origin_ || host != host_) + return false; + if (!subprotocol_list.empty()) + *subprotocol_out = subprotocol_list.front(); + + char tmp[2048]; + base::snprintf( + tmp, sizeof(tmp), "ws://%s%s", host.c_str(), resource.c_str()); + location_out->assign(tmp); + return true; + } + + private: + std::string resource_; + std::string origin_; + std::string host_; +}; + +char ReferenceSeq(unsigned n, unsigned salt) { + return (salt * 2 + n * 3) % ('z' - 'a') + 'a'; +} + +class ReadWriteTracker { + public: + ReadWriteTracker( + net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write) + : ws_(ws), + buf_size_(1 << 14), + accept_callback_(NewCallback(this, &ReadWriteTracker::OnAccept)), + read_callback_(NewCallback(this, &ReadWriteTracker::OnRead)), + write_callback_(NewCallback(this, &ReadWriteTracker::OnWrite)), + read_buf_(new net::IOBuffer(buf_size_)), + write_buf_(new net::IOBuffer(buf_size_)), + bytes_remaining_to_read_(bytes_to_read), + bytes_remaining_to_write_(bytes_to_write), + read_initiated_(false), + write_initiated_(false), + got_final_zero_(false) { + int rv = ws_->Accept(accept_callback_.get()); + if (rv != net::ERR_IO_PENDING) + OnAccept(rv); + } + + ~ReadWriteTracker() { + CHECK_EQ(bytes_remaining_to_write_, 0); + CHECK_EQ(bytes_remaining_to_read_, 0); + } + + void OnAccept(int result) { + ASSERT_EQ(result, 0); + if (GetRand(0, 1)) { + DoRead(); + DoWrite(); + } else { + DoWrite(); + DoRead(); + } + } + + void DoWrite() { + if (bytes_remaining_to_write_ < 1) + return; + int lot = GetRand(1, bytes_remaining_to_write_); + lot = std::min(lot, buf_size_); + for (int i = 0; i < lot; ++i) + write_buf_->data()[i] = ReferenceSeq( + bytes_remaining_to_write_ - i - 1, kWriteSalt); + int rv = ws_->Write(write_buf_, lot, write_callback_.get()); + if (rv != net::ERR_IO_PENDING) + OnWrite(rv); + } + + void DoRead() { + int lot = GetRand(1, buf_size_); + if (bytes_remaining_to_read_ < 1) { + if (got_final_zero_) + return; + } else { + lot = GetRand(1, bytes_remaining_to_read_); + lot = std::min(lot, buf_size_); + } + int rv = ws_->Read(read_buf_, lot, read_callback_.get()); + if (rv != net::ERR_IO_PENDING) + OnRead(rv); + } + + void OnWrite(int result) { + ASSERT_GT(result, 0); + ASSERT_LE(result, bytes_remaining_to_write_); + bytes_remaining_to_write_ -= result; + DoWrite(); + } + + void OnRead(int result) { + ASSERT_LE(result, bytes_remaining_to_read_); + if (bytes_remaining_to_read_ < 1) { + ASSERT_FALSE(got_final_zero_); + ASSERT_EQ(result, 0); + got_final_zero_ = true; + return; + } + for (int i = 0; i < result; ++i) { + ASSERT_EQ(read_buf_->data()[i], ReferenceSeq( + bytes_remaining_to_read_ - i - 1, kReadSalt)); + } + bytes_remaining_to_read_ -= result; + DoRead(); + } + + private: + net::WebSocketServerSocket* const ws_; + int const buf_size_; + scoped_ptr<net::CompletionCallback> accept_callback_; + scoped_ptr<net::CompletionCallback> read_callback_; + scoped_ptr<net::CompletionCallback> write_callback_; + scoped_refptr<net::IOBuffer> read_buf_; + scoped_refptr<net::IOBuffer> write_buf_; + int bytes_remaining_to_read_; + int bytes_remaining_to_write_; + bool read_initiated_; + bool write_initiated_; + bool got_final_zero_; +}; + +} // namespace + +namespace net { + +class WebSocketServerSocketTest : public testing::Test { + public: + virtual ~WebSocketServerSocketTest() { + } + + virtual void SetUp() { + count_ = 0; + accept_callback_[0].reset(NewCallback<WebSocketServerSocketTest, int>( + this, &WebSocketServerSocketTest::OnAccept0)); + accept_callback_[1].reset(NewCallback<WebSocketServerSocketTest, int>( + this, &WebSocketServerSocketTest::OnAccept1)); + } + + virtual void TearDown() { + } + + void OnAccept0(int result) { + ASSERT_EQ(result, 0); + ASSERT_LT(count_, 99999); + count_ += 1; + } + + void OnAccept1(int result) { + ASSERT_TRUE(result == ERR_CONNECTION_REFUSED || + result == ERR_ACCESS_DENIED); + ASSERT_LT(count_, 99999); + count_ += 1; + } + + int count_; + scoped_ptr<net::CompletionCallback> accept_callback_[2]; +}; + +TEST_F(WebSocketServerSocketTest, Handshake) { + srand(2523456); + std::vector<Socket*> kill_list; + std::vector< scoped_refptr<DrainableIOBuffer> > answer_list; + Validator validator("/demo", "http://example.com/", "example.com"); + count_ = 0; + const int kNumTests = 300; + for (int run = kNumTests; run--;) { + scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer( + new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes); + for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) { + std::copy(kSampleHandshakeRequest[i], + kSampleHandshakeRequest[i] + strlen(kSampleHandshakeRequest[i]), + sample->data()); + sample->DidConsume(strlen(kSampleHandshakeRequest[i])); + if (i != arraysize(kSampleHandshakeRequest) - 1) { + std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data()); + sample->DidConsume(strlen(kCRLF)); + } + } + int sample_len = sample->BytesConsumed(); + sample->SetOffset(0); + DrainableIOBuffer* answer = new DrainableIOBuffer( + new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes); + answer_list.push_back(answer); + TestingTransportSocket* transport = new TestingTransportSocket( + ResizeIOBuffer(sample.get(), sample_len), answer); + WebSocketServerSocket* ws = CreateWebSocketServerSocket( + transport, &validator); + ASSERT_TRUE(ws != NULL); + kill_list.push_back(ws); + + int rv = ws->Accept(accept_callback_[0].get()); + if (rv != ERR_IO_PENDING) + OnAccept0(rv); + } + MessageLoop::current()->RunAllPending(); + ASSERT_EQ(count_, kNumTests); + for (size_t i = answer_list.size(); i--;) { + ASSERT_EQ(answer_list[i]->BytesConsumed() + 0u, + strlen(kSampleHandshakeAnswer)); + ASSERT_TRUE(std::equal( + answer_list[i]->data() - answer_list[i]->BytesConsumed(), + answer_list[i]->data(), kSampleHandshakeAnswer)); + } + for (size_t i = kill_list.size(); i--;) + delete kill_list[i]; + MessageLoop::current()->RunAllPending(); +} + +TEST_F(WebSocketServerSocketTest, BadCred) { + srand(9034958); + std::vector<Socket*> kill_list; + std::vector< scoped_refptr<DrainableIOBuffer> > answer_list; + Validator *validator[] = { + new Validator("/demo", "http://gooogle.com/", "example.com"), + new Validator("/tcpproxy", "http://example.com/", "example.com"), + new Validator("/tcpproxy", "http://gooogle.com/", "example.com"), + new Validator("/demo", "http://example.com/", "exmple.com"), + new Validator("/demo", "http://gooogle.com/", "gooogle.com") + }; + count_ = 0; + for (int run = arraysize(validator); run--;) { + scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer( + new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes); + for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) { + std::copy(kSampleHandshakeRequest[i], + kSampleHandshakeRequest[i] + strlen(kSampleHandshakeRequest[i]), + sample->data()); + sample->DidConsume(strlen(kSampleHandshakeRequest[i])); + if (i != arraysize(kSampleHandshakeRequest) - 1) { + std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data()); + sample->DidConsume(strlen(kCRLF)); + } + } + int sample_len = sample->BytesConsumed(); + sample->SetOffset(0); + DrainableIOBuffer* answer = new DrainableIOBuffer( + new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes); + answer_list.push_back(answer); + TestingTransportSocket* transport = new TestingTransportSocket( + ResizeIOBuffer(sample.get(), sample_len), answer); + WebSocketServerSocket* ws = CreateWebSocketServerSocket( + transport, validator[run]); + ASSERT_TRUE(ws != NULL); + kill_list.push_back(ws); + + int rv = ws->Accept(accept_callback_[1].get()); + if (rv != ERR_IO_PENDING) + OnAccept1(rv); + } + MessageLoop::current()->RunAllPending(); + ASSERT_EQ(count_ + 0u, arraysize(validator)); + for (size_t i = answer_list.size(); i--;) + ASSERT_EQ(answer_list[i]->BytesConsumed(), 0); + for (size_t i = kill_list.size(); i--;) + delete kill_list[i]; + for (size_t i = arraysize(validator); i--;) + delete validator[i]; + MessageLoop::current()->RunAllPending(); +} + +TEST_F(WebSocketServerSocketTest, ReorderedHandshake) { + srand(205643459); + std::vector<Socket*> kill_list; + std::vector< scoped_refptr<DrainableIOBuffer> > answer_list; + Validator validator("/demo", "http://example.com/", "example.com"); + count_ = 0; + const int kNumTests = 200; + for (int run = kNumTests; run--;) { + scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer( + new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes); + + std::vector<size_t> fields_order; + for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) + fields_order.push_back(i); + // One leading and two trailing lines of request are special, leave them. + std::random_shuffle(fields_order.begin() + 1, + fields_order.begin() + fields_order.size() - 3, + g_rand); + + for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) { + size_t j = fields_order[i]; + std::copy(kSampleHandshakeRequest[j], + kSampleHandshakeRequest[j] + strlen(kSampleHandshakeRequest[j]), + sample->data()); + sample->DidConsume(strlen(kSampleHandshakeRequest[j])); + if (i != arraysize(kSampleHandshakeRequest) - 1) { + std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data()); + sample->DidConsume(strlen(kCRLF)); + } + } + int sample_len = sample->BytesConsumed(); + sample->SetOffset(0); + DrainableIOBuffer* answer = new DrainableIOBuffer( + new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes); + answer_list.push_back(answer); + TestingTransportSocket* transport = new TestingTransportSocket( + ResizeIOBuffer(sample.get(), sample_len), answer); + WebSocketServerSocket* ws = CreateWebSocketServerSocket( + transport, &validator); + ASSERT_TRUE(ws != NULL); + kill_list.push_back(ws); + + int rv = ws->Accept(accept_callback_[0].get()); + if (rv != ERR_IO_PENDING) + OnAccept0(rv); + } + MessageLoop::current()->RunAllPending(); + ASSERT_EQ(count_, kNumTests); + for (size_t i = answer_list.size(); i--;) { + ASSERT_EQ(answer_list[i]->BytesConsumed() + 0u, + strlen(kSampleHandshakeAnswer)); + ASSERT_TRUE(std::equal( + answer_list[i]->data() - answer_list[i]->BytesConsumed(), + answer_list[i]->data(), kSampleHandshakeAnswer)); + } + for (size_t i = kill_list.size(); i--;) + delete kill_list[i]; + MessageLoop::current()->RunAllPending(); +} + +TEST_F(WebSocketServerSocketTest, ConveyData) { + srand(8234523); + std::vector<Socket*> kill_list; + std::vector<ReadWriteTracker*> tracker_list; + Validator validator("/demo", "http://example.com/", "example.com"); + count_ = 0; + const int kNumTests = 150; + for (int run = kNumTests; run--;) { + int bytes_to_read = GetRand(1, 1 << 14); + int bytes_to_write = GetRand(1, 1 << 14); + int frames_limit = GetRand(1, 1 << 10); + int sample_limit = kHandshakeBufBytes + bytes_to_write + frames_limit * 2; + scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer( + new IOBuffer(sample_limit), sample_limit); + + std::vector<size_t> fields_order; + for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) + fields_order.push_back(i); + // One leading and two trailing lines of request are special, leave them. + std::random_shuffle(fields_order.begin() + 1, + fields_order.begin() + fields_order.size() - 3, + g_rand); + + for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) { + size_t j = fields_order[i]; + std::copy(kSampleHandshakeRequest[j], + kSampleHandshakeRequest[j] + strlen(kSampleHandshakeRequest[j]), + sample->data()); + sample->DidConsume(strlen(kSampleHandshakeRequest[j])); + if (i != arraysize(kSampleHandshakeRequest) - 1) { + std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data()); + sample->DidConsume(strlen(kCRLF)); + } + } + { + bool outside_frame = true; + int pos = 0; + for (int i = 0; i < bytes_to_write; ++i) { + if (outside_frame) { + sample->data()[pos++] = '\x00'; + outside_frame = false; + CHECK_GE(frames_limit, 1); + frames_limit -= 1; + } + sample->data()[pos++] = ReferenceSeq(bytes_to_write - i - 1, kReadSalt); + if ((frames_limit > 1 && + GetRand(0, 1 + (bytes_to_write - i) / frames_limit) == 0) || + i == bytes_to_write - 1) { + sample->data()[pos++] = '\xff'; + outside_frame = true; + } + } + sample->DidConsume(pos); + } + + int sample_len = sample->BytesConsumed(); + sample->SetOffset(0); + int answer_limit = kHandshakeBufBytes + bytes_to_read * 3; + DrainableIOBuffer* answer = new DrainableIOBuffer( + new IOBuffer(answer_limit), answer_limit); + TestingTransportSocket* transport = new TestingTransportSocket( + ResizeIOBuffer(sample.get(), sample_len), answer); + WebSocketServerSocket* ws = CreateWebSocketServerSocket( + transport, &validator); + ASSERT_TRUE(ws != NULL); + kill_list.push_back(ws); + + ReadWriteTracker* tracker = new ReadWriteTracker( + ws, bytes_to_write, bytes_to_read); + tracker_list.push_back(tracker); + } + MessageLoop::current()->RunAllPending(); + + for (size_t i = kill_list.size(); i--;) + delete kill_list[i]; + for (size_t i = tracker_list.size(); i--;) + delete tracker_list[i]; + MessageLoop::current()->RunAllPending(); +} + +} // namespace net |