import logging
import os
import pickle
from typing import Union, List, Tuple

import numpy as np
import pandas as pd
from astropy import units as u
from matplotlib import pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingGridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import QuantileTransformer
from xgboost import XGBRegressor

from assets.commons import (validate_parameter,
                            get_data)
from assets.commons.grid_utils import (compute_analytic_profile,
                                       compute_density_and_temperature_avg_profiles)


def compute_and_add_similarity_cols(average_features_per_target_bin: pd.DataFrame,
                                    input_df: pd.DataFrame,
                                    similarity_bins: int,
                                    columns_to_drop: Union[list, None] = None) -> pd.DataFrame:
    """
    Compute the cosine similarity between the input matrix and the average values split in bins of the target,
     adding the columns to the input.
    :param average_features_per_target_bin: the dataframe with the average values of the input per target bin
    :param input_df: the input dataframe
    :param columns_to_drop: the columns to drop from the input matrix before computing the cosine similarities
    :param similarity_bins: number of bins used to group the target variable
    :return: the input matrix with the similarity columns
    """
    _columns_to_drop = validate_parameter(columns_to_drop, default=[])
    clean_df = input_df.copy().drop(columns=_columns_to_drop)
    ordered_columns = average_features_per_target_bin.columns
    similarities = cosine_similarity(clean_df[ordered_columns], average_features_per_target_bin[ordered_columns])
    similarities = pd.DataFrame(similarities,
                                columns=[f'sim_{str(bin_idx).zfill(2)}' for bin_idx in range(similarity_bins)])
    similarities.index = clean_df.index
    return pd.merge(input_df, similarities, left_index=True, right_index=True, how='inner', validate='1:1', on=None)


def plot_results(inferred_data: pd.DataFrame,
                 use_model_for_inference: str = None,
                 ratios_to_process: Union[List[List[str]], None] = None):
    _use_model_for_inference = validate_parameter(
        use_model_for_inference,
        default='constant_abundance_p15_q05'
    )
    _ratios_to_process = validate_parameter(
        ratios_to_process,
        default=[['87', '86'], ['88', '87'], ['88', '86'], ['257', '256'], ['381', '380']]
    )

    postprocessed_data = get_data(limit_rows=None,
                                  use_model_for_inference=_use_model_for_inference)
    plt.clf()
    _, subplots = plt.subplots(nrows=2, ncols=3)
    subplots[-1][-1].axis('off')
    for idx, line_pairs in enumerate(_ratios_to_process):
        subplots[int(idx / 3)][idx % 3].semilogx()
        subplots[int(idx / 3)][idx % 3].scatter(10 ** inferred_data['avg_nh2'],
                                                inferred_data[f'ratio_{line_pairs[0]}-{line_pairs[1]}'],
                                                alpha=0.1)
        subplots[int(idx / 3)][idx % 3].scatter(10 ** postprocessed_data['avg_nh2'],
                                                postprocessed_data[f'ratio_{line_pairs[0]}-{line_pairs[1]}'],
                                                alpha=0.01)
        plt.savefig(f'training_results_{_use_model_for_inference}.png')


def get_avg_profiles(
        value_at_reference_density: float,
        value_at_reference_temperature: float,
        central_value_density: float = 1e8,
        power_law_density: float = -1.5,
        central_value_temperature: float = 2000,
        power_law_temperature: float = -0.5,
        maximum_radius: u.Quantity = 0.9 * u.pc,
        distance_reference: u.Quantity = 0.5 * u.pc,
        molecular_abundance: float = 1e-9,
        threshold: float = 90,
        abundance_jump: float = 1) -> Tuple[np.array, np.array, np.array, np.array]:
    """
        Get the average density, temperature, and standard deviation profiles along the line-of-sight, as well as the
            molecular column density.

        This function computes the average density, temperature, and standard deviation profiles along the LOS, together
            with the molecular column density.
        It first computes the density and temperature profiles analytically based on the specified parameters.
        Then, it calculates the average density and temperature profiles along with related quantities such as the
            molecular column density and the standard deviation of the species' number density using the computed
            profiles and specified parameters.

        :param value_at_reference_density (float): Reference value of the density profile.
        :param value_at_reference_temperature (float): Reference value of the temperature profile.
        :param central_value_density (float, optional): Central value of the density profile, and ceiling for this
            parameter. Defaults to 1e8 cm^-3.
        :param power_law_density (float, optional): Power-law index for the density profile. Defaults to -1.5.
        :param central_value_temperature (float, optional): Central value of the temperature profile, and ceiling for
            this parameter. Defaults to 2000 K.
        :param power_law_temperature (float, optional): Power-law index for the temperature profile. Defaults to -0.5.
        :param maximum_radius (u.Quantity, optional): Maximum radius within which to compute the profiles, larger radii
            are assumed to be devoid of material. Defaults to 0.9 pc.
        :param distance_reference (u.Quantity, optional): Reference distance at which the profiles values are specified.
            Defaults to 0.5 pc.
        :param molecular_abundance (float, optional): Molecular abundance for computing molecular number density.
            Defaults to 1e-9 for E-CH3OH.
        :param threshold (float, optional): Threshold in temperature above which E-CH3OH is sublimated from grains in
            hot cores. Defaults to 90 K.
        :param abundance_jump (float, optional): Abundance jump in hot cores. Defaults to 1, i.e. no jump.

        :return: Tuple[np.array, np.array, np.array, np.array]
            - Average density profile.
            - Average temperature profile.
            - Molecular column density.
            - Standard deviation of molecular number density profile.
        """
    density_profile = compute_analytic_profile(central_value=central_value_density,
                                               power_law_index=power_law_density,
                                               maximum_radius=maximum_radius.to(u.cm).value,
                                               value_at_reference=value_at_reference_density,
                                               distance_reference=distance_reference.to(u.cm).value)
    temperature_profile = compute_analytic_profile(central_value=central_value_temperature,
                                                   power_law_index=power_law_temperature,
                                                   maximum_radius=maximum_radius.to(u.cm).value,
                                                   value_at_reference=value_at_reference_temperature,
                                                   distance_reference=distance_reference.to(u.cm).value)
    (avg_density_profile, avg_temperature_profile,
     molecule_coldens, std_density) = compute_density_and_temperature_avg_profiles(
        density_profile=density_profile,
        temperature_profile=temperature_profile,
        molecular_abundance=molecular_abundance,
        threshold=threshold,
        abundance_jump=abundance_jump
    )
    return (avg_density_profile.flatten('F'), avg_temperature_profile.flatten('F'),
            molecule_coldens.flatten('F'), std_density.flatten('F'))


def get_weights(data: pd.Series,
                show_weights: bool = False,
                bandwidth: float = 0.1,
                subsample_fraction: float = 0.5) -> np.array:
    """
    Get weights for each data point using kernel density estimation (KDE).

    :param data (pd.Series): Input data series for which weights are computed.
    :param show_weights (bool, optional): Whether to interactively visualize the KDE and weights. Defaults to False.
    :param bandwidth (float, optional): Bandwidth parameter for the KDE. Defaults to 0.1.
    :param subsample_fraction (float, optional): Fraction of data to use for subsampling. Defaults to 0.5 to speed up
        computation.

    :return: np.array
        Array of weights corresponding to each data point.

    This function computes weights for each data point using kernel density estimation (KDE).
    It fits a KDE model to the subsampled data and then evaluates the KDE at each data point
    to obtain the weights. If specified, it also visualizes the KDE and weights.
    """
    kde = KernelDensity(bandwidth=bandwidth, kernel='epanechnikov')
    kde.fit(np.array(data.sample(frac=subsample_fraction, random_state=3)).reshape(-1, 1))
    tgt_min = data.min()
    tgt_max = data.max()
    ymin_kde = tgt_min - 1
    ymax_kde = tgt_max + 1
    grid = np.linspace(ymin_kde, ymax_kde, 100)
    z = np.exp(kde.score_samples(list(grid.reshape(-1, 1))))
    z_points = np.exp(kde.score_samples(list(data.values.reshape(-1, 1))))
    weights = 1 / z_points
    finite_weights_mask = np.isfinite(weights)
    padding_value = np.nanmax(weights[finite_weights_mask])
    weights = np.where(finite_weights_mask,
                       weights,
                       padding_value)
    if show_weights is True:
        plt.clf()
        plt.plot(grid, z)
        plt.scatter(data, z_points, marker='+')
        plt.scatter(data, weights / weights.max(), marker='+')
        plt.show()
    return weights


def split_data(merged: pd.DataFrame,
               target_column: str,
               predictor_columns: Union[list, None] = None,
               test_models: Union[list, None] = None,
               validation_models: Union[list, None] = None) -> Tuple[
    np.array, np.array, np.array, np.array, np.array, np.array]:
    """
    This function splits the data into training, validation, and test sets. It also performs some data processing,
    such as taking the logarithm for the specified columns.
    :param merged (pd.DataFrame): The DataFrame containing both predictor and target columns.
    :param target_column (str): The name of the target column in the DataFrame.
    :param predictor_columns (Union[list, None], optional): A list of predictor column names to be used in training.
        Defaults to None.
    :param test_models (Union[list, None], optional): A list of values for the characteristic number density,
        identifying the test data. If not specified, it uses 27000 cm^-3.
    :param validation_models (Union[list, None], optional):  A list of values for the characteristic number density,
        identifying the validation data. If not specified, it uses 2187000 cm^-3

    :return: Tuple[np.array, np.array, np.array, np.array, np.array, np.array]
        - x_test: Predictor variables for the test dataset.
        - x_train: Predictor variables for the training dataset.
        - x_validation: Predictor variables for the validation dataset.
        - y_test: Target variable for the test dataset.
        - y_train: Target variable for the training dataset.
        - y_validation: Target variable for the validation dataset.
    """
    _test_models = validate_parameter(test_models, default=[27000.0])
    _validation_models = validate_parameter(validation_models, default=[2187000.0])
    _predictor_columns = validate_parameter(
        predictor_columns,
        default=['nh2', 'tdust', 'avg_tdust', 'avg_nh2', 'molecule_column_density', 'std_nh2']
    )
    condition_test = merged['nh2'].isin(_test_models)
    condition_validation = merged['nh2'].isin(_validation_models)
    y = np.log10(merged[target_column].copy())
    X = merged[_predictor_columns].copy()
    nh2_list = [
        1.000e+03,
        3.000e+03,
        9.000e+03,
        2.700e+04,
        8.100e+04,
        2.430e+05,
        7.290e+05,
        2.187e+06,
        6.561e+06
    ]
    subsample = merged[
        merged['nh2'].isin(nh2_list) & ~merged['nh2'].isin(_test_models) & ~merged['nh2'].isin(_validation_models)]
    assert _test_models not in subsample.nh2.unique()
    assert _validation_models not in subsample.nh2.unique()
    y_sub = np.log10(subsample[target_column].copy())
    x_sub = subsample[_predictor_columns].copy()
    x_train = x_sub[(~condition_test) & (~condition_validation)].reset_index(drop=True)
    y_train = y_sub[(~condition_test) & (~condition_validation)].reset_index(drop=True)
    x_test = X[condition_test].reset_index(drop=True)
    y_test = y[condition_test].reset_index(drop=True)
    x_validation = X[condition_validation].reset_index(drop=True)
    y_validation = y[condition_validation].reset_index(drop=True)
    for x_df in (x_train, x_test, x_validation):
        for column in ('nh2', 'tdust'):
            x_df[column] = np.log10(x_df[column])
    return x_test, x_train, x_validation, y_test, y_train, y_validation


def train_ml_model(model_type: str,
                   model_kwargs: Union[dict, None],
                   x_train: pd.DataFrame,
                   y_train: pd.Series,
                   use_validation: bool = False,
                   x_validation: Union[None, pd.DataFrame] = None,
                   y_validation: Union[None, pd.Series] = None) -> Union[XGBRegressor, RandomForestRegressor]:
    """
    This function trains a machine learning model based on the specified model type and its corresponding parameters.
    It supports training XGBoost, Random Forest, with hyperparamter optimization and potentially other models in the
        future.
    :param model_type (str): Type of the machine learning model to be trained. Allowed values are 'XGBoost',
        'XGBoost_gridsearch', 'auto_skl', 'RandomForest'.
    :param model_kwargs (Union[dict, None]): Dictionary containing keyword arguments for initializing the model.
        Defaults to None.
    :param x_train (pd.DataFrame): Predictor variables for the training dataset.
    :param y_train (pd.Series): Target variable for the training dataset.
    :param use_validation (bool, optional): Whether to use a validation dataset for training during cross-validation and
        optimization. Defaults to False.
    :param x_validation (Union[None, pd.DataFrame], optional): Predictor variables for the validation dataset.
        Defaults to None.
    :param y_validation (Union[None, pd.Series], optional): Target variable for the validation dataset.
        Defaults to None.

    :return: Union[XGBRegressor, RandomForestRegressor]
        Trained machine learning model.
    """
    allowed_model_types = ('XGBoost', 'XGBoost_gridsearch', 'auto_skl', 'RandomForest')
    assert model_type in allowed_model_types
    if model_type == 'XGBoost':
        model = XGBRegressor(**model_kwargs)
    elif model_type == 'XGBoost_gridsearch':
        model = HalvingGridSearchCV(XGBRegressor(),
                                    param_grid=model_kwargs['param_grid'],
                                    **model_kwargs['param_gridsearch'],
                                    refit=True,
                                    verbose=0)
    elif model_type == 'RandomForest':
        model = RandomForestRegressor()
    else:
        raise RuntimeError('Model type requested not yet implemented')

    if model_type in ('XGBoost', 'XGBoost_gridsearch', 'RandomForest'):
        weights = get_weights(data=y_train, show_weights=False)
        if use_validation is True:
            model.fit(x_train, y_train, sample_weight=weights, eval_set=[(x_validation, y_validation)], verbose=0)
        else:
            model.fit(x_train, y_train, sample_weight=weights)
    else:
        model.fit(x_train, y_train)
    return model


def create_similarity_columns(model_root_folder: str,
                              target_id: str,
                              similarity_bins: int,
                              x_dataframes: dict,
                              y_train: pd.Series) -> dict:
    """
    This function evaluates and adds cosine-similarity columns based on the average features per target bin. It calculates
    the average features per target bin using the training data, groups the data into bins based on target
    values, and then computes similarity columns for each DataFrame in the input dictionary (train, test, validation).
    Finally, it saves the average features per target bin to a pickle file and returns the updated DataFrame dictionary.

    :param model_root_folder (str): The root folder where the trained model and other related files will be stored.
    :param target_id (str): Identifier for the target variable.
    :param similarity_bins (int): Number of bins for similarity calculation.
    :param x_dataframes (dict): Dictionary containing DataFrame objects for training, validation, and test sets.
    :param y_train (pd.Series): Target variable for the training dataset.

    :return: dict
        Dictionary containing updated DataFrame objects with added similarity columns.
    """
    pickle_outfile_name = os.path.join(model_root_folder, 'trained_model',
                                       f'average_features_per_target_bin_{target_id}.pickle')
    _x_dataframes = x_dataframes.copy()
    bins = pd.cut(y_train, similarity_bins)
    columns_to_drop = ([column for column in _x_dataframes['x_train'].columns if column.startswith('mom_zero_')]
                       + [column for column in _x_dataframes['x_train'].columns if column.startswith('ratio_')])
    clean_train = _x_dataframes['x_train'].copy().drop(columns=columns_to_drop)
    average_features_per_target_bin = clean_train.groupby(bins).mean()
    for key in _x_dataframes.keys():
        _x_dataframes[key] = compute_and_add_similarity_cols(
            average_features_per_target_bin=average_features_per_target_bin,
            columns_to_drop=columns_to_drop,
            input_df=_x_dataframes[key],
            similarity_bins=similarity_bins)
    if os.path.isfile(pickle_outfile_name):
        os.remove(pickle_outfile_name)
    with open(pickle_outfile_name, 'wb') as outfile:
        pickle.dump(average_features_per_target_bin, outfile)
    return _x_dataframes


def train_models_wrapper(target_name: str,
                         model_type: str,
                         model_kwargs: dict,
                         logger: logging.Logger,
                         similarity_bins: Union[None, int] = 50,
                         limit_rows: int = None,
                         model_root_folder: Union[str, None] = None,
                         use_validation: bool = False):
    """
    Train machine learning models for a specified target variable.

    :param target_name (str): Name of the target variable in the training set.
    :param model_type (str): Type of the machine learning model to be trained.
    :param model_kwargs (dict): Dictionary containing keyword arguments for initializing the model.
    :param similarity_bins (Union[None, int], optional): Number of bins for similarity calculation. Defaults to 50.
    :param limit_rows (int, optional): Limit the number of rows for preprocessing. Defaults to None.
    :param model_root_folder (Union[str, None], optional): Root folder where the trained model and other related files
        will be stored. if not specified, the fiducial model is assumed, i.e. 'prs/output/run_type/constant_abundance_p15_q05'.
    :param use_validation (bool, optional): Whether to use a validation dataset for training. Defaults to False.
    :param logger (logging.Logger): Logger object for logging messages.
    """
    _model_root_folder = validate_parameter(
        model_root_folder,
        default=os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05')
    )
    logger.info(f"Training for target {target_name}")
    target_id = target_name.split('_')[-1]
    x_dataframes, y_series = preprocess_data(model_root_folder=_model_root_folder,
                                             similarity_bins=similarity_bins,
                                             target_name=target_name,
                                             limit_rows=limit_rows,
                                             logger=logger)

    model = train_ml_model(model_type=model_type,
                           model_kwargs=model_kwargs,
                           x_train=x_dataframes['x_train'],
                           y_train=y_series['y_train'],
                           x_validation=x_dataframes['x_test'],
                           y_validation=y_series['y_test'],
                           use_validation=use_validation)
    with open(os.path.join(_model_root_folder, 'trained_model', f'ml_model_{target_id}.pickle'), 'wb') as outfile:
        pickle.dump(model, outfile)
    inspect_training(model=model,
                     x_dataframes=x_dataframes,
                     y_series=y_series,
                     logger=logger)


def inspect_training(model: Union[XGBRegressor, RandomForestRegressor],
                     x_dataframes: dict,
                     y_series: dict,
                     logger: logging.Logger):
    """
    Inspect the training process results and evaluate the model.

    :param model (Union[XGBRegressor, RandomForestRegressor]): Trained machine learning model.
    :param x_dataframes (dict): Dictionary containing DataFrame objects for training, validation, and test sets.
    :param y_series (dict): Dictionary containing target variable Series for training, validation, and test sets.
    :param logger (logging.Logger): Logger object for logging messages.
    """
    pred = model.predict(x_dataframes['x_test'])
    pred_val = model.predict(x_dataframes['x_validation'])
    pred_train = model.predict(x_dataframes['x_train'])
    mse = mean_squared_error(y_series['y_test'], pred)
    mse_validation = mean_squared_error(y_series['y_validation'], pred_val)
    mse_train = mean_squared_error(y_series['y_train'], pred_train)
    logger.info(f'MSE {mse}, MSE validation {mse_validation}, MSE train {mse_train}')
    plt.clf()
    plt.scatter(10 ** y_series['y_test'], 10 ** pred)
    plt.scatter(10 ** y_series['y_validation'], 10 ** pred_val)
    plt.scatter(10 ** y_series['y_train'], 10 ** pred_train)
    plt.loglog()
    y_low = np.min(10 ** y_series['y_train'])
    y_high = np.max(10 ** y_series['y_train'])
    plt.plot([y_low, y_high], [y_low, y_high], color='r')
    plt.show()


def preprocess_data(model_root_folder: str,
                    similarity_bins: Union[None, int],
                    target_name: str,
                    logger: logging.Logger,
                    limit_rows: int = None) -> Tuple[dict, dict]:
    """
        Preprocess the data for training machine learning models.

        :param model_root_folder (str): Root folder where the trained model and other related files will be stored.
        :param similarity_bins (Union[None, int]): Number of bins for similarity calculation.
        :param target_name (str): Name of the target variable.
        :param limit_rows (int, optional): Limit the number of rows for preprocessing. Defaults to None.
        :param logger (logging.Logger): Logger object for logging messages.
        :return: Tuple[dict, dict]
            - Dictionary containing DataFrame objects for training, validation, and test sets.
            - Dictionary containing target variable Series for training, validation, and test sets.
        """
    target_id = target_name.split('_')[-1]
    merged = get_data(limit_rows,
                      use_model_for_inference=model_root_folder.split(os.path.sep)[-1])
    x_test, x_train, x_validation, y_test, y_train, y_validation = split_data(merged=merged,
                                                                              target_column=target_name)
    x_dataframes = {
        'x_train': x_train,
        'x_test': x_test,
        'x_validation': x_validation
    }
    if similarity_bins is not None:
        x_dataframes = create_similarity_columns(model_root_folder=model_root_folder,
                                                 target_id=target_id,
                                                 similarity_bins=similarity_bins,
                                                 x_dataframes=x_dataframes,
                                                 y_train=y_train)
    columns = x_dataframes['x_train'].columns
    logger.info(f'Feature used for training {columns}')
    scaler = QuantileTransformer(output_distribution='uniform')
    scaler.fit(X=x_dataframes['x_train'])
    for key in x_dataframes.keys():
        x_dataframes[key] = pd.DataFrame(scaler.transform(X=x_dataframes[key]), columns=columns)
    logger.info(f'Train {len(x_train)}, test {len(x_test)}, validation {len(x_validation)}')

    with open(os.path.join(model_root_folder, 'trained_model', f'ml_scaler_{target_id}.pickle'), 'wb') as outfile:
        pickle.dump(scaler, outfile)
    y_series = {
        'y_train': y_train,
        'y_test': y_test,
        'y_validation': y_validation
    }
    return x_dataframes, y_series
