#!/usr/bin/env python3
# encoding: utf-8
# author: Ivano Saccheo
# date: 2023-06-07

"""
IGM absorption according to Inoue+14
"""

import numpy as np

def lyman_continuum_LAF(redshift, lambda_obs):

    ll = 911.8 #lyman-limit
    wav = lambda_obs/ll
    tau = np.zeros((len(lambda_obs),))

    if redshift < 1.2:
        idx = wav<(redshift+1)
        tau[idx] = 0.325*(wav[idx]**1.2-((1+redshift)**(-0.9))*(wav[idx]**2.1))

    elif redshift >= 1.2 and redshift < 4.7:
        idx1 = wav < 2.2
        idx2 = np.logical_and(wav >= 2.2, wav <(redshift+1))

        tau[idx1] = (0.0255*((1+redshift)**1.6)*(wav[idx1]**2.1)
                           +0.325*(wav[idx1]**1.2) -0.250*(wav[idx1]**2.1))

        tau[idx2] = 0.0255*(((1+redshift)**1.6)*(wav[idx2]**2.1)-(wav[idx2]**3.7))
    else:
        idx1 = wav < 2.2
        idx2 = np.logical_and(wav >= 2.2, wav <5.7)
        idx3 = np.logical_and(wav >= 5.7, wav < (redshift+1))

        tau[idx1] = (0.000522*((1+redshift)**3.4)*(wav[idx1]**2.1)
                     + 0.325*(wav[idx1]**1.2) - 0.0314*(wav[idx1]**2.1))

        tau[idx2] =  (0.000522*((1+redshift)**3.4)*(wav[idx2]**2.1)
                      +0.218*(wav[idx2]**2.1) -0.0255*(wav[idx2]**3.7))

        tau[idx3] = 0.000522*(((1+redshift)**3.4)*(wav[idx3]**2.1)- (wav[idx3]**5.5))

    return tau


def lyman_continuum_DLA(redshift, lambda_obs):

    ll = 911.8                   #lyman-limit
    wav = lambda_obs/ll
    tau = np.zeros((len(lambda_obs),))

    if redshift < 2:
        idx = wav < (1+redshift)
        tau[idx] = 0.211*((1+redshift)**2) - 0.0766*((1+redshift)**2.3)*(wav[idx]**(-0.3))-0.135*(wav[idx]**2)

    else:
        idx1 = wav < 3
        idx2 = np.logical_and(wav >=3, wav < (1+redshift))

        tau[idx1] = (0.634 + 0.047*((1+redshift)**3) -0.0178*((1+redshift)**3.3)*(wav[idx1]**(-0.3))
                    -0.135*(wav[idx1]**2)-0.291*(wav[idx1]**(-0.3)))

        tau[idx2] = 0.047*((1+redshift)**3)-0.0178*((1+redshift)**3.3)*(wav[idx2]**(-0.3))-0.0292*(wav[idx2]**3)

    return tau


def lyman_series_LAF(redshift, lambda_obs, coefficients):

    tau = np.zeros((len(lambda_obs), coefficients.shape[0]))

    for j in range(coefficients.shape[0]):

        idx1 = np.logical_and.reduce([lambda_obs < coefficients[j, 1] * 2.2,
                                      lambda_obs > coefficients[j, 1],
                                      lambda_obs < coefficients[j, 1] * (redshift + 1)], axis=0)

        idx2 = np.logical_and.reduce([lambda_obs >= coefficients[j, 1] * 2.2,
                                      lambda_obs < coefficients[j, 1] * 5.7,
                                      lambda_obs < coefficients[j, 1] * (redshift + 1)], axis=0)
        idx3 = np.logical_and.reduce([~np.logical_or(idx1, idx2),
                                      lambda_obs > coefficients[j, 1],
                                      lambda_obs < coefficients[j, 1] * (redshift + 1)], axis=0)

        tau[idx1, j] = coefficients[j,2]*((lambda_obs[idx1]/coefficients[j,1])**1.2)

        tau[idx2, j] = coefficients[j,3]*((lambda_obs[idx2]/coefficients[j,1])**3.7)

        tau[idx3, j] = coefficients[j,4]*((lambda_obs[idx3]/coefficients[j,1])**5.5)

    return np.sum(tau, axis = 1)


def lyman_series_DLA(redshift, lambda_obs, coefficients):

    tau = np.zeros((len(lambda_obs), coefficients.shape[0]))

    for j in range(coefficients.shape[0]):

        idx1 = np.logical_and.reduce([lambda_obs < coefficients[j,1] * 3,
                                      lambda_obs > coefficients[j,1],
                                      lambda_obs < coefficients[j,1] * (redshift + 1)], axis = 0)
        idx2 = np.logical_and.reduce([~idx1,
                                      lambda_obs >coefficients[j, 1],
                                      lambda_obs < coefficients[j, 1]*(redshift + 1)], axis = 0)


        tau[idx1, j] = coefficients[j, 5]*((lambda_obs[idx1]/coefficients[j, 1]) ** 2)

        tau[idx2, j] = coefficients[j, 6]*((lambda_obs[idx2]/coefficients[j, 1]) ** 3)


    return np.sum(tau, axis=1)


from util import ROOT
def get_IGM_absorption(
    redshift,
    lambda_obs,
    coefficients=np.loadtxt(f"{ROOT}/data/lyman_series_coefficients.dat")
):

    tau_continuum_laf = lyman_continuum_LAF(redshift, lambda_obs)
    tau_continuum_dla = lyman_continuum_DLA(redshift, lambda_obs)
    tau_series_laf = lyman_series_LAF(redshift, lambda_obs, coefficients)
    tau_series_dla = lyman_series_DLA(redshift, lambda_obs, coefficients)
    tau = tau_continuum_laf +  tau_continuum_dla + tau_series_laf + tau_series_dla
    return np.exp(-tau)


# Wavelength for the precompuated IGM absorption
LAM = 10 ** np.arange(np.log10(900), np.log10(300000 * (6 + 1)), 7.65e-4)
IGM = {}
def my_get_IGM_absorption(redshift, lambda_obs, decimals=2, *args, **kwargs):

    if redshift == 0:
        return 1.

    key = np.round(redshift, decimals)
    if key not in IGM:
        IGM[key] = get_IGM_absorption(key, LAM)
    return np.interp(lambda_obs, LAM, IGM[key])
