#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef USE_MPI
#include <mpi.h>
#ifdef USE_FFTW
#include <fftw3-mpi.h>
#endif
#endif
#include <omp.h>
#include <math.h>
#include <time.h>
#include <unistd.h>
#ifdef ACCOMP
#include "w-stacking_omp.h"
#else
#include "w-stacking.h"
#endif
#define PI 3.14159265359
#define NUM_OF_SECTORS -1
#define MIN(X, Y) (((X) < (Y)) ? (X) : (Y))
#define MAX(X, Y) (((X) > (Y)) ? (X) : (Y))
#define NOVERBOSE
#define NFILES 100

// Linked List set-up
struct sectorlist {
     long index;
     struct sectorlist * next;
};

void Push(struct sectorlist** headRef, long data) {
     struct sectorlist* newNode = malloc(sizeof(struct sectorlist));
     newNode->index = data;
     newNode->next = *headRef;
     *headRef = newNode;
}

// Main Code
int main(int argc, char * argv[])
{
	int rank;
	int size;

	FILE * pFile;
	FILE * pFile1;
	FILE * pFilereal;
	FILE * pFileimg;
	char filename[1000];
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/old/data/gauss2_t201806301100_SBL180.binMS/";
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/newgauss2noconj_t201806301100_SBL180.binMS/";
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/gauss1_t201806301100_SBL180.binMS/";
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/newgauss4_t201806301100_SBL180.binMS/";
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/gauss4_t201806301100_SBL180.binMS/";
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/hba-8hrs_t201806301100_SBH255i-test.binMS/";
	//
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/gridding/hba-8hrs_gauss4new.binMS/";
	//char datapath[900] = "/m100_scratch/userexternal/cgheller/Lofar/Observations/L798046_SB244_uv.uncorr_130B27932t_146MHz.pre-cal.binMS/";
	char datapath[900];
	char datapath_multi[NFILES][900];

	char ufile[30] = "ucoord.bin";
  char vfile[30] = "vcoord.bin";
  char wfile[30] = "wcoord.bin";
  char weightsfile[30] = "weights.bin";
  char visrealfile[30] = "visibilities_real.bin";
  char visimgfile[30] = "visibilities_img.bin";
  char metafile[30] = "meta.txt";
	char outfile[30] = "grid.txt";
	char outfile1[30] = "coords.txt";
	char outfile2[30] = "grid_real.bin";
	char outfile3[30] = "grid_img.bin";
	char fftfile[30] = "fft.txt";
	char fftfile2[30] = "fft_real.bin";
	char fftfile3[30] = "fft_img.bin";
	char logfile[30] = "run.log";
	char extension[30] = ".txt";
	char srank[4];
	char timingfile[30] = "timings.dat";

	double * uu;
	double * vv;
	double * ww;
	float * weights;
	float * visreal;
	float * visimg;

	long Nmeasures,Nmeasures0;
	long Nvis,Nvis0;
	long Nweights,Nweights0;
	long freq_per_chan,freq_per_chan0;
	long polarisations,polarisations0;
        long Ntimes,Ntimes0;
	double dt,dt0;
	double thours,thours0;
	long baselines,baselines0;
	double uvmin,uvmin0;
	double uvmax,uvmax0;
	double wmin,wmin0;
	double wmax,wmax0;
	double resolution;

  // MESH SIZE
	int grid_size_x = 2048;
	int grid_size_y = 2048;
	int local_grid_size_x;// = 8;
	int local_grid_size_y;// = 8;
	int xaxis;
	int yaxis;
	int num_w_planes = 8;

	// DAV: the corresponding KernelLen is calculated within the wstack function. It can be anyway hardcoded for optimization
	int w_support = 7;
	int num_threads;// = 4;
	double dx = 1.0/(double)grid_size_x;
	double dw = 1.0/(double)num_w_planes;
	double w_supporth = (double)((w_support-1)/2)*dx;

	clock_t start, end, start0, startk, endk;
	double setup_time, process_time, mpi_time, fftw_time, tot_time, kernel_time, reduce_time, compose_time, phase_time;
	double setup_time1, process_time1, mpi_time1, fftw_time1, tot_time1, kernel_time1, reduce_time1, compose_time1, phase_time1;
	double writetime, writetime1;

	struct timespec begin, finish, begin0, begink, finishk;
	double elapsed;
	long nsectors;
  /* GT get nymber of threads exit if not given */
  if(argc == 1) {
    fprintf(stderr, "Usage: %s number_of_OMP_Threads \n", argv[0]);
    exit(1);
  }
  // Set the number of OpenMP threads
  num_threads = atoi(argv[1]);

  if ( num_threads == 0 )
  {
    fprintf(stderr, "Wrong parameter: %s\n\n", argv[1]);
    fprintf(stderr, "Usage: %s number_of_OMP_Threads \n", argv[0]);
    exit(1);
  }

	clock_gettime(CLOCK_MONOTONIC, &begin0);
	start0 = clock();

	// Intialize MPI environment
#ifdef USE_MPI
	MPI_Init(&argc,&argv);
	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
	MPI_Comm_size(MPI_COMM_WORLD, &size);
	if(rank == 0)printf("Running with %d MPI tasks\n",size);
  #ifdef USE_FFTW
	fftw_mpi_init();
  #endif
#else
	rank = 0;
	size = 1;
#endif

  if(rank == 0)printf("Running with %d threads\n",num_threads);

#ifdef ACCOMP
if(rank == 0){
  if (0 == omp_get_num_devices()) {
      printf("No accelerator found ... exit\n");
      exit(255);
   }
   printf("Number of available GPUs %d\n", omp_get_num_devices());
   #ifdef NVIDIA
      prtAccelInfo();
   #endif
 }
#endif

	// set the local size of the image
	local_grid_size_x = grid_size_x;
	nsectors = NUM_OF_SECTORS;
	if (nsectors < 0) nsectors = size;
	local_grid_size_y = grid_size_y/nsectors;
	//nsectors = size;

	// LOCAL grid size
	xaxis = local_grid_size_x;
	yaxis = local_grid_size_y;

	clock_gettime(CLOCK_MONOTONIC, &begin);
	start = clock();

	// INPUT FILES (only the first ndatasets entries are used)
	int ndatasets = 1;
        //strcpy(datapath_multi[0],"data/newgauss2noconj_t201806301100_SBL180.binMS/");
        //strcpy(datapath_multi[0],"/m100_scratch/userexternal/cgheller/gridding/newgauss4_t201806301100_SBL180.binMS/");
        strcpy(datapath_multi[0],"/m100_scratch/userexternal/cgheller/gridding/Lofar/L798046_SB244_uv.uncorr_130B27932t_146MHz.pre-cal.binMS/");
        //strcpy(datapath_multi[1],"/m100_scratch/userexternal/cgheller/gridding/Lofar/L798046_SB244_uv.uncorr_130B27932t_134MHz.pre-cal.binMS/");

	strcpy(datapath,datapath_multi[0]);
	// Read metadata
	strcpy(filename,datapath);
	strcat(filename,metafile);
	pFile = fopen (filename,"r");
        fscanf(pFile,"%ld",&Nmeasures);
        fscanf(pFile,"%ld",&Nvis);
        fscanf(pFile,"%ld",&freq_per_chan);
        fscanf(pFile,"%ld",&polarisations);
        fscanf(pFile,"%ld",&Ntimes);
        fscanf(pFile,"%lf",&dt);
        fscanf(pFile,"%lf",&thours);
        fscanf(pFile,"%ld",&baselines);
        fscanf(pFile,"%lf",&uvmin);
        fscanf(pFile,"%lf",&uvmax);
        fscanf(pFile,"%lf",&wmin);
        fscanf(pFile,"%lf",&wmax);
	fclose(pFile);


	// WATCH THIS!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
	int nsub = 1000;
	//int nsub = 10;
	printf("Subtracting last %d measurements\n",nsub);
  Nmeasures = Nmeasures-nsub;
	Nvis = Nmeasures*freq_per_chan*polarisations;

  // calculate the coordinates of the center
	double uvshift = uvmin/(uvmax-uvmin);
	//printf("UVSHIFT %f %f %f %f %f\n",uvmin, uvmax, wmin, wmax, uvshift);

	if (rank == 0)
	{
	   printf("N. measurements %ld\n",Nmeasures);
	   printf("N. visibilities %ld\n",Nvis);
	}

  // Set temporary local size of points
	long nm_pe = (long)(Nmeasures/size);
	long remaining = Nmeasures%size;

  long startrow = rank*nm_pe;
  if (rank == size-1)nm_pe = nm_pe+remaining;

	long Nmeasures_tot = Nmeasures;
	Nmeasures = nm_pe;
	long Nvis_tot = Nvis;
	Nvis = Nmeasures*freq_per_chan*polarisations;
	Nweights = Nmeasures*polarisations;

#ifdef VERBOSE
	printf("N. measurements on %d %ld\n",rank,Nmeasures);
	printf("N. visibilities on %d %ld\n",rank,Nvis);
#endif


	// DAV: all these arrays can be allocatate statically for the sake of optimization. However be careful that if MPI is used
	// all the sizes are rescaled by the number of MPI tasks
	// Allocate arrays
	uu = (double*) calloc(Nmeasures,sizeof(double));
	vv = (double*) calloc(Nmeasures,sizeof(double));
	ww = (double*) calloc(Nmeasures,sizeof(double));
	weights = (float*) calloc(Nweights,sizeof(float));
	visreal = (float*) calloc(Nvis,sizeof(float));
	visimg = (float*) calloc(Nvis,sizeof(float));

  if(rank == 0)printf("READING DATA\n");
	// Read data
	strcpy(filename,datapath);
	strcat(filename,ufile);
	//printf("Reading %s\n",filename);

	pFile = fopen (filename,"rb");
	fseek (pFile,startrow*sizeof(double),SEEK_SET);
	fread(uu,Nmeasures*sizeof(double),1,pFile);
	fclose(pFile);

	strcpy(filename,datapath);
	strcat(filename,vfile);
	//printf("Reading %s\n",filename);

  pFile = fopen (filename,"rb");
	fseek (pFile,startrow*sizeof(double),SEEK_SET);
	fread(vv,Nmeasures*sizeof(double),1,pFile);
	fclose(pFile);

	strcpy(filename,datapath);
	strcat(filename,wfile);
	//printf("Reading %s\n",filename);

	pFile = fopen (filename,"rb");
	fseek (pFile,startrow*sizeof(double),SEEK_SET);
	fread(ww,Nmeasures*sizeof(double),1,pFile);
	fclose(pFile);

#ifdef USE_MPI
	MPI_Barrier(MPI_COMM_WORLD);
#endif

	clock_gettime(CLOCK_MONOTONIC, &finish);
	end = clock();
	setup_time = ((double) (end - start)) / CLOCKS_PER_SEC;
	setup_time1 = (finish.tv_sec - begin.tv_sec);
	setup_time1 += (finish.tv_nsec - begin.tv_nsec) / 1000000000.0;


  if(rank == 0)printf("GRIDDING DATA\n");

	// Create histograms and linked lists
  clock_gettime(CLOCK_MONOTONIC, &begin);
  start = clock();

	// Initialize linked list
	struct sectorlist ** sectorhead;
	sectorhead = (struct sectorlist **) malloc((nsectors+1) * sizeof(struct sectorlist));
	for (int isec=0; isec<=nsectors; isec++)
	{
		sectorhead[isec] = malloc(sizeof(struct sectorlist));
		sectorhead[isec]->index = -1;
		sectorhead[isec]->next = NULL;
	}


	long * histo_send = (long*) calloc(nsectors+1,sizeof(long));
	int * boundary = (int*) calloc(Nmeasures,sizeof(int));
	double uuh,vvh;
	for (long iphi = 0; iphi < Nmeasures; iphi++)
	{
     boundary[iphi] = -1;
	   uuh = uu[iphi];
	   vvh = vv[iphi];
	   int binphi = (int)(vvh*nsectors);
	   // check if the point influence also neighboring slabs
	   double updist = (double)((binphi+1)*yaxis)*dx - vvh;
	   double downdist = vvh - (double)(binphi*yaxis)*dx;
	   //
	   histo_send[binphi]++;
     Push(&sectorhead[binphi],iphi);
     if(updist < w_supporth && updist >= 0.0) {histo_send[binphi+1]++; boundary[iphi] = binphi+1; Push(&sectorhead[binphi+1],iphi);};
	   if(downdist < w_supporth && binphi > 0 && downdist >= 0.0) {histo_send[binphi-1]++; boundary[iphi] = binphi-1; Push(&sectorhead[binphi-1],iphi);};
	}

#ifdef PIPPO
	struct sectorlist * current;
	long iiii = 0;
	for (int j=0; j<nsectors; j++)
	{
		current = sectorhead[j];
		iiii = 0;
		while (current->index != -1)
		{
			printf("%d %d %ld %ld %ld\n",rank,j,iiii,histo_send[j],current->index);
			current = current->next;
			iiii++;
		}
	}
#endif

#ifdef VERBOSE
  for (int iii=0; iii<nsectors+1; iii++)printf("HISTO %d %d %ld\n",rank, iii, histo_send[iii]);
#endif

// Create sector grid
  double * gridss;
  double * gridss_w;
  double * gridss_real;
  double * gridss_img;
  double * grid;
  long size_of_grid;
  size_of_grid = 2*num_w_planes*xaxis*yaxis;
  gridss = (double*) calloc(size_of_grid,sizeof(double));
  gridss_w = (double*) calloc(size_of_grid,sizeof(double));
  gridss_real = (double*) calloc(size_of_grid/2,sizeof(double));
  gridss_img = (double*) calloc(size_of_grid/2,sizeof(double));
  // Create destination slab
  grid = (double*) calloc(size_of_grid,sizeof(double));
  // Create temporary global grid
#ifndef USE_MPI
  double * gridtot = (double*) calloc(2*grid_size_x*grid_size_y*num_w_planes,sizeof(double));
#endif
  double shift = (double)(dx*yaxis);
  // Open the MPI Memory Window for the slab
#ifdef USE_MPI
  MPI_Win slabwin;
  MPI_Win_create(grid, size_of_grid*sizeof(double), sizeof(double), MPI_INFO_NULL, MPI_COMM_WORLD, &slabwin);
  MPI_Win_fence(0,slabwin);
#endif
#ifndef USE_MPI
  pFile1 = fopen (outfile1,"w");
#endif


	// loop over files
	//
	kernel_time = 0.0;
	kernel_time1 = 0.0;
	reduce_time = 0.0;
	reduce_time1 = 0.0;
	compose_time = 0.0;
	compose_time1 = 0.0;

	// MAIN LOOP OVER FILES
	//
	for (int ifiles=0; ifiles<ndatasets; ifiles++)
	{
	strcpy(filename,datapath_multi[ifiles]);
        printf("Processing %s, %d of %d\n",filename,ifiles+1,ndatasets);

        // Read metadata
        strcpy(filename,datapath);
        strcat(filename,metafile);
        pFile = fopen (filename,"r");
        fscanf(pFile,"%ld",&Nmeasures0);
        fscanf(pFile,"%ld",&Nvis0);
        fscanf(pFile,"%ld",&freq_per_chan0);
        fscanf(pFile,"%ld",&polarisations0);
        fscanf(pFile,"%ld",&Ntimes0);
        fscanf(pFile,"%lf",&dt0);
        fscanf(pFile,"%lf",&thours0);
        fscanf(pFile,"%ld",&baselines0);
        fscanf(pFile,"%lf",&uvmin);
        fscanf(pFile,"%lf",&uvmax);
        fscanf(pFile,"%lf",&wmin0);
        fscanf(pFile,"%lf",&wmax0);
        fclose(pFile);

        // calculate the resolution in radians
        resolution = 1.0/MAX(abs(uvmin),abs(uvmax));
        // calculate the resolution in arcsec
        double resolution_asec = (3600.0*180.0)/MAX(abs(uvmin),abs(uvmax))/PI;
        printf("RESOLUTION = %f rad, %f arcsec\n", resolution, resolution_asec);

        strcpy(filename,datapath);
        strcat(filename,weightsfile);
        pFile = fopen (filename,"rb");
        fseek (pFile,startrow*polarisations*sizeof(float),SEEK_SET);
        fread(weights,(Nweights)*sizeof(float),1,pFile);
        fclose(pFile);

        strcpy(filename,datapath);
        strcat(filename,visrealfile);

  pFile = fopen (filename,"rb");
  fseek (pFile,startrow*freq_per_chan*polarisations*sizeof(float),SEEK_SET);
  fread(visreal,Nvis*sizeof(float),1,pFile);
  fclose(pFile);
  strcpy(filename,datapath);
  strcat(filename,visimgfile);
#ifdef VERBOSE
  printf("Reading %s\n",filename);
#endif
  pFile = fopen (filename,"rb");
  fseek (pFile,startrow*freq_per_chan*polarisations*sizeof(float),SEEK_SET);
  fread(visimg,Nvis*sizeof(float),1,pFile);
  fclose(pFile);

#ifdef USE_MPI
  MPI_Barrier(MPI_COMM_WORLD);
#endif
  // Declare temporary arrays for the masking
  double * uus;
  double * vvs;
  double * wws;
  float * visreals;
  float * visimgs;
  float * weightss;
	long isector;

  for (long isector_count=0; isector_count<nsectors; isector_count++)
      {
        clock_gettime(CLOCK_MONOTONIC, &begink);
        startk = clock();
        // define local destination sector
        //isector = (isector_count+rank)%size;
        isector = isector_count;
	  // allocate sector arrays
        long Nsec = histo_send[isector];
	      uus = (double*) malloc(Nsec*sizeof(double));
	      vvs = (double*) malloc(Nsec*sizeof(double));
	      wws = (double*) malloc(Nsec*sizeof(double));
	      long Nweightss = Nsec*polarisations;
	      long Nvissec = Nweightss*freq_per_chan;
	      weightss = (float*) malloc(Nweightss*sizeof(float));
	      visreals = (float*) malloc(Nvissec*sizeof(float));
	      visimgs = (float*) malloc(Nvissec*sizeof(float));

	  // select data for this sector
        long icount = 0;
	      long ip = 0;
	      long inu = 0;
//CLAAAA
	       struct sectorlist * current;
	       current = sectorhead[isector];



	       while (current->index != -1)
          {
             long ilocal = current->index;
      	     //double vvh = vv[ilocal];
             //int binphi = (int)(vvh*nsectors);
	           //if (binphi == isector || boundary[ilocal] == isector) {
	           uus[icount] = uu[ilocal];
	           vvs[icount] = vv[ilocal]-isector*shift;
	           wws[icount] = ww[ilocal];
             for (long ipol=0; ipol<polarisations; ipol++)
                {
		              weightss[ip] = weights[ilocal*polarisations+ipol];
                  ip++;
                }
             for (long ifreq=0; ifreq<polarisations*freq_per_chan; ifreq++)
                {
	                 visreals[inu] = visreal[ilocal*polarisations*freq_per_chan+ifreq];
	                 visimgs[inu] = visimg[ilocal*polarisations*freq_per_chan+ifreq];
		               //if(visimgs[inu]>1e10 || visimgs[inu]<-1e10)printf("%f %f %ld %ld %d %ld %ld\n",visreals[inu],visimgs[inu],inu,Nvissec,rank,ilocal*polarisations*freq_per_chan+ifreq,Nvis);
	                 inu++;
	              }
	           icount++;
	     //}
	            current = current->next;
             }

	  clock_gettime(CLOCK_MONOTONIC, &finishk);
	  endk = clock();
	  compose_time += ((double) (endk - startk)) / CLOCKS_PER_SEC;
	  compose_time1 += (finishk.tv_sec - begink.tv_sec);
	  compose_time1 += (finishk.tv_nsec - begink.tv_nsec) / 1000000000.0;

    #ifndef USE_MPI
	  double uumin = 1e20;
	  double vvmin = 1e20;
	  double uumax = -1e20;
	  double vvmax = -1e20;

          for (long ipart=0; ipart<Nsec; ipart++)
          {
	       uumin = MIN(uumin,uus[ipart]);
	       uumax = MAX(uumax,uus[ipart]);
	       vvmin = MIN(vvmin,vvs[ipart]);
	       vvmax = MAX(vvmax,vvs[ipart]);


               if(ipart%10 == 0)fprintf (pFile, "%ld %f %f %f\n",isector,uus[ipart],vvs[ipart]+isector*shift,wws[ipart]);
          }

	  printf("UU, VV, min, max = %f %f %f %f\n", uumin, uumax, vvmin, vvmax);
          #endif

          // Make convolution on the grid
          #ifdef VERBOSE
	  printf("Processing sector %ld\n",isector);
	  #endif
          clock_gettime(CLOCK_MONOTONIC, &begink);
          startk = clock();

          wstack(num_w_planes,
               Nsec,
               freq_per_chan,
               polarisations,
               uus,
               vvs,
               wws,
               visreals,
               visimgs,
               weightss,
               dx,
               dw,
               w_support,
	             xaxis,
	             yaxis,
               gridss,
               num_threads);


/* int z =0 ;
#pragma omp target map(to:test_i_gpu) map(from:z)
{
  int x; // only accessible from accelerator
  x = 2;
  z = x + test_i_gpu;
}*/



	  clock_gettime(CLOCK_MONOTONIC, &finishk);
	  endk = clock();
	  kernel_time += ((double) (endk - startk)) / CLOCKS_PER_SEC;
	  kernel_time1 += (finishk.tv_sec - begink.tv_sec);
	  kernel_time1 += (finishk.tv_nsec - begink.tv_nsec) / 1000000000.0;
          #ifdef VERBOSE
	  printf("Processed sector %ld\n",isector);
          #endif
          clock_gettime(CLOCK_MONOTONIC, &begink);
          startk = clock();

          //for (long iii=0; iii<2*xaxis*yaxis*num_w_planes; iii++)printf("--> %f\n",gridss[iii]);

          #ifndef USE_MPI
	  long stride = isector*2*xaxis*yaxis*num_w_planes;
          for (long iii=0; iii<2*xaxis*yaxis*num_w_planes; iii++)gridtot[stride+iii] = gridss[iii];
	  #endif

	  // Write grid in the corresponding remote slab
          #ifdef USE_MPI
	  int target_rank = (int)isector;
	  //int target_rank = (int)(size-isector-1);
	  #ifdef ONE_SIDE
	  printf("One Side communication active\n");
	  MPI_Win_lock(MPI_LOCK_SHARED,target_rank,0,slabwin);
	  MPI_Accumulate(gridss,size_of_grid,MPI_DOUBLE,target_rank,0,size_of_grid,MPI_DOUBLE,MPI_SUM,slabwin);
	  MPI_Win_unlock(target_rank,slabwin);
	  //MPI_Put(gridss,size_of_grid,MPI_DOUBLE,target_rank,0,size_of_grid,MPI_DOUBLE,slabwin);
          #else
          MPI_Reduce(gridss,grid,size_of_grid,MPI_DOUBLE,MPI_SUM,target_rank,MPI_COMM_WORLD);
          #endif //ONE_SIDE
          #endif //USE_MPI

	  clock_gettime(CLOCK_MONOTONIC, &finishk);
	  endk = clock();
	  reduce_time += ((double) (endk - startk)) / CLOCKS_PER_SEC;
	  reduce_time1 += (finishk.tv_sec - begink.tv_sec);
	  reduce_time1 += (finishk.tv_nsec - begink.tv_nsec) / 1000000000.0;
          // Go to next sector
	  for (long inull=0; inull<2*num_w_planes*xaxis*yaxis; inull++)gridss[inull] = 0.0;

	  // Deallocate all sector arrays
	  free(uus);
	  free(vvs);
	  free(wws);
	  free(weightss);
	  free(visreals);
	  free(visimgs);
        // End of loop over sectors
        }
	// End of loop over input files
	}

	// Finalize MPI communication
        #ifdef USE_MPI
	MPI_Win_fence(0,slabwin);
	#endif

        // Swap left and right parts APPARENTLY NOT NECESSARY
	/*
	for (long kswap=0; kswap<num_w_planes; kswap++)
	for (long jswap=0; jswap<yaxis; jswap++)
	for (long iswap=0; iswap<xaxis; iswap++)
	{
           long index_origin = 2*(iswap + jswap*xaxis + kswap*yaxis*xaxis);
	   gridss[index_origin] = grid[index_origin];
	   gridss[index_origin+1] = grid[index_origin+1];
	}
	for (long kswap=0; kswap<num_w_planes; kswap++)
	for (long jswap=0; jswap<yaxis; jswap++)
	for (long iswap=0; iswap<xaxis/2; iswap++)
	{
           long index_origin = 2*(iswap + jswap*xaxis + kswap*yaxis*xaxis);
	   long index_destination = 2*(iswap+xaxis/2 + jswap*xaxis + kswap*yaxis*xaxis);
	   grid[index_destination] = gridss[index_origin];
	   grid[index_destination+1] = gridss[index_origin+1];
	}

	for (long kswap=0; kswap<num_w_planes; kswap++)
	for (long jswap=0; jswap<yaxis; jswap++)
	for (long iswap=xaxis/2; iswap<xaxis; iswap++)
	{
           long index_origin = 2*(iswap + jswap*xaxis + kswap*yaxis*xaxis);
	   long index_destination = 2*(iswap-xaxis/2 + jswap*xaxis + kswap*yaxis*xaxis);
	   grid[index_destination] = gridss[index_origin];
	   grid[index_destination+1] = gridss[index_origin+1];
	}
	*/

        #ifndef USE_MPI
        fclose(pFile1);
        #endif
        #ifdef USE_MPI
	MPI_Barrier(MPI_COMM_WORLD);
	#endif

        end = clock();
        clock_gettime(CLOCK_MONOTONIC, &finish);
        process_time = ((double) (end - start)) / CLOCKS_PER_SEC;
        process_time1 = (finish.tv_sec - begin.tv_sec);
        process_time1 += (finish.tv_nsec - begin.tv_nsec) / 1000000000.0;
        clock_gettime(CLOCK_MONOTONIC, &begin);


#ifdef WRITE_DATA
	// Write results
	if (rank == 0)
	{
          printf("WRITING GRIDDED DATA\n");
          pFilereal = fopen (outfile2,"wb");
          pFileimg = fopen (outfile3,"wb");
	  #ifdef USE_MPI
	  for (int isector=0; isector<nsectors; isector++)
          {
	      MPI_Win_lock(MPI_LOCK_SHARED,isector,0,slabwin);
	      MPI_Get(gridss,size_of_grid,MPI_DOUBLE,isector,0,size_of_grid,MPI_DOUBLE,slabwin);
	      MPI_Win_unlock(isector,slabwin);
	      for (long i=0; i<size_of_grid/2; i++)
	      {
		      gridss_real[i] = gridss[2*i];
		      gridss_img[i] = gridss[2*i+1];
	      }
	      if (num_w_planes > 1)
	      {
                for (int iw=0; iw<num_w_planes; iw++)
                for (int iv=0; iv<yaxis; iv++)
                for (int iu=0; iu<xaxis; iu++)
                {
			  long global_index = (iu + (iv+isector*yaxis)*xaxis + iw*grid_size_x*grid_size_y)*sizeof(double);
                          long index = iu + iv*xaxis + iw*xaxis*yaxis;
			  fseek(pFilereal, global_index, SEEK_SET);
			  fwrite(&gridss_real[index], 1, sizeof(double), pFilereal);
                }
                for (int iw=0; iw<num_w_planes; iw++)
                for (int iv=0; iv<yaxis; iv++)
                for (int iu=0; iu<xaxis; iu++)
                {
                          long global_index = (iu + (iv+isector*yaxis)*xaxis + iw*grid_size_x*grid_size_y)*sizeof(double);
                          long index = iu + iv*xaxis + iw*xaxis*yaxis;
                          fseek(pFileimg, global_index, SEEK_SET);
                          fwrite(&gridss_img[index], 1, sizeof(double), pFileimg);
                          //double v_norm = sqrt(gridss[index]*gridss[index]+gridss[index+1]*gridss[index+1]);
                          //fprintf (pFile, "%d %d %d %f %f %f\n", iu,isector*yaxis+iv,iw,gridss[index],gridss[index+1],v_norm);
                }

	      } else {
		for (int iw=0; iw<num_w_planes; iw++)
		{
			  long global_index = (xaxis*isector*yaxis + iw*grid_size_x*grid_size_y)*sizeof(double);
                          long index = iw*xaxis*yaxis;
			  fseek(pFilereal, global_index, SEEK_SET);
                          fwrite(&gridss_real[index], xaxis*yaxis, sizeof(double), pFilereal);
                          fseek(pFileimg, global_index, SEEK_SET);
                          fwrite(&gridss_img[index], xaxis*yaxis, sizeof(double), pFileimg);
		}
	      }
          }
	  #else
          for (int iw=0; iw<num_w_planes; iw++)
            for (int iv=0; iv<grid_size_y; iv++)
               for (int iu=0; iu<grid_size_x; iu++)
               {
                          long index = 2*(iu + iv*grid_size_x + iw*grid_size_x*grid_size_y);
		          fwrite(&gridtot[index], 1, sizeof(double), pFilereal);
		          fwrite(&gridtot[index+1], 1, sizeof(double), pFileimg);
                          //double v_norm = sqrt(gridtot[index]*gridtot[index]+gridtot[index+1]*gridtot[index+1]);
                          //fprintf (pFile, "%d %d %d %f %f %f\n", iu,iv,iw,gridtot[index],gridtot[index+1],v_norm);
               }
          #endif
          fclose(pFilereal);
          fclose(pFileimg);
	}

        #ifdef USE_MPI
        MPI_Win_fence(0,slabwin);
        #endif
#endif //WRITE_DATA


#ifdef USE_FFTW
	// FFT transform the data (using distributed FFTW)

	if(rank == 0)printf("PERFORMING FFT\n");
        clock_gettime(CLOCK_MONOTONIC, &begin);
        start = clock();
        fftw_plan plan;
        fftw_complex *fftwgrid;
	ptrdiff_t alloc_local, local_n0, local_0_start;
	double norm = 1.0/(double)(grid_size_x*grid_size_y);

        // map the 1D array of complex visibilities to a 2D array required by FFTW (complex[*][2])
	// x is the direction of contiguous data and maps to the second parameter
	// y is the parallelized direction and corresponds to the first parameter (--> n0)
	// and perform the FFT per w plane
	alloc_local = fftw_mpi_local_size_2d(grid_size_y, grid_size_x, MPI_COMM_WORLD,&local_n0, &local_0_start);
	fftwgrid = fftw_alloc_complex(alloc_local);
	plan = fftw_mpi_plan_dft_2d(grid_size_y, grid_size_x, fftwgrid, fftwgrid, MPI_COMM_WORLD, FFTW_BACKWARD, FFTW_ESTIMATE);

	long fftwindex = 0;
	long fftwindex2D = 0;
	for (int iw=0; iw<num_w_planes; iw++)
        {
            //printf("FFTing plan %d\n",iw);
            // select the w-plane to transform
            for (int iv=0; iv<yaxis; iv++)
            {
               for (int iu=0; iu<xaxis; iu++)
               {
		   fftwindex2D = iu + iv*xaxis;
		   fftwindex = 2*(fftwindex2D + iw*xaxis*yaxis);
                   fftwgrid[fftwindex2D][0] = grid[fftwindex];
                   fftwgrid[fftwindex2D][1] = grid[fftwindex+1];
	       }
	    }

            // do the transform for each w-plane
	    fftw_execute(plan);

	    // save the transformed w-plane
            for (int iv=0; iv<yaxis; iv++)
            {
               for (int iu=0; iu<xaxis; iu++)
               {
		   fftwindex2D = iu + iv*xaxis;
		   fftwindex = 2*(fftwindex2D + iw*xaxis*yaxis);
                   gridss[fftwindex] = norm*fftwgrid[fftwindex2D][0];
                   gridss[fftwindex+1] = norm*fftwgrid[fftwindex2D][1];
               }
            }

	}

        fftw_destroy_plan(plan);

        #ifdef USE_MPI
        MPI_Win_fence(0,slabwin);
	MPI_Barrier(MPI_COMM_WORLD);
        #endif

        end = clock();
        clock_gettime(CLOCK_MONOTONIC, &finish);
        fftw_time = ((double) (end - start)) / CLOCKS_PER_SEC;
        fftw_time1 = (finish.tv_sec - begin.tv_sec);
        fftw_time1 += (finish.tv_nsec - begin.tv_nsec) / 1000000000.0;
        clock_gettime(CLOCK_MONOTONIC, &begin);

#ifdef WRITE_DATA
        // Write results
        #ifdef USE_MPI
	MPI_Win writewin;
        MPI_Win_create(gridss, size_of_grid*sizeof(double), sizeof(double), MPI_INFO_NULL, MPI_COMM_WORLD, &writewin);
	MPI_Win_fence(0,writewin);
        #endif
        if (rank == 0)
        {
          printf("WRITING FFT TRANSFORMED DATA\n");
          pFilereal = fopen (fftfile2,"wb");
          pFileimg = fopen (fftfile3,"wb");
          #ifdef USE_MPI
          for (int isector=0; isector<nsectors; isector++)
          {
              MPI_Win_lock(MPI_LOCK_SHARED,isector,0,writewin);
              MPI_Get(gridss_w,size_of_grid,MPI_DOUBLE,isector,0,size_of_grid,MPI_DOUBLE,writewin);
              MPI_Win_unlock(isector,writewin);
	      for (long i=0; i<size_of_grid/2; i++)
	      {
		      gridss_real[i] = gridss_w[2*i];
		      gridss_img[i] = gridss_w[2*i+1];
	      }
	      if (num_w_planes > 1)
	      {
                for (int iw=0; iw<num_w_planes; iw++)
                for (int iv=0; iv<yaxis; iv++)
                for (int iu=0; iu<xaxis; iu++)
                {
			  long global_index = (iu + (iv+isector*yaxis)*xaxis + iw*grid_size_x*grid_size_y)*sizeof(double);
                          long index = iu + iv*xaxis + iw*xaxis*yaxis;
			  fseek(pFilereal, global_index, SEEK_SET);
			  fwrite(&gridss_real[index], 1, sizeof(double), pFilereal);
                }
                for (int iw=0; iw<num_w_planes; iw++)
                for (int iv=0; iv<yaxis; iv++)
                for (int iu=0; iu<xaxis; iu++)
                {
                          long global_index = (iu + (iv+isector*yaxis)*xaxis + iw*grid_size_x*grid_size_y)*sizeof(double);
                          long index = iu + iv*xaxis + iw*xaxis*yaxis;
                          fseek(pFileimg, global_index, SEEK_SET);
                          fwrite(&gridss_img[index], 1, sizeof(double), pFileimg);
                }
              } else {
                          fwrite(gridss_real, size_of_grid/2, sizeof(double), pFilereal);
                          fwrite(gridss_img, size_of_grid/2, sizeof(double), pFileimg);
              }


          }
          #else
          /*
	  for (int iw=0; iw<num_w_planes; iw++)
            for (int iv=0; iv<grid_size_y; iv++)
               for (int iu=0; iu<grid_size_x; iu++)
               {
                          int isector = 0;
                          long index = 2*(iu + iv*grid_size_x + iw*grid_size_x*grid_size_y);
                          double v_norm = sqrt(gridtot[index]*gridtot[index]+gridtot[index+1]*gridtot[index+1]);
                          fprintf (pFile, "%d %d %d %f %f %f\n", iu,iv,iw,gridtot[index],gridtot[index+1],v_norm);
               }
	  */
          #endif
          fclose(pFilereal);
          fclose(pFileimg);
        }

        #ifdef USE_MPI
        //MPI_Win_fence(0,writewin);
        MPI_Win_fence(0,writewin);
	MPI_Win_free(&writewin);
	MPI_Barrier(MPI_COMM_WORLD);
        #endif
#endif //WRITE_DATA

	fftw_free(fftwgrid);

	// Phase correction
        clock_gettime(CLOCK_MONOTONIC, &begin);
        start = clock();
	if(rank == 0)printf("PHASE CORRECTION\n");
        double* image_real = (double*) calloc(xaxis*yaxis,sizeof(double));	
        double* image_imag = (double*) calloc(xaxis*yaxis,sizeof(double));	

        phase_correction(gridss,image_real,image_imag,xaxis,yaxis,num_w_planes,grid_size_x,grid_size_y,resolution,wmin,wmax,num_threads);

        end = clock();
        clock_gettime(CLOCK_MONOTONIC, &finish);
        phase_time = ((double) (end - start)) / CLOCKS_PER_SEC;
        phase_time1 = (finish.tv_sec - begin.tv_sec);
        phase_time1 += (finish.tv_nsec - begin.tv_nsec) / 1000000000.0;
#ifdef WRITE_IMAGE

        if(rank == 0)
        {
            pFilereal = fopen (fftfile2,"wb");
            pFileimg = fopen (fftfile3,"wb");
            fclose(pFilereal);
            fclose(pFileimg);
        }
	#ifdef USE_MPI
	MPI_Barrier(MPI_COMM_WORLD);
        #endif
        if(rank == 0)printf("WRITING IMAGE\n");
	for (int isector=0; isector<size; isector++)
	{
	    #ifdef USE_MPI
	    MPI_Barrier(MPI_COMM_WORLD);
            #endif
	    if(isector == rank)
            {
	       printf("%d writing\n",isector);
               pFilereal = fopen (fftfile2,"ab");
               pFileimg = fopen (fftfile3,"ab");

	       long global_index = isector*(xaxis*yaxis)*sizeof(double);

               fseek(pFilereal, global_index, SEEK_SET);
               fwrite(image_real, xaxis*yaxis, sizeof(double), pFilereal);
               fseek(pFileimg, global_index, SEEK_SET);
               fwrite(image_imag, xaxis*yaxis, sizeof(double), pFileimg);

               fclose(pFilereal);
               fclose(pFileimg);
	    }
	}
	#ifdef USE_MPI
	MPI_Barrier(MPI_COMM_WORLD);
        #endif

#endif //WRITE_IMAGE


#endif //USE_FFTW

	end = clock();
	clock_gettime(CLOCK_MONOTONIC, &finish);
	tot_time = ((double) (end - start0)) / CLOCKS_PER_SEC;
	tot_time1 = (finish.tv_sec - begin0.tv_sec);
	tot_time1 += (finish.tv_nsec - begin0.tv_nsec) / 1000000000.0;

        if (rank == 0)
        {
          printf("Setup time:    %f sec\n",setup_time);
          printf("Process time:  %f sec\n",process_time);
          printf("Kernel time = %f, Array Composition time %f, Reduce time: %f sec\n",kernel_time,compose_time,reduce_time);
#ifdef USE_FFTW
          printf("FFTW time:     %f sec\n",fftw_time);
          printf("Phase time:    %f sec\n",phase_time);
#endif
          printf("TOT time:      %f sec\n",tot_time);
	  if(num_threads > 1)
          {
            printf("PSetup time:   %f sec\n",setup_time1);
            printf("PProcess time: %f sec\n",process_time1);
            printf("PKernel time = %f, PArray Composition time %f, PReduce time: %f sec\n",kernel_time1,compose_time1,reduce_time1);
#ifdef USE_FFTW
            printf("PFFTW time:    %f sec\n",fftw_time1);
            printf("PPhase time:   %f sec\n",phase_time1);
#endif
            printf("PTOT time:     %f sec\n",tot_time1);
	  }
        }

	if (rank == 0)
	{
	 pFile = fopen (timingfile,"w");
	 if (num_threads == 1)
         {
	   fprintf(pFile, "%f %f %f %f %f %f %f\n",setup_time,kernel_time,compose_time,reduce_time,fftw_time,phase_time,tot_time);
	 } else {
	   fprintf(pFile, "%f %f %f %f %f %f %f\n",setup_time1,kernel_time1,compose_time1,reduce_time1,fftw_time1,phase_time1,tot_time1);
	 }  
	 fclose(pFile);
	} 

	// Close MPI environment
	#ifdef USE_MPI
        MPI_Win_fence(0,slabwin);
	MPI_Win_free(&slabwin);
	MPI_Finalize();
	#endif
}