import logging
import os
from typing import Union, List
import numpy as np
from astropy import units as u
from assets.commons import validate_parameter, load_config_file, setup_logger
from assets.commons.parsing import get_grid_properties
from assets.constants import mean_molecular_mass, radmc_grid_map, radmc_coord_system_map


def compute_power_law_radial_profile(
        central_value: float,
        power_law_index: float,
        distance_matrix: np.array,
        maximum_radius: Union[float, None] = None,
        value_at_reference: Union[float, None] = None,
        distance_reference: Union[float, u.Quantity] = 1.0,
        fill_reference_pixel: bool = True) -> np.array:
    """
    Compute a power law distribution over the specified grid
    :param central_value: value in the center of the grid
    :param power_law_index: index of the power law used to scale the profile
    :param distance_matrix: the matrix of distances from the reference point
    :param maximum_radius: the maximum radius to populate with gas within the grid in the same unit as the distance
                           matrix and distance reference
    :param value_at_reference: value of the profile at the reference distance; defaults to central value if not provided
    :param distance_reference: value of the reference distance
    :param fill_reference_pixel: whether to fill the reference point with central_value and set this to the maximum in
        the grid
    :return: the distance matrix
    """
    _value_at_reference = validate_parameter(value_at_reference, default=central_value)
    _distance_matrix = np.where(distance_matrix == 0, np.nanmin(distance_matrix[distance_matrix > 0]),
                                distance_matrix) if fill_reference_pixel is True else distance_matrix
    try:
        _distance_reference = distance_reference.to(u.cm).value
    except AttributeError:
        _distance_reference = distance_reference
    profile = _value_at_reference * (_distance_matrix / _distance_reference) ** power_law_index
    # If the routine fills the 0-distance point (the centre), it fixes the profile making the central value the maximum
    if (fill_reference_pixel is True) and (power_law_index != 0):
        profile = np.where(profile > central_value, central_value, profile)
    if maximum_radius is not None:
        profile = np.where(distance_matrix > maximum_radius + maximum_radius / 1e5, 0, profile)
    return profile


def compute_cartesian_coordinate_grid(indices: np.array,
                                      physical_px_size: np.array) -> np.array:
    """
    Compute the physical coordinate grid, for a regular grid
    :param indices: the indices of the grid pixels
    :param physical_px_size: the array of physical pixel sizes or the pixel size, as astropy.Quantities
    :return: the numpy.array with the physical grid coordinates
    """
    try:
        _physical_px_size_cm = [px_size.to(u.cm).value for px_size in physical_px_size]
    except TypeError:
        _physical_px_size_cm = physical_px_size.to(u.cm).value
    return (indices.T * _physical_px_size_cm).T


def get_centered_indices(grid_metadata: dict) -> np.array:
    """
    Recompute the indices using as reference the reference pixel in the configuration file
    :param grid_metadata: the dictionary with the grid metadata, to obtain the grid shape and the reference pixel
    :return: a numpy array of indices, wrt. the reference
    """
    return (np.indices(grid_metadata['grid_shape']).T - grid_metadata['grid_refpix']).T


def get_grid_edges(grid_metadata: dict) -> np.array:
    """
    Compute the physical coordinates at the grid boundaries
    :param grid_metadata: the dictionary with the grid metadata, to obtain the grid shape, the reference pixel, and the
        physical pixel size
    :return: a numpy array with the coordinates of the edges, per axis
    """
    grid_edges = []
    for axis_idx in range(len(grid_metadata['grid_shape'])):
        indices = np.arange(grid_metadata['grid_shape'][axis_idx] + 1) - grid_metadata['grid_refpix'][axis_idx] - 0.5
        # valid for regular grid
        grid_edges.append(compute_cartesian_coordinate_grid(indices=indices,
                                                            physical_px_size=grid_metadata['physical_px_size'][
                                                                axis_idx]))
    # Transposed for consistent flattening
    return np.array(grid_edges).T


def get_distance_matrix(grid_metadata: dict,
                        indices: np.array) -> np.array:
    """
    Compute physical distance from the reference pixel, in cm
    :param grid_metadata: the dictionary with the grid metadata, to obtain the grid shape, and the physical pixel size
    :param indices: numpy array of the grid pixel indices
    :return: a numpy array with the euclidean distance from the reference pixel
    """
    distance_matrix = np.zeros(grid_metadata['grid_shape'])
    for axis_idx in range(len(grid_metadata['grid_shape'])):
        distance_matrix += (indices[axis_idx, ...] * grid_metadata['physical_px_size'][axis_idx].to(u.cm).value) ** 2
    return np.sqrt(distance_matrix)


def get_physical_px_size(grid_metadata: dict) -> List[u.Quantity]:
    """
    Compute the physical pixel size
    :param grid_metadata: the dictionary with the grid metadata, to obtain the grid shape, the grid size, and the grid
        size units
    :return: a list of physical pixel sizes, per axis
    """
    physical_px_size = []
    for axis_idx in range(len(grid_metadata['grid_shape'])):
        physical_px_size.append(
            (grid_metadata['grid_size'][axis_idx] * u.Unit(grid_metadata['grid_size_units'][axis_idx])).to(u.cm) /
            grid_metadata['grid_shape'][axis_idx])
    return physical_px_size


def compute_analytic_profile(central_value: float,
                             power_law_index: float,
                             maximum_radius: float,
                             value_at_reference: float,
                             distance_reference: float) -> np.array:
    """
    Compute the analytic radial profile based on a power-law model.

    :param central_value (float): Central value of the profile, where the power-law could be undefined.
    :param power_law_index (float): Power-law index determining the shape of the profile.
    :param maximum_radius (float): Maximum radius within which to compute the profile; values outside this radius are
        set to zero.
    :param value_at_reference (float): Value of the profile at the reference distance.
    :param distance_reference (float): Reference distance at which the profile value is specified.

    :return: np.array: Computed radial profile.

    Example:
        # Compute an analytic profile with given parameters
        profile = compute_analytic_profile(central_value=10.0,
                                            power_law_index=-1.5,
                                            maximum_radius=5.0,
                                            value_at_reference=5.0,
                                            distance_reference=1.0)
    """
    config = load_config_file(config_file_path=os.path.join('stg', 'config', 'config.yml'))
    grid_metadata = extract_grid_metadata(config=config)
    distance_matrix = get_distance_matrix(grid_metadata=grid_metadata,
                                          indices=get_centered_indices(grid_metadata=grid_metadata))
    profile = compute_power_law_radial_profile(central_value=central_value,
                                               power_law_index=power_law_index,
                                               distance_matrix=distance_matrix,
                                               maximum_radius=maximum_radius,
                                               value_at_reference=value_at_reference,
                                               distance_reference=distance_reference)
    return profile


def compute_los_average_weighted_profile(profile: np.array,
                                         weights: Union[float, np.array]) -> np.array:
    """
    Compute the line-of-sight (LOS) average weighted profile.
    This function computes the weighted average profile along the LOS, where each profile is weighted
    according to the provided weights. The computation is performed by summing the products of each
    profile value and its corresponding weight across all LOS, and then dividing by the sum of weights
    to obtain the average.

    Note:
    - If a profile value is 0, it is treated as missing data and not included in the computation.
    - If a weight is 0, the corresponding profile is effectively excluded from the computation.

    :param profile: (np.array) Array containing the profiles along different lines of sight (LOS).

    :param weights: (Union[float, np.array]) Array containing the weights corresponding to each profile in the input array,
        or a single weight value if uniform weighting is desired.

    :return: (np.array) LOS average weighted profile.
    """
    return np.nansum(np.where(profile == 0, np.nan, profile * weights), axis=2) / \
        np.nansum(np.where(profile == 0, np.nan, weights), axis=2)


def compute_density_and_temperature_avg_profiles(density_profile: np.array,
                                                 temperature_profile: np.array,
                                                 molecular_abundance: float,
                                                 threshold: float,
                                                 abundance_jump: Union[float, int],
                                                 logger: Union[logging.Logger, None] = None) -> tuple:
    """
    Compute the average density and temperature profiles along the line-of-sight along with related quantities.
    This function computes the average density and temperature profiles along the LOS, along with related quantities
    such as molecular column density and the standard deviation of the species' number density. It first computes the
    molecular number density grid using the given density, temperature, abundance, temperature threshold, and abundance
    jump for simulating hot core-like abundance profiles. Then, it computes the average density and temperature profiles
    using the molecular number density grid as weights along the LOS. Finally, it calculates the
    molecular column density and its standard deviation.

    :param density_profile (np.array): Array containing the density profiles.
    :param temperature_profile (np.array): Array containing the temperature profiles.
    :param molecular_abundance (float): Molecular abundance for computing molecular number density.
    :param threshold (float): Threshold for adjusting the abundance in hot cores, where the molecule is potentially
        evaporated from dust grains.
    :param abundance_jump (Union[float, int]): Abundance jump in hot cores.
    :param logger (Union[logging.Logger, None], optional): Logger object for logging messages. If not specified, creates
        a generic logger for reporting the total mass.

    :return: Tuple[np.array, np.array, np.array, np.array]
        - Average density profile.
        - Average temperature profile.
        - Molecular column density.
        - Standard deviation of molecular number density.
    """
    _logger = validate_parameter(logger, setup_logger('GENERIC'))
    config = load_config_file(config_file_path=os.path.join('stg', 'config', 'config.yml'))
    grid_metadata = extract_grid_metadata(config=config)
    molecule_grid = compute_molecular_number_density_hot_core(gas_number_density_profile=density_profile,
                                                              abundance=molecular_abundance,
                                                              temperature_profile=temperature_profile,
                                                              threshold=threshold,
                                                              abundance_jump=abundance_jump)
    avg_density_profile = compute_los_average_weighted_profile(profile=density_profile,
                                                               weights=molecule_grid)
    mass = (np.nansum(density_profile * u.cm ** -3 * grid_metadata['physical_px_size'][0]
                      * grid_metadata['physical_px_size'][1]
                      * grid_metadata['physical_px_size'][2]) * mean_molecular_mass * m_p).to(u.M_sun)
    _logger.debug(f'Total mass: {mass}')
    avg_temperature_profile = compute_los_average_weighted_profile(profile=temperature_profile,
                                                                   weights=molecule_grid)
    return (avg_density_profile,
            avg_temperature_profile,
            np.nansum(molecule_grid, axis=2) * grid_metadata['physical_px_size'][2].value,
            np.nanstd(molecule_grid, axis=2))


def compute_molecular_number_density_hot_core(gas_number_density_profile: np.array,
                                              abundance: float,
                                              temperature_profile: np.array,
                                              threshold: float,
                                              abundance_jump: Union[float, int]) -> np.array:
    """
    Compute the molecular number density, using a step function abundance, changing above a specific temperature to
    simulate evaporation
    :param gas_number_density_profile: the gas number density profile array
    :param abundance: the gas abundance of the species
    :param temperature_profile: the temperature profile of the source
    :param threshold: the threshold at which the species evaporate
    :param abundance_jump: the factor describing the abundance variation wrt the bas level
    :return: the array of molecular gas density computed using a step function profile
    """
    return np.where(temperature_profile < threshold,
                    gas_number_density_profile * abundance,
                    gas_number_density_profile * abundance * abundance_jump)


def extract_grid_metadata(config: dict) -> dict:
    """
    Enrich grid metadata from the information in the configuration file
    :param config: configuration dictionary
    :return: a dictionary with the metadata
    """
    grid_config = config['grid']

    grid_metadata = grid_config.copy()
    grid_metadata['grid_type'] = radmc_grid_map[grid_config['grid_type']]
    grid_metadata['coordinate_system'] = radmc_coord_system_map[grid_config['coordinate_system']]
    grid_properties_keywords = ['shape', 'refpix', 'size', 'size_units']
    for key in grid_properties_keywords:
        grid_metadata[f'grid_{key}'] = get_grid_properties(grid_config=grid_config,
                                                           keyword=key)

    grid_metadata['physical_px_size'] = get_physical_px_size(grid_metadata)

    grid_metadata['centered_indices'] = get_centered_indices(grid_metadata)
    grid_metadata['coordinate_grid'] = compute_cartesian_coordinate_grid(indices=grid_metadata['centered_indices'],
                                                                         physical_px_size=grid_metadata[
                                                                             'physical_px_size'])

    grid_metadata['distance_matrix'] = get_distance_matrix(grid_metadata, grid_metadata['centered_indices'])

    grid_metadata['grid_edges'] = get_grid_edges(grid_metadata)
    grid_metadata['continuum_lambdas'] = 100
    return grid_metadata
