
#!/usr/bin/env python3
# encoding: utf-8
# Author: Akke Viitanen
# Email: akke.viitanen@helsinki.fi
# Date: 2023-02-13 16:20:51

"""
Provide a simple image class for imSim images
"""

import argparse
from copy import deepcopy
from itertools import product
from multiprocessing import Pool
import glob
import os
import subprocess
import sys
import time
import re

from astropy.coordinates import SkyCoord
import pandas as pd
import sqlite3
import pandas as pd
from astropy.io import fits
from astropy.wcs import WCS

import matplotlib as mpl
import matplotlib.pyplot as plt

class Image:

    def __init__(self, dirname):

        self.observation_id = re.findall("/([0-9]+)[/]?$", dirname)[0]
        self.eimage = fits.open(glob.glob(dirname + "/output/eimage_*.fits")[0])
        self.amp = fits.open(glob.glob(dirname + "/output/amp_*.fits.fz")[0])
        self.wcs = WCS(self.eimage[0].header)
        print(self.wcs)

    def get_baseline_entry(self):
        filename_baseline = "data/baseline_v3.0_10yrs.db"
        with sqlite3.connect(filename_baseline) as con:
            df = pd.read_sql_query(
                f"""
                SELECT * FROM observations
                WHERE observationId={self.observation_id}
                """,
                con
            )
        return df

    def is_in_image(self, ra, dec):
        """
        Return whether ra, dec coordinate is in the image
        """
        sky = SkyCoord(ra, dec, unit="deg")
        col, row = self.wcs.world_to_pixel(sky)

        for pix in [(0, 0), (0, 4096), (4004, 4096), (4004, 0)]:
            print(self.wcs.pixel_to_world(pix[0], pix[1]))
        quit()

        print(sky)
        print(col, row)
        quit()

        shape = self.eimage[0].data.shape
        print(shape)
        print(row, col)
        return (0 <= row < shape[0]) * (0 <= col < shape[1])

    def plot(self, fig, ax, *args, **kwargs):
        ax.imshow(self.eimage[0].data, *args, **kwargs)
        return fig, ax


im = Image("data/catalog/0.10deg2_0.21_z_5.49_ra_+150.11916667_dec_+2.20583333.bak/imsim/235928")
im.plot(plt.gcf(), plt.gca())
plt.savefig("image.pdf")
df = im.get_baseline_entry()

print(im.is_in_image(df["fieldRA"], df["fieldDec"]))
