import operator
import pandas as pd
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
import yaml
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger,
                            get_postprocessed_data)
from assets.constants import line_ratio_mapping
from multiprocessing import Pool
from typing import Tuple, Union, List
from scipy.integrate import cumtrapz
from scipy.stats import truncnorm
from scipy.interpolate import RegularGridInterpolator
from functools import reduce


def get_truncated_normal(mean: float = 0,
                         sd: float = 1,
                         lower_bound: float = 0,
                         upper_bound: float = 10):
    return truncnorm(
        (lower_bound - mean) / sd, (upper_bound - mean) / sd, loc=mean, scale=sd)


def train_kde_model(ratio: List[str],
                    training_data: pd.DataFrame,
                    points_per_axis: int,
                    best_bandwidth: Union[None, dict],
                    model_root_folder: Union[str, None] = None,
                    ratio_limits: Union[None, dict] = None,
                    rt_adjustment_factor: Union[int, float] = 2) -> Tuple[str, KernelDensity]:
    """
    Create the KDE from the modelled datapoints.
    :param points_per_axis: number of point in the KDE grid evaluation
    :param ratio: the line ratio to use in the model, in the form of a list of 2 line indices, according to the
    collisional coefficients file
    :param best_bandwidth: bandwidth to use for the KDE smoothing
    :param training_data: the dataframe containing the data for the KDE modelling
    :param model_root_folder: the folder referring to the model used for computation;
        defaults to fiducial model (constant_abundance_p15_q05)
    :param rt_adjustment_factor: a scaling factor to apply to the specified bandwidth for the computation of the KDE
        for the sparser RT-only data
    :param ratio_limits: fixed plotting limits
    :return: the string representing the ratio modelled and the KDE itself
    """
    _model_root_folder = validate_parameter(
        model_root_folder,
        default=os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05')
    )
    ratio_string = '-'.join(ratio)
    scaled_grid, model, positions, x, y, z = get_kde(points_per_axis=points_per_axis,
                                                     ratio_string=ratio_string,
                                                     training_data=training_data,
                                                     best_bandwidth=best_bandwidth[ratio_string])
    grid = scaled_grid.copy()
    grid[0] = scaled_grid[0] * 0.2
    grid[1] = scaled_grid[1] * best_bandwidth[ratio_string]
    plot_kde_ratio_nh2(grid=grid,
                       values_on_grid=z.reshape(points_per_axis, points_per_axis),
                       ratio_string=ratio_string,
                       model_root_folder=_model_root_folder,
                       training_data=training_data,
                       suffix_outfile='_ml',
                       ratio_limits=ratio_limits[ratio_string])

    scaled_grid, _, _, _, _, z_rt_only = get_kde(points_per_axis=points_per_axis,
                                                 ratio_string=ratio_string,
                                                 training_data=training_data[training_data['source'] == 'RT'],
                                                 best_bandwidth=best_bandwidth[ratio_string],
                                                 x=x / rt_adjustment_factor,
                                                 y=y / rt_adjustment_factor,
                                                 bw_adjustment_factor=rt_adjustment_factor)
    grid *= rt_adjustment_factor
    plot_kde_ratio_nh2(grid=grid,
                       values_on_grid=z_rt_only.reshape(points_per_axis, points_per_axis),
                       ratio_string=ratio_string,
                       model_root_folder=_model_root_folder,
                       training_data=training_data[training_data['source'] == 'RT'],
                       ratio_limits=ratio_limits[ratio_string])
    with open(
            os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'wb'
    ) as outfile:
        pickle.dump({'x': x, 'y': y, 'positions': positions, 'values': z, 'values_rt_only': z_rt_only}, outfile)
    return ratio_string, model


def get_kde(points_per_axis: int,
            ratio_string: str,
            training_data: pd.DataFrame,
            x: np.array = None,
            y: np.array = None,
            best_bandwidth: float = None,
            bw_adjustment_factor: Union[float, int] = 1) -> tuple:
    """
        Compute the Kernel Density Estimate (KDE) for a given ratio and training data.

        :param points_per_axis: Number of points to use along each axis for the KDE grid.
        :param ratio_string: The ratio string indicating which ratio of the training data to use.
        :param training_data: The DataFrame containing the training data.
        :param x: Optional. The x-axis values for the KDE grid. Defaults to a computed range if None.
        :param y: Optional. The y-axis values for the KDE grid. Defaults to a computed range if None.
        :param best_bandwidth: The best bandwidth to use for KDE. Defaults to 0.2 if None.
        :param bw_adjustment_factor: The adjustment factor to apply to the bandwidth. Defaults to 1.
        :return: A tuple containing the grid, the KDE model, the positions, x-axis values, y-axis values, and the
            computed KDE values.
    """

    _best_bandwidth = bw_adjustment_factor * validate_parameter(best_bandwidth, 0.2)
    _x_bandwidth = bw_adjustment_factor * 0.2
    log_nh2 = np.log10(training_data['avg_nh2'])
    xy_train = np.array([log_nh2, training_data[f'ratio_{ratio_string}']]).T
    xy_train_scaled = xy_train / np.array([_x_bandwidth, _best_bandwidth])
    # The grid search suggests a bandwidth of 0.1, but I need to make it larger due to the relatively large spacing
    # between models
    model = KernelDensity(bandwidth=1, kernel='epanechnikov')
    model.fit(xy_train_scaled)
    if x is None and y is None:
        xmin = log_nh2.min() - 0.5
        xmax = log_nh2.max() + 0.5
        ymin = np.max([training_data[f'ratio_{ratio_string}'].min() - 0.5, 0])
        ymax = training_data[f'ratio_{ratio_string}'].max() * 1.15
        x = np.linspace(xmin, xmax, points_per_axis) / _x_bandwidth
        y = np.linspace(ymin, ymax, points_per_axis) / _best_bandwidth
    grid = np.meshgrid(x, y, indexing='ij')
    positions = np.vstack([grid[0].ravel(), grid[1].ravel()])
    z = np.exp(model.score_samples(positions.T)) + 1e-5
    z /= z.max()
    return grid, model, positions, x, y, z


def plot_kde_ratio_nh2(grid: np.array,
                       values_on_grid: np.array,
                       ratio_string: str,
                       model_root_folder: str,
                       training_data: pd.DataFrame,
                       suffix_outfile: str = None,
                       ratio_limits: Union[None, list] = None):
    """
        Plot the Kernel Density Estimate (KDE) of a ratio against average H2 density along the line-of-sight and save
         the plot as a PNG file.

        :param grid: The grid of x and y values used for the KDE.
        :param values_on_grid: The computed KDE values on the grid.
        :param ratio_string: The ratio string indicating which ratio of the training data to plot.
        :param model_root_folder: The root folder where the model and figures are stored.
        :param training_data: The DataFrame containing the training data.
        :param suffix_outfile: Optional. The suffix to append to the output file name. Defaults to an empty string if None.
        :param ratio_limits: Optional. The limits for the ratio axis. Defaults to None, which auto-scales the axis.
        :return: None. Saves the plot as a PNG file in the specified folder.
    """
    plt.rcParams.update({'font.size': 20})
    _suffix_outfile = validate_parameter(suffix_outfile, default='')
    plt.clf()
    plt.figure(figsize=(8, 6))
    plt.scatter(training_data['avg_nh2'], training_data[f'ratio_{ratio_string}'], marker='+', alpha=0.1,
                facecolor='grey')
    plt.contour(10 ** grid[0], grid[1], values_on_grid, levels=np.arange(0.05, 0.95, 0.15))
    plt.semilogx()
    plt.xlabel(r'<$n$(H$_2$)> [cm$^{-3}$]')
    plt.ylabel(f'Ratio {line_ratio_mapping[ratio_string]}')
    plt.ylim(ratio_limits)
    plt.tight_layout()
    plt.savefig(os.path.join(
        model_root_folder,
        'figures',
        f'ratio_vs_avg_density_los_kde_{ratio_string}{_suffix_outfile}.png'))


def get_results(x_array: np.array,
                probability_density: np.array,
                probability_threshold: float = 0.05,
                interp_points: int = 100) -> Tuple[float, float, List[float]]:
    """
    Perform the data fitting, computing the best fit and the highest probability density interval
    :param x_array: the density array
    :param probability_density: the probability density values
    :param probability_threshold: the probability mass contained in the wings
    :param interp_points: the number of points used to interpolate the probability density
    :return: the probability density threshold, the best-fit value for density, and the HPD interval
    """
    probability, centered_probability_density, ordered_idxs = \
        get_probability_distribution(probability_density=probability_density,
                                     x_array=x_array)
    probability_density_threshold = get_probability_density_threshold(
        ordered_probability=probability[ordered_idxs],
        ordered_probability_density=centered_probability_density[ordered_idxs],
        probability_threshold=probability_threshold
    )
    hpd_interval = get_hpd_interval(x_array=x_array,
                                    probability_density=probability_density,
                                    hpd_threshold=probability_density_threshold,
                                    interp_points=interp_points)
    return probability_density_threshold, x_array[ordered_idxs][-1], hpd_interval


def get_probability_density_threshold(ordered_probability: np.array,
                                      ordered_probability_density: np.array,
                                      probability_threshold=0.05) -> float:
    """
    Compute the threshold in probability density that leaves out the requested probability mass
    :param ordered_probability: the probability array, ordered according to increasing probability density
    :param ordered_probability_density: the ordered probability density array, in ascending order
    :param probability_threshold: the probability mass in the tails, outside the HPD
    :return: the threshold in HPD that leaves the specified probability mass in the tails
    """
    ctrapz_ordered = np.cumsum(ordered_probability)
    threshold_idx = np.argmin(np.abs(ctrapz_ordered - probability_threshold))

    lower_index = np.max([threshold_idx - 1, 0])
    upper_index = np.min([threshold_idx + 1, len(ctrapz_ordered) - 1])
    x_interp = np.linspace(lower_index, upper_index, 200)
    interpolated_probability = np.interp(x_interp,
                                         [lower_index, threshold_idx, upper_index],
                                         np.take(ctrapz_ordered, [lower_index, threshold_idx, upper_index]))
    interp_idx = np.argmin(abs(interpolated_probability - probability_threshold))
    weighted_probability_density_threshold = np.interp(x_interp,
                                                       [lower_index, threshold_idx, upper_index],
                                                       np.take(ordered_probability_density,
                                                               [lower_index, threshold_idx, upper_index]))
    return weighted_probability_density_threshold[interp_idx]


def get_hpd_interval(x_array: np.array,
                     probability_density: np.array,
                     hpd_threshold: float,
                     interp_points: int = 100) -> List[float]:
    """
    Compute the highest probability density interval, that leaves out the given probability density mass
    :param x_array: the density array
    :param probability_density: the probability density array
    :param hpd_threshold: the probability density threshold
    :param interp_points: the number of points used to interpolate the probability density
    :return: a list of the boundaries fot the HPD (can be composed by multiple intervals
    """
    hpd_interval = []
    elements_above_threshold = np.argwhere(probability_density - hpd_threshold >= 0).flatten()
    intervals = np.split(elements_above_threshold, np.nonzero(np.diff(elements_above_threshold) != 1)[0] + 1)
    idx_list = []
    try:
        for interval in intervals:
            idx_list.append([np.min(interval) - 1, np.max(interval)])

        flat_list = [item for sublist in idx_list for item in sublist]
        for idx in flat_list:
            if idx < 0:
                hpd_interval.append(x_array[0])
            elif idx == len(x_array) - 1:
                hpd_interval.append(x_array[-1])
            else:
                x_interp = np.linspace(x_array[idx], x_array[idx + 1], interp_points)
                interpolated_probability_density = np.interp(x_interp, x_array[idx:idx + 2],
                                                             probability_density[idx:idx + 2])
                interp_idx = np.argmin(abs(interpolated_probability_density - hpd_threshold))
                hpd_interval.append(x_interp[interp_idx])
    except ValueError:
        pass

    return hpd_interval


def get_probability_distribution(probability_density: np.array,
                                 x_array: np.array) -> Tuple[np.array, np.array, np.array]:
    """
    Compute the probability mass distribution, and returns the centered probability density array, and the ordered
        indices that sort the probability density in ascending order
    :param probability_density: the probability density array
    :param x_array: the density array
    :return: the normalized probability mass distribution, the centered probability density array, and the indices that
        sort the probability density array in ascending order
    """
    dx_array = np.diff(x_array)
    avg_y_array = 0.5 * (probability_density[1:] + probability_density[0:-1])
    probability = avg_y_array * dx_array
    probability /= probability.sum()
    ordered_idxs = np.argsort(avg_y_array)
    return probability, avg_y_array, ordered_idxs


def compute_best_bandwidth(data: pd.DataFrame,
                           bandwidths: List[float],
                           line_pairs: List[List[str]],
                           nthreads: int = 15,
                           crossvalidation_folds: int = 5,
                           training_thinning: int = 10,
                           kernel: str = None) -> List[float]:
    """
    Function that compute the best bandwidth for the KDE via cross-validation. Consider that due to the rough spacing
    in the characteristic density of the clumps the best bandwidth is too small. We suggest scaling it up by a factor
    of 2.
    :param data: the data to be used for the cross validation
    :param bandwidths: the list of bandwidths to be tested
    :param line_pairs: the list of ratios for which the best bandwidth should be determined
    :param nthreads: the number of threads to use
    :param crossvalidation_folds: the number of folds for cross validation
    :param training_thinning: thinning factor for training data set (reduces computational times)
    :param kernel: the kernel type, as supported by sklearn; defaults to epanechnikov
    :return: the list of best bandwidths, one for each ratio
    """
    _kernel = validate_parameter(kernel, 'epanechnikov')
    best_bandwidths = []
    for ratio in line_pairs:
        ratio_string = '-'.join(ratio)
        y_train = data[[f'ratio_{ratio_string}']]
        model_y = KernelDensity(kernel=_kernel)
        cv_grid_y = GridSearchCV(model_y, {'bandwidth': bandwidths}, cv=crossvalidation_folds, n_jobs=nthreads,
                                 verbose=4)
        cv_grid_y.fit(y_train.sample(frac=1 / training_thinning))
        best_bandwidths.append(cv_grid_y.best_estimator_)
    return best_bandwidths


def recompute_and_save_kdes(data: pd.DataFrame,
                            line_pairs: List[List[str]],
                            points_per_axis: int,
                            nthreads: int = 10,
                            pickle_models_dict_filename: Union[None, str] = None,
                            model_root_folder: Union[None, str] = None,
                            best_bandwidths: Union[None, dict] = None,
                            ratio_limits: Union[None, dict] = None):
    """
    Retrieve the dictionary of the KDE models, either by computing it or unpickling it from previous runs
    :param points_per_axis: number of points for the KDE grid evaluation
    :param data: the data to be used for retraining
    :param line_pairs: all the lines pairs that constitute the ratios
    :param nthreads: the numbers of separate threads to use
    :param pickle_models_dict_filename: the filename of the pickle file containing the model dictionary
    :param model_root_folder: the folder referring to the model used for computation;
        defaults to fiducial model (constant_abundance_p15_q05)
    :param best_bandwidths: kernel bandwidths to use for each ratio
    :param ratio_limits: fixed plotting limits
    """
    _model_root_folder = validate_parameter(
        model_root_folder,
        default=os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05')
    )
    _pickle_models_dict_filename = validate_parameter(pickle_models_dict_filename, default='models_dict.pickle')

    models_dict = {}
    default_bandwidths = {
        '87-86': 0.1,
        '88-86': 0.1,
        '88-87': 0.1,
        '257-256': 0.4,
        '381-380': 0.2,
    }
    _best_bandwidths = validate_parameter(best_bandwidths, default=default_bandwidths)
    parallel_args = product(line_pairs, [data], [points_per_axis],
                            [_best_bandwidths], [_model_root_folder], [ratio_limits])
    with Pool(nthreads) as pool:
        results = pool.starmap(train_kde_model, parallel_args)

    for ratio_string, model in results:
        models_dict[ratio_string] = model

    with open(os.path.join(_model_root_folder, 'trained_model', _pickle_models_dict_filename), 'wb') as outfile:
        pickle.dump(models_dict, outfile)


def plot_results(log_x_grid: np.array,
                 probability_density: np.array,
                 probability_density_threshold: float,
                 hpd_interval: List[float],
                 source_name: Union[str, None] = None,
                 model_root_folder: Union[None, str] = None):
    """
    Plot the results of the fitting, the posterior of the number density of H2
    :param source_name: the source name string
    :param hpd_interval: the list of HPD interval boundaries
    :param log_x_grid: the logarithmic grid in number density
    :param probability_density: the array with the probability density values
    :param probability_density_threshold: the threshold where the probability density is to be cut to get the HPD
        interval
    :param model_root_folder: the folder referring to the model used for computation;
        defaults to fiducial model (constant_abundance_p15_q05)
    """
    _model_root_folder = validate_parameter(
        model_root_folder,
        default=os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05')
    )
    if (source_name is not None) and (source_name != 'default_source'):
        _source_name_string = source_name
    else:
        _source_name_string = ''
    plt.figure(figsize=(8, 6))
    plt.rcParams.update({'font.size': 16})
    plt.plot(10 ** log_x_grid, probability_density, marker='x')
    plt.axhline(y=probability_density_threshold, color='r', linestyle='--')
    plt.xlim(0.5 * np.min(np.array(hpd_interval)),
             2 * np.max(np.array(hpd_interval)))
    plt.semilogx()
    plt.xlabel(r'Number density [cm$^{-3}$]')
    plt.ylabel(r'Probability density')
    plt.title('$n$(H$_{2}$)' + f' posterior for {_source_name_string}')
    plt.tight_layout()
    plt.savefig(os.path.join(_model_root_folder, 'figures', f'density_pdf{_source_name_string}.png'))
    plt.clf()


def normalize_probability_density(log_x_grid: np.array,
                                  probability_density: np.array) -> np.array:
    """
    Normalize the probability density, so that the integral of the probability is 1
    :param log_x_grid: the logarithmic grid in number density
    :param probability_density: the array with the probability density values
    :return: the normalized probability density function
    """
    ecdf = cumtrapz(probability_density, 10 ** log_x_grid)
    renorm_constant = ecdf[-1]
    ecdf /= renorm_constant
    probability_density /= renorm_constant
    return probability_density, ecdf


def compute_probability_density(x_grid: np.array,
                                measured_integrated_intensities: Tuple[float, float],
                                integrated_intensities_uncertainty: Tuple[float, float],
                                model: KernelDensity,
                                ratio_realizations: int = 1000,
                                best_bandwidth: Union[None, float] = None,
                                ratio_string: Union[None, str] = None) -> np.array:
    """
    Compute the PDF from the model and the realizations of the ratio, given its measurement and the uncertainties.
    For now, the ratio meaaurement is assumed to have a gaussian uncertainty.
    :param x_grid: the grid in number density
    :param measured_integrated_intensities: the measured value of the integrated intensities
    :param integrated_intensities_uncertainty: the rms uncertainty on the integrated intensity
    :param model: the KDE model for the ratio
    :param ratio_realizations: the number of realizations of the ratio, to compute the PDF
    :param best_bandwidth: the bandwidth to use for KDE
    :return: a dictionary containing the PDF of the density and the ratio string
    """
    _best_bandwidth = validate_parameter(best_bandwidth, 0.2)
    _ratio_string = validate_parameter(ratio_string, default='not_specified')
    simulated_integrated_intensities_0 = get_truncated_normal(mean=measured_integrated_intensities[0],
                                                              sd=integrated_intensities_uncertainty[0],
                                                              lower_bound=1e-5,
                                                              upper_bound=measured_integrated_intensities[0] + 10 *
                                                                          integrated_intensities_uncertainty[0]) \
        .rvs(size=ratio_realizations)
    simulated_integrated_intensities_1 = get_truncated_normal(mean=measured_integrated_intensities[1],
                                                              sd=integrated_intensities_uncertainty[1],
                                                              lower_bound=1e-5,
                                                              upper_bound=measured_integrated_intensities[1] + 10 *
                                                                          integrated_intensities_uncertainty[1]) \
        .rvs(size=ratio_realizations)
    simulated_ratios = simulated_integrated_intensities_0 / simulated_integrated_intensities_1
    interp_function = RegularGridInterpolator((model['x'], model['y']),
                                              model['values'].reshape(model['x'].shape[0], model['y'].shape[0]),
                                              bounds_error=False,
                                              fill_value=0)
    positions = np.array(list(product(x_grid / 0.2, simulated_ratios / _best_bandwidth)))
    z = interp_function(positions)
    probability_density = np.nansum(z.reshape((len(x_grid), ratio_realizations)).T, axis=0)
    probability_density /= np.trapz(probability_density, x_grid)
    return {_ratio_string: probability_density}


def main(measured_integrated_intensity_dict: dict,
         integrated_intensity_uncertainty_dict: dict,
         ratio_list: list,
         probability_threshold: float,
         source_names: list,
         recompute_kde: bool = False,
         points_per_axis: int = 200,
         ratio_realizations: int = 1000,
         nthreads: int = 10,
         limit_rows: Union[None, int] = None,
         use_model_for_inference: Union[None, str] = None,
         best_bandwidths: Union[None, dict] = None,
         ratio_limits: Union[None, dict] = None):
    _use_model_for_inference = validate_parameter(
        use_model_for_inference,
        default='constant_abundance_p15_q05'
    )
    _model_root_folder = os.path.join('prs', 'output', 'run_type', _use_model_for_inference)
    data, line_pairs = get_inference_data(use_model_for_inference=_use_model_for_inference,
                                          limit_rows=limit_rows)
    measured_integrated_intensity_coupled = {}
    integrated_intensity_uncertainty_coupled = {}
    if recompute_kde is True:
        recompute_and_save_kdes(data=data,
                                line_pairs=line_pairs,
                                nthreads=nthreads,
                                points_per_axis=points_per_axis,
                                model_root_folder=_model_root_folder,
                                best_bandwidths=best_bandwidths,
                                ratio_limits=ratio_limits)

    x_grid = np.linspace(np.log10(0.7 * np.nanmin(data['avg_nh2'])),
                         np.log10(1.3 * np.nanmax(data['avg_nh2'])), points_per_axis)

    results_dict = {}
    posteriors = {
        'PDF': {},
        'ECDF': {},
        'density_grid': {}
    }
    for source_name in source_names:
        _ratio_list_per_source = []
        for ratio_string in ratio_list:
            line_ids = validate_line_ids(line_pairs, ratio_string)
            if np.isnan(measured_integrated_intensity_dict[line_ids[0]][source_name]) or \
                    np.isnan(measured_integrated_intensity_dict[line_ids[1]][source_name]):
                logger.warning(f'The ratio {ratio_string} is not available for source {source_name}. Proceeding with the remaining ratios...')
            else:
                measured_integrated_intensity_coupled[ratio_string] = [
                    measured_integrated_intensity_dict[line_ids[0]][source_name],
                    measured_integrated_intensity_dict[line_ids[1]][source_name]
                ]
                integrated_intensity_uncertainty_coupled[ratio_string] = [
                    integrated_intensity_uncertainty_dict[line_ids[0]][source_name],
                    integrated_intensity_uncertainty_dict[line_ids[1]][source_name]
                ]
                _ratio_list_per_source.append(ratio_string)

        # Compute combined probability density as the product of the individual ratio probability densities
        parallel_args = []
        for ratio_string in _ratio_list_per_source:
            with open(
                    os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
            ) as infile:
                model_dict = pickle.load(infile)
            parallel_args.append([x_grid,
                                  measured_integrated_intensity_coupled[ratio_string],
                                  integrated_intensity_uncertainty_coupled[ratio_string],
                                  model_dict,
                                  ratio_realizations,
                                  best_bandwidths[ratio_string],
                                  ratio_string])
        with Pool(nthreads) as pool:
            results = pool.starmap(compute_probability_density, parallel_args)
        results = reduce(operator.ior, results, {})
        combined_probability_density = np.array(list(results.values())).prod(axis=0)
        if np.isnan(combined_probability_density.sum()):
            logger.warning(f'Source {source_name} does not have ratios compatible with modelling with this grid')
            problematic_ratios = [key for key in results.keys() if np.isnan(results[key].sum())]
            logger.warning(f'The ratios outside the simulated boundaries are {problematic_ratios}')
            results_dict[source_name] = {
                'best_fit': np.nan,
                'hpd_interval': [np.nan, np.nan]
            }
        else:
            combined_probability_density, combined_ecdf = normalize_probability_density(x_grid,
                                                                                        combined_probability_density)

            probability_density_threshold, best_fit, hpd_interval = \
                get_results(x_array=10 ** x_grid,
                            probability_density=combined_probability_density,
                            probability_threshold=probability_threshold)
            plot_results(log_x_grid=x_grid,
                         probability_density=combined_probability_density,
                         probability_density_threshold=probability_density_threshold,
                         hpd_interval=hpd_interval,
                         source_name=source_name,
                         model_root_folder=_model_root_folder)
            logger.info(f'Processed source {source_name}')
            logger.info(f'The best fit value is {best_fit}, with an HPD interval of {hpd_interval}')
            logger.info(f'The threshold is {probability_density_threshold}')
            results_dict[source_name] = {
                'best_fit': float(best_fit),
                'hpd_interval': [float(element) for element in hpd_interval]
            }
            posteriors['PDF'][source_name] = combined_probability_density
            posteriors['ECDF'][source_name] = combined_ecdf
            posteriors['density_grid'][source_name] = 10 ** x_grid
        with open(os.path.join(_model_root_folder, 'volume_density_results.yml'), 'w') as outfile:
            yaml.dump(results_dict, outfile)
        with open(os.path.join(_model_root_folder, 'posteriors.pickle'), 'wb') as outfile:
            pickle.dump(posteriors, outfile)


def validate_line_ids(line_pairs, ratio_string):
    line_ids = ratio_string.split('-')
    if line_ids not in line_pairs:
        raise ValueError(f'Specified ratio {ratio_string} not present in the model ({line_pairs})')
    return line_ids


def get_inference_data(use_model_for_inference: str, limit_rows: int):
    line_pairs, data = get_postprocessed_data(limit_rows=limit_rows,
                                              use_model_for_inference=use_model_for_inference)
    data['source'] = 'RT'
    try:
        _inferred_data_file = os.path.join(os.path.join('prs', 'output', 'run_type', use_model_for_inference),
                                           'data',
                                           'inferred_data.csv')
        inferred_data = pd.read_csv(_inferred_data_file, index_col=0)
        inferred_data['source'] = 'ML'
        data = pd.concat([data, inferred_data], axis=0, ignore_index=True)
    except IOError:
        logger.warning('Inferred data not found. Proceeding with RT-generated data only...')
    return data, line_pairs


if __name__ == '__main__':
    external_input = load_config_file(config_file_path='config/density_inference_input.yml')
    logger = setup_logger(name='PRS - DENSITY INFERENCE')
    try:
        limit_rows = external_input['limit_rows']
    except KeyError:
        limit_rows = None
    try:
        use_model_for_inference = external_input['use_model_for_inference']
        if use_model_for_inference == 'PLACEHOLDER':
            logger.warning('No model specified for inference in density_inference_input.yml. Using fiducial.')
            use_model_for_inference = 'constant_abundance_p15_q05'
    except KeyError:
        use_model_for_inference = None
    try:
        points_per_axis = external_input['points_per_axis']
    except KeyError:
        points_per_axis = 200
    try:
        nthreads = external_input['nthreads']
    except KeyError:
        nthreads = 1
    try:
        sources = set()
        for line_id in external_input['measured_integrated_intensities']:
            sources = sources.union(set(external_input['measured_integrated_intensities'][line_id].keys()))
            _intensities_dict = external_input['measured_integrated_intensities']
            _uncertainty_dict = external_input['integrated_intensities_uncertainties']
    except AttributeError:
        sources = ['default_source']
        _intensities_dict = {line_id: {'default_source': external_input['measured_integrated_intensities'][line_id]}
                             for line_id in external_input['measured_integrated_intensities']}
        _uncertainty_dict = {
            line_id: {'default_source': external_input['integrated_intensities_uncertainties'][line_id]}
            for line_id in external_input['integrated_intensities_uncertainties']}
    sources = list(sources)
    sources.sort()

    logger.info(f'Using {use_model_for_inference} for inference')
    main(measured_integrated_intensity_dict=_intensities_dict,
         integrated_intensity_uncertainty_dict=_uncertainty_dict,
         source_names=sources,
         ratio_list=external_input['ratios_to_include'],
         probability_threshold=external_input['probability_threshold'],
         recompute_kde=external_input['recompute_kde'],
         points_per_axis=points_per_axis,
         ratio_realizations=external_input['simulated_ratio_realizations'],
         nthreads=nthreads,
         limit_rows=limit_rows,
         use_model_for_inference=use_model_for_inference,
         best_bandwidths=external_input['best_bandwidths'],
         ratio_limits=external_input['ratio_limits'])
