#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-07-12 09:32:38

"""
General utilities
"""


import argparse
import glob
import math
import os
import random
import re
import subprocess
import sys
import time

import astropy as ap
import astropy.coordinates as c
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

from astropy.cosmology import FlatLambdaCDM
from scipy.stats import binned_statistic
from astropy import constants

ROOT = "/staff2/viitanen/agile/"

def flux_to_mag(flux):
    """Convert flux in microjanskies to an AB magnitude"""
    return np.where(
        np.atleast_1d(flux) > 0,
        -2.5 * np.ma.log10(flux * 1e-6 / 3631),
        np.nan
    ).squeeze()

def mag_to_flux(mag):
    """Convert from flux [uJy] to AB magnitude"""
    return 3631 * 1e6 * 10 ** (mag / -2.5)

def get_volume(zmin, zmax, area_deg2, H0=70., Om0=0.30, Tcmb0=2.73):
    """Return the comoving volume in Mpc for a redshift shell"""

    cosmo = FlatLambdaCDM(H0=H0, Om0=Om0, Tcmb0=Tcmb0)
    volume = (cosmo.comoving_volume(zmax) - cosmo.comoving_volume(zmin)).value
    ret = volume * area_deg2 / (4 * np.pi * u.sr.to(u.deg ** 2))

    return ret


def get_chisq_nu(y, y_model, sigma):
    """Return the chisq value of the measurement y"""
    diff = (y - y_model) / sigma
    return np.sum(diff ** 2) / (y.size - 1)


def get_key_function(bins, x, values=None, *args, **kwargs):
    """
    Return the "key function" where key is e.g the stellar mass or the X-ray
    luminosity.

    The function is returned in units of 1/Mpc3/dex for the given cosmology.
    """

    # The binning
    dbins = np.diff(bins)
    assert np.allclose(dbins[0], dbins[1:])
    centers = bins[:-1] + dbins / 2

    # Short-circuit for an empty array
    if not x.size:
        return centers, dbins, np.zeros_like(centers), np.zeros_like(centers)

    # NOTE: assume poissonian errors on the counts, no cosmological variance etc.
    if values is None:
        values = np.ones_like(x)

    # Calculate the weight by using the area and the volume
    counts = binned_statistic(x=x, values=values, statistic="sum", bins=bins)[0]

    div = get_volume(*args, **kwargs) * dbins[0]

    function = counts / div
    dfunction = counts ** .5 / div
    return centers, dbins / 2, function, dfunction

def egg_band_to_index(egg, band):
    bands = [b.strip() for b in egg["BANDS"][0]]
    return bands.index(band)


import astropy.units as u
from astropy.time import Time
from astropy.coordinates import SkyCoord
def get_ra_dec(ra0, dec0, pm_ra_cosdec, pm_dec, mjd, mjd0=51544.5):

    """Get ra, dec for current epoch"""

    pm_ra_cosdec = np.where(np.isfinite(pm_ra_cosdec), pm_ra_cosdec, 0.0)
    pm_dec = np.where(np.isfinite(pm_dec), pm_dec, 0.0)

    t0 = Time(f"{mjd0}", format="mjd")
    t1 = Time(f"{mjd}", format="mjd")

    c0 = SkyCoord(
        ra=ra0,
        dec=dec0,
        unit=u.deg,
        pm_ra_cosdec=pm_ra_cosdec * u.mas / u.yr,
        pm_dec=pm_dec * u.mas / u.yr,
        obstime=t0
    )
    c1 = c0.apply_space_motion(new_obstime=t1)

    return c1.ra.value, c1.dec.value


def convert_flux(S1, E1_min=2, E1_max=10, E2_min=2, E2_max=7, Gamma=1.9):
    """
    Convert flux S from bandpass E1 to bandpass E2 assuming a power-law
    spectrum with photon index Gamma
    """
    idx = (2 - Gamma)
    return S1 * np.true_divide(
        E2_max ** idx - E2_min ** idx,
        E1_max ** idx - E1_min ** idx
    )


def get_log_L_2_keV(log_LX_2_10, Gamma=1.9, wavelength=6.2):
    """
    Returns monochromatic X-ray luminosity at lambda = wavelength  in erg/s Hz^-1 to be used for the alpha_ox
    Lx = restframe 2-10 kev luminosity.
    maybe we can work directly with frequencies and have it in one line
    """
    Lx = 10 ** log_LX_2_10
    K = (Lx / (6.2 ** (Gamma - 2) - 1.24 ** (Gamma - 2))) * (Gamma - 2)  #6.2, 1.24 = 2kev, 10kev in A°
    return np.log10((K * wavelength ** (Gamma - 1)) / 2.998e18)


def get_log_L_2500(log_L_2_keV, alpha=0.952, beta=2.138, scatter=True):

    """
    Returns the 2500 ang° monochromatic luminosity (in erg/s). It uses
    Lusso+10 eq. 5 (inverted) Lx = alpha L_opt - beta
    """

    log_L_2500 = (log_L_2_keV + beta) / alpha
    assert np.allclose(alpha * log_L_2500 - beta, log_L_2_keV)

    # TODO: implement realistic scatter
    if scatter:
        log_L_2500 += np.random.normal(loc=0, scale=0.4, size=log_L_2500.size)

    return log_L_2500
    #return (10 ** L_2500) * 2.998e18 / 2500


import sed
from astropy.cosmology import FlatLambdaCDM
COSMO = FlatLambdaCDM(Om0=0.30, H0=70)
def mock_lx_to_M_1450(lx, z, distance_cm, ebv, scatter, seed):
    l_2_kev = get_log_L_2_keV(lx)
    l_2500 = get_log_L_2500(l_2_kev, scatter=scatter)
    wav, flux = sed.get_sed(
        LogL2500=l_2500,
        AGN_type=1,
        ebv=ebv,
        redshift=z,
        distance_cm=distance_cm,
        LogL2kev=l_2_kev,
        flux_rf_1000_4000_gal=np.inf,
        seed=seed,
        wav_min=1449 * u.angstrom,
        wav_max=1451 * u.angstrom,
        dlog_wav=1e-4,
    )[1]
    flux_1450 = np.interp(1450, wav, flux)
    M_1450 = flux_to_mag(flux_1450)
    return M_1450

def mock_lx_to_sed(lx, z, distance_cm, ebv, scatter, seed):
    l_2_kev = get_log_L_2_keV(lx)
    l_2500 = get_log_L_2500(l_2_kev, scatter=scatter)
    wav, flux = sed.get_sed(
        LogL2500=l_2500,
        AGN_type=1,
        ebv=ebv,
        redshift=z,
        distance_cm=distance_cm,
        LogL2kev=l_2_kev,
        flux_rf_1000_4000_gal=np.inf,
        seed=seed,
    )[1]
    return wav, flux

def get_fraction_obscured(lx):

    """
    Get AGN type based on Merloni+2014. The type is a two-byte string e.g.
    11 where the first (second) byte refers to optical (X-ray) obscuration,
    where 1 means unobscured and 2 obscured.
        22  opt + X-ray obscured
        21  opt obscured, X-ray unobscured
        12  opt unobscured, X-ray obscured
        11  opt + X-ray unobscured
    """

    type22_21 = np.interp(lx, *np.loadtxt("data/merloni2014/22_21.csv").T, left=95, right=15)
    type21_12 = np.interp(lx, *np.loadtxt("data/merloni2014/21_12.csv").T, left=95, right=15)
    type12_11 = np.interp(lx, *np.loadtxt("data/merloni2014/12_11.csv").T, left=95, right=15)

    return type21_12


def get_E_BV(
    type2=False,
    alpha_1=7.93483055,
    n_1=2.97565676,
    alpha_2=11.6133635,
    n_2=1.42972,
    type_1_ebv=np.linspace(0,1, 101),
    type_2_ebv=np.linspace(0,3, 301),
    mu_type_2=0.3,
):

    def sample_ebv(N_AGN, probability_distribution, ebv_range, *args):
        cumulative = np.cumsum(probability_distribution(ebv_range, *args))
        cumulative /= np.max(cumulative)
        return np.interp(np.random.rand(N_AGN), cumulative, ebv_range)

    def hopkins04(x, alpha, n):
        """p(E_BV)"""
        y = 1 / (1 + (x * alpha) ** n)
        return y / np.trapz(y, x)

    ebv = None
    if type2:
        ebv = sample_ebv(1, hopkins04, type_2_ebv, alpha_2, n_2) + mu_type_2
    else:
        ebv = sample_ebv(1, hopkins04, type_1_ebv, alpha_1, n_1)
    return np.squeeze(ebv)


from astropy import constants
from igm_absorption import my_get_IGM_absorption
def luminosity_to_flux(wavlen, luminosity, redshift, distance_in_cm):

    """
    Convert luminosity (in erg/s/ang) to flux in uJy. Default distance is 10pc.
    """

    if redshift == 0:
        assert np.isclose(distance_in_cm, (10 * u.pc).to(u.cm).value)

    # Wavlen in angstrom and to observed frame
    wavlen_observed = wavlen * (1 + redshift)

    # To uJy in observed frame
    log_flux = (
        np.log10(luminosity)
        + np.log10(wavlen_observed)
        + np.log10((u.erg / u.s / u.cm ** 2 / u.Hz).to(u.uJy))
        - np.log10(constants.c.to(u.angstrom / u.s).value)
        - np.log10(4 * np.pi)
        - 2 * np.log10(distance_in_cm)
    )

    # Add IGM
    if redshift > 0.0:
        t_igm = my_get_IGM_absorption(redshift, lambda_obs=wavlen_observed)
        log_flux += np.log10(t_igm)

    # NOTE: everything is now in EGG units. Wavlen in um, flux in uJy
    return wavlen_observed * (u.angstrom.to(u.um)), 10 ** log_flux


def get_flux_band(lam, flux, band):

    """
    Integrate the SED through the filter transmission curve

        lam in microns
        flux in uJy
    """

    # Get the filter
    fil = get_band_egg(band)
    lam_filter = fil["LAM"][0]
    res_filter = fil["RES"][0]

    # Estimate and renormalize the filter at the datapoints
    res_filter = np.interp(lam, lam_filter, res_filter, left=0, right=0)
    norm = np.trapz(res_filter, lam)

    # NOTE: no overlap between wavelengths and filter lambda?
    if norm == 0:
        return 0

    res_filter /= norm
    flux_band = np.trapz(flux * res_filter, x=lam)
    return flux_band


def get_log_y_lo_hi(y, dy, null=99):
    y0 = np.log10(y)
    y1 = np.log10(y / (y - dy))
    y2 = np.log10((y + dy) / y)
    
    select = y - dy <= 0.0
    print(select.sum())

    y1[select] = null

    return y0, y1, y2
