#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-05-08 18:15:21

"""
Aird+2018 functions & data etc.
"""


import argparse
import glob
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.coordinates as c
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

from util import ROOT


AIRD2018_TO = np.loadtxt(f"{ROOT}/data/aird2018/pledd_all.dat")
AIRD2018_QU = np.loadtxt(f"{ROOT}/data/aird2018/pledd_qu.dat")
AIRD2018_SF = np.loadtxt(f"{ROOT}/data/aird2018/pledd_sf.dat")


def get_aird2018(type):
    return {
        "all": AIRD2018_TO,
        "quiescent": AIRD2018_QU,
        "star-forming": AIRD2018_SF,
    }.get(type)


def get_zbins(type="all"):
    d = get_aird2018(type)
    zlo = np.unique(d[:, 0])
    zhi = np.unique(d[:, 1])
    return [(z1, z2) for z1, z2 in zip(zlo, zhi)]


def get_mbins(type="all"):
    d = get_aird2018(type)
    mlo = np.unique(d[:, 2])
    mhi = np.unique(d[:, 3])
    return [(m1, m2) for m1, m2 in zip(mlo, mhi)]


def get_l(type="all"):
    d = get_aird2018(type)
    return np.unique(d[:, 4])


def get_plambda(m, z, type="all"):

    air18 = get_aird2018(type)

    # NOTE: extrapolation to the nearest bin
    z = np.clip(z, air18[:, 0].min(), air18[:, 1].max() - 1e-9)
    select_z = (air18[:, 0] <= z) * (z < air18[:, 1])

    mmin = air18[select_z, 2].min() ; mmax = air18[select_z, 3].max()
    if   m <= mmin: m = mmin
    elif m >= mmax: m = mmax - 1e-9
    select_m = (air18[:, 2] <= m) * (m < air18[:, 3])

    # NOTE: returns [lambda, p(lambda)]
    return air18[select_z * select_m, 4:6].T


def get_duty_cycle(m, z, type="all"):
    dl = np.diff(get_l(type))
    assert np.allclose(dl[0], dl[1:])
    dl = dl[0]
    return np.sum(get_plambda(m, z, type) * dl)


def get_log_lambda_sBHAR(m, z, t):
    l_all = []
    for _m, _z, _t in zip(m, z, t):
        l_vec = get_l(_t)
        p = get_plambda(_m, _z, _t)
        c = np.cumsum(p) * np.diff(l_vec)[0]
        l = np.interp(np.random.rand(), c, l_vec, left=-np.inf, right=-np.inf)
        l_all.append(l)
    return np.array(l_all)
