#include <iostream>
#include <vector>
#include <omp.h>
#include <cassert>
#include <new>
#include <algorithm>

#include "allvars.hpp"

int main()
{
  // allocate memory on the host and set the global pointer
  global_ptr = new (std::nothrow) MyData** [X];
  assert(global_ptr != nullptr);
  for (std::size_t x=0 ; x<X ; x++)
    {
      global_ptr[x] = new (std::nothrow) MyData* [Y];
      assert(global_ptr[x] != nullptr);

      for (std::size_t y=0 ; y<Y ; y++)
	{
	  global_ptr[x][y] = new (std::nothrow) MyData [Z];
	  assert(global_ptr[x][y] != nullptr);

	  for (std::size_t z=0 ; z<Z ; z++)
	    {
	      global_ptr[x][y][z] = static_cast<MyData>(z + 1);
	    } // loop over Z
	} // loop over Y
    } // loop over X

  std::cout << "\n\t global_ptr allocated on the host \n" << std::endl;
  
  // Allocate memory on the device and set the global pointer
  #pragma omp target enter data map(to: global_ptr[0:X][0:Y][0:Z])

  #pragma omp target teams distribute parallel for
  for (std::size_t index=0 ; index<Z ; index++)
    {
      MyData diff = 0;
      for (std::size_t x=0 ; x<X ; x++)
	for (std::size_t y=0 ; y<Y ; y++)
	  diff += global_ptr[x][y][index];

      for (std::size_t x=0 ; x<X ; x++)
	for (std::size_t y=0 ; y<Y ; y++)
	  global_ptr[x][y][index] = (diff / (X * Y * (index + 1)));
    } // kernel
  
  // Device-host synchronization
  #pragma omp target update from(global_ptr[0:X][0:Y][0:Z])

  // Check if any element along Z is equal to 1
  for (std::size_t x=0 ; x<X ; x++)
    for (std::size_t y=0 ; y<Y ; y++)
      {
	const bool One = std::all_of(&global_ptr[x][y][0], &global_ptr[x][y][Z], [](const MyData x) {return (x == 1);});
	if (One == false)
	  {
	    std::cout << "\n\t Test failed \n" << std::endl;
	    return -1;
	  }
      }

  // Deallocate memory on the device
  #pragma omp target exit data map(delete: global_ptr[0:X][0:Y][0:Z])

  // deallocate host memory
  for (std::size_t x=0 ; x<X ; x++)
    {
      for (std::size_t y=0 ; y<Y ; y++)
	{
	  delete[] global_ptr[x][y];
	}
      delete[] global_ptr[x];
    }
  delete[] global_ptr;

  std::cout << "\n\t Test OK! \n" << std::endl;
  
  return 0;
}
