#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-04-12 22:44:16

"""
Wrapper for calling egg
"""

import os
from util import ROOT

# See if EGG has been installed properly
import shutil
if not shutil.which("egg-gencat"):
    raise FileNotFoundError("Could not find EGG binary. Check the EGG installation.")

class Egg:

    def __init__(self, egg_kwargs):
        self.egg_kwargs = egg_kwargs
        if "out" not in self.egg_kwargs:
            self.egg_kwargs["out"] = self.get_filename()

        self.egg_kwargs["out"] = ROOT + "/" + self.egg_kwargs["out"]
        self.egg_kwargs["mass_func"] = ROOT + "/" + self.egg_kwargs["mass_func"]

    def get_argument_line(self, exclude=[]):
        egg_kwargs = {k: v for k, v in self.egg_kwargs.items()}
        return sorted([f"{k}={str(v)}" for k, v in egg_kwargs.items() if k not in exclude])

    def get_filename(self, exclude=["verbose", "out", "bands"]):

        if "out" in self.egg_kwargs:
            return self.egg_kwargs["out"]

        filename = "data/egg/" + '/'.join(self.get_argument_line(exclude=exclude)) + "/egg.fits"
        filename = filename.replace('[', '')
        filename = filename.replace(']', '')
        filename = filename.replace(',', '/')
        return filename

    def get_area(self):
        return self.egg_kwargs["area"]

    def run(self, overwrite=False):

        cmd = " ".join(["egg-gencat"] + self.get_argument_line())
        filename = self.get_filename()
        if not overwrite and os.path.exists(filename):
            return 0

        dirname = os.path.dirname(filename)
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        print("Running...", cmd)
        return os.system(cmd)

    def get_sed(self, i):
        """Return a EGG SED"""
        import fitsio
        dirname = os.path.dirname(self.egg_kwargs["save_sed"])
        fname = f"{dirname}/egg-seds-{i}.fits"
        try:
            if not os.path.exists(fname):
                os.system(f"egg-getsed seds={dirname}/egg-seds.dat id={i}")
            return fitsio.read(fname)
        except:
            print("Could not find galaxy sed")
            return None

    @staticmethod
    def read(filename):
        """
        Note that usually and for small files, loading EGG through
        astropy/fitsio works fine. For large files (n_bands * n_galaxies >= 2
        ** 28) these routines fail due to the dtype not fitting into a C
        integer.

        The workaround is to write these columns as binary files using a
        separate routine, which can then be read using numpy and shaped to the
        correct shape.

        This is a dirty hack, but the alternative is to change completely how
        EGG writes the FITS files.
        """

        import fitsio
        import numpy as np

        src = r"""
#include "fitsio.h"

#define MIN(x,y) ((x)<(y))?(x):(y)
#define BUFSIZE 1024*1024

int main(int argc, char **argv)
{

    /* the user interface */
    char *filename_in = argv[1];
    char *filename_out = argv[2];
    char *template = argv[3];
    size_t ntotal = atol(argv[4]);

    fitsfile *fptr;
    int status, hdutype;
    float array[BUFSIZE];
    FILE *write_ptr;

    /* Open the files */
    fits_open_file(&fptr, filename_in, READONLY, &status);
    fits_movabs_hdu(fptr, 2, &hdutype, &status);
    write_ptr = fopen(filename_out, "wb");

    /* Find out the number of the column */
    int colnum;
    fits_get_colnum(fptr, 0, template, &colnum, &status);

    /* Read BUFSIZE worth of data from fits at a time and write to the binary file */
    for (size_t start=0, nelem; start<ntotal; start+=nelem) {
        nelem = MIN(BUFSIZE, ntotal - start);
        fits_read_col(fptr, TFLOAT, colnum, 1, start + 1, nelem, NULL, array, NULL, &status);
        fwrite(array, nelem * sizeof(float), 1, write_ptr);
        printf("%10lu %10lu %10lu (%6.2f%%)\r", start, nelem, ntotal, (float) start / ntotal * 100);
    }
    printf("\n");

    fclose(write_ptr);
    fits_close_file(fptr, &status);

    return 0;
}
"""
        #print(src, file=open("main.c", 'w'))
        #os.system("gcc main.c -o main -lcfitsio -I/usr/include -L/usr/lib")

        columns = [
                "ID", "RA", "DEC", "Z", "D", "M", "SFR", "CMD",
                "DISK_ANGLE", "DISK_RADIUS", "DISK_RATIO",
                "BULGE_ANGLE", "BULGE_RADIUS", "BULGE_RATIO",
                "BANDS", "LAMBDA", "RFBANDS", "RFLAMBDA",
                "AVLINES_BULGE", "AVLINES_DISK",
                "CLUSTERED", "PASSIVE",
                "FLUX", "FLUX_BULGE", "FLUX_DISK",
                "RFMAG", "RFMAG_BULGE", "RFMAG_DISK",
        ]

        try:
            ret = {k: fitsio.read(filename, columns=k) for k in columns}
            return ret
        except:
            pass

        print("Reading (and writing) the EGG flux files one-by-one")

        # Remove the fluxes
        columns = columns[:-6]
        ret = {k: fitsio.read(filename, columns=k) for k in columns}

        Ngal = ret["ID"][0].size
        Nbands = ret["BANDS"][0].size
        Ntotal = Ngal * Nbands
        for k1 in "FLUX", "RFMAG":
            for k2 in "", "_BULGE", "_DISK":
                k = k1 + k2
                fin = filename
                fout = filename.replace(".fits", f"_{k}.dat")
                if not os.path.exists(fout):
                    print(f"  Writing {k} from {fin} to {fout}")
                    os.system(f"/home/viitanen/.local/bin/write_fits_column {fin} {fout} {k} {Ngal * Nbands}")
                ret[k] = np.fromfile(fout, dtype=np.float32).reshape((1, Ngal, Nbands))
        return ret

    @staticmethod
    def get_smf(z, key, filename):
        import fitsio
        import numpy as np
        smf = fitsio.read(filename)
        for i, (zlo, zhi) in enumerate(smf["ZB"][0].T):
            if zlo <= z < zhi:
                print(zlo, z, zhi)
                break
        x = np.mean(smf["MB"][0], axis=0)
        y = smf[key][0][i, :]
        return x, y
