diff options
Diffstat (limited to 'net/tools/flip_server/sm_connection.cc')
-rw-r--r-- | net/tools/flip_server/sm_connection.cc | 648 |
1 files changed, 648 insertions, 0 deletions
diff --git a/net/tools/flip_server/sm_connection.cc b/net/tools/flip_server/sm_connection.cc new file mode 100644 index 0000000..be1e8db --- /dev/null +++ b/net/tools/flip_server/sm_connection.cc @@ -0,0 +1,648 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/flip_server/sm_connection.h" + +#include <errno.h> +#include <netinet/tcp.h> +#include <sys/socket.h> + +#include <list> +#include <string> + +#include "net/tools/flip_server/constants.h" +#include "net/tools/flip_server/flip_config.h" +#include "net/tools/flip_server/http_interface.h" +#include "net/tools/flip_server/spdy_interface.h" +#include "net/tools/flip_server/spdy_ssl.h" +#include "net/tools/flip_server/streamer_interface.h" + +namespace net { + +// static +bool SMConnection::force_spdy_ = false; + +SMConnection::SMConnection(EpollServer* epoll_server, + SSLState* ssl_state, + MemoryCache* memory_cache, + FlipAcceptor* acceptor, + std::string log_prefix) + : last_read_time_(0), + fd_(-1), + events_(0), + registered_in_epoll_server_(false), + initialized_(false), + protocol_detected_(false), + connection_complete_(false), + connection_pool_(NULL), + epoll_server_(epoll_server), + ssl_state_(ssl_state), + memory_cache_(memory_cache), + acceptor_(acceptor), + read_buffer_(kSpdySegmentSize * 40), + sm_spdy_interface_(NULL), + sm_http_interface_(NULL), + sm_streamer_interface_(NULL), + sm_interface_(NULL), + log_prefix_(log_prefix), + max_bytes_sent_per_dowrite_(4096), + ssl_(NULL) { +} + +SMConnection::~SMConnection() { + if (initialized()) + Reset(); +} + +void SMConnection::ReadyToSend() { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Setting ready to send: EPOLLIN | EPOLLOUT"; + epoll_server_->SetFDReady(fd_, EPOLLIN | EPOLLOUT); +} + +void SMConnection::EnqueueDataFrame(DataFrame* df) { + output_list_.push_back(df); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "EnqueueDataFrame: " + << "size = " << df->size << ": Setting FD ready."; + ReadyToSend(); +} + +void SMConnection::InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + std::string server_ip, + std::string server_port, + std::string remote_ip, + bool use_ssl) { + if (initialized_) { + LOG(FATAL) << "Attempted to initialize already initialized server"; + return; + } + + client_ip_ = remote_ip; + + if (fd == -1) { + // If fd == -1, then we are initializing a new connection that will + // connect to the backend. + // + // ret: -1 == error + // 0 == connection in progress + // 1 == connection complete + // TODO(kelindsay): is_numeric_host_address value needs to be detected + server_ip_ = server_ip; + server_port_ = server_port; + int ret = CreateConnectedSocket(&fd_, + server_ip, + server_port, + true, + acceptor_->disable_nagle_); + + if (ret < 0) { + LOG(ERROR) << "-1 Could not create connected socket"; + return; + } else if (ret == 1) { + DCHECK_NE(-1, fd_); + connection_complete_ = true; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Connection complete to: " << server_ip_ << ":" + << server_port_ << " "; + } + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Connecting to server: " << server_ip_ << ":" + << server_port_ << " "; + } else { + // If fd != -1 then we are initializing a connection that has just been + // accepted from the listen socket. + connection_complete_ = true; + if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) { + epoll_server_->UnregisterFD(fd_); + } + if (fd_ != -1) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Closing pre-existing fd"; + close(fd_); + fd_ = -1; + } + + fd_ = fd; + } + + registered_in_epoll_server_ = false; + // Set the last read time here as the idle checker will start from + // now. + last_read_time_ = time(NULL); + initialized_ = true; + + connection_pool_ = connection_pool; + epoll_server_ = epoll_server; + + if (sm_interface) { + sm_interface_ = sm_interface; + protocol_detected_ = true; + } + + read_buffer_.Clear(); + + epoll_server_->RegisterFD(fd_, this, EPOLLIN | EPOLLOUT | EPOLLET); + + if (use_ssl) { + ssl_ = CreateSSLContext(ssl_state_->ssl_ctx); + SSL_set_fd(ssl_, fd_); + PrintSslError(); + } +} + +void SMConnection::CorkSocket() { + int state = 1; + int rv = setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof(state)); + if (rv < 0) + VLOG(1) << "setsockopt(CORK): " << errno; +} + +void SMConnection::UncorkSocket() { + int state = 0; + int rv = setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &state, sizeof(state)); + if (rv < 0) + VLOG(1) << "setsockopt(CORK): " << errno; +} + +int SMConnection::Send(const char* data, int len, int flags) { + int rv = 0; + CorkSocket(); + if (ssl_) { + ssize_t bytes_written = 0; + // Write smallish chunks to SSL so that we don't have large + // multi-packet TLS records to receive before being able to handle + // the data. We don't have to be too careful here, because our data + // frames are already getting chunked appropriately, and those are + // the most likely "big" frames. + while (len > 0) { + const int kMaxTLSRecordSize = 1500; + const char* ptr = &(data[bytes_written]); + int chunksize = std::min(len, kMaxTLSRecordSize); + rv = SSL_write(ssl_, ptr, chunksize); + VLOG(2) << "SSLWrite(" << chunksize << " bytes): " << rv; + if (rv <= 0) { + switch (SSL_get_error(ssl_, rv)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_ACCEPT: + case SSL_ERROR_WANT_CONNECT: + rv = -2; + break; + default: + PrintSslError(); + break; + } + break; + } + bytes_written += rv; + len -= rv; + if (rv != chunksize) + break; // If we couldn't write everything, we're implicitly stalled + } + // If we wrote some data, return that count. Otherwise + // return the stall error. + if (bytes_written > 0) + rv = bytes_written; + } else { + rv = send(fd_, data, len, flags); + } + if (!(flags & MSG_MORE)) + UncorkSocket(); + return rv; +} + +void SMConnection::OnEvent(int fd, EpollEvent* event) { + events_ |= event->in_events; + HandleEvents(); + if (events_) { + event->out_ready_mask = events_; + events_ = 0; + } +} +void SMConnection::OnShutdown(EpollServer* eps, int fd) { + Cleanup("OnShutdown"); + return; +} + +void SMConnection::Cleanup(const char* cleanup) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Cleanup: " << cleanup; + if (!initialized_) + return; + Reset(); + if (connection_pool_) + connection_pool_->SMConnectionDone(this); + if (sm_interface_) + sm_interface_->ResetForNewConnection(); + last_read_time_ = 0; +} + +void SMConnection::HandleEvents() { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Received: " + << EpollServer::EventMaskToString(events_).c_str(); + + if (events_ & EPOLLIN) { + if (!DoRead()) + goto handle_close_or_error; + } + + if (events_ & EPOLLOUT) { + // Check if we have connected or not + if (connection_complete_ == false) { + int sock_error; + socklen_t sock_error_len = sizeof(sock_error); + int ret = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &sock_error, + &sock_error_len); + if (ret != 0) { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "getsockopt error: " << errno << ": " << strerror(errno); + goto handle_close_or_error; + } + if (sock_error == 0) { + connection_complete_ = true; + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Connection complete to " << server_ip_ << ":" + << server_port_ << " "; + } else if (sock_error == EINPROGRESS) { + return; + } else { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "error connecting to server"; + goto handle_close_or_error; + } + } + if (!DoWrite()) + goto handle_close_or_error; + } + + if (events_ & (EPOLLHUP | EPOLLERR)) { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "!!! Got HUP or ERR"; + goto handle_close_or_error; + } + return; + + handle_close_or_error: + Cleanup("HandleEvents"); +} + +// Decide if SPDY was negotiated. +bool SMConnection::WasSpdyNegotiated() { + if (force_spdy()) + return true; + + // If this is an SSL connection, check if NPN specifies SPDY. + if (ssl_) { + const unsigned char *npn_proto; + unsigned int npn_proto_len; + SSL_get0_next_proto_negotiated(ssl_, &npn_proto, &npn_proto_len); + if (npn_proto_len > 0) { + std::string npn_proto_str((const char *)npn_proto, npn_proto_len); + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "NPN protocol detected: " << npn_proto_str; + if (!strncmp(reinterpret_cast<const char*>(npn_proto), + "spdy/2", npn_proto_len)) + return true; + } + } + + return false; +} + +bool SMConnection::SetupProtocolInterfaces() { + DCHECK(!protocol_detected_); + protocol_detected_ = true; + + bool spdy_negotiated = WasSpdyNegotiated(); + bool using_ssl = ssl_ != NULL; + + if (using_ssl) + VLOG(1) << (SSL_session_reused(ssl_) ? "Resumed" : "Renegotiated") + << " SSL Session."; + + if (acceptor_->spdy_only_ && !spdy_negotiated) { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "SPDY proxy only, closing HTTPS connection."; + return false; + } + + switch (acceptor_->flip_handler_type_) { + case FLIP_HANDLER_HTTP_SERVER: + { + DCHECK(!spdy_negotiated); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << (sm_http_interface_ ? "Creating" : "Reusing") + << " HTTP interface."; + if (!sm_http_interface_) + sm_http_interface_ = new HttpSM(this, + NULL, + epoll_server_, + memory_cache_, + acceptor_); + sm_interface_ = sm_http_interface_; + } + break; + case FLIP_HANDLER_PROXY: + { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << (sm_streamer_interface_ ? "Creating" : "Reusing") + << " PROXY Streamer interface."; + if (!sm_streamer_interface_) + sm_streamer_interface_ = new StreamerSM(this, + NULL, + epoll_server_, + acceptor_); + sm_interface_ = sm_streamer_interface_; + // If spdy is not negotiated, the streamer interface will proxy all + // data to the origin server. + if (!spdy_negotiated) + break; + } + // Otherwise fall through into the case below. + case FLIP_HANDLER_SPDY_SERVER: + { + DCHECK(spdy_negotiated); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << (sm_spdy_interface_ ? "Creating" : "Reusing") + << " SPDY interface."; + if (!sm_spdy_interface_) + sm_spdy_interface_ = new SpdySM(this, + NULL, + epoll_server_, + memory_cache_, + acceptor_); + sm_interface_ = sm_spdy_interface_; + } + break; + } + + CorkSocket(); + if (!sm_interface_->PostAcceptHook()) + return false; + + return true; +} + +bool SMConnection::DoRead() { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead()"; + while (!read_buffer_.Full()) { + char* bytes; + int size; + if (fd_ == -1) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoRead(): fd_ == -1. Invalid FD. Returning false"; + return false; + } + read_buffer_.GetWritablePtr(&bytes, &size); + ssize_t bytes_read = 0; + if (ssl_) { + bytes_read = SSL_read(ssl_, bytes, size); + if (bytes_read < 0) { + int err = SSL_get_error(ssl_, bytes_read); + switch (err) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_ACCEPT: + case SSL_ERROR_WANT_CONNECT: + events_ &= ~EPOLLIN; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoRead: SSL WANT_XXX: " << err; + goto done; + default: + PrintSslError(); + goto error_or_close; + } + } + } else { + bytes_read = recv(fd_, bytes, size, MSG_DONTWAIT); + } + int stored_errno = errno; + if (bytes_read == -1) { + switch (stored_errno) { + case EAGAIN: + events_ &= ~EPOLLIN; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EAGAIN while reading"; + goto done; + case EINTR: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EINTR while reading"; + continue; + default: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "While calling recv, got error: " + << (ssl_?"(ssl error)":strerror(stored_errno)); + goto error_or_close; + } + } else if (bytes_read > 0) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "read " << bytes_read + << " bytes"; + last_read_time_ = time(NULL); + // If the protocol hasn't been detected yet, set up the handlers + // we'll need. + if (!protocol_detected_) { + if (!SetupProtocolInterfaces()) { + LOG(ERROR) << "Error setting up protocol interfaces."; + goto error_or_close; + } + } + read_buffer_.AdvanceWritablePtr(bytes_read); + if (!DoConsumeReadData()) + goto error_or_close; + continue; + } else { // bytes_read == 0 + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "0 bytes read with recv call."; + } + goto error_or_close; + } + done: + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "DoRead done!"; + return true; + + error_or_close: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoRead(): error_or_close. " + << "Cleaning up, then returning false"; + Cleanup("DoRead"); + return false; +} + +bool SMConnection::DoConsumeReadData() { + char* bytes; + int size; + read_buffer_.GetReadablePtr(&bytes, &size); + while (size != 0) { + size_t bytes_consumed = sm_interface_->ProcessReadInput(bytes, size); + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "consumed " + << bytes_consumed << " bytes"; + if (bytes_consumed == 0) { + break; + } + read_buffer_.AdvanceReadablePtr(bytes_consumed); + if (sm_interface_->MessageFullyRead()) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "HandleRequestFullyRead: Setting EPOLLOUT"; + HandleResponseFullyRead(); + events_ |= EPOLLOUT; + } else if (sm_interface_->Error()) { + LOG(ERROR) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Framer error detected: Setting EPOLLOUT: " + << sm_interface_->ErrorAsString(); + // this causes everything to be closed/cleaned up. + events_ |= EPOLLOUT; + return false; + } + read_buffer_.GetReadablePtr(&bytes, &size); + } + return true; +} + +void SMConnection::HandleResponseFullyRead() { + sm_interface_->Cleanup(); +} + +bool SMConnection::DoWrite() { + size_t bytes_sent = 0; + int flags = MSG_NOSIGNAL | MSG_DONTWAIT; + if (fd_ == -1) { + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoWrite: fd == -1. Returning false."; + return false; + } + if (output_list_.empty()) { + VLOG(2) << log_prefix_ << "DoWrite: Output list empty."; + if (sm_interface_) { + sm_interface_->GetOutput(); + } + if (output_list_.empty()) { + events_ &= ~EPOLLOUT; + } + } + while (!output_list_.empty()) { + VLOG(2) << log_prefix_ << "DoWrite: Items in output list: " + << output_list_.size(); + if (bytes_sent >= max_bytes_sent_per_dowrite_) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << " byte sent >= max bytes sent per write: Setting EPOLLOUT: " + << bytes_sent; + events_ |= EPOLLOUT; + break; + } + if (sm_interface_ && output_list_.size() < 2) { + sm_interface_->GetOutput(); + } + DataFrame* data_frame = output_list_.front(); + const char* bytes = data_frame->data; + int size = data_frame->size; + bytes += data_frame->index; + size -= data_frame->index; + DCHECK_GE(size, 0); + if (size <= 0) { + output_list_.pop_front(); + delete data_frame; + continue; + } + + flags = MSG_NOSIGNAL | MSG_DONTWAIT; + // Look for a queue size > 1 because |this| frame is remains on the list + // until it has finished sending. + if (output_list_.size() > 1) { + VLOG(2) << log_prefix_ << "Outlist size: " << output_list_.size() + << ": Adding MSG_MORE flag"; + flags |= MSG_MORE; + } + VLOG(2) << log_prefix_ << "Attempting to send " << size << " bytes."; + ssize_t bytes_written = Send(bytes, size, flags); + int stored_errno = errno; + if (bytes_written == -1) { + switch (stored_errno) { + case EAGAIN: + events_ &= ~EPOLLOUT; + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EAGAIN while writing"; + goto done; + case EINTR: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "Got EINTR while writing"; + continue; + default: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "While calling send, got error: " << stored_errno + << ": " << (ssl_?"":strerror(stored_errno)); + goto error_or_close; + } + } else if (bytes_written > 0) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Wrote: " + << bytes_written << " bytes"; + data_frame->index += bytes_written; + bytes_sent += bytes_written; + continue; + } else if (bytes_written == -2) { + // -2 handles SSL_ERROR_WANT_* errors + events_ &= ~EPOLLOUT; + goto done; + } + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "0 bytes written with send call."; + goto error_or_close; + } + done: + UncorkSocket(); + return true; + + error_or_close: + VLOG(1) << log_prefix_ << ACCEPTOR_CLIENT_IDENT + << "DoWrite: error_or_close. Returning false " + << "after cleaning up"; + Cleanup("DoWrite"); + UncorkSocket(); + return false; +} + +void SMConnection::Reset() { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Resetting"; + if (ssl_) { + SSL_shutdown(ssl_); + PrintSslError(); + SSL_free(ssl_); + PrintSslError(); + ssl_ = NULL; + } + if (registered_in_epoll_server_) { + epoll_server_->UnregisterFD(fd_); + registered_in_epoll_server_ = false; + } + if (fd_ >= 0) { + VLOG(2) << log_prefix_ << ACCEPTOR_CLIENT_IDENT << "Closing connection"; + close(fd_); + fd_ = -1; + } + read_buffer_.Clear(); + initialized_ = false; + protocol_detected_ = false; + events_ = 0; + for (std::list<DataFrame*>::iterator i = + output_list_.begin(); + i != output_list_.end(); + ++i) { + delete *i; + } + output_list_.clear(); +} + +// static +SMConnection* SMConnection::NewSMConnection(EpollServer* epoll_server, + SSLState *ssl_state, + MemoryCache* memory_cache, + FlipAcceptor *acceptor, + std::string log_prefix) { + return new SMConnection(epoll_server, ssl_state, memory_cache, + acceptor, log_prefix); +} + +} // namespace net + + |