// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/policy/cloud/cloud_policy_invalidator.h"

#include "base/bind.h"
#include "base/hash.h"
#include "base/location.h"
#include "base/metrics/histogram.h"
#include "base/rand_util.h"
#include "base/sequenced_task_runner.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "base/values.h"
#include "components/invalidation/invalidation_service.h"
#include "components/policy/core/common/cloud/cloud_policy_client.h"
#include "components/policy/core/common/cloud/cloud_policy_refresh_scheduler.h"
#include "components/policy/core/common/cloud/enterprise_metrics.h"
#include "policy/policy_constants.h"
#include "sync/notifier/object_id_invalidation_map.h"

namespace policy {

const int CloudPolicyInvalidator::kMissingPayloadDelay = 5;
const int CloudPolicyInvalidator::kMaxFetchDelayDefault = 10000;
const int CloudPolicyInvalidator::kMaxFetchDelayMin = 1000;
const int CloudPolicyInvalidator::kMaxFetchDelayMax = 300000;
const int CloudPolicyInvalidator::kInvalidationGracePeriod = 10;
const int CloudPolicyInvalidator::kUnknownVersionIgnorePeriod = 30;
const int CloudPolicyInvalidator::kMaxInvalidationTimeDelta = 300;

CloudPolicyInvalidator::CloudPolicyInvalidator(
    CloudPolicyCore* core,
    const scoped_refptr<base::SequencedTaskRunner>& task_runner,
    scoped_ptr<base::Clock> clock)
    : state_(UNINITIALIZED),
      core_(core),
      task_runner_(task_runner),
      clock_(clock.Pass()),
      invalidation_service_(NULL),
      invalidations_enabled_(false),
      invalidation_service_enabled_(false),
      is_registered_(false),
      invalid_(false),
      invalidation_version_(0),
      unknown_version_invalidation_count_(0),
      weak_factory_(this),
      max_fetch_delay_(kMaxFetchDelayDefault),
      policy_hash_value_(0) {
  DCHECK(core);
  DCHECK(task_runner.get());
}

CloudPolicyInvalidator::~CloudPolicyInvalidator() {
  DCHECK(state_ == SHUT_DOWN);
}

void CloudPolicyInvalidator::Initialize(
    invalidation::InvalidationService* invalidation_service) {
  DCHECK(state_ == UNINITIALIZED);
  DCHECK(thread_checker_.CalledOnValidThread());
  DCHECK(invalidation_service);
  invalidation_service_ = invalidation_service;
  state_ = STOPPED;
  core_->AddObserver(this);
  if (core_->refresh_scheduler())
    OnRefreshSchedulerStarted(core_);
}

void CloudPolicyInvalidator::Shutdown() {
  DCHECK(state_ != SHUT_DOWN);
  DCHECK(thread_checker_.CalledOnValidThread());
  if (state_ == STARTED) {
    if (is_registered_)
      invalidation_service_->UnregisterInvalidationHandler(this);
    core_->store()->RemoveObserver(this);
    weak_factory_.InvalidateWeakPtrs();
  }
  if (state_ != UNINITIALIZED)
    core_->RemoveObserver(this);
  state_ = SHUT_DOWN;
}

void CloudPolicyInvalidator::OnInvalidatorStateChange(
    syncer::InvalidatorState state) {
  DCHECK(state_ == STARTED);
  DCHECK(thread_checker_.CalledOnValidThread());
  invalidation_service_enabled_ = state == syncer::INVALIDATIONS_ENABLED;
  UpdateInvalidationsEnabled();
}

void CloudPolicyInvalidator::OnIncomingInvalidation(
    const syncer::ObjectIdInvalidationMap& invalidation_map) {
  DCHECK(state_ == STARTED);
  DCHECK(thread_checker_.CalledOnValidThread());
  const syncer::SingleObjectInvalidationSet& list =
      invalidation_map.ForObject(object_id_);
  if (list.IsEmpty()) {
    NOTREACHED();
    return;
  }

  // Acknowledge all except the invalidation with the highest version.
  syncer::SingleObjectInvalidationSet::const_reverse_iterator it =
      list.rbegin();
  ++it;
  for ( ; it != list.rend(); ++it) {
    it->Acknowledge();
  }

  // Handle the highest version invalidation.
  HandleInvalidation(list.back());
}

std::string CloudPolicyInvalidator::GetOwnerName() const { return "Cloud"; }

void CloudPolicyInvalidator::OnCoreConnected(CloudPolicyCore* core) {}

void CloudPolicyInvalidator::OnRefreshSchedulerStarted(CloudPolicyCore* core) {
  DCHECK(state_ == STOPPED);
  DCHECK(thread_checker_.CalledOnValidThread());
  state_ = STARTED;
  OnStoreLoaded(core_->store());
  core_->store()->AddObserver(this);
}

void CloudPolicyInvalidator::OnCoreDisconnecting(CloudPolicyCore* core) {
  DCHECK(state_ == STARTED || state_ == STOPPED);
  DCHECK(thread_checker_.CalledOnValidThread());
  if (state_ == STARTED) {
    Unregister();
    core_->store()->RemoveObserver(this);
    state_ = STOPPED;
  }
}

void CloudPolicyInvalidator::OnStoreLoaded(CloudPolicyStore* store) {
  DCHECK(state_ == STARTED);
  DCHECK(thread_checker_.CalledOnValidThread());
  bool policy_changed = IsPolicyChanged(store->policy());

  if (is_registered_) {
    // Update the kMetricPolicyRefresh histogram.
    UMA_HISTOGRAM_ENUMERATION(
        kMetricPolicyRefresh,
        GetPolicyRefreshMetric(policy_changed),
        METRIC_POLICY_REFRESH_SIZE);

    // If the policy was invalid and the version stored matches the latest
    // invalidation version, acknowledge the latest invalidation.
    if (invalid_ && store->invalidation_version() == invalidation_version_)
      AcknowledgeInvalidation();
  }

  UpdateRegistration(store->policy());
  UpdateMaxFetchDelay(store->policy_map());
}

void CloudPolicyInvalidator::OnStoreError(CloudPolicyStore* store) {}

void CloudPolicyInvalidator::HandleInvalidation(
    const syncer::Invalidation& invalidation) {
  // Ignore old invalidations.
  if (invalid_ &&
      !invalidation.is_unknown_version() &&
      invalidation.version() <= invalidation_version_) {
    return;
  }

  // If there is still a pending invalidation, acknowledge it, since we only
  // care about the latest invalidation.
  if (invalid_)
    AcknowledgeInvalidation();

  // Get the version and payload from the invalidation.
  // When an invalidation with unknown version is received, use negative
  // numbers based on the number of such invalidations received. This
  // ensures that the version numbers do not collide with "real" versions
  // (which are positive) or previous invalidations with unknown version.
  int64 version;
  std::string payload;
  if (invalidation.is_unknown_version()) {
    version = -(++unknown_version_invalidation_count_);
  } else {
    version = invalidation.version();
    payload = invalidation.payload();
  }

  // Ignore the invalidation if it is expired.
  bool is_expired = IsInvalidationExpired(version);
  UMA_HISTOGRAM_ENUMERATION(
      kMetricPolicyInvalidations,
      GetInvalidationMetric(payload.empty(), is_expired),
      POLICY_INVALIDATION_TYPE_SIZE);
  if (is_expired) {
    invalidation.Acknowledge();
    return;
  }

  // Update invalidation state.
  invalid_ = true;
  invalidation_.reset(new syncer::Invalidation(invalidation));
  invalidation_version_ = version;

  // In order to prevent the cloud policy server from becoming overwhelmed when
  // a policy with many users is modified, delay for a random period of time
  // before fetching the policy. Delay for at least 20ms so that if multiple
  // invalidations are received in quick succession, only one fetch will be
  // performed.
  base::TimeDelta delay = base::TimeDelta::FromMilliseconds(
      base::RandInt(20, max_fetch_delay_));

  // If there is a payload, the policy can be refreshed at any time, so set
  // the version and payload on the client immediately. Otherwise, the refresh
  // must only run after at least kMissingPayloadDelay minutes.
  if (!payload.empty())
    core_->client()->SetInvalidationInfo(version, payload);
  else
    delay += base::TimeDelta::FromMinutes(kMissingPayloadDelay);

  // Schedule the policy to be refreshed.
  task_runner_->PostDelayedTask(
      FROM_HERE,
      base::Bind(
          &CloudPolicyInvalidator::RefreshPolicy,
          weak_factory_.GetWeakPtr(),
          payload.empty() /* is_missing_payload */),
      delay);
}

void CloudPolicyInvalidator::UpdateRegistration(
    const enterprise_management::PolicyData* policy) {
  // Create the ObjectId based on the policy data.
  // If the policy does not specify an the ObjectId, then unregister.
  if (!policy ||
      !policy->has_invalidation_source() ||
      !policy->has_invalidation_name()) {
    Unregister();
    return;
  }
  invalidation::ObjectId object_id(
      policy->invalidation_source(),
      policy->invalidation_name());

  // If the policy object id in the policy data is different from the currently
  // registered object id, update the object registration.
  if (!is_registered_ || !(object_id == object_id_))
    Register(object_id);
}

void CloudPolicyInvalidator::Register(const invalidation::ObjectId& object_id) {
  // Register this handler with the invalidation service if needed.
  if (!is_registered_) {
    OnInvalidatorStateChange(invalidation_service_->GetInvalidatorState());
    invalidation_service_->RegisterInvalidationHandler(this);
  }

  // Update internal state.
  if (invalid_)
    AcknowledgeInvalidation();
  is_registered_ = true;
  object_id_ = object_id;
  UpdateInvalidationsEnabled();

  // Update registration with the invalidation service.
  syncer::ObjectIdSet ids;
  ids.insert(object_id);
  invalidation_service_->UpdateRegisteredInvalidationIds(this, ids);
}

void CloudPolicyInvalidator::Unregister() {
  if (is_registered_) {
    if (invalid_)
      AcknowledgeInvalidation();
    invalidation_service_->UpdateRegisteredInvalidationIds(
        this,
        syncer::ObjectIdSet());
    invalidation_service_->UnregisterInvalidationHandler(this);
    is_registered_ = false;
    UpdateInvalidationsEnabled();
  }
}

void CloudPolicyInvalidator::UpdateMaxFetchDelay(const PolicyMap& policy_map) {
  int delay;

  // Try reading the delay from the policy.
  const base::Value* delay_policy_value =
      policy_map.GetValue(key::kMaxInvalidationFetchDelay);
  if (delay_policy_value && delay_policy_value->GetAsInteger(&delay)) {
    set_max_fetch_delay(delay);
    return;
  }

  set_max_fetch_delay(kMaxFetchDelayDefault);
}

void CloudPolicyInvalidator::set_max_fetch_delay(int delay) {
  if (delay < kMaxFetchDelayMin)
    max_fetch_delay_ = kMaxFetchDelayMin;
  else if (delay > kMaxFetchDelayMax)
    max_fetch_delay_ = kMaxFetchDelayMax;
  else
    max_fetch_delay_ = delay;
}

void CloudPolicyInvalidator::UpdateInvalidationsEnabled() {
  bool invalidations_enabled = invalidation_service_enabled_ && is_registered_;
  if (invalidations_enabled_ != invalidations_enabled) {
    invalidations_enabled_ = invalidations_enabled;
    if (invalidations_enabled)
      invalidations_enabled_time_ = clock_->Now();
    core_->refresh_scheduler()->SetInvalidationServiceAvailability(
        invalidations_enabled);
  }
}

void CloudPolicyInvalidator::RefreshPolicy(bool is_missing_payload) {
  DCHECK(thread_checker_.CalledOnValidThread());
  // In the missing payload case, the invalidation version has not been set on
  // the client yet, so set it now that the required time has elapsed.
  if (is_missing_payload)
    core_->client()->SetInvalidationInfo(invalidation_version_, std::string());
  core_->refresh_scheduler()->RefreshSoon();
}

void CloudPolicyInvalidator::AcknowledgeInvalidation() {
  DCHECK(invalid_);
  invalid_ = false;
  core_->client()->SetInvalidationInfo(0, std::string());
  invalidation_->Acknowledge();
  invalidation_.reset();
  // Cancel any scheduled policy refreshes.
  weak_factory_.InvalidateWeakPtrs();
}

bool CloudPolicyInvalidator::IsPolicyChanged(
    const enterprise_management::PolicyData* policy) {
  // Determine if the policy changed by comparing its hash value to the
  // previous policy's hash value.
  uint32 new_hash_value = 0;
  if (policy && policy->has_policy_value())
    new_hash_value = base::Hash(policy->policy_value());
  bool changed = new_hash_value != policy_hash_value_;
  policy_hash_value_ = new_hash_value;
  return changed;
}

bool CloudPolicyInvalidator::IsInvalidationExpired(int64 version) {
  base::Time last_fetch_time = base::Time::UnixEpoch() +
      base::TimeDelta::FromMilliseconds(core_->store()->policy()->timestamp());

  // If the version is unknown, consider the invalidation invalid if the
  // policy was fetched very recently.
  if (version < 0) {
    base::TimeDelta elapsed = clock_->Now() - last_fetch_time;
    return elapsed.InSeconds() < kUnknownVersionIgnorePeriod;
  }

  // The invalidation version is the timestamp in microseconds. If the
  // invalidation occurred before the last policy fetch, then the invalidation
  // is expired. Time is added to the invalidation to err on the side of not
  // expired.
  base::Time invalidation_time = base::Time::UnixEpoch() +
      base::TimeDelta::FromMicroseconds(version) +
      base::TimeDelta::FromSeconds(kMaxInvalidationTimeDelta);
  return invalidation_time < last_fetch_time;
}

int CloudPolicyInvalidator::GetPolicyRefreshMetric(bool policy_changed) {
  if (policy_changed) {
    if (invalid_)
      return METRIC_POLICY_REFRESH_INVALIDATED_CHANGED;
    if (GetInvalidationsEnabled())
      return METRIC_POLICY_REFRESH_CHANGED;
    return METRIC_POLICY_REFRESH_CHANGED_NO_INVALIDATIONS;
  }
  if (invalid_)
    return METRIC_POLICY_REFRESH_INVALIDATED_UNCHANGED;
  return METRIC_POLICY_REFRESH_UNCHANGED;
}

int CloudPolicyInvalidator::GetInvalidationMetric(bool is_missing_payload,
                                                  bool is_expired) {
  if (is_expired) {
    if (is_missing_payload)
      return POLICY_INVALIDATION_TYPE_NO_PAYLOAD_EXPIRED;
    return POLICY_INVALIDATION_TYPE_EXPIRED;
  }
  if (is_missing_payload)
    return POLICY_INVALIDATION_TYPE_NO_PAYLOAD;
  return POLICY_INVALIDATION_TYPE_NORMAL;
}

bool CloudPolicyInvalidator::GetInvalidationsEnabled() {
  if (!invalidations_enabled_)
    return false;
  // If invalidations have been enabled for less than the grace period, then
  // consider invalidations to be disabled for metrics reporting.
  base::TimeDelta elapsed = clock_->Now() - invalidations_enabled_time_;
  return elapsed.InSeconds() >= kInvalidationGracePeriod;
}

}  // namespace policy