diff options
-rw-r--r-- | net/socket/socks_client_socket.cc | 92 |
1 files changed, 47 insertions, 45 deletions
diff --git a/net/socket/socks_client_socket.cc b/net/socket/socks_client_socket.cc index ccedec0..1bcf80c 100644 --- a/net/socket/socks_client_socket.cc +++ b/net/socket/socks_client_socket.cc @@ -72,6 +72,9 @@ SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket, next_state_(STATE_NONE), socks_version_(kSOCKS4Unresolved), user_callback_(NULL), + handshake_buf_len_(0), + buffer_(NULL), + buffer_len_(0), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), @@ -230,60 +233,65 @@ int SOCKSClientSocket::DoResolveHostComplete(int result) { // Builds the buffer that is to be sent to the server. // We check whether the SOCKS proxy is 4 or 4A. // In case it is 4A, the record size increases by size of the hostname. -const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { +void SOCKSClientSocket::BuildHandshakeWriteBuffer() { DCHECK_NE(kSOCKS4Unresolved, socks_version_); - SOCKS4ServerRequest request; - request.version = kSOCKSVersion4; - request.command = kSOCKSStreamRequest; - request.nw_port = htons(host_request_info_.port()); + int record_size = kWriteHeaderSize + arraysize(kEmptyUserId); + if (socks_version_ == kSOCKS4a) { + record_size += host_request_info_.hostname().size() + 1; + } + + buffer_len_ = record_size; + buffer_.reset(new char[buffer_len_]); + + SOCKS4ServerRequest* request = + reinterpret_cast<SOCKS4ServerRequest*>(buffer_.get()); + + request->version = kSOCKSVersion4; + request->command = kSOCKSStreamRequest; + request->nw_port = htons(host_request_info_.port()); if (socks_version_ == kSOCKS4) { const struct addrinfo* ai = addresses_.head(); DCHECK(ai); // If the sockaddr is IPv6, we have already marked the version to socks4a // and so this step does not get hit. - struct sockaddr_in* ipv4_host = + struct sockaddr_in *ipv4_host = reinterpret_cast<struct sockaddr_in*>(ai->ai_addr); - memcpy(&request.ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr)); + memcpy(&request->ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr)); DLOG(INFO) << "Resolved Host is : " << NetAddressToString(ai); } else if (socks_version_ == kSOCKS4a) { // invalid IP of the form 0.0.0.127 - memcpy(&request.ip, kInvalidIp, arraysize(kInvalidIp)); + memcpy(&request->ip, kInvalidIp, arraysize(kInvalidIp)); } else { NOTREACHED(); } - std::string handshake_data(reinterpret_cast<char*>(&request), - sizeof(request)); - handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); + memcpy(&buffer_[kWriteHeaderSize], kEmptyUserId, arraysize(kEmptyUserId)); - // In case we are passing the domain also, pass the hostname - // terminated with a null character. if (socks_version_ == kSOCKS4a) { - handshake_data.append(host_request_info_.hostname()); - handshake_data.push_back('\0'); + memcpy(&buffer_[kWriteHeaderSize + arraysize(kEmptyUserId)], + host_request_info_.hostname().c_str(), + host_request_info_.hostname().size() + 1); } - - return handshake_data; } // Writes the SOCKS handshake data to the underlying socket connection. int SOCKSClientSocket::DoHandshakeWrite() { next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; - if (buffer_.empty()) { - buffer_ = BuildHandshakeWriteBuffer(); + if (!buffer_.get()) { + BuildHandshakeWriteBuffer(); bytes_sent_ = 0; } - int handshake_buf_len = buffer_.size() - bytes_sent_; - DCHECK_GT(handshake_buf_len, 0); - handshake_buf_ = new IOBuffer(handshake_buf_len); - memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], - handshake_buf_len); - return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_); + handshake_buf_len_ = buffer_len_ - bytes_sent_; + DCHECK_GT(handshake_buf_len_, 0); + handshake_buf_ = new IOBuffer(handshake_buf_len_); + memcpy(handshake_buf_.get()->data(), &buffer_[bytes_sent_], + handshake_buf_len_); + return transport_->Write(handshake_buf_, handshake_buf_len_, &io_callback_); } int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { @@ -292,14 +300,11 @@ int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { if (result < 0) return result; - // We ignore the case when result is 0, since the underlying Write - // may return spurious writes while waiting on the socket. - bytes_sent_ += result; - if (bytes_sent_ == buffer_.size()) { + if (bytes_sent_ == buffer_len_) { next_state_ = STATE_HANDSHAKE_READ; - buffer_.clear(); - } else if (bytes_sent_ < static_cast<int>(buffer_.size())) { + buffer_.reset(NULL); + } else if (bytes_sent_ < buffer_len_) { next_state_ = STATE_HANDSHAKE_WRITE; } else { return ERR_UNEXPECTED; @@ -313,13 +318,15 @@ int SOCKSClientSocket::DoHandshakeRead() { next_state_ = STATE_HANDSHAKE_READ_COMPLETE; - if (buffer_.empty()) { + if (!buffer_.get()) { + buffer_.reset(new char[kReadHeaderSize]); + buffer_len_ = kReadHeaderSize; bytes_received_ = 0; } - int handshake_buf_len = kReadHeaderSize - bytes_received_; - handshake_buf_ = new IOBuffer(handshake_buf_len); - return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_); + handshake_buf_len_ = buffer_len_ - bytes_received_; + handshake_buf_ = new IOBuffer(handshake_buf_len_); + return transport_->Read(handshake_buf_, handshake_buf_len_, &io_callback_); } int SOCKSClientSocket::DoHandshakeReadComplete(int result) { @@ -327,23 +334,18 @@ int SOCKSClientSocket::DoHandshakeReadComplete(int result) { if (result < 0) return result; - - // The underlying socket closed unexpectedly. - if (result == 0) - return ERR_CONNECTION_CLOSED; - - if (bytes_received_ + result > kReadHeaderSize) + if (bytes_received_ + result > buffer_len_) return ERR_INVALID_RESPONSE; - buffer_.append(handshake_buf_->data(), result); + memcpy(buffer_.get() + bytes_received_, handshake_buf_->data(), result); bytes_received_ += result; - if (bytes_received_ < kReadHeaderSize) { + if (bytes_received_ < buffer_len_) { next_state_ = STATE_HANDSHAKE_READ; return OK; } - const SOCKS4ServerResponse* response = - reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); + SOCKS4ServerResponse* response = + reinterpret_cast<SOCKS4ServerResponse*>(buffer_.get()); if (response->reserved_null != 0x00) { LOG(ERROR) << "Unknown response from SOCKS server."; |