diff --git a/Makefile b/Makefile
index 0dcde72606c8d3260e7fbde2e13d0b5e9c0ee53e..99d13891512a600974a7a33c5eb43683eaf22e3f 100644
--- a/Makefile
+++ b/Makefile
@@ -10,7 +10,7 @@ OPT += -DONE_SIDE
 # write the final image
 OPT += -DWRITE_IMAGE
 # perform w-stacking phase correction
-OPT += -DPHASE_ON
+#OPT += -DPHASE_ON
 
 CC = gcc
 CXX = g++
@@ -46,7 +46,7 @@ serial: $(COBJ)
 	$(CC) $(OMP) -o w-stackingCfftw_serial $(CFLAGS) $^ -lm
 
 serial_cuda:
-	$(NVCC) $(NVFLAGS) -c w-stacking.cu phase_correction.cu $(NVLIB)
+	$(NVCC) $(OPT) $(NVFLAGS) -c w-stacking.cu phase_correction.cu $(NVLIB)
 	$(CC) $(CFLAGS) $(OPT) -c w-stacking-fftw.c
 	$(CXX) $(CFLAGS) $(OPT) -o w-stackingfftw_serial w-stacking-fftw.o w-stacking.o phase_correction.o $(NVLIB) -lm
 
@@ -54,7 +54,7 @@ mpi: $(COBJ)
 	$(CC) $(OMP) -o w-stackingCfftw $(CFLAGS) $^ $(LIBS)
 
 mpi_cuda:
-	$(NVCC) $(NVFLAGS) -c w-stacking.cu phase_correction.cu $(NVLIB)
+	$(NVCC) $(NVFLAGS) $(OPT) -c w-stacking.cu phase_correction.cu $(NVLIB)
 	$(CC) $(CFLAGS) $(OPT) -c w-stacking-fftw.c
 	$(CXX) $(CFLAGS) $(OPT) -o w-stackingfftw w-stacking-fftw.o w-stacking.o phase_correction.o $(NVLIB) $(LIBS) -lm
 
diff --git a/phase_correction.cu b/phase_correction.cu
index 77f5ad25a9197dca27a2ae3e6a567f3000fec583..b5814718a480f7375b0e5d1b71e3a6b4b69403b7 100644
--- a/phase_correction.cu
+++ b/phase_correction.cu
@@ -7,19 +7,126 @@
 #include <stdio.h>
 
 #ifdef __CUDACC__
+
+__global__ void phase_g(int xaxis, 
+		        int yaxis,
+			int num_w_planes,
+			double * gridss,
+			double * image_real,
+			double * image_imag,
+			double wmin,
+			double dw,
+			double dwnorm,
+			int xaxistot,
+			int yaxistot,
+			double resolution)
+{
+	long gid = blockIdx.x*blockDim.x + threadIdx.x;
+	double add_term_real;
+	double add_term_img;
+	double wterm;
+	long arraysize = xaxis*yaxis*num_w_planes;
+
+	if(gid < arraysize)
+	{
+		int iw = (int)(gid/(xaxis*yaxis));
+		int iv = (int)((gid%(xaxis*yaxis))/xaxis);
+		int iu = (iv%yaxis);
+		long index = 2*gid;
+		long img_index = iu+iv*xaxis;
+
+                wterm = wmin + iw*dw;
+
+#ifdef PHASE_ON
+                if (num_w_planes > 1)
+                {
+                    double xcoord = (double)(iu-xaxistot/2);
+                    if(xcoord < 0.0)xcoord = (double)(iu+xaxistot/2);
+                    xcoord = sin(xcoord*resolution);
+                    double ycoord = (double)(iv-yaxistot/2);
+                    if(ycoord < 0.0)ycoord = (double)(iv+yaxistot/2);
+                    ycoord = sin(ycoord*resolution);
+
+                    double preal, pimag;
+                    double radius2 = (xcoord*xcoord+ycoord*ycoord);
+
+                    preal = cos(2.0*PI*wterm*(sqrt(1-radius2)-1.0));
+                    pimag = sin(2.0*PI*wterm*(sqrt(1-radius2)-1.0));
+
+                    double p,q,r,s;
+                    p = gridss[index];
+                    q = gridss[index+1];
+                    r = preal;
+                    s = pimag;
+
+                    //printf("%d %d %d %ld %ld\n",iu,iv,iw,index,img_index);
+
+		    add_term_real = (p*r-q*s)*dwnorm*sqrt(1-radius2);
+		    add_term_img = (p*s+q*r)*dwnorm*sqrt(1-radius2);
+		    atomicAdd(&(image_real[img_index]),add_term_real);
+		    atomicAdd(&(image_imag[img_index]),add_term_img);
+                } else {
+		    atomicAdd(&(image_real[img_index]),gridss[index]);
+		    atomicAdd(&(image_imag[img_index]),gridss[index+1]);
+                }
+#else
+		atomicAdd(&(image_real[img_index]),gridss[index]);
+		atomicAdd(&(image_imag[img_index]),gridss[index+1]);
+#endif // end of PHASE_ON
+
+	}
+
+}
+
 #endif
 
 void phase_correction(double* gridss, double* image_real, double* image_imag, int xaxis, int yaxis, int num_w_planes, int xaxistot, int yaxistot,
 		      double resolution, double wmin, double wmax)
 {
-	double dnum_w_planes = (double)num_w_planes;
-	double dxaxistot = (double)xaxistot;
-	double dyaxistot = (double)yaxistot;
-	double diagonal;
         double dw = (wmax-wmin)/(double)num_w_planes;
 	double wterm = wmin+0.5*dw;
 	double dwnorm = dw/(wmax-wmin);
 
+#ifdef __CUDACC__
+
+        int Nth = NTHREADS;
+        long Nbl = (long)((num_w_planes*xaxis*yaxis)/Nth) + 1;
+        if(NWORKERS == 1) {Nbl = 1; Nth = 1;};
+        printf("Running on GPU with %d threads and %d blocks\n",Nth,Nbl);
+	
+
+	cudaError_t mmm;
+	double * image_real_g;
+	double * image_imag_g;
+	double * gridss_g;
+
+        mmm=cudaMalloc(&gridss_g, 2*num_w_planes*xaxis*yaxis*sizeof(double));
+	mmm=cudaMalloc(&image_real_g, xaxis*yaxis*sizeof(double));
+	mmm=cudaMalloc(&image_imag_g, xaxis*yaxis*sizeof(double));
+
+	mmm=cudaMemcpy(gridss_g, gridss, 2*num_w_planes*xaxis*yaxis*sizeof(double), cudaMemcpyHostToDevice);
+	mmm=cudaMemset(image_real_g, 0.0, xaxis*yaxis*sizeof(double));
+	mmm=cudaMemset(image_imag_g, 0.0, xaxis*yaxis*sizeof(double));
+
+	// call the phase correction kernel
+	phase_g <<<Nbl,Nth>>> (xaxis,
+                               yaxis,
+			       num_w_planes,
+                               gridss_g,
+                               image_real_g,
+                               image_imag_g,
+                               wmin,
+                               dw,
+                               dwnorm,
+                               xaxistot,
+                               yaxistot,
+                               resolution);
+
+	mmm = cudaMemcpy(image_real, image_real_g, xaxis*yaxis*sizeof(double), cudaMemcpyDeviceToHost);
+	mmm = cudaMemcpy(image_imag, image_imag_g, xaxis*yaxis*sizeof(double), cudaMemcpyDeviceToHost);
+
+#else
+
 	for (int iw=0; iw<num_w_planes; iw++)
 	{
 	    for (int iv=0; iv<yaxis; iv++)
@@ -38,7 +145,6 @@ void phase_correction(double* gridss, double* image_real, double* image_imag, in
                     if(ycoord < 0.0)ycoord = (double)(iv+yaxistot/2);
 		    ycoord = sin(ycoord*resolution);
 
-		    double efact;
 		    double preal, pimag;
 		    double radius2 = (xcoord*xcoord+ycoord*ycoord);
 
@@ -61,12 +167,13 @@ void phase_correction(double* gridss, double* image_real, double* image_imag, in
 #else
   	        image_real[img_index] += gridss[index];
 		image_imag[img_index] += gridss[index+1];
-#endif
+#endif // end of PHASE_ON
 
             }  
 	    wterm += dw;
 	}
 
+#endif // end of __CUDACC__
 
 
 }
diff --git a/w-stacking-fftw.c b/w-stacking-fftw.c
index 55872f61fc2935d8fbcba3e607a696d7b7397042..07f25becd5e607419fcb767ca57d4b4c9e97f714 100644
--- a/w-stacking-fftw.c
+++ b/w-stacking-fftw.c
@@ -78,24 +78,24 @@ int main(int argc, char * argv[])
 	float * visreal;
 	float * visimg;
 
-	long Nmeasures;
-	long Nvis;
-	long Nweights;
-	long freq_per_chan;
-	long polarisations;
-        long Ntimes;
-	double dt;
-	double thours;
-	long baselines;
-	double uvmin;
-	double uvmax;
-	double wmin;
-	double wmax;
+	long Nmeasures,Nmeasures0;
+	long Nvis,Nvis0;
+	long Nweights,Nweights0;
+	long freq_per_chan,freq_per_chan0;
+	long polarisations,polarisations0;
+        long Ntimes,Ntimes0;
+	double dt,dt0;
+	double thours,thours0;
+	long baselines,baselines0;
+	double uvmin,uvmin0;
+	double uvmax,uvmax0;
+	double wmin,wmin0;
+	double wmax,wmax0;
 	double resolution;
 
         // MESH SIZE
-	int grid_size_x = 2048;
-	int grid_size_y = 2048;
+	int grid_size_x = 256;
+	int grid_size_y = 256;
 	int local_grid_size_x;// = 8;
 	int local_grid_size_y;// = 8;
 	int xaxis;
@@ -372,18 +372,18 @@ int main(int argc, char * argv[])
         strcpy(filename,datapath);
         strcat(filename,metafile);
         pFile = fopen (filename,"r");
-        fscanf(pFile,"%ld",&Nmeasures);
-        fscanf(pFile,"%ld",&Nvis);
-        fscanf(pFile,"%ld",&freq_per_chan);
-        fscanf(pFile,"%ld",&polarisations);
-        fscanf(pFile,"%ld",&Ntimes);
-        fscanf(pFile,"%lf",&dt);
-        fscanf(pFile,"%lf",&thours);
-        fscanf(pFile,"%ld",&baselines);
+        fscanf(pFile,"%ld",&Nmeasures0);
+        fscanf(pFile,"%ld",&Nvis0);
+        fscanf(pFile,"%ld",&freq_per_chan0);
+        fscanf(pFile,"%ld",&polarisations0);
+        fscanf(pFile,"%ld",&Ntimes0);
+        fscanf(pFile,"%lf",&dt0);
+        fscanf(pFile,"%lf",&thours0);
+        fscanf(pFile,"%ld",&baselines0);
         fscanf(pFile,"%lf",&uvmin);
         fscanf(pFile,"%lf",&uvmax);
-        fscanf(pFile,"%lf",&wmin);
-        fscanf(pFile,"%lf",&wmax);
+        fscanf(pFile,"%lf",&wmin0);
+        fscanf(pFile,"%lf",&wmax0);
         fclose(pFile);
 
         // calculate the resolution in radians
diff --git a/w-stacking.cu b/w-stacking.cu
index 9c76f82008b0742f88ff01dfef36843807ac777d..499189619802f8e94c7a41cd49caa46c8afb7093 100644
--- a/w-stacking.cu
+++ b/w-stacking.cu
@@ -147,7 +147,7 @@ void wstack(
 #ifdef __CUDACC__
     // Define the CUDA set up
     int Nth = NTHREADS;
-    int Nbl = num_points/Nth + 1;
+    long Nbl = (long)(num_points/Nth) + 1;
     if(NWORKERS == 1) {Nbl = 1; Nth = 1;};
     long Nvis = num_points*freq_per_chan*polarizations;
     printf("Running on GPU with %d threads and %d blocks\n",Nth,Nbl);