import numpy as np
from unittest import TestCase
from prs.prs_density_inference import (get_probability_density_threshold,
                                       get_probability_distribution,
                                       get_results,
                                       get_hpd_interval)


class Test(TestCase):
    def test_get_probability_density_threshold(self):
        x_array = np.arange(8)
        probability_density = np.array([0, 5, 15, 30, 30, 15, 5, 0]) / 100.
        expected_result = 0.025
        probability, centered_probability_density, ordered_idxs = get_probability_distribution(
            probability_density=probability_density,
            x_array=x_array)
        computed_threshold = get_probability_density_threshold(
            ordered_probability=probability[ordered_idxs],
            ordered_probability_density=centered_probability_density[ordered_idxs],
            probability_threshold=0.05
        )

        self.assertAlmostEqual(expected_result, computed_threshold, 5)

    def test_get_hpd_interval_last(self):
        x_array = np.arange(8)
        probability_density = np.array([0, 5, 15, 30, 10, 15, 15, 15]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.05,
                                        interp_points=1000)
        expected_hpd_interval = [1, 7]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_hpd_interval_first(self):
        x_array = np.arange(8)
        probability_density = np.array([15, 15, 15, 10, 30, 15, 5, 0]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.05,
                                        interp_points=1000)
        expected_hpd_interval = [0, 6]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_hpd_interval_double(self):
        x_array = np.arange(8)
        probability_density = np.array([10, 20, 15, 10, 30, 15, 5, 0]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.15,
                                        interp_points=1000)
        expected_hpd_interval = [0.5, 2, 3.25, 5]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_hpd_interval_edge_case_below(self):
        x_array = np.arange(8)
        probability_density = np.array([30, 20, 15, 10, 10, 15, 5, 0]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.15,
                                        interp_points=1000)
        expected_hpd_interval = [0, 2, 5, 5]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_hpd_interval_edge_case_constant(self):
        x_array = np.arange(8)
        probability_density = np.array([30, 20, 15, 15, 10, 10, 5, 0]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.1,
                                        interp_points=1000)
        expected_hpd_interval = [0, 5]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_hpd_interval_edge_case_constant2(self):
        x_array = np.arange(8)
        probability_density = np.array([10, 10, 35, 25, 10, 10, 5, 0]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.1,
                                        interp_points=1000)
        expected_hpd_interval = [0, 5]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_hpd_interval_edge_case_above(self):
        x_array = np.arange(8)
        probability_density = np.array([30, 20, 15, 10, 11, 15, 5, 1]) / 100.
        hpd_interval = get_hpd_interval(x_array=x_array,
                                        probability_density=probability_density,
                                        hpd_threshold=0.1,
                                        interp_points=1000)
        expected_hpd_interval = [0, 5.5]
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_results(self):
        x_array = np.arange(8)
        probability_density = np.array([0, 5, 15, 30, 30, 15, 5, 0]) / 100.
        computed_threshold, best_fit, hpd_interval = get_results(x_array=x_array,
                                                                 probability_density=probability_density,
                                                                 probability_threshold=0.05,
                                                                 interp_points=1000)
        expected_threshold = 0.025
        expected_best_fit = 3
        expected_hpd_interval = [0.5, 6.5]
        self.assertAlmostEqual(expected_threshold, computed_threshold, 5)
        self.assertEqual(expected_best_fit, best_fit)
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-2))

    def test_get_results_asymmetric(self):
        x_array = np.array([1, 2, 4, 8, 16, 32, 64, 128])
        probability_density = np.array([0, 5, 15, 34, 40, 5, 1, 0]) / 100.
        computed_threshold, best_fit, hpd_interval = get_results(x_array=x_array,
                                                                 probability_density=probability_density,
                                                                 probability_threshold=0.05,
                                                                 interp_points=1000)
        expected_threshold = 0.02557788944723618
        expected_best_fit = 8
        expected_hpd_interval = [1.5115, 51.5395]
        self.assertAlmostEqual(expected_threshold, computed_threshold, 5)
        self.assertEqual(expected_best_fit, best_fit)
        self.assertTrue(np.allclose(np.array(expected_hpd_interval), hpd_interval, rtol=1e-4))
