#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-04-27 11:41:37

"""
Implement Bon+16
"""

import os
import sys

import numpy as np
from scipy.integrate import quad

from util import ROOT

try:
    from smf import StellarMassFunctionBongiorno2016
    from smf import StellarMassFunctionCOSMOS2020
except ImportError:
    pass


N_MIN = 20.
N_MAX = 24.
L_MIN = 32.
L_MAX = 36.
M_MIN = 8.5
M_MAX = 13.5
Z_MIN = 0.01
Z_MAX = 6.99


def get_p(M_star, L_X, N_H, z, func_Psi, func_I, func_f, func_dV_div_dz, flux=None):

    """Return Bongiorno eq. 2 for a sample of AGN"""

    def func(M_star, L_X, N_H, z, flux):
        lambda_SAR = L_X / M_star
        ret = (
            func_Psi(M_star, lambda_SAR, z)
            * func_I(L_X, z, N_H, flux=flux)
            * func_f(N_H, L_X, z)
            * func_dV_div_dz(z)
        )
        return ret

    def get_p_i(i):
        """Eq. 2 and 4 for individual object i/j"""
        args = [np.array(a, ndmin=4) for a in [M_star[i], L_X[i], N_H[i], z[i], flux[i]]]
        if M_star[i] > 0:
            ret = func(*args)
        else:
            ret = quad(func, 10 ** M_MIN, 10 ** M_MAX, args=args[1:])[0]
        return ret
    i = np.arange(M_star.size)
    N = get_N(func_Psi, func_I, func_f, func_dV_div_dz)
    return np.vectorize(get_p_i)(i) / N


DD = 0.10
LOG_M, LOG_L, LOG_N, Z = np.meshgrid(
    np.arange(M_MIN, M_MAX, DD),
    np.arange(L_MIN, L_MAX, DD),
    np.arange(N_MIN, N_MAX, DD),
    np.arange(0.30,  2.50, DD),
)
def get_N(func_Psi, func_I, func_f, func_dV_div_dz):
    """Return Bongiorno eq. 3"""
    def func(log_M_star, log_lambda_SAR, log_N_H, z):
        return (
            func_Psi(10 ** log_M_star, 10 ** log_lambda_SAR, z)
            * func_I(10 ** log_L_X, z, 10 ** log_N_H)
            * func_f(10 ** log_N_H, 10 ** log_L_X, z)
            * func_dV_div_dz(z)
        )
    return np.sum(func(LOG_M, LOG_L, LOG_N, Z)) * DD ** 4


def get_likelihood(*args, **kwargs):
    ln_p_i = np.log(get_p_i(*args, **kwargs))
    ln_p_j = np.log(get_p_j(*args, **kwargs))
    return -2 * (np.sum(ln_p_i) + np.sum(ln_p_j))


def get_f_star(M, z, log_M_star_star=10.99, k_log_M_star=-0.25, alpha=0.24, z1=2.59):
    """Bon+16 Eq. (10)"""
    log_M_star = np.where(
        z < z1,
        log_M_star_star,
        log_M_star_star + k_log_M_star * (z - z1)
    )
    x = M / 10 ** log_M_star
    return np.power(x, alpha) * np.exp(-x)


def get_f_lambda_SAR(
        lambda_SAR,
        M,
        z,
        log_lambda_SAR_star0=33.8,
        k_lambda=-0.48,
        log_M_star0=11.0,
        gamma10=-1.01,
        k_gamma=0.58,
        k_gamma1=-0.06,
        gamma2=-3.72,
        z0=1.10,
        z1=2.59
):

    """Implements Eq. (11)"""

    # NOTE: high-redshift case
    z = np.atleast_1d(z)
    def get_gamma1(z):
        if z1 is None:
            return gamma10 + k_gamma * (z - z0)
        return np.where(
            z < z1,
            gamma10 + k_gamma * (z - z0),
            gamma10 + k_gamma * (z1 - z0) + k_gamma1 * (z - z1)
        )
    gamma1 = get_gamma1(z)

    lambda_SAR_star = 10 ** (log_lambda_SAR_star0 + k_lambda * (np.log10(M) - log_M_star0))
    x = lambda_SAR / lambda_SAR_star

    x1 = x ** -gamma1
    x2 = x ** -gamma2
    return 1 / (x1 + x2)


def get_f_z(z, p1=5.82, p2=2.36, p3=-4.64, z0=1.10, z1=2.59):

    """Bon+16 Eq. (12)"""

    # f_z is defined piecewise for three different redshift ranges
    #   0  <= z <  z0  (Bon+16)
    #   z0 <= z <= z1  (Bon+16)
    #   z1 <= z <  inf (Vii extension hi-z)
    x  = 1 + z
    x0 = 1 + z0

    f1 = lambda x: x ** p1
    f2 = lambda x: f1(x0) * (x / x0) ** p2

    if z1 is not None:
        x1 = 1 + z1
        f3 = lambda x: f2(x1) * (x / x1) ** p3

    ret = np.empty_like(z)
    if z1 is None:
        ret = np.where(z < 99, f2(x), ret)
        ret = np.where(z < z0, f1(x), ret)
    else:
        ret = np.where(z < 99, f3(x), ret)
        ret = np.where(z < z1, f2(x), ret)
        ret = np.where(z < z0, f1(x), ret)

    return ret


def get_Psi(
    M,
    lambda_SAR,
    z,
    log_Psi_star=-6.86,
    kwargs_f_lambda_SAR={},
    kwargs_f_star={},
    kwargs_f_z={}
):
    """Bon+16 Eq. (7)"""
    M = np.atleast_1d(M)
    lambda_SAR = np.atleast_1d(lambda_SAR)
    z = np.atleast_1d(z)
    return (
        10 ** log_Psi_star *
        get_f_lambda_SAR(lambda_SAR, M, z, **kwargs_f_lambda_SAR) *
        get_f_star(M, z, **kwargs_f_star) *
        get_f_z(z, **kwargs_f_z)
    )


def get_Phi_lambda_SAR(
    lambda_SAR,
    z,
    mmin=M_MIN,
    mmax=M_MAX,
    kwargs_f_lambda_SAR={},
    kwargs_f_star={},
    kwargs_f_z={}
):
    def fun(l):
        def integrand(logM):
            return get_Psi(
                10 ** logM,
                l,
                z,
                kwargs_f_lambda_SAR=kwargs_f_lambda_SAR,
                kwargs_f_star=kwargs_f_star,
                kwargs_f_z=kwargs_f_z,
            )
        # NOTE: upper limit of integration selected to prevent an overflow error
        return quad(integrand, a=mmin, b=mmax,)[0]
    return np.vectorize(fun)(lambda_SAR)


def get_Phi_star(
    m,
    z,
    lmin=L_MIN,
    lmax=L_MAX,
    kwargs_f_lambda_SAR={},
    kwargs_f_star={},
    kwargs_f_z={}
):
    def fun(m):
        def integrand(logl):
            return get_Psi(
                m,
                10 ** logl,
                z,
                kwargs_f_lambda_SAR=kwargs_f_lambda_SAR,
                kwargs_f_z=kwargs_f_z,
            )
        # NOTE: upper limit of integration selected to prevent an overflow error
        return quad(integrand, a=lmin, b=lmax)[0]
    return np.vectorize(fun)(m)


from scipy.integrate import dblquad
def get_Phi_lx(lx, z, lmin=L_MIN, mmin=M_MIN, dlx=1e-9, *args, **kwargs):
    """
    Return the X-ray luminosity function by integrating two times over the
    joint distribution function.

    In detail, given log_LX, with log_lambda = log_LX - log_Mstar the total
    number of AGN is given by integral over the bivariate distribution:

        int_lmin^lmax(m) int_mmin^mmax Psi(lambda, Mstar) dl dm

    where dl = dlog_lambda and dm = dlog_Mstar. The lower limits for both l and
    m are defined by Bon+16 lmin=32, mmin=M_MIN. log_lambda = log_LX - log_Mstar
    defines a curve with slope -1 in l-m space. Thus integrating over a small
    wedge defined via log_LX, log_LX + dlog_LX yields the total number of AGN.
    Differentiating the total number of AGN yields the X-ray luminosity
    function.
    """
    def get(lx, z):

        filename = f"data/xlf/bongiorno2016/{lx}_{z}.npy"
        if os.path.exists(filename):
            return np.load(filename)

        psi = lambda l, m: get_Psi(10 ** m, 10 ** l, z, *args, **kwargs)
        n = dblquad(psi, mmin, lx - lmin, lambda x: lx - x, lambda x: lx + dlx - x)[0]
        ret = n / dlx
        np.save(filename, ret)

        return ret

    return np.vectorize(get)(lx, z)


def get_duty_cycle(
    m,
    z,
    t,
    smf_agn=StellarMassFunctionBongiorno2016(),
    smf_gal=StellarMassFunctionCOSMOS2020(),
    frac_ctk_agn=0.0,
    *args,
    **kwargs
):
    phi_agn = smf_agn.get_stellar_mass_function(10 ** m, z, t, *args, **kwargs)
    phi_gal = smf_gal.get_stellar_mass_function(10 ** m, z, t)

    # Estimate the initial duty cycle
    duty_cycle = phi_agn / phi_gal
    print("NOTE: duty_cycle_old is", duty_cycle.sum())

    # NOTE: account for CTK AGN
    print("NOTE: CTK AGN Fraction is", frac_ctk_agn)
    duty_cycle /= (1 - frac_ctk_agn)
    print("NOTE: duty_cycle_new is", duty_cycle.sum())

    # NOTE: ceil duty cycle to 1.00 at max
    duty_cycle = np.where(duty_cycle <= 1.0, duty_cycle, 1.0)

    return duty_cycle


def get_plambda(mlo, mhi, z, lvec, *args, **kwargs):

    filename = f"{ROOT}/data/plambda/bongiorno2016/m_{mlo}_{mhi}_z_{z}_dl_{np.diff(lvec)[0]}.txt"
    if os.path.exists(filename):
        return np.loadtxt(filename)

    # Estimate lambda SAR at redshifts
    phi_SAR = get_Phi_lambda_SAR(10 ** lvec, z, mmin=mlo, mmax=mhi)

    # Normalize to one at log_lambda > -2 and then to the value of the duty cycle
    p = phi_SAR / np.sum(phi_SAR[lvec >= 32])

    # To units of 1/dex
    dl = np.diff(lvec)
    assert np.allclose(dl[0], dl[1:])
    p /= dl[0]

    np.savetxt(filename, p)
    return p


def get_log_lambda_SAR(mlo, mhi, z, t="all", size=1, *args, **kwargs):

    print(f"Getting log_lambda_SAR for (mlo, mhi, z, t) = ({mlo:6.2f}, {mhi:6.2f}, {z:6.2f}, {t})")

    # Lambda vector used for resampling
    lvec = np.linspace(32, 36, 401)
    dl = np.diff(lvec)[0]

    filename = f"{ROOT}/data/bongiorno2016/plambda/m_{mlo}_{mhi}_z_{z}_t_{t}.txt"
    if not os.path.exists(filename):
        p = get_plambda(mlo, mhi, z, t, *args, **kwargs)
        np.savetxt(filename, get_plambda(mlo, mhi, z, t, *args, **kwargs))
    else:
        p = np.loadtxt(filename)

    # NOTE: normalize p(lambda) to [0, 1]
    c = np.cumsum(p * dl)
    c = c / c[-1]

    lvec = np.append([32.0], lvec)
    c = np.append([0.0], c)

    if c.max() > 1 and not np.isclose(c.max(), 1.0):
        raise ValueError("Probability distribution not properly normalized")

    l_sample = np.interp(
        np.random.rand(size),
        c,
        lvec,
        left=-np.inf,
        right=-np.inf,
    )

    return l_sample


if __name__ == "__main__":
    import sys
    from itertools import product
    idx = int(sys.argv[1])
    for i, (z, m) in enumerate(product(np.arange(0.0, 10.0, 0.10), np.arange(8.0, 14.0, 0.10))):
        if i != idx:
            continue
        get_log_lambda_SAR(m, m + 0.10, z)
