diff options
author | ukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-03-26 07:35:55 +0000 |
---|---|---|
committer | ukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-03-26 07:35:55 +0000 |
commit | 511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3 (patch) | |
tree | 8a4b3eb672a1a3f4efd172de181892cc72ec7275 /net/websockets | |
parent | 57771511d6e0f0bfea8aafa6bd6aca1294b08f87 (diff) | |
download | chromium_src-511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3.zip chromium_src-511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3.tar.gz chromium_src-511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3.tar.bz2 |
Implement new websocket handshake based on draft-hixie-thewebsocketprotocol-76
BUG=none
TEST=net_unittests passes
Review URL: http://codereview.chromium.org/1108002
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@42736 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/websockets')
-rw-r--r-- | net/websockets/websocket.cc | 18 | ||||
-rw-r--r-- | net/websockets/websocket.h | 8 | ||||
-rw-r--r-- | net/websockets/websocket_handshake.cc | 318 | ||||
-rw-r--r-- | net/websockets/websocket_handshake.h | 79 | ||||
-rw-r--r-- | net/websockets/websocket_handshake_draft75.cc | 156 | ||||
-rw-r--r-- | net/websockets/websocket_handshake_draft75.h | 63 | ||||
-rw-r--r-- | net/websockets/websocket_handshake_draft75_unittest.cc | 217 | ||||
-rw-r--r-- | net/websockets/websocket_handshake_unittest.cc | 253 | ||||
-rw-r--r-- | net/websockets/websocket_unittest.cc | 4 |
9 files changed, 900 insertions, 216 deletions
diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc index dc4f568..1cc3cf9 100644 --- a/net/websockets/websocket.cc +++ b/net/websockets/websocket.cc @@ -9,6 +9,7 @@ #include "base/message_loop.h" #include "net/websockets/websocket_handshake.h" +#include "net/websockets/websocket_handshake_draft75.h" namespace net { @@ -101,9 +102,20 @@ void WebSocket::OnConnected(SocketStream* socket_stream, DCHECK(!current_write_buf_); DCHECK(!handshake_.get()); - handshake_.reset(new WebSocketHandshake( - request_->url(), request_->origin(), request_->location(), - request_->protocol())); + switch (request_->version()) { + case DEFAULT_VERSION: + handshake_.reset(new WebSocketHandshake( + request_->url(), request_->origin(), request_->location(), + request_->protocol())); + break; + case DRAFT75: + handshake_.reset(new WebSocketHandshakeDraft75( + request_->url(), request_->origin(), request_->location(), + request_->protocol())); + break; + default: + NOTREACHED() << "Unexpected protocol version:" << request_->version(); + } const std::string msg = handshake_->CreateClientHandshakeMessage(); IOBufferWithSize* buf = new IOBufferWithSize(msg.size()); diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h index 5e58391..427af56 100644 --- a/net/websockets/websocket.h +++ b/net/websockets/websocket.h @@ -60,15 +60,21 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>, OPEN = 1, CLOSED = 2, }; + enum ProtocolVersion { + DEFAULT_VERSION = 0, + DRAFT75 = 1, + }; class Request { public: Request(const GURL& url, const std::string protocol, const std::string origin, const std::string location, + ProtocolVersion version, URLRequestContext* context) : url_(url), protocol_(protocol), origin_(origin), location_(location), + version_(version), context_(context), host_resolver_(NULL), client_socket_factory_(NULL) {} @@ -78,6 +84,7 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>, const std::string& protocol() const { return protocol_; } const std::string& origin() const { return origin_; } const std::string& location() const { return location_; } + ProtocolVersion version() const { return version_; } URLRequestContext* context() const { return context_; } // Sets an alternative HostResolver. For testing purposes only. @@ -100,6 +107,7 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>, std::string protocol_; std::string origin_; std::string location_; + ProtocolVersion version_; scoped_refptr<URLRequestContext> context_; scoped_refptr<HostResolver> host_resolver_; diff --git a/net/websockets/websocket_handshake.cc b/net/websockets/websocket_handshake.cc index c17ea34..6f660bc 100644 --- a/net/websockets/websocket_handshake.cc +++ b/net/websockets/websocket_handshake.cc @@ -4,6 +4,11 @@ #include "net/websockets/websocket_handshake.h" +#include <algorithm> +#include <vector> + +#include "base/md5.h" +#include "base/rand_util.h" #include "base/ref_counted.h" #include "base/string_util.h" #include "net/http/http_response_headers.h" @@ -14,19 +19,6 @@ namespace net { const int WebSocketHandshake::kWebSocketPort = 80; const int WebSocketHandshake::kSecureWebSocketPort = 443; -const char WebSocketHandshake::kServerHandshakeHeader[] = - "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; -const size_t WebSocketHandshake::kServerHandshakeHeaderLength = - sizeof(kServerHandshakeHeader) - 1; - -const char WebSocketHandshake::kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; -const size_t WebSocketHandshake::kUpgradeHeaderLength = - sizeof(kUpgradeHeader) - 1; - -const char WebSocketHandshake::kConnectionHeader[] = "Connection: Upgrade\r\n"; -const size_t WebSocketHandshake::kConnectionHeaderLength = - sizeof(kConnectionHeader) - 1; - WebSocketHandshake::WebSocketHandshake( const GURL& url, const std::string& origin, @@ -46,19 +38,94 @@ bool WebSocketHandshake::is_secure() const { return url_.SchemeIs("wss"); } -std::string WebSocketHandshake::CreateClientHandshakeMessage() const { +std::string WebSocketHandshake::CreateClientHandshakeMessage() { + if (!parameter_.get()) { + parameter_.reset(new Parameter); + parameter_->GenerateKeys(); + } std::string msg; + + // WebSocket protocol 4.1 Opening handshake. + msg = "GET "; - msg += url_.path(); + msg += GetResourceName(); + msg += " HTTP/1.1\r\n"; + + std::vector<std::string> fields; + + fields.push_back("Upgrade: WebSocket"); + fields.push_back("Connection: Upgrade"); + + fields.push_back("Host: " + GetHostFieldValue()); + + fields.push_back("Origin: " + GetOriginFieldValue()); + + if (!protocol_.empty()) + fields.push_back("Sec-WebSocket-Protocol: " + protocol_); + + // TODO(ukai): Add cookie if necessary. + + fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); + fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); + + std::random_shuffle(fields.begin(), fields.end()); + + for (size_t i = 0; i < fields.size(); i++) { + msg += fields[i] + "\r\n"; + } + msg += "\r\n"; + + msg.append(parameter_->GetKey3()); + return msg; +} + +int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { + mode_ = MODE_INCOMPLETE; + int eoh = HttpUtil::LocateEndOfHeaders(data, len); + if (eoh < 0) + return -1; + + scoped_refptr<HttpResponseHeaders> headers( + new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); + + if (headers->response_code() != 101) { + mode_ = MODE_FAILED; + DLOG(INFO) << "Bad response code: " << headers->response_code(); + return eoh; + } + mode_ = MODE_NORMAL; + if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { + DLOG(INFO) << "Process Headers failed: " + << std::string(data, eoh); + mode_ = MODE_FAILED; + return eoh; + } + if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) { + mode_ = MODE_INCOMPLETE; + return -1; + } + uint8 expected[Parameter::kExpectedResponseSize]; + parameter_->GetExpectedResponse(expected); + if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { + mode_ = MODE_FAILED; + return eoh + Parameter::kExpectedResponseSize; + } + mode_ = MODE_CONNECTED; + return eoh + Parameter::kExpectedResponseSize; +} + +std::string WebSocketHandshake::GetResourceName() const { + std::string resource_name = url_.path(); if (url_.has_query()) { - msg += "?"; - msg += url_.query(); + resource_name += "?"; + resource_name += url_.query(); } - msg += " HTTP/1.1\r\n"; - msg += kUpgradeHeader; - msg += kConnectionHeader; - msg += "Host: "; - msg += StringToLowerASCII(url_.host()); + return resource_name; +} + +std::string WebSocketHandshake::GetHostFieldValue() const { + // url_.host() is expected to be encoded in punnycode here. + std::string host = StringToLowerASCII(url_.host()); if (url_.has_port()) { bool secure = is_secure(); int port = url_.EffectiveIntPort(); @@ -66,12 +133,14 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const { port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || (secure && port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { - msg += ":"; - msg += IntToString(port); + host += ":"; + host += IntToString(port); } } - msg += "\r\n"; - msg += "Origin: "; + return host; +} + +std::string WebSocketHandshake::GetOriginFieldValue() const { // It's OK to lowercase the origin as the Origin header does not contain // the path or query portions, as per // http://tools.ietf.org/html/draft-abarth-origin-00. @@ -79,91 +148,13 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const { // TODO(satorux): Should we trim the port portion here if it's 80 for // http:// or 443 for https:// ? Or can we assume it's done by the // client of the library? - msg += StringToLowerASCII(origin_); - msg += "\r\n"; - if (!protocol_.empty()) { - msg += "WebSocket-Protocol: "; - msg += protocol_; - msg += "\r\n"; - } - // TODO(ukai): Add cookie if necessary. - msg += "\r\n"; - return msg; + return StringToLowerASCII(origin_); } -int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { - mode_ = MODE_INCOMPLETE; - if (len < kServerHandshakeHeaderLength) { - return -1; - } - if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { - mode_ = MODE_NORMAL; - } else { - int eoh = HttpUtil::LocateEndOfHeaders(data, len); - if (eoh < 0) - return -1; - return eoh; - } - const char* p = data + kServerHandshakeHeaderLength; - const char* end = data + len + 1; - - if (mode_ == MODE_NORMAL) { - size_t header_size = end - p; - if (header_size < kUpgradeHeaderLength) - return -1; - if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { - mode_ = MODE_FAILED; - DLOG(INFO) << "Bad Upgrade Header " - << std::string(p, kUpgradeHeaderLength); - return p - data; - } - p += kUpgradeHeaderLength; - header_size = end - p; - if (header_size < kConnectionHeaderLength) - return -1; - if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { - mode_ = MODE_FAILED; - DLOG(INFO) << "Bad Connection Header " - << std::string(p, kConnectionHeaderLength); - return p - data; - } - p += kConnectionHeaderLength; - } - - int eoh = HttpUtil::LocateEndOfHeaders(data, len); - if (eoh == -1) - return eoh; - - scoped_refptr<HttpResponseHeaders> headers( - new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); - if (!ProcessHeaders(*headers)) { - DLOG(INFO) << "Process Headers failed: " - << std::string(data, eoh); - mode_ = MODE_FAILED; - } - switch (mode_) { - case MODE_NORMAL: - if (CheckResponseHeaders()) { - mode_ = MODE_CONNECTED; - } else { - mode_ = MODE_FAILED; - } - break; - default: - mode_ = MODE_FAILED; - break; - } - return eoh; -} - -// Gets the value of the specified header. -// It assures only one header of |name| in |headers|. -// Returns true iff single header of |name| is found in |headers| -// and |value| is filled with the value. -// Returns false otherwise. -static bool GetSingleHeader(const HttpResponseHeaders& headers, - const std::string& name, - std::string* value) { +/* static */ +bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, + const std::string& name, + std::string* value) { std::string first_value; void* iter = NULL; if (!headers.EnumerateHeader(&iter, name, &first_value)) @@ -179,16 +170,25 @@ static bool GetSingleHeader(const HttpResponseHeaders& headers, } bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { - if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) + std::string value; + if (!GetSingleHeader(headers, "upgrade", &value) || + value != "WebSocket") + return false; + + if (!GetSingleHeader(headers, "connection", &value) || + !LowerCaseEqualsASCII(value, "upgrade")) + return false; + + if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) return false; - if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) + if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) return false; // If |protocol_| is not specified by client, we don't care if there's // protocol field or not as specified in the spec. if (!protocol_.empty() - && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) + && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) return false; return true; } @@ -204,6 +204,100 @@ bool WebSocketHandshake::CheckResponseHeaders() const { return true; } +namespace { + +// unsigned int version of base::RandInt(). +// we can't use base::RandInt(), because max would be negative if it is +// represented as int, so DCHECK(min <= max) fails. +uint32 RandUint32(uint32 min, uint32 max) { + DCHECK(min <= max); + + uint64 range = static_cast<int64>(max) - min + 1; + uint64 number = base::RandUint64(); + // TODO(ukai): fix to be uniform. + // the distribution of the result of modulo will be biased. + uint32 result = min + static_cast<uint32>(number % range); + DCHECK(result >= min && result <= max); + return result; +} + +} + +uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = + RandUint32; +uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; +WebSocketHandshake::Parameter::Parameter() + : number_1_(0), number_2_(0) { + if (randomCharacterInSecWebSocketKey[0] == '\0') { + int i = 0; + for (int ch = 0x21; ch <= 0x2F; ch++, i++) + randomCharacterInSecWebSocketKey[i] = ch; + for (int ch = 0x3A; ch <= 0x7E; ch++, i++) + randomCharacterInSecWebSocketKey[i] = ch; + } +} + +WebSocketHandshake::Parameter::~Parameter() {} + +void WebSocketHandshake::Parameter::GenerateKeys() { + GenerateSecWebSocketKey(&number_1_, &key_1_); + GenerateSecWebSocketKey(&number_2_, &key_2_); + GenerateKey3(); +} + +static void SetChallengeNumber(uint8* buf, uint32 number) { + uint8* p = buf + 3; + for (int i = 0; i < 4; i++) { + *p = (uint8)(number & 0xFF); + --p; + number >>= 8; + } +} + +void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { + uint8 challenge[kExpectedResponseSize]; + SetChallengeNumber(&challenge[0], number_1_); + SetChallengeNumber(&challenge[4], number_2_); + memcpy(&challenge[8], key_3_.data(), kKey3Size); + MD5Digest digest; + MD5Sum(challenge, kExpectedResponseSize, &digest); + memcpy(expected, digest.a, kExpectedResponseSize); +} + +/* static */ +void WebSocketHandshake::Parameter::SetRandomNumberGenerator( + uint32 (*rand)(uint32 min, uint32 max)) { + rand_ = rand; +} + +void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( + uint32* number, std::string* key) { + uint32 space = rand_(1, 12); + uint32 max = 4294967295U / space; + *number = rand_(0, max); + uint32 product = *number * space; + + std::string s = StringPrintf("%010u", product); + for (uint32 i = 0; i < space; i++) { + int pos = rand_(1, s.length() - 1); + s = s.substr(0, pos) + " " + s.substr(pos); + } + int n = rand_(1, 12); + for (int i = 0; i < n; i++) { + int pos = rand_(0, s.length()); + int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); + s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + + s.substr(pos); + } + *key = s; +} + +void WebSocketHandshake::Parameter::GenerateKey3() { + key_3_.clear(); + for (int i = 0; i < 8; i++) { + key_3_.append(1, rand_(0, 255)); + } +} } // namespace net diff --git a/net/websockets/websocket_handshake.h b/net/websockets/websocket_handshake.h index 1e94eff..3f64b8b 100644 --- a/net/websockets/websocket_handshake.h +++ b/net/websockets/websocket_handshake.h @@ -8,6 +8,7 @@ #include <string> #include "base/basictypes.h" +#include "base/scoped_ptr.h" #include "googleurl/src/gurl.h" namespace net { @@ -18,12 +19,6 @@ class WebSocketHandshake { public: static const int kWebSocketPort; static const int kSecureWebSocketPort; - static const char kServerHandshakeHeader[]; - static const size_t kServerHandshakeHeaderLength; - static const char kUpgradeHeader[]; - static const size_t kUpgradeHeaderLength; - static const char kConnectionHeader[]; - static const size_t kConnectionHeaderLength; enum Mode { MODE_INCOMPLETE, MODE_NORMAL, MODE_FAILED, MODE_CONNECTED @@ -32,32 +27,33 @@ class WebSocketHandshake { const std::string& origin, const std::string& location, const std::string& protocol); - ~WebSocketHandshake(); + virtual ~WebSocketHandshake(); bool is_secure() const; // Creates the client handshake message from |this|. - std::string CreateClientHandshakeMessage() const; + virtual std::string CreateClientHandshakeMessage(); // Reads server handshake message in |len| of |data|, updates |mode_| and // returns number of bytes of the server handshake message. // Once connection is established, |mode_| will be MODE_CONNECTED. // If connection establishment failed, |mode_| will be MODE_FAILED. // Returns negative if the server handshake message is incomplete. - int ReadServerHandshake(const char* data, size_t len); + virtual int ReadServerHandshake(const char* data, size_t len); Mode mode() const { return mode_; } - private: - // Processes server handshake message, parsed as |headers|, and updates - // |ws_origin_|, |ws_location_| and |ws_protocol_|. - // Returns true if it's ok. - // Returns false otherwise (e.g. duplicate WebSocket-Origin: header, etc.) - bool ProcessHeaders(const HttpResponseHeaders& headers); - - // Checks |ws_origin_|, |ws_location_| and |ws_protocol_| are valid - // against |origin_|, |location_| and |protocol_|. - // Returns true if it's ok. - // Returns false otherwise (e.g. origin mismatch, etc.) - bool CheckResponseHeaders() const; + protected: + std::string GetResourceName() const; + std::string GetHostFieldValue() const; + std::string GetOriginFieldValue() const; + + // Gets the value of the specified header. + // It assures only one header of |name| in |headers|. + // Returns true iff single header of |name| is found in |headers| + // and |value| is filled with the value. + // Returns false otherwise. + static bool GetSingleHeader(const HttpResponseHeaders& headers, + const std::string& name, + std::string* value); GURL url_; // Handshake messages that the client is going to send out. @@ -72,6 +68,47 @@ class WebSocketHandshake { std::string ws_location_; std::string ws_protocol_; + private: + friend class WebSocketHandshakeTest; + + class Parameter { + public: + static const int kKey3Size = 8; + static const int kExpectedResponseSize = 16; + Parameter(); + ~Parameter(); + + void GenerateKeys(); + const std::string& GetSecWebSocketKey1() const { return key_1_; } + const std::string& GetSecWebSocketKey2() const { return key_2_; } + const std::string& GetKey3() const { return key_3_; } + + void GetExpectedResponse(uint8* expected) const; + + private: + friend class WebSocketHandshakeTest; + + // Set random number generator. |rand| should return a random number + // between min and max (inclusive). + static void SetRandomNumberGenerator( + uint32 (*rand)(uint32 min, uint32 max)); + void GenerateSecWebSocketKey(uint32* number, std::string* key); + void GenerateKey3(); + + uint32 number_1_; + uint32 number_2_; + std::string key_1_; + std::string key_2_; + std::string key_3_; + + static uint32 (*rand_)(uint32 min, uint32 max); + }; + + virtual bool ProcessHeaders(const HttpResponseHeaders& headers); + virtual bool CheckResponseHeaders() const; + + scoped_ptr<Parameter> parameter_; + DISALLOW_COPY_AND_ASSIGN(WebSocketHandshake); }; diff --git a/net/websockets/websocket_handshake_draft75.cc b/net/websockets/websocket_handshake_draft75.cc new file mode 100644 index 0000000..78805fb --- /dev/null +++ b/net/websockets/websocket_handshake_draft75.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_handshake_draft75.h" + +#include "base/ref_counted.h" +#include "base/string_util.h" +#include "net/http/http_response_headers.h" +#include "net/http/http_util.h" + +namespace net { + +const char WebSocketHandshakeDraft75::kServerHandshakeHeader[] = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; +const size_t WebSocketHandshakeDraft75::kServerHandshakeHeaderLength = + sizeof(kServerHandshakeHeader) - 1; + +const char WebSocketHandshakeDraft75::kUpgradeHeader[] = + "Upgrade: WebSocket\r\n"; +const size_t WebSocketHandshakeDraft75::kUpgradeHeaderLength = + sizeof(kUpgradeHeader) - 1; + +const char WebSocketHandshakeDraft75::kConnectionHeader[] = + "Connection: Upgrade\r\n"; +const size_t WebSocketHandshakeDraft75::kConnectionHeaderLength = + sizeof(kConnectionHeader) - 1; + +WebSocketHandshakeDraft75::WebSocketHandshakeDraft75( + const GURL& url, + const std::string& origin, + const std::string& location, + const std::string& protocol) + : WebSocketHandshake(url, origin, location, protocol) { +} + +WebSocketHandshakeDraft75::~WebSocketHandshakeDraft75() { +} + +std::string WebSocketHandshakeDraft75::CreateClientHandshakeMessage() { + std::string msg; + msg = "GET "; + msg += GetResourceName(); + msg += " HTTP/1.1\r\n"; + msg += kUpgradeHeader; + msg += kConnectionHeader; + msg += "Host: "; + msg += GetHostFieldValue(); + msg += "\r\n"; + msg += "Origin: "; + msg += GetOriginFieldValue(); + msg += "\r\n"; + if (!protocol_.empty()) { + msg += "WebSocket-Protocol: "; + msg += protocol_; + msg += "\r\n"; + } + // TODO(ukai): Add cookie if necessary. + msg += "\r\n"; + return msg; +} + +int WebSocketHandshakeDraft75::ReadServerHandshake( + const char* data, size_t len) { + mode_ = MODE_INCOMPLETE; + if (len < kServerHandshakeHeaderLength) { + return -1; + } + if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { + mode_ = MODE_NORMAL; + } else { + int eoh = HttpUtil::LocateEndOfHeaders(data, len); + if (eoh < 0) + return -1; + return eoh; + } + const char* p = data + kServerHandshakeHeaderLength; + const char* end = data + len; + + if (mode_ == MODE_NORMAL) { + size_t header_size = end - p; + if (header_size < kUpgradeHeaderLength) + return -1; + if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { + mode_ = MODE_FAILED; + DLOG(INFO) << "Bad Upgrade Header " + << std::string(p, kUpgradeHeaderLength); + return p - data; + } + p += kUpgradeHeaderLength; + header_size = end - p; + if (header_size < kConnectionHeaderLength) + return -1; + if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { + mode_ = MODE_FAILED; + DLOG(INFO) << "Bad Connection Header " + << std::string(p, kConnectionHeaderLength); + return p - data; + } + p += kConnectionHeaderLength; + } + + int eoh = HttpUtil::LocateEndOfHeaders(data, len); + if (eoh == -1) + return eoh; + + scoped_refptr<HttpResponseHeaders> headers( + new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); + if (!ProcessHeaders(*headers)) { + DLOG(INFO) << "Process Headers failed: " + << std::string(data, eoh); + mode_ = MODE_FAILED; + } + switch (mode_) { + case MODE_NORMAL: + if (CheckResponseHeaders()) { + mode_ = MODE_CONNECTED; + } else { + mode_ = MODE_FAILED; + } + break; + default: + mode_ = MODE_FAILED; + break; + } + return eoh; +} + +bool WebSocketHandshakeDraft75::ProcessHeaders( + const HttpResponseHeaders& headers) { + if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) + return false; + + if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) + return false; + + // If |protocol_| is not specified by client, we don't care if there's + // protocol field or not as specified in the spec. + if (!protocol_.empty() + && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) + return false; + return true; +} + +bool WebSocketHandshakeDraft75::CheckResponseHeaders() const { + DCHECK(mode_ == MODE_NORMAL); + if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) + return false; + if (location_ != ws_location_) + return false; + if (!protocol_.empty() && protocol_ != ws_protocol_) + return false; + return true; +} + +} // namespace net diff --git a/net/websockets/websocket_handshake_draft75.h b/net/websockets/websocket_handshake_draft75.h new file mode 100644 index 0000000..6cc0506 --- /dev/null +++ b/net/websockets/websocket_handshake_draft75.h @@ -0,0 +1,63 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_DRAFT75_H_ +#define NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_DRAFT75_H_ + +#include <string> + +#include "base/basictypes.h" +#include "googleurl/src/gurl.h" +#include "net/websockets/websocket_handshake.h" + +namespace net { + +class HttpResponseHeaders; + +class WebSocketHandshakeDraft75 : public WebSocketHandshake { + public: + static const int kWebSocketPort; + static const int kSecureWebSocketPort; + static const char kServerHandshakeHeader[]; + static const size_t kServerHandshakeHeaderLength; + static const char kUpgradeHeader[]; + static const size_t kUpgradeHeaderLength; + static const char kConnectionHeader[]; + static const size_t kConnectionHeaderLength; + + WebSocketHandshakeDraft75(const GURL& url, + const std::string& origin, + const std::string& location, + const std::string& protocol); + virtual ~WebSocketHandshakeDraft75(); + + // Creates the client handshake message from |this|. + virtual std::string CreateClientHandshakeMessage(); + + // Reads server handshake message in |len| of |data|, updates |mode_| and + // returns number of bytes of the server handshake message. + // Once connection is established, |mode_| will be MODE_CONNECTED. + // If connection establishment failed, |mode_| will be MODE_FAILED. + // Returns negative if the server handshake message is incomplete. + virtual int ReadServerHandshake(const char* data, size_t len); + + private: + // Processes server handshake message, parsed as |headers|, and updates + // |ws_origin_|, |ws_location_| and |ws_protocol_|. + // Returns true if it's ok. + // Returns false otherwise (e.g. duplicate WebSocket-Origin: header, etc.) + virtual bool ProcessHeaders(const HttpResponseHeaders& headers); + + // Checks |ws_origin_|, |ws_location_| and |ws_protocol_| are valid + // against |origin_|, |location_| and |protocol_|. + // Returns true if it's ok. + // Returns false otherwise (e.g. origin mismatch, etc.) + virtual bool CheckResponseHeaders() const; + + DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeDraft75); +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_DRAFT75_H_ diff --git a/net/websockets/websocket_handshake_draft75_unittest.cc b/net/websockets/websocket_handshake_draft75_unittest.cc new file mode 100644 index 0000000..aff75ad --- /dev/null +++ b/net/websockets/websocket_handshake_draft75_unittest.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> +#include <vector> + +#include "base/scoped_ptr.h" +#include "net/websockets/websocket_handshake_draft75.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/platform_test.h" + +namespace net { + +TEST(WebSocketHandshakeDraft75Test, Connect) { + const std::string kExpectedClientHandshakeMessage = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com/demo"), + "http://example.com", + "ws://example.com/demo", + "sample")); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + EXPECT_EQ(kExpectedClientHandshakeMessage, + handshake->CreateClientHandshakeMessage()); + + const char kResponse[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + // too short + EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16)); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + // only status line + EXPECT_EQ(-1, handshake->ReadServerHandshake( + kResponse, + WebSocketHandshakeDraft75::kServerHandshakeHeaderLength)); + EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + // by upgrade header + EXPECT_EQ(-1, handshake->ReadServerHandshake( + kResponse, + WebSocketHandshakeDraft75::kServerHandshakeHeaderLength + + WebSocketHandshakeDraft75::kUpgradeHeaderLength)); + EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + // by connection header + EXPECT_EQ(-1, handshake->ReadServerHandshake( + kResponse, + WebSocketHandshakeDraft75::kServerHandshakeHeaderLength + + WebSocketHandshakeDraft75::kUpgradeHeaderLength + + WebSocketHandshakeDraft75::kConnectionHeaderLength)); + EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + + EXPECT_EQ(-1, handshake->ReadServerHandshake( + kResponse, sizeof(kResponse) - 2)); + EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + + int handshake_length = strlen(kResponse); + EXPECT_EQ(handshake_length, handshake->ReadServerHandshake( + kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0 + EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode()); +} + +TEST(WebSocketHandshakeDraft75Test, ServerSentData) { + const std::string kExpectedClientHandshakeMessage = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com/demo"), + "http://example.com", + "ws://example.com/demo", + "sample")); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + EXPECT_EQ(kExpectedClientHandshakeMessage, + handshake->CreateClientHandshakeMessage()); + + const char kResponse[] ="HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n" + "\0Hello\xff"; + + int handshake_length = strlen(kResponse); + EXPECT_EQ(handshake_length, handshake->ReadServerHandshake( + kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0 + EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode()); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_Simple) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com/demo"), + "http://example.com", + "ws://example.com/demo", + "sample")); + EXPECT_EQ("GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n", + handshake->CreateClientHandshakeMessage()); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_PathAndQuery) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com/Test?q=xxx&p=%20"), + "http://example.com", + "ws://example.com/demo", + "sample")); + // Path and query should be preserved as-is. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("GET /Test?q=xxx&p=%20 HTTP/1.1\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_Host) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://Example.Com/demo"), + "http://Example.Com", + "ws://Example.Com/demo", + "sample")); + // Host should be lowercased + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com\r\n")); + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Origin: http://example.com\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_TrimPort80) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com:80/demo"), + "http://example.com", + "ws://example.com/demo", + "sample")); + // :80 should be trimmed as it's the default port for ws://. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_TrimPort443) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("wss://example.com:443/demo"), + "http://example.com", + "wss://example.com/demo", + "sample")); + // :443 should be trimmed as it's the default port for wss://. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, + CreateClientHandshakeMessage_NonDefaultPortForWs) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com:8080/demo"), + "http://example.com", + "wss://example.com/demo", + "sample")); + // :8080 should be preserved as it's not the default port for ws://. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:8080\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, + CreateClientHandshakeMessage_NonDefaultPortForWss) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("wss://example.com:4443/demo"), + "http://example.com", + "wss://example.com/demo", + "sample")); + // :4443 should be preserved as it's not the default port for wss://. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:4443\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_WsBut443) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("ws://example.com:443/demo"), + "http://example.com", + "ws://example.com/demo", + "sample")); + // :443 should be preserved as it's not the default port for ws://. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:443\r\n")); +} + +TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_WssBut80) { + scoped_ptr<WebSocketHandshakeDraft75> handshake( + new WebSocketHandshakeDraft75(GURL("wss://example.com:80/demo"), + "http://example.com", + "wss://example.com/demo", + "sample")); + // :80 should be preserved as it's not the default port for wss://. + EXPECT_THAT(handshake->CreateClientHandshakeMessage(), + testing::HasSubstr("Host: example.com:80\r\n")); +} + +} // namespace net diff --git a/net/websockets/websocket_handshake_unittest.cc b/net/websockets/websocket_handshake_unittest.cc index beae805..f688554 100644 --- a/net/websockets/websocket_handshake_unittest.cc +++ b/net/websockets/websocket_handshake_unittest.cc @@ -6,6 +6,7 @@ #include <vector> #include "base/scoped_ptr.h" +#include "base/string_util.h" #include "net/websockets/websocket_handshake.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/gmock/include/gmock/gmock.h" @@ -13,100 +14,216 @@ namespace net { -TEST(WebSocketHandshakeTest, Connect) { +class WebSocketHandshakeTest : public testing::Test { + public: + static void SetUpParameter(WebSocketHandshake* handshake, + uint32 number_1, uint32 number_2, + const std::string& key_1, const std::string& key_2, + const std::string& key_3) { + WebSocketHandshake::Parameter* parameter = + new WebSocketHandshake::Parameter; + parameter->number_1_ = number_1; + parameter->number_2_ = number_2; + parameter->key_1_ = key_1; + parameter->key_2_ = key_2; + parameter->key_3_ = key_3; + handshake->parameter_.reset(parameter); + } + + static void ExpectHeaderEquals(const std::string& expected, + const std::string& actual) { + std::vector<std::string> expected_lines; + Tokenize(expected, "\r\n", &expected_lines); + std::vector<std::string> actual_lines; + Tokenize(actual, "\r\n", &actual_lines); + // Request lines. + EXPECT_EQ(expected_lines[0], actual_lines[0]); + + std::vector<std::string> expected_headers; + for (size_t i = 1; i < expected_lines.size(); i++) { + // Finish at first CRLF CRLF. Note that /key_3/ might include CRLF. + if (expected_lines[i] == "") + break; + expected_headers.push_back(expected_lines[i]); + } + sort(expected_headers.begin(), expected_headers.end()); + + std::vector<std::string> actual_headers; + for (size_t i = 1; i < actual_lines.size(); i++) { + // Finish at first CRLF CRLF. Note that /key_3/ might include CRLF. + if (actual_lines[i] == "") + break; + actual_headers.push_back(actual_lines[i]); + } + sort(actual_headers.begin(), actual_headers.end()); + + EXPECT_EQ(expected_headers.size(), actual_headers.size()) + << "expected:" << expected + << "\nactual:" << actual; + for (size_t i = 0; i < expected_headers.size(); i++) { + EXPECT_EQ(expected_headers[i], actual_headers[i]); + } + } + + static void ExpectHandshakeMessageEquals(const std::string& expected, + const std::string& actual) { + // Headers. + ExpectHeaderEquals(expected, actual); + // Compare tailing \r\n\r\n<key3> (4 + 8 bytes). + ASSERT_GT(expected.size(), 12U); + const char* expected_key3 = expected.data() + expected.size() - 12; + EXPECT_GT(actual.size(), 12U); + if (actual.size() <= 12U) + return; + const char* actual_key3 = actual.data() + actual.size() - 12; + EXPECT_TRUE(memcmp(expected_key3, actual_key3, 12) == 0) + << "expected_key3:" << DumpKey(expected_key3, 12) + << ", actual_key3:" << DumpKey(actual_key3, 12); + } + + static std::string DumpKey(const char* buf, int len) { + std::string s; + for (int i = 0; i < len; i++) { + if (isprint(buf[i])) + s += StringPrintf("%c", buf[i]); + else + s += StringPrintf("\\x%02x", buf[i]); + } + return s; + } + + static std::string GetResourceName(WebSocketHandshake* handshake) { + return handshake->GetResourceName(); + } + static std::string GetHostFieldValue(WebSocketHandshake* handshake) { + return handshake->GetHostFieldValue(); + } + static std::string GetOriginFieldValue(WebSocketHandshake* handshake) { + return handshake->GetOriginFieldValue(); + } +}; + + +TEST_F(WebSocketHandshakeTest, Connect) { const std::string kExpectedClientHandshakeMessage = "GET /demo HTTP/1.1\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Host: example.com\r\n" "Origin: http://example.com\r\n" - "WebSocket-Protocol: sample\r\n" - "\r\n"; + "Sec-WebSocket-Protocol: sample\r\n" + "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7 15\r\n" + "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o 4 I97N 5-S3O 31\r\n" + "\r\n" + "\x47\x30\x22\x2D\x5A\x3F\x47\x58"; scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com/demo"), "http://example.com", "ws://example.com/demo", "sample")); + SetUpParameter(handshake.get(), 777007543U, 114997259U, + "388P O503D&ul7 {K%gX( %7 15", + "1 N ?|k UT0or 3o 4 I97N 5-S3O 31", + std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8)); EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); - EXPECT_EQ(kExpectedClientHandshakeMessage, - handshake->CreateClientHandshakeMessage()); + ExpectHandshakeMessageEquals( + kExpectedClientHandshakeMessage, + handshake->CreateClientHandshakeMessage()); - const char kResponse[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" - "WebSocket-Origin: http://example.com\r\n" - "WebSocket-Location: ws://example.com/demo\r\n" - "WebSocket-Protocol: sample\r\n" - "\r\n"; + "Sec-WebSocket-Origin: http://example.com\r\n" + "Sec-WebSocket-Location: ws://example.com/demo\r\n" + "Sec-WebSocket-Protocol: sample\r\n" + "\r\n" + "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75"; + std::vector<std::string> response_lines; + SplitStringDontTrim(kResponse, '\n', &response_lines); EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); // too short EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16)); EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + // only status line + std::string response = response_lines[0]; EXPECT_EQ(-1, handshake->ReadServerHandshake( - kResponse, - WebSocketHandshake::kServerHandshakeHeaderLength)); - EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + response.data(), response.size())); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); // by upgrade header + response += response_lines[1]; EXPECT_EQ(-1, handshake->ReadServerHandshake( - kResponse, - WebSocketHandshake::kServerHandshakeHeaderLength + - WebSocketHandshake::kUpgradeHeaderLength)); - EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + response.data(), response.size())); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); // by connection header + response += response_lines[2]; EXPECT_EQ(-1, handshake->ReadServerHandshake( - kResponse, - WebSocketHandshake::kServerHandshakeHeaderLength + - WebSocketHandshake::kUpgradeHeaderLength + - WebSocketHandshake::kConnectionHeaderLength)); - EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + response.data(), response.size())); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + response += response_lines[3]; // Sec-WebSocket-Origin + response += response_lines[4]; // Sec-WebSocket-Location + response += response_lines[5]; // Sec-WebSocket-Protocol EXPECT_EQ(-1, handshake->ReadServerHandshake( - kResponse, sizeof(kResponse) - 2)); - EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode()); + response.data(), response.size())); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); - int handshake_length = strlen(kResponse); + response += response_lines[6]; // \r\n + EXPECT_EQ(-1, handshake->ReadServerHandshake( + response.data(), response.size())); + EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); + + int handshake_length = sizeof(kResponse) - 1; // -1 for terminating \0 EXPECT_EQ(handshake_length, handshake->ReadServerHandshake( - kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0 + kResponse, handshake_length)); // -1 for terminating \0 EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode()); } -TEST(WebSocketHandshakeTest, ServerSentData) { +TEST_F(WebSocketHandshakeTest, ServerSentData) { const std::string kExpectedClientHandshakeMessage = "GET /demo HTTP/1.1\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Host: example.com\r\n" "Origin: http://example.com\r\n" - "WebSocket-Protocol: sample\r\n" - "\r\n"; + "Sec-WebSocket-Protocol: sample\r\n" + "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7 15\r\n" + "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o 4 I97N 5-S3O 31\r\n" + "\r\n" + "\x47\x30\x22\x2D\x5A\x3F\x47\x58"; scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com/demo"), "http://example.com", "ws://example.com/demo", "sample")); + SetUpParameter(handshake.get(), 777007543U, 114997259U, + "388P O503D&ul7 {K%gX( %7 15", + "1 N ?|k UT0or 3o 4 I97N 5-S3O 31", + std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8)); EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode()); - EXPECT_EQ(kExpectedClientHandshakeMessage, - handshake->CreateClientHandshakeMessage()); + ExpectHandshakeMessageEquals( + kExpectedClientHandshakeMessage, + handshake->CreateClientHandshakeMessage()); - const char kResponse[] ="HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" - "WebSocket-Origin: http://example.com\r\n" - "WebSocket-Location: ws://example.com/demo\r\n" - "WebSocket-Protocol: sample\r\n" + "Sec-WebSocket-Origin: http://example.com\r\n" + "Sec-WebSocket-Location: ws://example.com/demo\r\n" + "Sec-WebSocket-Protocol: sample\r\n" "\r\n" + "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75" "\0Hello\xff"; - int handshake_length = strlen(kResponse); + int handshake_length = strlen(kResponse); // key3 doesn't contain \0. EXPECT_EQ(handshake_length, handshake->ReadServerHandshake( kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0 EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode()); } -TEST(WebSocketHandshakeTest, is_secure_false) { +TEST_F(WebSocketHandshakeTest, is_secure_false) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com/demo"), "http://example.com", @@ -115,7 +232,7 @@ TEST(WebSocketHandshakeTest, is_secure_false) { EXPECT_FALSE(handshake->is_secure()); } -TEST(WebSocketHandshakeTest, is_secure_true) { +TEST_F(WebSocketHandshakeTest, is_secure_true) { // wss:// is secure. scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("wss://example.com/demo"), @@ -125,80 +242,59 @@ TEST(WebSocketHandshakeTest, is_secure_true) { EXPECT_TRUE(handshake->is_secure()); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_Simple) { - scoped_ptr<WebSocketHandshake> handshake( - new WebSocketHandshake(GURL("ws://example.com/demo"), - "http://example.com", - "ws://example.com/demo", - "sample")); - EXPECT_EQ("GET /demo HTTP/1.1\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Host: example.com\r\n" - "Origin: http://example.com\r\n" - "WebSocket-Protocol: sample\r\n" - "\r\n", - handshake->CreateClientHandshakeMessage()); -} - -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_PathAndQuery) { +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_ResourceName) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com/Test?q=xxx&p=%20"), "http://example.com", "ws://example.com/demo", "sample")); // Path and query should be preserved as-is. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("GET /Test?q=xxx&p=%20 HTTP/1.1\r\n")); + EXPECT_EQ("/Test?q=xxx&p=%20", GetResourceName(handshake.get())); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) { +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://Example.Com/demo"), "http://Example.Com", "ws://Example.Com/demo", "sample")); // Host should be lowercased - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com\r\n")); - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Origin: http://example.com\r\n")); + EXPECT_EQ("example.com", GetHostFieldValue(handshake.get())); + EXPECT_EQ("http://example.com", GetOriginFieldValue(handshake.get())); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) { +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com:80/demo"), "http://example.com", "ws://example.com/demo", "sample")); // :80 should be trimmed as it's the default port for ws://. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com\r\n")); + EXPECT_EQ("example.com", GetHostFieldValue(handshake.get())); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) { +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("wss://example.com:443/demo"), "http://example.com", "wss://example.com/demo", "sample")); // :443 should be trimmed as it's the default port for wss://. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com\r\n")); + EXPECT_EQ("example.com", GetHostFieldValue(handshake.get())); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_NonDefaultPortForWs) { +TEST_F(WebSocketHandshakeTest, + CreateClientHandshakeMessage_NonDefaultPortForWs) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com:8080/demo"), "http://example.com", "wss://example.com/demo", "sample")); // :8080 should be preserved as it's not the default port for ws://. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com:8080\r\n")); + EXPECT_EQ("example.com:8080", GetHostFieldValue(handshake.get())); } -TEST(WebSocketHandshakeTest, +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_NonDefaultPortForWss) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("wss://example.com:4443/demo"), @@ -206,30 +302,27 @@ TEST(WebSocketHandshakeTest, "wss://example.com/demo", "sample")); // :4443 should be preserved as it's not the default port for wss://. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com:4443\r\n")); + EXPECT_EQ("example.com:4443", GetHostFieldValue(handshake.get())); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) { +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("ws://example.com:443/demo"), "http://example.com", "ws://example.com/demo", "sample")); // :443 should be preserved as it's not the default port for ws://. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com:443\r\n")); + EXPECT_EQ("example.com:443", GetHostFieldValue(handshake.get())); } -TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) { +TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) { scoped_ptr<WebSocketHandshake> handshake( new WebSocketHandshake(GURL("wss://example.com:80/demo"), "http://example.com", "wss://example.com/demo", "sample")); // :80 should be preserved as it's not the default port for wss://. - EXPECT_THAT(handshake->CreateClientHandshakeMessage(), - testing::HasSubstr("Host: example.com:80\r\n")); + EXPECT_EQ("example.com:80", GetHostFieldValue(handshake.get())); } } // namespace net diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc index e18f712..5f6e8e9 100644 --- a/net/websockets/websocket_unittest.cc +++ b/net/websockets/websocket_unittest.cc @@ -157,6 +157,7 @@ TEST_F(WebSocketTest, Connect) { "sample", "http://example.com", "ws://example.com/demo", + WebSocket::DRAFT75, new TestURLRequestContext())); request->SetHostResolver(new MockHostResolver()); request->SetClientSocketFactory(&mock_socket_factory); @@ -218,6 +219,7 @@ TEST_F(WebSocketTest, ServerSentData) { "sample", "http://example.com", "ws://example.com/demo", + WebSocket::DRAFT75, new TestURLRequestContext())); request->SetHostResolver(new MockHostResolver()); request->SetClientSocketFactory(&mock_socket_factory); @@ -252,6 +254,7 @@ TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) { "sample", "http://example.com", "ws://example.com/demo", + WebSocket::DRAFT75, new TestURLRequestContext())); TestCompletionCallback callback; scoped_ptr<WebSocketEventRecorder> delegate( @@ -287,6 +290,7 @@ TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) { "sample", "http://example.com", "ws://example.com/demo", + WebSocket::DRAFT75, new TestURLRequestContext())); TestCompletionCallback callback; scoped_ptr<WebSocketEventRecorder> delegate( |