// Copyright (c) 2006-2010 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/basictypes.h"
#include "sandbox/src/crosscall_client.h"
#include "sandbox/src/crosscall_server.h"
#include "sandbox/src/sharedmem_ipc_client.h"
#include "sandbox/src/sharedmem_ipc_server.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace sandbox {

// Helper function to make the fake shared memory with some
// basic elements initialized.
IPCControl* MakeChannels(size_t channel_size, size_t total_shared_size,
                         size_t* base_start) {
  // Allocate memory
  char* mem = new char[total_shared_size];
  memset(mem, 0, total_shared_size);
  // Calculate how many channels we can fit in the shared memory.
  total_shared_size -= offsetof(IPCControl, channels);
  size_t channel_count =
    total_shared_size / (sizeof(ChannelControl) + channel_size);
  // Calculate the start of the first channel.
  *base_start = (sizeof(ChannelControl)* channel_count) +
    offsetof(IPCControl, channels);
  // Setup client structure.
  IPCControl* client_control = reinterpret_cast<IPCControl*>(mem);
  client_control->channels_count = channel_count;
  return client_control;
}

enum TestFixMode {
  FIX_NO_EVENTS,
  FIX_PONG_READY,
  FIX_PONG_NOT_READY
};

void FixChannels(IPCControl* client_control, size_t base_start,
                 size_t channel_size, TestFixMode mode) {
  for (size_t ix = 0; ix != client_control->channels_count; ++ix) {
    ChannelControl& channel = client_control->channels[ix];
    channel.channel_base = base_start;
    channel.state = kFreeChannel;
    if (mode != FIX_NO_EVENTS) {
      BOOL signaled = (FIX_PONG_READY == mode)? TRUE : FALSE;
      channel.ping_event = ::CreateEventW(NULL, FALSE, FALSE, NULL);
      channel.pong_event = ::CreateEventW(NULL, FALSE, signaled, NULL);
    }
    base_start += channel_size;
  }
}

void CloseChannelEvents(IPCControl* client_control) {
  for (size_t ix = 0; ix != client_control->channels_count; ++ix) {
    ChannelControl& channel = client_control->channels[ix];
    ::CloseHandle(channel.ping_event);
    ::CloseHandle(channel.pong_event);
  }
}

TEST(IPCTest, ChannelMaker) {
  // Test that our testing rig is computing offsets properly. We should have
  // 5 channnels and the offset to the first channel is 108 bytes in 32 bits
  // and 216 in 64 bits.
  size_t channel_start = 0;
  IPCControl* client_control = MakeChannels(12 * 64, 4096, &channel_start);
  ASSERT_TRUE(NULL != client_control);
  EXPECT_EQ(5, client_control->channels_count);
#if defined(_WIN64)
  EXPECT_EQ(216, channel_start);
#else
  EXPECT_EQ(108, channel_start);
#endif
  delete[] reinterpret_cast<char*>(client_control);
}

TEST(IPCTest, ClientLockUnlock) {
  // Make 7 channels of kIPCChannelSize (1kb) each. Test that we lock and
  // unlock channels properly.
  size_t base_start = 0;
  IPCControl* client_control =
      MakeChannels(kIPCChannelSize, 4096 * 2, &base_start);
  FixChannels(client_control, base_start, kIPCChannelSize, FIX_NO_EVENTS);

  char* mem = reinterpret_cast<char*>(client_control);
  SharedMemIPCClient client(mem);

  // Test that we lock the first 3 channels in sequence.
  void* buff0 = client.GetBuffer();
  EXPECT_TRUE(mem + client_control->channels[0].channel_base == buff0);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[1].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[2].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[3].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[4].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[5].state);

  void* buff1 = client.GetBuffer();
  EXPECT_TRUE(mem + client_control->channels[1].channel_base == buff1);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[1].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[2].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[3].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[4].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[5].state);

  void* buff2 = client.GetBuffer();
  EXPECT_TRUE(mem + client_control->channels[2].channel_base == buff2);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[1].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[2].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[3].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[4].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[5].state);

  // Test that we unlock and re-lock the right channel.
  client.FreeBuffer(buff1);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[1].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[2].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[3].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[4].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[5].state);

  void* buff2b = client.GetBuffer();
  EXPECT_TRUE(mem + client_control->channels[1].channel_base == buff2b);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[1].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[2].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[3].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[4].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[5].state);

  client.FreeBuffer(buff0);
  EXPECT_EQ(kFreeChannel, client_control->channels[0].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[1].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[2].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[3].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[4].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[5].state);

  delete[] reinterpret_cast<char*>(client_control);
}

TEST(IPCTest, CrossCallStrPacking) {
  // This test tries the CrossCall object with null and non-null string
  // combination of parameters, integer types and verifies that the unpacker
  // can read them properly.
  size_t base_start = 0;
  IPCControl* client_control =
      MakeChannels(kIPCChannelSize, 4096 * 4, &base_start);
  client_control->server_alive = HANDLE(1);
  FixChannels(client_control, base_start, kIPCChannelSize, FIX_PONG_READY);

  char* mem = reinterpret_cast<char*>(client_control);
  SharedMemIPCClient client(mem);

  CrossCallReturn answer;
  uint32 tag1 = 666;
  const wchar_t text[] = L"98765 - 43210";
  std::wstring copied_text;
  CrossCallParamsEx* actual_params;

  CrossCall(client, tag1, text, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(1, actual_params->GetParamsCount());
  EXPECT_EQ(tag1, actual_params->GetTag());
  EXPECT_TRUE(actual_params->GetParameterStr(0, &copied_text));
  EXPECT_STREQ(text, copied_text.c_str());

  // Check with an empty string.
  uint32 tag2 = 777;
  const wchar_t* null_text = NULL;
  CrossCall(client, tag2, null_text, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(1, actual_params->GetParamsCount());
  EXPECT_EQ(tag2, actual_params->GetTag());
  size_t param_size = 1;
  ArgType type = INVALID_TYPE;
  void* param_addr = actual_params->GetRawParameter(0, &param_size, &type);
  EXPECT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, param_size);
  EXPECT_EQ(WCHAR_TYPE, type);
  EXPECT_TRUE(actual_params->GetParameterStr(0, &copied_text));

  uint32 tag3 = 888;
  param_size = 1;
  copied_text.clear();

  // Check with an empty string and a non-empty string.
  CrossCall(client, tag3, null_text, text, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(2, actual_params->GetParamsCount());
  EXPECT_EQ(tag3, actual_params->GetTag());
  type = INVALID_TYPE;
  param_addr = actual_params->GetRawParameter(0, &param_size, &type);
  EXPECT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, param_size);
  EXPECT_EQ(WCHAR_TYPE, type);
  EXPECT_TRUE(actual_params->GetParameterStr(0, &copied_text));
  EXPECT_TRUE(actual_params->GetParameterStr(1, &copied_text));
  EXPECT_STREQ(text, copied_text.c_str());

  param_size = 1;
  std::wstring copied_text_p0, copied_text_p2;

  const wchar_t text2[] = L"AeFG";
  CrossCall(client, tag1, text2, null_text, text, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(3, actual_params->GetParamsCount());
  EXPECT_EQ(tag1, actual_params->GetTag());
  EXPECT_TRUE(actual_params->GetParameterStr(0, &copied_text_p0));
  EXPECT_STREQ(text2, copied_text_p0.c_str());
  EXPECT_TRUE(actual_params->GetParameterStr(2, &copied_text_p2));
  EXPECT_STREQ(text, copied_text_p2.c_str());
  type = INVALID_TYPE;
  param_addr = actual_params->GetRawParameter(1, &param_size, &type);
  EXPECT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, param_size);
  EXPECT_EQ(WCHAR_TYPE, type);

  CloseChannelEvents(client_control);
  delete[] reinterpret_cast<char*>(client_control);
}

TEST(IPCTest, CrossCallIntPacking) {
  // Check handling for regular 32 bit integers used in Windows.
  size_t base_start = 0;
  IPCControl* client_control =
      MakeChannels(kIPCChannelSize, 4096 * 4, &base_start);
  client_control->server_alive = HANDLE(1);
  FixChannels(client_control, base_start, kIPCChannelSize, FIX_PONG_READY);

  uint32 tag1 = 999;
  uint32 tag2 = 111;
  const wchar_t text[] = L"godzilla";
  CrossCallParamsEx* actual_params;

  char* mem = reinterpret_cast<char*>(client_control);
  SharedMemIPCClient client(mem);

  CrossCallReturn answer;
  DWORD dw = 0xE6578;
  CrossCall(client, tag2, dw, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(1, actual_params->GetParamsCount());
  EXPECT_EQ(tag2, actual_params->GetTag());
  ArgType type = INVALID_TYPE;
  size_t param_size = 1;
  void* param_addr = actual_params->GetRawParameter(0, &param_size, &type);
  ASSERT_EQ(sizeof(dw), param_size);
  EXPECT_EQ(ULONG_TYPE, type);
  ASSERT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, memcmp(&dw, param_addr, param_size));

  // Check handling for windows HANDLES.
  HANDLE h = HANDLE(0x70000500);
  CrossCall(client, tag1, text, h, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(2, actual_params->GetParamsCount());
  EXPECT_EQ(tag1, actual_params->GetTag());
  type = INVALID_TYPE;
  param_addr = actual_params->GetRawParameter(1, &param_size, &type);
  ASSERT_EQ(sizeof(h), param_size);
  EXPECT_EQ(VOIDPTR_TYPE, type);
  ASSERT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, memcmp(&h, param_addr, param_size));

  // Check combination of 32 and 64 bits.
  CrossCall(client, tag2, h, dw, h, &answer);
  actual_params = reinterpret_cast<CrossCallParamsEx*>(client.GetBuffer());
  EXPECT_EQ(3, actual_params->GetParamsCount());
  EXPECT_EQ(tag2, actual_params->GetTag());
  type = INVALID_TYPE;
  param_addr = actual_params->GetRawParameter(0, &param_size, &type);
  ASSERT_EQ(sizeof(h), param_size);
  EXPECT_EQ(VOIDPTR_TYPE, type);
  ASSERT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, memcmp(&h, param_addr, param_size));
  type = INVALID_TYPE;
  param_addr = actual_params->GetRawParameter(1, &param_size, &type);
  ASSERT_EQ(sizeof(dw), param_size);
  EXPECT_EQ(ULONG_TYPE, type);
  ASSERT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, memcmp(&dw, param_addr, param_size));
  type = INVALID_TYPE;
  param_addr = actual_params->GetRawParameter(2, &param_size, &type);
  ASSERT_EQ(sizeof(h), param_size);
  EXPECT_EQ(VOIDPTR_TYPE, type);
  ASSERT_TRUE(NULL != param_addr);
  EXPECT_EQ(0, memcmp(&h, param_addr, param_size));

  CloseChannelEvents(client_control);
  delete[] reinterpret_cast<char*>(client_control);
}

TEST(IPCTest, CrossCallValidation) {
  // First a sanity test with a well formed parameter object.
  unsigned long value = 124816;
  const uint32 kTag = 33;
  ActualCallParams<1, 256> params_1(kTag);
  params_1.CopyParamIn(0, &value, sizeof(value), false, ULONG_TYPE);
  void* buffer = const_cast<void*>(params_1.GetBuffer());

  size_t out_size = 0;
  CrossCallParamsEx* ccp = 0;
  ccp = CrossCallParamsEx::CreateFromBuffer(buffer, params_1.GetSize(),
                                            &out_size);
  ASSERT_TRUE(NULL != ccp);
  EXPECT_TRUE(ccp->GetBuffer() != buffer);
  EXPECT_EQ(kTag, ccp->GetTag());
  EXPECT_EQ(1, ccp->GetParamsCount());
  delete[] (reinterpret_cast<char*>(ccp));

#if defined(NDEBUG)
  // Test hat we handle integer overflow on the number of params
  // correctly. We use a test-only ctor for ActualCallParams that
  // allows to create malformed cross-call buffers.
  const int32 kPtrDiffSz = sizeof(ptrdiff_t);
  for (int32 ix = -1; ix != 3; ++ix) {
    uint32 fake_num_params = (kuint32max / kPtrDiffSz) + ix;
    ActualCallParams<1, 256> params_2(kTag, fake_num_params);
    params_2.CopyParamIn(0, &value, sizeof(value), false, ULONG_TYPE);
    buffer = const_cast<void*>(params_2.GetBuffer());

    EXPECT_TRUE(NULL != buffer);
    ccp = CrossCallParamsEx::CreateFromBuffer(buffer, params_2.GetSize(),
                                              &out_size);
    // If the buffer is malformed the return is NULL.
    EXPECT_TRUE(NULL == ccp);
  }
#endif  // defined(NDEBUG)

  ActualCallParams<1, 256> params_3(kTag, 1);
  params_3.CopyParamIn(0, &value, sizeof(value), false, ULONG_TYPE);
  buffer = const_cast<void*>(params_3.GetBuffer());
  EXPECT_TRUE(NULL != buffer);

  size_t correct_size = params_3.OverrideSize(1);
  ccp = CrossCallParamsEx::CreateFromBuffer(buffer, 256, &out_size);
  EXPECT_TRUE(NULL == ccp);

  // The correct_size is 8 bytes aligned.
  params_3.OverrideSize(correct_size - 7);
  ccp = CrossCallParamsEx::CreateFromBuffer(buffer, 256, &out_size);
  EXPECT_TRUE(NULL == ccp);

  params_3.OverrideSize(correct_size);
  ccp = CrossCallParamsEx::CreateFromBuffer(buffer, 256, &out_size);
  EXPECT_TRUE(NULL != ccp);
}

// This structure is passed to the mock server threads to simulate
// the server side IPC so it has the required kernel objects.
struct ServerEvents {
  HANDLE ping;
  HANDLE pong;
  volatile LONG* state;
  HANDLE mutex;
};

// This is the server thread that quicky answers an IPC and exits.
DWORD WINAPI QuickResponseServer(PVOID param) {
  ServerEvents* events = reinterpret_cast<ServerEvents*>(param);
  DWORD wait_result = 0;
  wait_result = ::WaitForSingleObject(events->ping, INFINITE);
  ::InterlockedExchange(events->state, kAckChannel);
  ::SetEvent(events->pong);
  return wait_result;
}

class CrossCallParamsMock : public CrossCallParams {
 public:
  CrossCallParamsMock(uint32 tag, size_t params_count)
      :  CrossCallParams(tag, params_count) {
  }
 private:
  void* params[4];
};

void FakeOkAnswerInChannel(void* channel) {
  CrossCallReturn* answer = reinterpret_cast<CrossCallReturn*>(channel);
  answer->call_outcome = SBOX_ALL_OK;
}

// Create two threads that will quickly answer IPCs; the first one
// using channel 1 (channel 0 is busy) and one using channel 0. No time-out
// should occur.
TEST(IPCTest, ClientFastServer) {
  const size_t channel_size = kIPCChannelSize;
  size_t base_start = 0;
  IPCControl* client_control =
      MakeChannels(channel_size, 4096 * 2, &base_start);
  FixChannels(client_control, base_start, kIPCChannelSize, FIX_PONG_NOT_READY);
  client_control->server_alive = ::CreateMutex(NULL, FALSE, NULL);

  char* mem = reinterpret_cast<char*>(client_control);
  SharedMemIPCClient client(mem);

  ServerEvents events = {0};
  events.ping = client_control->channels[1].ping_event;
  events.pong = client_control->channels[1].pong_event;
  events.state = &client_control->channels[1].state;

  HANDLE t1 = ::CreateThread(NULL, 0, QuickResponseServer, &events, 0, NULL);
  ASSERT_TRUE(NULL != t1);
  ::CloseHandle(t1);

  void* buff0 = client.GetBuffer();
  EXPECT_TRUE(mem + client_control->channels[0].channel_base == buff0);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[1].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[2].state);

  void* buff1 = client.GetBuffer();
  EXPECT_TRUE(mem + client_control->channels[1].channel_base == buff1);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kBusyChannel, client_control->channels[1].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[2].state);

  EXPECT_EQ(0, client_control->channels[1].ipc_tag);

  uint32 tag = 7654;
  CrossCallReturn answer;
  CrossCallParamsMock* params1 = new(buff1) CrossCallParamsMock(tag, 1);
  FakeOkAnswerInChannel(buff1);

  ResultCode result = client.DoCall(params1, &answer);
  if (SBOX_ERROR_CHANNEL_ERROR != result)
    client.FreeBuffer(buff1);

  EXPECT_TRUE(SBOX_ALL_OK == result);
  EXPECT_EQ(tag, client_control->channels[1].ipc_tag);
  EXPECT_EQ(kBusyChannel, client_control->channels[0].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[1].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[2].state);

  HANDLE t2 = ::CreateThread(NULL, 0, QuickResponseServer, &events, 0, NULL);
  ASSERT_TRUE(NULL != t2);
  ::CloseHandle(t2);

  client.FreeBuffer(buff0);
  events.ping = client_control->channels[0].ping_event;
  events.pong = client_control->channels[0].pong_event;
  events.state = &client_control->channels[0].state;

  tag = 4567;
  CrossCallParamsMock* params2 = new(buff0) CrossCallParamsMock(tag, 1);
  FakeOkAnswerInChannel(buff0);

  result = client.DoCall(params2, &answer);
  if (SBOX_ERROR_CHANNEL_ERROR != result)
    client.FreeBuffer(buff0);

  EXPECT_TRUE(SBOX_ALL_OK == result);
  EXPECT_EQ(tag, client_control->channels[0].ipc_tag);
  EXPECT_EQ(kFreeChannel, client_control->channels[0].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[1].state);
  EXPECT_EQ(kFreeChannel, client_control->channels[2].state);

  CloseChannelEvents(client_control);
  ::CloseHandle(client_control->server_alive);

  delete[] reinterpret_cast<char*>(client_control);
}

// This is the server thread that very slowly answers an IPC and exits. Note
// that the pong event needs to be signaled twice.
DWORD WINAPI SlowResponseServer(PVOID param) {
  ServerEvents* events = reinterpret_cast<ServerEvents*>(param);
  DWORD wait_result = 0;
  wait_result = ::WaitForSingleObject(events->ping, INFINITE);
  ::Sleep(kIPCWaitTimeOut1 + kIPCWaitTimeOut2 + 200);
  ::InterlockedExchange(events->state, kAckChannel);
  ::SetEvent(events->pong);
  return wait_result;
}

// This thread's job is to keep the mutex locked.
DWORD WINAPI MainServerThread(PVOID param) {
  ServerEvents* events = reinterpret_cast<ServerEvents*>(param);
  DWORD wait_result = 0;
  wait_result = ::WaitForSingleObject(events->mutex, INFINITE);
  Sleep(kIPCWaitTimeOut1 * 20);
  return wait_result;
}

// Creates a server thread that answers the IPC so slow that is guaranteed to
// trigger the time-out code path in the client. A second thread is created
// to hold locked the server_alive mutex: this signals the client that the
// server is not dead and it retries the wait.
TEST(IPCTest, ClientSlowServer) {
  size_t base_start = 0;
  IPCControl* client_control =
      MakeChannels(kIPCChannelSize, 4096*2, &base_start);
  FixChannels(client_control, base_start, kIPCChannelSize, FIX_PONG_NOT_READY);
  client_control->server_alive = ::CreateMutex(NULL, FALSE, NULL);

  char* mem = reinterpret_cast<char*>(client_control);
  SharedMemIPCClient client(mem);

  ServerEvents events = {0};
  events.ping = client_control->channels[0].ping_event;
  events.pong = client_control->channels[0].pong_event;
  events.state = &client_control->channels[0].state;

  HANDLE t1 = ::CreateThread(NULL, 0, SlowResponseServer, &events, 0, NULL);
  ASSERT_TRUE(NULL != t1);
  ::CloseHandle(t1);

  ServerEvents events2 = {0};
  events2.pong = events.pong;
  events2.mutex = client_control->server_alive;

  HANDLE t2 = ::CreateThread(NULL, 0, MainServerThread, &events2, 0, NULL);
  ASSERT_TRUE(NULL != t2);
  ::CloseHandle(t2);

  ::Sleep(1);

  void* buff0 = client.GetBuffer();
  uint32 tag = 4321;
  CrossCallReturn answer;
  CrossCallParamsMock* params1 = new(buff0) CrossCallParamsMock(tag, 1);
  FakeOkAnswerInChannel(buff0);

  ResultCode result = client.DoCall(params1, &answer);
  if (SBOX_ERROR_CHANNEL_ERROR != result)
    client.FreeBuffer(buff0);

  EXPECT_TRUE(SBOX_ALL_OK == result);
  EXPECT_EQ(tag, client_control->channels[0].ipc_tag);
  EXPECT_EQ(kFreeChannel, client_control->channels[0].state);

  CloseChannelEvents(client_control);
  ::CloseHandle(client_control->server_alive);
  delete[] reinterpret_cast<char*>(client_control);
}

// This test-only IPC dispatcher has two handlers with the same signature
// but only CallOneHandler should be used.
class UnitTestIPCDispatcher : public Dispatcher {
 public:
  enum {
    CALL_ONE_TAG = 78,
    CALL_TWO_TAG = 87
  };

  UnitTestIPCDispatcher();
  ~UnitTestIPCDispatcher() {};

  virtual bool SetupService(InterceptionManager* manager, int service) {
    return true;
  }

 private:
  bool CallOneHandler(IPCInfo* ipc, HANDLE p1, DWORD p2) {
    ipc->return_info.extended[0].handle = p1;
    ipc->return_info.extended[1].unsigned_int = p2;
    return true;
  }

  bool CallTwoHandler(IPCInfo* ipc, HANDLE p1, DWORD p2) {
    return true;
  }
};

UnitTestIPCDispatcher::UnitTestIPCDispatcher() {
  static const IPCCall call_one = {
    {CALL_ONE_TAG, VOIDPTR_TYPE, ULONG_TYPE},
    reinterpret_cast<CallbackGeneric>(
        &UnitTestIPCDispatcher::CallOneHandler)
  };
  static const IPCCall call_two = {
    {CALL_TWO_TAG, VOIDPTR_TYPE, ULONG_TYPE},
    reinterpret_cast<CallbackGeneric>(
        &UnitTestIPCDispatcher::CallTwoHandler)
  };
  ipc_calls_.push_back(call_one);
  ipc_calls_.push_back(call_two);
}

// This test does most of the shared memory IPC client-server roundtrip
// and tests the packing, unpacking and call dispatching.
TEST(IPCTest, SharedMemServerTests) {
  size_t base_start = 0;
  IPCControl* client_control =
      MakeChannels(kIPCChannelSize, 4096, &base_start);
  client_control->server_alive = HANDLE(1);
  FixChannels(client_control, base_start, kIPCChannelSize, FIX_PONG_READY);

  char* mem = reinterpret_cast<char*>(client_control);
  SharedMemIPCClient client(mem);

  CrossCallReturn answer;
  HANDLE bar = HANDLE(191919);
  DWORD foo = 6767676;
  CrossCall(client, UnitTestIPCDispatcher::CALL_ONE_TAG, bar, foo, &answer);
  void* buff = client.GetBuffer();
  ASSERT_TRUE(NULL != buff);

  UnitTestIPCDispatcher dispatcher;
  // Since we are directly calling InvokeCallback, most of this structure
  // can be set to NULL.
  sandbox::SharedMemIPCServer::ServerControl srv_control = {
      NULL, NULL, kIPCChannelSize, NULL,
      reinterpret_cast<char*>(client_control),
      NULL, &dispatcher, {0} };

  sandbox::CrossCallReturn call_return = {0};
  EXPECT_TRUE(SharedMemIPCServer::InvokeCallback(&srv_control, buff,
                                                 &call_return));
  EXPECT_EQ(SBOX_ALL_OK, call_return.call_outcome);
  EXPECT_TRUE(bar == call_return.extended[0].handle);
  EXPECT_EQ(foo, call_return.extended[1].unsigned_int);

  CloseChannelEvents(client_control);
  delete[] reinterpret_cast<char*>(client_control);
}

}  // namespace sandbox