// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include #include #include #include "base/bind.h" #include "base/bind_helpers.h" #include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/format_macros.h" #include "base/location.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" #include "base/run_loop.h" #include "base/single_thread_task_runner.h" #include "base/strings/string_split.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/thread_task_runner_handle.h" #include "base/time/time.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" #include "net/log/net_log.h" #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/socket/tcp_client_socket.h" #include "net/socket/tcp_server_socket.h" #include "net/url_request/url_fetcher.h" #include "net/url_request/url_fetcher_delegate.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_test_util.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { namespace { const int kMaxExpectedResponseLength = 2048; void SetTimedOutAndQuitLoop(const base::WeakPtr timed_out, const base::Closure& quit_loop_func) { if (timed_out) { *timed_out = true; quit_loop_func.Run(); } } bool RunLoopWithTimeout(base::RunLoop* run_loop) { bool timed_out = false; base::WeakPtrFactory timed_out_weak_factory(&timed_out); base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( FROM_HERE, base::Bind(&SetTimedOutAndQuitLoop, timed_out_weak_factory.GetWeakPtr(), run_loop->QuitClosure()), base::TimeDelta::FromSeconds(1)); run_loop->Run(); return !timed_out; } class TestHttpClient { public: TestHttpClient() : connect_result_(OK) {} int ConnectAndWait(const IPEndPoint& address) { AddressList addresses(address); NetLog::Source source; socket_.reset(new TCPClientSocket(addresses, NULL, source)); base::RunLoop run_loop; connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect, base::Unretained(this), run_loop.QuitClosure())); if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING) return connect_result_; if (!RunLoopWithTimeout(&run_loop)) return ERR_TIMED_OUT; return connect_result_; } void Send(const std::string& data) { write_buffer_ = new DrainableIOBuffer(new StringIOBuffer(data), data.length()); Write(); } bool Read(std::string* message, int expected_bytes) { int total_bytes_received = 0; message->clear(); while (total_bytes_received < expected_bytes) { TestCompletionCallback callback; ReadInternal(callback.callback()); int bytes_received = callback.WaitForResult(); if (bytes_received <= 0) return false; total_bytes_received += bytes_received; message->append(read_buffer_->data(), bytes_received); } return true; } bool ReadResponse(std::string* message) { if (!Read(message, 1)) return false; while (!IsCompleteResponse(*message)) { std::string chunk; if (!Read(&chunk, 1)) return false; message->append(chunk); } return true; } private: void OnConnect(const base::Closure& quit_loop, int result) { connect_result_ = result; quit_loop.Run(); } void Write() { int result = socket_->Write( write_buffer_.get(), write_buffer_->BytesRemaining(), base::Bind(&TestHttpClient::OnWrite, base::Unretained(this))); if (result != ERR_IO_PENDING) OnWrite(result); } void OnWrite(int result) { ASSERT_GT(result, 0); write_buffer_->DidConsume(result); if (write_buffer_->BytesRemaining()) Write(); } void ReadInternal(const CompletionCallback& callback) { read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength); int result = socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback); if (result != ERR_IO_PENDING) callback.Run(result); } bool IsCompleteResponse(const std::string& response) { // Check end of headers first. int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(), response.size()); if (end_of_headers < 0) return false; // Return true if response has data equal to or more than content length. int64 body_size = static_cast(response.size()) - end_of_headers; DCHECK_LE(0, body_size); scoped_refptr headers(new HttpResponseHeaders( HttpUtil::AssembleRawHeaders(response.data(), end_of_headers))); return body_size >= headers->GetContentLength(); } scoped_refptr read_buffer_; scoped_refptr write_buffer_; scoped_ptr socket_; int connect_result_; }; } // namespace class HttpServerTest : public testing::Test, public HttpServer::Delegate { public: HttpServerTest() : quit_after_request_count_(0) {} void SetUp() override { scoped_ptr server_socket( new TCPServerSocket(NULL, NetLog::Source())); server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1); server_.reset(new HttpServer(server_socket.Pass(), this)); ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_)); } void OnConnect(int connection_id) override {} void OnHttpRequest(int connection_id, const HttpServerRequestInfo& info) override { requests_.push_back(std::make_pair(info, connection_id)); if (requests_.size() == quit_after_request_count_) run_loop_quit_func_.Run(); } void OnWebSocketRequest(int connection_id, const HttpServerRequestInfo& info) override { NOTREACHED(); } void OnWebSocketMessage(int connection_id, const std::string& data) override { NOTREACHED(); } void OnClose(int connection_id) override {} bool RunUntilRequestsReceived(size_t count) { quit_after_request_count_ = count; if (requests_.size() == count) return true; base::RunLoop run_loop; run_loop_quit_func_ = run_loop.QuitClosure(); bool success = RunLoopWithTimeout(&run_loop); run_loop_quit_func_.Reset(); return success; } HttpServerRequestInfo GetRequest(size_t request_index) { return requests_[request_index].first; } int GetConnectionId(size_t request_index) { return requests_[request_index].second; } void HandleAcceptResult(scoped_ptr socket) { server_->accepted_socket_.reset(socket.release()); server_->HandleAcceptResult(OK); } protected: scoped_ptr server_; IPEndPoint server_address_; base::Closure run_loop_quit_func_; std::vector > requests_; private: size_t quit_after_request_count_; }; namespace { class WebSocketTest : public HttpServerTest { void OnHttpRequest(int connection_id, const HttpServerRequestInfo& info) override { NOTREACHED(); } void OnWebSocketRequest(int connection_id, const HttpServerRequestInfo& info) override { HttpServerTest::OnHttpRequest(connection_id, info); } void OnWebSocketMessage(int connection_id, const std::string& data) override { } }; TEST_F(HttpServerTest, Request) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); client.Send("GET /test HTTP/1.1\r\n\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); ASSERT_EQ("GET", GetRequest(0).method); ASSERT_EQ("/test", GetRequest(0).path); ASSERT_EQ("", GetRequest(0).data); ASSERT_EQ(0u, GetRequest(0).headers.size()); ASSERT_TRUE( base::StartsWithASCII(GetRequest(0).peer.ToString(), "127.0.0.1", true)); } TEST_F(HttpServerTest, RequestWithHeaders) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); const char* const kHeaders[][3] = { {"Header", ": ", "1"}, {"HeaderWithNoWhitespace", ":", "1"}, {"HeaderWithWhitespace", " : \t ", "1 1 1 \t "}, {"HeaderWithColon", ": ", "1:1"}, {"EmptyHeader", ":", ""}, {"EmptyHeaderWithWhitespace", ": \t ", ""}, {"HeaderWithNonASCII", ": ", "\xf7"}, }; std::string headers; for (size_t i = 0; i < arraysize(kHeaders); ++i) { headers += std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n"; } client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); ASSERT_EQ("", GetRequest(0).data); for (size_t i = 0; i < arraysize(kHeaders); ++i) { std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0])); std::string value = kHeaders[i][2]; ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field; ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0]; } } TEST_F(HttpServerTest, RequestWithDuplicateHeaders) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); const char* const kHeaders[][3] = { {"FirstHeader", ": ", "1"}, {"DuplicateHeader", ": ", "2"}, {"MiddleHeader", ": ", "3"}, {"DuplicateHeader", ": ", "4"}, {"LastHeader", ": ", "5"}, }; std::string headers; for (size_t i = 0; i < arraysize(kHeaders); ++i) { headers += std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n"; } client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); ASSERT_EQ("", GetRequest(0).data); for (size_t i = 0; i < arraysize(kHeaders); ++i) { std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0])); std::string value = (field == "duplicateheader") ? "2,4" : kHeaders[i][2]; ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field; ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0]; } } TEST_F(HttpServerTest, HasHeaderValueTest) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); const char* const kHeaders[] = { "Header: Abcd", "HeaderWithNoWhitespace:E", "HeaderWithWhitespace : \t f \t ", "DuplicateHeader: g", "HeaderWithComma: h, i ,j", "DuplicateHeader: k", "EmptyHeader:", "EmptyHeaderWithWhitespace: \t ", "HeaderWithNonASCII: \xf7", }; std::string headers; for (size_t i = 0; i < arraysize(kHeaders); ++i) { headers += std::string(kHeaders[i]) + "\r\n"; } client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); ASSERT_EQ("", GetRequest(0).data); ASSERT_TRUE(GetRequest(0).HasHeaderValue("header", "abcd")); ASSERT_FALSE(GetRequest(0).HasHeaderValue("header", "bc")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnowhitespace", "e")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithwhitespace", "f")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "g")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "h")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "i")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "j")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "k")); ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheader", "x")); ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheaderwithwhitespace", "x")); ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnonascii", "\xf7")); } TEST_F(HttpServerTest, RequestWithBody) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); std::string body = "a" + std::string(1 << 10, 'b') + "c"; client.Send(base::StringPrintf( "GET /test HTTP/1.1\r\n" "SomeHeader: 1\r\n" "Content-Length: %" PRIuS "\r\n\r\n%s", body.length(), body.c_str())); ASSERT_TRUE(RunUntilRequestsReceived(1)); ASSERT_EQ(2u, GetRequest(0).headers.size()); ASSERT_EQ(body.length(), GetRequest(0).data.length()); ASSERT_EQ('a', body[0]); ASSERT_EQ('c', *body.rbegin()); } TEST_F(WebSocketTest, RequestWebSocket) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); client.Send( "GET /test HTTP/1.1\r\n" "Upgrade: WebSocket\r\n" "Connection: SomethingElse, Upgrade\r\n" "Sec-WebSocket-Version: 8\r\n" "Sec-WebSocket-Key: key\r\n" "\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); } TEST_F(HttpServerTest, RequestWithTooLargeBody) { class TestURLFetcherDelegate : public URLFetcherDelegate { public: TestURLFetcherDelegate(const base::Closure& quit_loop_func) : quit_loop_func_(quit_loop_func) {} ~TestURLFetcherDelegate() override {} void OnURLFetchComplete(const URLFetcher* source) override { EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode()); quit_loop_func_.Run(); } private: base::Closure quit_loop_func_; }; base::RunLoop run_loop; TestURLFetcherDelegate delegate(run_loop.QuitClosure()); scoped_refptr request_context_getter( new TestURLRequestContextGetter(base::ThreadTaskRunnerHandle::Get())); scoped_ptr fetcher = URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test", server_address_.port())), URLFetcher::GET, &delegate); fetcher->SetRequestContext(request_context_getter.get()); fetcher->AddExtraRequestHeader( base::StringPrintf("content-length:%d", 1 << 30)); fetcher->Start(); ASSERT_TRUE(RunLoopWithTimeout(&run_loop)); ASSERT_EQ(0u, requests_.size()); } TEST_F(HttpServerTest, Send200) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); client.Send("GET /test HTTP/1.1\r\n\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); server_->Send200(GetConnectionId(0), "Response!", "text/plain"); std::string response; ASSERT_TRUE(client.ReadResponse(&response)); ASSERT_TRUE(base::StartsWithASCII(response, "HTTP/1.1 200 OK", true)); ASSERT_TRUE(base::EndsWith(response, "Response!", true)); } TEST_F(HttpServerTest, SendRaw) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); client.Send("GET /test HTTP/1.1\r\n\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(1)); server_->SendRaw(GetConnectionId(0), "Raw Data "); server_->SendRaw(GetConnectionId(0), "More Data"); server_->SendRaw(GetConnectionId(0), "Third Piece of Data"); const std::string expected_response("Raw Data More DataThird Piece of Data"); std::string response; ASSERT_TRUE(client.Read(&response, expected_response.length())); ASSERT_EQ(expected_response, response); } class MockStreamSocket : public StreamSocket { public: MockStreamSocket() : connected_(true), read_buf_(NULL), read_buf_len_(0) {} // StreamSocket int Connect(const CompletionCallback& callback) override { return ERR_NOT_IMPLEMENTED; } void Disconnect() override { connected_ = false; if (!read_callback_.is_null()) { read_buf_ = NULL; read_buf_len_ = 0; base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED); } } bool IsConnected() const override { return connected_; } bool IsConnectedAndIdle() const override { return IsConnected(); } int GetPeerAddress(IPEndPoint* address) const override { return ERR_NOT_IMPLEMENTED; } int GetLocalAddress(IPEndPoint* address) const override { return ERR_NOT_IMPLEMENTED; } const BoundNetLog& NetLog() const override { return net_log_; } void SetSubresourceSpeculation() override {} void SetOmniboxSpeculation() override {} bool WasEverUsed() const override { return true; } bool UsingTCPFastOpen() const override { return false; } bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } void GetConnectionAttempts(ConnectionAttempts* out) const override { out->clear(); } void ClearConnectionAttempts() override {} void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) override { if (!connected_) { return ERR_SOCKET_NOT_CONNECTED; } if (pending_read_data_.empty()) { read_buf_ = buf; read_buf_len_ = buf_len; read_callback_ = callback; return ERR_IO_PENDING; } DCHECK_GT(buf_len, 0); int read_len = std::min(static_cast(pending_read_data_.size()), buf_len); memcpy(buf->data(), pending_read_data_.data(), read_len); pending_read_data_.erase(0, read_len); return read_len; } int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) override { return ERR_NOT_IMPLEMENTED; } int SetReceiveBufferSize(int32 size) override { return ERR_NOT_IMPLEMENTED; } int SetSendBufferSize(int32 size) override { return ERR_NOT_IMPLEMENTED; } void DidRead(const char* data, int data_len) { if (!read_buf_.get()) { pending_read_data_.append(data, data_len); return; } int read_len = std::min(data_len, read_buf_len_); memcpy(read_buf_->data(), data, read_len); pending_read_data_.assign(data + read_len, data_len - read_len); read_buf_ = NULL; read_buf_len_ = 0; base::ResetAndReturn(&read_callback_).Run(read_len); } private: ~MockStreamSocket() override {} bool connected_; scoped_refptr read_buf_; int read_buf_len_; CompletionCallback read_callback_; std::string pending_read_data_; BoundNetLog net_log_; DISALLOW_COPY_AND_ASSIGN(MockStreamSocket); }; TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { MockStreamSocket* socket = new MockStreamSocket(); HandleAcceptResult(make_scoped_ptr(socket)); std::string body("body"); std::string request_text = base::StringPrintf( "GET /test HTTP/1.1\r\n" "SomeHeader: 1\r\n" "Content-Length: %" PRIuS "\r\n\r\n%s", body.length(), body.c_str()); socket->DidRead(request_text.c_str(), request_text.length() - 2); ASSERT_EQ(0u, requests_.size()); socket->DidRead(request_text.c_str() + request_text.length() - 2, 2); ASSERT_EQ(1u, requests_.size()); ASSERT_EQ(body, GetRequest(0).data); } TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { // The idea behind this test is that requests with or without bodies should // not break parsing of the next request. TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); std::string body = "body"; client.Send(base::StringPrintf( "GET /test HTTP/1.1\r\n" "Content-Length: %" PRIuS "\r\n\r\n%s", body.length(), body.c_str())); ASSERT_TRUE(RunUntilRequestsReceived(1)); ASSERT_EQ(body, GetRequest(0).data); int client_connection_id = GetConnectionId(0); server_->Send200(client_connection_id, "Content for /test", "text/plain"); std::string response1; ASSERT_TRUE(client.ReadResponse(&response1)); ASSERT_TRUE(base::StartsWithASCII(response1, "HTTP/1.1 200 OK", true)); ASSERT_TRUE(base::EndsWith(response1, "Content for /test", true)); client.Send("GET /test2 HTTP/1.1\r\n\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(2)); ASSERT_EQ("/test2", GetRequest(1).path); ASSERT_EQ(client_connection_id, GetConnectionId(1)); server_->Send404(client_connection_id); std::string response2; ASSERT_TRUE(client.ReadResponse(&response2)); ASSERT_TRUE(base::StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true)); client.Send("GET /test3 HTTP/1.1\r\n\r\n"); ASSERT_TRUE(RunUntilRequestsReceived(3)); ASSERT_EQ("/test3", GetRequest(2).path); ASSERT_EQ(client_connection_id, GetConnectionId(2)); server_->Send200(client_connection_id, "Content for /test3", "text/plain"); std::string response3; ASSERT_TRUE(client.ReadResponse(&response3)); ASSERT_TRUE(base::StartsWithASCII(response3, "HTTP/1.1 200 OK", true)); ASSERT_TRUE(base::EndsWith(response3, "Content for /test3", true)); } class CloseOnConnectHttpServerTest : public HttpServerTest { public: void OnConnect(int connection_id) override { connection_ids_.push_back(connection_id); server_->Close(connection_id); } protected: std::vector connection_ids_; }; TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) { TestHttpClient client; ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); client.Send("GET / HTTP/1.1\r\n\r\n"); ASSERT_FALSE(RunUntilRequestsReceived(1)); ASSERT_EQ(1ul, connection_ids_.size()); ASSERT_EQ(0ul, requests_.size()); } } // namespace } // namespace net