From db45b2ab3f10082838d973973992fdad48051604 Mon Sep 17 00:00:00 2001
From: "andrea.giannetti" <andrea.giannetti@inaf.it>
Date: Fri, 13 Dec 2024 09:57:39 +0100
Subject: [PATCH] Latest version update

---
 etl/assets/commons/__init__.py                |  10 ++
 etl/assets/data/radex_data.csv                |  73 ++++++++
 etl/prs/prs_analytical_representations.py     |  70 +++++---
 etl/prs/prs_density_inference.py              | 164 ++++++++++++++++--
 etl/prs/prs_make_comparison_figures.py        | 150 +++++++++-------
 .../prs_make_posterior_improvement_figure.py  |  48 +++++
 6 files changed, 411 insertions(+), 104 deletions(-)
 create mode 100644 etl/assets/data/radex_data.csv
 create mode 100644 etl/prs/prs_make_posterior_improvement_figure.py

diff --git a/etl/assets/commons/__init__.py b/etl/assets/commons/__init__.py
index 3e5f5a6..f803550 100644
--- a/etl/assets/commons/__init__.py
+++ b/etl/assets/commons/__init__.py
@@ -83,6 +83,16 @@ def load_config_file(config_file_path: str,
     return config
 
 
+def moving_average(x: np.array , window_size: int) -> np.array:
+    """
+    Compute the moving average of a 1D array
+    :param x: the 1D array to compute the moving average on
+    :param window_size: the size of the moving average window
+    :return: the moving average of the array
+    """
+    return np.convolve(x, np.ones(window_size), 'valid') / window_size
+
+
 def get_moldata(species_names: list,
                 logger: logging.Logger,
                 path: Union[str, None] = None,
diff --git a/etl/assets/data/radex_data.csv b/etl/assets/data/radex_data.csv
new file mode 100644
index 0000000..c0bc3e5
--- /dev/null
+++ b/etl/assets/data/radex_data.csv
@@ -0,0 +1,73 @@
+rot_trans,log_col_dens,log_density,temperature,line1,line2,line3
+2-1,15.4,3,10,0.78,0.17,0.01
+2-1,15.4,3,30,1.8,0.35,0.018
+2-1,15.4,4,10,3.37,0.78,0.04
+2-1,15.4,4,30,7.96,1.28,0.07
+2-1,15.4,5,10,5.64,2.43,0.34
+2-1,15.4,5,30,12.5,2.98,0.46
+2-1,15.4,6,10,6.11,4.11,1.41
+2-1,15.4,6,30,8.49,3.77,1.34
+2-1,15.4,7,10,5.96,4.62,2.12
+2-1,15.4,7,30,5.68,3.63,2.06
+2-1,15.4,8,10,5.84,4.73,2.30
+2-1,15.4,8,30,4.13,3.63,2.24
+2-1,14,3,10,0.069,0.0042,0.0001
+2-1,14,3,30,0.16,0.0091,0.00026
+2-1,14,4,10,0.47,0.028,0.00083
+2-1,14,4,30,0.85,0.046,0.0019
+2-1,14,5,10,0.91,0.085,0.0076
+2-1,14,5,30,0.86,0.09,0.014
+2-1,14,6,10,0.75,0.20,0.05
+2-1,14,6,30,0.46,0.14,0.05
+2-1,14,7,10,0.59,0.28,0.09
+2-1,14,7,30,0.28,0.15,0.08
+2-1,14,8,10,0.52,0.31,0.1
+2-1,14,8,30,0.19,0.15,0.09
+5-4,15.4,3,10,1.1e-3,2.3e-2
+5-4,15.4,3,30,1.3e-2,0.26,
+5-4,15.4,4,10,1.7e-2,0.32,
+5-4,15.4,4,30,0.22,2.33,
+5-4,15.4,5,10,0.25,2.14,
+5-4,15.4,5,30,2.72,10.44,
+5-4,15.4,6,10,1.43,3.85,
+5-4,15.4,6,30,6.79,16.5,
+5-4,15.4,7,10,2.29,3.97,
+5-4,15.4,7,30,7.76,13.87,
+5-4,15.4,8,10,2.50,3.88,
+5-4,15.4,8,30,8.07,11.0,
+5-4,14,3,10,3.6e-5,8.2e-4,
+5-4,14,3,30,4.5e-4,1.0e-2,
+5-4,14,4,10,4.1e-4,0.010,
+5-4,14,4,30,4.7e-3,0.12,
+5-4,14,5,10,5.4e-3,0.13,
+5-4,14,5,30,5.2e-2,0.92,
+5-4,14,6,10,4.8e-2,0.32,
+5-4,14,6,30,0.24,1.30,
+5-4,14,7,10,0.11,0.31,
+5-4,14,7,30,0.36,0.87
+5-4,14,8,10,0.13,0.29,
+5-4,14,8,30,0.39,0.60,
+7-6,15.4,3,10,1.4e-5,3.3e-4,
+7-6,15.4,3,30,1.4e-3,2.6e-2,
+7-6,15.4,4,10,2.0e-4,5.7e-3,
+7-6,15.4,4,30,2.0e-2,3.8e-1,
+7-6,15.4,5,10,4.8e-3,0.1,
+7-6,15.4,5,30,0.47,4.1,
+7-6,15.4,6,10,8.6e-2,0.52,
+7-6,15.4,6,30,3.54,11.37,
+7-6,15.4,7,10,0.26,0.67,
+7-6,15.4,7,30,5.46,10.82,
+7-6,15.4,8,10,0.32,0.66,
+7-6,15.4,8,30,5.97,8.53,
+7-6,14,3,10,5.3e-7,1.1e-5,
+7-6,14,3,30,5.4e-5,8.6e-4,
+7-6,14,4,10,6.2e-6,1.6e-4,
+7-6,14,4,30,6.2e-4,1.2e-2,
+7-6,14,5,10,1.2e-4,3.3e-3,
+7-6,14,5,30,1.1e-2,0.19,
+7-6,14,6,10,2.5e-3,2.2e-2,
+7-6,14,6,30,0.11,0.74,
+7-6,14,7,10,9.7e-3,3.0e-2,
+7-6,14,7,30,0.24,0.62,
+7-6,14,8,10,1.3e-2,2.9e-2,
+7-6,14,8,30,0.27,0.44,
\ No newline at end of file
diff --git a/etl/prs/prs_analytical_representations.py b/etl/prs/prs_analytical_representations.py
index 09ce4d9..bec3f56 100644
--- a/etl/prs/prs_analytical_representations.py
+++ b/etl/prs/prs_analytical_representations.py
@@ -2,17 +2,19 @@ import numpy as np
 import os
 import matplotlib.pyplot as plt
 import pickle
-from scipy.interpolate import splrep, BSpline, RegularGridInterpolator
 from scipy.optimize import curve_fit
 from assets.commons import (get_data,
                             load_config_file,
                             setup_logger,
                             validate_parameter)
+import yaml
 from typing import Union, List
 
 
-def approx(x, x_shift, a, b, exponent, scale=1):
-    return (np.abs(np.arctanh(2 * (x - x_shift) / scale) + a)) ** exponent + b
+def approx(x, x_shift, b, exponent, scale):
+    arg = 2 * (x - x_shift) / scale
+    clean_arg = np.where(np.abs(arg) > 0.99999, np.sign(arg) * 0.99999, arg)
+    return (np.arctanh(clean_arg) + b) ** exponent
 
 
 def main(ratios_to_fit: Union[List[str], None] = None):
@@ -25,19 +27,12 @@ def main(ratios_to_fit: Union[List[str], None] = None):
     logger.info(f'Using {config["run_type"]} to fit data...')
     logger.info(f'Data contains {data.shape[0]} rows')
 
-    bw = {
-        '87-86': 0.1,
-        '88-87': 0.1,
-        '88-86': 0.1,
-        '257-256': 0.4,
-        '381-380': 0.2}
-
     x0 = {
-        '87-86': ((0.55, 2.3, 3.75, 0.9),),
-        '88-87': ((0.59, 2.3, 3.9, 1.05, 0.96),),
-        '88-86': ((0.55, 2.5, 4.3, 0.83, 1.08),),
-        '257-256': ((5, 2.5, 2.45, 0.75, 8), (5, 2.5, 3.33, 0.75, -8)),
-        '381-380': ((2.04, 2.5, 3.1, 0.6, 2.1), (2.18, 2.5, 3.7, 0.79, -2.4)),
+        '87-86': ((0.55, 6.75, 0.9, 1),),
+        '88-87': ((0.59, 6.9, 1.05, 0.96),),
+        '88-86': ((0.55, 6.3, 0.83, 1.08),),
+        '257-256': ((5, 4.45, 0.75, 8), (4.5, 8.5, 0.79, -7)),
+        '381-380': ((2.04, 5.1, 0.6, 2.1), (2.18, 5.7, 0.79, -2.4)),
     }
     min_ratio = {
         '87-86': 0.08,
@@ -51,31 +46,52 @@ def main(ratios_to_fit: Union[List[str], None] = None):
         '88-86': 1.07,
         '257-256': 7.5,
         '381-380': 2.9}
+    density_limits = {
+        '87-86': ((5e4, 2e7),),
+        '88-87': ((1e5, 3e7),),
+        '88-86': ((1e5, 3e7),),
+        '257-256': ((2e3, 7.5e4), (7.5e4, 4e6)),
+        '381-380': ((1e4, 1.9e5), (1.9e5, 1e7))
+    }
+
+    best_fit_params = {}
 
     for ratio_string in _ratio_to_fit:
         _column_to_fit = f'ratio_{ratio_string}'
         data.sort_values(by=_column_to_fit, inplace=True)
         plt.clf()
-        plt.scatter(data['avg_nh2'], data[_column_to_fit], marker='+', alpha=0.1)
-        x_reg = np.linspace(min_ratio[ratio_string], max_ratio[ratio_string], 100)
+        x_reg = np.linspace(min_ratio[ratio_string], max_ratio[ratio_string], 1000)
+
+        kde_max = pickle.load(
+            open(f'prs/output/run_type/{config["run_type"]}/kde_smoothed_{ratio_string}_max_locus.pickle', 'rb'))
 
-        data_clean = data[[_column_to_fit, 'avg_nh2']].loc[data[_column_to_fit] <= max_ratio[ratio_string]].copy()
-        kde = pickle.load(
-            open(f'prs/output/run_type/{config["run_type"]}/trained_model/ratio_density_kde_{ratio_string}.pickle',
-                 'rb'))
-        rgi = RegularGridInterpolator((kde['y'] * bw[ratio_string], kde['x'] * 0.2), kde['values'].reshape(200, 200).T)
-        density = rgi(data_clean[[_column_to_fit, 'avg_nh2']])
-        data_clean = data_clean.loc[(density > density.mean() - density.std())]
-        plt.scatter(data_clean['avg_nh2'], data_clean[_column_to_fit], marker='+', alpha=0.1, color='red')
+        plt.plot(kde_max[1], kde_max[0], color='red')
+        plt.semilogy()
         plt.xlabel('log(<n(H2)>)')
         plt.ylabel('Ratio')
-        for components in x0[ratio_string]:
-            plt.plot(approx(x_reg, *components),x_reg , color='cyan')
+        best_fit_params[ratio_string] = []
+        for guesses, limits in zip(x0[ratio_string], density_limits[ratio_string]):
+            clean_mask = (~np.isnan(kde_max[0]) & ~np.isnan(kde_max[1]) &
+                          (kde_max[0] >= limits[0])
+                          & (kde_max[0] <= limits[1]))
+            best_fit, par_cov = curve_fit(approx,
+                                          kde_max[1][clean_mask],
+                                          np.log10(kde_max[0][clean_mask]),
+                                          p0=guesses,
+                                          bounds=((0.1, 2, 0.3, -9), (7, 10, 1.5, 9)))
+            approx_density = 10 ** approx(x_reg, *best_fit)
+            print(best_fit)
+            best_fit_params[ratio_string].append({param: float(element) for element, param in zip(best_fit, ['a', 'b', 'nu', 's'])})
+            approx_density = np.where((approx_density >= limits[0]) & (approx_density <= limits[1]), approx_density, np.nan)
+            plt.plot(x_reg, approx_density)
         plt.savefig(os.path.join('..',
         'publications',
         '6373bb408e4040043398e495',
         'referee',
         f'analytical_expressions_comparison_{ratio_string}.png'))
+        with (open(os.path.join('prs', 'output', 'run_type', 'constant_abundance_p15_q05', 'best_fit_params.yml'), 'w')
+              as outfile):
+            yaml.dump(best_fit_params, outfile)
 
 
 if __name__ == '__main__':
diff --git a/etl/prs/prs_density_inference.py b/etl/prs/prs_density_inference.py
index d464d61..24b6f7e 100644
--- a/etl/prs/prs_density_inference.py
+++ b/etl/prs/prs_density_inference.py
@@ -11,13 +11,14 @@ from sklearn.model_selection import GridSearchCV
 from assets.commons import (load_config_file,
                             validate_parameter,
                             setup_logger,
-                            get_postprocessed_data)
+                            get_postprocessed_data,
+                            moving_average)
 from assets.constants import line_ratio_mapping
 from multiprocessing import Pool
 from typing import Tuple, Union, List
 from scipy.integrate import cumtrapz
 from scipy.stats import truncnorm
-from scipy.interpolate import RegularGridInterpolator
+from scipy.interpolate import RegularGridInterpolator, InterpolatedUnivariateSpline
 from functools import reduce
 
 
@@ -35,6 +36,7 @@ def train_kde_model(ratio: List[str],
                     best_bandwidth: Union[None, dict],
                     model_root_folder: Union[str, None] = None,
                     ratio_limits: Union[None, dict] = None,
+                    plot_radex_data: Union[None, bool] = False,
                     rt_adjustment_factor: Union[int, float] = 2) -> Tuple[str, KernelDensity]:
     """
     Create the KDE from the modelled datapoints.
@@ -48,6 +50,7 @@ def train_kde_model(ratio: List[str],
     :param rt_adjustment_factor: a scaling factor to apply to the specified bandwidth for the computation of the KDE
         for the sparser RT-only data
     :param ratio_limits: fixed plotting limits
+    :param plot_radex_data: whether to plot the radex data
     :return: the string representing the ratio modelled and the KDE itself
     """
     _model_root_folder = validate_parameter(
@@ -62,11 +65,13 @@ def train_kde_model(ratio: List[str],
     grid = scaled_grid.copy()
     grid[0] = scaled_grid[0] * 0.2
     grid[1] = scaled_grid[1] * best_bandwidth[ratio_string]
+    radex_data = get_radex_data() if plot_radex_data else None
     plot_kde_ratio_nh2(grid=grid,
                        values_on_grid=z.reshape(points_per_axis, points_per_axis),
                        ratio_string=ratio_string,
                        model_root_folder=_model_root_folder,
                        training_data=training_data,
+                       additional_data=radex_data,
                        suffix_outfile='_ml',
                        ratio_limits=ratio_limits[ratio_string])
 
@@ -83,6 +88,7 @@ def train_kde_model(ratio: List[str],
                        ratio_string=ratio_string,
                        model_root_folder=_model_root_folder,
                        training_data=training_data[training_data['source'] == 'RT'],
+                       additional_data=radex_data,
                        ratio_limits=ratio_limits[ratio_string])
     with open(
             os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'wb'
@@ -140,6 +146,7 @@ def plot_kde_ratio_nh2(grid: np.array,
                        ratio_string: str,
                        model_root_folder: str,
                        training_data: pd.DataFrame,
+                       additional_data: pd.DataFrame = None,
                        suffix_outfile: str = None,
                        ratio_limits: Union[None, list] = None):
     """
@@ -153,24 +160,152 @@ def plot_kde_ratio_nh2(grid: np.array,
         :param training_data: The DataFrame containing the training data.
         :param suffix_outfile: Optional. The suffix to append to the output file name. Defaults to an empty string if None.
         :param ratio_limits: Optional. The limits for the ratio axis. Defaults to None, which auto-scales the axis.
+        :param additional_data: Optional. The DataFrame containing the additional data to plot.
         :return: None. Saves the plot as a PNG file in the specified folder.
     """
     plt.rcParams.update({'font.size': 20})
     _suffix_outfile = validate_parameter(suffix_outfile, default='')
     plt.clf()
     plt.figure(figsize=(8, 6))
-    plt.scatter(training_data['avg_nh2'], training_data[f'ratio_{ratio_string}'], marker='+', alpha=0.1,
-                facecolor='grey')
-    plt.contour(10 ** grid[0], grid[1], values_on_grid, levels=np.arange(0.05, 0.95, 0.15))
+    smoothed_max_locus_density, smoothed_max_locus_ratio = plot_kde_results(grid=grid,
+                                                                            ratio_string=ratio_string,
+                                                                            training_data=training_data,
+                                                                            values_on_grid=values_on_grid)
+    if (additional_data is not None) and (ratio_string in additional_data.columns):
+        plt.scatter(10 ** additional_data['log_density'], additional_data[ratio_string], marker='*', alpha=1,
+                    facecolor='red', s=100)
     plt.semilogx()
     plt.xlabel(r'<$n$(H$_2$)> [cm$^{-3}$]')
     plt.ylabel(f'Ratio {line_ratio_mapping[ratio_string]}')
     plt.ylim(ratio_limits)
     plt.tight_layout()
+    plt.legend(loc='best')
     plt.savefig(os.path.join(
         model_root_folder,
         'figures',
         f'ratio_vs_avg_density_los_kde_{ratio_string}{_suffix_outfile}.png'))
+    with open(os.path.join(model_root_folder, f'kde_smoothed_{ratio_string}_max_locus.pickle'), 'wb') as outfile:
+        pickle.dump([smoothed_max_locus_density, smoothed_max_locus_ratio], outfile)
+
+
+def plot_kde_results(grid: np.array,
+                     ratio_string: str,
+                     training_data: pd.DataFrame,
+                     values_on_grid: np.array,
+                     plot_colours: Union[None, Tuple[str, str]] = None,
+                     plot_training_data: bool = True,
+                     colour_scatter: Union[None, str] = None) -> Tuple[np.array, np.array]:
+    _plot_colours = validate_parameter(plot_colours, default=('gold', 'cornflowerblue'))
+    _colour_scatter = validate_parameter(colour_scatter, default='grey')
+    if plot_training_data:
+        plt.scatter(training_data['avg_nh2'], training_data[f'ratio_{ratio_string}'], marker='+', alpha=0.1,
+                    facecolor=_colour_scatter)
+    values_on_grid_norm = values_on_grid.copy() / np.max(values_on_grid, axis=1, keepdims=True)
+    density_plot_values = 10 ** grid[0][:, 0]
+    for hpd_threshold, colour, label in zip((0.9, 0.5), _plot_colours, ('90% HPD', '50% HPD')):
+        densities, limits_hpd = get_hpd_relation(grid, values_on_grid_norm, hpd_threshold=hpd_threshold)
+        density_plot_values = 10 ** moving_average(densities, 20)
+        ratio_lower_limits = moving_average(limits_hpd[:, 0], 20)
+        ratio_upper_limits = moving_average(limits_hpd[:, 1], 20)
+        plt.fill_between(density_plot_values,
+                         ratio_lower_limits,
+                         ratio_upper_limits, alpha=0.3, facecolor=colour,
+                         label=label)
+        plt.plot(density_plot_values, ratio_lower_limits, color=colour)
+        plt.plot(density_plot_values, ratio_upper_limits, color=colour)
+    maxdens, maxratio = get_max_locus(grid, values_on_grid)
+    smoothed_max_locus_density = 10 ** moving_average(maxdens, 20)
+    smoothed_max_locus_ratio = moving_average(maxratio, 20)
+    mask_condition = (smoothed_max_locus_density >= density_plot_values[0]) & \
+                     (smoothed_max_locus_density <= density_plot_values[-1])
+    smoothed_max_locus_density = np.where(mask_condition, smoothed_max_locus_density, np.nan)
+    smoothed_max_locus_ratio = np.where(mask_condition, smoothed_max_locus_ratio, np.nan)
+    plt.plot(smoothed_max_locus_density, smoothed_max_locus_ratio, label='Max PDF', color=_plot_colours[1])
+    return smoothed_max_locus_density, smoothed_max_locus_ratio
+
+
+def get_max_locus(grid, values_on_grid):
+    """
+    Compute the maximum of the KDE for each density value.
+
+    Parameters
+    ----------
+    grid : tuple
+        The grid of values used for the KDE.
+    values_on_grid : array_like
+        The computed KDE values on the grid.
+
+    Returns
+    -------
+    max_density : array_like
+        The maximum density values.
+    max_ratio : array_like
+        The corresponding maximum ratio values.
+    """
+    # Find the maximum ratio value for each density value
+    max_ratio = grid[1][0, np.argmax(values_on_grid, axis=1)]
+    # Replace missing solutions with NaN
+    max_ratio = np.where(max_ratio > 0, max_ratio, np.nan)
+    max_density = grid[0][:, 0]
+    # Apply mask from missing ratio values
+    max_density = np.where(max_ratio > 0, max_density, np.nan)
+    return max_density, max_ratio
+
+
+def get_hpd_relation(grid: Union[tuple, list], normalized_kde_values: np.array, hpd_threshold: float) -> tuple:
+    """
+    Calculate the highest posterior density (HPD) interval for a given KDE probability density.
+
+    Parameters
+    ----------
+    grid : tuple
+        A tuple containing two arrays: the density grid and the ratio grid.
+    normalized_kde_values : np.array
+        The normalized KDE values on the grid. Normalization is performed by dividing by the maximum value for each density.
+    hpd_threshold : float
+        The probability mass threshold for the HPD interval.
+    Returns
+    -------
+    densities : np.array
+        Array of density values corresponding to the HPD limits.
+    limits_hpd : np.array
+        Array of tuples, each containing the lower and upper bounds of the HPD interval.
+    """
+    hpd_limits = []
+    density_values = []
+    for index, density in enumerate(grid[0][:, 0]):
+        spline = InterpolatedUnivariateSpline(grid[1][0, :], normalized_kde_values[index, :] - (1 - hpd_threshold))
+        roots = spline.roots()
+
+        if len(roots) > 1:
+            hpd_limits.append((roots[0], roots[-1]))
+            density_values.append(density)
+        elif len(roots) == 1:
+            if (normalized_kde_values[index, :] - (1 - hpd_threshold))[0] > 0:
+                hpd_limits.append((np.min(grid[1][0, :]), roots[0]))
+            else:
+                hpd_limits.append((roots[0], np.max(grid[1][0, :])))
+            density_values.append(density)
+
+    return np.array(density_values), np.array(hpd_limits)
+
+
+def get_radex_data():
+    df_radex = pd.read_csv(os.path.join('assets', 'data', 'radex_data.csv'))
+    df_radex['87-86'] = df_radex['line2'] / df_radex['line1']
+    df_radex['88-87'] = df_radex['line3'] / df_radex['line2']
+    df_radex['88-86'] = df_radex['line3'] / df_radex['line1']
+    df_radex['257-256'] = df_radex['line2'] / df_radex['line1']
+    df_radex['381-380'] = df_radex['line2'] / df_radex['line1']
+    df_radex.loc[df_radex['rot_trans'] == '2-1', '257-256'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '2-1', '381-380'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '5-4', '87-86'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '5-4', '88-86'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '5-4', '381-380'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '7-6', '87-86'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '7-6', '88-86'] = pd.NA
+    df_radex.loc[df_radex['rot_trans'] == '7-6', '257-256'] = pd.NA
+    return df_radex
 
 
 def get_results(x_array: np.array,
@@ -323,7 +458,8 @@ def recompute_and_save_kdes(data: pd.DataFrame,
                             pickle_models_dict_filename: Union[None, str] = None,
                             model_root_folder: Union[None, str] = None,
                             best_bandwidths: Union[None, dict] = None,
-                            ratio_limits: Union[None, dict] = None):
+                            ratio_limits: Union[None, dict] = None,
+                            plot_radex_data: Union[None, bool] = False) -> Tuple[dict, KernelDensity]:
     """
     Retrieve the dictionary of the KDE models, either by computing it or unpickling it from previous runs
     :param points_per_axis: number of points for the KDE grid evaluation
@@ -335,6 +471,8 @@ def recompute_and_save_kdes(data: pd.DataFrame,
         defaults to fiducial model (constant_abundance_p15_q05)
     :param best_bandwidths: kernel bandwidths to use for each ratio
     :param ratio_limits: fixed plotting limits
+    :param plot_radex_data: whether to plot the radex data
+    :return: the dictionary of the KDE models and the kernel density model
     """
     _model_root_folder = validate_parameter(
         model_root_folder,
@@ -352,7 +490,7 @@ def recompute_and_save_kdes(data: pd.DataFrame,
     }
     _best_bandwidths = validate_parameter(best_bandwidths, default=default_bandwidths)
     parallel_args = product(line_pairs, [data], [points_per_axis],
-                            [_best_bandwidths], [_model_root_folder], [ratio_limits])
+                            [_best_bandwidths], [_model_root_folder], [ratio_limits], [plot_radex_data])
     with Pool(nthreads) as pool:
         results = pool.starmap(train_kde_model, parallel_args)
 
@@ -474,7 +612,8 @@ def main(measured_integrated_intensity_dict: dict,
          limit_rows: Union[None, int] = None,
          use_model_for_inference: Union[None, str] = None,
          best_bandwidths: Union[None, dict] = None,
-         ratio_limits: Union[None, dict] = None):
+         ratio_limits: Union[None, dict] = None,
+         plot_radex_data: Union[None, bool] = True) -> None:
     _use_model_for_inference = validate_parameter(
         use_model_for_inference,
         default='constant_abundance_p15_q05'
@@ -491,7 +630,8 @@ def main(measured_integrated_intensity_dict: dict,
                                 points_per_axis=points_per_axis,
                                 model_root_folder=_model_root_folder,
                                 best_bandwidths=best_bandwidths,
-                                ratio_limits=ratio_limits)
+                                ratio_limits=ratio_limits,
+                                plot_radex_data=plot_radex_data)
 
     x_grid = np.linspace(np.log10(0.7 * np.nanmin(data['avg_nh2'])),
                          np.log10(1.3 * np.nanmax(data['avg_nh2'])), points_per_axis)
@@ -508,7 +648,8 @@ def main(measured_integrated_intensity_dict: dict,
             line_ids = validate_line_ids(line_pairs, ratio_string)
             if np.isnan(measured_integrated_intensity_dict[line_ids[0]][source_name]) or \
                     np.isnan(measured_integrated_intensity_dict[line_ids[1]][source_name]):
-                logger.warning(f'The ratio {ratio_string} is not available for source {source_name}. Proceeding with the remaining ratios...')
+                logger.warning(
+                    f'The ratio {ratio_string} is not available for source {source_name}. Proceeding with the remaining ratios...')
             else:
                 measured_integrated_intensity_coupled[ratio_string] = [
                     measured_integrated_intensity_dict[line_ids[0]][source_name],
@@ -650,4 +791,5 @@ if __name__ == '__main__':
          limit_rows=limit_rows,
          use_model_for_inference=use_model_for_inference,
          best_bandwidths=external_input['best_bandwidths'],
-         ratio_limits=external_input['ratio_limits'])
+         ratio_limits=external_input['ratio_limits'],
+         plot_radex_data=True)
diff --git a/etl/prs/prs_make_comparison_figures.py b/etl/prs/prs_make_comparison_figures.py
index 9732327..66861e8 100644
--- a/etl/prs/prs_make_comparison_figures.py
+++ b/etl/prs/prs_make_comparison_figures.py
@@ -5,64 +5,76 @@ import matplotlib.pyplot as plt
 import numpy as np
 from assets.commons import (load_config_file,
                             validate_parameter,
-                            setup_logger,
-                            get_postprocessed_data)
+                            setup_logger)
 from assets.constants import line_ratio_mapping
-from prs.prs_density_inference import get_inference_data
+from prs.prs_density_inference import (get_inference_data,
+                                       plot_kde_results)
 from typing import Tuple, Union, List
 
-
 filename_root_map = {
-    'isothermal_p15': 'comparison_isothermal_fiducial',
-    # 'constant_abundance_p15_q05_x01': 'comparison_lowabundance_fiducial',
-    # 'constant_abundance_p15_q05_x10': 'comparison_highabundance_fiducial',
-    'hot_core_p15_q05': 'comparison_hotcore_fiducial',
+    'abundance_comparison': ['constant_abundance_p15_q05_x01', 'constant_abundance_p15_q05_x10'],
+    'double_microturbulence': ['double_microturbulence', ],
+    'density_distribution': ['constant_abundance_p12_q05', 'constant_abundance_p18_q05']
 }
 
 
-def plot_kde_ratio_nh2(grid: np.array,
-                       values_on_grid: np.array,
-                       comparison_grid: np.array,
-                       comparison_values_on_grid: np.array,
-                       ratio_string: str,
-                       data: pd.DataFrame,
-                       comparison_data: pd.DataFrame,
-                       root_outfile: str = None,
-                       ratio_limits: Union[None, list] = None):
+def plot_kde_ratio_nh2_comparison(grid: List[np.array],
+                                  values_on_grid: List[np.array],
+                                  points_per_axis: int,
+                                  ratio_string: str,
+                                  root_outfile: str,
+                                  training_data: List[pd.DataFrame],
+                                  additional_data: pd.DataFrame = None,
+                                  suffix_outfile: str = None,
+                                  ratio_limits: Union[None, list] = None,
+                                  plot_training_data: Union[int, bool] = False):
     """
         Plot the Kernel Density Estimate (KDE) of a ratio against average H2 density along the line-of-sight and save
          the plot as a PNG file.
 
         :param grid: The grid of x and y values used for the KDE.
-        :param comparison_grid: The grid of x and y values used for the KDE of the comparison model.
         :param values_on_grid: The computed KDE values on the grid.
-        :param comparison_values_on_grid: The computed KDE values on the comparison model grid.
+        :param points_per_axis: The number of points in each axis for the KDE grid.
         :param ratio_string: The ratio string indicating which ratio of the training data to plot.
-        :param data: The DataFrame containing the data for the scatterplot.
-        :param data: The DataFrame containing the comparison model data for the scatterplot.
         :param root_outfile: The root of the filename used to save the figures.
+        :param training_data: The DataFrame containing the training data.
+        :param suffix_outfile: Optional. The suffix to append to the output file name. Defaults to an empty string if None.
         :param ratio_limits: Optional. The limits for the ratio axis. Defaults to None, which auto-scales the axis.
+        :param additional_data: Optional. The DataFrame containing the additional data to plot.
+        :param plot_training_data: Optional. Whether to plot the training data; can be the index of the training data matrix to plot.
         :return: None. Saves the plot as a PNG file in the specified folder.
     """
     plt.rcParams.update({'font.size': 20})
+    _suffix_outfile = validate_parameter(suffix_outfile, default='')
     plt.clf()
     plt.figure(figsize=(8, 6))
-    plt.scatter(data['avg_nh2'], data[f'ratio_{ratio_string}'], marker='+', alpha=0.1,
-                facecolor='grey')
-    plt.scatter(comparison_data['avg_nh2'], comparison_data[f'ratio_{ratio_string}'], marker='x', alpha=0.01,
-                facecolor='green')
-
-    plt.contour(10 ** grid[0], grid[1], values_on_grid, levels=np.arange(0.05, 0.95, 0.15),
-                colors='black')
-    plt.contour(10 ** comparison_grid[0], comparison_grid[1],
-                comparison_values_on_grid,
-                levels=np.arange(0.05, 0.95, 0.15),
-                colors='lightgreen')
+    if (additional_data is not None) and (ratio_string in additional_data.columns):
+        plt.scatter(10 ** additional_data['log_density'], additional_data[ratio_string], marker='*', alpha=1,
+                    facecolor='red', s=100)
+
+    colours = (('yellow', 'gold'), ('lightgreen', 'green'), ('lightblue', 'blue'))
+    for individual_grid, individual_values_on_grid, individual_training_data, plot_cols, idx \
+            in zip(grid, values_on_grid, training_data, colours, range(3)):
+        if plot_training_data is True:
+            _plot_training_data = True
+        elif (plot_training_data is not False):
+            _plot_training_data = True if idx == plot_training_data else False
+        else:
+            _plot_training_data = False
+        plot_kde_results(grid=individual_grid,
+                         values_on_grid=individual_values_on_grid['values'].reshape(points_per_axis,
+                                                                                    points_per_axis),
+                         training_data=individual_training_data,
+                         ratio_string=ratio_string,
+                         plot_colours=plot_cols,
+                         plot_training_data=_plot_training_data,
+                         colour_scatter='grey')
     plt.semilogx()
     plt.xlabel(r'<$n$(H$_2$)> [cm$^{-3}$]')
     plt.ylabel(f'Ratio {line_ratio_mapping[ratio_string]}')
     plt.ylim(ratio_limits)
     plt.tight_layout()
+    # plt.legend(loc='best')
     plt.savefig(os.path.join(
         'prs',
         'output',
@@ -71,42 +83,45 @@ def plot_kde_ratio_nh2(grid: np.array,
 
 
 def main(ratio_list: list,
-         comparison_model: str,
+         comparison_model: List[str],
+         root_outfile: str,
          points_per_axis: int = 200,
          limit_rows: Union[None, int] = None,
          best_bandwidths: Union[None, dict] = None,
-         ratio_limits: Union[None, dict] = None):
-    _use_model_for_inference = 'constant_abundance_p15_q05'
-    _model_root_folder = os.path.join('prs', 'output', 'run_type', _use_model_for_inference)
-    data, _ = get_inference_data(use_model_for_inference=_use_model_for_inference,
-                                          limit_rows=limit_rows)
-    comparison_data, _ = get_inference_data(use_model_for_inference=comparison_model,
-                                          limit_rows=limit_rows)
+         ratio_limits: Union[None, dict] = None,
+         plot_training_data: Union[int, bool] = False):
+    _use_model_for_inference = ['constant_abundance_p15_q05'] + comparison_model
+    data = []
+    logger.info('Getting data...')
+    for model in _use_model_for_inference:
+        df_data, _ = get_inference_data(use_model_for_inference=model,
+                                        limit_rows=limit_rows)
+        # data.append(df_data[df_data['source'] == 'RT'])
+        data.append(df_data)
 
     for ratio_string in ratio_list:
-        with open(
-            os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
-        ) as infile:
-            kde_dict = pickle.load(infile)
-        with open(
-            os.path.join(os.path.join('prs', 'output', 'run_type', comparison_model),
-                         'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
-        ) as infile:
-            comparison_kde_dict = pickle.load(infile)
-
+        logger.info(f'Processing ratio {ratio_string}')
+        kdes = []
         grids = []
-        for kde in [kde_dict, comparison_kde_dict]:
-            grids.append(get_grid(best_bandwidths, kde, ratio_string))
-
-        plot_kde_ratio_nh2(grid=grids[0],
-                           values_on_grid=kde_dict['values_rt_only'].reshape(points_per_axis, points_per_axis),
-                           ratio_string=ratio_string,
-                           root_outfile=filename_root_map[comparison_model],
-                           data=data[data['source'] == 'RT'],
-                           comparison_data=comparison_data[comparison_data['source'] == 'RT'],
-                           comparison_grid=grids[1],
-                           comparison_values_on_grid=comparison_kde_dict['values_rt_only'].reshape(points_per_axis, points_per_axis),
-                           ratio_limits=ratio_limits[ratio_string])
+        for model in _use_model_for_inference:
+            _model_root_folder = os.path.join('prs', 'output', 'run_type', model)
+
+            with open(
+                    os.path.join(_model_root_folder, 'trained_model', f'ratio_density_kde_{ratio_string}.pickle'), 'rb'
+            ) as infile:
+                kde_dict = pickle.load(infile)
+            kdes.append(kde_dict)
+            grids.append(get_grid(best_bandwidths, kde_dict, ratio_string))
+
+        logger.info(f'Plotting ratio {ratio_string}')
+        plot_kde_ratio_nh2_comparison(grid=grids,
+                                      values_on_grid=kdes,
+                                      ratio_string=ratio_string,
+                                      root_outfile=root_outfile,
+                                      points_per_axis=points_per_axis,
+                                      training_data=data,
+                                      ratio_limits=ratio_limits[ratio_string],
+                                      plot_training_data=plot_training_data)
 
 
 def get_grid(best_bandwidths: dict,
@@ -139,11 +154,14 @@ if __name__ == '__main__':
     except KeyError:
         points_per_axis = 200
 
-    for comparison_model in filename_root_map.keys():
-        logger.info(f'Producing figure for {comparison_model}.')
+    for comparison_figure in filename_root_map.keys():
+        logger.info(f'Producing figure for {comparison_figure}.')
+        _plot_training_data = 1 if comparison_figure == 'double_microturbulence' else False
         main(ratio_list=external_input['ratios_to_include'],
              points_per_axis=points_per_axis,
              limit_rows=limit_rows,
-             comparison_model=comparison_model,
+             comparison_model=filename_root_map[comparison_figure],
+             root_outfile=comparison_figure,
              best_bandwidths=external_input['best_bandwidths'],
-             ratio_limits=external_input['ratio_limits'])
+             ratio_limits=external_input['ratio_limits'],
+             plot_training_data=_plot_training_data)
diff --git a/etl/prs/prs_make_posterior_improvement_figure.py b/etl/prs/prs_make_posterior_improvement_figure.py
new file mode 100644
index 0000000..a959847
--- /dev/null
+++ b/etl/prs/prs_make_posterior_improvement_figure.py
@@ -0,0 +1,48 @@
+import os
+import matplotlib.pyplot as plt
+import pickle
+from assets.commons import (load_config_file,
+                            setup_logger,
+                            validate_parameter)
+
+logger = setup_logger(name='PRS - FIGURE POSTERIOR IMPROVEMENT')
+
+
+def main(use_model_for_inference, source_name):
+    _use_model_for_inference = validate_parameter(use_model_for_inference, default='constant_abundance_p15_q05')
+    with open(
+            os.path.join('prs', 'output', 'run_type', _use_model_for_inference, 'posteriors.pickle'), 'rb'
+    ) as infile:
+        posteriors_dict = pickle.load(infile)
+    pdf_values = posteriors_dict['PDF'][source_name]
+    density_values = posteriors_dict['density_grid'][source_name]
+    with open(
+            os.path.join('prs', 'output', 'run_type', _use_model_for_inference, 'posteriors_rt_only.pickle'), 'rb'
+    ) as infile:
+        posteriors_rt_only_dict = pickle.load(infile)
+    pdf_values_rt_only = posteriors_rt_only_dict['PDF'][source_name]
+    density_values_rt_only = posteriors_rt_only_dict['density_grid'][source_name]
+    plt.clf()
+    plt.plot(density_values_rt_only, pdf_values_rt_only, label='RT only')
+    plt.plot(density_values, pdf_values, label='ML')
+    plt.title(f'Posterior improvement for {source_name}')
+    plt.legend(loc='best')
+    plt.xlabel('<$n$(H$_{2}$)> [cm$^{-3}$]')
+    plt.ylabel('Probability density')
+    plt.semilogx()
+    plt.xlim(2e2, 2e6)
+    plt.tight_layout()
+    plt.savefig(os.path.join('prs', 'output', 'comparison_figures', f'posterior_improvement_{source_name}.png'))
+
+
+if __name__ == '__main__':
+    external_input = load_config_file(config_file_path=os.path.join('config', 'density_inference_input.yml'))
+    try:
+        use_model_for_inference = external_input['use_model_for_inference']
+        if use_model_for_inference == 'PLACEHOLDER':
+            logger.warning('No model specified for inference in density_inference_input.yml. Using fiducial.')
+            use_model_for_inference = 'constant_abundance_p15_q05'
+    except KeyError:
+        use_model_for_inference = None
+
+    main(use_model_for_inference=use_model_for_inference, source_name='G08.71-0.41')
-- 
GitLab