// Copyright (c) 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/socket/tcp_client_socket.h" #include "base/callback_helpers.h" #include "base/logging.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" namespace net { TCPClientSocket::TCPClientSocket(const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) : socket_(new TCPSocket(net_log, source)), addresses_(addresses), current_address_index_(-1), next_connect_state_(CONNECT_STATE_NONE), previously_disconnected_(false) { } TCPClientSocket::TCPClientSocket(scoped_ptr connected_socket, const IPEndPoint& peer_address) : socket_(connected_socket.Pass()), addresses_(AddressList(peer_address)), current_address_index_(0), next_connect_state_(CONNECT_STATE_NONE), previously_disconnected_(false) { DCHECK(socket_); socket_->SetDefaultOptionsForClient(); use_history_.set_was_ever_connected(); } TCPClientSocket::~TCPClientSocket() { } int TCPClientSocket::Bind(const IPEndPoint& address) { if (current_address_index_ >= 0 || bind_address_) { // Cannot bind the socket if we are already connected or connecting. NOTREACHED(); return ERR_UNEXPECTED; } int result = OK; if (!socket_->IsValid()) { result = OpenSocket(address.GetFamily()); if (result != OK) return result; } result = socket_->Bind(address); if (result != OK) return result; bind_address_.reset(new IPEndPoint(address)); return OK; } int TCPClientSocket::Connect(const CompletionCallback& callback) { DCHECK(!callback.is_null()); // If connecting or already connected, then just return OK. if (socket_->IsValid() && current_address_index_ >= 0) return OK; socket_->StartLoggingMultipleConnectAttempts(addresses_); // We will try to connect to each address in addresses_. Start with the // first one in the list. next_connect_state_ = CONNECT_STATE_CONNECT; current_address_index_ = 0; int rv = DoConnectLoop(OK); if (rv == ERR_IO_PENDING) { connect_callback_ = callback; } else { socket_->EndLoggingMultipleConnectAttempts(rv); } return rv; } int TCPClientSocket::DoConnectLoop(int result) { DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); int rv = result; do { ConnectState state = next_connect_state_; next_connect_state_ = CONNECT_STATE_NONE; switch (state) { case CONNECT_STATE_CONNECT: DCHECK_EQ(OK, rv); rv = DoConnect(); break; case CONNECT_STATE_CONNECT_COMPLETE: rv = DoConnectComplete(rv); break; default: NOTREACHED() << "bad state " << state; rv = ERR_UNEXPECTED; break; } } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); return rv; } int TCPClientSocket::DoConnect() { DCHECK_GE(current_address_index_, 0); DCHECK_LT(current_address_index_, static_cast(addresses_.size())); const IPEndPoint& endpoint = addresses_[current_address_index_]; if (previously_disconnected_) { use_history_.Reset(); previously_disconnected_ = false; } next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; if (socket_->IsValid()) { DCHECK(bind_address_); } else { int result = OpenSocket(endpoint.GetFamily()); if (result != OK) return result; if (bind_address_) { result = socket_->Bind(*bind_address_); if (result != OK) { socket_->Close(); return result; } } } // |socket_| is owned by this class and the callback won't be run once // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. return socket_->Connect(endpoint, base::Bind(&TCPClientSocket::DidCompleteConnect, base::Unretained(this))); } int TCPClientSocket::DoConnectComplete(int result) { if (result == OK) { use_history_.set_was_ever_connected(); return OK; // Done! } // Close whatever partially connected socket we currently have. DoDisconnect(); // Try to fall back to the next address in the list. if (current_address_index_ + 1 < static_cast(addresses_.size())) { next_connect_state_ = CONNECT_STATE_CONNECT; ++current_address_index_; return OK; } // Otherwise there is nothing to fall back to, so give up. return result; } void TCPClientSocket::Disconnect() { DoDisconnect(); current_address_index_ = -1; bind_address_.reset(); } void TCPClientSocket::DoDisconnect() { // If connecting or already connected, record that the socket has been // disconnected. previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0; socket_->Close(); } bool TCPClientSocket::IsConnected() const { return socket_->IsConnected(); } bool TCPClientSocket::IsConnectedAndIdle() const { return socket_->IsConnectedAndIdle(); } int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const { return socket_->GetPeerAddress(address); } int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const { DCHECK(address); if (!socket_->IsValid()) { if (bind_address_) { *address = *bind_address_; return OK; } return ERR_SOCKET_NOT_CONNECTED; } return socket_->GetLocalAddress(address); } const BoundNetLog& TCPClientSocket::NetLog() const { return socket_->net_log(); } void TCPClientSocket::SetSubresourceSpeculation() { use_history_.set_subresource_speculation(); } void TCPClientSocket::SetOmniboxSpeculation() { use_history_.set_omnibox_speculation(); } bool TCPClientSocket::WasEverUsed() const { return use_history_.was_used_to_convey_data(); } bool TCPClientSocket::UsingTCPFastOpen() const { return socket_->UsingTCPFastOpen(); } bool TCPClientSocket::WasNpnNegotiated() const { return false; } NextProto TCPClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return false; } int TCPClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(!callback.is_null()); // |socket_| is owned by this class and the callback won't be run once // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. CompletionCallback read_callback = base::Bind( &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback); int result = socket_->Read(buf, buf_len, read_callback); if (result > 0) use_history_.set_was_used_to_convey_data(); return result; } int TCPClientSocket::Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(!callback.is_null()); // |socket_| is owned by this class and the callback won't be run once // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. CompletionCallback write_callback = base::Bind( &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback); int result = socket_->Write(buf, buf_len, write_callback); if (result > 0) use_history_.set_was_used_to_convey_data(); return result; } int TCPClientSocket::SetReceiveBufferSize(int32 size) { return socket_->SetReceiveBufferSize(size); } int TCPClientSocket::SetSendBufferSize(int32 size) { return socket_->SetSendBufferSize(size); } bool TCPClientSocket::SetKeepAlive(bool enable, int delay) { return socket_->SetKeepAlive(enable, delay); } bool TCPClientSocket::SetNoDelay(bool no_delay) { return socket_->SetNoDelay(no_delay); } void TCPClientSocket::DidCompleteConnect(int result) { DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); DCHECK_NE(result, ERR_IO_PENDING); DCHECK(!connect_callback_.is_null()); result = DoConnectLoop(result); if (result != ERR_IO_PENDING) { socket_->EndLoggingMultipleConnectAttempts(result); base::ResetAndReturn(&connect_callback_).Run(result); } } void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback, int result) { if (result > 0) use_history_.set_was_used_to_convey_data(); callback.Run(result); } int TCPClientSocket::OpenSocket(AddressFamily family) { DCHECK(!socket_->IsValid()); int result = socket_->Open(family); if (result != OK) return result; socket_->SetDefaultOptionsForClient(); return OK; } } // namespace net