import os
import numpy as np
from multiprocessing import Pool
import astropy.io.fits as fits

# Funzione che prende in input un singolo file FITS da processare.
def get_data(files):
    
    # Estrazione di energia, theta e phi dal nome del file
    base_name = os.path.basename(files)  # Ottiene il nome del file senza il percorso
    name_parts = base_name.replace('.fits', '').split('_')  # Rimuove l'estensione e suddivide il nome del file in una lista di stringhe usando _ come delimitatore.
    # Energia, theta e phi sono in posizioni fisse nel nome
    #energy = float(name_parts[1])
    theta = float(name_parts[2])
    phi = float(name_parts[3])
    
    # Apertura del file fits e lettura dei dati
    with fits.open(files) as hdul:
        data = hdul[1].data
        tot_events = len(data)
        scint_id_col = data['Scint_ID']
        energy = data['En_Primary'][0]
        en_dep = data['En_dep']
        eventi_X = (scint_id_col == -1000).sum()
        eventi_S = tot_events - eventi_X
        ratio_X = eventi_X/tot_events
        ratio_S = eventi_S/tot_events
        
    return energy, ratio_X, ratio_S, theta, phi, en_dep

def write_results(file_dat, results):
    npy_dir = "/home/alfonso/Scrivania/THESEUS/xgis_m7-main/python/npy"
    if not os.path.exists(npy_dir):
        os.makedirs(npy_dir)
    
    with open(file_dat, "w") as f:
        for energy, ratio_X, ratio_S, theta, phi, en_dep in results:
            # Salva l'array en_dep in un file separato, e salva tutto in una cartella
            npy_file = f"en_dep_{energy:.1f}_{theta}_{phi}.npy"
            en_dep_file = os.path.join(npy_dir, npy_file)
            np.save(en_dep_file, en_dep)
            
            f.write(f"{energy:.1f}\t{ratio_X:.3f}\t{ratio_S:.3f}\t{theta}\t{phi}\t{npy_file}\n")

# Estrae i dati, processa in parallelo tutti i file .fits presenti nella folder, usa le due funzioni già definite
def data_estraction(file_dat):
    # Crea una lista di file FITS trovati nella folder. Sorted per leggerli in ordine
    folder = "/home/alfonso/Scrivania/THESEUS/xgis_m7-main/fits"

    files = sorted([os.path.join(folder, file) for file in os.listdir(folder) if file.endswith(".fits")], key=os.path.getmtime)
    
    # Crea un pool di processi che permette di eseguire funzioni in parallelo.
    # Usa il metodo map() del pool per eseguire la funzione get_data() su ciascun file FITS in parallelo.
    # I risultati vengono raccolti in una lista chiamata results.
    with Pool() as pool:
        results = pool.map(get_data, files) # qui files è l'argomento che passi a get_data
    
    write_results(file_dat, results)