summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2009-11-04 10:02:28 +0000
committerukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2009-11-04 10:02:28 +0000
commit4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949 (patch)
tree16237134b9e037cf609ca4211b5c9efbcd8759ee
parent8ecd3aade85b825c8206735f16c23880023782db (diff)
downloadchromium_src-4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949.zip
chromium_src-4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949.tar.gz
chromium_src-4c4eac006bdbf1ad0f3e37db47c488f7c3ee4949.tar.bz2
Implement websocket throttling.
Implement the client-side requirements in the spec. 4.1 Handshake 1. If the user agent already has a Web Socket connection to the remote host (IP address) identified by /host/, even if known by another name, wait until that connection has been established or for that connection to have failed. BUG=none TEST=net_unittests passes Review URL: http://codereview.chromium.org/342052 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@30949 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--chrome/browser/renderer_host/socket_stream_dispatcher_host.cc2
-rw-r--r--net/net.gyp5
-rw-r--r--net/socket_stream/socket_stream.cc60
-rw-r--r--net/socket_stream/socket_stream.h14
-rw-r--r--net/socket_stream/socket_stream_throttle.cc80
-rw-r--r--net/socket_stream/socket_stream_throttle.h81
-rw-r--r--net/websockets/websocket_throttle.cc293
-rw-r--r--net/websockets/websocket_throttle.h67
-rw-r--r--net/websockets/websocket_throttle_unittest.cc166
9 files changed, 743 insertions, 25 deletions
diff --git a/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc b/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc
index ff921ad..68f7d4f 100644
--- a/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc
+++ b/chrome/browser/renderer_host/socket_stream_dispatcher_host.cc
@@ -9,9 +9,11 @@
#include "chrome/common/render_messages.h"
#include "chrome/common/net/socket_stream.h"
#include "ipc/ipc_message.h"
+#include "net/websockets/websocket_throttle.h"
SocketStreamDispatcherHost::SocketStreamDispatcherHost()
: sender_(NULL) {
+ net::WebSocketThrottle::Init();
}
SocketStreamDispatcherHost::~SocketStreamDispatcherHost() {
diff --git a/net/net.gyp b/net/net.gyp
index 23154cb..b80b2a9 100644
--- a/net/net.gyp
+++ b/net/net.gyp
@@ -431,6 +431,8 @@
'socket/tcp_pinger.h',
'socket_stream/socket_stream.cc',
'socket_stream/socket_stream.h',
+ 'socket_stream/socket_stream_throttle.cc',
+ 'socket_stream/socket_stream_throttle.h',
'third_party/parseftp/ParseFTPList.cpp',
'third_party/parseftp/ParseFTPList.h',
'url_request/url_request.cc',
@@ -479,6 +481,8 @@
'url_request/view_cache_helper.h',
'websockets/websocket.cc',
'websockets/websocket.h',
+ 'websockets/websocket_throttle.cc',
+ 'websockets/websocket_throttle.h',
],
'export_dependent_settings': [
'../base/base.gyp:base',
@@ -647,6 +651,7 @@
'url_request/url_request_tracker_unittest.cc',
'url_request/url_request_unittest.cc',
'url_request/url_request_unittest.h',
+ 'websockets/websocket_throttle_unittest.cc',
'websockets/websocket_unittest.cc',
],
'conditions': [
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc
index 62aad61..aa0795c 100644
--- a/net/socket_stream/socket_stream.cc
+++ b/net/socket_stream/socket_stream.cc
@@ -25,6 +25,7 @@
#include "net/socket/socks5_client_socket.h"
#include "net/socket/socks_client_socket.h"
#include "net/socket/tcp_client_socket.h"
+#include "net/socket_stream/socket_stream_throttle.h"
#include "net/url_request/url_request.h"
static const int kMaxPendingSendAllowed = 32768; // 32 kilobytes.
@@ -55,12 +56,16 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate)
write_buf_(NULL),
current_write_buf_(NULL),
write_buf_offset_(0),
- write_buf_size_(0) {
+ write_buf_size_(0),
+ throttle_(
+ SocketStreamThrottle::GetSocketStreamThrottleForScheme(
+ url.scheme())) {
DCHECK(MessageLoop::current()) <<
"The current MessageLoop must exist";
DCHECK_EQ(MessageLoop::TYPE_IO, MessageLoop::current()->type()) <<
"The current MessageLoop must be TYPE_IO";
DCHECK(delegate_);
+ DCHECK(throttle_);
}
SocketStream::~SocketStream() {
@@ -199,6 +204,7 @@ void SocketStream::Finish(int result) {
if (delegate) {
delegate->OnClose(this);
}
+ throttle_->OnClose(this);
Release();
}
@@ -213,6 +219,10 @@ void SocketStream::SetClientSocketFactory(
factory_ = factory;
}
+void SocketStream::CopyAddrInfo(struct addrinfo* head) {
+ addresses_.Copy(head);
+}
+
int SocketStream::DidEstablishConnection() {
if (!socket_.get() || !socket_->IsConnected()) {
next_state_ = STATE_CLOSE;
@@ -226,24 +236,29 @@ int SocketStream::DidEstablishConnection() {
return OK;
}
-void SocketStream::DidReceiveData(int result) {
+int SocketStream::DidReceiveData(int result) {
DCHECK(read_buf_);
DCHECK_GT(result, 0);
- if (!delegate_)
- return;
- // Notify recevied data to delegate.
- delegate_->OnReceivedData(this, read_buf_->data(), result);
+ int len = result;
+ result = throttle_->OnRead(this, read_buf_->data(), len, &io_callback_);
+ if (delegate_) {
+ // Notify recevied data to delegate.
+ delegate_->OnReceivedData(this, read_buf_->data(), len);
+ }
read_buf_ = NULL;
+ return result;
}
-void SocketStream::DidSendData(int result) {
- current_write_buf_ = NULL;
+int SocketStream::DidSendData(int result) {
DCHECK_GT(result, 0);
- if (!delegate_)
- return;
+ int len = result;
+ result = throttle_->OnWrite(this, current_write_buf_->data(), len,
+ &io_callback_);
+ current_write_buf_ = NULL;
+ if (delegate_)
+ delegate_->OnSentData(this, len);
- delegate_->OnSentData(this, result);
- int remaining_size = write_buf_size_ - write_buf_offset_ - result;
+ int remaining_size = write_buf_size_ - write_buf_offset_ - len;
if (remaining_size == 0) {
if (!pending_write_bufs_.empty()) {
write_buf_size_ = pending_write_bufs_.front()->size();
@@ -255,8 +270,9 @@ void SocketStream::DidSendData(int result) {
}
write_buf_offset_ = 0;
} else {
- write_buf_offset_ += result;
+ write_buf_offset_ += len;
}
+ return result;
}
void SocketStream::OnIOCompleted(int result) {
@@ -268,16 +284,14 @@ void SocketStream::OnReadCompleted(int result) {
// 0 indicates end-of-file, so socket was closed.
next_state_ = STATE_CLOSE;
} else if (result > 0 && read_buf_) {
- DidReceiveData(result);
- result = OK;
+ result = DidReceiveData(result);
}
DoLoop(result);
}
void SocketStream::OnWriteCompleted(int result) {
if (result >= 0 && write_buf_) {
- DidSendData(result);
- result = OK;
+ result = DidSendData(result);
}
DoLoop(result);
}
@@ -407,10 +421,12 @@ int SocketStream::DoResolveHost() {
}
int SocketStream::DoResolveHostComplete(int result) {
- if (result == OK)
+ if (result == OK) {
next_state_ = STATE_TCP_CONNECT;
- else
+ result = throttle_->OnStartOpenConnection(this, &io_callback_);
+ } else {
next_state_ = STATE_CLOSE;
+ }
// TODO(ukai): if error occured, reconsider proxy after error.
return result;
}
@@ -680,8 +696,7 @@ int SocketStream::DoReadWrite(int result) {
read_buf_ = new IOBuffer(kReadBufferSize);
result = socket_->Read(read_buf_, kReadBufferSize, &read_callback_);
if (result > 0) {
- DidReceiveData(result);
- return OK;
+ return DidReceiveData(result);
} else if (result == 0) {
// 0 indicates end-of-file, so socket was closed.
next_state_ = STATE_CLOSE;
@@ -705,8 +720,7 @@ int SocketStream::DoReadWrite(int result) {
current_write_buf_->BytesRemaining(),
&write_callback_);
if (result > 0) {
- DidSendData(result);
- return OK;
+ return DidSendData(result);
}
// If write is not pending, return the result and do next loop (to close
// the connection).
diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h
index 37b723a..14dfd36 100644
--- a/net/socket_stream/socket_stream.h
+++ b/net/socket_stream/socket_stream.h
@@ -31,6 +31,7 @@ class ClientSocketFactory;
class HostResolver;
class SSLConfigService;
class SingleRequestHostResolver;
+class SocketStreamThrottle;
// SocketStream is used to implement Web Sockets.
// It provides plain full-duplex stream with proxy and SSL support.
@@ -96,6 +97,7 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
void SetUserData(const void* key, UserData* data);
const GURL& url() const { return url_; }
+ const AddressList& address_list() const { return addresses_; }
Delegate* delegate() const { return delegate_; }
int max_pending_send_allowed() const { return max_pending_send_allowed_; }
@@ -191,14 +193,20 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
friend class base::RefCountedThreadSafe<SocketStream>;
~SocketStream();
+ friend class WebSocketThrottleTest;
+
+ // Copies the given addrinfo list in |addresses_|.
+ // Used for WebSocketThrottleTest.
+ void CopyAddrInfo(struct addrinfo* head);
+
// Finishes the job.
// Calls OnError and OnClose of delegate, and no more
// notifications will be sent to delegate.
void Finish(int result);
int DidEstablishConnection();
- void DidReceiveData(int result);
- void DidSendData(int result);
+ int DidReceiveData(int result);
+ int DidSendData(int result);
void OnIOCompleted(int result);
void OnReadCompleted(int result);
@@ -289,6 +297,8 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
int write_buf_size_;
PendingDataQueue pending_write_bufs_;
+ SocketStreamThrottle* throttle_;
+
DISALLOW_COPY_AND_ASSIGN(SocketStream);
};
diff --git a/net/socket_stream/socket_stream_throttle.cc b/net/socket_stream/socket_stream_throttle.cc
new file mode 100644
index 0000000..6a1d20d
--- /dev/null
+++ b/net/socket_stream/socket_stream_throttle.cc
@@ -0,0 +1,80 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <string>
+
+#include "net/socket_stream/socket_stream_throttle.h"
+
+#include "base/hash_tables.h"
+#include "base/singleton.h"
+#include "net/base/completion_callback.h"
+#include "net/socket_stream/socket_stream.h"
+
+namespace net {
+
+// Default SocketStreamThrottle. No throttling. Used for unknown URL scheme.
+class DefaultSocketStreamThrottle : public SocketStreamThrottle {
+ private:
+ DefaultSocketStreamThrottle() {}
+ virtual ~DefaultSocketStreamThrottle() {}
+ friend struct DefaultSingletonTraits<DefaultSocketStreamThrottle>;
+
+ DISALLOW_COPY_AND_ASSIGN(DefaultSocketStreamThrottle);
+};
+
+class SocketStreamThrottleRegistry {
+ public:
+ SocketStreamThrottle* GetSocketStreamThrottleForScheme(
+ const std::string& scheme);
+
+ void RegisterSocketStreamThrottle(
+ const std::string& scheme, SocketStreamThrottle* throttle);
+
+ private:
+ typedef base::hash_map<std::string, SocketStreamThrottle*> ThrottleMap;
+
+ SocketStreamThrottleRegistry() {}
+ ~SocketStreamThrottleRegistry() {}
+ friend struct DefaultSingletonTraits<SocketStreamThrottleRegistry>;
+
+ ThrottleMap throttles_;
+
+ DISALLOW_COPY_AND_ASSIGN(SocketStreamThrottleRegistry);
+};
+
+SocketStreamThrottle*
+SocketStreamThrottleRegistry::GetSocketStreamThrottleForScheme(
+ const std::string& scheme) {
+ ThrottleMap::const_iterator found = throttles_.find(scheme);
+ if (found == throttles_.end()) {
+ SocketStreamThrottle* throttle =
+ Singleton<DefaultSocketStreamThrottle>::get();
+ throttles_[scheme] = throttle;
+ return throttle;
+ }
+ return found->second;
+}
+
+void SocketStreamThrottleRegistry::RegisterSocketStreamThrottle(
+ const std::string& scheme, SocketStreamThrottle* throttle) {
+ throttles_[scheme] = throttle;
+}
+
+/* static */
+SocketStreamThrottle* SocketStreamThrottle::GetSocketStreamThrottleForScheme(
+ const std::string& scheme) {
+ SocketStreamThrottleRegistry* registry =
+ Singleton<SocketStreamThrottleRegistry>::get();
+ return registry->GetSocketStreamThrottleForScheme(scheme);
+}
+
+/* static */
+void SocketStreamThrottle::RegisterSocketStreamThrottle(
+ const std::string& scheme, SocketStreamThrottle* throttle) {
+ SocketStreamThrottleRegistry* registry =
+ Singleton<SocketStreamThrottleRegistry>::get();
+ registry->RegisterSocketStreamThrottle(scheme, throttle);
+}
+
+} // namespace net
diff --git a/net/socket_stream/socket_stream_throttle.h b/net/socket_stream/socket_stream_throttle.h
new file mode 100644
index 0000000..7726cbe
--- /dev/null
+++ b/net/socket_stream/socket_stream_throttle.h
@@ -0,0 +1,81 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_STREAM_SOCKET_STREAM_THROTTLE_H_
+#define NET_SOCKET_STREAM_SOCKET_STREAM_THROTTLE_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_errors.h"
+
+namespace net {
+
+class SocketStream;
+
+// Abstract interface to throttle SocketStream per URL scheme.
+// Each URL scheme (protocol) could define own SocketStreamThrottle.
+// These methods will be called on IO thread.
+class SocketStreamThrottle {
+ public:
+ // Called when |socket| is about to open connection.
+ // Returns net::OK if the connection can open now.
+ // Returns net::ERR_IO_PENDING if the connection should wait. In this case,
+ // |callback| will be called when it's ready to open connection.
+ virtual int OnStartOpenConnection(SocketStream* socket,
+ CompletionCallback* callback) {
+ // No throttle by default.
+ return OK;
+ }
+
+ // Called when |socket| read |len| bytes of |data|.
+ // May wake up another waiting socket.
+ // Returns net::OK if |socket| can continue to run.
+ // Returns net::ERR_IO_PENDING if |socket| should suspend to run. In this
+ // case, |callback| will be called when it's ready to resume running.
+ virtual int OnRead(SocketStream* socket, const char* data, int len,
+ CompletionCallback* callback) {
+ // No throttle by default.
+ return OK;
+ }
+
+ // Called when |socket| wrote |len| bytes of |data|.
+ // May wake up another waiting socket.
+ // Returns net::OK if |socket| can continue to run.
+ // Returns net::ERR_IO_PENDING if |socket| should suspend to run. In this
+ // case, |callback| will be called when it's ready to resume running.
+ virtual int OnWrite(SocketStream* socket, const char* data, int len,
+ CompletionCallback* callback) {
+ // No throttle by default.
+ return OK;
+ }
+
+ // Called when |socket| is closed.
+ // May wake up another waiting socket.
+ virtual void OnClose(SocketStream* socket) {}
+
+ // Gets SocketStreamThrottle for URL |scheme|.
+ // Doesn't pass ownership of the SocketStreamThrottle.
+ static SocketStreamThrottle* GetSocketStreamThrottleForScheme(
+ const std::string& scheme);
+
+ // Registers |throttle| for URL |scheme|.
+ // Doesn't take ownership of |throttle|. Typically |throttle| is
+ // singleton instance.
+ static void RegisterSocketStreamThrottle(
+ const std::string& scheme,
+ SocketStreamThrottle* throttle);
+
+ protected:
+ SocketStreamThrottle() {}
+ virtual ~SocketStreamThrottle() {}
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(SocketStreamThrottle);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_STREAM_SOCKET_STREAM_THROTTLE_H_
diff --git a/net/websockets/websocket_throttle.cc b/net/websockets/websocket_throttle.cc
new file mode 100644
index 0000000..fb320b6
--- /dev/null
+++ b/net/websockets/websocket_throttle.cc
@@ -0,0 +1,293 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/websockets/websocket_throttle.h"
+
+#if defined(OS_WIN)
+#include <ws2tcpip.h>
+#else
+#include <netdb.h>
+#endif
+
+#include <string>
+
+#include "base/message_loop.h"
+#include "base/ref_counted.h"
+#include "base/singleton.h"
+#include "base/string_util.h"
+#include "net/base/io_buffer.h"
+#include "net/socket_stream/socket_stream.h"
+
+namespace net {
+
+static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) {
+ switch (addrinfo->ai_family) {
+ case AF_INET: {
+ const struct sockaddr_in* const addr =
+ reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr);
+ return StringPrintf("%d:%s",
+ addrinfo->ai_family,
+ HexEncode(&addr->sin_addr, 4).c_str());
+ }
+ case AF_INET6: {
+ const struct sockaddr_in6* const addr6 =
+ reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr);
+ return StringPrintf("%d:%s",
+ addrinfo->ai_family,
+ HexEncode(&addr6->sin6_addr,
+ sizeof(addr6->sin6_addr)).c_str());
+ }
+ default:
+ return StringPrintf("%d:%s",
+ addrinfo->ai_family,
+ HexEncode(addrinfo->ai_addr,
+ addrinfo->ai_addrlen).c_str());
+ }
+}
+
+// State for WebSocket protocol on each SocketStream.
+// This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName.
+// This is alive between connection starts and handshake is finished.
+// In this class, it doesn't check actual handshake finishes, but only checks
+// end of header is found in read data.
+class WebSocketThrottle::WebSocketState : public SocketStream::UserData {
+ public:
+ explicit WebSocketState(const AddressList& addrs)
+ : address_list_(addrs),
+ callback_(NULL),
+ waiting_(false),
+ handshake_finished_(false),
+ buffer_(NULL) {
+ }
+ ~WebSocketState() {}
+
+ int OnStartOpenConnection(CompletionCallback* callback) {
+ DCHECK(!callback_);
+ if (!waiting_)
+ return OK;
+ callback_ = callback;
+ return ERR_IO_PENDING;
+ }
+
+ int OnRead(const char* data, int len, CompletionCallback* callback) {
+ DCHECK(!waiting_);
+ DCHECK(!callback_);
+ DCHECK(!handshake_finished_);
+ static const int kBufferSize = 8129;
+
+ if (!buffer_) {
+ // Fast path.
+ int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0);
+ if (eoh > 0) {
+ handshake_finished_ = true;
+ return OK;
+ }
+ buffer_ = new GrowableIOBuffer();
+ buffer_->SetCapacity(kBufferSize);
+ } else {
+ if (buffer_->RemainingCapacity() < len) {
+ if (!buffer_->SetCapacity(buffer_->capacity() + kBufferSize)) {
+ // TODO(ukai): Check more correctly.
+ // Seek to the last CR or LF and reduce memory usage.
+ LOG(ERROR) << "Too large headers? capacity=" << buffer_->capacity();
+ handshake_finished_ = true;
+ return OK;
+ }
+ }
+ }
+ memcpy(buffer_->data(), data, len);
+ buffer_->set_offset(buffer_->offset() + len);
+
+ int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(),
+ buffer_->offset(), 0);
+ handshake_finished_ = (eoh > 0);
+ return OK;
+ }
+
+ const AddressList& address_list() const { return address_list_; }
+ void SetWaiting() { waiting_ = true; }
+ bool IsWaiting() const { return waiting_; }
+ bool HandshakeFinished() const { return handshake_finished_; }
+ void Wakeup() {
+ waiting_ = false;
+ // We wrap |callback_| to keep this alive while this is released.
+ scoped_refptr<CompletionCallbackRunner> runner =
+ new CompletionCallbackRunner(callback_);
+ callback_ = NULL;
+ MessageLoopForIO::current()->PostTask(
+ FROM_HERE,
+ NewRunnableMethod(runner.get(),
+ &CompletionCallbackRunner::Run));
+ }
+
+ static const char* kKeyName;
+
+ private:
+ class CompletionCallbackRunner
+ : public base::RefCountedThreadSafe<CompletionCallbackRunner> {
+ public:
+ explicit CompletionCallbackRunner(CompletionCallback* callback)
+ : callback_(callback) {
+ DCHECK(callback_);
+ }
+ virtual ~CompletionCallbackRunner() {}
+ void Run() {
+ callback_->Run(OK);
+ }
+ private:
+ CompletionCallback* callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner);
+ };
+
+ const AddressList& address_list_;
+
+ CompletionCallback* callback_;
+ // True if waiting another websocket connection is established.
+ // False if the websocket is performing handshaking.
+ bool waiting_;
+
+ // True if the websocket handshake is completed.
+ // If true, it will be removed from queue and deleted from the SocketStream
+ // UserData soon.
+ bool handshake_finished_;
+
+ // Buffer for read data to check handshake response message.
+ scoped_refptr<GrowableIOBuffer> buffer_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketState);
+};
+
+const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState";
+
+WebSocketThrottle::WebSocketThrottle() {
+ SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this);
+ SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this);
+}
+
+WebSocketThrottle::~WebSocketThrottle() {
+ DCHECK(queue_.empty());
+ DCHECK(addr_map_.empty());
+}
+
+int WebSocketThrottle::OnStartOpenConnection(
+ SocketStream* socket, CompletionCallback* callback) {
+ WebSocketState* state = new WebSocketState(socket->address_list());
+ PutInQueue(socket, state);
+ return state->OnStartOpenConnection(callback);
+}
+
+int WebSocketThrottle::OnRead(SocketStream* socket,
+ const char* data, int len,
+ CompletionCallback* callback) {
+ WebSocketState* state = static_cast<WebSocketState*>(
+ socket->GetUserData(WebSocketState::kKeyName));
+ // If no state, handshake was already completed. Do nothing.
+ if (!state)
+ return OK;
+
+ int result = state->OnRead(data, len, callback);
+ if (state->HandshakeFinished()) {
+ RemoveFromQueue(socket, state);
+ WakeupSocketIfNecessary();
+ }
+ return result;
+}
+
+int WebSocketThrottle::OnWrite(SocketStream* socket,
+ const char* data, int len,
+ CompletionCallback* callback) {
+ // Do nothing.
+ return OK;
+}
+
+void WebSocketThrottle::OnClose(SocketStream* socket) {
+ WebSocketState* state = static_cast<WebSocketState*>(
+ socket->GetUserData(WebSocketState::kKeyName));
+ if (!state)
+ return;
+ RemoveFromQueue(socket, state);
+ WakeupSocketIfNecessary();
+}
+
+void WebSocketThrottle::PutInQueue(SocketStream* socket,
+ WebSocketState* state) {
+ socket->SetUserData(WebSocketState::kKeyName, state);
+ queue_.push_back(state);
+ const AddressList& address_list = socket->address_list();
+ for (const struct addrinfo* addrinfo = address_list.head();
+ addrinfo != NULL;
+ addrinfo = addrinfo->ai_next) {
+ std::string addrkey = AddrinfoToHashkey(addrinfo);
+ ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
+ if (iter == addr_map_.end()) {
+ ConnectingQueue* queue = new ConnectingQueue();
+ queue->push_back(state);
+ addr_map_[addrkey] = queue;
+ } else {
+ iter->second->push_back(state);
+ state->SetWaiting();
+ }
+ }
+}
+
+void WebSocketThrottle::RemoveFromQueue(SocketStream* socket,
+ WebSocketState* state) {
+ const AddressList& address_list = socket->address_list();
+ for (const struct addrinfo* addrinfo = address_list.head();
+ addrinfo != NULL;
+ addrinfo = addrinfo->ai_next) {
+ std::string addrkey = AddrinfoToHashkey(addrinfo);
+ ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
+ DCHECK(iter != addr_map_.end());
+ ConnectingQueue* queue = iter->second;
+ DCHECK(state == queue->front());
+ queue->pop_front();
+ if (queue->empty())
+ addr_map_.erase(iter);
+ }
+ for (ConnectingQueue::iterator iter = queue_.begin();
+ iter != queue_.end();
+ ++iter) {
+ if (*iter == state) {
+ queue_.erase(iter);
+ break;
+ }
+ }
+ socket->SetUserData(WebSocketState::kKeyName, NULL);
+}
+
+void WebSocketThrottle::WakeupSocketIfNecessary() {
+ for (ConnectingQueue::iterator iter = queue_.begin();
+ iter != queue_.end();
+ ++iter) {
+ WebSocketState* state = *iter;
+ if (!state->IsWaiting())
+ continue;
+
+ bool should_wakeup = true;
+ const AddressList& address_list = state->address_list();
+ for (const struct addrinfo* addrinfo = address_list.head();
+ addrinfo != NULL;
+ addrinfo = addrinfo->ai_next) {
+ std::string addrkey = AddrinfoToHashkey(addrinfo);
+ ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
+ DCHECK(iter != addr_map_.end());
+ ConnectingQueue* queue = iter->second;
+ if (state != queue->front()) {
+ should_wakeup = false;
+ break;
+ }
+ }
+ if (should_wakeup)
+ state->Wakeup();
+ }
+}
+
+/* static */
+void WebSocketThrottle::Init() {
+ Singleton<WebSocketThrottle>::get();
+}
+
+} // namespace net
diff --git a/net/websockets/websocket_throttle.h b/net/websockets/websocket_throttle.h
new file mode 100644
index 0000000..279aea2
--- /dev/null
+++ b/net/websockets/websocket_throttle.h
@@ -0,0 +1,67 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_
+#define NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_
+
+#include "base/hash_tables.h"
+#include "base/singleton.h"
+#include "net/socket_stream/socket_stream_throttle.h"
+
+namespace net {
+
+// SocketStreamThrottle for WebSocket protocol.
+// Implements the client-side requirements in the spec.
+// http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol
+// 4.1 Handshake
+// 1. If the user agent already has a Web Socket connection to the
+// remote host (IP address) identified by /host/, even if known by
+// another name, wait until that connection has been established or
+// for that connection to have failed.
+class WebSocketThrottle : public SocketStreamThrottle {
+ public:
+ virtual int OnStartOpenConnection(SocketStream* socket,
+ CompletionCallback* callback);
+ virtual int OnRead(SocketStream* socket, const char* data, int len,
+ CompletionCallback* callback);
+ virtual int OnWrite(SocketStream* socket, const char* data, int len,
+ CompletionCallback* callback);
+ virtual void OnClose(SocketStream* socket);
+
+ static void Init();
+
+ private:
+ class WebSocketState;
+ typedef std::deque<WebSocketState*> ConnectingQueue;
+ typedef base::hash_map<std::string, ConnectingQueue*> ConnectingAddressMap;
+
+ WebSocketThrottle();
+ virtual ~WebSocketThrottle();
+ friend struct DefaultSingletonTraits<WebSocketThrottle>;
+
+ // Puts |socket| in |queue_| and queues for the destination addresses
+ // of |socket|. Also sets |state| as UserData of |socket|.
+ // If other socket is using the same destination address, set |state| waiting.
+ void PutInQueue(SocketStream* socket, WebSocketState* state);
+
+ // Removes |socket| from |queue_| and queues for the destination addresses
+ // of |socket|. Also releases |state| from UserData of |socket|.
+ void RemoveFromQueue(SocketStream* socket, WebSocketState* state);
+
+ // Checks sockets waiting in |queue_| and check the socket is the front of
+ // every queue for the destination addresses of |socket|.
+ // If so, the socket can resume estabilshing connection, so wake up
+ // the socket.
+ void WakeupSocketIfNecessary();
+
+ // Key: string of host's address. Value: queue of sockets for the address.
+ ConnectingAddressMap addr_map_;
+
+ // Queue of sockets for websockets in opening state.
+ ConnectingQueue queue_;
+};
+
+} // namespace net
+
+#endif // NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_
diff --git a/net/websockets/websocket_throttle_unittest.cc b/net/websockets/websocket_throttle_unittest.cc
new file mode 100644
index 0000000..3757a0b
--- /dev/null
+++ b/net/websockets/websocket_throttle_unittest.cc
@@ -0,0 +1,166 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "build/build_config.h"
+
+#if defined(OS_WIN)
+#include <ws2tcpip.h>
+#else
+#include <netdb.h>
+#endif
+
+#include <string>
+
+#include "base/message_loop.h"
+#include "googleurl/src/gurl.h"
+#include "net/base/address_list.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket_stream/socket_stream.h"
+#include "net/websockets/websocket_throttle.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+class DummySocketStreamDelegate : public net::SocketStream::Delegate {
+ public:
+ DummySocketStreamDelegate() {}
+ virtual ~DummySocketStreamDelegate() {}
+ virtual void OnConnected(
+ net::SocketStream* socket, int max_pending_send_allowed) {}
+ virtual void OnSentData(net::SocketStream* socket, int amount_sent) {}
+ virtual void OnReceivedData(net::SocketStream* socket,
+ const char* data, int len) {}
+ virtual void OnClose(net::SocketStream* socket) {}
+};
+
+namespace net {
+
+class WebSocketThrottleTest : public PlatformTest {
+ protected:
+ struct addrinfo *AddAddr(int a1, int a2, int a3, int a4,
+ struct addrinfo* next) {
+ struct addrinfo* addrinfo = new struct addrinfo;
+ memset(addrinfo, 0, sizeof(struct addrinfo));
+ addrinfo->ai_family = AF_INET;
+ int addrlen = sizeof(struct sockaddr_in);
+ addrinfo->ai_addrlen = addrlen;
+ addrinfo->ai_addr = reinterpret_cast<sockaddr*>(new char[addrlen]);
+ memset(addrinfo->ai_addr, 0, sizeof(addrlen));
+ struct sockaddr_in* addr =
+ reinterpret_cast<sockaddr_in*>(addrinfo->ai_addr);
+ int addrint = ((a1 & 0xff) << 24) |
+ ((a2 & 0xff) << 16) |
+ ((a3 & 0xff) << 8) |
+ ((a4 & 0xff));
+ memcpy(&addr->sin_addr, &addrint, sizeof(int));
+ addrinfo->ai_next = next;
+ return addrinfo;
+ }
+ void DeleteAddrInfo(struct addrinfo* head) {
+ if (!head)
+ return;
+ struct addrinfo* next;
+ for (struct addrinfo* a = head; a != NULL; a = next) {
+ next = a->ai_next;
+ delete [] a->ai_addr;
+ delete a;
+ }
+ }
+
+ static void SetAddressList(SocketStream* socket, struct addrinfo* head) {
+ socket->CopyAddrInfo(head);
+ }
+};
+
+TEST_F(WebSocketThrottleTest, Throttle) {
+ WebSocketThrottle::Init();
+ DummySocketStreamDelegate delegate;
+
+ WebSocketThrottle* throttle = Singleton<WebSocketThrottle>::get();
+
+ EXPECT_EQ(throttle,
+ SocketStreamThrottle::GetSocketStreamThrottleForScheme("ws"));
+ EXPECT_EQ(throttle,
+ SocketStreamThrottle::GetSocketStreamThrottleForScheme("wss"));
+
+ // For host1: 1.2.3.4, 1.2.3.5, 1.2.3.6
+ struct addrinfo* addr = AddAddr(1, 2, 3, 4, NULL);
+ addr = AddAddr(1, 2, 3, 5, addr);
+ addr = AddAddr(1, 2, 3, 6, addr);
+ scoped_refptr<SocketStream> s1 =
+ new SocketStream(GURL("ws://host1/"), &delegate);
+ WebSocketThrottleTest::SetAddressList(s1, addr);
+ DeleteAddrInfo(addr);
+
+ TestCompletionCallback callback_s1;
+ EXPECT_EQ(OK, throttle->OnStartOpenConnection(s1, &callback_s1));
+
+ // For host2: 1.2.3.4
+ addr = AddAddr(1, 2, 3, 4, NULL);
+ scoped_refptr<SocketStream> s2 =
+ new SocketStream(GURL("ws://host2/"), &delegate);
+ WebSocketThrottleTest::SetAddressList(s2, addr);
+ DeleteAddrInfo(addr);
+
+ TestCompletionCallback callback_s2;
+ EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s2, &callback_s2));
+
+ // For host3: 1.2.3.5
+ addr = AddAddr(1, 2, 3, 5, NULL);
+ scoped_refptr<SocketStream> s3 =
+ new SocketStream(GURL("ws://host3/"), &delegate);
+ WebSocketThrottleTest::SetAddressList(s3, addr);
+ DeleteAddrInfo(addr);
+
+ TestCompletionCallback callback_s3;
+ EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s3, &callback_s3));
+
+ // For host4: 1.2.3.4, 1.2.3.6
+ addr = AddAddr(1, 2, 3, 4, NULL);
+ addr = AddAddr(1, 2, 3, 6, addr);
+ scoped_refptr<SocketStream> s4 =
+ new SocketStream(GURL("ws://host4/"), &delegate);
+ WebSocketThrottleTest::SetAddressList(s4, addr);
+ DeleteAddrInfo(addr);
+
+ TestCompletionCallback callback_s4;
+ EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s4, &callback_s4));
+
+ static const char kHeader[] = "HTTP/1.1 101 Web Socket Protocol\r\n";
+ EXPECT_EQ(OK,
+ throttle->OnRead(s1.get(), kHeader, sizeof(kHeader) - 1, NULL));
+ EXPECT_FALSE(callback_s2.have_result());
+ EXPECT_FALSE(callback_s3.have_result());
+ EXPECT_FALSE(callback_s4.have_result());
+
+ static const char kHeader2[] =
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "WebSocket-Origin: http://www.google.com\r\n"
+ "WebSocket-Location: ws://websocket.chromium.org\r\n"
+ "\r\n";
+ EXPECT_EQ(OK,
+ throttle->OnRead(s1.get(), kHeader2, sizeof(kHeader2) - 1, NULL));
+ MessageLoopForIO::current()->RunAllPending();
+ EXPECT_TRUE(callback_s2.have_result());
+ EXPECT_TRUE(callback_s3.have_result());
+ EXPECT_FALSE(callback_s4.have_result());
+
+ throttle->OnClose(s1.get());
+ MessageLoopForIO::current()->RunAllPending();
+ EXPECT_FALSE(callback_s4.have_result());
+ s1->DetachDelegate();
+
+ throttle->OnClose(s2.get());
+ MessageLoopForIO::current()->RunAllPending();
+ EXPECT_TRUE(callback_s4.have_result());
+ s2->DetachDelegate();
+
+ throttle->OnClose(s3.get());
+ MessageLoopForIO::current()->RunAllPending();
+ s3->DetachDelegate();
+ throttle->OnClose(s4.get());
+ s4->DetachDelegate();
+}
+
+}