import os
import sys
from itertools import product
from typing import Union
import sqlalchemy
import shutil
from astropy import units as u
from datetime import datetime
from radmc3dPy import image
from stg.stg_build_db_structure import (LinePars,
                                        SpeciesAndPartners,
                                        ModelPars,
                                        StarsPars)
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger,
                            compute_unique_hash_filename,
                            get_value_if_specified,
                            cleanup_directory)
from assets.commons.parsing import read_abundance_variation_schema
from assets.commons.db_utils import upsert, get_pg_engine
from assets.constants import (radmc_options_mapping,
                              radmc_lines_mode_mapping)

logger = setup_logger(name='MDL')


def save_cube_as_fits(cube_out_name: Union[str, None] = None,
                      cube_out_path: Union[str, None] = None,
                      path_radmc_files: Union[str, None] = None):
    """
    Convert the RADMC output to fits
    :param cube_out_name: outfile name
    :param cube_out_path: outfile path
    :param path_radmc_files: path where the radmc input/output files are stored
    """
    _cube_out_name = validate_parameter(cube_out_name, default='test_cube.fits')
    _cube_out_path = validate_parameter(cube_out_path, default=os.path.join('prs', 'fits', 'cubes'))
    _path_radmc_files = validate_parameter(path_radmc_files, default=os.path.join('mdl', 'radmc_files'))
    imdata = image.readImage(fname=os.path.join(_path_radmc_files, 'image.out'))
    output_name = os.path.join(_cube_out_path, _cube_out_name)
    if os.path.isfile(output_name):
        os.remove(output_name)
    imdata.writeFits(fname=output_name)


def check_config(config: dict) -> bool:
    """
    Check if all mandatory keys are present in the configuration dictionary.
    :param config: the configuration dictionary
    :return: True is all keys are present, False otherwise
    """
    mandatory_keys = {
        'inclination',
        'position_angle',
    }
    return set(mandatory_keys).difference(set(config.keys())) == set()


def get_command_options(config: dict) -> set:
    """
    Parse the radmc3d command options from the configuration dictionary
    :param config: the configuration dictionary
    :return: a set of string to postpone to the radmc3d image command
    """
    options_list = []
    for key in config["radmc_observation"]:
        try:
            options_list.append(f'{radmc_options_mapping[key]} {config["radmc_observation"][key]}')
        except KeyError:
            pass
    return set(options_list)


def populate_model_table(config_mdl: dict,
                         grid_zipfile: str,
                         cube_filename: str,
                         engine: sqlalchemy.engine,
                         executed_on: datetime.timestamp,
                         run_id: str):
    """
    Populate the model_parameters table in the DB
    :param config_mdl: the model configuration
    :param grid_zipfile: the name of the grid tarfile
    :param cube_filename: the name of the fits cube
    :param engine: the SQLAlchemy engine
    :param executed_on: the timestamp of execution
    :param run_id: the run unique identifier
    """
    model_pars_dict = {
        'zipped_grid_name': grid_zipfile,
        'fits_cube_name': cube_filename,
        'nphotons': config_mdl['radmc_postprocessing']['nphot'],
        'scattering_mode_max': int(config_mdl['radmc_postprocessing']['scattering_mode_max']),
        'iranfreqmode': config_mdl['radmc_postprocessing']['iranfreqmode'],
        'tgas_eq_tdust': config_mdl['radmc_postprocessing']['tgas_eq_tdust'],
        'inclination': config_mdl['radmc_observation']['inclination'],
        'position_angle': config_mdl['radmc_observation']['position_angle'],
        'imolspec': get_value_if_specified(config_mdl['radmc_observation'], 'imolspec'),
        'iline': get_value_if_specified(config_mdl['radmc_observation'], 'iline'),
        'width_kms': get_value_if_specified(config_mdl['radmc_observation'], 'width_kms'),
        'nchannels': get_value_if_specified(config_mdl['radmc_observation'], 'nchannels'),
        'npix': get_value_if_specified(config_mdl['radmc_observation'], 'npix'),
        'created_on': executed_on,
        'run_id': run_id
    }
    upsert(
        table_object=ModelPars,
        row_dict=model_pars_dict,
        conflict_keys=[ModelPars.fits_cube_name, ModelPars.run_id],
        engine=engine
    )


def populate_species_and_partner_table(config_lines: dict,
                                       engine: sqlalchemy.engine,
                                       executed_on: datetime.timestamp,
                                       grid_zipfile: str,
                                       run_id: str):
    """
    Populate the species_and_partners table in the DB
    :param config_lines: the line configuration
    :param engine: the SQLAlchemy engine
    :param executed_on: the timestamp of execution
    :param grid_zipfile: the name of the grid tarfile
    :param run_id: the run unique identifier
    """
    hot_core_specs = read_abundance_variation_schema(line_config=config_lines)
    for (species, collision_partner) in product(config_lines['species_to_include'], config_lines['collision_partners']):
        species_partner_dict = {
            'zipped_grid_name': f'{grid_zipfile}',
            'species_to_include': species,
            'molecular_abundance': config_lines['molecular_abundances'][species],
            'threshold': hot_core_specs[species]['threshold'],
            'abundance_jump': hot_core_specs[species]['abundance_jump'],
            'collision_partner': collision_partner,
            'molecular_abundance_collision_partner': config_lines['molecular_abundances'][collision_partner],
            'created_on': executed_on,
            'run_id': run_id
        }
        upsert(
            table_object=SpeciesAndPartners,
            row_dict=species_partner_dict,
            conflict_keys=[SpeciesAndPartners.zipped_grid_name,
                           SpeciesAndPartners.species_to_include,
                           SpeciesAndPartners.collision_partner,
                           SpeciesAndPartners.run_id],
            engine=engine
        )


def populate_line_table(config_lines: dict,
                        engine: sqlalchemy.engine,
                        executed_on: datetime.timestamp,
                        grid_zipfile: str,
                        run_id: str):
    """
    Populate the lines table in the DB
    :param config_lines: the dictionary containing the line configuration
    :param engine: the SQLAlchemy engine to use
    :param executed_on: the timestamp of execution, to add to the record
    :param grid_zipfile: the grid tarfile name, to be used as key
    :param run_id: the run unique identifier
    """
    line_pars_dict = {
        'zipped_grid_name': f'{grid_zipfile}',
        'lines_mode': config_lines['lines_mode'],
        'created_on': executed_on,
        'run_id': run_id
    }
    upsert(
        table_object=LinePars,
        row_dict=line_pars_dict,
        conflict_keys=[LinePars.zipped_grid_name, LinePars.run_id],
        engine=engine
    )


def main(grid_zipfile: str,
         run_id: str,
         override_config: Union[dict, None] = None,
         radmc_input_path: Union[str, None] = None,
         engine: sqlalchemy.engine = None) -> str:
    # This is necessary, because the lines_mode is needed both in the lines.inp and radmc3d.inp files
    # The reason for splitting the main input file from the rest is that some parameters can be changed
    # independently of the grid for the modeling. The mdl hash should depend on all the mdl parameters, not a subset
    _override_config = validate_parameter(override_config, default={'grid_lines': {}, 'model': {}})
    executed_on = datetime.now()
    _radmc_input_path = validate_parameter(radmc_input_path, default=os.path.join('mdl', 'radmc_files'))
    config_stg = load_config_file(os.path.join('stg', 'config', 'config.yml'),
                                  override_config=_override_config['grid_lines'])
    config_lines = config_stg['lines']
    config_mdl = load_config_file(os.path.join('mdl', 'config', 'config.yml'),
                                  override_config=_override_config['model'])
    assert check_config(config=config_mdl['radmc_observation'])
    with open(os.path.join('mdl', 'radmc3d_postprocessing.sh'), 'w') as outfile:
        outfile.write(f'cd {_radmc_input_path}\n')
        options_set = get_command_options(config_mdl)
        radmc_command = f'radmc3d image {" ".join(options_set)}'
        outfile.write(radmc_command)

    config_full = config_mdl.copy()
    config_full.update(config_stg)
    cube_filename = f'{compute_unique_hash_filename(config=config_full)}.fits'

    # Execute radmc if not done already
    if not os.path.isfile(os.path.join('prs', 'fits', 'cubes', cube_filename)):
        logger.debug(f'Executing command: {radmc_command}')
        execution_dir = os.getcwd()
        os.chdir(_radmc_input_path)
        os.system(radmc_command)
        os.chdir(execution_dir)
        logger.debug(f'Checking presence of file: {os.path.join(_radmc_input_path, "image.out")}')
        assert os.path.isfile(os.path.join(_radmc_input_path, 'image.out'))

        save_cube_as_fits(cube_out_name=cube_filename,
                          cube_out_path=os.path.join('prs', 'fits', 'cubes'),
                          path_radmc_files=radmc_input_path)
    else:
        logger.info('Computation performed already! Skipping...')

    if engine is None:
        engine = get_pg_engine(logger=logger, engine_kwargs={'pool_size': 2})

    populate_line_table(config_lines=config_lines,
                        engine=engine,
                        executed_on=executed_on,
                        grid_zipfile=grid_zipfile,
                        run_id=run_id)

    populate_species_and_partner_table(config_lines=config_lines,
                                       engine=engine,
                                       executed_on=executed_on,
                                       grid_zipfile=grid_zipfile,
                                       run_id=run_id)

    populate_model_table(config_mdl=config_mdl,
                         grid_zipfile=grid_zipfile,
                         cube_filename=cube_filename,
                         engine=engine,
                         executed_on=executed_on,
                         run_id=run_id)
    engine.dispose()
    return cube_filename


if __name__ == '__main__':
    main(grid_zipfile='459295aa894dffa8c521e606d14dbb6927638a2c.zip',
         run_id='test_run')
