/*
 * This file is part of the GROMACS molecular simulation package.
 *
 * Copyright 2025- The GROMACS Authors
 * and the project initiators Erik Lindahl, Berk Hess and David van der Spoel.
 * Consult the AUTHORS/COPYING files and https://www.gromacs.org for details.
 *
 * GROMACS is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 2.1
 * of the License, or (at your option) any later version.
 *
 * GROMACS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with GROMACS; if not, see
 * https://www.gnu.org/licenses, or write to the Free Software Foundation,
 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
 *
 * If you want to redistribute modifications to GROMACS, please
 * consider that scientific software is very special. Version
 * control is crucial - bugs must be traceable. We will be happy to
 * consider code for inclusion in the official distribution, but
 * derived work must not be called official GROMACS. Details are found
 * in the README & COPYING files - if they are missing, get the
 * official version at https://www.gromacs.org.
 *
 * To help us fund GROMACS development, we humbly ask that you cite
 * the research papers on the package. Check out https://www.gromacs.org.
 */
/*! \internal \file
 *
 * \brief Implements LINCS kernels using HIP
 *
 * This file contains HIP kernels of LINCS constraints algorithm.
 *
 * \author Artem Zhmurov <zhmurov@gmail.com>
 * \author Alan Gray <alang@nvidia.com>
 * \author Paul Bauer <paul.bauer.q@gmail.com>
 *
 * \ingroup module_mdlib
 */
#include "gromacs/gpu_utils/devicebuffer.h"
#include "gromacs/gpu_utils/gputraits.h"
#include "gromacs/gpu_utils/hip_sycl_kernel_utils.h"
#include "gromacs/gpu_utils/hiputils.h"
#include "gromacs/gpu_utils/typecasts_cuda_hip.h"
#include "gromacs/gpu_utils/vectype_ops_hip.h"
#include "gromacs/mdlib/lincs_gpu.h"
#include "gromacs/pbcutil/pbc_aiuc_hip.h"
#include "gromacs/utility/template_mp.h"

#include "lincs_gpu_internal.h"

#ifndef DOXYGEN

namespace gmx
{

//! Maximum number of threads in a block (for __launch_bounds__)
constexpr static int c_maxThreadsPerBlock = c_threadsPerBlock;

/*! \brief Main kernel for LINCS constraints.
 *
 * See Hess et al., J. Comput. Chem. 18: 1463-1472 (1997) for the description of the algorithm.
 *
 * In GPU version, one thread is responsible for all computations for one constraint. The blocks are
 * filled in a way that no constraint is coupled to the constraint from the next block. This is
 * achieved by moving active threads to the next block, if the correspondent group of coupled
 * constraints is to big to fit the current thread block. This may leave some 'dummy' threads in the
 * end of the thread block, i.e. threads that are not required to do actual work. Since constraints
 * from different blocks are not coupled, there is no need to synchronize across the device.
 * However, extensive communication in a thread block are still needed.
 *
 * \param[in,out] kernelParams  All parameters and pointers for the kernel condensed in single
 * struct. \param[in]     invdt         Inverse timestep (needed to update velocities).
 */
template<bool updateVelocities, bool computeVirial>
__launch_bounds__(c_maxThreadsPerBlock) __global__ void lincsKernel(LincsGpuKernelParameters kernelParams,
                                                                    const float3* __restrict__ gm_x,
                                                                    float3*     gm_xp,
                                                                    float3*     gm_v,
                                                                    const float invdt)
{
    const PbcAiuc                 pbcAiuc               = kernelParams.pbcAiuc;
    const int                     numConstraintsThreads = kernelParams.numConstraintsThreads;
    const int                     numIterations         = kernelParams.numIterations;
    const int                     expansionOrder        = kernelParams.expansionOrder;
    AmdFastBuffer<const AtomPair> gm_constraints{ kernelParams.d_constraints };
    AmdFastBuffer<const float> gm_constraintsTargetLengths{ kernelParams.d_constraintsTargetLengths };
    AmdFastBuffer<const int> gm_coupledConstraintsCounts{ kernelParams.d_coupledConstraintsCounts };
    AmdFastBuffer<const int> gm_coupledConstraintsIndices{ kernelParams.d_coupledConstraintsIndices };
    AmdFastBuffer<const float> gm_massFactors{ kernelParams.d_massFactors };
    AmdFastBuffer<float>       gm_matrixA{ kernelParams.d_matrixA };
    AmdFastBuffer<const float> gm_inverseMasses{ kernelParams.d_inverseMasses };
    AmdFastBuffer<float>       gm_virialScaled{ kernelParams.d_virialScaled };
    AmdFastBuffer<const int>   gm_constraintGroupSize{
        kernelParams.d_constraintGroupsSizes
    }; // # of consecutive constraints sharing i-th atom

    const int threadIndex = blockIdx.x * c_threadsPerBlock + threadIdx.x;
    const int tid         = threadIdx.x;

    // numConstraintsThreads should be a integer multiple of blockSize (numConstraintsThreads = numBlocks*blockSize).
    // This is to ensure proper synchronizations and reduction. All array are padded to the required size.
    GMX_DEVICE_ASSERT(threadIndex < numConstraintsThreads);

    // Vectors connecting constrained atoms before algorithm was applied.
    // Needed to construct constrain matrix A
    extern __shared__ float3 sm_r[];

    const AtomPair& pair = gm_constraints[threadIndex];
    const int       i    = pair.i;
    const int       j    = pair.j;

    // Mass-scaled Lagrange multiplier
    float lagrangeScaled = 0.0F;

    float targetLength;
    float inverseMassi;
    float inverseMassj;
    float sqrtReducedMass;

    float3 xi;
    float3 xj;
    float3 rc;

    // i == -1 indicates dummy constraint at the end of the thread block.
    const bool isDummyThread = (i == -1);

    // Everything computed for these dummies will be equal to zero
    if (isDummyThread)
    {
        targetLength    = 0.0F;
        inverseMassi    = 0.0F;
        inverseMassj    = 0.0F;
        sqrtReducedMass = 0.0F;

        xi = make_float3(0.0F, 0.0F, 0.0F);
        xj = make_float3(0.0F, 0.0F, 0.0F);
        rc = make_float3(0.0F, 0.0F, 0.0F);
    }
    else
    {
        // Collecting data
        targetLength    = gm_constraintsTargetLengths[threadIndex];
        inverseMassi    = gm_inverseMasses[i];
        inverseMassj    = gm_inverseMasses[j];
        sqrtReducedMass = __frsqrt_rn(inverseMassi + inverseMassj);

        xi = gm_x[i];
        xj = gm_x[j];

        const float3 dx = pbcDxAiuc(pbcAiuc, xi, xj);

        const float rlen = __frsqrt_rn(dx.x * dx.x + dx.y * dx.y + dx.z * dx.z);
        rc               = rlen * dx;
    }

    sm_r[tid] = rc;
    // Make sure that all r's are saved into shared memory
    // before they are accessed in the loop below
    __syncthreads();

    /*
     * Constructing LINCS matrix (A)
     */
    // float4 transactions for higher bw
    __shared__ float4 sm_corr[c_threadsPerBlock];
    __shared__ float4 sm_xpi[c_threadsPerBlock];

    // Only non-zero values are saved (for coupled constraints)
    int coupledConstraintsCount = gm_coupledConstraintsCounts[threadIndex];
    for (int n = 0; n < coupledConstraintsCount; n++)
    {
        const int     index = n * numConstraintsThreads + threadIndex;
        const int     c1    = gm_coupledConstraintsIndices[index];
        const float3& rc1   = sm_r[c1];
        gm_matrixA[index]   = gm_massFactors[index] * (rc.x * rc1.x + rc.y * rc1.y + rc.z * rc1.z);
    }

    // Keep track if we are operating on a constraint group
    const int constraintGroupSize = gm_constraintGroupSize[threadIndex];
    // This is only valid if the value has been changed away from -1
    const bool haveValidConstraintGroup = constraintGroupSize != -1;
    // Any number greater than 0 means we have grouped constraints
    const bool haveGroupedConstraints = constraintGroupSize > 0;

    // Skipping in dummy threads
    if (!isDummyThread)
    {
        xi = gm_xp[i];
        xj = gm_xp[j];
    }

    const float3 dx = pbcDxAiuc(pbcAiuc, xi, xj);

    float sol = sqrtReducedMass * ((rc.x * dx.x + rc.y * dx.y + rc.z * dx.z) - targetLength);

    /*
     *  Inverse matrix using a set of expansionOrder matrix multiplications
     */

    // Make sure that we don't overwrite the sm_r[..] array.
    __syncthreads();

    // This will use the same memory space as sm_r, which is no longer needed.
    extern __shared__ float sm_rhs[];
    // Save current right-hand-side vector in the shared memory
    sm_rhs[tid] = sol;

    for (int rec = 0; rec < expansionOrder; rec++)
    {
        // Making sure that all sm_rhs are saved before they are accessed in a loop below
        __syncthreads();
        float mvb = 0.0F;

        for (int n = 0; n < coupledConstraintsCount; n++)
        {
            const int index = n * numConstraintsThreads + threadIndex;
            const int c1    = gm_coupledConstraintsIndices[index];
            // Convolute current right-hand-side with A
            // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
            mvb = mvb + gm_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
        }
        // 'Switch' rhs vectors, save current result
        // These values will be accessed in the loop above during the next iteration.
        sm_rhs[tid + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
        sol                                               = sol + mvb;
    }
    // Current mass-scaled Lagrange multipliers
    lagrangeScaled = sqrtReducedMass * sol;

    // Save updated coordinates before correction for the rotational lengthening
    {
        const float3 tmp = rc * lagrangeScaled;

        if (haveValidConstraintGroup)
        {
            const float3 corr = -tmp * inverseMassi;
            sm_xpi[tid]       = make_float4(xi.x, xi.y, xi.z, 0.f);
            sm_corr[tid]      = make_float4(corr.x, corr.y, corr.z, 0.f);
        }

        __syncthreads();

        // Writing for all but dummy constraints
        if (!isDummyThread)
        {
            if (haveGroupedConstraints)
            {
                for (int gc = 0; gc <= constraintGroupSize; gc++)
                {
                    // thread with constraingGroupSize == 1 updates the shared memory position
                    // Keep it in LDS - no atomics
                    const float4 r_corr = sm_corr[tid + gc];
                    xi += make_float3(r_corr.x, r_corr.y, r_corr.z);
                }
                // update the coordinates to lds
                for (int gc = 0; gc <= constraintGroupSize; gc++)
                {
                    sm_xpi[tid + gc] = make_float4(xi.x, xi.y, xi.z, 0.f);
                }
            }
            else if (!haveValidConstraintGroup)
            {
                atomicAdd(&gm_xp[i], -tmp * inverseMassi);
            }
            atomicAdd(&gm_xp[j], tmp * inverseMassj);
        }
    }

    /*
     *  Correction for centripetal effects
     */
    for (int iter = 0; iter < numIterations; iter++)
    {
        // Make sure that all xp's are saved: atomic operation calls before are
        // communicating current xp[..] values across thread block.
        __syncthreads();

        if (!isDummyThread)
        {
            if (haveValidConstraintGroup)
            {
                xi = make_float3(sm_xpi[tid].x, sm_xpi[tid].y, sm_xpi[tid].z);
            }
            else
            {
                xi = gm_xp[i];
            }
            xj = gm_xp[j];
        }

        const float3 dx2   = pbcDxAiuc(pbcAiuc, xi, xj);
        const float  len2  = targetLength * targetLength;
        const float  dlen2 = 2.0F * len2 - gmxDeviceNorm2(dx2);

        // TODO A little bit more effective but slightly less readable version of the below would be:
        //      float proj = sqrtReducedMass*(targetLength - (dlen2 > 0.0f ? 1.0f : 0.0f)*dlen2*__frsqrt_rn(dlen2));
        float proj;
        if (dlen2 > 0.0F)
        {
            proj = sqrtReducedMass * (targetLength - dlen2 * __frsqrt_rn(dlen2));
        }
        else
        {
            proj = sqrtReducedMass * targetLength;
        }

        sm_rhs[tid] = proj;
        float sol2  = proj;

        /*
         * Same matrix inversion as above is used for updated data
         */
        for (int rec = 0; rec < expansionOrder; rec++)
        {
            // Make sure that all elements of rhs are saved into shared memory
            __syncthreads();
            float mvb = 0;

            for (int n = 0; n < coupledConstraintsCount; n++)
            {
                int index = n * numConstraintsThreads + threadIndex;
                int c1    = gm_coupledConstraintsIndices[index];

                mvb = mvb + gm_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
            }
            sm_rhs[tid + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
            sol2                                              = sol2 + mvb;
        }

        // Add corrections to Lagrange multipliers
        const float sqrtmu_sol2 = sqrtReducedMass * sol2;
        lagrangeScaled += sqrtmu_sol2;

        // Save updated coordinates for the next iteration
        // Dummy constraints are skipped
        if (!isDummyThread)
        {
            const float3 tmp = rc * sqrtmu_sol2;
            if (haveValidConstraintGroup)
            {
                const float3 corr = -tmp * inverseMassi;
                sm_corr[tid]      = make_float4(corr.x, corr.y, corr.z, 0.f);
            }
            __syncthreads();

            if (haveGroupedConstraints)
            {
                const float3 sumI = make_float3(0.f, 0.f, 0.f);
                for (int gc = 0; gc <= constraintGroupSize; gc++)
                {
                    const float4 r_corr = sm_corr[tid + gc];
                    xi += make_float3(r_corr.x, r_corr.y, r_corr.z);
                }
                // update the coordinates to lds
                for (int gc = 0; gc <= constraintGroupSize; gc++)
                {
                    sm_xpi[tid + gc] = make_float4(xi.x, xi.y, xi.z, 0.f);
                }
            }
            else if (!haveValidConstraintGroup)
            {
                atomicAdd(&gm_xp[i], -tmp * inverseMassi);
            }
            atomicAdd(&gm_xp[j], tmp * inverseMassj);
        }
    }

    // now flush sm_xp_i to global memory
    if (!isDummyThread)
    {
        if (haveGroupedConstraints)
        {
            gm_xp[i] = xi;
        }
    }

    // Updating particle velocities for all but dummy threads
    if constexpr (updateVelocities)
    {
        if (!isDummyThread)
        {
            const float3 tmp = rc * invdt * lagrangeScaled;
            // we don't stall on these, so just leave it like that
            atomicAdd(&gm_v[i], -tmp * inverseMassi);
            atomicAdd(&gm_v[j], tmp * inverseMassj);
        }
    }


    if constexpr (computeVirial)
    {
        // Virial is computed from Lagrange multiplier (lagrangeScaled), target constrain length
        // (targetLength) and the normalized vector connecting constrained atoms before
        // the algorithm was applied (rc). The evaluation of virial in each thread is
        // followed by basic reduction for the values inside single thread block.
        // Then, the values are reduced across grid by atomicAdd(...).
        //

        // Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
        // 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
        // lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
        // The sm_threadVirial[..] will overlap with the sm_r[..] and sm_rhs[..], but the latter
        // two are no longer in use, which we make sure by waiting for all threads in block.
        __syncthreads();
        extern __shared__ float sm_threadVirial[];
        const float             mult                 = targetLength * lagrangeScaled;
        sm_threadVirial[0 * c_threadsPerBlock + tid] = mult * rc.x * rc.x;
        sm_threadVirial[1 * c_threadsPerBlock + tid] = mult * rc.x * rc.y;
        sm_threadVirial[2 * c_threadsPerBlock + tid] = mult * rc.x * rc.z;
        sm_threadVirial[3 * c_threadsPerBlock + tid] = mult * rc.y * rc.y;
        sm_threadVirial[4 * c_threadsPerBlock + tid] = mult * rc.y * rc.z;
        sm_threadVirial[5 * c_threadsPerBlock + tid] = mult * rc.z * rc.z;

        __syncthreads();

        // Reduce up to one virial per thread block. All blocks are divided by half, the first
        // half of threads sums two virials. Then the first half is divided by two and the first
        // half of it sums two values. This procedure is repeated until only one thread is left.
        // Only works if the threads per blocks is a power of two (hence static_assert
        // in the beginning of the kernel).
        for (int divideBy = 2; divideBy <= static_cast<int>(c_threadsPerBlock); divideBy *= 2)
        {
            const int dividedAt = c_threadsPerBlock / divideBy;
            if (tid < dividedAt)
            {
                for (int d = 0; d < 6; d++)
                {
                    sm_threadVirial[d * c_threadsPerBlock + tid] +=
                            sm_threadVirial[d * c_threadsPerBlock + (tid + dividedAt)];
                }
            }
            // Syncronize if not within one warp
            if (dividedAt > warpSize / 2)
            {
                __syncthreads();
            }
            else
            {
                __builtin_amdgcn_wave_barrier();
            }
        }
        // First 6 threads in the block add the results of 6 tensor components to the global memory address.
        if (tid < 6)
        {
            atomicAdd(&(gm_virialScaled[tid]), sm_threadVirial[tid * c_threadsPerBlock]);
        }
    }
}

void launchLincsGpuKernel(LincsGpuKernelParameters*   kernelParams,
                          const DeviceBuffer<Float3>& d_x,
                          DeviceBuffer<Float3>        d_xp,
                          const bool                  updateVelocities,
                          DeviceBuffer<Float3>        d_v,
                          const real                  invdt,
                          const bool                  computeVirial,
                          const DeviceStream&         deviceStream)
{

    KernelLaunchConfig config;
    config.blockSize[0] = c_threadsPerBlock;
    config.blockSize[1] = 1;
    config.blockSize[2] = 1;
    config.gridSize[0]  = divideRoundUp(kernelParams->numConstraintsThreads, c_threadsPerBlock);
    config.gridSize[1]  = 1;
    config.gridSize[2]  = 1;

    gmx::dispatchTemplatedFunction(
            [&](auto updateVelocities_, auto computeVirial_)
            {
                auto kernelPtr = lincsKernel<updateVelocities_, computeVirial_>;


                // Shared memory is used to store:
                // -- Current coordinates (3 floats per thread)
                // -- Right-hand-sides for matrix inversion (2 floats per thread)
                // -- Virial tensor components (6 floats per thread)
                // Since none of these three are needed simultaneously, they can be saved at the same shared memory address
                // (i.e. correspondent arrays are intentionally overlapped in address space). Consequently, only
                // max{3, 2, 6} = 6 floats per thread are needed in case virial is computed, or max{3, 2} = 3 if not.
                if constexpr (computeVirial_)
                {
                    config.sharedMemorySize = c_threadsPerBlock * 6 * sizeof(float);
                }
                else
                {
                    config.sharedMemorySize = c_threadsPerBlock * 3 * sizeof(float);
                }

                const auto kernelArgs = prepareGpuKernelArguments(kernelPtr,
                                                                  config,
                                                                  kernelParams,
                                                                  asFloat3Pointer(&d_x),
                                                                  asFloat3Pointer(&d_xp),
                                                                  asFloat3Pointer(&d_v),
                                                                  &invdt);

                launchGpuKernel(kernelPtr,
                                config,
                                deviceStream,
                                nullptr,
                                "lincs_kernel<updateVelocities, computeVirial>",
                                kernelArgs);
            },
            updateVelocities,
            computeVirial);
}

} // namespace gmx

#endif // DOXYGEN
