/* Copyright (C) 2024   INAF - Osservatorio Astronomico di Cagliari

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   A copy of the GNU General Public License is distributed along with
   this program in the COPYING file. If not, see: <https://www.gnu.org/licenses/>.
 */

/*! \file magma_calls.cpp
 *
 * \brief Implementation of the interface with MAGMA libraries.
 */
#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
#endif

#ifdef USE_MAGMA
#ifndef INCLUDE_MAGMA_CALLS_H_
#include "../include/magma_calls.h"
#endif

void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
  // magma_int_t result = magma_init();
  magma_int_t err = MAGMA_SUCCESS;
  magma_queue_t queue = NULL;
  magma_device_t dev = (magma_device_t)device_id;
  magma_queue_create(dev, &queue);
  magmaDoubleComplex *dwork; // workspace
  magma_int_t ldwork; // size of dwork
  magma_int_t *piv , info; // array of pivot indices
  magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
  magma_int_t mm = m * m; // size of a
  magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host
  magmaDoubleComplex *d_a; // pointer to first element on device
  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
  // allocate matrices
  err = magma_zmalloc(&d_a, mm); // device memory for a
  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
  piv = new magma_int_t[m]; // host mem.
  magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
  
  magma_zgetrf_gpu(m, m, d_a, m, piv, &info);
  magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info);
  
  magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
  delete[] piv; // free host memory
  magma_free(d_a); // free device memory
  magma_free(dwork);
  magma_queue_destroy(queue); // destroy queue
  // result = magma_finalize();
  jer = (int)err;
}
#endif
