import os
from unittest import TestCase
import numpy as np
import pandas as pd
from astropy import units as u
from assets.commons import (load_config_file,
                            validate_parameter,
                            setup_logger,
                            get_moldata)
from assets.commons.db_utils import get_credentials
from assets.commons.grid_utils import (get_grid_edges,
                                       get_physical_px_size,
                                       compute_cartesian_coordinate_grid,
                                       compute_power_law_radial_profile,
                                       get_distance_matrix,
                                       get_centered_indices,
                                       extract_grid_metadata,
                                       compute_los_average_weighted_profile)
from assets.commons.parsing import (get_grid_properties,
                                    parse_grid_overrides)
from assets.commons.training_utils import (compute_and_add_similarity_cols,
                                           split_data)


def create_test_config(config_dict: dict,
                       config_filename: str):
    with open(config_filename, 'w') as config_file:
        config_file.write('grid:\n')
        for key in config_dict:
            config_file.write(f'    {key}: {config_dict[key]}\n')


class TestCommons(TestCase):
    def setUp(self):
        self.config_filename = 'config.yml'
        self.logger = setup_logger(name='TEST_COMMONS')

    def test_compute_power_law_radial_profile_flat(self):
        indices = np.indices([3, 3]) - 1
        distance_matrix = np.sqrt(indices[0, :, :] ** 2 + indices[1, :, :] ** 2)
        radial_profile = compute_power_law_radial_profile(
            central_value=1,
            power_law_index=0,
            distance_matrix=distance_matrix
        )
        self.assertTrue(np.array_equal(radial_profile, np.ones([3, 3])))

    def test_compute_power_law_radial_profile_pl(self):
        indices = np.indices([3, 3]) - 1
        distance_matrix = np.sqrt(indices[0, :, :] ** 2 + indices[1, :, :] ** 2)
        distance_matrix[1, 1] = 1.0
        radial_profile = compute_power_law_radial_profile(
            central_value=1,
            power_law_index=-1,
            distance_matrix=distance_matrix
        )
        self.assertTrue(np.array_equal(radial_profile, 1.0 / distance_matrix))

    def test_compute_power_law_radial_profile_pl_with_reference(self):
        indices = np.indices([3, 3]) - 1
        distance_matrix = np.sqrt(indices[0, :, :] ** 2 + indices[1, :, :] ** 2)
        radial_profile = compute_power_law_radial_profile(
            central_value=2,
            value_at_reference=1,
            distance_reference=1,
            power_law_index=-1,
            distance_matrix=distance_matrix
        )
        censored_rp = radial_profile.copy()
        censored_rp[1, 1] = np.nan
        self.assertEqual(radial_profile[1, 1], np.nanmax(censored_rp))

    def test_get_grid_properties_uniform(self):
        _config_filename = os.path.join('test_files', self.config_filename)
        config_dict = {
            'dim1': {"size": 0.1, "unit": "pc", "shape": 5, "refpix": 2},
        }
        keywords = ['size', 'unit', 'shape', 'refpix']
        expected_result = {
            'size': [0.1] * 3,
            'unit': ['pc'] * 3,
            'shape': [5] * 3,
            'refpix': [2] * 3,
        }
        create_test_config(config_dict=config_dict,
                           config_filename=_config_filename)
        config = load_config_file(_config_filename)
        for key in keywords:
            grid_properties = get_grid_properties(grid_config=config['grid'],
                                                  keyword=key)
            self.assertListEqual(grid_properties, expected_result[key])

    def test_get_grid_properties(self):
        _config_filename = os.path.join('test_files', self.config_filename)
        config_dict = {
            'dim1': {"size": 0.1, "unit": "pc", "shape": 5, "refpix": 2},
            'dim2': {"size": 0.2, "unit": "pc", "shape": 10, "refpix": 4.5},
            'dim3': {"size": 0.3, "unit": "pc", "shape": 15, "refpix": 7},
        }
        keywords = ['size', 'unit', 'shape', 'refpix']
        expected_result = {
            'size': [0.1, 0.2, 0.3],
            'unit': ['pc'] * 3,
            'shape': [5, 10, 15],
            'refpix': [2, 4.5, 7],
        }
        create_test_config(config_dict=config_dict,
                           config_filename=_config_filename)

        config = load_config_file(_config_filename)
        for key in keywords:
            grid_properties = get_grid_properties(grid_config=config['grid'],
                                                  keyword=key)
            self.assertListEqual(grid_properties, expected_result[key])

    def test_get_grid_edges(self):
        grid_metadata = {
            'grid_shape': [2, 2],
            'physical_px_size': [1 * u.cm, 1 * u.cm],
            'grid_refpix': [0.5, 0.5]
        }
        expected_results_per_axis = np.array([-1, 0, 1])
        grid_edges = get_grid_edges(grid_metadata=grid_metadata)
        for axis_idx in range(len(grid_metadata['grid_shape'])):
            self.assertTrue(np.array_equal(expected_results_per_axis, grid_edges[:, axis_idx]))

    def test_get_physical_px_size(self):
        _config_filename = os.path.join('test_files', self.config_filename)
        grid_size = 0.2
        npix = 5
        config_dict = {
            'grid_type': 'regular',
            'coordinate_system': 'cartesian',
            'central_density': '1e6',
            'density_unit': 'cm^-3',
            'dim1': {"size": grid_size, "size_units": "pc", "shape": npix, "refpix": 2},
        }
        expected_result = [(grid_size * u.pc).to(u.cm) / npix] * 3
        create_test_config(config_dict=config_dict,
                           config_filename=_config_filename)
        config = load_config_file(_config_filename)
        grid_metadata = extract_grid_metadata(config=config)
        physical_px_size = get_physical_px_size(grid_metadata)
        self.assertListEqual(expected_result, physical_px_size)

    def test_compute_cartesian_coordinate_grid(self):
        indices = np.array([0, 1, 2])
        physical_px_size = [0.1 * u.pc]
        expected_results = indices * (physical_px_size[0]).to(u.cm).value
        computed_grid = compute_cartesian_coordinate_grid(indices=indices,
                                                          physical_px_size=physical_px_size)
        self.assertTrue(np.array_equal(expected_results, computed_grid))

    def test_compute_cartesian_coordinate_grid_2d(self):
        indices = np.indices([3, 4])
        physical_px_size = [1 * u.cm, 1 * u.cm]
        expected_results = indices
        computed_grid = compute_cartesian_coordinate_grid(indices=indices,
                                                          physical_px_size=physical_px_size)
        self.assertTrue(np.array_equal(expected_results, computed_grid))

    def test_get_distance_matrix_2d(self):
        grid_metadata = {
            'grid_shape': [2, 2],
            'physical_px_size': [1 * u.cm, 1 * u.cm]
        }
        indices = np.indices(grid_metadata['grid_shape'])
        expected_results = np.array([[0, 1], [1, np.sqrt(2)]])
        distance_matrix = get_distance_matrix(grid_metadata=grid_metadata,
                                              indices=indices)
        self.assertTrue(np.array_equal(expected_results, distance_matrix))

    def test_get_distance_matrix(self):
        grid_metadata = {
            'grid_shape': [2, 2, 2],
            'physical_px_size': [1 * u.cm, 1 * u.cm, 1 * u.cm]
        }
        indices = np.indices(grid_metadata['grid_shape'])
        expected_results = np.array([[[0, 1], [1, np.sqrt(2)]],
                                     [[1, np.sqrt(2)], [np.sqrt(2), np.sqrt(3)]]])
        distance_matrix = get_distance_matrix(grid_metadata=grid_metadata,
                                              indices=indices)
        self.assertTrue(np.array_equal(expected_results, distance_matrix))

    def test_get_distance_matrix_symmetric(self):
        grid_metadata = {
            'grid_shape': [5, 5, 5],
            'physical_px_size': [1 * u.cm, 1 * u.cm, 1 * u.cm]
        }
        indices = np.indices(grid_metadata['grid_shape']) - 3
        distance_matrix = get_distance_matrix(grid_metadata=grid_metadata,
                                              indices=indices)
        self.assertTrue(np.array_equal(distance_matrix, distance_matrix.T))

    def test_get_centered_indices(self):
        grid_metadata = {
            'grid_shape': [5, 3],
            'grid_refpix': [2, 1]
        }
        expected_indices = [[[-2, -2, -2],
                             [-1, -1, -1],
                             [0, 0, 0],
                             [1, 1, 1],
                             [2, 2, 2]],
                            [[-1, 0, 1],
                             [-1, 0, 1],
                             [-1, 0, 1],
                             [-1, 0, 1],
                             [-1, 0, 1]]]
        indices = get_centered_indices(grid_metadata=grid_metadata)
        self.assertTrue(np.array_equal(np.array(expected_indices), indices))

    def test_validate_parameter_none(self):
        parameter = None
        default = 10
        result = validate_parameter(param_to_validate=parameter,
                                    default=default)
        self.assertEqual(default, result)

    def test_validate_parameter(self):
        parameter = 100
        default = 10
        result = validate_parameter(param_to_validate=parameter,
                                    default=default)
        self.assertEqual(parameter, result)

    def test_parse_grid_overrides_linear(self):
        config = {
            'overrides': {
                'dust_temperature_grid_type': 'linear',
                'dust_temperature_limits': [10, 30],
                'dust_temperature_step': 1,
            }
        }
        expected_results = np.arange(config['overrides']['dust_temperature_limits'][0],
                                     config['overrides']['dust_temperature_limits'][1],
                                     config['overrides']['dust_temperature_step'])
        grid_values = parse_grid_overrides(par_name='dust_temperature',
                                           config=config)
        self.assertTrue(np.array_equal(expected_results, grid_values))

    def test_parse_grid_overrides_log(self):
        config = {
            'overrides': {
                'dust_temperature_grid_type': 'log',
                'dust_temperature_limits': [10, 200],
                'dust_temperature_step': 2,
            }
        }
        expected_results = np.array([10, 20, 40, 80, 160])
        grid_values = parse_grid_overrides(par_name='dust_temperature',
                                           config=config)
        self.assertTrue(np.array_equal(expected_results, grid_values.astype(int)))

    def test_parse_grid_overrides_log_exponent(self):
        exponent = 0.25
        config = {
            'overrides': {
                'dust_temperature_grid_type': 'log',
                'dust_temperature_limits': [10, 100],
                'dust_temperature_step': exponent,
            }
        }
        expected_results = np.array(
            [10, 10 * 10 ** exponent, 10 * 10 ** (2 * exponent), 10 * 10 ** (3 * exponent), 10 * 10 ** (4 * exponent)])
        grid_values = parse_grid_overrides(par_name='dust_temperature',
                                           config=config)
        self.assertTrue(np.allclose(expected_results, grid_values, 5))

    def test_get_credentials(self):
        expected_results = {
            'DB_USER': 'pippo',
            'DB_PASS': 'pluto',
            'DB_HOST': 'localhost',
            'DB_NAME': 'postgres',
            'DB_PORT': 5432,
        }
        credentials = get_credentials(logger=self.logger,
                                      credentials_filename=os.path.join('test_files', 'credentials.yaml'),
                                      set_envs=True)
        self.assertDictEqual(credentials, expected_results)
        for key in expected_results:
            self.assertEqual(str(expected_results[key]), os.getenv(key))

    def test_get_moldata_from_cache(self):
        species = 'e-ch3oh'
        os.chdir('..')
        expected_path = os.path.join('tests', 'test_files', f'molecule_{species}.inp')
        try:
            os.remove(expected_path)
        except IOError:
            pass
        get_moldata(species_names=[species],
                    logger=self.logger,
                    path=os.path.join('tests', 'test_files'),
                    use_cache=True)
        self.assertTrue(os.path.isfile(expected_path))
        os.remove(expected_path)
        os.chdir('tests')

    def test_get_moldata_negate_cache(self):
        species = 'e-ch3oh'
        os.chdir('..')
        expected_path = os.path.join('tests', 'test_files', f'molecule_{species}.inp')
        try:
            os.remove(expected_path)
        except IOError:
            pass
        get_moldata(species_names=[species],
                    logger=self.logger,
                    path=os.path.join('tests', 'test_files'),
                    use_cache=False)
        self.assertTrue(os.path.isfile(expected_path))
        os.remove(expected_path)
        os.chdir('tests')

    def test_compute_los_average_weighted_profile(self):
        indices = np.indices([3, 3, 3]) - 1
        distance_matrix = np.sqrt(indices[0, :, :] ** 2 + indices[1, :, :] ** 2 + indices[2, :, :] ** 2)
        radial_profile = compute_power_law_radial_profile(
            central_value=1,
            power_law_index=0,
            distance_matrix=distance_matrix
        )
        rng = np.random.default_rng(seed=3)
        weights = rng.random(radial_profile.shape)
        los_average = compute_los_average_weighted_profile(profile=radial_profile,
                                                           weights=weights)
        expected_result = np.ones_like(los_average)
        self.assertTrue(np.array_equal(expected_result, los_average))

    def test_compute_los_average_weighted_profile_pl(self):
        indices = np.indices([3, 3, 3]) - 1
        distance_matrix = np.sqrt(indices[0, :, :] ** 2 + indices[1, :, :] ** 2 + indices[2, :, :] ** 2)
        radial_profile = compute_power_law_radial_profile(
            central_value=1,
            power_law_index=-1,
            distance_matrix=distance_matrix
        )
        weights = np.ones_like(radial_profile)
        los_average = compute_los_average_weighted_profile(profile=radial_profile,
                                                           weights=weights)
        expected_result = np.where(~np.isfinite(np.mean(1.0 / distance_matrix, axis=2)),
                                   1,
                                   np.mean(1.0 / distance_matrix, axis=2))
        self.assertTrue(np.array_equal(expected_result, los_average))

    def test_compute_los_average_weighted_profile_pl_weights(self):
        indices = np.indices([3, 3, 3]) - 1
        distance_matrix = np.sqrt(indices[0, :, :] ** 2 + indices[1, :, :] ** 2 + indices[2, :, :] ** 2)
        radial_profile = compute_power_law_radial_profile(
            central_value=1,
            power_law_index=-1,
            distance_matrix=distance_matrix
        )
        rng = np.random.default_rng(seed=3)
        weights = rng.random(radial_profile.shape)
        los_average = compute_los_average_weighted_profile(profile=radial_profile,
                                                           weights=weights)
        weights /= np.sum(weights, axis=2, keepdims=True)
        expected_radial_profile = np.where(~np.isfinite(1.0 / distance_matrix),
                                           1, 1.0 / distance_matrix)
        self.assertTrue(np.array_equal(radial_profile, expected_radial_profile))
        expected_result = np.where(~np.isfinite(1.0 / distance_matrix),
                                   1, np.sum(radial_profile * weights, axis=2))
        self.assertTrue(np.allclose(expected_result, los_average, 1e-5))


class TestTraining(TestCase):
    def setUp(self):
        feature_names = ['feature_01', 'feature_02']

        self.df_data = pd.DataFrame(
            data=[
                [1, -1],
                [1, 1],
                [-1, -1],
                [-1, 0],
                [2, 1],
                [0, 1],
                [0, 0]
            ],
            columns=feature_names
        )
        self.average_features_per_target_bin = pd.DataFrame(
            data=[
                [1, 1],
                [-1, -1]
            ],
            columns=feature_names
        )

    def test_compute_and_add_similarity_cols_output_features(self):
        test_df = compute_and_add_similarity_cols(average_features_per_target_bin=self.average_features_per_target_bin,
                                                  input_df=self.df_data,
                                                  similarity_bins=2)
        self.assertListEqual(
            list1=list(test_df.columns),
            list2=list(self.df_data.columns) + [f'sim_{str(idx).zfill(2)}' for idx in range(2)]
        )

    def test_compute_and_add_similarity_cols_drop_features_output(self):
        df_data_with_additional_column = self.df_data.copy()
        df_data_with_additional_column['additional_column'] = 1
        assert 'additional_column' in df_data_with_additional_column.columns
        test_df = compute_and_add_similarity_cols(average_features_per_target_bin=self.average_features_per_target_bin,
                                                  input_df=df_data_with_additional_column,
                                                  similarity_bins=2)
        self.assertListEqual(
            list1=list(test_df.columns),
            list2=list(df_data_with_additional_column.columns) + [f'sim_{str(idx).zfill(2)}' for idx in range(2)]
        )

    def test_compute_and_add_similarity_cols(self):
        test_df = compute_and_add_similarity_cols(average_features_per_target_bin=self.average_features_per_target_bin,
                                                  input_df=self.df_data,
                                                  similarity_bins=2)
        df_data_l2_norm = np.sqrt(np.sum(self.df_data.values ** 2, axis=1)).reshape(len(self.df_data), 1)
        average_features_per_target_bin_l2_norm = (np.sqrt(np.sum(self.average_features_per_target_bin.values ** 2,
                                                                  axis=1))
                                                   .reshape(1, len(self.average_features_per_target_bin)))
        expected_result = np.dot(
            self.df_data,
            self.average_features_per_target_bin.T
        ) / np.dot(df_data_l2_norm,
                   average_features_per_target_bin_l2_norm)
        self.assertTrue(
            np.array_equal(
                test_df[['sim_00', 'sim_01']].values.round(5),
                np.nan_to_num(expected_result.round(5), nan=0)
            )
        )

    def test_split_data(self):
        merged = pd.DataFrame(
            data=[
                [2.7e4, 1, 1, 2],
                [2.187e6, 2, 2, 4],
                [1e3, 3, 3, 6],
                [1e4, 4, 4, 8],
                [1e5, 5, 5, 10],
                [1e6, 6, 6, 12],
                [1e7, 7, 7, 14]
            ],
            columns=['nh2', 'tdust', 'predictor', 'target'])
        x_test, x_train, x_validation, y_test, y_train, y_validation = split_data(
            merged=merged,
            target_column='target',
            predictor_columns=['nh2', 'tdust', 'predictor']
        )
        self.assertListEqual(list((10**x_test['nh2'].unique()).round(1)), [2.7e4])
        self.assertListEqual(list((10**x_validation['nh2'].unique()).round(1)), [2.187e6])
