// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "extensions/browser/api/socket/tcp_socket.h"

#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
#include "net/base/address_list.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/rand_callback.h"
#include "net/socket/tcp_client_socket.h"
#include "net/socket/tcp_server_socket.h"
#include "testing/gmock/include/gmock/gmock.h"

using testing::_;
using testing::DoAll;
using testing::Return;
using testing::SaveArg;

namespace extensions {

class MockTCPSocket : public net::TCPClientSocket {
 public:
  explicit MockTCPSocket(const net::AddressList& address_list)
      : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) {
  }

  MOCK_METHOD3(Read, int(net::IOBuffer* buf, int buf_len,
                         const net::CompletionCallback& callback));
  MOCK_METHOD3(Write, int(net::IOBuffer* buf, int buf_len,
                          const net::CompletionCallback& callback));
  MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay));
  MOCK_METHOD1(SetNoDelay, bool(bool no_delay));
  bool IsConnected() const override {
    return true;
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(MockTCPSocket);
};

class MockTCPServerSocket : public net::TCPServerSocket {
 public:
  MockTCPServerSocket() : net::TCPServerSocket(NULL, net::NetLog::Source()) {}
  MOCK_METHOD2(Listen, int(const net::IPEndPoint& address, int backlog));
  MOCK_METHOD2(Accept, int(scoped_ptr<net::StreamSocket>* socket,
                            const net::CompletionCallback& callback));

 private:
  DISALLOW_COPY_AND_ASSIGN(MockTCPServerSocket);
};

class CompleteHandler {
 public:
  CompleteHandler() {}
  MOCK_METHOD1(OnComplete, void(int result_code));
  MOCK_METHOD2(OnReadComplete, void(int result_code,
      scoped_refptr<net::IOBuffer> io_buffer));

  // MOCK_METHOD cannot mock a scoped_ptr argument.
  MOCK_METHOD2(OnAcceptMock, void(int, net::TCPClientSocket*));
  void OnAccept(int count, scoped_ptr<net::TCPClientSocket> socket) {
    OnAcceptMock(count, socket.get());
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
};

const std::string FAKE_ID = "abcdefghijklmnopqrst";

TEST(SocketTest, TestTCPSocketRead) {
  net::AddressList address_list;
  scoped_ptr<MockTCPSocket> tcp_client_socket(new MockTCPSocket(address_list));
  CompleteHandler handler;

  EXPECT_CALL(*tcp_client_socket, Read(_, _, _))
      .Times(1);
  EXPECT_CALL(handler, OnReadComplete(_, _))
      .Times(1);

  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
      std::move(tcp_client_socket), FAKE_ID, true));

  const int count = 512;
  socket->Read(count, base::Bind(&CompleteHandler::OnReadComplete,
        base::Unretained(&handler)));
}

TEST(SocketTest, TestTCPSocketWrite) {
  net::AddressList address_list;
  scoped_ptr<MockTCPSocket> tcp_client_socket(new MockTCPSocket(address_list));
  CompleteHandler handler;

  net::CompletionCallback callback;
  EXPECT_CALL(*tcp_client_socket, Write(_, _, _))
      .Times(2)
      .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback),
                                     Return(128)));
  EXPECT_CALL(handler, OnComplete(_))
      .Times(1);

  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
      std::move(tcp_client_socket), FAKE_ID, true));

  scoped_refptr<net::IOBufferWithSize> io_buffer(
      new net::IOBufferWithSize(256));
  socket->Write(io_buffer.get(), io_buffer->size(),
      base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));
}

TEST(SocketTest, TestTCPSocketBlockedWrite) {
  net::AddressList address_list;
  scoped_ptr<MockTCPSocket> tcp_client_socket(new MockTCPSocket(address_list));
  CompleteHandler handler;

  net::CompletionCallback callback;
  EXPECT_CALL(*tcp_client_socket, Write(_, _, _))
      .Times(2)
      .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback),
                                     Return(net::ERR_IO_PENDING)));

  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
      std::move(tcp_client_socket), FAKE_ID, true));

  scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42));
  socket->Write(io_buffer.get(), io_buffer->size(),
      base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));

  // Good. Original call came back unable to complete. Now pretend the socket
  // finished, and confirm that we passed the error back.
  EXPECT_CALL(handler, OnComplete(42))
      .Times(1);
  callback.Run(40);
  callback.Run(2);
}

TEST(SocketTest, TestTCPSocketBlockedWriteReentry) {
  net::AddressList address_list;
  scoped_ptr<MockTCPSocket> tcp_client_socket(new MockTCPSocket(address_list));
  CompleteHandler handlers[5];

  net::CompletionCallback callback;
  EXPECT_CALL(*tcp_client_socket, Write(_, _, _))
      .Times(5)
      .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback),
                                     Return(net::ERR_IO_PENDING)));

  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
      std::move(tcp_client_socket), FAKE_ID, true));

  scoped_refptr<net::IOBufferWithSize> io_buffers[5];
  int i;
  for (i = 0; i < 5; i++) {
    io_buffers[i] = new net::IOBufferWithSize(128 + i * 50);
    scoped_refptr<net::IOBufferWithSize> io_buffer1(
        new net::IOBufferWithSize(42));
    socket->Write(io_buffers[i].get(), io_buffers[i]->size(),
        base::Bind(&CompleteHandler::OnComplete,
            base::Unretained(&handlers[i])));

    EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size()))
        .Times(1);
  }

  for (i = 0; i < 5; i++) {
    callback.Run(128 + i * 50);
  }
}

TEST(SocketTest, TestTCPSocketSetNoDelay) {
  net::AddressList address_list;
  scoped_ptr<MockTCPSocket> tcp_client_socket(new MockTCPSocket(address_list));

  bool no_delay = false;
  {
    testing::InSequence dummy;
    EXPECT_CALL(*tcp_client_socket, SetNoDelay(_))
        .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(true)));
    EXPECT_CALL(*tcp_client_socket, SetNoDelay(_))
        .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(false)));
  }

  scoped_ptr<TCPSocket> socket(
      TCPSocket::CreateSocketForTesting(std::move(tcp_client_socket), FAKE_ID));

  EXPECT_FALSE(no_delay);
  int result = socket->SetNoDelay(true);
  EXPECT_TRUE(result);
  EXPECT_TRUE(no_delay);

  result = socket->SetNoDelay(false);
  EXPECT_FALSE(result);
  EXPECT_FALSE(no_delay);
}

TEST(SocketTest, TestTCPSocketSetKeepAlive) {
  net::AddressList address_list;
  scoped_ptr<MockTCPSocket> tcp_client_socket(new MockTCPSocket(address_list));

  bool enable = false;
  int delay = 0;
  {
    testing::InSequence dummy;
    EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _))
        .WillOnce(testing::DoAll(SaveArg<0>(&enable), SaveArg<1>(&delay),
                                 Return(true)));
    EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _))
        .WillOnce(testing::DoAll(SaveArg<0>(&enable), SaveArg<1>(&delay),
                                 Return(false)));
  }

  scoped_ptr<TCPSocket> socket(
      TCPSocket::CreateSocketForTesting(std::move(tcp_client_socket), FAKE_ID));

  EXPECT_FALSE(enable);
  int result = socket->SetKeepAlive(true, 4500);
  EXPECT_TRUE(result);
  EXPECT_TRUE(enable);
  EXPECT_EQ(4500, delay);

  result = socket->SetKeepAlive(false, 0);
  EXPECT_FALSE(result);
  EXPECT_FALSE(enable);
  EXPECT_EQ(0, delay);
}

TEST(SocketTest, TestTCPServerSocketListenAccept) {
  scoped_ptr<MockTCPServerSocket> tcp_server_socket(new MockTCPServerSocket());
  CompleteHandler handler;

  EXPECT_CALL(*tcp_server_socket, Accept(_, _)).Times(1);
  EXPECT_CALL(*tcp_server_socket, Listen(_, _)).Times(1);

  scoped_ptr<TCPSocket> socket(TCPSocket::CreateServerSocketForTesting(
      std::move(tcp_server_socket), FAKE_ID));

  EXPECT_CALL(handler, OnAcceptMock(_, _));

  std::string err_msg;
  EXPECT_EQ(net::OK, socket->Listen("127.0.0.1", 9999, 10, &err_msg));
  socket->Accept(base::Bind(&CompleteHandler::OnAccept,
        base::Unretained(&handler)));
}

}  // namespace extensions