/* Copyright (C) 2024   INAF - Osservatorio Astronomico di Cagliari

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   A copy of the GNU General Public License is distributed along with
   this program in the COPYING file. If not, see: <https://www.gnu.org/licenses/>.
 */

/*! \file magma_calls.cpp
 *
 * \brief Implementation of the interface with MAGMA libraries.
 */
#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
#endif

#ifdef USE_MAGMA
#ifndef INCLUDE_MAGMA_CALLS_H_
#include "../include/magma_calls.h"
#endif

#ifdef USE_MAGMA_SVD
#include "magma_operators.h"
#endif

void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
  magmaDoubleComplex *dwork; // workspace
  magma_int_t ldwork; // size of dwork
  magma_int_t *piv; // array of pivot indices
  magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
  magma_int_t mm = m * m; // size of a
  magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host
  magmaDoubleComplex *d_a; // pointer to first element on device
  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
  // allocate matrices
  err = magma_zmalloc(&d_a, mm); // device memory for a
  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
  piv = new magma_int_t[m]; // host mem.
  magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
  
  magma_zgetrf_gpu(m, m, d_a, m, piv, &err);
  if (err == MAGMA_SUCCESS) {
    magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &err);
  }
  
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
  delete[] piv; // free host memory
  magma_free(d_a); // free device memory
  magma_queue_destroy(queue); // destroy queue
  // result = magma_finalize();
  jer = (int)err;
}

#ifdef USE_MAGMA_SVD
void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
  magmaDoubleComplex *a = reinterpret_cast<magmaDoubleComplex *>(&mat[0][0]);
  magmaDoubleComplex* guess_work;
  magma_zmalloc_cpu(&guess_work, n * n);
  double *s;
  magma_dmalloc_cpu(&s, n);
  magmaDoubleComplex *u;
  magma_zmalloc_cpu(&u, n * n);
  magmaDoubleComplex *vt;
  magma_zmalloc_cpu(&vt, n * n);
  double *rwork;
  magma_dmalloc_cpu(&rwork, 5 * n);
  magma_int_t *iwork;
  magma_imalloc_cpu(&iwork, 8 * n);
  
  magma_zgesvd(
	       MagmaAllVec, MagmaAllVec, n, n, a, n, s, u, n, vt, n, guess_work, -1, rwork, &err
  );
  if (err == MAGMA_SUCCESS) {
    const magmaDoubleComplex cc0 = MAGMA_Z_MAKE(0.0, 0.0);
    const magmaDoubleComplex cc1 = MAGMA_Z_MAKE(1.0, 0.0);
    magma_int_t lwork = (magma_int_t)real(real(guess_work[0]));
    magmaDoubleComplex *work;
    magma_zmalloc_cpu(&work, lwork);
    magma_zgesvd(
		 MagmaAllVec, MagmaAllVec, n, n, a, n, s, u, n, vt, n, work, lwork, rwork, &err
    );

    magma_free_cpu(work);
    for (magma_int_t si = 0; si < n; si++)
      s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];

    for (magma_int_t ri = 0; ri < n; ri++) {
      for (magma_int_t rj = 0; rj < n; rj++) {
    	u[n * ri + rj] = MAGMA_Z_MAKE(s[ri] * real(u[n * ri + rj]), -s[ri] * imag(u[n * ri + rj]));
    	vt[n * ri + rj] = MAGMA_Z_MAKE(real(vt[n * ri + rj]), -imag(vt[n * ri + rj]));
      }
    }

    magmaDoubleComplex value;
    for (magma_int_t mi = 0; mi < n; mi++) {
      for (magma_int_t mj = 0; mj < n; mj++) {
	value = MAGMA_Z_MAKE(0.0, 0.0);
	for (magma_int_t mk = 0; mk < n; mk++) {
	  magmaDoubleComplex elem1 = vt[n * mi + mk];
	  magmaDoubleComplex elem2 = u[n * mk + mj];
	  value = MAGMA_Z_ADD(value, MAGMA_Z_MUL(elem1, elem2));
	}
	a[n * mi + mj] = value;
      }
    }

    magmaDoubleComplex *d_a;
    magma_zmalloc(&d_a, n * n);
    magma_zsetmatrix(n, n, a, n, d_a , n, queue);
    magmablas_ztranspose_inplace(n, d_a, n, queue);
    magma_zgetmatrix(n, n, d_a, n, a , n, queue);
    magma_free(d_a);
  } else {
    jer = (int)err;
  }
  magma_free_cpu(guess_work);
  magma_free_cpu(iwork);
  magma_free_cpu(rwork);
  magma_free_cpu(s);
  magma_free_cpu(u);
  magma_free_cpu(vt);
  magma_queue_destroy(queue);
  jer = (int)err;
}
#endif // USE_MAGMA_SVD
#endif // USE_MAGMA
