// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/socket/transport_client_socket_pool_test_util.h"

#include <stdint.h>
#include <string>
#include <utility>

#include "base/location.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/run_loop.h"
#include "base/single_thread_task_runner.h"
#include "base/thread_task_runner_handle.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/load_timing_info.h"
#include "net/base/load_timing_info_test_util.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/ssl_client_socket.h"
#include "net/udp/datagram_client_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {

namespace {

IPAddress ParseIP(const std::string& ip) {
  IPAddress address;
  CHECK(address.AssignFromIPLiteral(ip));
  return address;
}

// A StreamSocket which connects synchronously and successfully.
class MockConnectClientSocket : public StreamSocket {
 public:
  MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
      : connected_(false),
        addrlist_(addrlist),
        net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {}

  // StreamSocket implementation.
  int Connect(const CompletionCallback& callback) override {
    connected_ = true;
    return OK;
  }
  void Disconnect() override { connected_ = false; }
  bool IsConnected() const override { return connected_; }
  bool IsConnectedAndIdle() const override { return connected_; }

  int GetPeerAddress(IPEndPoint* address) const override {
    *address = addrlist_.front();
    return OK;
  }
  int GetLocalAddress(IPEndPoint* address) const override {
    if (!connected_)
      return ERR_SOCKET_NOT_CONNECTED;
    if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
      SetIPv4Address(address);
    else
      SetIPv6Address(address);
    return OK;
  }
  const BoundNetLog& NetLog() const override { return net_log_; }

  void SetSubresourceSpeculation() override {}
  void SetOmniboxSpeculation() override {}
  bool WasEverUsed() const override { return false; }
  void EnableTCPFastOpenIfSupported() override {}
  bool WasNpnNegotiated() const override { return false; }
  NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
  bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
  void GetConnectionAttempts(ConnectionAttempts* out) const override {
    out->clear();
  }
  void ClearConnectionAttempts() override {}
  void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
  int64_t GetTotalReceivedBytes() const override {
    NOTIMPLEMENTED();
    return 0;
  }

  // Socket implementation.
  int Read(IOBuffer* buf,
           int buf_len,
           const CompletionCallback& callback) override {
    return ERR_FAILED;
  }
  int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override {
    return ERR_FAILED;
  }
  int SetReceiveBufferSize(int32_t size) override { return OK; }
  int SetSendBufferSize(int32_t size) override { return OK; }

 private:
  bool connected_;
  const AddressList addrlist_;
  BoundNetLog net_log_;

  DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket);
};

class MockFailingClientSocket : public StreamSocket {
 public:
  MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
      : addrlist_(addrlist),
        net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {}

  // StreamSocket implementation.
  int Connect(const CompletionCallback& callback) override {
    return ERR_CONNECTION_FAILED;
  }

  void Disconnect() override {}

  bool IsConnected() const override { return false; }
  bool IsConnectedAndIdle() const override { return false; }
  int GetPeerAddress(IPEndPoint* address) const override {
    return ERR_UNEXPECTED;
  }
  int GetLocalAddress(IPEndPoint* address) const override {
    return ERR_UNEXPECTED;
  }
  const BoundNetLog& NetLog() const override { return net_log_; }

  void SetSubresourceSpeculation() override {}
  void SetOmniboxSpeculation() override {}
  bool WasEverUsed() const override { return false; }
  void EnableTCPFastOpenIfSupported() override {}
  bool WasNpnNegotiated() const override { return false; }
  NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
  bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
  void GetConnectionAttempts(ConnectionAttempts* out) const override {
    out->clear();
    for (const auto& addr : addrlist_)
      out->push_back(ConnectionAttempt(addr, ERR_CONNECTION_FAILED));
  }
  void ClearConnectionAttempts() override {}
  void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}
  int64_t GetTotalReceivedBytes() const override {
    NOTIMPLEMENTED();
    return 0;
  }

  // Socket implementation.
  int Read(IOBuffer* buf,
           int buf_len,
           const CompletionCallback& callback) override {
    return ERR_FAILED;
  }

  int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override {
    return ERR_FAILED;
  }
  int SetReceiveBufferSize(int32_t size) override { return OK; }
  int SetSendBufferSize(int32_t size) override { return OK; }

 private:
  const AddressList addrlist_;
  BoundNetLog net_log_;

  DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
};

class MockTriggerableClientSocket : public StreamSocket {
 public:
  // |should_connect| indicates whether the socket should successfully complete
  // or fail.
  MockTriggerableClientSocket(const AddressList& addrlist,
                              bool should_connect,
                              net::NetLog* net_log)
      : should_connect_(should_connect),
        is_connected_(false),
        addrlist_(addrlist),
        net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
        weak_factory_(this) {}

  // Call this method to get a closure which will trigger the connect callback
  // when called. The closure can be called even after the socket is deleted; it
  // will safely do nothing.
  base::Closure GetConnectCallback() {
    return base::Bind(&MockTriggerableClientSocket::DoCallback,
                      weak_factory_.GetWeakPtr());
  }

  static scoped_ptr<StreamSocket> MakeMockPendingClientSocket(
      const AddressList& addrlist,
      bool should_connect,
      net::NetLog* net_log) {
    scoped_ptr<MockTriggerableClientSocket> socket(
        new MockTriggerableClientSocket(addrlist, should_connect, net_log));
    base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE,
                                                  socket->GetConnectCallback());
    return std::move(socket);
  }

  static scoped_ptr<StreamSocket> MakeMockDelayedClientSocket(
      const AddressList& addrlist,
      bool should_connect,
      const base::TimeDelta& delay,
      net::NetLog* net_log) {
    scoped_ptr<MockTriggerableClientSocket> socket(
        new MockTriggerableClientSocket(addrlist, should_connect, net_log));
    base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
        FROM_HERE, socket->GetConnectCallback(), delay);
    return std::move(socket);
  }

  static scoped_ptr<StreamSocket> MakeMockStalledClientSocket(
      const AddressList& addrlist,
      net::NetLog* net_log,
      bool failing) {
    scoped_ptr<MockTriggerableClientSocket> socket(
        new MockTriggerableClientSocket(addrlist, true, net_log));
    if (failing) {
      DCHECK_LE(1u, addrlist.size());
      ConnectionAttempts attempts;
      attempts.push_back(ConnectionAttempt(addrlist[0], ERR_CONNECTION_FAILED));
      socket->AddConnectionAttempts(attempts);
    }
    return std::move(socket);
  }

  // StreamSocket implementation.
  int Connect(const CompletionCallback& callback) override {
    DCHECK(callback_.is_null());
    callback_ = callback;
    return ERR_IO_PENDING;
  }

  void Disconnect() override {}

  bool IsConnected() const override { return is_connected_; }
  bool IsConnectedAndIdle() const override { return is_connected_; }
  int GetPeerAddress(IPEndPoint* address) const override {
    *address = addrlist_.front();
    return OK;
  }
  int GetLocalAddress(IPEndPoint* address) const override {
    if (!is_connected_)
      return ERR_SOCKET_NOT_CONNECTED;
    if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
      SetIPv4Address(address);
    else
      SetIPv6Address(address);
    return OK;
  }
  const BoundNetLog& NetLog() const override { return net_log_; }

  void SetSubresourceSpeculation() override {}
  void SetOmniboxSpeculation() override {}
  bool WasEverUsed() const override { return false; }
  void EnableTCPFastOpenIfSupported() override {}
  bool WasNpnNegotiated() const override { return false; }
  NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
  bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
  void GetConnectionAttempts(ConnectionAttempts* out) const override {
    *out = connection_attempts_;
  }
  void ClearConnectionAttempts() override { connection_attempts_.clear(); }
  void AddConnectionAttempts(const ConnectionAttempts& attempts) override {
    connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(),
                                attempts.end());
  }
  int64_t GetTotalReceivedBytes() const override {
    NOTIMPLEMENTED();
    return 0;
  }

  // Socket implementation.
  int Read(IOBuffer* buf,
           int buf_len,
           const CompletionCallback& callback) override {
    return ERR_FAILED;
  }

  int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override {
    return ERR_FAILED;
  }
  int SetReceiveBufferSize(int32_t size) override { return OK; }
  int SetSendBufferSize(int32_t size) override { return OK; }

 private:
  void DoCallback() {
    is_connected_ = should_connect_;
    callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED);
  }

  bool should_connect_;
  bool is_connected_;
  const AddressList addrlist_;
  BoundNetLog net_log_;
  CompletionCallback callback_;
  ConnectionAttempts connection_attempts_;

  base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_;

  DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket);
};

}  // namespace

void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
  LoadTimingInfo load_timing_info;
  // Only pass true in as |is_reused|, as in general, HttpStream types should
  // have stricter concepts of reuse than socket pools.
  EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));

  EXPECT_TRUE(load_timing_info.socket_reused);
  EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);

  ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
  ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
}

void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
  EXPECT_FALSE(handle.is_reused());

  LoadTimingInfo load_timing_info;
  EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));

  EXPECT_FALSE(load_timing_info.socket_reused);
  EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);

  ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
                              CONNECT_TIMING_HAS_DNS_TIMES);
  ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);

  TestLoadTimingInfoConnectedReused(handle);
}

void SetIPv4Address(IPEndPoint* address) {
  *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
}

void SetIPv6Address(IPEndPoint* address) {
  *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
}

MockTransportClientSocketFactory::MockTransportClientSocketFactory(
    NetLog* net_log)
    : net_log_(net_log),
      allocation_count_(0),
      client_socket_type_(MOCK_CLIENT_SOCKET),
      client_socket_types_(NULL),
      client_socket_index_(0),
      client_socket_index_max_(0),
      delay_(base::TimeDelta::FromMilliseconds(
          ClientSocketPool::kMaxConnectRetryIntervalMs)) {}

MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {}

scoped_ptr<DatagramClientSocket>
MockTransportClientSocketFactory::CreateDatagramClientSocket(
    DatagramSocket::BindType bind_type,
    const RandIntCallback& rand_int_cb,
    NetLog* net_log,
    const NetLog::Source& source) {
  NOTREACHED();
  return scoped_ptr<DatagramClientSocket>();
}

scoped_ptr<StreamSocket>
MockTransportClientSocketFactory::CreateTransportClientSocket(
    const AddressList& addresses,
    NetLog* /* net_log */,
    const NetLog::Source& /* source */) {
  allocation_count_++;

  ClientSocketType type = client_socket_type_;
  if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) {
    type = client_socket_types_[client_socket_index_++];
  }

  switch (type) {
    case MOCK_CLIENT_SOCKET:
      return scoped_ptr<StreamSocket>(
          new MockConnectClientSocket(addresses, net_log_));
    case MOCK_FAILING_CLIENT_SOCKET:
      return scoped_ptr<StreamSocket>(
          new MockFailingClientSocket(addresses, net_log_));
    case MOCK_PENDING_CLIENT_SOCKET:
      return MockTriggerableClientSocket::MakeMockPendingClientSocket(
          addresses, true, net_log_);
    case MOCK_PENDING_FAILING_CLIENT_SOCKET:
      return MockTriggerableClientSocket::MakeMockPendingClientSocket(
          addresses, false, net_log_);
    case MOCK_DELAYED_CLIENT_SOCKET:
      return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
          addresses, true, delay_, net_log_);
    case MOCK_DELAYED_FAILING_CLIENT_SOCKET:
      return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
          addresses, false, delay_, net_log_);
    case MOCK_STALLED_CLIENT_SOCKET:
      return MockTriggerableClientSocket::MakeMockStalledClientSocket(
          addresses, net_log_, false);
    case MOCK_STALLED_FAILING_CLIENT_SOCKET:
      return MockTriggerableClientSocket::MakeMockStalledClientSocket(
          addresses, net_log_, true);
    case MOCK_TRIGGERABLE_CLIENT_SOCKET: {
      scoped_ptr<MockTriggerableClientSocket> rv(
          new MockTriggerableClientSocket(addresses, true, net_log_));
      triggerable_sockets_.push(rv->GetConnectCallback());
      // run_loop_quit_closure_ behaves like a condition variable. It will
      // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
      // don't need to worry about atomicity because this code is
      // single-threaded.
      if (!run_loop_quit_closure_.is_null())
        run_loop_quit_closure_.Run();
      return std::move(rv);
    }
    default:
      NOTREACHED();
      return scoped_ptr<StreamSocket>(
          new MockConnectClientSocket(addresses, net_log_));
  }
}

scoped_ptr<SSLClientSocket>
MockTransportClientSocketFactory::CreateSSLClientSocket(
    scoped_ptr<ClientSocketHandle> transport_socket,
    const HostPortPair& host_and_port,
    const SSLConfig& ssl_config,
    const SSLClientSocketContext& context) {
  NOTIMPLEMENTED();
  return scoped_ptr<SSLClientSocket>();
}

void MockTransportClientSocketFactory::ClearSSLSessionCache() {
  NOTIMPLEMENTED();
}

void MockTransportClientSocketFactory::set_client_socket_types(
    ClientSocketType* type_list,
    int num_types) {
  DCHECK_GT(num_types, 0);
  client_socket_types_ = type_list;
  client_socket_index_ = 0;
  client_socket_index_max_ = num_types;
}

base::Closure
MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
  while (triggerable_sockets_.empty()) {
    base::RunLoop run_loop;
    run_loop_quit_closure_ = run_loop.QuitClosure();
    run_loop.Run();
    run_loop_quit_closure_.Reset();
  }
  base::Closure trigger = triggerable_sockets_.front();
  triggerable_sockets_.pop();
  return trigger;
}

}  // namespace net