// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/server/web_socket.h" #include "base/base64.h" #include "base/logging.h" #include "base/md5.h" #include "base/sha1.h" #include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" #include "base/sys_byteorder.h" #include "net/server/http_connection.h" #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" #include "net/server/web_socket_encoder.h" namespace net { namespace { static uint32 WebSocketKeyFingerprint(const std::string& str) { std::string result; const char* p_char = str.c_str(); int length = str.length(); int spaces = 0; for (int i = 0; i < length; ++i) { if (p_char[i] >= '0' && p_char[i] <= '9') result.append(&p_char[i], 1); else if (p_char[i] == ' ') spaces++; } if (spaces == 0) return 0; int64 number = 0; if (!base::StringToInt64(result, &number)) return 0; return base::HostToNet32(static_cast(number / spaces)); } class WebSocketHixie76 : public net::WebSocket { public: static net::WebSocket* Create(HttpServer* server, HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { if (connection->read_buf()->GetSize() < static_cast(*pos + kWebSocketHandshakeBodyLen)) return NULL; return new WebSocketHixie76(server, connection, request, pos); } void Accept(const HttpServerRequestInfo& request) override { std::string key1 = request.GetHeaderValue("sec-websocket-key1"); std::string key2 = request.GetHeaderValue("sec-websocket-key2"); uint32 fp1 = WebSocketKeyFingerprint(key1); uint32 fp2 = WebSocketKeyFingerprint(key2); char data[16]; memcpy(data, &fp1, 4); memcpy(data + 4, &fp2, 4); memcpy(data + 8, &key3_[0], 8); base::MD5Digest digest; base::MD5Sum(data, 16, &digest); std::string origin = request.GetHeaderValue("origin"); std::string host = request.GetHeaderValue("host"); std::string location = "ws://" + host + request.path; server_->SendRaw( connection_->id(), base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: %s\r\n" "Sec-WebSocket-Location: %s\r\n" "\r\n", origin.c_str(), location.c_str())); server_->SendRaw(connection_->id(), std::string(reinterpret_cast(digest.a), 16)); } ParseResult Read(std::string* message) override { DCHECK(message); HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); if (read_buf->StartOfBuffer()[0]) return FRAME_ERROR; base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); size_t pos = data.find('\377', 1); if (pos == base::StringPiece::npos) return FRAME_INCOMPLETE; message->assign(data.data() + 1, pos - 1); read_buf->DidConsume(pos + 1); return FRAME_OK; } void Send(const std::string& message) override { char message_start = 0; char message_end = -1; server_->SendRaw(connection_->id(), std::string(1, message_start)); server_->SendRaw(connection_->id(), message); server_->SendRaw(connection_->id(), std::string(1, message_end)); } private: static const int kWebSocketHandshakeBodyLen; WebSocketHixie76(HttpServer* server, HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) : WebSocket(server, connection) { std::string key1 = request.GetHeaderValue("sec-websocket-key1"); std::string key2 = request.GetHeaderValue("sec-websocket-key2"); if (key1.empty()) { server->SendResponse( connection->id(), HttpServerResponseInfo::CreateFor500( "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " "specified.")); return; } if (key2.empty()) { server->SendResponse( connection->id(), HttpServerResponseInfo::CreateFor500( "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " "specified.")); return; } key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, kWebSocketHandshakeBodyLen); *pos += kWebSocketHandshakeBodyLen; } std::string key3_; DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); }; const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; class WebSocketHybi17 : public WebSocket { public: static WebSocket* Create(HttpServer* server, HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { std::string version = request.GetHeaderValue("sec-websocket-version"); if (version != "8" && version != "13") return NULL; std::string key = request.GetHeaderValue("sec-websocket-key"); if (key.empty()) { server->SendResponse( connection->id(), HttpServerResponseInfo::CreateFor500( "Invalid request format. Sec-WebSocket-Key is empty or isn't " "specified.")); return NULL; } return new WebSocketHybi17(server, connection, request, pos); } void Accept(const HttpServerRequestInfo& request) override { static const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; std::string key = request.GetHeaderValue("sec-websocket-key"); std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); std::string encoded_hash; base::Base64Encode(base::SHA1HashString(data), &encoded_hash); server_->SendRaw(connection_->id(), base::StringPrintf( "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" "%s" "\r\n", encoded_hash.c_str(), response_extensions_.c_str())); } ParseResult Read(std::string* message) override { HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); int bytes_consumed = 0; ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); if (result == FRAME_OK) read_buf->DidConsume(bytes_consumed); if (result == FRAME_CLOSE) closed_ = true; return result; } void Send(const std::string& message) override { if (closed_) return; std::string encoded; encoder_->EncodeFrame(message, 0, &encoded); server_->SendRaw(connection_->id(), encoded); } private: WebSocketHybi17(HttpServer* server, HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) : WebSocket(server, connection), closed_(false) { std::string request_extensions = request.GetHeaderValue("sec-websocket-extensions"); encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, &response_extensions_)); if (!response_extensions_.empty()) { response_extensions_ = "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; } } scoped_ptr encoder_; std::string response_extensions_; bool closed_; DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); }; } // anonymous namespace WebSocket* WebSocket::CreateWebSocket(HttpServer* server, HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); if (socket) return socket; return WebSocketHixie76::Create(server, connection, request, pos); } WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) : server_(server), connection_(connection) { } WebSocket::~WebSocket() { } } // namespace net