// Copyright (c) 2009 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/tools/flip_server/sm_connection.h" #include #include #include #include #include #include #include #include "net/tools/flip_server/constants.h" #include "net/tools/flip_server/flip_config.h" #include "net/tools/flip_server/http_interface.h" #include "net/tools/flip_server/spdy_interface.h" #include "net/tools/flip_server/spdy_ssl.h" #include "net/tools/flip_server/streamer_interface.h" namespace net { // static bool SMConnection::force_spdy_ = false; DataFrame::~DataFrame() { if (delete_when_done) delete[] data; } SMConnection::SMConnection(EpollServer* epoll_server, SSLState* ssl_state, MemoryCache* memory_cache, FlipAcceptor* acceptor, std::string log_prefix) : last_read_time_(0), fd_(-1), events_(0), registered_in_epoll_server_(false), initialized_(false), protocol_detected_(false), connection_complete_(false), connection_pool_(NULL), epoll_server_(epoll_server), ssl_state_(ssl_state), memory_cache_(memory_cache), acceptor_(acceptor), read_buffer_(kSpdySegmentSize * 40), sm_spdy_interface_(NULL), sm_http_interface_(NULL), sm_streamer_interface_(NULL), sm_interface_(NULL), log_prefix_(log_prefix), max_bytes_sent_per_dowrite_(4096), ssl_(NULL) {} SMConnection::~SMConnection() { if (initialized()) Reset(); } EpollServer* SMConnection::epoll_server() { return epoll_server_; } void SMConnection::ReadyToSend() { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Setting ready to send: EPOLLIN | EPOLLOUT"; epoll_server_->SetFDReady(fd_, EPOLLIN | EPOLLOUT); } void SMConnection::EnqueueDataFrame(DataFrame* df) { output_list_.push_back(df); VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "EnqueueDataFrame: " << "size = " << df->size << ": Setting FD ready."; ReadyToSend(); } void SMConnection::InitSMConnection(SMConnectionPoolInterface* connection_pool, SMInterface* sm_interface, EpollServer* epoll_server, int fd, std::string server_ip, std::string server_port, std::string remote_ip, bool use_ssl) { if (initialized_) { LOG(FATAL) << "Attempted to initialize already initialized server"; return; } client_ip_ = remote_ip; if (fd == -1) { // If fd == -1, then we are initializing a new connection that will // connect to the backend. // // ret: -1 == error // 0 == connection in progress // 1 == connection complete // TODO(kelindsay): is_numeric_host_address value needs to be detected server_ip_ = server_ip; server_port_ = server_port; int ret = CreateConnectedSocket( &fd_, server_ip, server_port, true, acceptor_->disable_nagle_); if (ret < 0) { LOG(ERROR) << "-1 Could not create connected socket"; return; } else if (ret == 1) { DCHECK_NE(-1, fd_); connection_complete_ = true; VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Connection complete to: " << server_ip_ << ":" << server_port_ << " "; } VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Connecting to server: " << server_ip_ << ":" << server_port_ << " "; } else { // If fd != -1 then we are initializing a connection that has just been // accepted from the listen socket. connection_complete_ = true; if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) { epoll_server_->UnregisterFD(fd_); } if (fd_ != -1) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Closing pre-existing fd"; close(fd_); fd_ = -1; } fd_ = fd; } registered_in_epoll_server_ = false; // Set the last read time here as the idle checker will start from // now. last_read_time_ = time(NULL); initialized_ = true; connection_pool_ = connection_pool; epoll_server_ = epoll_server; if (sm_interface) { sm_interface_ = sm_interface; protocol_detected_ = true; } read_buffer_.Clear(); epoll_server_->RegisterFD(fd_, this, EPOLLIN | EPOLLOUT | EPOLLET); if (use_ssl) { ssl_ = CreateSSLContext(ssl_state_->ssl_ctx); SSL_set_fd(ssl_, fd_); PrintSslError(); } } void SMConnection::CorkSocket() { int state = 1; int rv = setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof(state)); if (rv < 0) VLOG(1) << "setsockopt(CORK): " << errno; } void SMConnection::UncorkSocket() { int state = 0; int rv = setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof(state)); if (rv < 0) VLOG(1) << "setsockopt(CORK): " << errno; } int SMConnection::Send(const char* data, int len, int flags) { int rv = 0; CorkSocket(); if (ssl_) { ssize_t bytes_written = 0; // Write smallish chunks to SSL so that we don't have large // multi-packet TLS records to receive before being able to handle // the data. We don't have to be too careful here, because our data // frames are already getting chunked appropriately, and those are // the most likely "big" frames. while (len > 0) { const int kMaxTLSRecordSize = 1500; const char* ptr = &(data[bytes_written]); int chunksize = std::min(len, kMaxTLSRecordSize); rv = SSL_write(ssl_, ptr, chunksize); VLOG(2) << "SSLWrite(" << chunksize << " bytes): " << rv; if (rv <= 0) { switch (SSL_get_error(ssl_, rv)) { case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_ACCEPT: case SSL_ERROR_WANT_CONNECT: rv = -2; break; default: PrintSslError(); break; } break; } bytes_written += rv; len -= rv; if (rv != chunksize) break; // If we couldn't write everything, we're implicitly stalled } // If we wrote some data, return that count. Otherwise // return the stall error. if (bytes_written > 0) rv = bytes_written; } else { rv = send(fd_, data, len, flags); } if (!(flags & MSG_MORE)) UncorkSocket(); return rv; } void SMConnection::OnRegistration(EpollServer* eps, int fd, int event_mask) { registered_in_epoll_server_ = true; } void SMConnection::OnEvent(int fd, EpollEvent* event) { events_ |= event->in_events; HandleEvents(); if (events_) { event->out_ready_mask = events_; events_ = 0; } } void SMConnection::OnUnregistration(int fd, bool replaced) { registered_in_epoll_server_ = false; } void SMConnection::OnShutdown(EpollServer* eps, int fd) { Cleanup("OnShutdown"); return; } void SMConnection::Cleanup(const char* cleanup) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Cleanup: " << cleanup; if (!initialized_) return; Reset(); if (connection_pool_) connection_pool_->SMConnectionDone(this); if (sm_interface_) sm_interface_->ResetForNewConnection(); last_read_time_ = 0; } void SMConnection::HandleEvents() { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Received: " << EpollServer::EventMaskToString(events_).c_str(); if (events_ & EPOLLIN) { if (!DoRead()) goto handle_close_or_error; } if (events_ & EPOLLOUT) { // Check if we have connected or not if (connection_complete_ == false) { int sock_error; socklen_t sock_error_len = sizeof(sock_error); int ret = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &sock_error, &sock_error_len); if (ret != 0) { VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "getsockopt error: " << errno << ": " << strerror(errno); goto handle_close_or_error; } if (sock_error == 0) { connection_complete_ = true; VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Connection complete to " << server_ip_ << ":" << server_port_ << " "; } else if (sock_error == EINPROGRESS) { return; } else { VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "error connecting to server"; goto handle_close_or_error; } } if (!DoWrite()) goto handle_close_or_error; } if (events_ & (EPOLLHUP | EPOLLERR)) { VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "!!! Got HUP or ERR"; goto handle_close_or_error; } return; handle_close_or_error: Cleanup("HandleEvents"); } // Decide if SPDY was negotiated. bool SMConnection::WasSpdyNegotiated(SpdyMajorVersion* version_negotiated) { *version_negotiated = SPDY3; if (force_spdy()) return true; // If this is an SSL connection, check if NPN specifies SPDY. if (ssl_) { const unsigned char* npn_proto; unsigned int npn_proto_len; SSL_get0_next_proto_negotiated(ssl_, &npn_proto, &npn_proto_len); if (npn_proto_len > 0) { std::string npn_proto_str((const char*)npn_proto, npn_proto_len); VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "NPN protocol detected: " << npn_proto_str; if (!strncmp(reinterpret_cast(npn_proto), "spdy/2", npn_proto_len)) { *version_negotiated = SPDY2; return true; } if (!strncmp(reinterpret_cast(npn_proto), "spdy/3", npn_proto_len)) { *version_negotiated = SPDY3; return true; } if (!strncmp(reinterpret_cast(npn_proto), "spdy/4a2", npn_proto_len)) { *version_negotiated = SPDY4; return true; } } } return false; } bool SMConnection::SetupProtocolInterfaces() { DCHECK(!protocol_detected_); protocol_detected_ = true; SpdyMajorVersion version; bool spdy_negotiated = WasSpdyNegotiated(&version); bool using_ssl = ssl_ != NULL; if (using_ssl) VLOG(1) << (SSL_session_reused(ssl_) ? "Resumed" : "Renegotiated") << " SSL Session."; if (acceptor_->spdy_only_ && !spdy_negotiated) { VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "SPDY proxy only, closing HTTPS connection."; return false; } switch (acceptor_->flip_handler_type_) { case FLIP_HANDLER_HTTP_SERVER: { DCHECK(!spdy_negotiated); VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << (sm_http_interface_ ? "Creating" : "Reusing") << " HTTP interface."; if (!sm_http_interface_) sm_http_interface_ = new HttpSM(this, NULL, memory_cache_, acceptor_); sm_interface_ = sm_http_interface_; break; } case FLIP_HANDLER_PROXY: { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << (sm_streamer_interface_ ? "Creating" : "Reusing") << " PROXY Streamer interface."; if (!sm_streamer_interface_) { sm_streamer_interface_ = new StreamerSM(this, NULL, epoll_server_, acceptor_); sm_streamer_interface_->set_is_request(); } sm_interface_ = sm_streamer_interface_; // If spdy is not negotiated, the streamer interface will proxy all // data to the origin server. if (!spdy_negotiated) break; } // Otherwise fall through into the case below. case FLIP_HANDLER_SPDY_SERVER: { DCHECK(spdy_negotiated); VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << (sm_spdy_interface_ ? "Creating" : "Reusing") << " SPDY interface."; if (!sm_spdy_interface_) sm_spdy_interface_ = new SpdySM( this, NULL, epoll_server_, memory_cache_, acceptor_, version); sm_interface_ = sm_spdy_interface_; break; } } CorkSocket(); if (!sm_interface_->PostAcceptHook()) return false; return true; } bool SMConnection::DoRead() { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead()"; while (!read_buffer_.Full()) { char* bytes; int size; if (fd_ == -1) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead(): fd_ == -1. Invalid FD. Returning false"; return false; } read_buffer_.GetWritablePtr(&bytes, &size); ssize_t bytes_read = 0; if (ssl_) { bytes_read = SSL_read(ssl_, bytes, size); if (bytes_read < 0) { int err = SSL_get_error(ssl_, bytes_read); switch (err) { case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_ACCEPT: case SSL_ERROR_WANT_CONNECT: events_ &= ~EPOLLIN; VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead: SSL WANT_XXX: " << err; goto done; default: PrintSslError(); goto error_or_close; } } } else { bytes_read = recv(fd_, bytes, size, MSG_DONTWAIT); } int stored_errno = errno; if (bytes_read == -1) { switch (stored_errno) { case EAGAIN: events_ &= ~EPOLLIN; VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Got EAGAIN while reading"; goto done; case EINTR: VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Got EINTR while reading"; continue; default: VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "While calling recv, got error: " << (ssl_ ? "(ssl error)" : strerror(stored_errno)); goto error_or_close; } } else if (bytes_read > 0) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "read " << bytes_read << " bytes"; last_read_time_ = time(NULL); // If the protocol hasn't been detected yet, set up the handlers // we'll need. if (!protocol_detected_) { if (!SetupProtocolInterfaces()) { LOG(ERROR) << "Error setting up protocol interfaces."; goto error_or_close; } } read_buffer_.AdvanceWritablePtr(bytes_read); if (!DoConsumeReadData()) goto error_or_close; continue; } else { // bytes_read == 0 VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "0 bytes read with recv call."; } goto error_or_close; } done: VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead done!"; return true; error_or_close: VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead(): error_or_close. " << "Cleaning up, then returning false"; Cleanup("DoRead"); return false; } bool SMConnection::DoConsumeReadData() { char* bytes; int size; read_buffer_.GetReadablePtr(&bytes, &size); while (size != 0) { size_t bytes_consumed = sm_interface_->ProcessReadInput(bytes, size); VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "consumed " << bytes_consumed << " bytes"; if (bytes_consumed == 0) { break; } read_buffer_.AdvanceReadablePtr(bytes_consumed); if (sm_interface_->MessageFullyRead()) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "HandleRequestFullyRead: Setting EPOLLOUT"; HandleResponseFullyRead(); events_ |= EPOLLOUT; } else if (sm_interface_->Error()) { LOG(ERROR) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Framer error detected: Setting EPOLLOUT: " << sm_interface_->ErrorAsString(); // this causes everything to be closed/cleaned up. events_ |= EPOLLOUT; return false; } read_buffer_.GetReadablePtr(&bytes, &size); } return true; } void SMConnection::HandleResponseFullyRead() { sm_interface_->Cleanup(); } bool SMConnection::DoWrite() { size_t bytes_sent = 0; int flags = MSG_NOSIGNAL | MSG_DONTWAIT; if (fd_ == -1) { VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoWrite: fd == -1. Returning false."; return false; } if (output_list_.empty()) { VLOG(2) << log_prefix_ << "DoWrite: Output list empty."; if (sm_interface_) { sm_interface_->GetOutput(); } if (output_list_.empty()) { events_ &= ~EPOLLOUT; } } while (!output_list_.empty()) { VLOG(2) << log_prefix_ << "DoWrite: Items in output list: " << output_list_.size(); if (bytes_sent >= max_bytes_sent_per_dowrite_) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << " byte sent >= max bytes sent per write: Setting EPOLLOUT: " << bytes_sent; events_ |= EPOLLOUT; break; } if (sm_interface_ && output_list_.size() < 2) { sm_interface_->GetOutput(); } DataFrame* data_frame = output_list_.front(); const char* bytes = data_frame->data; int size = data_frame->size; bytes += data_frame->index; size -= data_frame->index; DCHECK_GE(size, 0); if (size <= 0) { output_list_.pop_front(); delete data_frame; continue; } flags = MSG_NOSIGNAL | MSG_DONTWAIT; // Look for a queue size > 1 because |this| frame is remains on the list // until it has finished sending. if (output_list_.size() > 1) { VLOG(2) << log_prefix_ << "Outlist size: " << output_list_.size() << ": Adding MSG_MORE flag"; flags |= MSG_MORE; } VLOG(2) << log_prefix_ << "Attempting to send " << size << " bytes."; ssize_t bytes_written = Send(bytes, size, flags); int stored_errno = errno; if (bytes_written == -1) { switch (stored_errno) { case EAGAIN: events_ &= ~EPOLLOUT; VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Got EAGAIN while writing"; goto done; case EINTR: VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Got EINTR while writing"; continue; default: VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "While calling send, got error: " << stored_errno << ": " << (ssl_ ? "" : strerror(stored_errno)); goto error_or_close; } } else if (bytes_written > 0) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Wrote: " << bytes_written << " bytes"; data_frame->index += bytes_written; bytes_sent += bytes_written; continue; } else if (bytes_written == -2) { // -2 handles SSL_ERROR_WANT_* errors events_ &= ~EPOLLOUT; goto done; } VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "0 bytes written with send call."; goto error_or_close; } done: UncorkSocket(); return true; error_or_close: VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoWrite: error_or_close. Returning false " << "after cleaning up"; Cleanup("DoWrite"); UncorkSocket(); return false; } void SMConnection::Reset() { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Resetting"; if (ssl_) { SSL_shutdown(ssl_); PrintSslError(); SSL_free(ssl_); PrintSslError(); ssl_ = NULL; } if (registered_in_epoll_server_) { epoll_server_->UnregisterFD(fd_); registered_in_epoll_server_ = false; } if (fd_ >= 0) { VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Closing connection"; close(fd_); fd_ = -1; } read_buffer_.Clear(); initialized_ = false; protocol_detected_ = false; events_ = 0; for (std::list::iterator i = output_list_.begin(); i != output_list_.end(); ++i) { delete *i; } output_list_.clear(); } // static SMConnection* SMConnection::NewSMConnection(EpollServer* epoll_server, SSLState* ssl_state, MemoryCache* memory_cache, FlipAcceptor* acceptor, std::string log_prefix) { return new SMConnection( epoll_server, ssl_state, memory_cache, acceptor, log_prefix); } } // namespace net