import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
from scipy.optimize import curve_fit
from assets.commons import (get_data,
                            load_config_file,
                            setup_logger,
                            validate_parameter)
import yaml
from typing import Union, List


def approx(x, x_shift, b, exponent, scale):
    arg = 2 * (x - x_shift) / scale
    clean_arg = np.where(np.abs(arg) > 0.99999, np.sign(arg) * 0.99999, arg)
    return (np.arctanh(clean_arg) + b) ** exponent


def main(ratios_to_fit: Union[List[str], None] = None):
    config = load_config_file(config_file_path=os.path.join('config', 'config.yml'))
    _ratio_to_fit = validate_parameter(
        param_to_validate=ratios_to_fit,
        default=['-'.join(lines_ratio) for lines_ratio in config['overrides']['lines_to_process']])
    data = get_data(limit_rows=None,
                    use_model_for_inference=config['run_type'])
    logger.info(f'Using {config["run_type"]} to fit data...')
    logger.info(f'Data contains {data.shape[0]} rows')

    x0 = {
        '87-86': ((0.55, 6.75, 0.9, 1),),
        '88-87': ((0.59, 6.9, 1.05, 0.96),),
        '88-86': ((0.55, 6.3, 0.83, 1.08),),
        '257-256': ((5, 4.45, 0.75, 8), (4.5, 8.5, 0.79, -7)),
        '381-380': ((2.04, 5.1, 0.6, 2.1), (2.18, 5.7, 0.79, -2.4)),
    }
    min_ratio = {
        '87-86': 0.08,
        '88-87': 0.08,
        '88-86': 0.03,
        '257-256': 1,
        '381-380': 1}
    max_ratio = {
        '87-86': 1.04,
        '88-87': 1.02,
        '88-86': 1.07,
        '257-256': 7.5,
        '381-380': 2.9}
    density_limits = {
        '87-86': ((5e4, 2e7),),
        '88-87': ((1e5, 3e7),),
        '88-86': ((1e5, 3e7),),
        '257-256': ((2e3, 7.5e4), (7.5e4, 4e6)),
        '381-380': ((1e4, 1.9e5), (1.9e5, 1e7))
    }

    best_fit_params = {}

    for ratio_string in _ratio_to_fit:
        _column_to_fit = f'ratio_{ratio_string}'
        data.sort_values(by=_column_to_fit, inplace=True)
        plt.clf()
        x_reg = np.linspace(min_ratio[ratio_string], max_ratio[ratio_string], 1000)

        kde_max = pickle.load(
            open(f'prs/output/run_type/{config["run_type"]}/kde_smoothed_{ratio_string}_max_locus.pickle', 'rb'))

        plt.plot(kde_max[1], kde_max[0], color='red')
        plt.semilogy()
        plt.xlabel('log(<n(H2)>)')
        plt.ylabel('Ratio')
        best_fit_params[ratio_string] = []
        for guesses, limits in zip(x0[ratio_string], density_limits[ratio_string]):
            clean_mask = (~np.isnan(kde_max[0]) & ~np.isnan(kde_max[1]) &
                          (kde_max[0] >= limits[0])
                          & (kde_max[0] <= limits[1]))
            best_fit, par_cov = curve_fit(approx,
                                          kde_max[1][clean_mask],
                                          np.log10(kde_max[0][clean_mask]),
                                          p0=guesses,
                                          bounds=((0.1, 2, 0.3, -9), (7, 10, 1.5, 9)))
            approx_density = 10 ** approx(x_reg, *best_fit)
            print(best_fit)
            best_fit_params[ratio_string].append({param: float(element) for element, param in zip(best_fit, ['a', 'b', 'nu', 's'])})
            approx_density = np.where((approx_density >= limits[0]) & (approx_density <= limits[1]), approx_density, np.nan)
            plt.plot(x_reg, approx_density)
        plt.savefig(os.path.join('..',
        'publications',
        '6373bb408e4040043398e495',
        'referee',
        f'analytical_expressions_comparison_{ratio_string}.png'))
        with (open(os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05', 'best_fit_params.yml'), 'w')
              as outfile):
            yaml.dump(best_fit_params, outfile)


if __name__ == '__main__':
    logger = setup_logger(name='POLYFIT')
    main()
