// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_handshake_stream_create_helper.h" #include #include #include "net/base/completion_callback.h" #include "net/base/net_errors.h" #include "net/http/http_request_headers.h" #include "net/http/http_request_info.h" #include "net/http/http_response_headers.h" #include "net/http/http_response_info.h" #include "net/socket/client_socket_handle.h" #include "net/socket/socket_test_util.h" #include "net/websockets/websocket_basic_handshake_stream.h" #include "net/websockets/websocket_stream.h" #include "net/websockets/websocket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" namespace net { namespace { // This class encapsulates the details of creating a mock ClientSocketHandle. class MockClientSocketHandleFactory { public: MockClientSocketHandleFactory() : histograms_("a"), pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {} // The created socket expects |expect_written| to be written to the socket, // and will respond with |return_to_read|. The test will fail if the expected // text is not written, or if all the bytes are not read. scoped_ptr CreateClientSocketHandle( const std::string& expect_written, const std::string& return_to_read) { socket_factory_maker_.SetExpectations(expect_written, return_to_read); scoped_ptr socket_handle(new ClientSocketHandle); socket_handle->Init( "a", scoped_refptr(), MEDIUM, CompletionCallback(), &pool_, BoundNetLog()); return socket_handle.Pass(); } private: WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_; ClientSocketPoolHistograms histograms_; MockTransportClientSocketPool pool_; DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory); }; class TestConnectDelegate : public WebSocketStream::ConnectDelegate { public: ~TestConnectDelegate() override {} void OnSuccess(scoped_ptr stream) override {} void OnFailure(const std::string& failure_message) override {} void OnStartOpeningHandshake( scoped_ptr request) override {} void OnFinishOpeningHandshake( scoped_ptr response) override {} void OnSSLCertificateError( scoped_ptr ssl_error_callbacks, const SSLInfo& ssl_info, bool fatal) override {} }; class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { protected: scoped_ptr CreateAndInitializeStream( const std::string& socket_url, const std::string& socket_host, const std::string& socket_path, const std::vector& sub_protocols, const std::string& origin, const std::string& extra_request_headers, const std::string& extra_response_headers) { WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_, sub_protocols); create_helper.set_failure_message(&failure_message_); scoped_ptr socket_handle = socket_handle_factory_.CreateClientSocketHandle( WebSocketStandardRequest(socket_path, socket_host, origin, extra_request_headers), WebSocketStandardResponse(extra_response_headers)); scoped_ptr handshake( create_helper.CreateBasicStream(socket_handle.Pass(), false)); // If in future the implementation type returned by CreateBasicStream() // changes, this static_cast will be wrong. However, in that case the test // will fail and AddressSanitizer should identify the issue. static_cast(handshake.get()) ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); HttpRequestInfo request_info; request_info.url = GURL(socket_url); request_info.method = "GET"; request_info.load_flags = LOAD_DISABLE_CACHE; int rv = handshake->InitializeStream( &request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback()); EXPECT_EQ(OK, rv); HttpRequestHeaders headers; headers.SetHeader("Host", "localhost"); headers.SetHeader("Connection", "Upgrade"); headers.SetHeader("Pragma", "no-cache"); headers.SetHeader("Cache-Control", "no-cache"); headers.SetHeader("Upgrade", "websocket"); headers.SetHeader("Origin", origin); headers.SetHeader("Sec-WebSocket-Version", "13"); headers.SetHeader("User-Agent", ""); headers.SetHeader("Accept-Encoding", "gzip, deflate"); headers.SetHeader("Accept-Language", "en-us,fr"); HttpResponseInfo response; TestCompletionCallback dummy; rv = handshake->SendRequest(headers, &response, dummy.callback()); EXPECT_EQ(OK, rv); rv = handshake->ReadResponseHeaders(dummy.callback()); EXPECT_EQ(OK, rv); EXPECT_EQ(101, response.headers->response_code()); EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade")); EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket")); return handshake->Upgrade(); } MockClientSocketHandleFactory socket_handle_factory_; TestConnectDelegate connect_delegate_; std::string failure_message_; }; // Confirm that the basic case works as expected. TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) { scoped_ptr stream = CreateAndInitializeStream( "ws://localhost/", "localhost", "/", std::vector(), "http://localhost/", "", ""); EXPECT_EQ("", stream->GetExtensions()); EXPECT_EQ("", stream->GetSubProtocol()); } // Verify that the sub-protocols are passed through. TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) { std::vector sub_protocols; sub_protocols.push_back("chat"); sub_protocols.push_back("superchat"); scoped_ptr stream = CreateAndInitializeStream( "ws://localhost/", "localhost", "/", sub_protocols, "http://localhost/", "Sec-WebSocket-Protocol: chat, superchat\r\n", "Sec-WebSocket-Protocol: superchat\r\n"); EXPECT_EQ("superchat", stream->GetSubProtocol()); } // Verify that extension name is available. Bad extension names are tested in // websocket_stream_test.cc. TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) { scoped_ptr stream = CreateAndInitializeStream( "ws://localhost/", "localhost", "/", std::vector(), "http://localhost/", "", "Sec-WebSocket-Extensions: permessage-deflate\r\n"); EXPECT_EQ("permessage-deflate", stream->GetExtensions()); } // Verify that extension parameters are available. Bad parameters are tested in // websocket_stream_test.cc. TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) { scoped_ptr stream = CreateAndInitializeStream( "ws://localhost/", "localhost", "/", std::vector(), "http://localhost/", "", "Sec-WebSocket-Extensions: permessage-deflate;" " client_max_window_bits=14; server_max_window_bits=14;" " server_no_context_takeover; client_no_context_takeover\r\n"); EXPECT_EQ( "permessage-deflate;" " client_max_window_bits=14; server_max_window_bits=14;" " server_no_context_takeover; client_no_context_takeover", stream->GetExtensions()); } } // namespace } // namespace net