import logging
import uuid

import numpy as np
import os
import shutil
from datetime import datetime
from typing import Union, List, Tuple
from sqlalchemy import engine as sqla_engine
from assets.commons import (load_config_file,
                            get_moldata,
                            compute_unique_hash_filename,
                            make_archive,
                            validate_parameter,
                            setup_logger,
                            cleanup_directory,
                            get_value_if_specified)
from assets.commons.grid_utils import (compute_molecular_number_density_hot_core,
                                       compute_power_law_radial_profile,
                                       extract_grid_metadata)
from assets.commons.parsing import read_abundance_variation_schema
from assets.commons.db_utils import upsert, get_pg_engine
from assets.constants import (mean_molecular_mass,
                              radmc_input_headers,
                              radmc_lines_mode_mapping)
from stg.stg_build_db_structure import GridPars, GridFiles, StarsPars
from astropy import units as u
from astropy import constants as cst
from astropy.io import fits

logger = setup_logger(name='STG')


def write_radmc_input(filename: str,
                      quantity: np.array,
                      grid_metadata: dict,
                      path: Union[None, str] = None,
                      override_defaults: Union[None, dict] = None,
                      flatten_style: Union[None, str] = None):
    rt_metadata_default = {
        'iformat': 1,
        'grid_type': grid_metadata['grid_type'],
        'coordinate_system': grid_metadata['coordinate_system'],
        'grid_info': 0,  ##
        'grid_shape': grid_metadata['grid_shape'],
        'ncells': np.prod(grid_metadata['grid_shape']),
        'ncells_per_axis': ' '.join([str(axis_size) for axis_size in grid_metadata['grid_shape']]),
        'dust_species': 1,  ##
        'active_axes': ' '.join(['1'] * len(grid_metadata['grid_shape'])),
        'continuum_lambdas': grid_metadata['continuum_lambdas'],  ##
    }
    _path = validate_parameter(path, default='.')
    _filename = 'numberdens_mol.inp' if filename.startswith('numberdens_') else filename
    _flatten_style = validate_parameter(flatten_style, default='F')
    with open(os.path.join(_path, filename), "w") as outfile:
        for key in radmc_input_headers[_filename]:
            try:
                header_line = override_defaults[key]
            except (KeyError, TypeError):
                header_line = rt_metadata_default[key]
            outfile.write(f'{header_line}\n')
        outfile.write('\n'.join(quantity.flatten(_flatten_style).astype(str)))


def write_radmc_lines_input(line_config: dict,
                            logger: logging.Logger,
                            path: Union[None, str] = None):
    _path = validate_parameter(path, default='.')

    if line_config['lines_mode'] != 'lte':
        assert len(line_config['collision_partners']) > 0
    else:
        logger.warning('Some collision partners were specified, but LTE populations were requested. Ignoring collision '
                       'partner data...')
        line_config['collision_partners'] = []

    with open(os.path.join(_path, 'lines.inp'), "w") as outfile:
        # iformat
        outfile.write('2\n')
        outfile.write(f'{str(len(line_config["species_to_include"]))}\n')
        for species in line_config["species_to_include"]:
            outfile.write(
                f'{species} leiden 0 0 {str(len(line_config["collision_partners"]))}\n')
            for coll_partner in line_config['collision_partners']:
                outfile.write(f'{coll_partner}\n')


def copy_additional_files(files_to_transfer: Union[None, List[str]] = None,
                          dst_path: Union[str, None] = None):
    """
    Copy constant files from cache
    :param files_to_transfer: list of files to transfer
    :param dst_path: destination path of the files
    """
    _dst_path = validate_parameter(dst_path, default=os.path.join('mdl', 'radmc_files'))
    _files_to_transfer = validate_parameter(files_to_transfer, default=['dustkappa_silicate.inp',
                                                                        'dustopac.inp'])
    _path = os.path.join('mdl', 'data')
    if not os.path.isdir(_dst_path):
        os.mkdir(_dst_path)
    for filename in _files_to_transfer:
        shutil.copy2(os.path.join(_path, filename), _dst_path)


def write_radmc_main_input_file(config_mdl: dict,
                                config_lines: dict,
                                path: Union[None, str] = None):
    """
    Creates the main input file for RADMC, which can be used to reprocess the input
    :param config_mdl: the model configuration
    :param config_lines: te line configuration, from the STG layer
    :param path: the path where to put the input file
    """
    _path = validate_parameter(path, default='.')
    with open(os.path.join(_path, 'radmc3d.inp'), "w") as outfile:
        for parameter in config_mdl["radmc_postprocessing"]:
            outfile.write(f'{parameter} = {config_mdl["radmc_postprocessing"][parameter]}\n')
        outfile.write(f'lines_mode = {radmc_lines_mode_mapping[config_lines["lines_mode"]]}\n')


def get_solid_body_rotation_y(grid_metadata: dict) -> Tuple[np.array, np.array]:
    """
    Generate the velocity field for solid body rotation around the y-axis according to the configuration file
    :param grid_metadata: the grid metadata
    :return: a tuple with the velocity along the x- and y-axes
    """
    distance_xz = grid_metadata['distance_matrix'][:, grid_metadata['grid_refpix'][1], :]
    gradient = float(grid_metadata['velocity_gradient']) * u.Unit(grid_metadata['velocity_gradient_unit'])
    radial_velocity = (distance_xz * u.cm * gradient).to(u.cm / u.second)
    angle_in_radians = np.arctan(
        abs(grid_metadata['centered_indices'][2, ...] / grid_metadata['centered_indices'][0, ...]))
    velocity_x = radial_velocity[:, np.newaxis, :] * -np.sin(angle_in_radians) * np.sign(
        grid_metadata['centered_indices'][2, ...])
    velocity_z = radial_velocity[:, np.newaxis, :] * np.cos(angle_in_radians) * np.sign(
        grid_metadata['centered_indices'][0, ...])
    velocity_x = np.where(np.isnan(velocity_x), 0, velocity_x)
    velocity_z = np.where(np.isnan(velocity_z), 0, velocity_z)
    return velocity_x, velocity_z


def get_profiles(grid_metadata: dict) -> dict:
    """
    Function that computes the profiles of the physical quantities needed for the model.
    :param grid_metadata: the grid metadata obtained parsing the config file
    :return: a dictionary of arrays containing the profiles of physical quantities
    """
    density_ref = get_reference_value(grid_metadata=grid_metadata,
                                      quantity_name='density',
                                      desired_unit='cm^-3')
    tdust_ref = get_reference_value(grid_metadata=grid_metadata,
                                    quantity_name='dust_temperature',
                                    desired_unit='K')
    profiles_mapping = {
        'gas_number_density': {
            'central_value': (float(grid_metadata['central_density']) * u.Unit(grid_metadata['density_unit']))
            .to(u.cm ** -3).value,
            'power_law_index': float(grid_metadata['density_powerlaw_idx']),
            'value_at_reference': density_ref['value_at_reference'],
            'distance_reference': density_ref['distance_reference']
        },
        'dust_temperature': {
            'central_value': float(grid_metadata['dust_temperature']),
            'power_law_index': float(grid_metadata['dust_temperature_powerlaw_idx']),
            'value_at_reference': tdust_ref['value_at_reference'],
            'distance_reference': tdust_ref['distance_reference']
        },
        'microturbulence': {
            'central_value': (float(grid_metadata['microturbulence']) * u.Unit(grid_metadata['microturbulence_unit']))
            .to(u.cm / u.second).value,
            'power_law_index': 0,
            'value_at_reference': None,
            'distance_reference': 1.0
        },
        'escprob_lengthscale': {
            'central_value': min((np.max(grid_metadata['grid_size']) *
                                  u.Unit(grid_metadata['grid_size_units'][np.argmax(grid_metadata['grid_size'])]))
                                 .to(u.cm).value,
                                 (2 * grid_metadata['maximum_radius'] * u.Unit(grid_metadata['maximum_radius_unit']))
                                 .to(u.cm).value),
            'power_law_index': 0,
            'value_at_reference': None,
            'distance_reference': 1.0
        },
        'velocity_x': {
            'central_value': 0,
            'power_law_index': 0,
            'value_at_reference': None,
            'distance_reference': 1.0
        },
        'velocity_y': {
            'central_value': 0,
            'power_law_index': 0,
            'value_at_reference': None,
            'distance_reference': 1.0
        },
        'velocity_z': {
            'central_value': 0,
            'power_law_index': 0,
            'value_at_reference': None,
            'distance_reference': 1.0
        },
    }

    profiles = {}

    for profile in profiles_mapping:
        profiles[profile] = compute_power_law_radial_profile(
            central_value=profiles_mapping[profile]['central_value'],
            power_law_index=profiles_mapping[profile]['power_law_index'],
            distance_matrix=grid_metadata['distance_matrix'],
            value_at_reference=profiles_mapping[profile]['value_at_reference'],
            distance_reference=profiles_mapping[profile]['distance_reference'],
            maximum_radius=grid_metadata['maximum_radius'] * u.Unit(grid_metadata['maximum_radius_unit']).to(u.cm)
        )

    if grid_metadata['velocity_field'] == 'solid':
        profiles['velocity_x'], profiles['velocity_z'] = get_solid_body_rotation_y(grid_metadata=grid_metadata)
    profiles['velocity_field'] = np.array([profiles['velocity_x'], profiles['velocity_y'], profiles['velocity_z']])

    if 'floor_temperature' in grid_metadata.keys():
        profiles['dust_temperature'] = np.where(
            profiles['dust_temperature'] < float(grid_metadata['floor_temperature']),
            float(grid_metadata['floor_temperature']),
            profiles['dust_temperature']
        )
    return profiles


def get_reference_value(grid_metadata: dict,
                        quantity_name: str,
                        desired_unit: str) -> dict:
    """
    Compile a dictionary of reference values for the grid, applying a conversion, if necessary
    :param grid_metadata: the grid metadata
    :param quantity_name: the quantity to process
    :param desired_unit: the output unit to use
    :return: the dictionary with the value at reference in the desired units and the distance reference used to scale
        the power law
    """
    try:
        reference_value_dict = {
            'value_at_reference': (
                (float(grid_metadata[f'{quantity_name}_at_reference']) * u.Unit(grid_metadata[f'{quantity_name}_unit']))
                .to(u.Unit(desired_unit))).value,
            'distance_reference': (float(grid_metadata['distance_reference']) * u.Unit(
                grid_metadata['distance_reference_unit'])).to(u.cm).value
        }
    except (KeyError, AttributeError):
        reference_value_dict = {
            'value_at_reference': None,
            'distance_reference': 1.0
        }
    return reference_value_dict


def write_grid_input_files(profiles: dict,
                           grid_metadata: dict,
                           wavelengths_micron: np.array,
                           path: Union[str, None] = None):
    """
    Write the grid input files needed by RADMC3d
    :param profiles: the dictionary of profiles for physical quantities that define the model
    :param grid_metadata: the grid metadata
    :param wavelengths_micron: the array of wavelengths to model
    :param path: the output path
    """
    _path = validate_parameter(path, default=os.path.join('mdl', 'radmc_files'))

    quantity_mapping = {
        'amr_grid.inp': {
            'quantity': grid_metadata['grid_edges'],
        },
        'dust_density.inp': {
            'quantity': (profiles['gas_number_density'] * u.cm ** -3 / 100. * cst.m_p.to(
                u.gram) * mean_molecular_mass).value,
        },
        'dust_temperature.dat': {
            'quantity': profiles['dust_temperature'],
        },
        'microturbulence.inp': {
            'quantity': profiles['microturbulence'],
        },
        'gas_velocity.inp': {
            'quantity': profiles['velocity_field'],
        },
        'wavelength_micron.inp': {
            'quantity': wavelengths_micron,
        },
        'escprob_lengthscale.inp': {
            'quantity': profiles['escprob_lengthscale'],
        },
    }
    for filename in quantity_mapping:
        write_radmc_input(filename=filename,
                          quantity=quantity_mapping[filename]['quantity'],
                          path=_path,
                          grid_metadata=grid_metadata)


def write_molecular_number_density_profiles(profiles: dict,
                                            line_config: dict,
                                            grid_metadata: dict,
                                            path: Union[str, None] = None) -> dict:
    """
    Write the molecular number density profiles. A different function is used wrt the other profiles because the
    molecular ones depend on the others, and on the abundance (profiles)
    :param profiles: the dictionary of profiles for physical quantities that define the model
    :param line_config: the dictionary of the line configurations, for the abundances
    :param grid_metadata: the grid metadata
    :param path: the output path
    """
    _path = validate_parameter(path, default=os.path.join('mdl', 'radmc_files'))
    hot_core_specs = read_abundance_variation_schema(line_config=line_config)
    species_profiles = {}
    for species in line_config['species_to_include'] + line_config['collision_partners']:
        species_profiles[species] = compute_molecular_number_density_hot_core(
            gas_number_density_profile=profiles['gas_number_density'],
            abundance=float(line_config['molecular_abundances'][species]),
            temperature_profile=profiles['dust_temperature'],
            threshold=float(hot_core_specs[species]['threshold']),
            abundance_jump=float(hot_core_specs[species]['abundance_jump']))
        write_radmc_input(filename=f'numberdens_{species}.inp',
                          quantity=species_profiles[species],
                          path=_path,
                          grid_metadata=grid_metadata)
    return species_profiles


def save_fits_grid_profile(quantity: np.array,
                           filename: str,
                           path: str = None):
    """
    Save the model grid to a fits file
    :param quantity: the array to be stored in the fits file
    :param filename: the name of the file to be created
    :param path: the path where the file will be stored
    """
    _path = validate_parameter(path, default=os.path.join('prs', 'fits', 'grids'))
    if not os.path.isfile(os.path.join(_path, filename)):
        try:
            fits.writeto(os.path.join(_path, filename), quantity)
        except OSError:
            pass
    else:
        logger.info('Skipping saving of fits grid. File already present!')


def convert_dimensional_unit(value: Union[float, None, str],
                             current_unit: str,
                             desired_unit: str) -> Union[float, None]:
    """
    Convert the quantities specified in the configuration file to standard ones
    :param value: the value to convert
    :param current_unit: unit specified in the config file
    :param desired_unit: the output unit
    :return: the converted value, if defined
    """
    if value is not None:
        return (float(value) * u.Unit(current_unit)).to(u.Unit(desired_unit)).value
    else:
        return None


def remap_metadata_to_grid_row(zipped_grid_name: str,
                               grid_metadata: dict,
                               config: dict) -> dict:
    """
    Create a dictionary that can be persisted in the database, according to the GridPars table structure
    :param zipped_grid_name: the name of the compressed archive that contains grid files
    :param grid_metadata: the grid metadata
    :param config: the configuration dictionary, to store human-readable grid- and coordinate types
    :return: the dictionary to be inserted in the database
    """
    _config = config['grid']
    dimensional_quantities = {
        'central_density': {'unit_key': 'density_unit', 'desired_unit': 'cm^-3'},
        'density_at_reference': {'unit_key': 'density_unit', 'desired_unit': 'cm^-3'},
        'microturbulence': {'unit_key': 'microturbulence_unit', 'desired_unit': 'km/s'},
        'velocity_gradient': {'unit_key': 'velocity_gradient_unit', 'desired_unit': 'km/(s pc)'},
        'velocity_at_reference': {'unit_key': 'velocity_unit', 'desired_unit': 'km/s'},
        'distance_reference': {'unit_key': 'distance_reference_unit', 'desired_unit': 'pc'},
        'maximum_radius': {'unit_key': 'maximum_radius_unit', 'desired_unit': 'pc'},
    }
    grid_row = {
        'zipped_grid_name': zipped_grid_name,
        'grid_type': _config['grid_type'],
        'coordinate_system': _config['coordinate_system'],
        'central_density': get_value_if_specified(parameters_dict=grid_metadata, key='central_density'),
        'density_powerlaw_index': get_value_if_specified(parameters_dict=grid_metadata, key='density_powerlaw_idx'),
        'density_at_reference': get_value_if_specified(parameters_dict=grid_metadata, key='density_at_reference'),
        'dust_temperature': get_value_if_specified(parameters_dict=grid_metadata, key='dust_temperature'),
        'dust_temperature_powerlaw_index': get_value_if_specified(parameters_dict=grid_metadata,
                                                                  key='dust_temperature_powerlaw_idx'),
        'dust_temperature_at_reference': get_value_if_specified(parameters_dict=grid_metadata,
                                                                key='dust_temperature_at_reference'),
        'microturbulence': get_value_if_specified(parameters_dict=grid_metadata, key='microturbulence'),
        'velocity_field': get_value_if_specified(parameters_dict=grid_metadata, key='velocity_field'),
        'velocity_gradient': get_value_if_specified(parameters_dict=grid_metadata, key='velocity_gradient'),
        'velocity_powerlaw_index': get_value_if_specified(parameters_dict=grid_metadata, key='velocity_powerlaw_index'),
        'velocity_at_reference': get_value_if_specified(parameters_dict=grid_metadata, key='velocity_at_reference'),
        'distance_reference': get_value_if_specified(parameters_dict=grid_metadata, key='distance_reference'),
        'maximum_radius': get_value_if_specified(parameters_dict=grid_metadata, key='maximum_radius'),
        'grid_size_1': grid_metadata['grid_size'][0],
        'grid_shape_1': grid_metadata['grid_shape'][0],
        'grid_refpix_1': grid_metadata['grid_refpix'][0],
        'grid_size_2': grid_metadata['grid_size'][1],
        'grid_shape_2': grid_metadata['grid_shape'][1],
        'grid_refpix_2': grid_metadata['grid_refpix'][1],
        'grid_size_3': grid_metadata['grid_size'][2],
        'grid_shape_3': grid_metadata['grid_shape'][2],
        'grid_refpix_3': grid_metadata['grid_refpix'][2],
    }
    for qty in dimensional_quantities:
        try:
            grid_row[qty] = convert_dimensional_unit(value=grid_row[qty],
                                                     current_unit=u.Unit(
                                                         grid_metadata[dimensional_quantities[qty]['unit_key']]),
                                                     desired_unit=u.Unit(dimensional_quantities[qty]['desired_unit']))
        except KeyError:
            grid_row[qty] = None
    for dim_idx in range(2):
        grid_row[f'grid_size_{dim_idx + 1}'] = (
                grid_row[f'grid_size_{dim_idx + 1}'] * u.Unit(grid_metadata['grid_size_units'][dim_idx])).to(
            u.pc).value
    return grid_row


def populate_grid_table(config: dict,
                        engine: sqla_engine,
                        grid_metadata: dict,
                        run_id: str) -> str:
    """
    Upsert the grid data to the postgres DB
    :param config: the configuration dictionary
    :param engine: the SQLAlchemy engine to use
    :param grid_metadata: the grid metadata
    :return: the compressed archive filename
    """
    output_filename = f'{compute_unique_hash_filename(config=config)}.zip'
    remapped_row = remap_metadata_to_grid_row(
        zipped_grid_name=output_filename,
        grid_metadata=grid_metadata,
        config=config
    )
    remapped_row['run_id'] = run_id
    remapped_row['created_on'] = datetime.now()
    upsert(
        table_object=GridPars,
        row_dict=remapped_row,
        conflict_keys=[GridPars.zipped_grid_name, GridPars.run_id],
        engine=engine
    )
    return output_filename


def populate_grid_files(quantity_name: str,
                        engine: sqla_engine,
                        zip_filename: str,
                        filename: str,
                        run_id: str):
    """
    Stores into the DB the grid filename for later reference
    :param quantity_name: the name of the quantity to store
    :param engine: the SQLAlchemy engine
    :param zip_filename: the name of the grid archive containing the radmc3d input files
    :param filename: the fits filename where the grid is stored
    :param run_id: the ID of the run
    """
    raw_insert_entry = {'zipped_grid_name': zip_filename,
                        'quantity': quantity_name,
                        'fits_grid_name': filename,
                        'created_on': datetime.now(),
                        'run_id': run_id}
    upsert(
        table_object=GridFiles,
        row_dict=raw_insert_entry,
        conflict_keys=[GridFiles.zipped_grid_name, GridFiles.quantity, GridFiles.run_id],
        engine=engine
    )


def populate_stars_table(config_stars: dict,
                         engine: sqla_engine,
                         grid_zipfile: str,
                         run_id: str):
    """
    Populate the stars table in the DB
    :param config_stars: the dictionary containing the stars configurations
    :param engine: the SQLAlchemy engine to use
    :param executed_on: the timestamp of execution, to add to the record
    :param grid_zipfile: the grid tarfile name, to be used as key
    :param run_id: the run unique identifier
    """
    _config_stars = config_stars.copy()
    _config_stars['zipped_grid_name'] = grid_zipfile
    _config_stars['created_on'] = datetime.now()
    _config_stars['run_id'] = run_id

    upsert(
        table_object=StarsPars,
        row_dict=_config_stars,
        conflict_keys=[StarsPars.zipped_grid_name, StarsPars.run_id],
        engine=engine
    )


def write_stellar_input_file(stars_metadata: dict,
                             grid_metadata: dict,
                             path: str,
                             wavelengths_micron: np.array):
    """
    Create the input file to insert stars in the computation
    :param stars_metadata: the metadata about the stars, as specified in the configuration file
    :param grid_metadata: the grid metadata
    :param path: the path where the files will be created
    :param wavelengths_micron: the wavelength to consider
    """
    star_properties = [' '.join([str(rstar), str(mstar), str(pos[0]), str(pos[1]), str(pos[2])]) for rstar, mstar, pos
                       in zip(stars_metadata['rstars'], stars_metadata['mstars'], stars_metadata['star_positions'])]
    override_defaults = {
        'iformat': 2,
        'nstars': stars_metadata['nstars'],
        'continuum_lambdas': stars_metadata['nlambdas'],
        'stars_properties': '\n'.join(star_properties),
        'lambdas': ' \n'.join(wavelengths_micron.astype(str)),
    }
    write_radmc_input(filename='stars.inp',
                      path=path,
                      quantity=np.array(stars_metadata['star_fluxes']),
                      grid_metadata=grid_metadata,
                      override_defaults=override_defaults,
                      flatten_style='C')


def get_grid_name(method: Union[str, None] = None,
                  zip_filename: Union[str, None] = None,
                  quantity_name: Union[str, None] = None):
    """
    Compute the name of the grid file
    :param method: method used to compute the file name; uuid generates a random unique identifier, composite_grid
        appends the quantity name to the zipped archive name
    :param zip_filename: the zipped archive name
    :param quantity_name: the name of the quantity that should be processed
    :return: the name of the file
    """
    _method = validate_parameter(method, default='uuid')
    allowed_methods = ('uuid', 'composite_grid')
    if method == 'uuid':
        return f'{str(uuid.uuid4())}.fits'
    elif method == 'composite_grid':
        assert ((zip_filename is not None) and (quantity_name is not None))
        return f'{".".join(zip_filename.split(".")[0:-1])}_{quantity_name}.fits'
    else:
        raise NotImplementedError(
            f'The chosen method is not available. Allowed options are: {" ".join(allowed_methods)}')


def main(run_id: str,
         override_config: Union[dict, None] = None,
         path_radmc_files: Union[str, None] = None,
         compute_dust_temperature: bool = True,
         engine: sqla_engine = None) -> str:
    _override_config = validate_parameter(override_config, default={'grid_lines': {}, 'model': {}})
    config = load_config_file(os.path.join('stg', 'config', 'config.yml'),
                              override_config=_override_config['grid_lines'])
    config_lines = config['lines']
    config_mdl = load_config_file(os.path.join('mdl', 'config', 'config.yml'),
                                  override_config=_override_config['model'])
    input_files_dir = validate_parameter(path_radmc_files, default=os.path.join('mdl', 'radmc_files'))
    cleanup_directory(directory=input_files_dir,
                      logger=logger)
    copy_additional_files(dst_path=input_files_dir)

    grid_metadata = extract_grid_metadata(config=config)
    if 'stars' in config:
        wavelengths_micron = manage_wavelength_grid_with_stars(config, grid_metadata, input_files_dir)
    else:
        wavelengths_micron = np.logspace(np.log10(5),
                                         np.log10(1300),
                                         100)
    grid_metadata['continuum_lambdas'] = len(wavelengths_micron)

    profiles = get_profiles(grid_metadata=grid_metadata)
    write_grid_input_files(grid_metadata=grid_metadata,
                           profiles=profiles,
                           path=input_files_dir,
                           wavelengths_micron=wavelengths_micron)
    get_moldata(species_names=config['lines']['species_to_include'],
                logger=logger,
                path=input_files_dir,
                use_cache=True)
    write_radmc_lines_input(line_config=config['lines'],
                            path=input_files_dir,
                            logger=logger)

    molecular_number_density_profiles = write_molecular_number_density_profiles(profiles=profiles,
                                                                                grid_metadata=grid_metadata,
                                                                                line_config=config['lines'],
                                                                                path=input_files_dir)
    if engine is None:
        engine = get_pg_engine(logger=logger)
    zip_filename = populate_grid_table(config=config,
                                       engine=engine,
                                       grid_metadata=grid_metadata,
                                       run_id=run_id)
    engine.dispose()

    write_radmc_main_input_file(config_mdl=config_mdl,
                                config_lines=config_lines,
                                path=input_files_dir)
    execution_dir = os.getcwd()
    os.chdir(input_files_dir)
    # Recompute dust temperature distribution if needed, based on star positions and properties
    if 'stars' in config:
        if compute_dust_temperature is True:
            logger.info('Computing dust temperature distribution using the stars in the configuration file')
            try:
                _threads = config_mdl['radmc_observation']['threads']
            except KeyError:
                _threads = 4
            os.system(f'radmc3d mctherm setthreads {str(_threads)}')
            populate_stars_table(config_stars=config['stars'],
                                 engine=engine,
                                 grid_zipfile=zip_filename,
                                 run_id=run_id)
        else:
            logger.info('Using cached dust temperature distribution')
            shutil.copy(os.path.join(execution_dir, 'model', 'data', 'dust_temperature.dat'),
                        os.path.join('.', 'dust_temperature.dat'))
            # TODO: Should I apply the floor temperature here as well?
    os.chdir(execution_dir)

    if engine is None:
        engine = get_pg_engine(logger=logger)
    for quantity_name in ('gas_number_density', 'dust_temperature'):
        save_and_persist_grid(engine=engine,
                              profiles=profiles,
                              run_id=run_id,
                              zip_filename=zip_filename,
                              quantity_name=quantity_name)
    for species in molecular_number_density_profiles:
        save_and_persist_grid(engine=engine,
                              profiles=molecular_number_density_profiles,
                              run_id=run_id,
                              zip_filename=zip_filename,
                              quantity_name=species)

    engine.dispose()

    make_archive(output_filename=zip_filename,
                 source_dir=input_files_dir,
                 archive_path=os.path.join('stg', 'archive'))
    return zip_filename


def manage_wavelength_grid_with_stars(config: dict,
                                      grid_metadata: dict,
                                      input_files_dir: str) -> np.array:
    stars_metadata = config['stars']
    if stars_metadata['spacing'] == 'log':
        wavelengths_micron = np.logspace(np.log10(stars_metadata['lambdas_micron_limits'][0]),
                                         np.log10(stars_metadata['lambdas_micron_limits'][1]),
                                         stars_metadata['nlambdas'])
    elif stars_metadata['spacing'] == 'linear':
        wavelengths_micron = np.linspace(stars_metadata['lambdas_micron_limits'][0],
                                         stars_metadata['lambdas_micron_limits'][1],
                                         stars_metadata['nlambdas'])
    else:
        raise (NotImplemented('Spacing not defined. Choose between {linear, log}'))
    write_stellar_input_file(stars_metadata=stars_metadata,
                             grid_metadata=grid_metadata,
                             path=input_files_dir,
                             wavelengths_micron=wavelengths_micron)
    return wavelengths_micron


def save_and_persist_grid(engine: sqla_engine,
                          profiles: dict,
                          run_id: str,
                          zip_filename: str,
                          quantity_name: str,
                          method: Union[str, None] = None):
    _method = validate_parameter(method, default='composite_grid')
    grid_file_name = get_grid_name(method=_method,
                                   zip_filename=zip_filename,
                                   quantity_name=quantity_name)
    save_fits_grid_profile(quantity=profiles[quantity_name],
                           filename=grid_file_name)
    populate_grid_files(quantity_name=quantity_name,
                        engine=engine,
                        zip_filename=zip_filename,
                        filename=grid_file_name,
                        run_id=run_id)


if __name__ == '__main__':
    main(run_id='test_run')
