/* Copyright (C) 2024   INAF - Osservatorio Astronomico di Cagliari

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   A copy of the GNU General Public License is distributed along with
   this program in the COPYING file. If not, see: <https://www.gnu.org/licenses/>.
 */

/*! \file magma_calls.cpp
 *
 * \brief Implementation of the interface with MAGMA libraries.
 */
#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
#endif

#ifdef USE_MAGMA

#ifndef INCLUDE_MAGMA_CALLS_H_
#include "../include/magma_calls.h"
#endif

#include <limits>

void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  // magma_int_t result = magma_init();
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
  magmaDoubleComplex *dwork; // workspace
  magma_int_t ldwork; // size of dwork
  magma_int_t *piv , info; // array of pivot indices
  magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
  magma_int_t mm = m * m; // size of a
  magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host
  magmaDoubleComplex *d_a; // pointer to first element on device
  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
  // allocate matrices
  err = magma_zmalloc(&d_a, mm); // device memory for a
  if (err != MAGMA_SUCCESS) {
    printf("Error allocating d_a\n");
    exit(1);
  }
  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
  if (err != MAGMA_SUCCESS) {
    printf("Error allocating dwork\n");
    exit(1);
  }
  piv = new magma_int_t[m]; // host mem.
  magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
  
  magma_zgetrf_gpu(m, m, d_a, m, piv, &info);
  magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info);
  
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
  delete[] piv; // free host memory
  magma_free(d_a); // free device memory
  magma_free(dwork);
  magma_queue_destroy(queue); // destroy queue
  // result = magma_finalize();
  jer = (int)err;
}

void magma_zinvert_and_refine(dcomplex **mat, np_int n, int &jer, int &maxiters, double &accuracygoal, int refinemode, int device_id) {
  // magma_int_t result = magma_init();
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
  magmaDoubleComplex *dwork; // workspace
  magma_int_t ldwork; // size of dwork
  magma_int_t *piv , info; // array of pivot indices
  magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
  magma_int_t mm = m * m; // size of a
  magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host
  magmaDoubleComplex *d_a; // pointer to first element on device
  magmaDoubleComplex *d_a_orig; // pointer to original array on device
  magmaDoubleComplex *d_a_residual; // pointer to residual array on device
  magmaDoubleComplex *d_a_refine; // pointer to residual array on device
  magmaDoubleComplex *d_id; // pointer to the diagonal of identity matrix
  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
  // allocate matrices
  // magmaDoubleComplex *a_unref = new magmaDoubleComplex[mm]; 
  err = magma_zmalloc(&d_a, mm); // device memory for a, will contain the inverse after call to zgetri
  if (err != MAGMA_SUCCESS) {
    printf("Error allocating d_a\n");
    exit(1);
  }
  if (maxiters>0) {
    err = magma_zmalloc(&d_a_orig, mm); // device memory for copy of a
    if (err != MAGMA_SUCCESS) {
      printf("Error allocating d_a_orig\n");
      exit(1);
    }
  }
  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
  if (err != MAGMA_SUCCESS) {
    printf("Error allocating dwork\n");
    exit(1);
  }
  piv = new magma_int_t[m]; // host mem.
  magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
  if (maxiters>0) {
    magma_zcopy(mm, d_a, 1, d_a_orig, 1, queue); // copy d_a -> d_a_orig on gpu
  }
  // do the LU factorisation
  magma_zgetrf_gpu(m, m, d_a, m, piv, &info);
  // do the in-place inversion, after which d_a contains the (first approx) inverse
  magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info);
  // magma_zgetmatrix(m, m, d_a , m, a_unref, m, queue); // copy unrefined d_a -> a_unref
  magma_free(dwork); // free dwork, it was only needed by zgetri
  if (maxiters>0) {
    // allocate memory for the temporary matrix products
    err = magma_zmalloc(&d_a_residual, mm); // device memory for iterative correction of inverse of a
    if (err != MAGMA_SUCCESS) {
      printf("Error allocating d_a_residual\n");
      exit(1);
    }
    err = magma_zmalloc(&d_a_refine, mm); // device memory for iterative correction of inverse of a
    if (err != MAGMA_SUCCESS) {
      printf("Error allocating d_a_refine\n");
      exit(1);
    }
    // allocate memory for the identity vector on the host
    {
      dcomplex *native_id = new dcomplex[1];
      native_id[0] = 1;
      magmaDoubleComplex *id = (magmaDoubleComplex *) &(native_id[0]);
      // fill it with 1
      err = magma_zmalloc(&d_id, 1);
      if (err != MAGMA_SUCCESS) {
	printf("Error allocating d_id\n");
	exit(1);
      }
      magma_zsetvector(1, id, 1, d_id, 1, queue); // copy identity to device vector
      delete[] native_id; // free identity vector on host
    }
  }
  bool iteraterefine = true;
  if (maxiters>0) {
    magmaDoubleComplex magma_mone;
    magma_mone.x = -1;
    magma_mone.y = 0;
    magmaDoubleComplex magma_one;
    magma_one.x = 1;
    magma_one.y = 0;
    magmaDoubleComplex magma_zero;
    magma_zero.x = 0;
    magma_zero.y = 0;
    // multiply minus the original matrix times the inverse matrix
    // NOTE: factors in zgemm are swapped because zgemm is designed for column-major
    // Fortran-style arrays, whereas our arrays are C-style row-major.
    magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m,  magma_mone, d_a, m, d_a_orig, m, magma_zero, d_a_residual, m, queue);
    // add the identity to the product
    magma_zaxpy (m, magma_one, d_id, 0, d_a_residual, m+1, queue);
    double oldmax=0;
    if (refinemode >0) {
      // find the maximum absolute value of the residual
      magma_int_t maxindex = magma_izamax(mm, d_a_residual, 1, queue);
      magmaDoubleComplex magmamax;
      // transfer the maximum value to the host
      magma_zgetvector(1, d_a_residual+maxindex, 1, &magmamax, 1, queue);
      // take the module
      oldmax = cabs(magmamax.x + I*magmamax.y);
      printf("Initial max residue = %g\n", oldmax);
      if (oldmax < accuracygoal) iteraterefine = false;
    }
    // begin correction loop (should iterate maxiters times)
    int iter;
    for (iter=0; (iter<maxiters) && iteraterefine; iter++) {
      // multiply the inverse times the residual, add to the initial inverse
      magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_one, d_a_residual, m, d_a, m, magma_zero, d_a_refine, m, queue);
      // add to the initial inverse
      magma_zaxpy (mm, magma_one, d_a_refine, 1, d_a, 1, queue);
      // multiply minus the original matrix times the new inverse matrix
      magma_zgemm(MagmaNoTrans, MagmaNoTrans, m, m, m, magma_mone, d_a, m, d_a_orig, m, magma_zero, d_a_residual, m, queue);
      // add the identity to the product
      magma_zaxpy (m, magma_one, d_id, 0, d_a_residual, m+1, queue);
      if ((refinemode==2) || ((refinemode==1) && (iter == (maxiters-1)))) {
	// find the maximum absolute value of the residual
	magma_int_t maxindex = magma_izamax(mm, d_a_residual, 1, queue);
	// transfer the maximum value to the host
	magmaDoubleComplex magmamax;
	magma_zgetvector(1, d_a_residual+maxindex, 1, &magmamax, 1, queue);
	// take the module
	double newmax = cabs(magmamax.x + I*magmamax.y);
	printf("Max residue after %d iterations = %g\n", iter+1, newmax);
	// if the maximum in the residual decreased from the previous iteration,
	// update oldmax and go on, otherwise no point further iterating refinements
	if ((refinemode==2) && ((newmax > oldmax)||(newmax < accuracygoal))) iteraterefine = false;
	oldmax = newmax;
      }
    }
    // if we are being called with refinemode=2, then on exit we set maxiters to the actual number of iters we performed to achieve the required accuracy
    if (refinemode==2) maxiters = iter;
    accuracygoal = oldmax;
    // end correction loop
  }
  // free temporary device arrays 
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy final refined d_a -> a
  delete[] piv; // free host memory
  magma_free(d_a); // free device memory
  // delete[] a_unref;
  if (maxiters>0) {
    magma_free(d_id);
    magma_free(d_a_orig); // free device memory
    magma_free(d_a_residual); // free device memory
    magma_free(d_a_refine); // free device memory
  }
  magma_queue_destroy(queue); // destroy queue
  // result = magma_finalize();
  jer = (int)err;
}


#endif
