import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import pandas as pd
import sqlalchemy
import xarray as xr
from typing import Union, Tuple, List
from contextlib import closing
from itertools import product
from astropy.io import fits
from astropy import units as u
from sqlalchemy.orm import Session, aliased
from sqlalchemy import and_, or_, func, cast, Numeric
from sqlalchemy import engine as sqla_engine
from stg.stg_build_db_structure import (GridPars,
                                        GridFiles,
                                        StarsPars,
                                        ModelPars,
                                        RatioMaps,
                                        MomentZeroMaps)
from assets.commons import (load_config_file,
                            setup_logger,
                            validate_parameter)
from assets.commons.parsing import parse_grid_overrides
from assets.commons.grid_utils import compute_los_average_weighted_profile
from assets.commons.db_utils import get_pg_engine
from assets.constants import line_ratio_mapping

plt.rcParams.update({'font.size': 20})
logger = setup_logger(name='PRS - INSPECT')


def get_clump_avg_density(density_at_ref_value: np.array,
                          avg_densities_dict: dict) -> np.array:
    """
    Create the numpy array of the average densities over the entire clump for all the models in the grid.
    :param density_at_ref_value: the array with the characteristic values of the density
    :param avg_densities_dict: the dictionary of the clump densities, indexed by the characteristic density at reference value
    :return: an array with the sorted average number densities
    """
    _avg_densities = np.zeros(shape=len(density_at_ref_value))
    for (idx, dens) in enumerate(density_at_ref_value):
        _avg_densities[idx] = avg_densities_dict[dens]
    return _avg_densities


def get_aggregated_ratio_from_db(
        dust_temperature: float,
        gas_density: float,
        lines: Union[list, tuple],
        session: Session,
        run_id: str,
        is_isothermal: bool = False,
        is_homogeneous: bool = False) -> float:
    """
    Get the aggregated ratio from the database, according to the aggregation function specified in the pre config file;
    this function works for a homogeneous model. For more complex models the query must be revised.
    :param dust_temperature: the dust temperature of the model
    :param gas_density: the gas density of the model
    :param lines: the lines used to compute the ratio
    :param session: the SQLAlchemy session to use
    :param run_id: the id of the run
    :param is_isothermal: whether the model is isothermal
    :param is_homogeneous: whether the model has a homogeneous density distribution
    :return: the aggregated ratio
    """
    _dust_temperature = np.round(dust_temperature, 2)
    _gas_density = np.round(gas_density, 2)
    density_column, dust_temperature_column = get_filter_columns(is_homogeneous=is_homogeneous,
                                                                 is_isothermal=is_isothermal)
    mom_zero_1, mom_zero_2 = aliased(MomentZeroMaps), aliased(MomentZeroMaps)
    model_pars_1, model_pars_2 = aliased(ModelPars), aliased(ModelPars)
    results = session.query(RatioMaps) \
        .join(mom_zero_1,
              and_(RatioMaps.mom_zero_name_1 == mom_zero_1.mom_zero_name,
                   RatioMaps.run_id == mom_zero_1.run_id)) \
        .join(model_pars_1,
              and_(mom_zero_1.fits_cube_name == model_pars_1.fits_cube_name,
                   mom_zero_1.run_id == model_pars_1.run_id)) \
        .join(mom_zero_2,
              and_(RatioMaps.mom_zero_name_2 == mom_zero_2.mom_zero_name,
                   RatioMaps.run_id == mom_zero_2.run_id)) \
        .join(model_pars_2,
              and_(mom_zero_2.fits_cube_name == model_pars_2.fits_cube_name,
                   mom_zero_2.run_id == model_pars_2.run_id)) \
        .join(GridPars) \
        .filter(
        and_(func.round(cast(dust_temperature_column, Numeric), 2) == _dust_temperature,
             func.round(cast(density_column, Numeric), 2) == _gas_density,
             GridPars.run_id == run_id,
             model_pars_1.iline == lines[0],
             model_pars_2.iline == lines[1])).order_by(GridPars.created_on.desc()).first()
    return results.aggregated_ratio


def get_filter_columns(is_homogeneous: bool,
                       is_isothermal: bool) -> Tuple[sqlalchemy.Column, sqlalchemy.Column]:
    """
    Get the appropriate columns for filtering based on the input parameters.
    :param is_homogeneous: Flag indicating if the grid is homogeneous or not.
    :param is_isothermal: Flag indicating if the grid is isothermal or not.
    :return: A tuple containing the appropriate density and dust temperature
        columns for the given filter conditions.
    """
    # Select correct column for filtering
    if is_isothermal is True:
        dust_temperature_column = GridPars.dust_temperature
    else:
        dust_temperature_column = GridPars.dust_temperature_at_reference
    if is_homogeneous is True:
        density_column = GridPars.central_density
    else:
        density_column = GridPars.density_at_reference
    return density_column, dust_temperature_column


def get_results(
        dust_temperature: float,
        gas_density: float,
        lines: Union[list, tuple],
        run_id: str,
        session: Session,
        is_isothermal: bool = False,
        is_homogeneous: bool = False) -> Tuple[float, str, str, float, str, str, tuple, np.array]:
    """
    Get results from the database, given the parameters of the model
    :param dust_temperature: the characteristic dust temperature
    :param gas_density: the characteristic number density of molecular hydrogen
    :param lines: the lines used to compute the ratio
    :param run_id: the run_id of the model
    :param session: a session to query the database
    :param is_isothermal: whether the model is isothermal
    :param is_homogeneous: whether the model has a homogeneous density distribution
    :return: a tuple containing the aggregated line ratios, the moment zero maps names, the pixel size, the ratio map
        name, the fits grid name (for the number density), the scaling factor between the grid and image pixel (must be
        integer!), and the pixel-by-pixel measured line ratios
    """
    density_column, dust_temperature_column = get_filter_columns(is_homogeneous=is_homogeneous,
                                                                 is_isothermal=is_isothermal)

    mom_zero_1, mom_zero_2 = aliased(MomentZeroMaps), aliased(MomentZeroMaps)
    model_pars_1, model_pars_2 = aliased(ModelPars), aliased(ModelPars)
    # Query database
    results = session.query(GridPars, RatioMaps, GridFiles, model_pars_1) \
        .join(mom_zero_1,
              and_(RatioMaps.mom_zero_name_1 == mom_zero_1.mom_zero_name,
                   RatioMaps.run_id == mom_zero_1.run_id)) \
        .join(model_pars_1,
              and_(mom_zero_1.fits_cube_name == model_pars_1.fits_cube_name,
                   mom_zero_1.run_id == model_pars_1.run_id)) \
        .join(mom_zero_2,
              and_(RatioMaps.mom_zero_name_2 == mom_zero_2.mom_zero_name,
                   RatioMaps.run_id == mom_zero_2.run_id)) \
        .join(model_pars_2,
              and_(mom_zero_2.fits_cube_name == model_pars_2.fits_cube_name,
                   mom_zero_2.run_id == model_pars_2.run_id)) \
        .join(GridPars) \
        .join(GridFiles) \
        .join(StarsPars, isouter=True) \
        .filter(
        and_(dust_temperature_column == dust_temperature,
             density_column == gas_density,
             GridFiles.quantity == 'gas_number_density',
             GridPars.run_id == run_id),
        model_pars_1.iline == lines[0],
        model_pars_2.iline == lines[1]).order_by(GridPars.created_on.desc()).first()
    with fits.open(os.path.join('prs', 'fits', 'ratios', results[1].ratio_map_name)) as ratio_fitsfile:
        ratio_values = ratio_fitsfile[0].data
    assert (results[3].npix % results[0].grid_shape_1 == 0) and (results[3].npix % results[0].grid_shape_2 == 0)
    return results[1].aggregated_ratio, \
        results[1].mom_zero_name_1, \
        results[1].mom_zero_name_2, \
        (results[0].grid_size_3 * u.pc).to(u.cm).value / results[0].grid_shape_3, \
        results[1].ratio_map_name, \
        results[2].fits_grid_name, \
        (int(results[3].npix / results[0].grid_shape_1), int(results[3].npix / results[0].grid_shape_2)), \
        ratio_values


def get_grid_values(
        quantity_name: str,
        dust_temperature: float,
        gas_density: float,
        run_id: str,
        session: Session,
        is_isothermal: bool = False,
        is_homogeneous: bool = False) -> np.array:
    """
    Retrieve the grid values, given the model parameters
    :param quantity_name: the quantity to extract
    :param dust_temperature: the characteristic dust temperature
    :param gas_density: the characteristic number density of molecular hydrogen
    :param run_id: the run_id to process
    :param session: a session to query the database
    :param is_isothermal: whether the model is isothermal
    :param is_homogeneous: whether the model has a homogeneous density distribution
    :return: the array of grid values
    """
    density_column, dust_temperature_column = get_filter_columns(is_homogeneous=is_homogeneous,
                                                                 is_isothermal=is_isothermal)

    # Query database
    results = session.query(GridPars, GridFiles) \
        .join(GridFiles) \
        .filter(
        and_(dust_temperature_column == dust_temperature,
             density_column == gas_density,
             GridFiles.quantity == quantity_name,
             GridPars.run_id == run_id)).order_by(GridPars.created_on.desc()).first()
    with fits.open(os.path.join('prs', 'fits', 'grids', results[1].fits_grid_name)) as fitsfile:
        data = fitsfile[0].data
    return data


def get_column_density_vs_mom0(molecule_column_density_grid: np.array,
                               h2_volume_density_grid: np.array,
                               mom_zero_name_1: str,
                               mom_zero_name_2: str,
                               grid_pixel_size: float,
                               lines: Tuple[str, str],
                               gas_density: float,
                               dust_temperature: float,
                               zoom_ratios: Tuple[int, int]) -> pd.DataFrame:
    """
    Get a DataFrame containing the column density of H2 and the moment zero.

    :param molecule_column_density_grid: Array of column densities for the molecule of interest.
    :param h2_volume_density_grid: Array of column densities for H2.
    :param mom_zero_name_1: Name of the first moment file.
    :param mom_zero_name_2: Name of the second moment file.
    :param grid_pixel_size: Size of the grid pixels.
    :param lines: Tuple of the names of the lines id for which to compute the moment zero.
    :param gas_density: Gas density at the scaling radius.
    :param dust_temperature: Dust temperature at the scaling radius.
    :param zoom_ratios: Tuple of the horizontal and vertical zoom ratios.
    :return: A DataFrame containing the column density of H2 and the molecule of interest as a function of mom0.
    """
    h2_column_density_map = np.nansum(h2_volume_density_grid * grid_pixel_size, axis=2).flatten()
    molecule_column_density_map = np.nansum(molecule_column_density_grid * grid_pixel_size, axis=2).flatten()
    df_results = pd.DataFrame(data=h2_column_density_map, columns=['H2_column_density'])
    df_results['molecule_column_density'] = molecule_column_density_map
    df_results['nh2'] = gas_density
    df_results['tdust'] = dust_temperature
    with fits.open(os.path.join('prs', 'fits', 'moments', mom_zero_name_1)) as fitsfile:
        df_results[f'mom_zero_{lines[0]}'] = fitsfile[0].data[::zoom_ratios[0],
                                             ::zoom_ratios[1]].byteswap().newbyteorder().flatten()
    with fits.open(os.path.join('prs', 'fits', 'moments', mom_zero_name_2)) as fitsfile:
        df_results[f'mom_zero_{lines[1]}'] = fitsfile[0].data[::zoom_ratios[0],
                                             ::zoom_ratios[1]].byteswap().newbyteorder().flatten()
    return df_results[df_results['H2_column_density'] > 0]


def plot_coldens_vs_integrated_intensity(df_coldens_mom0_list: List[pd.DataFrame],
                                         lines: Tuple[str, str],
                                         run_type: str) -> pd.DataFrame:
    """
    Plots the column densities vs integrated intensity for each line.

    :param df_col_dens_list: List of DataFrames containing the column densities of interest.
    :param lines: Tuple of the id of the lines.
    :return: A DataFrame containing the column densities of interest vs. the moment zero.
    """
    df_coldens_mom0 = pd.concat(df_coldens_mom0_list)
    plt.xlabel('H_2 column density')
    plt.ylabel('Moment 0')
    plt.loglog()
    plt.scatter(df_coldens_mom0['H2_column_density'], df_coldens_mom0[f'mom_zero_{lines[0]}'],
                label=f'mom_zero_{lines[0]}')
    plt.scatter(df_coldens_mom0['H2_column_density'], df_coldens_mom0[f'mom_zero_{lines[1]}'],
                label=f'mom_zero_{lines[1]}')
    plt.legend()
    plt.savefig(
        os.path.join(
            'prs',
            'output',
            'run_type',
            run_type,
            'figures',
            f'coldens_moments_lines_{"-".join(lines)}.png'
        )
    )
    plt.clf()
    return df_coldens_mom0


def plot_ratio_vs_density(lines: Tuple[str, str],
                          results: xr.DataArray,
                          avg_density: np.array,
                          run_type: str):
    """
    Plots the line ratio vs. the average LOS volume density, and the integrated ratio as a function of the
        model-specific density and temperature at the scaling radius.

    :param lines: Tuple of the id of the lines.
    :param results: DataArray containing the ratios and volume densities.
    :param avg_density: numpy array with the overall average densities of the clumps.
    :param run_type: string identification of the model type.
    """
    plt.figure(figsize=(8, 6))
    plt.semilogx()
    plt.xlabel(r'<n(H$_2$) [cm$^{-3}$]>')
    plt.ylabel('simulated ratio')
    plt.title(f'Ratio {line_ratio_mapping["-".join(lines)]}')
    plt.tight_layout()
    plt.savefig(
        os.path.join(
            'prs',
            'output',
            'run_type',
            run_type,
            'figures',
            f'ratio_vs_avg_density_los_{"-".join(lines)}.png'
        )
    )
    plt.clf()
    plt.figure(figsize=(8, 6))
    results.plot(x='dust_temperature', y='gas_density', yscale='log', cbar_kwargs={'pad': 0.2})
    plt.xlabel('$T_{dust}(r_0)$ [K]')
    plt.ylabel('$n$(H$_2$, r$_0$) [cm$^{-3}$]')
    # plt.title(f'Integrated line ratio ({"-".join(lines)}) grid', y=1.02)
    ref_density_limits = plt.gca().get_ylim()
    ref_density_values_limits = np.array([min(results['gas_density'].data), max(results['gas_density'].data)])
    ax2 = plt.twinx()
    # Aligning the secondary axis values with the pixel centres
    ax2.set_ylim(np.array([ref_density_limits[0] / ref_density_values_limits[0] * min(avg_density),
                           ref_density_limits[1] / ref_density_values_limits[1] * max(avg_density)]))
    ax2.semilogy()
    ax2.set_ylabel(r'<n(H$_2$) [cm$^{-3}$]>')
    plt.tight_layout()
    plt.savefig(
        os.path.join(
            'prs',
            'output',
            'run_type',
            run_type,
            'figures',
            f'ratio_grid_lines_{"-".join(lines)}.png'
        )
    )
    plt.clf()


def main(run_id: str,
         is_isothermal: bool,
         is_homogeneous: bool,
         molecule_used_to_weight_los_quantity: Union[str, None] = None,
         engine: sqla_engine = None,
         run_type: Union[str, None] = None):
    if engine is None:
        engine = get_pg_engine(logger=logger)
    config = load_config_file(os.path.join('config', 'config.yml'))
    config_stg = load_config_file(os.path.join('stg', 'config', 'config.yml'))

    _run_type = validate_parameter(
        run_type,
        default='constant_abundance_p15_q05'
    )

    # grid definition
    dust_temperatures = parse_grid_overrides(par_name='dust_temperature',
                                             config=config)
    central_densities = parse_grid_overrides(par_name='gas_density',
                                             config=config)
    line_pairs = config['overrides']['lines_to_process']

    # If not specified, assumes that the first molecule should be used for weighting
    _molecule_used_to_weight_los_quantity = validate_parameter(molecule_used_to_weight_los_quantity,
                                                               default=config_stg['lines']['species_to_include'][0])

    results = xr.DataArray(np.empty(shape=[len(dust_temperatures), len(central_densities)]),
                           dims=('dust_temperature', 'gas_density'),
                           coords={
                               'dust_temperature': dust_temperatures,
                               'gas_density': central_densities
                           })
    results_dict = {}
    avg_densities_dict = {}
    with closing(Session(engine)) as session:
        for lines in line_pairs:
            df_coldens_mom0_list = []
            results_dict[f'{"-".join(lines)}'] = {}
            for (tdust, nh2) in product(dust_temperatures, central_densities):
                aggregated_ratio = get_aggregated_ratio_from_db(dust_temperature=tdust,
                                                                gas_density=nh2,
                                                                lines=lines,
                                                                session=session,
                                                                run_id=run_id,
                                                                is_isothermal=is_isothermal,
                                                                is_homogeneous=is_homogeneous)
                results.loc[tdust, nh2] = aggregated_ratio
                logger.debug(f'The aggregated ratio for lines {lines}, using {nh2}, {tdust} is: {aggregated_ratio}')
                _, \
                    mom0_name_1, \
                    mom0_name_2, \
                    grid_pixel_size, \
                    _, _, \
                    zoom_ratios, \
                    ratio_values = \
                    get_results(
                        dust_temperature=tdust,
                        gas_density=nh2,
                        lines=lines,
                        session=session,
                        run_id=run_id,
                        is_isothermal=is_isothermal,
                        is_homogeneous=is_homogeneous
                    )
                density_grid = get_grid_values(
                    quantity_name='gas_number_density',
                    dust_temperature=tdust,
                    gas_density=nh2,
                    session=session,
                    run_id=run_id,
                    is_isothermal=is_isothermal,
                    is_homogeneous=is_homogeneous
                )
                temperature_grid = get_grid_values(
                    quantity_name='dust_temperature',
                    dust_temperature=tdust,
                    gas_density=nh2,
                    session=session,
                    run_id=run_id,
                    is_isothermal=is_isothermal,
                    is_homogeneous=is_homogeneous
                )
                molecule_grid = get_grid_values(
                    quantity_name=_molecule_used_to_weight_los_quantity,
                    dust_temperature=tdust,
                    gas_density=nh2,
                    session=session,
                    run_id=run_id,
                    is_isothermal=is_isothermal,
                    is_homogeneous=is_homogeneous
                )
                avg_density_map = compute_los_average_weighted_profile(profile=density_grid, weights=molecule_grid)
                avg_density = np.nansum(np.where(density_grid == 0, np.nan, density_grid * molecule_grid)) / \
                              np.nansum(np.where(density_grid == 0, np.nan, molecule_grid))
                avg_densities_dict[nh2] = avg_density
                std_density_map = np.nanstd(np.where(density_grid == 0, np.nan, molecule_grid), axis=2)
                avg_temperature_map = compute_los_average_weighted_profile(profile=temperature_grid,
                                                                           weights=molecule_grid)
                correlation_data = np.array([avg_density_map.flatten(),
                                             avg_temperature_map.flatten(),
                                             std_density_map.flatten(),
                                             ratio_values[::zoom_ratios[0], ::zoom_ratios[1]].flatten()])
                results_dict[f'{"-".join(lines)}'][f'{"_".join([str(tdust), str(nh2)])}'] = correlation_data
                plt.scatter(correlation_data[0], correlation_data[3])
                df_coldens_mom0_list.append(get_column_density_vs_mom0(molecule_column_density_grid=molecule_grid,
                                                                       h2_volume_density_grid=density_grid,
                                                                       mom_zero_name_1=mom0_name_1,
                                                                       mom_zero_name_2=mom0_name_2,
                                                                       lines=lines,
                                                                       gas_density=nh2,
                                                                       dust_temperature=tdust,
                                                                       zoom_ratios=zoom_ratios,
                                                                       grid_pixel_size=grid_pixel_size))
            avg_densities = get_clump_avg_density(density_at_ref_value=results['gas_density'].data,
                                                  avg_densities_dict=avg_densities_dict)
            plot_ratio_vs_density(lines=lines,
                                  results=results,
                                  avg_density=avg_densities,
                                  run_type=_run_type)

            df_coldens_mom0 = plot_coldens_vs_integrated_intensity(df_coldens_mom0_list=df_coldens_mom0_list,
                                                                   lines=lines,
                                                                   run_type=_run_type).reset_index(). \
                rename(columns={'index': 'px_index'})
            df_coldens_mom0.to_csv(
                os.path.join(
                    'prs',
                    'output',
                    'run_type',
                    _run_type,
                    'data',
                    f'full_dataset_NH2_lines_{"-".join(lines)}.csv'
                )
            )

    df_list_concat = []
    for (tdust, nh2) in product(dust_temperatures, central_densities):
        df_list = []
        for lines in line_pairs:
            df = pd.DataFrame(
                data=results_dict[f'{"-".join(lines)}'][f'{"_".join([str(tdust), str(nh2)])}'].transpose(),
                columns=['avg_nh2', 'avg_tdust', 'std_nh2', f'ratio_{"-".join(lines)}'])
            df_list.append(df.dropna())
        df_tmp = pd.concat(df_list, axis=1)
        df_tmp = df_tmp.loc[:, ~df_tmp.columns.duplicated()].copy()
        df_tmp['nh2'] = nh2
        df_tmp['tdust'] = tdust
        df_list_concat.append(df_tmp)
    df_all = pd.concat(df_list_concat).reset_index(). \
        rename(columns={'index': 'px_index'})
    for lines in line_pairs:
        g = sns.jointplot(x=np.log10(df_all['avg_nh2']), y=df_all[f'ratio_{"-".join(lines)}'],
                          kind='kde', joint_kws={'bw_adjust': 2})
        g.plot_joint(sns.scatterplot, s=100, alpha=.5, marker='+', color='orange')

        plt.savefig(
            os.path.join(
                'prs',
                'output',
                'run_type',
                _run_type,
                'figures',
                f'ratio_vs_avg_density_los_kde_{"-".join(lines)}.png'
            )
        )
        plt.clf()
    df_all.to_csv(
        os.path.join(
            'prs',
            'output',
            'run_type',
            _run_type,
            'data',
            'full_dataset.csv'
        )
    )


if __name__ == '__main__':
    main(run_id='7dd5b365-875e-4857-ae11-2707820a33c1', is_isothermal=True, is_homogeneous=True, run_type='uniform')
