import os
import pandas as pd
import seaborn as sns
import pickle
import numpy as np
import matplotlib.pyplot as plt
from astropy.constants import G, m_p
import astropy.units as u
from assets.utils import get_poc_results
plt.rcParams.update({'font.size': 16})


def make_ecdf_figure(df_full: pd.DataFrame):
    sns.kdeplot(data=df_full, x='best_fit', hue='Class')
    plt.clf()
    plt.figure(figsize=(8, 6))
    sns.ecdfplot(data=df_full, x='best_fit', hue='Class', hue_order=['Quiescent', 'Protostellar', 'MYSOs', 'HII'])
    plt.semilogx()
    plt.xlim([1e5, 5e6])
    plt.xlabel('<$n$(H$_{2}$)> [cm$^{-3}$]')
    plt.ylabel('Cumulative probability')
    plt.title('Volume density ECDF per class')
    plt.tight_layout()
    plt.savefig(os.path.join('prs', 'output', 'poc_figures', 'best_fit_nh2.png'))


def make_violin_figure(df_full: pd.DataFrame, nsamples: int = 10000):
    with open(os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05', 'posteriors.pickle'),
              'rb') as infile:
        posteriors = pickle.load(infile)
    df_list = []
    for source in posteriors['PDF']:
        tmp_df = pd.DataFrame(columns=['source_name', 'Class', 'density'])
        try:
            class_phys = df_full[df_full['source_name'] == source]['Class'].iloc[0]
            samples = sampler(ecdf=posteriors['ECDF'][source], density_grid=posteriors['density_grid'][source],
                              nsamples=nsamples)
            tmp_df['density'] = samples
            tmp_df['source_name'] = source
            tmp_df['Class'] = class_phys
            df_list.append(tmp_df)
        except IndexError:
            pass

    plt.clf()
    plt.figure(figsize=(8, 6))
    df_generated_densities = pd.concat(df_list)
    sns.violinplot(df_generated_densities,
                   x='Class',
                   y='density',
                   order=['Quiescent', 'Protostellar', 'MYSOs', 'HII'],
                   bw=0.2)
    plt.semilogy()
    plt.ylim([5e4, 5e6])
    plt.ylabel('<$n$(H$_{2}$)> [cm$^{-3}$]')
    plt.title('Bootstrapped volume density distributions')
    plt.tight_layout()
    plt.savefig(os.path.join('prs', 'output', 'poc_figures', 'nh2_violin.png'))


def sampler(ecdf, density_grid, nsamples):
    generated_densities = np.zeros(shape=nsamples)
    for idx in range(nsamples):
        a = np.random.default_rng().uniform(0, 1)
        generated_index = np.argmax(ecdf >= a) - 1
        # In principle this can be interpolated for a smoother density distribution
        generated_densities[idx] = density_grid[generated_index]
    return generated_densities


def make_volume_densities_comparison_figure(df_full: pd.DataFrame):
    df_volume_density_sed = pd.read_csv(os.path.join('assets', 'data', 'top100_density.csv'),
                                        skiprows=4)
    df_merge = df_full.merge(df_volume_density_sed, left_on='AGAL_name', right_on='agal_name')
    hpd = df_full['hpd_interval'].values
    filtered_intervals = []
    for intervals in hpd:
        filtered_intervals.append(intervals[-2:])
    hpd_array = np.array(filtered_intervals)
    uncertainty_array = np.abs(hpd_array - np.array(df_full['best_fit'])[:, np.newaxis]).T
    plt.clf()
    plt.figure(figsize=(8, 6))
    plt.scatter(df_merge['volume_density'],
                df_merge['best_fit'],)
    plt.errorbar(df_merge['volume_density'],
                 df_merge['best_fit'],
                 yerr=uncertainty_array,
                 fmt='none')
    plt.plot([2e4, 2e6], [2e4, 2e6], color='red')
    plt.loglog()
    plt.xlabel('<$n$(H$_{2, dust}$)> [cm$^{-3}$]')
    plt.ylabel('<$n$(H$_{2, SAK}$)> [cm$^{-3}$]')
    plt.title('Comparison of the inferred volume density')
    plt.tight_layout()
    plt.savefig(os.path.join('prs', 'output', 'poc_figures', 'volume_density_comparison.png'))


def main():
    df_full = get_poc_results(line_fit_filename='ch3oh_data_top35.csv')
    df_full.rename(columns={'class_phys': 'Class'}, inplace=True)
    make_volume_densities_comparison_figure(df_full=df_full)
    df_full = df_full[(df_full['mass'] < 10000) & (df_full['mass'] > 300)]
    make_ecdf_figure(df_full=df_full)
    make_violin_figure(df_full=df_full)


if __name__ == '__main__':
    main()
