import os
import pandas as pd
import numpy as np
from assets.utils import get_poc_results
from collections import OrderedDict


def divide_chunks(l, n):
    for i in range(0, len(l), n):
        yield list(l[i:i + n])


def mask_non_detections(data: pd.DataFrame,
                        transition_id: str) -> pd.DataFrame:
    _data = data.copy()
    _data.loc[_data[f'tpeak_{transition_id}'] < 2 * _data[f'rms_noise_{transition_id}'],
             f'tpeak_{transition_id}'] = -1
    return _data


df_full = get_poc_results(line_fit_filename='ch3oh_data_top35.csv')

# Make POC LaTex table
pi230_beam_eff = 0.7
flash_beam_eff = 0.73
df_table = df_full.copy()
df_table[['tpeak_256', 'tpeak_257', 'rms_noise_256', 'rms_noise_257']] /= pi230_beam_eff
df_table[['area_256', 'area_257', 'area_unc_256', 'area_unc_257']] /= pi230_beam_eff
df_table[['tpeak_380', 'tpeak_381', 'rms_noise_380', 'rms_noise_381']] /= flash_beam_eff
df_table[['area_380', 'area_381', 'area_unc_380', 'area_unc_381']] /= flash_beam_eff
for transition_id in ['86', '87', '88', '256', '257', '380', '381']:
    df_table = mask_non_detections(data=df_table, transition_id=transition_id)

df_table['best_fit'] /= 1e5
df_table['best_fit'] = df_table['best_fit'].round(1)
df_table['hpd_interval'] = df_table['hpd_interval'].apply(lambda row: (np.array(row) / 1e5).round(1))
df_table['hpd_interval'] = df_table['hpd_interval'].apply(lambda row: list(divide_chunks(row, 2)))
df_table['mass'] = (df_table['mass'] / 100).round(1).astype(str)
df_table['distance'] = (df_table['distance']).round(1).astype(str)
density_format = '$[10^{5}\,\mathrm{cm^{-3}}]$'
fwhm_units = '[km s$^{-1}$]'

df_fit_results = df_table[['AGAL_name', 'class_phys', 'best_fit', 'hpd_interval','mass', 'distance']].copy()
df_source_properties = df_table[['AGAL_name', 'tpeak_86', 'tpeak_87', 'tpeak_88', 'tpeak_256', 'tpeak_257', 'tpeak_380',
                                 'tpeak_381', 'rms_noise_86', 'rms_noise_87', 'rms_noise_88', 'rms_noise_256',
                                 'rms_noise_257', 'rms_noise_380', 'rms_noise_381', 'linewidth_86', 'linewidth_87', 'linewidth_88', 'linewidth_256', 'linewidth_257', 'linewidth_380',
                                 'linewidth_381', 'linewidth_unc_86', 'linewidth_unc_87', 'linewidth_unc_88', 'linewidth_unc_256',
                                 'linewidth_unc_257', 'linewidth_unc_380', 'linewidth_unc_381']].copy()
for transition_id in ['86', '87', '88', '256', '257', '380', '381']:
    df_source_properties[f'tpeak_{transition_id}'] = '$' + df_source_properties[f'tpeak_{transition_id}'].round(2).astype(str) + ' \pm ' + df_table[f'rms_noise_{transition_id}'].round(2).astype(str) + '$'
    df_source_properties[f'linewidth_{transition_id}'] = '$' + df_source_properties[f'linewidth_{transition_id}'].round(2).astype(str) + ' \pm ' + df_table[f'linewidth_unc_{transition_id}'].round(2).astype(str) + '$'
    df_source_properties.loc[df_source_properties[f'tpeak_{transition_id}'].str.startswith('$-1.0'), [f'tpeak_{transition_id}', f'linewidth_{transition_id}']] = '-1.00'
    df_source_properties.loc[df_source_properties[f'tpeak_{transition_id}'].str.startswith('$nan'), [f'tpeak_{transition_id}', f'linewidth_{transition_id}']] = 'N/A'
    df_source_properties.drop(columns=[f'rms_noise_{transition_id}', f'linewidth_unc_{transition_id}'], inplace=True)


remap_columns = OrderedDict([
    ('AGAL_name', ('Source', '')),
    ('class_phys', ('Classification', '')),
    ('best_fit', ('best fit($n\mathrm{H_2}$)', density_format)),
    ('hpd_interval', ('67\%  HPD', density_format)),
    ('mass', ('Mass', '$[10^2 \mathrm{M_\odot}]$')),
    ('distance', ('Distance', '[kpc]')),
    ('tpeak_86', ('$T_{MB,(2_{-1}-1_{-1})}$', '[K]')),
    ('linewidth_86', ('$FWHM_{(2_K-1_K)}$', fwhm_units)),
    ('tpeak_87', ('$T_{MB,(2_{0}-1_{0})}$', '[K]')),
    ('linewidth_87', ('$FWHM_{87}$', '[km s$^{-1}]$')),
    ('tpeak_88', ('$T_{MB,(2_{1}-1_{1})}$', '[K]')),
    ('linewidth_88', ('$FWHM_{88}$', '[km s$^{-1}]$')),
    # ('rms_noise_86', ('$\sigma_{96.7GHz}$', '[K]')),
    ('tpeak_256', ('$T_{MB,(5_{0}-4_{0})}$', '[K]')),
    ('linewidth_256', ('$FWHM_{256}$', fwhm_units)),
    ('tpeak_257', ('$T_{MB,(5_{-1}-4_{-1})}$', '[K]')),
    ('linewidth_257', ('$FWHM_{(5_K-4_K)}$', fwhm_units)),
    # ('rms_noise_256', ('$\sigma_{241.7GHz}$', '[K]')),
    ('tpeak_380', ('$T_{MB,(7_{0}-6_{0})}$', '[K]')),
    ('linewidth_380', ('$FWHM_{380}$', fwhm_units)),
    ('tpeak_381', ('$T_{MB,(7_{-1}-6_{-1})}$', '[K]')),
    ('linewidth_381', ('$FWHM_{(7_K-6_K)}$', fwhm_units)),
    # ('rms_noise_380', ('$\sigma_{338.3GHz}$', '[K]')),
])
df_table = df_table[remap_columns.keys()]
df_table.rename(columns=remap_columns, inplace=True)
df_table.columns = pd.MultiIndex.from_tuples(df_table.columns)
df_source_properties.rename(columns=remap_columns, inplace=True)
df_source_properties.columns = pd.MultiIndex.from_tuples(df_source_properties.columns)
df_fit_results.rename(columns=remap_columns, inplace=True)
df_fit_results.columns = pd.MultiIndex.from_tuples(df_fit_results.columns)


caption = 'Sources included in the proof-of-concept, classification, distance, mass, and number density results. We indicate the highest-probability density interval containing the 67\% of the probability mass as HPD 67\%.'
label = 'tab:poc_results'
latex_table = (df_fit_results.style.hide(axis='index').format({
    ('best fit($n\mathrm{H_2}$)', '$[10^{5}\,\mathrm{cm^{-3}}]$'): '{:,.1f}'})
               .to_latex(caption=caption, label=label, hrules=True, environment='table*', position_float='centering'))
latex_table = latex_table.replace('[[', '[').replace(']]', ']')
with open(os.path.join('prs', 'output', 'poc_tables', 'fit_results.tex'), 'w') as outfile:
    outfile.write(latex_table)

table_cols = [
    [('Source', ''), ('$T_{MB,(2_{-1}-1_{-1})}$', '[K]'), ('$T_{MB,(2_{0}-1_{0})}$', '[K]'), ('$T_{MB,(2_{1}-1_{1})}$', '[K]'), ('$FWHM_{(2_K-1_K)}$',
                                                                                                 fwhm_units)],
    [('Source', ''), ('$T_{MB,(5_{0}-4_{0})}$', '[K]'), ('$T_{MB,(5_{-1}-4_{-1})}$', '[K]'), ('$FWHM_{(5_K-4_K)}$', fwhm_units), ('$T_{MB,(7_{0}-6_{0})}$', '[K]'), ('$T_{MB,(7_{-1}-6_{-1})}$', '[K]'), ('$FWHM_{(7_K-6_K)}$',
                                                                                                                                                         fwhm_units), ]
]
captions = [
    "Line properties of the lines in the $(2_K-1_K)$ methanol band. Only one FWHM is listed because the fit is performed forcing all lines to have the same width. The main-beam temperature of the lines is indicated as $T_{MB,J_K-J',_K'}$. A non-detection is indicated with three dots.",
    "Line properties of the lines in the $(5_K-4_K)$ and $(7_K-6_K)$ methanol bands. Only one FWHM is listed per band because the fit is performed forcing all lines to have the same width. The main-beam temperature of the lines is indicated as $T_{MB,J_K-J',_K'}$. A non-detection is indicated with three dots, while missing data are indicated with \'N/A\'.",
]
labels = [
    'tab:poc_lines_3mm',
    'tab:poc_lines_hf',
]
outfiles = [
    'line_table_3mm.tex',
    'line_table_hf.tex',
]
for cols, caption, label, filename in zip(table_cols, captions, labels, outfiles):
    latex_table = (df_source_properties[cols].style.hide(axis='index')
                   .to_latex(caption=caption, label=label, hrules=True, environment='table*', position_float='centering'))
    latex_table = latex_table.replace('[[', '[').replace(']]', ']')
    latex_table = latex_table.replace('-1.00', '\dots').replace('-1.00', '\dots')
    with open(os.path.join('prs', 'output', 'poc_tables', filename), 'w') as outfile:
        outfile.write(latex_table)

