import glob
import logging
import sys
import yaml
import numpy as np
import shutil
import os
import urllib.request
import json
import pandas as pd
from hashlib import sha1
from typing import (List,
                    Union,
                    Tuple)
from astropy import units as u
from assets.constants import (leiden_url_mapping)


def setup_logger(name: str,
                 log_level: str = None) -> logging.Logger:
    """
    Configure default logger
    :param name: Logger name
    :param log_level: Logging levels, as defined by logging
    :return: the logger object
    """
    log_level = 'DEBUG' if log_level is None else log_level
    """general logger configurator"""
    logger = logging.getLogger(name)
    logger.setLevel(log_level)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(log_level)
    formatter = logging.Formatter("%(asctime)s - %(name)-7s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger


def validate_parameter(param_to_validate,
                       default):
    """
    Substitute default value if the parameter is set to None
    :param param_to_validate: input to validate
    :param default: default value
    :return: default value if parameter is set to None, else the passed value
    """
    return param_to_validate if param_to_validate is not None else default


def load_config_file(config_file_path: str,
                     override_config: Union[dict, None] = None) -> dict:
    """
    Load the information in the YAML configuration file into a python dictionary
    :param config_file_path: path to the configuration file
    :param override_config: parameters of the input file to override (e.g. for grid creation)
    :return: a dictionary with the parsed information
    """
    _override_config = validate_parameter(override_config, default={})
    with open(config_file_path) as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)
    for key in _override_config:
        if key not in config.keys():
            config[key] = {}
        for subkey in _override_config[key]:
            try:
                config[key][subkey] = _override_config[key][subkey]
            except KeyError:
                config[key] = {subkey: _override_config[key][subkey]}

    try:
        if config['density_powerlaw_idx'] == 0:
            config['central_density'] = config['density_at_reference']
            config['density_at_reference'] = None
    except KeyError:
        pass

    try:
        if config['dust_temperature_powerlaw_idx'] == 0:
            config['dust_temperature'] = config['dust_temperature_at_reference']
            config['dust_temperature_at_reference'] = None
    except KeyError:
        pass
    return config


def get_moldata(species_names: list,
                logger: logging.Logger,
                path: Union[str, None] = None,
                use_cache: bool = False):
    """
    Downloads the molecular data from the Leiden molecular database; check whether the molecule in mapped in
        leiden_url_mapping
    :param species_names: the names of the species for which to download data
    :param logger: logger for output printing
    :param path: path where the files must be saved
    :param use_cache: use molecular input files in cache, if possible
    """
    _path = validate_parameter(path, default=os.path.join('mdl', 'radmc_files'))
    for species in species_names:
        molecular_file_name = f'molecule_{species}.inp'
        if use_cache is True:
            shutil.copy(os.path.join('mdl', 'data', molecular_file_name),
                        os.path.join(_path, molecular_file_name))
        if not os.path.isfile(os.path.join(_path, molecular_file_name)):
            logger.info(f'Downloading file molecule_{species}.inp...')
            data = urllib.request.urlopen(leiden_url_mapping[species]).read().decode()
            with open(os.path.join(_path, f'molecule_{species}.inp'), 'w') as outfile:
                outfile.writelines(data)
        else:
            logger.info(f'File molecule_{species}.inp found, skipping download...')


def compute_unique_hash_filename(config: dict) -> str:
    """
    Compute a unique and reproducible filename given a configuration dictionary
    :param config: the configuration dictionary to hash
    :return: a unique hash of the dictionary, to use as a key and filename
    """
    hashed_dict = sha1(json.dumps(config, sort_keys=True).encode())
    return f'{hashed_dict.hexdigest()}'


def make_archive(output_filename: str,
                 source_dir: str,
                 archive_path: Union[str, None] = None):
    """
    Compresses all files in a directory into a .zip file
    :param output_filename: the output filename
    :param source_dir: the directory to compress
    :param archive_path: path to store the compressed-grid archives
    """
    _archive_path = validate_parameter(archive_path,
                                       default=os.path.join('stg', 'archive'))
    filename_root, filename_format = output_filename.rsplit('.', maxsplit=1)
    try:
        os.remove(os.path.join(_archive_path, output_filename))
    except FileNotFoundError:
        pass
    shutil.make_archive(base_name=os.path.join(_archive_path, filename_root),
                        format=filename_format,
                        root_dir=source_dir)


def cleanup_directory(directory: str,
                      logger: logging.Logger):
    """
    Removes all files from the specified directory.
    :param directory: directory to clean up
    :param logger: logger to use
    """
    if os.path.isdir(directory):
        file_list = glob.glob(f'{directory}/*')
        for filename in file_list:
            if os.path.isfile(filename) is True:
                logger.debug(f'Removing file {filename}')
                os.remove(filename)
            else:
                logger.debug(f'Skipping {filename}')


def convert_frequency_to_wavelength(frequency: u.Quantity,
                                    output_units: u.Unit):
    return frequency.to(output_units, equivalencies=u.spectral())


def get_value_if_specified(parameters_dict: dict,
                           key: str):
    try:
        return parameters_dict[key]
    except KeyError:
        return None


def get_postprocessed_data(limit_rows: Union[None, int] = None,
                           use_model_for_inference: Union[None, str] = None) -> Tuple[List[List[str]], pd.DataFrame]:
    """
    Retrieve data and line pairs from the main config file
    :param use_model_for_inference: the prs/output/run_type folder from which to get the data from inference;
        defaults to fiducial model (constant_abundance_p15_q05)
    :param limit_rows: the number of rows to use from the original dataset; useful to run tests and limit
        computation time
    :return: the line pairs list and the dataset
    """
    _use_model_for_inference = validate_parameter(
        use_model_for_inference,
        default='constant_abundance_p15_q05'
    )
    _data_file = os.path.join('prs', 'output', 'run_type', _use_model_for_inference, 'data', 'full_dataset.csv')
    config = load_config_file(os.path.join('config', 'config.yml'))
    line_pairs = config['overrides']['lines_to_process']
    results_df_dict = {
        'ratios_and_los_averages': add_px_index_column(_data_file)
    }
    if limit_rows is not None:
        data = results_df_dict['ratios_and_los_averages'].head(limit_rows)
    else:
        data = results_df_dict['ratios_and_los_averages']
    return line_pairs, data


def prepare_matrix(filename: str,
                   columns: list,
                   use_model_for_inference: Union[None, str] = None) -> pd.DataFrame:
    """
    Retrieve and prepare the data matrix from a specified file and columns.

    :param filename: The name of the file to read the data from.
    :param columns: The list of columns to extract from the dataframe.
    :param use_model_for_inference: The folder within prs/output/run_type to get the data for inference;
        defaults to the fiducial model ('constant_abundance_p15_q05') if None is provided.
    :return: A pandas DataFrame containing the specified columns from the file with 'nh2' and 'tdust' columns
        rounded to one decimal place and converted to string type.
    """
    _use_model_for_inference = validate_parameter(
        use_model_for_inference,
        default='constant_abundance_p15_q05'
    )
    df = pd.read_csv(os.path.join('prs', 'output', 'run_type', _use_model_for_inference, 'data', filename))
    df['nh2'] = pd.Series(df['nh2'].round(1), dtype='string')
    df['tdust'] = pd.Series(df['tdust'].round(1), dtype='string')
    return df[columns]


def get_data(limit_rows: Union[int, None] = None,
             use_model_for_inference: Union[None, str] = None,
             log_columns: Union[None, List] = None):
    """
    Retrieve and preprocess dataset.

    :param limit_rows: The number of rows to use from the original dataset; useful for running tests and limiting
        computation time. Defaults to None, which uses all rows.
    :param use_model_for_inference: The folder within prs/output/run_type to get the data for inference;
        defaults to the fiducial model ('constant_abundance_p15_q05') if None is provided.
    :param log_columns: The list of columns to apply a logarithmic transformation to. Defaults to
        ['log_nh2', 'log_tdust', 'avg_nh2', 'avg_tdust', 'molecule_column_density', 'std_nh2'] if None is provided.
    :return: A pandas DataFrame containing the merged and processed data from multiple sources, with specified
        columns logarithmically transformed.
    """
    _use_model_for_inference = validate_parameter(
        use_model_for_inference,
        default='constant_abundance_p15_q05'
    )
    _log_columns = validate_parameter(
        log_columns,
        default=['log_nh2', 'log_tdust', 'avg_nh2', 'avg_tdust', 'molecule_column_density', 'std_nh2']
    )
    _, data = get_postprocessed_data(limit_rows=limit_rows,
                                     use_model_for_inference=_use_model_for_inference)
    data['nh2'] = pd.Series(data['nh2'].round(1), dtype='string')
    data['tdust'] = pd.Series(data['tdust'].round(1), dtype='string')
    df87 = prepare_matrix(filename='full_dataset_NH2_lines_87-86.csv',
                          columns=['px_index', 'nh2', 'tdust', 'mom_zero_87', 'mom_zero_86', 'molecule_column_density'],
                          use_model_for_inference=_use_model_for_inference)
    df88 = prepare_matrix(filename='full_dataset_NH2_lines_88-86.csv',
                          columns=['px_index', 'nh2', 'tdust', 'mom_zero_88'],
                          use_model_for_inference=_use_model_for_inference)
    df257 = prepare_matrix(filename='full_dataset_NH2_lines_257-256.csv',
                           columns=['px_index', 'nh2', 'tdust', 'mom_zero_257', 'mom_zero_256'],
                           use_model_for_inference=_use_model_for_inference)
    df381 = prepare_matrix(filename='full_dataset_NH2_lines_381-380.csv',
                           columns=['px_index', 'nh2', 'tdust', 'mom_zero_381', 'mom_zero_380'],
                           use_model_for_inference=_use_model_for_inference)
    merged = data.merge(df87, on=['px_index', 'nh2', 'tdust'])
    merged = merged.merge(df88, on=['px_index', 'nh2', 'tdust'])
    merged = merged.merge(df257, on=['px_index', 'nh2', 'tdust'])
    merged = merged.merge(df381, on=['px_index', 'nh2', 'tdust'])
    # These are transformed later on, to allow splitting
    merged['nh2'] = pd.to_numeric(merged['nh2'])
    merged['tdust'] = pd.to_numeric(merged['tdust'])
    merged['log_nh2'] = pd.to_numeric(merged['nh2'])
    merged['log_tdust'] = pd.to_numeric(merged['tdust'])
    for column in _log_columns:
        merged[column] = np.log10(merged[column])
    npixels = 101
    refpix = 50
    merged['px_distance'] = np.sqrt(
        ((merged['px_index'] / npixels).astype(int) - refpix) ** 2 + ((merged['px_index'] % npixels) - refpix) ** 2)
    merged['core_radius_px'] = (1e8 / merged['nh2']) ** -(2 / 3) * 0.5 * 101 / 2
    return merged


def add_px_index_column(filename: str) -> pd.DataFrame:
    """
    Reset index to add pixel index column to data
    :param filename: The filename of the csv file that contains the data
    :return: a pandas dataframe with the px_index column
    """
    df = pd.read_csv(filename, index_col=0)
    if 'px_index' not in df.columns:
        df = df.reset_index().rename(columns={'index': 'px_index'})
    return df
