#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-12-11 12:38:41

"""
Implement Merloni+ 2014
"""

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d

# Table 2 from Merloni+14
from util import ROOT
TAB2 = np.loadtxt(f"{ROOT}/data/merloni2014/table2.dat")

class Merloni2014:

    def __init__(self, interpolate, extrapolate, f_obs_minimum=0.05, f_obs_maximum=0.95):

        self.interpolate = interpolate
        self.extrapolate = extrapolate
        self.f_obs_minimum = f_obs_minimum
        self.f_obs_maximum = f_obs_maximum

    def get_parameters(self, lx):

        kind = "linear" if self.interpolate else "nearest"
        fill_value1 = "extrapolate" if self.extrapolate else (TAB2[0, 2], TAB2[2, 2])
        fill_value2 = "extrapolate" if self.extrapolate else (TAB2[0, 4], TAB2[2, 4])

        x = np.mean(TAB2[:, :2], axis=1)
        y1 = TAB2[:, 2]
        y2 = TAB2[:, 4]

        return (
            interp1d(x, y1, kind=kind, fill_value=fill_value1, bounds_error=False)(lx),
            interp1d(x, y2, kind=kind, fill_value=fill_value2, bounds_error=False)(lx)
        )

    def get_f_obs(self, z, lx):
        parameters = self.get_parameters(lx)
        f_obs = np.clip(
            self.fun(z, *self.get_parameters(lx)),
            self.f_obs_minimum,
            self.f_obs_maximum,
        )
        return np.clip(f_obs, self.f_obs_minimum, self.f_obs_maximum)

    @staticmethod
    def fun(z, b, d):
        """Redshift evolution of b and d parameters from Merloni+ 2014"""
        return b * (1 + z) ** d


if __name__ == "__main__":

    z = np.linspace(0.0, 5.5, 56)
    lx = np.linspace(42, 46, 41)

    norm = mpl.colors.Normalize(vmin=42, vmax=46)
    cmap = mpl.cm.viridis
    fig, axes = plt.subplots(1, 2, figsize=(2 * 6.4, 4.8))

    # Plot the Mer14 data
    for t in TAB2:
        axes[0].errorbar(0.5 * (t[0] + t[1]) - 0.01, t[2], yerr=t[3], color='b', marker='.', linestyle="none", zorder=99)
        axes[0].errorbar(0.5 * (t[0] + t[1]) + 0.01, t[4], yerr=t[5], color='r', marker='.', linestyle="none", zorder=99)

    for a1, a2, color, linewidth in [
        #(False, False, "C0", 4),
        #(False, True,  "C1", 3),
        (True,  False, "C0", 1),
        #(True,  True,  "C3", 1),
    ]:

        m = Merloni2014(a1, a2)
        b, d = m.get_parameters(lx)
        axes[0].plot(lx, b, color=color, linewidth=linewidth, linestyle='solid')
        axes[0].plot(lx, d, color=color, linewidth=linewidth, linestyle='solid')

        for l in lx:
            f_obs = m.get_f_obs(z, l)
            axes[1].plot(z, f_obs, color=cmap(norm(l)), linewidth=linewidth, linestyle='solid')

    plt.savefig("merloni2014.pdf")
    plt.show()
