#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-03-06 13:54:13

"""
Correlation function class for w(theta) and wp(rp)
"""

import argparse
from copy import deepcopy
from itertools import combinations, product

from astropy.cosmology import FlatLambdaCDM
from colossus.cosmology import cosmology
import Corrfunc
import fitsio
import numpy as np


class Correlation:

    def __init__(
        self, binfile,
        ra1, dec1, dc1, region1,
        ra2, dec2, dc2, region2,
        n_bootstrap=100, n_bootstrap_oversample=1,
        pi_max=None, dpi=None,
        null_value=np.nan,
        max_distance_region=10 ** 5,
        split=1,
    ):

        # Initialize the values
        self.binfile = binfile
        self.n_bootstrap = n_bootstrap
        self.n_bootstrap_oversample = n_bootstrap_oversample
        self.pi_max = pi_max
        self.dpi = dpi
        self.null_value = null_value # Null value used for the correlation function with 0 pairs
        self.max_distance_region = max_distance_region
        self.split = split

        dc1 = dc1 if dc1 is not None else np.zeros_like(ra1)
        dc2 = dc2 if dc2 is not None else np.zeros_like(ra2)
        region1 = region1 if region1 is not None else np.zeros_like(ra1)
        region2 = region2 if region2 is not None else np.zeros_like(ra2)

        self.centers = (binfile[:-1] * binfile[1:]) ** .5
        self.args1 = np.vstack([ra1, dec1, dc1, region1]).T
        self.args2 = np.vstack([ra2, dec2, dc2, region2]).T
        self.regions = np.unique(np.concatenate((region1, region2))).astype(np.int64)
        if null_value is np.nan:
            self.mask = np.ma.masked_invalid
        else:
            self.mask = lambda x: np.ma.masked_equal(x, null_value)

        # Calculate the pairs per each region
        self.dd_region = self.get_pair(1, self.args1, self.args1)
        self.dr_region = self.get_pair(0, self.args1, self.args2)
        self.rr_region = self.get_pair(1, self.args2, self.args2, self.split)

        # Correlation function and poisson errors
        self.n1, self.n2, self.dd, self.dr, self.rr, self.xi, self.dxi_poisson = self.get_correlation(self.regions)
        self.covariance_poisson = np.diag(self.dxi_poisson)
        if self.regions.size < 2:
            return

        # Jackknife
        self.xi_jackknife, self.covariance_jackknife = self.get_correlation_resample("jackknife")
        self.dxi_jackknife = np.diag(self.covariance_jackknife) ** .5

        # Bootstrap
        self.xi_bootstrap, self.covariance_bootstrap = self.get_correlation_resample("bootstrap")
        self.dxi_bootstrap = np.diag(self.covariance_bootstrap) ** .5


    def get_pair(self, autocorr, args1, args2, split=1):
        """Calculate all the pairs within the regions"""

        ret = {}

        for r1, r2 in product(self.regions, self.regions):

            if (autocorr and r1 > r2) or np.abs(r2 - r1) > self.max_distance_region:
                continue

            s1 = args1[:, -1] == r1
            s2 = args2[:, -1] == r2

            # Skip empty regions and autocorr from regions with only 1 point
            if (
                (s1.sum() == 0 or s2.sum() == 0)
                or (autocorr and r1 == r2 and s1.sum() == 1 and s2.sum() == 1)
            ):
                continue

            print("Calculating pairs from regions (%5d, %5d) with (%7d, %7d) objects" % (r1, r2, s1.sum(), s2.sum()))
            ret[r1, r2] = get_pair(
                autocorr and r1 == r2,
                self.binfile,
                *args1[s1, :-1].T,
                *args2[s2, :-1].T,
                pi_max=self.pi_max,
                dpi=self.dpi,
                split=split,
            )

            if autocorr:
                ret[r2, r1] = ret[r1, r2]

        return ret

    def get_correlation(self, regions):

        """Calculate the correlation function for the given regions"""

        print("Calculating the correlation function for regions", regions)

        def get_pp(pp_region):
            """Sum the pairs that reside in the regions"""
            pp = None
            for r1, r2 in product(regions, regions):
                if (r1, r2) not in pp_region:
                    continue
                if pp is None:
                    pp = np.zeros_like(pp_region[r1, r2])
                pp += pp_region[r1, r2]
            return pp

        n1 = 0
        n2 = 0
        for region in regions:
            n1 += np.in1d(self.args1[:, -1], [region]).sum()
            n2 += np.in1d(self.args2[:, -1], [region]).sum()

        dd = get_pp(self.dd_region)
        dr = get_pp(self.dr_region)
        rr = get_pp(self.rr_region)

        ret = (n1, n2, dd, dr, rr) + get_correlation(n1, n2, dd, dr, rr, dpi=self.dpi, null_value=self.null_value, split=self.split)
        return ret

    def get_correlation_resample(self, method="jackknife"):
        """
        Get the resampled correlation function and the covariance matrix
        using either bootstrap or jackknife
        """

        if method == "jackknife":
            regions_resample = combinations(self.regions, self.regions.size - 1)
        elif method == "bootstrap":
            regions_resample = [
                np.random.choice(
                    self.regions,
                    size=self.n_bootstrap_oversample * self.regions.size,
                    replace=True
                ) for _ in range(self.n_bootstrap)
            ]
        else:
            raise ValueError("Unknown method")

        xi = []
        for regions in regions_resample:
            xi.append(self.get_correlation(regions)[-2])
        xi = np.ma.array(xi)

        cov = np.ma.cov(self.mask(xi), bias=method == "jackknife", rowvar=False)
        if method == "jackknife":
            cov *= (self.regions.size - 1)

        return xi, cov

    def get_bias(self, wp_dm, error="poisson", rp_min=1.0, rp_max=30.0, *args, **kwargs):
        from scipy.optimize import minimize_scalar

        def fun(b, wp, wp_dm, inv_cov):
            delta = wp - b ** 2 * wp_dm
            return delta @ (inv_cov @ delta.T)

        select = (rp_min <= self.centers) * (self.centers <= rp_max) * np.isfinite(self.xi)
        cov = {
            "poisson": self.covariance_poisson,
            "jackknife": self.covariance_jackknife,
            "bootstrap": self.covariance_bootstrap,
        }.get(error)
        inv_cov = np.linalg.inv(cov[select, :][:, select])
        return minimize_scalar(fun, args=(self.xi[select], wp_dm[select], inv_cov), *args, **kwargs)

    def write(self, filename):

        if self.dpi is None:
            xname = "theta"
            yname = "w"
            xunit = "[degrees]"
            yunit = ""
        else:
            xname = "rp"
            yname = "wp"
            xunit = "[Mpc/h]"
            yunit = "[Mpc/h]"

        # Write the correlation function
        fitsio.write(
            filename,
            {
                xname + xunit: self.centers,
                yname + yunit: self.xi,
                f"d{yname}_poisson{yunit}":   self.dxi_poisson,
                f"d{yname}_jackknife{yunit}": self.dxi_jackknife,
                f"d{yname}_bootstrap{yunit}": self.dxi_bootstrap,
            },
            header={
                "n_data": self.n1,
                "n_random": self.n2,
                f"{xname}_min[xunit]": self.binfile[0],
                f"{xname}_max[xunit]": self.binfile[-1],
                "pi_min": 0,
                "pi_max": self.pi_max,
                "dpi": self.dpi,
                "n_bootstrap": self.n_bootstrap,
                "n_bootstrap_oversample": self.n_bootstrap_oversample,
                "null_value": self.null_value if self.null_value is not np.nan else "nan",
            },
            clobber=True
        )

        # Write the Jackknife/Bootstrap resamples
        fitsio.write(filename, self.xi_jackknife, extname="jackknife")
        fitsio.write(filename, self.xi_bootstrap, extname="bootstrap")

        # Write the pairs
        fitsio.write(filename, self.dd.astype(np.int64), extname="dd")
        fitsio.write(filename, self.dr.astype(np.int64), extname="dr")
        fitsio.write(filename, self.rr.astype(np.int64), extname="rr")

        # Write the 2d correlation
        if self.dpi is not None:
            xi, dxi = get_correlation(self.n1, self.n2, self.dd, self.dr, self.rr, null_value=self.null_value)[-2:]
            fitsio.write(filename, xi,  extname="xi")
            fitsio.write(filename, dxi, extname="dxi_poisson")

        # Write the pairs regions
        empty = np.zeros_like(self.dd)
        to_write = {
            'region1': [],
            'region2': [],
            'dd_region': [],
            'dr_region': [],
            'rr_region': [],
        }
        for r1, r2 in product(np.unique(self.regions), np.unique(self.regions)):
            if r1 > r2:
                continue
            for k, v in [
                    ("region1", r1),
                    ("region2", r2),
                    ("dd_region", self.dd_region.get((r1, r2), empty).flatten(order='F')),
                    ("dr_region", self.dr_region.get((r1, r2), empty).flatten(order='F')),
                    ("rr_region", self.rr_region.get((r1, r2), empty).flatten(order='F')),
            ]:
                to_write[k].append(v)
        for k in to_write:
            to_write[k] = np.array(to_write[k]).astype(np.int64)
        fitsio.write(filename, to_write, extname="pairs_region")

        # Write the covariance
        fitsio.write(filename, self.covariance_poisson,   extname="covariance_poisson")
        fitsio.write(filename, self.covariance_jackknife, extname="covariance_jackknife")
        fitsio.write(filename, self.covariance_bootstrap, extname="covariance_bootstrap")


def _get_pair(autocorr, binfile, ra1, dec1, dc1, ra2=None, dec2=None, dc2=None, pi_max=None, dpi=None):

    kwargs = {
        "autocorr": autocorr,
        "RA1": ra1,
        "DEC1": dec1,
        "RA2": ra2,
        "DEC2": dec2,
        "nthreads": 4,
        "binfile": binfile,
        "verbose": False,
    }

    fun = Corrfunc.mocks.DDtheta_mocks
    if dc1 is not None and dpi is not None:
        fun = Corrfunc.mocks.DDrppi_mocks
        kwargs.update(
            {
                "CZ1": dc1,
                "CZ2": dc2,
                "pimax": pi_max,
                "cosmology": 1,
                "is_comoving_dist": True,
            }
        )

    # Calculate the pairs
    pairs = fun(**kwargs)["npairs"]

    # Rebin for dpi. Implemented in a rolling sum
    if pi_max is not None and dpi is not None:
        _pairs = np.zeros_like(pairs[::dpi])
        for idx_rp in np.arange(0, pairs.size, pi_max):
            for idx_pi1, idx_pi2 in enumerate(np.arange(idx_rp, idx_rp + pi_max, dpi)):
                _pairs[idx_rp // dpi + idx_pi1] += pairs[idx_pi2:idx_pi2+dpi].sum()
        assert _pairs.sum() == pairs.sum()
        pairs = _pairs.reshape((pi_max // dpi, len(binfile) - 1), order='F')

    return pairs


def get_pair(autocorr, binfile, ra1, dec1, dc1, ra2=None, dec2=None, dc2=None, pi_max=None, dpi=None, split=1):

    # Calculate the pairs by splitting the catalogue
    for i in range(split):
        _ret = _get_pair(
            autocorr, binfile,
            ra1[i::split], dec1[i::split], dc1[i::split],
            ra2[i::split], dec2[i::split], dc2[i::split],
            pi_max, dpi
        )

        if i == 0:
            ret = deepcopy(_ret)
        else:
            ret += _ret

    return ret


def get_correlation(n1, n2, dd, dr, rr, dpi=None, null_value=np.nan, split=1):

    # Estimate the correlation
    n2p = n2 // split
    f1 = split * n2p * (n2p - 1) / (n1 * (n1 - 1))
    f2 = (n2p - 1) / n1
    xi = np.ma.true_divide(f1 * dd - 2 * f2 * dr + rr, rr)
    dxi = np.sqrt(
          np.ma.true_divide(f1 * dd ** .5, rr) ** 2
        + np.ma.true_divide(f2 * dr ** .5, rr) ** 2
        + np.ma.true_divide(f1 * dd - f2 * dr, rr ** 1.5) ** 2
    )

    # NOTE: Set invalid bins to null_value
    select =  (rr > 0)
    select *= (xi >= -1) # NOTE: xi >= -1
    xi[~select] = null_value
    dxi[~select] = null_value

    # Integrate along the line-of-sight
    if dpi is not None:
        xi  = 2 * dpi * np.ma.sum(xi, axis=0)
        dxi = 2 * dpi * np.ma.sum(dxi ** 2, axis=0) ** .5

    return xi, dxi


def get_region_radec(ra1, dec1, ra2, dec2, n_ra, n_dec):

    # Find the ra/dec binning that encloses all the data points
    ra = np.concatenate((ra1, ra2))
    ra_bins = np.linspace(ra.min(), ra.max() + 1e-6, n_ra + 1)
    sindec = np.sin(np.concatenate((dec1, dec2)) * np.pi / 180)
    sindec_bins = np.linspace(sindec.min(), sindec.max() + 1e-6, n_dec + 1)

    def get_idx(ra, dec):
        sindec = np.sin(dec * np.pi / 180.)
        idx_ra = np.digitize(ra, bins=ra_bins) - 1
        idx_dec = np.digitize(sindec, bins=sindec_bins) - 1
        assert np.all(0 <= idx_ra)
        assert np.all(0 <= idx_dec)
        assert np.all(idx_ra < ra_bins.size - 1)
        assert np.all(idx_dec < sindec_bins.size - 1)
        return idx_dec * (ra_bins.size - 1) + idx_ra

    return get_idx(ra1, dec1), get_idx(ra2, dec2)


def get_integral_constraint(rr, w):
    """see e.g. Krishnan+20"""
    return np.sum(rr * w) / np.sum(rr)


def get_wp_dm(rp, z, pi_max, cosmology_colossus, r=np.logspace(-3, 2.5, 55001), ps_args={"model": "eisenstein98_zb"}):
    """Return wp dm for given rp centers (in Mpc/h) and redshift"""

    # TODO: figure out where to pass out the parameters
    cf = np.vectorize(lambda r, z: cosmology_colossus.correlationFunction(r, z, ps_args=ps_args))
    xi_dm = cf(r, z)

    from scipy.interpolate import interp1d
    from scipy.integrate import quad
    func_xi_dm = interp1d(r, xi_dm)
    def integral(rp):
        def integrand(r):
            return 2 * r * func_xi_dm(r) / np.sqrt(r ** 2 - rp ** 2)
        wp_dm = quad(integrand, rp, (rp ** 2 + pi_max ** 2) ** .5, epsrel=1e-4, limit=10000)[0]
        return wp_dm
    return np.vectorize(integral)(rp)


def main():

    parser = argparse.ArgumentParser()

    # I/O
    parser.add_argument("filename_data")
    parser.add_argument("filename_rand")
    parser.add_argument("--filename_out", default="corr.fits")
    parser.add_argument("--col_ra", default="ra")
    parser.add_argument("--col_dec", default="dec")

    # Binning/clustering analysis
    parser.add_argument("--scale_min", default=1e-2, type=float, help="minimum scale in degrees or Mpc/h")
    parser.add_argument("--scale_max", default=1e+2, type=float, help="maximum scale in degrees or Mpc/h")
    parser.add_argument("--n_bins", default=40, type=int, help="number of bins")
    parser.add_argument("--pi_max", default=40, type=int, help="pi_max integral in Mpc/h")
    parser.add_argument("--dpi", default=None, type=int, help="pi binsize in Mpc/h. Leave empty to calculate the angular correlation function")
    parser.add_argument("--n_bootstrap", default=100, type=int)
    parser.add_argument("--n_bootstrap_oversample", default=1, type=int)
    parser.add_argument("--n_region_ra",  type=int, help="Number of random ra regions the catalogue is split for jackknife/bootstrap")
    parser.add_argument("--n_region_dec", type=int, help="Number of random dec regions the catalogue is split for jackknife/bootstrap")
    parser.add_argument("--split", default=1, type=int, help="Random catalog split factor")

    # Extra options
    parser.add_argument("--null_value", type=float, default=np.nan)
    parser.add_argument("--max_distance_region", default=10 ** 5, type=int, help="Maximum region distance to exclude regions that should not be correlated")

    # Cosmology
    parser.add_argument("--H0", default=100., type=float)
    parser.add_argument("--Om0", default=0.30, type=float)
    parser.add_argument("--Tcmb0", default=2.72548, type=float)
    parser.add_argument("--Ob0", default=0.045, type=float)
    parser.add_argument("--Neff", default=3.046, type=float)
    parser.add_argument("--sigma8", default=0.80, type=float)
    parser.add_argument("--ns", default=0.96, type=float)

    # Debugging
    parser.add_argument("--debug", action="store_true", help="Enter debug mode")

    # Parse the arguments
    args = parser.parse_args()

    # Read in the catalogues
    data = fitsio.read(args.filename_data, ext=1)
    rand = fitsio.read(args.filename_rand, ext=1)
    binf = np.geomspace(args.scale_min, args.scale_max, args.n_bins + 1)

    # Initialize the cosmology
    cosmo = FlatLambdaCDM(H0=args.H0, Om0=args.Om0, Ob0=args.Ob0, Tcmb0=args.Tcmb0, Neff=args.Neff)
    cosmo_colossus = cosmology.fromAstropy(cosmo, sigma8=args.sigma8, ns=args.ns, cosmo_name="LCDM")

    # Convert from z to comoving distnace
    dc1 = None
    dc2 = None
    if "z" in data.dtype.names:
        dc1 = cosmo.comoving_distance(data["z"]).value
        dc2 = cosmo.comoving_distance(rand["z"]).value

    # Assign the regions if requested
    if args.n_region_ra and args.n_region_dec:
        region1, region2 = get_region_radec(
            data[args.col_ra],
            data[args.col_dec],
            rand[args.col_ra],
            rand[args.col_dec],
            args.n_region_ra,
            args.n_region_dec,
        )
    else:
        region1 = data["region"]
        region2 = rand["region"]

    # Initialize the correlation function
    corr = Correlation(
        binfile=binf,
        ra1=data[args.col_ra], dec1=data[args.col_dec], dc1=dc1, region1=region1,
        ra2=rand[args.col_ra], dec2=rand[args.col_dec], dc2=dc2, region2=region2,
        n_bootstrap=args.n_bootstrap,
        n_bootstrap_oversample=args.n_bootstrap_oversample,
        pi_max=args.pi_max,
        dpi=args.dpi,
        null_value=args.null_value,
        max_distance_region=args.max_distance_region,
        split=args.split
    )

    # Write the correlation function
    corr.write(args.filename_out)


if __name__ == "__main__":
    main()
