From 547678ca38138cc761055e1bc0343a7c9b3672a5 Mon Sep 17 00:00:00 2001
From: Giovanni La Mura <giovanni.lamura@inaf.it>
Date: Tue, 7 May 2024 15:03:03 +0200
Subject: [PATCH] Create an interface to MAGMA library

---
 src/include/magma_calls.h    | 23 ++++++++++++++++
 src/include/types.h          | 12 ++++++---
 src/libnptm/Makefile         |  6 ++---
 src/libnptm/algebraic.cpp    | 10 ++++++-
 src/libnptm/lapack_calls.cpp |  5 ++--
 src/libnptm/magma_calls.cpp  | 51 ++++++++++++++++++++++++++++++++++++
 src/make.inc                 | 18 +++++++++++++
 7 files changed, 115 insertions(+), 10 deletions(-)
 create mode 100644 src/include/magma_calls.h
 create mode 100644 src/libnptm/magma_calls.cpp

diff --git a/src/include/magma_calls.h b/src/include/magma_calls.h
new file mode 100644
index 00000000..e9bc4e04
--- /dev/null
+++ b/src/include/magma_calls.h
@@ -0,0 +1,23 @@
+/* Distributed under the terms of GPLv3 or later. See COPYING for details. */
+
+/*! \file magma_calls.h
+ *
+ * \brief C++ interface to MAGMA calls.
+ *
+ */
+
+#ifndef INCLUDE_MAGMA_CALLS_H_
+#define INCLUDE_MAGMA_CALLS_H_
+
+/*! \brief Invert a complex matrix with double precision elements.
+ *
+ * Use LAPACKE64 to perform an in-place matrix inversion for a complex
+ * matrix with double precision elements.
+ *
+ * \param mat: Matrix of complex. The matrix to be inverted.
+ * \param n: `np_int` The number of rows and columns of the [n x n] matrix.
+ * \param jer: `int &` Reference to an integer return flag.
+ */
+void magma_zinvert(dcomplex **mat, np_int n, int &jer);
+
+#endif
diff --git a/src/include/types.h b/src/include/types.h
index 8c6ab4b2..a7f9ac70 100644
--- a/src/include/types.h
+++ b/src/include/types.h
@@ -16,11 +16,15 @@ typedef __complex__ double dcomplex;
 #ifdef USE_MKL
 #ifndef MKL_INT 
 #define MKL_INT int64_t
-#endif
+#endif // MKL_INT
 #include <mkl_lapacke.h>
 #else
 #include <lapacke.h>
-#endif
+#endif // USE_MKL
+#endif // USE_LAPACK
+
+#ifdef USE_MAGMA
+#include "magma_v2.h"
 #endif
 
 #ifndef np_int
@@ -28,8 +32,8 @@ typedef __complex__ double dcomplex;
 #define np_int lapack_int
 #else
 #define np_int int64_t
-#endif
-#endif
+#endif // lapack_int
+#endif // np_int
 
 /*! \brief Get the real part of a complex number.
  *
diff --git a/src/libnptm/Makefile b/src/libnptm/Makefile
index de9085d0..b02e5f2e 100644
--- a/src/libnptm/Makefile
+++ b/src/libnptm/Makefile
@@ -19,11 +19,11 @@ endif
 include ../make.inc
 
 
-CXX_NPTM_OBJS=$(OBJDIR)/Commons.o $(OBJDIR)/Configuration.o $(OBJDIR)/file_io.o $(OBJDIR)/Parsers.o $(OBJDIR)/sph_subs.o $(OBJDIR)/clu_subs.o $(OBJDIR)/tfrfme.o $(OBJDIR)/tra_subs.o $(OBJDIR)/TransitionMatrix.o $(OBJDIR)/lapack_calls.o $(OBJDIR)/algebraic.o $(OBJDIR)/types.o $(OBJDIR)/logging.o
+CXX_NPTM_OBJS=$(OBJDIR)/Commons.o $(OBJDIR)/Configuration.o $(OBJDIR)/file_io.o $(OBJDIR)/Parsers.o $(OBJDIR)/sph_subs.o $(OBJDIR)/clu_subs.o $(OBJDIR)/tfrfme.o $(OBJDIR)/tra_subs.o $(OBJDIR)/TransitionMatrix.o $(OBJDIR)/lapack_calls.o $(OBJDIR)/magma_calls.o $(OBJDIR)/algebraic.o $(OBJDIR)/types.o $(OBJDIR)/logging.o
 
-CXX_NPTM_DYNOBJS=$(DYNOBJDIR)/Commons.o $(DYNOBJDIR)/Configuration.o $(DYNOBJDIR)/file_io.o $(DYNOBJDIR)/Parsers.o $(DYNOBJDIR)/sph_subs.o $(DYNOBJDIR)/clu_subs.o $(DYNOBJDIR)/tfrfme.o $(DYNOBJDIR)/tra_subs.o $(DYNOBJDIR)/TransitionMatrix.o $(DYNOBJDIR)/lapack_calls.o $(DYNOBJDIR)/algebraic.o $(DYNOBJDIR)/types.o $(DYNOBJDIR)/logging.o
+CXX_NPTM_DYNOBJS=$(DYNOBJDIR)/Commons.o $(DYNOBJDIR)/Configuration.o $(DYNOBJDIR)/file_io.o $(DYNOBJDIR)/Parsers.o $(DYNOBJDIR)/sph_subs.o $(DYNOBJDIR)/clu_subs.o $(DYNOBJDIR)/tfrfme.o $(DYNOBJDIR)/tra_subs.o $(DYNOBJDIR)/TransitionMatrix.o $(DYNOBJDIR)/lapack_calls.o $(DYNOBJDIR)/magma_calls.o $(DYNOBJDIR)/algebraic.o $(DYNOBJDIR)/types.o $(DYNOBJDIR)/logging.o
 
-CXX_NPTM_DEBUG=$(OBJDIR)/Commons.g* $(OBJDIR)/Configuration.g* $(OBJDIR)/file_io.g* $(OBJDIR)/Parsers.g* $(OBJDIR)/sph_subs.g* $(OBJDIR)/clu_subs.g* $(OBJDIR)/tfrfme.g* $(OBJDIR)/tra_subs.g* $(OBJDIR)/TransitionMatrix.g* $(OBJDIR)/lapack_calls.g* $(OBJDIR)/algebraic.g* $(OBJDIR)/types.g* $(OBJDIR)/logging.g* $(DYNOBJDIR)/Commons.g* $(DYNOBJDIR)/Configuration.g* $(DYNOBJDIR)/file_io.g* $(DYNOBJDIR)/Parsers.g* $(DYNOBJDIR)/sph_subs.g* $(DYNOBJDIR)/clu_subs.g* $(DYNOBJDIR)/tfrfme.g* $(DYNOBJDIR)/tra_subs.g* $(DYNOBJDIR)/TransitionMatrix.g* $(DYNOBJDIR)/lapack_calls.g* $(DYNOBJDIR)/algebraic.g* $(DYNOBJDIR)/types.g* $(DYNOBJDIR)/logging.g*
+CXX_NPTM_DEBUG=$(OBJDIR)/Commons.g* $(OBJDIR)/Configuration.g* $(OBJDIR)/file_io.g* $(OBJDIR)/Parsers.g* $(OBJDIR)/sph_subs.g* $(OBJDIR)/clu_subs.g* $(OBJDIR)/tfrfme.g* $(OBJDIR)/tra_subs.g* $(OBJDIR)/TransitionMatrix.g* $(OBJDIR)/lapack_calls.g* $(OBJDIR)/magma_calls.g* $(OBJDIR)/algebraic.g* $(OBJDIR)/types.g* $(OBJDIR)/logging.g* $(DYNOBJDIR)/Commons.g* $(DYNOBJDIR)/Configuration.g* $(DYNOBJDIR)/file_io.g* $(DYNOBJDIR)/Parsers.g* $(DYNOBJDIR)/sph_subs.g* $(DYNOBJDIR)/clu_subs.g* $(DYNOBJDIR)/tfrfme.g* $(DYNOBJDIR)/tra_subs.g* $(DYNOBJDIR)/TransitionMatrix.g* $(DYNOBJDIR)/lapack_calls.g* $(DYNOBJDIR)/magma_calls.g* $(DYNOBJDIR)/algebraic.g* $(DYNOBJDIR)/types.g* $(DYNOBJDIR)/logging.g*
 
 all: $(BUILDDIR_NPTM)/libnptm.a $(BUILDDIR_NPTM)/libnptm.so
 
diff --git a/src/libnptm/algebraic.cpp b/src/libnptm/algebraic.cpp
index 8811abca..477b3373 100644
--- a/src/libnptm/algebraic.cpp
+++ b/src/libnptm/algebraic.cpp
@@ -14,6 +14,12 @@
 #endif
 #endif
 
+#ifdef USE_MAGMA
+#ifndef INCLUDE_MAGMA_CALLS_H_
+#include "../include/magma_calls.h"
+#endif
+#endif
+
 #ifndef INCLUDE_ALGEBRAIC_H_
 #include "../include/algebraic.h"
 #endif
@@ -26,7 +32,9 @@ using namespace std;
 
 void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size) {
   ier = 0;
-#ifdef USE_LAPACK
+#ifdef USE_MAGMA
+  magma_zinvert(mat, size, ier);
+#elif defined USE_LAPACK
   zinvert(mat, size, ier);
 #else
   lucin(mat, max_size, size, ier);
diff --git a/src/libnptm/lapack_calls.cpp b/src/libnptm/lapack_calls.cpp
index 363ba2af..1c68aa5f 100644
--- a/src/libnptm/lapack_calls.cpp
+++ b/src/libnptm/lapack_calls.cpp
@@ -8,19 +8,20 @@
 #include "../include/types.h"
 #endif
 
+/*
 #ifdef USE_LAPACK
 #ifdef USE_MKL
 #include <mkl_lapacke.h>
 #else
 #include <lapacke.h>
 #endif
+*/
 
+#ifdef USE_LAPACK
 #ifndef INCLUDE_LAPACK_CALLS_H_
 #include "../include/lapack_calls.h"
 #endif
-#endif
 
-#ifdef USE_LAPACK
 void zinvert(dcomplex **mat, np_int n, int &jer) {
   jer = 0;
   dcomplex *arr = &(mat[0][0]);
diff --git a/src/libnptm/magma_calls.cpp b/src/libnptm/magma_calls.cpp
new file mode 100644
index 00000000..da32c805
--- /dev/null
+++ b/src/libnptm/magma_calls.cpp
@@ -0,0 +1,51 @@
+/* Distributed under the terms of GPLv3 or later. See COPYING for details. */
+
+/*! \file magma_calls.cpp
+ *
+ * \brief Implementation of the interface with MAGMA libraries.
+ */
+#ifndef INCLUDE_TYPES_H_
+#include "../include/types.h"
+#endif
+
+#ifdef USE_MAGMA
+#ifndef INCLUDE_MAGMA_CALLS_H_
+#include "../include/magma_calls.h"
+#endif
+
+void magma_zinvert(dcomplex **mat, np_int n, int &jer) {
+  // magma_int_t result = magma_init();
+  magma_int_t result = MAGMA_SUCCESS;
+  magma_queue_t queue = NULL;
+  magma_int_t dev = 0;
+  magma_queue_create(dev, &queue);
+  magmaDoubleComplex *dwork; // dwork - workspace
+  magma_int_t ldwork; // size of dwork
+  magma_int_t *piv , info; // piv - array of indices of inter -
+  magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
+  magma_int_t mm = m * m; // size of a, r, c
+  magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // a- mxm matrix on the host
+  magmaDoubleComplex *d_a; // d_a - mxm matrix a on the device
+  magma_int_t err;
+  ldwork = m * magma_get_zgetri_nb(m); // optimal block size
+  // allocate matrices
+  err = magma_zmalloc(&d_a, mm); // device memory for a
+  err = magma_zmalloc(&dwork, ldwork); // dev. mem. for ldwork
+  piv = (magma_int_t *)malloc(m*sizeof(magma_int_t )); // host mem.
+  magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
+  // find the inverse matrix: d_a*X=I using the LU factorization
+  // with partial pivoting and row interchanges computed by
+  // magma_zgetrf_gpu; row i is interchanged with row piv(i);
+  // d_a - mxm matrix; d_a is overwritten by the inverse
+  
+  magma_zgetrf_gpu(m, m, d_a, m, piv, &info);
+  magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info);
+  
+  magma_zgetmatrix( m, m, d_a , m, a, m, queue); // copy d_a -> a
+  free(piv); // free host memory
+  magma_free(d_a); // free device memory
+  magma_queue_destroy(queue); // destroy queue
+  // result = magma_finalize();
+  jer = (int)result;
+}
+#endif
diff --git a/src/make.inc b/src/make.inc
index 88c109cc..60089ea8 100644
--- a/src/make.inc
+++ b/src/make.inc
@@ -71,21 +71,36 @@ endif
 #the next endif is for USE_LAPACK
 endif
 
+# define (outside) USE_MAGMA for magma support
+ifdef USE_MAGMA
+MAGMA_LDFLAGS= -lmagma
+#the next endif is for USE_MAGMA
+endif
+
 # CXXFLAGS defines the default compilation options for the C++ compiler
 ifndef CXXFLAGS
 override CXXFLAGS=-O3 -ggdb -pg -coverage -I$(HDF5_INCLUDE)
 ifdef USE_OPENMP
 override CXXFLAGS+= -fopenmp
+# closes USE_OPENMP
 endif
 ifdef USE_LAPACK
 override CXXFLAGS+= -DUSE_LAPACK -DLAPACK_ILP64 
 ifdef USE_MKL
 override CXXFLAGS+= -DMKL_ILP64 -DUSE_MKL -I$(MKLROOT)/include
+# closes USE_MKL
 endif
 ifdef USE_OPENMP
 override CXXFLAGS+= -fopenmp
+# closes USE_OPENMP
 endif
+# closes USE_LAPACK
 endif
+ifdef USE_MAGMA
+override CXXFLAGS+= -DUSE_MAGMA
+# closes USE_MAGMA
+endif
+# closes CXXFLAGS
 endif
 
 # HDF5_LIB defines the default path to the HDF5 libraries to use
@@ -98,6 +113,9 @@ override CXXLDFLAGS=-L/usr/lib64 -L$(HDF5_LIB) -lhdf5 $(STATICFLAG)
 ifdef USE_LAPACK
 override CXXLDFLAGS+= $(LAPACK_LDFLAGS)
 endif
+ifdef USE_MAGMA
+override CXXLDFLAGS+= $(MAGMA_LDFLAGS)
+endif
 override CXXLDFLAGS+= $(LDFLAGS)
 endif
 
-- 
GitLab