import astropy.io.fits as pyfits
import os
import sys
import numpy as np

    
def create_ebounds_header(extension):
    extension.header['EXTNAME'] = 'EBOUNDS'
    extension.header['TELESCOP'] = 'THESEUS'
    extension.header['INSTRUME'] = 'XGIS'
    extension.header['FILTER'] = 'NONE'
    extension.header['CHANTYPE'] = 'PHA' # PI 
    extension.header['DETCHANS'] = 'NONE' # the total number of raw detector PHA channels in the full (uncompressed) matrix
    extension.header['HDUCLASS'] = 'OGIP'
    extension.header['HDUCLAS1'] = 'RESPONSE'
    extension.header['HDUCLAS2'] = 'EBOUNDS'
    extension.header['HDUVERS'] = '1.2.0'
    # optional
    extension.header['PHAFILE'] = 'NONE' # name of PHA file for which this file was produced
    # obsolete (for old software)
    extension.header['RMFVERSN'] = '1992A'
    extension.header['HDUVERS1'] = '1.0.0'
    extension.header['HDUVERS2'] = '1.1.0'  


def create_matrix_header(extension):
    extension.header['EXTNAME'] = 'SPECRESP MATRIX'
    extension.header['TELESCOP'] = 'THESEUS'
    extension.header['INSTRUME'] = 'XGIS'
    extension.header['FILTER'] = 'NONE'
    extension.header['CHANTYPE'] = 'PHA' # PI 
    extension.header['DETCHANS'] = 'NONE' # the total number of raw detector PHA channels in the full (uncompressed) matrix.
    extension.header['HDUCLASS'] = 'OGIP'
    extension.header['HDUCLAS1'] = 'RESPONSE'
    extension.header['HDUCLAS2'] = 'RSP_MATRIX'
    extension.header['HDUVERS'] = '1.3.0'
    #extension.header['TLMIN#'] = 'NONE'# the first channel in the response. # is the column number for the F_CHAN column (see below).
    # optional
    extension.header['NUMGRP'] = 'NONE' # the total number of channel subsets. The sum of the N_GRP column.
    extension.header['NUMELT'] = 'NONE' # the total number of response elements. The sum of the N_CHAN column
    extension.header['PHAFILE'] = 'NONE' # name of PHA file for which this file was produced
    extension.header['LO_THRES '] = '0' # minimum probability threshold used to construct the matrix (matrix elements below this value are considered to zero and are not stored)
    extension.header['HDUCLAS3'] = 'REDIST' # giving further details of the stored matrix Allowed values are: 'REDIST' for a matrix whose elements represent probabilities associated with the photon redistribution process only 'DETECTOR' for a matrix whose elements have been multiplied by all energy-dependent effects associated with detector (eg detector efficiency, window transmission etc). 'FULL' for a matrix whose elements have been multiplied by all energy-dependent effects associated with detector, optics, collimator, filters etc.
    # obsolete (for old software)
    extension.header['RMFVERSN'] = '1992A'
    extension.header['HDUVERS1'] = '1.0.0'
    extension.header['HDUVERS2'] = '1.3.0'
    
    # altre mandatory che non ho capito

def estract_row(en_dep_file, bins):
    
    en_dep = np.load(en_dep_file)
    lenght = en_dep.size
    
    counts, bin_edges = np.histogram(en_dep, bins=bins)
    counts_norm = counts/lenght
    
    print(counts)
    print(counts_norm)
    
    n_grp = 0
    f_chan = []
    n_chan = []

    # Controlla i gruppi di valori consecutivi non nulli
    in_group = False  # Stato per controllare se siamo all'interno di un gruppo, booleano
    group_start = None  # Per memorizzare l'inizio del gruppo, intero, noen= senza valore

    # Il ciclo scorre ogni valore dell'array counts_norm e controlla se è maggiore di zero (cioè non nullo).
    for i in range(len(counts_norm)):
        if counts_norm[i] > 0:
            if not in_group:  # se non sono gia in un gruppo svolgi i seguenti # ! questi comandi li esegue solo per i valori iniziali di ogni gruppo
                in_group = True # apro un nuovo gruppo
                group_start = i  # Memorizza l'indice di inizio del gruppo
                f_chan.append(i)  # Aggiungi l'indice di inizio al vettore f_chan
                n_grp += 1  # Incrementa il numero di gruppi
        else:
            if in_group:  # Termina il gruppo corrente # ! qui esegue solo al primo zero che trova, se ne trova uno consecutivo non esegue perche trova in_group falso
                in_group = False
                group_length = i - group_start  # Calcola la lunghezza del gruppo
                n_chan.append(group_length)  # Aggiungi la lunghezza del gruppo a n_chan

    # Controlla se l'ultimo gruppo arriva fino alla fine dell'array
    if in_group:  # se è true esegue
        n_chan.append(len(counts_norm) - group_start)
    
    return n_grp, f_chan, n_chan, counts_norm # row_matrix = counts_norm

def rmf_file(en_dep_file_filt, theta, phi, energy):
    try:
        os.chdir("/home/alfonso/Scrivania/THESEUS/xgis_m7-main/python/npy")
        
        n_grp_col = []
        f_chan_col = []
        n_chan_col = []
        row_matrix_col = []
        
        bins = np.arange(0, 156, 5)  # se tra min e max voglio definire una spaziatura # ! per ora questo ma meglio linspace
        #bins = np.linspace(0, 156, 5) # se tra min e max voglio definire quanti bin voglio
        
        # se voglio definire i channels
        bin_centers = 0.5 * (bins[1:] + bins[:-1])
        channels = np.arange(1, len(bin_centers) + 1)
        
        for en_dep_file in en_dep_file_filt:
            n_grp, f_chan, n_chan, row_matrix = estract_row(en_dep_file, bins)
            
            n_grp_col.append(n_grp)
            f_chan_col.append(f_chan)
            n_chan_col.append(n_chan)
            row_matrix_col.append(row_matrix)
            
        print(f"n_grp: {n_grp_col}")
        print(f"f_chan: {f_chan_col}")
        print(f"n_chan: {n_chan_col}")
        print(f"row_matrix: {row_matrix_col}")        
        
        
        # Creazione di "Null" primary array
        primary_hdu = pyfits.PrimaryHDU()

        
        # Creazione di ebounds BinTableHDU per EVENTI X
        ebounds_bin_tableX_hdu = pyfits.BinTableHDU.from_columns([
            pyfits.Column(name='CHANNEL', format='1I', array = channels),
            pyfits.Column(name='E_MIN', format='1E', unit='keV', array = bins[:-1]),           
            pyfits.Column(name='E_MAX', format='1E', unit='keV', array = bins[1:]),      
            ])
        create_ebounds_header(ebounds_bin_tableX_hdu)
        
        # Creazione di matrice BinTableHDU per EVENTI X
        matrix_bin_tableX_hdu = pyfits.BinTableHDU.from_columns([
            pyfits.Column(name='ENERG_LO', format='1E', unit='keV', array = energy[:-1]),          
            pyfits.Column(name='ENERG_HI', format='1E', unit='keV', array = energy[1:]),
            pyfits.Column(name='N_GRP', format='1I', array = n_grp_col),
            pyfits.Column(name='F_CHAN', format='PJ', array = f_chan_col),
            pyfits.Column(name='N_CHAN', format='PJ', array = n_chan_col),
            pyfits.Column(name='MATRIX', format='PE', array = row_matrix_col),
        ])
    
        create_matrix_header(matrix_bin_tableX_hdu)  # Aggiunta header comune per gli eventi X



    
    # !!!!      da fare anche per eventi S     !!!!!!!!!!!!!!



        # Creazione di HDUList e scrittura del file FITS
        hdulX = pyfits.HDUList([primary_hdu, ebounds_bin_tableX_hdu, matrix_bin_tableX_hdu])
        #hdulS = pyfits.HDUList([primary_hdu, ebounds_bin_tableS_hdu, matrix_bin_tableS_hdu])

        output_dir = "/home/alfonso/Scrivania/THESEUS/xgis_m7-main/rmf"
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        output_X = os.path.join(output_dir, f"Xrmf_{theta}_{phi}.rmf")
        #output_S = os.path.join(output_dir, f"Srmf_{theta}_{phi}.rmf")

        hdulX.writeto(output_X, overwrite=True)
        #hdulS.writeto(output_S, overwrite=True)

        print("File .rmf creati correttamente.")
    except Exception as e:
        print(f"Errore durante la creazione del file .rmf: {e}")