summaryrefslogtreecommitdiffstats
path: root/net/server
diff options
context:
space:
mode:
authoryhirano <yhirano@chromium.org>2015-09-28 02:06:34 -0700
committerCommit bot <commit-bot@chromium.org>2015-09-28 09:07:37 +0000
commita10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e (patch)
tree32c1e7e310b7564625f1966cc8d40564f9b4bca0 /net/server
parent59fb12ff0a75637b5c785b420653cf1289b7c29c (diff)
downloadchromium_src-a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e.zip
chromium_src-a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e.tar.gz
chromium_src-a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e.tar.bz2
Fix WebSocketServer extension parser.
This CL makes the WebSocket server in net/server use the net/websockets parser for parsing Sec-WebSocket-Extensions in the extension negotiation. The new implementation validates the extension negotiation offer more strictly than before. Specifically, - Malformed Sec-WebSocket-Extensions header value causes connection failure. - Previously it was just ignored. - Malformed permessage-deflate parameters are declined. - Previously part of such params were accepted partially. BUG=523228 Review URL: https://codereview.chromium.org/1340523002 Cr-Commit-Position: refs/heads/master@{#351040}
Diffstat (limited to 'net/server')
-rw-r--r--net/server/http_server.cc7
-rw-r--r--net/server/web_socket.cc124
-rw-r--r--net/server/web_socket.h11
-rw-r--r--net/server/web_socket_encoder.cc208
-rw-r--r--net/server/web_socket_encoder.h61
-rw-r--r--net/server/web_socket_encoder_unittest.cc104
6 files changed, 275 insertions, 240 deletions
diff --git a/net/server/http_server.cc b/net/server/http_server.cc
index 3abd44d..f3560e8 100644
--- a/net/server/http_server.cc
+++ b/net/server/http_server.cc
@@ -236,11 +236,8 @@ int HttpServer::HandleReadResult(HttpConnection* connection, int rv) {
connection->socket()->GetPeerAddress(&request.peer);
if (request.HasHeaderValue("connection", "upgrade")) {
- scoped_ptr<WebSocket> websocket =
- WebSocket::CreateWebSocket(this, connection, request);
- if (!websocket) // Not enough data was received.
- break;
- connection->SetWebSocket(websocket.Pass());
+ connection->SetWebSocket(
+ make_scoped_ptr(new WebSocket(this, connection)));
read_buf->DidConsume(pos);
delegate_->OnWebSocketRequest(connection->id(), request);
if (HasClosedConnection(connection))
diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc
index c963745..79ffcec 100644
--- a/net/server/web_socket.cc
+++ b/net/server/web_socket.cc
@@ -4,6 +4,8 @@
#include "net/server/web_socket.h"
+#include <vector>
+
#include "base/base64.h"
#include "base/logging.h"
#include "base/sha1.h"
@@ -15,70 +17,87 @@
#include "net/server/http_server_request_info.h"
#include "net/server/http_server_response_info.h"
#include "net/server/web_socket_encoder.h"
+#include "net/websockets/websocket_deflate_parameters.h"
+#include "net/websockets/websocket_extension.h"
+#include "net/websockets/websocket_handshake_constants.h"
namespace net {
-WebSocket::WebSocket(HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request)
- : server_(server), connection_(connection), closed_(false) {
- std::string request_extensions =
- request.GetHeaderValue("sec-websocket-extensions");
- encoder_.reset(WebSocketEncoder::CreateServer(request_extensions,
- &response_extensions_));
- if (!response_extensions_.empty()) {
- response_extensions_ =
- "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n";
- }
+namespace {
+
+std::string ExtensionsHeaderString(
+ const std::vector<WebSocketExtension>& extensions) {
+ if (extensions.empty())
+ return std::string();
+
+ std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString();
+ for (size_t i = 1; i < extensions.size(); ++i)
+ result += ", " + extensions[i].ToString();
+ return result + "\r\n";
}
+std::string ValidResponseString(
+ const std::string& accept_hash,
+ const std::vector<WebSocketExtension> extensions) {
+ return base::StringPrintf(
+ "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: %s\r\n"
+ "%s"
+ "\r\n",
+ accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str());
+}
+
+} // namespace
+
+WebSocket::WebSocket(HttpServer* server, HttpConnection* connection)
+ : server_(server), connection_(connection), closed_(false) {}
+
WebSocket::~WebSocket() {}
-scoped_ptr<WebSocket> WebSocket::CreateWebSocket(
- HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request) {
+void WebSocket::Accept(const HttpServerRequestInfo& request) {
std::string version = request.GetHeaderValue("sec-websocket-version");
if (version != "8" && version != "13") {
- server->SendResponse(
- connection->id(),
- HttpServerResponseInfo::CreateFor500(
- "Invalid request format. The version is not valid."));
- return nullptr;
+ SendErrorResponse("Invalid request format. The version is not valid.");
+ return;
}
std::string key = request.GetHeaderValue("sec-websocket-key");
if (key.empty()) {
- server->SendResponse(
- connection->id(),
- HttpServerResponseInfo::CreateFor500(
- "Invalid request format. Sec-WebSocket-Key is empty or isn't "
- "specified."));
- return nullptr;
+ SendErrorResponse(
+ "Invalid request format. Sec-WebSocket-Key is empty or isn't "
+ "specified.");
+ return;
}
- return make_scoped_ptr(new WebSocket(server, connection, request));
-}
-
-void WebSocket::Accept(const HttpServerRequestInfo& request) {
- static const char* const kWebSocketGuid =
- "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
- std::string key = request.GetHeaderValue("sec-websocket-key");
- std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid);
std::string encoded_hash;
- base::Base64Encode(base::SHA1HashString(data), &encoded_hash);
-
- server_->SendRaw(
- connection_->id(),
- base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Sec-WebSocket-Accept: %s\r\n"
- "%s"
- "\r\n",
- encoded_hash.c_str(), response_extensions_.c_str()));
+ base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid),
+ &encoded_hash);
+
+ std::vector<WebSocketExtension> response_extensions;
+ auto i = request.headers.find("sec-websocket-extensions");
+ if (i == request.headers.end()) {
+ encoder_ = WebSocketEncoder::CreateServer();
+ } else {
+ WebSocketDeflateParameters params;
+ encoder_ = WebSocketEncoder::CreateServer(i->second, &params);
+ if (!encoder_) {
+ Fail();
+ return;
+ }
+ if (encoder_->deflate_enabled()) {
+ DCHECK(params.IsValidAsResponse());
+ response_extensions.push_back(params.AsExtension());
+ }
+ }
+ server_->SendRaw(connection_->id(),
+ ValidResponseString(encoded_hash, response_extensions));
}
WebSocket::ParseResult WebSocket::Read(std::string* message) {
+ if (closed_)
+ return FRAME_CLOSE;
+
HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize());
int bytes_consumed = 0;
@@ -98,4 +117,17 @@ void WebSocket::Send(const std::string& message) {
server_->SendRaw(connection_->id(), encoded);
}
+void WebSocket::Fail() {
+ closed_ = true;
+ // TODO(yhirano): The server SHOULD log the problem.
+ server_->Close(connection_->id());
+}
+
+void WebSocket::SendErrorResponse(const std::string& message) {
+ if (closed_)
+ return;
+ closed_ = true;
+ server_->Send500(connection_->id(), message);
+}
+
} // namespace net
diff --git a/net/server/web_socket.h b/net/server/web_socket.h
index d9509d8..5309544 100644
--- a/net/server/web_socket.h
+++ b/net/server/web_socket.h
@@ -27,10 +27,7 @@ class WebSocket final {
FRAME_ERROR
};
- static scoped_ptr<WebSocket> CreateWebSocket(
- HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request);
+ WebSocket(HttpServer* server, HttpConnection* connection);
void Accept(const HttpServerRequestInfo& request);
ParseResult Read(std::string* message);
@@ -38,14 +35,12 @@ class WebSocket final {
~WebSocket();
private:
- WebSocket(HttpServer* server,
- HttpConnection* connection,
- const HttpServerRequestInfo& request);
+ void Fail();
+ void SendErrorResponse(const std::string& message);
HttpServer* const server_;
HttpConnection* const connection_;
scoped_ptr<WebSocketEncoder> encoder_;
- std::string response_extensions_;
bool closed_;
DISALLOW_COPY_AND_ASSIGN(WebSocket);
diff --git a/net/server/web_socket_encoder.cc b/net/server/web_socket_encoder.cc
index 1a5431a..b1b93ee 100644
--- a/net/server/web_socket_encoder.cc
+++ b/net/server/web_socket_encoder.cc
@@ -4,10 +4,14 @@
#include "net/server/web_socket_encoder.h"
+#include <vector>
+
#include "base/logging.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "net/base/io_buffer.h"
+#include "net/websockets/websocket_deflate_parameters.h"
+#include "net/websockets/websocket_extension.h"
#include "net/websockets/websocket_extension_parser.h"
namespace net {
@@ -180,151 +184,111 @@ void EncodeFrameHybi17(const std::string& message,
} // anonymous namespace
// static
-WebSocketEncoder* WebSocketEncoder::CreateServer(
- const std::string& request_extensions,
- std::string* response_extensions) {
- bool deflate;
- bool has_client_window_bits;
- int client_window_bits;
- int server_window_bits;
- bool client_no_context_takeover;
- bool server_no_context_takeover;
- ParseExtensions(request_extensions, &deflate, &has_client_window_bits,
- &client_window_bits, &server_window_bits,
- &client_no_context_takeover, &server_no_context_takeover);
-
- if (deflate) {
- *response_extensions = base::StringPrintf(
- "permessage-deflate; server_max_window_bits=%d%s", server_window_bits,
- server_no_context_takeover ? "; server_no_context_takeover" : "");
- if (has_client_window_bits) {
- base::StringAppendF(response_extensions, "; client_max_window_bits=%d",
- client_window_bits);
- } else {
- DCHECK_EQ(client_window_bits, 15);
- }
- return new WebSocketEncoder(true /* is_server */, server_window_bits,
- client_window_bits, server_no_context_takeover);
- } else {
- *response_extensions = std::string();
- return new WebSocketEncoder(true /* is_server */);
- }
+scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
+ return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
}
// static
-WebSocketEncoder* WebSocketEncoder::CreateClient(
- const std::string& response_extensions) {
- bool deflate;
- bool has_client_window_bits;
- int client_window_bits;
- int server_window_bits;
- bool client_no_context_takeover;
- bool server_no_context_takeover;
- ParseExtensions(response_extensions, &deflate, &has_client_window_bits,
- &client_window_bits, &server_window_bits,
- &client_no_context_takeover, &server_no_context_takeover);
-
- if (deflate) {
- return new WebSocketEncoder(false /* is_server */, client_window_bits,
- server_window_bits, client_no_context_takeover);
- } else {
- return new WebSocketEncoder(false /* is_server */);
+scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
+ const std::string& extensions,
+ WebSocketDeflateParameters* deflate_parameters) {
+ WebSocketExtensionParser parser;
+ if (!parser.Parse(extensions)) {
+ // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
+ // connection.
+ return nullptr;
}
-}
-
-// static
-void WebSocketEncoder::ParseExtensions(const std::string& header_value,
- bool* deflate,
- bool* has_client_window_bits,
- int* client_window_bits,
- int* server_window_bits,
- bool* client_no_context_takeover,
- bool* server_no_context_takeover) {
- *deflate = false;
- *has_client_window_bits = false;
- *client_window_bits = 15;
- *server_window_bits = 15;
- *client_no_context_takeover = false;
- *server_no_context_takeover = false;
-
- if (header_value.empty())
- return;
- WebSocketExtensionParser parser;
- if (!parser.Parse(header_value))
- return;
- const std::vector<WebSocketExtension>& extensions = parser.extensions();
- // TODO(tyoshino): Fail if this method is used for parsing a response and
- // there are multiple permessage-deflate extensions or there are any unknown
- // extensions.
- for (const auto& extension : extensions) {
- if (extension.name() != "permessage-deflate") {
+ for (const auto& extension : parser.extensions()) {
+ std::string failure_message;
+ WebSocketDeflateParameters offer;
+ if (!offer.Initialize(extension, &failure_message) ||
+ !offer.IsValidAsRequest(&failure_message)) {
+ // We decline unknown / malformed extensions.
continue;
}
- const std::vector<WebSocketExtension::Parameter>& parameters =
- extension.parameters();
- for (const auto& param : parameters) {
- const std::string& name = param.name();
- // TODO(tyoshino): Fail the connection when an invalid value is given.
- if (name == "client_max_window_bits") {
- *has_client_window_bits = true;
- if (param.HasValue()) {
- int bits = 0;
- if (base::StringToInt(param.value(), &bits) && bits >= 8 &&
- bits <= 15) {
- *client_window_bits = bits;
- }
- }
- }
- if (name == "server_max_window_bits" && param.HasValue()) {
- int bits = 0;
- if (base::StringToInt(param.value(), &bits) && bits >= 8 && bits <= 15)
- *server_window_bits = bits;
- }
- if (name == "client_no_context_takeover")
- *client_no_context_takeover = true;
- if (name == "server_no_context_takeover")
- *server_no_context_takeover = true;
+ WebSocketDeflateParameters response = offer;
+ if (offer.is_client_max_window_bits_specified() &&
+ !offer.has_client_max_window_bits_value()) {
+ // We need to choose one value for the response.
+ response.SetClientMaxWindowBits(15);
}
- *deflate = true;
-
- break;
+ DCHECK(response.IsValidAsResponse());
+ DCHECK(offer.IsCompatibleWith(response));
+ auto deflater = make_scoped_ptr(
+ new WebSocketDeflater(response.server_context_take_over_mode()));
+ auto inflater = make_scoped_ptr(
+ new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize));
+ if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
+ !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
+ // For some reason we cannot accept the parameters.
+ continue;
+ }
+ *deflate_parameters = response;
+ return make_scoped_ptr(
+ new WebSocketEncoder(FOR_SERVER, deflater.Pass(), inflater.Pass()));
}
-}
-WebSocketEncoder::WebSocketEncoder(bool is_server) : is_server_(is_server) {
+ // We cannot find an acceptable offer.
+ return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
}
-WebSocketEncoder::WebSocketEncoder(bool is_server,
- int deflate_bits,
- int inflate_bits,
- bool no_context_takeover)
- : is_server_(is_server) {
- deflater_.reset(new WebSocketDeflater(
- no_context_takeover ? WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT
- : WebSocketDeflater::TAKE_OVER_CONTEXT));
- inflater_.reset(
- new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize));
+// static
+WebSocketEncoder* WebSocketEncoder::CreateClient(
+ const std::string& response_extensions) {
+ // TODO(yhirano): Add a way to return an error.
- if (!deflater_->Initialize(deflate_bits) ||
- !inflater_->Initialize(inflate_bits)) {
- // Disable deflate support.
- deflater_.reset();
- inflater_.reset();
+ WebSocketExtensionParser parser;
+ if (!parser.Parse(response_extensions)) {
+ // Parse error. Note that there are two cases here.
+ // 1) There is no Sec-WebSocket-Extensions header.
+ // 2) There is a malformed Sec-WebSocketExtensions header.
+ // We should return a deflate-disabled encoder for the former case and
+ // fail the connection for the latter case.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
+ }
+ if (parser.extensions().size() != 1) {
+ // Only permessage-deflate extension is supported.
+ // TODO (yhirano): Fail the connection.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
+ }
+ const auto& extension = parser.extensions()[0];
+ WebSocketDeflateParameters params;
+ std::string failure_message;
+ if (!params.Initialize(extension, &failure_message) ||
+ !params.IsValidAsResponse(&failure_message)) {
+ // TODO (yhirano): Fail the connection.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
}
-}
-WebSocketEncoder::~WebSocketEncoder() {
+ auto deflater = make_scoped_ptr(
+ new WebSocketDeflater(params.client_context_take_over_mode()));
+ auto inflater = make_scoped_ptr(
+ new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize));
+ if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
+ !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
+ // TODO (yhirano): Fail the connection.
+ return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr);
+ }
+
+ return new WebSocketEncoder(FOR_CLIENT, deflater.Pass(), inflater.Pass());
}
+WebSocketEncoder::WebSocketEncoder(Type type,
+ scoped_ptr<WebSocketDeflater> deflater,
+ scoped_ptr<WebSocketInflater> inflater)
+ : type_(type), deflater_(deflater.Pass()), inflater_(inflater.Pass()) {}
+
+WebSocketEncoder::~WebSocketEncoder() {}
+
WebSocket::ParseResult WebSocketEncoder::DecodeFrame(
const base::StringPiece& frame,
int* bytes_consumed,
std::string* output) {
bool compressed;
- WebSocket::ParseResult result =
- DecodeFrameHybi17(frame, is_server_, bytes_consumed, output, &compressed);
+ WebSocket::ParseResult result = DecodeFrameHybi17(
+ frame, type_ == FOR_SERVER, bytes_consumed, output, &compressed);
if (result == WebSocket::FRAME_OK && compressed) {
if (!Inflate(output))
result = WebSocket::FRAME_ERROR;
diff --git a/net/server/web_socket_encoder.h b/net/server/web_socket_encoder.h
index 23f0d9c..1eb749f 100644
--- a/net/server/web_socket_encoder.h
+++ b/net/server/web_socket_encoder.h
@@ -16,61 +16,50 @@
namespace net {
-class WebSocketEncoder {
+class WebSocketDeflateParameters;
+
+class WebSocketEncoder final {
public:
- ~WebSocketEncoder();
+ static const char kClientExtensions[];
- static WebSocketEncoder* CreateServer(const std::string& request_extensions,
- std::string* response_extensions);
+ ~WebSocketEncoder();
- static const char kClientExtensions[];
+ // Creates and returns an encoder for a server without extensions.
+ static scoped_ptr<WebSocketEncoder> CreateServer();
+ // Creates and returns an encoder.
+ // |extensions| is the value of a Sec-WebSocket-Extensions header.
+ // Returns nullptr when there is an error.
+ static scoped_ptr<WebSocketEncoder> CreateServer(
+ const std::string& extensions,
+ WebSocketDeflateParameters* params);
+ // TODO(yhirano): Return a scoped_ptr instead of a raw pointer.
static WebSocketEncoder* CreateClient(const std::string& response_extensions);
WebSocket::ParseResult DecodeFrame(const base::StringPiece& frame,
int* bytes_consumed,
std::string* output);
-
void EncodeFrame(const std::string& frame,
int masking_key,
std::string* output);
+ bool deflate_enabled() const { return deflater_; }
+
private:
- explicit WebSocketEncoder(bool is_server);
- WebSocketEncoder(bool is_server,
- int deflate_bits,
- int inflate_bits,
- bool no_context_takeover);
-
- // Parses a value in the Sec-WebSocket-Extensions header. If it contains a
- // single element of the permessage-deflate extension, stores the result of
- // parsing the parameters of the extension into the given variables.
- // Otherwise, returns with *deflate set to false.
- //
- // - If the client_max_window_bits parameter is missing, *client_window_bits
- // defaults to 15.
- // - If the client_max_window_bits parameter has an invalid value,
- // *client_window_bits will be set to 0.
- // - If the server_max_window_bits parameter is missing, *server_window_bits
- // defaults to 15.
- // - If the server_max_window_bits parameter has an invalid value,
- // *client_window_bits will be set to 0.
- //
- // TODO(tyoshino): Consider using a struct than taking a lot of pointers for
- // output.
- static void ParseExtensions(const std::string& header_value,
- bool* deflate,
- bool* has_client_window_bits,
- int* client_window_bits,
- int* server_window_bits,
- bool* client_no_context_takeover,
- bool* server_no_context_takeover);
+ enum Type {
+ FOR_SERVER,
+ FOR_CLIENT,
+ };
+
+ WebSocketEncoder(Type type,
+ scoped_ptr<WebSocketDeflater> deflater,
+ scoped_ptr<WebSocketInflater> inflater);
bool Inflate(std::string* message);
bool Deflate(const std::string& message, std::string* output);
+ Type type_;
scoped_ptr<WebSocketDeflater> deflater_;
scoped_ptr<WebSocketInflater> inflater_;
- bool is_server_;
DISALLOW_COPY_AND_ASSIGN(WebSocketEncoder);
};
diff --git a/net/server/web_socket_encoder_unittest.cc b/net/server/web_socket_encoder_unittest.cc
index 7bca876..9991bd7 100644
--- a/net/server/web_socket_encoder_unittest.cc
+++ b/net/server/web_socket_encoder_unittest.cc
@@ -3,30 +3,76 @@
// found in the LICENSE file.
#include "net/server/web_socket_encoder.h"
+
+#include "net/websockets/websocket_deflate_parameters.h"
+#include "net/websockets/websocket_extension.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
+TEST(WebSocketEncoderHandshakeTest, EmptyRequestShouldBeRejected) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server =
+ WebSocketEncoder::CreateServer("", &params);
+
+ EXPECT_FALSE(server);
+}
+
TEST(WebSocketEncoderHandshakeTest,
CreateServerWithoutClientMaxWindowBitsParameter) {
- std::string response_extensions;
- scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer(
- "permessage-deflate", &response_extensions));
- // The response must not include client_max_window_bits if the client didn't
- // declare that it accepts the parameter.
- EXPECT_EQ("permessage-deflate; server_max_window_bits=15",
- response_extensions);
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server =
+ WebSocketEncoder::CreateServer("permessage-deflate", &params);
+
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate", params.AsExtension().ToString());
}
TEST(WebSocketEncoderHandshakeTest,
CreateServerWithServerNoContextTakeoverParameter) {
- std::string response_extensions;
- scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer(
- "permessage-deflate; server_no_context_takeover", &response_extensions));
- EXPECT_EQ(
- "permessage-deflate; server_max_window_bits=15"
- "; server_no_context_takeover",
- response_extensions);
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer(
+ "permessage-deflate; server_no_context_takeover", &params);
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; server_no_context_takeover",
+ params.AsExtension().ToString());
+}
+
+TEST(WebSocketEncoderHandshakeTest, FirstExtensionShouldBeChosen) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer(
+ "permessage-deflate; server_no_context_takeover,"
+ "permessage-deflate; server_max_window_bits=15",
+ &params);
+
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; server_no_context_takeover",
+ params.AsExtension().ToString());
+}
+
+TEST(WebSocketEncoderHandshakeTest, FirstValidExtensionShouldBeChosen) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer(
+ "permessage-deflate; Xserver_no_context_takeover,"
+ "permessage-deflate; server_max_window_bits=15",
+ &params);
+
+ ASSERT_TRUE(server);
+ EXPECT_TRUE(server->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; server_max_window_bits=15",
+ params.AsExtension().ToString());
+}
+
+TEST(WebSocketEncoderHandshakeTest, AllExtensionsAreUnknownOrMalformed) {
+ WebSocketDeflateParameters params;
+ scoped_ptr<WebSocketEncoder> server =
+ WebSocketEncoder::CreateServer("unknown, permessage-deflate; x", &params);
+
+ ASSERT_TRUE(server);
+ EXPECT_FALSE(server->deflate_enabled());
}
class WebSocketEncoderTest : public testing::Test {
@@ -35,7 +81,7 @@ class WebSocketEncoderTest : public testing::Test {
void SetUp() override {
std::string response_extensions;
- server_.reset(WebSocketEncoder::CreateServer("", &response_extensions));
+ server_ = WebSocketEncoder::CreateServer();
EXPECT_EQ(std::string(), response_extensions);
client_.reset(WebSocketEncoder::CreateClient(""));
}
@@ -50,17 +96,29 @@ class WebSocketEncoderCompressionTest : public WebSocketEncoderTest {
WebSocketEncoderCompressionTest() : WebSocketEncoderTest() {}
void SetUp() override {
- std::string response_extensions;
- server_.reset(WebSocketEncoder::CreateServer(
- "permessage-deflate; client_max_window_bits", &response_extensions));
- EXPECT_EQ(
- "permessage-deflate; server_max_window_bits=15; "
- "client_max_window_bits=15",
- response_extensions);
- client_.reset(WebSocketEncoder::CreateClient(response_extensions));
+ WebSocketDeflateParameters params;
+ server_ = WebSocketEncoder::CreateServer(
+ "permessage-deflate; client_max_window_bits", &params);
+ ASSERT_TRUE(server_);
+ EXPECT_TRUE(server_->deflate_enabled());
+ EXPECT_EQ("permessage-deflate; client_max_window_bits=15",
+ params.AsExtension().ToString());
+ client_.reset(
+ WebSocketEncoder::CreateClient(params.AsExtension().ToString()));
}
};
+TEST_F(WebSocketEncoderTest, DeflateDisabledEncoder) {
+ scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer());
+ scoped_ptr<WebSocketEncoder> client(WebSocketEncoder::CreateClient(""));
+
+ ASSERT_TRUE(server);
+ ASSERT_TRUE(client);
+
+ EXPECT_FALSE(server->deflate_enabled());
+ EXPECT_FALSE(client->deflate_enabled());
+}
+
TEST_F(WebSocketEncoderTest, ClientToServer) {
std::string frame("ClientToServer");
int mask = 123456;