diff options
-rw-r--r-- | net/net.gyp | 1 | ||||
-rw-r--r-- | net/websockets/websocket.cc | 38 | ||||
-rw-r--r-- | net/websockets/websocket.h | 6 | ||||
-rw-r--r-- | net/websockets/websocket_unittest.cc | 126 |
4 files changed, 153 insertions, 18 deletions
diff --git a/net/net.gyp b/net/net.gyp index 4f50133..db853bf 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -559,6 +559,7 @@ 'net_test_support', '../base/base.gyp:base', '../base/base.gyp:base_i18n', + '../testing/gmock.gyp:gmock', '../testing/gtest.gyp:gtest', '../third_party/zlib/zlib.gyp:zlib', ], diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc index 719a870..ad7acaf 100644 --- a/net/websockets/websocket.cc +++ b/net/websockets/websocket.cc @@ -119,7 +119,10 @@ void WebSocket::OnConnected(SocketStream* socket_stream, read_consumed_len_ = 0; DCHECK(!current_write_buf_); - pending_write_bufs_.push_back(CreateClientHandshakeMessage()); + const std::string msg = request_->CreateClientHandshakeMessage(); + IOBufferWithSize* buf = new IOBufferWithSize(msg.size()); + memcpy(buf->data(), msg.data(), msg.size()); + pending_write_bufs_.push_back(buf); origin_loop_->PostTask(FROM_HERE, NewRunnableMethod(this, &WebSocket::SendPending)); } @@ -155,22 +158,22 @@ void WebSocket::OnError(const SocketStream* socket_stream, int error) { NewRunnableMethod(this, &WebSocket::DoError, error)); } -IOBufferWithSize* WebSocket::CreateClientHandshakeMessage() const { +std::string WebSocket::Request::CreateClientHandshakeMessage() const { std::string msg; msg = "GET "; - msg += request_->url().path(); - if (request_->url().has_query()) { + msg += url_.path(); + if (url_.has_query()) { msg += "?"; - msg += request_->url().query(); + msg += url_.query(); } msg += " HTTP/1.1\r\n"; msg += kUpgradeHeader; msg += kConnectionHeader; msg += "Host: "; - msg += StringToLowerASCII(request_->url().host()); - if (request_->url().has_port()) { - bool secure = request_->is_secure(); - int port = request_->url().EffectiveIntPort(); + msg += StringToLowerASCII(url_.host()); + if (url_.has_port()) { + bool secure = is_secure(); + int port = url_.EffectiveIntPort(); if ((!secure && port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || (secure && @@ -181,18 +184,23 @@ IOBufferWithSize* WebSocket::CreateClientHandshakeMessage() const { } msg += "\r\n"; msg += "Origin: "; - msg += StringToLowerASCII(request_->origin()); + // It's OK to lowercase the origin as the Origin header does not contain + // the path or query portions, as per + // http://tools.ietf.org/html/draft-abarth-origin-00. + // + // TODO(satorux): Should we trim the port portion here if it's 80 for + // http:// or 443 for https:// ? Or can we assume it's done by the + // client of the library? + msg += StringToLowerASCII(origin_); msg += "\r\n"; - if (!request_->protocol().empty()) { + if (!protocol_.empty()) { msg += "WebSocket-Protocol: "; - msg += request_->protocol(); + msg += protocol_; msg += "\r\n"; } // TODO(ukai): Add cookie if necessary. msg += "\r\n"; - IOBufferWithSize* buf = new IOBufferWithSize(msg.size()); - memcpy(buf->data(), msg.data(), msg.size()); - return buf; + return msg; } int WebSocket::CheckHandshake() { diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h index 566ee6f..0cf95db 100644 --- a/net/websockets/websocket.h +++ b/net/websockets/websocket.h @@ -95,6 +95,9 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>, return client_socket_factory_; } + // Creates the client handshake message from |this|. + std::string CreateClientHandshakeMessage() const; + private: GURL url_; std::string protocol_; @@ -154,9 +157,6 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>, friend class base::RefCountedThreadSafe<WebSocket>; virtual ~WebSocket(); - // Creates client handshake mssage based on |request_|. - IOBufferWithSize* CreateClientHandshakeMessage() const; - // Checks handshake. // Prerequisite: Server handshake message is received in |current_read_buf_|. // Returns number of bytes for server handshake message, diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc index 08b7bed..e3c5725 100644 --- a/net/websockets/websocket_unittest.cc +++ b/net/websockets/websocket_unittest.cc @@ -14,6 +14,7 @@ #include "net/url_request/url_request_unittest.h" #include "net/websockets/websocket.h" #include "testing/gtest/include/gtest/gtest.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/platform_test.h" struct WebSocketEvent { @@ -326,4 +327,129 @@ TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) { websocket->DetachDelegate(); } +TEST(WebSocketRequestTest, is_secure_false) { + WebSocket::Request request(GURL("ws://example.com/demo"), + "sample", + "http://example.com", + "ws://example.com/demo", + NULL); + EXPECT_FALSE(request.is_secure()); +} + +TEST(WebSocketRequestTest, is_secure_true) { + // wss:// is secure. + WebSocket::Request request(GURL("wss://example.com/demo"), + "sample", + "http://example.com", + "wss://example.com/demo", + NULL); + EXPECT_TRUE(request.is_secure()); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_Simple) { + WebSocket::Request request(GURL("ws://example.com/demo"), + "sample", + "http://example.com", + "ws://example.com/demo", + NULL); + EXPECT_EQ("GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n", + request.CreateClientHandshakeMessage()); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_PathAndQuery) { + WebSocket::Request request(GURL("ws://example.com/Test?q=xxx&p=%20"), + "sample", + "http://example.com", + "ws://example.com/demo", + NULL); + // Path and query should be preserved as-is. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("GET /Test?q=xxx&p=%20 HTTP/1.1\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_Host) { + WebSocket::Request request(GURL("ws://Example.Com/demo"), + "sample", + "http://Example.Com", + "ws://Example.Com/demo", + NULL); + // Host should be lowercased + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com\r\n")); + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Origin: http://example.com\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_TrimPort80) { + WebSocket::Request request(GURL("ws://example.com:80/demo"), + "sample", + "http://example.com", + "ws://example.com/demo", + NULL); + // :80 should be trimmed as it's the default port for ws://. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_TrimPort443) { + WebSocket::Request request(GURL("wss://example.com:443/demo"), + "sample", + "http://example.com", + "wss://example.com/demo", + NULL); + // :443 should be trimmed as it's the default port for wss://. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_NonDefaultPortForWs) { + WebSocket::Request request(GURL("ws://example.com:8080/demo"), + "sample", + "http://example.com", + "wss://example.com/demo", + NULL); + // :8080 should be preserved as it's not the default port for ws://. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:8080\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_NonDefaultPortForWss) { + WebSocket::Request request(GURL("wss://example.com:4443/demo"), + "sample", + "http://example.com", + "wss://example.com/demo", + NULL); + // :4443 should be preserved as it's not the default port for wss://. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:4443\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_WsBut443) { + WebSocket::Request request(GURL("ws://example.com:443/demo"), + "sample", + "http://example.com", + "ws://example.com/demo", + NULL); + // :443 should be preserved as it's not the default port for ws://. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:443\r\n")); +} + +TEST(WebSocketRequestTest, CreateClientHandshakeMessage_WssBut80) { + WebSocket::Request request(GURL("wss://example.com:80/demo"), + "sample", + "http://example.com", + "wss://example.com/demo", + NULL); + // :80 should be preserved as it's not the default port for wss://. + EXPECT_THAT(request.CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:80\r\n")); +} + } // namespace net |