////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Passing OpenMP data to foreign runtime (cuBLAS library).
//
// Author: David Goz
// mail  : david.goz@inaf.it
// date  : 03.09.2024
// code tested using nvhpc
//
// - Compile the code:
//      - using nvc
//          $ nvc++ -O3 -mp=gpu -gpu=ccnative,debug,lineinfo -target=gpu -Minfo=all -v
//                  hybrid_omp_cuda.cu -o hybrid_omp_cuda -lm && ./hybrid_omp_cuda
//
//      - using clang
//          $ clang -O3 -v -fopenmp -fopenmp-targets=nvptx64-nvidia-cuda 
//                 hybrid_omp_cuda.c -o hybrid_omp_cuda -lm
//
// - Run the code:
//   $ export OMP_TARGET_OFFLOAD=mandatory
//   $ ./hybrid_omp_cuda
////////////////////////////////////////////////////////////////////////////////////////////////////

#include <cuda.h>
#include <math.h>
#include <assert.h>
#include <float.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <omp.h>

#define N         1024
#define SIZE      ((N) * (N))
#define HOST      0
#define DEV       1
#define LOOP      10
#define INIT      0
#define KERNEL    1
#define DATA      2
#define BLOCKSIZE 1024

typedef double MyData;

static double _time[2][3];
static int thr[2];

double process_time()
{
  struct timespec ts;
  clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &ts);
  const double ret = (double) (ts.tv_sec) + (double) ts.tv_nsec * 1.0e-9;

  return ret;
}

double thread_time()
{
  struct timespec ts;
  clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts);
  const double ret = (double) (ts.tv_sec) + (double) ts.tv_nsec * 1.0e-9;

  return ret;
}

void InitHost(MyData *const restrict A,
	      MyData *const restrict B,
	      MyData *const restrict C,
	      int    *const restrict thr)
{
  double start;

  #pragma omp parallel 
  {
    #pragma omp barrier
    #pragma omp master
    {
      *thr = omp_get_num_threads();
      start = thread_time();
    }
   #pragma omp barrier

    #pragma omp for collapse(2)
    for (int i=0 ; i<N ; i++)
      for (int j=0 ; j<N ; j++)
	{
	  A[(i * N) + j] = 1.0;
	  B[(i * N) + j] = (1.0 / M_PI);
	  C[(i * N) + j] = 0.0;
	}

    #pragma omp master
    {
      _time[HOST][INIT] += (thread_time() - start); 
    }
  } // omp parallel

  return;
}

void InitDev(MyData *const restrict A,
	     MyData *const restrict B,
	     MyData *const restrict C)
{
  const double start = thread_time();

 #pragma omp target teams loop collapse(2) is_device_ptr(A, B, C)
  for (int i=0 ; i<N ; i++)
    for (int j=0 ; j<N ; j++)
      {
	A[(i * N) + j] = 1.0;
	B[(i * N) + j] = (1.0 / M_PI);
	C[(i * N) + j] = 0.0;
      }

  _time[DEV][INIT] += (thread_time() - start);

  return;
}

void HostMM(MyData *const restrict A,
	    MyData *const restrict B,
	    MyData *const restrict C,
	    int    *const restrict thr)
{
  // C = alpha * A * B + beta * C;

  double start;

  // naive calculation
  #pragma omp parallel
  {
    #pragma omp barrier
    #pragma omp master
    {
      *thr = omp_get_num_threads();
      start = thread_time();
    }
    #pragma omp barrier

    #pragma omp for collapse(2)
    for (int i=0 ; i<N ; i++)
      for (int j=0 ; j<N ; j++)
	{
	  MyData sum = 0.0;
	  for (int k=0 ; k<N ; k++)
	    sum += A[(i * N) + k] * B[(k * N) + j];
  
	  C[(i * N) + j] = sum;
	}

    #pragma omp master
    {
      _time[HOST][KERNEL] += (thread_time() - start);
    }
  } // omp parallel

  return;
}

__global__ void DevMM(MyData *const restrict A,
		      MyData *const restrict B,
		      MyData *const restrict C,
		      const int              n)
{
  const int size = (n * n);
  const int globalID = threadIdx.x + (blockIdx.x * blockDim.x);

  if (globalID >= size)
    return;

  const int i = (globalID / N);
  const int j = (globalID % N);

  MyData sum = 0.0;
  for (int k=0 ; k<N ; k++)
    sum += (A[(i * N) + k] * B[(k * N) + j]);

  C[(i * N) + j] = sum;
  
  return;
}

void check(MyData *const restrict host_array,
	   MyData *const restrict dev_array)
{
  int flag = 0;
  for (size_t i=0 ; i<SIZE ; i++)
    flag = ((fabs(host_array[i] - dev_array[i]) > FLT_EPSILON) ? 1 : flag);

  if (!flag)
    printf("\n\t Result OK \n");
  else
    printf("\n\t Result wrong \n");
  
  return;
}

int main()
{
  // Host allocation
  MyData *h_buffer = (MyData *)malloc(2 * SIZE * sizeof(MyData));
  assert(h_buffer != NULL);
  MyData *const restrict C_HOST = h_buffer;
  MyData *const restrict C_DEV  = C_HOST + SIZE;
  
  // Spawning 2 host threads
  #pragma omp parallel num_threads(2)
  {
    // Evaluate the Dgemm on the host
    #pragma omp single nowait
    {
      // allowing nested parallelism
      omp_set_max_active_levels(2);

      MyData *tmp = (MyData *)malloc(2 * SIZE * sizeof(MyData));
      MyData *const restrict A = tmp;
      MyData *const restrict B = A + SIZE;
      
      for (int loop=0 ; loop<LOOP ; loop++)
	{
	  InitHost(A, B, C_HOST, &thr[0]);
	  HostMM(A, B, C_HOST, &thr[1]);
	}

      free(tmp);
    } // omp single

    #pragma omp single nowait
    {
      // Device allocation
      const int dev = omp_get_default_device();
      const int host = omp_get_initial_device();
      MyData *d_buffer = (MyData *)omp_target_alloc((3 * SIZE * sizeof(MyData)), dev);
      assert(d_buffer != NULL);
      MyData *const restrict d_A = d_buffer;
      MyData *const restrict d_B = d_A + SIZE;
      MyData *const restrict d_C = d_B + SIZE;

      const dim3 nblock = {((SIZE + BLOCKSIZE - 1) / BLOCKSIZE), 1, 1};
      const dim3 block  = {BLOCKSIZE, 1, 1};
      
      for (int loop=0 ; loop<LOOP ; loop++)
	{
	  // Init device with blocking omp target directive
	  InitDev(d_A, d_B, d_C);

	  double start = thread_time();
	  DevMM<<< nblock, block >>>(d_A, d_B, d_C, N);
	  // CUDA synchronization point
	  cudaDeviceSynchronize();
	  _time[DEV][KERNEL] += (thread_time() - start);

	  // Fetch data from the device and deallocate
	  start = thread_time();
	  omp_target_memcpy(C_DEV, d_C, (SIZE * sizeof(MyData)), 0, 0, host, dev);
	  _time[DEV][DATA] += (thread_time() - start);
	} // LOOP
      
      // deallocate device's memory
      omp_target_free(d_buffer, dev);
    } // omp single
  } // synchronization point

  check(C_HOST, C_DEV);

  free(h_buffer);

  printf("\n\t Matrix size: %d x %d\n", N, N);

  printf("\n\t Host execution time:");
  printf("\n\t\t Init      : %lg [s] - threads: %d", _time[HOST][INIT]/LOOP, thr[0]);
  printf("\n\t\t Dgemm     : %lg [s] - threads: %d\n", _time[HOST][KERNEL]/LOOP, thr[1]);

  printf("\n\t Device execution time:");
  printf("\n\t\t Init      : %lg [s]", _time[DEV][INIT]/LOOP);
  printf("\n\t\t Dgemm     : %lg [s]", _time[DEV][KERNEL]/LOOP);
  printf("\n\t\t Fetch data: %lg [s]\n\n", _time[DEV][DATA]/LOOP);
  
  return 0;
}
