import os
import numpy as np
from stg.stg_radmc_input_generator import (get_solid_body_rotation_y,
                                           get_grid_name)
from assets.commons.parsing import read_abundance_variation_schema
from astropy import units as u
from unittest import TestCase
from assets.commons import load_config_file
from assets.commons.grid_utils import (compute_molecular_number_density_hot_core,
                                       extract_grid_metadata,
                                       compute_power_law_radial_profile)


def create_test_config(config_dict: dict,
                       config_filename: str):
    with open(config_filename, 'w') as config_file:
        config_file.write('grid:\n')
        for key in config_dict:
            config_file.write(f'    {key}: {config_dict[key]}\n')


class Test(TestCase):
    def setUp(self):
        self.config_filename = 'config.yml'

    def test_compute_power_law_radial_profile(self):
        profile = compute_power_law_radial_profile(central_value=2000.0,
                                                   power_law_index=0,
                                                   distance_matrix=np.array([
                                                       [np.sqrt(2), 1, np.sqrt(2)],
                                                       [1, 0, 1],
                                                       [np.sqrt(2), 1, np.sqrt(2)],
                                                   ]),
                                                   maximum_radius=np.sqrt(2),
                                                   value_at_reference=15.0,
                                                   distance_reference=3e18)
        expected_result = np.array([
            [15, 15, 15],
            [15, 15, 15],
            [15, 15, 15],
        ])
        self.assertTrue(np.allclose(profile, expected_result, 1.0e-5))

    def test_read_abundance_variation_schema(self):
        line_config = {
            'species_to_include': ['e-ch3oh'],
            'molecular_abundances': {
                "e-ch3oh": 1e-9,
                "p-h2": 1,
            },
            'hot_core_specs': {
                "e-ch3oh": {
                    'threshold': 90,
                    'abundance_jump': 100
                }
            },
            'lines_mode': 'lvg',
            'collision_partners': ['p-h2']
        }
        result_dict = read_abundance_variation_schema(line_config=line_config)
        expected_dict = {
            'e-ch3oh': {
                'threshold': 90,
                'abundance_jump': 100
            },
            'p-h2': {
                'threshold': 20,
                'abundance_jump': 1
            },
        }
        self.assertDictEqual(result_dict, expected_dict)

    def test_compute_molecular_number_density_hot_core(self):
        gas_number_density_profile = np.array([1, 1, 2, 3])
        temperature_profile = np.array([10, 100, 200, 20])
        abundance_array = compute_molecular_number_density_hot_core(
            gas_number_density_profile=gas_number_density_profile,
            abundance=1e-9,
            temperature_profile=temperature_profile,
            abundance_jump=100,
            threshold=90)
        expected_results = np.array([1e-9, 1e-7, 2e-7, 3e-9])
        self.assertTrue(np.allclose(abundance_array, expected_results, 1e-7))

    def test_get_grid_filename_composite(self):
        grid_name = get_grid_name(method='composite_grid',
                                  zip_filename='abc.def.zip',
                                  quantity_name='h2_density')
        self.assertEqual(grid_name, 'abc.def_h2_density.fits')

    def test_get_grid_filename_missing_info(self):
        with self.assertRaises(AssertionError):
            _ = get_grid_name(method='composite_grid',
                              quantity_name='h2_density')
        with self.assertRaises(AssertionError):
            _ = get_grid_name(method='composite_grid',
                              zip_filename='abc.def.zip')

    def test_get_grid_filename_undefined_method(self):
        with self.assertRaises(NotImplementedError):
            get_grid_name(method='puzzidilontano')

    def test_get_solid_body_rotation_y(self):
        _config_filename = os.path.join('test_files', self.config_filename)
        grid_size = 3
        npix = 3
        config_dict = {
            'grid_type': 'regular',
            'coordinate_system': 'cartesian',
            'central_density': '1e6',
            'density_unit': 'cm^-3',
            'dim1': {"size": grid_size, "size_units": "cm", "shape": npix, "refpix": 1},
            'velocity_field': 'solid',
            'velocity_gradient': 1,
            'velocity_gradient_unit': "cm/s cm",
        }
        expected_result_x = [
            [
                [np.sqrt(2) * np.sin(np.pi / 4), 0, -np.sqrt(2) * np.sin(np.pi / 4)],
                [np.sqrt(2) * np.sin(np.pi / 4), 0, -np.sqrt(2) * np.sin(np.pi / 4)],
                [np.sqrt(2) * np.sin(np.pi / 4), 0, -np.sqrt(2) * np.sin(np.pi / 4)],
            ],
            [
                [1, 0, -1],
                [1, 0, -1],
                [1, 0, -1],
            ],
            [
                [np.sqrt(2) * np.sin(np.pi / 4), 0, -np.sqrt(2) * np.sin(np.pi / 4)],
                [np.sqrt(2) * np.sin(np.pi / 4), 0, -np.sqrt(2) * np.sin(np.pi / 4)],
                [np.sqrt(2) * np.sin(np.pi / 4), 0, -np.sqrt(2) * np.sin(np.pi / 4)],
            ],
        ]
        expected_result_z = [
            [
                [-np.sqrt(2) * np.cos(np.pi / 4), -1.0, -np.sqrt(2) * np.cos(np.pi / 4)],
                [-np.sqrt(2) * np.cos(np.pi / 4), -1.0, -np.sqrt(2) * np.cos(np.pi / 4)],
                [-np.sqrt(2) * np.cos(np.pi / 4), -1.0, -np.sqrt(2) * np.cos(np.pi / 4)],
            ],
            [
                [0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0],
            ],
            [
                [np.sqrt(2) * np.cos(np.pi / 4), 1.0, np.sqrt(2) * np.cos(np.pi / 4)],
                [np.sqrt(2) * np.cos(np.pi / 4), 1.0, np.sqrt(2) * np.cos(np.pi / 4)],
                [np.sqrt(2) * np.cos(np.pi / 4), 1.0, np.sqrt(2) * np.cos(np.pi / 4)],
            ],
        ]
        create_test_config(config_dict=config_dict,
                           config_filename=_config_filename)
        config = load_config_file(_config_filename)
        grid_metadata = extract_grid_metadata(config=config)
        velocity_x, velocity_z = get_solid_body_rotation_y(grid_metadata=grid_metadata)
        self.assertTrue(np.array_equal(np.array(expected_result_x), velocity_x.value))
        self.assertTrue(np.array_equal(np.array(expected_result_z), velocity_z.value))
