From f3264886f8dcc1217f25eb8aea34c960a28f1445 Mon Sep 17 00:00:00 2001
From: Giovanni La Mura <giovanni.lamura@inaf.it>
Date: Mon, 7 Oct 2024 17:26:14 +0200
Subject: [PATCH] Use GPU matrix multiplication in SVD recombination

---
 src/libnptm/magma_calls.cpp | 43 ++++++++++++++++---------------------
 1 file changed, 19 insertions(+), 24 deletions(-)

diff --git a/src/libnptm/magma_calls.cpp b/src/libnptm/magma_calls.cpp
index c566a649..5801f849 100644
--- a/src/libnptm/magma_calls.cpp
+++ b/src/libnptm/magma_calls.cpp
@@ -97,35 +97,30 @@ void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
     );
 
     magma_free_cpu(work);
-    for (magma_int_t si = 0; si < n; si++)
-      s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];
-
-    for (magma_int_t ri = 0; ri < n; ri++) {
-      for (magma_int_t rj = 0; rj < n; rj++) {
-    	u[n * ri + rj] = MAGMA_Z_MAKE(s[ri] * real(u[n * ri + rj]), -s[ri] * imag(u[n * ri + rj]));
-    	vt[n * ri + rj] = MAGMA_Z_MAKE(real(vt[n * ri + rj]), -imag(vt[n * ri + rj]));
-      }
-    }
 
-    magmaDoubleComplex value;
-    for (magma_int_t mi = 0; mi < n; mi++) {
-      for (magma_int_t mj = 0; mj < n; mj++) {
-	value = MAGMA_Z_MAKE(0.0, 0.0);
-	for (magma_int_t mk = 0; mk < n; mk++) {
-	  magmaDoubleComplex elem1 = vt[n * mi + mk];
-	  magmaDoubleComplex elem2 = u[n * mk + mj];
-	  value = MAGMA_Z_ADD(value, MAGMA_Z_MUL(elem1, elem2));
-	}
-	a[n * mi + mj] = value;
+    double rpart, ipart;
+    for (magma_int_t si = 0; si < n; si++) {
+      s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];
+      for (magma_int_t sj = 0; sj < n; sj++) {
+	rpart = s[si] * real(u[n * si + sj] );
+	ipart = s[si] * imag(u[n * si + sj] );
+	u[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart);
+	rpart = real(vt[n * si + sj] );
+	ipart = imag(vt[n * si + sj] );
+	vt[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart);
       }
     }
-
-    magmaDoubleComplex *d_a;
+    magmaDoubleComplex *d_a, *d_u, *d_vt;
     magma_zmalloc(&d_a, n * n);
-    magma_zsetmatrix(n, n, a, n, d_a , n, queue);
-    magmablas_ztranspose_inplace(n, d_a, n, queue);
-    magma_zgetmatrix(n, n, d_a, n, a , n, queue);
+    magma_zmalloc(&d_u, n * n);
+    magma_zmalloc(&d_vt, n * n);
+    magma_zsetmatrix(n, n, u, n, d_u, n, queue);
+    magma_zsetmatrix(n, n, vt, n, d_vt, n, queue);
+    magmablas_zgemm(MagmaTrans, MagmaTrans, n, n, n, cc1, d_vt, n, d_u, n, cc0, d_a, n, queue);
+    magma_zgetmatrix(n, n, d_a, n, a, n, queue);
     magma_free(d_a);
+    magma_free(d_u);
+    magma_free(d_vt);
   } else {
     jer = (int)err;
   }
-- 
GitLab