// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket.h" #include #include #include "base/bind.h" #include "base/bind_helpers.h" #include "base/callback.h" #include "net/base/completion_callback.h" #include "net/base/io_buffer.h" #include "net/base/mock_host_resolver.h" #include "net/base/test_completion_callback.h" #include "net/socket/socket_test_util.h" #include "net/url_request/url_request_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/platform_test.h" struct WebSocketEvent { enum EventType { EVENT_OPEN, EVENT_MESSAGE, EVENT_ERROR, EVENT_CLOSE, }; WebSocketEvent(EventType type, net::WebSocket* websocket, const std::string& websocket_msg, bool websocket_flag) : event_type(type), socket(websocket), msg(websocket_msg), flag(websocket_flag) {} EventType event_type; net::WebSocket* socket; std::string msg; bool flag; }; class WebSocketEventRecorder : public net::WebSocketDelegate { public: explicit WebSocketEventRecorder(net::CompletionCallback* callback) : callback_(callback) {} virtual ~WebSocketEventRecorder() {} void SetOnOpen(const base::Callback& callback) { onopen_ = callback; } void SetOnMessage(const base::Callback& callback) { onmessage_ = callback; } void SetOnClose(const base::Callback& callback) { onclose_ = callback; } virtual void OnOpen(net::WebSocket* socket) { events_.push_back( WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket, std::string(), false)); if (!onopen_.is_null()) onopen_.Run(&events_.back()); } virtual void OnMessage(net::WebSocket* socket, const std::string& msg) { events_.push_back( WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg, false)); if (!onmessage_.is_null()) onmessage_.Run(&events_.back()); } virtual void OnError(net::WebSocket* socket) { events_.push_back( WebSocketEvent(WebSocketEvent::EVENT_ERROR, socket, std::string(), false)); if (!onerror_.is_null()) onerror_.Run(&events_.back()); } virtual void OnClose(net::WebSocket* socket, bool was_clean) { events_.push_back( WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket, std::string(), was_clean)); if (!onclose_.is_null()) onclose_.Run(&events_.back()); if (callback_) callback_->Run(net::OK); } void DoClose(WebSocketEvent* event) { event->socket->Close(); } const std::vector& GetSeenEvents() const { return events_; } private: std::vector events_; base::Callback onopen_; base::Callback onmessage_; base::Callback onerror_; base::Callback onclose_; net::CompletionCallback* callback_; DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder); }; namespace net { class WebSocketTest : public PlatformTest { protected: void InitReadBuf(WebSocket* websocket) { // Set up |current_read_buf_|. websocket->current_read_buf_ = new GrowableIOBuffer(); } void SetReadConsumed(WebSocket* websocket, int consumed) { websocket->read_consumed_len_ = consumed; } void AddToReadBuf(WebSocket* websocket, const char* data, int len) { websocket->AddToReadBuffer(data, len); } void TestProcessFrameData(WebSocket* websocket, const char* expected_remaining_data, int expected_remaining_len) { websocket->ProcessFrameData(); const char* actual_remaining_data = websocket->current_read_buf_->StartOfBuffer() + websocket->read_consumed_len_; int actual_remaining_len = websocket->current_read_buf_->offset() - websocket->read_consumed_len_; EXPECT_EQ(expected_remaining_len, actual_remaining_len); EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data, expected_remaining_len)); } }; TEST_F(WebSocketTest, Connect) { MockClientSocketFactory mock_socket_factory; MockRead data_reads[] = { MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "WebSocket-Origin: http://example.com\r\n" "WebSocket-Location: ws://example.com/demo\r\n" "WebSocket-Protocol: sample\r\n" "\r\n"), // Server doesn't close the connection after handshake. MockRead(true, ERR_IO_PENDING), }; MockWrite data_writes[] = { MockWrite("GET /demo HTTP/1.1\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Host: example.com\r\n" "Origin: http://example.com\r\n" "WebSocket-Protocol: sample\r\n" "\r\n"), }; StaticSocketDataProvider data(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes)); mock_socket_factory.AddSocketDataProvider(&data); MockHostResolver host_resolver; WebSocket::Request* request( new WebSocket::Request(GURL("ws://example.com/demo"), "sample", "http://example.com", "ws://example.com/demo", WebSocket::DRAFT75, new TestURLRequestContext())); request->SetHostResolver(&host_resolver); request->SetClientSocketFactory(&mock_socket_factory); TestCompletionCallback callback; scoped_ptr delegate( new WebSocketEventRecorder(&callback)); delegate->SetOnOpen(base::Bind(&WebSocketEventRecorder::DoClose, base::Unretained(delegate.get()))); scoped_refptr websocket( new WebSocket(request, delegate.get())); EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state()); websocket->Connect(); callback.WaitForResult(); const std::vector& events = delegate->GetSeenEvents(); EXPECT_EQ(2U, events.size()); EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type); EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type); } TEST_F(WebSocketTest, ServerSentData) { MockClientSocketFactory mock_socket_factory; static const char kMessage[] = "Hello"; static const char kFrame[] = "\x00Hello\xff"; static const int kFrameLen = sizeof(kFrame) - 1; MockRead data_reads[] = { MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "WebSocket-Origin: http://example.com\r\n" "WebSocket-Location: ws://example.com/demo\r\n" "WebSocket-Protocol: sample\r\n" "\r\n"), MockRead(true, kFrame, kFrameLen), // Server doesn't close the connection after handshake. MockRead(true, ERR_IO_PENDING), }; MockWrite data_writes[] = { MockWrite("GET /demo HTTP/1.1\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Host: example.com\r\n" "Origin: http://example.com\r\n" "WebSocket-Protocol: sample\r\n" "\r\n"), }; StaticSocketDataProvider data(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes)); mock_socket_factory.AddSocketDataProvider(&data); MockHostResolver host_resolver; WebSocket::Request* request( new WebSocket::Request(GURL("ws://example.com/demo"), "sample", "http://example.com", "ws://example.com/demo", WebSocket::DRAFT75, new TestURLRequestContext())); request->SetHostResolver(&host_resolver); request->SetClientSocketFactory(&mock_socket_factory); TestCompletionCallback callback; scoped_ptr delegate( new WebSocketEventRecorder(&callback)); delegate->SetOnMessage(base::Bind(&WebSocketEventRecorder::DoClose, base::Unretained(delegate.get()))); scoped_refptr websocket( new WebSocket(request, delegate.get())); EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state()); websocket->Connect(); callback.WaitForResult(); const std::vector& events = delegate->GetSeenEvents(); EXPECT_EQ(3U, events.size()); EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type); EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type); EXPECT_EQ(kMessage, events[1].msg); EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type); } TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) { WebSocket::Request* request( new WebSocket::Request(GURL("ws://example.com/demo"), "sample", "http://example.com", "ws://example.com/demo", WebSocket::DRAFT75, new TestURLRequestContext())); TestCompletionCallback callback; scoped_ptr delegate( new WebSocketEventRecorder(&callback)); scoped_refptr websocket( new WebSocket(request, delegate.get())); // Frame data: skip length 1 ('x'), and try to skip length 129 // (1 * 128 + 1) bytes after \x81\x01, but buffer is too short to skip. static const char kTestLengthFrame[] = "\x80\x01x\x80\x81\x01\x01\x00unexpected data\xFF"; const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1; InitReadBuf(websocket.get()); AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength); SetReadConsumed(websocket.get(), 0); static const char kExpectedRemainingFrame[] = "\x80\x81\x01\x01\x00unexpected data\xFF"; const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1; TestProcessFrameData(websocket.get(), kExpectedRemainingFrame, kExpectedRemainingLength); // No onmessage event expected. const std::vector& events = delegate->GetSeenEvents(); EXPECT_EQ(1U, events.size()); EXPECT_EQ(WebSocketEvent::EVENT_ERROR, events[0].event_type); websocket->DetachDelegate(); } TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) { WebSocket::Request* request( new WebSocket::Request(GURL("ws://example.com/demo"), "sample", "http://example.com", "ws://example.com/demo", WebSocket::DRAFT75, new TestURLRequestContext())); TestCompletionCallback callback; scoped_ptr delegate( new WebSocketEventRecorder(&callback)); scoped_refptr websocket( new WebSocket(request, delegate.get())); static const char kTestUnterminatedFrame[] = "\x00unterminated frame"; const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1; InitReadBuf(websocket.get()); AddToReadBuf(websocket.get(), kTestUnterminatedFrame, kTestUnterminatedFrameLength); SetReadConsumed(websocket.get(), 0); TestProcessFrameData(websocket.get(), kTestUnterminatedFrame, kTestUnterminatedFrameLength); { // No onmessage event expected. const std::vector& events = delegate->GetSeenEvents(); EXPECT_EQ(0U, events.size()); } static const char kTestTerminateFrame[] = " is terminated in next read\xff"; const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1; AddToReadBuf(websocket.get(), kTestTerminateFrame, kTestTerminateFrameLength); TestProcessFrameData(websocket.get(), "", 0); static const char kExpectedMsg[] = "unterminated frame is terminated in next read"; { const std::vector& events = delegate->GetSeenEvents(); EXPECT_EQ(1U, events.size()); EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type); EXPECT_EQ(kExpectedMsg, events[0].msg); } websocket->DetachDelegate(); } } // namespace net