#include "iteration_functions.h"
#include <gsl/gsl_interp2d.h>

double I0_center_updown(double tau, double u);
double I0_upward(double tau, double u);
double I0_uniform_updown(double tau, double u);
double I0_exp_upward(double tau, double u);
double I0_exp_downward(double tau, double u);


void set_spline2d_obj(gsl_spline2d* Spline_I2d, double* za, double** I_funct);
int set_array_idx(double tau_prev, double tau, double* array, int nstep);

double tau0;

/*================================================================================*/

void optimize_taustep(int k_iter, int seed_distribution)
{
  Nstep_tau = 100;

  tau0 = disktau / 2.;

  while (solve_rte(k_iter, seed_distribution, Nstep_tau))
  {
    Nstep_tau = (int)Nstep_tau * 1.2;
  }
}

/*================================================================================*/

int solve_rte(int k_iter, int seed_distribution, int Nstep_tau)
{

  printf("\n\nReceveing Nstep_tau %d\n", Nstep_tau);
  int ii, jj, kk, status;

  Nstep_mu = 30;
  /* Nstep_tau = 100;*/

  tau0 = disktau / 2.;

  // printf("tau0=%4.3f\n", tau0);

  static double y[4];
  void* jac = NULL;
  double step_tau;
  double* array_tau;
  double* array_mu;
  double* array_weights;

  gsl_spline2d* Spline_I2d_l_upstream;
  gsl_spline2d* Spline_I2d_l_downstream;
  gsl_spline2d* Spline_I2d_r_upstream;
  gsl_spline2d* Spline_I2d_r_downstream;

  /*==============================================================================================*/
  float* roots = malloc((Nstep_mu + 1) * sizeof(float));
  float* weights = malloc((Nstep_mu + 1) * sizeof(float));

  array_mu = malloc(Nstep_mu * sizeof(double));
  array_weights = malloc(Nstep_mu * sizeof(double));

  gauleg(0, 1, roots, weights, Nstep_mu);

  for (ii = 1; ii <= Nstep_mu; ii++)
  {
    array_mu[ii - 1] = roots[ii];
    array_weights[ii - 1] = weights[ii];
    // printf("%lf 1 \n", array_mu[ii - 1]);
  }

  array_tau = malloc(Nstep_tau * sizeof(double));
  make_linarray(0, 2. * tau0, Nstep_tau, array_tau);

  Iklr_intensity* ptr_Iklr = malloc(k_iter * sizeof(Iklr_intensity));

  for (ii = 0; ii < k_iter; ii++)
  {
    ptr_Iklr[ii].Il_matrix_upstream = dmatrix(0, Nstep_tau - 1, 0, Nstep_mu - 1);
    ptr_Iklr[ii].Ir_matrix_upstream = dmatrix(0, Nstep_tau - 1, 0, Nstep_mu - 1);

    ptr_Iklr[ii].Il_matrix_downstream = dmatrix(0, Nstep_tau - 1, 0, Nstep_mu - 1);
    ptr_Iklr[ii].Ir_matrix_downstream = dmatrix(0, Nstep_tau - 1, 0, Nstep_mu - 1);
  }

  /*==============================================================================*/

  for (ii = 0; ii < Nstep_tau; ii++)
  {
    for (jj = 0; jj < Nstep_mu; jj++)
    {
      if (seed_distribution == SEED_CENTER)
      {
        ptr_Iklr[0].Il_matrix_upstream[ii][jj] = I0_center_updown(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_upstream[ii][jj] = I0_center_updown(array_tau[ii], array_mu[jj]);

        ptr_Iklr[0].Il_matrix_downstream[ii][jj] = I0_center_updown(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_downstream[ii][jj] = I0_center_updown(array_tau[ii], array_mu[jj]);

        /*if (jj==10)
        {
            printf("tau=%5.4f I=%5.4e\n", array_tau[ii], ptr_Iklr[0].Il_matrix_upstream[ii][jj]);
        }*/
      }
      else if (seed_distribution == SEED_UNIFORM)
      {
        ptr_Iklr[0].Il_matrix_upstream[ii][jj] = I0_uniform_updown(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_upstream[ii][jj] = I0_uniform_updown(array_tau[ii], array_mu[jj]);

        ptr_Iklr[0].Il_matrix_downstream[ii][jj] = I0_uniform_updown(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_downstream[ii][jj] = I0_uniform_updown(array_tau[ii], array_mu[jj]);
      }

      else if (seed_distribution == SEED_BASE)
      {
        ptr_Iklr[0].Il_matrix_upstream[ii][jj] = I0_upward(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_upstream[ii][jj] = I0_upward(array_tau[ii], array_mu[jj]);

        ptr_Iklr[0].Il_matrix_downstream[ii][jj] = 0;
        ptr_Iklr[0].Ir_matrix_downstream[ii][jj] = 0;
      }
      
      else if (seed_distribution == SEED_EXP)
      {
        ptr_Iklr[0].Il_matrix_upstream[ii][jj] = I0_exp_upward(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_upstream[ii][jj] = I0_exp_upward(array_tau[ii], array_mu[jj]);

        ptr_Iklr[0].Il_matrix_downstream[ii][jj] = I0_exp_downward(array_tau[ii], array_mu[jj]);
        ptr_Iklr[0].Ir_matrix_downstream[ii][jj] = I0_exp_downward(array_tau[ii], array_mu[jj]);
        
        
      }
 
      
      else
      {
        printf(
            "Error: please select seed=4 (photons at the base) or seed=2 (uniform distribution)\n");
        exit(1);
      }
    }
  }



  /*=================================================================*/
  /*Setup for 2D interpolation*/
  /*=================================================================*/

  double* Il_upstream_za = malloc(Nstep_tau * Nstep_mu * sizeof(double));
  double* Il_downstream_za = malloc(Nstep_tau * Nstep_mu * sizeof(double));
  double* Ir_upstream_za = malloc(Nstep_tau * Nstep_mu * sizeof(double));
  double* Ir_downstream_za = malloc(Nstep_tau * Nstep_mu * sizeof(double));

  const gsl_interp2d_type* T_interp2d = gsl_interp2d_bilinear;

  Spline_I2d_l_upstream = gsl_spline2d_alloc(T_interp2d, Nstep_tau, Nstep_mu);
  Spline_I2d_l_downstream = gsl_spline2d_alloc(T_interp2d, Nstep_tau, Nstep_mu);

  Spline_I2d_r_upstream = gsl_spline2d_alloc(T_interp2d, Nstep_tau, Nstep_mu);
  Spline_I2d_r_downstream = gsl_spline2d_alloc(T_interp2d, Nstep_tau, Nstep_mu);

  xacc_2d = gsl_interp_accel_alloc();
  yacc_2d = gsl_interp_accel_alloc();

  step_tau = (2 * tau0) / Nstep_tau;


  Iklr_intensity* ptr_data = malloc(sizeof(Iklr_intensity));
  Iklr_intensity* params;

  // const gsl_odeiv_step_type* T = gsl_odeiv_step_rk4;

  /*================================================================================*/
  /*Start the loop over k-iterations*/
  /*================================================================================*/

  double tau, tau_prev;
  // step_tau = 0.01;

  for (kk = 1; kk < k_iter; kk++)
  {
    printf("Iteration k=%d....\n\n", kk);

    /*==========================================================================================*/
    /*Initialize the spline for 2D interpolation for Il(tau,u)^{k-1} and Ir(tau,u)^{k-1}*/
    /*both for upward and downward intensities*/
    /*==========================================================================================*/

    set_spline2d_obj(Spline_I2d_l_upstream, Il_upstream_za, ptr_Iklr[kk - 1].Il_matrix_upstream);
    set_spline2d_obj(Spline_I2d_l_downstream, Il_downstream_za, ptr_Iklr[kk - 1].Il_matrix_downstream);

    set_spline2d_obj(Spline_I2d_r_upstream, Ir_upstream_za, ptr_Iklr[kk - 1].Ir_matrix_upstream);
    set_spline2d_obj(Spline_I2d_r_downstream, Ir_downstream_za, ptr_Iklr[kk - 1].Ir_matrix_downstream);

    gsl_spline2d_init(Spline_I2d_l_upstream, array_tau, array_mu, Il_upstream_za, Nstep_tau, Nstep_mu);
    gsl_spline2d_init(Spline_I2d_l_downstream, array_tau, array_mu, Il_downstream_za, Nstep_tau, Nstep_mu);

    gsl_spline2d_init(Spline_I2d_r_upstream, array_tau, array_mu, Ir_upstream_za, Nstep_tau, Nstep_mu);
    gsl_spline2d_init(Spline_I2d_r_downstream, array_tau, array_mu, Ir_downstream_za, Nstep_tau, Nstep_mu);

    ptr_data->Spline_I2d_l_upstream = Spline_I2d_l_upstream;
    ptr_data->Spline_I2d_l_downstream = Spline_I2d_l_downstream;

    ptr_data->Spline_I2d_r_upstream = Spline_I2d_r_upstream;
    ptr_data->Spline_I2d_r_downstream = Spline_I2d_r_downstream;


    /*=======================================================================*/
    /*Solve the RTE at the quadrature points*/
    /*=======================================================================*/

    status = GSL_SUCCESS + 1;

    //gsl_set_error_handler_off();

    for (jj = 0; jj < Nstep_mu; jj++)
    {
      // printf("\nSolve system for u=%4.3f\n", array_mu[jj]);

      ptr_data->u = array_mu[jj];
      ptr_data->array_u = array_mu;
      ptr_data->weights_u = array_weights;

      params = ptr_data;

      /*======================================*/
      /*Initial boundary conditions*/
      /*======================================*/

      tau = 0;
      tau_prev = 0;
      int idx;

      y[0] = 0;
      y[1] = 0;
      y[2] = 0;
      y[3] = 0;

      gsl_odeiv2_system sys = {RTE_Equations, jac, 4, params};
      gsl_odeiv2_driver* d = gsl_odeiv2_driver_alloc_y_new(&sys, gsl_odeiv2_step_rk8pd, step_tau, 1e-4, 1e-4);

      while (tau < 2 * tau0)
      {
        // printf("tau prima %lf\n", tau);

        status = gsl_odeiv2_driver_apply_fixed_step(d, &tau, step_tau, 1, y);

        if (status)
        {
          printf("Error during integration: %s\n", gsl_strerror(status));
          // printf("Status %d\n", status);
          return (1);
        }

        idx = set_array_idx(tau_prev, tau, array_tau, Nstep_tau);

        if (idx > 0)
        {
          ptr_Iklr[kk].Il_matrix_upstream[idx][jj] = y[0];
          ptr_Iklr[kk].Ir_matrix_upstream[idx][jj] = y[1];

          ptr_Iklr[kk].Il_matrix_downstream[idx][jj] = y[2];
          ptr_Iklr[kk].Ir_matrix_downstream[idx][jj] = y[3];
        }

        tau_prev = tau;
      }

      gsl_odeiv2_driver_free(d);

    } // End of loop over array_mu

  } // End of loop over k_iter

  Ikl_intensity* ptr_a = 0;
  Ikr_intensity* ptr_b = 0;

  compute_results(1, k_iter, Nstep_tau, Nstep_mu, array_mu, array_weights, ptr_Iklr, ptr_a, ptr_b);

  free(Il_upstream_za);
  free(Il_downstream_za);
  free(Ir_upstream_za);
  free(Ir_downstream_za);
  free(array_tau);
  free(ptr_Iklr);

  gsl_spline2d_free(Spline_I2d_l_upstream);
  gsl_spline2d_free(Spline_I2d_l_downstream);
  gsl_spline2d_free(Spline_I2d_r_upstream);
  gsl_spline2d_free(Spline_I2d_r_downstream);


  return 0;
}

/*================================================================================*/
/*Solution of equations (4.206) and (4.207) of Pomraning 1973*/
/*================================================================================*/

int RTE_Equations(double s, const double y[], double f[], void* params)
{
  double* array_u = ((Iklr_intensity*)params)->array_u;
  double* weights_u = (double*)((Iklr_intensity*)params)->weights_u;
  double u = ((Iklr_intensity*)params)->u;

  gsl_spline2d* Spline_I2d_l_upstream = ((Iklr_intensity*)params)->Spline_I2d_l_upstream;
  gsl_spline2d* Spline_I2d_l_downstream = ((Iklr_intensity*)params)->Spline_I2d_l_downstream;

  gsl_spline2d* Spline_I2d_r_upstream = ((Iklr_intensity*)params)->Spline_I2d_r_upstream;
  gsl_spline2d* Spline_I2d_r_downstream = ((Iklr_intensity*)params)->Spline_I2d_r_downstream;


  /*==============================================*/
  /*Equation for I_l  upstream*/
  /*==============================================*/

  //printf("UNO  s=%lf\n");


  f[0] = -1 / u * y[0] +
         3 / (8 * u) *
             (legendre_integration_A(Spline_I2d_l_upstream, s, u, array_u, weights_u) +
              legendre_integration_A(Spline_I2d_l_downstream, 2 * tau0 - s, u, array_u, weights_u) +

              legendre_integration_B(Spline_I2d_r_upstream, s, u, array_u, weights_u) +
              legendre_integration_B(Spline_I2d_r_downstream, 2 * tau0 - s, u, array_u, weights_u));



  //printf("UNO DOPO  s=%10.7f\n");

  /*==============================================*/
  /*Equation for I_r upstream*/
  /*==============================================*/



  f[1] = -1 / u * y[1] +
         3 / (8 * u) *
             (legendre_integration_C(Spline_I2d_l_upstream, s, u, array_u, weights_u) +
              legendre_integration_C(Spline_I2d_l_downstream, 2 * tau0 - s, u, array_u, weights_u) +
              legendre_integration_D(Spline_I2d_r_upstream, s, u, array_u, weights_u) +
              legendre_integration_D(Spline_I2d_r_downstream, 2 * tau0 - s, u, array_u, weights_u));

  /*==============================================*/
  /*Equation for I_l downstream*/
  /*==============================================*/



  f[2] = -1 / u * y[2] +
         3 / (8 * u) *
             (legendre_integration_A(Spline_I2d_l_downstream, s, u, array_u, weights_u) +
              legendre_integration_A(Spline_I2d_l_upstream, 2 * tau0 - s, u, array_u, weights_u) +

              legendre_integration_B(Spline_I2d_r_downstream, s, u, array_u, weights_u) +
              legendre_integration_B(Spline_I2d_r_upstream, 2 * tau0 - s, u, array_u, weights_u));

  /*==============================================*/
  /*Equation for I_r downstream*/
  /*==============================================*/



  f[3] = -1 / u * y[3] +
         3 / (8 * u) *
             (legendre_integration_C(Spline_I2d_l_downstream, s, u, array_u, weights_u) +
              legendre_integration_C(Spline_I2d_l_upstream, 2 * tau0 - s, u, array_u, weights_u) +

              legendre_integration_D(Spline_I2d_r_downstream, s, u, array_u, weights_u) +
              legendre_integration_D(Spline_I2d_r_upstream, 2 * tau0 - s, u, array_u, weights_u));

  // printf("SLURM tau=%lf u=%5.4f f[0]=%5.4e f[1]=%5.4e f[2]=%5.4e f[3]=%5.4e\n", s, u, f[0], f[1], f[2], f[3]);

  return GSL_SUCCESS;
}


/*====================================================================*/
/*Value of the I^k(tau,u) function for k=0 and seed photons*/
/*with exponential distribution for u > 0 and u < 0*/
/*====================================================================*/

double I0_exp_upward(double tau, double u)
{
	double value;
	
	value= ((-1 + exp(((-1 + u) * tau) / u))) / 
           (2.0 * exp(tau) * (-1 + u));
	
	return value;
	
}


double I0_exp_downward(double tau, double u)
{
	double value;
	  
   	value=(exp(-2 * tau0) * (exp(tau) - exp(-tau / u))) / 
           (2.0 * (1 + u));
    
          
	return value;
	
}



/*====================================================================*/
/*Value of the I^k(tau,u) function for k=0 and seed photons*/
/*at the base of the slab (tau=0)*/
/*====================================================================*/

double I0_upward(double tau, double u)
{
  double value;

  if (tau == 0)
  {
    value = 0;
  }
  else
  {
    value = 1 / 2. * 1 / u * exp(-tau / u);
  }

  return value;
}

/*====================================================================*/
/*Value of the I^k(tau,u) function for k=0 and seed photons*/
/*uniformly distributed*/
/*====================================================================*/

double I0_uniform_updown(double tau, double u)
{
  double value;

  value = 1 / 2. * (1 - exp(-tau / u));

  return value;
}

/*====================================================================*/
/*Value of the I^k(tau,u) function for k=0 and seed photons*/
/*uniformly distributed*/
/*====================================================================*/

double I0_center_updown(double tau, double u)
{
  double value;

  if (tau > tau0)
  {
    value = 1 / 2. * 1 / u * exp(-(tau - tau0) / u);
  }
  else
  {
    value = 0;
  }

  return value;
}

/*==============================================================================*/
void set_spline2d_obj(gsl_spline2d* Spline_I2d, double* za, double** I_funct)
{
  int tt, qq;

  for (tt = 0; tt < Nstep_tau; tt++)
  {
    for (qq = 0; qq < Nstep_mu; qq++)
    {
      /*printf("tt=%d qq=%d funct=%4.3e\n", tt, qq, I_funct[tt][qq]); */
      gsl_spline2d_set(Spline_I2d, za, tt, qq, I_funct[tt][qq]);
    }
  }
}

int set_array_idx(double tau_prev, double tau, double* array, int nstep)
{
  int idx, ii;

  idx = -1;

  if (tau_prev == array[0])
  {
    idx = 0;
  }
  else if (tau == array[nstep - 1])
  {
    idx = nstep - 1;
  }
  else
  {
    for (ii = 1; ii < nstep; ii++)
    {
      if (tau_prev <= array[ii] && tau >= array[ii])
      {
        idx = ii;
        // printf("Sono qui tau_prev %4.3f tau %4.3f array[ii] %4.3f idx %d\n", tau_prev, tau,
        //		array[ii], idx);

        break;
      }
    }
  }

  return idx;
}