summaryrefslogtreecommitdiffstats
path: root/net/websockets
diff options
context:
space:
mode:
authorukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2010-03-26 07:35:55 +0000
committerukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2010-03-26 07:35:55 +0000
commit511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3 (patch)
tree8a4b3eb672a1a3f4efd172de181892cc72ec7275 /net/websockets
parent57771511d6e0f0bfea8aafa6bd6aca1294b08f87 (diff)
downloadchromium_src-511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3.zip
chromium_src-511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3.tar.gz
chromium_src-511d0a0a31a54e0cc0f15cb1b977dc9f9b20f0d3.tar.bz2
Implement new websocket handshake based on draft-hixie-thewebsocketprotocol-76
BUG=none TEST=net_unittests passes Review URL: http://codereview.chromium.org/1108002 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@42736 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/websockets')
-rw-r--r--net/websockets/websocket.cc18
-rw-r--r--net/websockets/websocket.h8
-rw-r--r--net/websockets/websocket_handshake.cc318
-rw-r--r--net/websockets/websocket_handshake.h79
-rw-r--r--net/websockets/websocket_handshake_draft75.cc156
-rw-r--r--net/websockets/websocket_handshake_draft75.h63
-rw-r--r--net/websockets/websocket_handshake_draft75_unittest.cc217
-rw-r--r--net/websockets/websocket_handshake_unittest.cc253
-rw-r--r--net/websockets/websocket_unittest.cc4
9 files changed, 900 insertions, 216 deletions
diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc
index dc4f568..1cc3cf9 100644
--- a/net/websockets/websocket.cc
+++ b/net/websockets/websocket.cc
@@ -9,6 +9,7 @@
#include "base/message_loop.h"
#include "net/websockets/websocket_handshake.h"
+#include "net/websockets/websocket_handshake_draft75.h"
namespace net {
@@ -101,9 +102,20 @@ void WebSocket::OnConnected(SocketStream* socket_stream,
DCHECK(!current_write_buf_);
DCHECK(!handshake_.get());
- handshake_.reset(new WebSocketHandshake(
- request_->url(), request_->origin(), request_->location(),
- request_->protocol()));
+ switch (request_->version()) {
+ case DEFAULT_VERSION:
+ handshake_.reset(new WebSocketHandshake(
+ request_->url(), request_->origin(), request_->location(),
+ request_->protocol()));
+ break;
+ case DRAFT75:
+ handshake_.reset(new WebSocketHandshakeDraft75(
+ request_->url(), request_->origin(), request_->location(),
+ request_->protocol()));
+ break;
+ default:
+ NOTREACHED() << "Unexpected protocol version:" << request_->version();
+ }
const std::string msg = handshake_->CreateClientHandshakeMessage();
IOBufferWithSize* buf = new IOBufferWithSize(msg.size());
diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h
index 5e58391..427af56 100644
--- a/net/websockets/websocket.h
+++ b/net/websockets/websocket.h
@@ -60,15 +60,21 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
OPEN = 1,
CLOSED = 2,
};
+ enum ProtocolVersion {
+ DEFAULT_VERSION = 0,
+ DRAFT75 = 1,
+ };
class Request {
public:
Request(const GURL& url, const std::string protocol,
const std::string origin, const std::string location,
+ ProtocolVersion version,
URLRequestContext* context)
: url_(url),
protocol_(protocol),
origin_(origin),
location_(location),
+ version_(version),
context_(context),
host_resolver_(NULL),
client_socket_factory_(NULL) {}
@@ -78,6 +84,7 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
const std::string& protocol() const { return protocol_; }
const std::string& origin() const { return origin_; }
const std::string& location() const { return location_; }
+ ProtocolVersion version() const { return version_; }
URLRequestContext* context() const { return context_; }
// Sets an alternative HostResolver. For testing purposes only.
@@ -100,6 +107,7 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
std::string protocol_;
std::string origin_;
std::string location_;
+ ProtocolVersion version_;
scoped_refptr<URLRequestContext> context_;
scoped_refptr<HostResolver> host_resolver_;
diff --git a/net/websockets/websocket_handshake.cc b/net/websockets/websocket_handshake.cc
index c17ea34..6f660bc 100644
--- a/net/websockets/websocket_handshake.cc
+++ b/net/websockets/websocket_handshake.cc
@@ -4,6 +4,11 @@
#include "net/websockets/websocket_handshake.h"
+#include <algorithm>
+#include <vector>
+
+#include "base/md5.h"
+#include "base/rand_util.h"
#include "base/ref_counted.h"
#include "base/string_util.h"
#include "net/http/http_response_headers.h"
@@ -14,19 +19,6 @@ namespace net {
const int WebSocketHandshake::kWebSocketPort = 80;
const int WebSocketHandshake::kSecureWebSocketPort = 443;
-const char WebSocketHandshake::kServerHandshakeHeader[] =
- "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
-const size_t WebSocketHandshake::kServerHandshakeHeaderLength =
- sizeof(kServerHandshakeHeader) - 1;
-
-const char WebSocketHandshake::kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
-const size_t WebSocketHandshake::kUpgradeHeaderLength =
- sizeof(kUpgradeHeader) - 1;
-
-const char WebSocketHandshake::kConnectionHeader[] = "Connection: Upgrade\r\n";
-const size_t WebSocketHandshake::kConnectionHeaderLength =
- sizeof(kConnectionHeader) - 1;
-
WebSocketHandshake::WebSocketHandshake(
const GURL& url,
const std::string& origin,
@@ -46,19 +38,94 @@ bool WebSocketHandshake::is_secure() const {
return url_.SchemeIs("wss");
}
-std::string WebSocketHandshake::CreateClientHandshakeMessage() const {
+std::string WebSocketHandshake::CreateClientHandshakeMessage() {
+ if (!parameter_.get()) {
+ parameter_.reset(new Parameter);
+ parameter_->GenerateKeys();
+ }
std::string msg;
+
+ // WebSocket protocol 4.1 Opening handshake.
+
msg = "GET ";
- msg += url_.path();
+ msg += GetResourceName();
+ msg += " HTTP/1.1\r\n";
+
+ std::vector<std::string> fields;
+
+ fields.push_back("Upgrade: WebSocket");
+ fields.push_back("Connection: Upgrade");
+
+ fields.push_back("Host: " + GetHostFieldValue());
+
+ fields.push_back("Origin: " + GetOriginFieldValue());
+
+ if (!protocol_.empty())
+ fields.push_back("Sec-WebSocket-Protocol: " + protocol_);
+
+ // TODO(ukai): Add cookie if necessary.
+
+ fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1());
+ fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2());
+
+ std::random_shuffle(fields.begin(), fields.end());
+
+ for (size_t i = 0; i < fields.size(); i++) {
+ msg += fields[i] + "\r\n";
+ }
+ msg += "\r\n";
+
+ msg.append(parameter_->GetKey3());
+ return msg;
+}
+
+int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
+ mode_ = MODE_INCOMPLETE;
+ int eoh = HttpUtil::LocateEndOfHeaders(data, len);
+ if (eoh < 0)
+ return -1;
+
+ scoped_refptr<HttpResponseHeaders> headers(
+ new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
+
+ if (headers->response_code() != 101) {
+ mode_ = MODE_FAILED;
+ DLOG(INFO) << "Bad response code: " << headers->response_code();
+ return eoh;
+ }
+ mode_ = MODE_NORMAL;
+ if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) {
+ DLOG(INFO) << "Process Headers failed: "
+ << std::string(data, eoh);
+ mode_ = MODE_FAILED;
+ return eoh;
+ }
+ if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) {
+ mode_ = MODE_INCOMPLETE;
+ return -1;
+ }
+ uint8 expected[Parameter::kExpectedResponseSize];
+ parameter_->GetExpectedResponse(expected);
+ if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) {
+ mode_ = MODE_FAILED;
+ return eoh + Parameter::kExpectedResponseSize;
+ }
+ mode_ = MODE_CONNECTED;
+ return eoh + Parameter::kExpectedResponseSize;
+}
+
+std::string WebSocketHandshake::GetResourceName() const {
+ std::string resource_name = url_.path();
if (url_.has_query()) {
- msg += "?";
- msg += url_.query();
+ resource_name += "?";
+ resource_name += url_.query();
}
- msg += " HTTP/1.1\r\n";
- msg += kUpgradeHeader;
- msg += kConnectionHeader;
- msg += "Host: ";
- msg += StringToLowerASCII(url_.host());
+ return resource_name;
+}
+
+std::string WebSocketHandshake::GetHostFieldValue() const {
+ // url_.host() is expected to be encoded in punnycode here.
+ std::string host = StringToLowerASCII(url_.host());
if (url_.has_port()) {
bool secure = is_secure();
int port = url_.EffectiveIntPort();
@@ -66,12 +133,14 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const {
port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
(secure &&
port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
- msg += ":";
- msg += IntToString(port);
+ host += ":";
+ host += IntToString(port);
}
}
- msg += "\r\n";
- msg += "Origin: ";
+ return host;
+}
+
+std::string WebSocketHandshake::GetOriginFieldValue() const {
// It's OK to lowercase the origin as the Origin header does not contain
// the path or query portions, as per
// http://tools.ietf.org/html/draft-abarth-origin-00.
@@ -79,91 +148,13 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const {
// TODO(satorux): Should we trim the port portion here if it's 80 for
// http:// or 443 for https:// ? Or can we assume it's done by the
// client of the library?
- msg += StringToLowerASCII(origin_);
- msg += "\r\n";
- if (!protocol_.empty()) {
- msg += "WebSocket-Protocol: ";
- msg += protocol_;
- msg += "\r\n";
- }
- // TODO(ukai): Add cookie if necessary.
- msg += "\r\n";
- return msg;
+ return StringToLowerASCII(origin_);
}
-int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
- mode_ = MODE_INCOMPLETE;
- if (len < kServerHandshakeHeaderLength) {
- return -1;
- }
- if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
- mode_ = MODE_NORMAL;
- } else {
- int eoh = HttpUtil::LocateEndOfHeaders(data, len);
- if (eoh < 0)
- return -1;
- return eoh;
- }
- const char* p = data + kServerHandshakeHeaderLength;
- const char* end = data + len + 1;
-
- if (mode_ == MODE_NORMAL) {
- size_t header_size = end - p;
- if (header_size < kUpgradeHeaderLength)
- return -1;
- if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
- mode_ = MODE_FAILED;
- DLOG(INFO) << "Bad Upgrade Header "
- << std::string(p, kUpgradeHeaderLength);
- return p - data;
- }
- p += kUpgradeHeaderLength;
- header_size = end - p;
- if (header_size < kConnectionHeaderLength)
- return -1;
- if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
- mode_ = MODE_FAILED;
- DLOG(INFO) << "Bad Connection Header "
- << std::string(p, kConnectionHeaderLength);
- return p - data;
- }
- p += kConnectionHeaderLength;
- }
-
- int eoh = HttpUtil::LocateEndOfHeaders(data, len);
- if (eoh == -1)
- return eoh;
-
- scoped_refptr<HttpResponseHeaders> headers(
- new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
- if (!ProcessHeaders(*headers)) {
- DLOG(INFO) << "Process Headers failed: "
- << std::string(data, eoh);
- mode_ = MODE_FAILED;
- }
- switch (mode_) {
- case MODE_NORMAL:
- if (CheckResponseHeaders()) {
- mode_ = MODE_CONNECTED;
- } else {
- mode_ = MODE_FAILED;
- }
- break;
- default:
- mode_ = MODE_FAILED;
- break;
- }
- return eoh;
-}
-
-// Gets the value of the specified header.
-// It assures only one header of |name| in |headers|.
-// Returns true iff single header of |name| is found in |headers|
-// and |value| is filled with the value.
-// Returns false otherwise.
-static bool GetSingleHeader(const HttpResponseHeaders& headers,
- const std::string& name,
- std::string* value) {
+/* static */
+bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers,
+ const std::string& name,
+ std::string* value) {
std::string first_value;
void* iter = NULL;
if (!headers.EnumerateHeader(&iter, name, &first_value))
@@ -179,16 +170,25 @@ static bool GetSingleHeader(const HttpResponseHeaders& headers,
}
bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) {
- if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
+ std::string value;
+ if (!GetSingleHeader(headers, "upgrade", &value) ||
+ value != "WebSocket")
+ return false;
+
+ if (!GetSingleHeader(headers, "connection", &value) ||
+ !LowerCaseEqualsASCII(value, "upgrade"))
+ return false;
+
+ if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_))
return false;
- if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
+ if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_))
return false;
// If |protocol_| is not specified by client, we don't care if there's
// protocol field or not as specified in the spec.
if (!protocol_.empty()
- && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
+ && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_))
return false;
return true;
}
@@ -204,6 +204,100 @@ bool WebSocketHandshake::CheckResponseHeaders() const {
return true;
}
+namespace {
+
+// unsigned int version of base::RandInt().
+// we can't use base::RandInt(), because max would be negative if it is
+// represented as int, so DCHECK(min <= max) fails.
+uint32 RandUint32(uint32 min, uint32 max) {
+ DCHECK(min <= max);
+
+ uint64 range = static_cast<int64>(max) - min + 1;
+ uint64 number = base::RandUint64();
+ // TODO(ukai): fix to be uniform.
+ // the distribution of the result of modulo will be biased.
+ uint32 result = min + static_cast<uint32>(number % range);
+ DCHECK(result >= min && result <= max);
+ return result;
+}
+
+}
+
+uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) =
+ RandUint32;
+uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39];
+WebSocketHandshake::Parameter::Parameter()
+ : number_1_(0), number_2_(0) {
+ if (randomCharacterInSecWebSocketKey[0] == '\0') {
+ int i = 0;
+ for (int ch = 0x21; ch <= 0x2F; ch++, i++)
+ randomCharacterInSecWebSocketKey[i] = ch;
+ for (int ch = 0x3A; ch <= 0x7E; ch++, i++)
+ randomCharacterInSecWebSocketKey[i] = ch;
+ }
+}
+
+WebSocketHandshake::Parameter::~Parameter() {}
+
+void WebSocketHandshake::Parameter::GenerateKeys() {
+ GenerateSecWebSocketKey(&number_1_, &key_1_);
+ GenerateSecWebSocketKey(&number_2_, &key_2_);
+ GenerateKey3();
+}
+
+static void SetChallengeNumber(uint8* buf, uint32 number) {
+ uint8* p = buf + 3;
+ for (int i = 0; i < 4; i++) {
+ *p = (uint8)(number & 0xFF);
+ --p;
+ number >>= 8;
+ }
+}
+
+void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const {
+ uint8 challenge[kExpectedResponseSize];
+ SetChallengeNumber(&challenge[0], number_1_);
+ SetChallengeNumber(&challenge[4], number_2_);
+ memcpy(&challenge[8], key_3_.data(), kKey3Size);
+ MD5Digest digest;
+ MD5Sum(challenge, kExpectedResponseSize, &digest);
+ memcpy(expected, digest.a, kExpectedResponseSize);
+}
+
+/* static */
+void WebSocketHandshake::Parameter::SetRandomNumberGenerator(
+ uint32 (*rand)(uint32 min, uint32 max)) {
+ rand_ = rand;
+}
+
+void WebSocketHandshake::Parameter::GenerateSecWebSocketKey(
+ uint32* number, std::string* key) {
+ uint32 space = rand_(1, 12);
+ uint32 max = 4294967295U / space;
+ *number = rand_(0, max);
+ uint32 product = *number * space;
+
+ std::string s = StringPrintf("%010u", product);
+ for (uint32 i = 0; i < space; i++) {
+ int pos = rand_(1, s.length() - 1);
+ s = s.substr(0, pos) + " " + s.substr(pos);
+ }
+ int n = rand_(1, 12);
+ for (int i = 0; i < n; i++) {
+ int pos = rand_(0, s.length());
+ int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1);
+ s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) +
+ s.substr(pos);
+ }
+ *key = s;
+}
+
+void WebSocketHandshake::Parameter::GenerateKey3() {
+ key_3_.clear();
+ for (int i = 0; i < 8; i++) {
+ key_3_.append(1, rand_(0, 255));
+ }
+}
} // namespace net
diff --git a/net/websockets/websocket_handshake.h b/net/websockets/websocket_handshake.h
index 1e94eff..3f64b8b 100644
--- a/net/websockets/websocket_handshake.h
+++ b/net/websockets/websocket_handshake.h
@@ -8,6 +8,7 @@
#include <string>
#include "base/basictypes.h"
+#include "base/scoped_ptr.h"
#include "googleurl/src/gurl.h"
namespace net {
@@ -18,12 +19,6 @@ class WebSocketHandshake {
public:
static const int kWebSocketPort;
static const int kSecureWebSocketPort;
- static const char kServerHandshakeHeader[];
- static const size_t kServerHandshakeHeaderLength;
- static const char kUpgradeHeader[];
- static const size_t kUpgradeHeaderLength;
- static const char kConnectionHeader[];
- static const size_t kConnectionHeaderLength;
enum Mode {
MODE_INCOMPLETE, MODE_NORMAL, MODE_FAILED, MODE_CONNECTED
@@ -32,32 +27,33 @@ class WebSocketHandshake {
const std::string& origin,
const std::string& location,
const std::string& protocol);
- ~WebSocketHandshake();
+ virtual ~WebSocketHandshake();
bool is_secure() const;
// Creates the client handshake message from |this|.
- std::string CreateClientHandshakeMessage() const;
+ virtual std::string CreateClientHandshakeMessage();
// Reads server handshake message in |len| of |data|, updates |mode_| and
// returns number of bytes of the server handshake message.
// Once connection is established, |mode_| will be MODE_CONNECTED.
// If connection establishment failed, |mode_| will be MODE_FAILED.
// Returns negative if the server handshake message is incomplete.
- int ReadServerHandshake(const char* data, size_t len);
+ virtual int ReadServerHandshake(const char* data, size_t len);
Mode mode() const { return mode_; }
- private:
- // Processes server handshake message, parsed as |headers|, and updates
- // |ws_origin_|, |ws_location_| and |ws_protocol_|.
- // Returns true if it's ok.
- // Returns false otherwise (e.g. duplicate WebSocket-Origin: header, etc.)
- bool ProcessHeaders(const HttpResponseHeaders& headers);
-
- // Checks |ws_origin_|, |ws_location_| and |ws_protocol_| are valid
- // against |origin_|, |location_| and |protocol_|.
- // Returns true if it's ok.
- // Returns false otherwise (e.g. origin mismatch, etc.)
- bool CheckResponseHeaders() const;
+ protected:
+ std::string GetResourceName() const;
+ std::string GetHostFieldValue() const;
+ std::string GetOriginFieldValue() const;
+
+ // Gets the value of the specified header.
+ // It assures only one header of |name| in |headers|.
+ // Returns true iff single header of |name| is found in |headers|
+ // and |value| is filled with the value.
+ // Returns false otherwise.
+ static bool GetSingleHeader(const HttpResponseHeaders& headers,
+ const std::string& name,
+ std::string* value);
GURL url_;
// Handshake messages that the client is going to send out.
@@ -72,6 +68,47 @@ class WebSocketHandshake {
std::string ws_location_;
std::string ws_protocol_;
+ private:
+ friend class WebSocketHandshakeTest;
+
+ class Parameter {
+ public:
+ static const int kKey3Size = 8;
+ static const int kExpectedResponseSize = 16;
+ Parameter();
+ ~Parameter();
+
+ void GenerateKeys();
+ const std::string& GetSecWebSocketKey1() const { return key_1_; }
+ const std::string& GetSecWebSocketKey2() const { return key_2_; }
+ const std::string& GetKey3() const { return key_3_; }
+
+ void GetExpectedResponse(uint8* expected) const;
+
+ private:
+ friend class WebSocketHandshakeTest;
+
+ // Set random number generator. |rand| should return a random number
+ // between min and max (inclusive).
+ static void SetRandomNumberGenerator(
+ uint32 (*rand)(uint32 min, uint32 max));
+ void GenerateSecWebSocketKey(uint32* number, std::string* key);
+ void GenerateKey3();
+
+ uint32 number_1_;
+ uint32 number_2_;
+ std::string key_1_;
+ std::string key_2_;
+ std::string key_3_;
+
+ static uint32 (*rand_)(uint32 min, uint32 max);
+ };
+
+ virtual bool ProcessHeaders(const HttpResponseHeaders& headers);
+ virtual bool CheckResponseHeaders() const;
+
+ scoped_ptr<Parameter> parameter_;
+
DISALLOW_COPY_AND_ASSIGN(WebSocketHandshake);
};
diff --git a/net/websockets/websocket_handshake_draft75.cc b/net/websockets/websocket_handshake_draft75.cc
new file mode 100644
index 0000000..78805fb
--- /dev/null
+++ b/net/websockets/websocket_handshake_draft75.cc
@@ -0,0 +1,156 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/websockets/websocket_handshake_draft75.h"
+
+#include "base/ref_counted.h"
+#include "base/string_util.h"
+#include "net/http/http_response_headers.h"
+#include "net/http/http_util.h"
+
+namespace net {
+
+const char WebSocketHandshakeDraft75::kServerHandshakeHeader[] =
+ "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
+const size_t WebSocketHandshakeDraft75::kServerHandshakeHeaderLength =
+ sizeof(kServerHandshakeHeader) - 1;
+
+const char WebSocketHandshakeDraft75::kUpgradeHeader[] =
+ "Upgrade: WebSocket\r\n";
+const size_t WebSocketHandshakeDraft75::kUpgradeHeaderLength =
+ sizeof(kUpgradeHeader) - 1;
+
+const char WebSocketHandshakeDraft75::kConnectionHeader[] =
+ "Connection: Upgrade\r\n";
+const size_t WebSocketHandshakeDraft75::kConnectionHeaderLength =
+ sizeof(kConnectionHeader) - 1;
+
+WebSocketHandshakeDraft75::WebSocketHandshakeDraft75(
+ const GURL& url,
+ const std::string& origin,
+ const std::string& location,
+ const std::string& protocol)
+ : WebSocketHandshake(url, origin, location, protocol) {
+}
+
+WebSocketHandshakeDraft75::~WebSocketHandshakeDraft75() {
+}
+
+std::string WebSocketHandshakeDraft75::CreateClientHandshakeMessage() {
+ std::string msg;
+ msg = "GET ";
+ msg += GetResourceName();
+ msg += " HTTP/1.1\r\n";
+ msg += kUpgradeHeader;
+ msg += kConnectionHeader;
+ msg += "Host: ";
+ msg += GetHostFieldValue();
+ msg += "\r\n";
+ msg += "Origin: ";
+ msg += GetOriginFieldValue();
+ msg += "\r\n";
+ if (!protocol_.empty()) {
+ msg += "WebSocket-Protocol: ";
+ msg += protocol_;
+ msg += "\r\n";
+ }
+ // TODO(ukai): Add cookie if necessary.
+ msg += "\r\n";
+ return msg;
+}
+
+int WebSocketHandshakeDraft75::ReadServerHandshake(
+ const char* data, size_t len) {
+ mode_ = MODE_INCOMPLETE;
+ if (len < kServerHandshakeHeaderLength) {
+ return -1;
+ }
+ if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
+ mode_ = MODE_NORMAL;
+ } else {
+ int eoh = HttpUtil::LocateEndOfHeaders(data, len);
+ if (eoh < 0)
+ return -1;
+ return eoh;
+ }
+ const char* p = data + kServerHandshakeHeaderLength;
+ const char* end = data + len;
+
+ if (mode_ == MODE_NORMAL) {
+ size_t header_size = end - p;
+ if (header_size < kUpgradeHeaderLength)
+ return -1;
+ if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
+ mode_ = MODE_FAILED;
+ DLOG(INFO) << "Bad Upgrade Header "
+ << std::string(p, kUpgradeHeaderLength);
+ return p - data;
+ }
+ p += kUpgradeHeaderLength;
+ header_size = end - p;
+ if (header_size < kConnectionHeaderLength)
+ return -1;
+ if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
+ mode_ = MODE_FAILED;
+ DLOG(INFO) << "Bad Connection Header "
+ << std::string(p, kConnectionHeaderLength);
+ return p - data;
+ }
+ p += kConnectionHeaderLength;
+ }
+
+ int eoh = HttpUtil::LocateEndOfHeaders(data, len);
+ if (eoh == -1)
+ return eoh;
+
+ scoped_refptr<HttpResponseHeaders> headers(
+ new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
+ if (!ProcessHeaders(*headers)) {
+ DLOG(INFO) << "Process Headers failed: "
+ << std::string(data, eoh);
+ mode_ = MODE_FAILED;
+ }
+ switch (mode_) {
+ case MODE_NORMAL:
+ if (CheckResponseHeaders()) {
+ mode_ = MODE_CONNECTED;
+ } else {
+ mode_ = MODE_FAILED;
+ }
+ break;
+ default:
+ mode_ = MODE_FAILED;
+ break;
+ }
+ return eoh;
+}
+
+bool WebSocketHandshakeDraft75::ProcessHeaders(
+ const HttpResponseHeaders& headers) {
+ if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
+ return false;
+
+ if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
+ return false;
+
+ // If |protocol_| is not specified by client, we don't care if there's
+ // protocol field or not as specified in the spec.
+ if (!protocol_.empty()
+ && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
+ return false;
+ return true;
+}
+
+bool WebSocketHandshakeDraft75::CheckResponseHeaders() const {
+ DCHECK(mode_ == MODE_NORMAL);
+ if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
+ return false;
+ if (location_ != ws_location_)
+ return false;
+ if (!protocol_.empty() && protocol_ != ws_protocol_)
+ return false;
+ return true;
+}
+
+} // namespace net
diff --git a/net/websockets/websocket_handshake_draft75.h b/net/websockets/websocket_handshake_draft75.h
new file mode 100644
index 0000000..6cc0506
--- /dev/null
+++ b/net/websockets/websocket_handshake_draft75.h
@@ -0,0 +1,63 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_DRAFT75_H_
+#define NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_DRAFT75_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "googleurl/src/gurl.h"
+#include "net/websockets/websocket_handshake.h"
+
+namespace net {
+
+class HttpResponseHeaders;
+
+class WebSocketHandshakeDraft75 : public WebSocketHandshake {
+ public:
+ static const int kWebSocketPort;
+ static const int kSecureWebSocketPort;
+ static const char kServerHandshakeHeader[];
+ static const size_t kServerHandshakeHeaderLength;
+ static const char kUpgradeHeader[];
+ static const size_t kUpgradeHeaderLength;
+ static const char kConnectionHeader[];
+ static const size_t kConnectionHeaderLength;
+
+ WebSocketHandshakeDraft75(const GURL& url,
+ const std::string& origin,
+ const std::string& location,
+ const std::string& protocol);
+ virtual ~WebSocketHandshakeDraft75();
+
+ // Creates the client handshake message from |this|.
+ virtual std::string CreateClientHandshakeMessage();
+
+ // Reads server handshake message in |len| of |data|, updates |mode_| and
+ // returns number of bytes of the server handshake message.
+ // Once connection is established, |mode_| will be MODE_CONNECTED.
+ // If connection establishment failed, |mode_| will be MODE_FAILED.
+ // Returns negative if the server handshake message is incomplete.
+ virtual int ReadServerHandshake(const char* data, size_t len);
+
+ private:
+ // Processes server handshake message, parsed as |headers|, and updates
+ // |ws_origin_|, |ws_location_| and |ws_protocol_|.
+ // Returns true if it's ok.
+ // Returns false otherwise (e.g. duplicate WebSocket-Origin: header, etc.)
+ virtual bool ProcessHeaders(const HttpResponseHeaders& headers);
+
+ // Checks |ws_origin_|, |ws_location_| and |ws_protocol_| are valid
+ // against |origin_|, |location_| and |protocol_|.
+ // Returns true if it's ok.
+ // Returns false otherwise (e.g. origin mismatch, etc.)
+ virtual bool CheckResponseHeaders() const;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeDraft75);
+};
+
+} // namespace net
+
+#endif // NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_DRAFT75_H_
diff --git a/net/websockets/websocket_handshake_draft75_unittest.cc b/net/websockets/websocket_handshake_draft75_unittest.cc
new file mode 100644
index 0000000..aff75ad
--- /dev/null
+++ b/net/websockets/websocket_handshake_draft75_unittest.cc
@@ -0,0 +1,217 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <string>
+#include <vector>
+
+#include "base/scoped_ptr.h"
+#include "net/websockets/websocket_handshake_draft75.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+TEST(WebSocketHandshakeDraft75Test, Connect) {
+ const std::string kExpectedClientHandshakeMessage =
+ "GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n";
+
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com/demo"),
+ "http://example.com",
+ "ws://example.com/demo",
+ "sample"));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+ EXPECT_EQ(kExpectedClientHandshakeMessage,
+ handshake->CreateClientHandshakeMessage());
+
+ const char kResponse[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "WebSocket-Origin: http://example.com\r\n"
+ "WebSocket-Location: ws://example.com/demo\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n";
+
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+ // too short
+ EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+ // only status line
+ EXPECT_EQ(-1, handshake->ReadServerHandshake(
+ kResponse,
+ WebSocketHandshakeDraft75::kServerHandshakeHeaderLength));
+ EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+ // by upgrade header
+ EXPECT_EQ(-1, handshake->ReadServerHandshake(
+ kResponse,
+ WebSocketHandshakeDraft75::kServerHandshakeHeaderLength +
+ WebSocketHandshakeDraft75::kUpgradeHeaderLength));
+ EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+ // by connection header
+ EXPECT_EQ(-1, handshake->ReadServerHandshake(
+ kResponse,
+ WebSocketHandshakeDraft75::kServerHandshakeHeaderLength +
+ WebSocketHandshakeDraft75::kUpgradeHeaderLength +
+ WebSocketHandshakeDraft75::kConnectionHeaderLength));
+ EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+
+ EXPECT_EQ(-1, handshake->ReadServerHandshake(
+ kResponse, sizeof(kResponse) - 2));
+ EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+
+ int handshake_length = strlen(kResponse);
+ EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
+ kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0
+ EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
+}
+
+TEST(WebSocketHandshakeDraft75Test, ServerSentData) {
+ const std::string kExpectedClientHandshakeMessage =
+ "GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n";
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com/demo"),
+ "http://example.com",
+ "ws://example.com/demo",
+ "sample"));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+ EXPECT_EQ(kExpectedClientHandshakeMessage,
+ handshake->CreateClientHandshakeMessage());
+
+ const char kResponse[] ="HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "WebSocket-Origin: http://example.com\r\n"
+ "WebSocket-Location: ws://example.com/demo\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"
+ "\0Hello\xff";
+
+ int handshake_length = strlen(kResponse);
+ EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
+ kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0
+ EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_Simple) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com/demo"),
+ "http://example.com",
+ "ws://example.com/demo",
+ "sample"));
+ EXPECT_EQ("GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n",
+ handshake->CreateClientHandshakeMessage());
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_PathAndQuery) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com/Test?q=xxx&p=%20"),
+ "http://example.com",
+ "ws://example.com/demo",
+ "sample"));
+ // Path and query should be preserved as-is.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("GET /Test?q=xxx&p=%20 HTTP/1.1\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_Host) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://Example.Com/demo"),
+ "http://Example.Com",
+ "ws://Example.Com/demo",
+ "sample"));
+ // Host should be lowercased
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com\r\n"));
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Origin: http://example.com\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_TrimPort80) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com:80/demo"),
+ "http://example.com",
+ "ws://example.com/demo",
+ "sample"));
+ // :80 should be trimmed as it's the default port for ws://.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_TrimPort443) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("wss://example.com:443/demo"),
+ "http://example.com",
+ "wss://example.com/demo",
+ "sample"));
+ // :443 should be trimmed as it's the default port for wss://.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test,
+ CreateClientHandshakeMessage_NonDefaultPortForWs) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com:8080/demo"),
+ "http://example.com",
+ "wss://example.com/demo",
+ "sample"));
+ // :8080 should be preserved as it's not the default port for ws://.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com:8080\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test,
+ CreateClientHandshakeMessage_NonDefaultPortForWss) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("wss://example.com:4443/demo"),
+ "http://example.com",
+ "wss://example.com/demo",
+ "sample"));
+ // :4443 should be preserved as it's not the default port for wss://.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com:4443\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_WsBut443) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("ws://example.com:443/demo"),
+ "http://example.com",
+ "ws://example.com/demo",
+ "sample"));
+ // :443 should be preserved as it's not the default port for ws://.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com:443\r\n"));
+}
+
+TEST(WebSocketHandshakeDraft75Test, CreateClientHandshakeMessage_WssBut80) {
+ scoped_ptr<WebSocketHandshakeDraft75> handshake(
+ new WebSocketHandshakeDraft75(GURL("wss://example.com:80/demo"),
+ "http://example.com",
+ "wss://example.com/demo",
+ "sample"));
+ // :80 should be preserved as it's not the default port for wss://.
+ EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
+ testing::HasSubstr("Host: example.com:80\r\n"));
+}
+
+} // namespace net
diff --git a/net/websockets/websocket_handshake_unittest.cc b/net/websockets/websocket_handshake_unittest.cc
index beae805..f688554 100644
--- a/net/websockets/websocket_handshake_unittest.cc
+++ b/net/websockets/websocket_handshake_unittest.cc
@@ -6,6 +6,7 @@
#include <vector>
#include "base/scoped_ptr.h"
+#include "base/string_util.h"
#include "net/websockets/websocket_handshake.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/gmock/include/gmock/gmock.h"
@@ -13,100 +14,216 @@
namespace net {
-TEST(WebSocketHandshakeTest, Connect) {
+class WebSocketHandshakeTest : public testing::Test {
+ public:
+ static void SetUpParameter(WebSocketHandshake* handshake,
+ uint32 number_1, uint32 number_2,
+ const std::string& key_1, const std::string& key_2,
+ const std::string& key_3) {
+ WebSocketHandshake::Parameter* parameter =
+ new WebSocketHandshake::Parameter;
+ parameter->number_1_ = number_1;
+ parameter->number_2_ = number_2;
+ parameter->key_1_ = key_1;
+ parameter->key_2_ = key_2;
+ parameter->key_3_ = key_3;
+ handshake->parameter_.reset(parameter);
+ }
+
+ static void ExpectHeaderEquals(const std::string& expected,
+ const std::string& actual) {
+ std::vector<std::string> expected_lines;
+ Tokenize(expected, "\r\n", &expected_lines);
+ std::vector<std::string> actual_lines;
+ Tokenize(actual, "\r\n", &actual_lines);
+ // Request lines.
+ EXPECT_EQ(expected_lines[0], actual_lines[0]);
+
+ std::vector<std::string> expected_headers;
+ for (size_t i = 1; i < expected_lines.size(); i++) {
+ // Finish at first CRLF CRLF. Note that /key_3/ might include CRLF.
+ if (expected_lines[i] == "")
+ break;
+ expected_headers.push_back(expected_lines[i]);
+ }
+ sort(expected_headers.begin(), expected_headers.end());
+
+ std::vector<std::string> actual_headers;
+ for (size_t i = 1; i < actual_lines.size(); i++) {
+ // Finish at first CRLF CRLF. Note that /key_3/ might include CRLF.
+ if (actual_lines[i] == "")
+ break;
+ actual_headers.push_back(actual_lines[i]);
+ }
+ sort(actual_headers.begin(), actual_headers.end());
+
+ EXPECT_EQ(expected_headers.size(), actual_headers.size())
+ << "expected:" << expected
+ << "\nactual:" << actual;
+ for (size_t i = 0; i < expected_headers.size(); i++) {
+ EXPECT_EQ(expected_headers[i], actual_headers[i]);
+ }
+ }
+
+ static void ExpectHandshakeMessageEquals(const std::string& expected,
+ const std::string& actual) {
+ // Headers.
+ ExpectHeaderEquals(expected, actual);
+ // Compare tailing \r\n\r\n<key3> (4 + 8 bytes).
+ ASSERT_GT(expected.size(), 12U);
+ const char* expected_key3 = expected.data() + expected.size() - 12;
+ EXPECT_GT(actual.size(), 12U);
+ if (actual.size() <= 12U)
+ return;
+ const char* actual_key3 = actual.data() + actual.size() - 12;
+ EXPECT_TRUE(memcmp(expected_key3, actual_key3, 12) == 0)
+ << "expected_key3:" << DumpKey(expected_key3, 12)
+ << ", actual_key3:" << DumpKey(actual_key3, 12);
+ }
+
+ static std::string DumpKey(const char* buf, int len) {
+ std::string s;
+ for (int i = 0; i < len; i++) {
+ if (isprint(buf[i]))
+ s += StringPrintf("%c", buf[i]);
+ else
+ s += StringPrintf("\\x%02x", buf[i]);
+ }
+ return s;
+ }
+
+ static std::string GetResourceName(WebSocketHandshake* handshake) {
+ return handshake->GetResourceName();
+ }
+ static std::string GetHostFieldValue(WebSocketHandshake* handshake) {
+ return handshake->GetHostFieldValue();
+ }
+ static std::string GetOriginFieldValue(WebSocketHandshake* handshake) {
+ return handshake->GetOriginFieldValue();
+ }
+};
+
+
+TEST_F(WebSocketHandshakeTest, Connect) {
const std::string kExpectedClientHandshakeMessage =
"GET /demo HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"Host: example.com\r\n"
"Origin: http://example.com\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n";
+ "Sec-WebSocket-Protocol: sample\r\n"
+ "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7 15\r\n"
+ "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o 4 I97N 5-S3O 31\r\n"
+ "\r\n"
+ "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com/demo"),
"http://example.com",
"ws://example.com/demo",
"sample"));
+ SetUpParameter(handshake.get(), 777007543U, 114997259U,
+ "388P O503D&ul7 {K%gX( %7 15",
+ "1 N ?|k UT0or 3o 4 I97N 5-S3O 31",
+ std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
- EXPECT_EQ(kExpectedClientHandshakeMessage,
- handshake->CreateClientHandshakeMessage());
+ ExpectHandshakeMessageEquals(
+ kExpectedClientHandshakeMessage,
+ handshake->CreateClientHandshakeMessage());
- const char kResponse[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
+ const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
- "WebSocket-Origin: http://example.com\r\n"
- "WebSocket-Location: ws://example.com/demo\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n";
+ "Sec-WebSocket-Origin: http://example.com\r\n"
+ "Sec-WebSocket-Location: ws://example.com/demo\r\n"
+ "Sec-WebSocket-Protocol: sample\r\n"
+ "\r\n"
+ "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75";
+ std::vector<std::string> response_lines;
+ SplitStringDontTrim(kResponse, '\n', &response_lines);
EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
// too short
EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16));
EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+
// only status line
+ std::string response = response_lines[0];
EXPECT_EQ(-1, handshake->ReadServerHandshake(
- kResponse,
- WebSocketHandshake::kServerHandshakeHeaderLength));
- EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+ response.data(), response.size()));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
// by upgrade header
+ response += response_lines[1];
EXPECT_EQ(-1, handshake->ReadServerHandshake(
- kResponse,
- WebSocketHandshake::kServerHandshakeHeaderLength +
- WebSocketHandshake::kUpgradeHeaderLength));
- EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+ response.data(), response.size()));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
// by connection header
+ response += response_lines[2];
EXPECT_EQ(-1, handshake->ReadServerHandshake(
- kResponse,
- WebSocketHandshake::kServerHandshakeHeaderLength +
- WebSocketHandshake::kUpgradeHeaderLength +
- WebSocketHandshake::kConnectionHeaderLength));
- EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+ response.data(), response.size()));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+ response += response_lines[3]; // Sec-WebSocket-Origin
+ response += response_lines[4]; // Sec-WebSocket-Location
+ response += response_lines[5]; // Sec-WebSocket-Protocol
EXPECT_EQ(-1, handshake->ReadServerHandshake(
- kResponse, sizeof(kResponse) - 2));
- EXPECT_EQ(WebSocketHandshake::MODE_NORMAL, handshake->mode());
+ response.data(), response.size()));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
- int handshake_length = strlen(kResponse);
+ response += response_lines[6]; // \r\n
+ EXPECT_EQ(-1, handshake->ReadServerHandshake(
+ response.data(), response.size()));
+ EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
+
+ int handshake_length = sizeof(kResponse) - 1; // -1 for terminating \0
EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
- kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0
+ kResponse, handshake_length)); // -1 for terminating \0
EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
}
-TEST(WebSocketHandshakeTest, ServerSentData) {
+TEST_F(WebSocketHandshakeTest, ServerSentData) {
const std::string kExpectedClientHandshakeMessage =
"GET /demo HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"Host: example.com\r\n"
"Origin: http://example.com\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n";
+ "Sec-WebSocket-Protocol: sample\r\n"
+ "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7 15\r\n"
+ "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o 4 I97N 5-S3O 31\r\n"
+ "\r\n"
+ "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com/demo"),
"http://example.com",
"ws://example.com/demo",
"sample"));
+ SetUpParameter(handshake.get(), 777007543U, 114997259U,
+ "388P O503D&ul7 {K%gX( %7 15",
+ "1 N ?|k UT0or 3o 4 I97N 5-S3O 31",
+ std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
- EXPECT_EQ(kExpectedClientHandshakeMessage,
- handshake->CreateClientHandshakeMessage());
+ ExpectHandshakeMessageEquals(
+ kExpectedClientHandshakeMessage,
+ handshake->CreateClientHandshakeMessage());
- const char kResponse[] ="HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
+ const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
- "WebSocket-Origin: http://example.com\r\n"
- "WebSocket-Location: ws://example.com/demo\r\n"
- "WebSocket-Protocol: sample\r\n"
+ "Sec-WebSocket-Origin: http://example.com\r\n"
+ "Sec-WebSocket-Location: ws://example.com/demo\r\n"
+ "Sec-WebSocket-Protocol: sample\r\n"
"\r\n"
+ "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75"
"\0Hello\xff";
- int handshake_length = strlen(kResponse);
+ int handshake_length = strlen(kResponse); // key3 doesn't contain \0.
EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0
EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
}
-TEST(WebSocketHandshakeTest, is_secure_false) {
+TEST_F(WebSocketHandshakeTest, is_secure_false) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com/demo"),
"http://example.com",
@@ -115,7 +232,7 @@ TEST(WebSocketHandshakeTest, is_secure_false) {
EXPECT_FALSE(handshake->is_secure());
}
-TEST(WebSocketHandshakeTest, is_secure_true) {
+TEST_F(WebSocketHandshakeTest, is_secure_true) {
// wss:// is secure.
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("wss://example.com/demo"),
@@ -125,80 +242,59 @@ TEST(WebSocketHandshakeTest, is_secure_true) {
EXPECT_TRUE(handshake->is_secure());
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_Simple) {
- scoped_ptr<WebSocketHandshake> handshake(
- new WebSocketHandshake(GURL("ws://example.com/demo"),
- "http://example.com",
- "ws://example.com/demo",
- "sample"));
- EXPECT_EQ("GET /demo HTTP/1.1\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Host: example.com\r\n"
- "Origin: http://example.com\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n",
- handshake->CreateClientHandshakeMessage());
-}
-
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_PathAndQuery) {
+TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_ResourceName) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com/Test?q=xxx&p=%20"),
"http://example.com",
"ws://example.com/demo",
"sample"));
// Path and query should be preserved as-is.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("GET /Test?q=xxx&p=%20 HTTP/1.1\r\n"));
+ EXPECT_EQ("/Test?q=xxx&p=%20", GetResourceName(handshake.get()));
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) {
+TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://Example.Com/demo"),
"http://Example.Com",
"ws://Example.Com/demo",
"sample"));
// Host should be lowercased
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com\r\n"));
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Origin: http://example.com\r\n"));
+ EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
+ EXPECT_EQ("http://example.com", GetOriginFieldValue(handshake.get()));
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) {
+TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com:80/demo"),
"http://example.com",
"ws://example.com/demo",
"sample"));
// :80 should be trimmed as it's the default port for ws://.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com\r\n"));
+ EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) {
+TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("wss://example.com:443/demo"),
"http://example.com",
"wss://example.com/demo",
"sample"));
// :443 should be trimmed as it's the default port for wss://.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com\r\n"));
+ EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_NonDefaultPortForWs) {
+TEST_F(WebSocketHandshakeTest,
+ CreateClientHandshakeMessage_NonDefaultPortForWs) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com:8080/demo"),
"http://example.com",
"wss://example.com/demo",
"sample"));
// :8080 should be preserved as it's not the default port for ws://.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com:8080\r\n"));
+ EXPECT_EQ("example.com:8080", GetHostFieldValue(handshake.get()));
}
-TEST(WebSocketHandshakeTest,
+TEST_F(WebSocketHandshakeTest,
CreateClientHandshakeMessage_NonDefaultPortForWss) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("wss://example.com:4443/demo"),
@@ -206,30 +302,27 @@ TEST(WebSocketHandshakeTest,
"wss://example.com/demo",
"sample"));
// :4443 should be preserved as it's not the default port for wss://.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com:4443\r\n"));
+ EXPECT_EQ("example.com:4443", GetHostFieldValue(handshake.get()));
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) {
+TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("ws://example.com:443/demo"),
"http://example.com",
"ws://example.com/demo",
"sample"));
// :443 should be preserved as it's not the default port for ws://.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com:443\r\n"));
+ EXPECT_EQ("example.com:443", GetHostFieldValue(handshake.get()));
}
-TEST(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) {
+TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) {
scoped_ptr<WebSocketHandshake> handshake(
new WebSocketHandshake(GURL("wss://example.com:80/demo"),
"http://example.com",
"wss://example.com/demo",
"sample"));
// :80 should be preserved as it's not the default port for wss://.
- EXPECT_THAT(handshake->CreateClientHandshakeMessage(),
- testing::HasSubstr("Host: example.com:80\r\n"));
+ EXPECT_EQ("example.com:80", GetHostFieldValue(handshake.get()));
}
} // namespace net
diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc
index e18f712..5f6e8e9 100644
--- a/net/websockets/websocket_unittest.cc
+++ b/net/websockets/websocket_unittest.cc
@@ -157,6 +157,7 @@ TEST_F(WebSocketTest, Connect) {
"sample",
"http://example.com",
"ws://example.com/demo",
+ WebSocket::DRAFT75,
new TestURLRequestContext()));
request->SetHostResolver(new MockHostResolver());
request->SetClientSocketFactory(&mock_socket_factory);
@@ -218,6 +219,7 @@ TEST_F(WebSocketTest, ServerSentData) {
"sample",
"http://example.com",
"ws://example.com/demo",
+ WebSocket::DRAFT75,
new TestURLRequestContext()));
request->SetHostResolver(new MockHostResolver());
request->SetClientSocketFactory(&mock_socket_factory);
@@ -252,6 +254,7 @@ TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
"sample",
"http://example.com",
"ws://example.com/demo",
+ WebSocket::DRAFT75,
new TestURLRequestContext()));
TestCompletionCallback callback;
scoped_ptr<WebSocketEventRecorder> delegate(
@@ -287,6 +290,7 @@ TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
"sample",
"http://example.com",
"ws://example.com/demo",
+ WebSocket::DRAFT75,
new TestURLRequestContext()));
TestCompletionCallback callback;
scoped_ptr<WebSocketEventRecorder> delegate(