import astropy.units as u
import glob
import fitsio
import re
import matplotlib.pyplot as plt
import numpy as np

filenames = sorted(glob.glob("data/catalog/100.00deg2_0.21_z_4.00_ra_+150.11916667_dec_+2.20583333_m22_0.??/agn.fits"))
fobss = [re.findall("(0[.][0-9][0-9])/", filename)[0] for filename in filenames]
print('bmag', ' '.join(f"{float(fobs):+5.2f}" for fobs in fobss))

for bbin in [15, 16, 17, 18]:

    vals = []

    for filename, fobs in zip(filenames, fobss):
        f = fitsio.read(filename)
        f = f[f["is_agn"] * ~f["is_optical_type2"]]
        B = -2.5 * np.log10(f["johnson-B_point"] * 1e-6 / 3631)

        select_z = (0.40 <= f["Z"]) * (f["Z"] < 0.60)
        select_l = f["log_LX_2_10"] > 44.41
        select = select_z * select_l
        from copy import deepcopy
        B2 = deepcopy(B) 
        B2[select] = np.where(np.random.rand(select.sum()) < 0.01, B2[select], 99)

        print(np.sum(B < bbin))
        print(np.sum(B2 < bbin))
        print(bbin, np.log10(np.sum(B < bbin) / 100), np.log10(np.sum(B2 < bbin) / 100))
        break

        if False:
            select = B < 16
            print("N (B < 16)", select.sum())
            print("z",              f[select]["Z"])
            print("log_LX",         f[select]["log_LX_2_10"])
            print("M",              f[select]["M"])
            print("log_lambda_SAR", f[select]["log_lambda_SAR"])
            print("log_L_2_keV",    f[select]["log_L_2_keV"])
            print("log_L_2500",     f[select]["log_L_2500"])

            import bongiorno2016
            import xlf
            lx = f[select]["log_LX_2_10"]
            z = f[select]["Z"]
            miy15 = np.loadtxt("data/miyaji2015/xlf_miyaji2015_table4.dat")
            dist = (z[:, None] - miy15[:, 2][None, :]) ** 2 + (lx[:, None] - miy15[:, 5][None, :]) ** 2
            dist = np.argmin(dist, axis=1)

            miy = np.log10(miy15[:, -2])
            bon = np.log10(bongiorno2016.get_Phi_lx(miy15[:, 5], miy15[:, 2]))
            delta = bon - miy
            print(np.quantile(delta, [0.16, 0.50, 0.84]))
            print(np.quantile(delta[miy15[:, 2] == 0.497], [0.16, 0.50, 0.84]))

            miy = np.log10(miy15[dist, -2])
            sz = miy15[:, 2] == 0.497
            plt.errorbar(miy15[sz, 5], miy15[sz, -2], miy15[sz, -1], marker='.', linestyle="none", label="miy")

            bon = np.log10(bongiorno2016.get_Phi_lx(miy15[sz, 5], miy15[sz, 2]))
            plt.plot(miy15[sz, 5], 10 ** bon, '.', label="bon")

            sz2 = (0.40 <= f["Z"]) * (f["Z"] < 0.60)
            from astropy.cosmology import FlatLambdaCDM
            cosmo = FlatLambdaCDM(Om0=0.30, H0=70)
            volume1 = (cosmo.comoving_volume(0.40) * 100 / (4 * np.pi * u.sr.to(u.deg ** 2))).value
            volume2 = (cosmo.comoving_volume(0.60) * 100 / (4 * np.pi * u.sr.to(u.deg ** 2))).value
            bins = np.unique(np.append(miy15[sz, 3], miy15[sz, 4]))
            mock = np.histogram(f["log_LX_2_10"][f["is_agn"] * sz2], bins=bins)[0].astype(np.float64)
            dmock = mock ** .5
            mock /= (volume2 - volume1) * (miy15[sz, 4] - miy15[sz, 3])
            dmock /= (volume2 - volume1) * (miy15[sz, 4] - miy15[sz, 3])
            plt.errorbar(miy15[sz, 5], mock, dmock, label="mock")

            select = miy15[sz, 5] > 44.5
            print()
            print(miy15[sz, -2][select])
            print(10 ** bon[select])
            print(mock[select])
            print()
            quit()

            plt.semilogy()
            plt.legend()
            plt.savefig("xlf_bon_miy_mock.pdf")
            #print(bon)
            #print(miy)
            print(bon - miy15[sz, -2])
            quit()

        if False:
            if bbin == 15 and filename == filenames[0]:
                x = f["log_L_2500"]
                y = f["log_L_2_keV"]
                select = B < 18
                plt.plot(x, y, '.')
                plt.plot(x[select], y[select], '.')

        fobs = re.findall("(0[.][0-9][0-9])/", filename)[0]

        select = B < bbin
        N = np.sum(select)
        logN = -np.inf if N == 0 else np.log10(N / 100.0)
        vals.append(logN)

        plt.plot(bbin, logN, '.', color="C%d" % filenames.index(filename))

        if bbin == 15:
            plt.savefig("xy.pdf")

    if False:
        print("%4d" % bbin, ' '.join(f"{v:+5.2f}" for v in vals))

plt.savefig("xy.pdf")
