#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-02-06 22:12:06

"""
Provide the AGN logN-logS i.e. cumulative number of sources N(>S) above some
flux limit S.
"""

import argparse
import glob
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.coordinates as c
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.interpolate import interp1d
from scipy.stats import poisson
from gdpyc import GasMap
try:
    import xspec
    xspec.Xset.chatter = 0
    xspec.Xset.logChatter = 0
except ImportError:
    pass

UFX = u.erg * u.cm ** -2 * u.s ** -1


def _get_dN_per_dS_AGN(
    S,
    K_AGN,
    beta1_AGN,
    beta2_AGN,
    f_break_AGN,
    S_ref=1e-14 * UFX
):
    K_AGN *= 1e14 * u.deg ** -2 / UFX
    f_break_AGN *= 1e-15 * UFX
    return np.where(
        S <= f_break_AGN,
        K_AGN * (S / S_ref) ** -beta1_AGN,
        K_AGN * (f_break_AGN / S_ref) ** (beta2_AGN - beta1_AGN) * (S / S_ref) ** -beta2_AGN
    )


def get_dN_per_dS_AGN(
    S,
    Emin=0.5 * u.keV,
    Emax=2.0 * u.keV,
    reference="cdfs7ms_luo2017"
):
    """See Luo+17 Eq. (3) and Table 8"""
    args = None
    if (Emin.value, Emax.value) == (0.5, 2.0):
        args = 161.96, 1.52, 2.45, 7.1
    elif (Emin.value, Emax.value) == (2.0, 7.0):
        args = 453.70, 1.46, 2.72, 8.9
    else:
        raise ValueError("unknown band")
    ret = _get_dN_per_dS_AGN(S, *args)
    return ret


def get_N_above_S_AGN(S, Emin, Emax, log=True):
    S_center = (S[:-1] + S[1:]) / 2.0
    if log:
        S_center = (S[:-1] * S[1:]) ** .5

    dS = np.diff(S)
    dN_per_dS_AGN = get_dN_per_dS_AGN(S_center, Emin, Emax)
    dN = dN_per_dS_AGN * dS
    return S_center, np.cumsum(dN[::-1])[::-1]


def get_sample_N_above_S(
    N_sample,
    S_min=1e-17 * UFX,
    S_max=1e-13 * UFX,
    *args,
    **kwargs
):
    S_edge = np.geomspace(S_min.value, S_max.value, 1001) * UFX
    S_center = (S_edge[1:] * S_edge[:-1]) ** .5
    dS = np.diff(S_edge)
    pdf = get_dN_per_dS_AGN(S_center, *args, **kwargs) * dS
    cdf = np.cumsum(pdf)
    cdf /= cdf[-1]
    return np.interp(
        np.random.rand(N_sample),
        cdf,
        S_edge[1:]
    )


def get_galactic_absorption(coord):
    return GasMap.nh(coord, nhmap='LAB')


RATIO = {}
def get_flux_absorbed(
    S,
    coord=None,
    nH=None,
    E_min1=0.2,
    E_max1=12.0,
    E_min2=0.5,
    E_max2=10.0,
    PhoIndex=1.4,
    nH_min=1e17,
    nH_max=1e23,
    N_nH=100,
    S_ref=1e-15 * UFX
):

    # (1) Estimate ratio == Sabsorbed_to_Sintrinsic(nH)
    nHs = np.geomspace(nH_min, nH_max, N_nH + 1) * u.cm ** -2
    if (E_min1, E_max1, E_min2, E_max2, S_ref) not in RATIO:
        model = xspec.Model("tbabs*cflux(powerlaw)")
        model.cflux.Emin = E_min1
        model.cflux.Emax = E_max1
        model.cflux.lg10Flux = np.log10(S_ref.to(UFX).value)
        model.powerlaw.PhoIndex = PhoIndex
        ratio = []
        for _nH in nHs:
            model.TBabs.nH = _nH / (1e22 * u.cm ** -2)
            xspec.AllModels.calcFlux(f"{E_min2} {E_max2}")
            ratio.append(model.flux[0] * UFX / S_ref)
        RATIO[E_min1, E_max1, E_min2, E_max2, S_ref] = ratio
    ratio = RATIO[(E_min1, E_max1, E_min2, E_max2, S_ref)]

    # (2) interpolate the distribution
    from scipy.interpolate import interp1d
    from time import time
    Sabs_to_Sint = interp1d(nHs, ratio)
    if coord is not None:
        t0 = time()
        nH = get_galactic_absorption(coord)
    elif nH is None:
        nH = np.zeros_like(S)

    return Sabs_to_Sint(nH) * S


def convert_flux_to_countrate_liu16(flux_band, Emin=0.2, Emax=12.0, PhoIndex=1.4, nH=0.0):

    """
    Convert a full band 0.2-12.0 keV to counts in the same band. First, the
    input flux is converted to 0.5-10 keV and the linear relation for Liu+16
    flux (0.5-10 keV) vs. netcounts (0.2-12.0 keV) is used to estimate the
    countrate

    NOTE: Liu+2016 seems to report observed (absorbed) fluxes. Thus every
    quantity here referes to the observed fluxes/counts.
    """

    log_flux_band = np.log10(flux_band.value)

    def flux_to_count(log_flux_0p5_10p0):
        """parameters from topcat linear fit 20230214 using FluxF vs. (ctTotF70
        - bkgTotF70) / expTotF70 reported correlation is 0.98."""
        return 0.944325 * log_flux_0p5_10p0 + 10.211855

    model = xspec.Model("tbabs*cflux(powerlaw)")
    model.cflux.Emin = Emin
    model.cflux.Emax = Emax
    model.TBabs.nH = nH
    model.cflux.lg10Flux = log_flux_band
    model.powerlaw.PhoIndex = PhoIndex
    xspec.AllModels.calcFlux(f"0.5 10.0")
    return 10 ** flux_to_count(np.log10(model.flux[0])) * UFX


CATALOG_BAND_COUNTS_LOG = {
    # Catalog      band         is_counts   is_log  factor
    "aegis":       ((0.5, 7.0), True,       False,  0.70 * 1e11 * 1e-3),
    "bootes":      ((0.5, 7.0), False,      False,  None),
    "ccl":         ((0.5, 2.0), False,      True,   None),
    "cdfs4ms":     ((0.5, 7.0), True,       False,  0.70 * 1e11 * 1e-3),
    "cdfs7ms":     (None,       None,       None,   None),
    "cdfn":        ((0.5, 7.0), False,      False,  None),
    "xmmxxl":      ((0.5, 8.0), False,      False,  0.70 * 1e11),  # 70% EEF + ECF
    "xuds":        (None,       None,       None,   None),
}

# ecf
ECF = {
    (0.5, 2.0): (1 / 7.40e-11) * u.count * u.erg ** -1 * u.cm ** 2, # NOTE: from Civano+16 Chandra cycle 14
    (0.5, 7.0): (1 / 1.71e-11) * u.count * u.erg ** -1 * u.cm ** 2, # NOTE: from Civano+16 Chandra cycle 14
    (0.5, 8.0): (1 / 1.00e-11) * u.count * u.erg ** -1 * u.cm ** 2, # NOTE: estimated from a linear fit to pierre+2016 catalog using 70% EEF
}

def convert_flux_to_map(catalog, ra, dec, flux, exp=None, bkg=None, draw_poisson=False):
    band, is_counts, is_log, factor = CATALOG_BAND_COUNTS_LOG[catalog]
    flux_absorbed = get_flux_absorbed(
        flux,
        coord=c.SkyCoord(ra, dec),
        E_min1=0.5,
        E_max1=2.0,
        E_min2=band[0],
        E_max2=band[1],
    )
    assert flux_absorbed.unit == u.erg / u.cm ** 2 / u.s

    if is_counts:
        flux_absorbed *= ECF[band]
    if factor is not None:
        flux_absorbed *= factor

    flux_absorbed = flux_absorbed.value
    if is_log:
        flux_absorbed = np.log10(flux_absorbed)

    # Count-rate to counts conversion
    if exp is not None:
        flux_absorbed *= exp
    if bkg is not None:
        flux_absorbed += bkg
    if draw_poisson:
        flux_absorbed = poisson.rvs(mu=flux_absorbed, size=flux_absorbed.size)

    return flux_absorbed


def main():
    from astropy.coordinates import SkyCoord
    from astropy.wcs import WCS
    s, h = fitsio.read("data/original/maps/specz/catalog_data_brightman14_cosmos.sens", header=True)
    wcs = WCS(h)

    N = 10000000
    S = get_sample_N_above_S(N)
    coords = SkyCoord(
        4.0 * (np.random.rand(S.size) - 0.5) + 150.08,
        4.0 * (np.random.rand(S.size) - 0.5) + 2.20,
        unit="deg"
    )

    # randomize flux 1e-18 to 1e-13
    n, bins, patches = plt.hist(
        S.value,
        bins=np.logspace(-18, -13),
        histtype="step",
        cumulative=-1,
        label="fluxes drawn from N(>S)"
    )
    plt.loglog()
    plt.show()
    quit()
    col, row = np.round(wcs.world_to_pixel(coords)).astype(np.int64)

    # Discard OOB
    keep = (0 <= row) * (row < s.shape[0]) * (0 <= col) * (col < s.shape[1])
    S = S[keep]
    coords = coords[keep]
    row = row[keep]
    col = col[keep]

    # Discard below flux limit
    Sabs = get_flux_absorbed(S, coords)
    keep = (s[row, col] > 0.0) * (Sabs.value > s[row, col])
    S = S[keep]
    Sabs = Sabs[keep]
    coords = coords[keep]
    row = row[keep]
    col = col[keep]

    first = True
    for _S in np.logspace(-18, -13) * UFX:
        if not first:
            plt.plot(_S, (S > _S).sum(), 'k.')
            plt.plot(_S, (Sabs > _S).sum(), 'r.')
        else:
            plt.plot(_S, (S > _S).sum(), 'k.', label="intrinsic flux after sensmap")
            plt.plot(_S, (Sabs > _S).sum(), 'r.', label="(extragal) absorbed flux after sensmap")
            first = False

    _S = np.logspace(-18, -13) * UFX
    x, y = get_N_above_S_AGN(_S)
    plt.plot(x, y * n.mean() / y.mean(), label="7 Ms CDF-S N(>S) [deg-2] (Luo+17, renormalized to histogram)")
    plt.xlabel("flux 0.5-2.0 keV [erg/cm2/s]")
    plt.ylabel("counts")
    plt.legend()
    plt.show()


if __name__ == "__main__":
    main()
