diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/net.gyp | 4 | ||||
-rw-r--r-- | net/server/http_connection.cc | 84 | ||||
-rw-r--r-- | net/server/http_connection.h | 54 | ||||
-rw-r--r-- | net/server/http_server.cc | 285 | ||||
-rw-r--r-- | net/server/http_server.h | 37 | ||||
-rw-r--r-- | net/server/http_server_request_info.cc | 9 | ||||
-rw-r--r-- | net/server/http_server_request_info.h | 3 | ||||
-rw-r--r-- | net/server/web_socket.cc | 365 | ||||
-rw-r--r-- | net/server/web_socket.h | 46 |
9 files changed, 653 insertions, 234 deletions
diff --git a/net/net.gyp b/net/net.gyp index de2b4df..d9958a1 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -1316,10 +1316,14 @@ '../testing/gtest.gyp:gtest', ], 'sources': [ + 'server/http_connection.cc', + 'server/http_connection.h', 'server/http_server.cc', 'server/http_server.h', 'server/http_server_request_info.cc', 'server/http_server_request_info.h', + 'server/web_socket.cc', + 'server/web_socket.h', ], }, { diff --git a/net/server/http_connection.cc b/net/server/http_connection.cc new file mode 100644 index 0000000..3404c68 --- /dev/null +++ b/net/server/http_connection.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/server/http_connection.h" + +#include "base/string_util.h" +#include "base/stringprintf.h" +#include "net/base/listen_socket.h" +#include "net/server/http_server.h" +#include "net/server/web_socket.h" + +namespace net { + +int HttpConnection::last_id_ = 0; + +void HttpConnection::Send(const std::string& data) { + if (!socket_) + return; + socket_->Send(data); +} + +void HttpConnection::Send(const char* bytes, int len) { + if (!socket_) + return; + socket_->Send(bytes, len); +} + +void HttpConnection::Send200(const std::string& data, + const std::string& content_type) { + if (!socket_) + return; + socket_->Send(base::StringPrintf( + "HTTP/1.1 200 OK\r\n" + "Content-Type:%s\r\n" + "Content-Length:%d\r\n" + "\r\n", + content_type.c_str(), + static_cast<int>(data.length()))); + socket_->Send(data); +} + +void HttpConnection::Send404() { + if (!socket_) + return; + socket_->Send( + "HTTP/1.1 404 Not Found\r\n" + "Content-Length: 0\r\n" + "\r\n"); +} + +void HttpConnection::Send500(const std::string& message) { + if (!socket_) + return; + socket_->Send(base::StringPrintf( + "HTTP/1.1 500 Internal Error\r\n" + "Content-Type:text/html\r\n" + "Content-Length:%d\r\n" + "\r\n" + "%s", + static_cast<int>(message.length()), + message.c_str())); +} + +HttpConnection::HttpConnection(HttpServer* server, ListenSocket* sock) + : server_(server), + socket_(sock) { + id_ = last_id_++; +} + +HttpConnection::~HttpConnection() { + DetachSocket(); + server_->delegate_->OnClose(id_); +} + +void HttpConnection::DetachSocket() { + socket_ = NULL; +} + +void HttpConnection::Shift(int num_bytes) { + recv_data_ = recv_data_.substr(num_bytes); +} + +} // namespace net diff --git a/net/server/http_connection.h b/net/server/http_connection.h new file mode 100644 index 0000000..f154bc3 --- /dev/null +++ b/net/server/http_connection.h @@ -0,0 +1,54 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SERVER_HTTP_CONNECTION_H_ +#define NET_SERVER_HTTP_CONNECTION_H_ +#pragma once + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" + +namespace net { + +class HttpServer; +class HttpServerRequestInfo; +class ListenSocket; +class WebSocket; + +class HttpConnection { + public: + void Send(const std::string& data); + void Send(const char* bytes, int len); + void Send200(const std::string& data, const std::string& content_type); + void Send404(); + void Send500(const std::string& message); + + void Shift(int num_bytes); + + const std::string& recv_data() const { return recv_data_; } + const int id() const { return id_; } + + private: + friend class HttpServer; + static int last_id_; + + HttpConnection(HttpServer* server, ListenSocket* sock); + ~HttpConnection(); + + void DetachSocket(); + + HttpServer* server_; + scoped_refptr<ListenSocket> socket_; + scoped_ptr<WebSocket> web_socket_; + std::string recv_data_; + int id_; + DISALLOW_COPY_AND_ASSIGN(HttpConnection); +}; + +} // namespace net + +#endif // NET_SERVER_HTTP_CONNECTION_H_ diff --git a/net/server/http_server.cc b/net/server/http_server.cc index 061731a..42d5752 100644 --- a/net/server/http_server.cc +++ b/net/server/http_server.cc @@ -6,12 +6,12 @@ #include "base/compiler_specific.h" #include "base/logging.h" -#include "base/md5.h" -#include "base/string_number_conversions.h" #include "base/string_util.h" #include "base/stringprintf.h" #include "build/build_config.h" +#include "net/server/http_connection.h" #include "net/server/http_server_request_info.h" +#include "net/server/web_socket.h" #if defined(OS_WIN) #include <winsock2.h> @@ -21,8 +21,6 @@ namespace net { -int HttpServer::Connection::lastId_ = 0; - HttpServer::HttpServer(const std::string& host, int port, HttpServer::Delegate* del) @@ -38,174 +36,53 @@ HttpServer::~HttpServer() { server_ = NULL; } -std::string GetHeaderValue( - const HttpServerRequestInfo& request, - const std::string& header_name) { - HttpServerRequestInfo::HeadersMap::iterator it = - request.headers.find(header_name); - if (it != request.headers.end()) - return it->second; - return ""; -} - -uint32 WebSocketKeyFingerprint(const std::string& str) { - std::string result; - const char* pChar = str.c_str(); - int length = str.length(); - int spaces = 0; - for (int i = 0; i < length; ++i) { - if (pChar[i] >= '0' && pChar[i] <= '9') - result.append(&pChar[i], 1); - else if (pChar[i] == ' ') - spaces++; - } - if (spaces == 0) - return 0; - int64 number = 0; - if (!base::StringToInt64(result, &number)) - return 0; - return htonl(static_cast<uint32>(number / spaces)); -} - -void HttpServer::AcceptWebSocket( - int connection_id, - const HttpServerRequestInfo& request) { - Connection* connection = FindConnection(connection_id); - if (connection == NULL) - return; - - std::string key1 = GetHeaderValue(request, "Sec-WebSocket-Key1"); - std::string key2 = GetHeaderValue(request, "Sec-WebSocket-Key2"); - - uint32 fp1 = WebSocketKeyFingerprint(key1); - uint32 fp2 = WebSocketKeyFingerprint(key2); - - char data[16]; - memcpy(data, &fp1, 4); - memcpy(data + 4, &fp2, 4); - memcpy(data + 8, &request.data[0], 8); - - base::MD5Digest digest; - base::MD5Sum(data, 16, &digest); - - std::string origin = GetHeaderValue(request, "Origin"); - std::string host = GetHeaderValue(request, "Host"); - std::string location = "ws://" + host + request.path; - connection->is_web_socket_ = true; - connection->socket_->Send(base::StringPrintf( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Origin: %s\r\n" - "Sec-WebSocket-Location: %s\r\n" - "\r\n", - origin.c_str(), - location.c_str())); - connection->socket_->Send(reinterpret_cast<char*>(digest.a), 16); -} - -void HttpServer::SendOverWebSocket(int connection_id, - const std::string& data) { - Connection* connection = FindConnection(connection_id); - if (connection == NULL) - return; - - DCHECK(connection->is_web_socket_); - char message_start = 0; - char message_end = -1; - connection->socket_->Send(&message_start, 1); - connection->socket_->Send(data); - connection->socket_->Send(&message_end, 1); -} - void HttpServer::Send(int connection_id, const std::string& data) { - Connection* connection = FindConnection(connection_id); + HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - - connection->socket_->Send(data); + connection->Send(data); } void HttpServer::Send(int connection_id, const char* bytes, int len) { - Connection* connection = FindConnection(connection_id); + HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - connection->socket_->Send(bytes, len); + connection->Send(bytes, len); } void HttpServer::Send200(int connection_id, const std::string& data, const std::string& content_type) { - Connection* connection = FindConnection(connection_id); + HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - - connection->socket_->Send(base::StringPrintf( - "HTTP/1.1 200 OK\r\n" - "Content-Type:%s\r\n" - "Content-Length:%d\r\n" - "\r\n", - content_type.c_str(), - static_cast<int>(data.length()))); - connection->socket_->Send(data); + connection->Send200(data, content_type); } void HttpServer::Send404(int connection_id) { - Connection* connection = FindConnection(connection_id); + HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - - connection->socket_->Send( - "HTTP/1.1 404 Not Found\r\n" - "Content-Length: 0\r\n" - "\r\n"); + connection->Send404(); } void HttpServer::Send500(int connection_id, const std::string& message) { - Connection* connection = FindConnection(connection_id); + HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - - connection->socket_->Send(base::StringPrintf( - "HTTP/1.1 500 Internal Error\r\n" - "Content-Type:text/html\r\n" - "Content-Length:%d\r\n" - "\r\n" - "%s", - static_cast<int>(message.length()), - message.c_str())); + connection->Send500(message); } void HttpServer::Close(int connection_id) { - Connection* connection = FindConnection(connection_id); + HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; connection->DetachSocket(); } -HttpServer::Connection::Connection(HttpServer* server, ListenSocket* sock) - : server_(server), - socket_(sock), - is_web_socket_(false) { - id_ = lastId_++; -} - -HttpServer::Connection::~Connection() { - DetachSocket(); - server_->delegate_->OnClose(id_); -} - -void HttpServer::Connection::DetachSocket() { - socket_ = NULL; -} - -void HttpServer::Connection::Shift(int num_bytes) { - recv_data_ = recv_data_.substr(num_bytes); -} - // // HTTP Request Parser // This HTTP request parser uses a simple state machine to quickly parse @@ -222,8 +99,6 @@ enum header_parse_inputs { INPUT_CR, INPUT_LF, INPUT_COLON, - INPUT_00, - INPUT_FF, INPUT_DEFAULT, MAX_INPUTS, }; @@ -237,9 +112,6 @@ enum header_parse_states { ST_NAME, // Receiving a request header name ST_SEPARATOR, // Receiving the separator between header name and value ST_VALUE, // Receiving a request header value - ST_WS_READY, // Ready to receive web socket frame - ST_WS_FRAME, // Receiving WebSocket frame - ST_WS_CLOSE, // Closing the connection WebSocket connection ST_DONE, // Parsing is complete and successful ST_ERR, // Parsing encountered invalid syntax. MAX_STATES @@ -247,18 +119,15 @@ enum header_parse_states { // State transition table int parser_state[MAX_STATES][MAX_INPUTS] = { -/* METHOD */ { ST_URL, ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_METHOD }, -/* URL */ { ST_PROTO, ST_ERR, ST_ERR, ST_URL, ST_ERR, ST_ERR, ST_URL }, -/* PROTOCOL */ { ST_ERR, ST_HEADER, ST_NAME, ST_ERR, ST_ERR, ST_ERR, ST_PROTO }, -/* HEADER */ { ST_ERR, ST_ERR, ST_NAME, ST_ERR, ST_ERR, ST_ERR, ST_ERR }, -/* NAME */ { ST_SEPARATOR, ST_DONE, ST_ERR, ST_SEPARATOR, ST_ERR, ST_ERR, ST_NAME }, -/* SEPARATOR */ { ST_SEPARATOR, ST_ERR, ST_ERR, ST_SEPARATOR, ST_ERR, ST_ERR, ST_VALUE }, -/* VALUE */ { ST_VALUE, ST_HEADER, ST_NAME, ST_VALUE, ST_ERR, ST_ERR, ST_VALUE }, -/* WS_READY */ { ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_WS_FRAME, ST_WS_CLOSE, ST_ERR}, -/* WS_FRAME */ { ST_WS_FRAME, ST_WS_FRAME, ST_WS_FRAME, ST_WS_FRAME, ST_ERR, ST_WS_READY, ST_WS_FRAME }, -/* WS_CLOSE */ { ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_WS_CLOSE, ST_ERR, ST_ERR }, -/* DONE */ { ST_DONE, ST_DONE, ST_DONE, ST_DONE, ST_DONE, ST_DONE, ST_DONE }, -/* ERR */ { ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_ERR } +/* METHOD */ { ST_URL, ST_ERR, ST_ERR, ST_ERR, ST_METHOD }, +/* URL */ { ST_PROTO, ST_ERR, ST_ERR, ST_URL, ST_URL }, +/* PROTOCOL */ { ST_ERR, ST_HEADER, ST_NAME, ST_ERR, ST_PROTO }, +/* HEADER */ { ST_ERR, ST_ERR, ST_NAME, ST_ERR, ST_ERR }, +/* NAME */ { ST_SEPARATOR, ST_DONE, ST_ERR, ST_SEPARATOR, ST_NAME }, +/* SEPARATOR */ { ST_SEPARATOR, ST_ERR, ST_ERR, ST_SEPARATOR, ST_VALUE }, +/* VALUE */ { ST_VALUE, ST_HEADER, ST_NAME, ST_VALUE, ST_VALUE }, +/* DONE */ { ST_DONE, ST_DONE, ST_DONE, ST_DONE, ST_DONE }, +/* ERR */ { ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_ERR } }; // Convert an input character to the parser's input token. @@ -272,20 +141,16 @@ int charToInput(char ch) { return INPUT_LF; case ':': return INPUT_COLON; - case 0x0: - return INPUT_00; - case static_cast<char>(-1): - return INPUT_FF; } return INPUT_DEFAULT; } -bool HttpServer::ParseHeaders(Connection* connection, +bool HttpServer::ParseHeaders(HttpConnection* connection, HttpServerRequestInfo* info, - int* ppos) { - int& pos = *ppos; - int data_len = connection->recv_data_.length(); - int state = connection->is_web_socket_ ? ST_WS_READY : ST_METHOD; + size_t* ppos) { + size_t& pos = *ppos; + size_t data_len = connection->recv_data_.length(); + int state = ST_METHOD; std::string buffer; std::string header_name; std::string header_value; @@ -325,11 +190,6 @@ bool HttpServer::ParseHeaders(Connection* connection, case ST_SEPARATOR: buffer.append(&ch, 1); break; - case ST_WS_FRAME: - info->data = buffer; - buffer.clear(); - return true; - break; } state = next_state; } else { @@ -340,15 +200,11 @@ bool HttpServer::ParseHeaders(Connection* connection, case ST_PROTO: case ST_VALUE: case ST_NAME: - case ST_WS_FRAME: buffer.append(&ch, 1); break; case ST_DONE: DCHECK(input == INPUT_LF); return true; - case ST_WS_CLOSE: - connection->is_web_socket_ = false; - return false; case ST_ERR: return false; } @@ -360,81 +216,98 @@ bool HttpServer::ParseHeaders(Connection* connection, void HttpServer::DidAccept(ListenSocket* server, ListenSocket* socket) { - Connection* connection = new Connection(this, socket); - id_to_connection_[connection->id_] = connection; + HttpConnection* connection = new HttpConnection(this, socket); + id_to_connection_[connection->id()] = connection; socket_to_connection_[socket] = connection; } void HttpServer::DidRead(ListenSocket* socket, const char* data, int len) { - Connection* connection = FindConnection(socket); + HttpConnection* connection = FindConnection(socket); DCHECK(connection != NULL); if (connection == NULL) return; connection->recv_data_.append(data, len); while (connection->recv_data_.length()) { - int pos = 0; - HttpServerRequestInfo request; - if (!ParseHeaders(connection, &request, &pos)) - break; + if (connection->web_socket_.get()) { + std::string message; + WebSocket::ParseResult result = connection->web_socket_->Read(&message); + if (result == WebSocket::FRAME_INCOMPLETE) + break; - if (connection->is_web_socket_) { - delegate_->OnWebSocketMessage(connection->id_, request.data); - connection->Shift(pos); + if (result == WebSocket::FRAME_ERROR) { + Close(connection->id()); + break; + } + delegate_->OnWebSocketMessage(connection->id(), message); continue; } - std::string connection_header = GetHeaderValue(request, "Connection"); + HttpServerRequestInfo request; + size_t pos = 0; + if (!ParseHeaders(connection, &request, &pos)) + break; + + std::string connection_header = request.GetHeaderValue("Connection"); if (connection_header == "Upgrade") { - // Is this WebSocket and if yes, upgrade the connection. - std::string key1 = GetHeaderValue(request, "Sec-WebSocket-Key1"); - std::string key2 = GetHeaderValue(request, "Sec-WebSocket-Key2"); - - const int websocket_handshake_body_len = 8; - if (pos + websocket_handshake_body_len > - static_cast<int>(connection->recv_data_.length())) { - // We haven't received websocket handshake body yet. Wait. - break; - } + connection->web_socket_.reset(WebSocket::CreateWebSocket(connection, + request, + &pos)); - if (!key1.empty() && !key2.empty()) { - request.data = connection->recv_data_.substr( - pos, - pos + websocket_handshake_body_len); - pos += websocket_handshake_body_len; - delegate_->OnWebSocketRequest(connection->id_, request); - connection->Shift(pos); - continue; - } + if (!connection->web_socket_.get()) // Not enought data was received. + break; + delegate_->OnWebSocketRequest(connection->id(), request); + connection->Shift(pos); + continue; } // Request body is not supported. It is always empty. - delegate_->OnHttpRequest(connection->id_, request); + delegate_->OnHttpRequest(connection->id(), request); connection->Shift(pos); } } void HttpServer::DidClose(ListenSocket* socket) { - Connection* connection = FindConnection(socket); + HttpConnection* connection = FindConnection(socket); DCHECK(connection != NULL); - id_to_connection_.erase(connection->id_); + id_to_connection_.erase(connection->id()); socket_to_connection_.erase(connection->socket_); delete connection; } -HttpServer::Connection* HttpServer::FindConnection(int connection_id) { +HttpConnection* HttpServer::FindConnection(int connection_id) { IdToConnectionMap::iterator it = id_to_connection_.find(connection_id); if (it == id_to_connection_.end()) return NULL; return it->second; } -HttpServer::Connection* HttpServer::FindConnection(ListenSocket* socket) { +HttpConnection* HttpServer::FindConnection(ListenSocket* socket) { SocketToConnectionMap::iterator it = socket_to_connection_.find(socket); if (it == socket_to_connection_.end()) return NULL; return it->second; } +void HttpServer::AcceptWebSocket( + int connection_id, + const HttpServerRequestInfo& request) { + HttpConnection* connection = FindConnection(connection_id); + if (connection == NULL) + return; + + DCHECK(connection->web_socket_.get()); + connection->web_socket_->Accept(request); +} + +void HttpServer::SendOverWebSocket(int connection_id, + const std::string& data) { + HttpConnection* connection = FindConnection(connection_id); + if (connection == NULL) + return; + DCHECK(connection->web_socket_.get()); + connection->web_socket_->Send(data); +} + } // namespace net diff --git a/net/server/http_server.h b/net/server/http_server.h index 78f6e42..1bc24a4 100644 --- a/net/server/http_server.h +++ b/net/server/http_server.h @@ -15,7 +15,9 @@ namespace net { +class HttpConnection; class HttpServerRequestInfo; +class WebSocket; class HttpServer : public ListenSocket::ListenSocketDelegate, public base::RefCountedThreadSafe<HttpServer> { @@ -53,28 +55,7 @@ class HttpServer : public ListenSocket::ListenSocketDelegate, private: friend class base::RefCountedThreadSafe<HttpServer>; - class Connection { - private: - static int lastId_; - friend class HttpServer; - - Connection(HttpServer* server, ListenSocket* sock); - ~Connection(); - - void DetachSocket(); - - void Shift(int num_bytes); - - HttpServer* server_; - scoped_refptr<ListenSocket> socket_; - bool is_web_socket_; - std::string recv_data_; - int id_; - - DISALLOW_COPY_AND_ASSIGN(Connection); - }; - friend class Connection; - + friend class HttpConnection; // ListenSocketDelegate virtual void DidAccept(ListenSocket* server, ListenSocket* socket); @@ -84,18 +65,18 @@ private: // Expects the raw data to be stored in recv_data_. If parsing is successful, // will remove the data parsed from recv_data_, leaving only the unused // recv data. - bool ParseHeaders(Connection* connection, + bool ParseHeaders(HttpConnection* connection, HttpServerRequestInfo* info, - int* ppos); + size_t* pos); - Connection* FindConnection(int connection_id); - Connection* FindConnection(ListenSocket* socket); + HttpConnection* FindConnection(int connection_id); + HttpConnection* FindConnection(ListenSocket* socket); HttpServer::Delegate* delegate_; scoped_refptr<ListenSocket> server_; - typedef std::map<int, Connection*> IdToConnectionMap; + typedef std::map<int, HttpConnection*> IdToConnectionMap; IdToConnectionMap id_to_connection_; - typedef std::map<ListenSocket*, Connection*> SocketToConnectionMap; + typedef std::map<ListenSocket*, HttpConnection*> SocketToConnectionMap; SocketToConnectionMap socket_to_connection_; DISALLOW_COPY_AND_ASSIGN(HttpServer); diff --git a/net/server/http_server_request_info.cc b/net/server/http_server_request_info.cc index e53a2e2..eda597a 100644 --- a/net/server/http_server_request_info.cc +++ b/net/server/http_server_request_info.cc @@ -10,4 +10,13 @@ HttpServerRequestInfo::HttpServerRequestInfo() {} HttpServerRequestInfo::~HttpServerRequestInfo() {} +std::string HttpServerRequestInfo::GetHeaderValue( + const std::string& header_name) const { + HttpServerRequestInfo::HeadersMap::const_iterator it = + headers.find(header_name); + if (it != headers.end()) + return it->second; + return ""; +} + } // namespace net diff --git a/net/server/http_server_request_info.h b/net/server/http_server_request_info.h index 21f319a..4c08c1f 100644 --- a/net/server/http_server_request_info.h +++ b/net/server/http_server_request_info.h @@ -20,6 +20,9 @@ class HttpServerRequestInfo { HttpServerRequestInfo(); ~HttpServerRequestInfo(); + // Returns header value for given header name. + std::string GetHeaderValue(const std::string& header_name) const; + // Request method. std::string method; diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc new file mode 100644 index 0000000..de5e33a --- /dev/null +++ b/net/server/web_socket.cc @@ -0,0 +1,365 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/server/web_socket.h" + +#include "base/base64.h" +#include "base/rand_util.h" +#include "base/logging.h" +#include "base/md5.h" +#include "base/sha1.h" +#include "base/string_number_conversions.h" +#include "base/stringprintf.h" +#include "net/server/http_connection.h" +#include "net/server/http_server_request_info.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#else +#include <arpa/inet.h> +#endif + +#include <limits> + +namespace net { + +namespace { + +static uint32 WebSocketKeyFingerprint(const std::string& str) { + std::string result; + const char* p_char = str.c_str(); + int length = str.length(); + int spaces = 0; + for (int i = 0; i < length; ++i) { + if (p_char[i] >= '0' && p_char[i] <= '9') + result.append(&p_char[i], 1); + else if (p_char[i] == ' ') + spaces++; + } + if (spaces == 0) + return 0; + int64 number = 0; + if (!base::StringToInt64(result, &number)) + return 0; + return htonl(static_cast<uint32>(number / spaces)); +} + +class WebSocketHixie76 : public net::WebSocket { + public: + static net::WebSocket* Create(HttpConnection* connection, + const HttpServerRequestInfo& request, + size_t* pos) { + if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen) + return NULL; + return new WebSocketHixie76(connection, request, pos); + } + + virtual void Accept(const HttpServerRequestInfo& request) { + std::string key1 = request.GetHeaderValue("Sec-WebSocket-Key1"); + std::string key2 = request.GetHeaderValue("Sec-WebSocket-Key2"); + + uint32 fp1 = WebSocketKeyFingerprint(key1); + uint32 fp2 = WebSocketKeyFingerprint(key2); + + char data[16]; + memcpy(data, &fp1, 4); + memcpy(data + 4, &fp2, 4); + memcpy(data + 8, &key3_[0], 8); + + base::MD5Digest digest; + base::MD5Sum(data, 16, &digest); + + std::string origin = request.GetHeaderValue("Origin"); + std::string host = request.GetHeaderValue("Host"); + std::string location = "ws://" + host + request.path; + connection_->Send(base::StringPrintf( + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Origin: %s\r\n" + "Sec-WebSocket-Location: %s\r\n" + "\r\n", + origin.c_str(), + location.c_str())); + connection_->Send(reinterpret_cast<char*>(digest.a), 16); + } + + virtual ParseResult Read(std::string* message) { + DCHECK(message); + const std::string& data = connection_->recv_data(); + if (data[0]) + return FRAME_ERROR; + + size_t pos = data.find('\377', 1); + if (pos == std::string::npos) + return FRAME_INCOMPLETE; + + std::string buffer(data.begin() + 1, data.begin() + pos); + message->swap(buffer); + connection_->Shift(pos + 1); + + return FRAME_OK; + } + + virtual void Send(const std::string& message) { + char message_start = 0; + char message_end = -1; + connection_->Send(&message_start, 1); + connection_->Send(message); + connection_->Send(&message_end, 1); + } + + private: + static const int kWebSocketHandshakeBodyLen; + + WebSocketHixie76(HttpConnection* connection, + const HttpServerRequestInfo& request, + size_t* pos) : WebSocket(connection) { + std::string key1 = request.GetHeaderValue("Sec-WebSocket-Key1"); + std::string key2 = request.GetHeaderValue("Sec-WebSocket-Key2"); + + if (key1.empty()) { + connection->Send500("Invalid request format. " + "Sec-WebSocket-Key1 is empty or isn't specified."); + return; + } + + if (key2.empty()) { + connection->Send500("Invalid request format. " + "Sec-WebSocket-Key2 is empty or isn't specified."); + return; + } + + key3_ = connection->recv_data().substr( + *pos, + *pos + kWebSocketHandshakeBodyLen); + *pos += kWebSocketHandshakeBodyLen; + } + + std::string key3_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); +}; + +const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; + + +// Constants for hybi-10 frame format. + +typedef int OpCode; + +const OpCode kOpCodeContinuation = 0x0; +const OpCode kOpCodeText = 0x1; +const OpCode kOpCodeBinary = 0x2; +const OpCode kOpCodeClose = 0x8; +const OpCode kOpCodePing = 0x9; +const OpCode kOpCodePong = 0xA; + +const unsigned char kFinalBit = 0x80; +const unsigned char kReserved1Bit = 0x40; +const unsigned char kReserved2Bit = 0x20; +const unsigned char kReserved3Bit = 0x10; +const unsigned char kOpCodeMask = 0xF; +const unsigned char kMaskBit = 0x80; +const unsigned char kPayloadLengthMask = 0x7F; + +const size_t kMaxSingleBytePayloadLength = 125; +const size_t kTwoBytePayloadLengthField = 126; +const size_t kEightBytePayloadLengthField = 127; +const size_t kMaskingKeyWidthInBytes = 4; + +class WebSocketHybi10 : public WebSocket { + public: + static WebSocket* Create(HttpConnection* connection, + const HttpServerRequestInfo& request, + size_t* pos) { + std::string version = request.GetHeaderValue("Sec-WebSocket-Version"); + if (version != "8") + return NULL; + + std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); + if (key.empty()) { + connection->Send500("Invalid request format. " + "Sec-WebSocket-Key is empty or isn't specified."); + return NULL; + } + return new WebSocketHybi10(connection, request, pos); + } + + virtual void Accept(const HttpServerRequestInfo& request) { + static const char* const kWebSocketGuid = + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); + std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); + std::string encoded_hash; + base::Base64Encode(base::SHA1HashString(data), &encoded_hash); + + std::string response = base::StringPrintf( + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n" + "\r\n", + encoded_hash.c_str()); + connection_->Send(response); + } + + virtual ParseResult Read(std::string* message) { + size_t data_length = connection_->recv_data().length(); + if (data_length < 2) + return FRAME_INCOMPLETE; + + const char* p = connection_->recv_data().c_str(); + const char* buffer_end = p + data_length; + + unsigned char first_byte = *p++; + unsigned char second_byte = *p++; + + final_ = first_byte & kFinalBit; + reserved1_ = first_byte & kReserved1Bit; + reserved2_ = first_byte & kReserved2Bit; + reserved3_ = first_byte & kReserved3Bit; + op_code_ = first_byte & kOpCodeMask; + masked_ = second_byte & kMaskBit; + + CHECK_EQ(kOpCodeText, op_code_); + + uint64 payload_length64 = second_byte & kPayloadLengthMask; + if (payload_length64 > kMaxSingleBytePayloadLength) { + int extended_payload_length_size; + if (payload_length64 == kTwoBytePayloadLengthField) + extended_payload_length_size = 2; + else { + DCHECK(payload_length64 == kEightBytePayloadLengthField); + extended_payload_length_size = 8; + } + if (buffer_end - p < extended_payload_length_size) + return FRAME_INCOMPLETE; + payload_length64 = 0; + for (int i = 0; i < extended_payload_length_size; ++i) { + payload_length64 <<= 8; + payload_length64 |= static_cast<unsigned char>(*p++); + } + } + + static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; + size_t masking_key_length = masked_ ? kMaskingKeyWidthInBytes : 0; + static size_t max_length = std::numeric_limits<size_t>::max(); + if (payload_length64 > max_payload_length || + payload_length64 + masking_key_length > max_length) { + // WebSocket frame length too large. + return FRAME_ERROR; + } + payload_length_ = static_cast<size_t>(payload_length64); + + size_t total_length = masking_key_length + payload_length_; + if (static_cast<size_t>(buffer_end - p) < total_length) + return FRAME_INCOMPLETE; + + if (masked_) { + message->resize(payload_length_); + const char* masking_key = p; + char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); + for (size_t i = 0; i < payload_length_; ++i) // Unmask the payload. + (*message)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; + } else { + std::string buffer(p, p + payload_length_); + message->swap(buffer); + } + + size_t pos = p + masking_key_length + payload_length_ - + connection_->recv_data().c_str(); + connection_->Shift(pos); + return FRAME_OK; + } + + virtual void Send(const std::string& message) { + std::vector<char> frame; + OpCode op_code = kOpCodeText; + size_t data_length = message.length(); + + frame.push_back(kFinalBit | op_code); + if (data_length <= kMaxSingleBytePayloadLength) + frame.push_back(kMaskBit | data_length); + else if (data_length <= 0xFFFF) { + frame.push_back(kMaskBit | kTwoBytePayloadLengthField); + frame.push_back((data_length & 0xFF00) >> 8); + frame.push_back(data_length & 0xFF); + } else { + frame.push_back(kMaskBit | kEightBytePayloadLengthField); + char extended_payload_length[8]; + size_t remaining = data_length; + // Fill the length into extended_payload_length in the network byte order. + for (int i = 0; i < 8; ++i) { + extended_payload_length[7 - i] = remaining & 0xFF; + remaining >>= 8; + } + frame.insert(frame.end(), + extended_payload_length, + extended_payload_length + 8); + DCHECK(!remaining); + } + + // Mask the frame. + size_t masking_key_start = frame.size(); + // Add placeholder for masking key. Will be overwritten. + frame.resize(frame.size() + kMaskingKeyWidthInBytes); + size_t payload_start = frame.size(); + const char* data = message.c_str(); + frame.insert(frame.end(), data, data + data_length); + + base::RandBytes(&frame[0] + masking_key_start, + kMaskingKeyWidthInBytes); + for (size_t i = 0; i < data_length; ++i) { + frame[payload_start + i] ^= + frame[masking_key_start + i % kMaskingKeyWidthInBytes]; + } + connection_->Send(&frame[0], frame.size()); + } + + private: + WebSocketHybi10(HttpConnection* connection, + const HttpServerRequestInfo& request, + size_t* pos) + : WebSocket(connection), + op_code_(0), + final_(false), + reserved1_(false), + reserved2_(false), + reserved3_(false), + masked_(false), + payload_(0), + payload_length_(0), + frame_end_(0) { + } + + OpCode op_code_; + bool final_; + bool reserved1_; + bool reserved2_; + bool reserved3_; + bool masked_; + const char* payload_; + size_t payload_length_; + const char* frame_end_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketHybi10); +}; + +} // anonymous namespace + +WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, + const HttpServerRequestInfo& request, + size_t* pos) { + WebSocket* socket = WebSocketHybi10::Create(connection, request, pos); + if (socket) + return socket; + + return WebSocketHixie76::Create(connection, request, pos); +} + +WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { +} + +} // namespace net diff --git a/net/server/web_socket.h b/net/server/web_socket.h new file mode 100644 index 0000000..baed07c --- /dev/null +++ b/net/server/web_socket.h @@ -0,0 +1,46 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SERVER_WEB_SOCKET_H_ +#define NET_SERVER_WEB_SOCKET_H_ +#pragma once + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" + +namespace net { + +class HttpConnection; +class HttpServerRequestInfo; + +class WebSocket { + public: + enum ParseResult { + FRAME_OK, + FRAME_INCOMPLETE, + FRAME_ERROR + }; + + static WebSocket* CreateWebSocket(HttpConnection* connection, + const HttpServerRequestInfo& request, + size_t* pos); + + virtual void Accept(const HttpServerRequestInfo& request) = 0; + virtual ParseResult Read(std::string* message) = 0; + virtual void Send(const std::string& message) = 0; + virtual ~WebSocket() {} + + protected: + explicit WebSocket(HttpConnection* connection); + HttpConnection* connection_; + + private: + DISALLOW_COPY_AND_ASSIGN(WebSocket); +}; + +} // namespace net + +#endif // NET_SERVER_WEB_SOCKET_H_ |