import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import splrep, BSpline, RegularGridInterpolator
from scipy.optimize import curve_fit
from assets.commons import (get_data,
                            load_config_file,
                            setup_logger,
                            validate_parameter)
from typing import Union, List


def approx(x, x_shift, a, b, exponent, scale=1):
    return (np.abs(np.arctanh(2 * (x - x_shift) / scale) + a)) ** exponent + b


def main(ratios_to_fit: Union[List[str], None] = None):
    config = load_config_file(config_file_path=os.path.join('config', 'config.yml'))
    _ratio_to_fit = validate_parameter(
        param_to_validate=ratios_to_fit,
        default=['-'.join(lines_ratio) for lines_ratio in config['overrides']['lines_to_process']])
    data = get_data(limit_rows=None,
                    use_model_for_inference=config['run_type'])
    logger.info(f'Using {config["run_type"]} to fit data...')
    logger.info(f'Data contains {data.shape[0]} rows')

    bw = {
        '87-86': 0.1,
        '88-87': 0.1,
        '88-86': 0.1,
        '257-256': 0.4,
        '381-380': 0.2}

    x0 = {
        '87-86': ((0.55, 2.3, 3.75, 0.9),),
        '88-87': ((0.59, 2.3, 3.9, 1.05, 0.96),),
        '88-86': ((0.55, 2.5, 4.3, 0.83, 1.08),),
        '257-256': ((5, 2.5, 2.45, 0.75, 8), (5, 2.5, 3.33, 0.75, -8)),
        '381-380': ((2.04, 2.5, 3.1, 0.6, 2.1), (2.18, 2.5, 3.7, 0.79, -2.4)),
    }
    min_ratio = {
        '87-86': 0.08,
        '88-87': 0.08,
        '88-86': 0.03,
        '257-256': 1,
        '381-380': 1}
    max_ratio = {
        '87-86': 1.04,
        '88-87': 1.02,
        '88-86': 1.07,
        '257-256': 7.5,
        '381-380': 2.9}

    for ratio_string in _ratio_to_fit:
        _column_to_fit = f'ratio_{ratio_string}'
        data.sort_values(by=_column_to_fit, inplace=True)
        plt.clf()
        plt.scatter(data[_column_to_fit], data['avg_nh2'], marker='+', alpha=0.1)
        x_reg = np.linspace(min_ratio[ratio_string], max_ratio[ratio_string], 100)

        data_clean = data[[_column_to_fit, 'avg_nh2']].loc[data[_column_to_fit] <= max_ratio[ratio_string]].copy()
        kde = pickle.load(
            open(f'prs/output/run_type/{config["run_type"]}/trained_model/ratio_density_kde_{ratio_string}.pickle',
                 'rb'))
        rgi = RegularGridInterpolator((kde['y'] * bw[ratio_string], kde['x'] * 0.2), kde['values'].reshape(200, 200).T)
        density = rgi(data_clean[[_column_to_fit, 'avg_nh2']])
        data_clean = data_clean.loc[(density > density.mean() - density.std())]
        plt.scatter(data_clean[_column_to_fit], data_clean['avg_nh2'], marker='+', alpha=0.1, color='red')
        for components in x0[ratio_string]:
            plt.plot(x_reg, approx(x_reg, *components), color='cyan')
        plt.show()


if __name__ == '__main__':
    logger = setup_logger(name='POLYFIT')
    main()
