Skip to content
Snippets Groups Projects
Commit 098daf96 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Implement SVD based matrix inversion

parent 4b27c93a
No related branches found
No related tags found
No related merge requests found
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#ifndef INCLUDE_MAGMA_CALLS_H_ #ifndef INCLUDE_MAGMA_CALLS_H_
#define INCLUDE_MAGMA_CALLS_H_ #define INCLUDE_MAGMA_CALLS_H_
#define MIN(a,b) ((a) > (b) ? (b) : (a))
/*! \brief Invert a complex matrix with double precision elements. /*! \brief Invert a complex matrix with double precision elements.
* *
* Use LAPACKE64 to perform an in-place matrix inversion for a complex * Use LAPACKE64 to perform an in-place matrix inversion for a complex
...@@ -35,4 +37,16 @@ ...@@ -35,4 +37,16 @@
*/ */
void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id=0); void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id=0);
#endif #ifdef USE_MAGMA_SVD
/*! \brief Invert a complex matrix using MAGMA zgesdd implementation.
*
* Use LAPACKE64 to perform an in-place matrix inversion for a complex
* matrix with double precision elements.
*
* \param mat: Matrix of complex. The matrix to be inverted.
* \param n: `np_int` The number of rows and columns of the [n x n] matrix.
* \param jer: `int &` Reference to an integer return flag.
*/
void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id=0);
#endif //USE_MAGMA_SVD
#endif // INCLUDE_MAGMA_CALLS_H_
...@@ -47,10 +47,14 @@ using namespace std; ...@@ -47,10 +47,14 @@ using namespace std;
void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size, int target_device) { void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size, int target_device) {
ier = 0; ier = 0;
#ifdef USE_MAGMA #ifdef USE_MAGMA
#ifdef USE_MAGMA_SVD
magma_svd_zinvert(mat, size, ier, target_device);
#else
magma_zinvert(mat, size, ier, target_device); magma_zinvert(mat, size, ier, target_device);
#endif // USE_MAGMA_SVD
#elif defined USE_LAPACK #elif defined USE_LAPACK
zinvert(mat, size, ier); zinvert(mat, size, ier);
#else #else
lucin(mat, max_size, size, ier); lucin(mat, max_size, size, ier);
#endif #endif // USE_MAGMA
} }
...@@ -27,15 +27,18 @@ ...@@ -27,15 +27,18 @@
#include "../include/magma_calls.h" #include "../include/magma_calls.h"
#endif #endif
#ifdef USE_MAGMA_SVD
#include "magma_operators.h"
#endif
void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) { void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
// magma_int_t result = magma_init();
magma_int_t err = MAGMA_SUCCESS; magma_int_t err = MAGMA_SUCCESS;
magma_queue_t queue = NULL; magma_queue_t queue = NULL;
magma_device_t dev = (magma_device_t)device_id; magma_device_t dev = (magma_device_t)device_id;
magma_queue_create(dev, &queue); magma_queue_create(dev, &queue);
magmaDoubleComplex *dwork; // workspace magmaDoubleComplex *dwork; // workspace
magma_int_t ldwork; // size of dwork magma_int_t ldwork; // size of dwork
magma_int_t *piv , info; // array of pivot indices magma_int_t *piv; // array of pivot indices
magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix
magma_int_t mm = m * m; // size of a magma_int_t mm = m * m; // size of a
magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host
...@@ -47,8 +50,10 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) { ...@@ -47,8 +50,10 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
piv = new magma_int_t[m]; // host mem. piv = new magma_int_t[m]; // host mem.
magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a magma_zsetmatrix(m, m, a, m, d_a , m, queue); // copy a -> d_a
magma_zgetrf_gpu(m, m, d_a, m, piv, &info); magma_zgetrf_gpu(m, m, d_a, m, piv, &err);
magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info); if (err == MAGMA_SUCCESS) {
magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &err);
}
magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a
delete[] piv; // free host memory delete[] piv; // free host memory
...@@ -57,4 +62,81 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) { ...@@ -57,4 +62,81 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
// result = magma_finalize(); // result = magma_finalize();
jer = (int)err; jer = (int)err;
} }
#endif
#ifdef USE_MAGMA_SVD
void magma_svd_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) {
magma_int_t err = MAGMA_SUCCESS;
magma_queue_t queue = NULL;
magma_device_t dev = (magma_device_t)device_id;
magma_queue_create(dev, &queue);
magmaDoubleComplex *a = reinterpret_cast<magmaDoubleComplex *>(&mat[0][0]);
magmaDoubleComplex* guess_work;
magma_zmalloc_cpu(&guess_work, n * n);
double *s;
magma_dmalloc_cpu(&s, n);
magmaDoubleComplex *u;
magma_zmalloc_cpu(&u, n * n);
magmaDoubleComplex *vt;
magma_zmalloc_cpu(&vt, n * n);
double *rwork;
magma_dmalloc_cpu(&rwork, 5 * n);
magma_int_t *iwork;
magma_imalloc_cpu(&iwork, 8 * n);
magma_zgesvd(
MagmaAllVec, MagmaAllVec, n, n, a, n, s, u, n, vt, n, guess_work, -1, rwork, &err
);
if (err == MAGMA_SUCCESS) {
const magmaDoubleComplex cc0 = MAGMA_Z_MAKE(0.0, 0.0);
const magmaDoubleComplex cc1 = MAGMA_Z_MAKE(1.0, 0.0);
magma_int_t lwork = (magma_int_t)real(real(guess_work[0]));
magmaDoubleComplex *work;
magma_zmalloc_cpu(&work, lwork);
magma_zgesvd(
MagmaAllVec, MagmaAllVec, n, n, a, n, s, u, n, vt, n, work, lwork, rwork, &err
);
magma_free_cpu(work);
for (magma_int_t si = 0; si < n; si++)
s[si] = (s[si] == 0.0) ? 0.0 : 1.0 / s[si];
for (magma_int_t ri = 0; ri < n; ri++) {
for (magma_int_t rj = 0; rj < n; rj++) {
u[n * ri + rj] = MAGMA_Z_MAKE(s[ri] * real(u[n * ri + rj]), -s[ri] * imag(u[n * ri + rj]));
vt[n * ri + rj] = MAGMA_Z_MAKE(real(vt[n * ri + rj]), -imag(vt[n * ri + rj]));
}
}
magmaDoubleComplex value;
for (magma_int_t mi = 0; mi < n; mi++) {
for (magma_int_t mj = 0; mj < n; mj++) {
value = MAGMA_Z_MAKE(0.0, 0.0);
for (magma_int_t mk = 0; mk < n; mk++) {
magmaDoubleComplex elem1 = vt[n * mi + mk];
magmaDoubleComplex elem2 = u[n * mk + mj];
value = MAGMA_Z_ADD(value, MAGMA_Z_MUL(elem1, elem2));
}
a[n * mi + mj] = value;
}
}
magmaDoubleComplex *d_a;
magma_zmalloc(&d_a, n * n);
magma_zsetmatrix(n, n, a, n, d_a , n, queue);
magmablas_ztranspose_inplace(n, d_a, n, queue);
magma_zgetmatrix(n, n, d_a, n, a , n, queue);
magma_free(d_a);
} else {
jer = (int)err;
}
magma_free_cpu(guess_work);
magma_free_cpu(iwork);
magma_free_cpu(rwork);
magma_free_cpu(s);
magma_free_cpu(u);
magma_free_cpu(vt);
magma_queue_destroy(queue);
jer = (int)err;
}
#endif // USE_MAGMA_SVD
#endif // USE_MAGMA
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment