diff options
author | yhirano <yhirano@chromium.org> | 2015-09-28 02:06:34 -0700 |
---|---|---|
committer | Commit bot <commit-bot@chromium.org> | 2015-09-28 09:07:37 +0000 |
commit | a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e (patch) | |
tree | 32c1e7e310b7564625f1966cc8d40564f9b4bca0 /net/server | |
parent | 59fb12ff0a75637b5c785b420653cf1289b7c29c (diff) | |
download | chromium_src-a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e.zip chromium_src-a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e.tar.gz chromium_src-a10dd4ef5e42a54bea6ce71ef3f3d9f974dbb37e.tar.bz2 |
Fix WebSocketServer extension parser.
This CL makes the WebSocket server in net/server use the net/websockets
parser for parsing Sec-WebSocket-Extensions in the extension negotiation.
The new implementation validates the extension negotiation offer more
strictly than before. Specifically,
- Malformed Sec-WebSocket-Extensions header value causes connection failure.
- Previously it was just ignored.
- Malformed permessage-deflate parameters are declined.
- Previously part of such params were accepted partially.
BUG=523228
Review URL: https://codereview.chromium.org/1340523002
Cr-Commit-Position: refs/heads/master@{#351040}
Diffstat (limited to 'net/server')
-rw-r--r-- | net/server/http_server.cc | 7 | ||||
-rw-r--r-- | net/server/web_socket.cc | 124 | ||||
-rw-r--r-- | net/server/web_socket.h | 11 | ||||
-rw-r--r-- | net/server/web_socket_encoder.cc | 208 | ||||
-rw-r--r-- | net/server/web_socket_encoder.h | 61 | ||||
-rw-r--r-- | net/server/web_socket_encoder_unittest.cc | 104 |
6 files changed, 275 insertions, 240 deletions
diff --git a/net/server/http_server.cc b/net/server/http_server.cc index 3abd44d..f3560e8 100644 --- a/net/server/http_server.cc +++ b/net/server/http_server.cc @@ -236,11 +236,8 @@ int HttpServer::HandleReadResult(HttpConnection* connection, int rv) { connection->socket()->GetPeerAddress(&request.peer); if (request.HasHeaderValue("connection", "upgrade")) { - scoped_ptr<WebSocket> websocket = - WebSocket::CreateWebSocket(this, connection, request); - if (!websocket) // Not enough data was received. - break; - connection->SetWebSocket(websocket.Pass()); + connection->SetWebSocket( + make_scoped_ptr(new WebSocket(this, connection))); read_buf->DidConsume(pos); delegate_->OnWebSocketRequest(connection->id(), request); if (HasClosedConnection(connection)) diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc index c963745..79ffcec 100644 --- a/net/server/web_socket.cc +++ b/net/server/web_socket.cc @@ -4,6 +4,8 @@ #include "net/server/web_socket.h" +#include <vector> + #include "base/base64.h" #include "base/logging.h" #include "base/sha1.h" @@ -15,70 +17,87 @@ #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" #include "net/server/web_socket_encoder.h" +#include "net/websockets/websocket_deflate_parameters.h" +#include "net/websockets/websocket_extension.h" +#include "net/websockets/websocket_handshake_constants.h" namespace net { -WebSocket::WebSocket(HttpServer* server, - HttpConnection* connection, - const HttpServerRequestInfo& request) - : server_(server), connection_(connection), closed_(false) { - std::string request_extensions = - request.GetHeaderValue("sec-websocket-extensions"); - encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, - &response_extensions_)); - if (!response_extensions_.empty()) { - response_extensions_ = - "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; - } +namespace { + +std::string ExtensionsHeaderString( + const std::vector<WebSocketExtension>& extensions) { + if (extensions.empty()) + return std::string(); + + std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString(); + for (size_t i = 1; i < extensions.size(); ++i) + result += ", " + extensions[i].ToString(); + return result + "\r\n"; } +std::string ValidResponseString( + const std::string& accept_hash, + const std::vector<WebSocketExtension> extensions) { + return base::StringPrintf( + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n" + "%s" + "\r\n", + accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str()); +} + +} // namespace + +WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) + : server_(server), connection_(connection), closed_(false) {} + WebSocket::~WebSocket() {} -scoped_ptr<WebSocket> WebSocket::CreateWebSocket( - HttpServer* server, - HttpConnection* connection, - const HttpServerRequestInfo& request) { +void WebSocket::Accept(const HttpServerRequestInfo& request) { std::string version = request.GetHeaderValue("sec-websocket-version"); if (version != "8" && version != "13") { - server->SendResponse( - connection->id(), - HttpServerResponseInfo::CreateFor500( - "Invalid request format. The version is not valid.")); - return nullptr; + SendErrorResponse("Invalid request format. The version is not valid."); + return; } std::string key = request.GetHeaderValue("sec-websocket-key"); if (key.empty()) { - server->SendResponse( - connection->id(), - HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key is empty or isn't " - "specified.")); - return nullptr; + SendErrorResponse( + "Invalid request format. Sec-WebSocket-Key is empty or isn't " + "specified."); + return; } - return make_scoped_ptr(new WebSocket(server, connection, request)); -} - -void WebSocket::Accept(const HttpServerRequestInfo& request) { - static const char* const kWebSocketGuid = - "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - std::string key = request.GetHeaderValue("sec-websocket-key"); - std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); std::string encoded_hash; - base::Base64Encode(base::SHA1HashString(data), &encoded_hash); - - server_->SendRaw( - connection_->id(), - base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: %s\r\n" - "%s" - "\r\n", - encoded_hash.c_str(), response_extensions_.c_str())); + base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid), + &encoded_hash); + + std::vector<WebSocketExtension> response_extensions; + auto i = request.headers.find("sec-websocket-extensions"); + if (i == request.headers.end()) { + encoder_ = WebSocketEncoder::CreateServer(); + } else { + WebSocketDeflateParameters params; + encoder_ = WebSocketEncoder::CreateServer(i->second, ¶ms); + if (!encoder_) { + Fail(); + return; + } + if (encoder_->deflate_enabled()) { + DCHECK(params.IsValidAsResponse()); + response_extensions.push_back(params.AsExtension()); + } + } + server_->SendRaw(connection_->id(), + ValidResponseString(encoded_hash, response_extensions)); } WebSocket::ParseResult WebSocket::Read(std::string* message) { + if (closed_) + return FRAME_CLOSE; + HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); int bytes_consumed = 0; @@ -98,4 +117,17 @@ void WebSocket::Send(const std::string& message) { server_->SendRaw(connection_->id(), encoded); } +void WebSocket::Fail() { + closed_ = true; + // TODO(yhirano): The server SHOULD log the problem. + server_->Close(connection_->id()); +} + +void WebSocket::SendErrorResponse(const std::string& message) { + if (closed_) + return; + closed_ = true; + server_->Send500(connection_->id(), message); +} + } // namespace net diff --git a/net/server/web_socket.h b/net/server/web_socket.h index d9509d8..5309544 100644 --- a/net/server/web_socket.h +++ b/net/server/web_socket.h @@ -27,10 +27,7 @@ class WebSocket final { FRAME_ERROR }; - static scoped_ptr<WebSocket> CreateWebSocket( - HttpServer* server, - HttpConnection* connection, - const HttpServerRequestInfo& request); + WebSocket(HttpServer* server, HttpConnection* connection); void Accept(const HttpServerRequestInfo& request); ParseResult Read(std::string* message); @@ -38,14 +35,12 @@ class WebSocket final { ~WebSocket(); private: - WebSocket(HttpServer* server, - HttpConnection* connection, - const HttpServerRequestInfo& request); + void Fail(); + void SendErrorResponse(const std::string& message); HttpServer* const server_; HttpConnection* const connection_; scoped_ptr<WebSocketEncoder> encoder_; - std::string response_extensions_; bool closed_; DISALLOW_COPY_AND_ASSIGN(WebSocket); diff --git a/net/server/web_socket_encoder.cc b/net/server/web_socket_encoder.cc index 1a5431a..b1b93ee 100644 --- a/net/server/web_socket_encoder.cc +++ b/net/server/web_socket_encoder.cc @@ -4,10 +4,14 @@ #include "net/server/web_socket_encoder.h" +#include <vector> + #include "base/logging.h" #include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" #include "net/base/io_buffer.h" +#include "net/websockets/websocket_deflate_parameters.h" +#include "net/websockets/websocket_extension.h" #include "net/websockets/websocket_extension_parser.h" namespace net { @@ -180,151 +184,111 @@ void EncodeFrameHybi17(const std::string& message, } // anonymous namespace // static -WebSocketEncoder* WebSocketEncoder::CreateServer( - const std::string& request_extensions, - std::string* response_extensions) { - bool deflate; - bool has_client_window_bits; - int client_window_bits; - int server_window_bits; - bool client_no_context_takeover; - bool server_no_context_takeover; - ParseExtensions(request_extensions, &deflate, &has_client_window_bits, - &client_window_bits, &server_window_bits, - &client_no_context_takeover, &server_no_context_takeover); - - if (deflate) { - *response_extensions = base::StringPrintf( - "permessage-deflate; server_max_window_bits=%d%s", server_window_bits, - server_no_context_takeover ? "; server_no_context_takeover" : ""); - if (has_client_window_bits) { - base::StringAppendF(response_extensions, "; client_max_window_bits=%d", - client_window_bits); - } else { - DCHECK_EQ(client_window_bits, 15); - } - return new WebSocketEncoder(true /* is_server */, server_window_bits, - client_window_bits, server_no_context_takeover); - } else { - *response_extensions = std::string(); - return new WebSocketEncoder(true /* is_server */); - } +scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() { + return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr)); } // static -WebSocketEncoder* WebSocketEncoder::CreateClient( - const std::string& response_extensions) { - bool deflate; - bool has_client_window_bits; - int client_window_bits; - int server_window_bits; - bool client_no_context_takeover; - bool server_no_context_takeover; - ParseExtensions(response_extensions, &deflate, &has_client_window_bits, - &client_window_bits, &server_window_bits, - &client_no_context_takeover, &server_no_context_takeover); - - if (deflate) { - return new WebSocketEncoder(false /* is_server */, client_window_bits, - server_window_bits, client_no_context_takeover); - } else { - return new WebSocketEncoder(false /* is_server */); +scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer( + const std::string& extensions, + WebSocketDeflateParameters* deflate_parameters) { + WebSocketExtensionParser parser; + if (!parser.Parse(extensions)) { + // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the + // connection. + return nullptr; } -} - -// static -void WebSocketEncoder::ParseExtensions(const std::string& header_value, - bool* deflate, - bool* has_client_window_bits, - int* client_window_bits, - int* server_window_bits, - bool* client_no_context_takeover, - bool* server_no_context_takeover) { - *deflate = false; - *has_client_window_bits = false; - *client_window_bits = 15; - *server_window_bits = 15; - *client_no_context_takeover = false; - *server_no_context_takeover = false; - - if (header_value.empty()) - return; - WebSocketExtensionParser parser; - if (!parser.Parse(header_value)) - return; - const std::vector<WebSocketExtension>& extensions = parser.extensions(); - // TODO(tyoshino): Fail if this method is used for parsing a response and - // there are multiple permessage-deflate extensions or there are any unknown - // extensions. - for (const auto& extension : extensions) { - if (extension.name() != "permessage-deflate") { + for (const auto& extension : parser.extensions()) { + std::string failure_message; + WebSocketDeflateParameters offer; + if (!offer.Initialize(extension, &failure_message) || + !offer.IsValidAsRequest(&failure_message)) { + // We decline unknown / malformed extensions. continue; } - const std::vector<WebSocketExtension::Parameter>& parameters = - extension.parameters(); - for (const auto& param : parameters) { - const std::string& name = param.name(); - // TODO(tyoshino): Fail the connection when an invalid value is given. - if (name == "client_max_window_bits") { - *has_client_window_bits = true; - if (param.HasValue()) { - int bits = 0; - if (base::StringToInt(param.value(), &bits) && bits >= 8 && - bits <= 15) { - *client_window_bits = bits; - } - } - } - if (name == "server_max_window_bits" && param.HasValue()) { - int bits = 0; - if (base::StringToInt(param.value(), &bits) && bits >= 8 && bits <= 15) - *server_window_bits = bits; - } - if (name == "client_no_context_takeover") - *client_no_context_takeover = true; - if (name == "server_no_context_takeover") - *server_no_context_takeover = true; + WebSocketDeflateParameters response = offer; + if (offer.is_client_max_window_bits_specified() && + !offer.has_client_max_window_bits_value()) { + // We need to choose one value for the response. + response.SetClientMaxWindowBits(15); } - *deflate = true; - - break; + DCHECK(response.IsValidAsResponse()); + DCHECK(offer.IsCompatibleWith(response)); + auto deflater = make_scoped_ptr( + new WebSocketDeflater(response.server_context_take_over_mode())); + auto inflater = make_scoped_ptr( + new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize)); + if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) || + !inflater->Initialize(response.PermissiveClientMaxWindowBits())) { + // For some reason we cannot accept the parameters. + continue; + } + *deflate_parameters = response; + return make_scoped_ptr( + new WebSocketEncoder(FOR_SERVER, deflater.Pass(), inflater.Pass())); } -} -WebSocketEncoder::WebSocketEncoder(bool is_server) : is_server_(is_server) { + // We cannot find an acceptable offer. + return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr)); } -WebSocketEncoder::WebSocketEncoder(bool is_server, - int deflate_bits, - int inflate_bits, - bool no_context_takeover) - : is_server_(is_server) { - deflater_.reset(new WebSocketDeflater( - no_context_takeover ? WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT - : WebSocketDeflater::TAKE_OVER_CONTEXT)); - inflater_.reset( - new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize)); +// static +WebSocketEncoder* WebSocketEncoder::CreateClient( + const std::string& response_extensions) { + // TODO(yhirano): Add a way to return an error. - if (!deflater_->Initialize(deflate_bits) || - !inflater_->Initialize(inflate_bits)) { - // Disable deflate support. - deflater_.reset(); - inflater_.reset(); + WebSocketExtensionParser parser; + if (!parser.Parse(response_extensions)) { + // Parse error. Note that there are two cases here. + // 1) There is no Sec-WebSocket-Extensions header. + // 2) There is a malformed Sec-WebSocketExtensions header. + // We should return a deflate-disabled encoder for the former case and + // fail the connection for the latter case. + return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); + } + if (parser.extensions().size() != 1) { + // Only permessage-deflate extension is supported. + // TODO (yhirano): Fail the connection. + return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); + } + const auto& extension = parser.extensions()[0]; + WebSocketDeflateParameters params; + std::string failure_message; + if (!params.Initialize(extension, &failure_message) || + !params.IsValidAsResponse(&failure_message)) { + // TODO (yhirano): Fail the connection. + return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); } -} -WebSocketEncoder::~WebSocketEncoder() { + auto deflater = make_scoped_ptr( + new WebSocketDeflater(params.client_context_take_over_mode())); + auto inflater = make_scoped_ptr( + new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize)); + if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) || + !inflater->Initialize(params.PermissiveServerMaxWindowBits())) { + // TODO (yhirano): Fail the connection. + return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); + } + + return new WebSocketEncoder(FOR_CLIENT, deflater.Pass(), inflater.Pass()); } +WebSocketEncoder::WebSocketEncoder(Type type, + scoped_ptr<WebSocketDeflater> deflater, + scoped_ptr<WebSocketInflater> inflater) + : type_(type), deflater_(deflater.Pass()), inflater_(inflater.Pass()) {} + +WebSocketEncoder::~WebSocketEncoder() {} + WebSocket::ParseResult WebSocketEncoder::DecodeFrame( const base::StringPiece& frame, int* bytes_consumed, std::string* output) { bool compressed; - WebSocket::ParseResult result = - DecodeFrameHybi17(frame, is_server_, bytes_consumed, output, &compressed); + WebSocket::ParseResult result = DecodeFrameHybi17( + frame, type_ == FOR_SERVER, bytes_consumed, output, &compressed); if (result == WebSocket::FRAME_OK && compressed) { if (!Inflate(output)) result = WebSocket::FRAME_ERROR; diff --git a/net/server/web_socket_encoder.h b/net/server/web_socket_encoder.h index 23f0d9c..1eb749f 100644 --- a/net/server/web_socket_encoder.h +++ b/net/server/web_socket_encoder.h @@ -16,61 +16,50 @@ namespace net { -class WebSocketEncoder { +class WebSocketDeflateParameters; + +class WebSocketEncoder final { public: - ~WebSocketEncoder(); + static const char kClientExtensions[]; - static WebSocketEncoder* CreateServer(const std::string& request_extensions, - std::string* response_extensions); + ~WebSocketEncoder(); - static const char kClientExtensions[]; + // Creates and returns an encoder for a server without extensions. + static scoped_ptr<WebSocketEncoder> CreateServer(); + // Creates and returns an encoder. + // |extensions| is the value of a Sec-WebSocket-Extensions header. + // Returns nullptr when there is an error. + static scoped_ptr<WebSocketEncoder> CreateServer( + const std::string& extensions, + WebSocketDeflateParameters* params); + // TODO(yhirano): Return a scoped_ptr instead of a raw pointer. static WebSocketEncoder* CreateClient(const std::string& response_extensions); WebSocket::ParseResult DecodeFrame(const base::StringPiece& frame, int* bytes_consumed, std::string* output); - void EncodeFrame(const std::string& frame, int masking_key, std::string* output); + bool deflate_enabled() const { return deflater_; } + private: - explicit WebSocketEncoder(bool is_server); - WebSocketEncoder(bool is_server, - int deflate_bits, - int inflate_bits, - bool no_context_takeover); - - // Parses a value in the Sec-WebSocket-Extensions header. If it contains a - // single element of the permessage-deflate extension, stores the result of - // parsing the parameters of the extension into the given variables. - // Otherwise, returns with *deflate set to false. - // - // - If the client_max_window_bits parameter is missing, *client_window_bits - // defaults to 15. - // - If the client_max_window_bits parameter has an invalid value, - // *client_window_bits will be set to 0. - // - If the server_max_window_bits parameter is missing, *server_window_bits - // defaults to 15. - // - If the server_max_window_bits parameter has an invalid value, - // *client_window_bits will be set to 0. - // - // TODO(tyoshino): Consider using a struct than taking a lot of pointers for - // output. - static void ParseExtensions(const std::string& header_value, - bool* deflate, - bool* has_client_window_bits, - int* client_window_bits, - int* server_window_bits, - bool* client_no_context_takeover, - bool* server_no_context_takeover); + enum Type { + FOR_SERVER, + FOR_CLIENT, + }; + + WebSocketEncoder(Type type, + scoped_ptr<WebSocketDeflater> deflater, + scoped_ptr<WebSocketInflater> inflater); bool Inflate(std::string* message); bool Deflate(const std::string& message, std::string* output); + Type type_; scoped_ptr<WebSocketDeflater> deflater_; scoped_ptr<WebSocketInflater> inflater_; - bool is_server_; DISALLOW_COPY_AND_ASSIGN(WebSocketEncoder); }; diff --git a/net/server/web_socket_encoder_unittest.cc b/net/server/web_socket_encoder_unittest.cc index 7bca876..9991bd7 100644 --- a/net/server/web_socket_encoder_unittest.cc +++ b/net/server/web_socket_encoder_unittest.cc @@ -3,30 +3,76 @@ // found in the LICENSE file. #include "net/server/web_socket_encoder.h" + +#include "net/websockets/websocket_deflate_parameters.h" +#include "net/websockets/websocket_extension.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { +TEST(WebSocketEncoderHandshakeTest, EmptyRequestShouldBeRejected) { + WebSocketDeflateParameters params; + scoped_ptr<WebSocketEncoder> server = + WebSocketEncoder::CreateServer("", ¶ms); + + EXPECT_FALSE(server); +} + TEST(WebSocketEncoderHandshakeTest, CreateServerWithoutClientMaxWindowBitsParameter) { - std::string response_extensions; - scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer( - "permessage-deflate", &response_extensions)); - // The response must not include client_max_window_bits if the client didn't - // declare that it accepts the parameter. - EXPECT_EQ("permessage-deflate; server_max_window_bits=15", - response_extensions); + WebSocketDeflateParameters params; + scoped_ptr<WebSocketEncoder> server = + WebSocketEncoder::CreateServer("permessage-deflate", ¶ms); + + ASSERT_TRUE(server); + EXPECT_TRUE(server->deflate_enabled()); + EXPECT_EQ("permessage-deflate", params.AsExtension().ToString()); } TEST(WebSocketEncoderHandshakeTest, CreateServerWithServerNoContextTakeoverParameter) { - std::string response_extensions; - scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer( - "permessage-deflate; server_no_context_takeover", &response_extensions)); - EXPECT_EQ( - "permessage-deflate; server_max_window_bits=15" - "; server_no_context_takeover", - response_extensions); + WebSocketDeflateParameters params; + scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer( + "permessage-deflate; server_no_context_takeover", ¶ms); + ASSERT_TRUE(server); + EXPECT_TRUE(server->deflate_enabled()); + EXPECT_EQ("permessage-deflate; server_no_context_takeover", + params.AsExtension().ToString()); +} + +TEST(WebSocketEncoderHandshakeTest, FirstExtensionShouldBeChosen) { + WebSocketDeflateParameters params; + scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer( + "permessage-deflate; server_no_context_takeover," + "permessage-deflate; server_max_window_bits=15", + ¶ms); + + ASSERT_TRUE(server); + EXPECT_TRUE(server->deflate_enabled()); + EXPECT_EQ("permessage-deflate; server_no_context_takeover", + params.AsExtension().ToString()); +} + +TEST(WebSocketEncoderHandshakeTest, FirstValidExtensionShouldBeChosen) { + WebSocketDeflateParameters params; + scoped_ptr<WebSocketEncoder> server = WebSocketEncoder::CreateServer( + "permessage-deflate; Xserver_no_context_takeover," + "permessage-deflate; server_max_window_bits=15", + ¶ms); + + ASSERT_TRUE(server); + EXPECT_TRUE(server->deflate_enabled()); + EXPECT_EQ("permessage-deflate; server_max_window_bits=15", + params.AsExtension().ToString()); +} + +TEST(WebSocketEncoderHandshakeTest, AllExtensionsAreUnknownOrMalformed) { + WebSocketDeflateParameters params; + scoped_ptr<WebSocketEncoder> server = + WebSocketEncoder::CreateServer("unknown, permessage-deflate; x", ¶ms); + + ASSERT_TRUE(server); + EXPECT_FALSE(server->deflate_enabled()); } class WebSocketEncoderTest : public testing::Test { @@ -35,7 +81,7 @@ class WebSocketEncoderTest : public testing::Test { void SetUp() override { std::string response_extensions; - server_.reset(WebSocketEncoder::CreateServer("", &response_extensions)); + server_ = WebSocketEncoder::CreateServer(); EXPECT_EQ(std::string(), response_extensions); client_.reset(WebSocketEncoder::CreateClient("")); } @@ -50,17 +96,29 @@ class WebSocketEncoderCompressionTest : public WebSocketEncoderTest { WebSocketEncoderCompressionTest() : WebSocketEncoderTest() {} void SetUp() override { - std::string response_extensions; - server_.reset(WebSocketEncoder::CreateServer( - "permessage-deflate; client_max_window_bits", &response_extensions)); - EXPECT_EQ( - "permessage-deflate; server_max_window_bits=15; " - "client_max_window_bits=15", - response_extensions); - client_.reset(WebSocketEncoder::CreateClient(response_extensions)); + WebSocketDeflateParameters params; + server_ = WebSocketEncoder::CreateServer( + "permessage-deflate; client_max_window_bits", ¶ms); + ASSERT_TRUE(server_); + EXPECT_TRUE(server_->deflate_enabled()); + EXPECT_EQ("permessage-deflate; client_max_window_bits=15", + params.AsExtension().ToString()); + client_.reset( + WebSocketEncoder::CreateClient(params.AsExtension().ToString())); } }; +TEST_F(WebSocketEncoderTest, DeflateDisabledEncoder) { + scoped_ptr<WebSocketEncoder> server(WebSocketEncoder::CreateServer()); + scoped_ptr<WebSocketEncoder> client(WebSocketEncoder::CreateClient("")); + + ASSERT_TRUE(server); + ASSERT_TRUE(client); + + EXPECT_FALSE(server->deflate_enabled()); + EXPECT_FALSE(client->deflate_enabled()); +} + TEST_F(WebSocketEncoderTest, ClientToServer) { std::string frame("ClientToServer"); int mask = 123456; |