diff options
author | ukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-11-04 10:02:28 +0000 |
---|---|---|
committer | ukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-11-04 10:02:28 +0000 |
commit | 4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949 (patch) | |
tree | 16237134b9e037cf609ca4211b5c9efbcd8759ee | |
parent | 8ecd3aade85b825c8206735f16c23880023782db (diff) | |
download | chromium_src-4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949.zip chromium_src-4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949.tar.gz chromium_src-4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949.tar.bz2 |
Implement websocket throttling.
Implement the client-side requirements in the spec.
4.1 Handshake
1. If the user agent already has a Web Socket connection to the
remote host (IP address) identified by /host/, even if known by
another name, wait until that connection has been established or
for that connection to have failed.
BUG=none
TEST=net_unittests passes
Review URL: http://codereview.chromium.org/342052
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@30949 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | chrome/browser/renderer_host/socket_stream_dispatcher_host.cc | 2 | ||||
-rw-r--r-- | net/net.gyp | 5 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.cc | 60 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.h | 14 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_throttle.cc | 80 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_throttle.h | 81 | ||||
-rw-r--r-- | net/websockets/websocket_throttle.cc | 293 | ||||
-rw-r--r-- | net/websockets/websocket_throttle.h | 67 | ||||
-rw-r--r-- | net/websockets/websocket_throttle_unittest.cc | 166 |
9 files changed, 743 insertions, 25 deletions
diff --git a/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc b/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc index ff921ad..68f7d4f 100644 --- a/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc +++ b/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc @@ -9,9 +9,11 @@ #include "chrome/common/render_messages.h" #include "chrome/common/net/socket_stream.h" #include "ipc/ipc_message.h" +#include "net/websockets/websocket_throttle.h" SocketStreamDispatcherHost::SocketStreamDispatcherHost() : sender_(NULL) { + net::WebSocketThrottle::Init(); } SocketStreamDispatcherHost::~SocketStreamDispatcherHost() { diff --git a/net/net.gyp b/net/net.gyp index 23154cb..b80b2a9 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -431,6 +431,8 @@ 'socket/tcp_pinger.h', 'socket_stream/socket_stream.cc', 'socket_stream/socket_stream.h', + 'socket_stream/socket_stream_throttle.cc', + 'socket_stream/socket_stream_throttle.h', 'third_party/parseftp/ParseFTPList.cpp', 'third_party/parseftp/ParseFTPList.h', 'url_request/url_request.cc', @@ -479,6 +481,8 @@ 'url_request/view_cache_helper.h', 'websockets/websocket.cc', 'websockets/websocket.h', + 'websockets/websocket_throttle.cc', + 'websockets/websocket_throttle.h', ], 'export_dependent_settings': [ '../base/base.gyp:base', @@ -647,6 +651,7 @@ 'url_request/url_request_tracker_unittest.cc', 'url_request/url_request_unittest.cc', 'url_request/url_request_unittest.h', + 'websockets/websocket_throttle_unittest.cc', 'websockets/websocket_unittest.cc', ], 'conditions': [ diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index 62aad61..aa0795c 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -25,6 +25,7 @@ #include "net/socket/socks5_client_socket.h" #include "net/socket/socks_client_socket.h" #include "net/socket/tcp_client_socket.h" +#include "net/socket_stream/socket_stream_throttle.h" #include "net/url_request/url_request.h" static const int kMaxPendingSendAllowed = 32768; // 32 kilobytes. @@ -55,12 +56,16 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate) write_buf_(NULL), current_write_buf_(NULL), write_buf_offset_(0), - write_buf_size_(0) { + write_buf_size_(0), + throttle_( + SocketStreamThrottle::GetSocketStreamThrottleForScheme( + url.scheme())) { DCHECK(MessageLoop::current()) << "The current MessageLoop must exist"; DCHECK_EQ(MessageLoop::TYPE_IO, MessageLoop::current()->type()) << "The current MessageLoop must be TYPE_IO"; DCHECK(delegate_); + DCHECK(throttle_); } SocketStream::~SocketStream() { @@ -199,6 +204,7 @@ void SocketStream::Finish(int result) { if (delegate) { delegate->OnClose(this); } + throttle_->OnClose(this); Release(); } @@ -213,6 +219,10 @@ void SocketStream::SetClientSocketFactory( factory_ = factory; } +void SocketStream::CopyAddrInfo(struct addrinfo* head) { + addresses_.Copy(head); +} + int SocketStream::DidEstablishConnection() { if (!socket_.get() || !socket_->IsConnected()) { next_state_ = STATE_CLOSE; @@ -226,24 +236,29 @@ int SocketStream::DidEstablishConnection() { return OK; } -void SocketStream::DidReceiveData(int result) { +int SocketStream::DidReceiveData(int result) { DCHECK(read_buf_); DCHECK_GT(result, 0); - if (!delegate_) - return; - // Notify recevied data to delegate. - delegate_->OnReceivedData(this, read_buf_->data(), result); + int len = result; + result = throttle_->OnRead(this, read_buf_->data(), len, &io_callback_); + if (delegate_) { + // Notify recevied data to delegate. + delegate_->OnReceivedData(this, read_buf_->data(), len); + } read_buf_ = NULL; + return result; } -void SocketStream::DidSendData(int result) { - current_write_buf_ = NULL; +int SocketStream::DidSendData(int result) { DCHECK_GT(result, 0); - if (!delegate_) - return; + int len = result; + result = throttle_->OnWrite(this, current_write_buf_->data(), len, + &io_callback_); + current_write_buf_ = NULL; + if (delegate_) + delegate_->OnSentData(this, len); - delegate_->OnSentData(this, result); - int remaining_size = write_buf_size_ - write_buf_offset_ - result; + int remaining_size = write_buf_size_ - write_buf_offset_ - len; if (remaining_size == 0) { if (!pending_write_bufs_.empty()) { write_buf_size_ = pending_write_bufs_.front()->size(); @@ -255,8 +270,9 @@ void SocketStream::DidSendData(int result) { } write_buf_offset_ = 0; } else { - write_buf_offset_ += result; + write_buf_offset_ += len; } + return result; } void SocketStream::OnIOCompleted(int result) { @@ -268,16 +284,14 @@ void SocketStream::OnReadCompleted(int result) { // 0 indicates end-of-file, so socket was closed. next_state_ = STATE_CLOSE; } else if (result > 0 && read_buf_) { - DidReceiveData(result); - result = OK; + result = DidReceiveData(result); } DoLoop(result); } void SocketStream::OnWriteCompleted(int result) { if (result >= 0 && write_buf_) { - DidSendData(result); - result = OK; + result = DidSendData(result); } DoLoop(result); } @@ -407,10 +421,12 @@ int SocketStream::DoResolveHost() { } int SocketStream::DoResolveHostComplete(int result) { - if (result == OK) + if (result == OK) { next_state_ = STATE_TCP_CONNECT; - else + result = throttle_->OnStartOpenConnection(this, &io_callback_); + } else { next_state_ = STATE_CLOSE; + } // TODO(ukai): if error occured, reconsider proxy after error. return result; } @@ -680,8 +696,7 @@ int SocketStream::DoReadWrite(int result) { read_buf_ = new IOBuffer(kReadBufferSize); result = socket_->Read(read_buf_, kReadBufferSize, &read_callback_); if (result > 0) { - DidReceiveData(result); - return OK; + return DidReceiveData(result); } else if (result == 0) { // 0 indicates end-of-file, so socket was closed. next_state_ = STATE_CLOSE; @@ -705,8 +720,7 @@ int SocketStream::DoReadWrite(int result) { current_write_buf_->BytesRemaining(), &write_callback_); if (result > 0) { - DidSendData(result); - return OK; + return DidSendData(result); } // If write is not pending, return the result and do next loop (to close // the connection). diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index 37b723a..14dfd36 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -31,6 +31,7 @@ class ClientSocketFactory; class HostResolver; class SSLConfigService; class SingleRequestHostResolver; +class SocketStreamThrottle; // SocketStream is used to implement Web Sockets. // It provides plain full-duplex stream with proxy and SSL support. @@ -96,6 +97,7 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { void SetUserData(const void* key, UserData* data); const GURL& url() const { return url_; } + const AddressList& address_list() const { return addresses_; } Delegate* delegate() const { return delegate_; } int max_pending_send_allowed() const { return max_pending_send_allowed_; } @@ -191,14 +193,20 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { friend class base::RefCountedThreadSafe<SocketStream>; ~SocketStream(); + friend class WebSocketThrottleTest; + + // Copies the given addrinfo list in |addresses_|. + // Used for WebSocketThrottleTest. + void CopyAddrInfo(struct addrinfo* head); + // Finishes the job. // Calls OnError and OnClose of delegate, and no more // notifications will be sent to delegate. void Finish(int result); int DidEstablishConnection(); - void DidReceiveData(int result); - void DidSendData(int result); + int DidReceiveData(int result); + int DidSendData(int result); void OnIOCompleted(int result); void OnReadCompleted(int result); @@ -289,6 +297,8 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { int write_buf_size_; PendingDataQueue pending_write_bufs_; + SocketStreamThrottle* throttle_; + DISALLOW_COPY_AND_ASSIGN(SocketStream); }; diff --git a/net/socket_stream/socket_stream_throttle.cc b/net/socket_stream/socket_stream_throttle.cc new file mode 100644 index 0000000..6a1d20d --- /dev/null +++ b/net/socket_stream/socket_stream_throttle.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> + +#include "net/socket_stream/socket_stream_throttle.h" + +#include "base/hash_tables.h" +#include "base/singleton.h" +#include "net/base/completion_callback.h" +#include "net/socket_stream/socket_stream.h" + +namespace net { + +// Default SocketStreamThrottle. No throttling. Used for unknown URL scheme. +class DefaultSocketStreamThrottle : public SocketStreamThrottle { + private: + DefaultSocketStreamThrottle() {} + virtual ~DefaultSocketStreamThrottle() {} + friend struct DefaultSingletonTraits<DefaultSocketStreamThrottle>; + + DISALLOW_COPY_AND_ASSIGN(DefaultSocketStreamThrottle); +}; + +class SocketStreamThrottleRegistry { + public: + SocketStreamThrottle* GetSocketStreamThrottleForScheme( + const std::string& scheme); + + void RegisterSocketStreamThrottle( + const std::string& scheme, SocketStreamThrottle* throttle); + + private: + typedef base::hash_map<std::string, SocketStreamThrottle*> ThrottleMap; + + SocketStreamThrottleRegistry() {} + ~SocketStreamThrottleRegistry() {} + friend struct DefaultSingletonTraits<SocketStreamThrottleRegistry>; + + ThrottleMap throttles_; + + DISALLOW_COPY_AND_ASSIGN(SocketStreamThrottleRegistry); +}; + +SocketStreamThrottle* +SocketStreamThrottleRegistry::GetSocketStreamThrottleForScheme( + const std::string& scheme) { + ThrottleMap::const_iterator found = throttles_.find(scheme); + if (found == throttles_.end()) { + SocketStreamThrottle* throttle = + Singleton<DefaultSocketStreamThrottle>::get(); + throttles_[scheme] = throttle; + return throttle; + } + return found->second; +} + +void SocketStreamThrottleRegistry::RegisterSocketStreamThrottle( + const std::string& scheme, SocketStreamThrottle* throttle) { + throttles_[scheme] = throttle; +} + +/* static */ +SocketStreamThrottle* SocketStreamThrottle::GetSocketStreamThrottleForScheme( + const std::string& scheme) { + SocketStreamThrottleRegistry* registry = + Singleton<SocketStreamThrottleRegistry>::get(); + return registry->GetSocketStreamThrottleForScheme(scheme); +} + +/* static */ +void SocketStreamThrottle::RegisterSocketStreamThrottle( + const std::string& scheme, SocketStreamThrottle* throttle) { + SocketStreamThrottleRegistry* registry = + Singleton<SocketStreamThrottleRegistry>::get(); + registry->RegisterSocketStreamThrottle(scheme, throttle); +} + +} // namespace net diff --git a/net/socket_stream/socket_stream_throttle.h b/net/socket_stream/socket_stream_throttle.h new file mode 100644 index 0000000..7726cbe --- /dev/null +++ b/net/socket_stream/socket_stream_throttle.h @@ -0,0 +1,81 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_STREAM_SOCKET_STREAM_THROTTLE_H_ +#define NET_SOCKET_STREAM_SOCKET_STREAM_THROTTLE_H_ + +#include <string> + +#include "base/basictypes.h" +#include "net/base/completion_callback.h" +#include "net/base/net_errors.h" + +namespace net { + +class SocketStream; + +// Abstract interface to throttle SocketStream per URL scheme. +// Each URL scheme (protocol) could define own SocketStreamThrottle. +// These methods will be called on IO thread. +class SocketStreamThrottle { + public: + // Called when |socket| is about to open connection. + // Returns net::OK if the connection can open now. + // Returns net::ERR_IO_PENDING if the connection should wait. In this case, + // |callback| will be called when it's ready to open connection. + virtual int OnStartOpenConnection(SocketStream* socket, + CompletionCallback* callback) { + // No throttle by default. + return OK; + } + + // Called when |socket| read |len| bytes of |data|. + // May wake up another waiting socket. + // Returns net::OK if |socket| can continue to run. + // Returns net::ERR_IO_PENDING if |socket| should suspend to run. In this + // case, |callback| will be called when it's ready to resume running. + virtual int OnRead(SocketStream* socket, const char* data, int len, + CompletionCallback* callback) { + // No throttle by default. + return OK; + } + + // Called when |socket| wrote |len| bytes of |data|. + // May wake up another waiting socket. + // Returns net::OK if |socket| can continue to run. + // Returns net::ERR_IO_PENDING if |socket| should suspend to run. In this + // case, |callback| will be called when it's ready to resume running. + virtual int OnWrite(SocketStream* socket, const char* data, int len, + CompletionCallback* callback) { + // No throttle by default. + return OK; + } + + // Called when |socket| is closed. + // May wake up another waiting socket. + virtual void OnClose(SocketStream* socket) {} + + // Gets SocketStreamThrottle for URL |scheme|. + // Doesn't pass ownership of the SocketStreamThrottle. + static SocketStreamThrottle* GetSocketStreamThrottleForScheme( + const std::string& scheme); + + // Registers |throttle| for URL |scheme|. + // Doesn't take ownership of |throttle|. Typically |throttle| is + // singleton instance. + static void RegisterSocketStreamThrottle( + const std::string& scheme, + SocketStreamThrottle* throttle); + + protected: + SocketStreamThrottle() {} + virtual ~SocketStreamThrottle() {} + + private: + DISALLOW_COPY_AND_ASSIGN(SocketStreamThrottle); +}; + +} // namespace net + +#endif // NET_SOCKET_STREAM_SOCKET_STREAM_THROTTLE_H_ diff --git a/net/websockets/websocket_throttle.cc b/net/websockets/websocket_throttle.cc new file mode 100644 index 0000000..fb320b6 --- /dev/null +++ b/net/websockets/websocket_throttle.cc @@ -0,0 +1,293 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_throttle.h" + +#if defined(OS_WIN) +#include <ws2tcpip.h> +#else +#include <netdb.h> +#endif + +#include <string> + +#include "base/message_loop.h" +#include "base/ref_counted.h" +#include "base/singleton.h" +#include "base/string_util.h" +#include "net/base/io_buffer.h" +#include "net/socket_stream/socket_stream.h" + +namespace net { + +static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) { + switch (addrinfo->ai_family) { + case AF_INET: { + const struct sockaddr_in* const addr = + reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr); + return StringPrintf("%d:%s", + addrinfo->ai_family, + HexEncode(&addr->sin_addr, 4).c_str()); + } + case AF_INET6: { + const struct sockaddr_in6* const addr6 = + reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr); + return StringPrintf("%d:%s", + addrinfo->ai_family, + HexEncode(&addr6->sin6_addr, + sizeof(addr6->sin6_addr)).c_str()); + } + default: + return StringPrintf("%d:%s", + addrinfo->ai_family, + HexEncode(addrinfo->ai_addr, + addrinfo->ai_addrlen).c_str()); + } +} + +// State for WebSocket protocol on each SocketStream. +// This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName. +// This is alive between connection starts and handshake is finished. +// In this class, it doesn't check actual handshake finishes, but only checks +// end of header is found in read data. +class WebSocketThrottle::WebSocketState : public SocketStream::UserData { + public: + explicit WebSocketState(const AddressList& addrs) + : address_list_(addrs), + callback_(NULL), + waiting_(false), + handshake_finished_(false), + buffer_(NULL) { + } + ~WebSocketState() {} + + int OnStartOpenConnection(CompletionCallback* callback) { + DCHECK(!callback_); + if (!waiting_) + return OK; + callback_ = callback; + return ERR_IO_PENDING; + } + + int OnRead(const char* data, int len, CompletionCallback* callback) { + DCHECK(!waiting_); + DCHECK(!callback_); + DCHECK(!handshake_finished_); + static const int kBufferSize = 8129; + + if (!buffer_) { + // Fast path. + int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0); + if (eoh > 0) { + handshake_finished_ = true; + return OK; + } + buffer_ = new GrowableIOBuffer(); + buffer_->SetCapacity(kBufferSize); + } else { + if (buffer_->RemainingCapacity() < len) { + if (!buffer_->SetCapacity(buffer_->capacity() + kBufferSize)) { + // TODO(ukai): Check more correctly. + // Seek to the last CR or LF and reduce memory usage. + LOG(ERROR) << "Too large headers? capacity=" << buffer_->capacity(); + handshake_finished_ = true; + return OK; + } + } + } + memcpy(buffer_->data(), data, len); + buffer_->set_offset(buffer_->offset() + len); + + int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(), + buffer_->offset(), 0); + handshake_finished_ = (eoh > 0); + return OK; + } + + const AddressList& address_list() const { return address_list_; } + void SetWaiting() { waiting_ = true; } + bool IsWaiting() const { return waiting_; } + bool HandshakeFinished() const { return handshake_finished_; } + void Wakeup() { + waiting_ = false; + // We wrap |callback_| to keep this alive while this is released. + scoped_refptr<CompletionCallbackRunner> runner = + new CompletionCallbackRunner(callback_); + callback_ = NULL; + MessageLoopForIO::current()->PostTask( + FROM_HERE, + NewRunnableMethod(runner.get(), + &CompletionCallbackRunner::Run)); + } + + static const char* kKeyName; + + private: + class CompletionCallbackRunner + : public base::RefCountedThreadSafe<CompletionCallbackRunner> { + public: + explicit CompletionCallbackRunner(CompletionCallback* callback) + : callback_(callback) { + DCHECK(callback_); + } + virtual ~CompletionCallbackRunner() {} + void Run() { + callback_->Run(OK); + } + private: + CompletionCallback* callback_; + + DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner); + }; + + const AddressList& address_list_; + + CompletionCallback* callback_; + // True if waiting another websocket connection is established. + // False if the websocket is performing handshaking. + bool waiting_; + + // True if the websocket handshake is completed. + // If true, it will be removed from queue and deleted from the SocketStream + // UserData soon. + bool handshake_finished_; + + // Buffer for read data to check handshake response message. + scoped_refptr<GrowableIOBuffer> buffer_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketState); +}; + +const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState"; + +WebSocketThrottle::WebSocketThrottle() { + SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this); + SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this); +} + +WebSocketThrottle::~WebSocketThrottle() { + DCHECK(queue_.empty()); + DCHECK(addr_map_.empty()); +} + +int WebSocketThrottle::OnStartOpenConnection( + SocketStream* socket, CompletionCallback* callback) { + WebSocketState* state = new WebSocketState(socket->address_list()); + PutInQueue(socket, state); + return state->OnStartOpenConnection(callback); +} + +int WebSocketThrottle::OnRead(SocketStream* socket, + const char* data, int len, + CompletionCallback* callback) { + WebSocketState* state = static_cast<WebSocketState*>( + socket->GetUserData(WebSocketState::kKeyName)); + // If no state, handshake was already completed. Do nothing. + if (!state) + return OK; + + int result = state->OnRead(data, len, callback); + if (state->HandshakeFinished()) { + RemoveFromQueue(socket, state); + WakeupSocketIfNecessary(); + } + return result; +} + +int WebSocketThrottle::OnWrite(SocketStream* socket, + const char* data, int len, + CompletionCallback* callback) { + // Do nothing. + return OK; +} + +void WebSocketThrottle::OnClose(SocketStream* socket) { + WebSocketState* state = static_cast<WebSocketState*>( + socket->GetUserData(WebSocketState::kKeyName)); + if (!state) + return; + RemoveFromQueue(socket, state); + WakeupSocketIfNecessary(); +} + +void WebSocketThrottle::PutInQueue(SocketStream* socket, + WebSocketState* state) { + socket->SetUserData(WebSocketState::kKeyName, state); + queue_.push_back(state); + const AddressList& address_list = socket->address_list(); + for (const struct addrinfo* addrinfo = address_list.head(); + addrinfo != NULL; + addrinfo = addrinfo->ai_next) { + std::string addrkey = AddrinfoToHashkey(addrinfo); + ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); + if (iter == addr_map_.end()) { + ConnectingQueue* queue = new ConnectingQueue(); + queue->push_back(state); + addr_map_[addrkey] = queue; + } else { + iter->second->push_back(state); + state->SetWaiting(); + } + } +} + +void WebSocketThrottle::RemoveFromQueue(SocketStream* socket, + WebSocketState* state) { + const AddressList& address_list = socket->address_list(); + for (const struct addrinfo* addrinfo = address_list.head(); + addrinfo != NULL; + addrinfo = addrinfo->ai_next) { + std::string addrkey = AddrinfoToHashkey(addrinfo); + ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); + DCHECK(iter != addr_map_.end()); + ConnectingQueue* queue = iter->second; + DCHECK(state == queue->front()); + queue->pop_front(); + if (queue->empty()) + addr_map_.erase(iter); + } + for (ConnectingQueue::iterator iter = queue_.begin(); + iter != queue_.end(); + ++iter) { + if (*iter == state) { + queue_.erase(iter); + break; + } + } + socket->SetUserData(WebSocketState::kKeyName, NULL); +} + +void WebSocketThrottle::WakeupSocketIfNecessary() { + for (ConnectingQueue::iterator iter = queue_.begin(); + iter != queue_.end(); + ++iter) { + WebSocketState* state = *iter; + if (!state->IsWaiting()) + continue; + + bool should_wakeup = true; + const AddressList& address_list = state->address_list(); + for (const struct addrinfo* addrinfo = address_list.head(); + addrinfo != NULL; + addrinfo = addrinfo->ai_next) { + std::string addrkey = AddrinfoToHashkey(addrinfo); + ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); + DCHECK(iter != addr_map_.end()); + ConnectingQueue* queue = iter->second; + if (state != queue->front()) { + should_wakeup = false; + break; + } + } + if (should_wakeup) + state->Wakeup(); + } +} + +/* static */ +void WebSocketThrottle::Init() { + Singleton<WebSocketThrottle>::get(); +} + +} // namespace net diff --git a/net/websockets/websocket_throttle.h b/net/websockets/websocket_throttle.h new file mode 100644 index 0000000..279aea2 --- /dev/null +++ b/net/websockets/websocket_throttle.h @@ -0,0 +1,67 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_ +#define NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_ + +#include "base/hash_tables.h" +#include "base/singleton.h" +#include "net/socket_stream/socket_stream_throttle.h" + +namespace net { + +// SocketStreamThrottle for WebSocket protocol. +// Implements the client-side requirements in the spec. +// http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol +// 4.1 Handshake +// 1. If the user agent already has a Web Socket connection to the +// remote host (IP address) identified by /host/, even if known by +// another name, wait until that connection has been established or +// for that connection to have failed. +class WebSocketThrottle : public SocketStreamThrottle { + public: + virtual int OnStartOpenConnection(SocketStream* socket, + CompletionCallback* callback); + virtual int OnRead(SocketStream* socket, const char* data, int len, + CompletionCallback* callback); + virtual int OnWrite(SocketStream* socket, const char* data, int len, + CompletionCallback* callback); + virtual void OnClose(SocketStream* socket); + + static void Init(); + + private: + class WebSocketState; + typedef std::deque<WebSocketState*> ConnectingQueue; + typedef base::hash_map<std::string, ConnectingQueue*> ConnectingAddressMap; + + WebSocketThrottle(); + virtual ~WebSocketThrottle(); + friend struct DefaultSingletonTraits<WebSocketThrottle>; + + // Puts |socket| in |queue_| and queues for the destination addresses + // of |socket|. Also sets |state| as UserData of |socket|. + // If other socket is using the same destination address, set |state| waiting. + void PutInQueue(SocketStream* socket, WebSocketState* state); + + // Removes |socket| from |queue_| and queues for the destination addresses + // of |socket|. Also releases |state| from UserData of |socket|. + void RemoveFromQueue(SocketStream* socket, WebSocketState* state); + + // Checks sockets waiting in |queue_| and check the socket is the front of + // every queue for the destination addresses of |socket|. + // If so, the socket can resume estabilshing connection, so wake up + // the socket. + void WakeupSocketIfNecessary(); + + // Key: string of host's address. Value: queue of sockets for the address. + ConnectingAddressMap addr_map_; + + // Queue of sockets for websockets in opening state. + ConnectingQueue queue_; +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_ diff --git a/net/websockets/websocket_throttle_unittest.cc b/net/websockets/websocket_throttle_unittest.cc new file mode 100644 index 0000000..3757a0b --- /dev/null +++ b/net/websockets/websocket_throttle_unittest.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "build/build_config.h" + +#if defined(OS_WIN) +#include <ws2tcpip.h> +#else +#include <netdb.h> +#endif + +#include <string> + +#include "base/message_loop.h" +#include "googleurl/src/gurl.h" +#include "net/base/address_list.h" +#include "net/base/test_completion_callback.h" +#include "net/socket_stream/socket_stream.h" +#include "net/websockets/websocket_throttle.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +class DummySocketStreamDelegate : public net::SocketStream::Delegate { + public: + DummySocketStreamDelegate() {} + virtual ~DummySocketStreamDelegate() {} + virtual void OnConnected( + net::SocketStream* socket, int max_pending_send_allowed) {} + virtual void OnSentData(net::SocketStream* socket, int amount_sent) {} + virtual void OnReceivedData(net::SocketStream* socket, + const char* data, int len) {} + virtual void OnClose(net::SocketStream* socket) {} +}; + +namespace net { + +class WebSocketThrottleTest : public PlatformTest { + protected: + struct addrinfo *AddAddr(int a1, int a2, int a3, int a4, + struct addrinfo* next) { + struct addrinfo* addrinfo = new struct addrinfo; + memset(addrinfo, 0, sizeof(struct addrinfo)); + addrinfo->ai_family = AF_INET; + int addrlen = sizeof(struct sockaddr_in); + addrinfo->ai_addrlen = addrlen; + addrinfo->ai_addr = reinterpret_cast<sockaddr*>(new char[addrlen]); + memset(addrinfo->ai_addr, 0, sizeof(addrlen)); + struct sockaddr_in* addr = + reinterpret_cast<sockaddr_in*>(addrinfo->ai_addr); + int addrint = ((a1 & 0xff) << 24) | + ((a2 & 0xff) << 16) | + ((a3 & 0xff) << 8) | + ((a4 & 0xff)); + memcpy(&addr->sin_addr, &addrint, sizeof(int)); + addrinfo->ai_next = next; + return addrinfo; + } + void DeleteAddrInfo(struct addrinfo* head) { + if (!head) + return; + struct addrinfo* next; + for (struct addrinfo* a = head; a != NULL; a = next) { + next = a->ai_next; + delete [] a->ai_addr; + delete a; + } + } + + static void SetAddressList(SocketStream* socket, struct addrinfo* head) { + socket->CopyAddrInfo(head); + } +}; + +TEST_F(WebSocketThrottleTest, Throttle) { + WebSocketThrottle::Init(); + DummySocketStreamDelegate delegate; + + WebSocketThrottle* throttle = Singleton<WebSocketThrottle>::get(); + + EXPECT_EQ(throttle, + SocketStreamThrottle::GetSocketStreamThrottleForScheme("ws")); + EXPECT_EQ(throttle, + SocketStreamThrottle::GetSocketStreamThrottleForScheme("wss")); + + // For host1: 1.2.3.4, 1.2.3.5, 1.2.3.6 + struct addrinfo* addr = AddAddr(1, 2, 3, 4, NULL); + addr = AddAddr(1, 2, 3, 5, addr); + addr = AddAddr(1, 2, 3, 6, addr); + scoped_refptr<SocketStream> s1 = + new SocketStream(GURL("ws://host1/"), &delegate); + WebSocketThrottleTest::SetAddressList(s1, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s1; + EXPECT_EQ(OK, throttle->OnStartOpenConnection(s1, &callback_s1)); + + // For host2: 1.2.3.4 + addr = AddAddr(1, 2, 3, 4, NULL); + scoped_refptr<SocketStream> s2 = + new SocketStream(GURL("ws://host2/"), &delegate); + WebSocketThrottleTest::SetAddressList(s2, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s2; + EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s2, &callback_s2)); + + // For host3: 1.2.3.5 + addr = AddAddr(1, 2, 3, 5, NULL); + scoped_refptr<SocketStream> s3 = + new SocketStream(GURL("ws://host3/"), &delegate); + WebSocketThrottleTest::SetAddressList(s3, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s3; + EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s3, &callback_s3)); + + // For host4: 1.2.3.4, 1.2.3.6 + addr = AddAddr(1, 2, 3, 4, NULL); + addr = AddAddr(1, 2, 3, 6, addr); + scoped_refptr<SocketStream> s4 = + new SocketStream(GURL("ws://host4/"), &delegate); + WebSocketThrottleTest::SetAddressList(s4, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s4; + EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s4, &callback_s4)); + + static const char kHeader[] = "HTTP/1.1 101 Web Socket Protocol\r\n"; + EXPECT_EQ(OK, + throttle->OnRead(s1.get(), kHeader, sizeof(kHeader) - 1, NULL)); + EXPECT_FALSE(callback_s2.have_result()); + EXPECT_FALSE(callback_s3.have_result()); + EXPECT_FALSE(callback_s4.have_result()); + + static const char kHeader2[] = + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://www.google.com\r\n" + "WebSocket-Location: ws://websocket.chromium.org\r\n" + "\r\n"; + EXPECT_EQ(OK, + throttle->OnRead(s1.get(), kHeader2, sizeof(kHeader2) - 1, NULL)); + MessageLoopForIO::current()->RunAllPending(); + EXPECT_TRUE(callback_s2.have_result()); + EXPECT_TRUE(callback_s3.have_result()); + EXPECT_FALSE(callback_s4.have_result()); + + throttle->OnClose(s1.get()); + MessageLoopForIO::current()->RunAllPending(); + EXPECT_FALSE(callback_s4.have_result()); + s1->DetachDelegate(); + + throttle->OnClose(s2.get()); + MessageLoopForIO::current()->RunAllPending(); + EXPECT_TRUE(callback_s4.have_result()); + s2->DetachDelegate(); + + throttle->OnClose(s3.get()); + MessageLoopForIO::current()->RunAllPending(); + s3->DetachDelegate(); + throttle->OnClose(s4.get()); + s4->DetachDelegate(); +} + +} |