// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "remoting/host/gnubby_auth_handler_posix.h" #include #include #include "base/bind.h" #include "base/files/file_util.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" #include "base/lazy_instance.h" #include "base/stl_util.h" #include "base/values.h" #include "net/socket/unix_domain_listen_socket_posix.h" #include "remoting/base/logging.h" #include "remoting/host/gnubby_socket.h" #include "remoting/proto/control.pb.h" #include "remoting/protocol/client_stub.h" namespace remoting { namespace { const char kConnectionId[] = "connectionId"; const char kControlMessage[] = "control"; const char kControlOption[] = "option"; const char kDataMessage[] = "data"; const char kDataPayload[] = "data"; const char kErrorMessage[] = "error"; const char kGnubbyAuthMessage[] = "gnubby-auth"; const char kGnubbyAuthV1[] = "auth-v1"; const char kMessageType[] = "type"; // The name of the socket to listen for gnubby requests on. base::LazyInstance::Leaky g_gnubby_socket_name = LAZY_INSTANCE_INITIALIZER; // STL predicate to match by a StreamListenSocket pointer. class CompareSocket { public: explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {} bool operator()(const std::pair element) const { return element.second->IsSocket(socket_); } private: net::StreamListenSocket* socket_; }; // Socket authentication function that only allows connections from callers with // the current uid. bool MatchUid(const net::UnixDomainServerSocket::Credentials& credentials) { bool allowed = credentials.user_id == getuid(); if (!allowed) HOST_LOG << "Refused socket connection from uid " << credentials.user_id; return allowed; } // Returns the command code (the first byte of the data) if it exists, or -1 if // the data is empty. unsigned int GetCommandCode(const std::string& data) { return data.empty() ? -1 : static_cast(data[0]); } // Creates a string of byte data from a ListValue of numbers. Returns true if // all of the list elements are numbers. bool ConvertListValueToString(base::ListValue* bytes, std::string* out) { out->clear(); unsigned int byte_count = bytes->GetSize(); if (byte_count != 0) { out->reserve(byte_count); for (unsigned int i = 0; i < byte_count; i++) { int value; if (!bytes->GetInteger(i, &value)) return false; out->push_back(static_cast(value)); } } return true; } } // namespace GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix( protocol::ClientStub* client_stub) : client_stub_(client_stub), last_connection_id_(0) { DCHECK(client_stub_); } GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() { STLDeleteValues(&active_sockets_); } // static scoped_ptr GnubbyAuthHandler::Create( protocol::ClientStub* client_stub) { return make_scoped_ptr(new GnubbyAuthHandlerPosix(client_stub)); } // static void GnubbyAuthHandler::SetGnubbySocketName( const base::FilePath& gnubby_socket_name) { g_gnubby_socket_name.Get() = gnubby_socket_name; } void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { DCHECK(CalledOnValidThread()); scoped_ptr value(base::JSONReader::Read(message)); base::DictionaryValue* client_message; if (value && value->GetAsDictionary(&client_message)) { std::string type; if (!client_message->GetString(kMessageType, &type)) { LOG(ERROR) << "Invalid gnubby-auth message"; return; } if (type == kControlMessage) { std::string option; if (client_message->GetString(kControlOption, &option) && option == kGnubbyAuthV1) { CreateAuthorizationSocket(); } else { LOG(ERROR) << "Invalid gnubby-auth control option"; } } else if (type == kDataMessage) { ActiveSockets::iterator iter = GetSocketForMessage(client_message); if (iter != active_sockets_.end()) { base::ListValue* bytes; std::string response; if (client_message->GetList(kDataPayload, &bytes) && ConvertListValueToString(bytes, &response)) { HOST_LOG << "Sending gnubby response: " << GetCommandCode(response); iter->second->SendResponse(response); } else { LOG(ERROR) << "Invalid gnubby data"; SendErrorAndCloseActiveSocket(iter); } } else { LOG(ERROR) << "Unknown gnubby-auth data connection"; } } else if (type == kErrorMessage) { ActiveSockets::iterator iter = GetSocketForMessage(client_message); if (iter != active_sockets_.end()) { HOST_LOG << "Sending gnubby error"; SendErrorAndCloseActiveSocket(iter); } else { LOG(ERROR) << "Unknown gnubby-auth error connection"; } } else { LOG(ERROR) << "Unknown gnubby-auth message type: " << type; } } } void GnubbyAuthHandlerPosix::DeliverHostDataMessage( int connection_id, const std::string& data) const { DCHECK(CalledOnValidThread()); base::DictionaryValue request; request.SetString(kMessageType, kDataMessage); request.SetInteger(kConnectionId, connection_id); base::ListValue* bytes = new base::ListValue(); for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) { bytes->AppendInteger(static_cast(*i)); } request.Set(kDataPayload, bytes); std::string request_json; if (!base::JSONWriter::Write(&request, &request_json)) { LOG(ERROR) << "Failed to create request json"; return; } protocol::ExtensionMessage message; message.set_type(kGnubbyAuthMessage); message.set_data(request_json); client_stub_->DeliverHostMessage(message); } bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting( net::StreamListenSocket* socket) const { return std::find_if(active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)) != active_sockets_.end(); } int GnubbyAuthHandlerPosix::GetConnectionIdForTesting( net::StreamListenSocket* socket) const { ActiveSockets::const_iterator iter = std::find_if( active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); return iter->first; } GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting( net::StreamListenSocket* socket) const { ActiveSockets::const_iterator iter = std::find_if( active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); return iter->second; } void GnubbyAuthHandlerPosix::DidAccept( net::StreamListenSocket* server, scoped_ptr socket) { DCHECK(CalledOnValidThread()); int connection_id = ++last_connection_id_; active_sockets_[connection_id] = new GnubbySocket(socket.Pass(), base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut, base::Unretained(this), connection_id)); } void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, const char* data, int len) { DCHECK(CalledOnValidThread()); ActiveSockets::iterator iter = std::find_if( active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); if (iter != active_sockets_.end()) { GnubbySocket* gnubby_socket = iter->second; gnubby_socket->AddRequestData(data, len); if (gnubby_socket->IsRequestTooLarge()) { SendErrorAndCloseActiveSocket(iter); } else if (gnubby_socket->IsRequestComplete()) { std::string request_data; gnubby_socket->GetAndClearRequestData(&request_data); ProcessGnubbyRequest(iter->first, request_data); } } else { LOG(ERROR) << "Received data for unknown connection"; } } void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) { DCHECK(CalledOnValidThread()); ActiveSockets::iterator iter = std::find_if( active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); if (iter != active_sockets_.end()) { delete iter->second; active_sockets_.erase(iter); } } void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() { DCHECK(CalledOnValidThread()); if (!g_gnubby_socket_name.Get().empty()) { // If the file already exists, a socket in use error is returned. base::DeleteFile(g_gnubby_socket_name.Get(), false); HOST_LOG << "Listening for gnubby requests on " << g_gnubby_socket_name.Get().value(); auth_socket_ = net::deprecated::UnixDomainListenSocket::CreateAndListen( g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid)); if (!auth_socket_.get()) { LOG(ERROR) << "Failed to open socket for gnubby requests"; } } else { HOST_LOG << "No gnubby socket name specified"; } } void GnubbyAuthHandlerPosix::ProcessGnubbyRequest( int connection_id, const std::string& request_data) { HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data); DeliverHostDataMessage(connection_id, request_data); } GnubbyAuthHandlerPosix::ActiveSockets::iterator GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) { int connection_id; if (message->GetInteger(kConnectionId, &connection_id)) { return active_sockets_.find(connection_id); } return active_sockets_.end(); } void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket( const ActiveSockets::iterator& iter) { iter->second->SendSshError(); delete iter->second; active_sockets_.erase(iter); } void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) { HOST_LOG << "Gnubby request timed out"; ActiveSockets::iterator iter = active_sockets_.find(connection_id); if (iter != active_sockets_.end()) SendErrorAndCloseActiveSocket(iter); } } // namespace remoting