#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-04-20 15:57:25

"""
X-ray/optical luminosity function
"""


import argparse
import glob
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.coordinates as c
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

from astropy.io.ascii import read


BAND_PARAMETERS_BC = {
    "B band":       (+3.759, -0.361, +9.830, -0.0063),
    "UV":           (+1.862, -0.361, +4.870, -0.0063),
    "Soft X-ray":   (+5.712, -0.026, +17.67, +0.2780),
    "Hard X-ray":   (+4.073, -0.026, +12.60, +0.2780),
    "Mid-IR":       (+4.361, -0.361, +11.40, -0.0063),
}

def get_luminosity_bolometric(L_band, c_1, k_1, c_2, k_2, lmin=38, lmax=52, nl=14001):
    """Return Eq. 5 of Shen+ 2020"""
    def get_RHS(L_bol):
        # Return the ratio L_bol / L_band
        norm = 1e10
        return L_bol / (c_1 * (L_bol / norm) ** k_1 + c_2 * (L_bol / norm) ** k_2)

    L_bol = (np.logspace(lmin, lmax, nl) * u.erg / u.s).to(u.L_sun).value
    RHS = get_RHS(L_bol)
    return np.interp(L_band, RHS, L_bol)


def get_bc(L_bol, c_1, k_1, c_2, k_2):
    """Shen bolometric correction"""
    def log_pl(c, k):
        return np.log10(c) + k * (np.log10(L_bol) - 10)
    return 10 ** sum(log_pl(c, k) for c, k in [(c_1, k_1), (c_2, k_2)])


def get_qlf(L, z, a0=0.8569, a1=-0.2614, a2=0.02, b0=2.5375, b1=-1.0425,
            b2=1.1201, c0=13.0088, c1=-0.5759, c2=0.4554, d0=-3.5426,
            d1=-0.3936, z_ref=2.0):
    """Shen quasar luminosity function"""
    def T0(x): return 1
    def T1(x): return x
    def T2(x): return 2 * x ** 2 - 1

    x = 1 + z
    x0 = 1 + z_ref
    gamma1 = a0 * T0(x) + a1 * T1(x) + a2 * T2(x)
    gamma2 = 2 * b0 / ((x / x0) ** b1 + (x / x0) ** b2)
    log_L = 2 * c0 / ((x / x0) ** c1 + (x / x0) ** c2)
    log_phi = d0 * T0(x) + d1 * T1(x)
    a = L / 10 ** log_L
    return 10 ** log_phi / (a ** gamma1 + a ** gamma2)


def get_2PL(L_X, A_44_z, L_X_star, gamma_1, gamma_2):
    """Return eq. (22) of Miyaji+ 15"""
    return np.ma.true_divide(
        A_44_z * ((1e44 / L_X_star) ** gamma_1 + (1e44 / L_X_star) ** gamma_2),
        ((L_X / L_X_star) ** gamma_1 + (L_X / L_X_star) ** gamma_2)
    )


miyaji2015 = read("data/miyaji2015/2PL_miyaji2015_table3.dat")
def get_xlf(L_X, z):
    # NOTE: no extrapolation
    z = np.clip(z, 0.015, 5.8 - 1e-6)

    select = (miyaji2015["z_min"] <= z) * (z < miyaji2015["z_max"])
    m = miyaji2015[select]

    return get_2PL(
        10 ** L_X,
        m["A_44^z=z_c"],
        10 ** m["log_L_*^z=z_c"],
        m["gamma_1"],
        m["gamma_2"]
    )


from astropy.cosmology import FlatLambdaCDM
cosmo = FlatLambdaCDM(Om0=0.30, H0=70, Tcmb0=2.73)
def get_volume(area, z_min, z_max):
    """Return the cosmological volume in Mpc ** 3"""
    volume = cosmo.comoving_volume(z_max) - cosmo.comoving_volume(z_min)
    area *= (u.deg ** 2 / (4 * np.pi * u.sr)).si
    volume *= area
    assert volume.unit == u.Mpc ** 3
    return volume.value


def get_lx_redshift(area, z_min, z_max, dlogLX=0.01):
    """
    area    area to be simulated in deg2
    z_min   minimum redshift
    z_max   maximum redshift
    dlogLX  binning parameter for logLX, should match the binsize of whichever histogram is used for validation
    """

    x = np.arange(42.0, 46.0, step=dlogLX)
    y = get_xlf(10 ** x, z_min, z_max)
    Y = np.cumsum(y) / np.sum(y)

    # Integrate the luminosity function to find the number of AGN
    volume = get_volume(area, z_min, z_max)
    N_agn = np.round(np.sum(y) * dlogLX * volume).astype(np.int64)

    return 10 ** np.interp(np.random.rand(N_agn), Y, x, left=x.min(), right=x.max())


#F = None
#def get_xlf(z, l, *args, **kwargs):
#    from scipy.interpolate import SmoothBivariateSpline
#    zz, ll, pp, dd = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat", usecols=(2, 5, -2, -1)).T
#
#    pp_l = np.log10(pp - dd)
#    pp_u = np.log10(pp + dd)
#    dd = 0.5 * (pp_u - pp_l)
#    pp = np.log10(pp)
#
#    F = SmoothBivariateSpline(zz, ll, pp, w=dd, *args, **kwargs)
#
#    return 10 ** F(z, l, grid=False)


def get_sample(area, z_min, z_max, *args, **kwargs):

    redshift = np.empty(0)
    lx = np.empty(0)

    for z_lo, z_mi, z_hi in zip(
        np.sort(np.unique(miyaji2015["z_min"])),
        np.sort(np.unique(miyaji2015["z_c"])),
        np.sort(np.unique(miyaji2015["z_max"])),
    ):

        if not (z_min <= z_mi < z_max):
            continue

        lx_current = get_lx_redshift(area, z_lo, z_hi, *args, **kwargs)
        lx = np.append(lx, lx_current)

        redshift_current = None
        redshift_current = np.full_like(lx_current, z_mi)
        redshift = np.append(redshift, redshift_current)

    return redshift, lx



def main():

    dlogLX = 0.001
    area = 100
    logLX = np.arange(42, 46 + 1e-6, dlogLX)

    for zlim in 9.99,:
        for i, dist in enumerate([None, 'uniform', 'gaussian']):
            z, lx = get_sample(area, dlogLX=dlogLX, dist=dist, zlim=zlim)
            d = cosmo.luminosity_distance(z).cgs.value
            Gamma = 1.9
            fx = np.log10(lx) - np.log10(4 * np.pi * d ** 2) + (2 - Gamma) * np.log10(1 + z)
            fx = np.log10(convert_flux(10 ** fx, 2, 10, 2, 7, Gamma=1.7))
            for fx_lo in np.linspace(-17, -12, 51):
                plt.plot(fx_lo, np.sum(fx > fx_lo) / 100, '.', color="C%d" % i)

    print("luo")
    x, y = np.loadtxt("data/luo2017/Default Dataset.csv").T
    plt.plot(np.log10(x), y)

    print("mock")
    fits = fitsio.read("test_mock/catalog.fits")
    for fx_lo in np.linspace(-17, -12, 51):
        plt.plot(fx_lo, np.sum(fits["log_FX_2_7"] > fx_lo) / 10, 'kx')

    print("ueda")
    x, y = np.loadtxt("/Users/akke.viitanen/Downloads/Default Dataset-2.csv").T
    f = (10 ** x / 1e-14) ** 1.5
    print(y / f)
    plt.plot(x, y / f)

    print("ueda CTN")
    flux = np.linspace(-16.8, -12, 51)
    x1, y1 = np.loadtxt("/Users/akke.viitanen/Downloads/Default Dataset-4.csv").T
    x2, y2 = np.loadtxt("/Users/akke.viitanen/Downloads/Default Dataset-5.csv").T
    y = np.interp(flux, x1, y1) + np.interp(flux, x2, y2)
    f = (10 ** flux / 1e-14) ** 1.5
    plt.plot(flux, y / f, linestyle='dashed')

    plt.show()

    #fig, axes = plt.subplots(4, 3, figsize=(4 * 6.4, 3 * 4.8))
    #for ax, z_min, z_max in zip(
    #        axes.flatten(),
    #        np.sort(np.unique(miyaji2015["z_min"])),
    #        np.sort(np.unique(miyaji2015["z_max"])),
    #):

    #    # Plot the sample
    #    ax.hist(lx[(z_min <= z) * (z < z_max)], bins=10 ** logLX, histtype="step")

    #    # Plot expected number from Miyaji+ 15
    #    ax.plot(
    #        10 ** logLX,
    #        get_xlf(10 ** logLX, z_min, z_max) * get_volume(area, z_min, z_max) * dlogLX
    #    )
    #    ax.loglog()
    #plt.show()

    from catalog_galaxy_agn import get_luminosity_bolometric
    lbol = get_luminosity_bolometric(
        lx * (u.erg/u.s).to(u.L_sun),
        4.073, -0.026, 12.60, 0.278
    ) * (u.L_sun).to(u.erg/u.s)
    bins = np.logspace(42, 48, 61)

    fig, axes = plt.subplots(4, 3, figsize=(4 * 6.4, 3 * 4.8))
    for ax, z_min, z_max in zip(
            axes.flatten(),
            np.sort(np.unique(miyaji2015["z_min"])),
            np.sort(np.unique(miyaji2015["z_max"])),
    ):

        volume = get_volume(area, z_min, z_max)
        select = (z_min <= z) * (z < z_max)

        ## Plot the sample
        #ax.hist(
        #    lbol[select],
        #    bins=bins,
        #    weights=np.full_like(lbol[select], 1 / (volume * 0.10)),
        #    histtype="step"
        #)

        # Plot QLF
        qlf1 = get_qlf(bins * (u.erg/u.s).to(u.L_sun), z_min)
        qlf2 = get_qlf(bins * (u.erg/u.s).to(u.L_sun), z_max)
        mi = np.min([qlf1, qlf2], axis=0)
        ma = np.max([qlf1, qlf2], axis=0)
        ax.fill_between(bins, mi, ma, alpha=0.2, color='red')
        ax.loglog()

        # Plot Miyaji+15 points
        lbol = get_luminosity_bolometric(
            lx * (u.erg/u.s).to(u.L_sun),
            4.073, -0.026, 12.60, 0.278
        ) * (u.L_sun).to(u.erg/u.s)

        ax.plot(
            get_luminosity_bolometric(
                10 ** logLX * (u.erg/u.s).to(u.L_sun),
                4.073, -0.026, 12.60, 0.278
            ) * (u.L_sun).to(u.erg/u.s),
            get_xlf(10 ** logLX, z_min, z_max),
        )
        ax.loglog()

    plt.show()




if __name__ == "__main__":
    main()
