diff options
author | byungchul@chromium.org <byungchul@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-08-22 18:10:13 +0000 |
---|---|---|
committer | byungchul@chromium.org <byungchul@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-08-22 18:12:05 +0000 |
commit | 3626bc39d629245a782c61ba0bc668583778e392 (patch) | |
tree | 7156999fdfb9044dd14d5463568b8f6ce01543fe | |
parent | 1c2f071e2a14a39b215a897434d30834ae959560 (diff) | |
download | chromium_src-3626bc39d629245a782c61ba0bc668583778e392.zip chromium_src-3626bc39d629245a782c61ba0bc668583778e392.tar.gz chromium_src-3626bc39d629245a782c61ba0bc668583778e392.tar.bz2 |
Replace StreamListenSocket with StreamSocket in HttpServer.
1) HttpServer gets ServerSocket instead of StreamListenSocket.
2) HttpConnection is just a container for socket, websocket, and pending read/write buffers.
3) HttpServer handles data buffering and asynchronous read/write.
4) HttpConnection has limit in data buffering, up to 1Mbytes by default.
5) For devtools, send buffer limit is 100Mbytes.
6) Unittests for buffer handling in HttpConnection.
BUG=371906
Review URL: https://codereview.chromium.org/296053012
Cr-Commit-Position: refs/heads/master@{#291447}
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@291447 0039d316-1c4b-4281-b951-d872f2087c98
32 files changed, 1407 insertions, 377 deletions
diff --git a/android_webview/native/aw_dev_tools_server.cc b/android_webview/native/aw_dev_tools_server.cc index 90b72d1..10f3243c 100644 --- a/android_webview/native/aw_dev_tools_server.cc +++ b/android_webview/native/aw_dev_tools_server.cc @@ -19,7 +19,7 @@ #include "content/public/browser/web_contents.h" #include "content/public/common/user_agent.h" #include "jni/AwDevToolsServer_jni.h" -#include "net/socket/unix_domain_listen_socket_posix.h" +#include "net/socket/unix_domain_server_socket_posix.h" using content::DevToolsAgentHost; using content::RenderViewHost; @@ -157,6 +157,25 @@ std::string GetViewDescription(WebContents* web_contents) { return json; } +// Factory for UnixDomainServerSocket. +class UnixDomainServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + explicit UnixDomainServerSocketFactory(const std::string& socket_name) + : content::DevToolsHttpHandler::ServerSocketFactory(socket_name, 0, 1) {} + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::UnixDomainServerSocket( + base::Bind(&content::CanUserConnectToDevTools), + true /* use_abstract_namespace */)); + } + + DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocketFactory); +}; + } // namespace namespace android_webview { @@ -173,11 +192,11 @@ void AwDevToolsServer::Start() { if (protocol_handler_) return; + scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory> factory( + new UnixDomainServerSocketFactory( + base::StringPrintf(kSocketNameFormat, getpid()))); protocol_handler_ = content::DevToolsHttpHandler::Start( - new net::deprecated::UnixDomainListenSocketWithAbstractNamespaceFactory( - base::StringPrintf(kSocketNameFormat, getpid()), - "", - base::Bind(&content::CanUserConnectToDevTools)), + factory.Pass(), base::StringPrintf(kFrontEndURL, content::GetWebKitRevision().c_str()), new AwDevToolsServerDelegate(), base::FilePath()); diff --git a/chrome/browser/android/dev_tools_server.cc b/chrome/browser/android/dev_tools_server.cc index 158ef74..83be244 100644 --- a/chrome/browser/android/dev_tools_server.cc +++ b/chrome/browser/android/dev_tools_server.cc @@ -40,7 +40,9 @@ #include "content/public/common/user_agent.h" #include "grit/browser_resources.h" #include "jni/DevToolsServer_jni.h" +#include "net/base/net_errors.h" #include "net/socket/unix_domain_listen_socket_posix.h" +#include "net/socket/unix_domain_server_socket_posix.h" #include "net/url_request/url_request_context_getter.h" #include "ui/base/resource/resource_bundle.h" @@ -393,6 +395,49 @@ class DevToolsServerDelegate : public content::DevToolsHttpHandlerDelegate { DISALLOW_COPY_AND_ASSIGN(DevToolsServerDelegate); }; +// Factory for UnixDomainServerSocket. It tries a fallback socket when +// original socket doesn't work. +class UnixDomainServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + UnixDomainServerSocketFactory( + const std::string& socket_name, + const net::UnixDomainServerSocket::AuthCallback& auth_callback) + : content::DevToolsHttpHandler::ServerSocketFactory(socket_name, 0, 1), + auth_callback_(auth_callback) { + } + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::UnixDomainServerSocket(auth_callback_, + true /* use_abstract_namespace */)); + } + + virtual scoped_ptr<net::ServerSocket> CreateAndListen() const OVERRIDE { + scoped_ptr<net::ServerSocket> socket = Create(); + if (!socket) + return scoped_ptr<net::ServerSocket>(); + + if (socket->ListenWithAddressAndPort(address_, port_, backlog_) == net::OK) + return socket.Pass(); + + // Try a fallback socket name. + const std::string fallback_address( + base::StringPrintf("%s_%d", address_.c_str(), getpid())); + if (socket->ListenWithAddressAndPort(fallback_address, port_, backlog_) + == net::OK) + return socket.Pass(); + + return scoped_ptr<net::ServerSocket>(); + } + + const net::UnixDomainServerSocket::AuthCallback auth_callback_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocketFactory); +}; + } // namespace DevToolsServer::DevToolsServer(const std::string& socket_name_prefix) @@ -419,12 +464,10 @@ void DevToolsServer::Start(bool allow_debug_permission) { allow_debug_permission ? base::Bind(&AuthorizeSocketAccessWithDebugPermission) : base::Bind(&content::CanUserConnectToDevTools); - + scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory> factory( + new UnixDomainServerSocketFactory(socket_name_, auth_callback)); protocol_handler_ = content::DevToolsHttpHandler::Start( - new net::deprecated::UnixDomainListenSocketWithAbstractNamespaceFactory( - socket_name_, - base::StringPrintf("%s_%d", socket_name_.c_str(), getpid()), - auth_callback), + factory.Pass(), base::StringPrintf(kFrontEndURL, content::GetWebKitRevision().c_str()), new DevToolsServerDelegate(auth_callback), base::FilePath()); diff --git a/chrome/browser/devtools/device/android_web_socket.cc b/chrome/browser/devtools/device/android_web_socket.cc index a0054da..2a3886b 100644 --- a/chrome/browser/devtools/device/android_web_socket.cc +++ b/chrome/browser/devtools/device/android_web_socket.cc @@ -199,8 +199,7 @@ void WebSocketImpl::OnBytesRead(scoped_refptr<net::IOBuffer> response_buffer, return; } - std::string data = std::string(response_buffer->data(), result); - response_buffer_ += data; + response_buffer_.append(response_buffer->data(), result); int bytes_consumed; std::string output; diff --git a/chrome/browser/devtools/remote_debugging_server.cc b/chrome/browser/devtools/remote_debugging_server.cc index 211ff1a..c971a4a 100644 --- a/chrome/browser/devtools/remote_debugging_server.cc +++ b/chrome/browser/devtools/remote_debugging_server.cc @@ -9,7 +9,28 @@ #include "chrome/browser/ui/webui/devtools_ui.h" #include "chrome/common/chrome_paths.h" #include "content/public/browser/devtools_http_handler.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" + +namespace { + +class TCPServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + TCPServerSocketFactory(const std::string& address, int port, int backlog) + : content::DevToolsHttpHandler::ServerSocketFactory( + address, port, backlog) {} + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + } + + DISALLOW_COPY_AND_ASSIGN(TCPServerSocketFactory); +}; + +} // namespace RemoteDebuggingServer::RemoteDebuggingServer( chrome::HostDesktopType host_desktop_type, @@ -24,8 +45,10 @@ RemoteDebuggingServer::RemoteDebuggingServer( DCHECK(result); } + scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory> factory( + new TCPServerSocketFactory(ip, port, 1)); devtools_http_handler_ = content::DevToolsHttpHandler::Start( - new net::TCPListenSocketFactory(ip, port), + factory.Pass(), "", new BrowserListTabContentsProvider(host_desktop_type), output_dir); diff --git a/chrome/test/chromedriver/net/net_util_unittest.cc b/chrome/test/chromedriver/net/net_util_unittest.cc index e4d8b14..dbe41d0 100644 --- a/chrome/test/chromedriver/net/net_util_unittest.cc +++ b/chrome/test/chromedriver/net/net_util_unittest.cc @@ -20,7 +20,7 @@ #include "net/base/net_errors.h" #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "net/url_request/url_request_context_getter.h" #include "testing/gtest/include/gtest/gtest.h" @@ -54,8 +54,10 @@ class FetchUrlTest : public testing::Test, } void InitOnIO(base::WaitableEvent* event) { - net::TCPListenSocketFactory factory("127.0.0.1", 0); - server_ = new net::HttpServer(factory, this); + scoped_ptr<net::ServerSocket> server_socket( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1); + server_.reset(new net::HttpServer(server_socket.Pass(), this)); net::IPEndPoint address; CHECK_EQ(net::OK, server_->GetLocalAddress(&address)); server_url_ = base::StringPrintf("http://127.0.0.1:%d", address.port()); @@ -63,7 +65,7 @@ class FetchUrlTest : public testing::Test, } void DestroyServerOnIO(base::WaitableEvent* event) { - server_ = NULL; + server_.reset(NULL); event->Signal(); } @@ -78,10 +80,7 @@ class FetchUrlTest : public testing::Test, server_->Send404(connection_id); break; case kClose: - // net::HttpServer doesn't allow us to close connection during callback. - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&net::HttpServer::Close, server_, connection_id)); + server_->Close(connection_id); break; default: break; @@ -104,7 +103,7 @@ class FetchUrlTest : public testing::Test, base::Thread io_thread_; ServerResponse response_; - scoped_refptr<net::HttpServer> server_; + scoped_ptr<net::HttpServer> server_; scoped_refptr<URLRequestContextGetter> context_getter_; std::string server_url_; }; diff --git a/chrome/test/chromedriver/net/test_http_server.cc b/chrome/test/chromedriver/net/test_http_server.cc index 740a33d..1d19524 100644 --- a/chrome/test/chromedriver/net/test_http_server.cc +++ b/chrome/test/chromedriver/net/test_http_server.cc @@ -13,7 +13,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/server/http_server_request_info.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "testing/gtest/include/gtest/gtest.h" TestHttpServer::TestHttpServer() @@ -92,10 +92,7 @@ void TestHttpServer::OnWebSocketRequest( server_->Send404(connection_id); break; case kClose: - // net::HttpServer doesn't allow us to close connection during callback. - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&net::HttpServer::Close, server_, connection_id)); + server_->Close(connection_id); break; } } @@ -112,10 +109,7 @@ void TestHttpServer::OnWebSocketMessage(int connection_id, server_->SendOverWebSocket(connection_id, data); break; case kCloseOnMessage: - // net::HttpServer doesn't allow us to close connection during callback. - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&net::HttpServer::Close, server_, connection_id)); + server_->Close(connection_id); break; } } @@ -128,8 +122,10 @@ void TestHttpServer::OnClose(int connection_id) { void TestHttpServer::StartOnServerThread(bool* success, base::WaitableEvent* event) { - net::TCPListenSocketFactory factory("127.0.0.1", 0); - server_ = new net::HttpServer(factory, this); + scoped_ptr<net::ServerSocket> server_socket( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1); + server_.reset(new net::HttpServer(server_socket.Pass(), this)); net::IPEndPoint address; int error = server_->GetLocalAddress(&address); @@ -139,14 +135,13 @@ void TestHttpServer::StartOnServerThread(bool* success, web_socket_url_ = GURL(base::StringPrintf("ws://127.0.0.1:%d", address.port())); } else { - server_ = NULL; + server_.reset(NULL); } *success = server_.get(); event->Signal(); } void TestHttpServer::StopOnServerThread(base::WaitableEvent* event) { - if (server_.get()) - server_ = NULL; + server_.reset(NULL); event->Signal(); } diff --git a/chrome/test/chromedriver/net/test_http_server.h b/chrome/test/chromedriver/net/test_http_server.h index 697434d..40e6ac4 100644 --- a/chrome/test/chromedriver/net/test_http_server.h +++ b/chrome/test/chromedriver/net/test_http_server.h @@ -77,7 +77,7 @@ class TestHttpServer : public net::HttpServer::Delegate { base::Thread thread_; // Access only on the server thread. - scoped_refptr<net::HttpServer> server_; + scoped_ptr<net::HttpServer> server_; // Access only on the server thread. std::set<int> connections_; diff --git a/chrome/test/chromedriver/server/chromedriver_server.cc b/chrome/test/chromedriver/server/chromedriver_server.cc index e93091e..2508c59 100644 --- a/chrome/test/chromedriver/server/chromedriver_server.cc +++ b/chrome/test/chromedriver/server/chromedriver_server.cc @@ -33,7 +33,7 @@ #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" namespace { @@ -55,8 +55,10 @@ class HttpServer : public net::HttpServer::Delegate { std::string binding_ip = kLocalHostAddress; if (allow_remote) binding_ip = "0.0.0.0"; - server_ = new net::HttpServer( - net::TCPListenSocketFactory(binding_ip, port), this); + scoped_ptr<net::ServerSocket> server_socket( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort(binding_ip, port, 1); + server_.reset(new net::HttpServer(server_socket.Pass(), this)); net::IPEndPoint address; return server_->GetLocalAddress(&address) == net::OK; } @@ -89,7 +91,7 @@ class HttpServer : public net::HttpServer::Delegate { } HttpRequestHandlerFunc handle_request_func_; - scoped_refptr<net::HttpServer> server_; + scoped_ptr<net::HttpServer> server_; base::WeakPtrFactory<HttpServer> weak_factory_; // Should be last. }; diff --git a/chromecast/shell/browser/devtools/remote_debugging_server.cc b/chromecast/shell/browser/devtools/remote_debugging_server.cc index 076b066..57214f1 100644 --- a/chromecast/shell/browser/devtools/remote_debugging_server.cc +++ b/chromecast/shell/browser/devtools/remote_debugging_server.cc @@ -17,11 +17,11 @@ #include "content/public/browser/devtools_http_handler.h" #include "content/public/common/content_switches.h" #include "content/public/common/user_agent.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #if defined(OS_ANDROID) #include "content/public/browser/android/devtools_auth.h" -#include "net/socket/unix_domain_socket_posix.h" +#include "net/socket/unix_domain_server_socket_posix.h" #endif // defined(OS_ANDROID) namespace chromecast { @@ -35,7 +35,45 @@ const char kFrontEndURL[] = #endif // defined(OS_ANDROID) const int kDefaultRemoteDebuggingPort = 9222; -net::StreamListenSocketFactory* CreateSocketFactory(int port) { +#if defined(OS_ANDROID) +class UnixDomainServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + explicit UnixDomainServerSocketFactory(const std::string& socket_name) + : content::DevToolsHttpHandler::ServerSocketFactory(socket_name, 0, 1) {} + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::UnixDomainServerSocket( + base::Bind(&content::CanUserConnectToDevTools), + true /* use_abstract_namespace */)); + } + + DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocketFactory); +}; +#else +class TCPServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + TCPServerSocketFactory(const std::string& address, int port, int backlog) + : content::DevToolsHttpHandler::ServerSocketFactory( + address, port, backlog) {} + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + } + + DISALLOW_COPY_AND_ASSIGN(TCPServerSocketFactory); +}; +#endif + +scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory> +CreateSocketFactory(int port) { #if defined(OS_ANDROID) base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); std::string socket_name = "content_shell_devtools_remote"; @@ -43,11 +81,12 @@ net::StreamListenSocketFactory* CreateSocketFactory(int port) { socket_name = command_line->GetSwitchValueASCII( switches::kRemoteDebuggingSocketName); } - return new net::UnixDomainSocketWithAbstractNamespaceFactory( - socket_name, "", base::Bind(&content::CanUserConnectToDevTools)); + return scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory>( + new UnixDomainServerSocketFactory(socket_name)); #else - return new net::TCPListenSocketFactory("0.0.0.0", port); -#endif // defined(OS_ANDROID) + return scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory>( + new TCPServerSocketFactory("0.0.0.0", port, 1)); +#endif } std::string GetFrontendUrl() { diff --git a/cloud_print/gcp20/prototype/privet_http_server.cc b/cloud_print/gcp20/prototype/privet_http_server.cc index 9aa2835..41daa81 100644 --- a/cloud_print/gcp20/prototype/privet_http_server.cc +++ b/cloud_print/gcp20/prototype/privet_http_server.cc @@ -10,7 +10,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/url_util.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "url/gurl.h" namespace { @@ -105,10 +105,12 @@ bool PrivetHttpServer::Start(uint16 port) { if (server_) return true; - net::TCPListenSocketFactory factory("0.0.0.0", port); - server_ = new net::HttpServer(factory, this); - net::IPEndPoint address; + scoped_ptr<net::ServerSocket> server_socket( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort("0.0.0.0", port, 1); + server_.reset(new net::HttpServer(server_socket.Pass(), this)); + net::IPEndPoint address; if (server_->GetLocalAddress(&address) != net::OK) { NOTREACHED() << "Cannot start HTTP server"; return false; @@ -122,7 +124,7 @@ void PrivetHttpServer::Shutdown() { if (!server_) return; - server_ = NULL; + server_.reset(NULL); } void PrivetHttpServer::OnHttpRequest(int connection_id, diff --git a/cloud_print/gcp20/prototype/privet_http_server.h b/cloud_print/gcp20/prototype/privet_http_server.h index 4b22e35..6cb14d10 100644 --- a/cloud_print/gcp20/prototype/privet_http_server.h +++ b/cloud_print/gcp20/prototype/privet_http_server.h @@ -204,7 +204,7 @@ class PrivetHttpServer: public net::HttpServer::Delegate { uint16 port_; // Contains encapsulated object for listening for requests. - scoped_refptr<net::HttpServer> server_; + scoped_ptr<net::HttpServer> server_; Delegate* delegate_; @@ -212,4 +212,3 @@ class PrivetHttpServer: public net::HttpServer::Delegate { }; #endif // CLOUD_PRINT_GCP20_PROTOTYPE_PRIVET_HTTP_SERVER_H_ - diff --git a/content/browser/devtools/devtools_browser_target.cc b/content/browser/devtools/devtools_browser_target.cc index b050152..e959946 100644 --- a/content/browser/devtools/devtools_browser_target.cc +++ b/content/browser/devtools/devtools_browser_target.cc @@ -17,9 +17,8 @@ namespace content { -DevToolsBrowserTarget::DevToolsBrowserTarget( - net::HttpServer* http_server, - int connection_id) +DevToolsBrowserTarget::DevToolsBrowserTarget(net::HttpServer* http_server, + int connection_id) : message_loop_proxy_(base::MessageLoopProxy::current()), http_server_(http_server), connection_id_(connection_id), diff --git a/content/browser/devtools/devtools_http_handler_impl.cc b/content/browser/devtools/devtools_http_handler_impl.cc index 72fe254..18d30e3 100644 --- a/content/browser/devtools/devtools_http_handler_impl.cc +++ b/content/browser/devtools/devtools_http_handler_impl.cc @@ -39,6 +39,7 @@ #include "net/base/net_errors.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" +#include "net/socket/server_socket.h" #if defined(OS_ANDROID) #include "base/android/build_info.h" @@ -67,6 +68,9 @@ const char kTargetFaviconUrlField[] = "faviconUrl"; const char kTargetWebSocketDebuggerUrlField[] = "webSocketDebuggerUrl"; const char kTargetDevtoolsFrontendUrlField[] = "devtoolsFrontendUrl"; +// Maximum write buffer size of devtools http/websocket connectinos. +const int32 kSendBufferSizeForDevTools = 100 * 1024 * 1024; // 100Mb + // An internal implementation of DevToolsAgentHostClient that delegates // messages sent to a DebuggerShell instance. class DevToolsAgentHostClientImpl : public DevToolsAgentHostClient { @@ -104,13 +108,15 @@ class DevToolsAgentHostClientImpl : public DevToolsAgentHostClient { message_loop_->PostTask( FROM_HERE, base::Bind(&net::HttpServer::SendOverWebSocket, - server_, + base::Unretained(server_), connection_id_, response)); message_loop_->PostTask( FROM_HERE, - base::Bind(&net::HttpServer::Close, server_, connection_id_)); + base::Bind(&net::HttpServer::Close, + base::Unretained(server_), + connection_id_)); } virtual void DispatchProtocolMessage( @@ -119,7 +125,7 @@ class DevToolsAgentHostClientImpl : public DevToolsAgentHostClient { message_loop_->PostTask( FROM_HERE, base::Bind(&net::HttpServer::SendOverWebSocket, - server_, + base::Unretained(server_), connection_id_, message)); } @@ -130,9 +136,9 @@ class DevToolsAgentHostClientImpl : public DevToolsAgentHostClient { } private: - base::MessageLoop* message_loop_; - net::HttpServer* server_; - int connection_id_; + base::MessageLoop* const message_loop_; + net::HttpServer* const server_; + const int connection_id_; scoped_refptr<DevToolsAgentHost> agent_host_; }; @@ -160,12 +166,12 @@ int DevToolsHttpHandler::GetFrontendResourceId(const std::string& name) { // static DevToolsHttpHandler* DevToolsHttpHandler::Start( - const net::StreamListenSocketFactory* socket_factory, + scoped_ptr<ServerSocketFactory> server_socket_factory, const std::string& frontend_url, DevToolsHttpHandlerDelegate* delegate, const base::FilePath& active_port_output_directory) { DevToolsHttpHandlerImpl* http_handler = - new DevToolsHttpHandlerImpl(socket_factory, + new DevToolsHttpHandlerImpl(server_socket_factory.Pass(), frontend_url, delegate, active_port_output_directory); @@ -173,6 +179,28 @@ DevToolsHttpHandler* DevToolsHttpHandler::Start( return http_handler; } +DevToolsHttpHandler::ServerSocketFactory::ServerSocketFactory( + const std::string& address, + int port, + int backlog) + : address_(address), + port_(port), + backlog_(backlog) { +} + +DevToolsHttpHandler::ServerSocketFactory::~ServerSocketFactory() { +} + +scoped_ptr<net::ServerSocket> +DevToolsHttpHandler::ServerSocketFactory::CreateAndListen() const { + scoped_ptr<net::ServerSocket> socket = Create(); + if (socket && + socket->ListenWithAddressAndPort(address_, port_, backlog_) == net::OK) { + return socket.Pass(); + } + return scoped_ptr<net::ServerSocket>(); +} + DevToolsHttpHandlerImpl::~DevToolsHttpHandlerImpl() { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); // Stop() must be called prior to destruction. @@ -262,6 +290,8 @@ static std::string GetMimeType(const std::string& filename) { void DevToolsHttpHandlerImpl::OnHttpRequest( int connection_id, const net::HttpServerRequestInfo& info) { + server_->SetSendBufferSize(connection_id, kSendBufferSizeForDevTools); + if (info.path.find("/json") == 0) { BrowserThread::PostTask( BrowserThread::UI, @@ -351,6 +381,7 @@ void DevToolsHttpHandlerImpl::OnWebSocketRequest( true /* handle on UI thread */); browser_targets_[connection_id] = browser_target; + server_->SetSendBufferSize(connection_id, kSendBufferSizeForDevTools); server_->AcceptWebSocket(connection_id, request); return; } @@ -651,16 +682,16 @@ void DevToolsHttpHandlerImpl::OnCloseUI(int connection_id) { } DevToolsHttpHandlerImpl::DevToolsHttpHandlerImpl( - const net::StreamListenSocketFactory* socket_factory, + scoped_ptr<ServerSocketFactory> server_socket_factory, const std::string& frontend_url, DevToolsHttpHandlerDelegate* delegate, const base::FilePath& active_port_output_directory) : frontend_url_(frontend_url), - socket_factory_(socket_factory), + server_socket_factory_(server_socket_factory.Pass()), delegate_(delegate), active_port_output_directory_(active_port_output_directory) { if (frontend_url_.empty()) - frontend_url_ = "/devtools/devtools.html"; + frontend_url_ = "/devtools/devtools.html"; // Balanced in ResetHandlerThreadAndRelease(). AddRef(); @@ -668,14 +699,15 @@ DevToolsHttpHandlerImpl::DevToolsHttpHandlerImpl( // Runs on the handler thread void DevToolsHttpHandlerImpl::Init() { - server_ = new net::HttpServer(*socket_factory_.get(), this); + server_.reset(new net::HttpServer(server_socket_factory_->CreateAndListen(), + this)); if (!active_port_output_directory_.empty()) WriteActivePortToUserProfile(); } // Runs on the handler thread void DevToolsHttpHandlerImpl::Teardown() { - server_ = NULL; + server_.reset(NULL); } // Runs on FILE thread to make sure that it is serialized against @@ -733,7 +765,7 @@ void DevToolsHttpHandlerImpl::SendJson(int connection_id, thread_->message_loop()->PostTask( FROM_HERE, base::Bind(&net::HttpServer::SendResponse, - server_.get(), + base::Unretained(server_.get()), connection_id, response)); } @@ -746,7 +778,7 @@ void DevToolsHttpHandlerImpl::Send200(int connection_id, thread_->message_loop()->PostTask( FROM_HERE, base::Bind(&net::HttpServer::Send200, - server_.get(), + base::Unretained(server_.get()), connection_id, data, mime_type)); @@ -757,7 +789,9 @@ void DevToolsHttpHandlerImpl::Send404(int connection_id) { return; thread_->message_loop()->PostTask( FROM_HERE, - base::Bind(&net::HttpServer::Send404, server_.get(), connection_id)); + base::Bind(&net::HttpServer::Send404, + base::Unretained(server_.get()), + connection_id)); } void DevToolsHttpHandlerImpl::Send500(int connection_id, @@ -766,7 +800,9 @@ void DevToolsHttpHandlerImpl::Send500(int connection_id, return; thread_->message_loop()->PostTask( FROM_HERE, - base::Bind(&net::HttpServer::Send500, server_.get(), connection_id, + base::Bind(&net::HttpServer::Send500, + base::Unretained(server_.get()), + connection_id, message)); } @@ -777,8 +813,16 @@ void DevToolsHttpHandlerImpl::AcceptWebSocket( return; thread_->message_loop()->PostTask( FROM_HERE, - base::Bind(&net::HttpServer::AcceptWebSocket, server_.get(), - connection_id, request)); + base::Bind(&net::HttpServer::SetSendBufferSize, + base::Unretained(server_.get()), + connection_id, + kSendBufferSizeForDevTools)); + thread_->message_loop()->PostTask( + FROM_HERE, + base::Bind(&net::HttpServer::AcceptWebSocket, + base::Unretained(server_.get()), + connection_id, + request)); } base::DictionaryValue* DevToolsHttpHandlerImpl::SerializeTarget( diff --git a/content/browser/devtools/devtools_http_handler_impl.h b/content/browser/devtools/devtools_http_handler_impl.h index e9cee0b..3220e05 100644 --- a/content/browser/devtools/devtools_http_handler_impl.h +++ b/content/browser/devtools/devtools_http_handler_impl.h @@ -27,7 +27,7 @@ class Value; } namespace net { -class StreamListenSocketFactory; +class ServerSocketFactory; class URLRequestContextGetter; } @@ -43,8 +43,7 @@ class DevToolsHttpHandlerImpl friend class base::RefCountedThreadSafe<DevToolsHttpHandlerImpl>; friend class DevToolsHttpHandler; - // Takes ownership over |socket_factory|. - DevToolsHttpHandlerImpl(const net::StreamListenSocketFactory* socket_factory, + DevToolsHttpHandlerImpl(scoped_ptr<ServerSocketFactory> server_socket_factory, const std::string& frontend_url, DevToolsHttpHandlerDelegate* delegate, const base::FilePath& active_port_output_directory); @@ -117,12 +116,12 @@ class DevToolsHttpHandlerImpl scoped_ptr<base::Thread> thread_; std::string frontend_url_; - scoped_ptr<const net::StreamListenSocketFactory> socket_factory_; - scoped_refptr<net::HttpServer> server_; + const scoped_ptr<ServerSocketFactory> server_socket_factory_; + scoped_ptr<net::HttpServer> server_; typedef std::map<int, DevToolsAgentHostClient*> ConnectionToClientMap; ConnectionToClientMap connection_to_client_ui_; - scoped_ptr<DevToolsHttpHandlerDelegate> delegate_; - base::FilePath active_port_output_directory_; + const scoped_ptr<DevToolsHttpHandlerDelegate> delegate_; + const base::FilePath active_port_output_directory_; typedef std::map<std::string, DevToolsTarget*> TargetMap; TargetMap target_map_; typedef std::map<int, scoped_refptr<DevToolsBrowserTarget> > BrowserTargets; diff --git a/content/browser/devtools/devtools_http_handler_unittest.cc b/content/browser/devtools/devtools_http_handler_unittest.cc index 8871092..987ea3e3 100644 --- a/content/browser/devtools/devtools_http_handler_unittest.cc +++ b/content/browser/devtools/devtools_http_handler_unittest.cc @@ -13,7 +13,7 @@ #include "content/public/browser/devtools_target.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/socket/stream_listen_socket.h" +#include "net/socket/server_socket.h" #include "testing/gtest/include/gtest/gtest.h" namespace content { @@ -23,49 +23,55 @@ const int kDummyPort = 4321; const base::FilePath::CharType kDevToolsActivePortFileName[] = FILE_PATH_LITERAL("DevToolsActivePort"); -using net::StreamListenSocket; - -class DummyListenSocket : public StreamListenSocket, - public StreamListenSocket::Delegate { +class DummyServerSocket : public net::ServerSocket { public: - DummyListenSocket() - : StreamListenSocket(net::kInvalidSocket, this) {} - - // StreamListenSocket::Delegate "implementation" - virtual void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> connection) OVERRIDE {} - virtual void DidRead(StreamListenSocket* connection, - const char* data, - int len) OVERRIDE {} - virtual void DidClose(StreamListenSocket* sock) OVERRIDE {} - protected: - virtual ~DummyListenSocket() {} - virtual void Accept() OVERRIDE {} - virtual int GetLocalAddress(net::IPEndPoint* address) OVERRIDE { + DummyServerSocket() {} + + // net::ServerSocket "implementation" + virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE { + return net::OK; + } + + virtual int ListenWithAddressAndPort(const std::string& ip_address, + int port, + int backlog) OVERRIDE { + return net::OK; + } + + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { net::IPAddressNumber number; EXPECT_TRUE(net::ParseIPLiteralToNumber("127.0.0.1", &number)); *address = net::IPEndPoint(number, kDummyPort); return net::OK; } + + virtual int Accept(scoped_ptr<net::StreamSocket>* socket, + const net::CompletionCallback& callback) OVERRIDE { + return net::ERR_IO_PENDING; + } }; -class DummyListenSocketFactory : public net::StreamListenSocketFactory { +class DummyServerSocketFactory + : public DevToolsHttpHandler::ServerSocketFactory { public: - DummyListenSocketFactory( - base::Closure quit_closure_1, base::Closure quit_closure_2) - : quit_closure_1_(quit_closure_1), quit_closure_2_(quit_closure_2) {} - virtual ~DummyListenSocketFactory() { + DummyServerSocketFactory(base::Closure quit_closure_1, + base::Closure quit_closure_2) + : DevToolsHttpHandler::ServerSocketFactory("", 0, 0), + quit_closure_1_(quit_closure_1), + quit_closure_2_(quit_closure_2) {} + + virtual ~DummyServerSocketFactory() { BrowserThread::PostTask( BrowserThread::UI, FROM_HERE, quit_closure_2_); } - virtual scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const OVERRIDE { + private: + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { BrowserThread::PostTask( BrowserThread::UI, FROM_HERE, quit_closure_1_); - return scoped_ptr<net::StreamListenSocket>(new DummyListenSocket()); + return scoped_ptr<net::ServerSocket>(new DummyServerSocket()); } - private: + base::Closure quit_closure_1_; base::Closure quit_closure_2_; }; @@ -73,22 +79,28 @@ class DummyListenSocketFactory : public net::StreamListenSocketFactory { class DummyDelegate : public DevToolsHttpHandlerDelegate { public: virtual std::string GetDiscoveryPageHTML() OVERRIDE { return std::string(); } + virtual bool BundlesFrontendResources() OVERRIDE { return true; } + virtual base::FilePath GetDebugFrontendDir() OVERRIDE { return base::FilePath(); } + virtual std::string GetPageThumbnailData(const GURL& url) OVERRIDE { return std::string(); } + virtual scoped_ptr<DevToolsTarget> CreateNewTarget(const GURL& url) OVERRIDE { return scoped_ptr<DevToolsTarget>(); } + virtual void EnumerateTargets(TargetCallback callback) OVERRIDE { callback.Run(TargetList()); } + virtual scoped_ptr<net::StreamListenSocket> CreateSocketForTethering( - net::StreamListenSocket::Delegate* delegate, - std::string* name) OVERRIDE { + net::StreamListenSocket::Delegate* delegate, + std::string* name) OVERRIDE { return scoped_ptr<net::StreamListenSocket>(); } }; @@ -100,14 +112,17 @@ class DevToolsHttpHandlerTest : public testing::Test { DevToolsHttpHandlerTest() : ui_thread_(BrowserThread::UI, &message_loop_) { } + protected: virtual void SetUp() { file_thread_.reset(new BrowserThreadImpl(BrowserThread::FILE)); file_thread_->Start(); } + virtual void TearDown() { file_thread_->Stop(); } + private: base::MessageLoopForIO message_loop_; BrowserThreadImpl ui_thread_; @@ -116,13 +131,14 @@ class DevToolsHttpHandlerTest : public testing::Test { TEST_F(DevToolsHttpHandlerTest, TestStartStop) { base::RunLoop run_loop, run_loop_2; + scoped_ptr<DevToolsHttpHandler::ServerSocketFactory> factory( + new DummyServerSocketFactory(run_loop.QuitClosure(), + run_loop_2.QuitClosure())); content::DevToolsHttpHandler* devtools_http_handler_ = - content::DevToolsHttpHandler::Start( - new DummyListenSocketFactory(run_loop.QuitClosure(), - run_loop_2.QuitClosure()), - std::string(), - new DummyDelegate(), - base::FilePath()); + content::DevToolsHttpHandler::Start(factory.Pass(), + std::string(), + new DummyDelegate(), + base::FilePath()); // Our dummy socket factory will post a quit message once the server will // become ready. run_loop.Run(); @@ -135,13 +151,14 @@ TEST_F(DevToolsHttpHandlerTest, TestDevToolsActivePort) { base::RunLoop run_loop, run_loop_2; base::ScopedTempDir temp_dir; EXPECT_TRUE(temp_dir.CreateUniqueTempDir()); + scoped_ptr<DevToolsHttpHandler::ServerSocketFactory> factory( + new DummyServerSocketFactory(run_loop.QuitClosure(), + run_loop_2.QuitClosure())); content::DevToolsHttpHandler* devtools_http_handler_ = - content::DevToolsHttpHandler::Start( - new DummyListenSocketFactory(run_loop.QuitClosure(), - run_loop_2.QuitClosure()), - std::string(), - new DummyDelegate(), - temp_dir.path()); + content::DevToolsHttpHandler::Start(factory.Pass(), + std::string(), + new DummyDelegate(), + temp_dir.path()); // Our dummy socket factory will post a quit message once the server will // become ready. run_loop.Run(); diff --git a/content/public/browser/devtools_http_handler.h b/content/public/browser/devtools_http_handler.h index cdceff8..2b656b6 100644 --- a/content/public/browser/devtools_http_handler.h +++ b/content/public/browser/devtools_http_handler.h @@ -8,12 +8,13 @@ #include <string> #include "base/files/file_path.h" +#include "base/memory/scoped_ptr.h" #include "content/common/content_export.h" class GURL; namespace net { -class StreamListenSocketFactory; +class ServerSocket; class URLRequestContextGetter; } @@ -26,6 +27,32 @@ class DevToolsHttpHandlerDelegate; // this browser. class DevToolsHttpHandler { public: + + // Factory of net::ServerSocket. This is to separate instantiating dev tools + // and instantiating server socket. + class CONTENT_EXPORT ServerSocketFactory { + public: + ServerSocketFactory(const std::string& address, int port, int backlog); + virtual ~ServerSocketFactory(); + + // Returns a new instance of ServerSocket or NULL if an error occurred. + // It calls ServerSocket::ListenWithAddressAndPort() with address, port and + // backlog passed to constructor. + virtual scoped_ptr<net::ServerSocket> CreateAndListen() const; + + protected: + // Creates a server socket. ServerSocket::Listen() will be called soon + // unless it returns NULL. + virtual scoped_ptr<net::ServerSocket> Create() const = 0; + + const std::string address_; + const int port_; + const int backlog_; + + private: + DISALLOW_COPY_AND_ASSIGN(ServerSocketFactory); + }; + // Returns true if the given protocol version is supported. CONTENT_EXPORT static bool IsSupportedProtocolVersion( const std::string& version); @@ -40,7 +67,7 @@ class DevToolsHttpHandler { // port selected by the OS will be written to a well-known file in // the output directory. CONTENT_EXPORT static DevToolsHttpHandler* Start( - const net::StreamListenSocketFactory* socket_factory, + scoped_ptr<ServerSocketFactory> server_socket_factory, const std::string& frontend_url, DevToolsHttpHandlerDelegate* delegate, const base::FilePath& active_port_output_directory); diff --git a/content/shell/browser/shell_devtools_delegate.cc b/content/shell/browser/shell_devtools_delegate.cc index 0f4214f..c740b46 100644 --- a/content/shell/browser/shell_devtools_delegate.cc +++ b/content/shell/browser/shell_devtools_delegate.cc @@ -24,12 +24,12 @@ #include "content/public/common/user_agent.h" #include "content/shell/browser/shell.h" #include "grit/shell_resources.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "ui/base/resource/resource_bundle.h" #if defined(OS_ANDROID) #include "content/public/browser/android/devtools_auth.h" -#include "net/socket/unix_domain_listen_socket_posix.h" +#include "net/socket/unix_domain_server_socket_posix.h" #endif using content::DevToolsAgentHost; @@ -44,7 +44,45 @@ const char kFrontEndURL[] = #endif const char kTargetTypePage[] = "page"; -net::StreamListenSocketFactory* CreateSocketFactory() { +#if defined(OS_ANDROID) +class UnixDomainServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + explicit UnixDomainServerSocketFactory(const std::string& socket_name) + : content::DevToolsHttpHandler::ServerSocketFactory(socket_name, 0, 1) {} + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::UnixDomainServerSocket( + base::Bind(&content::CanUserConnectToDevTools), + true /* use_abstract_namespace */)); + } + + DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocketFactory); +}; +#else +class TCPServerSocketFactory + : public content::DevToolsHttpHandler::ServerSocketFactory { + public: + TCPServerSocketFactory(const std::string& address, int port, int backlog) + : content::DevToolsHttpHandler::ServerSocketFactory( + address, port, backlog) {} + + private: + // content::DevToolsHttpHandler::ServerSocketFactory. + virtual scoped_ptr<net::ServerSocket> Create() const OVERRIDE { + return scoped_ptr<net::ServerSocket>( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + } + + DISALLOW_COPY_AND_ASSIGN(TCPServerSocketFactory); +}; +#endif + +scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory> +CreateSocketFactory() { const CommandLine& command_line = *CommandLine::ForCurrentProcess(); #if defined(OS_ANDROID) std::string socket_name = "content_shell_devtools_remote"; @@ -52,9 +90,8 @@ net::StreamListenSocketFactory* CreateSocketFactory() { socket_name = command_line.GetSwitchValueASCII( switches::kRemoteDebuggingSocketName); } - return new net::deprecated:: - UnixDomainListenSocketWithAbstractNamespaceFactory( - socket_name, "", base::Bind(&content::CanUserConnectToDevTools)); + return scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory>( + new UnixDomainServerSocketFactory(socket_name)); #else // See if the user specified a port on the command line (useful for // automation). If not, use an ephemeral port by specifying 0. @@ -70,7 +107,8 @@ net::StreamListenSocketFactory* CreateSocketFactory() { DLOG(WARNING) << "Invalid http debugger port number " << temp_port; } } - return new net::TCPListenSocketFactory("127.0.0.1", port); + return scoped_ptr<content::DevToolsHttpHandler::ServerSocketFactory>( + new TCPServerSocketFactory("127.0.0.1", port, 1)); #endif } diff --git a/extensions/browser/api/socket/tcp_socket.cc b/extensions/browser/api/socket/tcp_socket.cc index 321b631..d71ebd2 100644 --- a/extensions/browser/api/socket/tcp_socket.cc +++ b/extensions/browser/api/socket/tcp_socket.cc @@ -201,6 +201,7 @@ int TCPSocket::Listen(const std::string& address, if (!server_socket_.get()) { server_socket_.reset(new net::TCPServerSocket(NULL, net::NetLog::Source())); } + int result = server_socket_->ListenWithAddressAndPort(address, port, backlog); if (result) *error_msg = kSocketListenError; diff --git a/mojo/spy/websocket_server.cc b/mojo/spy/websocket_server.cc index bf4c96a..20e2da6 100644 --- a/mojo/spy/websocket_server.cc +++ b/mojo/spy/websocket_server.cc @@ -14,7 +14,7 @@ #include "net/base/net_errors.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "url/gurl.h" namespace mojo { @@ -42,8 +42,10 @@ WebSocketServer::~WebSocketServer() { } bool WebSocketServer::Start() { - net::TCPListenSocketFactory factory("0.0.0.0", port_); - web_server_ = new net::HttpServer(factory, this); + scoped_ptr<net::ServerSocket> server_socket( + new net::TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort("0.0.0.0", port_, 1); + web_server_.reset(new net::HttpServer(server_socket.Pass(), this)); net::IPEndPoint address; int error = web_server_->GetLocalAddress(&address); port_ = address.port(); @@ -91,9 +93,7 @@ void WebSocketServer::OnWebSocketRequest( const net::HttpServerRequestInfo& info) { if (connection_id_ != kNotConnected) { // Reject connection since we already have our client. - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&net::HttpServer::Close, web_server_, connection_id)); + web_server_->Close(connection_id); return; } // Accept the connection. @@ -157,4 +157,3 @@ bool WebSocketServer::Connected() const { } } // namespace mojo - diff --git a/mojo/spy/websocket_server.h b/mojo/spy/websocket_server.h index 1811c0c..eb685c7 100644 --- a/mojo/spy/websocket_server.h +++ b/mojo/spy/websocket_server.h @@ -66,7 +66,7 @@ class WebSocketServer : public net::HttpServer::Delegate, private: int port_; int connection_id_; - scoped_refptr<net::HttpServer> web_server_; + scoped_ptr<net::HttpServer> web_server_; spy_api::SpyServerPtr spy_server_; DISALLOW_COPY_AND_ASSIGN(WebSocketServer); diff --git a/net/net.gypi b/net/net.gypi index f240eba..b57d853 100644 --- a/net/net.gypi +++ b/net/net.gypi @@ -1555,6 +1555,7 @@ 'quic/quic_utils_test.cc', 'quic/quic_write_blocked_list_test.cc', 'quic/reliable_quic_stream_test.cc', + 'server/http_connection_unittest.cc', 'server/http_server_response_info_unittest.cc', 'server/http_server_unittest.cc', 'socket/client_socket_pool_base_unittest.cc', diff --git a/net/server/http_connection.cc b/net/server/http_connection.cc index d433012..3401f81 100644 --- a/net/server/http_connection.cc +++ b/net/server/http_connection.cc @@ -4,44 +4,163 @@ #include "net/server/http_connection.h" -#include "net/server/http_server.h" -#include "net/server/http_server_response_info.h" +#include "base/logging.h" #include "net/server/web_socket.h" -#include "net/socket/stream_listen_socket.h" +#include "net/socket/stream_socket.h" namespace net { -int HttpConnection::last_id_ = 0; +HttpConnection::ReadIOBuffer::ReadIOBuffer() + : base_(new GrowableIOBuffer()), + max_buffer_size_(kDefaultMaxBufferSize) { + SetCapacity(kInitialBufSize); +} -void HttpConnection::Send(const std::string& data) { - if (!socket_.get()) - return; - socket_->Send(data); +HttpConnection::ReadIOBuffer::~ReadIOBuffer() { + data_ = NULL; // base_ owns data_. +} + +int HttpConnection::ReadIOBuffer::GetCapacity() const { + return base_->capacity(); +} + +void HttpConnection::ReadIOBuffer::SetCapacity(int capacity) { + DCHECK_LE(GetSize(), capacity); + base_->SetCapacity(capacity); + data_ = base_->data(); +} + +bool HttpConnection::ReadIOBuffer::IncreaseCapacity() { + if (GetCapacity() >= max_buffer_size_) { + LOG(ERROR) << "Too large read data is pending: capacity=" << GetCapacity() + << ", max_buffer_size=" << max_buffer_size_ + << ", read=" << GetSize(); + return false; + } + + int new_capacity = GetCapacity() * kCapacityIncreaseFactor; + if (new_capacity > max_buffer_size_) + new_capacity = max_buffer_size_; + SetCapacity(new_capacity); + return true; +} + +char* HttpConnection::ReadIOBuffer::StartOfBuffer() const { + return base_->StartOfBuffer(); +} + +int HttpConnection::ReadIOBuffer::GetSize() const { + return base_->offset(); +} + +void HttpConnection::ReadIOBuffer::DidRead(int bytes) { + DCHECK_GE(RemainingCapacity(), bytes); + base_->set_offset(base_->offset() + bytes); + data_ = base_->data(); +} + +int HttpConnection::ReadIOBuffer::RemainingCapacity() const { + return base_->RemainingCapacity(); +} + +void HttpConnection::ReadIOBuffer::DidConsume(int bytes) { + int previous_size = GetSize(); + int unconsumed_size = previous_size - bytes; + DCHECK_LE(0, unconsumed_size); + if (unconsumed_size > 0) { + // Move unconsumed data to the start of buffer. + memmove(StartOfBuffer(), StartOfBuffer() + bytes, unconsumed_size); + } + base_->set_offset(unconsumed_size); + data_ = base_->data(); + + // If capacity is too big, reduce it. + if (GetCapacity() > kMinimumBufSize && + GetCapacity() > previous_size * kCapacityIncreaseFactor) { + int new_capacity = GetCapacity() / kCapacityIncreaseFactor; + if (new_capacity < kMinimumBufSize) + new_capacity = kMinimumBufSize; + // realloc() within GrowableIOBuffer::SetCapacity() could move data even + // when size is reduced. If unconsumed_size == 0, i.e. no data exists in + // the buffer, free internal buffer first to guarantee no data move. + if (!unconsumed_size) + base_->SetCapacity(0); + SetCapacity(new_capacity); + } +} + +HttpConnection::QueuedWriteIOBuffer::QueuedWriteIOBuffer() + : total_size_(0), + max_buffer_size_(kDefaultMaxBufferSize) { +} + +HttpConnection::QueuedWriteIOBuffer::~QueuedWriteIOBuffer() { + data_ = NULL; // pending_data_ owns data_. } -void HttpConnection::Send(const char* bytes, int len) { - if (!socket_.get()) +bool HttpConnection::QueuedWriteIOBuffer::IsEmpty() const { + return pending_data_.empty(); +} + +bool HttpConnection::QueuedWriteIOBuffer::Append(const std::string& data) { + if (data.empty()) + return true; + + if (total_size_ + static_cast<int>(data.size()) > max_buffer_size_) { + LOG(ERROR) << "Too large write data is pending: size=" + << total_size_ + data.size() + << ", max_buffer_size=" << max_buffer_size_; + return false; + } + + pending_data_.push(data); + total_size_ += data.size(); + + // If new data is the first pending data, updates data_. + if (pending_data_.size() == 1) + data_ = const_cast<char*>(pending_data_.front().data()); + return true; +} + +void HttpConnection::QueuedWriteIOBuffer::DidConsume(int size) { + DCHECK_GE(total_size_, size); + DCHECK_GE(GetSizeToWrite(), size); + if (size == 0) return; - socket_->Send(bytes, len); + + if (size < GetSizeToWrite()) { + data_ += size; + } else { // size == GetSizeToWrite(). Updates data_ to next pending data. + pending_data_.pop(); + data_ = IsEmpty() ? NULL : const_cast<char*>(pending_data_.front().data()); + } + total_size_ -= size; } -void HttpConnection::Send(const HttpServerResponseInfo& response) { - Send(response.Serialize()); +int HttpConnection::QueuedWriteIOBuffer::GetSizeToWrite() const { + if (IsEmpty()) { + DCHECK_EQ(0, total_size_); + return 0; + } + DCHECK_GE(data_, pending_data_.front().data()); + int consumed = static_cast<int>(data_ - pending_data_.front().data()); + DCHECK_GT(static_cast<int>(pending_data_.front().size()), consumed); + return pending_data_.front().size() - consumed; } -HttpConnection::HttpConnection(HttpServer* server, - scoped_ptr<StreamListenSocket> sock) - : server_(server), - socket_(sock.Pass()) { - id_ = last_id_++; +HttpConnection::HttpConnection(int id, scoped_ptr<StreamSocket> socket) + : id_(id), + socket_(socket.Pass()), + read_buf_(new ReadIOBuffer()), + write_buf_(new QueuedWriteIOBuffer()) { } HttpConnection::~HttpConnection() { - server_->delegate_->OnClose(id_); } -void HttpConnection::Shift(int num_bytes) { - recv_data_ = recv_data_.substr(num_bytes); +void HttpConnection::SetWebSocket(scoped_ptr<WebSocket> web_socket) { + DCHECK(!web_socket_); + web_socket_ = web_socket.Pass(); } } // namespace net diff --git a/net/server/http_connection.h b/net/server/http_connection.h index 17faa46..c7225e1 100644 --- a/net/server/http_connection.h +++ b/net/server/http_connection.h @@ -5,43 +5,130 @@ #ifndef NET_SERVER_HTTP_CONNECTION_H_ #define NET_SERVER_HTTP_CONNECTION_H_ +#include <queue> #include <string> #include "base/basictypes.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "net/http/http_status_code.h" +#include "net/base/io_buffer.h" namespace net { -class HttpServer; -class HttpServerResponseInfo; -class StreamListenSocket; +class StreamSocket; class WebSocket; +// A container which has all information of an http connection. It includes +// id, underlying socket, and pending read/write data. class HttpConnection { public: - ~HttpConnection(); + // IOBuffer for data read. It's a wrapper around GrowableIOBuffer, with more + // functions for buffer management. It moves unconsumed data to the start of + // buffer. + class ReadIOBuffer : public IOBuffer { + public: + static const int kInitialBufSize = 1024; + static const int kMinimumBufSize = 128; + static const int kCapacityIncreaseFactor = 2; + static const int kDefaultMaxBufferSize = 1 * 1024 * 1024; // 1 Mbytes. + + ReadIOBuffer(); + + // Capacity. + int GetCapacity() const; + void SetCapacity(int capacity); + // Increases capacity and returns true if capacity is not beyond the limit. + bool IncreaseCapacity(); + + // Start of read data. + char* StartOfBuffer() const; + // Returns the bytes of read data. + int GetSize() const; + // More read data was appended. + void DidRead(int bytes); + // Capacity for which more read data can be appended. + int RemainingCapacity() const; + + // Removes consumed data and moves unconsumed data to the start of buffer. + void DidConsume(int bytes); + + // Limit of how much internal capacity can increase. + int max_buffer_size() const { return max_buffer_size_; } + void set_max_buffer_size(int max_buffer_size) { + max_buffer_size_ = max_buffer_size; + } + + private: + virtual ~ReadIOBuffer(); + + scoped_refptr<GrowableIOBuffer> base_; + int max_buffer_size_; + + DISALLOW_COPY_AND_ASSIGN(ReadIOBuffer); + }; + + // IOBuffer of pending data to write which has a queue of pending data. Each + // pending data is stored in std::string. data() is the data of first + // std::string stored. + class QueuedWriteIOBuffer : public IOBuffer { + public: + static const int kDefaultMaxBufferSize = 1 * 1024 * 1024; // 1 Mbytes. + + QueuedWriteIOBuffer(); + + // Whether or not pending data exists. + bool IsEmpty() const; - void Send(const std::string& data); - void Send(const char* bytes, int len); - void Send(const HttpServerResponseInfo& response); + // Appends new pending data and returns true if total size doesn't exceed + // the limit, |total_size_limit_|. It would change data() if new data is + // the first pending data. + bool Append(const std::string& data); - void Shift(int num_bytes); + // Consumes data and changes data() accordingly. It cannot be more than + // GetSizeToWrite(). + void DidConsume(int size); + + // Gets size of data to write this time. It is NOT total data size. + int GetSizeToWrite() const; + + // Total size of all pending data. + int total_size() const { return total_size_; } + + // Limit of how much data can be pending. + int max_buffer_size() const { return max_buffer_size_; } + void set_max_buffer_size(int max_buffer_size) { + max_buffer_size_ = max_buffer_size; + } + + private: + virtual ~QueuedWriteIOBuffer(); + + std::queue<std::string> pending_data_; + int total_size_; + int max_buffer_size_; + + DISALLOW_COPY_AND_ASSIGN(QueuedWriteIOBuffer); + }; + + HttpConnection(int id, scoped_ptr<StreamSocket> socket); + ~HttpConnection(); - const std::string& recv_data() const { return recv_data_; } int id() const { return id_; } + StreamSocket* socket() const { return socket_.get(); } + ReadIOBuffer* read_buf() const { return read_buf_.get(); } + QueuedWriteIOBuffer* write_buf() const { return write_buf_.get(); } - private: - friend class HttpServer; - static int last_id_; + WebSocket* web_socket() const { return web_socket_.get(); } + void SetWebSocket(scoped_ptr<WebSocket> web_socket); - HttpConnection(HttpServer* server, scoped_ptr<StreamListenSocket> sock); + private: + const int id_; + const scoped_ptr<StreamSocket> socket_; + const scoped_refptr<ReadIOBuffer> read_buf_; + const scoped_refptr<QueuedWriteIOBuffer> write_buf_; - HttpServer* server_; - scoped_ptr<StreamListenSocket> socket_; scoped_ptr<WebSocket> web_socket_; - std::string recv_data_; - int id_; + DISALLOW_COPY_AND_ASSIGN(HttpConnection); }; diff --git a/net/server/http_connection_unittest.cc b/net/server/http_connection_unittest.cc new file mode 100644 index 0000000..488fd6f --- /dev/null +++ b/net/server/http_connection_unittest.cc @@ -0,0 +1,331 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/server/http_connection.h" + +#include <string> + +#include "base/memory/ref_counted.h" +#include "base/strings/string_piece.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +std::string GetTestString(int size) { + std::string test_string; + for (int i = 0; i < size; ++i) { + test_string.push_back('A' + (i % 26)); + } + return test_string; +} + +TEST(HttpConnectionTest, ReadIOBuffer_SetCapacity) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); + + const int kNewCapacity = HttpConnection::ReadIOBuffer::kInitialBufSize + 128; + buffer->SetCapacity(kNewCapacity); + EXPECT_EQ(kNewCapacity, buffer->GetCapacity()); + EXPECT_EQ(kNewCapacity, buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); +} + +TEST(HttpConnectionTest, ReadIOBuffer_SetCapacity_WithData) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->RemainingCapacity()); + + // Write arbitrary data up to kInitialBufSize. + const std::string kReadData( + GetTestString(HttpConnection::ReadIOBuffer::kInitialBufSize)); + memcpy(buffer->data(), kReadData.data(), kReadData.size()); + buffer->DidRead(kReadData.size()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize - + static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); + + // Check if read data in the buffer is same after SetCapacity(). + const int kNewCapacity = HttpConnection::ReadIOBuffer::kInitialBufSize + 128; + buffer->SetCapacity(kNewCapacity); + EXPECT_EQ(kNewCapacity, buffer->GetCapacity()); + EXPECT_EQ(kNewCapacity - static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); +} + +TEST(HttpConnectionTest, ReadIOBuffer_IncreaseCapacity) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_TRUE(buffer->IncreaseCapacity()); + const int kExpectedInitialBufSize = + HttpConnection::ReadIOBuffer::kInitialBufSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor; + EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity()); + EXPECT_EQ(kExpectedInitialBufSize, buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); + + // Increase capacity until it fails. + while (buffer->IncreaseCapacity()); + EXPECT_FALSE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->max_buffer_size()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->GetCapacity()); + + // Enlarge capacity limit. + buffer->set_max_buffer_size(buffer->max_buffer_size() * 2); + EXPECT_TRUE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + + // Shrink capacity limit. It doesn't change capacity itself. + buffer->set_max_buffer_size( + HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize / 2); + EXPECT_FALSE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); +} + +TEST(HttpConnectionTest, ReadIOBuffer_IncreaseCapacity_WithData) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + EXPECT_TRUE(buffer->IncreaseCapacity()); + const int kExpectedInitialBufSize = + HttpConnection::ReadIOBuffer::kInitialBufSize * + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor; + EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity()); + EXPECT_EQ(kExpectedInitialBufSize, buffer->RemainingCapacity()); + EXPECT_EQ(0, buffer->GetSize()); + + // Write arbitrary data up to kExpectedInitialBufSize. + std::string kReadData(GetTestString(kExpectedInitialBufSize)); + memcpy(buffer->data(), kReadData.data(), kReadData.size()); + buffer->DidRead(kReadData.size()); + EXPECT_EQ(kExpectedInitialBufSize, buffer->GetCapacity()); + EXPECT_EQ(kExpectedInitialBufSize - static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); + + // Increase capacity until it fails and check if read data in the buffer is + // same. + while (buffer->IncreaseCapacity()); + EXPECT_FALSE(buffer->IncreaseCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->max_buffer_size()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize + 0, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kDefaultMaxBufferSize - + static_cast<int>(kReadData.size()), + buffer->RemainingCapacity()); + EXPECT_EQ(static_cast<int>(kReadData.size()), buffer->GetSize()); + EXPECT_EQ(kReadData, + base::StringPiece(buffer->StartOfBuffer(), buffer->GetSize())); +} + +TEST(HttpConnectionTest, ReadIOBuffer_DidRead_DidConsume) { + scoped_refptr<HttpConnection::ReadIOBuffer> buffer( + new HttpConnection::ReadIOBuffer); + const char* start_of_buffer = buffer->StartOfBuffer(); + EXPECT_EQ(start_of_buffer, buffer->data()); + + // Read data. + const int kReadLength = 128; + const std::string kReadData(GetTestString(kReadLength)); + memcpy(buffer->data(), kReadData.data(), kReadLength); + buffer->DidRead(kReadLength); + // No change in total capacity. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize + 0, + buffer->GetCapacity()); + // Change in unused capacity because of read data. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize - kReadLength, + buffer->RemainingCapacity()); + EXPECT_EQ(kReadLength, buffer->GetSize()); + // No change in start pointers of read data. + EXPECT_EQ(start_of_buffer, buffer->StartOfBuffer()); + // Change in start pointer of unused buffer. + EXPECT_EQ(start_of_buffer + kReadLength, buffer->data()); + // Test read data. + EXPECT_EQ(kReadData, std::string(buffer->StartOfBuffer(), buffer->GetSize())); + + // Consume data partially. + const int kConsumedLength = 32; + ASSERT_LT(kConsumedLength, kReadLength); + buffer->DidConsume(kConsumedLength); + // Capacity reduced because read data was too small comparing to capacity. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + // Change in unused capacity because of read data. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor - + kReadLength + kConsumedLength, + buffer->RemainingCapacity()); + // Change in read size. + EXPECT_EQ(kReadLength - kConsumedLength, buffer->GetSize()); + // Start data could be changed even when capacity is reduced. + start_of_buffer = buffer->StartOfBuffer(); + // Change in start pointer of unused buffer. + EXPECT_EQ(start_of_buffer + kReadLength - kConsumedLength, buffer->data()); + // Change in read data. + EXPECT_EQ(kReadData.substr(kConsumedLength), + std::string(buffer->StartOfBuffer(), buffer->GetSize())); + + // Read more data. + const int kReadLength2 = 64; + buffer->DidRead(kReadLength2); + // No change in total capacity. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + // Change in unused capacity because of read data. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor - + kReadLength + kConsumedLength - kReadLength2, + buffer->RemainingCapacity()); + // Change in read size + EXPECT_EQ(kReadLength - kConsumedLength + kReadLength2, buffer->GetSize()); + // No change in start pointer of read part. + EXPECT_EQ(start_of_buffer, buffer->StartOfBuffer()); + // Change in start pointer of unused buffer. + EXPECT_EQ(start_of_buffer + kReadLength - kConsumedLength + kReadLength2, + buffer->data()); + + // Consume data fully. + buffer->DidConsume(kReadLength - kConsumedLength + kReadLength2); + // Capacity reduced again because read data was too small. + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->GetCapacity()); + EXPECT_EQ(HttpConnection::ReadIOBuffer::kInitialBufSize / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor / + HttpConnection::ReadIOBuffer::kCapacityIncreaseFactor, + buffer->RemainingCapacity()); + // All reverts to initial because no data is left. + EXPECT_EQ(0, buffer->GetSize()); + // Start data could be changed even when capacity is reduced. + start_of_buffer = buffer->StartOfBuffer(); + EXPECT_EQ(start_of_buffer, buffer->data()); +} + +TEST(HttpConnectionTest, QueuedWriteIOBuffer_Append_DidConsume) { + scoped_refptr<HttpConnection::QueuedWriteIOBuffer> buffer( + new HttpConnection::QueuedWriteIOBuffer()); + EXPECT_TRUE(buffer->IsEmpty()); + EXPECT_EQ(0, buffer->GetSizeToWrite()); + EXPECT_EQ(0, buffer->total_size()); + + const std::string kData("data to write"); + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_FALSE(buffer->IsEmpty()); + EXPECT_EQ(static_cast<int>(kData.size()), buffer->GetSizeToWrite()); + EXPECT_EQ(static_cast<int>(kData.size()), buffer->total_size()); + // First data to write is same to kData. + EXPECT_EQ(kData, base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + const std::string kData2("more data to write"); + EXPECT_TRUE(buffer->Append(kData2)); + EXPECT_FALSE(buffer->IsEmpty()); + // No change in size to write. + EXPECT_EQ(static_cast<int>(kData.size()), buffer->GetSizeToWrite()); + // Change in total size. + EXPECT_EQ(static_cast<int>(kData.size() + kData2.size()), + buffer->total_size()); + // First data to write has not been changed. Same to kData. + EXPECT_EQ(kData, base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + // Consume data partially. + const int kConsumedLength = kData.length() - 1; + buffer->DidConsume(kConsumedLength); + EXPECT_FALSE(buffer->IsEmpty()); + // Change in size to write. + EXPECT_EQ(static_cast<int>(kData.size()) - kConsumedLength, + buffer->GetSizeToWrite()); + // Change in total size. + EXPECT_EQ(static_cast<int>(kData.size() + kData2.size()) - kConsumedLength, + buffer->total_size()); + // First data to write has shrinked. + EXPECT_EQ(kData.substr(kConsumedLength), + base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + // Consume first data fully. + buffer->DidConsume(kData.size() - kConsumedLength); + EXPECT_FALSE(buffer->IsEmpty()); + // Now, size to write is size of data added second. + EXPECT_EQ(static_cast<int>(kData2.size()), buffer->GetSizeToWrite()); + // Change in total size. + EXPECT_EQ(static_cast<int>(kData2.size()), buffer->total_size()); + // First data to write has changed to kData2. + EXPECT_EQ(kData2, + base::StringPiece(buffer->data(), buffer->GetSizeToWrite())); + + // Consume second data fully. + buffer->DidConsume(kData2.size()); + EXPECT_TRUE(buffer->IsEmpty()); + EXPECT_EQ(0, buffer->GetSizeToWrite()); + EXPECT_EQ(0, buffer->total_size()); +} + +TEST(HttpConnectionTest, QueuedWriteIOBuffer_TotalSizeLimit) { + scoped_refptr<HttpConnection::QueuedWriteIOBuffer> buffer( + new HttpConnection::QueuedWriteIOBuffer()); + EXPECT_EQ(HttpConnection::QueuedWriteIOBuffer::kDefaultMaxBufferSize + 0, + buffer->max_buffer_size()); + + // Set total size limit very small. + buffer->set_max_buffer_size(10); + + const int kDataLength = 4; + const std::string kData(kDataLength, 'd'); + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength, buffer->total_size()); + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 2, buffer->total_size()); + + // Cannot append more data because it exceeds the limit. + EXPECT_FALSE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 2, buffer->total_size()); + + // Consume data partially. + const int kConsumedLength = 2; + buffer->DidConsume(kConsumedLength); + EXPECT_EQ(kDataLength * 2 - kConsumedLength, buffer->total_size()); + + // Can add more data. + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 3 - kConsumedLength, buffer->total_size()); + + // Cannot append more data because it exceeds the limit. + EXPECT_FALSE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 3 - kConsumedLength, buffer->total_size()); + + // Enlarge limit. + buffer->set_max_buffer_size(20); + // Can add more data. + EXPECT_TRUE(buffer->Append(kData)); + EXPECT_EQ(kDataLength * 4 - kConsumedLength, buffer->total_size()); +} + +} // namespace +} // namespace net diff --git a/net/server/http_server.cc b/net/server/http_server.cc index 043e625..fb0dab3 100644 --- a/net/server/http_server.cc +++ b/net/server/http_server.cc @@ -17,14 +17,25 @@ #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" #include "net/server/web_socket.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/server_socket.h" +#include "net/socket/stream_socket.h" +#include "net/socket/tcp_server_socket.h" namespace net { -HttpServer::HttpServer(const StreamListenSocketFactory& factory, +HttpServer::HttpServer(scoped_ptr<ServerSocket> server_socket, HttpServer::Delegate* delegate) - : delegate_(delegate), - server_(factory.CreateAndListen(this)) { + : server_socket_(server_socket.Pass()), + delegate_(delegate), + last_id_(0), + weak_ptr_factory_(this) { + DCHECK(server_socket_); + DoAcceptLoop(); +} + +HttpServer::~HttpServer() { + STLDeleteContainerPairSecondPointers( + id_to_connection_.begin(), id_to_connection_.end()); } void HttpServer::AcceptWebSocket( @@ -33,9 +44,8 @@ void HttpServer::AcceptWebSocket( HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - - DCHECK(connection->web_socket_.get()); - connection->web_socket_->Accept(request); + DCHECK(connection->web_socket()); + connection->web_socket()->Accept(request); } void HttpServer::SendOverWebSocket(int connection_id, @@ -43,23 +53,23 @@ void HttpServer::SendOverWebSocket(int connection_id, HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - DCHECK(connection->web_socket_.get()); - connection->web_socket_->Send(data); + DCHECK(connection->web_socket()); + connection->web_socket()->Send(data); } void HttpServer::SendRaw(int connection_id, const std::string& data) { HttpConnection* connection = FindConnection(connection_id); if (connection == NULL) return; - connection->Send(data); + + bool writing_in_progress = !connection->write_buf()->IsEmpty(); + if (connection->write_buf()->Append(data) && !writing_in_progress) + DoWriteLoop(connection); } void HttpServer::SendResponse(int connection_id, const HttpServerResponseInfo& response) { - HttpConnection* connection = FindConnection(connection_id); - if (connection == NULL) - return; - connection->Send(response); + SendRaw(connection_id, response.Serialize()); } void HttpServer::Send(int connection_id, @@ -67,8 +77,9 @@ void HttpServer::Send(int connection_id, const std::string& data, const std::string& content_type) { HttpServerResponseInfo response(status_code); - response.SetBody(data, content_type); + response.SetContentHeaders(data.size(), content_type); SendResponse(connection_id, response); + SendRaw(connection_id, data); } void HttpServer::Send200(int connection_id, @@ -90,108 +101,209 @@ void HttpServer::Close(int connection_id) { if (connection == NULL) return; - // Initiating close from server-side does not lead to the DidClose call. - // Do it manually here. - DidClose(connection->socket_.get()); + id_to_connection_.erase(connection_id); + delegate_->OnClose(connection_id); + + // The call stack might have callbacks which still have the pointer of + // connection. Instead of referencing connection with ID all the time, + // destroys the connection in next run loop to make sure any pending + // callbacks in the call stack return. + base::MessageLoopProxy::current()->DeleteSoon(FROM_HERE, connection); } int HttpServer::GetLocalAddress(IPEndPoint* address) { - if (!server_) - return ERR_SOCKET_NOT_CONNECTED; - return server_->GetLocalAddress(address); + return server_socket_->GetLocalAddress(address); +} + +void HttpServer::SetReceiveBufferSize(int connection_id, int32 size) { + HttpConnection* connection = FindConnection(connection_id); + DCHECK(connection); + connection->read_buf()->set_max_buffer_size(size); } -void HttpServer::DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> socket) { - HttpConnection* connection = new HttpConnection(this, socket.Pass()); +void HttpServer::SetSendBufferSize(int connection_id, int32 size) { + HttpConnection* connection = FindConnection(connection_id); + DCHECK(connection); + connection->write_buf()->set_max_buffer_size(size); +} + +void HttpServer::DoAcceptLoop() { + int rv; + do { + rv = server_socket_->Accept(&accepted_socket_, + base::Bind(&HttpServer::OnAcceptCompleted, + weak_ptr_factory_.GetWeakPtr())); + if (rv == ERR_IO_PENDING) + return; + rv = HandleAcceptResult(rv); + } while (rv == OK); +} + +void HttpServer::OnAcceptCompleted(int rv) { + if (HandleAcceptResult(rv) == OK) + DoAcceptLoop(); +} + +int HttpServer::HandleAcceptResult(int rv) { + if (rv < 0) { + LOG(ERROR) << "Accept error: rv=" << rv; + return rv; + } + + HttpConnection* connection = + new HttpConnection(++last_id_, accepted_socket_.Pass()); id_to_connection_[connection->id()] = connection; - // TODO(szym): Fix socket access. Make HttpConnection the Delegate. - socket_to_connection_[connection->socket_.get()] = connection; + DoReadLoop(connection); + return OK; } -void HttpServer::DidRead(StreamListenSocket* socket, - const char* data, - int len) { - HttpConnection* connection = FindConnection(socket); - DCHECK(connection != NULL); - if (connection == NULL) +void HttpServer::DoReadLoop(HttpConnection* connection) { + int rv; + do { + HttpConnection::ReadIOBuffer* read_buf = connection->read_buf(); + // Increases read buffer size if necessary. + if (read_buf->RemainingCapacity() == 0 && !read_buf->IncreaseCapacity()) { + Close(connection->id()); + return; + } + + rv = connection->socket()->Read( + read_buf, + read_buf->RemainingCapacity(), + base::Bind(&HttpServer::OnReadCompleted, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (rv == ERR_IO_PENDING) + return; + rv = HandleReadResult(connection, rv); + } while (rv == OK); +} + +void HttpServer::OnReadCompleted(int connection_id, int rv) { + HttpConnection* connection = FindConnection(connection_id); + if (!connection) // It might be closed right before by write error. return; - connection->recv_data_.append(data, len); - while (connection->recv_data_.length()) { - if (connection->web_socket_.get()) { + if (HandleReadResult(connection, rv) == OK) + DoReadLoop(connection); +} + +int HttpServer::HandleReadResult(HttpConnection* connection, int rv) { + if (rv <= 0) { + Close(connection->id()); + return rv == 0 ? ERR_CONNECTION_CLOSED : rv; + } + + HttpConnection::ReadIOBuffer* read_buf = connection->read_buf(); + read_buf->DidRead(rv); + + // Handles http requests or websocket messages. + while (read_buf->GetSize() > 0) { + if (connection->web_socket()) { std::string message; - WebSocket::ParseResult result = connection->web_socket_->Read(&message); + WebSocket::ParseResult result = connection->web_socket()->Read(&message); if (result == WebSocket::FRAME_INCOMPLETE) break; if (result == WebSocket::FRAME_CLOSE || result == WebSocket::FRAME_ERROR) { Close(connection->id()); - break; + return ERR_CONNECTION_CLOSED; } delegate_->OnWebSocketMessage(connection->id(), message); + if (HasClosedConnection(connection)) + return ERR_CONNECTION_CLOSED; continue; } HttpServerRequestInfo request; size_t pos = 0; - if (!ParseHeaders(connection, &request, &pos)) + if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(), + &request, &pos)) { break; + } // Sets peer address if exists. - socket->GetPeerAddress(&request.peer); + connection->socket()->GetPeerAddress(&request.peer); if (request.HasHeaderValue("connection", "upgrade")) { - connection->web_socket_.reset(WebSocket::CreateWebSocket(connection, - request, - &pos)); - - if (!connection->web_socket_.get()) // Not enough data was received. + scoped_ptr<WebSocket> websocket( + WebSocket::CreateWebSocket(this, connection, request, &pos)); + if (!websocket) // Not enough data was received. break; + connection->SetWebSocket(websocket.Pass()); + read_buf->DidConsume(pos); delegate_->OnWebSocketRequest(connection->id(), request); - connection->Shift(pos); + if (HasClosedConnection(connection)) + return ERR_CONNECTION_CLOSED; continue; } const char kContentLength[] = "content-length"; - if (request.headers.count(kContentLength)) { + if (request.headers.count(kContentLength) > 0) { size_t content_length = 0; const size_t kMaxBodySize = 100 << 20; if (!base::StringToSizeT(request.GetHeaderValue(kContentLength), &content_length) || content_length > kMaxBodySize) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "request content-length too big or unknown: " + - request.GetHeaderValue(kContentLength))); - DidClose(socket); - break; + SendResponse(connection->id(), + HttpServerResponseInfo::CreateFor500( + "request content-length too big or unknown: " + + request.GetHeaderValue(kContentLength))); + Close(connection->id()); + return ERR_CONNECTION_CLOSED; } - if (connection->recv_data_.length() - pos < content_length) + if (read_buf->GetSize() - pos < content_length) break; // Not enough data was received yet. - request.data = connection->recv_data_.substr(pos, content_length); + request.data.assign(read_buf->StartOfBuffer() + pos, content_length); pos += content_length; } + read_buf->DidConsume(pos); delegate_->OnHttpRequest(connection->id(), request); - connection->Shift(pos); + if (HasClosedConnection(connection)) + return ERR_CONNECTION_CLOSED; } + + return OK; } -void HttpServer::DidClose(StreamListenSocket* socket) { - HttpConnection* connection = FindConnection(socket); - DCHECK(connection != NULL); - id_to_connection_.erase(connection->id()); - socket_to_connection_.erase(connection->socket_.get()); - delete connection; +void HttpServer::DoWriteLoop(HttpConnection* connection) { + int rv = OK; + HttpConnection::QueuedWriteIOBuffer* write_buf = connection->write_buf(); + while (rv == OK && write_buf->GetSizeToWrite() > 0) { + rv = connection->socket()->Write( + write_buf, + write_buf->GetSizeToWrite(), + base::Bind(&HttpServer::OnWriteCompleted, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (rv == ERR_IO_PENDING || rv == OK) + return; + rv = HandleWriteResult(connection, rv); + } } -HttpServer::~HttpServer() { - STLDeleteContainerPairSecondPointers( - id_to_connection_.begin(), id_to_connection_.end()); +void HttpServer::OnWriteCompleted(int connection_id, int rv) { + HttpConnection* connection = FindConnection(connection_id); + if (!connection) // It might be closed right before by read error. + return; + + if (HandleWriteResult(connection, rv) == OK) + DoWriteLoop(connection); } +int HttpServer::HandleWriteResult(HttpConnection* connection, int rv) { + if (rv < 0) { + Close(connection->id()); + return rv; + } + + connection->write_buf()->DidConsume(rv); + return OK; +} + +namespace { + // // HTTP Request Parser // This HTTP request parser uses a simple state machine to quickly parse @@ -255,17 +367,19 @@ int charToInput(char ch) { return INPUT_DEFAULT; } -bool HttpServer::ParseHeaders(HttpConnection* connection, +} // namespace + +bool HttpServer::ParseHeaders(const char* data, + size_t data_len, HttpServerRequestInfo* info, size_t* ppos) { size_t& pos = *ppos; - size_t data_len = connection->recv_data_.length(); int state = ST_METHOD; std::string buffer; std::string header_name; std::string header_value; while (pos < data_len) { - char ch = connection->recv_data_[pos++]; + char ch = data[pos++]; int input = charToInput(ch); int next_state = parser_state[state][input]; @@ -337,11 +451,12 @@ HttpConnection* HttpServer::FindConnection(int connection_id) { return it->second; } -HttpConnection* HttpServer::FindConnection(StreamListenSocket* socket) { - SocketToConnectionMap::iterator it = socket_to_connection_.find(socket); - if (it == socket_to_connection_.end()) - return NULL; - return it->second; +// This is called after any delegate callbacks are called to check if Close() +// has been called during callback processing. Using the pointer of connection, +// |connection| is safe here because Close() deletes the connection in next run +// loop. +bool HttpServer::HasClosedConnection(HttpConnection* connection) { + return FindConnection(connection->id()) != connection; } } // namespace net diff --git a/net/server/http_server.h b/net/server/http_server.h index 4309d122..2ae698b 100644 --- a/net/server/http_server.h +++ b/net/server/http_server.h @@ -5,13 +5,14 @@ #ifndef NET_SERVER_HTTP_SERVER_H_ #define NET_SERVER_HTTP_SERVER_H_ -#include <list> #include <map> +#include <string> #include "base/basictypes.h" +#include "base/macros.h" #include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" #include "net/http/http_status_code.h" -#include "net/socket/stream_listen_socket.h" namespace net { @@ -19,30 +20,28 @@ class HttpConnection; class HttpServerRequestInfo; class HttpServerResponseInfo; class IPEndPoint; +class ServerSocket; +class StreamSocket; class WebSocket; -class HttpServer : public StreamListenSocket::Delegate, - public base::RefCountedThreadSafe<HttpServer> { +class HttpServer { public: + // Delegate to handle http/websocket events. Beware that it is not safe to + // destroy the HttpServer in any of these callbacks. class Delegate { public: virtual void OnHttpRequest(int connection_id, const HttpServerRequestInfo& info) = 0; - virtual void OnWebSocketRequest(int connection_id, const HttpServerRequestInfo& info) = 0; - virtual void OnWebSocketMessage(int connection_id, const std::string& data) = 0; - virtual void OnClose(int connection_id) = 0; - - protected: - virtual ~Delegate() {} }; - HttpServer(const StreamListenSocketFactory& socket_factory, + HttpServer(scoped_ptr<ServerSocket> server_socket, HttpServer::Delegate* delegate); + ~HttpServer(); void AcceptWebSocket(int connection_id, const HttpServerRequestInfo& request); @@ -51,6 +50,7 @@ class HttpServer : public StreamListenSocket::Delegate, // performed that data constitutes a valid HTTP response. A valid HTTP // response may be split across multiple calls to SendRaw. void SendRaw(int connection_id, const std::string& data); + // TODO(byungchul): Consider replacing function name with SendResponseInfo void SendResponse(int connection_id, const HttpServerResponseInfo& response); void Send(int connection_id, HttpStatusCode status_code, @@ -64,40 +64,50 @@ class HttpServer : public StreamListenSocket::Delegate, void Close(int connection_id); + void SetReceiveBufferSize(int connection_id, int32 size); + void SetSendBufferSize(int connection_id, int32 size); + // Copies the local address to |address|. Returns a network error code. int GetLocalAddress(IPEndPoint* address); - // ListenSocketDelegate - virtual void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> socket) OVERRIDE; - virtual void DidRead(StreamListenSocket* socket, - const char* data, - int len) OVERRIDE; - virtual void DidClose(StreamListenSocket* socket) OVERRIDE; + private: + friend class HttpServerTest; + + typedef std::map<int, HttpConnection*> IdToConnectionMap; - protected: - virtual ~HttpServer(); + void DoAcceptLoop(); + void OnAcceptCompleted(int rv); + int HandleAcceptResult(int rv); - private: - friend class base::RefCountedThreadSafe<HttpServer>; - friend class HttpConnection; + void DoReadLoop(HttpConnection* connection); + void OnReadCompleted(int connection_id, int rv); + int HandleReadResult(HttpConnection* connection, int rv); + + void DoWriteLoop(HttpConnection* connection); + void OnWriteCompleted(int connection_id, int rv); + int HandleWriteResult(HttpConnection* connection, int rv); // Expects the raw data to be stored in recv_data_. If parsing is successful, // will remove the data parsed from recv_data_, leaving only the unused // recv data. - bool ParseHeaders(HttpConnection* connection, + bool ParseHeaders(const char* data, + size_t data_len, HttpServerRequestInfo* info, size_t* pos); HttpConnection* FindConnection(int connection_id); - HttpConnection* FindConnection(StreamListenSocket* socket); - HttpServer::Delegate* delegate_; - scoped_ptr<StreamListenSocket> server_; - typedef std::map<int, HttpConnection*> IdToConnectionMap; + // Whether or not Close() has been called during delegate callback processing. + bool HasClosedConnection(HttpConnection* connection); + + const scoped_ptr<ServerSocket> server_socket_; + scoped_ptr<StreamSocket> accepted_socket_; + HttpServer::Delegate* const delegate_; + + int last_id_; IdToConnectionMap id_to_connection_; - typedef std::map<StreamListenSocket*, HttpConnection*> SocketToConnectionMap; - SocketToConnectionMap socket_to_connection_; + + base::WeakPtrFactory<HttpServer> weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(HttpServer); }; diff --git a/net/server/http_server_response_info.cc b/net/server/http_server_response_info.cc index e4c6043..2d0a32e 100644 --- a/net/server/http_server_response_info.cc +++ b/net/server/http_server_response_info.cc @@ -41,8 +41,14 @@ void HttpServerResponseInfo::SetBody(const std::string& body, const std::string& content_type) { DCHECK(body_.empty()); body_ = body; + SetContentHeaders(body.length(), content_type); +} + +void HttpServerResponseInfo::SetContentHeaders( + size_t content_length, + const std::string& content_type) { AddHeader(HttpRequestHeaders::kContentLength, - base::StringPrintf("%" PRIuS, body.length())); + base::StringPrintf("%" PRIuS, content_length)); AddHeader(HttpRequestHeaders::kContentType, content_type); } diff --git a/net/server/http_server_response_info.h b/net/server/http_server_response_info.h index d6cedaa..bbb76d8 100644 --- a/net/server/http_server_response_info.h +++ b/net/server/http_server_response_info.h @@ -27,6 +27,9 @@ class HttpServerResponseInfo { // This also adds an appropriate Content-Length header. void SetBody(const std::string& body, const std::string& content_type); + // Sets content-length and content-type. Body should be sent separately. + void SetContentHeaders(size_t content_length, + const std::string& content_type); std::string Serialize() const; diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc index 467bde4..4b67040 100644 --- a/net/server/http_server_unittest.cc +++ b/net/server/http_server_unittest.cc @@ -2,11 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <algorithm> #include <utility> #include <vector> #include "base/bind.h" #include "base/bind_helpers.h" +#include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/format_macros.h" #include "base/memory/ref_counted.h" @@ -24,11 +26,12 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_log.h" +#include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/socket/tcp_client_socket.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_server_socket.h" #include "net/url_request/url_fetcher.h" #include "net/url_request/url_fetcher_delegate.h" #include "net/url_request/url_request_context.h" @@ -155,8 +158,10 @@ class HttpServerTest : public testing::Test, HttpServerTest() : quit_after_request_count_(0) {} virtual void SetUp() OVERRIDE { - TCPListenSocketFactory socket_factory("127.0.0.1", 0); - server_ = new HttpServer(socket_factory, this); + scoped_ptr<ServerSocket> server_socket( + new TCPServerSocket(NULL, net::NetLog::Source())); + server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1); + server_.reset(new HttpServer(server_socket.Pass(), this)); ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_)); } @@ -199,8 +204,13 @@ class HttpServerTest : public testing::Test, return requests_[request_index].second; } + void HandleAcceptResult(scoped_ptr<StreamSocket> socket) { + server_->accepted_socket_.reset(socket.release()); + server_->HandleAcceptResult(OK); + } + protected: - scoped_refptr<HttpServer> server_; + scoped_ptr<HttpServer> server_; IPEndPoint server_address_; base::Closure run_loop_quit_func_; std::vector<std::pair<HttpServerRequestInfo, int> > requests_; @@ -429,23 +439,105 @@ TEST_F(HttpServerTest, SendRaw) { namespace { -class MockStreamListenSocket : public StreamListenSocket { +class MockStreamSocket : public StreamSocket { public: - MockStreamListenSocket(StreamListenSocket::Delegate* delegate) - : StreamListenSocket(kInvalidSocket, delegate) {} + MockStreamSocket() + : connected_(true), + read_buf_(NULL), + read_buf_len_(0) {} + + // StreamSocket + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual void Disconnect() OVERRIDE { + connected_ = false; + if (!read_callback_.is_null()) { + read_buf_ = NULL; + read_buf_len_ = 0; + base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED); + } + } + virtual bool IsConnected() const OVERRIDE { return connected_; } + virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return true; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { return false; } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } - virtual void Accept() OVERRIDE { NOTREACHED(); } + // Socket + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + if (!connected_) { + return ERR_SOCKET_NOT_CONNECTED; + } + if (pending_read_data_.empty()) { + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + return ERR_IO_PENDING; + } + DCHECK_GT(buf_len, 0); + int read_len = std::min(static_cast<int>(pending_read_data_.size()), + buf_len); + memcpy(buf->data(), pending_read_data_.data(), read_len); + pending_read_data_.erase(0, read_len); + return read_len; + } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual int SetReceiveBufferSize(int32 size) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + virtual int SetSendBufferSize(int32 size) OVERRIDE { + return ERR_NOT_IMPLEMENTED; + } + + void DidRead(const char* data, int data_len) { + if (!read_buf_) { + pending_read_data_.append(data, data_len); + return; + } + int read_len = std::min(data_len, read_buf_len_); + memcpy(read_buf_->data(), data, read_len); + pending_read_data_.assign(data + read_len, data_len - read_len); + read_buf_ = NULL; + read_buf_len_ = 0; + base::ResetAndReturn(&read_callback_).Run(read_len); + } private: - virtual ~MockStreamListenSocket() {} + virtual ~MockStreamSocket() {} + + bool connected_; + scoped_refptr<IOBuffer> read_buf_; + int read_buf_len_; + CompletionCallback read_callback_; + std::string pending_read_data_; + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MockStreamSocket); }; } // namespace TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { - StreamListenSocket* socket = - new MockStreamListenSocket(server_.get()); - server_->DidAccept(NULL, make_scoped_ptr(socket)); + MockStreamSocket* socket = new MockStreamSocket(); + HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket)); std::string body("body"); std::string request_text = base::StringPrintf( "GET /test HTTP/1.1\r\n" @@ -453,9 +545,9 @@ TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { "Content-Length: %" PRIuS "\r\n\r\n%s", body.length(), body.c_str()); - server_->DidRead(socket, request_text.c_str(), request_text.length() - 2); + socket->DidRead(request_text.c_str(), request_text.length() - 2); ASSERT_EQ(0u, requests_.size()); - server_->DidRead(socket, request_text.c_str() + request_text.length() - 2, 2); + socket->DidRead(request_text.c_str() + request_text.length() - 2, 2); ASSERT_EQ(1u, requests_.size()); ASSERT_EQ(body, GetRequest(0).data); } diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc index f06b425..ec0fdac 100644 --- a/net/server/web_socket.cc +++ b/net/server/web_socket.cc @@ -15,6 +15,7 @@ #include "base/strings/stringprintf.h" #include "base/sys_byteorder.h" #include "net/server/http_connection.h" +#include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" @@ -43,12 +44,14 @@ static uint32 WebSocketKeyFingerprint(const std::string& str) { class WebSocketHixie76 : public net::WebSocket { public: - static net::WebSocket* Create(HttpConnection* connection, + static net::WebSocket* Create(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { - if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen) + if (connection->read_buf()->GetSize() < + static_cast<int>(*pos + kWebSocketHandshakeBodyLen)) return NULL; - return new WebSocketHixie76(connection, request, pos); + return new WebSocketHixie76(server, connection, request, pos); } virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { @@ -69,31 +72,33 @@ class WebSocketHixie76 : public net::WebSocket { std::string origin = request.GetHeaderValue("origin"); std::string host = request.GetHeaderValue("host"); std::string location = "ws://" + host + request.path; - connection_->Send(base::StringPrintf( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Origin: %s\r\n" - "Sec-WebSocket-Location: %s\r\n" - "\r\n", - origin.c_str(), - location.c_str())); - connection_->Send(reinterpret_cast<char*>(digest.a), 16); + server_->SendRaw( + connection_->id(), + base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Origin: %s\r\n" + "Sec-WebSocket-Location: %s\r\n" + "\r\n", + origin.c_str(), + location.c_str())); + server_->SendRaw(connection_->id(), + std::string(reinterpret_cast<char*>(digest.a), 16)); } virtual ParseResult Read(std::string* message) OVERRIDE { DCHECK(message); - const std::string& data = connection_->recv_data(); - if (data[0]) + HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); + if (read_buf->StartOfBuffer()[0]) return FRAME_ERROR; + base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); size_t pos = data.find('\377', 1); - if (pos == std::string::npos) + if (pos == base::StringPiece::npos) return FRAME_INCOMPLETE; - std::string buffer(data.begin() + 1, data.begin() + pos); - message->swap(buffer); - connection_->Shift(pos + 1); + message->assign(data.data() + 1, pos - 1); + read_buf->DidConsume(pos + 1); return FRAME_OK; } @@ -101,37 +106,42 @@ class WebSocketHixie76 : public net::WebSocket { virtual void Send(const std::string& message) OVERRIDE { char message_start = 0; char message_end = -1; - connection_->Send(&message_start, 1); - connection_->Send(message); - connection_->Send(&message_end, 1); + server_->SendRaw(connection_->id(), std::string(1, message_start)); + server_->SendRaw(connection_->id(), message); + server_->SendRaw(connection_->id(), std::string(1, message_end)); } private: static const int kWebSocketHandshakeBodyLen; - WebSocketHixie76(HttpConnection* connection, + WebSocketHixie76(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, - size_t* pos) : WebSocket(connection) { + size_t* pos) + : WebSocket(server, connection) { std::string key1 = request.GetHeaderValue("sec-websocket-key1"); std::string key2 = request.GetHeaderValue("sec-websocket-key2"); if (key1.empty()) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " - "specified.")); + server->SendResponse( + connection->id(), + HttpServerResponseInfo::CreateFor500( + "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " + "specified.")); return; } if (key2.empty()) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " - "specified.")); + server->SendResponse( + connection->id(), + HttpServerResponseInfo::CreateFor500( + "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " + "specified.")); return; } - key3_ = connection->recv_data().substr( - *pos, - *pos + kWebSocketHandshakeBodyLen); + key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, + kWebSocketHandshakeBodyLen); *pos += kWebSocketHandshakeBodyLen; } @@ -169,7 +179,8 @@ const size_t kMaskingKeyWidthInBytes = 4; class WebSocketHybi17 : public WebSocket { public: - static WebSocket* Create(HttpConnection* connection, + static WebSocket* Create(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { std::string version = request.GetHeaderValue("sec-websocket-version"); @@ -178,12 +189,14 @@ class WebSocketHybi17 : public WebSocket { std::string key = request.GetHeaderValue("sec-websocket-key"); if (key.empty()) { - connection->Send(HttpServerResponseInfo::CreateFor500( - "Invalid request format. Sec-WebSocket-Key is empty or isn't " - "specified.")); + server->SendResponse( + connection->id(), + HttpServerResponseInfo::CreateFor500( + "Invalid request format. Sec-WebSocket-Key is empty or isn't " + "specified.")); return NULL; } - return new WebSocketHybi17(connection, request, pos); + return new WebSocketHybi17(server, connection, request, pos); } virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { @@ -194,24 +207,24 @@ class WebSocketHybi17 : public WebSocket { std::string encoded_hash; base::Base64Encode(base::SHA1HashString(data), &encoded_hash); - std::string response = base::StringPrintf( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: %s\r\n" - "\r\n", - encoded_hash.c_str()); - connection_->Send(response); + server_->SendRaw( + connection_->id(), + base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n" + "\r\n", + encoded_hash.c_str())); } virtual ParseResult Read(std::string* message) OVERRIDE { - const std::string& frame = connection_->recv_data(); + HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); + base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); int bytes_consumed = 0; - ParseResult result = WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message); if (result == FRAME_OK) - connection_->Shift(bytes_consumed); + read_buf->DidConsume(bytes_consumed); if (result == FRAME_CLOSE) closed_ = true; return result; @@ -220,25 +233,26 @@ class WebSocketHybi17 : public WebSocket { virtual void Send(const std::string& message) OVERRIDE { if (closed_) return; - std::string data = WebSocket::EncodeFrameHybi17(message, 0); - connection_->Send(data); + server_->SendRaw(connection_->id(), + WebSocket::EncodeFrameHybi17(message, 0)); } private: - WebSocketHybi17(HttpConnection* connection, + WebSocketHybi17(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) - : WebSocket(connection), - op_code_(0), - final_(false), - reserved1_(false), - reserved2_(false), - reserved3_(false), - masked_(false), - payload_(0), - payload_length_(0), - frame_end_(0), - closed_(false) { + : WebSocket(server, connection), + op_code_(0), + final_(false), + reserved1_(false), + reserved2_(false), + reserved3_(false), + masked_(false), + payload_(0), + payload_length_(0), + frame_end_(0), + closed_(false) { } OpCode op_code_; @@ -257,21 +271,23 @@ class WebSocketHybi17 : public WebSocket { } // anonymous namespace -WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, +WebSocket* WebSocket::CreateWebSocket(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos) { - WebSocket* socket = WebSocketHybi17::Create(connection, request, pos); + WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); if (socket) return socket; - return WebSocketHixie76::Create(connection, request, pos); + return WebSocketHixie76::Create(server, connection, request, pos); } // static -WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame, - bool client_frame, - int* bytes_consumed, - std::string* output) { +WebSocket::ParseResult WebSocket::DecodeFrameHybi17( + const base::StringPiece& frame, + bool client_frame, + int* bytes_consumed, + std::string* output) { size_t data_length = frame.length(); if (data_length < 2) return FRAME_INCOMPLETE; @@ -349,8 +365,7 @@ WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame, for (size_t i = 0; i < payload_length; ++i) // Unmask the payload. (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; } else { - std::string buffer(p, p + payload_length); - output->swap(buffer); + output->assign(p, p + payload_length); } size_t pos = p + actual_masking_key_length + payload_length - buffer_begin; @@ -400,7 +415,9 @@ std::string WebSocket::EncodeFrameHybi17(const std::string& message, return std::string(&frame[0], frame.size()); } -WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { +WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) + : server_(server), + connection_(connection) { } } // namespace net diff --git a/net/server/web_socket.h b/net/server/web_socket.h index 49ced84..9b3a794 100644 --- a/net/server/web_socket.h +++ b/net/server/web_socket.h @@ -8,10 +8,12 @@ #include <string> #include "base/basictypes.h" +#include "base/strings/string_piece.h" namespace net { class HttpConnection; +class HttpServer; class HttpServerRequestInfo; class WebSocket { @@ -23,11 +25,12 @@ class WebSocket { FRAME_ERROR }; - static WebSocket* CreateWebSocket(HttpConnection* connection, + static WebSocket* CreateWebSocket(HttpServer* server, + HttpConnection* connection, const HttpServerRequestInfo& request, size_t* pos); - static ParseResult DecodeFrameHybi17(const std::string& frame, + static ParseResult DecodeFrameHybi17(const base::StringPiece& frame, bool client_frame, int* bytes_consumed, std::string* output); @@ -41,8 +44,10 @@ class WebSocket { virtual ~WebSocket() {} protected: - explicit WebSocket(HttpConnection* connection); - HttpConnection* connection_; + WebSocket(HttpServer* server, HttpConnection* connection); + + HttpServer* const server_; + HttpConnection* const connection_; private: DISALLOW_COPY_AND_ASSIGN(WebSocket); diff --git a/net/socket/server_socket.h b/net/socket/server_socket.h index 528955b..4b9ca8e 100644 --- a/net/socket/server_socket.h +++ b/net/socket/server_socket.h @@ -21,7 +21,7 @@ class NET_EXPORT ServerSocket { ServerSocket(); virtual ~ServerSocket(); - // Binds the socket and starts listening. Destroy the socket to stop + // Binds the socket and starts listening. Destroys the socket to stop // listening. virtual int Listen(const IPEndPoint& address, int backlog) = 0; |