from itertools import groupby
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from assets.commons import (get_data)
from astropy import units as u
from astropy import constants
from sklearn.linear_model import LinearRegression
from typing import Tuple
plt.rcParams.update({'font.size': 16})


def get_moldata() -> pd.DataFrame:
    source_file = 'mdl/data/molecule_e-ch3oh.inp'
    rows_list = []

    with open(source_file, 'r') as infile:
        file_content = infile.readlines()
    file_content_sections = [list(g) for k, g in groupby(file_content, lambda x: x.startswith('!')) if not k]
    levels = file_content_sections[3]

    for line in levels:
        split_line = line.split()
        rows_list.append({
            'original_index': int(split_line[0]),
            'energy': float(split_line[1]),
            'g': float(split_line[2]),
            'weight': 2 * float(split_line[2]) + 1,
            'J_K': f'{split_line[3]}'
        })

    energy_levels = pd.DataFrame(rows_list)
    return energy_levels.sort_values(by='energy')


def compute_partition_function(temperature: float,
                               energies_and_gs: pd.DataFrame) -> float:
    partition_function = 0
    for level in energies_and_gs.iterrows():
        partition_function += float(level[1]['g']) * np.exp(-float(level[1]['energy']) / temperature)
    return partition_function


def get_coldens_z_temperature(energies: np.array, y_rotational_diagram: np.array) -> Tuple[np.array, np.array]:
    coldens_z = []
    temperature = []
    for pix_y_rotdia in y_rot_dia.value:
        reg = LinearRegression().fit(np.log10(np.e) * energies.reshape(-1, 1), np.log10(pix_y_rotdia))
        coldens_z.append(reg.intercept_)
        temperature.append(-1 / reg.coef_[0])
    temperature = np.array(temperature)
    temperature = np.where((temperature > 200) | (temperature <= 0), 200, temperature)
    return np.array(coldens_z), temperature


def get_y_rotational_diagram(mom_zero_data: pd.DataFrame,
                             frequencies: np.array,
                             einstein_a: np.array,
                             gs: np.array):
    integrated_intensities = np.array(
        mom_zero_data[['mom_zero_86', 'mom_zero_87', 'mom_zero_88']]) * u.Jansky.decompose() * u.km / u.s * \
                             (constants.c.cgs ** 2 / (2 * constants.k_B.decompose() * (frequencies * u.GHz).to(
                                 u.Hz) ** 2).decompose()) / (np.arcsin(2 / 101)) ** 2
    integrated_intensities = integrated_intensities.decompose().cgs
    y_rot_dia = 8 * np.pi * constants.k_B.cgs * (frequencies * u.GHz).to(u.Hz) ** 2 / \
                (constants.h.cgs * constants.c.cgs ** 3 * einstein_a * (u.s ** -1)) * \
                integrated_intensities / gs
    y_rot_dia = y_rot_dia.decompose().cgs
    return y_rot_dia


# Transitions data
# 86, 87, 88
gs = np.array([5, 5, 5])
energies = np.array([12.5, 20.1, 28.0])
einstein_a = np.array([2.557794E-06, 3.407341E-06, 2.624407E-06])
frequencies = np.array([96.739358, 96.744545, 96.75550])

data = get_data(limit_rows=None)
molecular_data = get_moldata()

y_rot_dia = get_y_rotational_diagram(mom_zero_data=data,
                                     frequencies=frequencies,
                                     einstein_a=einstein_a,
                                     gs=gs)

coldens_z, temperature = get_coldens_z_temperature(energies=energies,
                                                   y_rotational_diagram=y_rot_dia)
Z = compute_partition_function(energies_and_gs=molecular_data, temperature=temperature)
data['lte_coldens'] = coldens_z + np.log10(Z) #np.mean(log_n, axis=1)



plt.clf()
plt.scatter(10**data.lte_coldens, 10**data.molecule_column_density, alpha=0.3)
# sns.kdeplot(x=10**data.lte_coldens, y=10**data.molecule_column_density)
plt.loglog()
plt.plot([1e11, 1e17], [1e11, 1e17], color='red')
plt.xlabel(r'N(CH$_3$OH) LTE [cm$^{-2}$]')
plt.ylabel(r'N(CH$_3$OH) Model [cm$^{-2}$]')
plt.tight_layout()
plt.savefig(os.path.join('prs', 'output', 'lte_approx', 'lte_approximation_coldens.png'))
