// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "components/devtools_bridge/socket_tunnel_server.h" #include "base/bind.h" #include "base/location.h" #include "components/devtools_bridge/abstract_data_channel.h" #include "components/devtools_bridge/session_dependency_factory.h" #include "components/devtools_bridge/socket_tunnel_connection.h" #include "components/devtools_bridge/socket_tunnel_packet_handler.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/socket/unix_domain_client_socket_posix.h" namespace devtools_bridge { class SocketTunnelServer::Connection : public SocketTunnelConnection { public: class Delegate { public: virtual void RemoveConnection(int index) = 0; virtual void SendPacket( const void* data, size_t length) = 0; }; Connection(Delegate* delegate, int index, const std::string& socket_name) : SocketTunnelConnection(index), delegate_(delegate), socket_(socket_name, true) { } void Connect() { int result = socket()->Connect(base::Bind( &Connection::OnConnectionComplete, base::Unretained(this))); if (result != net::ERR_IO_PENDING) OnConnectionComplete(result); } void ClosedByClient() { if (socket()->IsConnected()) { socket()->Disconnect(); SendControlPacket(SERVER_CLOSE); } delegate_->RemoveConnection(index_); } protected: net::StreamSocket* socket() override { return &socket_; } void OnDataPacketRead(const void* data, size_t length) override { delegate_->SendPacket(data, length); ReadNextChunk(); } void OnReadError(int error) override { socket()->Disconnect(); SendControlPacket(SERVER_CLOSE); delegate_->RemoveConnection(index_); delegate_ = NULL; } private: void OnConnectionComplete(int result) { if (result == net::OK) { SendControlPacket(SERVER_OPEN_ACK); ReadNextChunk(); } else { SendControlPacket(SERVER_CLOSE); delegate_->RemoveConnection(index_); delegate_ = NULL; } } void SendControlPacket(ServerOpCode op_code) { char buffer[kControlPacketSizeBytes]; BuildControlPacket(buffer, op_code); delegate_->SendPacket(buffer, kControlPacketSizeBytes); } Delegate* delegate_; net::UnixDomainClientSocket socket_; }; /** * Lives on the IO thread. */ class SocketTunnelServer::ConnectionController : private Connection::Delegate { public: ConnectionController( scoped_refptr io_task_runner, scoped_refptr data_channel, const std::string& socket_name) : io_task_runner_(io_task_runner), data_channel_(data_channel), socket_name_(socket_name) { DCHECK(data_channel_.get()); } void HandleControlPacket(int connection_index, int op_code) { DCHECK(connection_index < kMaxConnectionCount); switch (op_code) { case SocketTunnelConnection::CLIENT_OPEN: if (connections_[connection_index].get() != NULL) { DLOG(ERROR) << "Opening connection which already open: " << connection_index; HandleProtocolError(); return; } connections_[connection_index].reset( new Connection(this, connection_index, socket_name_)); connections_[connection_index]->Connect(); break; case SocketTunnelConnection::CLIENT_CLOSE: if (connections_[connection_index].get() == NULL) { // Ignore. Client may close the connection before received // notification from the server. return; } connections_[connection_index]->ClosedByClient(); break; default: DLOG(ERROR) << "Invalid op_code: " << op_code; HandleProtocolError(); return; } } void HandleDataPacket(int connection_index, scoped_refptr packet) { Connection* connection = connections_[connection_index].get(); if (connection != NULL) connection->Write(packet); } void HandleProtocolError() { data_channel_->Close(); } void CloseAllConnections() { for (int i = 0; i < kMaxConnectionCount; i++) { connections_[i].reset(); } } private: static void DeleteConnectionImpl(Connection*) {} // Connection::Delegate implementation void RemoveConnection(int connection_index) override { // Remove immediately, delete later to preserve this of the caller. Connection* connection = connections_[connection_index].release(); io_task_runner_->PostTask( FROM_HERE, base::Bind(&ConnectionController::DeleteConnectionImpl, base::Owned(connection))); } void SendPacket(const void* data, size_t length) override { data_channel_->SendBinaryMessage(data, length); } static const int kMaxConnectionCount = SocketTunnelConnection::kMaxConnectionCount; scoped_refptr io_task_runner_; scoped_refptr data_channel_; scoped_ptr connections_[kMaxConnectionCount]; const std::string socket_name_; }; class SocketTunnelServer::DataChannelObserver : public AbstractDataChannel::Observer, private SocketTunnelPacketHandler { public: DataChannelObserver(scoped_refptr io_task_runner, scoped_ptr controller) : io_task_runner_(io_task_runner), controller_(controller.Pass()) { } ~DataChannelObserver() override { // Deleting on IO thread allows post tasks with base::Unretained // because all of them will be processed before deletion. io_task_runner_->PostTask( FROM_HERE, base::Bind(&DataChannelObserver::DeleteControllerOnIOThread, base::Passed(&controller_))); } void OnOpen() override { // Nothing to do. Activity could only be initiated by a control packet. } void OnClose() override { io_task_runner_->PostTask( FROM_HERE, base::Bind( &ConnectionController::CloseAllConnections, base::Unretained(controller_.get()))); } void OnMessage(const void* data, size_t length) override { DecodePacket(data, length); } private: static void DeleteControllerOnIOThread( scoped_ptr controller) {} // SocketTunnelPacketHandler implementation. void HandleControlPacket(int connection_index, int op_code) override { io_task_runner_->PostTask( FROM_HERE, base::Bind( &ConnectionController::HandleControlPacket, base::Unretained(controller_.get()), connection_index, op_code)); } void HandleDataPacket(int connection_index, scoped_refptr data) override { io_task_runner_->PostTask( FROM_HERE, base::Bind( &ConnectionController::HandleDataPacket, base::Unretained(controller_.get()), connection_index, data)); } void HandleProtocolError() override { io_task_runner_->PostTask( FROM_HERE, base::Bind( &ConnectionController::HandleProtocolError, base::Unretained(controller_.get()))); } const scoped_refptr io_task_runner_; scoped_ptr controller_; }; SocketTunnelServer::SocketTunnelServer(SessionDependencyFactory* factory, AbstractDataChannel* data_channel, const std::string& socket_name) : data_channel_(data_channel) { scoped_ptr controller( new ConnectionController(factory->io_thread_task_runner(), data_channel->proxy(), socket_name)); scoped_ptr data_channel_observer( new DataChannelObserver(factory->io_thread_task_runner(), controller.Pass())); data_channel_->RegisterObserver(data_channel_observer.Pass()); } SocketTunnelServer::~SocketTunnelServer() { data_channel_->UnregisterObserver(); } } // namespace devtools_bridge