summaryrefslogtreecommitdiffstats
path: root/net/websockets
diff options
context:
space:
mode:
authorukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2009-10-21 05:36:17 +0000
committerukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2009-10-21 05:36:17 +0000
commitdb6ffa59f24996886615578c5791d17fc51b423e (patch)
treead81b60465fb69dbd72de17a503faea8c1665a51 /net/websockets
parent710572ce32772f31ebee4d7a1bc48734b0e69163 (diff)
downloadchromium_src-db6ffa59f24996886615578c5791d17fc51b423e.zip
chromium_src-db6ffa59f24996886615578c5791d17fc51b423e.tar.gz
chromium_src-db6ffa59f24996886615578c5791d17fc51b423e.tar.bz2
WebSocket protocol handler for live experiment.
This is in-browser-process WebSocket protocol handler, which will be used in WebSocket live experiment. BUG=none TEST=net_unittests passes Review URL: http://codereview.chromium.org/304014 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@29614 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/websockets')
-rw-r--r--net/websockets/websocket.cc441
-rw-r--r--net/websockets/websocket.h213
-rw-r--r--net/websockets/websocket_unittest.cc214
3 files changed, 868 insertions, 0 deletions
diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc
new file mode 100644
index 0000000..f95a9cc
--- /dev/null
+++ b/net/websockets/websocket.cc
@@ -0,0 +1,441 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <algorithm>
+#include <limits>
+
+#include "net/websockets/websocket.h"
+
+#include "base/message_loop.h"
+#include "net/http/http_response_headers.h"
+#include "net/http/http_util.h"
+
+namespace net {
+
+static const int kWebSocketPort = 80;
+static const int kSecureWebSocketPort = 443;
+
+static const char kServerHandshakeHeader[] =
+ "HTTP/1.1 101 Web Socket Protocol\r\n";
+static const size_t kServerHandshakeHeaderLength =
+ sizeof(kServerHandshakeHeader) - 1;
+
+static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
+static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1;
+
+static const char kConnectionHeader[] = "Connection: Upgrade\r\n";
+static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1;
+
+bool WebSocket::Request::is_secure() const {
+ return url_.SchemeIs("wss");
+}
+
+WebSocket::WebSocket(Request* request, Delegate* delegate)
+ : ready_state_(INITIALIZED),
+ mode_(MODE_INCOMPLETE),
+ request_(request),
+ delegate_(delegate),
+ origin_loop_(MessageLoop::current()),
+ socket_stream_(NULL),
+ max_pending_send_allowed_(0),
+ current_read_buf_(NULL),
+ read_consumed_len_(0),
+ current_write_buf_(NULL) {
+ DCHECK(request_.get());
+ DCHECK(delegate_);
+ DCHECK(origin_loop_);
+}
+
+WebSocket::~WebSocket() {
+ DCHECK(!delegate_);
+ DCHECK(!socket_stream_);
+}
+
+void WebSocket::Connect() {
+ DCHECK(ready_state_ == INITIALIZED);
+ DCHECK(request_.get());
+ DCHECK(delegate_);
+ DCHECK(!socket_stream_);
+ DCHECK(MessageLoop::current() == origin_loop_);
+
+ socket_stream_ = new SocketStream(request_->url(), this);
+ socket_stream_->set_context(request_->context());
+
+ if (request_->host_resolver())
+ socket_stream_->SetHostResolver(request_->host_resolver());
+ if (request_->client_socket_factory())
+ socket_stream_->SetClientSocketFactory(request_->client_socket_factory());
+
+ ready_state_ = CONNECTING;
+ socket_stream_->Connect();
+}
+
+void WebSocket::Send(const std::string& msg) {
+ DCHECK(ready_state_ == OPEN);
+ DCHECK(MessageLoop::current() == origin_loop_);
+
+ IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2);
+ char* p = buf->data();
+ *p = '\0';
+ memcpy(p + 1, msg.data(), msg.size());
+ *(p + 1 + msg.size()) = '\xff';
+ pending_write_bufs_.push_back(buf);
+ SendPending();
+}
+
+void WebSocket::Close() {
+ DCHECK(MessageLoop::current() == origin_loop_);
+
+ if (ready_state_ == INITIALIZED) {
+ DCHECK(!socket_stream_);
+ ready_state_ = CLOSED;
+ return;
+ }
+ if (ready_state_ != CLOSED) {
+ DCHECK(socket_stream_);
+ socket_stream_->Close();
+ return;
+ }
+}
+
+void WebSocket::OnConnected(SocketStream* socket_stream,
+ int max_pending_send_allowed) {
+ DCHECK(socket_stream == socket_stream_);
+ max_pending_send_allowed_ = max_pending_send_allowed;
+
+ // Use |max_pending_send_allowed| as hint for initial size of read buffer.
+ current_read_buf_ = new GrowableIOBuffer();
+ current_read_buf_->set_capacity(max_pending_send_allowed_);
+ read_consumed_len_ = 0;
+
+ DCHECK(!current_write_buf_);
+ pending_write_bufs_.push_back(CreateClientHandshakeMessage());
+ origin_loop_->PostTask(FROM_HERE,
+ NewRunnableMethod(this, &WebSocket::SendPending));
+}
+
+void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
+ DCHECK(socket_stream == socket_stream_);
+ DCHECK(current_write_buf_);
+ current_write_buf_->DidConsume(amount_sent);
+ DCHECK_GE(current_write_buf_->BytesRemaining(), 0);
+ if (current_write_buf_->BytesRemaining() == 0) {
+ current_write_buf_ = NULL;
+ pending_write_bufs_.pop_front();
+ }
+ origin_loop_->PostTask(FROM_HERE,
+ NewRunnableMethod(this, &WebSocket::SendPending));
+}
+
+void WebSocket::OnReceivedData(SocketStream* socket_stream,
+ const char* data, int len) {
+ DCHECK(socket_stream == socket_stream_);
+ DCHECK(current_read_buf_);
+ // Check if |current_read_buf_| has enough space to store |len| of |data|.
+ if (len >= current_read_buf_->RemainingCapacity()) {
+ current_read_buf_->set_capacity(
+ current_read_buf_->offset() + len);
+ }
+
+ DCHECK(current_read_buf_->RemainingCapacity() >= len);
+ memcpy(current_read_buf_->data(), data, len);
+ current_read_buf_->set_offset(current_read_buf_->offset() + len);
+
+ origin_loop_->PostTask(FROM_HERE,
+ NewRunnableMethod(this, &WebSocket::DoReceivedData));
+}
+
+void WebSocket::OnClose(SocketStream* socket_stream) {
+ origin_loop_->PostTask(FROM_HERE,
+ NewRunnableMethod(this, &WebSocket::DoClose));
+}
+
+IOBufferWithSize* WebSocket::CreateClientHandshakeMessage() const {
+ std::string msg;
+ msg = "GET ";
+ msg += request_->url().path();
+ if (request_->url().has_query()) {
+ msg += "?";
+ msg += request_->url().query();
+ }
+ msg += " HTTP/1.1\r\n";
+ msg += kUpgradeHeader;
+ msg += kConnectionHeader;
+ msg += "Host: ";
+ msg += StringToLowerASCII(request_->url().host());
+ if (request_->url().has_port()) {
+ bool secure = request_->is_secure();
+ int port = request_->url().EffectiveIntPort();
+ if ((!secure &&
+ port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
+ (secure &&
+ port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
+ msg += ":";
+ msg += IntToString(port);
+ }
+ }
+ msg += "\r\n";
+ msg += "Origin: ";
+ msg += StringToLowerASCII(request_->origin());
+ msg += "\r\n";
+ if (!request_->protocol().empty()) {
+ msg += "WebSocket-Protocol: ";
+ msg += request_->protocol();
+ msg += "\r\n";
+ }
+ // TODO(ukai): Add cookie if necessary.
+ msg += "\r\n";
+ IOBufferWithSize* buf = new IOBufferWithSize(msg.size());
+ memcpy(buf->data(), msg.data(), msg.size());
+ return buf;
+}
+
+int WebSocket::CheckHandshake() {
+ DCHECK(current_read_buf_);
+ DCHECK(ready_state_ == CONNECTING);
+ mode_ = MODE_INCOMPLETE;
+ const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_;
+ const char *p = start;
+ size_t len = current_read_buf_->offset() - read_consumed_len_;
+ if (len < kServerHandshakeHeaderLength) {
+ return -1;
+ }
+ if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
+ mode_ = MODE_NORMAL;
+ } else {
+ int eoh = HttpUtil::LocateEndOfHeaders(p, len);
+ if (eoh < 0)
+ return -1;
+ scoped_refptr<HttpResponseHeaders> headers(
+ new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh)));
+ if (headers->response_code() == 401) {
+ mode_ = MODE_AUTHENTICATE;
+ // TODO(ukai): Implement authentication handlers.
+ }
+ // Invalid response code.
+ ready_state_ = CLOSED;
+ return eoh;
+ }
+ const char* end = p + len + 1;
+ p += kServerHandshakeHeaderLength;
+
+ if (mode_ == MODE_NORMAL) {
+ size_t header_size = end - p;
+ if (header_size < kUpgradeHeaderLength)
+ return -1;
+ if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
+ ready_state_ = CLOSED;
+ return p - start;
+ }
+ p += kUpgradeHeaderLength;
+
+ header_size = end - p;
+ if (header_size < kConnectionHeaderLength)
+ return -1;
+ if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
+ ready_state_ = CLOSED;
+ return p - start;
+ }
+ p += kConnectionHeaderLength;
+ }
+ int eoh = HttpUtil::LocateEndOfHeaders(start, len);
+ if (eoh == -1)
+ return eoh;
+ scoped_refptr<HttpResponseHeaders> headers(
+ new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh)));
+ if (!ProcessHeaders(*headers)) {
+ ready_state_ = CLOSED;
+ return eoh;
+ }
+ switch (mode_) {
+ case MODE_NORMAL:
+ if (CheckResponseHeaders()) {
+ ready_state_ = OPEN;
+ } else {
+ ready_state_ = CLOSED;
+ }
+ break;
+ default:
+ ready_state_ = CLOSED;
+ break;
+ }
+ return eoh;
+}
+
+// Gets the value of the specified header.
+// It assures only one header of |name| in |headers|.
+// Returns true iff single header of |name| is found in |headers|
+// and |value| is filled with the value.
+// Returns false otherwise.
+static bool GetSingleHeader(const HttpResponseHeaders& headers,
+ const std::string& name,
+ std::string* value) {
+ std::string first_value;
+ void* iter = NULL;
+ if (!headers.EnumerateHeader(&iter, name, &first_value))
+ return false;
+
+ // Checks no more |name| found in |headers|.
+ // Second call of EnumerateHeader() must return false.
+ std::string second_value;
+ if (headers.EnumerateHeader(&iter, name, &second_value))
+ return false;
+ *value = first_value;
+ return true;
+}
+
+bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) {
+ if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
+ return false;
+
+ if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
+ return false;
+
+ if (!request_->protocol().empty()
+ && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
+ return false;
+ return true;
+}
+
+bool WebSocket::CheckResponseHeaders() const {
+ DCHECK(mode_ == MODE_NORMAL);
+ if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str()))
+ return false;
+ if (request_->location() != ws_location_)
+ return false;
+ if (request_->protocol() != ws_protocol_)
+ return false;
+ return true;
+}
+
+void WebSocket::SendPending() {
+ DCHECK(MessageLoop::current() == origin_loop_);
+ DCHECK(socket_stream_);
+ if (!current_write_buf_) {
+ if (pending_write_bufs_.empty())
+ return;
+ current_write_buf_ = new DrainableIOBuffer(
+ pending_write_bufs_.front(), pending_write_bufs_.front()->size());
+ }
+ DCHECK_GT(current_write_buf_->BytesRemaining(), 0);
+ bool sent = socket_stream_->SendData(
+ current_write_buf_->data(),
+ std::min(current_write_buf_->BytesRemaining(),
+ max_pending_send_allowed_));
+ DCHECK(sent);
+}
+
+void WebSocket::DoReceivedData() {
+ DCHECK(MessageLoop::current() == origin_loop_);
+ switch (ready_state_) {
+ case CONNECTING:
+ {
+ int eoh = CheckHandshake();
+ if (eoh < 0) {
+ // Not enough data, Retry when more data is available.
+ return;
+ }
+ SkipReadBuffer(eoh);
+ }
+ if (ready_state_ != OPEN) {
+ // Handshake failed.
+ socket_stream_->Close();
+ return;
+ }
+ delegate_->OnOpen(this);
+ if (current_read_buf_->offset() == read_consumed_len_) {
+ // No remaining data after handshake message.
+ break;
+ }
+ // FALL THROUGH
+ case OPEN:
+ ProcessFrameData();
+ break;
+
+ case CLOSED:
+ // Closed just after DoReceivedData is queued on |origin_loop_|.
+ break;
+ default:
+ NOTREACHED();
+ break;
+ }
+}
+
+void WebSocket::ProcessFrameData() {
+ DCHECK(current_read_buf_);
+ const char* start_frame =
+ current_read_buf_->StartOfBuffer() + read_consumed_len_;
+ const char* next_frame = start_frame;
+ const char* p = next_frame;
+ const char* end =
+ current_read_buf_->StartOfBuffer() + current_read_buf_->offset();
+ while (p < end) {
+ unsigned char frame_byte = static_cast<unsigned char>(*p++);
+ if ((frame_byte & 0x80) == 0x80) {
+ int length = 0;
+ while (p < end && (*p & 0x80) == 0x80) {
+ if (length > std::numeric_limits<int>::max() / 128) {
+ // frame length overflow.
+ socket_stream_->Close();
+ return;
+ }
+ length = length * 128 + *p & 0x7f;
+ ++p;
+ }
+ // Checks if the frame body hasn't been completely received yet.
+ // It also checks the case the frame length bytes haven't been completely
+ // received yet, because p == end and length > 0 in such case.
+ if (p + length < end) {
+ p += length;
+ next_frame = p;
+ }
+ } else {
+ const char* msg_start = p;
+ while (p < end && *p != '\xff')
+ ++p;
+ if (p < end && *p == '\xff') {
+ if (frame_byte == 0x00)
+ delegate_->OnMessage(this, std::string(msg_start, p - msg_start));
+ ++p;
+ next_frame = p;
+ }
+ }
+ }
+ SkipReadBuffer(next_frame - start_frame);
+}
+
+void WebSocket::SkipReadBuffer(int len) {
+ read_consumed_len_ += len;
+ int remaining = current_read_buf_->offset() - read_consumed_len_;
+ DCHECK_GE(remaining, 0);
+ if (remaining < read_consumed_len_ &&
+ current_read_buf_->RemainingCapacity() < read_consumed_len_) {
+ // Pre compaction:
+ // 0 v-read_consumed_len_ v-offset v- capacity
+ // |..processed..| .. remaining .. | .. RemainingCapacity |
+ //
+ memmove(current_read_buf_->StartOfBuffer(),
+ current_read_buf_->StartOfBuffer() + read_consumed_len_,
+ remaining);
+ read_consumed_len_ = 0;
+ current_read_buf_->set_offset(remaining);
+ // Post compaction:
+ // 0read_consumed_len_ v- offset v- capacity
+ // |.. remaining .. | .. RemainingCapacity ... |
+ //
+ }
+}
+
+void WebSocket::DoClose() {
+ DCHECK(MessageLoop::current() == origin_loop_);
+ Delegate* delegate = delegate_;
+ delegate_ = NULL;
+ ready_state_ = CLOSED;
+ if (!socket_stream_)
+ return;
+ socket_stream_ = NULL;
+ delegate->OnClose(this);
+}
+
+} // namespace net
diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h
new file mode 100644
index 0000000..5393a28
--- /dev/null
+++ b/net/websockets/websocket.h
@@ -0,0 +1,213 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// WebSocket protocol implementation in chromium.
+// It is intended to be used for live experiment of WebSocket connectivity
+// metrics.
+// Note that it is not used for WebKit's WebSocket communication.
+// See third_party/WebKit/WebCore/websockets/ instead.
+
+#ifndef NET_WEBSOCKETS_WEBSOCKET_H_
+#define NET_WEBSOCKETS_WEBSOCKET_H_
+
+#include <deque>
+#include <string>
+
+#include "base/ref_counted.h"
+#include "googleurl/src/gurl.h"
+#include "net/base/io_buffer.h"
+#include "net/socket_stream/socket_stream.h"
+#include "net/url_request/url_request_context.h"
+
+class MessageLoop;
+
+namespace net {
+
+class ClientSocketFactory;
+class HostResolver;
+class HttpResponseHeaders;
+
+class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
+ public SocketStream::Delegate {
+ public:
+ enum State {
+ INITIALIZED = -1,
+ CONNECTING = 0,
+ OPEN = 1,
+ CLOSED = 2,
+ };
+ class Request {
+ public:
+ Request(const GURL& url, const std::string protocol,
+ const std::string origin, const std::string location,
+ URLRequestContext* context)
+ : url_(url),
+ protocol_(protocol),
+ origin_(origin),
+ location_(location),
+ context_(context),
+ host_resolver_(NULL),
+ client_socket_factory_(NULL) {}
+ ~Request() {}
+
+ const GURL& url() const { return url_; }
+ bool is_secure() const;
+ const std::string& protocol() const { return protocol_; }
+ const std::string& origin() const { return origin_; }
+ const std::string& location() const { return location_; }
+ URLRequestContext* context() const { return context_; }
+
+ // Sets an alternative HostResolver. For testing purposes only.
+ void SetHostResolver(HostResolver* host_resolver) {
+ host_resolver_ = host_resolver;
+ }
+ HostResolver* host_resolver() const { return host_resolver_; }
+
+ // Sets an alternative ClientSocketFactory. Doesn't take ownership of
+ // |factory|. For testing purposes only.
+ void SetClientSocketFactory(ClientSocketFactory* factory) {
+ client_socket_factory_ = factory;
+ }
+ ClientSocketFactory* client_socket_factory() const {
+ return client_socket_factory_;
+ }
+
+ private:
+ GURL url_;
+ std::string protocol_;
+ std::string origin_;
+ std::string location_;
+ scoped_refptr<URLRequestContext> context_;
+
+ scoped_refptr<HostResolver> host_resolver_;
+ ClientSocketFactory* client_socket_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(Request);
+ };
+ // Delegate methods will be called on the same message loop as
+ // WebSocket is constructed.
+ class Delegate {
+ public:
+ virtual ~Delegate() {}
+
+ // Called when WebSocket connection has been established.
+ virtual void OnOpen(WebSocket* socket) = 0;
+
+ // Called when |msg| is received at |socket|.
+ // |msg| should be in UTF-8.
+ virtual void OnMessage(WebSocket* socket, const std::string& msg) = 0;
+
+ // Called when |socket| is closed.
+ virtual void OnClose(WebSocket* socket) = 0;
+ };
+
+ // Constructs new WebSocket.
+ // It takes ownership of |req|.
+ // |delegate| must be alive while this object is alive.
+ WebSocket(Request* req, Delegate* delegate);
+
+ Delegate* delegate() const { return delegate_; }
+
+ State ready_state() const { return ready_state_; }
+
+ // Connects new WebSocket.
+ void Connect();
+
+ // Sends |msg| on the WebSocket connection.
+ // |msg| should be in UTF-8.
+ void Send(const std::string& msg);
+
+ // Closes the WebSocket connection.
+ void Close();
+
+ // SocketStream::Delegate methods.
+ // Called on IO thread.
+ virtual void OnConnected(SocketStream* socket_stream,
+ int max_pending_send_allowed);
+ virtual void OnSentData(SocketStream* socket_stream, int amount_sent);
+ virtual void OnReceivedData(SocketStream* socket_stream,
+ const char* data, int len);
+ virtual void OnClose(SocketStream* socket);
+
+ private:
+ enum Mode {
+ MODE_INCOMPLETE, MODE_NORMAL, MODE_AUTHENTICATE,
+ };
+ typedef std::deque< scoped_refptr<IOBufferWithSize> > PendingDataQueue;
+
+ friend class base::RefCountedThreadSafe<WebSocket>;
+ virtual ~WebSocket();
+
+ // Creates client handshake mssage based on |request_|.
+ IOBufferWithSize* CreateClientHandshakeMessage() const;
+
+ // Checks handshake.
+ // Prerequisite: Server handshake message is received in |current_read_buf_|.
+ // Returns number of bytes for server handshake message,
+ // or negative if server handshake message is not received fully yet.
+ int CheckHandshake();
+
+ // Processes server handshake message, parsed as |headers|, and updates
+ // |ws_origin_|, |ws_location_| and |ws_protocol_|.
+ // Returns true if it's ok.
+ // Returns false otherwise (e.g. duplicate WebSocket-Origin: header, etc.)
+ bool ProcessHeaders(const HttpResponseHeaders& headers);
+
+ // Checks |ws_origin_|, |ws_location_| and |ws_protocol_| are valid
+ // against |request_|.
+ // Returns true if it's ok.
+ // Returns false otherwise (e.g. origin mismatch, etc.)
+ bool CheckResponseHeaders() const;
+
+ // Sends pending data in |current_write_buf_| and/or |pending_write_bufs_|.
+ void SendPending();
+
+ // Handles received data.
+ void DoReceivedData();
+
+ // Processes frame data in |current_read_buf_|.
+ void ProcessFrameData();
+
+ // Skips |len| bytes in |current_read_buf_|.
+ void SkipReadBuffer(int len);
+
+ // Handles closed connection.
+ void DoClose();
+
+ State ready_state_;
+ Mode mode_;
+ scoped_ptr<Request> request_;
+ Delegate* delegate_;
+ MessageLoop* origin_loop_;
+
+ // Handshake messages that server sent.
+ std::string ws_origin_;
+ std::string ws_location_;
+ std::string ws_protocol_;
+
+ scoped_refptr<SocketStream> socket_stream_;
+ int max_pending_send_allowed_;
+
+ // [0..offset) is received data from |socket_stream_|.
+ // [0..read_consumed_len_) is already processed.
+ // [read_consumed_len_..offset) is unprocessed data.
+ // [offset..capacity) is free space.
+ scoped_refptr<GrowableIOBuffer> current_read_buf_;
+ int read_consumed_len_;
+
+ // Drainable IOBuffer on the front of |pending_write_bufs_|.
+ // [0..offset) is already sent to |socket_stream_|.
+ // [offset..size) is being sent to |socket_stream_|, waiting OnSentData.
+ scoped_refptr<DrainableIOBuffer> current_write_buf_;
+
+ // Deque of IOBuffers in pending.
+ // Front IOBuffer is being sent via |current_write_buf_|.
+ PendingDataQueue pending_write_bufs_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocket);
+};
+
+} // namespace net
+
+#endif // NET_WEBSOCKETS_WEBSOCKET_H_
diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc
new file mode 100644
index 0000000..d0e3fd9
--- /dev/null
+++ b/net/websockets/websocket_unittest.cc
@@ -0,0 +1,214 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <string>
+#include <vector>
+
+#include "base/task.h"
+#include "net/base/completion_callback.h"
+#include "net/base/mock_host_resolver.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/socket_test_util.h"
+#include "net/url_request/url_request_unittest.h"
+#include "net/websockets/websocket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+struct WebSocketEvent {
+ enum EventType {
+ EVENT_OPEN, EVENT_MESSAGE, EVENT_CLOSE,
+ };
+
+ WebSocketEvent(EventType type, net::WebSocket* websocket,
+ const std::string& websocket_msg)
+ : event_type(type), socket(websocket), msg(websocket_msg) {}
+
+ EventType event_type;
+ net::WebSocket* socket;
+ std::string msg;
+};
+
+class WebSocketEventRecorder : public net::WebSocket::Delegate {
+ public:
+ explicit WebSocketEventRecorder(net::CompletionCallback* callback)
+ : onopen_(NULL),
+ onmessage_(NULL),
+ onclose_(NULL),
+ callback_(callback) {}
+ virtual ~WebSocketEventRecorder() {
+ delete onopen_;
+ delete onmessage_;
+ delete onclose_;
+ }
+
+ void SetOnOpen(Callback1<WebSocketEvent*>::Type* callback) {
+ onopen_ = callback;
+ }
+ void SetOnMessage(Callback1<WebSocketEvent*>::Type* callback) {
+ onmessage_ = callback;
+ }
+ void SetOnClose(Callback1<WebSocketEvent*>::Type* callback) {
+ onclose_ = callback;
+ }
+
+ virtual void OnOpen(net::WebSocket* socket) {
+ events_.push_back(
+ WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket, std::string()));
+ if (onopen_)
+ onopen_->Run(&events_.back());
+ }
+
+ virtual void OnMessage(net::WebSocket* socket, const std::string& msg) {
+ events_.push_back(
+ WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg));
+ if (onmessage_)
+ onmessage_->Run(&events_.back());
+ }
+ virtual void OnClose(net::WebSocket* socket) {
+ events_.push_back(
+ WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket, std::string()));
+ if (onclose_)
+ onclose_->Run(&events_.back());
+ if (callback_)
+ callback_->Run(net::OK);
+ }
+
+ void DoClose(WebSocketEvent* event) {
+ event->socket->Close();
+ }
+
+ const std::vector<WebSocketEvent>& GetSeenEvents() const {
+ return events_;
+ }
+
+ private:
+ std::vector<WebSocketEvent> events_;
+ Callback1<WebSocketEvent*>::Type* onopen_;
+ Callback1<WebSocketEvent*>::Type* onmessage_;
+ Callback1<WebSocketEvent*>::Type* onclose_;
+ net::CompletionCallback* callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
+};
+
+class WebSocketTest : public PlatformTest {
+};
+
+TEST_F(WebSocketTest, Connect) {
+ net::MockClientSocketFactory mock_socket_factory;
+ net::MockRead data_reads[] = {
+ net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "WebSocket-Origin: http://example.com\r\n"
+ "WebSocket-Location: ws://example.com/demo\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
+ // Server doesn't close the connection after handshake.
+ net::MockRead(true, net::ERR_IO_PENDING),
+ };
+ net::MockWrite data_writes[] = {
+ net::MockWrite("GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
+ };
+ net::StaticMockSocket data(data_reads, data_writes);
+ mock_socket_factory.AddMockSocket(&data);
+
+ net::WebSocket::Request* request(
+ new net::WebSocket::Request(GURL("ws://example.com/demo"),
+ "sample",
+ "http://example.com",
+ "ws://example.com/demo",
+ new TestURLRequestContext()));
+ request->SetHostResolver(new net::MockHostResolver());
+ request->SetClientSocketFactory(&mock_socket_factory);
+
+ TestCompletionCallback callback;
+
+ scoped_ptr<WebSocketEventRecorder> delegate(
+ new WebSocketEventRecorder(&callback));
+ delegate->SetOnOpen(NewCallback(delegate.get(),
+ &WebSocketEventRecorder::DoClose));
+
+ scoped_refptr<net::WebSocket> websocket(
+ new net::WebSocket(request, delegate.get()));
+
+ EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state());
+ websocket->Connect();
+
+ callback.WaitForResult();
+
+ const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(2U, events.size());
+
+ EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
+ EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type);
+}
+
+TEST_F(WebSocketTest, ServerSentData) {
+ net::MockClientSocketFactory mock_socket_factory;
+ static const char kMessage[] = "Hello";
+ static const char kFrame[] = "\x00Hello\xff";
+ static const int kFrameLen = sizeof(kFrame) - 1;
+ net::MockRead data_reads[] = {
+ net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "WebSocket-Origin: http://example.com\r\n"
+ "WebSocket-Location: ws://example.com/demo\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
+ net::MockRead(true, kFrame, kFrameLen),
+ // Server doesn't close the connection after handshake.
+ net::MockRead(true, net::ERR_IO_PENDING),
+ };
+ net::MockWrite data_writes[] = {
+ net::MockWrite("GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
+ };
+ net::StaticMockSocket data(data_reads, data_writes);
+ mock_socket_factory.AddMockSocket(&data);
+
+ net::WebSocket::Request* request(
+ new net::WebSocket::Request(GURL("ws://example.com/demo"),
+ "sample",
+ "http://example.com",
+ "ws://example.com/demo",
+ new TestURLRequestContext()));
+ request->SetHostResolver(new net::MockHostResolver());
+ request->SetClientSocketFactory(&mock_socket_factory);
+
+ TestCompletionCallback callback;
+
+ scoped_ptr<WebSocketEventRecorder> delegate(
+ new WebSocketEventRecorder(&callback));
+ delegate->SetOnMessage(NewCallback(delegate.get(),
+ &WebSocketEventRecorder::DoClose));
+
+ scoped_refptr<net::WebSocket> websocket(
+ new net::WebSocket(request, delegate.get()));
+
+ EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state());
+ websocket->Connect();
+
+ callback.WaitForResult();
+
+ const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(3U, events.size());
+
+ EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
+ EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type);
+ EXPECT_EQ(kMessage, events[1].msg);
+ EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
+}