#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-04-12 22:41:33

"""Tools for assigning MBH"""
import numpy as np
from scipy.interpolate import CubicSpline
from util import ROOT


def get_log_mbh_from_bulge(log_mstar, A=500):
    return log_mstar - np.log10(A)


def get_log_mbh_shankar2019(log_mstar):
    """Return equation 5 from Shankar+2019"""
    x = log_mstar - 11
    log_mbh = 7.574 + 1.946 * x - 0.306 * x ** 2 - 0.011 * x ** 3
    return log_mbh


def get_log_mbh_sg2016(log_mstar):
    """Return equation 3 from Shankar+2019"""
    return 8.54 + 1.18 * (log_mstar - 11)


def get_log_mbh_tanaka2024(log_mstar):
    """Return equation 10 from Tanaka+2024, see also Figure 12"""
    return 0.67 * log_mstar + 1.31


def get_log_mbh_pacucci2023(log_mstar):
    """Return the first equation from Pacucci+2023 abstract"""
    return -2.43 + 1.06 * log_mstar


def get_delta_log_mbh_shankar2019(log_mstar, low=0.0, high=np.inf):
    """Return equation 6 from Shankar+2019"""
    delta_mbh = 0.32 - 0.1 * (log_mstar - 12)
    delta_mbh = np.clip(delta_mbh, low, high)
    return delta_mbh


def get_spline(Mst, Mbh):

    '''
    Takes the Mstar and Mbh values at z=0 and creates a cubic spline that
    extrapolates with dy/dx of the final knot and d2y/dx2 = 0
    '''

    BinWidth = 0.25
    Mbins = np.arange(int(min(Mst)),int(max(Mst))+1+BinWidth,BinWidth)
    Marr = (Mbins[1:]+Mbins[:-1])/2 ; nbins = Marr.size
    meanMbh = np.array([np.log10(np.mean(10**Mbh[np.where((Mst>=Mbins[i])&(Mst<Mbins[i+1]))])) for i in range(nbins)])
    spline = CubicSpline(Marr[np.isfinite(meanMbh)],meanMbh[np.isfinite(meanMbh)],bc_type='natural')

    '''
    Now we extend the spline so that the extrapolation is a C1 continuous
    linear relation
    '''

    # Get First point
    x0 = spline.x[0] ; y0 = spline(x0) ; dydx0 = spline(x0,nu=1)
    newx = np.nextafter(x0,x0-1) ; newy = y0 + dydx0*(newx-x0)
    newCoeffs = np.array([0,0,dydx0,newy]) # [d^3y/dx^3, d^2y/dx^2, dy/dx, y]
    spline.extend(newCoeffs[...,None],np.r_[x0])

    # Get last point
    x1 = spline.x[-1] ; y1 = spline(x1) ; dydx1 = spline(x1,nu=1)
    newx = np.nextafter(x1,x1+1) ; newy = y1 + dydx1*(newx-x1)
    newCoeffs = np.array([0,0,dydx1,newy]) # [d^3y/dx^3, d^2y/dx^2, dy/dx, y]
    spline.extend(newCoeffs[...,None],np.r_[x1])

    return spline


X = None
Y = None
Z = None
SPLINE = {}
def get_log_mbh_continuity(log_mstar, z):

    """Get MBH using Daniel's continuity approach"""

    # Initialize the data
    global X, Y, Z
    if X is None:
        X = np.loadtxt(f"{ROOT}/opt/GrowBHs/Mstar.txt")
        Y = np.loadtxt(f"{ROOT}/opt/GrowBHs/Mbh.txt")
        Z = np.loadtxt(f"{ROOT}/opt/GrowBHs/z.txt")

    # Populate the spline functions...
    if not SPLINE:
        for i, my_z in enumerate(Z):
            # NOTE: daniel's code ends at z=5? clip to that redshift... This
            # assumes NO redshift evolution past z=5
            i = min(i, len(Z) - 2)

            # Sort Mstar and MBH
            x, y = X[:, i], Y[:, i]
            idx = np.argsort(x)
            SPLINE[my_z] = get_spline(x[idx], y[idx])

    # Find the closest spine to the redshift
    idx = np.argmin((Z - z) ** 2)
    spline = SPLINE.get(Z[idx])

    # Return log MBH
    return spline(log_mstar)


REFERENCES = (
    "Shankar+16",
    "K&H13",
    "R&V15",
    "Suh+19",
    "S&G16",
    "Tanaka+24",
    "Pacucci+23",
)

def get_log_mbh(log_mstar, reference="Shankar+16", z=0.0):
    """
    Please double check them once you plot them!

    #Shankar+16  OK
    set x=2,13,0.1
    set y=7.574+1.946*(x-11)-0.306*((x-11)**2)-0.011*((x-11)**3)
    connect x y
    set yup=y+0.3
    set ydown=y-0.3
    shed x yup x ydown 100000

    # Kormendy & Ho #ok
    set x=2,13,0.1
    set y=8.56+1.58*(x-11) #K&H da Shankar09

    # Reines & Volonteri
    set x=2,13,0.1
    set x1=1.793+0.845*x
    set y=7.45+1.05*(x-11)


    # Suh+19
    set x=2,13,0.1
    set y=1.47*x-8.56
    """

    x = log_mstar
    return {
        "Shankar+16": 7.574 + 1.946 * (x - 11) - 0.306 * ((x - 11) ** 2) - 0.011 * ((x - 11) ** 3),
        "K&H13":      8.56 + 1.58 * (x - 11),
        "R&V15":      7.45 + 1.05 * (x - 11),
        "Suh+19":     1.47 * x - 8.56,
        "S&G16":      8.54 + 1.18 * (x - 11),
        "Tanaka+24":  get_log_mbh_tanaka2024(log_mstar),
        "Pacucci+23": get_log_mbh_pacucci2023(log_mstar),
        "continuity": get_log_mbh_continuity(log_mstar, z=z)
    }[reference]


if __name__ == "__main__":

    import matplotlib.pyplot as plt
    log_mstar = np.linspace(9, 12)

    for reference in REFERENCES:
        plt.plot(log_mstar, get_log_mbh(log_mstar, reference), label=reference, linewidth=2.0)

    plt.plot(log_mstar, get_log_mbh(log_mstar, "continuity", z=0.0), label="continuity z=0.0", linestyle='dotted', linewidth=2.0)
    plt.plot(log_mstar, get_log_mbh(log_mstar, "continuity", z=2.0), label="continuity z=2.0", linestyle='dotted', linewidth=2.0)
    plt.plot(log_mstar, get_log_mbh(log_mstar, "continuity", z=4.0), label="continuity z=4.0", linestyle='dotted', linewidth=2.0)

    plt.legend(fontsize="x-small")
    plt.xlabel(r"$\log M_{\rm star}$ [Msun]")
    plt.ylabel(r"$\log M_{\rm BH}$ [Msun]")

    plt.savefig("mbh_mstar_all.pdf")
