/*******************************************************************************
* Copyright (C) 2025 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of oneapi::mkl::rng::device::count_engine_adaptor
*       to produce random numbers using gamma distribution on a SYCL device (CPU, GPU).
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>

#include <sycl/sycl.hpp>

#include "oneapi/mkl/rng/device.hpp"

#include "common_for_examples.hpp"

// example parameters
constexpr std::uint64_t seed = 777;
constexpr std::size_t n_per_item = 20;
constexpr std::size_t n = 1024 * n_per_item;
constexpr int n_print = 10;

namespace rng_device = oneapi::mkl::rng::device;

template <typename Type>
int run_example(sycl::queue& queue) {
    // prepare array for random numbers
    using allocator_t = sycl::usm_allocator<Type, sycl::usm::alloc::shared>;
    allocator_t allocator(queue);

    std::vector<Type, allocator_t> average_vec(n / n_per_item, allocator);
    std::vector<Type, allocator_t> r_count_vec(n / n_per_item, allocator);
    Type* average = average_vec.data();
    Type* r_count = r_count_vec.data();

    // submit a kernel to generate on device
    try {
        queue.parallel_for(sycl::range<1>(n / n_per_item), [=](sycl::item<1> item) {
            size_t item_id = item.get_id(0);
            rng_device::count_engine_adaptor<rng_device::mcg59<1>> adaptor
                (seed, item_id * n_per_item);
            rng_device::gamma<Type> distr(2.0f, 0.1f, 0.9f);

            Type res(0);
            for(std::size_t i = 0; i < n_per_item; i++) {
                res += rng_device::generate(distr, adaptor);
            }
            average[item_id] = res / n_per_item;
            r_count[item_id] = adaptor.get_count();
        }).wait_and_throw();
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
        return 1;
    }

    std::cout << "\t\tOutput of generator:" << std::endl;

    std::cout << "first " << n_print << " numbers of " << n << ": " << std::endl;
    for (int i = 0; i < n_print; i++) {
        std::cout << average[i] << " ";
    }
    std::cout << std::endl;

    std::cout << "first " << n_print << " engine calls of " << n << ": " << std::endl;
    for (int i = 0; i < n_print; i++) {
        std::cout << r_count[i] << " ";
    }
    std::cout << std::endl;

    return 0;
}

//
// description of example setup, APIs used
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Example to use count_engine_adaptor class: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using APIs:" << std::endl;
    std::cout << "# mcg59 gamma" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//

int main() {
    print_example_banner();

    // handler to catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception:\n" << e.what() << std::endl;
            }
        }
    };

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);
        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            sycl::queue queue(my_dev, exception_handler);

            std::cout << "\n\tRunning with single precision real data type:" << std::endl;
            if (run_example<float>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            if (isDoubleSupported(my_dev)) {
                std::cout << "\n\tRunning with double precision real data type:" << std::endl;
                if (run_example<double>(queue)) {
                    std::cout << "FAILED" << std::endl;
                    return 1;
                }
            }
            else {
                std::cout << "Double precision is not supported for this device" << std::endl;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
