// Copyright 2015 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "remoting/base/buffered_socket_writer.h" #include "base/bind.h" #include "base/location.h" #include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/thread_task_runner_handle.h" #include "net/base/net_errors.h" namespace remoting { struct BufferedSocketWriterBase::PendingPacket { PendingPacket(scoped_refptr data, const base::Closure& done_task) : data(data), done_task(done_task) { } scoped_refptr data; base::Closure done_task; }; BufferedSocketWriterBase::BufferedSocketWriterBase() : buffer_size_(0), socket_(nullptr), write_pending_(false), closed_(false), destroyed_flag_(nullptr) { } void BufferedSocketWriterBase::Init(net::Socket* socket, const WriteFailedCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK(socket); socket_ = socket; write_failed_callback_ = callback; } bool BufferedSocketWriterBase::Write( scoped_refptr data, const base::Closure& done_task) { DCHECK(CalledOnValidThread()); DCHECK(socket_); DCHECK(data.get()); // Don't write after Close(). if (closed_) return false; queue_.push_back(new PendingPacket(data, done_task)); buffer_size_ += data->size(); DoWrite(); // DoWrite() may trigger OnWriteError() to be called. return !closed_; } void BufferedSocketWriterBase::DoWrite() { DCHECK(CalledOnValidThread()); DCHECK(socket_); // Don't try to write if there is another write pending. if (write_pending_) return; // Don't write after Close(). if (closed_) return; while (true) { net::IOBuffer* current_packet; int current_packet_size; GetNextPacket(¤t_packet, ¤t_packet_size); // Return if the queue is empty. if (!current_packet) return; int result = socket_->Write( current_packet, current_packet_size, base::Bind(&BufferedSocketWriterBase::OnWritten, base::Unretained(this))); bool write_again = false; HandleWriteResult(result, &write_again); if (!write_again) return; } } void BufferedSocketWriterBase::HandleWriteResult(int result, bool* write_again) { *write_again = false; if (result < 0) { if (result == net::ERR_IO_PENDING) { write_pending_ = true; } else { HandleError(result); if (!write_failed_callback_.is_null()) write_failed_callback_.Run(result); } return; } base::Closure done_task = AdvanceBufferPosition(result); if (!done_task.is_null()) { bool destroyed = false; destroyed_flag_ = &destroyed; done_task.Run(); if (destroyed) { // Stop doing anything if we've been destroyed by the callback. return; } destroyed_flag_ = nullptr; } *write_again = true; } void BufferedSocketWriterBase::OnWritten(int result) { DCHECK(CalledOnValidThread()); DCHECK(write_pending_); write_pending_ = false; bool write_again; HandleWriteResult(result, &write_again); if (write_again) DoWrite(); } void BufferedSocketWriterBase::HandleError(int result) { DCHECK(CalledOnValidThread()); closed_ = true; STLDeleteElements(&queue_); // Notify subclass that an error is received. OnError(result); } int BufferedSocketWriterBase::GetBufferSize() { return buffer_size_; } int BufferedSocketWriterBase::GetBufferChunks() { return queue_.size(); } void BufferedSocketWriterBase::Close() { DCHECK(CalledOnValidThread()); closed_ = true; } BufferedSocketWriterBase::~BufferedSocketWriterBase() { if (destroyed_flag_) *destroyed_flag_ = true; STLDeleteElements(&queue_); } base::Closure BufferedSocketWriterBase::PopQueue() { base::Closure result = queue_.front()->done_task; delete queue_.front(); queue_.pop_front(); return result; } BufferedSocketWriter::BufferedSocketWriter() { } void BufferedSocketWriter::GetNextPacket( net::IOBuffer** buffer, int* size) { if (!current_buf_.get()) { if (queue_.empty()) { *buffer = nullptr; return; // Nothing to write. } current_buf_ = new net::DrainableIOBuffer(queue_.front()->data.get(), queue_.front()->data->size()); } *buffer = current_buf_.get(); *size = current_buf_->BytesRemaining(); } base::Closure BufferedSocketWriter::AdvanceBufferPosition(int written) { buffer_size_ -= written; current_buf_->DidConsume(written); if (current_buf_->BytesRemaining() == 0) { current_buf_ = nullptr; return PopQueue(); } return base::Closure(); } void BufferedSocketWriter::OnError(int result) { current_buf_ = nullptr; } BufferedSocketWriter::~BufferedSocketWriter() { } BufferedDatagramWriter::BufferedDatagramWriter() { } void BufferedDatagramWriter::GetNextPacket( net::IOBuffer** buffer, int* size) { if (queue_.empty()) { *buffer = nullptr; return; // Nothing to write. } *buffer = queue_.front()->data.get(); *size = queue_.front()->data->size(); } base::Closure BufferedDatagramWriter::AdvanceBufferPosition(int written) { DCHECK_EQ(written, queue_.front()->data->size()); buffer_size_ -= queue_.front()->data->size(); return PopQueue(); } void BufferedDatagramWriter::OnError(int result) { // Nothing to do here. } BufferedDatagramWriter::~BufferedDatagramWriter() { } } // namespace remoting