#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2024-06-14

"""
Lightcurve module for the different objects
"""

import astropy.units as u
import numpy as np
import os

from my_lsst import filter_lam_fwhm
import util


def get_lightcurve():
    return 0


from AGN_lightcurves_simulations.AGN_sim.DRW_sim import LC2
def get_lightcurve_agn(mjd0, mjd, flux, band, z, mag_i, logMbh, type2, T=3653, deltatc=1, seed=None):

    # Initialize the time vector
    tt = np.arange(0, T, deltatc)
    assert np.all((tt.min() <= mjd - mjd0) * (mjd - mjd0 <= tt.max()))

    # Figure out the rest-frame wavelength for the filter
    lambda_of = filter_lam_fwhm[band.replace("lsst-", '')][0] * u.um.to(u.angstrom)

    # Set the seed
    np.random.seed(seed)

    tt, yy = LC2(
        T=T,
        deltatc=deltatc,
        z=z,
        mag_i=mag_i,
        logMbh=logMbh,
        lambda_rest=lambda_of / (z + 1),
        oscillations=False,
        A=0.14,
        noise=0.00005,
        frame="observed",
        type2=type2,
    )[:2]

    # Normalize to the 10yr mean
    yy = util.mag_to_flux(yy)
    yy /= np.mean(yy, axis=1)[:, None]
    yy *= flux[:, None]

    return np.array([np.interp(mjd - mjd0, tt, _yy) for _yy in yy])


def get_lightcurve_cepheid(mjd0, mjd, flux, band, z, y, logte, logl, mass, pmode, period, seed=None):

    source = {
        "z": z,
        "y": y,
        "logte": logte,
        "logl": logl,
        "mass": mass,
        "pmode": pmode,
        "period%d" % pmode: period,
        f"{band[-1]}mag": util.flux_to_mag(flux),
    }

    from CEP_lsst_quasar_project.get_lightcurve import get_lightcurve_normalized_cepheid
    time, flux = get_lightcurve_normalized_cepheid(source, band)

    # NOTE: add a random phase to the lightcurve
    np.random.seed(seed)
    time += np.random.uniform(0, period)
    time -= time.min()

    return np.interp(mjd - mjd0, time, flux, period=period)


def get_lightcurve_lpv(mjd0, mjd, flux, band, pmode, period, is_c_rich, seed=None):

    # Initialize the return variable for easy short-circuiting
    ret = np.full_like(mjd, flux)

    # NOTE: Refer to Michele email
    if period > 700:
        return ret

    def get_amplitude_lpv():
        lam = filter_lam_fwhm[band[-1]][0]
        t = np.loadtxt("data/lpv/table1.dat")

        # R = amplitude_band / amplitude_iband
        std_iband = np.power(10, 1.5 * np.log10(period) - 4)

        idx_r = 3 if is_c_rich else 1
        R = np.interp(lam, t[:, 0], t[:, idx_r])
        return R * std_iband

    # NOTE:
    #   Implement a sine curve with
    #       1) period given by period + random phase
    #       2) amplitude given by the std (see Michele email)
    #       3) normalization given by the magnitude of the star
    #
    #       lc(phi) = A * np.sin(phi) + B,
    #
    #   where phi = (mjd - mjd0) / period * 2 * np.pi
    np.random.seed(seed)
    tt = np.arange(0, period, 0.10) + np.random.uniform(0, period)
    A = get_amplitude_lpv()
    B = util.flux_to_mag(flux)
    yy = A * np.sin(2 * np.pi * tt / period) + B
    yy = util.mag_to_flux(yy)
    return np.interp(mjd - mjd0, tt, yy, period=period)


import batman
def get_lightcurve_binary(mjd0, mjd, flux1, flux2, radius1, radius2, p, a, i, e, t0, t1, seed=None):

    """
    Return the primary and secondary lightcurves for the binary system

    Input:

        mjd0:     reference mjd
        mjd:      mjd of the observation
        flux1:    flux of the primary star
        flux2:    flux of the secondary star in units of flux1
        radius1:  radius of the primary star in Rsun
        radius2:  radius of the secondary star in Rsun
        p:        period of the system in units of mjd
        a:        orbital semi-major axis of the system in Rsun
        i:        orbital inclination of the system in degrees e.g. 90 = edge-on
        e:        orbital eccentricity of the system
        t0:       time of the primary transit in units of mjd e.g. period * 0.25
        t1:       time of the secondary in units of mjd e.g. period * 0.75
        seed:     random number seed to offset the phase curve

    Output:

        lc1, lc2: lightcurves of the primary and the secondary. The total
                  lightcurve of the system is then lc1 + lc2. In the same units
                  of flux1 and flux2.
    """

    assert flux1 >= flux2
    assert radius1 >= radius2

    # NOTE: there is some weird feature in batman with radius1 == radius2 which
    # is prevented by adding a very minor offset to the secondary radius
    if np.isclose(radius1, radius2):
        radius2 -= radius2 / 1e6

    param = batman.TransitParams() # object to store transit parameters
    param.t0 = t0                  # time of inferior conjunction
    param.per = p                  # orbital period in days
    param.rp = radius2 / radius1   # planet radius (in units of stellar radii)
    param.a = a / radius1          # semi-major axis (in units of stellar radii)
    param.inc = i                  # orbital inclination (in degrees)
    param.ecc = e                  # eccentricity
    param.w = 90.                  # longitude of periastron (in degrees)
    param.limb_dark = "quadratic"  # limb darkening model
    param.u = [0.10, 0.30]         # limb darkening coefficients [u1, u2, u3, u4]
    param.fp = flux2 / flux1       # planet to star flux ratio
    param.t_secondary = t1         # central eclipse time

    t = mjd - mjd0
    if seed is not None:
        np.random.seed(seed)
        t += np.random.uniform(p)

    model1 = batman.TransitModel(param, t, transittype="primary")
    model2 = batman.TransitModel(param, t, transittype="secondary")

    lc1 = model1.light_curve(param) * (flux1 + flux2) - flux2
    lc2 = model2.light_curve(param) * flux1 - flux1
    lc = lc1 + lc2

    return lc1, lc2
