#include "iteration_functions.h"

gsl_spline2d* Spline_I2d_l;
gsl_spline2d* Spline_I2d_r;
gsl_interp_accel* xacc_2d;
gsl_interp_accel* yacc_2d;

gsl_spline* Spline_sample_Iu;
gsl_spline* Spline_eval_limbdark;
gsl_spline* Spline_eval_pdeg;

gsl_spline* Spline_tau_Ak;
gsl_spline* Spline_tau_Bk;
gsl_spline* Spline_tau_Ck;

gsl_interp_accel* xacc_1d;
gsl_interp_accel* yacc_1d;

int Nstep_tau;
int Nstep_mu;
double epsilon;
double u_min;
double u_max;
double tau_s;

/*========================================================================*/

void slab_polarization(int kiter, int seed_location)
{
  Nstep_mu = 20;
  Nstep_tau = 100;

  tau0 = disktau / 2.;

  int ii, jj, kk;

  double* array_mu;
  double* array_tau;
  double* array_weights;
  double* Ak_array;
  double* Bk_array;
  double* Ck_array;

  double error;

  printf("tau0=%4.3f\n", tau0);

  array_mu = malloc(Nstep_mu * sizeof(double));
  array_weights = malloc(Nstep_mu * sizeof(double));
  array_tau = malloc(Nstep_tau * sizeof(double));

  Ak_array = malloc(Nstep_tau * sizeof(double));
  Bk_array = malloc(Nstep_tau * sizeof(double));
  Ck_array = malloc(Nstep_tau * sizeof(double));

  epsilon = 1e-5;

  tau_s = 2 * tau0 - tau_c;

  /*=================================================================*/

  float* roots_u = malloc((Nstep_mu + 1) * sizeof(float));
  float* weights = malloc((Nstep_mu + 1) * sizeof(float));

  gauleg(0, 1, roots_u, weights, Nstep_mu);

  for (ii = 1; ii <= Nstep_mu; ii++)
  {
    array_mu[ii - 1] = roots_u[ii];
    array_weights[ii - 1] = weights[ii];
  }

  u_min = array_mu[0];
  u_max = array_mu[Nstep_mu - 1];

  /*make_linarray(u_min, 1, Nstep_mu, array_mu);*/

  /* exit(1);*/
  /*=================================================================*/

  make_linarray(0, 2. * tau0, Nstep_tau, array_tau);

  Ikl_intensity* ptr_Ikl;
  Ikr_intensity* ptr_Ikr;

  ptr_Ikl = malloc(kiter * sizeof(Ikl_intensity));
  ptr_Ikr = malloc(kiter * sizeof(Ikr_intensity));

  for (ii = 0; ii < kiter; ii++)
  {
    ptr_Ikl[ii].matrix = dmatrix(0, Nstep_tau - 1, 0, Nstep_mu - 1);
    ptr_Ikr[ii].matrix = dmatrix(0, Nstep_tau - 1, 0, Nstep_mu - 1);
  }

  /*=================================================================*/
  /*Set the function for photons undegoing no scattering*/
  /*=================================================================*/

  for (ii = 0; ii < Nstep_tau; ii++)
  {
    for (jj = 0; jj < Nstep_mu; jj++)
    {
      if (seed_location == SEED_CENTER)
      {
        ptr_Ikl[0].matrix[ii][jj] = I0_center(tau0, array_tau[ii], array_mu[jj]);
        ptr_Ikr[0].matrix[ii][jj] = I0_center(tau0, array_tau[ii], array_mu[jj]);
      }
      else if (seed_location == SEED_UNIFORM)
      {
        ptr_Ikl[0].matrix[ii][jj] = I0_uniform(tau0, array_tau[ii], array_mu[jj]);
        ptr_Ikr[0].matrix[ii][jj] = I0_uniform(tau0, array_tau[ii], array_mu[jj]);
      }
      else if (seed_location == SEED_DELTA)
      {
        ptr_Ikl[0].matrix[ii][jj] = I0_delta(array_tau[ii], tau_s, array_mu[jj]);
        ptr_Ikr[0].matrix[ii][jj] = I0_delta(array_tau[ii], tau_s, array_mu[jj]);
        printf("Do not use this configuration for seed photons fot ST85 algorithm\n");
        exit(1);
      }
      else if (seed_location == SEED_BASE)
      {
        ptr_Ikl[0].matrix[ii][jj] = I0_base(array_tau[ii], tau_s, array_mu[jj], U_POSITIVE);
        ptr_Ikr[0].matrix[ii][jj] = I0_base(array_tau[ii], tau_s, array_mu[jj], U_POSITIVE);
        printf("Do not use this configuration for seed photons fot ST85 algorithm\n");
        exit(1);
      }
      else if (seed_location == SEED_EXP)
      {
        ptr_Ikl[0].matrix[ii][jj] = I0_exp(array_tau[ii], tau0, array_mu[jj]);
        ptr_Ikr[0].matrix[ii][jj] = I0_exp(array_tau[ii], tau0, array_mu[jj]);
        printf("Do not use this configuration for seed photons fot ST85 algorithm\n");
        exit(1);
      }

      else
      {
        printf("Wait, do not use this!\n");
        exit(1);
      }
    }
  }

  /*=================================================================*/
  /*Setup for 2D interpolation*/
  /*=================================================================*/

  double* za = malloc(Nstep_tau * Nstep_mu * sizeof(double));
  const gsl_interp2d_type* T_interp2d = gsl_interp2d_bilinear;

  Spline_I2d_l = gsl_spline2d_alloc(T_interp2d, Nstep_tau, Nstep_mu);
  Spline_I2d_r = gsl_spline2d_alloc(T_interp2d, Nstep_tau, Nstep_mu);

  xacc_2d = gsl_interp_accel_alloc();
  yacc_2d = gsl_interp_accel_alloc();

  /*=================================================================*/
  /*Initialize objects for 1D interpolation*/
  /*=================================================================*/

  Spline_tau_Ak = gsl_spline_alloc(gsl_interp_cspline, Nstep_tau);
  Spline_tau_Bk = gsl_spline_alloc(gsl_interp_cspline, Nstep_tau);
  Spline_tau_Ck = gsl_spline_alloc(gsl_interp_cspline, Nstep_tau);

  xacc_1d = gsl_interp_accel_alloc();

  /*=================================================================*/

  AngularFunctionParams* ptr_data = malloc(sizeof(AngularFunctionParams));
  ptr_data->seed_location = seed_location;
  ptr_data->tau_s = tau_s;

  /*=============================================================*/
  /*Il and Ir functions*/
  /*=============================================================*/

  gsl_function I_l;
  I_l.function = &Il_integral;

  gsl_function I_r;
  I_r.function = &Ir_integral;

  /*=============================================================*/

  gsl_integration_workspace* w = gsl_integration_workspace_alloc(1000);

  double params_Ifunct[2];
  double result_Ir;
  double result_Il;

  /*=============================================================*/
  /*Here start the iteration over scattering*/
  /*=============================================================*/

  printf("Running iteration procedure using ST85 algorithm...\n");

  for (kk = 1; kk < kiter; kk++)
  {
    if (kk % 5 == 0)
      printf("iteration k=%d\n", kk);
    /*=============================================*/
    /*Loop over tau*/
    /*=============================================*/

    for (ii = 0; ii < Nstep_tau; ii++)
    {
      ptr_data->tau = array_tau[ii];
      ptr_data->k = kk;

      Compute_Ak_array(ptr_data, ptr_Ikl[kk - 1].matrix, za, array_tau, array_mu, kk, w, Ak_array);

      gsl_spline_init(Spline_tau_Ak, array_tau, Ak_array, Nstep_tau);

      Compute_Bk_array(ptr_data, ptr_Ikl[kk - 1].matrix, za, array_tau, array_mu, kk, w, Bk_array);

      gsl_spline_init(Spline_tau_Bk, array_tau, Bk_array, Nstep_tau);

      Compute_Ck_array(ptr_data, ptr_Ikr[kk - 1].matrix, za, array_tau, array_mu, kk, w, Ck_array);

      gsl_spline_init(Spline_tau_Ck, array_tau, Ck_array, Nstep_tau);

      /*=============================================*/
      /*Loop over mu*/
      /*=============================================*/

      for (jj = 0; jj < Nstep_mu; jj++)
      {
        params_Ifunct[0] = array_tau[ii];
        params_Ifunct[1] = array_mu[jj];

        I_l.params = params_Ifunct;
        I_r.params = params_Ifunct;

        if (array_tau[ii] == 2 * tau0)
        {
          ptr_Ikl[kk].matrix[ii][jj] = 0;
          ptr_Ikr[kk].matrix[ii][jj] = 0;
        }
        else
        {
          gsl_integration_qags(&I_l, array_tau[ii], 2 * tau0, epsilon, epsilon, 1000, w, &result_Il, &error);
          gsl_integration_qags(&I_r, array_tau[ii], 2 * tau0, epsilon, epsilon, 1000, w, &result_Ir, &error);

          if (isnan(result_Il))
          {
            printf("Warning: Il is nan for tau=%lf mu=%lf\n", array_tau[ii], array_mu[jj]);
            result_Il = 0;
          }

          if (isnan(result_Ir))
          {
            printf("Warning: Ir is nan for tau=%lf mu=%lf\n", array_tau[ii], array_mu[jj]);
            result_Ir = 0;
          }

          ptr_Ikl[kk].matrix[ii][jj] = result_Il;
          ptr_Ikr[kk].matrix[ii][jj] = result_Ir;
        }
      }

    } /*Loop over tau*/

  } /*Loop over kiter*/

  /*========================================================================*/
  /*Now sum results*/
  /*========================================================================*/

  Iklr_intensity* ptr_void = 0;

  compute_results(2, kiter, Nstep_tau, Nstep_mu, array_mu, array_weights, ptr_void, ptr_Ikl, ptr_Ikr);

} // End of function

/*==========================================================================*/
void make_linarray(double a, double b, int nstep, double* my_array)
{
  int ii;
  double hstep;

  hstep = (b - a) / (nstep - 1);

  for (ii = 0; ii < nstep; ii++)
  {
    my_array[ii] = a + hstep * ii;
    /*printf("ii=%d value=%lf\n", ii, my_array[ii]);*/
  }
}

/*========================================================================*/
/*Distribution of photons with zero scattering*/
/*========================================================================*/

double theta_funct(double tau, double tau_c)
{
  double x;
  double flag;
  x = tau_c - tau;

  if (x > 0)
  {
    flag = 1;
  }
  else
  {
    flag = 0;
  }

  // printf("Argument of theta : tau %lf tau_delta %lf flag= %d \n", tau, tau_c, flag);
  return flag;
}

double I0_center(double tau0, double tau, double mu)
{
  double arg_exp, value;

  arg_exp = (tau0 - tau) / mu;

  value = 1 / mu * theta_funct(tau, tau0) * exp(-arg_exp);

  return value;
}

double I0_delta(double tau, double tau_s, double mu)

{
  double value;

  value = 1 / mu * theta_funct(tau, tau_s) * exp(-(tau_s - tau) / mu);

  return value;
}

double I0_uniform(double tau0, double tau, double mu)
{
  double value;

  value = 1 - exp(-(2 * tau0 - tau) / mu);

  return value;
}

double I0_exp(double tau, double tau0, double mu)
{
  double value;
  double tau_fold = 0.2;

  value = tau_fold / (tau_fold - mu) * (exp(-(2 * tau0 - tau) / tau_fold) - exp(-(2 * tau0 - tau) / mu));

  return value;
}

double I0_base(double tau, double tau_s, double mu, int sign_mu)
{
  double value;

  // printf("tau_s=%4.3f tau_c %4.3f\n", tau_s, tau_c);

  if (sign_mu > 0)
  {
    if (tau > tau_s)
    {
      value = 1 - exp(-(2 * tau0 - tau) / mu);
    }
    else
    {
      value = (-exp(-(2 * tau0 - tau) / mu) + exp(-(tau_s - tau) / mu));
    }
  }
  else
  {
    if (tau < tau_c)
    {
      value = 1 - exp(-(tau_c - tau) / mu);
    }
    else
    {
      value = 0;
    }
  }

  return value;
}

/*========================================================================*/
/*Functions A^{k-1}, B^{k-1}, C^{k-1}*/
/*========================================================================*/

void Compute_Ak_array(AngularFunctionParams* ptr_data,
                      double** Il_funct,
                      double* za,
                      double* array_tau,
                      double* array_mu,
                      int k,
                      gsl_integration_workspace* w,
                      double* Ak_array)
{
  int ii, tt, qq;
  double result, error;

  gsl_function F_Ak;
  F_Ak.function = &Ak_integral;

  /*==================================*/
  /*Loop over tau*/
  /*==================================*/

  for (ii = 0; ii < Nstep_tau; ii++)
  {
    for (tt = 0; tt < Nstep_tau; tt++)
    {
      for (qq = 0; qq < Nstep_mu; qq++)
      {
        gsl_spline2d_set(Spline_I2d_l, za, tt, qq, Il_funct[tt][qq]);
      }
    }

    gsl_spline2d_init(Spline_I2d_l, array_tau, array_mu, za, Nstep_tau, Nstep_mu);

    ptr_data->tau = array_tau[ii];
    ptr_data->k = k;

    F_Ak.params = ptr_data;

    // printf("Integration between %lf - %lf\n", u_min, u_max);

    gsl_integration_qags(&F_Ak, u_min, u_max, epsilon, epsilon, 1000, w, &result, &error);
    Ak_array[ii] = result;
  }
}

/*=================================================================================*/

void Compute_Bk_array(AngularFunctionParams* ptr_data,
                      double** Il_funct,
                      double* za,
                      double* array_tau,
                      double* array_mu,
                      int k,
                      gsl_integration_workspace* w,
                      double* Bk_array)
{
  int ii, tt, qq;
  double result, error;

  gsl_function F_Bk;
  F_Bk.function = &Bk_integral;

  /*==================================*/
  /*Loop over tau*/
  /*==================================*/

  for (ii = 0; ii < Nstep_tau; ii++)
  {
    for (tt = 0; tt < Nstep_tau; tt++)
    {
      for (qq = 0; qq < Nstep_mu; qq++)
      {
        gsl_spline2d_set(Spline_I2d_l, za, tt, qq, Il_funct[tt][qq]);
      }
    }

    gsl_spline2d_init(Spline_I2d_l, array_tau, array_mu, za, Nstep_tau, Nstep_mu);

    ptr_data->tau = array_tau[ii];
    ptr_data->k = k;
    F_Bk.params = ptr_data;

    gsl_integration_qags(&F_Bk, u_min, u_max, epsilon, epsilon, 1000, w, &result, &error);
    Bk_array[ii] = result;
  }
}

/*=========================================================================*/

void Compute_Ck_array(AngularFunctionParams* ptr_data,
                      double** Ir_funct,
                      double* za,
                      double* array_tau,
                      double* array_mu,
                      int k,
                      gsl_integration_workspace* w,
                      double* Ck_array)
{
  int ii, tt, qq;
  double result, error;

  gsl_function F_Ck;
  F_Ck.function = &Ck_integral;

  /*==================================*/
  /*Loop over tau*/
  /*==================================*/

  for (ii = 0; ii < Nstep_tau; ii++)
  {
    for (tt = 0; tt < Nstep_tau; tt++)
    {
      for (qq = 0; qq < Nstep_mu; qq++)
      {
        gsl_spline2d_set(Spline_I2d_r, za, tt, qq, Ir_funct[tt][qq]);
      }
    }

    gsl_spline2d_init(Spline_I2d_r, array_tau, array_mu, za, Nstep_tau, Nstep_mu);

    ptr_data->tau = array_tau[ii];
    ptr_data->k = k;
    F_Ck.params = ptr_data;

    gsl_integration_qags(&F_Ck, u_min, u_max, epsilon, epsilon, 1000, w, &result, &error);
    Ck_array[ii] = result;
  }
}
