import pandas as pd
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger)
from assets.constants import line_ratio_mapping
from prs.prs_density_inference import (get_inference_data,
                                       plot_kde_results)
from typing import Tuple, Union, List

filename_root_map = {
    'abundance_comparison': ['constant_abundance_p15_q05_x01', 'constant_abundance_p15_q05_x10'],
    'double_microturbulence': ['double_microturbulence', ],
    'density_distribution': ['constant_abundance_p12_q05', 'constant_abundance_p18_q05']
}


def plot_kde_ratio_nh2_comparison(grid: List[np.array],
                                  values_on_grid: List[np.array],
                                  points_per_axis: int,
                                  ratio_string: str,
                                  root_outfile: str,
                                  training_data: List[pd.DataFrame],
                                  additional_data: pd.DataFrame = None,
                                  suffix_outfile: str = None,
                                  ratio_limits: Union[None, list] = None,
                                  plot_training_data: Union[int, bool] = False):
    """
        Plot the Kernel Density Estimate (KDE) of a ratio against average H2 density along the line-of-sight and save
         the plot as a PNG file.

        :param grid: The grid of x and y values used for the KDE.
        :param values_on_grid: The computed KDE values on the grid.
        :param points_per_axis: The number of points in each axis for the KDE grid.
        :param ratio_string: The ratio string indicating which ratio of the training data to plot.
        :param root_outfile: The root of the filename used to save the figures.
        :param training_data: The DataFrame containing the training data.
        :param suffix_outfile: Optional. The suffix to append to the output file name. Defaults to an empty string if None.
        :param ratio_limits: Optional. The limits for the ratio axis. Defaults to None, which auto-scales the axis.
        :param additional_data: Optional. The DataFrame containing the additional data to plot.
        :param plot_training_data: Optional. Whether to plot the training data; can be the index of the training data matrix to plot.
        :return: None. Saves the plot as a PNG file in the specified folder.
    """
    plt.rcParams.update({'font.size': 20})
    _suffix_outfile = validate_parameter(suffix_outfile, default='')
    plt.clf()
    plt.figure(figsize=(8, 6))
    if (additional_data is not None) and (ratio_string in additional_data.columns):
        plt.scatter(10 ** additional_data['log_density'], additional_data[ratio_string], marker='*', alpha=1,
                    facecolor='red', s=100)

    colours = (('yellow', 'gold'), ('lightgreen', 'green'), ('lightblue', 'blue'))
    for individual_grid, individual_values_on_grid, individual_training_data, plot_cols, idx \
            in zip(grid, values_on_grid, training_data, colours, range(3)):
        if plot_training_data is True:
            _plot_training_data = True
        elif (plot_training_data is not False):
            _plot_training_data = True if idx == plot_training_data else False
        else:
            _plot_training_data = False
        plot_kde_results(grid=individual_grid,
                         values_on_grid=individual_values_on_grid['values'].reshape(points_per_axis,
                                                                                    points_per_axis),
                         training_data=individual_training_data,
                         ratio_string=ratio_string,
                         plot_colours=plot_cols,
                         plot_training_data=_plot_training_data,
                         colour_scatter='grey')
    plt.semilogx()
    plt.xlabel(r'<$n$(H$_2$)> [cm$^{-3}$]')
    plt.ylabel(f'Ratio {line_ratio_mapping[ratio_string]}')
    plt.ylim(ratio_limits)
    plt.tight_layout()
    # plt.legend(loc='best')
    plt.savefig(os.path.join(
        'prs',
        'output',
        'comparison_figures',
        f'{root_outfile}_{ratio_string}.png'))


def main(ratio_list: list,
         comparison_model: List[str],
         root_outfile: str,
         points_per_axis: int = 200,
         limit_rows: Union[None, int] = None,
         best_bandwidths: Union[None, dict] = None,
         ratio_limits: Union[None, dict] = None,
         plot_training_data: Union[int, bool] = False):
    _use_model_for_inference = ['constant_abundance_p15_q05'] + comparison_model
    data = []
    logger.info('Getting data...')
    for model in _use_model_for_inference:
        df_data, _ = get_inference_data(use_model_for_inference=model,
                                        limit_rows=limit_rows)
        # data.append(df_data[df_data['source'] == 'RT'])
        data.append(df_data)

    for ratio_string in ratio_list:
        logger.info(f'Processing ratio {ratio_string}')
        kdes = []
        grids = []
        for model in _use_model_for_inference:
            _model_root_folder = os.path.join('prs', 'output', 'run_type', model)

            with open(
                    os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
            ) as infile:
                kde_dict = pickle.load(infile)
            kdes.append(kde_dict)
            grids.append(get_grid(best_bandwidths, kde_dict, ratio_string))

        logger.info(f'Plotting ratio {ratio_string}')
        plot_kde_ratio_nh2_comparison(grid=grids,
                                      values_on_grid=kdes,
                                      ratio_string=ratio_string,
                                      root_outfile=root_outfile,
                                      points_per_axis=points_per_axis,
                                      training_data=data,
                                      ratio_limits=ratio_limits[ratio_string],
                                      plot_training_data=plot_training_data)


def get_grid(best_bandwidths: dict,
             kde_dict: dict,
             ratio_string: str) -> np.array:
    """
    Computes the KDE grid for plotting.
    :param best_bandwidths: The values of the ratio bandwidth used for the KDE computation.
    :param kde_dict: The dictionary used to persist the KDE results, containing the grid point (x, y) and the values of
     the PDF at those points.
    :param ratio_string: The line ratio to be used for plotting.
    :return: The array with the scaled grid for plotting.
    """
    scaled_grid = np.meshgrid(kde_dict['x'], kde_dict['y'], indexing='ij')
    grid = scaled_grid.copy()
    grid[0] = scaled_grid[0] * 0.2
    grid[1] = scaled_grid[1] * best_bandwidths[ratio_string]
    return grid


if __name__ == '__main__':
    external_input = load_config_file(config_file_path=os.path.join('config', 'density_inference_input.yml'))
    logger = setup_logger(name='PRS - FIGURES')
    try:
        limit_rows = external_input['limit_rows']
    except KeyError:
        limit_rows = None
    try:
        points_per_axis = external_input['points_per_axis']
    except KeyError:
        points_per_axis = 200

    for comparison_figure in filename_root_map.keys():
        logger.info(f'Producing figure for {comparison_figure}.')
        _plot_training_data = 1 if comparison_figure == 'double_microturbulence' else False
        main(ratio_list=external_input['ratios_to_include'],
             points_per_axis=points_per_axis,
             limit_rows=limit_rows,
             comparison_model=filename_root_map[comparison_figure],
             root_outfile=comparison_figure,
             best_bandwidths=external_input['best_bandwidths'],
             ratio_limits=external_input['ratio_limits'],
             plot_training_data=_plot_training_data)
