#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-02-13 16:20:51

"""
Create Galaxy+AGN mocks for the LSST Italian AGN in-kind contribution
"""

import argparse
from copy import deepcopy
from itertools import product
import glob
from multiprocessing import Pool, cpu_count
import os
import subprocess
import sys
import time
import re

from astropy import constants
from astropy.cosmology import FlatLambdaCDM
from astropy.time import Time
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from scipy.stats import binned_statistic
import astropy.units as u
import fitsio
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
import pandas as pd
import emcee
import sqlite3
import pandas as pd
import pyvo as vo

from plambda import Plambda
import util
from AGN_lightcurves_simulations.AGN_sim.DRW_sim import LC2
from my_lsst import filter_lam_fwhm
import my_lsst
import ueda2014
import bongiorno2016
from merloni2014 import Merloni2014
import sed
import smf
from mock_catalog_SED.qsogen_4_catalog import qsosed
import egg
import lightcurve

from util import ROOT

BAND_LAM_FWHM = {}
ret = str(subprocess.check_output(["egg-gencat", "list_bands"]), encoding="utf-8")
for line in ret.split('\n'):
    line = line.strip()
    if not line.startswith('-'):
        continue
    items = line.strip().split()
    BAND_LAM_FWHM[items[1]] = float(items[4]), float(items[8])

# Optimization: store the pre-computed AGN SEDs and lightcurves
#SEDS = {}
FLUXES = {}
LBOL = {}
LIGHTCURVES = {}

# Pre-load the posterior distribution for the AGN SED
#print("Reading in the posterior distribution...")
#POSTERIOR_DISTRIBUTION = pd.read_csv(f"{ROOT}/posterior.dat", sep=' ')
POSTERIOR_DISTRIBUTION = pd.read_csv(f"{ROOT}/data/posterior_2024_05_20.dat", sep=' ')
PARAMETER_NAMES = POSTERIOR_DISTRIBUTION.columns

# LSST magnitude limits for 30s exposure and stacked 10yr
from my_lsst import limiting_magnitude_30s, limiting_magnitude_10yr
from igm_absorption import my_get_IGM_absorption

# TODO: something more fancy
#ROOT = os.getcwd()

def get_band_egg(band):
    band = band.replace("magabs_", "")
    with open(f"{ROOT}/egg/share/filter-db/db.dat") as f:
        for line in f:
            name, filename = line.strip().split('=')
            if name != band:
                continue
            return fitsio.read(f"{ROOT}/egg/share/filter-db/{filename}")
    return None


class CatalogGalaxyAGN:

    def __init__(
        self,
        filename_egg="egg.fits",
        filename_agn="agn.fits",
        filename_stars="star.fits",
        filename_binaries="binary.fits",
        type_plambda="aird+2018",
        no_classification=False,
        seed=20230413,
        filename_catalog=None,
        fetch_stars=True,
        save_sed=False,
        seds_file=None,
        dm=0.10,
        dz=0.10,
        dl=0.10,
        merloni2014_interpolate=0,
        merloni2014_extrapolate=0,
        merloni2014_f_obs_minimum=0.05,
        merloni2014_f_obs_maximum=0.95
    ):

        self.filename_egg = filename_egg
        self.filename_agn = filename_agn
        self.filename_stars = filename_stars
        self.filename_binaries = filename_binaries
        self.type_plambda = type_plambda
        self.no_classification = int(no_classification)
        self.seed = int(seed)
        self.filename_catalog = filename_catalog
        self.fetch_stars = int(fetch_stars)
        self.save_sed = int(save_sed)
        self.seds_file = seds_file
        self.dm = float(dm)
        self.dz = float(dz)
        self.dl = float(dl)

        # Optimization: save the selections in Z, M, L
        self.SZ = {}
        self.SM = {}
        self.SL = {}

        self.dirname = ROOT + "/" + os.path.dirname(self.filename_agn)

        self.merloni2014 = Merloni2014(
            int(merloni2014_interpolate),
            int(merloni2014_extrapolate),
            float(merloni2014_f_obs_minimum),
            float(merloni2014_f_obs_maximum),
        )

        self.star = None
        self.star_binary = None

        tmp = fitsio.read(self.filename_egg, columns="CMD")
        self.area = float(re.findall("area=([0-9.]+)", tmp[0])[0])
        if self.seed:
            np.random.seed(self.seed)

        # NOTE: short-circuit for an existing catalog
        if os.path.exists(self.filename_catalog):
            self.egg = {
                k: fitsio.read(self.filename_egg, columns=k) for k in
                ["BANDS", "LAMBDA", "RFBANDS", "RFLAMBDA"]
            }
            self.catalog = fitsio.read(self.filename_catalog)
            return

        self.egg = egg.Egg.read(self.filename_egg)
        self.catalog = self.get_catalog()

    def _get_select(self, z=None, m=None, l=None, t="all"):

        select = np.ones_like(self.catalog["Z"], dtype=bool)

        # Perform the selection in redshift and Mstar
        if z is not None:
            if z not in self.SZ:
                self.SZ[z] =  (z <= self.catalog["Z"]) * (self.catalog["Z"] < z + self.dz)
            select *= self.SZ[z]

        if m is not None:
            if m not in self.SM:
                self.SM[m] = (m <= self.catalog["M"]) * (self.catalog["M"] < m + self.dm)
            select *= self.SM[m]

        # Perform the selection in lx if requested
        if l is not None:
            if l not in self.SL:
                self.SL[l] = (l <= self.catalog["log_LX_2_10"]) * (self.catalog["log_LX_2_10"] < l + self.dl)
            select *= self.SL[l]

        # Perform the selection in type
        if t != "all":
            select *= self.catalog["PASSIVE"] == (t == "quiescent")

        return select

    def get_star(self, selection="rmag", maglim=28, fbin=1.0):

        if os.path.exists(ROOT + "/" + self.filename_stars):
            return fitsio.read(ROOT + "/" + self.filename_stars)

        # Build the query
        ra_min = self.egg["RA"][0].min()
        ra_max = self.egg["RA"][0].max()
        dec_min = self.egg["DEC"][0].min()
        dec_max = self.egg["DEC"][0].max()

        query = \
            f"""
            SELECT * FROM lsst_sim.simdr2
            WHERE {selection} < {maglim}
            AND {ra_min} < ra AND ra < {ra_max}
            AND {dec_min} < dec AND dec < {dec_max}
            """
        print(query)

        # Run the syncronous job
        tap_service = vo.dal.TAPService("https://datalab.noirlab.edu/tap")
        tap_results = tap_service.search(query)
        table = tap_results.to_table()

        # Convert AB magnitudes to microjanskys
        for c in table.columns:
            if "mag" in c:
                table[c] = 10 ** ((table[c] - 8.90) / (-2.50)) * 1e6

        table.write(ROOT + "/" + self.filename_stars)
        return table

    def get_star_binary(self, selection="c3_rmag", maglim=28, fbin=1.00):

        if os.path.exists(ROOT + "/" + self.filename_binaries):
            return fitsio.read(ROOT + "/" + self.filename_binaries)

        ra_min = self.egg["RA"][0].min()
        ra_max = self.egg["RA"][0].max()
        dec_min = self.egg["DEC"][0].min()
        dec_max = self.egg["DEC"][0].max()

        query = \
            f"""
            SELECT * FROM lsst_sim.simdr2_binary
            WHERE {selection} < {maglim}
            AND {ra_min} < ra AND ra < {ra_max}
            AND {dec_min} < dec AND dec < {dec_max}
            """
        print(query)

        # Run the syncronous job
        tap_service = vo.dal.TAPService("https://datalab.noirlab.edu/tap")
        tap_results = tap_service.search(query)
        table = tap_results.to_table()

        # Convert AB magnitudes to microjanskys
        for c in table.columns:
            if "mag" in c:
                table[c] = 10 ** ((table[c] - 8.90) / (-2.50)) * 1e6

        table.write(ROOT + "/" + self.filename_binaries)
        return table


    def get_catalog(self):

        # Create the galaxy+AGN catalog
        self.catalog = np.empty_like(self.egg["RA"][0], dtype=self.get_dtype())

        filename = ROOT + "/" + self.filename_agn

        if not os.path.exists(filename):

            # column, type, description, unit
            for c, t, d, u in self.get_columns():
                if (
                    (c in self.egg.keys()) or
                    (c.replace("magabs_", "") in [_.strip() for _ in self.egg["BANDS"][0]]) or
                    ("_disk" in c) or ("_bulge" in c)
                ):
                    self.catalog[c] = self.get_galaxy(c).astype(t)
                elif "_total" in c:
                    self.catalog[c] = self.get_flux_total(
                        c.replace("magabs_", "").replace("_total", ""),
                        rest_frame="magabs" in c
                    )
                else:
                    self.catalog[c] = self.get_agn(c)

            if not os.path.exists(os.path.dirname(filename)):
                os.makedirs(os.path.dirname(filename))

            fitsio.write(filename, self.catalog, clobber=True)

        self.catalog = fitsio.read(filename)

        if not self.fetch_stars:
            print("Not fetching stars... returning catalog")
            return self.catalog

        # Add the stars
        self.star = self.get_star()
        catalog_stars = np.full_like(self.catalog, np.nan, shape=len(self.star))
        for c1, c2 in [
            ("ra",       "RA"),
            ("dec",      "DEC"),
            ("mu0",      "D"),
            ("umag",     "lsst-u_point"),
            ("gmag",     "lsst-g_point"),
            ("rmag",     "lsst-r_point"),
            ("imag",     "lsst-i_point"),
            ("zmag",     "lsst-z_point"),
            ("ymag",     "lsst-y_point"),
            ("pmracosd", "pmracosd"),
            ("pmdec",    "pmdec"),
        ]:
            catalog_stars[c2] = self.star[c1]

        # 20231205 AV: this line was missing?? why?? stars were never appended correctly?!
        self.catalog = np.append(self.catalog, catalog_stars)

        # Add the binary stars
        self.star_binary = self.get_star_binary()
        catalog_stars_binary = np.full_like(self.catalog, np.nan, shape=len(self.star_binary))
        for c1, c2 in [
            ("ra",       "RA"),
            ("dec",      "DEC"),
            ("mu0",      "D"),
            ("c3_umag",  "lsst-u_point"),
            ("c3_gmag",  "lsst-g_point"),
            ("c3_rmag",  "lsst-r_point"),
            ("c3_imag",  "lsst-i_point"),
            ("c3_zmag",  "lsst-z_point"),
            ("c3_ymag",  "lsst-y_point"),
            ("pmracosd", "pmracosd"),
            ("pmdec",    "pmdec"),
        ]:
            catalog_stars_binary[c2] = self.star_binary[c1]
        self.catalog = np.append(self.catalog, catalog_stars_binary)

        # NOTE: dirty post-processing
        is_star = self.get_is_star()
        self.catalog["ID"][is_star] = np.arange(self.catalog.size)[is_star]
        self.catalog["is_agn_ctn"][is_star] = False
        self.catalog["is_agn_ctk"][is_star] = False
        self.catalog["is_agn"][is_star] = False
        self.catalog["is_optical_type2"][is_star] = False

        for band in map(str.strip, self.egg["BANDS"][0]):
            self.catalog[band + "_total"] = self.get_flux_total(band, rest_frame=False)

        for band in map(str.strip, self.egg["BANDS"][0]):
            self.catalog["magabs_" + band + "_total"] = self.get_flux_total(band, rest_frame=True)

        # Reprocess the lsst flags
        self.catalog["is_lsst_any_30s"] = self.get_is_lsst("30s", False)
        self.catalog["is_lsst_all_30s"] = self.get_is_lsst("30s", True)
        self.catalog["is_lsst_any_10yr"] = self.get_is_lsst("10yr", False)
        self.catalog["is_lsst_all_10yr"] = self.get_is_lsst("10yr", True)

        return self.catalog

    def get_columns(self):

        return [

            # Columns from EGG
            ("ID", np.int64, "unique EGG ID", ""),
            ("RA", np.float64, "Right ascenscion", "deg"),
            ("DEC", np.float64, "Declination", "deg"),
            ("Z", np.float64, "Cosmological redshift", ""),
            ("D", np.float64, "Luminosity distance OR distance modulus for stars", "Mpc / dimensionless"),
            ("M", np.float64, "log10 of stellar mass", "Msun"),
            ("SFR", np.float64, "Star-formation rate", "Msun/yr"),
            ("PASSIVE", bool, "Is passive (non-star-forming)", ""),
            ("CLUSTERED", bool, "Is clustered", ""),

        ] + [

            # Galaxy morphological parameters
            ("DISK_ANGLE",   np.float64, "Rotation angle of the galaxy disk", "deg"),
            ("DISK_RADIUS",  np.float64, "Half-light radius of the galaxy disk", "arcsec"),
            ("DISK_RATIO",   np.float64, "Ratio of minor to major axis of the disk", ""),
            ("BULGE_ANGLE",  np.float64, "Rotation angle of the galaxy bulge", "deg"),
            ("BULGE_RADIUS", np.float64, "Half-light radius of the galaxy bulge", "arcsec"),
            ("BULGE_RATIO",  np.float64, "Ratio of minor to major axis of the bulge", ""),

        ] + [

            # Emission line extinction
            ("AVLINES_BULGE",  np.float64, "Emission line extinction in the bulge", "mag"),
            ("AVLINES_DISK",   np.float64, "Emission line extinction in the disk",  "mag"),

        ] + [

            # Disk fluxes
            (b.strip() + "_disk",  np.float64, f"{b} flux disk", "uJy") for b in self.egg["BANDS"][0]

        ] + [

            # Bulge fluxes
            (b.strip() + "_bulge", np.float64, f"{b} flux bulge", "uJy") for b in self.egg["BANDS"][0]

        ] + [

            # EGG absolute magnitudes
            ("magabs_" + b.strip(), np.float64, f"{b} absolute magnitude", "AB mag") for b in self.egg["RFBANDS"][0]

        ] + [

            ## Columns derived for AGN
            ##("classification",   np.int16,   "Star-forming classification according to Aird+", ""),
            #("log_NH",           np.float64, "Obscuration according to Ueda+2014", ""),
            ("log_lambda_SAR",   np.float64, "log10 of specific accretion ratio", ""),
            ("log_LX_2_10",      np.float64, "log10 of 2-10 keV intrinsic X-ray luminosity", "erg/s"),
            ("log_FX_2_10",      np.float64, "log10 of 2-10 keV intrinsic X-ray flux at redshift", "erg/cm2/s"),
            ("log_FX_2_7",       np.float64, "log10 of 2-7 keV intrinsic X-ray flux at redshift", "erg/cm2/s"),
            ("duty_cycle_ctn",   np.float64, "Probability of BH being active and CTN", ""),
            ("duty_cycle_ctk",   np.float64, "Probability of BH being active and CTK", ""),
            ("is_agn_ctn",       bool,       "CTN AGN classification based on duty cycle", ""),
            ("is_agn_ctk",       bool,       "CTK AGN classification", ""),
            ("is_agn",           bool,       "is CTN or CTK AGN", ""),
            ("is_optical_type2", bool,       "Optical AGN type according to Merloni+14", ""),
            ("E_BV",             np.float64, "", ""),
            ("log_L_2_keV",      np.float64, "", "erg/s/Hz"),
            ("log_L_2500",       np.float64, "", "erg/s/Hz"),
            ("MBH",              np.float64, "Black hole mass",     "log(Msun)"),
            #("Mhalo",            np.float64, "DM Halo mass",        "log(Msun)"),
            #("sigma",            np.float64, "Velocity dispersion", "km/s"),
            ("log_L_FIR",        np.float64, "", "erg/s"),
            ("log_L_GAL_5GHz",   np.float64, "", "W/Hz"),
            ("log_L_AGN_5GHz",   np.float64, "", "erg/s"),
            #("type",             "U2",       "4-value Obscuration type according to Merloni+14", ""),
            ("log_L_bol",        np.float64, "", "bolometric AGN luminosity"),
            # TODO:
            #("tau",              np.float64, "AGN lightcurve parameter tau (r-band)",    ""),
            #("sf_inf",           np.float64, "AGN lightcurve parameter sf_inf (r-band)", ""),
        ] + [

            # Point fluxes
            (b.strip() + "_point", np.float64, f"{b} flux", "uJy") for b in self.egg["BANDS"][0]

        ] + [

            # AGN absolute magnitudes parameters
            ("magabs_" + b.strip() + "_point", np.float64, f"{b} absolute magnitude", "AB mag") for b in self.egg["RFBANDS"][0]

        ] + [

            # Total fluxes
            (b.strip() + "_total",  np.float64, f"{b} total flux of the object", "uJy") for b in self.egg["BANDS"][0]

        ] + [

            # EGG absolute magnitudes
            ("magabs_" + b.strip() + "_total", np.float64, f"{b} total absolute magnitude", "AB mag") for b in self.egg["RFBANDS"][0]

        ] + [

            # Proper motions
            ("pmracosd", np.float64, "mas/yr", "Proper motion in RA*cos(Dec)"),
            ("pmdec",    np.float64, "mas/yr", "Proper motion in dec"),

        ] + [

            # LSST selection
            ("is_lsst_any_30s",  bool, "Object is brighter than the LSST 30s limiting magnitude in any of the 6 bands", ""),
            ("is_lsst_any_10yr", bool, "Object is brighter than the LSST 10yr limit magnitude in any of the 6 bands", ""),

            # LSST selection
            ("is_lsst_all_30s",  bool, "Object is brighter than the LSST 30s limiting magnitude in all of the 6 bands", ""),
            ("is_lsst_all_10yr", bool, "Object is brighter than the LSST 10yr limit magnitude in all of the 6 bands", ""),

        ]

    def get_dtype(self):
        i = []
        for n, t, _, _ in self.get_columns():
            i += [(n, t)]
        return np.dtype(i)

    def get_galaxy(self, key):

        print("Getting galaxy attribute", key)

        _bands = [b.strip() for b in self.egg["BANDS"][0]]
        key_without_suffix = key.replace("_disk", "").replace("_bulge", "")

        try:
            if key_without_suffix in _bands:
                idx = _bands.index(key_without_suffix)
                key_egg = "FLUX"
                if "bulge" in key:
                    key_egg += "_BULGE"
                if "disk" in key:
                    key_egg += "_DISK"
                return self.egg[key_egg][0, :, idx]

            elif "magabs_" in key:
                idx = _bands.index(key.replace("magabs_", ""))
                return self.egg["RFMAG"][0, :, idx]
        except IndexError:
            # Fluxes not generated... return zerovalues
            return np.full_like(self.egg["Z"][0], np.inf if "magabs_" in key else 0.0)

        ret = self.egg[key][0]
        if key == "M":
            # NOTE: Salpeter to Chabrier IMF conversion see e.g. Grylls+20
            ret -= 0.24

        if key == "SFR":
            # Where necessary to convert SFRs from the literature from Chabrier
            # or Kroupa IMFs to the Salpeter IMF, we divide by constant factors
            # of 0.63 (Chabrier) or 0.67 (Kroupa).

            #     SFR_salpeter = SFR_chabrier / 0.63
            # ==> SFR_chabrier = 0.63 * SFR_salpeter
            # (https://ned.ipac.caltech.edu/level5/March14/Madau/Madau3.html)
            ret *= 0.63

        return ret

    def get_agn(self, key):

        print("Getting AGN attribute", key)

        ret = None
        func = getattr(self, f"get_{key}", None)

        if func:
            ret = func()
            exec(f"self.{key} = ret")
        elif "_point" in key:
            ret = self.get_flux_agn(
                band=key.replace("magabs_", "").replace("_point", ""),
                rest_frame="magabs_" in key
            )
        elif "is_lsst" in key:
            ret = self.get_is_lsst(key.split("_")[-1], key.split("_")[-2] == "all")

        return ret

    def get_random(self, select=None, fact=5):

        """Return unclustered positions within the area"""

        # Initialize the random
        catalog = self.catalog
        if select is not None:
            catalog = catalog[select]

        random = np.full_like(self.catalog, -np.inf, shape=self.catalog.size * fact)

        # Assign RA
        ra_min, ra_max = self.catalog["RA"].min(), self.catalog["RA"].max()
        random["RA"] = np.random.rand(random.size) * (ra_max - ra_min) + ra_min

        # Assign dec
        sindec_min, sindec_max = map(lambda a: np.sin(a * np.pi / 180.), [self.catalog["DEC"].min(), self.catalog["DEC"].max()])
        random["DEC"] = 180 / np.pi * np.arcsin(np.random.rand(random.size) * (sindec_max - sindec_min) + sindec_min)
        return random

    def get_log_lambda_SAR(self):

        log_lambda_SAR = np.full_like(self.catalog["Z"], -np.inf)
        zs = np.arange(np.floor(self.catalog["Z"].min()), np.ceil(self.catalog["Z"].max()), self.dz)
        ms = np.arange(np.floor(self.catalog["M"].min()), np.ceil(self.catalog["M"].max()), self.dm)

        if self.type_plambda == "bongiorno+2016":
            from plambda import get_plambda_bon16_air18
            lvec = np.linspace(32, 36, 401)
            for z, m in product(zs, ms):

                print(f"Assigning log_lambda_SAR {z:.2f} {m:.2f}", end='\r')
                select = self._get_select(z, m)
                if select.sum() == 0:
                    continue

                lvec, pq, ps = get_plambda_bon16_air18(m, m + self.dm, z + self.dz / 2, lvec)
                sq = select * (self.catalog["PASSIVE"] == 1)
                ss = select * (self.catalog["PASSIVE"] == 0)
                log_lambda_SAR[sq] = np.interp(np.random.rand(sq.sum()), np.cumsum(pq), lvec, left=32, right=36)
                log_lambda_SAR[ss] = np.interp(np.random.rand(ss.sum()), np.cumsum(ps), lvec, left=32, right=36)

        elif self.type_plambda == "zou+2024":
            from plambda import get_plambda_zou2024
            lvec = np.linspace(32, 40, 801)
            for i in self.catalog["ID"]:

                print(i, end='\r')

                t = ["starforming", "quiescent"][self.catalog["PASSIVE"][i]]
                log_mstar = self.catalog["M"][i]
                z = self.catalog["Z"][i]

                # NOTE: normalize plambda to 1 so that every objects has
                # log_lambda_SAR The final active fraction is then dictated by
                # the duty cycle
                plambda = get_plambda_zou2024(lvec, t, log_mstar, z)
                Plambda = np.cumsum(plambda)
                Plambda /= Plambda[-1]
                log_lambda_SAR[i] = np.interp(np.random.rand(), Plambda, lvec)

        else:

            raise ValueError

        ###########################################################
        # Old methodology of using simply Bongiorno+2016
        #log_lambda_SAR[select] = bongiorno2016.get_log_lambda_SAR(
        #    mlo=m,
        #    mhi=m + self.dm,
        #    z=z + self.dz / 2.0,  # NOTE: mean redshift of the bin
        #    size=select.sum()
        #)
        ###########################################################
        return log_lambda_SAR

    def get_log_LX_2_10(self):
        return self.catalog["log_lambda_SAR"] + self.catalog["M"]

    def get_log_FX_2_10(self, Gamma=1.9):
        # NOTE: 1 + z dependence from K-correction
        z = self.catalog["Z"]
        d = self.catalog["D"] * u.Mpc.to(u.cm)
        return self.log_LX_2_10 - np.log10(4 * np.pi * d ** 2) + (2 - Gamma) * np.log10(1 + z)

    def get_log_FX_2_7(self, Gamma=1.9):
        return np.log10(util.convert_flux(10 ** self.log_FX_2_10, 2.0, 10.0, 2.0, 7.0, Gamma=Gamma))

    def get_log_L_2_keV(self, Gamma=1.9, wavelength=6.2):
        """
        Returns monochromatic X-ray luminosity at lambda = wavelength  in erg/s Hz^-1 to be used for the alpha_ox
        Lx = restframe 2-10 kev luminosity.
        maybe we can work directly with frequencies and have it in one line
        """
        Lx = 10 ** self.catalog["log_LX_2_10"]
        K = (Lx / (6.2 ** (Gamma - 2) - 1.24 ** (Gamma - 2))) * (Gamma - 2)  #6.2, 1.24 = 2kev, 10kev in A°
        return np.log10((K * wavelength ** (Gamma - 1)) / 2.998e18)

    def get_log_L_2500(self, alpha=0.952, beta=2.138, scatter=0.40):

        """
        Returns the 2500 ang° monochromatic luminosity (in erg/s). It uses
        Lusso+10 eq. 5 (inverted) Lx = alpha L_opt - beta
        """

        log_L_2_keV = self.catalog["log_L_2_keV"]
        log_L_2500 = (log_L_2_keV + beta) / alpha
        assert np.allclose(alpha * log_L_2500 - beta, log_L_2_keV)

        # TODO: implement realistic scatter?
        log_L_2500 += np.random.normal(loc=0, scale=scatter, size=self.catalog.size)

        return log_L_2500

    def get_is_optical_type2(self, func_get_f_obs=None):

        # Get obscured AGN fraction from Merloni+2014 as a function of z, LX
        select_ctn = self.catalog["is_agn_ctn"]
        select_ctk = self.catalog["is_agn_ctk"]

        if func_get_f_obs is None:
            func_get_f_obs = self.merloni2014.get_f_obs

        f_obs = func_get_f_obs(
            self.catalog["Z"][select_ctn],
            self.catalog["log_LX_2_10"][select_ctn]
        )

        # Randomize type2 based on obscured fraction
        ret = np.full(self.catalog.size, False)
        ret[select_ctn] = np.random.rand(select_ctn.sum()) < f_obs
        ret[select_ctk] = True
        return ret

    def get_E_BV(
        self,
        #alpha_1=15.19373112, # NOTE: These type1 AGN parameters are from COSMOS (Bon+).
        #n_1=1.58310706,
        alpha_1=7.93483055,   # NOTE: these parameters are derived from Zou+ catalog from LSST DDFs.
        n_1=2.97565676,
        alpha_2=11.6133635,
        n_2=1.42972,
        type_1_ebv=np.linspace(0,1, 101),
        type_2_ebv=np.linspace(0,3, 301),
        mu_type_2=0.3,
    ):

        def sample_ebv(N_AGN, probability_distribution, ebv_range, *args):
            cumulative = np.cumsum(probability_distribution(ebv_range, *args))
            cumulative /= np.max(cumulative)
            return np.interp(np.random.rand(N_AGN), cumulative, ebv_range)

        def hopkins04(x, alpha, n):
            """p(E_BV)"""
            y = 1 / (1 + (x * alpha) ** n)
            return y / np.trapz(y, x)

        ebv = np.empty(self.catalog.size)

        type_1_optical = ~self.catalog["is_optical_type2"]
        type_2_optical =  self.catalog["is_optical_type2"]

        N_type_1 = np.sum(type_1_optical)
        N_type_2 = np.sum(type_2_optical)

        ebv[type_1_optical] = sample_ebv(type_1_optical.sum(), hopkins04, type_1_ebv, alpha_1, n_1)
        ebv[type_2_optical] = sample_ebv(type_2_optical.sum(), hopkins04, type_2_ebv, alpha_2, n_2) + mu_type_2

        ## NOTE: set CTK AGN E(B-V) to 9 (arbitrarily high number) manually...
        #is_ctk = self.catalog["is_agn_ctk"]
        #ebv[is_ctk] = 9.00
        return ebv

    def get_duty_cycle_ctn(self):

        duty_cycle = np.zeros(self.catalog.size)

        if self.type_plambda == "bongiorno+2016":

            smf_agn = smf.StellarMassFunctionBongiorno2016()
            smf_gal = smf.StellarMassFunctionCOSMOS2020AGN()

            mvec = np.arange(np.floor(self.catalog["M"].min()), np.ceil(self.catalog["M"].max()), 0.01)

            for t in "quiescent", "star-forming":

                for z in np.arange(np.floor(self.catalog["Z"].min()), np.ceil(self.catalog["Z"].max()), self.dz):

                    for m in np.arange(np.floor(self.catalog["M"].min()), np.ceil(self.catalog["M"].max()), self.dm):

                        select = self._get_select(z=z, m=m, t=t)
                        if select.sum() == 0:
                            continue

                        ## NOTE: optimization to use pre-saved SMFs
                        ret = []
                        for prefix, func in [
                            ("gal", smf_gal.get_stellar_mass_function),
                            ("agn", smf_agn.get_stellar_mass_function),
                        ]:
                                filename = f"{ROOT}/data/smf/{prefix}_m_{m + self.dm / 2}_z_{z + self.dz / 2}_{t}.npy"
                                if not os.path.exists(filename):
                                    np.save(filename, func(10 ** (m + self.dm / 2), z + self.dz / 2, type=t))
                                ret += [np.load(filename)]
                        duty_cycle[select] = ret[1] / ret[0]
                        print(f"Duty cycle t={t} z={z:6.2f}, m={m:6.2f} N={select.sum()}", end='\r')

        elif self.type_plambda == "zou+2024":

            from plambda import get_plambda_zou2024
            lvec = np.linspace(32, 40, 801)

            for i in self.catalog["ID"]:

                t = ["starforming", "quiescent"][self.catalog["PASSIVE"][i]]
                log_mstar = self.catalog["M"][i]
                z = self.catalog["Z"][i]

                plambda = get_plambda_zou2024(lvec, t, log_mstar, z)
                select = lvec > 32
                f_agn = np.trapz(plambda[select], lvec[select])
                duty_cycle[i] = f_agn
                print(i, f_agn, end='\r')

        else:

            raise ValueError

        # NOTE: maximum possible duty cycle is 1.0
        return np.clip(duty_cycle, 0.0, 1.0)

    def get_duty_cycle_ctk(self):

        """Get the CTK AGN duty cycle"""

        ret = np.zeros_like(self.catalog["Z"])

        for z in np.arange(np.floor(self.catalog["Z"].min()), np.ceil(self.catalog["Z"].max()), self.dz):

            for l in np.arange(np.floor(self.catalog["log_LX_2_10"].min()), np.ceil(self.catalog["log_LX_2_10"].max()), self.dl):

                select = self._get_select(z=z, l=l)
                if select.sum() == 0:
                    continue

                # From Ueda+2014 we get that the fraction of CTK AGN should be some number X
                # i.e. that
                #
                #   N_CTK / (N_CTN + N_CTK) = X
                #
                # Solving for N_CTK yields
                #
                #   N_CTK = X * N_CTN / (1 - X)
                frac = [ueda2014.get_f(l + self.dl / 2, z + self.dz / 2, nh) for nh in [20, 21, 22, 23, 24]]
                norm = np.sum(frac)
                frac_agn_ctk = frac[-1] / norm

                n_ctn = self.catalog["duty_cycle_ctn"][select].sum()
                n_ctk = frac_agn_ctk * n_ctn / (1 - frac_agn_ctk)
                n_agn = n_ctn + n_ctk
                n_gal = select.sum() - n_ctn
                assert np.isclose(n_ctk / n_agn, frac_agn_ctk)

                ret[select] = n_ctk / n_gal

        return ret

    def get_is_agn_ctn(self, seed=1337):
        """Render a random sample of AGN "active" according to the duty cycle"""
        np.random.seed(seed)
        return np.random.rand(self.catalog.size) < self.catalog["duty_cycle_ctn"]

    def get_is_agn_ctk(self, seed=7331):
        """Assign CTK AGN fraction randomly"""
        np.random.seed(seed)

        # I believe the ctk agn duty cycle should be modified here somehow...
        #   p(is_ctk) = p(U < duty_cycle_ctk)
        # How do deal with objects which are rendered both ctn and ctk? Maybe just by random?
        is_agn_ctn = self.catalog["is_agn_ctn"]
        is_agn_ctk = np.random.rand(self.catalog["Z"].size) < self.catalog["duty_cycle_ctk"]
        while (is_agn_ctn * is_agn_ctk).sum() != 0:
            select = is_agn_ctn * is_agn_ctk
            print(select.sum())
            is_agn_ctk[select] = np.random.rand(select.sum()) < self.catalog["duty_cycle_ctk"][select]
        return is_agn_ctk

    def get_is_agn(self):
        return self.catalog["is_agn_ctn"] + self.catalog["is_agn_ctk"]


    #def get_log_NH(self):

    #    # Get NH distribution function from Ueda+2014 and normalize to unity
    #    nhs = np.array([20, 21, 22, 23, 24])
    #    p = ueda2014.get_f(
    #        self.catalog["log_LX_2_10"][:, None],
    #        self.catalog["Z"][:, None],
    #        nhs[None, :]
    #    )
    #    p /= p.sum(axis=1)

    #    ret = np.zeros_like(self.catalog["Z"])

    #    for i in np.arange(self.catalog.size):
    #        print(i, self.catalog.size, i / self.catalog.size * 100, end='\r')
    #        p = np.array([ueda2014.get_f(l, z, nh) for nh in nhs])
    #        nh_lo = np.random.choice(nhs, p=p[i, :])
    #        nh_hi = nh_lo + 1 if nh_lo < 24 else nh_lo + 2
    #        ret[i] = np.random.uniform(nh_lo, nh_hi)

    #    return ret


    #def get_type(self):

    #    """
    #    Get AGN type based on Merloni+2014. The type is a two-byte string e.g.
    #    11 where the first (second) byte refers to optical (X-ray) obscuration,
    #    where 1 means unobscured and 2 obscured.
    #        22  opt + X-ray obscured
    #        21  opt obscured, X-ray unobscured
    #        12  opt unobscured, X-ray obscured
    #        11  opt + X-ray unobscured
    #    """

    #    tmp = []
    #    for fname in (
    #        "data/merloni2014/22_21.csv",
    #        "data/merloni2014/21_12.csv",
    #        "data/merloni2014/12_11.csv",
    #    ):
    #        x, y = np.loadtxt(fname).T
    #        tmp.append(interp1d(x, y, bounds_error=False, fill_value=(y[0], y[-1])))
    #    type22_21, type21_12, type12_11 = tmp

    #    # Assign type for CTK AGN
    #    # Define some selections for CTN/CTK AGN
    #    #select_agn = self.catalog["is_agn"]
    #    select_ctn = self.catalog["is_agn_ctn"]
    #    select_ctk = self.catalog["is_agn_ctk"]
    #    #select = select_agn * select_ctn
    #    select = select_ctn

    #    loglx = np.arange(
    #        self.catalog[select]["log_LX_2_10"].min(),
    #        self.catalog[select]["log_LX_2_10"].max() + 0.15,
    #        0.10,
    #    )

    #    ret = np.empty_like(self.log_LX_2_10, dtype="U2")

    #    for l_min, l_max in zip(loglx[:-1], loglx[1:]):

    #        # Perform the selection
    #        select = (l_min <= self.log_LX_2_10) * (self.log_LX_2_10 < l_max)

    #        l = (l_min + l_max) / 2.

    #        p22 = type22_21(l)
    #        p21 = type21_12(l) - type22_21(l)
    #        p12 = type12_11(l) - type21_12(l)
    #        p11 = 100 - type12_11(l)

    #        ret[select] = np.random.choice(
    #            ["22", "21", "12", "11"],
    #            size=select.sum(),
    #            replace=True,
    #            p=np.array([p22, p21, p12, p11]) / 100.
    #        )

    #    # NOTE: Assign CTK_AGN type
    #    ret[select_ctk] = "22"

    #    return ret

    #def get_sed_single_object(self, i):

    #    """
    #    This is a helper function in order to do multiprocessing... See the function below for the logic
    #    """

    #    if i in SEDS:
    #        return SEDS[i]

    #    # Initialize the kwargs
    #    kwargs = {
    #        "LogL2500": self.catalog[i]["log_L_2500"],
    #        "AGN_type": 1 + 1 * self.catalog[i]["is_optical_type2"],
    #        "ebv": self.catalog[i]["E_BV"],
    #        "redshift": self.catalog[i]["Z"],
    #        "distance_cm": self.catalog[i]["D"] * u.Mpc.to(u.cm),
    #        "LogL2kev": self.catalog[i]["log_L_2500"],
    #        "flux_rf_1000_4000_gal": util.mag_to_flux(self.catalog[i]["magabs_mock-1000-4000"]),
    #        "seed": self.catalog[i]["ID"],
    #    }

    #    # Get the SED
    #    ebv, (lam_rf, flux_rf), (lam_of, flux_of) = sed.get_sed(**kwargs)

    #    # Update the E_BV value
    #    self.catalog[i]["E_BV"] = ebv

    #    SEDS[i] = (lam_rf, flux_rf), (lam_of, flux_of)

    #    return SEDS[i]

    def get_sed(self, i, component):

        dirname = os.path.dirname(self.filename_catalog)

        if component in ["bulge", "disk"]:

            filename = f"{dirname}/seds/egg-seds-{component}-{i}.fits"
            if not os.path.exists(filename):
                cmd = f"egg-getsed seds={dirname}/egg-seds.dat id={i} component={component} out={filename}"
                print("Running", cmd)
                os.system(cmd)

        else:

            filename = f"{dirname}/seds/agn-seds-{i}.fits"
            if not os.path.exists(filename):
                return None

        fits = fitsio.read(filename)
        return fits["LAMBDA"][0], fits["FLUX"][0]

    def _get_agn_sed(self, i, ratio_max=0.90):

        print("Getting AGN SED", i, self.catalog.size, end='\r')

        # Get the wavlen in angstrom
        # NOTE: Add some interesting wavelengths for greater accuracy
        #dlog_wav = 7.65e-5
        dlog_wav = 7.65e-4
        wavlen = 10 ** np.arange(np.log10(500), np.log10(250000) + dlog_wav, dlog_wav)
        wavlen = np.append(wavlen, [1450, 4400, 5007, 150000])
        wavlen = np.sort(wavlen)

        # Populate the SED
        while True:

            # Initialize the rng seed
            np.random.seed(self.catalog[i]["ID"])

            agn_sed = qsosed.Quasar_sed(
                LogL2500=self.catalog[i]["log_L_2500"],
                AGN_type=1 + self.catalog[i]["is_optical_type2"],
                ebv=self.catalog[i]["E_BV"],
                physical_units=True,
                wavlen=wavlen,
                LogL2kev=self.catalog[i]["log_L_2_keV"],
                add_NL=self.catalog[i]["is_optical_type2"],
                NL_normalization="lamastra",
                Av_lines=self.catalog[i]["AVLINES_BULGE"] + self.catalog[i]["AVLINES_DISK"],
                **dict(zip(PARAMETER_NAMES, *POSTERIOR_DISTRIBUTION.sample().values))
            )

            # Check for type2 AGN flux NOT exceeding the host galaxy flux by
            # some limit
            if self.catalog[i]["is_optical_type2"] and self.catalog[i]["E_BV"] <= 9.0:
                lam, flux_agn = util.luminosity_to_flux(
                    agn_sed.wavlen.value,
                    agn_sed.lum.value,
                    redshift=0.0,
                    distance_in_cm=10 * u.pc.to(u.cm),
                )
                flux_agn = sed.get_flux_band(lam, flux_agn, band="mock-1000-4000")
                flux_gal = util.mag_to_flux(self.catalog[i]["magabs_mock-1000-4000"])
                ratio = flux_agn / (flux_gal + flux_agn)

                if ratio > ratio_max:
                    print(
                        "AGN 1000-4000 angstrom >90% of the total... Incrementing E(B-V) by 0.10...",
                        "%6d"  % self.catalog[i]["ID"],
                        "%.2f" % np.log10(flux_agn),
                        "%.2f" % np.log10(flux_gal),
                        "%.2f" % ratio,
                    )
                    self.catalog[i]["E_BV"] += 0.10
                    continue
            break

        # Write the file if requested
        if self.save_sed:
            filename = self.dirname + f"/seds/agn-seds-{self.catalog[i]['ID']}.fits"
            lam, flux = util.luminosity_to_flux(
                agn_sed.wavlen.value,
                agn_sed.lum.value,
                redshift=self.catalog[i]["Z"],
                distance_in_cm=self.catalog[i]["D"] * u.Mpc.to(u.cm),
            )
            sed.write(filename, lam, flux)

        # NOTE: optimization: save values from the sed for future use
        for b, r in product(map(str.strip, self.egg["BANDS"][0]), [False, True]):
            FLUXES[i, b, r] = self._init_flux_single(agn_sed, b, r, self.catalog["Z"][i], self.catalog["D"][i])

        #SEDS[i] = agn_sed
        #return SEDS[i]

        return agn_sed

    def _init_flux_single(self, my_sed, band, rest_frame, redshift, distance):

        # Get the flux in observed frame or rest frame
        if rest_frame:
            redshift = 0
            distance = 1e-5 # NOTE: 10pc in Mpc

        lam, flux = util.luminosity_to_flux(
            my_sed.wavlen.value,
            my_sed.lum.value,
            redshift=redshift,
            distance_in_cm=distance * u.Mpc.to(u.cm),
        )

        # Get the flux for the requested band
        flux_band = sed.get_flux_band(lam, flux, band)

        # NOTE: convert to magnitudes for the rest_frame
        if rest_frame:
            flux_band = util.flux_to_mag(flux_band)

        return flux_band

    def get_flux_agn(self, band, rest_frame, idxs=None, mjd=None, mjd0=None):

        flux = np.full(self.catalog.size, np.inf if rest_frame else 0.0)

        if idxs is None:
            idxs = self.catalog[self.catalog["is_agn"]]["ID"]

        for idx in idxs:

            # NOTE: fluxes should have been initialized by now
            if (idx, band, rest_frame) not in FLUXES:
                my_sed = self._get_agn_sed(idx)
            flux[idx] = FLUXES[idx, band, rest_frame]

            # Estimate the AGN lightcurve?
            c = self.catalog[idx]

            if not (mjd is None and mjd0 is None):
                lc = lightcurve.get_lightcurve_agn(
                    mjd0=mjd0, mjd=mjd, flux=flux[idx], band=band,
                    z=c["Z"][idx], mag_i=c["magabs_lsst-i"][idx],
                    logMbh=c["MBH"][idx], type2=c["is_optical_type2"][idx],
                    T=3653, deltatc=1, seed=idx
                )
                flux[idx] = lc

        return flux

    def _get_flux_star(star, band, mjd0, mjd):

        b = band.replace("lsst-", "")
        flux = util.mag_to_flux(star[f"{b}mag"])

        if self._get_is_star_cepheid(star["label"], star["pmode"]):
            # Handle cepheid
            flux = lightcurve.get_lightcurve_cepheid(mjd0, mjd, flux,
                band, star["z"], star["y"], star["logte"], star["logl"],
                star["mass"], star["pmode"], star["period"], seed=idx)

        elif self._get_is_star_lpv(star["label"], star["pmode"]):
            # Handle LPV
            flux = get_lightcurve_lpv(mjd0, mjd, flux, band,
                star["pmode"], star["period"], star["c_o"] > 1, seed=idx)

        return flux


    def get_flux_star(self, band, idxs=None, mjd=None, mjd0=None):

        flux = np.full(self.catalog.size, np.inf if rest_frame else 0.0)

        if idxs is None:
            idxs = self.catalog["ID"][self.get_is_star()]

        for idx in idxs:

            if not self.get_is_star()[idx]:
                continue

            c = self.catalog[idx]
            flux[idx] = c[f"{band}_point"]

            # Short-circuit for static universe case
            if (mjd is None and mjd0 is None):
                continue

            # NOTE: stars come in single and binary stars. Any single star may
            # vary according to its own physics, while binary stars have an
            # additional variability due to possible eclipses between the two.
            # The next piece of code tries to account for both variabilities
            idx_star = idx - self.egg["ID"].size
            fluxes = []
            for prefix in "", "c1_", "c2_":

                if self.get_is_star_single()[idx]:
                    s = self.star[idx_star]
                else:
                    s = self.binary[idx_star]

                # Build the custom star
                try:
                    star = {
                        "flux":   util.mag_to_flux(s[prefix + "{b}mag"]),
                        "y":      s[prefix + "y"],
                        "z":      s[prefix + "z"],
                        "logte":  s[prefix + "logte"],
                        "logl":   s[prefix + "logl"],
                        "mass":   s[prefix + "mass"],
                        "pmode":  s[prefix + "pmode"],
                        "period": s[prefix + "period"],
                        "c_o":    s[prefix + "c_o"],
                    }

                    # Get the instant flux
                    lc = self._get_flux_star(star, band, mjd, mjd0)
                    fluxes.append(lc)
                except:
                    raise

            # Return single star flux
            if len(fluxes) == 1:
                flux[idx] = fluxes[0]
                continue

            # Return binary star flux
            lc1, lc2 = lightcurve.get_lightcurve_binary(
                mjd0=mjd0,
                mjd=mjd,
                flux1=fluxes[0],
                flux2=fluxes[1],
                radius1=self._get_radius_star(s["c1_mass"], s["c1_logg"]),
                radius2=self._get_radius_star(s["c2_mass"], s["c2_logg"]),
                p=s["p"],
                a=s["a"],
                i=s["i"],
                e=s["e"],
                t0=s["p"] * 0.25,
                t1=s["p"] * 0.75,
                seed=idx + 1000
            )
            flux[idx] = lc1 + lc2

        return np.nan

    def get_log_L_bol(self):

        # NOTE: this function initializes the SEDs...
        log_L_bol = np.full(self.catalog.size, -np.inf)

        for i in self.catalog[self.catalog["is_agn"]]["ID"]:
            my_sed = self._get_agn_sed(i)
            log_L_bol[i] = np.log10(my_sed.Lbol)

        return log_L_bol

    def get_is_lsst(self, key, all):

        """
        Return whether the an object is within the expected 5sigma magnitude
        limits of LSST. 'key' controls whether it is single visit (30s) or full
        10yr. 'all' controls whether the object should be detected in _all_ of
        the bands or any of the bands.
        """

        if all:
            select = np.ones_like(self.catalog["Z"], dtype=bool)
        else:
            select = np.zeros_like(self.catalog["Z"], dtype=bool)

        if "30s" in key:
            LIMIT_LSST = limiting_magnitude_30s
        elif "10yr" in key:
            LIMIT_LSST = limiting_magnitude_10yr

        for b in 'ugrizy':
            mag = util.flux_to_mag(self.get_flux_total("lsst-" + b, rest_frame=False))
            if all:
                select *= mag < LIMIT_LSST[b]
            else:
                select += mag < LIMIT_LSST[b]

        return select

    def get_flux_total(self, band, rest_frame=False, keys=["disk", "bulge", "point"]):

        # NOTE: override rest-frame behavior as EGG does not have absmags for disk/bulge separately
        if rest_frame:
            keys = ["", "point"]

        ret = np.zeros(self.catalog.size)

        for p in keys:
            key = "magabs_" * rest_frame + f"{band}" + f"_{p}" * (p != "")
            flux = self.catalog[key]
            if rest_frame:
                flux = util.mag_to_flux(flux)
            ret += np.where(np.isfinite(flux), flux, 0.0)

        if rest_frame:
            ret = util.flux_to_mag(ret)

        return ret

    def get_lightcurve_agn(self, idxs=None, band="lsst-r", T=3653, deltatc=1):

        """
        AV NOTE:

        This function behaves a little bit wildly with the seed since we really
        do not want to estimate lightcurves to non-AGN. The convenience
        function below solves the issue of finding the lightcurve of a source
        with a certain ID.
        """

        # Initialize the time vector
        tt = np.arange(0, T, deltatc)

        # Figure out the rest-frame wavelength for the filter
        lambda_of = filter_lam_fwhm[band.replace("lsst-", '')][0] * u.um.to(u.angstrom)

        # NOTE: select all AGN if idxs is None
        if idxs is None:
            idxs = self.catalog[self.catalog["is_agn"]]["ID"]
        idxs = np.atleast_1d(idxs)
        print(f"Getting AGN lightcurves {band}", idxs.size)

        # Remove indices that are already present in LIGHTCURVES
        for i, idx in enumerate(idxs):

            print(f"{i} / {idxs.size} ({i / idxs.size * 100:.2f}%)", end='\r')
            key = idx, band
            if key in LIGHTCURVES:
                continue

            np.random.seed(idx)

            tt, yy = LC2(
                T=T,
                deltatc=deltatc,
                z=self.catalog[idx]["Z"],
                mag_i=self.catalog[idx]["magabs_lsst-i_point"],
                logMbh=self.catalog[idx]["MBH"],
                lambda_rest=lambda_of / (self.catalog[idx]["Z"] + 1),
                oscillations=False,
                A=0.14,
                noise=0.00005,
                frame="observed",
                type2=self.catalog[idx]["is_optical_type2"],
            )[:2]

            # Normalize to the 10yr mean
            yy = util.mag_to_flux(yy)
            yy *= self.catalog[f"{band}_point"][idx] / yy.mean()

            # Update the dictionary
            LIGHTCURVES[key] = yy

        return tt, np.array([LIGHTCURVES[i, band] for i in idxs]).squeeze()

    def write(self, filename=None):

        filename = ROOT + "/" + filename

        if os.path.exists(filename):
            return 1

        fitsio.write(
            self.filename_catalog if filename is None else filename,
            self.catalog,
            units=[u for _, _, _, u in self.get_columns()],
            clobber=True
        )

        return 0

    def write_instance_catalog(
        self,
        filename,
        mjd0=60218,
        maglim=None,
        selection_band=None,
        do_lightcurve=True,
        write_agn_host_galaxy=True,
        **kwargs
    ):

        if os.path.exists(filename):
            print("Instance catalog", filename, "exists. Will not overwrite")
            return

        t0 = time.time()
        mjd = kwargs["mjd"]
        print("Writing the instance catalog...")

        # Build the header
        header = '\n'.join(f"{k} {v}" for k, v in kwargs.items())
        band = "lsst-" + "ugrizy"[int(kwargs.get("filter"))]

        # Get the selection
        select = np.ones_like(self.catalog, dtype=bool)
        print("NOTE: original number of sources", select.sum())

        if (maglim is not None) and (selection_band is not None):
            print(f"NOTE: Selecting only sources with {selection_band} < {maglim}")
            select *= util.flux_to_mag(self.get_flux_total(selection_band)) < float(maglim)
        print(f"NOTE: culled number of sources after magnitude cut", select.sum())

        ###############################################################################
        # NOTE: de-select sources that are not within the LSST limit. 10deg2 is
        # conservative compared to the 9.6deg2
        ###############################################################################
        ra = self.catalog["RA"]
        dec = self.catalog["DEC"]
        select_star = select * self.get_is_star()
        if select_star.sum() > 0:
            _ra, _dec = util.get_ra_dec(
                self.catalog["RA"][select_star],
                self.catalog["DEC"][select_star],
                self.catalog["pmracosd"][select_star],
                self.catalog["pmdec"][select_star],
                mjd
            )
            ra[select_star] = _ra
            dec[select_star] = _dec

        coord1 = SkyCoord(ra, dec, unit="deg")
        coord2 = SkyCoord(kwargs["rightascension"], kwargs["declination"], unit="deg")
        select *= coord1.separation(coord2) < np.sqrt(10 * u.deg ** 2 / np.pi)

        c = self.catalog[select]
        ra = ra[select]
        dec = dec[select]
        print("NOTE: culled number of sources after (maglim, ra, dec) cut", select.sum())

        #######################################################################
        # Estimate the fluxes
        #######################################################################
        flux_point = deepcopy(c[band + "_point"])

        if do_lightcurve:
            ###################################################################
            # Get the AGN lightcurves
            ###################################################################

            select_agn = c["is_agn"]
            print(f"Estimating {select_agn.sum()} AGN lightcurves in the {band} band...")
            t0 = time.time()

            lc = lightcurve.get_lightcurve_agn(
                mjd0=mjd0,
                mjd=mjd,
                flux=c[band + "_point"][select_agn],
                band=band,
                z=c["Z"][select_agn],
                mag_i=c["magabs_lsst-i"][select_agn],
                logMbh=c["MBH"][select_agn],
                type2=c["is_optical_type2"][select_agn],
                T=3653,
                deltatc=1,
                seed=self.seed,
            )
            flux_point[select_agn] = lc

            t1 = time.time()
            print(f"Done estimating AGN lightcurves in {t1 - t0} seconds")

        # Some preliminary mag calculations
        mag_disk  = util.flux_to_mag(c[band + "_disk"])
        mag_bulge = util.flux_to_mag(c[band + "_bulge"])
        mag_point = util.flux_to_mag(flux_point)

        ###############################################################################
        # Get the required fields for each entry here
        uid = c["ID"]

        angle_disk = 90 - c["DISK_ANGLE"]
        d_disk1    = c["DISK_RADIUS"]
        d_disk2    = c["DISK_RADIUS"] * c["DISK_RATIO"]

        angle_bulge = 90 - c["DISK_ANGLE"]
        d_bulge1    = c["BULGE_RADIUS"]
        d_bulge2    = c["BULGE_RADIUS"] * c["BULGE_RATIO"]

        catalog = []

        # Object flags
        is_galaxy = self.get_is_galaxy()[select]
        is_star = self.get_is_star()[select]
        is_agn = self.get_is_agn()[select]
        assert is_galaxy.sum() + is_star.sum() + is_agn.sum() == c.size

        for i in np.arange(c.size):
            # Case star
            if is_star[i]:
                catalog += [f"object {uid[i]} {ra[i]} {dec[i]} {mag_point[i]} data/imsim/Const.Template.spec.gz 0 0 0 0 0 0 point none CCM 0.0635117705 3.1"]
            # Case AGN
            elif is_agn[i]:
                catalog += int(write_agn_host_galaxy) * [f"object {uid[i]} {ra[i]} {dec[i]} {mag_bulge[i]} data/imsim/Const.Template.spec.gz 0 0 0 0 0 0 sersic2D {d_bulge1[i]} {d_bulge2[i]} {angle_bulge[i]} 4 none  CCM 0.0635117705 3.1"]
                catalog += int(write_agn_host_galaxy) * [f"object {uid[i]} {ra[i]} {dec[i]} {mag_disk[i]}  data/imsim/Const.Template.spec.gz 0 0 0 0 0 0 sersic2D {d_disk1[i]}  {d_disk2[i]}  {angle_disk[i]}  1 none  CCM 0.0635117705 3.1"]
                catalog += [f"object {uid[i]} {ra[i]} {dec[i]} {mag_point[i]} data/imsim/Const.Template.spec.gz 0 0 0 0 0 0 point none CCM 0.0635117705 3.1"]
            # Case galaxy
            elif is_galaxy[i]:
                catalog += [f"object {uid[i]} {ra[i]} {dec[i]} {mag_bulge[i]} data/imsim/Const.Template.spec.gz 0 0 0 0 0 0 sersic2D {d_bulge1[i]} {d_bulge2[i]} {angle_bulge[i]} 4 none  CCM 0.0635117705 3.1"]
                catalog += [f"object {uid[i]} {ra[i]} {dec[i]} {mag_disk[i]}  data/imsim/Const.Template.spec.gz 0 0 0 0 0 0 sersic2D {d_disk1[i]}  {d_disk2[i]}  {angle_disk[i]}  1 none  CCM 0.0635117705 3.1"]
            else:
                raise ValueError("Unknown type of object")

        dirname = os.path.dirname(filename)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname)
        print('\n'.join([header] + catalog), file=open(filename, 'w'))

        t1 = time.time()
        with open(os.path.dirname(filename) + "/time.log", 'w') as f:
            print(f"Wrote instance catalog in {t1 - t0} seconds")
            print(f"Wrote instance catalog in {t1 - t0} seconds", file=f)

    @staticmethod
    def get_seqnum(observation_id):
        return observation_id % 32768

    def simulate_image(
            self,
            observation_id,
            detector=94,
            maglim=None,
            selection_band=None,
            ra=None,
            dec=None,
            filename_visits=f"{ROOT}/data/baseline_v3.0_10yrs.db",
            exptime="30 30 30 30 30 30",   # string of exptimes for ugrizy
            do_lightcurve=True,
            write_agn_host_galaxy=True,
            run_galsim=True,
            dirname=None,
        ):

        """Simulate a single image corresponding to an observation_id in the cadence file"""

        # Find the observation
        df = None
        query = f"""SELECT * FROM observations WHERE observationId={observation_id}"""
        with sqlite3.connect(filename_visits) as con:
            df = pd.read_sql_query(query, con)

        # Iterate over the observations
        for idx, d in df.iterrows():

            # Figure out the directory/filenames
            observation_id = d["observationId"]
            dirname_id = ROOT + "/" + dirname + f"/{observation_id}"
            if not os.path.exists(dirname_id):
                os.makedirs(dirname_id)

            directory_output = f"{dirname_id}/output"
            directory_checkpoint = f"{dirname_id}/checkpoint"

            ###########################################################
            # Create the instance catalog
            ###########################################################
            filename_instance_catalog = f"{dirname_id}/instance_catalog.txt"

            idx_filter = "ugrizy".index(d["filter"])
            kwargs_instance_catalog = {
                "mjd": d["observationStartMJD"],
                "rightascension": d["fieldRA"],
                "declination": d["fieldDec"],
                "altitude": d["altitude"],
                "azimuth": d["azimuth"],
                "filter": idx_filter,
                "rotskypos": d["rotSkyPos"],
                "dist2moon": d["moonDistance"],
                "moonalt": d["moonAlt"],
                "moondec": d["moonDec"],
                "moonphase": d["moonPhase"],
                "moonra": d["moonRA"],
                "nsnap": 1,
                #"obshistid": observationId,
                "obshistid": 0,
                "rottelpos": d["rotTelPos"],
                "seed": observation_id,
                #"seeing": d["seeingFwhm500"],
                "seeing": d["seeingFwhmEff"],
                "sunalt": d["sunAlt"],
                "vistime": d["visitExposureTime"],
                "numexp": 0,
                "seqnum": self.get_seqnum(observation_id),
            }

            # NOTE: override ra, dec?
            if (ra is not None) and (dec is not None):
                kwargs_instance_catalog["rightascension"] = float(ra)
                kwargs_instance_catalog["declination"] = float(dec)

            # Create the instance catalog
            self.write_instance_catalog(
                filename_instance_catalog,
                maglim=maglim,
                selection_band=selection_band,
                do_lightcurve=int(do_lightcurve),
                write_agn_host_galaxy=int(write_agn_host_galaxy),
                **kwargs_instance_catalog
            )

            ###############################################################################
            # Create the imsim yaml file
            ###############################################################################
            filename_imsim_yaml = f"{dirname_id}/imsim-user-instcat_{detector}.yaml"

            texp = list(map(float, exptime.split()))[idx_filter]
            filter_name = d["filter"]
            imsim_yaml = f"""\
modules:
- imsim
template: imsim-config-instcat
input.instance_catalog.file_name: {filename_instance_catalog}
input.instance_catalog.sort_mag: False
input.checkpoint.dir: {directory_checkpoint}
input.tree_rings: ""
image.sensor: ""
output.dir: {directory_output}
output.det_num.first: {detector}
output.nfiles: 1
output.nproc: 1
output.timeout: 36000
eval_variables.fexptime: {texp}
eval_variables.sband: '{filter_name}'"""

            # Write imSim yaml to a file
            print(imsim_yaml, file=open(filename_imsim_yaml, 'w'))

            ###############################################################################
            # Run galsim
            ###############################################################################
            if not int(run_galsim):
                return

            print("Simulating image with observationId", observation_id)
            filename_imsim_log = f"{dirname_id}/galsim_{detector}.log"
            t0 = time.time()
            os.system(f"OMP_NUM_THREADS=1 galsim {filename_imsim_yaml} 2>&1 | tee {filename_imsim_log}")
            dt = time.time() - t0
            print(f"Simulated image with observationId {observationId} in {dt} seconds")


    def simulate_images(self, query, det_list, filename_visits=f"{ROOT}/data/baseline_v3.0_10yrs.db", *args, **kwargs):

        # Find the observation
        with sqlite3.connect(filename_visits) as con:
            df = pd.read_sql_query(query, con)

        # Simluate the images one by one
        for idx, d in df.iterrows():
            for det in map(int, det_list.split()):
                self.simulate_image(d["observationId"], det, *args, **kwargs)

    def write_reference_catalog(self, filename, maglim, selection_band):

        """
        Write a mock LSST reference catalog in csv format to the given filename
        """

        filename = ROOT + "/" + filename

        if os.path.exists(filename):
            return

        print("Writing the reference catalog...")

        # Cut at some magnitude limit e.g. r < 24
        select = util.flux_to_mag(self.get_flux_total(selection_band)) < float(maglim)

        # Select non-variable sources
        select *= ~self.catalog["is_agn"]

        # Perform the selection
        c = self.catalog[select]

        data = {
            "source_id": c["ID"],
            "ra": c["RA"],
            "dec": c["DEC"],
            "u": util.flux_to_mag(self.get_flux_total("lsst-u"))[select],
            "g": util.flux_to_mag(self.get_flux_total("lsst-g"))[select],
            "r": util.flux_to_mag(self.get_flux_total("lsst-r"))[select],
            "i": util.flux_to_mag(self.get_flux_total("lsst-i"))[select],
            "z": util.flux_to_mag(self.get_flux_total("lsst-z"))[select],
            "y": util.flux_to_mag(self.get_flux_total("lsst-y"))[select],
            # NOTE: for pmra see here:
            #   https://gea.esac.esa.int/archive/documentation/GDR2/Gaia_archive/chap_datamodel/sec_dm_main_tables/ssec_dm_gaia_source.html#:~:text=pmra%20%3A%20Proper%20motion%20in%20right,direction%20of%20increasing%20right%20ascension.
            #   https://pipelines.lsst.io/modules/lsst.meas.algorithms/creating-a-reference-catalog.html
            "pmra":  np.where(np.isfinite(c["pmracosd"]), c["pmracosd"], 0),
            "pmdec":  np.where(np.isfinite(c["pmdec"]), c["pmdec"], 0),
            "epoch": c["RA"].size * ["2000-01-01T12:00:00"],
        }

        with open(filename, 'w') as f:
            print(','.join(data.keys()), file=f)
            for i in range(c["RA"].size):
                print(','.join([f"{v[i]}" for v in data.values()]), file=f)

        # Run the LSST tools ontop
        dirname = f"{os.path.dirname(filename)}/reference_catalog"
        if os.path.exists(dirname):
            os.system("rm -rf {dirname}")
        os.mkdir(f"{dirname}")
        os.system(f"convertReferenceCatalog {dirname} {ROOT}/python/config_reference_catalog.py {filename}")

    def match_to_file(filename, radius_arcsec=1.0):
        """Match the catalog to an LSST Science Pipelines file"""
        from astropy.coordinates import SkyCoord
        from astropy.table import Table
        filename_fits = filename.replace(".parq", ".fits")
        Table.read(filename).write(filename_fits)
        os.system(f"stilts tmatch2 in1={self.filename_catalog} in2={filename_fits} matcher=sky values1='RA DEC' values2='coord_ra coord_dec' find='all' out='match.fits'")

###############################################################################
# NOTE: proper motions are only defined for stellar objects which we get from
# the LSST SIM tables directly
###############################################################################
    def get_pmracosd(self):
        return np.full_like(self.catalog["RA"], np.nan)

    def get_pmdec(self):
        return np.full_like(self.catalog["RA"], np.nan)

    def get_MBH(self):

        # NOTE: use external library to assign MBH
        from mbh import get_log_mbh_continuity, get_delta_log_mbh_shankar2019

        # Get MBH
        log_mbh = np.array([get_log_mbh_continuity(c["M"], c["Z"]) for c in self.catalog])

        # Add the scatter
        delta_log_mbh = np.array([get_delta_log_mbh_shankar2019(c["M"], low=0.00, high=np.inf) for c in self.catalog])
        delta_log_mbh = np.random.normal(scale=delta_log_mbh, size=self.catalog["M"].size)

        return log_mbh + delta_log_mbh

    #def get_MBH(self):
    #    # AV NOTE: this is an EXTREMELY simplified derivation of the black hole
    #    # mass, assuming linear correlation with the bulge and NO scatter. Use
    #    # with EXTREME caution.
    #    A = 500
    #    return self.catalog["M"] - np.log10(A)

    def get_log_L_FIR(self):
        """SFR(M⊙ year−1) = 4.5 × 10−44 LFIR (ergs s−1) (starbursts)"""
        return np.where(
            self.catalog["PASSIVE"],
            0,
            self.catalog["SFR"] / 4.5e-44
        )

    def get_log_L_GAL_5_GHz(self):

        """
        Galaxy Radio luminosity according to

            Kennicutt+98 (eq. 4) and
            Delvecchio+21 (abstract + eq. 1)
        """

        N = self.catalog["M"].size
        A = np.random.normal(loc=2.646, scale=0.024, size=N)
        B = np.random.normal(loc=-0.023, scale=0.008, size=N)
        C = np.random.normal(loc=0.148, scale=0.013, size=N)
        q_IR = A * (1 + self.catalog["Z"]) ** B - C * (self.catalog["M"] - 10)

        return (self.catalog["log_L_FIR"] - 3.75e12) - q_IR


    def get_log_L_AGN_5_GHz(self, scatter=True):

        """
        AGN Radio luminosity according to the fundamental plane Gultekin+19
        """

        #N = self.catalog.size
        #mu0 = np.random.normal(loc=0.55, scale=0.22, size=N)
        #xi_mu_R = np.random.normal(1.09, scale=0.10, size=N)
        #xi_mu_X = np.random.normal(-0.59, scale=0.155, size=N)
        #m = self.catalog["MBH"]
        #l = self.catalog["log_LX_2_10"]

        #return (m - 8) - mu0 - xi_mu_X * (l - 40) / xi_mu_R + 38


        # Eq. 5
        #   mu = mu0 + xi_mu_R * R + xi_mu_X * X
        #   mu = log10(M / 1e8)
        #   R  = log10(L_R / 1e38)
        #   X  = log10(L_X / 1e40)

        # NOTE: check gultekin sec 4.7 eq. 19
        N = self.catalog.size
        X = self.catalog["log_LX_2_10"] - 40
        mu = self.catalog["MBH"] - 8
        if scatter:
            A = np.random.normal(-0.62, 0.16, N)
            B = np.random.normal(0.70, 0.085, N)
            C = np.random.normal(0.74, 0.06, N)
        else:
            A = np.random.normal(-0.62, 0.0, N)
            B = np.random.normal(0.70, 0.0, N)
            C = np.random.normal(0.74, 0.0, N)

        R = A + B * X + mu
        return R + 38


    # TODO:
    #def get_tau(self):
    #    return self.get_lightcurve_agn(band="lsst-r", return_tau_sf_inf=True)[0]

    #def get_sf_inf(self):
    #    return self.get_lightcurve_agn(band="lsst-r", return_tau_sf_inf=True)[1]

    def get_is_star(self):
        return ~np.isfinite(self.catalog["Z"])

    def get_is_star_single(self):
        fst = self.catalog.size - self.star.size - self.binary.size
        lst = self.catalog.size - self.binary.size
        return np.where((fst <= self.catalog["ID"]) * (self.catalog["ID"] < lst))

    def get_is_star_binary(self):
        fst = self.catalog.size - self.binary.size
        lst = self.catalog.size
        return np.where((fst <= self.catalog["ID"]) * (self.catalog["ID"] < lst))

    @staticmethod
    def _get_is_star_cepheid(label, pmode):
        return (4 <= label) * (label <= 6) * (0 <= pmode) * (pmode <= 1)

    @staticmethod
    def _get_is_star_lpv(label, pmode):
        return (7 <= label) * (label <= 8) * (0 <= pmode) * (pmode <= 4)

    def get_is_star_cepheid(self):
        ret = np.zeros_like(self.catalog, dtype=bool)
        uids = self.catalog["ID"][self.get_is_star_single()]
        for uid in uids:
            label = self.star["label"][uid - np.min(uids)]
            pmode = self.star["pmode"][uid - np.min(uids)]
            ret[uid].append(self._get_is_star_cepheid(label, pmode))
        return np.array(ret)

    def get_is_star_lpv(self, orich=True, crich=True):
        ret = np.zeros_like(self.catalog, dtype=bool)
        uids = self.catalog["ID"][self.get_is_star_single()]
        for uid in uids:
            label = self.star["label"][uid - np.min(uids)]
            pmode = self.star["pmode"][uid - np.min(uids)]
            select = (7 <= label <= 8) * (0 <= pmode <= 4)

            if crich ^ orich:
                if orich:
                    select *= self.star["c_o"][uid - np.min(uids)] <= 1
                if crich:
                    select *= self.star["c_o"][uid - np.min(uids)] > 1

            ret[uid].append(select)
        return np.array(ret)

    def get_is_galaxy(self):
        return ~(self.get_is_star() + self.get_is_agn())

    def __getitem__(self, key):
        return self.catalog[key]

    def get_observation(self, query, filename_visits=f"{ROOT}/data/baseline_v3.0_10yrs.db"):
        with sqlite3.connect(filename_visits) as con:
            df = pd.read_sql_query(query, con)
        return df

    def _find_file(self, globstr):
        ret = sorted(glob.glob(globstr, recursive=True))[::-1]
        if not ret:
            raise FileNotFoundError(f"File not found with pattern: {globstr}")
        print('\n'.join(ret))
        return ret

    def get_instance_catalog(self, observation_id):
        return self._find_file(f"{self.dirname}/imsim/{observation_id}/instance_catalog.txt")

    def get_eimage(self, observation_id, detector):
        return self._find_file(f"{self.dirname}/imsim/{observation_id}/**/eimage*det{detector:03d}*.fits")

    def get_amp(self, observation_id, detector):
        return self._find_file(f"{self.dirname}/imsim/{observation_id}/**/amp*det{detector:03d}*.fits.fz")

    def get_visit_from_observation_id(self, observation_id):

        """
        Get the visit id that can be used to query butler and/or is used in
        some filenames. The visit id follows the format

            50YYMMDDNNNNN

        where

            YYMMDD is the observation start date
            NNNNN is a zero-padded 5-digit sequence number
        """

        # NOTE: convert mjd-time to ISO YYMMDD... the closest subfmt I found
        # was YYYY-MM-DD so this is a bit ugly
        mjd = self.get_observation(
        f"""
        SELECT * FROM observations WHERE observationId={observation_id}
        """
        )["observationStartMJD"][0]
        time_str = Time(mjd, format="mjd").to_value("iso", subfmt="date").replace("-", '')[2:]
        seqnum = "%05d" % self.get_seqnum(observation_id)

        # NOTE: I have no idea who decides the 50 prefix, maybe from the instrument?
        return ''.join(["50", time_str, seqnum])

    def get_pipeline_product(self, product, observation_id=None, detector=None, suffix="fits"):

        visit = self.get_visit_from_observation_id(observation_id)

        # Convert detector to its name i.e. xx == Rxx_Syy
        if detector is not None:
            detector = my_lsst.get_detector(detector)

        items = [str(i) for i in (product + "_", visit, detector) if i is not None]
        if items:
            globstr = '*'.join(items)
            globstr = f"{self.dirname}/repo/**/*{globstr}*.{suffix}"
        else:
            globstr = f"{self.dirname}/repo/**/*.*"

        ret = self._find_file(globstr)
        return ret

    def get_calexp(self, observation_id, detector):
        return self.get_pipeline_product("calexp", observation_id, detector)

    def get_source_table(self, observation_id, detector):
        return self.get_pipeline_product("sourceTable", observation_id, detector, suffix="parq")

    def get_ra_dec_is_in_image(self, ra, dec, observation_id, detector):

        fname = self.get_eimage(observation_id, detector)
        if not fname:
            return None

        fits, header = fitsio.read(fname[0], header=True)
        wcs = WCS(header)
        row, col = wcs.world_to_pixel(SkyCoord(ra, dec, unit="deg"))

        return (0 <= row) * (row < fits.shape[0]) * (0 <= col) * (col < fits.shape[1])

    def get_truth_objects_in_image(self, observation_id, detector):
        ra = self.catalog["RA"]
        dec = self.catalog["DEC"]
        return self.get_ra_dec_is_in_image(ra, dec, observation_id, detector)

    def get_ra_dec_observation_id_detector(self, ra, dec, query=None):

        ret = []
        if query is None:
            query = """SELECT * FROM observations"""

        for observation_id in self.get_observation(query)["observationId"]:
            for detector in np.arange(189):
                try:
                    test = self.get_ra_dec_is_in_image(ra, dec, observation_id, detector)
                    ret.append((observation_id, detector))
                    break
                except FileNotFoundError:
                    continue

        return ret

    def get_row_col_ra_dec(self, ra, dec, pm_ra_cosdec, pm_dec, observation_id, detector):

        """
        Return the row and column of a given ra, dec taking into account proper motion
        """

        observation = self.get_observation(f"""SELECT * FROM observations WHERE observationId={observation_id}""")
        mjd = observation["observationStartMJD"]
        ra, dec = util.get_ra_dec(ra, dec, pm_ra_cosdec, pm_dec, mjd=mjd)

        filename_calexp = self.get_calexp(observation_id, detector)
        fits, header = fitsio.read(filename_calexp, header=True)
        wcs = WCS(header)

        coord = SkyCoord(ra, dec, unit="deg")
        return wcs.world_to_pixel(coord)


    def _match(self, observation_id, detector, order12=True, catalog=None):

        from astropy.table import Table
        from astropy.coordinates import SkyCoord
        from astropy.coordinates import match_coordinates_sky
        if catalog is None:
            catalog = self.catalog
        coord1 = SkyCoord(catalog["RA"], catalog["DEC"], unit="deg")

        filename = self.get_source_table(observation_id, detector)[0]
        source = Table.read(filename)
        coord2 = SkyCoord(source["coord_ra"], source["coord_dec"], unit="deg")

        if order12:
            idx, sep2d, _ = match_coordinates_sky(coord1, coord2)
        else:
            idx, sep2d, _ = match_coordinates_sky(coord2, coord1)

        return idx, sep2d

    def get_match_truth_to_source(self, observation_id, detector, catalog=None):
        return self._match(observation_id, detector, order12=True, catalog=catalog)

    def get_match_source_to_truth(self, observation_id, detector):
        return self._match(observation_id, detector, order12=False, catalog=catalog)

    def get_match_to_truth(self, observation_id, detector):

        from astropy.table import Table
        from astropy.coordinates import SkyCoord
        from astropy.coordinates import match_coordinates_sky

        coord1 = SkyCoord(self.catalog["RA"], self.catalog["DEC"], unit="deg")

        filename = self.get_source_table(observation_id, detector)[0]
        source = Table.read(filename)
        coord2 = SkyCoord(source["coord_ra"], source["coord_dec"], unit="deg")

        idx, sep2d, _ = match_coordinates_sky(coord2, coord1)

        return idx, sep2d

    def get_lightcurve_cepheid(idxs, band):
        from CEP_lsst_quasar_project import get_lightcurve_normalized_cepheid
        tt = []
        lc = []
        for idx in idxs:
            uid_star = idx - self.galaxy["ID"].max()
            _tt, _lc = get_lightcurve_normalized_cepheid(self.star[uid_star], band)
            tt.append(_tt)
            lc.append(_lc)
        return np.array(tt), np.array(lc)

    @staticmethod
    def _get_radius_star(mass, log_gravity_surface):
        # Returns
        #   r = sqrt(GM / a)
        # in units of Rsun
        g = 10 ** gravity_surface * u.cm / u.s ** 2
        m = mass * u.M_sun
        return np.sqrt(c.G * m / g).to(u.R_sun)


def get_catalog_from_config(filename):
    import configparser
    config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation())
    config.read(filename)

    from egg import Egg
    egg = Egg(config["egg"])
    egg.run(overwrite=0)
    return CatalogGalaxyAGN(egg.get_filename(), **config["combined"])
