/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

// clang-format off
#include <map>
#include <tuple>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>

#include "fbgemm_gpu/embedding_forward_split_cpu.h"
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm_gpu/utils/cpu_utils.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"

#if FBGEMM_GPU_MEMCHECK
#define FBGEMM_MEM_CHECK_ONLY
#else
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
#endif

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

namespace {
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_approx_cpu_kernel(
    Tensor grad_output,
    Tensor host_weights,
    const at::TensorAccessor<int64_t, 1> weights_offsets_data,
    const at::TensorAccessor<int, 1> D_offsets_data,
    Tensor indices,
    Tensor offsets,
    int64_t pooling_mode,
    Tensor indice_weights,
    int T,
    int B,
    {% if "momentum1_offsets" in args.split_function_arg_names %}
    const at::TensorAccessor<int64_t, 1> momentum1_offsets_data,
    {% endif %}
    {% if "momentum2_offsets" in args.split_function_arg_names %}
    const at::TensorAccessor<int64_t, 1> momentum2_offsets_data,
    {% endif %}
    {{ args.split_cpu_kernel_args | join(", ") }}) {
  auto grad_output_data = grad_output.accessor<grad_t, 2>();
  auto host_weights_data = host_weights.accessor<scalar_t, 1>();
  
  [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_approx_cpu_kernel";
  const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
  const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);

  // If indice_weights are not defined, then this accessor won't be used
  auto indice_weights_data = indice_weights.defined()
      ? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
      : at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr);
for (const auto t : c10::irange(T)) {
    int feature_begin = t; // to conform interface with exact
    const auto D_begin = D_offsets_data[t];
    const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
    const auto table_begin = weights_offsets_data[t];
    at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (const auto b : c10::irange(b_begin,b_end)) {
        const auto pool_begin = offsets_data[t * B + b];
        const auto pool_end = offsets_data[t * B + b + 1];
        const auto L = pool_end - pool_begin;
        const double scale_factor =
            // NOTE: MEAN pooling will not work with indice_weights!
            (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && !indice_weights.defined() && L > 0)
            ? 1.0 / L
            : 1.0;
for (const auto p : c10::irange(pool_begin,pool_end)) {
          auto idx = indices_data[p];
          [[maybe_unused]] const int64_t embedding_begin = table_begin + idx * D;
          scalar_t grad_buffer[D];
for (const auto d : c10::irange(D)) {
            grad_buffer[d] = scale_factor *
                (indice_weights.defined()
                     ? static_cast<scalar_t>(grad_output_data[b][D_begin + d] * indice_weights_data[p])
                     : static_cast<scalar_t>(grad_output_data[b][D_begin + d]));
          }
          {{ split_weight_update_cpu }};
        } // for each p
      } // for each b
    }); // parallel for B
  } // for each t
}
} // namespace

// The template for approximate optimizers
{{ "void" if not dense else "Tensor" }}
split_embedding_backward_codegen_{{ optimizer }}_cpu(
    Tensor grad_output,
    Tensor host_weights,
    {% if not dense %}
    Tensor weights_placements,
    {% endif %}
    Tensor weights_offsets,
    Tensor D_offsets,
    int64_t max_D,
    Tensor hash_size_cumsum,
    int64_t total_hash_size_bits,
    Tensor indices,
    Tensor offsets,
    int64_t pooling_mode,
    Tensor indice_weights,
    {% if not dense %}
    bool stochastic_rounding,
    {% endif %}
    {{args.split_function_args | join(", ")}},
    int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)
) {
  int64_t T = D_offsets.numel() - 1;
  TORCH_CHECK_GT(T, 0);
  // offsets = [T x B  + 1]
  int64_t B = (offsets.size(0) - 1) / T;
  TORCH_CHECK_GE(B, 0);

  const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
  const auto D_offsets_data = D_offsets.accessor<int, 1>();
  {%if "momentum1_offsets" in args.split_function_arg_names %}
  const auto momentum1_offsets_data = momentum1_offsets.accessor<int64_t, 1>();
  {% endif %}
  {%if "momentum2_offsets" in args.split_function_arg_names %}
  const auto momentum2_offsets_data = momentum2_offsets.accessor<int64_t, 1>();
  {% endif %}

  TORCH_CHECK_EQ(host_weights.dim(), 1);

  {% if optimizer == "approx_rowwise_adagrad" %}

  // TODO: fp16 and weighted
  bool use_fbgemm =
      (host_weights.scalar_type() == at::ScalarType::Float/* ||
       host_weights.scalar_type() == at::ScalarType::Half*/) &&
      grad_output.scalar_type() == at::ScalarType::Float &&
      !indice_weights.defined() && static_cast<PoolingMode>(pooling_mode) == PoolingMode::SUM;

  if (use_fbgemm) {
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

      auto grad_stride = grad_output.size(1);
      const float* grad_output_data = grad_output.data_ptr<float>();
      float* host_weights_data = host_weights.data_ptr<float>();
      
      const auto* indices_data = indices.data_ptr<index_t>();
      const auto* offsets_data = offsets.data_ptr<index_t>();

      const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
      float* momentum1_data = momentum1_host.data_ptr<float>();

      at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
        int t_begin = tb_begin / B;
        int t_end = (tb_end + B - 1) / B;

        for (const auto t : c10::irange(t_begin,t_end)) {
          auto D_begin = D_offsets_data[t];
          auto D = D_offsets_data[t + 1] - D_offsets_data[t];
          auto table_begin = weights_offsets_data[t];
          auto momentum_begin = momentum1_offsets_data[t];

          int64_t hash_size;
          int t_temp = t + 1;
          do {
            hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
            ++t_temp;
          } while (hash_size == 0);

          int b_begin = (t == t_begin) ? tb_begin % B : 0;
          int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

          auto kernel =
              fbgemm::GenerateRowWiseSparseAdaGradFused<index_t, index_t, float>(
                  D,
                  /*prefetch=*/16,
                  /*use_offsets=*/true,
                  /*use_stochastic_round=*/true,
                  /*grad_stride=*/grad_stride);
          auto offsets_begin_ptr = offsets_data + t * B + b_begin;
          auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
          bool success = kernel(
              b_end - b_begin,
              index_size,
              hash_size,
              reinterpret_cast<float*>(host_weights_data + table_begin),
              reinterpret_cast<const float*>(
                  grad_output_data + b_begin * grad_stride + D_begin),
              reinterpret_cast<float*>(momentum1_data + momentum_begin),
              indices_data + *offsets_begin_ptr,
              offsets_begin_ptr,
              eps,
              // fbgemm follows caffe2 convention of negative learning rate
              -learning_rate);

          if (!success) {
            fbgemm_gpu::report_embedding_error(
              t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
          }
        }
      }); // parallel_for
    }); // dispatch indices.scalar_type()

    return;
  } // use_fbgemm

  {% endif %}

  AT_DISPATCH_INDEX_TYPES(
    indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {
  
      FBGEMM_DISPATCH_FLOAT_AND_HALF(
        grad_output.scalar_type(), "split_embedding_backward_approx_cpu_kernel_2", [&] {
        using grad_t = scalar_t;

          FBGEMM_DISPATCH_FLOAT_AND_HALF(
            host_weights.scalar_type(), "split_embedding_backward_approx_cpu_kernel_3", [&] {
              split_embedding_backward_approx_cpu_kernel<index_t, scalar_t, grad_t>(
                  grad_output,
                  host_weights,
                  weights_offsets_data,
                  D_offsets_data,
                  indices,
                  offsets,
                  pooling_mode,
                  indice_weights,
                  T,
                  B,
                  {% if "momentum1_offsets" in args.split_function_arg_names %}
                  momentum1_offsets_data,
                  {% endif %}
                  {% if "momentum2_offsets" in args.split_function_arg_names %}
                  momentum2_offsets_data,
                  {% endif %}
                  {{ args.split_cpu_kernel_arg_constructors | join(", ") }});
            }); // dispatch host_weights.scalar_type()
        }); // dispatch grad_output.scalar_type()
    }); // dispatch indices.scalar_type()

  return;
}
// clang-format on
