import os
import sqlalchemy
import numpy as np
import astropy.units as u
from datetime import datetime
from astropy.io import fits
from astropy import constants
from typing import Union, List
from assets.commons import (validate_parameter,
                            get_value_if_specified,
                            load_config_file,
                            setup_logger,
                            compute_unique_hash_filename)
from assets.commons.db_utils import upsert, get_pg_engine
from assets.constants import aggregation_function_mapping
from stg.stg_build_db_structure import (MomentZeroMaps,
                                        RatioMaps)

logger = setup_logger(name='PRS')


def populate_mom_zero_table(config_prs: dict,
                            fits_cube_name: str,
                            moment_zero_map_fits: str,
                            aggregated_moment_zero: float,
                            executed_on: datetime.timestamp,
                            engine: sqlalchemy.engine,
                            run_id: str):
    """
    Insert record in DB for the moment zero table
    :param config_prs: configuration of the presentation layer
    :param fits_cube_name: name of the fits cube
    :param moment_zero_map_fits: filename of the moment zero map
    :param aggregated_moment_zero: the aggregated value of the moment zero
    :param executed_on: timestamp of execution
    :param engine: SQLAlchemy engine to use
    :param run_id: the run unique identifier
    """
    _integration_limit_low = get_value_if_specified(config_prs, 'integration_limits')[0] if get_value_if_specified(
        config_prs, 'integration_limits') != 'all' else None
    _integration_limit_high = get_value_if_specified(config_prs, 'integration_limits')[1] if get_value_if_specified(
        config_prs, 'integration_limits') != 'all' else None
    moment_zero_dict = {
        'mom_zero_name': f'{moment_zero_map_fits}',
        'fits_cube_name': f'{fits_cube_name}',
        'integration_limit_low': _integration_limit_low,
        'integration_limit_high': _integration_limit_high,
        'aggregated_moment_zero': aggregated_moment_zero,
        'aggregation_function': config_prs['moment_zero_aggregation_function'],
        'created_on': executed_on,
        'run_id': run_id,
    }
    upsert(
        table_object=MomentZeroMaps,
        row_dict=moment_zero_dict,
        conflict_keys=[MomentZeroMaps.mom_zero_name,
                       MomentZeroMaps.run_id],
        engine=engine
    )


def populate_line_ratios_table(config_prs: dict,
                               moment_zero_map_fits_list: List[str],
                               ratio_map_name: str,
                               aggregated_ratio: float,
                               engine: sqlalchemy.engine,
                               executed_on: datetime.timestamp,
                               run_id: str):
    """
    Insert record in DB for the line ratios table
    :param config_prs: configuration of the presentation layer
    :param moment_zero_map_fits_list: list of the moment 0 maps to use to compute the ratio
    :param ratio_map_name: filename of the ratio map
    :param aggregated_ratio: aggregated value of the ratio
    :param engine: SQLAlchemy engine to use
    :param executed_on: timestamp of execution
    :param run_id: the run unique identifier
    """
    _integration_limit_low = get_value_if_specified(config_prs, 'integration_limits')[0] if get_value_if_specified(
        config_prs, 'integration_limits') != 'all' else None
    _integration_limit_high = get_value_if_specified(config_prs, 'integration_limits')[1] if get_value_if_specified(
        config_prs, 'integration_limits') != 'all' else None
    assert len(moment_zero_map_fits_list) == 2
    line_ratio_dict = {
        'ratio_map_name': f'{ratio_map_name}',
        'mom_zero_name_1': f'{moment_zero_map_fits_list[0]}',
        'mom_zero_name_2': f'{moment_zero_map_fits_list[1]}',
        'aggregated_ratio': aggregated_ratio,
        'aggregation_function': config_prs['aggregation_function'],
        'created_on': executed_on,
        'run_id': run_id,
    }
    upsert(
        table_object=RatioMaps,
        row_dict=line_ratio_dict,
        conflict_keys=[RatioMaps.ratio_map_name,
                       RatioMaps.run_id],
        engine=engine
    )


def compute_moment_zero(cube: str,
                        config: dict,
                        cube_path: Union[str, None] = None,
                        moment_zero_path: Union[str, None] = None,
                        moment_zero_fits_name: Union[str, None] = None,
                        hdu_idx: int = 0) -> float:
    """
    Compute moment zero map, given a cube
    :param cube: filename of the fits cube
    :param config: configuration of the presentation layer
    :param cube_path: path where the fits cube is to be found
    :param moment_zero_path: path where the moment 0 map is to be saved
    :param moment_zero_fits_name: filename for the moment zero map
    :param hdu_idx: index of the HDU to extract the data and the header
    :return: the aggregated value of the moment zero, according to the function specified in the configuration
    """
    _cube_path = validate_parameter(cube_path, default=os.path.join('prs', 'fits', 'cubes'))
    _moment_zero_path = validate_parameter(moment_zero_path, default=os.path.join('prs', 'fits', 'moments'))
    _moment_zero_fits_name = validate_parameter(moment_zero_fits_name, default='test_mom0.fits')

    fitsfile = open_fits_file_duck_typing(fitsfile=cube, fits_path=_cube_path)
    header = fitsfile[hdu_idx].header.copy()
    data = fitsfile[hdu_idx].data
    conversion = header['CUNIT3'].strip() == 'HZ'
    spectral_unit = 'km/s' if conversion is True else 'Hz'
    conversion_factor = 1 if conversion is False else (-constants.c * header['CDELT3'] / header['CRVAL3']).to('km/s').value
    mom0 = (data.sum(axis=0) * abs(conversion_factor))
    keywords_to_delete = ['NAXIS3', 'CRPIX3', 'CDELT3', 'CRVAL3', 'CUNIT3', 'CTYPE3']
    header['BUNIT'] += f' {spectral_unit}'
    header['BTYPE'] = 'MOMENT0'
    header['NAXIS'] = 2
    for key in keywords_to_delete:
        del header[key]
    fits.writeto(os.path.join(_moment_zero_path, _moment_zero_fits_name), data=mom0, header=header, overwrite=True)
    try:
        return aggregation_function_mapping[config['moment_zero_aggregation_function']](mom0)
    except KeyError:
        logger.warning(
            'Moment zero aggregation function not set or misconfigured. Trying to use the one for the ratio...')
        return aggregation_function_mapping[config['aggregation_function']](mom0)


def open_fits_file_duck_typing(fitsfile: Union[str, fits.PrimaryHDU],
                               fits_path: str = None) -> fits.PrimaryHDU:
    _fits_path = validate_parameter(fits_path, default='.')
    try:
        hdu = fits.open(os.path.join(_fits_path, fitsfile))
    except TypeError:
        hdu = fitsfile
    return hdu


def compute_image_ratios(fits1: str,
                         fits2: str,
                         config: dict,
                         fits_path: Union[str, None] = None,
                         ratio_fits_path: Union[str, None] = None,
                         ratio_fits_name: Union[str, None] = None,
                         hdu1_idx: int = 0,
                         hdu2_idx: int = 0):
    """
    Compute the ratio of the images and save it to fits. The ratio is computed as fits1/fits2
    :param fits1: filename of the first moment 0 map
    :param fits2: filename of the second moment 0 map
    :param config: configuration of the presentation layer
    :param fits_path: path where the fits files are to be found
    :param ratio_fits_path: path where the ratio map is to be saved
    :param ratio_fits_name: filename for the ratio map
    :param hdu1_idx: index of the HDU to extract the data and the header, for the first fits file
    :param hdu2_idx: index of the HDU to extract the data and the header, for the second fits file
    :return:
    """
    _fits_path = validate_parameter(fits_path, default=os.path.join('prs', 'fits', 'moments'))
    _ratio_fits_path = validate_parameter(ratio_fits_path, default=os.path.join('prs', 'fits', 'ratios'))
    _ratio_fits_name = validate_parameter(ratio_fits_name, default='ratio.fits')
    hdu1 = open_fits_file_duck_typing(fitsfile=fits1,
                                      fits_path=_fits_path)
    hdu2 = open_fits_file_duck_typing(fitsfile=fits2,
                                      fits_path=_fits_path)
    ratio_image_data = hdu1[hdu1_idx].data / hdu2[hdu2_idx].data
    if config['aggregation_function'] in ('mean', 'sum'):
        aggregated_ratio = aggregation_function_mapping[config['aggregation_function']](ratio_image_data)
    elif config['aggregation_function'] == 'weighted_mean':
        aggregated_ratio = np.nansum(hdu1[hdu1_idx].data) / np.nansum(hdu2[hdu2_idx].data)
    else:
        raise NotImplementedError('Aggregation function not configured')
    output_header = hdu1[hdu1_idx].header
    output_header['BUNIT'] = ''
    output_header['BTYPE'] = 'INT.RATIO'
    fits.writeto(os.path.join(_ratio_fits_path, _ratio_fits_name),
                 data=ratio_image_data,
                 header=output_header,
                 overwrite=True)
    return aggregated_ratio


def main(cube_fits_list: List[str],
         run_id: str,
         mom0_out_cube1: Union[str, None] = None,
         mom0_out_cube2: Union[str, None] = None,
         engine: sqlalchemy.engine = None) -> str:
    assert len(cube_fits_list) == 2
    _mom0_out_cube1 = validate_parameter(mom0_out_cube1, default=cube_fits_list[0].replace('.fits', '_mom0.fits'))
    _mom0_out_cube2 = validate_parameter(mom0_out_cube2, default=cube_fits_list[1].replace('.fits', '_mom0.fits'))
    config_prs = load_config_file(os.path.join('prs', 'config', 'config.yml'))['flux_computation']
    executed_on = datetime.now()
    if engine is None:
        engine = get_pg_engine(logger=logger)
    config_prs.update({
        'cube_fits_list': cube_fits_list,
        'mom0_image_list': [_mom0_out_cube1, _mom0_out_cube2]
    })
    ratio_fits_name = f'{compute_unique_hash_filename(config=config_prs)}.fits'
    aggregated_mom0_1 = compute_moment_zero(cube=cube_fits_list[0],
                                            config=config_prs,
                                            moment_zero_fits_name=_mom0_out_cube1)
    aggregated_mom0_2 = compute_moment_zero(cube=cube_fits_list[1],
                                            config=config_prs,
                                            moment_zero_fits_name=_mom0_out_cube2)
    aggregated_image_ratio = compute_image_ratios(fits1=_mom0_out_cube1,
                                                  fits2=_mom0_out_cube2,
                                                  config=config_prs,
                                                  ratio_fits_name=ratio_fits_name)
    populate_mom_zero_table(config_prs=config_prs,
                            fits_cube_name=cube_fits_list[0],
                            moment_zero_map_fits=_mom0_out_cube1,
                            aggregated_moment_zero=aggregated_mom0_1,
                            executed_on=executed_on,
                            engine=engine,
                            run_id=run_id)
    populate_mom_zero_table(config_prs=config_prs,
                            fits_cube_name=cube_fits_list[1],
                            aggregated_moment_zero=aggregated_mom0_2,
                            moment_zero_map_fits=_mom0_out_cube2,
                            executed_on=executed_on,
                            engine=engine,
                            run_id=run_id)
    populate_line_ratios_table(config_prs=config_prs,
                               moment_zero_map_fits_list=[_mom0_out_cube1, _mom0_out_cube2],
                               ratio_map_name=ratio_fits_name,
                               aggregated_ratio=aggregated_image_ratio,
                               engine=engine,
                               executed_on=executed_on,
                               run_id=run_id)
    return ratio_fits_name


if __name__ == '__main__':
    main(cube_fits_list=['test_cube.fits', 'test_cube.fits'],
         mom0_out_cube1='test_cube_mom0.fits',
         mom0_out_cube2='test_cube_mom0_1.fits',
         run_id='test_run')
