diff options
-rw-r--r-- | chrome/chrome_tests.gypi | 32 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/sync_websocket.cc | 118 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/sync_websocket.h | 111 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/sync_websocket_unittest.cc | 174 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/url_request_context_getter.cc | 55 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/url_request_context_getter.h | 43 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/websocket.cc | 150 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/websocket.h | 86 | ||||
-rw-r--r-- | chrome/test/chromedriver/net/websocket_unittest.cc | 280 | ||||
-rw-r--r-- | net/websockets/websocket_frame.h | 12 | ||||
-rw-r--r-- | net/websockets/websocket_frame_parser.h | 2 |
11 files changed, 1055 insertions, 8 deletions
diff --git a/chrome/chrome_tests.gypi b/chrome/chrome_tests.gypi index 107381a..0d9ab38 100644 --- a/chrome/chrome_tests.gypi +++ b/chrome/chrome_tests.gypi @@ -652,7 +652,9 @@ 'type': 'static_library', 'dependencies': [ '../base/base.gyp:base', - '../base/third_party/dynamic_annotations/dynamic_annotations.gyp:dynamic_annotations' + '../base/third_party/dynamic_annotations/dynamic_annotations.gyp:dynamic_annotations', + '../build/temp_gyp/googleurl.gyp:googleurl', + '../net/net.gyp:net', ], 'include_dirs': [ '..', @@ -675,6 +677,12 @@ 'test/chromedriver/command_executor_impl.h', 'test/chromedriver/commands.cc', 'test/chromedriver/commands.h', + 'test/chromedriver/net/sync_websocket.cc', + 'test/chromedriver/net/sync_websocket.h', + 'test/chromedriver/net/url_request_context_getter.cc', + 'test/chromedriver/net/url_request_context_getter.h', + 'test/chromedriver/net/websocket.cc', + 'test/chromedriver/net/websocket.h', 'test/chromedriver/session.cc', 'test/chromedriver/session.h', 'test/chromedriver/session_command.cc', @@ -710,6 +718,28 @@ 'test/chromedriver/synchronized_map_unittest.cc', ], }, + # ChromeDriver2 tests that aren't run on the main buildbots. + { + 'target_name': 'chromedriver2_tests', + 'type': 'executable', + 'dependencies': [ + 'chromedriver2_lib', + '../base/base.gyp:base', + '../base/base.gyp:run_all_unittests', + '../build/temp_gyp/googleurl.gyp:googleurl', + '../net/net.gyp:http_server', + '../net/net.gyp:net', + '../net/net.gyp:net_test_support', + '../testing/gtest.gyp:gtest', + ], + 'include_dirs': [ + '..,' + ], + 'sources': [ + 'test/chromedriver/net/sync_websocket_unittest.cc', + 'test/chromedriver/net/websocket_unittest.cc', + ], + }, # This is the new ChromeDriver based on DevTools. { 'target_name': 'chromedriver2', diff --git a/chrome/test/chromedriver/net/sync_websocket.cc b/chrome/test/chromedriver/net/sync_websocket.cc new file mode 100644 index 0000000..dadf16c --- /dev/null +++ b/chrome/test/chromedriver/net/sync_websocket.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome/test/chromedriver/net/sync_websocket.h" + +#include "base/bind.h" +#include "base/callback.h" +#include "base/location.h" +#include "base/single_thread_task_runner.h" +#include "base/synchronization/waitable_event.h" +#include "googleurl/src/gurl.h" +#include "net/base/net_errors.h" +#include "net/url_request/url_request_context_getter.h" + +SyncWebSocket::SyncWebSocket( + net::URLRequestContextGetter* context_getter) + : core_(new Core(context_getter)) {} + +SyncWebSocket::~SyncWebSocket() {} + +bool SyncWebSocket::Connect(const GURL& url) { + return core_->Connect(url); +} + +bool SyncWebSocket::Send(const std::string& message) { + return core_->Send(message); +} + +bool SyncWebSocket::ReceiveNextMessage(std::string* message) { + return core_->ReceiveNextMessage(message); +} + +SyncWebSocket::Core::Core(net::URLRequestContextGetter* context_getter) + : context_getter_(context_getter), + closed_(false), + on_update_event_(&lock_) {} + +bool SyncWebSocket::Core::Connect(const GURL& url) { + bool success = false; + base::WaitableEvent event(false, false); + context_getter_->GetNetworkTaskRunner()->PostTask( + FROM_HERE, + base::Bind(&SyncWebSocket::Core::ConnectOnIO, + this, url, &success, &event)); + event.Wait(); + return success; +} + +bool SyncWebSocket::Core::Send(const std::string& message) { + bool success = false; + base::WaitableEvent event(false, false); + context_getter_->GetNetworkTaskRunner()->PostTask( + FROM_HERE, + base::Bind(&SyncWebSocket::Core::SendOnIO, + this, message, &success, &event)); + event.Wait(); + return success; +} + +bool SyncWebSocket::Core::ReceiveNextMessage(std::string* message) { + base::AutoLock lock(lock_); + while (received_queue_.empty() && !closed_) on_update_event_.Wait(); + if (closed_) + return false; + *message = received_queue_.front(); + received_queue_.pop_front(); + return true; +} + +void SyncWebSocket::Core::OnMessageReceived(const std::string& message) { + base::AutoLock lock(lock_); + received_queue_.push_back(message); + on_update_event_.Signal(); +} + +void SyncWebSocket::Core::OnClose() { + base::AutoLock lock(lock_); + closed_ = true; + on_update_event_.Signal(); +} + +SyncWebSocket::Core::~Core() { } + +void SyncWebSocket::Core::ConnectOnIO( + const GURL& url, + bool* success, + base::WaitableEvent* event) { + socket_.reset(new WebSocket(context_getter_, url, this)); + socket_->Connect(base::Bind( + &SyncWebSocket::Core::OnConnectCompletedOnIO, + this, success, event)); +} + +void SyncWebSocket::Core::OnConnectCompletedOnIO( + bool* success, + base::WaitableEvent* event, + int error) { + *success = (error == net::OK); + event->Signal(); +} + +void SyncWebSocket::Core::SendOnIO( + const std::string& message, + bool* success, + base::WaitableEvent* event) { + *success = socket_->Send(message); + event->Signal(); +} + +void SyncWebSocket::Core::OnDestruct() const { + scoped_refptr<base::SingleThreadTaskRunner> network_task_runner = + context_getter_->GetNetworkTaskRunner(); + if (network_task_runner->BelongsToCurrentThread()) + delete this; + else + network_task_runner->DeleteSoon(FROM_HERE, this); +} diff --git a/chrome/test/chromedriver/net/sync_websocket.h b/chrome/test/chromedriver/net/sync_websocket.h new file mode 100644 index 0000000..ad5cf5d --- /dev/null +++ b/chrome/test/chromedriver/net/sync_websocket.h @@ -0,0 +1,111 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_TEST_CHROMEDRIVER_NET_SYNC_WEBSOCKET_H_ +#define CHROME_TEST_CHROMEDRIVER_NET_SYNC_WEBSOCKET_H_ + +#include <list> +#include <string> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/synchronization/condition_variable.h" +#include "base/synchronization/lock.h" +#include "chrome/test/chromedriver/net/websocket.h" +#include "net/base/completion_callback.h" +#include "net/socket_stream/socket_stream.h" + +namespace base { +class WaitableEvent; +} + +namespace net { +class URLRequestContextGetter; +} + +class GURL; + +// Proxy for using a WebSocket running on a background thread synchronously. +class SyncWebSocket { + public: + explicit SyncWebSocket(net::URLRequestContextGetter* context_getter); + virtual ~SyncWebSocket(); + + // Connects to the WebSocket server. Returns true on success. + bool Connect(const GURL& url); + + // Sends message. Returns true on success. + bool Send(const std::string& message); + + // Receives next message. Blocks until at least one message is received or + // the socket is closed. Returns true on success and modifies |message|. + bool ReceiveNextMessage(std::string* message); + + private: + struct CoreTraits; + class Core : public WebSocketListener, + public base::RefCountedThreadSafe<Core, CoreTraits> { + public: + explicit Core(net::URLRequestContextGetter* context_getter); + + bool Connect(const GURL& url); + + bool Send(const std::string& message); + + bool ReceiveNextMessage(std::string* message); + + // Overriden from WebSocketListener: + virtual void OnMessageReceived(const std::string& message) OVERRIDE; + virtual void OnClose() OVERRIDE; + + private: + friend class base::RefCountedThreadSafe<Core, CoreTraits>; + friend class base::DeleteHelper<Core>; + friend struct CoreTraits; + + virtual ~Core(); + + void ConnectOnIO(const GURL& url, + bool* success, + base::WaitableEvent* event); + void OnConnectCompletedOnIO(bool* connected, + base::WaitableEvent* event, + int error); + void SendOnIO(const std::string& message, + bool* result, + base::WaitableEvent* event); + + // OnDestruct is meant to ensure deletion on the IO thread. + void OnDestruct() const; + + scoped_refptr<net::URLRequestContextGetter> context_getter_; + + // Only accessed on IO thread. + scoped_ptr<WebSocket> socket_; + + base::Lock lock_; + + // Protected by |lock_|. + bool closed_; + + // Protected by |lock_|. + std::list<std::string> received_queue_; + + // Protected by |lock_|. + // Signaled when the socket closes or a message is received. + base::ConditionVariable on_update_event_; + }; + + scoped_refptr<Core> core_; +}; + +struct SyncWebSocket::CoreTraits { + static void Destruct(const SyncWebSocket::Core* core) { + core->OnDestruct(); + } +}; + +#endif // CHROME_TEST_CHROMEDRIVER_NET_SYNC_WEBSOCKET_H_ diff --git a/chrome/test/chromedriver/net/sync_websocket_unittest.cc b/chrome/test/chromedriver/net/sync_websocket_unittest.cc new file mode 100644 index 0000000..5bb2b53 --- /dev/null +++ b/chrome/test/chromedriver/net/sync_websocket_unittest.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> + +#include "base/bind.h" +#include "base/compiler_specific.h" +#include "base/location.h" +#include "base/memory/ref_counted.h" +#include "base/message_loop.h" +#include "base/message_loop_proxy.h" +#include "base/single_thread_task_runner.h" +#include "base/stringprintf.h" +#include "base/synchronization/waitable_event.h" +#include "base/threading/thread.h" +#include "chrome/test/chromedriver/net/sync_websocket.h" +#include "chrome/test/chromedriver/net/url_request_context_getter.h" +#include "googleurl/src/gurl.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/tcp_listen_socket.h" +#include "net/server/http_server.h" +#include "net/server/http_server_request_info.h" +#include "net/url_request/url_request_context_getter.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace { + +class SyncWebSocketTest : public testing::Test, + public net::HttpServer::Delegate { + public: + SyncWebSocketTest() + : io_thread_("io"), + close_on_receive_(false) { + base::Thread::Options options(MessageLoop::TYPE_IO, 0); + CHECK(io_thread_.StartWithOptions(options)); + context_getter_ = new URLRequestContextGetter( + io_thread_.message_loop_proxy()); + base::WaitableEvent event(false, false); + io_thread_.message_loop_proxy()->PostTask( + FROM_HERE, + base::Bind(&SyncWebSocketTest::InitOnIO, + base::Unretained(this), &event)); + event.Wait(); + } + + virtual ~SyncWebSocketTest() { + base::WaitableEvent event(false, false); + io_thread_.message_loop_proxy()->PostTask( + FROM_HERE, + base::Bind(&SyncWebSocketTest::DestroyServerOnIO, + base::Unretained(this), &event)); + event.Wait(); + } + + void InitOnIO(base::WaitableEvent* event) { + net::TCPListenSocketFactory factory("127.0.0.1", 0); + server_ = new net::HttpServer(factory, this); + net::IPEndPoint address; + CHECK_EQ(net::OK, server_->GetLocalAddress(&address)); + server_url_ = GURL(base::StringPrintf("ws://127.0.0.1:%d", address.port())); + event->Signal(); + } + + void DestroyServerOnIO(base::WaitableEvent* event) { + server_ = NULL; + event->Signal(); + } + + // Overridden from net::HttpServer::Delegate: + virtual void OnHttpRequest(int connection_id, + const net::HttpServerRequestInfo& info) {} + + virtual void OnWebSocketRequest(int connection_id, + const net::HttpServerRequestInfo& info) { + server_->AcceptWebSocket(connection_id, info); + } + + virtual void OnWebSocketMessage(int connection_id, + const std::string& data) { + if (close_on_receive_) { + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&net::HttpServer::Close, server_, connection_id)); + } else { + server_->SendOverWebSocket(connection_id, data); + } + } + + virtual void OnClose(int connection_id) {} + + protected: + base::Thread io_thread_; + scoped_refptr<net::HttpServer> server_; + scoped_refptr<URLRequestContextGetter> context_getter_; + GURL server_url_; + bool close_on_receive_; +}; + +} // namespace + +TEST_F(SyncWebSocketTest, CreateDestroy) { + SyncWebSocket sock(context_getter_); +} + +TEST_F(SyncWebSocketTest, Connect) { + SyncWebSocket sock(context_getter_); + ASSERT_TRUE(sock.Connect(server_url_)); +} + +TEST_F(SyncWebSocketTest, ConnectFail) { + SyncWebSocket sock(context_getter_); + ASSERT_FALSE(sock.Connect(GURL("ws://127.0.0.1:33333"))); +} + +TEST_F(SyncWebSocketTest, SendReceive) { + SyncWebSocket sock(context_getter_); + ASSERT_TRUE(sock.Connect(server_url_)); + ASSERT_TRUE(sock.Send("hi")); + std::string message; + ASSERT_TRUE(sock.ReceiveNextMessage(&message)); + ASSERT_STREQ("hi", message.c_str()); +} + +TEST_F(SyncWebSocketTest, SendReceiveLarge) { + SyncWebSocket sock(context_getter_); + ASSERT_TRUE(sock.Connect(server_url_)); + // Sends/receives 200kb. For some reason pushing this above 240kb on my + // machine results in receiving no data back from the http server. + std::string wrote_message(200 << 10, 'a'); + ASSERT_TRUE(sock.Send(wrote_message)); + std::string message; + ASSERT_TRUE(sock.ReceiveNextMessage(&message)); + ASSERT_EQ(wrote_message.length(), message.length()); + ASSERT_EQ(wrote_message, message); +} + +TEST_F(SyncWebSocketTest, SendReceiveMany) { + SyncWebSocket sock(context_getter_); + ASSERT_TRUE(sock.Connect(server_url_)); + ASSERT_TRUE(sock.Send("1")); + ASSERT_TRUE(sock.Send("2")); + std::string message; + ASSERT_TRUE(sock.ReceiveNextMessage(&message)); + ASSERT_STREQ("1", message.c_str()); + ASSERT_TRUE(sock.Send("3")); + ASSERT_TRUE(sock.ReceiveNextMessage(&message)); + ASSERT_STREQ("2", message.c_str()); + ASSERT_TRUE(sock.ReceiveNextMessage(&message)); + ASSERT_STREQ("3", message.c_str()); +} + +TEST_F(SyncWebSocketTest, CloseOnReceive) { + close_on_receive_ = true; + SyncWebSocket sock(context_getter_); + ASSERT_TRUE(sock.Connect(server_url_)); + ASSERT_TRUE(sock.Send("1")); + std::string message; + ASSERT_FALSE(sock.ReceiveNextMessage(&message)); + ASSERT_STREQ("", message.c_str()); +} + +TEST_F(SyncWebSocketTest, CloseOnSend) { + SyncWebSocket sock(context_getter_); + ASSERT_TRUE(sock.Connect(server_url_)); + base::WaitableEvent event(false, false); + io_thread_.message_loop_proxy()->PostTask( + FROM_HERE, + base::Bind(&SyncWebSocketTest::DestroyServerOnIO, + base::Unretained(this), &event)); + event.Wait(); + ASSERT_FALSE(sock.Send("1")); +} diff --git a/chrome/test/chromedriver/net/url_request_context_getter.cc b/chrome/test/chromedriver/net/url_request_context_getter.cc new file mode 100644 index 0000000..b5b4a45 --- /dev/null +++ b/chrome/test/chromedriver/net/url_request_context_getter.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome/test/chromedriver/net/url_request_context_getter.h" + +#include <string> + +#include "net/proxy/proxy_config_service.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_builder.h" + +namespace { + +// Config getter that always returns direct settings. +class ProxyConfigServiceDirect : public net::ProxyConfigService { + public: + // Overridden from ProxyConfigService: + virtual void AddObserver(Observer* observer) OVERRIDE {} + virtual void RemoveObserver(Observer* observer) OVERRIDE {} + virtual ConfigAvailability GetLatestProxyConfig( + net::ProxyConfig* config) OVERRIDE { + *config = net::ProxyConfig::CreateDirect(); + return CONFIG_VALID; + } +}; + +} // namespace + +URLRequestContextGetter::URLRequestContextGetter( + scoped_refptr<base::SingleThreadTaskRunner> network_task_runner) + : network_task_runner_(network_task_runner) { +} + +net::URLRequestContext* URLRequestContextGetter::GetURLRequestContext() { + CHECK(network_task_runner_->BelongsToCurrentThread()); + if (!url_request_context_) { + net::URLRequestContextBuilder builder; + // net::HttpServer fails to parse headers if user-agent header is blank. + builder.set_user_agent("chromedriver"); + builder.DisableHttpCache(); +#if defined(OS_LINUX) || defined(OS_ANDROID) + builder.set_proxy_config_service(new ProxyConfigServiceDirect()); +#endif + url_request_context_.reset(builder.Build()); + } + return url_request_context_.get(); +} + +scoped_refptr<base::SingleThreadTaskRunner> + URLRequestContextGetter::GetNetworkTaskRunner() const { + return network_task_runner_; +} + +URLRequestContextGetter::~URLRequestContextGetter() {} diff --git a/chrome/test/chromedriver/net/url_request_context_getter.h b/chrome/test/chromedriver/net/url_request_context_getter.h new file mode 100644 index 0000000..23f7680 --- /dev/null +++ b/chrome/test/chromedriver/net/url_request_context_getter.h @@ -0,0 +1,43 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_TEST_CHROMEDRIVER_NET_URL_REQUEST_CONTEXT_GETTER_H_ +#define CHROME_TEST_CHROMEDRIVER_NET_URL_REQUEST_CONTEXT_GETTER_H_ + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/url_request/url_request_context_getter.h" + +namespace base { +class SingleThreadTaskRunner; +} + +namespace net { +class URLRequestContext; +} + +class URLRequestContextGetter : public net::URLRequestContextGetter { + public: + explicit URLRequestContextGetter( + scoped_refptr<base::SingleThreadTaskRunner> network_task_runner); + + // Overridden from net::URLRequestContextGetter: + virtual net::URLRequestContext* GetURLRequestContext() OVERRIDE; + virtual scoped_refptr<base::SingleThreadTaskRunner> + GetNetworkTaskRunner() const OVERRIDE; + + private: + virtual ~URLRequestContextGetter(); + + scoped_refptr<base::SingleThreadTaskRunner> network_task_runner_; + + // Only accessed on the IO thread. + scoped_ptr<net::URLRequestContext> url_request_context_; + + DISALLOW_COPY_AND_ASSIGN(URLRequestContextGetter); +}; + +#endif // CHROME_TEST_CHROMEDRIVER_NET_URL_REQUEST_CONTEXT_GETTER_H_ diff --git a/chrome/test/chromedriver/net/websocket.cc b/chrome/test/chromedriver/net/websocket.cc new file mode 100644 index 0000000..86c2b3a --- /dev/null +++ b/chrome/test/chromedriver/net/websocket.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome/test/chromedriver/net/websocket.h" + +#include "base/base64.h" +#include "base/memory/scoped_vector.h" +#include "base/rand_util.h" +#include "base/sha1.h" +#include "base/string_split.h" +#include "base/stringprintf.h" +#include "net/base/io_buffer.h" +#include "net/http/http_response_headers.h" +#include "net/http/http_util.h" +#include "net/url_request/url_request_context_getter.h" +#include "net/websockets/websocket_frame.h" +#include "net/websockets/websocket_job.h" + +WebSocket::WebSocket( + net::URLRequestContextGetter* context_getter, + const GURL& url, + WebSocketListener* listener) + : context_getter_(context_getter), + url_(url), + listener_(listener), + connected_(false) { + net::WebSocketJob::EnsureInit(); + web_socket_ = new net::WebSocketJob(this); +} + +WebSocket::~WebSocket() { + CHECK(thread_checker_.CalledOnValidThread()); + web_socket_->Close(); + web_socket_->DetachDelegate(); +} + +void WebSocket::Connect(const net::CompletionCallback& callback) { + CHECK(thread_checker_.CalledOnValidThread()); + CHECK_EQ(net::WebSocketJob::INITIALIZED, web_socket_->state()); + + connect_callback_ = callback; + + scoped_refptr<net::SocketStream> socket = new net::SocketStream( + url_, web_socket_); + socket->set_context(context_getter_->GetURLRequestContext()); + + web_socket_->InitSocketStream(socket); + web_socket_->Connect(); +} + +bool WebSocket::Send(const std::string& message) { + CHECK(thread_checker_.CalledOnValidThread()); + + net::WebSocketFrameHeader header; + header.final = true; + header.reserved1 = false; + header.reserved2 = false; + header.reserved3 = false; + header.opcode = net::WebSocketFrameHeader::kOpCodeText; + header.masked = true; + header.payload_length = message.length(); + int header_size = net::GetWebSocketFrameHeaderSize(header); + net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey(); + std::string header_str; + header_str.resize(header_size); + CHECK_EQ(header_size, net::WriteWebSocketFrameHeader( + header, &masking_key, &header_str[0], header_str.length())); + + std::string masked_message = message; + net::MaskWebSocketFramePayload( + masking_key, 0, &masked_message[0], masked_message.length()); + std::string data = header_str + masked_message; + return web_socket_->SendData(data.c_str(), data.length()); +} + +void WebSocket::OnConnected(net::SocketStream* socket, + int max_pending_send_allowed) { + CHECK(base::Base64Encode(base::RandBytesAsString(16), &sec_key_)); + std::string handshake = base::StringPrintf( + "GET %s HTTP/1.1\r\n" + "Host: %s\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: %s\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" + "\r\n", + url_.path().c_str(), + url_.host().c_str(), + sec_key_.c_str()); + if (!web_socket_->SendData(handshake.c_str(), handshake.length())) + OnConnectFinished(net::ERR_FAILED); +} + +void WebSocket::OnSentData(net::SocketStream* socket, + int amount_sent) {} + +void WebSocket::OnReceivedData(net::SocketStream* socket, + const char* data, int len) { + net::WebSocketJob::State state = web_socket_->state(); + if (!connect_callback_.is_null()) { + // WebSocketJob guarantees the first OnReceivedData call contains all + // the response headers. + const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + std::string websocket_accept; + CHECK(base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey), + &websocket_accept)); + scoped_refptr<net::HttpResponseHeaders> headers( + new net::HttpResponseHeaders( + net::HttpUtil::AssembleRawHeaders(data, len))); + if (headers->response_code() != 101 || + !headers->HasHeaderValue("Upgrade", "WebSocket") || + !headers->HasHeaderValue("Connection", "Upgrade") || + !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) { + OnConnectFinished(net::ERR_FAILED); + return; + } + OnConnectFinished( + state == net::WebSocketJob::OPEN ? net::OK : net::ERR_FAILED); + } else if (connected_) { + ScopedVector<net::WebSocketFrameChunk> frame_chunks; + CHECK(parser_.Decode(data, len, &frame_chunks)); + for (size_t i = 0; i < frame_chunks.size(); ++i) { + scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data; + if (buffer) + next_message_ += std::string(buffer->data(), buffer->size()); + if (frame_chunks[i]->final_chunk) { + listener_->OnMessageReceived(next_message_); + next_message_.clear(); + } + } + } +} + +void WebSocket::OnClose(net::SocketStream* socket) { + if (!connect_callback_.is_null()) + OnConnectFinished(net::ERR_CONNECTION_CLOSED); + else + listener_->OnClose(); +} + +void WebSocket::OnConnectFinished(net::Error error) { + if (error == net::OK) + connected_ = true; + net::CompletionCallback temp = connect_callback_; + connect_callback_.Reset(); + temp.Run(error); +} diff --git a/chrome/test/chromedriver/net/websocket.h b/chrome/test/chromedriver/net/websocket.h new file mode 100644 index 0000000..f2481c9 --- /dev/null +++ b/chrome/test/chromedriver/net/websocket.h @@ -0,0 +1,86 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_TEST_CHROMEDRIVER_NET_WEBSOCKET_H_ +#define CHROME_TEST_CHROMEDRIVER_NET_WEBSOCKET_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/threading/thread_checker.h" +#include "googleurl/src/gurl.h" +#include "net/base/completion_callback.h" +#include "net/base/net_errors.h" +#include "net/socket_stream/socket_stream.h" +#include "net/websockets/websocket_frame_parser.h" + +namespace net { +class URLRequestContextGetter; +class WebSocketJob; +} // namespace net + +class WebSocketListener; + +// A text-only, non-thread safe WebSocket. Must be created and used on a single +// thread. Intended particularly for use with net::HttpServer. +class WebSocket : public net::SocketStream::Delegate { + public: + WebSocket(net::URLRequestContextGetter* context_getter, + const GURL& url, + WebSocketListener* listener); + virtual ~WebSocket(); + + // Initializes the WebSocket connection. Invokes the given callback with + // a net::Error. May only be called once. + void Connect(const net::CompletionCallback& callback); + + // Sends the given message and returns true on success. + bool Send(const std::string& message); + + // Overridden from net::SocketStream::Delegate: + virtual void OnConnected(net::SocketStream* socket, + int max_pending_send_allowed) OVERRIDE; + virtual void OnSentData(net::SocketStream* socket, + int amount_sent) OVERRIDE; + virtual void OnReceivedData(net::SocketStream* socket, + const char* data, + int len) OVERRIDE; + virtual void OnClose(net::SocketStream* socket) OVERRIDE; + + private: + void OnConnectFinished(net::Error error); + + base::ThreadChecker thread_checker_; + scoped_refptr<net::URLRequestContextGetter> context_getter_; + GURL url_; + scoped_refptr<net::WebSocketJob> web_socket_; + WebSocketListener* listener_; + net::CompletionCallback connect_callback_; + std::string sec_key_; + net::WebSocketFrameParser parser_; + std::string next_message_; + bool connected_; + + DISALLOW_COPY_AND_ASSIGN(WebSocket); +}; + +// Listens for WebSocket messages and disconnects on the same thread as the +// WebSocket. +class WebSocketListener { + public: + virtual ~WebSocketListener() {} + + // Called when a WebSocket message is received. + virtual void OnMessageReceived(const std::string& message) = 0; + + // Called when the WebSocket connection closes. Will be called at most once. + // Will not be called if the connection was never established or if the close + // was initiated by the client. + virtual void OnClose() = 0; +}; + +#endif // CHROME_TEST_CHROMEDRIVER_NET_WEBSOCKET_H_ diff --git a/chrome/test/chromedriver/net/websocket_unittest.cc b/chrome/test/chromedriver/net/websocket_unittest.cc new file mode 100644 index 0000000..d47df88 --- /dev/null +++ b/chrome/test/chromedriver/net/websocket_unittest.cc @@ -0,0 +1,280 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> +#include <vector> + +#include "base/bind.h" +#include "base/compiler_specific.h" +#include "base/location.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" +#include "base/message_loop_proxy.h" +#include "base/run_loop.h" +#include "base/single_thread_task_runner.h" +#include "base/stringprintf.h" +#include "base/time.h" +#include "chrome/test/chromedriver/net/websocket.h" +#include "googleurl/src/gurl.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/tcp_listen_socket.h" +#include "net/server/http_server.h" +#include "net/server/http_server_request_info.h" +#include "net/url_request/url_request_context_getter.h" +#include "net/url_request/url_request_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace { + +void OnConnectFinished(int* save_error, int error) { + MessageLoop::current()->Quit(); + *save_error = error; +} + +class Listener : public WebSocketListener { + public: + explicit Listener(const std::vector<std::string>& messages) + : messages_(messages) {} + + virtual ~Listener() { + EXPECT_TRUE(messages_.empty()); + } + + virtual void OnMessageReceived(const std::string& message) OVERRIDE { + ASSERT_TRUE(messages_.size()); + EXPECT_EQ(messages_[0], message); + messages_.erase(messages_.begin()); + if (messages_.empty()) + MessageLoop::current()->Quit(); + } + + virtual void OnClose() OVERRIDE { + EXPECT_TRUE(false); + } + + private: + std::vector<std::string> messages_; +}; + +class CloseListener : public WebSocketListener { + public: + explicit CloseListener(bool expect_close) : expect_close_(expect_close) {} + + virtual ~CloseListener() { + EXPECT_FALSE(expect_close_); + } + + virtual void OnMessageReceived(const std::string& message) OVERRIDE {} + + virtual void OnClose() OVERRIDE { + EXPECT_TRUE(expect_close_); + if (expect_close_) + MessageLoop::current()->Quit(); + expect_close_ = false; + } + + private: + bool expect_close_; +}; + +class WebSocketTest : public testing::Test, + public net::HttpServer::Delegate { + public: + enum WebSocketRequestResponse { + kAccept = 0, + kNotFound, + kClose + }; + + WebSocketTest() + : ALLOW_THIS_IN_INITIALIZER_LIST(server_(CreateServer())), + context_getter_( + new net::TestURLRequestContextGetter(loop_.message_loop_proxy())), + ws_request_response_(kAccept), + close_on_message_(false), + quit_on_close_(false) { + net::IPEndPoint address; + CHECK_EQ(net::OK, server_->GetLocalAddress(&address)); + server_url_ = GURL(base::StringPrintf("ws://127.0.0.1:%d", address.port())); + } + + // Overridden from net::HttpServer::Delegate: + virtual void OnHttpRequest(int connection_id, + const net::HttpServerRequestInfo& info) {} + + virtual void OnWebSocketRequest(int connection_id, + const net::HttpServerRequestInfo& info) { + switch (ws_request_response_) { + case kAccept: + server_->AcceptWebSocket(connection_id, info); + break; + case kNotFound: + server_->Send404(connection_id); + break; + case kClose: + // net::HttpServer doesn't allow us to close connection during callback. + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&net::HttpServer::Close, server_, connection_id)); + break; + } + } + + virtual void OnWebSocketMessage(int connection_id, + const std::string& data) { + if (close_on_message_) { + // net::HttpServer doesn't allow us to close connection during callback. + MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&net::HttpServer::Close, server_, connection_id)); + } else { + server_->SendOverWebSocket(connection_id, data); + } + } + + virtual void OnClose(int connection_id) { + if (quit_on_close_) + MessageLoop::current()->Quit(); + } + + protected: + net::HttpServer* CreateServer() { + net::TCPListenSocketFactory factory("127.0.0.1", 0); + return new net::HttpServer(factory, this); + } + + scoped_ptr<WebSocket> CreateWebSocket(const GURL& url, + WebSocketListener* listener) { + int error; + scoped_ptr<WebSocket> sock(new WebSocket( + context_getter_, url, listener)); + sock->Connect(base::Bind(&OnConnectFinished, &error)); + loop_.PostDelayedTask( + FROM_HERE, MessageLoop::QuitWhenIdleClosure(), + base::TimeDelta::FromSeconds(10)); + base::RunLoop().Run(); + if (error == net::OK) + return sock.Pass(); + return scoped_ptr<WebSocket>(); + } + + scoped_ptr<WebSocket> CreateConnectedWebSocket(WebSocketListener* listener) { + return CreateWebSocket(server_url_, listener); + } + + void SendReceive(const std::vector<std::string>& messages) { + Listener listener(messages); + scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); + ASSERT_TRUE(sock); + for (size_t i = 0; i < messages.size(); ++i) { + ASSERT_TRUE(sock->Send(messages[i])); + } + base::RunLoop run_loop; + loop_.PostDelayedTask( + FROM_HERE, run_loop.QuitClosure(), + base::TimeDelta::FromSeconds(10)); + run_loop.Run(); + } + + MessageLoopForIO loop_; + scoped_refptr<net::HttpServer> server_; + scoped_refptr<net::URLRequestContextGetter> context_getter_; + GURL server_url_; + WebSocketRequestResponse ws_request_response_; + bool close_on_message_; + bool quit_on_close_; +}; + +} // namespace + +TEST_F(WebSocketTest, CreateDestroy) { + CloseListener listener(false); + WebSocket sock(context_getter_, GURL("http://ok"), &listener); +} + +TEST_F(WebSocketTest, Connect) { + CloseListener listener(false); + ASSERT_TRUE(CreateWebSocket(server_url_, &listener)); + quit_on_close_ = true; + base::RunLoop run_loop; + loop_.PostDelayedTask( + FROM_HERE, run_loop.QuitClosure(), + base::TimeDelta::FromSeconds(10)); + run_loop.Run(); +} + +TEST_F(WebSocketTest, ConnectNoServer) { + CloseListener listener(false); + ASSERT_FALSE(CreateWebSocket(GURL("ws://127.0.0.1:33333"), NULL)); +} + +TEST_F(WebSocketTest, Connect404) { + ws_request_response_ = kNotFound; + CloseListener listener(false); + ASSERT_FALSE(CreateWebSocket(server_url_, NULL)); + quit_on_close_ = true; + base::RunLoop run_loop; + loop_.PostDelayedTask( + FROM_HERE, run_loop.QuitClosure(), + base::TimeDelta::FromSeconds(10)); + run_loop.Run(); +} + +TEST_F(WebSocketTest, ConnectServerClosesConn) { + ws_request_response_ = kClose; + CloseListener listener(false); + ASSERT_FALSE(CreateWebSocket(server_url_, &listener)); +} + +TEST_F(WebSocketTest, CloseOnReceive) { + close_on_message_ = true; + CloseListener listener(true); + scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); + ASSERT_TRUE(sock); + ASSERT_TRUE(sock->Send("hi")); + base::RunLoop run_loop; + loop_.PostDelayedTask( + FROM_HERE, run_loop.QuitClosure(), + base::TimeDelta::FromSeconds(10)); + run_loop.Run(); +} + +TEST_F(WebSocketTest, CloseOnSend) { + CloseListener listener(true); + scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); + ASSERT_TRUE(sock); + server_ = NULL; + loop_.PostTask( + FROM_HERE, + base::Bind(base::IgnoreResult(&WebSocket::Send), + base::Unretained(sock.get()), "hi")); + base::RunLoop run_loop; + loop_.PostDelayedTask( + FROM_HERE, run_loop.QuitClosure(), + base::TimeDelta::FromSeconds(10)); + run_loop.Run(); +} + +TEST_F(WebSocketTest, SendReceive) { + std::vector<std::string> messages; + messages.push_back("hello"); + SendReceive(messages); +} + +TEST_F(WebSocketTest, SendReceiveLarge) { + std::vector<std::string> messages; + // Sends/receives 200kb. For some reason pushing this above 240kb on my + // machine results in receiving no data back from the http server. + messages.push_back(std::string(200 << 10, 'a')); + SendReceive(messages); +} + +TEST_F(WebSocketTest, SendReceiveMultiple) { + std::vector<std::string> messages; + messages.push_back("1"); + messages.push_back("2"); + messages.push_back("3"); + SendReceive(messages); +} diff --git a/net/websockets/websocket_frame.h b/net/websockets/websocket_frame.h index 94f4ede..46b011f 100644 --- a/net/websockets/websocket_frame.h +++ b/net/websockets/websocket_frame.h @@ -20,7 +20,7 @@ class IOBufferWithSize; // // Members of this class correspond to each element in WebSocket frame header // (see http://tools.ietf.org/html/rfc6455#section-5.2). -struct NET_EXPORT_PRIVATE WebSocketFrameHeader { +struct NET_EXPORT WebSocketFrameHeader { typedef int OpCode; static const OpCode kOpCodeContinuation; static const OpCode kOpCodeText; @@ -65,7 +65,7 @@ struct NET_EXPORT_PRIVATE WebSocketFrameHeader { // // This struct is used for reading WebSocket frame data (created by // WebSocketFrameParser). To construct WebSocket frames, use functions below. -struct NET_EXPORT_PRIVATE WebSocketFrameChunk { +struct NET_EXPORT WebSocketFrameChunk { WebSocketFrameChunk(); ~WebSocketFrameChunk(); @@ -89,7 +89,7 @@ struct WebSocketMaskingKey { // Returns the size of WebSocket frame header. The size of WebSocket frame // header varies from 2 bytes to 14 bytes depending on the payload length // and maskedness. -NET_EXPORT_PRIVATE int GetWebSocketFrameHeaderSize( +NET_EXPORT int GetWebSocketFrameHeaderSize( const WebSocketFrameHeader& header); // Writes wire format of a WebSocket frame header into |output|, and returns @@ -108,14 +108,14 @@ NET_EXPORT_PRIVATE int GetWebSocketFrameHeaderSize( // GetWebSocketFrameHeaderSize() can be used to know the size of header // beforehand. If the size of |buffer| is insufficient, this function returns // ERR_INVALID_ARGUMENT and does not write any data to |buffer|. -NET_EXPORT_PRIVATE int WriteWebSocketFrameHeader( +NET_EXPORT int WriteWebSocketFrameHeader( const WebSocketFrameHeader& header, const WebSocketMaskingKey* masking_key, char* buffer, int buffer_size); // Generates a masking key suitable for use in a new WebSocket frame. -NET_EXPORT_PRIVATE WebSocketMaskingKey GenerateWebSocketMaskingKey(); +NET_EXPORT WebSocketMaskingKey GenerateWebSocketMaskingKey(); // Masks WebSocket frame payload. // @@ -129,7 +129,7 @@ NET_EXPORT_PRIVATE WebSocketMaskingKey GenerateWebSocketMaskingKey(); // // Since frame masking is a reversible operation, this function can also be // used for unmasking a WebSocket frame. -NET_EXPORT_PRIVATE void MaskWebSocketFramePayload( +NET_EXPORT void MaskWebSocketFramePayload( const WebSocketMaskingKey& masking_key, uint64 frame_offset, char* data, diff --git a/net/websockets/websocket_frame_parser.h b/net/websockets/websocket_frame_parser.h index 8f5c3ff..c2517c9 100644 --- a/net/websockets/websocket_frame_parser.h +++ b/net/websockets/websocket_frame_parser.h @@ -22,7 +22,7 @@ namespace net { // Specification of WebSocket frame format is available at // <http://tools.ietf.org/html/rfc6455#section-5>. -class NET_EXPORT_PRIVATE WebSocketFrameParser { +class NET_EXPORT WebSocketFrameParser { public: WebSocketFrameParser(); ~WebSocketFrameParser(); |