diff options
Diffstat (limited to 'mojo/edk/system/message_pipe_dispatcher.cc')
-rw-r--r-- | mojo/edk/system/message_pipe_dispatcher.cc | 164 |
1 files changed, 128 insertions, 36 deletions
diff --git a/mojo/edk/system/message_pipe_dispatcher.cc b/mojo/edk/system/message_pipe_dispatcher.cc index 6856def..90b3774 100644 --- a/mojo/edk/system/message_pipe_dispatcher.cc +++ b/mojo/edk/system/message_pipe_dispatcher.cc @@ -8,6 +8,7 @@ #include "base/logging.h" #include "base/message_loop/message_loop.h" #include "mojo/edk/embedder/embedder_internal.h" +#include "mojo/edk/embedder/platform_handle_utils.h" #include "mojo/edk/embedder/platform_shared_buffer.h" #include "mojo/edk/embedder/platform_support.h" #include "mojo/edk/system/configuration.h" @@ -38,6 +39,13 @@ struct MOJO_ALIGNAS(8) SerializedMessagePipeHandleDispatcher { size_t serialized_read_buffer_size; size_t serialized_write_buffer_size; size_t serialized_messagage_queue_size; + + // These are the FDs required as part of serializing channel_ and + // message_queue_. This is only used on POSIX. + size_t serialized_fds_index; // (Or |kInvalidMessagePipeHandleIndex|.) + size_t serialized_read_fds_length; + size_t serialized_write_fds_length; + size_t serialized_message_fds_length; }; char* SerializeBuffer(char* start, std::vector<char>* buffer) { @@ -65,6 +73,13 @@ bool GetHandle(size_t index, return true; } +#if defined(OS_POSIX) +void ClosePlatformHandles(std::vector<int>* fds) { + for (size_t i = 0; i < fds->size(); ++i) + PlatformHandle((*fds)[i]).CloseIfNecessary(); +} +#endif + } // namespace // MessagePipeDispatcher ------------------------------------------------------- @@ -104,14 +119,17 @@ MojoResult MessagePipeDispatcher::ValidateCreateOptions( void MessagePipeDispatcher::Init( ScopedPlatformHandle message_pipe, char* serialized_read_buffer, size_t serialized_read_buffer_size, - char* serialized_write_buffer, size_t serialized_write_buffer_size) { + char* serialized_write_buffer, size_t serialized_write_buffer_size, + std::vector<int>* serialized_read_fds, + std::vector<int>* serialized_write_fds) { if (message_pipe.get().is_valid()) { channel_ = RawChannel::Create(message_pipe.Pass()); // TODO(jam): It's probably cleaner to pass this in Init call. channel_->SetSerializedData( serialized_read_buffer, serialized_read_buffer_size, - serialized_write_buffer, serialized_write_buffer_size); + serialized_write_buffer, serialized_write_buffer_size, + serialized_read_fds, serialized_write_fds); if (g_use_channel_on_io_thread_only) { internal::g_io_thread_task_runner->PostTask( FROM_HERE, base::Bind(&MessagePipeDispatcher::InitOnIO, this)); @@ -251,13 +269,38 @@ scoped_refptr<MessagePipeDispatcher> MessagePipeDispatcher::Deserialize( scoped_refptr<MessagePipeDispatcher> rv( Create(MessagePipeDispatcher::kDefaultCreateOptions)); - rv->Init(platform_handle.Pass(), - serialized_read_buffer, - serialized_read_buffer_size, - serialized_write_buffer, - serialized_write_buffer_size); rv->write_error_ = serialization->write_error; + std::vector<int> serialized_read_fds; + std::vector<int> serialized_write_fds; +#if defined(OS_POSIX) + std::vector<int> serialized_fds; + size_t serialized_fds_index = 0; + + size_t total_fd_count = serialization->serialized_read_fds_length + + serialization->serialized_write_fds_length + + serialization->serialized_message_fds_length; + for (size_t i = 0; i < total_fd_count; ++i) { + ScopedPlatformHandle handle; + if (!GetHandle(serialization->serialized_fds_index + i, platform_handles, + &handle)) { + ClosePlatformHandles(&serialized_fds); + return nullptr; + } + serialized_fds.push_back(handle.release().fd); + } + + serialized_read_fds.assign( + serialized_fds.begin(), + serialized_fds.begin() + serialization->serialized_read_fds_length); + serialized_fds_index += serialization->serialized_read_fds_length; + serialized_write_fds.assign( + serialized_fds.begin() + serialized_fds_index, + serialized_fds.begin() + serialized_fds_index + + serialization->serialized_write_fds_length); + serialized_fds_index += serialization->serialized_write_fds_length; +#endif + while (message_queue_size) { size_t message_size; CHECK(MessageInTransit::GetNextMessageSize( @@ -277,6 +320,7 @@ scoped_refptr<MessagePipeDispatcher> MessagePipeDispatcher::Deserialize( &platform_handle_table); if (num_platform_handles > 0) { +#if defined(OS_WIN) temp_platform_handles = GetReadPlatformHandles(num_platform_handles, platform_handle_table).Pass(); @@ -284,6 +328,12 @@ scoped_refptr<MessagePipeDispatcher> MessagePipeDispatcher::Deserialize( LOG(ERROR) << "Invalid number of platform handles received"; return nullptr; } +#else + temp_platform_handles.reset(new PlatformHandleVector()); + for (size_t i = 0; i < num_platform_handles; ++i) + temp_platform_handles->push_back( + PlatformHandle(serialized_fds[serialized_fds_index++])); +#endif } } @@ -301,6 +351,14 @@ scoped_refptr<MessagePipeDispatcher> MessagePipeDispatcher::Deserialize( rv->message_queue_.AddMessage(message.Pass()); } + rv->Init(platform_handle.Pass(), + serialized_read_buffer, + serialized_read_buffer_size, + serialized_write_buffer, + serialized_write_buffer_size, + &serialized_read_fds, + &serialized_write_fds); + if (message_queue_size) { // Should be empty by now. LOG(ERROR) << "Invalid queued messages"; return nullptr; @@ -312,6 +370,9 @@ scoped_refptr<MessagePipeDispatcher> MessagePipeDispatcher::Deserialize( MessagePipeDispatcher::MessagePipeDispatcher() : channel_(nullptr), serialized_(false), + serialized_read_fds_length_(0u), + serialized_write_fds_length_(0u), + serialized_message_fds_length_(0u), calling_init_(false), write_error_(false) { } @@ -319,6 +380,9 @@ MessagePipeDispatcher::MessagePipeDispatcher() MessagePipeDispatcher::~MessagePipeDispatcher() { // |Close()|/|CloseImplNoLock()| should have taken care of the channel. DCHECK(!channel_); +#if defined(OS_POSIX) + ClosePlatformHandles(&serialized_fds_); +#endif } void MessagePipeDispatcher::CancelAllAwakablesNoLock() { @@ -339,17 +403,24 @@ void MessagePipeDispatcher::CloseImplNoLock() { void MessagePipeDispatcher::SerializeInternal() { // We need to stop watching handle immediately, even though not on IO thread, // so that other messages aren't read after this. - { - if (channel_) { - bool write_error = false; - serialized_platform_handle_ = channel_->ReleaseHandle( - &serialized_read_buffer_, &serialized_write_buffer_, &write_error); - channel_ = nullptr; - if (write_error) - write_error = true; - } else { - // It's valid that the other side wrote some data and closed its end. - } + std::vector<int> serialized_read_fds, serialized_write_fds; + if (channel_) { + bool write_error = false; + + serialized_platform_handle_ = channel_->ReleaseHandle( + &serialized_read_buffer_, &serialized_write_buffer_, + &serialized_read_fds, &serialized_write_fds, &write_error); + serialized_fds_.insert(serialized_fds_.end(), serialized_read_fds.begin(), + serialized_read_fds.end()); + serialized_read_fds_length_ = serialized_read_fds.size(); + serialized_fds_.insert(serialized_fds_.end(), serialized_write_fds.begin(), + serialized_write_fds.end()); + serialized_write_fds_length_ = serialized_write_fds.size(); + channel_ = nullptr; + if (write_error) + write_error = true; + } else { + // It's valid that the other side wrote some data and closed its end. } DCHECK(serialized_message_queue_.empty()); @@ -383,35 +454,37 @@ void MessagePipeDispatcher::SerializeInternal() { // cont'd if (transport_data_buffer_size != 0) { -#if defined(OS_WIN) // TODO(jam): copied from RawChannelWin::WriteNoLock( - if (RawChannel::GetSerializedPlatformHandleSize()) { + PlatformHandleVector* all_platform_handles = + message->transport_data()->platform_handles(); + if (all_platform_handles) { +#if defined(OS_WIN) char* serialization_data = static_cast<char*>(message->transport_data()->buffer()) + message->transport_data()->platform_handle_table_offset(); - PlatformHandleVector* all_platform_handles = - message->transport_data()->platform_handles(); - if (all_platform_handles) { - DWORD current_process_id = base::GetCurrentProcId(); - for (size_t i = 0; i < all_platform_handles->size(); i++) { - *reinterpret_cast<DWORD*>(serialization_data) = current_process_id; - serialization_data += sizeof(DWORD); - *reinterpret_cast<HANDLE*>(serialization_data) = - all_platform_handles->at(i).handle; - serialization_data += sizeof(HANDLE); - all_platform_handles->at(i) = PlatformHandle(); - } + DWORD current_process_id = base::GetCurrentProcId(); + for (size_t i = 0; i < all_platform_handles->size(); i++) { + *reinterpret_cast<DWORD*>(serialization_data) = current_process_id; + serialization_data += sizeof(DWORD); + *reinterpret_cast<HANDLE*>(serialization_data) = + all_platform_handles->at(i).handle; + serialization_data += sizeof(HANDLE); + all_platform_handles->at(i) = PlatformHandle(); } - } +#else + for (size_t i = 0; i < all_platform_handles->size(); i++) { + serialized_fds_.push_back(all_platform_handles->at(i).fd); + serialized_message_fds_length_++; + all_platform_handles->at(i) = PlatformHandle(); + } +#endif serialized_message_queue_.insert( serialized_message_queue_.end(), static_cast<const char*>(message->transport_data()->buffer()), static_cast<const char*>(message->transport_data()->buffer()) + transport_data_buffer_size); -#else - NOTREACHED() << "TODO(jam) implement"; -#endif + } } for (size_t i = 0; i < dispatchers.size(); ++i) @@ -435,6 +508,10 @@ MessagePipeDispatcher::CreateEquivalentDispatcherAndCloseImplNoLock() { serialized_message_queue_.swap(rv->serialized_message_queue_); serialized_read_buffer_.swap(rv->serialized_read_buffer_); serialized_write_buffer_.swap(rv->serialized_write_buffer_); + serialized_fds_.swap(rv->serialized_fds_); + rv->serialized_read_fds_length_ = serialized_read_fds_length_; + rv->serialized_write_fds_length_ = serialized_write_fds_length_; + rv->serialized_message_fds_length_ = serialized_message_fds_length_; rv->serialized_ = true; rv->write_error_ = write_error_; return scoped_refptr<Dispatcher>(rv.get()); @@ -602,6 +679,7 @@ void MessagePipeDispatcher::StartSerializeImplNoLock( !serialized_write_buffer_.empty() || !serialized_message_queue_.empty()) (*max_platform_handles)++; + *max_platform_handles += serialized_fds_.size(); *max_size = sizeof(SerializedMessagePipeHandleDispatcher); } @@ -646,6 +724,20 @@ bool MessagePipeDispatcher::EndSerializeAndCloseImplNoLock( serialization->shared_memory_handle_index = kInvalidMessagePipeHandleIndex; } + serialization->serialized_read_fds_length = serialized_read_fds_length_; + serialization->serialized_write_fds_length = serialized_write_fds_length_; + serialization->serialized_message_fds_length = serialized_message_fds_length_; + if (serialized_fds_.empty()) { + serialization->serialized_fds_index = kInvalidMessagePipeHandleIndex; + } else { +#if defined(OS_POSIX) + serialization->serialized_fds_index = platform_handles->size(); + for (size_t i = 0; i < serialized_fds_.size(); ++i) + platform_handles->push_back(PlatformHandle(serialized_fds_[i])); + serialized_fds_.clear(); +#endif + } + *actual_size = sizeof(SerializedMessagePipeHandleDispatcher); return true; } |