#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-04-28 17:08:03

"""
Implement Shen+ 2020 equations
"""

import numpy as np
import astropy.units as u
from scipy.optimize import minimize_scalar

BAND_PARAMETERS_BC = {
    "B band":       (+3.759, -0.361, +9.830, -0.0063),
    "UV":           (+1.862, -0.361, +4.870, -0.0063),
    "Soft X-ray":   (+5.712, -0.026, +17.67, +0.2780),
    "Hard X-ray":   (+4.073, -0.026, +12.60, +0.2780),
    "Mid-IR":       (+4.361, -0.361, +11.40, -0.0063),
}
def get_log_ratio(log_L_bol, c1, k1, c2, k2):
    # Return the ratio L_bol / L_band
    return np.log10(
        c1 * np.power(10, log_L_bol - 10 - 33.58297192910481) ** k1 +
        c2 * np.power(10, log_L_bol - 10 - 33.58297192910481) ** k2
    )


def get_luminosity_bolometric(log_L_band, *args, **kwargs):

    """Return Eq. 5 of Shen+ 2020"""

    def fun(log_L_band):

        # Note: magic number is Lsun in erg/s
        def to_minimize(log_L_bol):
            ratio1 = log_L_bol - log_L_band
            ratio2 = get_log_ratio(log_L_bol, *args, **kwargs)
            return (ratio1 - ratio2) ** 2

        return minimize_scalar(to_minimize).x

    return np.vectorize(fun)(log_L_band)


def get_qlf(
        l,
        z,
        a0=0.8569,
        a1=-0.2614,
        a2=0.0200,
        b0=2.5375,
        b1=-1.0425,
        b2=1.1201,
        c0=13.0088,
        c1=-0.5759,
        c2=0.4554,
        d0=-3.5426,
        d1=-0.3936,
        z_ref=2.0
):

    """Return Eq. 14 of Shen+ 2020"""
    def T0(x): return 1
    def T1(x): return x
    def T2(x): return 2 * x ** 2 - 1

    x = 1 + z
    ratio = (1 + z) / (1 + z_ref)

    gamma1 = a0 * T0(x) + a1 * T1(x) + a2 * T2(x)
    gamma2 = 2 * b0 / (ratio ** b1 + ratio ** b2)
    log_L = 2 * c0 / (ratio ** c1 + ratio ** c2) + 33.58297192910481
    log_phi = d0 * T0(x) + d1 * T1(x)
    a = 10 ** (l - log_L)

    # See Eq. 11 of Shen+ 2020
    return 10 ** log_phi / (a ** gamma1 + a ** gamma2)


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import astropy.units as u

    Lband = np.linspace(40, 48, 801)
    for k, v in BAND_PARAMETERS_BC.items():
        Lbol = get_luminosity_bolometric(Lband, *v)
        plt.plot(Lbol, Lbol - Lband, label=k)
    plt.legend()
    plt.show()
