/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#define ESIMD_UNROLL _Pragma("unroll")


//
// ESIMD kernel for SpMV using ESB4 format -- useful for trmv with blockptr_st, blockptr_en, colind, values
//
// note that in ESB format, blockptr_st points to vector of block where lower or upper starts for the
// block and blockptr_en points to where lower or upper ends on the block, but there may be elements
// within this set of vectors that are on the otherside, so we must load and then check with mask while
// accumulating
//
auto mv_esb4_esimd_kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL
{
    local_int_t id = item.get_global_id(0);

    std::array<local_int_t, unroll> block, start_row, st_vec, en_vec;
    std::array<bool, unroll> isLBOverlap;

    esimd::simd<double, BLOCK_SIZE> x_vec(1.0), vals(0.0), zero_vec(0.0);
    std::array<esimd::simd<double, BLOCK_SIZE>, unroll> y_vec, z_vec;

    // offset for masking out elements in the L (colind < start_row + {0,..,BLOCK_SIZE})
    esimd::simd<local_int_t, BLOCK_SIZE> offset0(0,1), indices;
    std::array<esimd::simd<local_int_t, BLOCK_SIZE>, unroll> offset;

    ESIMD_UNROLL
    for (local_int_t i = 0; i < unroll; i++) {
        block[i] = unroll * id + i;
        if (applyReorder) block[i] = reorder[block[i]];
        start_row[i]  = block[i] * BLOCK_SIZE;
        st_vec[i] = blockptr_st[block[i]];
        en_vec[i] = blockptr_en[block[i]];
        offset[i] = offset0 + start_row[i];
        y_vec[i] = z_vec[i] = zero_vec;
    }

    ESIMD_UNROLL
    for (local_int_t i = 0; i < unroll; i++) {

        // Either (U + B)x or (L + B)x part
        ESIMD_UNROLL
        for (local_int_t j = st_vec[i]; j < en_vec[i]; j++) {
            
            indices = esimd_lsc_block_load<local_int_t, local_int_t, BLOCK_SIZE, st, uc>(colind, j * BLOCK_SIZE);
#if defined(TRMV_UB)
            // Mask indices left of and including diagonals (U + B)
            esimd::simd_mask<BLOCK_SIZE> mask = indices <= offset[i];
            // Mask indices in B
            esimd::simd_mask<BLOCK_SIZE> mask1 = indices >= nrows;
#elif defined(TRMV_LB)
            // Mask indices in [offset, nrows), in U, or are less than zero.
            esimd::simd_mask<BLOCK_SIZE> mask = (indices >= offset[i]) || (indices < 0);
#endif
            vals = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, st, uc>(values, j * BLOCK_SIZE);
            
            x_vec = esimd_exp::lsc_gather<double, 1, ds::default_size, uc, ca, BLOCK_SIZE>(
                x, indices * sizeof(double), !mask, zero_vec);
            y_vec[i] += x_vec * vals;
            
#if defined(TRMV_UB)
            x_vec.merge(zero_vec, mask1);
            z_vec[i] += x_vec * vals;
#endif
        }
#if defined(TRMV_LB)
        // Bx part
        ESIMD_UNROLL
        for (local_int_t j = nonloc_st[block[i]]; j < nonloc_en[block[i]]; j++) {
            
            indices = esimd_lsc_block_load<local_int_t, local_int_t, BLOCK_SIZE, st, uc>(colind, j * BLOCK_SIZE);
            // Mask indices in (offset, nrows), in U, or are less than zero.
            esimd::simd_mask<BLOCK_SIZE> mask = indices < nrows;
            
            vals = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, st, uc>(values, j * BLOCK_SIZE);
            
            x_vec = esimd_exp::lsc_gather<double, 1, ds::default_size, uc, ca, BLOCK_SIZE>(
                x, indices * sizeof(double), !mask, zero_vec);
            y_vec[i] += x_vec * vals;
        }
#endif
    }

    ESIMD_UNROLL
    for (local_int_t i = 0; i < unroll; i++) {
#if defined(TRMV_UB)
        esimd_lsc_block_store<double, local_int_t, BLOCK_SIZE, uc, uc>(y1, start_row[i], z_vec[i]);
        auto r_vec = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, uc, uc>(r, start_row[i]);
        y_vec[i] = r_vec - y_vec[i]; // Computes r - (U + B)x
#else
        // Update y part
        auto y_vec_old = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, uc, uc>(y, start_row[i]);
        y_vec[i] += y_vec_old;
#endif
        esimd_lsc_block_store<double, local_int_t, BLOCK_SIZE, uc, uc>(y, start_row[i], y_vec[i]);
    }
};
