diff options
5 files changed, 166 insertions, 84 deletions
diff --git a/extensions/browser/api/cast_channel/cast_channel_api.cc b/extensions/browser/api/cast_channel/cast_channel_api.cc index 8878a0c5..bf71053 100644 --- a/extensions/browser/api/cast_channel/cast_channel_api.cc +++ b/extensions/browser/api/cast_channel/cast_channel_api.cc @@ -65,7 +65,7 @@ void FillChannelInfo(const CastSocket& socket, ChannelInfo* channel_info) { bool IsValidConnectInfoPort(const ConnectInfo& connect_info) { return connect_info.port > 0 && connect_info.port < - std::numeric_limits<unsigned short>::max(); + std::numeric_limits<uint16_t>::max(); } bool IsValidConnectInfoAuth(const ConnectInfo& connect_info) { @@ -162,7 +162,8 @@ CastSocket* CastChannelAsyncApiFunction::GetSocketOrCompleteWithError( int channel_id) { CastSocket* socket = GetSocket(channel_id); if (!socket) { - SetResultFromError(cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID); + SetResultFromError(channel_id, + cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID); AsyncWorkCompleted(); } return socket; @@ -183,21 +184,24 @@ void CastChannelAsyncApiFunction::RemoveSocket(int channel_id) { manager_->Remove(extension_->id(), channel_id); } -void CastChannelAsyncApiFunction::SetResultFromSocket(int channel_id) { - CastSocket* socket = GetSocket(channel_id); - DCHECK(socket); +void CastChannelAsyncApiFunction::SetResultFromSocket( + const CastSocket& socket) { ChannelInfo channel_info; - FillChannelInfo(*socket, &channel_info); - error_ = socket->error_state(); + FillChannelInfo(socket, &channel_info); + error_ = socket.error_state(); SetResultFromChannelInfo(channel_info); } -void CastChannelAsyncApiFunction::SetResultFromError(ChannelError error) { +void CastChannelAsyncApiFunction::SetResultFromError(int channel_id, + ChannelError error) { ChannelInfo channel_info; - channel_info.channel_id = -1; + channel_info.channel_id = channel_id; channel_info.url = ""; channel_info.ready_state = cast_channel::READY_STATE_CLOSED; channel_info.error_state = error; + channel_info.connect_info.ip_address = ""; + channel_info.connect_info.port = 0; + channel_info.connect_info.auth = cast_channel::CHANNEL_AUTH_TYPE_SSL; SetResultFromChannelInfo(channel_info); error_ = error; } @@ -338,7 +342,13 @@ void CastChannelOpenFunction::AsyncWorkStart() { void CastChannelOpenFunction::OnOpen(int result) { DCHECK_CURRENTLY_ON(BrowserThread::IO); VLOG(1) << "Connect finished, OnOpen invoked."; - SetResultFromSocket(new_channel_id_); + CastSocket* socket = GetSocket(new_channel_id_); + if (!socket) { + SetResultFromError(new_channel_id_, + cast_channel::CHANNEL_ERROR_CONNECT_ERROR); + } else { + SetResultFromSocket(*socket); + } AsyncWorkCompleted(); } @@ -382,10 +392,13 @@ void CastChannelSendFunction::AsyncWorkStart() { void CastChannelSendFunction::OnSend(int result) { DCHECK_CURRENTLY_ON(BrowserThread::IO); - if (result < 0) { - SetResultFromError(cast_channel::CHANNEL_ERROR_SOCKET_ERROR); + int channel_id = params_->channel.channel_id; + CastSocket* socket = GetSocket(channel_id); + if (result < 0 || !socket) { + SetResultFromError(channel_id, + cast_channel::CHANNEL_ERROR_SOCKET_ERROR); } else { - SetResultFromSocket(params_->channel.channel_id); + SetResultFromSocket(*socket); } AsyncWorkCompleted(); } @@ -410,12 +423,16 @@ void CastChannelCloseFunction::AsyncWorkStart() { void CastChannelCloseFunction::OnClose(int result) { DCHECK_CURRENTLY_ON(BrowserThread::IO); VLOG(1) << "CastChannelCloseFunction::OnClose result = " << result; - if (result < 0) { - SetResultFromError(cast_channel::CHANNEL_ERROR_SOCKET_ERROR); + int channel_id = params_->channel.channel_id; + CastSocket* socket = GetSocket(channel_id); + if (result < 0 || !socket) { + SetResultFromError(channel_id, + cast_channel::CHANNEL_ERROR_SOCKET_ERROR); } else { - int channel_id = params_->channel.channel_id; - SetResultFromSocket(channel_id); + SetResultFromSocket(*socket); + // This will delete |socket|. RemoveSocket(channel_id); + socket = NULL; } AsyncWorkCompleted(); } diff --git a/extensions/browser/api/cast_channel/cast_channel_api.h b/extensions/browser/api/cast_channel/cast_channel_api.h index 99bbb02..3e07711 100644 --- a/extensions/browser/api/cast_channel/cast_channel_api.h +++ b/extensions/browser/api/cast_channel/cast_channel_api.h @@ -98,12 +98,13 @@ class CastChannelAsyncApiFunction : public AsyncApiFunction { // manager. void RemoveSocket(int channel_id); - // Sets the function result to a ChannelInfo obtained from the state of the - // CastSocket corresponding to |channel_id|. - void SetResultFromSocket(int channel_id); + // Sets the function result to a ChannelInfo obtained from the state of + // |socket|. + void SetResultFromSocket(const cast_channel::CastSocket& socket); - // Sets the function result to a ChannelInfo with |error|. - void SetResultFromError(cast_channel::ChannelError error); + // Sets the function result to a ChannelInfo populated with |channel_id| and + // |error|. + void SetResultFromError(int channel_id, cast_channel::ChannelError error); // Returns the socket corresponding to |channel_id| if one exists, or null // otherwise. diff --git a/extensions/browser/api/cast_channel/cast_channel_apitest.cc b/extensions/browser/api/cast_channel/cast_channel_apitest.cc index 14f1fbd..9f56d2d 100644 --- a/extensions/browser/api/cast_channel/cast_channel_apitest.cc +++ b/extensions/browser/api/cast_channel/cast_channel_apitest.cc @@ -20,6 +20,9 @@ #include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock_mutant.h" +// TODO(mfoltz): Mock out the ApiResourceManager to resolve threading issues +// (crbug.com/398242) and simulate unloading of the extension. + namespace cast_channel = extensions::core_api::cast_channel; using cast_channel::CastSocket; using cast_channel::ChannelError; @@ -69,11 +72,6 @@ class MockCastSocket : public CastSocket { base::TimeDelta::FromMilliseconds(kTimeoutMs)) {} virtual ~MockCastSocket() {} - virtual bool CalledOnValidThread() const OVERRIDE { - // Always return true in testing. - return true; - } - MOCK_METHOD1(Connect, void(const net::CompletionCallback& callback)); MOCK_METHOD2(SendMessage, void(const MessageInfo& message, const net::CompletionCallback& callback)); diff --git a/extensions/browser/api/cast_channel/cast_socket.cc b/extensions/browser/api/cast_channel/cast_socket.cc index f446bff..80c9e85 100644 --- a/extensions/browser/api/cast_channel/cast_socket.cc +++ b/extensions/browser/api/cast_channel/cast_socket.cc @@ -99,7 +99,11 @@ CastSocket::CastSocket(const std::string& owner_extension_id, current_read_buffer_ = header_read_buffer_; } -CastSocket::~CastSocket() { } +CastSocket::~CastSocket() { + // Ensure that resources are freed but do not run pending callbacks to avoid + // any re-entrancy. + CloseInternal(); +} ReadyState CastSocket::ready_state() const { return ready_state_; @@ -176,19 +180,24 @@ void CastSocket::Connect(const net::CompletionCallback& callback) { connect_callback_ = callback; connect_state_ = CONN_STATE_TCP_CONNECT; if (connect_timeout_.InMicroseconds() > 0) { - GetTimer()->Start( - FROM_HERE, - connect_timeout_, - base::Bind(&CastSocket::CancelConnect, AsWeakPtr())); + DCHECK(connect_timeout_callback_.IsCancelled()); + connect_timeout_callback_.Reset(base::Bind(&CastSocket::CancelConnect, + base::Unretained(this))); + GetTimer()->Start(FROM_HERE, + connect_timeout_, + connect_timeout_callback_.callback()); } DoConnectLoop(net::OK); } void CastSocket::PostTaskToStartConnectLoop(int result) { DCHECK(CalledOnValidThread()); - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr(), result)); + DCHECK(connect_loop_callback_.IsCancelled()); + connect_loop_callback_.Reset(base::Bind(&CastSocket::DoConnectLoop, + base::Unretained(this), + result)); + base::MessageLoop::current()->PostTask(FROM_HERE, + connect_loop_callback_.callback()); } void CastSocket::CancelConnect() { @@ -204,6 +213,7 @@ void CastSocket::CancelConnect() { // 1. Connect method: this starts the flow // 2. Callback from network operations that finish asynchronously void CastSocket::DoConnectLoop(int result) { + connect_loop_callback_.Cancel(); if (is_canceled_) { LOG(ERROR) << "CANCELLED - Aborting DoConnectLoop."; return; @@ -258,11 +268,12 @@ void CastSocket::DoConnectLoop(int result) { } int CastSocket::DoTcpConnect() { + DCHECK(connect_loop_callback_.IsCancelled()); VLOG_WITH_CONNECTION(1) << "DoTcpConnect"; connect_state_ = CONN_STATE_TCP_CONNECT_COMPLETE; tcp_socket_ = CreateTcpSocket(); return tcp_socket_->Connect( - base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr())); + base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this))); } int CastSocket::DoTcpConnectComplete(int result) { @@ -277,11 +288,12 @@ int CastSocket::DoTcpConnectComplete(int result) { } int CastSocket::DoSslConnect() { + DCHECK(connect_loop_callback_.IsCancelled()); VLOG_WITH_CONNECTION(1) << "DoSslConnect"; connect_state_ = CONN_STATE_SSL_CONNECT_COMPLETE; socket_ = CreateSslSocket(tcp_socket_.PassAs<net::StreamSocket>()); return socket_->Connect( - base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr())); + base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this))); } int CastSocket::DoSslConnectComplete(int result) { @@ -306,16 +318,28 @@ int CastSocket::DoAuthChallengeSend() { // Post a task to send auth challenge so that DoWriteLoop is not nested inside // DoConnectLoop. This is not strictly necessary but keeps the write loop // code decoupled from connect loop code. - base::MessageLoop::current()->PostTask( - FROM_HERE, + DCHECK(send_auth_challenge_callback_.IsCancelled()); + send_auth_challenge_callback_.Reset( base::Bind(&CastSocket::SendCastMessageInternal, - AsWeakPtr(), + base::Unretained(this), challenge_message, - base::Bind(&CastSocket::DoConnectLoop, AsWeakPtr()))); + base::Bind(&CastSocket::DoAuthChallengeSendWriteComplete, + base::Unretained(this)))); + base::MessageLoop::current()->PostTask( + FROM_HERE, + send_auth_challenge_callback_.callback()); // Always return IO_PENDING since the result is always asynchronous. return net::ERR_IO_PENDING; } +void CastSocket::DoAuthChallengeSendWriteComplete(int result) { + send_auth_challenge_callback_.Cancel(); + VLOG_WITH_CONNECTION(2) << "DoAuthChallengeSendWriteComplete: " << result; + DCHECK_GT(result, 0); + DCHECK_EQ(write_queue_.size(), 1UL); + PostTaskToStartConnectLoop(result); +} + int CastSocket::DoAuthChallengeSendComplete(int result) { VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result; if (result < 0) @@ -354,15 +378,46 @@ void CastSocket::DoConnectCallback(int result) { } void CastSocket::Close(const net::CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); + CloseInternal(); + RunPendingCallbacksOnClose(); + // Run this callback last. It may delete the socket. + callback.Run(net::OK); +} + +void CastSocket::CloseInternal() { + // TODO(mfoltz): Enforce this when CastChannelAPITest is rewritten to create + // and free sockets on the same thread. crbug.com/398242 + // DCHECK(CalledOnValidThread()); + if (ready_state_ == READY_STATE_CLOSED) { + return; + } VLOG_WITH_CONNECTION(1) << "Close ReadyState = " << ready_state_; tcp_socket_.reset(); socket_.reset(); cert_verifier_.reset(); transport_security_state_.reset(); + GetTimer()->Stop(); + + // Cancel callbacks that we queued ourselves to re-enter the connect or read + // loops. + connect_loop_callback_.Cancel(); + send_auth_challenge_callback_.Cancel(); + read_loop_callback_.Cancel(); + connect_timeout_callback_.Cancel(); ready_state_ = READY_STATE_CLOSED; - callback.Run(net::OK); - // |callback| can delete |this| +} + +void CastSocket::RunPendingCallbacksOnClose() { + DCHECK_EQ(ready_state_, READY_STATE_CLOSED); + if (!connect_callback_.is_null()) { + connect_callback_.Run(net::ERR_CONNECTION_FAILED); + connect_callback_.Reset(); + } + for (; !write_queue_.empty(); write_queue_.pop()) { + net::CompletionCallback& callback = write_queue_.front().callback; + callback.Run(net::ERR_FAILED); + callback.Reset(); + } } void CastSocket::SendMessage(const MessageInfo& message, @@ -377,7 +432,6 @@ void CastSocket::SendMessage(const MessageInfo& message, callback.Run(net::ERR_FAILED); return; } - SendCastMessageInternal(message_proto, callback); } @@ -454,11 +508,10 @@ int CastSocket::DoWrite() { << request.io_buffer->BytesConsumed(); write_state_ = WRITE_STATE_WRITE_COMPLETE; - return socket_->Write( request.io_buffer.get(), request.io_buffer->BytesRemaining(), - base::Bind(&CastSocket::DoWriteLoop, AsWeakPtr())); + base::Bind(&CastSocket::DoWriteLoop, base::Unretained(this))); } int CastSocket::DoWriteComplete(int result) { @@ -483,21 +536,11 @@ int CastSocket::DoWriteComplete(int result) { int CastSocket::DoWriteCallback() { DCHECK(!write_queue_.empty()); + write_state_ = WRITE_STATE_WRITE; WriteRequest& request = write_queue_.front(); int bytes_consumed = request.io_buffer->BytesConsumed(); - - // If inside connection flow, then there should be exaclty one item in - // the write queue. - if (ready_state_ == READY_STATE_CONNECTING) { - write_queue_.pop(); - DCHECK(write_queue_.empty()); - PostTaskToStartConnectLoop(bytes_consumed); - } else { - WriteRequest& request = write_queue_.front(); - request.callback.Run(bytes_consumed); - write_queue_.pop(); - } - write_state_ = WRITE_STATE_WRITE; + request.callback.Run(bytes_consumed); + write_queue_.pop(); return net::OK; } @@ -526,12 +569,15 @@ int CastSocket::DoWriteError(int result) { void CastSocket::PostTaskToStartReadLoop() { DCHECK(CalledOnValidThread()); - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&CastSocket::StartReadLoop, AsWeakPtr())); + DCHECK(read_loop_callback_.IsCancelled()); + read_loop_callback_.Reset(base::Bind(&CastSocket::StartReadLoop, + base::Unretained(this))); + base::MessageLoop::current()->PostTask(FROM_HERE, + read_loop_callback_.callback()); } void CastSocket::StartReadLoop() { + read_loop_callback_.Cancel(); // Read loop would have already been started if read state is not NONE if (read_state_ == READ_STATE_NONE) { read_state_ = READ_STATE_READ; @@ -603,7 +649,7 @@ int CastSocket::DoRead() { return socket_->Read( current_read_buffer_.get(), num_bytes_to_read, - base::Bind(&CastSocket::DoReadLoop, AsWeakPtr())); + base::Bind(&CastSocket::DoReadLoop, base::Unretained(this))); } int CastSocket::DoReadComplete(int result) { @@ -723,9 +769,9 @@ bool CastSocket::Serialize(const CastMessage& message_proto, void CastSocket::CloseWithError(ChannelError error) { DCHECK(CalledOnValidThread()); - socket_.reset(NULL); - ready_state_ = READY_STATE_CLOSED; + CloseInternal(); error_state_ = error; + RunPendingCallbacksOnClose(); if (delegate_) delegate_->OnError(this, error); } @@ -756,7 +802,7 @@ void CastSocket::MessageHeader::SetMessageSize(size_t size) { void CastSocket::MessageHeader::PrependToString(std::string* str) { MessageHeader output = *this; output.message_size = base::HostToNet32(message_size); - size_t header_size = base::checked_cast<size_t,uint32>( + size_t header_size = base::checked_cast<size_t, uint32>( MessageHeader::header_size()); scoped_ptr<char, base::FreeDeleter> char_array( static_cast<char*>(malloc(header_size))); @@ -769,7 +815,7 @@ void CastSocket::MessageHeader::PrependToString(std::string* str) { void CastSocket::MessageHeader::ReadFromIOBuffer( net::GrowableIOBuffer* buffer, MessageHeader* header) { uint32 message_size; - size_t header_size = base::checked_cast<size_t,uint32>( + size_t header_size = base::checked_cast<size_t, uint32>( MessageHeader::header_size()); memcpy(&message_size, buffer->StartOfBuffer(), header_size); header->message_size = base::NetToHost32(message_size); diff --git a/extensions/browser/api/cast_channel/cast_socket.h b/extensions/browser/api/cast_channel/cast_socket.h index 7bcd513..5ff1714 100644 --- a/extensions/browser/api/cast_channel/cast_socket.h +++ b/extensions/browser/api/cast_channel/cast_socket.h @@ -9,11 +9,9 @@ #include <string> #include "base/basictypes.h" -#include "base/callback.h" #include "base/cancelable_callback.h" #include "base/gtest_prod_util.h" #include "base/memory/ref_counted.h" -#include "base/memory/weak_ptr.h" #include "base/threading/thread_checker.h" #include "base/timer/timer.h" #include "extensions/browser/api/api_resource.h" @@ -45,17 +43,16 @@ class CastMessage; // // NOTE: Not called "CastChannel" to reduce confusion with the generated API // code. -class CastSocket : public ApiResource, - public base::SupportsWeakPtr<CastSocket> { +class CastSocket : public ApiResource { public: - // Object to be informed of incoming messages and errors. + // Object to be informed of incoming messages and errors. The CastSocket that + // owns the delegate must not be deleted by it, only by the ApiResourceManager + // or in the callback to Close(). class Delegate { public: // An error occurred on the channel. - // It is fine to delete the socket in this callback. virtual void OnError(const CastSocket* socket, ChannelError error) = 0; // A message was received on the channel. - // Do NOT delete the socket in this callback. virtual void OnMessage(const CastSocket* socket, const MessageInfo& message) = 0; @@ -72,6 +69,8 @@ class CastSocket : public ApiResource, CastSocket::Delegate* delegate, net::NetLog* net_log, const base::TimeDelta& connect_timeout); + + // Ensures that the socket is closed. virtual ~CastSocket(); // The IP endpoint for the destination of the channel. @@ -98,8 +97,8 @@ class CastSocket : public ApiResource, virtual ChannelError error_state() const; // Connects the channel to the peer. If successful, the channel will be in - // READY_STATE_OPEN. - // It is fine to delete the CastSocket object in |callback|. + // READY_STATE_OPEN. DO NOT delete the CastSocket object in |callback|. + // Instead use Close(). virtual void Connect(const net::CompletionCallback& callback); // Sends a message over a connected channel. The channel must be in @@ -108,15 +107,15 @@ class CastSocket : public ApiResource, // Note that if an error occurs the following happens: // 1. Completion callbacks for all pending writes are invoked with error. // 2. Delegate::OnError is called once. - // 3. Castsocket is closed. + // 3. CastSocket is closed. // - // DO NOT delete the CastSocket object in write completion callback. - // But it is fine to delete the socket in Delegate::OnError + // DO NOT delete the CastSocket object in |callback|. Instead use Close(). virtual void SendMessage(const MessageInfo& message, const net::CompletionCallback& callback); - // Closes the channel. On completion, the channel will be in - // READY_STATE_CLOSED. + // Closes the channel if not already closed. On completion, the channel will + // be in READY_STATE_CLOSED. + // // It is fine to delete the CastSocket object in |callback|. virtual void Close(const net::CompletionCallback& callback); @@ -221,6 +220,7 @@ class CastSocket : public ApiResource, int DoSslConnectComplete(int result); int DoAuthChallengeSend(); int DoAuthChallengeSendComplete(int result); + void DoAuthChallengeSendWriteComplete(int result); int DoAuthChallengeReplyComplete(int result); ///////////////////////////////////////////////////////////////////////////// @@ -266,9 +266,17 @@ class CastSocket : public ApiResource, // Parses the contents of body_read_buffer_ and sets current_message_ to // the message received. bool ProcessBody(); - // Closes socket, updating the error state and signaling the delegate that - // |error| has occurred. + // Closes the socket, sets |error_state_| to |error| and signals the + // delegate that |error| has occurred. void CloseWithError(ChannelError error); + // Frees resources and cancels pending callbacks. |ready_state_| will be set + // READY_STATE_CLOSED on completion. A no-op if |ready_state_| is already + // READY_STATE_CLOSED. + void CloseInternal(); + // Runs pending callbacks that are passed into us to notify API clients that + // pending operations will fail because the socket has been closed. + void RunPendingCallbacksOnClose(); + // Serializes the content of message_proto (with a header) to |message_data|. static bool Serialize(const CastMessage& message_proto, std::string* message_data); @@ -324,6 +332,8 @@ class CastSocket : public ApiResource, // Callback invoked when the socket is connected or fails to connect. net::CompletionCallback connect_callback_; + // Callback invoked by |connect_timeout_timer_| to cancel the connection. + base::CancelableClosure connect_timeout_callback_; // Duration to wait before timing out. base::TimeDelta connect_timeout_; // Timer invoked when the connection has timed out. @@ -343,6 +353,16 @@ class CastSocket : public ApiResource, // The current status of the channel. ReadyState ready_state_; + // Task invoked to (re)start the connect loop. Canceled on entry to the + // connect loop. + base::CancelableClosure connect_loop_callback_; + // Task invoked to send the auth challenge. Canceled when the auth challenge + // has been sent. + base::CancelableClosure send_auth_challenge_callback_; + // Callback invoked to (re)start the read loop. Canceled on entry to the read + // loop. + base::CancelableClosure read_loop_callback_; + // Holds a message to be written to the socket. |callback| is invoked when the // message is fully written or an error occurrs. struct WriteRequest { |