#include <stdlib.h>
#include <stdio.h>
#include "/cineca/prod/opt/compilers/hpc-sdk/2023/binary/Linux_ppc64le/23.1/math_libs/11.8/include/cufftmp/cufftMp.h"
#include <mpi.h>
#include <cuda_runtime.h>
#include <complex.h>
#include "cuComplex.h"
#include "w-stacking.h"
#include <time.h>


void cuda_fft(
	int num_w_planes,
	int grid_size_x,
	int grid_size_y,
	int xaxis,
	int yaxis,
	double * grid,
	double * gridss,
	MPI_Comm comm)
{
#ifdef __CUDACC__

        cudaError_t mmm;
        cufftResult_t status;

	cufftDoubleComplex *fftwgrid;
	fftwgrid = (cufftDoubleComplex*) malloc(sizeof(cufftDoubleComplex)*2*num_w_planes*yaxis*grid_size_x);



	// Plan creation

	cufftHandle plan;
        status = cufftCreate(&plan);
        if (status != CUFFT_SUCCESS) {printf("!!! cufftCreate ERROR %d !!!\n", status);}

	cudaStream_t stream{};
	cudaStreamCreate(&stream);


	status = cufftMpAttachComm(plan, CUFFT_COMM_MPI, &comm);
	if (status != CUFFT_SUCCESS) {printf("!!! cufftMpAttachComm ERROR %d !!!\n", status);}

	status = cufftSetStream(plan, stream);
	if (status != CUFFT_SUCCESS) {printf("!!! cufftSetStream ERROR %d !!!\n", status);}

	size_t workspace;
	status = cufftMakePlan2d(plan, grid_size_x, grid_size_y, CUFFT_Z2Z, &workspace);
	if (status != CUFFT_SUCCESS) {printf("!!! cufftMakePlan2d ERROR %d !!!\n", status);}
	cudaDeviceSynchronize();



	long fftwindex = 0;
	long fftwindex2D = 0;
	double norm = 1.0/(double)(grid_size_x*grid_size_y);




	// Grid composition

	for (int iw=0; iw<num_w_planes; iw++)
        {
                printf("select the %d w-plane to transform\n", iw);
                for (int iv=0; iv<yaxis; iv++)
                {
                        for (int iu=0; iu<xaxis; iu++)
                        {
	                        fftwindex2D = iu + iv*xaxis;
				fftwindex = 2*(fftwindex2D + iw*xaxis*yaxis);
                	        fftwgrid[fftwindex2D].x = grid[fftwindex];
                        	fftwgrid[fftwindex2D].y = grid[fftwindex+1];
                        }
                }


		cudaLibXtDesc *fftwgrid_g;
        	cudaLibXtDesc *fftwgrid_g2;


		status = cufftXtMalloc(plan, &fftwgrid_g, CUFFT_XT_FORMAT_INPLACE);
		if (status != CUFFT_SUCCESS) {printf("!!! cufftXtMalloc ERROR %d !!!\n", status);}

	        status = cufftXtMalloc(plan, &fftwgrid_g2, CUFFT_XT_FORMAT_INPLACE);
        	if (status != CUFFT_SUCCESS) {printf("!!! cufftXtMalloc 2 ERROR %d !!!\n", status);}
	        cudaDeviceSynchronize();

                mmm = cudaStreamSynchronize(stream);
                if (mmm != cudaSuccess) {printf("!!! cudaStreamSynchronize ERROR %d !!!\n", mmm);}

		status = cufftXtMemcpy(plan, fftwgrid_g, fftwgrid, CUFFT_COPY_HOST_TO_DEVICE);
		if (status != CUFFT_SUCCESS) {printf("!!! cufftXtMemcpy htd fftwgrid_g ERROR %d !!!\n", status);}

		cudaDeviceSynchronize();

		status = cufftXtExecDescriptor(plan, fftwgrid_g, fftwgrid_g, CUFFT_INVERSE);
		if (status != CUFFT_SUCCESS) {printf("!!! cufftXtExecDescriptor ERROR %d !!!\n", status);}


                mmm = cudaStreamSynchronize(stream);
                if (mmm != cudaSuccess) {printf("!!! cudaStreamSynchronize 2 ERROR %d !!!\n", mmm);}

		cudaDeviceSynchronize();


                status = cufftXtMemcpy(plan, fftwgrid_g2, fftwgrid_g, CUFFT_COPY_DEVICE_TO_DEVICE);
                if (status != CUFFT_SUCCESS) {printf("!!! cufftXtMemcpy dtd fftwgrid ERROR %d !!!\n", status);}


                mmm = cudaStreamSynchronize(stream);
                if (mmm != cudaSuccess) {printf("!!! cudaStreamSynchronize 2 ERROR %d !!!\n", mmm);}

                cudaDeviceSynchronize();

                status = cufftXtMemcpy(plan, fftwgrid, fftwgrid_g2, CUFFT_COPY_DEVICE_TO_HOST);
                if (status != CUFFT_SUCCESS) {printf("!!! cufftXtMemcpy dth fftwgrid ERROR %d !!!\n", status);}


		cufftXtFree(fftwgrid_g);
		cufftXtFree(fftwgrid_g2);

		for (int iv=0; iv<yaxis; iv++)
                {
                        for (int iu=0; iu<xaxis; iu++)
                        {
                        	fftwindex2D = iu + iv*xaxis;
                        	fftwindex = 2*(fftwindex2D + iw*xaxis*yaxis);
                        	gridss[fftwindex] = norm*fftwgrid[fftwindex2D].x;
                        	gridss[fftwindex+1] = norm*fftwgrid[fftwindex2D].y;
			}
        	}
	}



	status = cufftDestroy(plan);
	if (status != CUFFT_SUCCESS) {printf("!!! cufftDestroy fftwgrid ERROR %d !!!\n", status);}


	cudaStreamDestroy(stream);
	cudaDeviceSynchronize();

#endif // __CUDACC__
}
