#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-06-26 10:51:39

"""
Implement Ueda+2014
"""


import argparse
import glob
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.coordinates as c
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp


def get_Psi(log_L_X, z, beta=0.24, Psi_min=0.20, Psi_max=0.84, *args, **kwargs):
    
    """Eq. 3"""

    log_L_X = np.atleast_1d(log_L_X)
    z = np.atleast_1d(z)

    ret = []
    Psi_4375 = get_Psi_4375(z, *args, **kwargs)

    ret = Psi_4375 - beta * (log_L_X - 43.75)
    ret[ret < Psi_min] = Psi_min
    ret[ret > Psi_max] = Psi_max
    return ret


def get_Psi_4375(z, Psi_4375_0=0.43, a1=0.48):
    """Eq. 4"""
    return np.where(
        z < 2.0,
        Psi_4375_0 * (1 + z) ** a1,
        Psi_4375_0 * (1 + 2) ** a1,
    )


def get_f(log_L_X, z, log_N_H, epsilon=1.7, f_CTK=1.0, *args, **kwargs):

    """Eqs. 5 & 6"""

    log_L_X = np.atleast_2d(log_L_X)
    z = np.atleast_2d(z)
    log_N_H = np.atleast_2d(log_N_H)

    Psi = get_Psi(log_L_X, z)

    select_Psi_1 = Psi  < (1 + epsilon) / (3 + epsilon)
    select_Psi_2 = Psi >= (1 + epsilon) / (3 + epsilon)

    select_NH_00_21 = ( 0 <= log_N_H) * (log_N_H < 21)
    select_NH_21_22 = (21 <= log_N_H) * (log_N_H < 22)
    select_NH_22_23 = (22 <= log_N_H) * (log_N_H < 23)
    select_NH_23_24 = (23 <= log_N_H) * (log_N_H < 24)
    select_NH_24_99 = (24 <= log_N_H) * (log_N_H < 99)

    ret = np.zeros_like(Psi)

    # NOTE: the following contains all the conditions and their piecewise
    # definitions from Ueda+2014
    select_value = [
        # Case Psi < (1 + epsilon) / (3 + epsilon)
        (select_Psi_1 * select_NH_00_21, 1 - (2 + epsilon) / (1 + epsilon) * Psi),
        (select_Psi_1 * select_NH_21_22, 1 / (1 + epsilon) * Psi),
        (select_Psi_1 * select_NH_22_23, 1 / (1 + epsilon) * Psi),
        (select_Psi_1 * select_NH_23_24, epsilon / (1 + epsilon) * Psi),
        (select_Psi_1 * select_NH_24_99, f_CTK / 2 * Psi),
        # Case Psi >= (1 + epsilon) / (3 + epsilon)
        (select_Psi_2 * select_NH_00_21, 2 / 3. - (3 + 2 * epsilon) / (3 + 3 * epsilon) * Psi),
        (select_Psi_2 * select_NH_21_22, 1 / 3. - epsilon / (3 + 3 * epsilon) * Psi),
        (select_Psi_2 * select_NH_22_23, 1 / (1 + epsilon) * Psi),
        (select_Psi_2 * select_NH_23_24, epsilon / (1 + epsilon) * Psi),
        (select_Psi_2 * select_NH_24_99, f_CTK / 2 * Psi),
    ]

    for select, value in select_value:
        ret[select] = value[select]

    return np.squeeze(ret)

    #def get_single(log_L_X, z):
    #    Psi = get_Psi(log_L_X, z)
    #    if Psi < (1 + epsilon) / (3 + epsilon):
    #        if       log_N_H < 21: return 1 - (2 + epsilon) / (1 + epsilon) * Psi
    #        if 21 <= log_N_H < 22: return 1 / (1 + epsilon) * Psi
    #        if 22 <= log_N_H < 23: return 1 / (1 + epsilon) * Psi
    #        if 23 <= log_N_H < 24: return epsilon / (1 + epsilon) * Psi
    #        if 24 <= log_N_H     : return f_CTK / 2 * Psi

    #    if Psi >= (1 + epsilon) / (3 + epsilon):
    #        if       log_N_H < 21: return 2 / 3. - (3 + 2 * epsilon) / (3 + 3 * epsilon) * Psi
    #        if 21 <= log_N_H < 22: return 1 / 3. - epsilon / (3 + 3 * epsilon) * Psi
    #        if 22 <= log_N_H < 23: return 1 / (1 + epsilon) * Psi
    #        if 23 <= log_N_H < 24: return epsilon / (1 + epsilon) * Psi
    #        if 24 <= log_N_H     : return f_CTK / 2 * Psi

    #return np.vectorize(get_single)(log_L_X, z)


def plot_merloni2014():

    lxs = np.linspace(42.0, 46.0, 41)
    zs = 0.55, 0.95, 1.30, 1.80, 2.80

    fig, axes = plt.subplots(2, 3, figsize=(1.8 * 6.4, 1.8 * 4.8))
    axes = axes.flatten()

    for i, (zmin, zmax, marker, color) in enumerate([
            (0.30, 0.80, 'o', 'purple'),
            (0.80, 1.10, '^', 'blue'),
            (1.10, 1.50, 'v', 'cyan'),
            (1.50, 2.10, 's', 'green'),
            (2.10, 3.50, '*', 'red')
    ]):
       
        f_obs_1 = 100 * (get_f(lxs, zmin, 22) + get_f(lxs, zmin, 23))
        f_obs_2 = 100 * (get_f(lxs, zmax, 22) + get_f(lxs, zmax, 23))
        axes[i].fill_between(lxs, f_obs_1, f_obs_2, color=color, alpha=0.2, label="Ueda+ 14")

        filename = f"data/merloni2014/fig7/{zmin:.1f}_z_{zmax:.1f}.dat"
        x, ylo, ymi, yhi = np.loadtxt(filename).T
        axes[i].errorbar(
            x,
            ymi,
            yerr=(ymi - ylo, yhi - ymi),
            linestyle="none",
            marker=marker,
            color=color,
            label=f"Merloni+ 14",
        )
        axes[i].set_title(f"{zmin} < z < {zmax}")
        axes[i].legend(loc="upper right")

        axes[i].set_xlim(42, 46)
        axes[i].set_ylim(0, 100)
        axes[i].set_xlabel(r"$\log L_\mathrm{X}$ [erg/s]")
        axes[i].set_ylabel(r"Obscured AGN fraction [%]")

    axes[-1].remove()
    plt.savefig("fig/merloni2014.pdf")


def plot_frac_ctk():
    plt.figure()
    NH = 24
    Z, LX = np.meshgrid(np.linspace(0.20, 5.50, 34), np.linspace(42, 46, 41))
    f_nh = get_f(LX, Z, NH)
    plt.imshow(f_nh, extent=(0.20, 5.50, 42, 46))
    plt.xlabel(r"$z$")
    plt.ylabel(r"$\log L_\mathrm{X}$ [erg/s]")
    cbar = plt.colorbar()
    plt.title("Ueda+14 $\mathrm{frac}_\mathrm{CTK}(L_\mathrm{X}, z)$")
    plt.savefig("fig/redshift_luminosity_frac_ctk_ueda2014.pdf")


def main():
    plot_merloni2014()
    plot_frac_ctk()
    return 0


if __name__ == "__main__":
    main()
