#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-03-13 14:22:48

"""
Stellar mass functions
"""


import argparse
import glob
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.coordinates as c
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.integrate import quad

from util import ROOT


class StellarMassFunction:
    def __init__(self):
        self.name = ""
        self.imf = None
        self.find_zbin = True
        self.types = "all", "star-forming", "quiescent"

    def get_zbins(self):
        raise NotImplementedError

    def _get_stellar_mass_function(self, stellar_mass, zbin, type="all"):
        raise NotImplementedError

    def get_stellar_mass_function(self, stellar_mass, z, type="all", *args, **kwargs):

        stellar_mass = np.atleast_1d(stellar_mass).astype(np.float64)
        smf = np.zeros_like(stellar_mass)

        zs = np.array(self.get_zbins(type))
        idx_zmin = np.argmin(zs[:, 0])
        idx_zmax = np.argmax(zs[:, 1])

        for zlo, zhi in self.get_zbins(type):

            if zlo <= z < zhi:
                smf += self._get_stellar_mass_function(stellar_mass, (zlo, zhi), type, *args, **kwargs)

        return smf

    def get_zbin_nearest(self, zbin, geometric=False):

        zbins = self.get_zbins()

        zlo = np.array([z[0] for z in zbins])
        zhi = np.array([z[1] for z in zbins])

        # NOTE: find the closest zbin by 2-d minimization
        zlo, zhi = np.meshgrid(zlo, zhi)
        d1 = zlo - zbin[0]
        d2 = zhi - zbin[1]
        d3 = (zhi - zlo) - (zbin[1] - zbin[0])
        dt = d1 ** 2 + d2 ** 2 + d3 ** 2
        i1, i2 = np.argwhere(dt == np.min(dt))[0]
        return zlo[i1, i2], zhi[i1, i2]

    def write_egg(self, filename, mbins, zbins=None):

        if zbins is None:
            zbins = self.get_zbins()

        dtype = np.dtype([
            ("ZB",      ">f4", (2, len(zbins))),
            ("MB",      ">f4", (2, len(mbins))),
            ("ACTIVE",  ">f8", (len(zbins), len(mbins))),
            ("PASSIVE", ">f8", (len(zbins), len(mbins))),
            ("IMF",     "<U8")
        ])
        ret = np.zeros(1, dtype=dtype)

        ret["ZB"][0][0] = np.array([z[0] for z in zbins])
        ret["ZB"][0][1] = np.array([z[1] for z in zbins])

        ret["MB"][0][0] = np.array([m[0] for m in mbins])
        ret["MB"][0][1] = np.array([m[1] for m in mbins])
        mcens = np.mean(mbins, axis=1)

        for row, zbin in enumerate(zbins):
            for k1, k2 in (
                ["ACTIVE",  "star-forming"],
                ["PASSIVE", "quiescent"]
            ):
                ret[k1][0][row] = self.get_stellar_mass_function(10 ** mcens, 0.5 * sum(zbin), type=k2)

        ret["IMF"] = self.imf

        fitsio.write(filename, ret, clobber=True)

        return 0

    @staticmethod
    def get_schechter(logM, logMstar, alpha1, Phi1, alpha2=None, Phi2=None, factor=np.log(10)):

        """Eq. 8 from Weaver+2020"""

        def fun(alpha, Phi):
            
            """Single Schechter function"""

            if alpha is None or not np.isfinite(alpha):
                return np.zeros_like(logM)

            return Phi * (10 ** (logM - logMstar)) ** (alpha + 1)

        smf = factor * np.exp(-10 ** (logM - logMstar)) * (fun(alpha1, Phi1) + fun(alpha2, Phi2))
        return smf


class StellarMassFunctionCandels(StellarMassFunction):

    def __init__(self):
        super().__init__()
        self.name = "CANDELS (Schreiber+ 2017)"
        self.imf = "salpeter"
        self.data = fitsio.read(f"{ROOT}/egg/share/mass_func_candels.fits")

    def get_zbins(self, type=None):
        return [(z1, z2) for z1, z2 in self.data["ZB"][0].T]

    def _get_stellar_mass_function(
            self,
            stellar_mass,
            zbin,
            type="all"
        ):

        assert type in ["all", "quiescent", "star-forming"]

        ret = np.zeros_like(stellar_mass)
        z = self.data["ZB"][0]
        m = self.data["MB"][0].mean(axis=0)
        act = np.zeros_like(stellar_mass)
        pas = np.zeros_like(stellar_mass)
        log_stellar_mass = np.log10(stellar_mass)

        select = (z[0, :] == zbin[0]) * (z[1, :] == zbin[1])
        act += np.interp(log_stellar_mass, m, self.data["ACTIVE"][0, select, :].ravel())
        pas += np.interp(log_stellar_mass, m, self.data["PASSIVE"][0, select, :].ravel())
        ret = (type in ["all", "quiescent"]) * pas + (type in ["all", "star-forming"]) * act

        return np.atleast_1d(ret)


class StellarMassFunctionCOSMOS2020(StellarMassFunction):

    def __init__(self):
        super().__init__()
        self.name = "COSMOS2020 (Weaver+ 2022)"
        self.imf = "chabrier"

    def get_zbins(self, type="quiescent"):

        # NOTE: the two last bins are only available for the total stellar mass function
        if type == "all_best_fit":
            name = "total"
        else:
            name = "quiescent" if "all" in type else type
        zbins = [(z1, z2) for z1, z2 in np.loadtxt(f"{ROOT}/data/weaver2022/{name}_ext.dat", usecols=(0, 1))]
        #zbins = [(z1, z2) for z1, z2 in np.loadtxt(f"data/weaver2022/{name}.dat", usecols=(0, 1))]

        return zbins

    def _get_stellar_mass_function(self, stellar_mass, zbin, type="all"):

        smf = np.zeros_like(stellar_mass)

        # Special case: all_best_fit returns the total stellar mass function based on the fit
        if type == "all_best_fit":
            if zbin[0] == 0.0:
                zbin = self.get_zbins(type)[1]
            if zbin[1] == 7.0:
                zbin = self.get_zbins(type)[-2]
            parameters = np.loadtxt(f"{ROOT}/data/weaver2022/total_ext.dat")
            parameters = parameters[(zbin[0] == parameters[:, 0]) * (parameters[:, 1] == zbin[1]), 2:-1][0]
            return self.get_schechter(np.log10(stellar_mass), *parameters)

        for name in "quiescent", "star-forming":
            if type != "all" and name != type:
                continue
            parameters = np.loadtxt(f"{ROOT}/data/weaver2022/{name}_ext.dat")
            parameters = parameters[(zbin[0] == parameters[:, 0]) * (parameters[:, 1] == zbin[1]), 2:-1][0]
            smf += self.get_schechter(np.log10(stellar_mass), *parameters)

        return smf


class StellarMassFunctionBongiorno2016(StellarMassFunctionCOSMOS2020):

    def __init__(self):
        super().__init__()
        self.name = "AGN HGMF (Bongiorno+ 2016)"
        self.imf = "chabrier"
        self.find_zbin = False

    def _get_stellar_mass_function(
        self,
        stellar_mass,
        zbin,
        type="all",
        smf_reference=StellarMassFunctionCOSMOS2020(),
        weight_aird=True,
        schechter=False,
        **kwargs
    ):

        # Get the total stellar mass function
        args = stellar_mass, zbin
        if schechter:
            smf = self._get_stellar_mass_function_schechter(*args, **kwargs)
        else:
            smf = self._get_stellar_mass_function_integral(*args, **kwargs)

        if type == "all":
            return smf

        # If not asked to weight by Aird+, weight the output SMF only by the
        # relative fraction of qu/sf galaxies
        if True:
        #if not weight_aird:
            f = smf_reference._get_stellar_mass_function(stellar_mass, zbin, type=type)
            f_total = smf_reference._get_stellar_mass_function(stellar_mass, zbin, type="all")
            return f / f_total * smf

        # NOTE: to find the quiescent / star-forming fraction some assumptions
        # are needed since Bon+16 do not report these numbers. If we assume
        # AGN qu/sf host population from Air+18, we can work under the
        # assumption that the ratio should be the same. I.e. if B is the number
        # of AGN predicted by Bon+16, while Aq/As correspond to the ratio of
        # number of quiescent/star-forming AGN hosts predicted by Air+18, then
        # the proper weighting scheme is:
        #   Nq = B x (Aq / As) / (1 + Aq / As)
        #   Ns = B / (1 + Aq / As)
        from aird2018 import get_plambda, get_mbins
        zbin = tuple(zbin)
        def get_ratio(m):
            z = 0.5 * sum(zbin)
            smf_q = smf_reference.get_stellar_mass_function(10 ** m, z, type="quiescent")
            smf_s = smf_reference.get_stellar_mass_function(10 ** m, z, type="star-forming")
            d_q = np.sum(get_plambda(m, z, "quiescent") * 0.08)
            d_s = np.sum(get_plambda(m, z, "star-forming") * 0.08)
            return np.ma.true_divide(smf_q * d_q, smf_s * d_s)

        from scipy.interpolate import interp1d
        ms = np.unique(np.array(get_mbins()).flatten())
        rs = np.array([get_ratio(m) for m in ms]).flatten()
        fun = interp1d(ms, rs, bounds_error=False, fill_value=(rs[0], rs[-1]))

        ratio = np.array([fun(m) for m in np.atleast_1d(np.log10(stellar_mass))])
        factor_qu = ratio / (1 + ratio)
        factor_sf = 1 / (1 + ratio)
        if type == "quiescent":
            return factor_qu * smf
        elif type == "star-forming":
            return factor_sf * smf

    def _get_stellar_mass_function_integral(self, stellar_mass, zbin, zmean_arithmetic=True, *args, **kwargs):
        """Bon+16 Eq. (8)"""
        import bongiorno2016
        zmean = sum(zbin) / 2 if zmean_arithmetic else (zbin[1] * zbin[0]) ** .5
        return bongiorno2016.get_Phi_star(stellar_mass, zmean, *args, **kwargs)

    def _get_stellar_mass_function_schechter(self, stellar_mass, zbin):
        # NOTE: in the schechter function the definition of alpha is not
        # consistent with Bon+16 and Wea+22 so that
        #   alpha_weaver = alpha_bon16 - 1
        zbin = tuple(zbin)
        parameters = {
            (0.3, 0.8): (10.99, -0.41 - 1, 10 ** -3.83),
            (0.8, 1.5): (10.99, -0.24 - 1, 10 ** -3.54),
            (1.5, 2.5): (10.99, -0.03 - 1, 10 ** -3.84),
        }
        return self.get_schechter(np.log10(stellar_mass), *parameters[zbin], factor=1)


class StellarMassFunctionIlbert2013(StellarMassFunction):

    def __init__(self):
        super().__init__()
        self.name = "zCOSMOS (Ilbert+ 2013)"
        self.imf = "chabrier"
        self.find_zbin = False

    def get_zbins(self, type=None):
        return [
            (0.2, 0.5),
            (0.5, 0.8),
            (0.8, 1.1),
            (1.1, 1.5),
            (1.5, 2.0),
            (2.0, 2.5),
            (2.5, 3.0),
            (3.0, 4.0),
        ]

    def _get_stellar_mass_function(self, stellar_mass, zbin, type="all"):
        """Return the best-fit double-schechter stellar mass function"""
        Mstar, Phi1, alpha1, Phi2, alpha2 = self._get_parameters(zbin, type)[1:]
        logM = np.log10(stellar_mass)
        logMstar = np.log10(Mstar)
        return self.get_schechter(logM, logMstar, alpha1, Phi1, alpha2, Phi2)

    def _get_parameters(self, zbin, type):
        """Returns a dictionary of best-fit parameter values for different classes"""
        return {
            #    zmin zmax   Mcomplete    Mstar        phi1     alpha1 phi2     alpha2
            "all": {
                (0.2, 0.5): (10 **  7.93, 10 ** 10.88, 1.68e-3, -0.69, 0.77e-3, -1.42),
                (0.5, 0.8): (10 **  8.70, 10 ** 11.03, 1.22e-3, -1.00, 0.16e-3, -1.64),
                (0.8, 1.1): (10 **  9.13, 10 ** 10.87, 2.03e-3, -0.52, 0.29e-3, -1.62),
                (1.1, 1.5): (10 **  9.42, 10 ** 10.71, 1.35e-3, -0.08, 0.67e-3, -1.46),
                (1.5, 2.0): (10 **  9.67, 10 ** 10.74, 0.88e-3, -0.24, 0.33e-3, -1.60),
                (2.0, 2.5): (10 ** 10.04, 10 ** 10.74, 0.62e-3, -0.22, 0.15e-3, -1.60),
                (2.5, 3.0): (10 ** 10.24, 10 ** 10.76, 0.26e-3, -0.15, 0.14e-3, -1.60),
                (3.0, 4.0): (10 ** 10.27, 10 ** 10.74, 0.03e-3,  0.95, 0.09e-3, -1.60),
            },
            "quiescent": {
                (0.2, 0.5): (10 ** 08.24, 10 ** 10.91, 1.27e-3,  -0.68, 0.03e-3, -1.52),
                (0.5, 0.8): (10 ** 08.96, 10 ** 10.93, 1.11e-3,  -0.46, 0.00e-3, -0.00),
                (0.8, 1.1): (10 ** 09.37, 10 ** 10.81, 1.57e-3,  -0.11, 0.00e-3, -0.00),
                (1.1, 1.5): (10 ** 09.60, 10 ** 10.72, 0.70e-3,  +0.04, 0.00e-3, -0.00),
                (1.5, 2.0): (10 ** 09.87, 10 ** 10.73, 0.22e-3,  +0.10, 0.00e-3, -0.00),
                (2.0, 2.5): (10 ** 10.11, 10 ** 10.59, 0.10e-3,  +0.88, 0.00e-3, -0.00),
                (2.5, 3.0): (10 ** 10.39, 10 ** 10.27, 0.003e-3, +3.26, 0.00e-3, -0.00),
                (3.0, 4.0): (0,           0,           0,            0,       0,     0),
            },
            "star-forming": {
                (0.2, 0.5): (10 ** 07.86, 10 ** 10.60, 1.16e-3, +0.17, 1.08e-3, -1.40),
                (0.5, 0.8): (10 ** 08.64, 10 ** 10.62, 0.77e-3, +0.03, 0.84e-3, -1.43),
                (0.8, 1.1): (10 ** 09.04, 10 ** 10.80, 0.50e-3, -0.67, 0.48e-3, -1.51),
                (1.1, 1.5): (10 ** 09.29, 10 ** 10.67, 0.53e-3, +0.11, 0.87e-3, -1.37),
                (1.5, 2.0): (10 ** 09.65, 10 ** 10.66, 0.75e-3, -0.08, 0.39e-3, -1.60),
                (2.0, 2.5): (10 ** 10.01, 10 ** 10.73, 0.50e-3, -0.33, 0.15e-3, -1.60),
                (2.5, 3.0): (10 ** 10.20, 10 ** 10.90, 0.15e-3, -0.62, 0.11e-3, -1.60),
                (3.0, 4.0): (10 ** 10.26, 10 ** 10.74, 0.02e-3, +1.31, 0.10e-3, -1.60),
            },
        }[type][zbin]


class StellarMassFunctionAird2018(StellarMassFunctionCOSMOS2020):

    def __init__(self):
        super().__init__()
        self.name = "AGN HGMF (Aird+ 2018)"
        self.imf = "chabrier"
        self.find_zbin = False

    def _get_stellar_mass_function(self, stellar_mass, zbin, type="all", smf_reference=StellarMassFunctionCOSMOS2020()):
        from aird2018 import get_mbins, get_duty_cycle
        ms = []
        ds = []
        for m in np.mean(get_mbins(type), axis=1):
            ms.append(m)
            ds.append(get_duty_cycle(m, np.mean(zbin), type))
        ds = np.interp(np.log10(stellar_mass), ms, ds)
        assert np.all(0 <= ds) and np.all(ds <= 1.0)
        return ds * smf_reference.get_stellar_mass_function(stellar_mass, zbin)


class StellarMassFunctionCOSMOS2020AGN(StellarMassFunctionCOSMOS2020):

    """
    NOTE: this stellar mass functions tries to add in the AGN hosts that were
    initially deleted in Weaver's work.

    Since most of the deleted sources were X-ray detected, we assume that the
    Bon+16 AGN host galaxy stellar mass functions corresponds roughly to the
    total number of deleted sources.

    From Bon+16 we do not know the relative contribution of
    Quiescent/Star-Forming so we add them to the stellar mass function in
    proportion to the original COSMOS2020.
    """

    def __init__(self):
        super().__init__()
        self.name = "COSMOS2020 + AGN (Weaver+ 2022)"
        self.bon16 = StellarMassFunctionBongiorno2016()

    def _get_stellar_mass_function(self, stellar_mass, zbin, type="all"):

        def get_schechter(name):
            parameters = np.loadtxt(f"{ROOT}/data/weaver2022/{name}_ext.dat")
            parameters = parameters[(zbin[0] == parameters[:, 0]) * (parameters[:, 1] == zbin[1]), 2:-1][0]
            return self.get_schechter(np.log10(stellar_mass), *parameters)

        from bongiorno2016 import get_Phi_star
        smf_qu = get_schechter("quiescent")
        smf_sf = get_schechter("star-forming")

        # Get Bongiorno+2016 AGN HGSMF
        zcen = np.mean(zbin)
        mvec = np.linspace(8, 16, 801)
        filename = f"{ROOT}/data/smf/bon16_{zcen}.dat"
        if not os.path.exists(filename):
            smf_bon = get_Phi_star(10 ** mvec, zcen)
            np.savetxt(filename, np.vstack([mvec, smf_bon]).T)
        mvec, smf_bon = np.loadtxt(filename).T
        smf_bon = np.interp(np.log10(stellar_mass), mvec, smf_bon)

        smf = (
            (type in ["all", "quiescent"]) * smf_qu +
            (type in ["all", "star-forming"]) * smf_sf
        )

        # NOTE: the factor is derived from two assumptions:
        #   1: Aq + As = A   -- total number of AGN is given by Q+SF
        #   2: Aq/As = Gq/Gs -- ratio of Q/SF AGN follows galaxy ratio
        factor = {
            "all": 1,
            "quiescent": np.ma.true_divide(smf_qu, smf_qu + smf_sf),
            "star-forming": 1 - np.ma.true_divide(smf_qu, smf_qu + smf_sf),
        }.get(type)

        # Add in the AGN hosts
        smf += smf_bon * factor

        return smf


if __name__ == "__main__":

    m = np.linspace(8.0, 12.5, 451)
    smf_ref = StellarMassFunctionCOSMOS2020()
    zbins = smf_ref.get_zbins()
    smf_ref.write_egg("egg/share/mass_func_cosmos2020.fits", mbins=np.array([m[:-1], m[1:]]).T, zbins=zbins)

    m = np.linspace(8.0, 12.5, 451)
    smf_ref = StellarMassFunctionCOSMOS2020AGN()
    zbins = smf_ref.get_zbins()
    smf_ref.write_egg("egg/share/mass_func_cosmos2020_agn.fits", mbins=np.array([m[:-1], m[1:]]).T, zbins=zbins)

    #smf = StellarMassFunctionBongiorno2016()
    #smf.write_egg("egg/share/mass_func_bongiorno2016.fits", mbins=np.array([m[:-1], m[1:]]).T, zbins=zbins)
