// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "content/browser/renderer_host/websocket_host.h" #include <inttypes.h> #include <utility> #include "base/bind.h" #include "base/bind_helpers.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" #include "base/single_thread_task_runner.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/thread_task_runner_handle.h" #include "content/browser/bad_message.h" #include "content/browser/renderer_host/websocket_blob_sender.h" #include "content/browser/renderer_host/websocket_dispatcher_host.h" #include "content/browser/ssl/ssl_error_handler.h" #include "content/browser/ssl/ssl_manager.h" #include "content/common/websocket_messages.h" #include "content/public/browser/browser_thread.h" #include "content/public/browser/render_frame_host.h" #include "content/public/browser/storage_partition.h" #include "ipc/ipc_message_macros.h" #include "net/base/net_errors.h" #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" #include "net/ssl/ssl_info.h" #include "net/websockets/websocket_channel.h" #include "net/websockets/websocket_errors.h" #include "net/websockets/websocket_event_interface.h" #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode #include "net/websockets/websocket_handshake_request_info.h" #include "net/websockets/websocket_handshake_response_info.h" #include "url/origin.h" namespace content { namespace { typedef net::WebSocketEventInterface::ChannelState ChannelState; // Convert a content::WebSocketMessageType to a // net::WebSocketFrameHeader::OpCode net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( WebSocketMessageType type) { DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION || type == WEB_SOCKET_MESSAGE_TYPE_TEXT || type == WEB_SOCKET_MESSAGE_TYPE_BINARY); typedef net::WebSocketFrameHeader::OpCode OpCode; // These compile asserts verify that the same underlying values are used for // both types, so we can simply cast between them. static_assert(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) == net::WebSocketFrameHeader::kOpCodeContinuation, "enum values must match for opcode continuation"); static_assert(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) == net::WebSocketFrameHeader::kOpCodeText, "enum values must match for opcode text"); static_assert(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) == net::WebSocketFrameHeader::kOpCodeBinary, "enum values must match for opcode binary"); return static_cast<OpCode>(type); } WebSocketMessageType OpCodeToMessageType( net::WebSocketFrameHeader::OpCode opCode) { DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation || opCode == net::WebSocketFrameHeader::kOpCodeText || opCode == net::WebSocketFrameHeader::kOpCodeBinary); // This cast is guaranteed valid by the static_assert() statements above. return static_cast<WebSocketMessageType>(opCode); } ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) { const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE = WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE; const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED = WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED; DCHECK(host_state == WEBSOCKET_HOST_ALIVE || host_state == WEBSOCKET_HOST_DELETED); // These compile asserts verify that we can get away with using static_cast<> // for the conversion. static_assert(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) == net::WebSocketEventInterface::CHANNEL_ALIVE, "enum values must match for state_alive"); static_assert(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) == net::WebSocketEventInterface::CHANNEL_DELETED, "enum values must match for state_deleted"); return static_cast<ChannelState>(host_state); } // Implementation of WebSocketBlobSender::Channel class SendChannelImpl final : public WebSocketBlobSender::Channel { public: explicit SendChannelImpl(net::WebSocketChannel* channel) : channel_(channel) {} // Implementation of WebSocketBlobSender::Channel size_t GetSendQuota() const override { return static_cast<size_t>(channel_->current_send_quota()); } ChannelState SendFrame(bool fin, const std::vector<char>& data) override { int opcode = first_frame_ ? net::WebSocketFrameHeader::kOpCodeBinary : net::WebSocketFrameHeader::kOpCodeContinuation; first_frame_ = false; return channel_->SendFrame(fin, opcode, data); } private: net::WebSocketChannel* channel_; bool first_frame_ = true; DISALLOW_COPY_AND_ASSIGN(SendChannelImpl); }; } // namespace // Implementation of net::WebSocketEventInterface. Receives events from our // WebSocketChannel object. Each event is translated to an IPC and sent to the // renderer or child process via WebSocketDispatcherHost. class WebSocketHost::WebSocketEventHandler final : public net::WebSocketEventInterface { public: WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, WebSocketHost* host, int routing_id, int render_frame_id); ~WebSocketEventHandler() override; // net::WebSocketEventInterface implementation ChannelState OnAddChannelResponse(const std::string& selected_subprotocol, const std::string& extensions) override; ChannelState OnDataFrame(bool fin, WebSocketMessageType type, const std::vector<char>& data) override; ChannelState OnClosingHandshake() override; ChannelState OnFlowControl(int64_t quota) override; ChannelState OnDropChannel(bool was_clean, uint16_t code, const std::string& reason) override; ChannelState OnFailChannel(const std::string& message) override; ChannelState OnStartOpeningHandshake( scoped_ptr<net::WebSocketHandshakeRequestInfo> request) override; ChannelState OnFinishOpeningHandshake( scoped_ptr<net::WebSocketHandshakeResponseInfo> response) override; ChannelState OnSSLCertificateError( scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, const GURL& url, const net::SSLInfo& ssl_info, bool fatal) override; private: class SSLErrorHandlerDelegate final : public SSLErrorHandler::Delegate { public: SSLErrorHandlerDelegate( scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks); ~SSLErrorHandlerDelegate() override; base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr(); // SSLErrorHandler::Delegate methods void CancelSSLRequest(int error, const net::SSLInfo* ssl_info) override; void ContinueSSLRequest() override; private: scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_; base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate); }; WebSocketDispatcherHost* const dispatcher_; WebSocketHost* const host_; const int routing_id_; const int render_frame_id_; scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_; DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); }; WebSocketHost::WebSocketEventHandler::WebSocketEventHandler( WebSocketDispatcherHost* dispatcher, WebSocketHost* host, int routing_id, int render_frame_id) : dispatcher_(dispatcher), host_(host), routing_id_(routing_id), render_frame_id_(render_frame_id) {} WebSocketHost::WebSocketEventHandler::~WebSocketEventHandler() { DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_; } ChannelState WebSocketHost::WebSocketEventHandler::OnAddChannelResponse( const std::string& selected_protocol, const std::string& extensions) { DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse" << " routing_id=" << routing_id_ << " selected_protocol=\"" << selected_protocol << "\"" << " extensions=\"" << extensions << "\""; return StateCast(dispatcher_->SendAddChannelResponse( routing_id_, selected_protocol, extensions)); } ChannelState WebSocketHost::WebSocketEventHandler::OnDataFrame( bool fin, net::WebSocketFrameHeader::OpCode type, const std::vector<char>& data) { DVLOG(3) << "WebSocketEventHandler::OnDataFrame" << " routing_id=" << routing_id_ << " fin=" << fin << " type=" << type << " data is " << data.size() << " bytes"; return StateCast(dispatcher_->SendFrame(routing_id_, fin, OpCodeToMessageType(type), data)); } ChannelState WebSocketHost::WebSocketEventHandler::OnClosingHandshake() { DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake" << " routing_id=" << routing_id_; return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_)); } ChannelState WebSocketHost::WebSocketEventHandler::OnFlowControl( int64_t quota) { DVLOG(3) << "WebSocketEventHandler::OnFlowControl" << " routing_id=" << routing_id_ << " quota=" << quota; if (host_->blob_sender_) host_->blob_sender_->OnNewSendQuota(); return StateCast(dispatcher_->SendFlowControl(routing_id_, quota)); } ChannelState WebSocketHost::WebSocketEventHandler::OnDropChannel( bool was_clean, uint16_t code, const std::string& reason) { DVLOG(3) << "WebSocketEventHandler::OnDropChannel" << " routing_id=" << routing_id_ << " was_clean=" << was_clean << " code=" << code << " reason=\"" << reason << "\""; return StateCast( dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason)); } ChannelState WebSocketHost::WebSocketEventHandler::OnFailChannel( const std::string& message) { DVLOG(3) << "WebSocketEventHandler::OnFailChannel" << " routing_id=" << routing_id_ << " message=\"" << message << "\""; return StateCast(dispatcher_->NotifyFailure(routing_id_, message)); } ChannelState WebSocketHost::WebSocketEventHandler::OnStartOpeningHandshake( scoped_ptr<net::WebSocketHandshakeRequestInfo> request) { bool should_send = dispatcher_->CanReadRawCookies(); DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake " << "should_send=" << should_send; if (!should_send) return WebSocketEventInterface::CHANNEL_ALIVE; WebSocketHandshakeRequest request_to_pass; request_to_pass.url.Swap(&request->url); net::HttpRequestHeaders::Iterator it(request->headers); while (it.GetNext()) request_to_pass.headers.push_back(std::make_pair(it.name(), it.value())); request_to_pass.headers_text = base::StringPrintf("GET %s HTTP/1.1\r\n", request_to_pass.url.spec().c_str()) + request->headers.ToString(); request_to_pass.request_time = request->request_time; return StateCast( dispatcher_->NotifyStartOpeningHandshake(routing_id_, request_to_pass)); } ChannelState WebSocketHost::WebSocketEventHandler::OnFinishOpeningHandshake( scoped_ptr<net::WebSocketHandshakeResponseInfo> response) { bool should_send = dispatcher_->CanReadRawCookies(); DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake " << "should_send=" << should_send; if (!should_send) return WebSocketEventInterface::CHANNEL_ALIVE; WebSocketHandshakeResponse response_to_pass; response_to_pass.url.Swap(&response->url); response_to_pass.status_code = response->status_code; response_to_pass.status_text.swap(response->status_text); size_t iter = 0; std::string name, value; while (response->headers->EnumerateHeaderLines(&iter, &name, &value)) response_to_pass.headers.push_back(std::make_pair(name, value)); response_to_pass.headers_text = net::HttpUtil::ConvertHeadersBackToHTTPResponse( response->headers->raw_headers()); response_to_pass.response_time = response->response_time; return StateCast( dispatcher_->NotifyFinishOpeningHandshake(routing_id_, response_to_pass)); } ChannelState WebSocketHost::WebSocketEventHandler::OnSSLCertificateError( scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, const GURL& url, const net::SSLInfo& ssl_info, bool fatal) { DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError" << " routing_id=" << routing_id_ << " url=" << url.spec() << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal; ssl_error_handler_delegate_.reset( new SSLErrorHandlerDelegate(std::move(callbacks))); SSLManager::OnSSLCertificateSubresourceError( ssl_error_handler_delegate_->GetWeakPtr(), url, dispatcher_->render_process_id(), render_frame_id_, ssl_info, fatal); // The above method is always asynchronous. return WebSocketEventInterface::CHANNEL_ALIVE; } WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: SSLErrorHandlerDelegate( scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks) : callbacks_(std::move(callbacks)), weak_ptr_factory_(this) {} WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: ~SSLErrorHandlerDelegate() {} base::WeakPtr<SSLErrorHandler::Delegate> WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { return weak_ptr_factory_.GetWeakPtr(); } void WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: CancelSSLRequest(int error, const net::SSLInfo* ssl_info) { DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" << " error=" << error << " cert_status=" << (ssl_info ? ssl_info->cert_status : static_cast<net::CertStatus>(-1)); callbacks_->CancelSSLRequest(error, ssl_info); } void WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: ContinueSSLRequest() { DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; callbacks_->ContinueSSLRequest(); } WebSocketHost::WebSocketHost(int routing_id, WebSocketDispatcherHost* dispatcher, net::URLRequestContext* url_request_context, base::TimeDelta delay) : dispatcher_(dispatcher), url_request_context_(url_request_context), routing_id_(routing_id), delay_(delay), pending_flow_control_quota_(0), handshake_succeeded_(false), weak_ptr_factory_(this) { DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id; } WebSocketHost::~WebSocketHost() {} void WebSocketHost::GoAway() { OnDropChannel(false, static_cast<uint16_t>(net::kWebSocketErrorGoingAway), ""); } bool WebSocketHost::OnMessageReceived(const IPC::Message& message) { bool handled = true; IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message) IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest) IPC_MESSAGE_HANDLER(WebSocketHostMsg_SendBlob, OnSendBlob) IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame) IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl) IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel) IPC_MESSAGE_UNHANDLED(handled = false) IPC_END_MESSAGE_MAP() return handled; } void WebSocketHost::OnAddChannelRequest( const GURL& socket_url, const std::vector<std::string>& requested_protocols, const url::Origin& origin, int render_frame_id) { DVLOG(3) << "WebSocketHost::OnAddChannelRequest" << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url << "\" requested_protocols=\"" << base::JoinString(requested_protocols, ", ") << "\" origin=\"" << origin << "\""; DCHECK(!channel_); if (delay_ > base::TimeDelta()) { base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( FROM_HERE, base::Bind(&WebSocketHost::AddChannel, weak_ptr_factory_.GetWeakPtr(), socket_url, requested_protocols, origin, render_frame_id), delay_); } else { AddChannel(socket_url, requested_protocols, origin, render_frame_id); } // |this| may have been deleted here. } void WebSocketHost::AddChannel( const GURL& socket_url, const std::vector<std::string>& requested_protocols, const url::Origin& origin, int render_frame_id) { DVLOG(3) << "WebSocketHost::AddChannel" << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url << "\" requested_protocols=\"" << base::JoinString(requested_protocols, ", ") << "\" origin=\"" << origin << "\""; DCHECK(!channel_); scoped_ptr<net::WebSocketEventInterface> event_interface( new WebSocketEventHandler(dispatcher_, this, routing_id_, render_frame_id)); channel_.reset(new net::WebSocketChannel(std::move(event_interface), url_request_context_)); if (pending_flow_control_quota_ > 0) { // channel_->SendFlowControl(pending_flow_control_quota_) must be called // after channel_->SendAddChannelRequest() below. // We post OnFlowControl() here using |weak_ptr_factory_| instead of // calling SendFlowControl directly, because |this| may have been deleted // after channel_->SendAddChannelRequest(). base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::Bind(&WebSocketHost::OnFlowControl, weak_ptr_factory_.GetWeakPtr(), pending_flow_control_quota_)); pending_flow_control_quota_ = 0; } channel_->SendAddChannelRequest(socket_url, requested_protocols, origin); // |this| may have been deleted here. } void WebSocketHost::OnSendBlob(const std::string& uuid, uint64_t expected_size) { DVLOG(3) << "WebSocketHost::OnSendBlob" << " routing_id=" << routing_id_ << " uuid=" << uuid << " expected_size=" << expected_size; DCHECK(channel_); if (blob_sender_) { bad_message::ReceivedBadMessage( dispatcher_, bad_message::WSH_SEND_BLOB_DURING_BLOB_SEND); return; } blob_sender_.reset(new WebSocketBlobSender( make_scoped_ptr(new SendChannelImpl(channel_.get())))); StoragePartition* partition = dispatcher_->storage_partition(); storage::FileSystemContext* file_system_context = partition->GetFileSystemContext(); net::WebSocketEventInterface::ChannelState channel_state = net::WebSocketEventInterface::CHANNEL_ALIVE; // This use of base::Unretained is safe because the WebSocketBlobSender object // is owned by this object and will not call it back after destruction. int rv = blob_sender_->Start( uuid, expected_size, dispatcher_->blob_storage_context(), file_system_context, BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE).get(), &channel_state, base::Bind(&WebSocketHost::BlobSendComplete, base::Unretained(this))); if (channel_state == net::WebSocketEventInterface::CHANNEL_ALIVE && rv != net::ERR_IO_PENDING) BlobSendComplete(rv); // |this| may be destroyed here. } void WebSocketHost::OnSendFrame(bool fin, WebSocketMessageType type, const std::vector<char>& data) { DVLOG(3) << "WebSocketHost::OnSendFrame" << " routing_id=" << routing_id_ << " fin=" << fin << " type=" << type << " data is " << data.size() << " bytes"; DCHECK(channel_); if (blob_sender_) { bad_message::ReceivedBadMessage( dispatcher_, bad_message::WSH_SEND_FRAME_DURING_BLOB_SEND); return; } channel_->SendFrame(fin, MessageTypeToOpCode(type), data); } void WebSocketHost::OnFlowControl(int64_t quota) { DVLOG(3) << "WebSocketHost::OnFlowControl" << " routing_id=" << routing_id_ << " quota=" << quota; if (!channel_) { // WebSocketChannel is not yet created due to the delay introduced by // per-renderer WebSocket throttling. // SendFlowControl() is called after WebSocketChannel is created. pending_flow_control_quota_ += quota; return; } channel_->SendFlowControl(quota); } void WebSocketHost::OnDropChannel(bool was_clean, uint16_t code, const std::string& reason) { DVLOG(3) << "WebSocketHost::OnDropChannel" << " routing_id=" << routing_id_ << " was_clean=" << was_clean << " code=" << code << " reason=\"" << reason << "\""; if (!channel_) { // WebSocketChannel is not yet created due to the delay introduced by // per-renderer WebSocket throttling. WebSocketDispatcherHost::WebSocketHostState result = dispatcher_->DoDropChannel(routing_id_, false, net::kWebSocketErrorAbnormalClosure, ""); DCHECK_EQ(WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED, result); return; } blob_sender_.reset(); // TODO(yhirano): Handle |was_clean| appropriately. channel_->StartClosingHandshake(code, reason); } void WebSocketHost::BlobSendComplete(int result) { DVLOG(3) << "WebSocketHost::BlobSendComplete" << " routing_id=" << routing_id_ << " result=" << net::ErrorToString(result); // All paths through this method must reset blob_sender_, so take ownership // at the beginning. scoped_ptr<WebSocketBlobSender> blob_sender(std::move(blob_sender_)); switch (result) { case net::OK: ignore_result(dispatcher_->BlobSendComplete(routing_id_)); // |this| may be destroyed here. return; case net::ERR_UPLOAD_FILE_CHANGED: { uint64_t expected_size = blob_sender->expected_size(); uint64_t actual_size = blob_sender->ActualSize(); if (expected_size != actual_size) { ignore_result(dispatcher_->NotifyFailure( routing_id_, base::StringPrintf("Blob size mismatch; renderer size = %" PRIu64 ", browser size = %" PRIu64, expected_size, actual_size))); // |this| is destroyed here. return; } // else fallthrough } default: ignore_result(dispatcher_->NotifyFailure( routing_id_, "Failed to load Blob: error code = " + net::ErrorToString(result))); // |this| is destroyed here. return; } } } // namespace content