import os
import matplotlib.pyplot as plt
import pickle
from assets.commons import (load_config_file,
                            setup_logger,
                            validate_parameter)

logger = setup_logger(name='PRS - FIGURE POSTERIOR IMPROVEMENT')


def main(use_model_for_inference, source_name):
    _use_model_for_inference = validate_parameter(use_model_for_inference, default='constant_abundance_p15_q05')
    with open(
            os.path.join('prs', 'output', 'run_type', _use_model_for_inference, 'posteriors.pickle'), 'rb'
    ) as infile:
        posteriors_dict = pickle.load(infile)
    pdf_values = posteriors_dict['PDF'][source_name]
    density_values = posteriors_dict['density_grid'][source_name]
    with open(
            os.path.join('prs', 'output', 'run_type', _use_model_for_inference, 'posteriors_rt_only.pickle'), 'rb'
    ) as infile:
        posteriors_rt_only_dict = pickle.load(infile)
    pdf_values_rt_only = posteriors_rt_only_dict['PDF'][source_name]
    density_values_rt_only = posteriors_rt_only_dict['density_grid'][source_name]
    plt.clf()
    plt.plot(density_values_rt_only, pdf_values_rt_only, label='RT only')
    plt.plot(density_values, pdf_values, label='ML')
    plt.title(f'Posterior improvement for {source_name}')
    plt.legend(loc='best')
    plt.xlabel('<$n$(H$_{2}$)> [cm$^{-3}$]')
    plt.ylabel('Probability density')
    plt.semilogx()
    plt.xlim(2e2, 2e6)
    plt.tight_layout()
    plt.savefig(os.path.join('prs', 'output', 'comparison_figures', f'posterior_improvement_{source_name}.png'))


if __name__ == '__main__':
    external_input = load_config_file(config_file_path=os.path.join('config', 'density_inference_input.yml'))
    try:
        use_model_for_inference = external_input['use_model_for_inference']
        if use_model_for_inference == 'PLACEHOLDER':
            logger.warning('No model specified for inference in density_inference_input.yml. Using fiducial.')
            use_model_for_inference = 'constant_abundance_p15_q05'
    except KeyError:
        use_model_for_inference = None

    main(use_model_for_inference=use_model_for_inference, source_name='G08.71-0.41')
