#!/usr/bin/env python
#
# Copyright (C) 2023 Intel Corporation
#
# This software and the related documents are Intel copyrighted materials, and your use of them
# is governed by the express license under which they were provided to you ("License"). Unless
# the License provides otherwise, you may not use, modify, copy, publish, distribute, disclose
# or transmit this software or the related documents without Intel's prior written permission.
#
# This software and the related documents are provided as is, with no express or implied
# warranties, other than those that are expressly stated in the License.
#


import os
import sys
from csv import DictReader, reader
import subprocess
import logging

try:
    from defusedxml import ElementTree
except ImportError:
    print('Required defusedxml package version v0.7.1')
    sys.exit()

_logger = logging.getLogger(__name__)


def log_important_message(msg):
    symbols_num = len(msg)
    message = os.linesep + '*' * (4 + symbols_num)
    message += os.linesep + '* ' + msg + ' *'
    message += os.linesep + '*' * (4 + symbols_num)
    print(message)


def exit_with_message(msg):
    _logger.error(msg)
    exit(1)


def exit_path_exists(file_path):
    if not os.path.exists(file_path):
        exit_with_message(file_path + ' does not exist!')


def get_report(params, vtune_path, vtune_result):
    params = [vtune_path] + params + ['-r', vtune_result]

    _logger.debug('Calling {}'.format(' '.join(params)))
    process = subprocess.Popen(params, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8')
    stdout, stderr = process.communicate()
    if process.returncode not in [0]:
        _logger.error('Report {} finished with error code {}.'.format(' '.join(params), process.returncode))
        _logger.error('STDERR:' + os.linesep + stderr + os.linesep + 'STDOUT:' + os.linesep + stdout)
        exit_with_message('Unable to obtain report.')

    _logger.debug('STDERR:' + os.linesep + stderr + os.linesep + 'STDOUT' + os.linesep + stdout)
    return stdout.splitlines()


def convert_bdf_to_context_bdf(bdf):
    bdf_str = bdf.replace(':', '_')
    bdf_str = bdf_str.replace('.', '_')
    return bdf_str


class CombinedMetric:
    def __init__(self, name, value):
        self.name = name
        self.result_values = [value]

    def add_result_value(self, value):
        self.result_values.append(value)

    def get_value(self):
        if self.is_sum_metric():
            return sum(self.result_values)

        # Currently most of the metrics are averaged. Might be changed in the future.
        return sum(self.result_values) / len(self.result_values)

    def is_sum_metric(self):
        return 'Total Time' in self.name or 'Instance Count' in self.name


class ComputeKernel:
    # Columns containing kernel fields
    adapter_column = 'GPU Adapter'
    tile_column = 'GPU Stack'
    name_column = 'Computing Task'
    global_size_column = 'Work Size:Global'
    local_size_column = 'Work Size:Local'
    simd_column = 'Computing Task:SIMD Width'
    kernel_field_columns = (adapter_column, tile_column, name_column, global_size_column, local_size_column, simd_column)

    def __init__(self, adapter, tile, name, global_size, local_size, simd):
        self.key_fields = dict()
        self.key_fields[ComputeKernel.adapter_column] = adapter
        self.key_fields[ComputeKernel.tile_column] = tile
        self.key_fields[ComputeKernel.name_column] = name
        self.key_fields[ComputeKernel.global_size_column] = global_size
        self.key_fields[ComputeKernel.local_size_column] = local_size
        self.key_fields[ComputeKernel.simd_column] = simd
        self.unique_id = 0

    def __str__(self):
        return 'Task Name: {}; Adapter: {}; {}; Global Size: {}; Local Size: {}; Simd Width: {}; '.format(
            self.key_fields[ComputeKernel.name_column],
            self.key_fields[ComputeKernel.adapter_column],
            self.key_fields[ComputeKernel.tile_column],
            self.key_fields[ComputeKernel.global_size_column],
            self.key_fields[ComputeKernel.local_size_column],
            self.key_fields[ComputeKernel.simd_column])

    def __eq__(self, other):
        return (self.key_fields[ComputeKernel.adapter_column] == other.key_fields[ComputeKernel.adapter_column] and
                self.key_fields[ComputeKernel.tile_column] == other.key_fields[ComputeKernel.tile_column] and
                self.key_fields[ComputeKernel.name_column] == other.key_fields[ComputeKernel.name_column] and
                self.key_fields[ComputeKernel.global_size_column] == other.key_fields[ComputeKernel.global_size_column] and
                self.key_fields[ComputeKernel.local_size_column] == other.key_fields[ComputeKernel.local_size_column] and
                self.key_fields[ComputeKernel.simd_column] == other.key_fields[ComputeKernel.simd_column] and
                self.unique_id == other.unique_id)

    def __hash__(self):
        hash_values = list(self.key_fields.values()).append(self.unique_id)
        return hash(hash_values)


class VTuneReporter:
    def __init__(self):
        self._context_values = {}
        self._vtune_path = None

    def compare_result_gpus(self, master_result, results):
        master_gpus = self.get_gpu_info(master_result)
        master_bdfs = set(master_gpus.keys())
        for result in results:
            result_gpus = self.get_gpu_info(result)
            result_bdfs = set(result_gpus.keys())
            if result_bdfs != master_bdfs:
                exit_with_message(
                    'GPU adapters in {} do not match GPU adapters in {}'.format(result, master_result))

    def combine_metrics(self, master_result, results):
        master_kernel_metrics, all_metric_columns = self.get_kernels(master_result)
        merged_kernel_metrics = dict()
        for kernel, metrics in master_kernel_metrics.items():
            merged_kernel_metrics[kernel] = {k: CombinedMetric(k, v) for k, v in metrics.items()}

        for result in results:
            result_kernel_metrics, result_metric_columns = self.get_kernels(result)
            # Combine all unique metrics in the order as they go in results for the Grid header
            for x in result_metric_columns:
                if x not in all_metric_columns:
                    all_metric_columns.append(x)

            # Combine metrics
            log_important_message('Merging metrics')
            for kernel in master_kernel_metrics.keys():
                if kernel not in result_kernel_metrics:
                    _logger.warning('Unable to locate kernel "{}" in result {}.'.format(kernel, result))
                else:
                    for metric, value in result_kernel_metrics[kernel].items():
                        if metric not in merged_kernel_metrics[kernel]:
                            merged_kernel_metrics[kernel][metric] = CombinedMetric(metric, value)
                        else:
                            merged_kernel_metrics[kernel][metric].add_result_value(value)

        return (merged_kernel_metrics, all_metric_columns)

    def get_vtune_path(self):
        if self._vtune_path:
            return self._vtune_path

        vtune_binary_name = 'vtune.exe' if sys.platform == 'win32' else 'vtune'
        vtune_dir = os.path.realpath(os.path.dirname(__file__))
        self._vtune_path = os.path.join(vtune_dir, vtune_binary_name)
        if not (os.path.exists(self._vtune_path)):
            exit_with_message('Unable to locate vtune binary.')

        return self._vtune_path

    def fill_context_values(self, vtune_result):
        if self._context_values:
            return None  # Already done

        value_types = (
            ('value', str),
            ('{http://www.intel.com/2001/XMLSchema#double}value', float),
            ('{http://www.w3.org/2001/XMLSchema#boolean}value', lambda x: x == 'true'),
            ('{http://www.w3.org/2001/XMLSchema#unsignedShort}value', int),
            ('{http://www.w3.org/2001/XMLSchema#int}value', int),
            ('{http://www.w3.org/2001/XMLSchema#unsignedInt}value', int),
            ('{http://www.w3.org/2001/XMLSchema#unsignedByte}value', int),
            ('{http://www.w3.org/2001/XMLSchema#long}value', int),
            ('{http://www.w3.org/2001/XMLSchema#unsignedLong}value', int),
            ('{http://www.intel.com/2009/BagSchema#null}value', lambda x: None),
        )

        context_values_file_path = os.path.join(vtune_result, 'config', 'context_values.cfg')
        exit_path_exists(context_values_file_path)
        tree = ElementTree.parse(context_values_file_path)
        for elem in tree.iterfind('contextValue'):
            key = elem.attrib['id']
            for value_key, type_func in value_types:
                if value_key in elem.attrib:
                    self._context_values[key] = type_func(elem.attrib[value_key])
                    break
            else:
                exit_with_message('{} context value is not parsed.'.format(elem.attrib))

    def get_context_value(self, value_name, vtune_result):
        self.fill_context_values(vtune_result)
        value = self._context_values.get(value_name)
        if not value:
            exit_with_message('Unable to locate {} in context values.'.format(value_name))

        return value

    def get_context_value_with_bdf(self, bdf_str, value_name, vtune_result, num_gpus):
        gpu_context_value_prefix = 'gpu'

        if num_gpus == 1:
            value_name_without_bdf = gpu_context_value_prefix + value_name
            return self.get_context_value(value_name_without_bdf, vtune_result)

        value_name_with_bdf = gpu_context_value_prefix + '_' + bdf_str + '_' + value_name
        return self.get_context_value(value_name_with_bdf, vtune_result)


    def get_gpu_info(self, vtune_result):
        # Context value names for GPU information
        gpu_adapter_name_list_str = 'gpuAdapterNameList'
        eu_count_str = 'EuCount'
        eu_threads_count_str = 'EuThreadsCount'
        gpu_max_frequency_str = 'AdapterMaxCoreFreq'

        gpu_adapter_name_list = self.get_context_value(gpu_adapter_name_list_str, vtune_result)
        gpus = {}
        for gpu_info in gpu_adapter_name_list.split(';'):
            if not gpu_info:  # ignore empty string
                continue

            bdf, name = gpu_info.split('|')
            assert bdf not in gpus, 'GPU BDF {} duplicated in {}.'.format(bdf, vtune_result)
            gpus[bdf] = {'name': name}

        num_gpus = len(gpus)
        for bdf in gpus.keys():
            bdf_str = convert_bdf_to_context_bdf(bdf)
            gpus[bdf]['EuCount'] = self.get_context_value_with_bdf(bdf_str, eu_count_str, vtune_result, num_gpus)
            gpus[bdf]['EuThreadsCount'] = \
                self.get_context_value_with_bdf(bdf_str, eu_threads_count_str, vtune_result, num_gpus)

            gpus[bdf]['Max frequency'] = \
                self.get_context_value_with_bdf(bdf_str, gpu_max_frequency_str, vtune_result, num_gpus)

        _logger.debug('GPU information')
        _logger.debug('Found {} GPUs in {}:'.format(len(gpus), vtune_result))
        for gpu in gpus:
            _logger.debug(gpu)
            _logger.debug(gpus[gpu])

        return gpus

    def get_kernels(self, vtune_result):
        gpu_hotspots_request = ['-R', 'hotspots', '-group-by=gpu-adapter,gpu-stack,computing-task',
                                '-csv-delimiter=semicolon']

        stdout = get_report(gpu_hotspots_request, self.get_vtune_path(), vtune_result)
        log_important_message('GPU kernels in {}'.format(vtune_result))
        grid_kernels = DictReader(stdout, delimiter=';', quotechar='|')
        kernels = dict()
        for x in grid_kernels:
            _logger.debug(x)
            name = x[ComputeKernel.name_column]
            if '[Outside any task]' == name:  # Ignore '[Outside any task]'
                continue

            adapter = x[ComputeKernel.adapter_column] if ComputeKernel.adapter_column in grid_kernels.fieldnames else ''
            tile = x[ComputeKernel.tile_column] if ComputeKernel.tile_column in grid_kernels.fieldnames else ''
            global_size = x[ComputeKernel.global_size_column] if ComputeKernel.global_size_column in grid_kernels.fieldnames else ''
            local_size = x[ComputeKernel.local_size_column] if ComputeKernel.local_size_column in grid_kernels.fieldnames else ''
            simd = x[ComputeKernel.simd_column] if ComputeKernel.simd_column in grid_kernels.fieldnames else ''
            kernel = ComputeKernel(adapter, tile, name, global_size, local_size, simd)
            print(kernel)
            while kernel in kernels:  # Tasks with the same names and params, e.g. 'zeCommandListAppendMemoryCopy'
                kernel.unique_id += 1

            if kernel.unique_id > 0:
                _logger.debug('unique id: {}'.format(kernel.unique_id))

            # Save metrics
            metrics = dict()
            for column in set(grid_kernels.fieldnames) - set(ComputeKernel.kernel_field_columns):
                metrics[column] = float(x[column]) if x[column] else 0.0

            kernels[kernel] = metrics

        return (kernels, list(grid_kernels.fieldnames))


def print_grid(combined_metrics, all_metrics, delimiter, is_text, grid_filename=None):
    grid_lines = []
    column_sizes = [len(x) for x in all_metrics]

    if is_text:
        delimiter = '  '

    # Header
    grid_lines.append(all_metrics)

    # Data
    for kernel, kernel_combined_metrics in combined_metrics.items():
        column_idx = 0
        grid_line_values = []
        for metric in all_metrics:
            value = ''
            if metric in ComputeKernel.kernel_field_columns:
                value = kernel.key_fields[metric]
            elif metric in kernel_combined_metrics:
                value = str(kernel_combined_metrics[metric].get_value())

            if len(value) > column_sizes[column_idx]:
                column_sizes[column_idx] = len(value)

            grid_line_values.append(value)
            column_idx += 1

        grid_lines.append(grid_line_values)

    log_important_message('Generating the table of results')
    for grid_line_values in grid_lines:
        column_idx = 0
        for value in grid_line_values:
            if is_text:
                print(value.ljust(column_sizes[column_idx]), end=delimiter)
                column_idx += 1
            else:
                print(value, end=delimiter)

        print('')  # EOL

    if grid_filename:
        log_important_message('Saving report to {}'.format(grid_filename))
        with open(grid_filename, 'w') as grid_file:
            for grid_line_values in grid_lines:
                column_idx = 0
                for value in grid_line_values:
                    if is_text:
                        grid_file.write(value.ljust(column_sizes[column_idx]))
                        column_idx += 1
                    else:
                        grid_file.write(value)

                    grid_file.write(delimiter)

                grid_file.write('\n')


def main(parameters):
    master_result = parameters.first_result_dir
    exit_path_exists(master_result)

    results = parameters.result_dir
    for result in results:
        exit_path_exists(result)

    _logger.setLevel(logging.INFO)
    env_value = os.environ.get('AMPLXE_LOG_LEVEL')
    if env_value:
        _logger.setLevel(logging.DEBUG)

    log_formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(log_formatter)
    _logger.addHandler(stream_handler)

    vtune_reporter = VTuneReporter()
    vtune_reporter.compare_result_gpus(master_result, results)
    merged_kernel_metrics, all_metric_columns = vtune_reporter.combine_metrics(master_result, results)

    csv_delimiter = ';'
    if parameters.csv_delimiter == 'comma':
        csv_delimiter = ','
    elif parameters.csv_delimiter == 'colon':
        csv_delimiter = ':'
    elif parameters.csv_delimiter == 'tab':
        csv_delimiter = '\t'

    is_txt = parameters.format == 'text'
    print_grid(merged_kernel_metrics, all_metric_columns, csv_delimiter, is_txt, parameters.report_output)

    return 0


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Combine GPU metrics from multiple analyses run on the application')

    parser.add_argument('-format', type=str, choices=['csv', 'text'], default='csv',
                        help='Specify an output format (.CSV | .TXT) for the report.')

    parser.add_argument('-csv-delimiter', type=str, choices=['comma', 'semicolon', 'colon', 'tab'],
                        default='semicolon',
                        help='Specify a delimiter character for CSV output.\
                              Select from semicolon, comma, colon, or tab.')

    parser.add_argument('-report-output', type=str, help='Write report output to a file')
    parser.add_argument('first_result_dir', type=str, help='Directory containing first analysis result')
    parser.add_argument('result_dir', type=str, help='Directory containing analysis result', nargs='+')

    args = parser.parse_args()
    main(args)
