Skip to content
Snippets Groups Projects
Commit f3264886 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Use GPU matrix multiplication in SVD recombination

parent d7d7dc8c
No related branches found
No related tags found
No related merge requests found
...@@ -97,35 +97,30 @@ void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) { ...@@ -97,35 +97,30 @@ void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
); );
magma_free_cpu(work); magma_free_cpu(work);
for (magma_int_t si = 0; si < n; si++)
s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];
for (magma_int_t ri = 0; ri < n; ri++) {
for (magma_int_t rj = 0; rj < n; rj++) {
u[n * ri + rj] = MAGMA_Z_MAKE(s[ri] * real(u[n * ri + rj]), -s[ri] * imag(u[n * ri + rj]));
vt[n * ri + rj] = MAGMA_Z_MAKE(real(vt[n * ri + rj]), -imag(vt[n * ri + rj]));
}
}
magmaDoubleComplex value; double rpart, ipart;
for (magma_int_t mi = 0; mi < n; mi++) { for (magma_int_t si = 0; si < n; si++) {
for (magma_int_t mj = 0; mj < n; mj++) { s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];
value = MAGMA_Z_MAKE(0.0, 0.0); for (magma_int_t sj = 0; sj < n; sj++) {
for (magma_int_t mk = 0; mk < n; mk++) { rpart = s[si] * real(u[n * si + sj] );
magmaDoubleComplex elem1 = vt[n * mi + mk]; ipart = s[si] * imag(u[n * si + sj] );
magmaDoubleComplex elem2 = u[n * mk + mj]; u[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart);
value = MAGMA_Z_ADD(value, MAGMA_Z_MUL(elem1, elem2)); rpart = real(vt[n * si + sj] );
} ipart = imag(vt[n * si + sj] );
a[n * mi + mj] = value; vt[n * si + sj] = MAGMA_Z_MAKE(rpart, -ipart);
} }
} }
magmaDoubleComplex *d_a, *d_u, *d_vt;
magmaDoubleComplex *d_a;
magma_zmalloc(&d_a, n * n); magma_zmalloc(&d_a, n * n);
magma_zsetmatrix(n, n, a, n, d_a , n, queue); magma_zmalloc(&d_u, n * n);
magmablas_ztranspose_inplace(n, d_a, n, queue); magma_zmalloc(&d_vt, n * n);
magma_zsetmatrix(n, n, u, n, d_u, n, queue);
magma_zsetmatrix(n, n, vt, n, d_vt, n, queue);
magmablas_zgemm(MagmaTrans, MagmaTrans, n, n, n, cc1, d_vt, n, d_u, n, cc0, d_a, n, queue);
magma_zgetmatrix(n, n, d_a, n, a, n, queue); magma_zgetmatrix(n, n, d_a, n, a, n, queue);
magma_free(d_a); magma_free(d_a);
magma_free(d_u);
magma_free(d_vt);
} else { } else {
jer = (int)err; jer = (int)err;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment