diff options
author | dilmah@chromium.org <dilmah@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-10-14 17:48:07 +0000 |
---|---|---|
committer | dilmah@chromium.org <dilmah@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-10-14 17:48:07 +0000 |
commit | 5ddffb822144de64aa36bd1aa4fc0aad634a3454 (patch) | |
tree | 1319b6250421e165159e82288df605ae7d245b5c /chrome/browser/chromeos/web_socket_proxy.cc | |
parent | 3b99e2ef219b43cdc51a3313268e6c9a27aa1be9 (diff) | |
download | chromium_src-5ddffb822144de64aa36bd1aa4fc0aad634a3454.zip chromium_src-5ddffb822144de64aa36bd1aa4fc0aad634a3454.tar.gz chromium_src-5ddffb822144de64aa36bd1aa4fc0aad634a3454.tar.bz2 |
Support SSL connections in websocket-to-TCP proxy.
For historical reasons current implementation of WS-to-TCP proxy was implemented as standalone libevent-based server.
Then it was integrated into chromium as is.
In order to support SSL we need to connect libevent-based proxy with MessageLoopForIO-based chromium network stack.
We do it using pipes.
It is intended as temporary solution until we will have
new shiny implementation of WS-to-TCP proxy integrated into network stack.
BUG=chromium-os:15533
TEST=Manual
Review URL: http://codereview.chromium.org/8087001
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@105515 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'chrome/browser/chromeos/web_socket_proxy.cc')
-rw-r--r-- | chrome/browser/chromeos/web_socket_proxy.cc | 771 |
1 files changed, 651 insertions, 120 deletions
diff --git a/chrome/browser/chromeos/web_socket_proxy.cc b/chrome/browser/chromeos/web_socket_proxy.cc index 6dce729..03c2d60 100644 --- a/chrome/browser/chromeos/web_socket_proxy.cc +++ b/chrome/browser/chromeos/web_socket_proxy.cc @@ -26,6 +26,7 @@ #include "base/base64.h" #include "base/basictypes.h" #include "base/logging.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/sha1.h" #include "base/stl_util.h" @@ -36,9 +37,21 @@ #include "chrome/common/url_constants.h" #include "content/browser/browser_thread.h" #include "content/common/content_notification_types.h" +#include "content/common/notification_details.h" #include "content/common/notification_service.h" #include "content/public/common/url_constants.h" #include "googleurl/src/gurl.h" +#include "googleurl/src/url_parse.h" +#include "net/base/address_list.h" +#include "net/base/cert_verifier.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/ssl_config_service.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" #include "third_party/libevent/evdns.h" #include "third_party/libevent/event.h" @@ -131,29 +144,10 @@ std::string FetchAsciiSnippet(uint8* begin, uint8* end, AsciiFilter filter) { return rv; } -// Returns true on success. -bool FetchDecimalDigits(const std::string& s, uint32* result) { - *result = 0; - bool got_something = false; - for (size_t i = 0; i < s.size(); ++i) { - if (IsAsciiDigit(s[i])) { - got_something = true; - if (*result > std::numeric_limits<uint32>::max() / 10) - return false; - *result *= 10; - int digit = s[i] - '0'; - if (*result > std::numeric_limits<uint32>::max() - digit) - return false; - *result += digit; - } - } - return got_something; -} - // Parses "passport:hostname:port:" string. Returns true on success. bool FetchPassportNamePort( uint8* begin, uint8* end, - std::string* passport, std::string* name, uint32* port) { + std::string* passport, std::string* name, int* port) { std::string input(begin, end); if (input[input.size() - 1] != ':') return false; @@ -169,8 +163,8 @@ bool FetchPassportNamePort( COMPILE_ASSERT(sizeof(kAsciiDigits) == 10 + 1, mess_with_digits); if (port_str.find_first_not_of(kAsciiDigits) != std::string::npos) return false; - if (!FetchDecimalDigits(port_str, port) || - *port <= 0 || + if (!base::StringToInt(port_str, port) || + *port < 0 || *port >= (1 << 16)) { return false; } @@ -196,11 +190,11 @@ inline size_t strlen(const void* s) { return ::strlen(static_cast<const char*>(s)); } -void SendNotification() { +void SendNotification(int port) { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); NotificationService::current()->Notify( chrome::NOTIFICATION_WEB_SOCKET_PROXY_STARTED, - NotificationService::AllSources(), NotificationService::NoDetails()); + NotificationService::AllSources(), Details<int>(&port)); } class Conn; @@ -208,8 +202,7 @@ class Conn; // Websocket to TCP proxy server. class Serv { public: - Serv(const std::vector<std::string>& allowed_origins, - struct sockaddr* addr, int addr_len); + explicit Serv(const std::vector<std::string>& allowed_origins); ~Serv(); // Do not call it twice. @@ -236,16 +229,17 @@ class Serv { // in a client websocket handshake. std::vector<std::string> allowed_origins_; - // Address to listen incoming websocket connections. - struct sockaddr* addr_; - int addr_len_; - // Libevent base. struct event_base* evbase_; // Socket to listen incoming websocket connections. int listening_sock_; + // TODO(dilmah): remove this extra socket as soon as obsolete + // getPassportForTCP function is removed from webSocketProxyPrivate API. + // Additional socket to listen incoming connections on fixed port 10101. + int extra_listening_sock_; + // Used to communicate control requests: either shutdown request or network // change notification. int control_descriptor_[2]; @@ -264,6 +258,9 @@ class Serv { scoped_ptr<struct event> connection_event_; scoped_ptr<struct event> control_event_; + // TODO(dilmah): remove this extra event as soon as obsolete + // getPassportForTCP function is removed from webSocketProxyPrivate API. + scoped_ptr<struct event> extra_connection_event_; DISALLOW_COPY_AND_ASSIGN(Serv); }; @@ -288,11 +285,16 @@ class Conn { PHASE_DEFUNCT // Connection was nuked. }; - // Channel structure (either proxy<->javascript or proxy<->destination). + // Channel structure (either proxy<->browser or proxy<->destination). class Chan { public: explicit Chan(Conn* master) - : master_(master), sock_(-1), bev_(NULL), write_pending_(false) { + : master_(master), + write_pending_(false), + read_bev_(NULL), + write_bev_(NULL), + read_fd_(-1), + write_fd_(-1) { } ~Chan() { @@ -301,23 +303,33 @@ class Conn { // Returns true on success. bool Write(const void* data, size_t size) { - if (bev_ == NULL || sock_ < 0) + if (write_bev_ == NULL) return false; write_pending_ = true; - return (0 == bufferevent_write(bev_, data, size)); + return (0 == bufferevent_write(write_bev_, data, size)); } void Zap() { - if (bev_) { - bufferevent_disable(bev_, EV_READ | EV_WRITE); - bufferevent_free(bev_); - bev_ = NULL; + if (read_bev_) { + bufferevent_disable(read_bev_, EV_READ); + bufferevent_free(read_bev_); } - if (sock_ >= 0) { - shutdown(sock_, SHUT_RDWR); - close(sock_); - sock_ = -1; + if (write_bev_ && write_bev_ != read_bev_) { + bufferevent_disable(write_bev_, EV_READ); + bufferevent_free(write_bev_); } + read_bev_ = NULL; + write_bev_ = NULL; + if (write_fd_ && read_fd_ == write_fd_) + shutdown(write_fd_, SHUT_RDWR); + if (write_fd_ >= 0) { + close(write_fd_); + DCHECK_GE(read_fd_, 0); + } + if (read_fd_ && read_fd_ != write_fd_) + close(read_fd_); + read_fd_ = -1; + write_fd_ = -1; write_pending_ = false; master_->ConsiderSuicide(); } @@ -327,15 +339,27 @@ class Conn { Zap(); } - int& sock() { return sock_; } - bool& write_pending() { return write_pending_; } - struct bufferevent*& bev() { return bev_; } + int read_fd() const { return read_fd_; } + void set_read_fd(int fd) { read_fd_ = fd; } + int write_fd() const { return write_fd_; } + void set_write_fd(int fd) { write_fd_ = fd; } + bool write_pending() const { return write_pending_; } + void set_write_pending(bool pending) { write_pending_ = pending; } + struct bufferevent* read_bev() const { return read_bev_; } + void set_read_bev(struct bufferevent* bev) { read_bev_ = bev; } + struct bufferevent* write_bev() const { return write_bev_; } + void set_write_bev(struct bufferevent* bev) { write_bev_ = bev; } private: Conn* master_; - int sock_; // UNIX descriptor. - struct bufferevent* bev_; bool write_pending_; // Whether write buffer is not flushed yet. + struct bufferevent* read_bev_; + struct bufferevent* write_bev_; + // UNIX descriptors. + int read_fd_; + int write_fd_; + + DISALLOW_COPY_AND_ASSIGN(Chan); }; // Status of processing incoming data. @@ -415,10 +439,16 @@ class Conn { // Header fields supplied by client at initial websocket handshake. std::map<std::string, std::string> header_fields_; + // Parameters requested via query component of GET resource. + std::map<std::string, std::string> requested_parameters_; + // Hostname and port of destination socket. // Websocket client supplies them in first data frame (destframe). std::string destname_; - uint32 destport_; + int destport_; + + // Whether TLS over TCP requested. + bool do_tls_; // We try to DNS resolve hostname in both IPv4 and IPv6 domains. // Track resolution failures here. @@ -434,14 +464,371 @@ class Conn { DISALLOW_COPY_AND_ASSIGN(Conn); }; -Serv::Serv( - const std::vector<std::string>& allowed_origins, - struct sockaddr* addr, int addr_len) +class SSLChan : public MessageLoopForIO::Watcher { + public: + static void Start(const net::AddressList& address_list, + const net::HostPortPair& host_port_pair, + int read_pipe, + int write_pipe) { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); + SSLChan* ALLOW_UNUSED chan = new SSLChan( + address_list, host_port_pair, read_pipe, write_pipe); + } + + private: + enum Phase { + PHASE_CONNECTING, + PHASE_RUNNING, + PHASE_CLOSING, + PHASE_CLOSED + }; + + class DerivedIOBufferWithSize : public net::IOBufferWithSize { + public: + DerivedIOBufferWithSize(net::IOBuffer* host, int size) + : IOBufferWithSize(host->data(), size), host_(host) { + DCHECK(host_); + DCHECK(host_->data()); + } + + virtual ~DerivedIOBufferWithSize() { + data_ = NULL; // We do not own memory, bypass base class destructor. + } + + protected: + scoped_refptr<net::IOBuffer> host_; + }; + + // Provides queue of data represented as IOBuffers. + class IOBufferQueue { + public: + // We do not allocate all capacity at once but lazily in |buf_size_| chunks. + explicit IOBufferQueue(int capacity) + : buf_size_(1 + capacity / kNumBuffersLimit) { + } + + // Obtains IOBuffer to add new data to back. + net::IOBufferWithSize* GetIOBufferToFill() { + if (back_ == NULL) { + if (storage_.size() >= kNumBuffersLimit) + return NULL; + storage_.push_back(new net::IOBufferWithSize(buf_size_)); + back_ = new net::DrainableIOBuffer(storage_.back(), buf_size_); + } + return new DerivedIOBufferWithSize( + back_.get(), back_->BytesRemaining()); + } + + // Obtains IOBuffer with some data from front. + net::IOBufferWithSize* GetIOBufferToProcess() { + if (front_ == NULL) { + if (storage_.empty()) + return NULL; + front_ = new net::DrainableIOBuffer(storage_.front(), buf_size_); + } + int front_capacity = (storage_.size() == 1 && back_) ? + back_->BytesConsumed() : buf_size_; + return new DerivedIOBufferWithSize( + front_.get(), front_capacity - front_->BytesConsumed()); + } + + // Records number of bytes as added to back. + void DidFill(int bytes) { + DCHECK(back_); + back_->DidConsume(bytes); + if (back_->BytesRemaining() == 0) + back_ = NULL; + } + + // Pops number of bytes from front. + void DidProcess(int bytes) { + DCHECK(front_); + front_->DidConsume(bytes); + if (front_->BytesRemaining() == 0) { + storage_.pop_front(); + front_ = NULL; + } + } + + void Clear() { + front_ = NULL; + back_ = NULL; + storage_.clear(); + } + + private: + static const unsigned kNumBuffersLimit = 12; + const int buf_size_; + std::list< scoped_refptr<net::IOBufferWithSize> > storage_; + scoped_refptr<net::DrainableIOBuffer> front_; + scoped_refptr<net::DrainableIOBuffer> back_; + + DISALLOW_COPY_AND_ASSIGN(IOBufferQueue); + }; + + SSLChan(const net::AddressList address_list, + const net::HostPortPair host_port_pair, + int read_pipe, + int write_pipe) + : phase_(PHASE_CONNECTING), + host_port_pair_(host_port_pair), + inbound_stream_(WebSocketProxy::kBufferLimit), + outbound_stream_(WebSocketProxy::kBufferLimit), + read_pipe_(read_pipe), + write_pipe_(write_pipe), + method_factory_(this), + socket_connect_callback_(NewCallback(this, &SSLChan::OnSocketConnect)), + ssl_handshake_callback_( + NewCallback(this, &SSLChan::OnSSLHandshakeCompleted)), + socket_read_callback_(NewCallback(this, &SSLChan::OnSocketRead)), + socket_write_callback_(NewCallback(this, &SSLChan::OnSocketWrite)) { + if (!SetNonBlock(read_pipe_) || !SetNonBlock(write_pipe_)) { + Shut(net::ERR_UNEXPECTED); + return; + } + net::ClientSocketFactory* factory = + net::ClientSocketFactory::GetDefaultFactory(); + socket_.reset(factory->CreateTransportClientSocket( + address_list, NULL, net::NetLog::Source())); + if (socket_ == NULL) { + Shut(net::ERR_FAILED); + return; + } + int result = socket_->Connect(socket_connect_callback_.get()); + if (result != net::ERR_IO_PENDING) + OnSocketConnect(result); + } + + ~SSLChan() { + phase_ = PHASE_CLOSED; + write_pipe_controller_.StopWatchingFileDescriptor(); + read_pipe_controller_.StopWatchingFileDescriptor(); + close(write_pipe_); + close(read_pipe_); + } + + void Shut(int ALLOW_UNUSED net_error_code) { + if (phase_ != PHASE_CLOSED) { + phase_ = PHASE_CLOSING; + scoped_refptr<net::IOBufferWithSize> buf[] = { + outbound_stream_.GetIOBufferToProcess(), + inbound_stream_.GetIOBufferToProcess() + }; + for (int i = arraysize(buf); i--;) { + if (buf[i] && buf[i]->size() > 0) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod(&SSLChan::Proceed)); + return; + } + } + phase_ = PHASE_CLOSED; + if (socket_ != NULL) { + socket_->Disconnect(); + socket_.reset(); + } + delete this; + } + } + + void OnSocketConnect(int result) { + if (phase_ != PHASE_CONNECTING) { + NOTREACHED(); + return; + } + if (result) { + Shut(result); + return; + } + net::ClientSocketHandle* handle = new net::ClientSocketHandle(); + handle->set_socket(socket_.release()); + net::ClientSocketFactory* factory = + net::ClientSocketFactory::GetDefaultFactory(); + net::SSLClientSocketContext ssl_context; + if (!cert_verifier_.get()) + cert_verifier_.reset(new net::CertVerifier()); + ssl_context.cert_verifier = cert_verifier_.get(); + socket_.reset(factory->CreateSSLClientSocket( + handle, host_port_pair_, ssl_config_, NULL, ssl_context)); + if (!socket_.get()) { + LOG(WARNING) << "Failed to create an SSL client socket."; + OnSSLHandshakeCompleted(net::ERR_UNEXPECTED); + return; + } + result = socket_->Connect(ssl_handshake_callback_.get()); + if (result != net::ERR_IO_PENDING) + OnSSLHandshakeCompleted(result); + } + + void OnSSLHandshakeCompleted(int result) { + if (result) + Shut(result); + is_socket_read_pending_ = false; + is_socket_write_pending_ = false; + is_read_pipe_blocked_ = false; + is_write_pipe_blocked_ = false; + MessageLoopForIO::current()->WatchFileDescriptor( + read_pipe_, false, MessageLoopForIO::WATCH_READ, + &read_pipe_controller_, this); + MessageLoopForIO::current()->WatchFileDescriptor( + write_pipe_, false, MessageLoopForIO::WATCH_WRITE, + &write_pipe_controller_, this); + phase_ = PHASE_RUNNING; + Proceed(); + } + + void OnSocketRead(int result) { + DCHECK(is_socket_read_pending_); + is_socket_read_pending_ = false; + if (result <= 0) { + Shut(result); + return; + } + inbound_stream_.DidFill(result); + Proceed(); + } + + void OnSocketWrite(int result) { + DCHECK(is_socket_write_pending_); + is_socket_write_pending_ = false; + if (result < 0) { + outbound_stream_.Clear(); + Shut(result); + return; + } + outbound_stream_.DidProcess(result); + Proceed(); + } + + // MessageLoopForIO::Watcher overrides. + virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE { + if (fd != read_pipe_) { + NOTREACHED(); + return; + } + is_read_pipe_blocked_ = false; + Proceed(); + } + + virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE { + if (fd != write_pipe_) { + NOTREACHED(); + return; + } + is_write_pipe_blocked_ = false; + Proceed(); + } + + private: + void Proceed() { + if (phase_ != PHASE_RUNNING && phase_ != PHASE_CLOSING) + return; + for (bool proceed = true; proceed;) { + proceed = false; + if (!is_read_pipe_blocked_ && phase_ == PHASE_RUNNING) { + scoped_refptr<net::IOBufferWithSize> buf = + outbound_stream_.GetIOBufferToFill(); + if (buf && buf->size() > 0) { + int rv = read(read_pipe_, buf->data(), buf->size()); + if (rv > 0) { + outbound_stream_.DidFill(rv); + proceed = true; + } else if (rv == -1 && errno == EAGAIN) { + is_read_pipe_blocked_ = true; + MessageLoopForIO::current()->WatchFileDescriptor( + read_pipe_, false, MessageLoopForIO::WATCH_READ, + &read_pipe_controller_, this); + } else if (rv == 0) { + Shut(0); + } else { + DCHECK_LT(rv, 0); + Shut(net::ERR_UNEXPECTED); + return; + } + } + } + if (!is_socket_read_pending_ && phase_ == PHASE_RUNNING) { + scoped_refptr<net::IOBufferWithSize> buf = + inbound_stream_.GetIOBufferToFill(); + if (buf && buf->size() > 0) { + int rv = socket_->Read(buf, buf->size(), socket_read_callback_.get()); + is_socket_read_pending_ = true; + if (rv != net::ERR_IO_PENDING) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod(&SSLChan::OnSocketRead, rv)); + } + } + } + if (!is_socket_write_pending_) { + scoped_refptr<net::IOBufferWithSize> buf = + outbound_stream_.GetIOBufferToProcess(); + if (buf && buf->size() > 0) { + int rv = socket_->Write( + buf, buf->size(), socket_write_callback_.get()); + is_socket_write_pending_ = true; + if (rv != net::ERR_IO_PENDING) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod(&SSLChan::OnSocketWrite, rv)); + } + } else if (phase_ == PHASE_CLOSING) { + Shut(0); + } + } + if (!is_write_pipe_blocked_) { + scoped_refptr<net::IOBufferWithSize> buf = + inbound_stream_.GetIOBufferToProcess(); + if (buf && buf->size() > 0) { + int rv = write(write_pipe_, buf->data(), buf->size()); + if (rv > 0) { + inbound_stream_.DidProcess(rv); + proceed = true; + } else if (rv == -1 && errno == EAGAIN) { + is_write_pipe_blocked_ = true; + MessageLoopForIO::current()->WatchFileDescriptor( + write_pipe_, false, MessageLoopForIO::WATCH_WRITE, + &write_pipe_controller_, this); + } else { + DCHECK_LE(rv, 0); + inbound_stream_.Clear(); + Shut(net::ERR_UNEXPECTED); + return; + } + } else if (phase_ == PHASE_CLOSING) { + Shut(0); + } + } + } + } + + Phase phase_; + scoped_ptr<net::StreamSocket> socket_; + net::HostPortPair host_port_pair_; + scoped_ptr<net::CertVerifier> cert_verifier_; + net::SSLConfig ssl_config_; + IOBufferQueue inbound_stream_; + IOBufferQueue outbound_stream_; + int read_pipe_; + int write_pipe_; + bool is_socket_read_pending_; + bool is_socket_write_pending_; + bool is_read_pipe_blocked_; + bool is_write_pipe_blocked_; + ScopedRunnableMethodFactory<SSLChan> method_factory_; + scoped_ptr<net::OldCompletionCallback> socket_connect_callback_; + scoped_ptr<net::OldCompletionCallback> ssl_handshake_callback_; + scoped_ptr<net::OldCompletionCallback> socket_read_callback_; + scoped_ptr<net::OldCompletionCallback> socket_write_callback_; + MessageLoopForIO::FileDescriptorWatcher read_pipe_controller_; + MessageLoopForIO::FileDescriptorWatcher write_pipe_controller_; + + friend class base::RefCountedThreadSafe<SSLChan>; + DISALLOW_COPY_AND_ASSIGN(SSLChan); +}; + +Serv::Serv(const std::vector<std::string>& allowed_origins) : allowed_origins_(allowed_origins), - addr_(addr), - addr_len_(addr_len), evbase_(NULL), listening_sock_(-1), + extra_listening_sock_(-1), shutdown_requested_(false) { std::sort(allowed_origins_.begin(), allowed_origins_.end()); control_descriptor_[0] = -1; @@ -474,7 +861,19 @@ void Serv::Run() { LOG(ERROR) << "WebSocketProxy: Failed to create socket"; return; } - if (bind(listening_sock_, addr_, addr_len_)) { + { + int on = 1; + setsockopt(listening_sock_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(0); // let OS allocatate ephemeral port number. + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + if (bind(listening_sock_, + reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr))) { LOG(ERROR) << "WebSocketProxy: Failed to bind server socket"; return; } @@ -482,10 +881,6 @@ void Serv::Run() { LOG(ERROR) << "WebSocketProxy: Failed to listen server socket"; return; } - { - int on = 1; - setsockopt(listening_sock_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - } if (!SetNonBlock(listening_sock_)) { LOG(ERROR) << "WebSocketProxy: Failed to go non block"; return; @@ -500,6 +895,48 @@ void Serv::Run() { return; } + { + // TODO(dilmah): remove this control block as soon as obsolete + // getPassportForTCP function is removed from webSocketProxyPrivate API. + // Following block adds extra listening socket on fixed port 10101. + extra_listening_sock_ = socket(AF_INET, SOCK_STREAM, 0); + if (extra_listening_sock_ < 0) { + LOG(ERROR) << "WebSocketProxy: Failed to create socket"; + return; + } + { + int on = 1; + setsockopt(listening_sock_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + } + const int kPort = 10101; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(kPort); + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + if (bind(extra_listening_sock_, + reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr))) { + LOG(ERROR) << "WebSocketProxy: Failed to bind server socket"; + return; + } + if (listen(extra_listening_sock_, 12)) { + LOG(ERROR) << "WebSocketProxy: Failed to listen server socket"; + return; + } + if (!SetNonBlock(extra_listening_sock_)) { + LOG(ERROR) << "WebSocketProxy: Failed to go non block"; + return; + } + extra_connection_event_.reset(new struct event); + event_set(extra_connection_event_.get(), extra_listening_sock_, + EV_READ | EV_PERSIST, &OnConnect, this); + event_base_set(evbase_, extra_connection_event_.get()); + if (event_add(extra_connection_event_.get(), NULL)) { + LOG(ERROR) << "WebSocketProxy: Failed to add listening event"; + return; + } + } + control_event_.reset(new struct event); event_set(control_event_.get(), control_descriptor_[0], EV_READ | EV_PERSIST, &OnControlRequest, this); @@ -516,9 +953,16 @@ void Serv::Run() { return; } + memset(&addr, 0, sizeof(addr)); + socklen_t addr_len = sizeof(addr); + if (getsockname( + listening_sock_, reinterpret_cast<struct sockaddr*>(&addr), &addr_len)) { + LOG(ERROR) << "Failed to determine listening port"; + return; + } BrowserThread::PostTask( BrowserThread::UI, FROM_HERE, - NewRunnableFunction(&SendNotification)); + NewRunnableFunction(&SendNotification, ntohs(addr.sin_port))); LOG(INFO) << "WebSocketProxy: Starting event dispatch loop."; event_base_dispatch(evbase_); @@ -556,6 +1000,10 @@ void Serv::CloseAll() { event_del(control_event_.get()); control_event_.reset(); } + if (extra_connection_event_.get()) { + event_del(extra_connection_event_.get()); + extra_connection_event_.reset(); + } if (connection_event_.get()) { event_del(connection_event_.get()); connection_event_.reset(); @@ -634,28 +1082,30 @@ bool Serv::IsOriginAllowed(const std::string& origin) { void Serv::OnConnect(int listening_sock, short event, void* ctx) { Serv* self = static_cast<Serv*>(ctx); Conn* cs = self->GetFreshConn(); - cs->primchan().sock() = accept(listening_sock, NULL, NULL); - if (cs->primchan().sock() < 0 - || !SetNonBlock(cs->primchan().sock())) { + int sock = accept(listening_sock, NULL, NULL); + if (sock < 0 || !SetNonBlock(sock)) { // Read readiness was triggered on listening socket // yet we failed to accept a connection; definitely weird. NOTREACHED(); self->ZapConn(cs); return; } + cs->primchan().set_read_fd(sock); + cs->primchan().set_write_fd(sock); - cs->primchan().bev() = bufferevent_new( - cs->primchan().sock(), + struct bufferevent* bev = bufferevent_new( + sock, &Conn::OnPrimchanRead, &Conn::OnPrimchanWrite, &Conn::OnPrimchanError, cs->evkey()); - if (cs->primchan().bev() == NULL) { + if (bev == NULL) { self->ZapConn(cs); return; } - bufferevent_base_set(self->evbase_, cs->primchan().bev()); - bufferevent_setwatermark( - cs->primchan().bev(), EV_READ, 0, WebSocketProxy::kReadBufferLimit); - if (bufferevent_enable(cs->primchan().bev(), EV_READ | EV_WRITE)) { + cs->primchan().set_read_bev(bev); + cs->primchan().set_write_bev(bev); + bufferevent_base_set(self->evbase_, bev); + bufferevent_setwatermark(bev, EV_READ, 0, WebSocketProxy::kBufferLimit); + if (bufferevent_enable(bev, EV_READ | EV_WRITE)) { self->ZapConn(cs); return; } @@ -684,12 +1134,14 @@ Conn::Conn(Serv* master) frame_mask_index_(0), primchan_(this), destchan_(this), + do_tls_(false), destresolution_ipv4_failed_(false), destresolution_ipv6_failed_(false) { while (evkey_map_.find(last_evkey_) != evkey_map_.end()) { - evkey_ = last_evkey_ = - reinterpret_cast<EventKey>(reinterpret_cast<size_t>(last_evkey_) + 1); + last_evkey_ = reinterpret_cast<EventKey>(reinterpret_cast<size_t>( + last_evkey_) + 1); } + evkey_ = last_evkey_; evkey_map_[evkey_] = this; // Schedule timeout for initial phase of connection. destconnect_timeout_event_.reset(new struct event); @@ -754,15 +1206,15 @@ Conn::Status Conn::ConsumeHeader(struct evbuffer* evb) { uint8* buf = EVBUFFER_DATA(evb); size_t buf_size = EVBUFFER_LENGTH(evb); - static const uint8 kGetMagic[] = "GET " kProxyPath " "; + static const uint8 kGetPrefix[] = "GET " kProxyPath; static const uint8 kKeyValueDelimiter[] = ": "; if (buf_size <= 0) return STATUS_INCOMPLETE; if (!buf) return STATUS_ABORT; - if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetMagic)), - kGetMagic)) { + if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)), + kGetPrefix)) { // Data head does not match what is expected. return STATUS_ABORT; } @@ -770,14 +1222,36 @@ Conn::Status Conn::ConsumeHeader(struct evbuffer* evb) { if (buf_size >= WebSocketProxy::kHeaderLimit) return STATUS_ABORT; uint8* buf_end = buf + buf_size; + // Handshake request must end with double CRLF. uint8* term_pos = std::search(buf, buf_end, kCRLFCRLF, - kCRLFCRLF + strlen(kCRLFCRLF)); + kCRLFCRLF + strlen(kCRLFCRLF)); + if (term_pos == buf_end) + return STATUS_INCOMPLETE; term_pos += strlen(kCRLFCRLF); - // First line is "GET /tcpproxy" line, so we skip it. - uint8* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF)); - if (pos == term_pos) + // First line is "GET path?query protocol" line. If query is empty then we + // fall back to (obsolete) way of obtaining parameters from first websocket + // frame. Otherwise query contains all required parameters (host, port etc). + uint8* get_request_end = std::search( + buf, term_pos, kCRLF, kCRLF + strlen(kCRLF)); + DCHECK(get_request_end != term_pos); + uint8* resource_end = std::find( + buf + strlen(kGetPrefix), get_request_end, ' '); + if (*resource_end != ' ') return STATUS_ABORT; - for (;;) { + if (resource_end != buf + strlen(kGetPrefix)) { + char* piece = reinterpret_cast<char*>(buf) + strlen(kGetPrefix) + 1; + url_parse::Component query( + 0, resource_end - reinterpret_cast<uint8*>(piece)); + for (url_parse::Component key, value; + url_parse::ExtractQueryKeyValue(piece, &query, &key, &value);) { + if (key.len > 0) { + requested_parameters_[std::string(piece + key.begin, key.len)] = + net::UnescapeURLComponent(std::string(piece + value.begin, + value.len), UnescapeRule::URL_SPECIAL_CHARS); + } + } + } + for (uint8* pos = get_request_end;;) { pos += strlen(kCRLF); if (term_pos - pos < static_cast<ptrdiff_t>(strlen(kCRLF))) return STATUS_ABORT; @@ -812,13 +1286,29 @@ Conn::Status Conn::ConsumeHeader(struct evbuffer* evb) { GURL origin = GURL(GetOrigin()).GetOrigin(); if (!origin.is_valid()) return STATUS_ABORT; - // Here we check origin. This check may seem redundant because we verify - // passport token later. However the earlier we can reject connection the - // better. We receive origin field in websocket header way before receiving - // passport string. if (!master_->IsOriginAllowed(origin.spec())) return STATUS_ABORT; + if (!requested_parameters_.empty()) { + destname_ = requested_parameters_["hostname"]; + int port; + if (!base::StringToInt(requested_parameters_["port"], &port) || + port < 0 || port >= 1 << 16) { + return STATUS_ABORT; + } + destport_ = port; + do_tls_ = (requested_parameters_["tls"] == "true"); + + requested_parameters_["extension_id"] = + FetchExtensionIdFromOrigin(GetOrigin()); + std::string passport(requested_parameters_["passport"]); + requested_parameters_.erase("passport"); + if (!browser::InternalAuthVerification::VerifyPassport( + passport, "web_socket_proxy", requested_parameters_)) { + return STATUS_ABORT; + } + } + evbuffer_drain(evb, term_pos - buf); return STATUS_OK; } @@ -870,6 +1360,11 @@ bool Conn::EmitFrame( } Conn::Status Conn::ConsumeDestframe(struct evbuffer* evb) { + if (!requested_parameters_.empty()) { + // Parameters were already provided (and verified) in query component of + // websocket URL. + return STATUS_OK; + } if (frame_bytes_remaining_ == 0) { Conn::Status rv = ConsumeFrameHeader(evb); if (rv != STATUS_OK) @@ -926,8 +1421,10 @@ Conn::Status Conn::ConsumeFrameHeader(struct evbuffer* evb) { } int opcode = buf[0] & 0x0f; switch (opcode) { - case 1: // Text frame. + case WS_OPCODE_TEXT: break; + case WS_OPCODE_CLOSE: + return STATUS_ABORT; default: NOTIMPLEMENTED(); return STATUS_ABORT; @@ -976,7 +1473,7 @@ Conn::Status Conn::ProcessFrameData(struct evbuffer* evb) { std::string out_bytes; base::Base64Decode(std::string(buf, buf + buf_size), &out_bytes); evbuffer_drain(evb, buf_size); - DCHECK(destchan_.bev() != NULL); + DCHECK(destchan_.write_bev()); if (!destchan_.Write(out_bytes.c_str(), out_bytes.size())) return STATUS_ABORT; break; @@ -994,28 +1491,63 @@ Conn::Status Conn::ProcessFrameData(struct evbuffer* evb) { } bool Conn::TryConnectDest(const struct sockaddr* addr, socklen_t addrlen) { - if (destchan_.sock() >= 0 || destchan_.bev() != NULL) - return false; - destchan_.sock() = socket(addr->sa_family, SOCK_STREAM, 0); - if (destchan_.sock() < 0) - return false; - if (!SetNonBlock(destchan_.sock())) + if (destchan_.read_fd() >= 0 || destchan_.read_bev() != NULL) return false; - if (connect(destchan_.sock(), addr, addrlen)) { - if (errno != EINPROGRESS) + if (do_tls_) { + int fd[4]; + if (pipe(fd) || pipe(fd + 2)) + return false; + destchan_.set_read_fd(fd[0]); + destchan_.set_write_fd(fd[3]); + for (int i = arraysize(fd); i--;) { + if (!SetNonBlock(fd[i])) + return false; + } + destchan_.set_read_bev(bufferevent_new( + destchan_.read_fd(), + &OnDestchanRead, NULL, &OnDestchanError, + evkey_)); + destchan_.set_write_bev(bufferevent_new( + destchan_.write_fd(), + NULL, &OnDestchanWrite, &OnDestchanError, + evkey_)); + net::AddressList addrlist = net::AddressList::CreateFromSockaddr( + addr, addrlen, SOCK_STREAM, IPPROTO_TCP); + net::HostPortPair host_port_pair(destname_, destport_); + BrowserThread::PostTask( + BrowserThread::IO, FROM_HERE, NewRunnableFunction( + &SSLChan::Start, addrlist, host_port_pair, fd[2], fd[1])); + } else { + int sock = socket(addr->sa_family, SOCK_STREAM, 0); + if (sock < 0) + return false; + destchan_.set_read_fd(sock); + destchan_.set_write_fd(sock); + if (!SetNonBlock(sock)) return false; + if (connect(sock, addr, addrlen)) { + if (errno != EINPROGRESS) + return false; + } + destchan_.set_read_bev(bufferevent_new( + sock, + &OnDestchanRead, &OnDestchanWrite, &OnDestchanError, + evkey_)); + destchan_.set_write_bev(destchan_.read_bev()); } - destchan_.bev() = bufferevent_new( - destchan_.sock(), - &OnDestchanRead, &OnDestchanWrite, &OnDestchanError, - evkey_); - if (destchan_.bev() == NULL) + if (destchan_.read_bev() == NULL || destchan_.write_bev() == NULL) return false; - if (bufferevent_base_set(master_->evbase(), destchan_.bev())) + if (bufferevent_base_set(master_->evbase(), destchan_.read_bev()) || + bufferevent_base_set(master_->evbase(), destchan_.write_bev())) { return false; + } bufferevent_setwatermark( - destchan_.bev(), EV_READ, 0, WebSocketProxy::kReadBufferLimit); - return !bufferevent_enable(destchan_.bev(), EV_READ | EV_WRITE); + destchan_.read_bev(), EV_READ, 0, WebSocketProxy::kBufferLimit); + if (bufferevent_enable(destchan_.read_bev(), EV_READ) || + bufferevent_enable(destchan_.write_bev(), EV_WRITE)) { + return false; + } + return true; } const std::string& Conn::GetOrigin() { @@ -1028,7 +1560,7 @@ void Conn::OnPrimchanRead(struct bufferevent* bev, EventKey evkey) { Conn* cs = Conn::Get(evkey); if (bev == NULL || cs == NULL || - bev != cs->primchan_.bev()) { + bev != cs->primchan_.read_bev()) { NOTREACHED(); return; } @@ -1057,7 +1589,6 @@ void Conn::OnPrimchanRead(struct bufferevent* bev, EventKey evkey) { return; } cs->phase_ = PHASE_WAIT_DESTFRAME; - return; } case PHASE_WAIT_DESTFRAME: { switch (cs->ConsumeDestframe(EVBUFFER_INPUT(bev))) { @@ -1130,7 +1661,7 @@ void Conn::OnPrimchanRead(struct bufferevent* bev, EventKey evkey) { } case PHASE_WAIT_DESTCONNECT: { if (EVBUFFER_LENGTH(EVBUFFER_INPUT(bev)) >= - WebSocketProxy::kReadBufferLimit) { + WebSocketProxy::kBufferLimit) { cs->Shut(WS_CLOSE_LIMIT_VIOLATION, "Read buffer overflow"); } return; @@ -1203,17 +1734,18 @@ void Conn::OnPrimchanWrite(struct bufferevent* bev, EventKey evkey) { Conn* cs = Conn::Get(evkey); if (bev == NULL || cs == NULL || - bev != cs->primchan_.bev()) { + bev != cs->primchan_.write_bev()) { NOTREACHED(); return; } - cs->primchan_.write_pending() = false; + // Write callback is called when low watermark is reached, 0 by default. + cs->primchan_.set_write_pending(false); if (cs->phase_ >= PHASE_SHUT) { cs->master_->ZapConn(cs); return; } if (cs->phase_ > PHASE_WAIT_DESTCONNECT) - OnDestchanRead(cs->destchan_.bev(), evkey); + OnDestchanRead(cs->destchan_.read_bev(), evkey); if (cs->phase_ >= PHASE_SHUT) cs->primchan_.Zap(); } @@ -1224,10 +1756,10 @@ void Conn::OnPrimchanError(struct bufferevent* bev, Conn* cs = Conn::Get(evkey); if (bev == NULL || cs == NULL || - bev != cs->primchan_.bev()) { + (bev != cs->primchan_.read_bev() && bev != cs->primchan_.write_bev())) { return; } - cs->primchan_.write_pending() = false; + cs->primchan_.set_write_pending(false); if (cs->phase_ >= PHASE_SHUT) cs->master_->ZapConn(cs); else @@ -1311,13 +1843,13 @@ void Conn::OnDestchanRead(struct bufferevent* bev, EventKey evkey) { Conn* cs = Conn::Get(evkey); if (bev == NULL || cs == NULL || - bev != cs->destchan_.bev()) { + bev != cs->destchan_.read_bev()) { NOTREACHED(); return; } if (EVBUFFER_LENGTH(EVBUFFER_INPUT(bev)) <= 0) return; - if (cs->primchan_.bev() == NULL) { + if (cs->primchan_.write_bev() == NULL) { cs->master_->ZapConn(cs); return; } @@ -1341,15 +1873,16 @@ void Conn::OnDestchanWrite(struct bufferevent* bev, EventKey evkey) { Conn* cs = Conn::Get(evkey); if (bev == NULL || cs == NULL || - bev != cs->destchan_.bev()) { + bev != cs->destchan_.write_bev()) { NOTREACHED(); return; } - cs->destchan_.write_pending() = false; + // Write callback is called when low watermark is reached, 0 by default. + cs->destchan_.set_write_pending(false); if (cs->phase_ == PHASE_WAIT_DESTCONNECT) cs->phase_ = PHASE_OUTSIDE_FRAME; if (cs->phase_ < PHASE_SHUT) - OnPrimchanRead(cs->primchan_.bev(), evkey); + OnPrimchanRead(cs->primchan_.read_bev(), evkey); else cs->destchan_.Zap(); } @@ -1360,10 +1893,10 @@ void Conn::OnDestchanError(struct bufferevent* bev, Conn* cs = Conn::Get(evkey); if (bev == NULL || cs == NULL || - bev != cs->destchan_.bev()) { + (bev != cs->destchan_.read_bev() && bev != cs->destchan_.write_bev())) { return; } - cs->destchan_.write_pending() = false; + cs->destchan_.set_write_pending(false); if (cs->phase_ >= PHASE_SHUT) cs->master_->ZapConn(cs); else @@ -1376,10 +1909,8 @@ Conn::EventKeyMap Conn::evkey_map_; } // namespace -WebSocketProxy::WebSocketProxy( - const std::vector<std::string>& allowed_origins, - struct sockaddr* addr, int addr_len) - : impl_(new Serv(allowed_origins, addr, addr_len)) { +WebSocketProxy::WebSocketProxy(const std::vector<std::string>& allowed_origins) + : impl_(new Serv(allowed_origins)) { } WebSocketProxy::~WebSocketProxy() { |