#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-04-11 17:02:54

"""
Aird+ implementation for p(lambda)
"""


import argparse
import glob
from itertools import product
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.integrate import quad
from scipy.interpolate import interp1d

from util import ROOT


REDSHIFT_DOUBLE_POWER_LAW_BONGIORNO2016 = {
    (0.30, 0.80): (10 ** -6.04, -1.35, -3.64, 34.33),
    (0.80, 1.50): (10 ** -5.22, -1.02, -3.61, 34.32),
    (1.50, 2.50): (10 ** -4.85, -0.54, -3.58, 34.30),
}

REDSHIFT_AIRD18 = [
    (0.1, 0.5),
    (0.5, 1.0),
    (1.0, 1.5),
    (1.5, 2.0),
    (2.0, 2.5),
    (2.5, 3.0),
    (3.0, 4.0),
]


def get_double_powerlaw(lambda_sar, log_phi_star, gamma1, gamma2, log_lambda_star):
    a = (lambda_sar / 10 ** log_lambda_star) ** -gamma1
    b = (lambda_sar / 10 ** log_lambda_star) ** -gamma2
    return log_phi_star / (a + b)


class Plambda:

    def __init__(
        self,
        method="aird+2018",
        no_classification=False,
        dm=0.10,
        dz=0.10,
    ):
        self.method = method
        self.no_classification = no_classification
        self.dm = dm
        self.dz = dz

        self.data = self.get_data()
        self.cs = np.unique(self.data[:, 0]).astype(np.int64)
        self.zbins = np.unique(np.concatenate([self.data[:, 1], self.data[:, 2]]))
        self.mbins = np.unique(np.concatenate([self.data[:, 3], self.data[:, 4]]))
        self.cbins = self.get_bins()

    def get_bins(self):
        """Return the SFR bins from Air+2018"""
        return -np.inf, -1.3, np.inf

    def get_data(self):

        # Columns are
        #   classification zmin zmax mass_lo mass_hi logl p(median) p_lolim p_uplim flag
        data = np.empty((0, 10))

        # NOTE: Aird+2018 is given in dex-1 while Aird+2019 is apparently NOT
        if self.method == "aird+2018":
            data = np.empty((0, 10))
            for i, f in [
                    (-1, f"{ROOT}/data/aird2018/pledd_all.dat"),
                    (+0, f"{ROOT}/data/aird2018/pledd_qu.dat"),
                    (+1, f"{ROOT}/data/aird2018/pledd_sf.dat"),
            ]:
                d = np.loadtxt(f)
                v = np.full(d.shape[0], i)[:, None]
                d = np.hstack([v, d])
                data = np.append(data, d, axis=0)
        elif self.method == "bongiorno+2016":
            log_lambda_sar = np.linspace(32.0, 36.0, 41)
            log_stellar_mass = np.linspace(9.5, 20.0, 2)

            # NOTE: zbins from Miyaji +15
            zbins = [
                (0.015, 0.2),
                (0.200, 0.4),
                (0.400, 0.6),
                (0.600, 0.8),
                (0.800, 1.0),
                (1.000, 1.2),
                (1.200, 1.6),
                (1.600, 2.0),
                (2.000, 2.4),
                (2.400, 3.0),
                (3.000, 5.8),
            ]
            for (zlo, zhi) in zbins:
                for (mlo, mhi) in zip(log_stellar_mass[:-1], log_stellar_mass[1:]):
                    # NOTE: integrate in small bins of stellar mass
                    p = np.zeros_like(log_lambda_sar)
                    #get_Psi_lambda_SAR(0.5 * (mhi + mlo), 10 ** log_lambda_sar, 0.5 * (zlo + zhi))
                    row = np.vstack([
                        np.full_like(log_lambda_sar, -1),
                        np.full_like(log_lambda_sar, zlo),
                        np.full_like(log_lambda_sar, zhi),
                        np.full_like(log_lambda_sar, mlo),
                        np.full_like(log_lambda_sar, mhi),
                        log_lambda_sar - 34, # NOTE: in "eddington" units
                        p,
                        p,
                        p,
                        np.full_like(log_lambda_sar, 1)
                    ]).T
                    data = np.append(data, row, axis=0)
        return data

    def get_main_sequence(self, log_stellar_mass, redshift):
        """Equation (8) of Aird+2019"""
        return -7.6 + 0.76 * log_stellar_mass + 2.95 * np.log10(1 + redshift)

    def get_classification(
        self,
        redshift,
        log_stellar_mass,
        log_star_formation_rate,
    ):

        """
        Get the galaxy classification according to Aird. The classification
        depends on the number of bins given, which quantify the difference with
        respect to the main sequence of star-formation.
        """

        if self.no_classification:
            return -1

        log_sfr_ms = self.get_main_sequence(log_stellar_mass, redshift)
        diff = log_star_formation_rate - log_sfr_ms
        ret = np.digitize(diff, bins=self.cbins, right=True) - 1
        assert np.all((0 <= ret) * (ret < len(self.cbins)))

        return ret

    def get_log_lambda_sBHAR(
        self,
        classification,
        redshift,
        log_stellar_mass,
        *args,
        **kwargs
    ):

        """
        Sample p(lambda_sBHAR) to derive lambda_sBHAR
        """

        import bongiorno2016, aird2018

        classification = np.array([
            {
                -1: "all",
                +0: "quiescent",
                +1: "star-forming",
            }.get(c) for c in classification
        ])

        log_lambda_sBHAR = None
        if self.method == "bongiorno+2016":
            m = bongiorno2016
        elif self.method == "aird+2018":
            m = aird2018

        return m.get_log_lambda_sBHAR(
            log_stellar_mass,
            redshift,
            classification,
            self.dm,
            self.dz,
            *args,
            **kwargs,
        )


from smf import StellarMassFunctionCOSMOS2020AGN
import aird2018, bongiorno2016
SMF = {}
MVEC = 10 ** np.linspace(8, 16, 801)
def get_plambda_bon16_air18(mlo, mhi, z, lvec, smf=StellarMassFunctionCOSMOS2020AGN()):

    """
    Get a plambda which is consistent with Bon+16 total, but gets the SF/Q
    fractions from Aird+2018
    """

    ##################################################
    # 1: Estimate stellar mass function at z of sample
    mcen = 10 ** ((mlo + mhi) / 2)
    phis = []
    for t in "quiescent", "star-forming", "all":
        key = z, t
        if key not in SMF:
            SMF[key] = smf.get_stellar_mass_function(MVEC, z, t)
        phis.append(np.interp(mcen, MVEC, SMF.get(key), left=0, right=0))
    phi_q = np.interp(mcen, MVEC, SMF[z, "quiescent"])
    phi_s = np.interp(mcen, MVEC, SMF[z, "star-forming"])
    phi_t = np.interp(mcen, MVEC, SMF[z, "all"])
    assert np.allclose(phi_q + phi_s, phi_t)
    ##################################################

    #################################
    # 2: Estimate p(lambda) from Aird
    def my_get_plambda(t):
        lam, p = aird2018.get_plambda(mcen, z, t)
        lam += 34
        p[lam < 32] = 0
        p /= p.sum()
        p = np.interp(lvec, lam, p, left=0, right=0)
        return p / p.sum()
    p_lambda_q = my_get_plambda("quiescent")
    p_lambda_s = my_get_plambda("star-forming")
    #################################

    ######################################
    # 3: Estimate Bongiorno+2016 p(lambda) and normalize to 1
    p_lambda_t = bongiorno2016.get_plambda(mlo, mhi, z, lvec)
    p_lambda_t /= p_lambda_t.sum()
    ######################################

    ###########################################
    # 4: A factor to renormalize Aird p(lambda)
    #
    # The factor is derived from assuming
    #
    #   Phi_SAR_q' + Phi_SAR_s' = Phi_SAR_bon16,
    #
    # where 'q' refers to quiescent, 's' to star-forming and
    # 'bon16' to the total. Moreover, Phi_SAR maybe written as
    #
    #   Phi_SAR = int p(lambda) Phi_star dlambda,
    #
    # where Phi_star is the stellar mass function and int p(lambda)
    # dlambda = 1.
    #
    # Finally, at some lambda, we may write the previous equation as
    #
    #       Phi_star_q' * p_lambda_q + Phi_star_s' * p_lambda_s = Phi_star * p_lambda_t
    #   =>  A * Phi_star_q * p_lambda_q + A * Phi_star_s * p_lambda_s = Phi_star * p_lambda_t
    #   =>  A = Phi_star * p_lambda_t / (Phi_star_q * p_lambda_q + Phi_star_s * p_lambda_s)
    A = np.ma.true_divide(p_lambda_t * phi_t, p_lambda_q * phi_q + p_lambda_s * phi_s)
    pq = A * p_lambda_q ; pq /= np.sum(pq)
    ps = A * p_lambda_s ; ps /= np.sum(ps)
    return lvec, pq, ps
    ###########################################


DATA_ZOU2024 = {}
def get_plambda_zou2024(log_lambda_sar, t, log_mstar, z, fill_value="extrapolate"):

    # Fix the typename
    if t == "all":
        t = "main"

    mbins = 9.50, 10.00, 10.50, 11.00, 11.50, 12.00
    mcens = 9.75, 10.25, 10.75, 11.25, 11.75

    zbins = 0.10, 0.50, 1.00, 1.50, 2.00, 2.50, 3.00, 4.00
    zcens = 0.30, 0.75, 1.25, 1.75, 2.25, 2.75, 3.50

    key = t, log_mstar, z, fill_value
    if key not in DATA_ZOU2024:

        def get_idx(v, c):
            idx = np.digitize(v, c) - 1
            return np.clip(idx, 0, len(c) - 1)
        idx_m = get_idx(log_mstar, mcens) ; mm = mcens[idx_m]
        idx_z = get_idx(z, zcens) ; zz = zcens[idx_z]

        fits = fitsio.read(f"data/plambda/zou2024/{t}/plambda.fits", ext=f"logm_{mm}_z_{zz}")

        DATA_ZOU2024[key] = None
        if fits is not None:
            DATA_ZOU2024[key] = interp1d(fits["loglambda"], fits["median"], bounds_error=False, fill_value=fill_value)

    fun = DATA_ZOU2024[key]
    if fun is None:
        return np.full_like(log_lambda_sar, 0)

    return 10 ** DATA_ZOU2024[key](log_lambda_sar)
