diff options
Diffstat (limited to 'net/websockets')
-rw-r--r-- | net/websockets/websocket_throttle.cc | 293 | ||||
-rw-r--r-- | net/websockets/websocket_throttle.h | 67 | ||||
-rw-r--r-- | net/websockets/websocket_throttle_unittest.cc | 166 |
3 files changed, 526 insertions, 0 deletions
diff --git a/net/websockets/websocket_throttle.cc b/net/websockets/websocket_throttle.cc new file mode 100644 index 0000000..fb320b6 --- /dev/null +++ b/net/websockets/websocket_throttle.cc @@ -0,0 +1,293 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_throttle.h" + +#if defined(OS_WIN) +#include <ws2tcpip.h> +#else +#include <netdb.h> +#endif + +#include <string> + +#include "base/message_loop.h" +#include "base/ref_counted.h" +#include "base/singleton.h" +#include "base/string_util.h" +#include "net/base/io_buffer.h" +#include "net/socket_stream/socket_stream.h" + +namespace net { + +static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) { + switch (addrinfo->ai_family) { + case AF_INET: { + const struct sockaddr_in* const addr = + reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr); + return StringPrintf("%d:%s", + addrinfo->ai_family, + HexEncode(&addr->sin_addr, 4).c_str()); + } + case AF_INET6: { + const struct sockaddr_in6* const addr6 = + reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr); + return StringPrintf("%d:%s", + addrinfo->ai_family, + HexEncode(&addr6->sin6_addr, + sizeof(addr6->sin6_addr)).c_str()); + } + default: + return StringPrintf("%d:%s", + addrinfo->ai_family, + HexEncode(addrinfo->ai_addr, + addrinfo->ai_addrlen).c_str()); + } +} + +// State for WebSocket protocol on each SocketStream. +// This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName. +// This is alive between connection starts and handshake is finished. +// In this class, it doesn't check actual handshake finishes, but only checks +// end of header is found in read data. +class WebSocketThrottle::WebSocketState : public SocketStream::UserData { + public: + explicit WebSocketState(const AddressList& addrs) + : address_list_(addrs), + callback_(NULL), + waiting_(false), + handshake_finished_(false), + buffer_(NULL) { + } + ~WebSocketState() {} + + int OnStartOpenConnection(CompletionCallback* callback) { + DCHECK(!callback_); + if (!waiting_) + return OK; + callback_ = callback; + return ERR_IO_PENDING; + } + + int OnRead(const char* data, int len, CompletionCallback* callback) { + DCHECK(!waiting_); + DCHECK(!callback_); + DCHECK(!handshake_finished_); + static const int kBufferSize = 8129; + + if (!buffer_) { + // Fast path. + int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0); + if (eoh > 0) { + handshake_finished_ = true; + return OK; + } + buffer_ = new GrowableIOBuffer(); + buffer_->SetCapacity(kBufferSize); + } else { + if (buffer_->RemainingCapacity() < len) { + if (!buffer_->SetCapacity(buffer_->capacity() + kBufferSize)) { + // TODO(ukai): Check more correctly. + // Seek to the last CR or LF and reduce memory usage. + LOG(ERROR) << "Too large headers? capacity=" << buffer_->capacity(); + handshake_finished_ = true; + return OK; + } + } + } + memcpy(buffer_->data(), data, len); + buffer_->set_offset(buffer_->offset() + len); + + int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(), + buffer_->offset(), 0); + handshake_finished_ = (eoh > 0); + return OK; + } + + const AddressList& address_list() const { return address_list_; } + void SetWaiting() { waiting_ = true; } + bool IsWaiting() const { return waiting_; } + bool HandshakeFinished() const { return handshake_finished_; } + void Wakeup() { + waiting_ = false; + // We wrap |callback_| to keep this alive while this is released. + scoped_refptr<CompletionCallbackRunner> runner = + new CompletionCallbackRunner(callback_); + callback_ = NULL; + MessageLoopForIO::current()->PostTask( + FROM_HERE, + NewRunnableMethod(runner.get(), + &CompletionCallbackRunner::Run)); + } + + static const char* kKeyName; + + private: + class CompletionCallbackRunner + : public base::RefCountedThreadSafe<CompletionCallbackRunner> { + public: + explicit CompletionCallbackRunner(CompletionCallback* callback) + : callback_(callback) { + DCHECK(callback_); + } + virtual ~CompletionCallbackRunner() {} + void Run() { + callback_->Run(OK); + } + private: + CompletionCallback* callback_; + + DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner); + }; + + const AddressList& address_list_; + + CompletionCallback* callback_; + // True if waiting another websocket connection is established. + // False if the websocket is performing handshaking. + bool waiting_; + + // True if the websocket handshake is completed. + // If true, it will be removed from queue and deleted from the SocketStream + // UserData soon. + bool handshake_finished_; + + // Buffer for read data to check handshake response message. + scoped_refptr<GrowableIOBuffer> buffer_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketState); +}; + +const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState"; + +WebSocketThrottle::WebSocketThrottle() { + SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this); + SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this); +} + +WebSocketThrottle::~WebSocketThrottle() { + DCHECK(queue_.empty()); + DCHECK(addr_map_.empty()); +} + +int WebSocketThrottle::OnStartOpenConnection( + SocketStream* socket, CompletionCallback* callback) { + WebSocketState* state = new WebSocketState(socket->address_list()); + PutInQueue(socket, state); + return state->OnStartOpenConnection(callback); +} + +int WebSocketThrottle::OnRead(SocketStream* socket, + const char* data, int len, + CompletionCallback* callback) { + WebSocketState* state = static_cast<WebSocketState*>( + socket->GetUserData(WebSocketState::kKeyName)); + // If no state, handshake was already completed. Do nothing. + if (!state) + return OK; + + int result = state->OnRead(data, len, callback); + if (state->HandshakeFinished()) { + RemoveFromQueue(socket, state); + WakeupSocketIfNecessary(); + } + return result; +} + +int WebSocketThrottle::OnWrite(SocketStream* socket, + const char* data, int len, + CompletionCallback* callback) { + // Do nothing. + return OK; +} + +void WebSocketThrottle::OnClose(SocketStream* socket) { + WebSocketState* state = static_cast<WebSocketState*>( + socket->GetUserData(WebSocketState::kKeyName)); + if (!state) + return; + RemoveFromQueue(socket, state); + WakeupSocketIfNecessary(); +} + +void WebSocketThrottle::PutInQueue(SocketStream* socket, + WebSocketState* state) { + socket->SetUserData(WebSocketState::kKeyName, state); + queue_.push_back(state); + const AddressList& address_list = socket->address_list(); + for (const struct addrinfo* addrinfo = address_list.head(); + addrinfo != NULL; + addrinfo = addrinfo->ai_next) { + std::string addrkey = AddrinfoToHashkey(addrinfo); + ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); + if (iter == addr_map_.end()) { + ConnectingQueue* queue = new ConnectingQueue(); + queue->push_back(state); + addr_map_[addrkey] = queue; + } else { + iter->second->push_back(state); + state->SetWaiting(); + } + } +} + +void WebSocketThrottle::RemoveFromQueue(SocketStream* socket, + WebSocketState* state) { + const AddressList& address_list = socket->address_list(); + for (const struct addrinfo* addrinfo = address_list.head(); + addrinfo != NULL; + addrinfo = addrinfo->ai_next) { + std::string addrkey = AddrinfoToHashkey(addrinfo); + ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); + DCHECK(iter != addr_map_.end()); + ConnectingQueue* queue = iter->second; + DCHECK(state == queue->front()); + queue->pop_front(); + if (queue->empty()) + addr_map_.erase(iter); + } + for (ConnectingQueue::iterator iter = queue_.begin(); + iter != queue_.end(); + ++iter) { + if (*iter == state) { + queue_.erase(iter); + break; + } + } + socket->SetUserData(WebSocketState::kKeyName, NULL); +} + +void WebSocketThrottle::WakeupSocketIfNecessary() { + for (ConnectingQueue::iterator iter = queue_.begin(); + iter != queue_.end(); + ++iter) { + WebSocketState* state = *iter; + if (!state->IsWaiting()) + continue; + + bool should_wakeup = true; + const AddressList& address_list = state->address_list(); + for (const struct addrinfo* addrinfo = address_list.head(); + addrinfo != NULL; + addrinfo = addrinfo->ai_next) { + std::string addrkey = AddrinfoToHashkey(addrinfo); + ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); + DCHECK(iter != addr_map_.end()); + ConnectingQueue* queue = iter->second; + if (state != queue->front()) { + should_wakeup = false; + break; + } + } + if (should_wakeup) + state->Wakeup(); + } +} + +/* static */ +void WebSocketThrottle::Init() { + Singleton<WebSocketThrottle>::get(); +} + +} // namespace net diff --git a/net/websockets/websocket_throttle.h b/net/websockets/websocket_throttle.h new file mode 100644 index 0000000..279aea2 --- /dev/null +++ b/net/websockets/websocket_throttle.h @@ -0,0 +1,67 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_ +#define NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_ + +#include "base/hash_tables.h" +#include "base/singleton.h" +#include "net/socket_stream/socket_stream_throttle.h" + +namespace net { + +// SocketStreamThrottle for WebSocket protocol. +// Implements the client-side requirements in the spec. +// http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol +// 4.1 Handshake +// 1. If the user agent already has a Web Socket connection to the +// remote host (IP address) identified by /host/, even if known by +// another name, wait until that connection has been established or +// for that connection to have failed. +class WebSocketThrottle : public SocketStreamThrottle { + public: + virtual int OnStartOpenConnection(SocketStream* socket, + CompletionCallback* callback); + virtual int OnRead(SocketStream* socket, const char* data, int len, + CompletionCallback* callback); + virtual int OnWrite(SocketStream* socket, const char* data, int len, + CompletionCallback* callback); + virtual void OnClose(SocketStream* socket); + + static void Init(); + + private: + class WebSocketState; + typedef std::deque<WebSocketState*> ConnectingQueue; + typedef base::hash_map<std::string, ConnectingQueue*> ConnectingAddressMap; + + WebSocketThrottle(); + virtual ~WebSocketThrottle(); + friend struct DefaultSingletonTraits<WebSocketThrottle>; + + // Puts |socket| in |queue_| and queues for the destination addresses + // of |socket|. Also sets |state| as UserData of |socket|. + // If other socket is using the same destination address, set |state| waiting. + void PutInQueue(SocketStream* socket, WebSocketState* state); + + // Removes |socket| from |queue_| and queues for the destination addresses + // of |socket|. Also releases |state| from UserData of |socket|. + void RemoveFromQueue(SocketStream* socket, WebSocketState* state); + + // Checks sockets waiting in |queue_| and check the socket is the front of + // every queue for the destination addresses of |socket|. + // If so, the socket can resume estabilshing connection, so wake up + // the socket. + void WakeupSocketIfNecessary(); + + // Key: string of host's address. Value: queue of sockets for the address. + ConnectingAddressMap addr_map_; + + // Queue of sockets for websockets in opening state. + ConnectingQueue queue_; +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_THROTTLE_H_ diff --git a/net/websockets/websocket_throttle_unittest.cc b/net/websockets/websocket_throttle_unittest.cc new file mode 100644 index 0000000..3757a0b --- /dev/null +++ b/net/websockets/websocket_throttle_unittest.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "build/build_config.h" + +#if defined(OS_WIN) +#include <ws2tcpip.h> +#else +#include <netdb.h> +#endif + +#include <string> + +#include "base/message_loop.h" +#include "googleurl/src/gurl.h" +#include "net/base/address_list.h" +#include "net/base/test_completion_callback.h" +#include "net/socket_stream/socket_stream.h" +#include "net/websockets/websocket_throttle.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +class DummySocketStreamDelegate : public net::SocketStream::Delegate { + public: + DummySocketStreamDelegate() {} + virtual ~DummySocketStreamDelegate() {} + virtual void OnConnected( + net::SocketStream* socket, int max_pending_send_allowed) {} + virtual void OnSentData(net::SocketStream* socket, int amount_sent) {} + virtual void OnReceivedData(net::SocketStream* socket, + const char* data, int len) {} + virtual void OnClose(net::SocketStream* socket) {} +}; + +namespace net { + +class WebSocketThrottleTest : public PlatformTest { + protected: + struct addrinfo *AddAddr(int a1, int a2, int a3, int a4, + struct addrinfo* next) { + struct addrinfo* addrinfo = new struct addrinfo; + memset(addrinfo, 0, sizeof(struct addrinfo)); + addrinfo->ai_family = AF_INET; + int addrlen = sizeof(struct sockaddr_in); + addrinfo->ai_addrlen = addrlen; + addrinfo->ai_addr = reinterpret_cast<sockaddr*>(new char[addrlen]); + memset(addrinfo->ai_addr, 0, sizeof(addrlen)); + struct sockaddr_in* addr = + reinterpret_cast<sockaddr_in*>(addrinfo->ai_addr); + int addrint = ((a1 & 0xff) << 24) | + ((a2 & 0xff) << 16) | + ((a3 & 0xff) << 8) | + ((a4 & 0xff)); + memcpy(&addr->sin_addr, &addrint, sizeof(int)); + addrinfo->ai_next = next; + return addrinfo; + } + void DeleteAddrInfo(struct addrinfo* head) { + if (!head) + return; + struct addrinfo* next; + for (struct addrinfo* a = head; a != NULL; a = next) { + next = a->ai_next; + delete [] a->ai_addr; + delete a; + } + } + + static void SetAddressList(SocketStream* socket, struct addrinfo* head) { + socket->CopyAddrInfo(head); + } +}; + +TEST_F(WebSocketThrottleTest, Throttle) { + WebSocketThrottle::Init(); + DummySocketStreamDelegate delegate; + + WebSocketThrottle* throttle = Singleton<WebSocketThrottle>::get(); + + EXPECT_EQ(throttle, + SocketStreamThrottle::GetSocketStreamThrottleForScheme("ws")); + EXPECT_EQ(throttle, + SocketStreamThrottle::GetSocketStreamThrottleForScheme("wss")); + + // For host1: 1.2.3.4, 1.2.3.5, 1.2.3.6 + struct addrinfo* addr = AddAddr(1, 2, 3, 4, NULL); + addr = AddAddr(1, 2, 3, 5, addr); + addr = AddAddr(1, 2, 3, 6, addr); + scoped_refptr<SocketStream> s1 = + new SocketStream(GURL("ws://host1/"), &delegate); + WebSocketThrottleTest::SetAddressList(s1, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s1; + EXPECT_EQ(OK, throttle->OnStartOpenConnection(s1, &callback_s1)); + + // For host2: 1.2.3.4 + addr = AddAddr(1, 2, 3, 4, NULL); + scoped_refptr<SocketStream> s2 = + new SocketStream(GURL("ws://host2/"), &delegate); + WebSocketThrottleTest::SetAddressList(s2, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s2; + EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s2, &callback_s2)); + + // For host3: 1.2.3.5 + addr = AddAddr(1, 2, 3, 5, NULL); + scoped_refptr<SocketStream> s3 = + new SocketStream(GURL("ws://host3/"), &delegate); + WebSocketThrottleTest::SetAddressList(s3, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s3; + EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s3, &callback_s3)); + + // For host4: 1.2.3.4, 1.2.3.6 + addr = AddAddr(1, 2, 3, 4, NULL); + addr = AddAddr(1, 2, 3, 6, addr); + scoped_refptr<SocketStream> s4 = + new SocketStream(GURL("ws://host4/"), &delegate); + WebSocketThrottleTest::SetAddressList(s4, addr); + DeleteAddrInfo(addr); + + TestCompletionCallback callback_s4; + EXPECT_EQ(ERR_IO_PENDING, throttle->OnStartOpenConnection(s4, &callback_s4)); + + static const char kHeader[] = "HTTP/1.1 101 Web Socket Protocol\r\n"; + EXPECT_EQ(OK, + throttle->OnRead(s1.get(), kHeader, sizeof(kHeader) - 1, NULL)); + EXPECT_FALSE(callback_s2.have_result()); + EXPECT_FALSE(callback_s3.have_result()); + EXPECT_FALSE(callback_s4.have_result()); + + static const char kHeader2[] = + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://www.google.com\r\n" + "WebSocket-Location: ws://websocket.chromium.org\r\n" + "\r\n"; + EXPECT_EQ(OK, + throttle->OnRead(s1.get(), kHeader2, sizeof(kHeader2) - 1, NULL)); + MessageLoopForIO::current()->RunAllPending(); + EXPECT_TRUE(callback_s2.have_result()); + EXPECT_TRUE(callback_s3.have_result()); + EXPECT_FALSE(callback_s4.have_result()); + + throttle->OnClose(s1.get()); + MessageLoopForIO::current()->RunAllPending(); + EXPECT_FALSE(callback_s4.have_result()); + s1->DetachDelegate(); + + throttle->OnClose(s2.get()); + MessageLoopForIO::current()->RunAllPending(); + EXPECT_TRUE(callback_s4.have_result()); + s2->DetachDelegate(); + + throttle->OnClose(s3.get()); + MessageLoopForIO::current()->RunAllPending(); + s3->DetachDelegate(); + throttle->OnClose(s4.get()); + s4->DetachDelegate(); +} + +} |