// Copyright 2015 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <stddef.h>
#include <stdint.h>

#include <utility>

#include "base/logging.h"
#include "base/macros.h"
#include "base/memory/linked_ptr.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/run_loop.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "mojo/common/data_pipe_utils.h"
#include "mojo/services/network/net_address_type_converters.h"
#include "mojo/services/network/public/cpp/web_socket_read_queue.h"
#include "mojo/services/network/public/cpp/web_socket_write_queue.h"
#include "mojo/services/network/public/interfaces/http_connection.mojom.h"
#include "mojo/services/network/public/interfaces/http_message.mojom.h"
#include "mojo/services/network/public/interfaces/http_server.mojom.h"
#include "mojo/services/network/public/interfaces/net_address.mojom.h"
#include "mojo/services/network/public/interfaces/network_service.mojom.h"
#include "mojo/services/network/public/interfaces/web_socket.mojom.h"
#include "mojo/services/network/public/interfaces/web_socket_factory.mojom.h"
#include "mojo/shell/public/cpp/shell_test.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
#include "net/socket/tcp_client_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace mojo {
namespace {

const int kMaxExpectedResponseLength = 2048;

NetAddressPtr GetLocalHostWithAnyPort() {
  NetAddressPtr addr(NetAddress::New());
  addr->family = NetAddressFamily::IPV4;
  addr->ipv4 = NetAddressIPv4::New();
  addr->ipv4->port = 0;
  addr->ipv4->addr.resize(4);
  addr->ipv4->addr[0] = 127;
  addr->ipv4->addr[1] = 0;
  addr->ipv4->addr[2] = 0;
  addr->ipv4->addr[3] = 1;

  return addr;
}

using TestHeaders = std::vector<std::pair<std::string, std::string>>;

struct TestRequest {
  std::string method;
  std::string url;
  TestHeaders headers;
  scoped_ptr<std::string> body;
};

struct TestResponse {
  uint32_t status_code;
  TestHeaders headers;
  scoped_ptr<std::string> body;
};

std::string MakeRequestMessage(const TestRequest& data) {
  std::string message = data.method + " " + data.url + " HTTP/1.1\r\n";
  for (const auto& item : data.headers)
    message += item.first + ": " + item.second + "\r\n";
  message += "\r\n";
  if (data.body)
    message += *data.body;

  return message;
}

HttpResponsePtr MakeResponseStruct(const TestResponse& data) {
  HttpResponsePtr response(HttpResponse::New());
  response->status_code = data.status_code;
  response->headers.resize(data.headers.size());
  size_t index = 0;
  for (const auto& item : data.headers) {
    HttpHeaderPtr header(HttpHeader::New());
    header->name = item.first;
    header->value = item.second;
    response->headers[index++] = std::move(header);
  }

  if (data.body) {
    uint32_t num_bytes = static_cast<uint32_t>(data.body->size());
    MojoCreateDataPipeOptions options;
    options.struct_size = sizeof(MojoCreateDataPipeOptions);
    options.flags = MOJO_CREATE_DATA_PIPE_OPTIONS_FLAG_NONE;
    options.element_num_bytes = 1;
    options.capacity_num_bytes = num_bytes;
    DataPipe data_pipe(options);
    response->body = std::move(data_pipe.consumer_handle);
    MojoResult result =
        WriteDataRaw(data_pipe.producer_handle.get(), data.body->data(),
                     &num_bytes, MOJO_WRITE_DATA_FLAG_ALL_OR_NONE);
    EXPECT_EQ(MOJO_RESULT_OK, result);
  }

  return response;
}

void CheckHeaders(const TestHeaders& expected,
                  const Array<HttpHeaderPtr>& headers) {
  // The server impl fiddles with Content-Length and Content-Type. So we don't
  // do a strict check here.
  std::map<std::string, std::string> header_map;
  for (size_t i = 0; i < headers.size(); ++i) {
    std::string lower_name =
        base::ToLowerASCII(headers[i]->name.To<std::string>());
    header_map[lower_name] = headers[i]->value;
  }

  for (const auto& item : expected) {
    std::string lower_name = base::ToLowerASCII(item.first);
    EXPECT_NE(header_map.end(), header_map.find(lower_name));
    EXPECT_EQ(item.second, header_map[lower_name]);
  }
}

void CheckRequest(const TestRequest& expected, HttpRequestPtr request) {
  EXPECT_EQ(expected.method, request->method);
  EXPECT_EQ(expected.url, request->url);
  CheckHeaders(expected.headers, request->headers);
  if (expected.body) {
    EXPECT_TRUE(request->body.is_valid());
    std::string body;
    common::BlockingCopyToString(std::move(request->body), &body);
    EXPECT_EQ(*expected.body, body);
  } else {
    EXPECT_FALSE(request->body.is_valid());
  }
}

void CheckResponse(const TestResponse& expected, const std::string& response) {
  int header_end =
      net::HttpUtil::LocateEndOfHeaders(response.c_str(), response.size());
  std::string assembled_headers =
      net::HttpUtil::AssembleRawHeaders(response.c_str(), header_end);
  scoped_refptr<net::HttpResponseHeaders> parsed_headers(
      new net::HttpResponseHeaders(assembled_headers));
  EXPECT_EQ(expected.status_code,
            static_cast<uint32_t>(parsed_headers->response_code()));
  for (const auto& item : expected.headers)
    EXPECT_TRUE(parsed_headers->HasHeaderValue(item.first, item.second));

  if (expected.body) {
    EXPECT_NE(-1, header_end);
    std::string body(response, static_cast<size_t>(header_end));
    EXPECT_EQ(*expected.body, body);
  } else {
    EXPECT_EQ(response.size(), static_cast<size_t>(header_end));
  }
}

class TestHttpClient {
 public:
  TestHttpClient() : connect_result_(net::OK) {}

  void Connect(const net::IPEndPoint& address) {
    net::AddressList addresses(address);
    net::NetLog::Source source;
    socket_.reset(new net::TCPClientSocket(addresses, NULL, source));

    base::RunLoop run_loop;
    connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
                                                  base::Unretained(this),
                                                  run_loop.QuitClosure()));
    if (connect_result_ == net::ERR_IO_PENDING)
      run_loop.Run();

    ASSERT_EQ(net::OK, connect_result_);
  }

  void Send(const std::string& data) {
    write_buffer_ = new net::DrainableIOBuffer(new net::StringIOBuffer(data),
                                               data.length());
    Write();
  }

  // Note: This method determines the end of the response only by Content-Length
  // and connection termination. Besides, it doesn't truncate at the end of the
  // response, so |message| may return more data (e.g., part of the next
  // response).
  void ReadResponse(std::string* message) {
    if (!Read(message, 1))
      return;
    while (!IsCompleteResponse(*message)) {
      std::string chunk;
      if (!Read(&chunk, 1))
        return;
      message->append(chunk);
    }
    return;
  }

 private:
  void OnConnect(const base::Closure& quit_loop, int result) {
    connect_result_ = result;
    quit_loop.Run();
  }

  void Write() {
    int result = socket_->Write(
        write_buffer_.get(), write_buffer_->BytesRemaining(),
        base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
    if (result != net::ERR_IO_PENDING)
      OnWrite(result);
  }

  void OnWrite(int result) {
    ASSERT_GT(result, 0);
    write_buffer_->DidConsume(result);
    if (write_buffer_->BytesRemaining())
      Write();
  }

  bool Read(std::string* message, int expected_bytes) {
    int total_bytes_received = 0;
    message->clear();
    while (total_bytes_received < expected_bytes) {
      net::TestCompletionCallback callback;
      ReadInternal(callback.callback());
      int bytes_received = callback.WaitForResult();
      if (bytes_received <= 0)
        return false;

      total_bytes_received += bytes_received;
      message->append(read_buffer_->data(), bytes_received);
    }
    return true;
  }

  void ReadInternal(const net::CompletionCallback& callback) {
    read_buffer_ = new net::IOBufferWithSize(kMaxExpectedResponseLength);
    int result =
        socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback);
    if (result != net::ERR_IO_PENDING)
      callback.Run(result);
  }

  bool IsCompleteResponse(const std::string& response) {
    // Check end of headers first.
    int end_of_headers =
        net::HttpUtil::LocateEndOfHeaders(response.data(), response.size());
    if (end_of_headers < 0)
      return false;

    // Return true if response has data equal to or more than content length.
    int64_t body_size = static_cast<int64_t>(response.size()) - end_of_headers;
    DCHECK_LE(0, body_size);
    scoped_refptr<net::HttpResponseHeaders> headers(
        new net::HttpResponseHeaders(net::HttpUtil::AssembleRawHeaders(
            response.data(), end_of_headers)));
    return body_size >= headers->GetContentLength();
  }

  scoped_refptr<net::IOBufferWithSize> read_buffer_;
  scoped_refptr<net::DrainableIOBuffer> write_buffer_;
  scoped_ptr<net::TCPClientSocket> socket_;
  int connect_result_;

  DISALLOW_COPY_AND_ASSIGN(TestHttpClient);
};

class WebSocketClientImpl : public WebSocketClient {
 public:
  explicit WebSocketClientImpl()
      : binding_(this, &client_ptr_),
        wait_for_message_count_(0),
        run_loop_(nullptr) {}
  ~WebSocketClientImpl() override {}

  // Establishes a connection from the client side.
  void Connect(WebSocketPtr web_socket, const std::string& url) {
    web_socket_ = std::move(web_socket);

    DataPipe data_pipe;
    send_stream_ = std::move(data_pipe.producer_handle);
    write_send_stream_.reset(new WebSocketWriteQueue(send_stream_.get()));

    web_socket_->Connect(url, Array<String>(), "http://example.com",
                         std::move(data_pipe.consumer_handle),
                         std::move(client_ptr_));
  }

  // Establishes a connection from the server side.
  void AcceptConnectRequest(
      const HttpConnectionDelegate::OnReceivedWebSocketRequestCallback&
          callback) {
    InterfaceRequest<WebSocket> web_socket_request = GetProxy(&web_socket_);

    DataPipe data_pipe;
    send_stream_ = std::move(data_pipe.producer_handle);
    write_send_stream_.reset(new WebSocketWriteQueue(send_stream_.get()));

    callback.Run(std::move(web_socket_request),
                 std::move(data_pipe.consumer_handle), std::move(client_ptr_));
  }

  void WaitForConnectCompletion() {
    DCHECK(!run_loop_);

    if (receive_stream_.is_valid())
      return;

    base::RunLoop run_loop;
    run_loop_ = &run_loop;
    run_loop.Run();
    run_loop_ = nullptr;
  }

  void Send(const std::string& message) {
    DCHECK(!message.empty());

    uint32_t size = static_cast<uint32_t>(message.size());
    write_send_stream_->Write(
        &message[0], size,
        base::Bind(&WebSocketClientImpl::OnFinishedWritingSendStream,
                   base::Unretained(this), size));
  }

  void WaitForMessage(size_t count) {
    DCHECK(!run_loop_);

    if (received_messages_.size() >= count)
      return;
    wait_for_message_count_ = count;
    base::RunLoop run_loop;
    run_loop_ = &run_loop;
    run_loop.Run();
    run_loop_ = nullptr;
  }

  std::vector<std::string>& received_messages() { return received_messages_; }

 private:
  // WebSocketClient implementation.
  void DidConnect(const String& selected_subprotocol,
                  const String& extensions,
                  ScopedDataPipeConsumerHandle receive_stream) override {
    receive_stream_ = std::move(receive_stream);
    read_receive_stream_.reset(new WebSocketReadQueue(receive_stream_.get()));

    web_socket_->FlowControl(2048);
    if (run_loop_)
      run_loop_->Quit();
  }

  void DidReceiveData(bool fin,
                      WebSocket::MessageType type,
                      uint32_t num_bytes) override {
    DCHECK(num_bytes > 0);

    read_receive_stream_->Read(
        num_bytes,
        base::Bind(&WebSocketClientImpl::OnFinishedReadingReceiveStream,
                   base::Unretained(this), num_bytes));
  }

  void DidReceiveFlowControl(int64_t quota) override {}

  void DidFail(const String& message) override {}

  void DidClose(bool was_clean, uint16_t code, const String& reason) override {}

  void OnFinishedWritingSendStream(uint32_t num_bytes, const char* buffer) {
    EXPECT_TRUE(buffer);

    web_socket_->Send(true, WebSocket::MessageType::TEXT, num_bytes);
  }

  void OnFinishedReadingReceiveStream(uint32_t num_bytes, const char* data) {
    EXPECT_TRUE(data);

    received_messages_.push_back(std::string(data, num_bytes));
    if (run_loop_ && received_messages_.size() >= wait_for_message_count_) {
      wait_for_message_count_ = 0;
      run_loop_->Quit();
    }
  }

  WebSocketClientPtr client_ptr_;
  Binding<WebSocketClient> binding_;
  WebSocketPtr web_socket_;

  ScopedDataPipeProducerHandle send_stream_;
  scoped_ptr<WebSocketWriteQueue> write_send_stream_;

  ScopedDataPipeConsumerHandle receive_stream_;
  scoped_ptr<WebSocketReadQueue> read_receive_stream_;

  std::vector<std::string> received_messages_;
  size_t wait_for_message_count_;

  // Pointing to a stack-allocated RunLoop instance.
  base::RunLoop* run_loop_;

  DISALLOW_COPY_AND_ASSIGN(WebSocketClientImpl);
};

class HttpConnectionDelegateImpl : public HttpConnectionDelegate {
 public:
  struct PendingRequest {
    HttpRequestPtr request;
    OnReceivedRequestCallback callback;
  };

  HttpConnectionDelegateImpl(HttpConnectionPtr connection,
                             InterfaceRequest<HttpConnectionDelegate> request)
      : connection_(std::move(connection)),
        binding_(this, std::move(request)),
        wait_for_request_count_(0),
        run_loop_(nullptr) {}
  ~HttpConnectionDelegateImpl() override {}

  // HttpConnectionDelegate implementation:
  void OnReceivedRequest(HttpRequestPtr request,
                         const OnReceivedRequestCallback& callback) override {
    linked_ptr<PendingRequest> pending_request(new PendingRequest);
    pending_request->request = std::move(request);
    pending_request->callback = callback;
    pending_requests_.push_back(pending_request);
    if (run_loop_ && pending_requests_.size() >= wait_for_request_count_) {
      wait_for_request_count_ = 0;
      run_loop_->Quit();
    }
  }

  void OnReceivedWebSocketRequest(
      HttpRequestPtr request,
      const OnReceivedWebSocketRequestCallback& callback) override {
    web_socket_.reset(new WebSocketClientImpl());

    web_socket_->AcceptConnectRequest(callback);

    if (run_loop_)
      run_loop_->Quit();
  }

  void SendResponse(HttpResponsePtr response) {
    ASSERT_FALSE(pending_requests_.empty());
    linked_ptr<PendingRequest> request = pending_requests_[0];
    pending_requests_.erase(pending_requests_.begin());
    request->callback.Run(std::move(response));
  }

  void WaitForRequest(size_t count) {
    DCHECK(!run_loop_);

    if (pending_requests_.size() >= count)
      return;

    wait_for_request_count_ = count;
    base::RunLoop run_loop;
    run_loop_ = &run_loop;
    run_loop.Run();
    run_loop_ = nullptr;
  }

  void WaitForWebSocketRequest() {
    DCHECK(!run_loop_);

    if (web_socket_)
      return;

    base::RunLoop run_loop;
    run_loop_ = &run_loop;
    run_loop.Run();
    run_loop_ = nullptr;
  }

  std::vector<linked_ptr<PendingRequest>>& pending_requests() {
    return pending_requests_;
  }

  WebSocketClientImpl* web_socket() { return web_socket_.get(); }

 private:
  HttpConnectionPtr connection_;
  Binding<HttpConnectionDelegate> binding_;
  std::vector<linked_ptr<PendingRequest>> pending_requests_;
  size_t wait_for_request_count_;
  scoped_ptr<WebSocketClientImpl> web_socket_;

  // Pointing to a stack-allocated RunLoop instance.
  base::RunLoop* run_loop_;

  DISALLOW_COPY_AND_ASSIGN(HttpConnectionDelegateImpl);
};

class HttpServerDelegateImpl : public HttpServerDelegate {
 public:
  explicit HttpServerDelegateImpl(HttpServerDelegatePtr* delegate_ptr)
      : binding_(this, delegate_ptr),
        wait_for_connection_count_(0),
        run_loop_(nullptr) {}
  ~HttpServerDelegateImpl() override {}

  // HttpServerDelegate implementation.
  void OnConnected(HttpConnectionPtr connection,
                   InterfaceRequest<HttpConnectionDelegate> delegate) override {
    connections_.push_back(make_linked_ptr(new HttpConnectionDelegateImpl(
        std::move(connection), std::move(delegate))));
    if (run_loop_ && connections_.size() >= wait_for_connection_count_) {
      wait_for_connection_count_ = 0;
      run_loop_->Quit();
    }
  }

  void WaitForConnection(size_t count) {
    DCHECK(!run_loop_);

    if (connections_.size() >= count)
      return;

    wait_for_connection_count_ = count;
    base::RunLoop run_loop;
    run_loop_ = &run_loop;
    run_loop.Run();
    run_loop_ = nullptr;
  }

  std::vector<linked_ptr<HttpConnectionDelegateImpl>>& connections() {
    return connections_;
  }

 private:
  Binding<HttpServerDelegate> binding_;
  std::vector<linked_ptr<HttpConnectionDelegateImpl>> connections_;
  size_t wait_for_connection_count_;
  // Pointing to a stack-allocated RunLoop instance.
  base::RunLoop* run_loop_;

  DISALLOW_COPY_AND_ASSIGN(HttpServerDelegateImpl);
};

class HttpServerTest : public test::ShellTest {
 public:
  HttpServerTest()
      : ShellTest("exe:network_service_unittests") {}
  ~HttpServerTest() override {}

 protected:
  void SetUp() override {
    ShellTest::SetUp();

    scoped_ptr<Connection> connection =
        connector()->Connect("mojo:network_service");
    connection->GetInterface(&network_service_);
    connection->GetInterface(&web_socket_factory_);
  }

  scoped_ptr<base::MessageLoop> CreateMessageLoop() override {
    return make_scoped_ptr(new base::MessageLoop(base::MessageLoop::TYPE_IO));
  }

  void CreateHttpServer(HttpServerDelegatePtr delegate,
                        NetAddressPtr* out_bound_to) {
    base::RunLoop loop;
    network_service_->CreateHttpServer(
        GetLocalHostWithAnyPort(), std::move(delegate),
        [out_bound_to, &loop](NetworkErrorPtr result, NetAddressPtr bound_to) {
          ASSERT_EQ(net::OK, result->code);
          EXPECT_NE(0u, bound_to->ipv4->port);
          *out_bound_to = std::move(bound_to);
          loop.Quit();
        });
    loop.Run();
  }

  NetworkServicePtr network_service_;
  WebSocketFactoryPtr web_socket_factory_;

 private:
  DISALLOW_COPY_AND_ASSIGN(HttpServerTest);
};

}  // namespace

TEST_F(HttpServerTest, BasicHttpRequestResponse) {
  NetAddressPtr bound_to;
  HttpServerDelegatePtr server_delegate_ptr;
  HttpServerDelegateImpl server_delegate_impl(&server_delegate_ptr);
  CreateHttpServer(std::move(server_delegate_ptr), &bound_to);

  TestHttpClient client;
  client.Connect(bound_to.To<net::IPEndPoint>());

  server_delegate_impl.WaitForConnection(1);
  HttpConnectionDelegateImpl& connection =
      *server_delegate_impl.connections()[0];

  TestRequest request_data = {"HEAD", "/test", {{"Hello", "World"}}, nullptr};
  client.Send(MakeRequestMessage(request_data));

  connection.WaitForRequest(1);

  CheckRequest(request_data,
               std::move(connection.pending_requests()[0]->request));

  TestResponse response_data = {200, {{"Content-Length", "4"}}, nullptr};
  connection.SendResponse(MakeResponseStruct(response_data));
  // This causes the underlying TCP connection to be closed. The client can
  // determine the end of the response based on that.
  server_delegate_impl.connections().clear();

  std::string response_message;
  client.ReadResponse(&response_message);

  CheckResponse(response_data, response_message);
}

TEST_F(HttpServerTest, HttpRequestResponseWithBody) {
  NetAddressPtr bound_to;
  HttpServerDelegatePtr server_delegate_ptr;
  HttpServerDelegateImpl server_delegate_impl(&server_delegate_ptr);
  CreateHttpServer(std::move(server_delegate_ptr), &bound_to);

  TestHttpClient client;
  client.Connect(bound_to.To<net::IPEndPoint>());

  server_delegate_impl.WaitForConnection(1);
  HttpConnectionDelegateImpl& connection =
      *server_delegate_impl.connections()[0];

  TestRequest request_data = {
      "Post",
      "/test",
      {{"Hello", "World"},
       {"Content-Length", "23"},
       {"Content-Type", "text/plain"}},
      make_scoped_ptr(new std::string("This is a test request!"))};
  client.Send(MakeRequestMessage(request_data));

  connection.WaitForRequest(1);

  CheckRequest(request_data,
               std::move(connection.pending_requests()[0]->request));

  TestResponse response_data = {
      200,
      {{"Content-Length", "26"}},
      make_scoped_ptr(new std::string("This is a test response..."))};
  connection.SendResponse(MakeResponseStruct(response_data));

  std::string response_message;
  client.ReadResponse(&response_message);

  CheckResponse(response_data, response_message);
}

TEST_F(HttpServerTest, WebSocket) {
  NetAddressPtr bound_to;
  HttpServerDelegatePtr server_delegate_ptr;
  HttpServerDelegateImpl server_delegate_impl(&server_delegate_ptr);
  CreateHttpServer(std::move(server_delegate_ptr), &bound_to);

  WebSocketPtr web_socket_ptr;
  web_socket_factory_->CreateWebSocket(GetProxy(&web_socket_ptr));
  WebSocketClientImpl socket_0;
  socket_0.Connect(
      std::move(web_socket_ptr),
      base::StringPrintf("ws://127.0.0.1:%d/hello", bound_to->ipv4->port));

  server_delegate_impl.WaitForConnection(1);
  HttpConnectionDelegateImpl& connection =
      *server_delegate_impl.connections()[0];

  connection.WaitForWebSocketRequest();
  WebSocketClientImpl& socket_1 = *connection.web_socket();

  socket_1.WaitForConnectCompletion();
  socket_0.WaitForConnectCompletion();

  socket_0.Send("Hello");
  socket_0.Send("world!");

  socket_1.WaitForMessage(2);
  EXPECT_EQ("Hello", socket_1.received_messages()[0]);
  EXPECT_EQ("world!", socket_1.received_messages()[1]);

  socket_1.Send("How do");
  socket_1.Send("you do?");

  socket_0.WaitForMessage(2);
  EXPECT_EQ("How do", socket_0.received_messages()[0]);
  EXPECT_EQ("you do?", socket_0.received_messages()[1]);
}

}  // namespace mojo