From db6ffa59f24996886615578c5791d17fc51b423e Mon Sep 17 00:00:00 2001 From: "ukai@chromium.org" Date: Wed, 21 Oct 2009 05:36:17 +0000 Subject: WebSocket protocol handler for live experiment. This is in-browser-process WebSocket protocol handler, which will be used in WebSocket live experiment. BUG=none TEST=net_unittests passes Review URL: http://codereview.chromium.org/304014 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@29614 0039d316-1c4b-4281-b951-d872f2087c98 --- net/net.gyp | 3 + net/websockets/websocket.cc | 441 +++++++++++++++++++++++++++++++++++ net/websockets/websocket.h | 213 +++++++++++++++++ net/websockets/websocket_unittest.cc | 214 +++++++++++++++++ 4 files changed, 871 insertions(+) create mode 100644 net/websockets/websocket.cc create mode 100644 net/websockets/websocket.h create mode 100644 net/websockets/websocket_unittest.cc diff --git a/net/net.gyp b/net/net.gyp index 00c5661..f672009 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -384,6 +384,8 @@ 'url_request/url_request_view_net_internals_job.h', 'url_request/view_cache_helper.cc', 'url_request/view_cache_helper.h', + 'websockets/websocket.cc', + 'websockets/websocket.h', ], 'export_dependent_settings': [ '../base/base.gyp:base', @@ -553,6 +555,7 @@ 'socket/tcp_pinger_unittest.cc', 'url_request/url_request_unittest.cc', 'url_request/url_request_unittest.h', + 'websockets/websocket_unittest.cc', ], 'conditions': [ [ 'OS == "win"', { diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc new file mode 100644 index 0000000..f95a9cc --- /dev/null +++ b/net/websockets/websocket.cc @@ -0,0 +1,441 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include + +#include "net/websockets/websocket.h" + +#include "base/message_loop.h" +#include "net/http/http_response_headers.h" +#include "net/http/http_util.h" + +namespace net { + +static const int kWebSocketPort = 80; +static const int kSecureWebSocketPort = 443; + +static const char kServerHandshakeHeader[] = + "HTTP/1.1 101 Web Socket Protocol\r\n"; +static const size_t kServerHandshakeHeaderLength = + sizeof(kServerHandshakeHeader) - 1; + +static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; +static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1; + +static const char kConnectionHeader[] = "Connection: Upgrade\r\n"; +static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1; + +bool WebSocket::Request::is_secure() const { + return url_.SchemeIs("wss"); +} + +WebSocket::WebSocket(Request* request, Delegate* delegate) + : ready_state_(INITIALIZED), + mode_(MODE_INCOMPLETE), + request_(request), + delegate_(delegate), + origin_loop_(MessageLoop::current()), + socket_stream_(NULL), + max_pending_send_allowed_(0), + current_read_buf_(NULL), + read_consumed_len_(0), + current_write_buf_(NULL) { + DCHECK(request_.get()); + DCHECK(delegate_); + DCHECK(origin_loop_); +} + +WebSocket::~WebSocket() { + DCHECK(!delegate_); + DCHECK(!socket_stream_); +} + +void WebSocket::Connect() { + DCHECK(ready_state_ == INITIALIZED); + DCHECK(request_.get()); + DCHECK(delegate_); + DCHECK(!socket_stream_); + DCHECK(MessageLoop::current() == origin_loop_); + + socket_stream_ = new SocketStream(request_->url(), this); + socket_stream_->set_context(request_->context()); + + if (request_->host_resolver()) + socket_stream_->SetHostResolver(request_->host_resolver()); + if (request_->client_socket_factory()) + socket_stream_->SetClientSocketFactory(request_->client_socket_factory()); + + ready_state_ = CONNECTING; + socket_stream_->Connect(); +} + +void WebSocket::Send(const std::string& msg) { + DCHECK(ready_state_ == OPEN); + DCHECK(MessageLoop::current() == origin_loop_); + + IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2); + char* p = buf->data(); + *p = '\0'; + memcpy(p + 1, msg.data(), msg.size()); + *(p + 1 + msg.size()) = '\xff'; + pending_write_bufs_.push_back(buf); + SendPending(); +} + +void WebSocket::Close() { + DCHECK(MessageLoop::current() == origin_loop_); + + if (ready_state_ == INITIALIZED) { + DCHECK(!socket_stream_); + ready_state_ = CLOSED; + return; + } + if (ready_state_ != CLOSED) { + DCHECK(socket_stream_); + socket_stream_->Close(); + return; + } +} + +void WebSocket::OnConnected(SocketStream* socket_stream, + int max_pending_send_allowed) { + DCHECK(socket_stream == socket_stream_); + max_pending_send_allowed_ = max_pending_send_allowed; + + // Use |max_pending_send_allowed| as hint for initial size of read buffer. + current_read_buf_ = new GrowableIOBuffer(); + current_read_buf_->set_capacity(max_pending_send_allowed_); + read_consumed_len_ = 0; + + DCHECK(!current_write_buf_); + pending_write_bufs_.push_back(CreateClientHandshakeMessage()); + origin_loop_->PostTask(FROM_HERE, + NewRunnableMethod(this, &WebSocket::SendPending)); +} + +void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) { + DCHECK(socket_stream == socket_stream_); + DCHECK(current_write_buf_); + current_write_buf_->DidConsume(amount_sent); + DCHECK_GE(current_write_buf_->BytesRemaining(), 0); + if (current_write_buf_->BytesRemaining() == 0) { + current_write_buf_ = NULL; + pending_write_bufs_.pop_front(); + } + origin_loop_->PostTask(FROM_HERE, + NewRunnableMethod(this, &WebSocket::SendPending)); +} + +void WebSocket::OnReceivedData(SocketStream* socket_stream, + const char* data, int len) { + DCHECK(socket_stream == socket_stream_); + DCHECK(current_read_buf_); + // Check if |current_read_buf_| has enough space to store |len| of |data|. + if (len >= current_read_buf_->RemainingCapacity()) { + current_read_buf_->set_capacity( + current_read_buf_->offset() + len); + } + + DCHECK(current_read_buf_->RemainingCapacity() >= len); + memcpy(current_read_buf_->data(), data, len); + current_read_buf_->set_offset(current_read_buf_->offset() + len); + + origin_loop_->PostTask(FROM_HERE, + NewRunnableMethod(this, &WebSocket::DoReceivedData)); +} + +void WebSocket::OnClose(SocketStream* socket_stream) { + origin_loop_->PostTask(FROM_HERE, + NewRunnableMethod(this, &WebSocket::DoClose)); +} + +IOBufferWithSize* WebSocket::CreateClientHandshakeMessage() const { + std::string msg; + msg = "GET "; + msg += request_->url().path(); + if (request_->url().has_query()) { + msg += "?"; + msg += request_->url().query(); + } + msg += " HTTP/1.1\r\n"; + msg += kUpgradeHeader; + msg += kConnectionHeader; + msg += "Host: "; + msg += StringToLowerASCII(request_->url().host()); + if (request_->url().has_port()) { + bool secure = request_->is_secure(); + int port = request_->url().EffectiveIntPort(); + if ((!secure && + port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || + (secure && + port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { + msg += ":"; + msg += IntToString(port); + } + } + msg += "\r\n"; + msg += "Origin: "; + msg += StringToLowerASCII(request_->origin()); + msg += "\r\n"; + if (!request_->protocol().empty()) { + msg += "WebSocket-Protocol: "; + msg += request_->protocol(); + msg += "\r\n"; + } + // TODO(ukai): Add cookie if necessary. + msg += "\r\n"; + IOBufferWithSize* buf = new IOBufferWithSize(msg.size()); + memcpy(buf->data(), msg.data(), msg.size()); + return buf; +} + +int WebSocket::CheckHandshake() { + DCHECK(current_read_buf_); + DCHECK(ready_state_ == CONNECTING); + mode_ = MODE_INCOMPLETE; + const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_; + const char *p = start; + size_t len = current_read_buf_->offset() - read_consumed_len_; + if (len < kServerHandshakeHeaderLength) { + return -1; + } + if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { + mode_ = MODE_NORMAL; + } else { + int eoh = HttpUtil::LocateEndOfHeaders(p, len); + if (eoh < 0) + return -1; + scoped_refptr headers( + new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh))); + if (headers->response_code() == 401) { + mode_ = MODE_AUTHENTICATE; + // TODO(ukai): Implement authentication handlers. + } + // Invalid response code. + ready_state_ = CLOSED; + return eoh; + } + const char* end = p + len + 1; + p += kServerHandshakeHeaderLength; + + if (mode_ == MODE_NORMAL) { + size_t header_size = end - p; + if (header_size < kUpgradeHeaderLength) + return -1; + if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { + ready_state_ = CLOSED; + return p - start; + } + p += kUpgradeHeaderLength; + + header_size = end - p; + if (header_size < kConnectionHeaderLength) + return -1; + if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { + ready_state_ = CLOSED; + return p - start; + } + p += kConnectionHeaderLength; + } + int eoh = HttpUtil::LocateEndOfHeaders(start, len); + if (eoh == -1) + return eoh; + scoped_refptr headers( + new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh))); + if (!ProcessHeaders(*headers)) { + ready_state_ = CLOSED; + return eoh; + } + switch (mode_) { + case MODE_NORMAL: + if (CheckResponseHeaders()) { + ready_state_ = OPEN; + } else { + ready_state_ = CLOSED; + } + break; + default: + ready_state_ = CLOSED; + break; + } + return eoh; +} + +// Gets the value of the specified header. +// It assures only one header of |name| in |headers|. +// Returns true iff single header of |name| is found in |headers| +// and |value| is filled with the value. +// Returns false otherwise. +static bool GetSingleHeader(const HttpResponseHeaders& headers, + const std::string& name, + std::string* value) { + std::string first_value; + void* iter = NULL; + if (!headers.EnumerateHeader(&iter, name, &first_value)) + return false; + + // Checks no more |name| found in |headers|. + // Second call of EnumerateHeader() must return false. + std::string second_value; + if (headers.EnumerateHeader(&iter, name, &second_value)) + return false; + *value = first_value; + return true; +} + +bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) { + if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) + return false; + + if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) + return false; + + if (!request_->protocol().empty() + && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) + return false; + return true; +} + +bool WebSocket::CheckResponseHeaders() const { + DCHECK(mode_ == MODE_NORMAL); + if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str())) + return false; + if (request_->location() != ws_location_) + return false; + if (request_->protocol() != ws_protocol_) + return false; + return true; +} + +void WebSocket::SendPending() { + DCHECK(MessageLoop::current() == origin_loop_); + DCHECK(socket_stream_); + if (!current_write_buf_) { + if (pending_write_bufs_.empty()) + return; + current_write_buf_ = new DrainableIOBuffer( + pending_write_bufs_.front(), pending_write_bufs_.front()->size()); + } + DCHECK_GT(current_write_buf_->BytesRemaining(), 0); + bool sent = socket_stream_->SendData( + current_write_buf_->data(), + std::min(current_write_buf_->BytesRemaining(), + max_pending_send_allowed_)); + DCHECK(sent); +} + +void WebSocket::DoReceivedData() { + DCHECK(MessageLoop::current() == origin_loop_); + switch (ready_state_) { + case CONNECTING: + { + int eoh = CheckHandshake(); + if (eoh < 0) { + // Not enough data, Retry when more data is available. + return; + } + SkipReadBuffer(eoh); + } + if (ready_state_ != OPEN) { + // Handshake failed. + socket_stream_->Close(); + return; + } + delegate_->OnOpen(this); + if (current_read_buf_->offset() == read_consumed_len_) { + // No remaining data after handshake message. + break; + } + // FALL THROUGH + case OPEN: + ProcessFrameData(); + break; + + case CLOSED: + // Closed just after DoReceivedData is queued on |origin_loop_|. + break; + default: + NOTREACHED(); + break; + } +} + +void WebSocket::ProcessFrameData() { + DCHECK(current_read_buf_); + const char* start_frame = + current_read_buf_->StartOfBuffer() + read_consumed_len_; + const char* next_frame = start_frame; + const char* p = next_frame; + const char* end = + current_read_buf_->StartOfBuffer() + current_read_buf_->offset(); + while (p < end) { + unsigned char frame_byte = static_cast(*p++); + if ((frame_byte & 0x80) == 0x80) { + int length = 0; + while (p < end && (*p & 0x80) == 0x80) { + if (length > std::numeric_limits::max() / 128) { + // frame length overflow. + socket_stream_->Close(); + return; + } + length = length * 128 + *p & 0x7f; + ++p; + } + // Checks if the frame body hasn't been completely received yet. + // It also checks the case the frame length bytes haven't been completely + // received yet, because p == end and length > 0 in such case. + if (p + length < end) { + p += length; + next_frame = p; + } + } else { + const char* msg_start = p; + while (p < end && *p != '\xff') + ++p; + if (p < end && *p == '\xff') { + if (frame_byte == 0x00) + delegate_->OnMessage(this, std::string(msg_start, p - msg_start)); + ++p; + next_frame = p; + } + } + } + SkipReadBuffer(next_frame - start_frame); +} + +void WebSocket::SkipReadBuffer(int len) { + read_consumed_len_ += len; + int remaining = current_read_buf_->offset() - read_consumed_len_; + DCHECK_GE(remaining, 0); + if (remaining < read_consumed_len_ && + current_read_buf_->RemainingCapacity() < read_consumed_len_) { + // Pre compaction: + // 0 v-read_consumed_len_ v-offset v- capacity + // |..processed..| .. remaining .. | .. RemainingCapacity | + // + memmove(current_read_buf_->StartOfBuffer(), + current_read_buf_->StartOfBuffer() + read_consumed_len_, + remaining); + read_consumed_len_ = 0; + current_read_buf_->set_offset(remaining); + // Post compaction: + // 0read_consumed_len_ v- offset v- capacity + // |.. remaining .. | .. RemainingCapacity ... | + // + } +} + +void WebSocket::DoClose() { + DCHECK(MessageLoop::current() == origin_loop_); + Delegate* delegate = delegate_; + delegate_ = NULL; + ready_state_ = CLOSED; + if (!socket_stream_) + return; + socket_stream_ = NULL; + delegate->OnClose(this); +} + +} // namespace net diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h new file mode 100644 index 0000000..5393a28 --- /dev/null +++ b/net/websockets/websocket.h @@ -0,0 +1,213 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// WebSocket protocol implementation in chromium. +// It is intended to be used for live experiment of WebSocket connectivity +// metrics. +// Note that it is not used for WebKit's WebSocket communication. +// See third_party/WebKit/WebCore/websockets/ instead. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_H_ +#define NET_WEBSOCKETS_WEBSOCKET_H_ + +#include +#include + +#include "base/ref_counted.h" +#include "googleurl/src/gurl.h" +#include "net/base/io_buffer.h" +#include "net/socket_stream/socket_stream.h" +#include "net/url_request/url_request_context.h" + +class MessageLoop; + +namespace net { + +class ClientSocketFactory; +class HostResolver; +class HttpResponseHeaders; + +class WebSocket : public base::RefCountedThreadSafe, + public SocketStream::Delegate { + public: + enum State { + INITIALIZED = -1, + CONNECTING = 0, + OPEN = 1, + CLOSED = 2, + }; + class Request { + public: + Request(const GURL& url, const std::string protocol, + const std::string origin, const std::string location, + URLRequestContext* context) + : url_(url), + protocol_(protocol), + origin_(origin), + location_(location), + context_(context), + host_resolver_(NULL), + client_socket_factory_(NULL) {} + ~Request() {} + + const GURL& url() const { return url_; } + bool is_secure() const; + const std::string& protocol() const { return protocol_; } + const std::string& origin() const { return origin_; } + const std::string& location() const { return location_; } + URLRequestContext* context() const { return context_; } + + // Sets an alternative HostResolver. For testing purposes only. + void SetHostResolver(HostResolver* host_resolver) { + host_resolver_ = host_resolver; + } + HostResolver* host_resolver() const { return host_resolver_; } + + // Sets an alternative ClientSocketFactory. Doesn't take ownership of + // |factory|. For testing purposes only. + void SetClientSocketFactory(ClientSocketFactory* factory) { + client_socket_factory_ = factory; + } + ClientSocketFactory* client_socket_factory() const { + return client_socket_factory_; + } + + private: + GURL url_; + std::string protocol_; + std::string origin_; + std::string location_; + scoped_refptr context_; + + scoped_refptr host_resolver_; + ClientSocketFactory* client_socket_factory_; + + DISALLOW_COPY_AND_ASSIGN(Request); + }; + // Delegate methods will be called on the same message loop as + // WebSocket is constructed. + class Delegate { + public: + virtual ~Delegate() {} + + // Called when WebSocket connection has been established. + virtual void OnOpen(WebSocket* socket) = 0; + + // Called when |msg| is received at |socket|. + // |msg| should be in UTF-8. + virtual void OnMessage(WebSocket* socket, const std::string& msg) = 0; + + // Called when |socket| is closed. + virtual void OnClose(WebSocket* socket) = 0; + }; + + // Constructs new WebSocket. + // It takes ownership of |req|. + // |delegate| must be alive while this object is alive. + WebSocket(Request* req, Delegate* delegate); + + Delegate* delegate() const { return delegate_; } + + State ready_state() const { return ready_state_; } + + // Connects new WebSocket. + void Connect(); + + // Sends |msg| on the WebSocket connection. + // |msg| should be in UTF-8. + void Send(const std::string& msg); + + // Closes the WebSocket connection. + void Close(); + + // SocketStream::Delegate methods. + // Called on IO thread. + virtual void OnConnected(SocketStream* socket_stream, + int max_pending_send_allowed); + virtual void OnSentData(SocketStream* socket_stream, int amount_sent); + virtual void OnReceivedData(SocketStream* socket_stream, + const char* data, int len); + virtual void OnClose(SocketStream* socket); + + private: + enum Mode { + MODE_INCOMPLETE, MODE_NORMAL, MODE_AUTHENTICATE, + }; + typedef std::deque< scoped_refptr > PendingDataQueue; + + friend class base::RefCountedThreadSafe; + virtual ~WebSocket(); + + // Creates client handshake mssage based on |request_|. + IOBufferWithSize* CreateClientHandshakeMessage() const; + + // Checks handshake. + // Prerequisite: Server handshake message is received in |current_read_buf_|. + // Returns number of bytes for server handshake message, + // or negative if server handshake message is not received fully yet. + int CheckHandshake(); + + // Processes server handshake message, parsed as |headers|, and updates + // |ws_origin_|, |ws_location_| and |ws_protocol_|. + // Returns true if it's ok. + // Returns false otherwise (e.g. duplicate WebSocket-Origin: header, etc.) + bool ProcessHeaders(const HttpResponseHeaders& headers); + + // Checks |ws_origin_|, |ws_location_| and |ws_protocol_| are valid + // against |request_|. + // Returns true if it's ok. + // Returns false otherwise (e.g. origin mismatch, etc.) + bool CheckResponseHeaders() const; + + // Sends pending data in |current_write_buf_| and/or |pending_write_bufs_|. + void SendPending(); + + // Handles received data. + void DoReceivedData(); + + // Processes frame data in |current_read_buf_|. + void ProcessFrameData(); + + // Skips |len| bytes in |current_read_buf_|. + void SkipReadBuffer(int len); + + // Handles closed connection. + void DoClose(); + + State ready_state_; + Mode mode_; + scoped_ptr request_; + Delegate* delegate_; + MessageLoop* origin_loop_; + + // Handshake messages that server sent. + std::string ws_origin_; + std::string ws_location_; + std::string ws_protocol_; + + scoped_refptr socket_stream_; + int max_pending_send_allowed_; + + // [0..offset) is received data from |socket_stream_|. + // [0..read_consumed_len_) is already processed. + // [read_consumed_len_..offset) is unprocessed data. + // [offset..capacity) is free space. + scoped_refptr current_read_buf_; + int read_consumed_len_; + + // Drainable IOBuffer on the front of |pending_write_bufs_|. + // [0..offset) is already sent to |socket_stream_|. + // [offset..size) is being sent to |socket_stream_|, waiting OnSentData. + scoped_refptr current_write_buf_; + + // Deque of IOBuffers in pending. + // Front IOBuffer is being sent via |current_write_buf_|. + PendingDataQueue pending_write_bufs_; + + DISALLOW_COPY_AND_ASSIGN(WebSocket); +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_H_ diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc new file mode 100644 index 0000000..d0e3fd9 --- /dev/null +++ b/net/websockets/websocket_unittest.cc @@ -0,0 +1,214 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include + +#include "base/task.h" +#include "net/base/completion_callback.h" +#include "net/base/mock_host_resolver.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/socket_test_util.h" +#include "net/url_request/url_request_unittest.h" +#include "net/websockets/websocket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +struct WebSocketEvent { + enum EventType { + EVENT_OPEN, EVENT_MESSAGE, EVENT_CLOSE, + }; + + WebSocketEvent(EventType type, net::WebSocket* websocket, + const std::string& websocket_msg) + : event_type(type), socket(websocket), msg(websocket_msg) {} + + EventType event_type; + net::WebSocket* socket; + std::string msg; +}; + +class WebSocketEventRecorder : public net::WebSocket::Delegate { + public: + explicit WebSocketEventRecorder(net::CompletionCallback* callback) + : onopen_(NULL), + onmessage_(NULL), + onclose_(NULL), + callback_(callback) {} + virtual ~WebSocketEventRecorder() { + delete onopen_; + delete onmessage_; + delete onclose_; + } + + void SetOnOpen(Callback1::Type* callback) { + onopen_ = callback; + } + void SetOnMessage(Callback1::Type* callback) { + onmessage_ = callback; + } + void SetOnClose(Callback1::Type* callback) { + onclose_ = callback; + } + + virtual void OnOpen(net::WebSocket* socket) { + events_.push_back( + WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket, std::string())); + if (onopen_) + onopen_->Run(&events_.back()); + } + + virtual void OnMessage(net::WebSocket* socket, const std::string& msg) { + events_.push_back( + WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg)); + if (onmessage_) + onmessage_->Run(&events_.back()); + } + virtual void OnClose(net::WebSocket* socket) { + events_.push_back( + WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket, std::string())); + if (onclose_) + onclose_->Run(&events_.back()); + if (callback_) + callback_->Run(net::OK); + } + + void DoClose(WebSocketEvent* event) { + event->socket->Close(); + } + + const std::vector& GetSeenEvents() const { + return events_; + } + + private: + std::vector events_; + Callback1::Type* onopen_; + Callback1::Type* onmessage_; + Callback1::Type* onclose_; + net::CompletionCallback* callback_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder); +}; + +class WebSocketTest : public PlatformTest { +}; + +TEST_F(WebSocketTest, Connect) { + net::MockClientSocketFactory mock_socket_factory; + net::MockRead data_reads[] = { + net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"), + // Server doesn't close the connection after handshake. + net::MockRead(true, net::ERR_IO_PENDING), + }; + net::MockWrite data_writes[] = { + net::MockWrite("GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"), + }; + net::StaticMockSocket data(data_reads, data_writes); + mock_socket_factory.AddMockSocket(&data); + + net::WebSocket::Request* request( + new net::WebSocket::Request(GURL("ws://example.com/demo"), + "sample", + "http://example.com", + "ws://example.com/demo", + new TestURLRequestContext())); + request->SetHostResolver(new net::MockHostResolver()); + request->SetClientSocketFactory(&mock_socket_factory); + + TestCompletionCallback callback; + + scoped_ptr delegate( + new WebSocketEventRecorder(&callback)); + delegate->SetOnOpen(NewCallback(delegate.get(), + &WebSocketEventRecorder::DoClose)); + + scoped_refptr websocket( + new net::WebSocket(request, delegate.get())); + + EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state()); + websocket->Connect(); + + callback.WaitForResult(); + + const std::vector& events = delegate->GetSeenEvents(); + EXPECT_EQ(2U, events.size()); + + EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type); + EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type); +} + +TEST_F(WebSocketTest, ServerSentData) { + net::MockClientSocketFactory mock_socket_factory; + static const char kMessage[] = "Hello"; + static const char kFrame[] = "\x00Hello\xff"; + static const int kFrameLen = sizeof(kFrame) - 1; + net::MockRead data_reads[] = { + net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"), + net::MockRead(true, kFrame, kFrameLen), + // Server doesn't close the connection after handshake. + net::MockRead(true, net::ERR_IO_PENDING), + }; + net::MockWrite data_writes[] = { + net::MockWrite("GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"), + }; + net::StaticMockSocket data(data_reads, data_writes); + mock_socket_factory.AddMockSocket(&data); + + net::WebSocket::Request* request( + new net::WebSocket::Request(GURL("ws://example.com/demo"), + "sample", + "http://example.com", + "ws://example.com/demo", + new TestURLRequestContext())); + request->SetHostResolver(new net::MockHostResolver()); + request->SetClientSocketFactory(&mock_socket_factory); + + TestCompletionCallback callback; + + scoped_ptr delegate( + new WebSocketEventRecorder(&callback)); + delegate->SetOnMessage(NewCallback(delegate.get(), + &WebSocketEventRecorder::DoClose)); + + scoped_refptr websocket( + new net::WebSocket(request, delegate.get())); + + EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state()); + websocket->Connect(); + + callback.WaitForResult(); + + const std::vector& events = delegate->GetSeenEvents(); + EXPECT_EQ(3U, events.size()); + + EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type); + EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type); + EXPECT_EQ(kMessage, events[1].msg); + EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type); +} -- cgit v1.1