diff --git a/src/libnptm/magma_calls.cpp b/src/libnptm/magma_calls.cpp index c566a6496002f6318509534d4bd9892a5dc06574..5801f849130e6a3b0829998f602a48f6f4e44a32 100644 --- a/src/libnptm/magma_calls.cpp +++ b/src/libnptm/magma_calls.cpp @@ -97,35 +97,30 @@ void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) { ); magma_free_cpu(work); - for (magma_int_t si = 0; si < n; si++) - s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si]; - - for (magma_int_t ri = 0; ri < n; ri++) { - for (magma_int_t rj = 0; rj < n; rj++) { - u[n * ri + rj] = MAGMA_Z_MAKE(s[ri] * real(u[n * ri + rj]), -s[ri] * imag(u[n * ri + rj])); - vt[n * ri + rj] = MAGMA_Z_MAKE(real(vt[n * ri + rj]), -imag(vt[n * ri + rj])); - } - } - magmaDoubleComplex value; - for (magma_int_t mi = 0; mi < n; mi++) { - for (magma_int_t mj = 0; mj < n; mj++) { - value = MAGMA_Z_MAKE(0.0, 0.0); - for (magma_int_t mk = 0; mk < n; mk++) { - magmaDoubleComplex elem1 = vt[n * mi + mk]; - magmaDoubleComplex elem2 = u[n * mk + mj]; - value = MAGMA_Z_ADD(value, MAGMA_Z_MUL(elem1, elem2)); - } - a[n * mi + mj] = value; + double rpart, ipart; + for (magma_int_t si = 0; si < n; si++) { + s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si]; + for (magma_int_t sj = 0; sj < n; sj++) { + rpart = s[si] * real(u[n * si + sj] ); + ipart = s[si] * imag(u[n * si + sj] ); + u[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart); + rpart = real(vt[n * si + sj] ); + ipart = imag(vt[n * si + sj] ); + vt[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart); } } - - magmaDoubleComplex *d_a; + magmaDoubleComplex *d_a, *d_u, *d_vt; magma_zmalloc(&d_a, n * n); - magma_zsetmatrix(n, n, a, n, d_a , n, queue); - magmablas_ztranspose_inplace(n, d_a, n, queue); - magma_zgetmatrix(n, n, d_a, n, a , n, queue); + magma_zmalloc(&d_u, n * n); + magma_zmalloc(&d_vt, n * n); + magma_zsetmatrix(n, n, u, n, d_u, n, queue); + magma_zsetmatrix(n, n, vt, n, d_vt, n, queue); + magmablas_zgemm(MagmaTrans, MagmaTrans, n, n, n, cc1, d_vt, n, d_u, n, cc0, d_a, n, queue); + magma_zgetmatrix(n, n, d_a, n, a, n, queue); magma_free(d_a); + magma_free(d_u); + magma_free(d_vt); } else { jer = (int)err; }