diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/net.gyp | 7 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.cc | 4 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.h | 21 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_job.cc | 27 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_job.h | 87 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_job_manager.cc | 59 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_job_manager.h | 40 | ||||
-rw-r--r-- | net/websockets/websocket_job.cc | 378 | ||||
-rw-r--r-- | net/websockets/websocket_job.h | 98 | ||||
-rw-r--r-- | net/websockets/websocket_job_unittest.cc | 495 |
10 files changed, 1205 insertions, 11 deletions
diff --git a/net/net.gyp b/net/net.gyp index b359f6a..dffd51f 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -441,6 +441,10 @@ 'socket/tcp_pinger.h', 'socket_stream/socket_stream.cc', 'socket_stream/socket_stream.h', + 'socket_stream/socket_stream_job.cc', + 'socket_stream/socket_stream_job.h', + 'socket_stream/socket_stream_job_manager.cc', + 'socket_stream/socket_stream_job_manager.h', 'socket_stream/socket_stream_metrics.cc', 'socket_stream/socket_stream_metrics.h', 'socket_stream/socket_stream_throttle.cc', @@ -503,6 +507,8 @@ 'url_request/view_cache_helper.h', 'websockets/websocket.cc', 'websockets/websocket.h', + 'websockets/websocket_job.cc', + 'websockets/websocket_job.h', 'websockets/websocket_throttle.cc', 'websockets/websocket_throttle.h', ], @@ -689,6 +695,7 @@ 'url_request/request_tracker_unittest.cc', 'url_request/url_request_unittest.cc', 'url_request/url_request_unittest.h', + 'websockets/websocket_job_unittest.cc', 'websockets/websocket_throttle_unittest.cc', 'websockets/websocket_unittest.cc', ], diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index 162e7f3..04c68b0 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -39,8 +39,8 @@ void SocketStream::ResponseHeaders::Realloc(size_t new_size) { } SocketStream::SocketStream(const GURL& url, Delegate* delegate) - : url_(url), - delegate_(delegate), + : delegate_(delegate), + url_(url), max_pending_send_allowed_(kMaxPendingSendAllowed), next_state_(STATE_NONE), http_auth_handler_factory_(NULL), diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index 5b1ae3e..1334c15 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -101,6 +101,7 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { void SetUserData(const void* key, UserData* data); const GURL& url() const { return url_; } + bool is_secure() const; const AddressList& address_list() const { return addresses_; } Delegate* delegate() const { return delegate_; } int max_pending_send_allowed() const { return max_pending_send_allowed_; } @@ -112,28 +113,28 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { // Opens the connection on the IO thread. // Once the connection is established, calls delegate's OnConnected. - void Connect(); + virtual void Connect(); // Requests to send |len| bytes of |data| on the connection. // Returns true if |data| is buffered in the job. // Returns false if size of buffered data would exceeds // |max_pending_send_allowed_| and |data| is not sent at all. - bool SendData(const char* data, int len); + virtual bool SendData(const char* data, int len); // Requests to close the connection. // Once the connection is closed, calls delegate's OnClose. - void Close(); + virtual void Close(); // Restarts with authentication info. // Should be used for response of OnAuthRequired. - void RestartWithAuth( + virtual void RestartWithAuth( const std::wstring& username, const std::wstring& password); // Detach delegate. Call before delegate is deleted. // Once delegate is detached, close the socket stream and never call delegate // back. - void DetachDelegate(); + virtual void DetachDelegate(); // Sets an alternative HostResolver. For testing purposes only. void SetHostResolver(HostResolver* host_resolver); @@ -142,6 +143,12 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { // |factory|. For testing purposes only. void SetClientSocketFactory(ClientSocketFactory* factory); + protected: + friend class base::RefCountedThreadSafe<SocketStream>; + ~SocketStream(); + + Delegate* delegate_; + private: class RequestHeaders : public IOBuffer { public: @@ -201,8 +208,6 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { typedef std::deque< scoped_refptr<IOBufferWithSize> > PendingDataQueue; friend class RequestTracker<SocketStream>; - friend class base::RefCountedThreadSafe<SocketStream>; - ~SocketStream(); friend class WebSocketThrottleTest; @@ -248,7 +253,6 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { int HandleCertificateError(int result); - bool is_secure() const; SSLConfigService* ssl_config_service() const; ProxyService* proxy_service() const; @@ -258,7 +262,6 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { scoped_refptr<LoadLog> load_log_; GURL url_; - Delegate* delegate_; int max_pending_send_allowed_; scoped_refptr<URLRequestContext> context_; diff --git a/net/socket_stream/socket_stream_job.cc b/net/socket_stream/socket_stream_job.cc new file mode 100644 index 0000000..c8849a5 --- /dev/null +++ b/net/socket_stream/socket_stream_job.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket_stream/socket_stream_job.h" + +#include "net/socket_stream/socket_stream_job_manager.h" + +namespace net { + +static SocketStreamJobManager* GetJobManager() { + return Singleton<SocketStreamJobManager>::get(); +} + +// static +SocketStreamJob::ProtocolFactory* SocketStreamJob::RegisterProtocolFactory( + const std::string& scheme, ProtocolFactory* factory) { + return GetJobManager()->RegisterProtocolFactory(scheme, factory); +} + +// static +SocketStreamJob* SocketStreamJob::CreateSocketStreamJob( + const GURL& url, SocketStream::Delegate* delegate) { + return GetJobManager()->CreateJob(url, delegate); +} + +} // namespace net diff --git a/net/socket_stream/socket_stream_job.h b/net/socket_stream/socket_stream_job.h new file mode 100644 index 0000000..618620c --- /dev/null +++ b/net/socket_stream/socket_stream_job.h @@ -0,0 +1,87 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_STREAM_SOCKET_STREAM_JOB_H_ +#define NET_SOCKET_STREAM_SOCKET_STREAM_JOB_H_ + +#include <string> + +#include "base/ref_counted.h" +#include "net/socket_stream/socket_stream.h" + +class GURL; + +namespace net { + +// SocketStreamJob represents full-duplex communication over SocketStream. +// If a protocol (e.g. WebSocket protocol) needs to inspect/modify data +// over SocketStream, you can implement protocol specific job (e.g. +// WebSocketJob) to do some work on data over SocketStream. +// Registers the protocol specific SocketStreamJob by RegisterProtocolFactory +// and call CreateSocketStreamJob to create SocketStreamJob for the URL. +class SocketStreamJob : public base::RefCountedThreadSafe<SocketStreamJob> { + public: + // Callback function implemented by protocol handlers to create new jobs. + typedef SocketStreamJob* (ProtocolFactory)(const GURL& url, + SocketStream::Delegate* delegate); + + static ProtocolFactory* RegisterProtocolFactory(const std::string& scheme, + ProtocolFactory* factory); + + static SocketStreamJob* CreateSocketStreamJob( + const GURL& url, SocketStream::Delegate* delegate); + + SocketStreamJob() {} + void InitSocketStream(SocketStream* socket) { + socket_ = socket; + } + + virtual SocketStream::UserData *GetUserData(const void* key) const { + return socket_->GetUserData(key); + } + virtual void SetUserData(const void* key, SocketStream::UserData* data) { + socket_->SetUserData(key, data); + } + + URLRequestContext* context() const { + return socket_->context(); + } + void set_context(URLRequestContext* context) { + socket_->set_context(context); + } + + virtual void Connect() { + socket_->Connect(); + } + + virtual bool SendData(const char* data, int len) { + return socket_->SendData(data, len); + } + + virtual void Close() { + socket_->Close(); + } + + virtual void RestartWithAuth( + const std::wstring& username, + const std::wstring& password) { + socket_->RestartWithAuth(username, password); + } + + virtual void DetachDelegate() { + socket_->DetachDelegate(); + } + + protected: + friend class base::RefCountedThreadSafe<SocketStreamJob>; + virtual ~SocketStreamJob() {} + + scoped_refptr<SocketStream> socket_; + + DISALLOW_COPY_AND_ASSIGN(SocketStreamJob); +}; + +} // namespace net + +#endif // NET_SOCKET_STREAM_SOCKET_STREAM_JOB_H_ diff --git a/net/socket_stream/socket_stream_job_manager.cc b/net/socket_stream/socket_stream_job_manager.cc new file mode 100644 index 0000000..7dd0d6b --- /dev/null +++ b/net/socket_stream/socket_stream_job_manager.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket_stream/socket_stream_job_manager.h" + +namespace net { + +SocketStreamJobManager::SocketStreamJobManager() { +} + +SocketStreamJobManager::~SocketStreamJobManager() { +} + +SocketStreamJob* SocketStreamJobManager::CreateJob( + const GURL& url, SocketStream::Delegate* delegate) const { + // If url is invalid, create plain SocketStreamJob, which will close + // the socket immediately. + if (!url.is_valid()) { + SocketStreamJob* job = new SocketStreamJob(); + job->InitSocketStream(new SocketStream(url, delegate)); + return job; + } + + const std::string& scheme = url.scheme(); // already lowercase + + AutoLock locked(lock_); + FactoryMap::const_iterator found = factories_.find(scheme); + if (found != factories_.end()) { + SocketStreamJob* job = found->second(url, delegate); + if (job) + return job; + } + SocketStreamJob* job = new SocketStreamJob(); + job->InitSocketStream(new SocketStream(url, delegate)); + return job; +} + +SocketStreamJob::ProtocolFactory* +SocketStreamJobManager::RegisterProtocolFactory( + const std::string& scheme, SocketStreamJob::ProtocolFactory* factory) { + AutoLock locked(lock_); + + SocketStreamJob::ProtocolFactory* old_factory; + FactoryMap::iterator found = factories_.find(scheme); + if (found != factories_.end()) { + old_factory = found->second; + } else { + old_factory = NULL; + } + if (factory) { + factories_[scheme] = factory; + } else if (found != factories_.end()) { + factories_.erase(found); + } + return old_factory; +} + +} // namespace net diff --git a/net/socket_stream/socket_stream_job_manager.h b/net/socket_stream/socket_stream_job_manager.h new file mode 100644 index 0000000..17ff833 --- /dev/null +++ b/net/socket_stream/socket_stream_job_manager.h @@ -0,0 +1,40 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_STREAM_SOCKET_STREAM_JOB_MANAGER_H_ +#define NET_SOCKET_STREAM_SOCKET_STREAM_JOB_MANAGER_H_ + +#include <map> +#include <string> + +#include "net/socket_stream/socket_stream.h" +#include "net/socket_stream/socket_stream_job.h" + +class GURL; + +namespace net { + +class SocketStreamJobManager { + public: + SocketStreamJobManager(); + ~SocketStreamJobManager(); + + SocketStreamJob* CreateJob( + const GURL& url, SocketStream::Delegate* delegate) const; + + SocketStreamJob::ProtocolFactory* RegisterProtocolFactory( + const std::string& scheme, SocketStreamJob::ProtocolFactory* factory); + + private: + typedef std::map<std::string, SocketStreamJob::ProtocolFactory*> FactoryMap; + + mutable Lock lock_; + FactoryMap factories_; + + DISALLOW_COPY_AND_ASSIGN(SocketStreamJobManager); +}; + +} // namespace net + +#endif // NET_SOCKET_STREAM_SOCKET_STREAM_JOB_MANAGER_H_ diff --git a/net/websockets/websocket_job.cc b/net/websockets/websocket_job.cc new file mode 100644 index 0000000..3ba0c36 --- /dev/null +++ b/net/websockets/websocket_job.cc @@ -0,0 +1,378 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_job.h" + +#include "googleurl/src/gurl.h" +#include "net/base/net_errors.h" +#include "net/base/cookie_policy.h" +#include "net/base/cookie_store.h" +#include "net/http/http_util.h" +#include "net/url_request/url_request_context.h" + +namespace net { + +// lower-case header names. +static const char* const kCookieHeaders[] = { + "cookie", "cookie2" +}; +static const char* const kSetCookieHeaders[] = { + "set-cookie", "set-cookie2" +}; + +static SocketStreamJob* WebSocketJobFactory( + const GURL& url, SocketStream::Delegate* delegate) { + WebSocketJob* job = new WebSocketJob(delegate); + job->InitSocketStream(new SocketStream(url, job)); + return job; +} + +class WebSocketJobInitSingleton { + private: + friend struct DefaultSingletonTraits<WebSocketJobInitSingleton>; + WebSocketJobInitSingleton() { + SocketStreamJob::RegisterProtocolFactory("ws", WebSocketJobFactory); + SocketStreamJob::RegisterProtocolFactory("wss", WebSocketJobFactory); + } +}; + +static void ParseHandshakeMessage( + const char* handshake_message, int len, + std::string* status_line, + std::string* header) { + size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n"); + if (i == base::StringPiece::npos) { + *status_line = std::string(handshake_message, len); + *header = ""; + return; + } + *status_line = std::string(handshake_message, i + 2); + *header = std::string(handshake_message + i + 2, len - i - 2); +} + +static void FetchResponseCookies( + const char* handshake_message, int len, + std::vector<std::string>* response_cookies) { + std::string handshake_response(handshake_message, len); + HttpUtil::HeadersIterator iter(handshake_response.begin(), + handshake_response.end(), "\r\n"); + while (iter.GetNext()) { + for (size_t i = 0; i < arraysize(kSetCookieHeaders); i++) { + if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), + kSetCookieHeaders[i])) { + response_cookies->push_back(iter.values()); + } + } + } +} + +// static +void WebSocketJob::EnsureInit() { + Singleton<WebSocketJobInitSingleton>::get(); +} + +WebSocketJob::WebSocketJob(SocketStream::Delegate* delegate) + : delegate_(delegate), + state_(INITIALIZED), + handshake_request_sent_(0), + handshake_response_header_length_(0), + response_cookies_save_index_(0), + ALLOW_THIS_IN_INITIALIZER_LIST(can_get_cookies_callback_( + this, &WebSocketJob::OnCanGetCookiesCompleted)), + ALLOW_THIS_IN_INITIALIZER_LIST(can_set_cookie_callback_( + this, &WebSocketJob::OnCanSetCookieCompleted)) { +} + +WebSocketJob::~WebSocketJob() { + DCHECK_EQ(CLOSED, state_); + DCHECK(!delegate_); + DCHECK(!socket_.get()); +} + +void WebSocketJob::Connect() { + DCHECK(socket_.get()); + DCHECK_EQ(state_, INITIALIZED); + state_ = CONNECTING; + socket_->Connect(); +} + +bool WebSocketJob::SendData(const char* data, int len) { + switch (state_) { + case INITIALIZED: + return false; + + case CONNECTING: + return SendHandshakeRequest(data, len); + + case OPEN: + return socket_->SendData(data, len); + + case CLOSED: + return false; + } + return false; +} + +void WebSocketJob::Close() { + state_ = CLOSED; + socket_->Close(); +} + +void WebSocketJob::RestartWithAuth( + const std::wstring& username, + const std::wstring& password) { + state_ = CONNECTING; + socket_->RestartWithAuth(username, password); +} + +void WebSocketJob::DetachDelegate() { + state_ = CLOSED; + delegate_ = NULL; + socket_->DetachDelegate(); + socket_ = NULL; +} + +void WebSocketJob::OnConnected( + SocketStream* socket, int max_pending_send_allowed) { + if (delegate_) + delegate_->OnConnected(socket, max_pending_send_allowed); +} + +void WebSocketJob::OnSentData(SocketStream* socket, int amount_sent) { + if (state_ == CONNECTING) { + OnSentHandshakeRequest(socket, amount_sent); + return; + } + if (delegate_) + delegate_->OnSentData(socket, amount_sent); +} + +void WebSocketJob::OnReceivedData( + SocketStream* socket, const char* data, int len) { + if (state_ == CONNECTING) { + OnReceivedHandshakeResponse(socket, data, len); + return; + } + if (delegate_) + delegate_->OnReceivedData(socket, data, len); +} + +void WebSocketJob::OnClose(SocketStream* socket) { + state_ = CLOSED; + SocketStream::Delegate* delegate = delegate_; + delegate_ = NULL; + socket_ = NULL; + if (delegate) + delegate->OnClose(socket); +} + +void WebSocketJob::OnAuthRequired( + SocketStream* socket, AuthChallengeInfo* auth_info) { + if (delegate_) + delegate_->OnAuthRequired(socket, auth_info); +} + +void WebSocketJob::OnError(const SocketStream* socket, int error) { + if (delegate_) + delegate_->OnError(socket, error); +} + +bool WebSocketJob::SendHandshakeRequest(const char* data, int len) { + DCHECK_EQ(state_, CONNECTING); + if (!handshake_request_.empty()) { + // if we're already sending handshake message, don't send any more data + // until handshake is completed. + return false; + } + original_handshake_request_.append(data, len); + original_handshake_request_header_length_ = + HttpUtil::LocateEndOfHeaders(original_handshake_request_.data(), + original_handshake_request_.size(), 0); + if (original_handshake_request_header_length_ > 0) { + // handshake message is completed. + AddCookieHeaderAndSend(); + } + // Just buffered in original_handshake_request_. + return true; +} + +void WebSocketJob::AddCookieHeaderAndSend() { + AddRef(); // Balanced in OnCanGetCookiesCompleted + + int policy = OK; + if (socket_->context()->cookie_policy()) { + GURL url_for_cookies = GetURLForCookies(); + policy = socket_->context()->cookie_policy()->CanGetCookies( + url_for_cookies, + url_for_cookies, + &can_get_cookies_callback_); + if (policy == ERR_IO_PENDING) + return; // Wait for completion callback + } + OnCanGetCookiesCompleted(policy); +} + +void WebSocketJob::OnCanGetCookiesCompleted(int policy) { + if (socket_ && delegate_ && state_ == CONNECTING) { + std::string handshake_request_status_line; + std::string handshake_request_header; + ParseHandshakeMessage(original_handshake_request_.data(), + original_handshake_request_header_length_, + &handshake_request_status_line, + &handshake_request_header); + + // Remove cookie headers. + handshake_request_header = HttpUtil::StripHeaders( + handshake_request_header, + kCookieHeaders, arraysize(kCookieHeaders)); + + if (policy == OK) { + // Add cookies, including HttpOnly cookies. + if (socket_->context()->cookie_store()) { + CookieOptions cookie_options; + cookie_options.set_include_httponly(); + std::string cookie = + socket_->context()->cookie_store()->GetCookiesWithOptions( + GetURLForCookies(), cookie_options); + if (!cookie.empty()) { + HttpUtil::AppendHeaderIfMissing("Cookie", cookie, + &handshake_request_header); + } + } + } + + // Simply ignore rest data in original request header after + // original_handshake_request_header_length_, because websocket protocol + // doesn't allow sending message before handshake is completed. + // TODO(ukai): report as error? + handshake_request_ = + handshake_request_status_line + handshake_request_header + "\r\n"; + + handshake_request_sent_ = 0; + socket_->SendData(handshake_request_.data(), + handshake_request_.size()); + } + Release(); // Balance AddRef taken in AddCookieHeaderAndSend +} + +void WebSocketJob::OnSentHandshakeRequest( + SocketStream* socket, int amount_sent) { + DCHECK_EQ(state_, CONNECTING); + handshake_request_sent_ += amount_sent; + if (handshake_request_sent_ >= handshake_request_.size()) { + // handshake request has been sent. + // notify original size of handshake request to delegate. + if (delegate_) + delegate_->OnSentData(socket, original_handshake_request_.size()); + } +} + +void WebSocketJob::OnReceivedHandshakeResponse( + SocketStream* socket, const char* data, int len) { + DCHECK_EQ(state_, CONNECTING); + handshake_response_.append(data, len); + handshake_response_header_length_ = HttpUtil::LocateEndOfHeaders( + handshake_response_.data(), + handshake_response_.size(), 0); + if (handshake_response_header_length_ > 0) { + // handshake message is completed. + SaveCookiesAndNotifyHeaderComplete(); + } +} + +void WebSocketJob::SaveCookiesAndNotifyHeaderComplete() { + // handshake message is completed. + DCHECK(handshake_response_.data()); + DCHECK_GT(handshake_response_header_length_, 0); + + response_cookies_.clear(); + response_cookies_save_index_ = 0; + + FetchResponseCookies(handshake_response_.data(), + handshake_response_header_length_, + &response_cookies_); + + // Now, loop over the response cookies, and attempt to persist each. + SaveNextCookie(); +} + +void WebSocketJob::SaveNextCookie() { + if (response_cookies_save_index_ == response_cookies_.size()) { + response_cookies_.clear(); + response_cookies_save_index_ = 0; + + std::string handshake_response_status_line; + std::string handshake_response_header; + ParseHandshakeMessage(handshake_response_.data(), + handshake_response_header_length_, + &handshake_response_status_line, + &handshake_response_header); + // Remove cookie headers. + std::string filtered_handshake_response_header = + HttpUtil::StripHeaders( + handshake_response_header, + kSetCookieHeaders, arraysize(kSetCookieHeaders)); + std::string remaining_data = + std::string(handshake_response_.data() + + handshake_response_header_length_, + handshake_response_.size() - + handshake_response_header_length_); + std::string received_data = + handshake_response_status_line + + filtered_handshake_response_header + + "\r\n" + + remaining_data; + state_ = OPEN; + if (delegate_) + delegate_->OnReceivedData(socket_, + received_data.data(), received_data.size()); + return; + } + + AddRef(); // Balanced in OnCanSetCookieCompleted + + int policy = OK; + if (socket_->context()->cookie_policy()) { + GURL url_for_cookies = GetURLForCookies(); + policy = socket_->context()->cookie_policy()->CanSetCookie( + url_for_cookies, + url_for_cookies, + response_cookies_[response_cookies_save_index_], + &can_set_cookie_callback_); + if (policy == ERR_IO_PENDING) + return; // Wait for completion callback + } + + OnCanSetCookieCompleted(policy); +} + +void WebSocketJob::OnCanSetCookieCompleted(int policy) { + if (socket_ && delegate_ && state_ == CONNECTING) { + if ((policy == OK || policy == OK_FOR_SESSION_ONLY) && + socket_->context()->cookie_store()) { + CookieOptions options; + options.set_include_httponly(); + if (policy == OK_FOR_SESSION_ONLY) + options.set_force_session(); + GURL url_for_cookies = GetURLForCookies(); + socket_->context()->cookie_store()->SetCookieWithOptions( + url_for_cookies, response_cookies_[response_cookies_save_index_], + options); + } + response_cookies_save_index_++; + SaveNextCookie(); + } + Release(); // Balance AddRef taken in SaveNextCookie +} + +GURL WebSocketJob::GetURLForCookies() const { + GURL url = socket_->url(); + std::string scheme = socket_->is_secure() ? "https" : "http"; + url_canon::Replacements<char> replacements; + replacements.SetScheme(scheme.c_str(), + url_parse::Component(0, scheme.length())); + return url.ReplaceComponents(replacements); +} + +} // namespace net diff --git a/net/websockets/websocket_job.h b/net/websockets/websocket_job.h new file mode 100644 index 0000000..31fa503 --- /dev/null +++ b/net/websockets/websocket_job.h @@ -0,0 +1,98 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_JOB_H_ +#define NET_WEBSOCKETS_WEBSOCKET_JOB_H_ + +#include <string> +#include <vector> + +#include "base/ref_counted.h" +#include "net/base/completion_callback.h" +#include "net/socket_stream/socket_stream_job.h" + +class GURL; + +namespace net { + +// WebSocket protocol specific job on SocketStream. +// It captures WebSocket handshake message and handles cookie operations. +// Chome security policy doesn't allow renderer process (except dev tools) +// see HttpOnly cookies, so it injects cookie header in handshake request and +// strips set-cookie headers in handshake response. +// TODO(ukai): refactor to merge WebSocketThrottle functionality. +// TODO(ukai): refactor websocket.cc to use this. +class WebSocketJob : public SocketStreamJob, public SocketStream::Delegate { + public: + // This is state of WebSocket, not SocketStream. + enum State { + INITIALIZED = -1, + CONNECTING = 0, + OPEN = 1, + CLOSED = 2, + }; + static void EnsureInit(); + + explicit WebSocketJob(SocketStream::Delegate* delegate); + + virtual void Connect(); + virtual bool SendData(const char* data, int len); + virtual void Close(); + virtual void RestartWithAuth( + const std::wstring& username, + const std::wstring& password); + virtual void DetachDelegate(); + + // SocketStream::Delegate methods. + virtual void OnConnected( + SocketStream* socket, int max_pending_send_allowed); + virtual void OnSentData( + SocketStream* socket, int amount_sent); + virtual void OnReceivedData( + SocketStream* socket, const char* data, int len); + virtual void OnClose(SocketStream* socket); + virtual void OnAuthRequired( + SocketStream* socket, AuthChallengeInfo* auth_info); + virtual void OnError( + const SocketStream* socket, int error); + + private: + friend class WebSocketJobTest; + virtual ~WebSocketJob(); + + bool SendHandshakeRequest(const char* data, int len); + void AddCookieHeaderAndSend(); + void OnCanGetCookiesCompleted(int policy); + + void OnSentHandshakeRequest(SocketStream* socket, int amount_sent); + void OnReceivedHandshakeResponse( + SocketStream* socket, const char* data, int len); + void SaveCookiesAndNotifyHeaderComplete(); + void SaveNextCookie(); + void OnCanSetCookieCompleted(int policy); + + GURL GetURLForCookies() const; + + SocketStream::Delegate* delegate_; + State state_; + + std::string original_handshake_request_; + int original_handshake_request_header_length_; + std::string handshake_request_; + size_t handshake_request_sent_; + + std::string handshake_response_; + int handshake_response_header_length_; + std::vector<std::string> response_cookies_; + size_t response_cookies_save_index_; + + CompletionCallbackImpl<WebSocketJob> can_get_cookies_callback_; + CompletionCallbackImpl<WebSocketJob> can_set_cookie_callback_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketJob); +}; + +} // namespace + +#endif // NET_WEBSOCKETS_WEBSOCKET_JOB_H_ diff --git a/net/websockets/websocket_job_unittest.cc b/net/websockets/websocket_job_unittest.cc new file mode 100644 index 0000000..de96a32 --- /dev/null +++ b/net/websockets/websocket_job_unittest.cc @@ -0,0 +1,495 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> +#include <vector> + +#include "base/ref_counted.h" +#include "googleurl/src/gurl.h" +#include "net/base/cookie_policy.h" +#include "net/base/cookie_store.h" +#include "net/base/net_errors.h" +#include "net/socket_stream/socket_stream.h" +#include "net/url_request/url_request_context.h" +#include "net/websockets/websocket_job.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/platform_test.h" + +namespace net { + +class MockSocketStream : public SocketStream { + public: + MockSocketStream(const GURL& url, SocketStream::Delegate* delegate) + : SocketStream(url, delegate) {} + virtual ~MockSocketStream() {} + + virtual void Connect() {} + virtual bool SendData(const char* data, int len) { + sent_data_ += std::string(data, len); + return true; + } + + virtual void Close() {} + virtual void RestartWithAuth( + const std::wstring& username, std::wstring& password) {} + virtual void DetachDelegate() { + delegate_ = NULL; + } + + const std::string& sent_data() const { + return sent_data_; + } + + private: + std::string sent_data_; +}; + +class MockSocketStreamDelegate : public SocketStream::Delegate { + public: + MockSocketStreamDelegate() + : amount_sent_(0) {} + virtual ~MockSocketStreamDelegate() {} + + virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed) { + } + virtual void OnSentData(SocketStream* socket, int amount_sent) { + amount_sent_ += amount_sent; + } + virtual void OnReceivedData(SocketStream* socket, + const char* data, int len) { + received_data_ += std::string(data, len); + } + virtual void OnClose(SocketStream* socket) { + } + + size_t amount_sent() const { return amount_sent_; } + const std::string& received_data() const { return received_data_; } + + private: + int amount_sent_; + std::string received_data_; +}; + +class MockCookieStore : public CookieStore { + public: + struct Entry { + GURL url; + std::string cookie_line; + CookieOptions options; + }; + MockCookieStore() {} + + virtual bool SetCookieWithOptions(const GURL& url, + const std::string& cookie_line, + const CookieOptions& options) { + Entry entry; + entry.url = url; + entry.cookie_line = cookie_line; + entry.options = options; + entries_.push_back(entry); + return true; + } + virtual std::string GetCookiesWithOptions(const GURL& url, + const CookieOptions& options) { + std::string result; + for (size_t i = 0; i < entries_.size(); i++) { + Entry &entry = entries_[i]; + if (url == entry.url) { + if (!result.empty()) { + result += "; "; + } + result += entry.cookie_line; + } + } + return result; + } + virtual void DeleteCookie(const GURL& url, + const std::string& cookie_name) {} + virtual CookieMonster* GetCookieMonster() { return NULL; } + + const std::vector<Entry>& entries() const { return entries_; } + + private: + friend class base::RefCountedThreadSafe<MockCookieStore>; + virtual ~MockCookieStore() {} + + std::vector<Entry> entries_; +}; + +class MockCookiePolicy : public CookiePolicy, + public base::RefCountedThreadSafe<MockCookiePolicy> { + public: + MockCookiePolicy() : allow_all_cookies_(true), callback_(NULL) {} + + void set_allow_all_cookies(bool allow_all_cookies) { + allow_all_cookies_ = allow_all_cookies; + } + + virtual int CanGetCookies(const GURL& url, + const GURL& first_party_for_cookies, + CompletionCallback* callback) { + DCHECK(!callback_); + callback_ = callback; + MessageLoop::current()->PostTask( + FROM_HERE, NewRunnableMethod(this, &MockCookiePolicy::OnCanGetCookies)); + return ERR_IO_PENDING; + } + + virtual int CanSetCookie(const GURL& url, + const GURL& first_party_for_cookies, + const std::string& cookie_line, + CompletionCallback* callback) { + DCHECK(!callback_); + callback_ = callback; + MessageLoop::current()->PostTask( + FROM_HERE, NewRunnableMethod(this, &MockCookiePolicy::OnCanSetCookie)); + return ERR_IO_PENDING; + } + + private: + friend class base::RefCountedThreadSafe<MockCookiePolicy>; + virtual ~MockCookiePolicy() {} + + void OnCanGetCookies() { + CompletionCallback* callback = callback_; + callback_ = NULL; + if (allow_all_cookies_) + callback->Run(OK); + else + callback->Run(ERR_ACCESS_DENIED); + } + void OnCanSetCookie() { + CompletionCallback* callback = callback_; + callback_ = NULL; + if (allow_all_cookies_) + callback->Run(OK); + else + callback->Run(ERR_ACCESS_DENIED); + } + + bool allow_all_cookies_; + CompletionCallback* callback_; +}; + +class MockURLRequestContext : public URLRequestContext { + public: + MockURLRequestContext(CookieStore* cookie_store, + CookiePolicy* cookie_policy) { + cookie_store_ = cookie_store; + cookie_policy_ = cookie_policy; + } + + private: + friend class base::RefCountedThreadSafe<MockURLRequestContext>; + virtual ~MockURLRequestContext() {} +}; + +class WebSocketJobTest : public PlatformTest { + public: + virtual void SetUp() { + cookie_store_ = new MockCookieStore; + cookie_policy_ = new MockCookiePolicy; + context_ = new MockURLRequestContext( + cookie_store_.get(), cookie_policy_.get()); + } + virtual void TearDown() { + cookie_store_ = NULL; + cookie_policy_ = NULL; + context_ = NULL; + websocket_ = NULL; + socket_ = NULL; + } + protected: + void InitWebSocketJob(const GURL& url, MockSocketStreamDelegate* delegate) { + websocket_ = new WebSocketJob(delegate); + socket_ = new MockSocketStream(url, websocket_.get()); + websocket_->InitSocketStream(socket_.get()); + websocket_->state_ = WebSocketJob::CONNECTING; + websocket_->set_context(context_.get()); + } + WebSocketJob::State GetWebSocketJobState() { + return websocket_->state_; + } + void CloseWebSocketJob() { + if (websocket_->socket_) + websocket_->socket_->DetachDelegate(); + websocket_->state_ = WebSocketJob::CLOSED; + websocket_->delegate_ = NULL; + websocket_->socket_ = NULL; + } + + scoped_refptr<MockCookieStore> cookie_store_; + scoped_refptr<MockCookiePolicy> cookie_policy_; + scoped_refptr<MockURLRequestContext> context_; + scoped_refptr<WebSocketJob> websocket_; + scoped_refptr<MockSocketStream> socket_; +}; + +TEST_F(WebSocketJobTest, SimpleHandshake) { + GURL url("ws://example.com/demo"); + MockSocketStreamDelegate delegate; + InitWebSocketJob(url, &delegate); + + static const char* kHandshakeRequestMessage = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + bool sent = websocket_->SendData(kHandshakeRequestMessage, + strlen(kHandshakeRequestMessage)); + EXPECT_EQ(true, sent); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage)); + EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); + + static const char* kHandshakeResponseMessage = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + websocket_->OnReceivedData(socket_.get(), + kHandshakeResponseMessage, + strlen(kHandshakeResponseMessage)); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data()); + EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); + CloseWebSocketJob(); +} + +TEST_F(WebSocketJobTest, SlowHandshake) { + GURL url("ws://example.com/demo"); + MockSocketStreamDelegate delegate; + InitWebSocketJob(url, &delegate); + + static const char* kHandshakeRequestMessage = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + std::vector<std::string> lines; + SplitString(kHandshakeRequestMessage, '\n', &lines); + for (size_t i = 0; i < lines.size() - 2; i++) { + std::string line = lines[i] + "\r\n"; + SCOPED_TRACE("Line: " + line); + bool sent = websocket_->SendData(line.c_str(), line.size()); + EXPECT_EQ(true, sent); + MessageLoop::current()->RunAllPending(); + EXPECT_TRUE(socket_->sent_data().empty()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + } + bool sent = websocket_->SendData("\r\n", 2); + EXPECT_EQ(true, sent); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + + for (size_t i = 0; i < lines.size() - 2; i++) { + std::string line = lines[i] + "\r\n"; + SCOPED_TRACE("Line: " + line); + websocket_->OnSentData(socket_.get(), line.size()); + EXPECT_EQ(0U, delegate.amount_sent()); + } + websocket_->OnSentData(socket_.get(), 2); // \r\n + EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + + static const char* kHandshakeResponseMessage = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + lines.clear(); + SplitString(kHandshakeResponseMessage, '\n', &lines); + for (size_t i = 0; i < lines.size() - 2; i++) { + std::string line = lines[i] + "\r\n"; + SCOPED_TRACE("Line: " + line); + websocket_->OnReceivedData(socket_, + line.c_str(), + line.size()); + MessageLoop::current()->RunAllPending(); + EXPECT_TRUE(delegate.received_data().empty()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + } + websocket_->OnReceivedData(socket_.get(), "\r\n", 2); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data()); + EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); + CloseWebSocketJob(); +} + +TEST_F(WebSocketJobTest, HandshakeWithCookie) { + GURL url("ws://example.com/demo"); + GURL cookieUrl("http://example.com/demo"); + CookieOptions cookie_options; + cookie_store_->SetCookieWithOptions( + cookieUrl, "CR-test=1", cookie_options); + cookie_options.set_include_httponly(); + cookie_store_->SetCookieWithOptions( + cookieUrl, "CR-test-httponly=1", cookie_options); + + MockSocketStreamDelegate delegate; + InitWebSocketJob(url, &delegate); + + static const char* kHandshakeRequestMessage = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "Cookie: WK-test=1\r\n" + "\r\n"; + + static const char* kHandshakeRequestExpected = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "Cookie: CR-test=1; CR-test-httponly=1\r\n" + "\r\n"; + + bool sent = websocket_->SendData(kHandshakeRequestMessage, + strlen(kHandshakeRequestMessage)); + EXPECT_EQ(true, sent); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected)); + EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); + + static const char* kHandshakeResponseMessage = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "Set-Cookie: CR-set-test=1\r\n" + "\r\n"; + + static const char* kHandshakeResponseExpected = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + websocket_->OnReceivedData(socket_.get(), + kHandshakeResponseMessage, + strlen(kHandshakeResponseMessage)); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data()); + EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); + + EXPECT_EQ(3U, cookie_store_->entries().size()); + EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url); + EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line); + EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url); + EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line); + EXPECT_EQ(cookieUrl, cookie_store_->entries()[2].url); + EXPECT_EQ("CR-set-test=1", cookie_store_->entries()[2].cookie_line); + + CloseWebSocketJob(); +} + +TEST_F(WebSocketJobTest, HandshakeWithCookieButNotAllowed) { + GURL url("ws://example.com/demo"); + GURL cookieUrl("http://example.com/demo"); + CookieOptions cookie_options; + cookie_store_->SetCookieWithOptions( + cookieUrl, "CR-test=1", cookie_options); + cookie_options.set_include_httponly(); + cookie_store_->SetCookieWithOptions( + cookieUrl, "CR-test-httponly=1", cookie_options); + cookie_policy_->set_allow_all_cookies(false); + + MockSocketStreamDelegate delegate; + InitWebSocketJob(url, &delegate); + + static const char* kHandshakeRequestMessage = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "Cookie: WK-test=1\r\n" + "\r\n"; + + static const char* kHandshakeRequestExpected = + "GET /demo HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Host: example.com\r\n" + "Origin: http://example.com\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + bool sent = websocket_->SendData(kHandshakeRequestMessage, + strlen(kHandshakeRequestMessage)); + EXPECT_EQ(true, sent); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data()); + EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState()); + websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected)); + EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent()); + + static const char* kHandshakeResponseMessage = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "Set-Cookie: CR-set-test=1\r\n" + "\r\n"; + + static const char* kHandshakeResponseExpected = + "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "WebSocket-Origin: http://example.com\r\n" + "WebSocket-Location: ws://example.com/demo\r\n" + "WebSocket-Protocol: sample\r\n" + "\r\n"; + + websocket_->OnReceivedData(socket_.get(), + kHandshakeResponseMessage, + strlen(kHandshakeResponseMessage)); + MessageLoop::current()->RunAllPending(); + EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data()); + EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState()); + + EXPECT_EQ(2U, cookie_store_->entries().size()); + EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url); + EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line); + EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url); + EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line); + + CloseWebSocketJob(); +} + +} // namespace net |