#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-02-13 16:20:51

"""
Plot Galaxy+AGN mocks for the LSST Italian AGN in-kind contribution
"""

import argparse
from itertools import product
import os
import subprocess
import sys
import re
import glob

from astropy.cosmology import FlatLambdaCDM
from scipy.stats import binned_statistic
import astropy.units as u
import fitsio
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
import pandas as pd

import plambda
import util
from smf import StellarMassFunctionCOSMOS2020AGN
from my_lsst import limiting_magnitude_10yr, seeing
import bongiorno2016
import sed

ZBINS = StellarMassFunctionCOSMOS2020AGN().get_zbins("quiescent")

COSMO70  = FlatLambdaCDM(Om0=0.30, H0=70)
COSMO100 = FlatLambdaCDM(Om0=0.30, H0=100)


def get_figure(rows, cols, *args, **kwargs):
    return plt.subplots(rows, cols, figsize=(cols * 6.4, rows * 4.8), *args, **kwargs)


class Plot:

    def __init__(self, dirname, catalog):

        print("Initializing the plotter...")
        self.dirname = dirname
        self.catalog = catalog

        if not os.path.exists(dirname):
            os.makedirs(dirname)

    def plot(self):

        print("Plotting everything...")

        # Call all the plot_ functions in sequence
        for p in sorted(dir(self)):

            if not p.startswith("plot_"):
                continue

            fun = getattr(self, f"{p}")
            if not callable(fun):
                continue

            savefig = f"{self.dirname}/{p.replace('plot_', '')}.pdf"
            if os.path.exists(savefig):
                continue

            print("Calling...", p)
            fun()
            plt.savefig(savefig)
            plt.close("all")

    def _get_c(self, zlo, zhi, t="all", mlo=0.0, mhi=np.inf, llo=0.0, lhi=np.inf, is_agn=None, is_optical_type1=None, is_optical_type2=None):
        select = (
            np.isfinite(self.catalog.catalog["Z"]) *
            (zlo <= self.catalog.catalog["Z"]) * (self.catalog.catalog["Z"] < zhi) *
            (mlo <= self.catalog.catalog["M"]) * (self.catalog.catalog["M"] < mhi) *
            (llo <= self.catalog.catalog["log_LX_2_10"]) * (self.catalog.catalog["log_LX_2_10"] < lhi)
        )
        if t != "all":
            select *= self.catalog.catalog["PASSIVE"] == (t == "quiescent")

        if is_agn is True: select *= self.catalog.catalog["is_agn"]
        if is_optical_type2 is True: select *= self.catalog.catalog["is_optical_type2"]
        if is_optical_type1 is True: select *= ~self.catalog.catalog["is_optical_type2"]

        return self.catalog.catalog[select]

    def plot_stellar_mass_function_quiescent_starforming(self):

        dmbin = 0.10
        mbins = np.arange(8.5, 12.5 + dmbin / 2, dmbin)
        mcens = 0.5 * (mbins[:-1] + mbins[1:])

        fig, axes = get_figure(4, 3)
        axes[ 0,  0].remove()
        axes[-1, -1].remove()

        for ax, (zlo, zhi) in zip(axes.flatten(), ZBINS):

            for is_agn in False, True:

                for t in "starforming", "quiescent":

                    idx = (t == "quiescent") + 2 * is_agn
                    color = "C%d" % idx
                    match idx:
                        case 0: label = "Starforming galaxy"
                        case 1: label = "Quiescent galaxy"
                        case 2: label = "Starforming AGN"
                        case 3: label = "Quiescent AGN"

                    # Plot the EGG SMF
                    key = "ACTIVE" if t == "starforming" else "PASSIVE"
                    from egg import Egg
                    x, y = Egg.get_smf(0.5 * (zlo + zhi), key, "opt/egg/share/mass_func_cosmos2020.fits")
                    if is_agn:
                        __, y2 = Egg.get_smf(0.5 * (zlo + zhi), key, "opt/egg/share/mass_func_cosmos2020_agn.fits")
                        y = y2 - y
                    ax.plot(x, np.log10(y), color=color)

                    # Plot the Mock SMF
                    c = self._get_c(zlo=zlo, zhi=zhi, t=t)
                    w = c["is_agn_ctn"] if is_agn else None
                    x, dx, y, dy = util.get_key_function(mbins, c["M"], w, zmin=zlo, zmax=zhi, area_deg2=self.catalog.area)
                    y, dy1, dy2 = util.get_log_y_lo_hi(y, dy)
                    ax.errorbar(mcens, y, (dy1, dy2), marker='.', linestyle='none', color=color, label=label)


            ax.set_xlim(9.0, 12.5)
            ax.set_ylim(-8, -1)
            ax.set_xlabel(r"$\log M_{\rm star}$ [Msun]")
            ax.set_ylabel(r"$\log \Phi_{\rm star}$ [1/Mpc$^3$/dex]")
            ax.text(0.90, 0.90, r"$z = %.1f - %.1f$" % (zlo, zhi), transform=ax.transAxes, horizontalalignment="right")
            ax.legend(loc="lower left")


#    def plot_lambda_sar_function_quiescent_starforming(self):
#
#        fig, axes = get_figure(3, 4)
#
#        for ax, (zlo, zhi) in zip(axes.flatten(), ZBINS):
#
#            for t, color in [
#                ("all", 'k'),
#                ("quiescent", 'red'),
#                ("star-forming", 'blue'),
#            ]:
#
#                c = self._get_c(zlo=zlo, zhi=zhi, t=t)
#
#                # Galaxy SMF
#                kwargs = {
#                    "bins": np.linspace(32, 36, 41),
#                    "x": c["log_lambda_SAR"],
#                    "zmin": zlo,
#                    "zmax": zhi,
#                    "area_deg2": self.catalog.area
#                }
#                x, dx, y, dy = util.get_key_function(**kwargs)
#                ax.errorbar(x, y, xerr=dx, yerr=dy, marker='.', linestyle="none", color=color, label="Mock galaxy " + t)
#
#                # AGN SMF
#                kwargs = {
#                    "bins": np.linspace(32, 36, 41),
#                    "x": c["log_lambda_SAR"][c["is_agn"]],
#                    "zmin": zlo,
#                    "zmax": zhi,
#                    "area_deg2": self.catalog.area
#                }
#                x, dx, y, dy = util.get_key_function(**kwargs)
#                ax.errorbar(x, y, xerr=dx, yerr=dy, marker='x', linestyle="none", color=color, label="Mock AGN " + t)
#
#                ax.semilogy()
#                ax.set_xlabel(r"$\log M_\star \ [M_\odot]$")
#                ax.set_xlabel(r"$\Phi_{M_\star} \ [\mathrm{Mpc}^{-3}\,\mathrm{dex}^{-1}$")
#                ax.set_ylim(1e-8, 1e-1)
#
#
#
#    #def _plot_function(
#    #        self,
#    #        zbins,
#    #        key,
#    #        bins,
#    #        types,
#    #        xlim,
#    #        ylim,
#    #        xlabel,
#    #        ylabel,
#    #        plot_galaxy=False,
#    #        plot_agn=False,
#    #        plot_agn_ctn=False,
#    #        plot_agn_ctk=False,
#    #        func_ref=None,
#    #):
#
#    #    fig, axes = get_figure(len(zbins), 1)
#
#    #    for ax, zbin in zip(axes.flatten(), zbins):
#
#    #        # Plot the quiescent/star-forming functions
#    #        for t in types:
#
#    #            c = self._get_c(*zbin, t=t)
#
#    #            color = {
#    #                "all": "black",
#    #                "quiescent": "orange",
#    #                "star-forming": "purple",
#    #            }[t]
#
#    #            # Plot the reference function
#    #            if func_ref is not None:
#    #                y1 = func_ref(bins, zbin[0], t)
#    #                y2 = func_ref(bins, zbin[1], t)
#    #                ax.fill_between(bins, y1, y2, alpha=0.20, color=color)
#
#    #            kwargs = {
#    #                "bins": bins,
#    #                "x": c[key],
#    #                "zmin": zbin[0],
#    #                "zmax": zbin[1],
#    #                "area_deg2": self.catalog.area
#    #            }
#
#    #            # Plot the galaxy/AGN functions
#    #            for marker, flag, values, label in [
#    #                (".", plot_galaxy, None, "Galaxy"),
#    #                ("x", plot_agn, c["duty_cycle"], "AGN"),
#    #                ("v", plot_agn_ctn, c["duty_cycle"] * ~c["is_agn_ctn"], "CTN AGN"),
#    #                ("^", plot_agn_ctk, c["duty_cycle"] *  c["is_agn_ctk"], "CTK AGN"),
#    #            ]:
#    #                    if not flag: continue
#    #                    x, dx, y, dy = util.get_key_function(values=values, **kwargs)
#    #                    ax.errorbar(x, y, xerr=dx, yerr=dy, marker=marker, linestyle="none", color=color, label=label)
#
#    #        ax.text(0.90, 0.90, r"$%.2f \leq z < %.2f$" % zbin, transform=ax.transAxes, horizontalalignment="right")
#    #        ax.set_xlabel(xlabel)
#    #        ax.set_ylabel(ylabel)
#    #        ax.set_xlim(xlim)
#    #        ax.set_ylim(ylim)
#    #        ax.semilogy()
#    #        ax.legend()
#
#
#    ##def plot_stellar_mass_function_galaxy_agn(self):
#
#    ##    smf = StellarMassFunctionCOSMOS2020AGN()
#    ##    func_ref = lambda m, z, t: smf.get_stellar_mass_function(10 ** m, z, t)
#
#    ##    self._plot_function(
#    ##        zbins=ZBINS,
#    ##        key="M",
#    ##        bins=np.arange(8.5, 13.0 + 1e-6, 0.10),
#    ##        types=["all", "quiescent", "star-forming"],
#    ##        xlim=(8.5, 13.0),
#    ##        ylim=(1e-8, 1e-1),
#    ##        xlabel=r"$\log M_*\ [M_\odot]$",
#    ##        ylabel=r"$\Phi(M_*)\ [1/\mathrm{Mpc}^3/\mathrm{dex}]$",
#    ##        func_ref=func_ref,
#    ##        plot_galaxy=True,
#    ##        plot_agn=True,
#    ##    )
#
#    #def plot_lambda_sar_function(self):
#    #    from bongiorno2016 import get_Phi_lambda_SAR
#    #    func_ref = lambda l, z, t: get_Phi_lambda_SAR(10 ** l, z)
#    #    self._plot_function(
#    #        zbins=ZBINS,
#    #        key="log_lambda_SAR",
#    #        bins=np.arange(32.0, 36.0 + 1e-6, 0.10),
#    #        types=["all"],
#    #        xlim=(32, 36),
#    #        ylim=(1e-8, 1e-1),
#    #        xlabel=r"$\log \lambda_\mathrm{SAR}$",
#    #        ylabel=r"$\Phi(\lambda_\mathrm{SAR})\ [1/\mathrm{Mpc}^3/\mathrm{dex}]$",
#    #        func_ref=func_ref,
#    #        plot_agn_ctn=True,
#    #        plot_agn_ctk=False,
#    #    )

    def plot_xray_luminosity_function(self):

        import bongiorno2016
        import util
        import xlf

        dbin = 0.10
        bins = np.arange(42, 46 + dbin / 2, dbin)
        cens = 0.5 * (bins[:-1] + bins[1:])

        miy15 = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat")

        fig, axes = plt.subplots(4, 3, figsize=(3 * 6.4, 4 * 4.8))
        axes[ 0,  0].remove()
        axes[-1, -1].remove()

        for ax, (zlo, zhi, zcen) in zip(axes.flatten(), np.unique(miy15[:, :3], axis=0)):

            print(zlo, zhi, zcen)

            # Plot Bon+16
            xlf_bon16_0 = bongiorno2016.get_Phi_lx(cens, zcen)
            xlf_bon16_1 = bongiorno2016.get_Phi_lx(cens, zlo)
            xlf_bon16_2 = bongiorno2016.get_Phi_lx(cens, zhi)
            _y1 = np.log10(xlf_bon16_1)
            _y2 = np.log10(xlf_bon16_2)
            y0 = np.log10(xlf_bon16_0)
            y1 = np.minimum(_y1, _y2)
            y2 = np.maximum(_y1, _y2)
            label = "Bon+16" if zcen < 2.50 else "Bon+16 + Hi-z"
            ax.plot(cens, y0, label=label)
            ax.fill_between(cens, y1, y2, alpha=0.20)

            # Plot the mock 
            c = self._get_c(zlo=zlo, zhi=zhi)
            x, dx, y, dy = util.get_key_function(bins, c["log_LX_2_10"], c["is_agn_ctn"], zmin=zlo, zmax=zhi, area_deg2=self.catalog.area)
            y, dy1, dy2 = util.get_log_y_lo_hi(y, dy)
            ax.errorbar(cens, y, (dy1, dy2), linestyle="none", marker=".", label="Mock CTN AGN")

            # Plot the mock with a B band cut
            for BB in 17, 18:
                B = util.flux_to_mag(c["lsst-g_total"])
                select = B >= BB
                x, dx, y, dy = util.get_key_function(bins, c["log_LX_2_10"][select], c["is_agn_ctn"][select], zmin=zlo, zmax=zhi, area_deg2=self.catalog.area)
                y, dy1, dy2 = util.get_log_y_lo_hi(y, dy)
                ax.errorbar(cens, y, (dy1, dy2), linestyle="none", marker=".", label=r"Mock CTN AGN $B \geq %d$" % BB)

            # Plot Miy+15 data points
            select = miy15[:, 2] == zcen
            y, dy1, dy2 = util.get_log_y_lo_hi(miy15[select, -2], miy15[select, -1])
            ax.errorbar(miy15[select, 5], y, (dy1, dy2), linestyle="none", marker='s', fillstyle="none", label="Miy+15")

            # Plot Miy+15 2PL model
            miy15_2pl = np.loadtxt("data/miyaji2015/2PL_miyaji2015_table3.dat")
            select = miy15_2pl[:, 0] == zlo
            A_44 = miy15_2pl[select, 3]
            log_L_star  = miy15_2pl[select, 7]
            gamma_1  = miy15_2pl[select, 10]
            gamma_2  = miy15_2pl[select, 13]
            div1 = 1e44 / 10 ** log_L_star
            div2 = 10 ** cens / 10 ** log_L_star
            xlf_miy15_2pl = A_44 * np.ma.true_divide(div1 ** gamma_1 + div1 ** gamma_2, div2 ** gamma_1 + div2 ** gamma_2)
            ax.plot(cens, np.log10(xlf_miy15_2pl), label="Miy+15 2PL", linestyle='dotted')

            # Plot Ueda+2014
            filename_ueda = sorted(glob.glob("opt/quasarlf/data/ueda2014/hard*.dat"))
            zmin = np.array([float(re.findall("z(.*)-(.*).dat", f)[0][0]) for f in filename_ueda])
            zmax = np.array([float(re.findall("z(.*)-(.*).dat", f)[0][1]) for f in filename_ueda])
            zcen_ueda = (zmin * zmax) ** .5
            idx = np.argmin(np.abs(zcen - zcen_ueda))
            x, _y, _y1, _y2 = np.loadtxt(filename_ueda[idx]).T
            y = np.log10(_y)
            y_lo = np.log10(_y1)
            y_hi = np.log10(_y2)
            ax.errorbar(x, y, (y - y_lo, y_hi - y), label="Ued+14", marker='.', linestyle='none')

            # Set the text etc..
            ax.set_xlim(42, 46)
            ax.set_ylim(-10, -3)
            ax.text(0.9, 0.9, r"$z = %.1f - %.1f$" % (zlo, zhi), transform=ax.transAxes, horizontalalignment="right")
            ax.legend(loc="lower left")
            ax.set_xlabel(r"$\log L_{\rm X}$ [erg/s]")
            ax.set_ylabel(r"$\log \Phi_{\rm X}$ [1/Mpc$^3$/dex]")

#    def plot_xray_luminosity_function(self):
#
#        miyaji15 = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat", usecols=(0, 1, 5, -2, -1))
#        zbins = np.sort(np.unique(np.concatenate([miyaji15[:, 0], miyaji15[:, 1]])))
#
#        fig, axes = get_figure(4, 3, tight_layout=True)
#        axes[-1, -1].remove()
#
#        for ax, zmin, zmax in zip(axes.flatten(), zbins[:-1], zbins[1:]):
#
#            # Plot the mock
#            c = self._get_c(zlo=zmin, zhi=zmax, t="all")
#            x, dx, y, dy = util.get_key_function(
#                np.linspace(42, 48, 61),
#                c["log_LX_2_10"][c["is_agn_ctn"]],
#                zmin=zmin,
#                zmax=zmax,
#                area_deg2=self.catalog.area
#            )
#            ax.errorbar(x, y, xerr=dx, yerr=dy, linestyle="none", marker='.', label="Mock CTN AGN")
#
#            # Plot Miyiaji+15
#            select_miyaji = (miyaji15[:, 0] == zmin) * (miyaji15[:, 1] == zmax)
#            ax.errorbar(
#                miyaji15[select_miyaji][:, 2],
#                miyaji15[select_miyaji][:, 3],
#                yerr=miyaji15[select_miyaji][:, 4],
#                marker='.',
#                linestyle="none",
#                label="Miyaji+ 15",
#            )
#            ax.semilogy()
#            ax.set_title(r"$%.3f < z < %.3f$" % (zmin, zmax))
#
#            ## Plot Bon+16 estimate
#            #from bongiorno2016 import get_Phi_lx
#            #phi_bon16_lo = get_Phi_lx(x, zmin)
#            #phi_bon16_hi = get_Phi_lx(x, zmax)
#            #ax.fill_between(x, phi_bon16_lo, phi_bon16_hi, label="AGN XLF Bon+16", alpha=0.10)
#
#        for ax in axes[-1, :]:
#            ax.set_xlabel(r"$\log L_\mathrm{X}\ [\mathrm{erg}\,\mathrm{s}^{-1}]$")
#        for ax in axes[:, 0]:
#            ax.set_ylabel(r"$\Phi_{L_\mathrm{X}}\ [\mathrm{Mpc}^{-3}\,\mathrm{dex}^{-1}]$")
#        for ax in axes.flatten():
#            ax.legend()
#            ax.set_xlim(42, 48)
#            ax.set_ylim(1e-10, 1e-2)
#
#    def plot_bivariate_distribution_function(self):
#
#        import bongiorno2016
#
#        fig, axes = get_figure(len(ZBINS), 3)
#        dz = 0.20
#        dm = 0.10
#        dl = 0.10
#
#        mvec = np.arange(8.5, 12.5 + 1e-6, dm)
#        lvec = np.arange(32.0, 36.0 + 1e-6, dl)
#        m, l = np.meshgrid(mvec[:-1] + dm / 2, lvec[:-1] + dl /2)
#        norm = mpl.colors.LogNorm(vmin=1e-10, vmax=1e-2)
#        cmap = mpl.cm.Oranges
#
#        for i, (zlo, zhi) in enumerate(ZBINS):
#            c = self._get_c(zlo, zhi, "all")
#            volume = util.get_volume(zmin=zlo, zmax=zhi, area_deg2=self.catalog.area)
#            Psi1 = axes[i, 0].hist2d(
#                    c["M"],
#                    c["log_lambda_SAR"],
#                    bins=(mvec, lvec),
#                    weights=c["is_agn_ctn"] / volume / dm / dl,
#                    norm=norm,
#                    cmap=cmap,
#            )[0].T
#
#            Psi2 = bongiorno2016.get_Psi(10 ** m, 10 ** l, np.mean([zlo, zhi]))
#            im = axes[i, 1].imshow(
#                    Psi2,
#                    extent=(8.5, 12.5, 32, 36),
#                    aspect="auto",
#                    norm=norm,
#                    cmap=cmap,
#            )
#            plt.colorbar(im, ax=axes[i, 1])
#            counts = np.histogram2d(c["M"], c["log_lambda_SAR"], bins=(mvec, lvec))[0].T
#            relerr = np.ma.true_divide((Psi1 - Psi2), Psi2)
#            relerr = np.where(counts < 10, np.nan, relerr)
#            axes[i, 2].imshow(
#                    relerr,
#                    vmin=-1.00,
#                    vmax=1.00,
#                    cmap=mpl.cm.coolwarm,
#                    extent=(8.5, 12.5, 32, 36),
#            )
#
#            axes[i, 1].set_title(f"{zlo} < z < {zhi}")
#
#    def plot_plambda(self):
#
#        from bongiorno2016 import get_plambda
#        from plambda import get_plambda_bon16_air18
#
#        fig, axes = get_figure(len(ZBINS), 4)
#        lbins = np.linspace(32, 36, 41)
#        lvec = np.linspace(32, 36, 401)
#
#        for row, (zlo, zhi) in enumerate(ZBINS):
#
#            for col, mlo, mhi in [
#                (0, 8.50, 9.50),
#                (1, 9.50, 10.5),
#                (2, 10.5, 11.5),
#                (3, 11.5, 12.5),
#            ]:
#
#                pt = bongiorno2016.get_plambda(mlo, mhi, (zlo + zhi) / 2, lvec)
#                _, pq, ps = get_plambda_bon16_air18(mlo, mhi, (zlo + zhi) / 2, lvec)
#
#                print(row, zlo, zhi, mlo, mhi)
#
#                for color, t, p in [
#                    ("black", "all",          pt),
#                    ("red",   "quiescent",    pq),
#                    ("blue",  "star-forming", ps),
#                ]:
#
#                    fact = 1
#                    if t != "all":
#                        fact = 0.01
#                    axes[row, col].plot(lvec, p / fact, color=color, linestyle="dotted", linewidth=2)
#
#                    c = self._get_c(zlo, zhi, t, mlo, mhi)
#                    if c.size == 0:
#                        continue
#
#                    weights = np.full(c.size, 1 / c.size / 0.1)
#                    ret = axes[row, col].hist(
#                        c["log_lambda_SAR"],
#                        bins=lbins,
#                        weights=weights,
#                        color=color,
#                        histtype="step",
#                        linestyle="solid",
#                    )[0]
#
#                    axes[row, col].semilogy()
#                    axes[row, col].set_xlim(32, 36)
#                    axes[row, col].set_ylim(1e-4, 1e1)
#                    axes[row, col].set_xlabel("log_lambda_SAR")
#                    axes[row, col].set_ylabel("p(lambda_SAR | Mstar, z) [1/dex]")
#
#    def plot_logn_logs(self):
#
#        logSbins = np.linspace(-20, -10, 101)
#        Sbins = 10 ** logSbins
#        Scens = (Sbins[:-1] + Sbins[1:]) / 2
#        logScens = np.log10(Scens)
#
#        fig = plt.figure(figsize=(8 * 0.6, 9 * 0.6))
#        ax = fig.gca()
#
#        # Plot Luo+17 CDF-S 7 Ms
#        from logn_logs import get_N_above_S_AGN
#        Scens, logN_logS = get_N_above_S_AGN(
#            Sbins * u.erg / u.cm ** 2 / u.s,
#            Emin=2.0 * u.keV,
#            Emax=7.0 * u.keV
#        )
#
#        # Plot Luo+17 model
#        ax.plot(
#            np.log10(Scens.value),
#            logN_logS.value,
#            label="7 Ms CDF-S Model (Luo+ 17)",
#            zorder=10,
#            linestyle='dotted'
#        )
#
#        # Plot Luo+17 points
#        x, y = np.loadtxt("data/luo2017/Default Dataset.csv").T
#        ax.plot(np.log10(x), y, '.', label="7 Ms CDF-S Measured (Luo+ 17)", zorder=10)
#
#        def N_above_S(log_fx, area_deg2=self.catalog.area):
#            counts = []
#            error = []
#            c = self.catalog.catalog
#            counts = np.array([(log_fx > S).sum() for S in logSbins[:-1]])
#            error = counts ** .5
#            return logScens, .5 * np.diff(logSbins), counts / area_deg2, error / area_deg2
#
#        def get_N_above_S_aird2015(log_fx):
#            # Scale aird+2015 sensitivity curvesto the area of the mock
#            aird2015 = np.loadtxt("data/aird2015/aird2015_fig1.dat")
#            log_flux_aird2015 = aird2015[:, 0]
#            area_aird2015 = np.sum(aird2015[:, 1:], axis=1)
#            area_aird2015 *= self.catalog.area / area_aird2015.max()
#
#            area = np.interp(log_fx, log_flux_aird2015, area_aird2015)
#            logn_logs = []
#            dlogn_logs = []
#
#            for logS in logSbins[:-1]:
#                select = log_fx > logS
#                logn_logs.append(
#                    (np.ones_like(log_fx)[select] / area[select]).sum()
#                )
#                dlogn_logs.append(1)
#            return logScens, .5 * np.diff(logSbins), logn_logs, dlogn_logs
#
#        if False:
#            # Plot Gilli+07 mock
#            m = fitsio.read("data/g07_a10deg2.fits", ext=1)
#            x, dx, y, dy = N_above_S(np.log10(m["fh"]), area_deg2=10)
#            ax.errorbar(x, y, xerr=dx, yerr=dy, marker='s', linestyle="none", label="Gilli+ 07", fillstyle='none')
#
#        if False:
#            # Plot Luo+17 catalog
#            # NOTE: for some reason fitsio fails to open the file
#            from astropy.io import fits
#            m = fits.open("data/CDFS_Xraycatalog_Luo2017.fit")[1].data
#            x, dx, y, dy = N_above_S(
#                np.log10(m["FHB"]),
#                area_deg2=484.2 * u.arcmin.to(u.deg) ** 2
#            )
#            ax.errorbar(x, y, xerr=dx, yerr=dy, marker='d', linestyle="none", label="Luo+ 17 catalog", fillstyle='none')
#
#        # NOTE: luo+17 assumes Gamma=1.4 while Aird+ assume Gamma=1.9. This
#        # does not affect the logN-logS drastically
#        log_fx_2_7 = self.catalog.catalog["log_FX_2_7"]
#
#        select0 = self.catalog.catalog["is_agn_ctn"]
#        select1 = self.catalog.catalog["is_agn_ctn"] * (self.catalog.catalog["log_LX_2_10"] > 42)
#        for s, k in [
#            (select0, "Mock CTN AGN"),
#            (select1, "Mock CTN AGN logLX>42"),
#        ]:
#            x, dx, y, dy = N_above_S(log_fx_2_7[s])
#            ax.fill_between(x, y - dy, y + dy, alpha=0.3, label=k, color='red')
#
#        # Plot bon+16 theory
#        from util import convert_flux
#        x, y1, y2 = np.loadtxt("src/bon16/logn_logs.txt", usecols=(0, 1, -1)).T
#        x = np.log10(convert_flux(x, 2, 10, 2, 7, Gamma=1.4))
#        #ax.plot(x, y1, linestyle="dashed", label="Bon+16 theory logM > 9.5")
#        #ax.plot(x, y1, linestyle="dashed", label="Bivariate distribution (Bon+16)")
#        ax.plot(x, y2, '.', label="XMM-COSMOS (Cap+09)")
#
#        # Plot flux limits for XMM-COSMOS
#        flim_cosmos = convert_flux(9.3e-15, 2, 10, 2, 7)
#        ax.axvline(np.log10(flim_cosmos), linestyle='dotted')
#        ax.text(np.log10(flim_cosmos), 1e3, "XMM-COSMOS limit", rotation=90, fontsize="small")
#
#        # Plot flux limits for CDF-S 7 Ms
#        flim_cdfs = 2.7e-17
#        ax.axvline(np.log10(flim_cdfs), linestyle='dotted')
#        ax.text(np.log10(flim_cdfs), 5e2, "CDF-S 7 Ms limit", rotation=90, fontsize="small")
#
#        ax.set_xlim(-17.4, -13)
#        ax.set_ylim(1e0, 1e5)
#        ax.set_ylim(5, 60000)
#        ax.set_xlabel("log S [2-7 keV, erg/cm2/s]")
#        ax.set_ylabel("N(>S) [1/deg2]")
#        ax.legend(loc="lower left")
#        ax.semilogy()
#
#        return fig, ax
#
#
##    def plot_stellar_mass_function_quiescent_starforming(self):
##
##        from smf import (
##            StellarMassFunctionCandels,
##            StellarMassFunctionCOSMOS2020,
##            StellarMassFunctionBongiorno2016,
##            StellarMassFunctionAird2018
##        )
##
##        mvec = np.linspace(8, 12, 41)
##        smf_ref = StellarMassFunctionCOSMOS2020()
##
##        types = ["all", "star-forming", "quiescent"]
##
##        nrows = len(smf_ref.get_zbins())
##        ncols = len(types)
##
##        fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6.4, nrows * 4.8), sharey=False)
##
##        for row, zbin in enumerate(smf_ref.get_zbins()):
##
##            z = np.mean(zbin)
##
##            # Redshift selection
##            select_z = np.ones_like(self.catalog.catalog["Z"], dtype=bool)
##            select_z *= zbin[0] <= self.catalog.catalog["Z"]
##            select_z *= self.catalog.catalog["Z"] < zbin[1]
##
##            for col, t in enumerate(types):
##
##                ax = axes[row, col]
##
##                # Plot all the reference stellar mass functions
##                for s in (
##                        StellarMassFunctionCandels(),
##                        StellarMassFunctionCOSMOS2020(),
##                        StellarMassFunctionBongiorno2016(),
##                        #StellarMassFunctionAird2018(),
##                ):
##
##                    offset = 0
##                    if s.imf == "salpeter":
##                        offset -= 0.24
##
##                    ax.plot(mvec + offset, s.get_stellar_mass_function(10 ** mvec, z, t), label=s.name)
##
##                # SF/Quiescent selection
##                if t == "star-forming":
##                    select_type = ~self.catalog.catalog["PASSIVE"]
##                elif t == "quiescent":
##                    select_type = self.catalog.catalog["PASSIVE"]
##                else:
##                    select_type = np.ones_like(self.catalog.catalog["Z"], dtype=bool)
##
##                # Plot the mock galaxy stellar mass function
##                for is_agn in False, True:
##
##                    # AGN
##                    select_agn = np.ones_like(self.catalog.catalog["Z"], dtype=bool)
##                    if is_agn:
##                        select_agn = self.catalog.catalog["is_agn_lambda"]
##                    c = self.catalog.catalog[select_z * select_type * select_agn]
##
##                    x, dx, y, dy = util.get_key_function(mvec, c["M"], zmin=zbin[0], zmax=zbin[1], area_deg2=self.catalog.area)
##                    label = "mock " + ["galaxy", "agn"][is_agn]
##                    #markersize = 10 - 2 * is_agn
##                    ax.errorbar(x, y, xerr=dx, yerr=dy, linestyle="none", marker='.', label=label)
##
##                    ax.semilogy()
##                    ax.set_title(t)
##                    ax.set_ylim(1e-8, 1e-1)
##                    ax.text(0.9, 0.9, "%.2f < z < %.2f" % zbin, transform=ax.transAxes, horizontalalignment="right")
##                    ax.legend(loc="lower left", fontsize="x-small")
##                    ax.set_xlabel("logM [Msun]")
##                    ax.set_ylabel("Phi [1/Mpc3/dex]")
##
##    def plot_xray_luminosity_function2(self):
##
##        miyaji15 = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat", usecols=(0, 1, 5, -2, -1))
##        zbins = np.sort(np.unique(np.concatenate([miyaji15[:, 0], miyaji15[:, 1]])))
##
##        figsize = 8 * 0.8, 9 * 0.8
##        fig = plt.figure(figsize=figsize)
##        ax = fig.gca()
##
##        for i, (zmin, zmax) in enumerate(zip(zbins[:-1], zbins[1:])):
##
##            print(zmin, zmax)
##            offset = len(zbins) - 1 - i
##            #color = mpl.cm.gist_heat_r(mpl.colors.Normalize(-1, 6.00)(0.5 * (zmin + zmax)))
##            color = mpl.cm.gist_rainbow_r(mpl.colors.Normalize(0, 5.0)(0.5 * (zmin + zmax)))
##
##            # Plot Miy+15
##            select_miyaji = (miyaji15[:, 0] == zmin) * (miyaji15[:, 1] == zmax)
##            y = np.log10(miyaji15[select_miyaji][:, 3]) + offset
##            ylo = np.log10(miyaji15[select_miyaji][:, 3] - miyaji15[select_miyaji][:, 4]) + offset
##            yhi = np.log10(miyaji15[select_miyaji][:, 3] + miyaji15[select_miyaji][:, 4]) + offset
##
##            ax.errorbar(
##                miyaji15[select_miyaji][:, 2],
##                y,
##                yerr=(y - ylo, yhi - y),
##                marker='.',
##                linestyle="none",
##                #label="Miyaji+ 15" if i == len(zbins) - 2 else "",
##                color=color
##            )
##
##            # Plot Bon+16 estimate
##            from bongiorno2016 import get_Phi_lx
##            x = np.linspace(42, 46, 41)
##            phi_bon16_lo = np.log10(get_Phi_lx(x, zmin))
##            phi_bon16_hi = np.log10(get_Phi_lx(x, zmax))
##            ax.fill_between(
##                x,
##                phi_bon16_lo + offset,
##                phi_bon16_hi + offset,
##                #label="AGN XLF Bon+16" if i == len(zbins) - 2 else "",
##                label=r"$%.3f < z < %.3f$" % (zmin, zmax),
##                alpha=0.30,
##                color=color
##            )
##
##        ax.fill_between([], [], [], color='k', alpha=0.20, label="AGN XLF Bongiorno+16")
##        ax.errorbar([], [], [], marker='.', linestyle='none', color='k', label="AGN XLF Miyaji+15")
##
##        ax.legend(fontsize="x-small")
##        ax.set_ylim(-6.0, 12)
##        ax.set_xlabel("log LX 2-10 keV [cgs]")
##        ax.set_ylabel("log Phi_LX [1/Mpc3/dex] + offset")
##
##    def plot_xray_luminosity_function(self):
##
##        miyaji15 = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat", usecols=(0, 1, 5, -2, -1))
##        zbins = np.sort(np.unique(np.concatenate([miyaji15[:, 0], miyaji15[:, 1]])))
##
##        figsize = 3 * 6.4, 4 * 4.8
##        fig, axes = plt.subplots(4, 3, figsize=figsize, sharex=False, sharey=False, tight_layout=True)
##        axes[-1, -1].remove()
##
##        for ax, zmin, zmax in zip(axes.flatten(), zbins[:-1], zbins[1:]):
##            select = (zmin < self.catalog.catalog["Z"]) * (self.catalog.catalog["Z"] < zmax)
##
##            # Plot the mock
##            x, dx, y, dy = util.get_key_function(
##                np.linspace(42, 48, 31),
##                self.catalog.catalog["log_LX_2_10"][select],
##                zmin=zmin,
##                zmax=zmax,
##                area_deg2=self.catalog.area
##            )
##            ax.errorbar(x, y, xerr=dx, yerr=dy, linestyle="none", marker='.', label="AGN logN-logS (Bon+16)")
##            #ax.fill_between(x, y - dy, y + dy, alpha=0.5, labal="mock AGN")
##
##            # Plot Miy+15
##            select_miyaji = (miyaji15[:, 0] == zmin) * (miyaji15[:, 1] == zmax)
##            ax.errorbar(
##                miyaji15[select_miyaji][:, 2],
##                miyaji15[select_miyaji][:, 3],
##                yerr=miyaji15[select_miyaji][:, 4],
##                marker='.',
##                linestyle="none",
##                label="Miyaji+ 15",
##            )
##            ax.semilogy()
##            ax.set_title("%.3f < z < %.3f" % (zmin, zmax))
##
##            # Plot Bon+16 estimate
##            from bongiorno2016 import get_Phi_lx
##            phi_bon16_lo = get_Phi_lx(x, zmin)
##            phi_bon16_hi = get_Phi_lx(x, zmax)
##            ax.fill_between(x, phi_bon16_lo, phi_bon16_hi, label="AGN XLF Bon+16", alpha=0.10)
##
##            # chisq test
##            y = np.interp(miyaji15[select_miyaji][:, 2], x, y)
##            print("xlf", zmin, zmax, util.get_chisq_nu(y, miyaji15[select_miyaji, 3], miyaji15[select_miyaji, 4]))
##
##            for ax in axes[-1, :]:
##                ax.set_xlabel("log LX [h70^-2 / erg / s]")
##            for ax in axes[:, 0]:
##                ax.set_ylabel("dphi / dlogLX [h70^3 / Mpc3 / dex]")
##            for ax in axes.flatten():
##                ax.legend()
##                ax.set_xlim(42, 48)
##                ax.set_ylim(1e-10, 1e-2)

    def plot_logn_B(self):

        Bbins = np.linspace(10, 30, 201)

        fig = plt.figure()
        ax = fig.gca()

        # Define the N( < B) function
        def get_N_above(mag, B):
            ret = np.sum(mag < B)
            return ret / self.catalog.area, ret ** .5 / self.catalog.area

        # De-select stars
        select0 = ~self.catalog.get_is_galaxy()
        select1 = self.catalog.catalog["is_agn_ctn"] * ~self.catalog.catalog["is_optical_type2"]
        select2 = self.catalog.catalog["is_agn_ctn"] *  self.catalog.catalog["is_optical_type2"]
        select3 = select1 * (self.catalog.catalog["magabs_lsst-i_point"] < -22)
        select4 = select3 * (self.catalog.catalog["log_LX_2_10"] > 42)

        for band in "johnson-B", "mock-4400", "lsst-g":

            try:
                mag0 = util.flux_to_mag(self.catalog.get_flux_total(band, keys=["disk", "bulge"]))[select0]
                mag1 = util.flux_to_mag(self.catalog.get_flux_total(band, keys=["disk", "bulge", "point"]))[select1]
                mag2 = util.flux_to_mag(self.catalog.get_flux_total(band, keys=["disk", "bulge", "point"]))[select2]
                mag3 = util.flux_to_mag(self.catalog.get_flux_total(band, keys=["disk", "bulge", "point"]))[select3]
                mag4 = util.flux_to_mag(self.catalog.get_flux_total(band, keys=["disk", "bulge", "point"]))[select4]
            except:
                continue

            #for label, mag, color in [
            for label, mag in [
                #("Mock Galaxy",    mag0, None, 'black'),
                #("Mock AGN type1",                            mag1, 'blue'),
                #("Mock AGN type1 $M_i < -22$",                mag3, 'red'),
                ("Mock AGN type1 $M_i < -22$ $log L_X > 42$", mag4),
                #("Mock AGN type2", mag2, None, 'red'),
            ]:
                y1 = [] ; y2 = []
                for B in Bbins:
                    y, dy = get_N_above(mag, B)
                    y1.append(y - dy)
                    y2.append(y + dy)
                ax.fill_between(
                        Bbins,
                        np.ma.log10(y1),
                        np.ma.log10(y2),
                        alpha=0.30,
                        label=label + f" ({band})"
                )

        """
        Figure 10.7: A summary of our current understanding of the numbers of
        AGNs per square degree of sky brighter than a given apparent magnitude,
        adapted from Beck-Winchatz & Anderson (2007). The ultrafaint points are
        from the COMBO-17 survey (purple stars; Wolf et al. 2003) and HST based
        surveys (pink circles and green squares; Beck-Winchatz & Anderson 2007).
        Shown for broad comparison are: brighter 2SLAQ points
        (blue, upside-down triangles; Richards et al. 2005); a simple
        extrapolation of 2SLAQ points to ultrafaint magnitudes (solid line);
        and the Hartwick & Schade (1990) compilation (small, red triangles),
        which incorporates many earlier quasar surveys. The data show ∼ 500
        AGNs deg−2 to m < 24.5 and z < 2.1. The LSST AGN surveys will extend
        both fainter and across a much wider redshift range, suggesting a
        sample of at least ∼ 107 AGNs.
        """

        for f, label in (
            ("data/quasar_number_counts/lsst_science_book/red_fixed.dat",   "Hartwick & Schade 1990"),
            ("data/quasar_number_counts/lsst_science_book/green_fixed.dat", "HST surveys (Beck-Winchatz & Anderson 2007)"),
            ("data/quasar_number_counts/lsst_science_book/pink_fixed.dat",  "HST surveys (Beck-Winchatz & Anderson 2007)"),
        ):
            x, y, dy1, dy2 = np.loadtxt(f).T
            ax.errorbar(x, y, yerr=(dy1, dy2), linestyle="none", marker='.', label=label)

        for f, label in (
            ("data/quasar_number_counts/lsst_science_book/blue_fixed.dat",   "2SLAQ (Richards+ 2006)"),
            ("data/quasar_number_counts/lsst_science_book/purple_fixed.dat", "COMBO-17 (Wolf+ 2003)"),
        ):
            x, y = np.loadtxt(f).T
            ax.plot(x, y, '.', label=label)

        ax.set_ylabel(r"$\log N(<B)$ [1/deg2]")
        ax.set_xlabel(r"$B$")
        ax.legend(loc="upper left", fontsize="small")
        ax.set_xlim(14, 28)
        ax.set_ylim(-2.5, 3.5)

##    def plot_mstar_sfr_classification(self):
##
##        return
##
##        c = self.catalog.catalog
##        zbins = self.catalog.plambda.zbins
##
##        # Get one figure per redshift bin
##        zs = {(zbins[i], zbins[i+1]): True for i in range(zbins.size - 1)}
##        N = np.round(np.sqrt(len(zs))).astype(np.int64)
##        figsize = np.array([6.4, 4.8]) * N
##        fig, axes = plt.subplots(N, N, figsize=figsize, sharex=False, sharey=False, tight_layout=True)
##        axes = {z: ax for z, ax in zip(zs, axes.flatten())}
##        labeled = {}
##
##        first = True
##        for (i, (zmin, zmax), (mmin, mmax)), (fun, fun_d, fun_u) in self.catalog.plambda.get_funs().items():
##
##            select_z = (zmin < c["Z"]) * (c["Z"] < zmax)
##            select_m = (mmin < c["M"]) * (c["M"] < mmax)
##            select_c = c["classification"] == i
##            select = select_z * select_m * select_c
##            if not select.sum():
##                continue
##
##            ax = axes[zmin, zmax]
##
##            for j, select_egg in enumerate([
##                c["PASSIVE"],
##                ~c["PASSIVE"]
##            ]):
##
##                color_markersize_labels = [
##                    ('red',    16, "EGG quiescent / Aird+ quiescent"),
##                    ('orange', 14, "EGG star-forming / Aird+ quiescent"),
##                    ('purple', 12, "EGG quiescent / Aird+ star-forming"),
##                    ('blue',   10, "EGG star-forming / Aird+ star-forming"),
##                ]
##                index = j + 2 * i
##
##                # Downsample the catalog
##                idx = np.random.choice(
##                    (select * select_egg).sum(),
##                    size=min((select * select_egg).sum(), 1000),
##                    replace=False
##                )
##
##                ax.plot(
##                    c[select * select_egg]["M"][idx],
##                    np.log10(c[select * select_egg]["SFR"][idx]),
##                    '.',
##                    color=color_markersize_labels[index][0],
##                    label=None if (zmin, zmax, index) in labeled else color_markersize_labels[index][2],
##                    markersize=color_markersize_labels[index][1],
##                    zorder=-color_markersize_labels[index][1],
##                )
##
##                labeled[zmin, zmax, index] = True
##
##            ax.set_title("%.1f < z < %.1f" % (zmin, zmax))
##            ax.set_xlabel("log (M / Msun)")
##            ax.set_ylabel("log (SFR / Msun / yr)")
##            ax.set_xlim(9, 12)
##            ax.set_ylim(-4, 6)
##            ax.legend()
##
##
##    def plot_lambda_SAR_histogram(self):
##
##        import bongiorno2016
##        from smf import StellarMassFunctionCOSMOS2020AGN
##        zbins = StellarMassFunctionCOSMOS2020AGN().get_zbins()
##
##        fig, axes = plt.subplots(1, 3, figsize=(3 * 6.4, 4.8))
##
##        bins = np.linspace(32, 38, 31)
##        kwargs = {"marker": '.', "linestyle": "none"}
##
##        zs = [0.55, 1.20, 2.0]
##        ms = [9.5, 10.0, 10.5, 11.0, 11.5, 12.0]
##
##        c = self.catalog.catalog
##        for ax, z1 in zip(axes, zs):
##            z2 = z1 + 0.10
##            for idx, (m1, m2) in enumerate(zip(ms[:-1], ms[1:])):
##                color = "C%d" % idx
##                select = ((z1 <= c["Z"]) * (c["Z"] < z2) * (m1 <= c["M"]) * (c["M"] < m2))
##                for a, w in (
##                    (c["log_lambda_SAR"], c["duty_cycle"]),
##                ):
##                    x, dx, y, dy = util.get_key_function(
##                        bins,
##                        a[select],
##                        w[select],
##                        zmin=z1,
##                        zmax=z2,
##                        area_deg2=self.catalog.area,
##                    )
##                    ax.errorbar(x, y, xerr=dx, yerr=dy, color=color, **kwargs)
##                ax.plot(
##                    bins,
##                    bongiorno2016.get_Phi_lambda_SAR(10 ** bins, 0.5 * (z1 + z2), m1, m2),
##                    color=color
##                )
##                ax.semilogy()
##        plt.xlim(32, 38)
##        plt.ylim(1e-7, 1e-2)
##
##        #plt.plot(bongiorno2016.get_Phi_lambda_SAR(10 ** bins, z, mmin=m, mmax=m + 0.10))
##
##        return fig, ax
##
##        #figsize = Ncol * 6.4, Nrow * 4.8
##        figsize = 16 * 0.6, 9 * 0.6
##        fig, axes = plt.subplots(
##            1,
##            2,
##            figsize=figsize,
##            tight_layout=True,
##            sharex=False,
##            sharey=False
##        )
##
##        from bongiorno2016 import get_Phi_lambda_SAR
##        lvec = np.arange(32, 36, 0.10)
##        for zbin in smf_ref.get_zbins():
##            if zbin[0] == 0 or zbin[1] == 7:
##                continue
##            # Plot Bongiorno+2016
##            color = mpl.cm.viridis(
##                mpl.colors.Normalize(
##                    0.20,
##                    5.50,
##                )(np.mean(zbin))
##            )
##            phi = get_Phi_lambda_SAR(10 ** lvec, np.mean(zbin))
##            axes[0].semilogy(lvec, phi, label="%.1f < z < %.1f" % zbin, color=color)
##        axes[0].set_ylim(1e-6, 1e-2)
##        axes[0].legend()
##        axes[0].set_xlabel("log_lambda_SAR")
##        axes[0].set_ylabel("Phi_SAR")
##
##        mbin = [(8, 9), (9, 10), (10, 11), (11, 12)]
##        for idx, z in enumerate([1.0, 2.0, 3.0]):
##            linestyle = ["solid", "dashed", "dotted"][idx]
##            for i, (mmin, mmax) in enumerate(mbin):
##                phi = get_Phi_lambda_SAR(10 ** lvec, z, mmin=mmin, mmax=mmax)
##                axes[1].semilogy(lvec, phi, linestyle=linestyle, label="%2d < log M < %2d" % (mmin, mmax) if z == 1 else "", color='C%d' % i)
##        axes[1].plot([], [], color='k', linestyle='solid', label='z = 1')
##        axes[1].plot([], [], color='k', linestyle='dotted', label='z = 2')
##        axes[1].set_ylim(1e-6, 1e-2)
##        axes[1].legend()
##        axes[1].set_xlabel("log_lambda_SAR")
##        axes[1].set_ylabel("Phi_SAR")
##
##        #for (zmin, zmax), ax in zip(smf_ref.get_zbins(), axes.flatten()):
##
##        #    if False:
##        #        select = (zmin <= self.catalog.catalog["Z"]) * (self.catalog.catalog["Z"] < zmax)
##        #        l = self.catalog.catalog[select]["log_lambda_SAR"]
##        #        x, dx, y, dy = util.get_key_function(
##        #            np.linspace(32, 36, 41),
##        #            l,
##        #            zmin=zmin,
##        #            zmax=zmax,
##        #            area_deg2=self.catalog.area,
##        #        )
##        #        ax.errorbar(x, y, xerr=dx, yerr=dy, marker='.', linestyle="none")
##
##        #    # Plot Bongiorno+2016
##        #    from bongiorno2016 import get_Phi_lambda_SAR
##        #    phi = get_Phi_lambda_SAR(10 ** (lvec + 34.0), 0.5 * (zmin + zmax))
##        #    #phi = get_Phi_lambda_SAR(10 ** lvec, 0.5 * (zmin + zmax))
##        #    ax.plot(lvec + 34, phi, color='k', label="Bon+16", linestyle="dashed")
##
##        #    if False:
##        #        # Plot Bon+16 curves from paper for debugging
##        #        zmean = 0.5 * (zmin + zmax)
##        #        if False: pass
##        #        elif zmean < 0.80: filename = "z0.55"
##        #        elif zmean < 1.50: filename = "z1.15"
##        #        elif zmean < 2.50: filename = "z2.00"
##        #        x, y = np.loadtxt(f"data/bongiorno2016/fig4/{filename}.csv").T
##        #        #ax.plot(x - 34, 10 ** y, label="nearest z Bon+16 curve", linestyle="dotted")
##        #        #ax.plot(x, 10 ** y, label="nearest z Bon+16 curve", linestyle="dotted")
##
##        #    ax.set_title("%.1f < z < %.1f" % (zmin, zmax))
##        #    ax.set_xlabel("log lambda_SAR")
##        #    ax.set_ylabel("p(lambda_SAR | z, Mstar) [dex-1]")
##        #    #ax.set_xlim(-3, 2)
##        #    ax.set_xlim(32, 36)
##        #    ax.set_ylim(1e-8, 1e-2)
##        #    ax.semilogy()
##        #    ax.legend()
##
##
##    #def plot_w_theta_redshift(self, fact=1, split=1):
##
##    #    from corr import Correlation, get_region_radec
##
##    #    ## Get the H-band cut
##    #    #egg = fitsio.read("egg.fits")
##    #    #hst_f160w = egg["FLUX"][0, :, 9]
##    #    #H = 23.9 - 2.5 * np.log10(hst_f160w)
##
##    #    # Perform the selection
##    #    data = self.catalog.catalog
##    #    #data = data[(H < 23) * (9 < data["M"]) * (data["M"] < 10.3)]
##    #    #data = data[(9 < data["M"]) * (data["M"] < 10.3)]
##    #    #data = data[data["CLUSTERED"] == 1]
##
##    #    fig, axes = plt.subplots(
##    #        3, 2,
##    #        tight_layout=True, sharex=False, sharey=False,
##    #        figsize=(2 * 6.4, 3 * 4.8)
##    #    )
##
##    #    for ax, (zmin, zmax) in zip(
##    #        axes.flatten(),
##    #        [
##    #            (0.10, 0.70),
##    #            (0.70, 1.20),
##    #            (1.20, 1.70),
##    #            (1.70, 2.40),
##    #            (2.40, 3.40),
##    #            (3.40, 4.90),
##    #        ]
##    #    ):
##
##    #        for m_min, m_max in [(0.0, 10.48), (10.48, np.inf)]:
##
##    #            select= (
##    #                (zmin < data["Z"]) * (data["Z"] < zmax) *
##    #                (m_min < data["M"]) * (data["M"] < m_max)
##    #            )
##    #            _data = data[select]
##    #            if _data.size == 0:
##    #                continue
##
##    #            # Shuffle the unclustered positions to get the random catalog
##    #            _rand = self.catalog.get_random(select, fact=fact)
##
##    #            region1, region2 = get_region_radec(
##    #                _data["RA"], _data["DEC"],
##    #                _rand["RA"], _rand["DEC"],
##    #                3, 3
##    #            )
##
##    #            # Estimate w_theta
##    #            w_theta = Correlation(
##    #                np.logspace(0, 2, 11) * u.arcsec.to(u.deg),
##    #                _data["RA"], _data["DEC"], None, None,
##    #                _rand["RA"], _rand["DEC"], None, None,
##    #                split=split
##    #            )
##    #            ax.errorbar(
##    #                w_theta.centers * u.deg.to(u.arcsec),
##    #                w_theta.xi,
##    #                yerr=w_theta.dxi_poisson,
##    #                marker='.',
##    #                linestyle="none",
##    #                label="%.2f < logM < %.2f" % (m_min, m_max)
##    #            )
##    #            ax.set_ylim(1e-4, 1e1)
##
##    #        # Plot Schreiber+ 17 for reference
##    #        x, y = np.loadtxt("data/schreiber2017/%.1f_z_%.1f.dat" % (zmin, zmax)).T
##    #        idx = np.argsort(x)
##    #        ax.loglog(x[idx], y[idx], label="Schreiber+ 17")
##
##    #        ax.set_title("%.1f < z < %.1f" % (zmin, zmax))
##    #        ax.set_xlabel("theta [arcsec]")
##    #        ax.set_ylabel(r"w(theta) ($\pm 1 \sigma$ poisson error)")
##    #        ax.legend()
##
##    #    return fig, ax
##
##    def plot_L_2500_2keV(self):
##
##        fig = plt.figure()
##        ax = fig.gca()
##
##        x = self.catalog.catalog["log_L_2500"]
##        y = self.catalog.catalog["log_L_2_keV"]
##        ax.plot(x, y, '.', label="mock AGN")
##
##        alpha = 0.952
##        beta = 2.138
##        logx = np.linspace(25, 32, 71)
##        logy = alpha * logx - beta
##        ax.plot(logx, logy, label="Lusso+ 2010 Eq. (5)")
##
##        lusso2010 = fitsio.read("data/lusso2010/lusso2010.fits", ext=1)
##        ax.plot(lusso2010["lgLUV"], lusso2010["lgL2"], 'x', label="Lusso+ 2010 type 1 AGN")
##
##        ax.set_xlabel(r"$\log(L_{2500\,\mathrm{ang}})$")
##        ax.set_ylabel(r"$\log(L_{2\,\mathrm{keV}})$")
##        ax.legend()
##
##
##    def plot_E_BV(self):
##        fig = plt.figure()
##        ax = fig.gca()
##
##        bins = np.arange(0, 1, 0.01)
##
##        for label, select in (
##            ["AGN type1", np.in1d(self.catalog.catalog["type"], ["11", "12"])],
##            ["AGN type2", np.in1d(self.catalog.catalog["type"], ["22", "21"])],
##        ):
##            ax.hist(
##                self.catalog.catalog["E_BV"][select],
##                bins=bins,
##                histtype="step",
##                label=label,
##                density=True,
##                weights=np.full(select.sum(), 1 / 0.01)
##            )
##
##        ax.set_xlabel("E[B-V] [mag]")
##        ax.set_ylabel("p(E[B-V]) [1/dex]")
##        ax.legend()
##
##
##    def plot_g_ug(self, smf_ref=StellarMassFunctionCOSMOS2020()):
##
##        fig, axes = plt.subplots(3, 3, figsize=(3 * 6.4, 3 * 4.8), sharex=False, sharey=False)
##
##        select_galaxies = self.catalog.catalog["Z"] > 0
##        select_agn = self.catalog.catalog["is_agn_lambda"]
##        select_stars = ~select_galaxies
##
##        for ax, (zmin, zmax) in zip(axes.flatten(), smf_ref.get_zbins(type="all")):
##            for select, label in [
##                    (select_stars,    "star"),
##                    (select_galaxies, "galaxy"),
##                    (select_agn,      "galaxy+AGN"),
##            ]:
##                _select = select * \
##                    (zmin <= self.catalog.catalog["Z"]) * \
##                    (self.catalog.catalog["Z"] < zmax)
##
##                u = self.catalog.catalog["lsst-u"][_select]
##                g = self.catalog.catalog["lsst-g"][_select]
##                if np.all(select == select_agn):
##                    u += self.catalog.catalog["lsst-u_point"][_select]
##                    g += self.catalog.catalog["lsst-g_point"][_select]
##
##                u = util.flux_to_mag(u)
##                g = util.flux_to_mag(g)
##                ax.plot(g, (u - g), '.', label=label)
##
##            ax.set_xlabel("lsst g [AB mag]")
##            ax.set_ylabel("lsst-u $-$ lsst-g [AB mag]")
##            ax.set_title("$%.2f \leq z < %.2f$" % (zmin, zmax))
##            ax.axvline(27.5, linestyle='dotted', label="LSST SIM TRILEGAL r-band limit")
##            ax.set_xlim(35, 5)
##            ax.set_ylim(-1, 6)
##            ax.legend()
##
##
##    def _plot_quasar_luminosity_function(self, fig, ax, zmin, zmax, key="log_LX_2_10"):
##
##        # Plot the mock QLF with XLF conversion
##        from xlf import get_luminosity_bolometric, BAND_PARAMETERS_BC
##        bins = np.logspace(42.5, 50.5, 81)
##        log_bins = np.log10(bins)
##        select = (zmin < self.catalog.catalog["Z"]) * (self.catalog.catalog["Z"] < zmax)
##        x, dx, y, dy = util.get_key_function(
##            log_bins,
##            self.catalog.catalog[key][select],
##            values=self.catalog["duty_cycle"][select],
##            zmin=zmin,
##            zmax=zmax,
##            area_deg2=self.catalog.area
##        )
##        lbol = get_luminosity_bolometric(
##            10 ** x * (u.erg / u.s).to(u.L_sun),
##            *BAND_PARAMETERS_BC["Hard X-ray"]
##        )
##        ax.errorbar(
##            lbol * u.L_sun.to(u.erg / u.s),
##            y,
##            yerr=dy,
##            linestyle="none",
##            marker='.',
##            label="mock XLF, Shen+ 20 BC"
##        )
##
##        # Plot Shen+20
##        from xlf import get_qlf
##        qlf1 = get_qlf(lbol, zmin)
##        qlf2 = get_qlf(lbol, zmax)
##        qlf_lo = np.min([qlf1, qlf2], axis=0)
##        qlf_hi = np.max([qlf1, qlf2], axis=0)
##        ax.fill_between(lbol * u.L_sun.to(u.erg / u.s), qlf_lo, qlf_hi, alpha=0.2, label="Shen+ 20 2PL", color="red")
##
##        # Plot Miyaji+ 15 with Shen+ 20 bolometric correcion
##        from xlf import get_xlf, get_luminosity_bolometric, BAND_PARAMETERS_BC
##        lx = np.logspace(42, 46, 41)
##        xlf = get_xlf(lx, zmin, zmax)
##        lbol = get_luminosity_bolometric(
##            lx * (u.erg / u.s).to(u.L_sun),
##            *BAND_PARAMETERS_BC["Hard X-ray"]
##        )
##        ax.plot(lbol * u.L_sun.to(u.erg / u.s), xlf, label="Miyaji+ 15 XLF 2PL, Shen+ 20 BC")
##
##        # Set the text
##        ax.set_title("%.3f < z < %.3f" % (zmin, zmax))
##        ax.set_xlabel("L_bol [erg/s]")
##        ax.set_ylabel("phi [1 / Mpc3 / dex]")
##        ax.loglog()
##
##        # Limits
##        ax.set_ylim(1e-11, 1e-3)
##
##        return fig, ax
##
##
##    def plot_quasar_luminosity_function(self):
##
##        miyaji15 = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat", usecols=(0, 1, 5, -2, -1))
##        zbins = np.sort(np.unique(np.concatenate([miyaji15[:, 0], miyaji15[:, 1]])))
##        figsize = 3 * 6.4, 4 * 4.8
##        fig, axes = plt.subplots(4, 3, figsize=figsize, sharex=False, sharey=False)
##        for ax, zmin, zmax in zip(axes.flatten(), zbins[:-1], zbins[1:]):
##            self._plot_quasar_luminosity_function(fig, ax, zmin, zmax)
##        axes[0, 0].legend()
##
##
#    def plot_quasar_Mg_luminosity_function(self):
#
#        pal16 = np.loadtxt("data/palanque-delabrouille2016/table_a1.dat")
#        M_g_bins = np.arange(-26.00, -15.00, 0.40)
#
#        names = open("data/palanque-delabrouille2016/table_a1.dat").readline()[2:].split()
#        nz = (len(names) - 1) // 2
#
#        # Get the K correction
#        from scipy.interpolate import interp1d
#        kz = np.loadtxt("data/palanque-delabrouille2016/kcorrection/Default Dataset.csv")
#        kz = interp1d(*kz.T, bounds_error=False, fill_value="extrapolate")
#        ztest = np.linspace(0.5, 4.0)
#        assert np.all(kz(ztest) < 2) * np.all(kz(ztest) > -0.5)
#
#        # Start the plots
#        fig, axes = plt.subplots(2, 4, figsize=(4 * 6.4, 2 * 4.8), sharey=False)
#
#        for i, ax in enumerate(axes.flatten()):
#
#            # Plot pal+16 qlf
#            log_x  = pal16[:, 0]
#            log_y  = pal16[:, 1 + 2 * i]
#            dlog_y = pal16[:, 2 + 2 * i]
#
#            y = 10 ** log_y
#            y_lo = 10 ** (log_y - dlog_y)
#            y_hi = 10 ** (log_y + dlog_y)
#
#            ax.errorbar(
#                pal16[:, 0],
#                y,
#                yerr=[y - y_lo, y_hi - y],
#                color='k',
#                marker='x',
#                linestyle="none",
#                label="SDSS-IV/eBOSS (Pal+16)",
#                zorder=99
#            )
#
#            # Find the redshift bin
#            zstr = re.findall("...._z_....", names[1 + 2 * i])[0].split('_')
#            zmin, zmax = map(float, zstr[::2])
#
#            # Plot the LSST limits
#            cosmo = FlatLambdaCDM(H0=67.9, Om0=0.3065, Tcmb0=2.73)
#            z = (zmin + zmax) / 2
#            ax.axvline(26.9 - cosmo.distmod(z).value - (kz(z) - kz(2.0)), linestyle='dotted')
#            ax.axvline(24.5 - cosmo.distmod(z).value - (kz(z) - kz(2.0)), linestyle='dotted')
#
#            # Plot the catalog
#            c = self._get_c(zmin, zmax)
#            c = c[c["is_agn"] * ~c["is_optical_type2"]]
#            if c.size != 0:
#                # Get the dereddened g-band flux/magnitude
#                #gflux = self.catalog.get_agn_flux("lsst-g", apply_ebv=False, redo=True)
#                #gflux = self.catalog.catalog["agn_lsst-r"]
#                gmag = util.flux_to_mag(c["lsst-g_point"])
#                M_g_z2 = gmag - cosmo.distmod(c["Z"]).value - (kz(c["Z"]) - kz(2))
#
#                x, dx, y, dy = util.get_key_function(
#                    M_g_bins,
#                    M_g_z2,
#                    zmin=zmin,
#                    zmax=zmax,
#                    area_deg2=self.catalog.area,
#                    H0=67.9,
#                    Om0=0.3065,
#                    Tcmb0=2.73,
#                )
#                ax.fill_between(x, y - dy, y + dy, alpha=0.8, label="Mock type1 AGN")
#
#            ax.set_xlabel("Mg(z=2)")
#            ax.set_ylabel("Phi [1/Mpc3/mag]")
#            ax.semilogy()
#            ax.set_title(' < '.join(zstr), transform=ax.transAxes)
#            ax.legend(loc="lower left")
#            #ax.set_xlim(ax.get_xlim()[::-1])
#            #ax.set_xlim(-15, -40)
#
#        axes[0, 0].legend(loc="lower left")
#
#    def plot_quasar_M_1450_luminosity_function(self):
#
#        import qlf
#        M_1450_min = -40
#        M_1450_max = 0
#        dM_1450 = 1.00
#        mbins = np.arange(M_1450_min, M_1450_max + 1e-6, dM_1450)
#        mvec = 0.5 * (mbins[:-1] + mbins[1:])
#
#        fig, axes = get_figure(2, 4)
#        dz = 0.10
#
#        for ax, z in zip(axes.flatten(), qlf.TABLE_MANTI17[:, 0]):
#
#            # Plot Manti UV LF
#            ax.plot(mvec, qlf.get_quasar_UV_luminosity_function(mvec, z), label="Manti+ 17")
#
#            # Plot mock UV LF
#            c = self._get_c(zlo=z - dz / 2, zhi=z + dz / 2)
#
#            # Plot galaxy + AGN UV LF
#            x, dx, y, dy = util.get_key_function(
#                bins=mbins,
#                x=c["magabs_mock-1450_total"],
#                zmin=z - dz / 2,
#                zmax=z + dz / 2,
#                area_deg2=self.catalog.area,
#                H0=67.3,
#                Om0=0.315,
#            )
#            ax.errorbar(x, y, xerr=dx, yerr=dy, marker='.', linestyle="none", label="mock galaxy + AGN")
#
#            for select in (
#                #c["is_agn"] * ~c["is_optical_type2"],
#                c["is_agn"] * ~c["is_optical_type2"] * (c["log_LX_2_10"] > 42),
#                #c["is_agn"] * ~c["is_optical_type2"] * (c["log_LX_2_10"] > 43),
#                #c["is_agn"] * ~c["is_optical_type2"] * (c["log_LX_2_10"] > 44),
#            ):
#
#                x, dx, y, dy = util.get_key_function(
#                    bins=mbins,
#                    x=c["magabs_mock-1450_point"][select],
#                    zmin=z - dz / 2,
#                    zmax=z + dz / 2,
#                    area_deg2=self.catalog.area,
#                    H0=67.3,
#                    Om0=0.315,
#                )
#                ax.errorbar(x, y, xerr=dx, yerr=dy, marker='.', linestyle="none", label="mock AGN type1")
#
#                x, dx, y, dy = util.get_key_function(
#                    bins=mbins,
#                    x=c["magabs_mock-1450_total"][select],
#                    zmin=z - dz / 2,
#                    zmax=z + dz / 2,
#                    area_deg2=self.catalog.area,
#                    H0=67.3,
#                    Om0=0.315,
#                )
#                ax.errorbar(x, y, xerr=dx, yerr=dy, marker='.', linestyle="none", label="mock galaxy+AGN type1")
#
#
################################################################################
## Miayji + 2015 exercise
################################################################################
#            from merloni2014 import Merloni2014
#            mer14 = Merloni2014(True, False, 0.05, 0.95)
#            dlogLX = 0.10
#            dM_1450 = 1.00
#            lx = np.arange(42, 46 + 1e-6, dlogLX)
#
#            Phi_M_1450_total = np.zeros(mbins.size - 1)
#            Phi_M_1450_type1 = np.zeros(mbins.size - 1)
#
#            # NOTE: topcat fit between LX vs. M_1450 20240229
#            m = -0.3696373
#            c = 35.76567
#
#            # Calculate M_1450 through LX
#            for l in lx:
#
#                M_1450_sample = []
#                for i in range(100):
#                    M_1450_sample.append(
#                        util.mock_lx_to_M_1450(
#                            lx=l + dlogLX / 2.0,
#                            z=z,
#                            distance_cm=COSMO70.comoving_distance(z).to(u.cm),
#                            ebv=util.get_E_BV(type2=False),
#                            scatter=True,
#                            seed=i,
#                        )
#                    )
#                    print(i, end='\r')
#
#                M_1450_sample = np.array(M_1450_sample)
#                print(z, l)
#
#                # NOTE: M_1450_sample is in 1/dex so that M_1450 * dM sums to unity
#                Phi_M_1450_sample = np.histogram(M_1450_sample, bins=mbins, density=True)[0]
#
#                # Normalize to Phi_LX
#                # Phi_LX in 1/Mpc3
#                frac_unobs = 1 - mer14.get_f_obs(z, l + dlogLX / 2.0)
#                Phi_lx = qlf.get_xray_luminosity_function_bon16(l + dlogLX / 2.0, z)
#                Phi_M_1450_total += Phi_M_1450_sample * Phi_lx * dlogLX
#                Phi_M_1450_type1 += Phi_M_1450_sample * Phi_lx * dlogLX * frac_unobs
#
#            ax.plot(mvec, Phi_M_1450_total)
#            ax.plot(mvec, Phi_M_1450_type1)
#
#            ax.semilogy()
#            ax.text(0.10, 0.90, f"z = {z:.2f}", transform=ax.transAxes)
#            ax.set_xlabel(r"$M_\mathrm{AB}^{1450}$")
#            ax.set_ylabel(r"$\Phi\ [\mathrm{Mpc}^{-3}\,\mathrm{mag}^{-1}]$")
#            ax.set_xlim(M_1450_min, M_1450_max)
#            ax.set_ylim(1e-8, 1e-2)
#            if ax == axes[0, 0]:
#                ax.legend(frameon=True)
#
#        return fig, axes
#
##    def plot_quasar_Mg_luminosity_function2(self):
##
##        pal16 = np.loadtxt("data/palanque-delabrouille2016/table_a1.dat")
##        M_g_bins = np.arange(-35.00, -16.00, 0.40)
##
##        names = open("data/palanque-delabrouille2016/table_a1.dat").readline()[2:].split()
##        nz = 8
##
##        c = self.catalog.catalog
##
##        # Get the K correction
##        from scipy.interpolate import interp1d
##        kz = np.loadtxt("data/palanque-delabrouille2016/kcorrection/Default Dataset.csv")
##        kz = interp1d(*kz.T, bounds_error=False, fill_value="extrapolate")
##
##        # Get the dereddened g-band flux/magnitude
##        cosmo = FlatLambdaCDM(H0=67.9, Om0=0.3065, Tcmb0=2.73)
##        #gflux = self.catalog.get_agn_flux("lsst-g", apply_ebv=False, redo=True)
##        #gflux = self.catalog.catalog["agn_lsst-r"]
##        gflux = self.catalog.catalog["lsst-g_point"]
##        gmag = util.flux_to_mag(gflux)
##        M_g_z2 = gmag - cosmo.distmod(c["Z"]).value - (kz(c["Z"]) - kz(2))
##
##        # Start the plots
##        fig, axes = plt.subplots(2, 4, figsize=(16, 9), sharex=False, sharey=False, tight_layout=True)
##        for i, ax in enumerate(axes.flatten()):
##
##            # Plot pal+16 qlf
##            log_x  = pal16[:, 0]
##            log_y  = pal16[:, 1 + 2 * i]
##            dlog_y = pal16[:, 2 + 2 * i]
##
##            ax.errorbar(
##                log_x,
##                log_y,
##                yerr=dlog_y,
##                marker='.',
##                color='k',
##                linestyle="none",
##                label="SDSS-IV/eBOSS (Pal+16)" if i == 0 else None,
##                zorder=99
##            )
##
##            # Find the redshift bin
##            zstr = re.findall("...._z_....", names[1 + 2 * i])[0].split('_')
##            zmin, zmax = map(float, zstr[::2])
##
##            # Perform the selection in redshift and AGN
##            select_z = (zmin < c["Z"]) * (c["Z"] < zmax)
##            if select_z.sum() == 0:
##                continue
##
##            label = "mock AGN type1" if i == 0 else None
##            #select_type = np.in1d(c["type"], ["11", "12"])
##            select_type = c["is_optical_type2"]
##
##            log_x, log_dx, y, dy = util.get_key_function(
##                M_g_bins,
##                M_g_z2[select_z * select_type],
##                values=self.catalog.catalog["duty_cycle"][select_z * select_type],
##                zmin=zmin,
##                zmax=zmax,
##                area_deg2=self.catalog.area,
##                H0=67.9,
##                Om0=0.3065,
##                Tcmb0=2.73,
##            )
##            select = y - dy > 0
##            log_y_lo = np.where(select, np.log10(y - dy), -99)
##            log_y_hi = np.log10(y + dy)
##            ax.fill_between(
##                log_x,
##                log_y_lo,
##                log_y_hi,
##                alpha=0.6,
##                label=label
##            )
##
##            ax.set_xlabel(r"$M_g(z=2)$")
##            if i == 0 or i == 4:
##                ax.set_ylabel(r"$log \Phi$  [1/Mpc3/mag]")
##            ax.set_title(' < '.join(zstr), transform=ax.transAxes)
##            ax.set_xlim(ax.get_xlim()[::-1])
##            ax.set_xlim(-21, -30)
##            ax.set_ylim(-9, -4)
##
##        axes[0, 0].legend()
##
##    def plot_mstar_sfr(self):
##        fig = plt.figure()
##        ax = fig.gca()
##
##        c = self.catalog.catalog
##        imag = -2.5 * np.log10(self.catalog.catalog["agn_lsst-r"]) + 23.90
##        for select, label in [
##                (np.ones(c.size, dtype=bool), "mock galaxies"),
##                (c["is_agn_lambda"], "mock AGN lambda_SAR > 32"),
##                (c["is_agn_LX"],     "mock AGN LX > 1e42"),
##        ]:
##            ax.plot(c["M"][select], c["SFR"][select], '.', label=label)
##
##        from bongiorno2016 import M_MIN, M_MAX
##        ax.set_xlim(M_MIN, M_MAX)
##
##        ax.legend()
##        ax.set_xlabel(r"$\log M_\mathrm{star}$ [Msun]")
##        ax.set_ylabel(r"$\mathrm{SFR}$ [Msun/yr]")
##        ax.semilogy()
##
##
##    def plot_cmd_agn_host(self):
##
##        from astropy.io import ascii
##        bon = ascii.read("data/bongiorno2012/master_xmm_sed_all_best_noDUP_NEW_FIN_Lbol.dat_FROZEN")
##
##        fig, axes = plt.subplots(2, 2, figsize=(2 * 6.4, 2 * 4.8))
##        fig.suptitle("0.30 < z < 2.50")
##
##        bins_x = np.linspace(-26.0, -16.0, 11)
##        bins_y = np.linspace(0.5, 2.0, 16)
##
##        from copy import deepcopy
##        c = deepcopy(self.catalog.catalog)
##
##        axes[0, 0].hist(bon["MBjkc_gal"], bins=bins_x, histtype="step", density=True)
##        axes[0, 1].hist(bon["MUjkc_gal"], bins=bins_x, histtype="step", density=True)
##        axes[1, 1].hist(bon["MUjkc_gal"] - bon["MBjkc_gal"], bins=bins_y, histtype="step", density=True)
##        axes[1, 0].plot(
##            bon["MBjkc_gal"],
##            bon["MUjkc_gal"] - bon["MBjkc_gal"],
##            '.',
##            label="Bon+12"
##        )
##
##        imag = -2.5 * np.log10(self.catalog.catalog["agn_lsst-r"]) + 23.90
##        for select, label in (
##            (np.ones_like(c, dtype=bool), "all"),
##            (c["is_agn_lambda"], "AGN"),
##            (c["is_agn_lambda"] * (imag < 27), "AGN, i < 27"),
##        ):
##
##            select *= (0.3 <= c["Z"]) * (c["Z"] < 2.5)
##
##            axes[0, 0].hist(c[select]["magabs_johnson-B"], bins=bins_x, histtype="step", density=True)
##            axes[0, 1].hist(c[select]["magabs_johnson-U"], bins=bins_x, histtype="step", density=True)
##            axes[1, 1].hist(c[select]["magabs_johnson-U"] - self.catalog.catalog[select]["magabs_johnson-B"], bins=bins_y, histtype="step", density=True)
##            axes[1, 0].plot(
##                c[select]["magabs_johnson-B"],
##                (c[select]["magabs_johnson-U"] - c[select]["magabs_johnson-B"]),
##                '.',
##                label=label
##            )
##
##        axes[1, 0].legend(frameon=True)
##
##        axes[1, 0].set_xlim(-14, -27)
##        axes[0, 1].set_xlim(-14, -27)
##        axes[0, 0].set_xlim(-14, -27)
##
##        axes[1, 0].set_ylim(0.0, 1.7)
##        axes[1, 1].set_xlim(0.0, 1.7)
##
##        #axes[0, 0].set_xlim(axes[0, 0].get_xlim()[::-1])
##        #axes[1, 0].set_xlim(axes[1, 0].get_xlim()[::-1])
##
##        axes[1, 0].set_xlabel("$B$")
##        axes[1, 0].set_ylabel("$U-B$")
##
##        axes[0, 0].set_xlabel("$B$")   ; axes[0, 0].set_ylabel("frequency")
##        axes[0, 1].set_xlabel("$U$")   ; axes[0, 1].set_ylabel("frequency")
##        axes[1, 1].set_xlabel("$U-B$") ; axes[1, 1].set_ylabel("frequency")
##
##
##    #def plot_classification_mer14(self):
##    #    """Plot Merloni+14 distribution for three different redshift bins"""
##
##    #    fig, axes = plt.subplots(1, 3, figsize=(3 * 6.4, 4.8))
##    #    c = self.catalog.catalog
##    #    for col, (zlo, zhi) in enumerate([
##    #        (0.3, 3.5),
##    #        (0.3, 1.7),
##    #        (1.7, 3.5)
##    #    ]):
##    #        ax = axes[col]
##    #        loglx = np.arange(42.0, 46.1, 0.2)
##    #        first = True
##    #        for l_min, l_max in zip(loglx[:-1], loglx[1:]):
##    #            print(zlo, zhi, l_min, l_max)
##    #            select = (
##    #                (zlo < c["Z"]) * (c["Z"] < zhi) *
##    #                (l_min < c["log_LX_2_10"]) * (c["log_LX_2_10"] < l_max)
##    #            )
##    #            p = 0
##    #            for i, (t1, t2) in enumerate([
##    #                ("22", "21"),
##    #                ("21", "12"),
##    #                ("12", "11")
##    #            ]):
##    #                select_type = select * (c["type"] == t1)
##    #                p += 100 * np.ma.true_divide(select_type.sum(), select.sum())
##    #                ax.plot(0.5 * (l_min + l_max), p, '.', color="C%d" % i, label="type %s %s" % (t1, t2) if first else None)
##
##    #                mer14 = np.loadtxt("data/merloni2014/%s_%s.csv" % (t1, t2))
##    #                idx = np.argsort(mer14[:, 0])
##    #                ax.plot(*mer14[idx, :].T, color="C%d" % i)
##
##    #            first = False
##
##    #        ax.set_title("%.2f < z < %.2f" % (zlo, zhi))
##    #        ax.set_xlabel("log LX [erg/s]")
##    #        ax.set_ylabel("Fraction [%]")
##    #        ax.legend()
##
##    def plot_redshift_histogram(self):
##        fig = plt.figure()
##        ax = fig.gca()
##
##        c = self.catalog.catalog
##        zbins = np.linspace(0, 10, 101)
##
##        from smf import StellarMassFunctionCOSMOS2020
##        smf = StellarMassFunctionCOSMOS2020()
##        for mlo, mhi in [
##                (0.00, 99.0),
##                (9.50, 10.0),
##                (10.0, 10.5),
##                (10.5, 11.0),
##                (11.5, 12.0)
##        ]:
##            select = (mlo <= c["M"]) ^ (c["M"] < mhi)
##            ax.hist(c["Z"][select], bins=zbins, histtype="step", label=f"{mlo} <= M < {mhi}")
##
##        ax.set_xlabel("redshift")
##        ax.set_ylabel("frequency binsize $\Delta z = 0.10$")
##        ax.legend()
##
##    def plot_redshift_xray_luminosity(self):
##
##        fig = plt.figure()
##        ax = fig.gca()
##
##        for select, label in (
##            (np.ones_like(self.catalog.catalog, dtype=bool), "all"),
##            (self.catalog.catalog["is_lsst_all_10yr"], "single visit"),
##            (self.catalog.catalog["is_lsst_all_30s"], "10 yr"),
##        ):
##            ax.plot(
##                self.catalog.catalog["Z"][select],
##                self.catalog.catalog["log_LX_2_10"][select],
##                '.',
##                label=label,
##            )
##        ax.set_xlabel("redshift")
##        ax.set_ylabel("log LX 2-10 keV")
##        ax.legend()
##
##    def plot_redshift_Mi_luminosity(self):
##
##        fig = plt.figure()
##        ax = fig.gca()
##
##        from util import COSMO_DEFAULT
##        i = util.flux_to_mag(self.catalog.catalog["total_lsst-i"])
##        M_i = i - COSMO_DEFAULT.distmod(self.catalog.catalog["Z"]).value
##
##        for select in (
##            np.ones_like(self.catalog.catalog, dtype=bool),
##            self.catalog.catalog["is_lsst_all_10yr"],
##            self.catalog.catalog["is_lsst_all_30s"],
##        ):
##            ax.plot(self.catalog.catalog[select]["Z"], M_i[select], '.')
##        ax.set_xlabel("redshift")
##        ax.set_ylabel("M_i")
##        ax.set_ylim(ax.get_ylim()[::-1])
##
##    def plot_bivariate_distribution_bon16(self, lim=5):
##
##        # Calculate the weight by using the area and the volume
##
##        from smf import StellarMassFunctionCOSMOS2020
##        smf_ref = StellarMassFunctionCOSMOS2020()
##
##        from bongiorno2016 import M_MIN, M_MAX
##        binsize = 0.20
##        xb = np.arange(M_MIN, M_MAX + binsize / 2., binsize)
##        yb = np.arange(32.0, 36.0 + binsize / 2., binsize)
##
##        xc = xb[:-1] + binsize / 2
##        yc = yb[:-1] + binsize / 2
##        xc, yc = np.meshgrid(xc, yc, indexing="ij")
##        from bongiorno2016 import get_Psi
##
##        c = self.catalog.catalog
##        x = c["M"]
##        #y = c["log_LX_2_10"] - c["M"]
##        y = c["log_lambda_SAR"]
##
##        fig, axes = plt.subplots(10, 5, figsize=(5 * 6.4, 12 * 4.8))
##        for ax, (zmin, zmax) in zip(axes, smf_ref.get_zbins(type="all")):
##
##            select = (zmin <= c["Z"]) * (c["Z"] < zmax)
##            volume = util.get_volume(zmin, zmax, area_deg2=self.catalog.area)
##            h, xedges, yedges = np.histogram2d(
##                x[select],
##                y[select],
##                (xb, yb)
##            )
##
##            # NOTE: Mask out noisy bins
##            h[h < lim] = 0
##
##            h /= (volume * binsize ** 2)
##            h_bon16 = get_Psi(10 ** xc, 10 ** yc, 0.5 * (zmin + zmax))
##
##            norm1 = mpl.colors.LogNorm(vmin=1e-10, vmax=1e-3)
##            norm2 = mpl.colors.Normalize(vmin=-0.5, vmax=0.5)
##            ratio = np.ma.masked_equal((h / h_bon16).T, 0)
##
##            im = ax[0].pcolormesh(xb, yb, h_bon16.T, norm=norm1)
##            im = ax[1].pcolormesh(xb, yb, h.T, norm=norm1)
##            im = ax[2].pcolormesh(xb, yb, np.ma.log10(ratio), norm=norm2, cmap=mpl.cm.coolwarm)
##            ax[3].hist(np.ma.log10(ratio).flatten(), bins=np.linspace(-.5, .5, 21))
##
##            for _ in ax:
##                _.set_title("%.2f < z < %.2f" % (zmin, zmax))
##                _.set_xlabel("log M")
##                _.set_ylabel("log lambda_SAR")
##
##            ax[3].set_xlabel("log ( Psi_mock / Psi_bon16 )")
##            ax[3].set_ylabel("frequency")
##
##            ## Highlight some FX
##            #Gamma = 1.9
##            #cosmo = FlatLambdaCDM(H0=70, Om0=0.30, Tcmb0=2.73)
##            #z = 0.5 * (zmin + zmax)
##            #fx = xc + yc \
##            #    - np.log10(4 * np.pi * cosmo.luminosity_distance(z).cgs.value ** 2) \
##            #    + (2 - Gamma) * np.log10(1 + z)
##            #fx = np.ma.masked_less(fx, -16)
##            #im = ax[4].pcolormesh(xb, yb, fx.T)
##
##
#
##    def plot_sed_full(self):
##        self._plot_sed(0.1, 6.0)
##
##    def _plot_rgb(self, coadd=True, extent=(0, 1024, 0, 1024)):
##        from astropy.visualization import make_lupton_rgb
##        name = "coadd" if coadd else "calexp/0.0"
##        rng = np.random.default_rng()
##        image_b = fitsio.read(f"{self.dirname}/../sky/{name}/lsst-g/combined-lsst-g-sci.fits")
##        image_g = fitsio.read(f"{self.dirname}/../sky/{name}/lsst-r/combined-lsst-r-sci.fits")
##        image_r = fitsio.read(f"{self.dirname}/../sky/{name}/lsst-i/combined-lsst-i-sci.fits")
##
##        def get_image(i):
##            i = i[extent[0]:extent[1], extent[2]:extent[3]]
##            return ((i - i.mean()) / i.std())
##
##        image = make_lupton_rgb(
##            1.0 * get_image(image_r),
##            1.0 * get_image(image_g),
##            0.8 * get_image(image_b),
##            #minimum=1.5,
##            #Q=10,
##            #stretch=0.3
##        )
##        plt.imshow(image, interpolation="none")
##    def plot_rgb_calexp(self):
##        self._plot_rgb(coadd=False)
##    def plot_rgb_coadd(self):
##        self._plot_rgb(coadd=True)
##    def plot_rgb_coadd_169(self):
##        self._plot_rgb(coadd=True, extent=(0, 900, 0, 1600))
##
##    def _plot_ugrizy(self, coadd=True, extent=None, overplot_objects=False):
##        """Create a image with the LSST bands"""
##        import re
##
##        name = "coadd" if coadd else "calexp/0.0"
##        fnames = (
##            f"{self.dirname}/../sky/{name}/lsst-u/combined-lsst-u-sci.fits",
##            f"{self.dirname}/../sky/{name}/lsst-g/combined-lsst-g-sci.fits",
##            f"{self.dirname}/../sky/{name}/lsst-r/combined-lsst-r-sci.fits",
##            f"{self.dirname}/../sky/{name}/lsst-i/combined-lsst-i-sci.fits",
##            f"{self.dirname}/../sky/{name}/lsst-z/combined-lsst-z-sci.fits",
##            f"{self.dirname}/../sky/{name}/lsst-y/combined-lsst-y-sci.fits",
##        )
##        fname_ref = f"{self.dirname}/../sky/{name}/lsst-r/combined-lsst-r-sci.fits"
##
##        def load(filename):
##            fits, header = fitsio.read(filename, header=True)
##
##            from astropy.wcs import WCS
##            wcs = WCS(header)
##
##            if extent is None:
##                return fits
##
##            y1, y2, x1, x2 = extent
##            return fits[y1:y2, x1:x2]
##
##        qlo, qhi = np.quantile(
##            load(fname_ref),
##            [0.10, 0.90]
##        )
##
##        #fig, axes = plt.subplots(2, 3, figsize=(3 * 6.4, 2 * 4.8), sharex=True, sharey=True)
##        fig, axes = plt.subplots(2, 3, figsize=(16, 9), sharex=True, sharey=True)
##        for b, fname, ax in zip('ugrizy', fnames, axes.flatten()):
##
##            if not os.path.exists(fname):
##                print("plot_ugrizy: filename %s not found. Continuing without" % fnames[0])
##                return
##
##            fits = load(fname)
##            im = ax.imshow(
##                fits,
##                norm=mpl.colors.Normalize(vmin=qlo, vmax=qhi),
##                cmap=mpl.cm.gist_heat,
##                interpolation="none"
##            )
##            ax.set_title(f"lsst-{b}", fontsize="x-large")
##            #plt.colorbar(im, ax=ax)
##        plt.tight_layout()
##
##    def plot_ugrizy_calexp(self):      self._plot_ugrizy(False)
##    def plot_ugrizy_coadd(self):       self._plot_ugrizy(True)
##    def plot_ugrizy_calexp_zoom(self): self._plot_ugrizy(False, extent=[1000, 1500, 1000, 1500])
##    def plot_ugrizy_coadd_zoom(self):  self._plot_ugrizy(True,  extent=[1000, 1500, 1000, 1500])
##
##    def plot_diff_image(self):
##
##        #fig, axes = plt.subplots(2, 3, figsize=(3 * 6.4, 2 * 4.8), tight_layout=True, sharex=True, sharey=True)
##        fig, axes = plt.subplots(2, 3, figsize=(16, 9), tight_layout=True, sharex=True, sharey=True)
##        plt.tight_layout()
##
##        filename0 = f"{self.dirname}/imsim/53.09004683168595_-33.97759584632161_0/lsst-r/output/eimage_00000001-0-r-R22_S11-det094.fits"
##        fits0, header = fitsio.read(filename0, header=True)
##
##        from astropy.wcs import WCS
##        wcs = WCS(header)
##
##        c = self.catalog.catalog[
##            self.catalog.catalog["is_agn_lambda"] *
##            np.isfinite(self.catalog.catalog["agn_lsst-r"])
##        ]
##        idx = np.argsort(c["agn_lsst-r"])[::-1]
##        c = c[idx]
##
##        i = 0
##
##        while True:
##
##            col, row = map(int, wcs.all_world2pix(c[i]["RA"], c[i]["DEC"], 0))
##            row1 = row - 40
##            row2 = row + 40
##            col1 = col - 40
##            col2 = col + 40
##
##            if not (
##                (0 <= row1 < fits0.shape[0]) and
##                (0 <= row2 < fits0.shape[0]) and
##                (0 <= col1 < fits0.shape[1]) and
##                (0 <= col2 < fits0.shape[1])
##            ):
##                i += 1
##                continue
##            break
##
##        fits0 = fits0[row1:row2, col1:col2]
##        qlo, qhi = np.quantile(fits0, [0.998, 1.00])
##
##        for i, ax in enumerate(axes.flatten()):
##            ax.set_title("t = %d days" % (i * 365))
##            filename = f"{self.dirname}/imsim/53.09004683168595_-33.97759584632161_{i * 365}/lsst-r/output/eimage_00000001-0-r-R22_S11-det094.fits"
##            fits = fitsio.read(filename)
##            ax.imshow(
##                fits[row1:row2, col1:col2],
##                cmap=mpl.cm.gist_heat,
##                norm=mpl.colors.LogNorm(vmin=qlo, vmax=qhi)
##            )
##
##
##    def plot_lightcurve(self):
##
##        """Plot some AGN lightcurve examples from the mock"""
##
##        plt.figure()
##        #plt.title("Example DRW lightcurves with $\Delta Lbol = \pm 1\,\mathrm{dex}$")
##        plt.title("Example DRW lightcurves")
##
##        t = np.arange(7300)
##        idx = 0
##
##        for i, c in enumerate(self.catalog.catalog):
##
##            if not c["is_agn_lambda"]:
##                continue
##
##            #lc0 = self.catalog.get_lightcurve(i, lum_offset=-1.00)
##            lc1 = self.catalog.get_lightcurve(i, lum_offset= 0.00, debug=True)
##            #lc2 = self.catalog.get_lightcurve(i, lum_offset=+1.00)
##            #if np.all(lc0 == 1.0):
##            #    continue
##            #assert np.any(lc0 != lc1)
##
##            #y0 = -2.5 * np.log10(c["agn_lsst-r"] * lc0 * 1e-6 / 3631)
##            y1 = -2.5 * np.log10(c["agn_lsst-r"] * lc1 * 1e-6 / 3631)
##            #y2 = -2.5 * np.log10(c["agn_lsst-r"] * lc2 * 1e-6 / 3631)
##
##            color = "C%d" % idx
##            plt.plot(t[:3650], y1, color=color)
##            #plt.fill_between(t[:3650], y0, y2, alpha=0.4, color=color)
##            #plt.axhline(c["agn_lsst-r"], linestyle='dashed', color=color)
##            #plt.axhline(10 ** np.log10(y1).mean(), linestyle='dotted', color=color)
##            plt.axhline(y1.mean(), linestyle='dotted', color=color)
##
##            idx += 1
##            if idx >= 3:
##                break
##
##        plt.xlabel("time [days]")
##        plt.ylabel("lsst-r [mag]")
##        plt.ylim(plt.ylim()[::-1])
##
##
##    #def plot_agn_diff_lsst_r(self):
##    #    """Plot some AGN lightcurve examples from the mock"""
##    #    from astropy.wcs import WCS
##    #    fig, axes = plt.subplots(5, 3, figsize=(3 * 6.4, 5 * 4.8))
##    #    for c in self.catalog.catalog[self.catalog.catalog["is_agn_lambda"]]:
##    #        ra = c["RA"]
##    #        dec = c["DEC"]
##    #        try:
##    #            fname0 = f"{self.dirname}/../sky/calexp/0/lsst-r/combined-lsst-r-sci.fits"
##    #            fits0, header = fitsio.read(fname0, header=True)
##    #            wcs = WCS(header)
##    #            col, row = map(int, wcs.all_world2pix(ra, dec, 0))
##    #            qlo, qhi = np.quantile(
##    #                fits0[
##    #                    row-50:row+50,
##    #                    col-50:col+50
##    #                ],
##    #                [0.999, 0.9999]
##    #            )
##    #        except (IndexError, OSError):
##    #            pass
##    #        for i in range(5):
##    #            time = i * 365
##    #            try:
##    #                fname = f"{self.dirname}/../sky/calexp/{time}/lsst-r/combined-lsst-r-sci.fits"
##    #                fits = fitsio.read(fname)
##    #            except OSError:
##    #                continue
##    #            print(row, col)
##    #            if (
##    #                (row - 50 < 0) or
##    #                (row + 50 >= fits.shape[0]) or
##    #                (col - 50 < 0) or
##    #                (col + 50 >= fits.shape[1])
##    #            ):
##    #                continue
##    #            axes[i, 1].set_title(f"time = {time} [days]")
##    #            lc = self.catalog.get_lightcurve(c["ID"])
##    #            axes[i, 0].plot(np.arange(3650), lc)
##    #            axes[i, 0].plot(time, lc[time], 'o', color='r', markersize=20, fillstyle='none', linewidth=5)
##    #            axes[i, 1].imshow(fits[row-50:row+50, col-50:col+50], vmin=qlo, vmax=qhi,)
##    #            delta = fits - fits0
##    #            sigma = np.std(delta, ddof=1)
##    #            axes[i, 2].imshow(delta[row-50:row+50, col-50:col+50], vmin=-3 * sigma, vmax=3 * sigma, cmap=mpl.cm.coolwarm)
##    #    plt.tight_layout()
##
##    def get_select_mag(self):
##        c = self.catalog.catalog_photometric
##        return (
##            c["is_detected_all"] *
##            ~(c["GroupSize_lsst_u"] > 0) *
##            ~(c["GroupSize_lsst_g"] > 0) *
##            ~(c["GroupSize_lsst_r"] > 0) *
##            ~(c["GroupSize_lsst_i"] > 0) *
##            ~(c["GroupSize_lsst_z"] > 0) *
##            ~(c["GroupSize_lsst_y"] > 0)
##        )
##
##    def plot_mag_true_measured(self):
##
##        fig, axes = plt.subplots(2, 3, figsize=(3 * 6.4, 2 * 4.8))
##        select = self.get_select_mag()
##
##        c = self.catalog.catalog_photometric[select]
##        is_gal  = c["is_gal"]
##        is_agn  = c["is_agn"]
##        is_star = c["is_star"]
##
##        for b, ax in zip('ugrizy', axes.flatten()):
##
##            ax.set_title("lsst-%s" % b)
##
##            for s, l in zip(
##                [is_gal, is_agn, is_star],
##                ["galaxy", "AGN", "star"]
##            ):
##                ax.errorbar(
##                    c["MAG_lsst_%s" % b][s],
##                    c["MAG_BEST_lsst_%s" % b][s],
##                    c["MAGERR_BEST_lsst_%s" % b][s],
##                    linestyle="none",
##                    marker=".",
##                    label=l
##                )
##            ax.plot([27, 15], [27, 15], color='k', linestyle='dotted')
##            ax.set_xlabel("mag true")
##            ax.set_ylabel("mag measured")
##            ax.set_xlim(ax.get_xlim()[::-1])
##            ax.set_ylim(ax.get_ylim()[::-1])
##            ax.axhline(limiting_magnitude_10yr[b], linestyle='dotted', label="LSST 10yr")
##            ax.axvline(limiting_magnitude_10yr[b], linestyle='dotted')
##            ax.legend()
##
##    def plot_mag_radius(self):
##        fig, axes = plt.subplots(2, 3, figsize=(3 * 6.4, 2 * 4.8), sharex=True, sharey=True)
##        select = self.get_select_mag()
##        c = self.catalog.catalog_photometric[select]
##        is_gal  = c["is_gal"]
##        is_agn  = c["is_agn"]
##        is_star = c["is_star"]
##
##        for b, ax in zip('ugrizy', axes.flatten()):
##            print(b)
##            ax.set_title("lsst-%s" % b)
##            for s, l in zip(
##                [is_gal, is_agn, is_star],
##                ["galaxy", "AGN", "star"]
##            ):
##                ax.errorbar(
##                    c["MAG_BEST_lsst_%s" % b][s],
##                    c["FLUX_RADIUS_lsst_%s" % b][s],
##                    xerr=c["MAGERR_BEST_lsst_%s" % b][s],
##                    linestyle="none",
##                    marker=".",
##                    label=l
##                )
##            ax.set_xlabel("mag measured")
##            ax.set_ylabel("radius half-light measured")
##            if ax == axes.flatten()[0]:
##                ax.set_xlim(ax.get_xlim()[::-1])
##            ax.semilogy()
##            ax.axvline(limiting_magnitude_10yr[b], linestyle='dotted', label="LSST 10yr")
##            #ax.axhline(seeing[b], linestyle='dashed', label="seeing")
##            ax.legend()
##
##    def plot_mag_color_true_measured(self):
##
##        select = self.get_select_mag()
##        c = self.catalog.catalog_photometric[select]
##
##        is_gal  = c["is_gal"]
##        is_agn  = c["is_agn"]
##        is_star = c["is_star"]
##
##        fig, axes = plt.subplots(2, 2, figsize=(2 * 6.4, 2 * 4.8))
##
##        for (c1, c2), ax in zip(
##            ["ug", "gr", "ri", "iz"],
##            axes.flatten()
##        ):
##
##            color1 = c[f"MAG_lsst_{c1}"] - c[f"MAG_lsst_{c2}"]
##            color2 = c[f"MAG_BEST_lsst_{c1}"] - c[f"MAG_BEST_lsst_{c2}"]
##            ax.plot(color1[is_gal],  color2[is_gal],  'd', label='galaxy')
##            ax.plot(color1[is_agn],  color2[is_agn],  'o', label='AGN')
##            ax.plot(color1[is_star], color2[is_star], '*', label='star')
##            ax.set_xlabel(f"({c1} - {c2}) true")
##            ax.set_ylabel(f"({c1} - {c2}) measured")
##            ax.plot([-0.5, 1.5], [-0.5, 1.5], color='k', linestyle='dashed')
##            ax.legend()
##
##    def plot_mag_color_color(self):
##
##        select = self.get_select_mag()
##        c = self.catalog.catalog_photometric[select]
##
##        is_gal  = c["is_gal"]
##        is_agn  = c["is_agn"]
##        is_star = c["is_star"]
##
##        fig, axes = plt.subplots(2, 3, figsize=(3 * 6.4, 2 * 4.8))
##
##        for col, (c1, c2, c3, c4) in enumerate(
##            ["uggr", "grri", "izri"],
##        ):
##
##            color1 = c[f"MAG_lsst_{c1}"] - c[f"MAG_lsst_{c2}"]
##            color2 = c[f"MAG_lsst_{c3}"] - c[f"MAG_lsst_{c4}"]
##            axes[0, col].plot(color1[is_gal],  color2[is_gal],  'd', label='galaxy')
##            axes[0, col].plot(color1[is_agn],  color2[is_agn],  'o', label='AGN')
##            axes[0, col].plot(color1[is_star], color2[is_star], '*', label='star')
##            axes[0, col].set_xlabel(f"({c1} - {c2}) true")
##            axes[0, col].set_ylabel(f"({c3} - {c4}) true")
##            axes[0, col].legend()
##
##            color1 = c[f"MAG_BEST_lsst_{c1}"] - c[f"MAG_BEST_lsst_{c2}"]
##            color2 = c[f"MAG_BEST_lsst_{c3}"] - c[f"MAG_BEST_lsst_{c4}"]
##            axes[1, col].plot(color1[is_gal],  color2[is_gal],  'd', label='galaxy')
##            axes[1, col].plot(color1[is_agn],  color2[is_agn],  'o', label='AGN')
##            axes[1, col].plot(color1[is_star], color2[is_star], '*', label='star')
##            axes[1, col].set_xlabel(f"({c1} - {c2}) measured")
##            axes[1, col].set_ylabel(f"({c3} - {c4}) measured")
##            axes[1, col].legend()
##
##            axes[0, col].set_xlim(-0.5, 1.5)
##            axes[0, col].set_ylim(-0.5, 1.5)
##            axes[1, col].set_xlim(-0.5, 1.5)
##            axes[1, col].set_ylim(-0.5, 1.5)
##
##
##    def plot_redshift_obscured_fraction_mer14_fig7(self):
##
##        cmap = mpl.cm.Blues
##        norm = mpl.colors.BoundaryNorm([42, 43, 44, 45, 46], cmap.N)
##
##        zbins = np.linspace(0.0, 6.0, 61)
##
##        plt.figure()
##
##        # NOTE: luminosity bins close to Mer+14
##        lbins = 42.0, 43.2, 43.5, 43.8, 44.1, 44.3, 44.7, 46.0
##
##        for i, (l_lo, l_hi) in enumerate(zip(lbins[:-1], lbins[1:])):
##            color = "C%d" % i
##
##            dz = 0.40
##            for z in np.arange(0, 6, dz):
##                z_lo = z
##                z_hi = z + dz
##
##                select = (
##                    (l_lo <= self.catalog.catalog["log_LX_2_10"]) * (self.catalog.catalog["log_LX_2_10"] < l_hi) *
##                    (z_lo <= self.catalog.catalog["Z"]) * (self.catalog.catalog["Z"] < z_hi)
##                )
##
##                if select.sum() == 0:
##                    continue
##
##                c = self.catalog.catalog[select]
##                obs = c["is_optical_type2"].sum()
##                tot = c.size
##                plt.plot(z + dz / 2, 100 * obs / tot, '.', color=color)
##
##
##
##    def plot_nh_fraction_ueda2014_fig4(self):
##        fig = plt.figure()
##        ax = fig.gca()
##        import ueda2014
##
##        fig, axes = plt.subplots(6, 1, figsize=(6.4, 3 * 4.8))
##        zs = 0.5, 1.0, 2.0
##        ls = 43, 44, 45
##
##        for ax, (z, l) in zip(axes, product(zs, ls)):
##
##            select = self.catalog._get_select(z=z, l=l)
##            tot = (self.catalog.catalog["log_NH"][select] < 24.0).sum()
##            print(select.sum(), tot)
##
##            for nh_lo in 20, 21, 22, 23, 24:
##
##                nh_hi = nh_lo + 1 + (nh_lo == 24)
##                select_nh = (nh_lo <= self.catalog.catalog["log_NH"]) * (self.catalog.catalog["log_NH"] < nh_hi)
##
##                frac1 = (select * select_nh).sum() / tot
##                frac2 = ueda2014.get_f(l, z, nh_lo)
##
##                ax.plot(0.5 * (nh_lo + nh_hi), frac1, '.')
##                ax.plot(0.5 * (nh_lo + nh_hi), ueda2014.get_f(l, z, nh_lo), 'x')
##
##        return fig, axes
#
#
#    #def plot_flux_flux(self):
#    #    from astropy.coordinates import SkyCoord
#    #    from astropy.table import Table
#    #    sourcetable = Table.read(glob.glob("data/catalog/0.10deg2/imsim_no_qso_host/repo/**/*sourceTable_LSST*.parq", recursive=True)[0])
#
#    #    # Match to the catalog
#    #    c1 = SkyCoord(self.catalog.catalog["RA"], self.catalog.catalog["DEC"], unit="deg")
#    #    c2 = SkyCoord(sourcetable["ra"], sourcetable["dec"], unit="deg")
#    #    idx, d2d, d3d = c2.match_to_catalog_sky(c1)
#
#    #    select_sep = d2d < 1.0 * u.arcsec
#    #    c1 = self.catalog.catalog[idx[select_sep]]
#    #    c2 = sourcetable[select_sep]
#
#    #    columns = ["ap03Flux", "ap06Flux", "ap09Flux", "ap12Flux", "ap35Flux", "ap70Flux", "gaussianFlux", "psfFlux"]
#    #    fig, axes = get_figure(len(columns), 1)
#
#    #    for ax, column_flux in zip(axes.flatten(), columns):
#
#    #        for label, select in [
#    #            ("star",   self.catalog.get_is_star()[idx[select_sep]]),
#    #            ("galaxy", self.catalog.get_is_galaxy()[idx[select_sep]]),
#    #            ("agn",    self.catalog.get_is_agn()[idx[select_sep]]),
#    #        ]:
#    #            ax.errorbar(
#    #                c1["lsst-r_total"][select],
#    #                c2[column_flux][select] / 1e3,
#    #                yerr=c2[column_flux + "Err"][select] / 1e3,
#    #                marker='.',
#    #                linestyle="none",
#    #                label=label,
#    #            )
#
#    #        ax.set_xlabel(f"lsst-r_total [uJy]")
#    #        ax.set_ylabel(f"{column_flux} [uJy]")
#    #        ax.legend()
#    #        ax.loglog()
#    #        ax.set_xlim(0.5, None)
#    #        ax.set_ylim(0.5, None)
#
#    def plot_obscured_fraction_optical(self):
#
#        fig = plt.figure()
#        ax = fig.gca()
#
#        # Binning parameters
#        dz = 0.50
#        dl = 0.50
#        zvec = np.linspace(0.00, 6.00, 601)
#
#        # Cmap/normalization
#        cmap = mpl.cm.viridis
#        norm = mpl.colors.Normalize(vmin=42, vmax=46)
#
#        for llo, lhi in [
#            (42.0, 43.2),
#            (43.2, 43.5),
#            (43.5, 43.8),
#            (43.8, 44.1),
#            (44.1, 44.3),
#            (44.3, 44.7),
#            (44.7, 46.0)
#        ]:
#
#            x = []
#            y1 = []
#            y2 = []
#
#            for zlo in np.arange(0.20, 5.50, dz):
#
#                zhi = zlo + dz
#                print(zlo, zhi, llo, lhi)
#
#                c = self._get_c(zlo=zlo, zhi=zhi, llo=llo, lhi=lhi)
#                if not c.size:
#                    continue
#
#                N_obs = np.sum(c["is_agn_ctn"] * c["is_optical_type2"])
#                N_tot = np.sum(c["is_agn_ctn"])
#                f_obs  = 100 * N_obs / N_tot
#
#                # NOTE: assuming poissonian errors for both counts
#                df_obs = np.sqrt(
#                    np.power(100 * N_obs ** .5 / N_tot, 2) +
#                    np.power(100 * N_obs / N_tot ** 1.5, 2)
#                )
#
#                x  += [zlo + dz / 2]
#                y1 += [f_obs - df_obs]
#                y2 += [f_obs + df_obs]
#
#            # Plot the Mock
#            ax.fill_between(
#                x,
#                np.clip(y1, 0, 100),
#                np.clip(y2, 0, 100),
#                color=cmap(norm(0.5 * (llo + lhi))),
#                alpha=0.20,
#                label=f"{llo:.2f} < log LX < {lhi:.2f}"
#            )
#
#            # Plot Merloni+2014
#            ax.plot(
#                zvec,
#                100 * self.catalog.merloni2014.get_f_obs(zvec, 0.5 * (llo + lhi)),
#                color=cmap(norm(llo + dl / 2)),
#                linestyle="dotted"
#            )
#
#        ax.legend()
#        ax.set_xlabel("redshift")
#        ax.set_ylabel("Obscured AGN fraction [%]")
#
#
#    #def plot_lightcurve_bands(self):
#
#    #    """Plot some AGN lightcurve examples from the mock"""
#
#    #    plt.figure()
#    #    plt.title("Example DRW lightcurves")
#
#    #    cs = sorted(self.catalog.catalog, key=lambda c: c["lsst-r_point"])[::-1]
#    #    for c in cs:
#
#    #        if not c["is_agn_ctn"]:
#    #            continue
#
#    #        for b in "ugrizy":
#    #            tt, lc1 = self.catalog.get_lightcurve(band=f"lsst-{b}", idxs=c["ID"])
#    #            plt.plot(tt, lc1)
#
#    #        break
#
#    #    plt.xlabel("time [days]")
#    #    plt.ylabel("lsst-r [uJy]")
#    #    plt.ylim(plt.ylim()[::-1])
#
#
#    def plot_quasar_bolometric_luminosity_function(self):
#
#        import qlf
#        import shen2020
#
#        lx = np.linspace(42, 46, 401)
#        fig, axes = get_figure(4, 3)
#
#        for ax, zlo, z, zhi in zip(
#                axes.flatten(),
#                np.unique(qlf.TABLE_MIYAJI15[:, 0]),
#                np.unique(qlf.TABLE_MIYAJI15[:, 2]),
#                np.unique(qlf.TABLE_MIYAJI15[:, 1]),
#        ):
#
#            # Get the X-ray luminosity function
#            phi_lx = qlf.get_xray_luminosity_function(lx, z)
#
#            # Get bolometric luminosity corresponding to the XLF
#            lbol = shen2020.get_luminosity_bolometric(lx, *shen2020.BAND_PARAMETERS_BC["Hard X-ray"])
#            ax.plot(lbol, phi_lx)
#
#            # Get the bolometric luminosity function from Shen
#            lbol = np.linspace(43, 50, 701)
#            qlf1 = shen2020.get_qlf(lbol, zlo)
#            qlf2 = shen2020.get_qlf(lbol, z)
#            qlf3 = shen2020.get_qlf(lbol, zhi)
#            ax.plot(lbol, qlf2)
#            ax.fill_between(lbol, np.minimum(qlf1, qlf3), np.maximum(qlf1, qlf3), alpha=0.20)
#
#            ax.semilogy()
#            ax.set_ylim(1e-8, 1e-2)
#            ax.set_title(f"{zlo} < z < {zhi}")
#            ax.set_xlabel(r"$\log L_\mathrm{bol}$ [erg/s]")
#            ax.set_ylabel(r"$\Phi_{L_\mathrm{bol}}$ [1/Mpc3/dex]")
#
#        axes[-1, -1].remove()
#        return fig, axes
#
#    def plot_color_magnitude(self):
#
#        cmds = [
#                # magabs            color1              color1
#                ("lsst-g",          "lsst-u",           "lsst-g"),
#                ("lsst-r",          "lsst-g",           "lsst-r"),
#                ("lsst-i",          "lsst-r",           "lsst-i"),
#                ("lsst-z",          "lsst-i",           "lsst-z"),
#                ("lsst-y",          "lsst-z",           "lsst-y"),
#                ("euclid-nisp-J",   "euclid-nisp-Y",    "euclid-nisp-J"),
#                ("euclid-nisp-H",   "euclid-nisp-J",    "euclid-nisp-H"),
#        ]
#
#        fig, axes = get_figure(3, 3)
#        for ax, (b, c1, c2) in zip(axes.flatten(), cmds):
#
#            print(b, c1, c2)
#
#            c = self.catalog.catalog[self.catalog.catalog["Z"] > 0]
#
#            x = c[f"magabs_{b}_total"]
#            y1 = util.flux_to_mag(c[f"{c1}_total"])
#            y2 = util.flux_to_mag(c[f"{c2}_total"])
#            y = y1 - y2
#
#            select = np.isfinite(x) * np.isfinite(y)
#            x = x[select]
#            y = y[select]
#
#            # Returns ~99% interval of the data
#            def get_limit(a):
#                lo, hi = np.quantile(a, [0.01, 0.99])
#                return lo, hi + 1e-6
#            print(get_limit(x))
#            print(get_limit(y))
#
#            bins = (
#                np.arange(*get_limit(x), 0.01),
#                np.arange(*get_limit(y), 0.01),
#            )
#            im = ax.hist2d(x, y, bins=bins, norm=mpl.colors.LogNorm())[-1]
#
#            ax.set_xlabel(f"{b}")
#            ax.set_ylabel(f"{c1} - {c2}")
#            plt.colorbar(im, ax=ax)
#
#    def plot_color_color(self):
#
#        from corner import corner
#
#
#        # The colors to plot
#        color_color = [
#            ("lsst-u", "lsst-g", "lsst-g", "lsst-r"),
#            ("lsst-g", "lsst-r", "lsst-r", "lsst-i"),
#            ("lsst-r", "lsst-i", "lsst-i", "lsst-z"),
#            ("lsst-i", "lsst-z", "lsst-z", "lsst-y"),
#            ("euclid-nisp-Y", "euclid-nisp-J", "euclid-nisp-J", "euclid-nisp-H"),
#        ]
#
#        fig, axes = get_figure(3, 3)
#
#        for ax, (c1, c2, c3, c4) in zip(axes.flatten(), color_color):
#            x = self.catalog.catalog[f"{c1}_total"] - self.catalog.catalog[f"{c2}_total"]
#            y = self.catalog.catalog[f"{c3}_total"] - self.catalog.catalog[f"{c4}_total"]
#
#            ax.hist2d(x, y)
#
#        #labels = [f"{c1} - {c2}" for c1, c2 in colors]
#        #mag = np.array([self.catalog.catalog[f"{c1}_total"] - self.catalog.catalog[f"{c2}_total"] for c1, c2 in colors]).T
#        #corner(mag, labels=labels)
#
#    def _plot_luminosity_function(self, band, redshift):
#
#        dataid, key = {
#            "Mid-IR":     (-2, "magabs-mock-15um_total"),
#            "B band":     (-1, "magabs-mock-4400_total"),
#            "UV":         (-5, "magabs-mock-1450_total"),
#            "Hard X-ray": (-4, "log_LX_2_10"),
#        }[band]
#
#        # Plot Shen+ 2020
#        from quasarlf.pubtools.load_observations import my_get_data
#        x, dx, y, dy = my_get_data(dataid, redshift)
#        plt.plot(x, y, '.', label="Shen+ 2020")
#        plt.legend()
#    #def plot_quasar_luminosity_function_15um(self): self._plot_quasar_luminosity_function("Mid-IR",     1.0)
#    #def plot_quasar_luminosity_function_4400(self): self._plot_quasar_luminosity_function("B band",     1.0)
#    #def plot_quasar_luminosity_function_1450(self): self._plot_quasar_luminosity_function("UV",         1.0)
#    #def plot_quasar_luminosity_function_X(self):    self._plot_quasar_luminosity_function("Hard X-ray", 1.0)
#
#    def plot_quasar_luminosity_function_all(self):
#
#        from quasarlf.pubtools.load_observations import my_get_data
#        zs = 0.5, 1.0, 2.0, 3.0, 4.0, 5.0
#        bands = "Mid-IR", "B band", "UV", "Hard X-ray"
#        dz = 0.20
#        dbin = 0.20
#        dlx = 0.20
#
#        # Optimization dictionary
#        import qlf
#        from merloni2014 import Merloni2014
#        seds1 = {}
#        seds2 = {}
#        Phi_x = {}
#        frac_obs = {}
#        mer14 = Merloni2014(True, False, 0.05, 0.95)
#        lxs = np.arange(40, 46, dlx)
#        Nrandom = 100
#
#        if False:
#            for z in zs:
#                distance_cm = COSMO70.comoving_distance(z).to(u.cm)
#                for lx in lxs:
#                    Phi_x[z, lx] = qlf.get_xray_luminosity_function_bon16(lx + dlx / 2, z)
#                    frac_obs[z, lx] = mer14.get_f_obs(z, lx + dlx / 2)
#                    for i in range(Nrandom):
#                        seds1[z, lx, i] = util.mock_lx_to_sed(lx=lx + dlx / 2, z=z, distance_cm=distance_cm, ebv=util.get_E_BV(type2=0), scatter=1, seed=i)
#                        seds2[z, lx, i] = util.mock_lx_to_sed(lx=lx + dlx / 2, z=z, distance_cm=distance_cm, ebv=util.get_E_BV(type2=1), scatter=1, seed=i + Nrandom)
#                        print("%.1f" % z, "%.2f" % lx, "%3d" % i, end='\r')
#
#        # Initialize the figure
#        fig, axes = get_figure(len(zs), 4)
#
#        # Plot Shen+ 2020 compilation
#        for col, band in enumerate(bands):
#
#            dataid, key, bins, xlabel, ylabel, lam = {
#                "Mid-IR":     (-2, "magabs_mock-15um", np.arange(-32, -20 + 1e-6, dbin), r"$M_{\rm J}$ [Lsun]", r"$\Phi_{\rm J}$ [1/Mpc3/mag]", 15 * u.um),
#                "B band":     (-1, "magabs_mock-4400", np.arange(-32, -20 + 1e-6, dbin), r"$M_{\rm B}$ [Lsun]", r"$\Phi_{\rm B}$ [1/Mpc3/mag]", 4400 * u.angstrom),
#                "UV":         (-5, "magabs_mock-1450", np.arange(-32, -20 + 1e-6, dbin), r"$M_{1450}$",         r"$\Phi_{1450}$  [1/Mpc3/mag]", 1450 * u.angstrom),
#                "Hard X-ray": (-4, "log_LX_2_10",      np.arange(+40, +46 + 1e-6, dbin), r"$\log L_{\rm X}$",   r"$\Phi_{\rm X}$ [1/Mpc3/dex]", None),
#            }[band]
#
#            for row, z in enumerate(zs):
#
#                print(band, z)
#
#                # Plot Shen+ 2020
#                x, dx, y, dy = my_get_data(dataid, z)
#
#                # NOTE: convert from Lsun to AB magnitude
#                if band in ("Mid-IR", "B band"):
#                    x = 10 ** x * u.L_sun / lam.to(u.Hz, equivalencies=u.spectral())
#                    L_AB = 3631 * u.Jy * 4 * np.pi * (10 * u.pc) ** 2
#                    x = -2.5 * np.log10((x / L_AB).cgs)
#                    y -= np.log10(2.5)
#
#                axes[row, col].errorbar(x, 10 ** y, yerr=None, marker='o', linestyle="none", label="Observed (Shen+ 2020")
#
#                # Plot the mock
#                c = self._get_c(z - dz / 2, z + dz / 2, is_agn=True)
#                if c.size:
#
#                    if band == "Hard X-ray":
#
#                        # AGN only
#                        x, dx, y, dy = util.get_key_function(
#                                bins=bins,
#                                x=c[key],
#                                zmin=z - dz / 2,
#                                zmax=z + dz / 2,
#                                area_deg2=self.catalog.area
#                        )
#                        axes[row, col].errorbar(x, y, xerr=dx, yerr=dy, marker='d', linestyle="none", label="Mock AGN")
#
#                    else:
#
#                        # PLOT AGN and AGN + GALAXY LF
#                        # NOTE: convert to logLsun........
#                        for k, marker, label in [
#                            ("_point", 'd', "Mock AGN"),
#                            ("_total", 's', "Mock AGN + galaxy"),
#                        ]:
#
#                            log_L_band = c[key + k]
#
#                            if False:
#                                if band != "UV":
#                                    nu = lam.to(u.Hz, equivalencies=u.spectral())
#                                    L_ab = (4 * np.pi * (10 * u.pc) ** 2 * 3631 * u.Jy).cgs
#                                    L_band = nu * L_ab * 10 ** (-0.4 * c[key + k])
#                                    log_L_band = np.log10(L_band.to(u.L_sun).value)
#
#                            x, dx, y, dy = util.get_key_function(
#                                    bins=bins,
#                                    x=log_L_band,
#                                    zmin=z - dz / 2,
#                                    zmax=z + dz / 2,
#                                    area_deg2=self.catalog.area
#                            )
#                            axes[row, col].errorbar(x, y, xerr=dx, yerr=dy, marker=marker, linestyle="none", label=label)
#
#
#                ###############################################################################
#                # Plot the "theoretical" type1 estimate
#                if False:
#                    if band == "Hard X-ray":
#                        Phi1 = np.array([Phi_x[z, lx] * (1 - frac_obs[z, lx]) for lx in lxs])
#                        Phi2 = np.array([Phi_x[z, lx] * frac_obs[z, lx] for lx in lxs])
#                        cens = lxs + dlx / 2.
#                    else:
#                        bins_all = np.arange(-100, 100, dbin)
#                        cens = 0.5 * (bins_all[1:] + bins_all[:-1])
#                        Phi1 = np.zeros(cens.size)
#                        Phi2 = np.zeros(cens.size)
#                        for lx in lxs:
#                            M1 = []
#                            M2 = []
#                            for i in range(Nrandom):
#                                flux_band1 = sed.get_flux_band(*seds1[z, lx, i], key.replace("magabs_", ""))
#                                flux_band2 = sed.get_flux_band(*seds2[z, lx, i], key.replace("magabs_", ""))
#
#                                M1.append(util.flux_to_mag(flux_band1))
#                                M2.append(util.flux_to_mag(flux_band2))
#
#                            Phi_sample1 = np.histogram(M1, bins=bins_all, density=True)[0]
#                            Phi_sample2 = np.histogram(M2, bins=bins_all, density=True)[0]
#                            Phi1 += Phi_sample1 * Phi_x[z, lx] * (1 - frac_obs[z, lx]) * dlx
#                            Phi2 += Phi_sample2 * Phi_x[z, lx] * frac_obs[z, lx] * dlx
#
#                    axes[row, col].plot(cens, Phi1,        color='blue', linestyle="dashed", label="BD+SED type1")
#                    axes[row, col].plot(cens, Phi2,        color='red',  linestyle="dotted", label="BD+SED type2")
#                    axes[row, col].plot(cens, Phi1 + Phi2, color='k',    linestyle="solid",  label="BD+SED all")
#                ###############################################################################
#
#                axes[row, col].semilogy()
#                axes[row, col].set_xlim(np.min(bins), np.max(bins))
#                axes[row, col].set_ylim(1e-9, 1e-2)
#                axes[row, col].set_xlabel(xlabel)
#                axes[row, col].set_ylabel(ylabel)
#                axes[row, col].text(
#                    0.90,
#                    [0.10, 0.90][band == "Hard X-ray"],
#                    f"$z \sim {z}$",
#                    horizontalalignment="right",
#                    transform=axes[row, col].transAxes
#                )
#                axes[row, col].legend()
#
#        return fig, axes
#
#
#    def plot_mstar_mbh(self):
#
#        fig = plt.figure()
#        ax = fig.gca()
#
#        mvec = np.linspace(8.5, 13.5, 51)
#        dm = np.diff(mvec)[0]
#
#        from mbh import (
#            get_log_mbh_from_bulge,
#            get_log_mbh_shankar2019,
#            get_delta_log_mbh_shankar2019,
#            get_log_mbh_sg2016,
#            get_log_mbh_continuity
#        )
#
#        # Plot the mock
#        every_z = 1
#        for zlo, zhi in ZBINS[::every_z]:
#            xs = []
#            ys = []
#            print('\t', zlo, zhi)
#            for mlo, mhi in zip(mvec, mvec + dm):
#                c = self._get_c(zlo=zlo, zhi=zhi, mlo=mlo, mhi=mhi)
#                if c.size < 30:
#                    continue
#                xs.append(mlo + dm / 2)
#                ys.append(np.quantile(c["MBH"], [0.16, 0.68]))
#            xs, ys = map(np.array, [xs, ys])
#            if ys.size == 0:
#                continue
#
#            # Plot the spline
#            log_mbh_continuity = [get_log_mbh_continuity(x, (zlo + zhi) / 2) for x in xs ]
#            ax.plot(xs, log_mbh_continuity, linestyle="dashed")
#
#            # Plot the mock
#            ax.fill_between(xs, ys[:, 0], ys[:, 1], alpha=0.20, label=f"$%.1f \leq z < %.1f$" % (zlo, zhi))
#
#
#        ## Plot y = x
#        #ax.plot(mvec, mvec, color='k', linestyle='dotted', label="y = x")
#
#        # Plot naive mbulge-mbh
#        ax.plot(
#            mvec,
#            get_log_mbh_from_bulge(mvec),
#            color="black",
#            label=r"$M_{\rm BH} = M_{\rm star} / 500$"
#        )
#
#        # Plot Shankar +2019
#        log_mbh = get_log_mbh_shankar2019(mvec)
#        ax.plot(mvec, log_mbh, label="Shankar+ 2019", color='red')
#
#        # Plot Shankar errors?
#        delta_log_mbh = get_delta_log_mbh_shankar2019(mvec)
#        #ax.fill_between(
#        #    mvec,
#        #    log_mbh - delta_log_mbh,
#        #    log_mbh + delta_log_mbh,
#        #    alpha=0.20,
#        #    color='red'
#        #)
#
#        # Plot S&G 2016
#        log_mbh = get_log_mbh_sg2016(mvec)
#        ax.plot(mvec, log_mbh, label="S&G 2016", color='purple')
#
#        ax.legend(fontsize="x-small")
#        ax.set_xlabel(r"$\log M_\mathrm{star}$ [Msun]")
#        ax.set_ylabel(r"$\log M_\mathrm{BH}$ [Msun]")
#        ax.set_xlim(8.5, 12.5)
#        ax.set_ylim(5.0, 9.0)
#        ax.set_aspect(1.0)
#
#        return fig, ax
#
##    def plot_w_theta_redshift(self, fact=50, split=50):
##
##        from corr import Correlation, get_region_radec
##
##        # Perform the selection
##        data = self.catalog.catalog
##
##        fig, axes = plt.subplots(
##            3,
##            2,
##            tight_layout=True,
##            sharex=False,
##            sharey=False,
##            figsize=(2 * 6.4, 3 * 4.8)
##        )
##
##        for ax, (zmin, zmax) in zip(
##            axes.flatten(),
##            [
##                (0.10, 0.70),
##                (0.70, 1.20),
##                (1.20, 1.70),
##                (1.70, 2.40),
##                (2.40, 3.40),
##                (3.40, 4.90),
##            ]
##        ):
##
##            for m_min, m_max in [(0.0, 10.48), (10.48, np.inf)]:
##
##                select= (
##                    (zmin < data["Z"]) * (data["Z"] < zmax) *
##                    (m_min < data["M"]) * (data["M"] < m_max)
##                )
##                _data = data[select]
##                if _data.size == 0:
##                    continue
##
##                # Shuffle the unclustered positions to get the random catalog
##                _rand = self.catalog.get_random(select, fact=fact)
##
##                region1, region2 = get_region_radec(
##                    _data["RA"], _data["DEC"],
##                    _rand["RA"], _rand["DEC"],
##                    3, 3
##                )
##
##                # Estimate w_theta
##                w_theta = Correlation(
##                    np.logspace(0, 2, 11) * u.arcsec.to(u.deg),
##                    _data["RA"], _data["DEC"], None, None,
##                    _rand["RA"], _rand["DEC"], None, None,
##                    split=split
##                )
##                ax.errorbar(
##                    w_theta.centers * u.deg.to(u.arcsec),
##                    w_theta.xi,
##                    yerr=w_theta.dxi_poisson,
##                    marker='.',
##                    linestyle="none",
##                    label="%.2f < logM < %.2f" % (m_min, m_max)
##                )
##                ax.set_ylim(1e-4, 1e1)
##
##            # Plot Schreiber+ 17 for reference
##            x, y = np.loadtxt("data/schreiber2017/%.1f_z_%.1f.dat" % (zmin, zmax)).T
##            idx = np.argsort(x)
##            ax.loglog(x[idx], y[idx], label="Schreiber+ 17")
##
##            ax.set_title("%.1f < z < %.1f" % (zmin, zmax))
##            ax.set_xlabel("theta [arcsec]")
##            ax.set_ylabel(r"w(theta) ($\pm 1 \sigma$ poisson error)")
##            ax.legend()
##
##        return fig, ax
#
#    def _plot_sed(self, ax, i, wav_min_micron=0.1, wav_max_micron=2.1):
#
#        """Plot a single SED"""
#
#        select = self.catalog.catalog["ID"] == i
#        c = self.catalog.catalog[select][0]
#        my_id = c["ID"]
#
#        dirname = os.path.dirname(self.catalog.filename_egg)
#
#        # Load the bulge SED
#        os.system(f"egg-getsed seds={dirname}/egg-seds.dat id={my_id} component=bulge")
#        filename = dirname + f"/egg-seds-bulge-{my_id}.fits"
#        bulge = fitsio.read(filename)
#
#        # Load the disk SED
#        os.system(f"egg-getsed seds={dirname}/egg-seds.dat id={my_id} component=disk")
#        filename = dirname + f"/egg-seds-disk-{my_id}.fits"
#        disk = fitsio.read(filename)
#
#        # Load the AGN SED
#        agn = None
#        try:
#            dirname = os.path.dirname(self.catalog.filename_agn)
#            agn = fitsio.read(f"{dirname}/seds/agn-seds-{my_id}.fits")
#        except:
#            pass
#
#        # Plot the galaxy/AGN SED
#        total = {
#            "LAMBDA": bulge["LAMBDA"],
#            "FLUX": np.zeros_like(bulge["LAMBDA"]),
#        }
#        for f in disk, bulge, agn:
#            if f is None:
#                continue
#            total["FLUX"][0] += np.interp(total["LAMBDA"][0], f["LAMBDA"][0], f["FLUX"][0], left=0., right=0.)
#
#        offset = 1
#        import astropy.constants
#        for color, my_xy in [
#            ("red", disk),
#            ("purple", bulge),
#            ("blue", agn),
#            ("black", total),
#        ]:
#
#            if my_xy is None:
#                continue
#
#            my_x = (my_xy["LAMBDA"][0] * u.um).to(u.angstrom)
#
#            # Convert to 1e-17 erg/s/cm2/angstrom
#            my_y = (my_xy["FLUX"][0] * u.uJy).to(
#                u.erg / u.cm ** 2 / u.s / u.angstrom,
#                equivalencies=u.spectral_density(my_x)
#            ) / 1e-17
#
#            ax.plot(my_x, my_y, color=color)
#
#        # Plot the broad-band magnitudes
#        from catalog_galaxy_agn import BAND_LAM_FWHM
#        for band in map(lambda a: a.strip(), self.catalog.egg["BANDS"][0]):
#
#            if True:
#                continue
#
#            if not (("euclid-" in band) or ("lsst-" in band)):
#                continue
#
#            lam = BAND_LAM_FWHM[band][0]
#            color = get_color(lam)
#            ax.plot(
#                BAND_LAM_FWHM[band][0] * u.um.to(u.angstrom),
#                offset * (c[band + "_bulge"] + c[band + "_disk"] + c[band + "_point"]),
#                'o',
#                color='k',
#                markersize=10,
#                fillstyle="none"
#            )
#
#            # Plot the filters
#            if False:
#                for filename_in in glob.glob("egg/share/filter-db/inaf/*.dat"):
#
#                    test = filename_in.replace("_", "-").replace(".", "-").lower()
#                    if band.lower() not in test:
#                        continue
#
#                    lam, res = np.loadtxt(filename_in).T
#                    ax.fill_between(lam / 10000, 1e-10, res, alpha=0.2, color=color, linewidth=0)
#
#        text = '\n'.join([
#            r"${\rm ID} = %d$" % my_id,
#            r"$z = %.2f$" % c["Z"],
#            r"$\log M = %.2f$" % c["M"],
#            r"$\log \lambda_{\rm SAR} = %.2f$" % c["log_lambda_SAR"],
#            r"type = %s" % (1 + c["is_optical_type2"])
#        ])
#        ax.text(0.05, 0.95, text, transform=ax.transAxes, verticalalignment="top")
#
#        ax.set_xlabel("observed wavelength [angstrom]")
#        ax.set_ylabel("observed flux [1e-17 erg/s/cm2/angstrom]")
#        #ax.set_xlim(wav_min_micron * 0.90, wav_max_micron * 1.10)
#        #ax.set_ylim(1e-2, 1e6)
#        ax.set_xlim(900, 500000)
#        ax.set_ylim(1e-4, None)
#        #ax.semilogy()
#        ax.loglog()
#
#    def plot_sed_peculiar(self):
#
#        select = (
#            self.catalog.catalog["is_agn"] *
#            self.catalog.catalog["is_optical_type2"] *
#            (self.catalog.catalog["E_BV"] == 9.0) *
#            (self.catalog.catalog["lsst-g_point"] / self.catalog.catalog["lsst-g_total"] > 0.50)
#        )
#
#        if select.sum() == 0:
#            return
#
#        f1 = self.catalog.catalog[select]["lsst-g_bulge"]
#        f2 = self.catalog.catalog[select]["lsst-g_disk"]
#        f3 = self.catalog.catalog[select]["lsst-g_point"]
#        ftot = f1 + f2 + f3
#
#        N = np.ceil(select.sum() ** .5)
#
#        fig, axes = get_figure(int(N), int(N))
#        for i, ax in zip(self.catalog.catalog["ID"][select], axes.flatten()):
#            self._plot_sed(ax, i)
#
#        return fig, axes
#
#    def plot_logn_logflux(self):
#        fig, axes = plt.subplots(2, 3, figsize=(3 * 6.4, 2 * 4.8))
#
#        import my_lsst
#
#        for ax, b in zip(axes.flatten(), "ugrizy"):
#
#            select_lx = self.catalog.catalog["log_LX_2_10"] > 42
#
#            for select, color, label in [
#                (np.ones_like(self.catalog.catalog, dtype=bool),                                    'black',  "all"),
#                (self.catalog.get_is_agn() * select_lx,                                             'purple', "AGN"),
#                (self.catalog.get_is_agn() * ~self.catalog.catalog["is_optical_type2"] * select_lx, 'blue',   "AGN type1"),
#                (self.catalog.get_is_galaxy(),                                                      'green',  "galaxy"),
#                (self.catalog.get_is_star(),                                                        'red',    "star"),
#            ]:
#
#                for m in np.arange(15, 30, 1.00):
#                    print(b, m)
#                    N = np.sum(select * (util.flux_to_mag(self.catalog.catalog[f"lsst-{b}_total"]) < m))
#                    ax.plot(m, N / self.catalog.area, '.', color=color, label=label if m == 15 else None)
#
#            # Plot limiting mag and QSO number counts
#            lsst_fov = 9.6
#            ax.plot(my_lsst.limiting_magnitude_cosmos[b], my_lsst.qso_number_counts_cosmos[b] / lsst_fov, 'x', color="blue", markersize=12)
#
#            ax.set_xlabel(f"magnitude lsst-{b} [ABmag]")
#            ax.set_ylabel(f"N(<magnitude) [1/deg2]")
#            ax.semilogy()
#            ax.legend()
#
#        return fig, axes
#
#
