import pandas as pd
import os
import pickle
import numpy as np
from itertools import product
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger)
from assets.commons.parsing import parse_grid_overrides
from assets.commons.training_utils import compute_and_add_similarity_cols
from assets.commons.training_utils import plot_results, get_avg_profiles
from typing import Tuple, Union, List
from assets.utils import get_profile_input


def main(species: Union[str, None] = None,
         tdust: float = 21.0,
         nh2: float = 1e5,
         model_root_folder: Union[str, None] = None,
         ratios_to_process: Union[List[List[str]], None] = None) -> pd.DataFrame:
    _model_root_folder = validate_parameter(
        model_root_folder,
        default=os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05')
    )
    _ratios_to_process = validate_parameter(
        ratios_to_process,
        default=[['87', '86'], ['88', '87'], ['88', '86'], ['257', '256'], ['381', '380']]
    )
    (overrides_grid, overrides_abundances) = get_profile_input(species=species,
                                                               model_root_folder=_model_root_folder)
    (avg_density_profile, avg_temperature_profile,
     molecule_coldens, std_density) = \
        get_avg_profiles(
            value_at_reference_density=nh2,
            value_at_reference_temperature=tdust,
            **overrides_abundances,
            **overrides_grid
        )
    x_interp = pd.DataFrame()

    x_interp['avg_nh2'] = np.log10(avg_density_profile)
    x_interp['avg_tdust'] = np.log10(avg_temperature_profile)
    x_interp['molecule_column_density'] = np.log10(molecule_coldens)
    x_interp['px_index'] = x_interp.index.copy()
    x_interp['nh2'] = np.log10(nh2)
    x_interp['tdust'] = np.log10(tdust)
    x_interp['std_nh2'] = np.log10(std_density)

    results_df = pd.DataFrame()
    x_interp['px_index'] = x_interp.index.copy()
    x_interp = x_interp.dropna(subset=['avg_nh2'])
    indices = x_interp['px_index'].reset_index(drop=True)
    results_df['px_index'] = indices
    ratio_string_list = []
    for line_ratio in _ratios_to_process:
        line_ratio_string = '-'.join(line_ratio)
        ratio_string_list.append(f'ratio_{line_ratio_string}')
        with open(os.path.join(_model_root_folder, 'trained_model', f'average_features_per_target_bin_{line_ratio_string}.pickle'),
                  'rb') as infile:
            average_features_per_taget_bin = pickle.load(infile)
        x_interp = compute_and_add_similarity_cols(average_features_per_target_bin=average_features_per_taget_bin,
                                                   input_df=x_interp,
                                                   columns_to_drop=['px_index'],
                                                   similarity_bins=50)
        with open(os.path.join(_model_root_folder, 'trained_model', f'ml_scaler_{line_ratio_string}.pickle'), 'rb') as infile:
            scaler = pickle.load(infile)
        with open(os.path.join(_model_root_folder, 'trained_model', f'ml_model_{line_ratio_string}.pickle'), 'rb') as infile:
            model = pickle.load(infile)
        columns = model.feature_names_in_
        x_interp_transformed = pd.DataFrame(scaler.transform(x_interp[columns]), columns=columns)
        results_df[f'ratio_{line_ratio_string}'] = 10 ** model.predict(x_interp_transformed)
        for column in x_interp.columns:
            if column.startswith('sim_'):
                x_interp.drop(columns=column, inplace=True)

    results_df = results_df.merge(x_interp, on=['px_index'], how='inner', validate='1:1')
    return results_df[['nh2', 'tdust', 'avg_tdust', 'avg_nh2', 'px_index'] + ratio_string_list]


if __name__ == '__main__':
    config = load_config_file(config_file_path=os.path.join('config', 'config.yml'))
    external_input = load_config_file(config_file_path=os.path.join('config', 'ml_modelling.yml'))
    logger = setup_logger(name='PRS - ML modelling')
    try:
        limit_rows = external_input['limit_rows']
    except KeyError:
        limit_rows = None
    try:
        use_model_for_training = external_input['use_model_for_training']
    except KeyError:
        use_model_for_training = None

    _use_model_for_training = validate_parameter(
        use_model_for_training,
        default='constant_abundance_p15_q05'
    )
    _model_root_folder = os.path.join('prs', 'output', 'run_type', _use_model_for_training)
    logger.info(f'Using {_use_model_for_training} for training in folder {_model_root_folder}')

    tdust_grid = parse_grid_overrides(par_name='tdust', config=external_input)
    nh2_grid = parse_grid_overrides(par_name='nh2', config=external_input)

    inferred_data_list = []
    for (tdust, nh2) in product(tdust_grid, nh2_grid):
        logger.info(f'Producing emulated data for {nh2}, {tdust}')
        if nh2 not in [1e3, 3e3, 9e3, 27e3, 81e3, 243e3, 729e3, 2187e3]:
            inferred_data_list.append(
                main(nh2=nh2,
                     tdust=tdust,
                     model_root_folder=_model_root_folder,
                     ratios_to_process=config['overrides']['lines_to_process'])
            )

    inferred_data = pd.concat(inferred_data_list, axis=0, ignore_index=True)
    plot_results(inferred_data=inferred_data,
                 use_model_for_inference=_use_model_for_training,
                 ratios_to_process=config['overrides']['lines_to_process'])
    inferred_data['avg_nh2'] = 10 ** inferred_data['avg_nh2']
    inferred_data.to_csv(os.path.join(_model_root_folder, 'data', 'inferred_data.csv'))
    logger.info('Completed emulation')
