diff options
Diffstat (limited to 'net/websockets/websocket_handshake.cc')
-rw-r--r-- | net/websockets/websocket_handshake.cc | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/net/websockets/websocket_handshake.cc b/net/websockets/websocket_handshake.cc new file mode 100644 index 0000000..c17ea34 --- /dev/null +++ b/net/websockets/websocket_handshake.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_handshake.h" + +#include "base/ref_counted.h" +#include "base/string_util.h" +#include "net/http/http_response_headers.h" +#include "net/http/http_util.h" + +namespace net { + +const int WebSocketHandshake::kWebSocketPort = 80; +const int WebSocketHandshake::kSecureWebSocketPort = 443; + +const char WebSocketHandshake::kServerHandshakeHeader[] = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; +const size_t WebSocketHandshake::kServerHandshakeHeaderLength = + sizeof(kServerHandshakeHeader) - 1; + +const char WebSocketHandshake::kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; +const size_t WebSocketHandshake::kUpgradeHeaderLength = + sizeof(kUpgradeHeader) - 1; + +const char WebSocketHandshake::kConnectionHeader[] = "Connection: Upgrade\r\n"; +const size_t WebSocketHandshake::kConnectionHeaderLength = + sizeof(kConnectionHeader) - 1; + +WebSocketHandshake::WebSocketHandshake( + const GURL& url, + const std::string& origin, + const std::string& location, + const std::string& protocol) + : url_(url), + origin_(origin), + location_(location), + protocol_(protocol), + mode_(MODE_INCOMPLETE) { +} + +WebSocketHandshake::~WebSocketHandshake() { +} + +bool WebSocketHandshake::is_secure() const { + return url_.SchemeIs("wss"); +} + +std::string WebSocketHandshake::CreateClientHandshakeMessage() const { + std::string msg; + msg = "GET "; + msg += url_.path(); + if (url_.has_query()) { + msg += "?"; + msg += url_.query(); + } + msg += " HTTP/1.1\r\n"; + msg += kUpgradeHeader; + msg += kConnectionHeader; + msg += "Host: "; + msg += StringToLowerASCII(url_.host()); + if (url_.has_port()) { + bool secure = is_secure(); + int port = url_.EffectiveIntPort(); + if ((!secure && + port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || + (secure && + port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { + msg += ":"; + msg += IntToString(port); + } + } + msg += "\r\n"; + msg += "Origin: "; + // It's OK to lowercase the origin as the Origin header does not contain + // the path or query portions, as per + // http://tools.ietf.org/html/draft-abarth-origin-00. + // + // TODO(satorux): Should we trim the port portion here if it's 80 for + // http:// or 443 for https:// ? Or can we assume it's done by the + // client of the library? + msg += StringToLowerASCII(origin_); + msg += "\r\n"; + if (!protocol_.empty()) { + msg += "WebSocket-Protocol: "; + msg += protocol_; + msg += "\r\n"; + } + // TODO(ukai): Add cookie if necessary. + msg += "\r\n"; + return msg; +} + +int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { + mode_ = MODE_INCOMPLETE; + if (len < kServerHandshakeHeaderLength) { + return -1; + } + if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { + mode_ = MODE_NORMAL; + } else { + int eoh = HttpUtil::LocateEndOfHeaders(data, len); + if (eoh < 0) + return -1; + return eoh; + } + const char* p = data + kServerHandshakeHeaderLength; + const char* end = data + len + 1; + + if (mode_ == MODE_NORMAL) { + size_t header_size = end - p; + if (header_size < kUpgradeHeaderLength) + return -1; + if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { + mode_ = MODE_FAILED; + DLOG(INFO) << "Bad Upgrade Header " + << std::string(p, kUpgradeHeaderLength); + return p - data; + } + p += kUpgradeHeaderLength; + header_size = end - p; + if (header_size < kConnectionHeaderLength) + return -1; + if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { + mode_ = MODE_FAILED; + DLOG(INFO) << "Bad Connection Header " + << std::string(p, kConnectionHeaderLength); + return p - data; + } + p += kConnectionHeaderLength; + } + + int eoh = HttpUtil::LocateEndOfHeaders(data, len); + if (eoh == -1) + return eoh; + + scoped_refptr<HttpResponseHeaders> headers( + new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); + if (!ProcessHeaders(*headers)) { + DLOG(INFO) << "Process Headers failed: " + << std::string(data, eoh); + mode_ = MODE_FAILED; + } + switch (mode_) { + case MODE_NORMAL: + if (CheckResponseHeaders()) { + mode_ = MODE_CONNECTED; + } else { + mode_ = MODE_FAILED; + } + break; + default: + mode_ = MODE_FAILED; + break; + } + return eoh; +} + +// Gets the value of the specified header. +// It assures only one header of |name| in |headers|. +// Returns true iff single header of |name| is found in |headers| +// and |value| is filled with the value. +// Returns false otherwise. +static bool GetSingleHeader(const HttpResponseHeaders& headers, + const std::string& name, + std::string* value) { + std::string first_value; + void* iter = NULL; + if (!headers.EnumerateHeader(&iter, name, &first_value)) + return false; + + // Checks no more |name| found in |headers|. + // Second call of EnumerateHeader() must return false. + std::string second_value; + if (headers.EnumerateHeader(&iter, name, &second_value)) + return false; + *value = first_value; + return true; +} + +bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { + if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) + return false; + + if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) + return false; + + // If |protocol_| is not specified by client, we don't care if there's + // protocol field or not as specified in the spec. + if (!protocol_.empty() + && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) + return false; + return true; +} + +bool WebSocketHandshake::CheckResponseHeaders() const { + DCHECK(mode_ == MODE_NORMAL); + if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) + return false; + if (location_ != ws_location_) + return false; + if (!protocol_.empty() && protocol_ != ws_protocol_) + return false; + return true; +} + + + +} // namespace net |