From 0c307b8c1bc86d2f480c80d5e435c12e3d97e38c Mon Sep 17 00:00:00 2001 From: Giovanni La Mura Date: Sun, 4 Aug 2024 22:33:29 -0500 Subject: [PATCH] Enable possibility to send matrix inversion to different GPUs --- src/include/algebraic.h | 3 ++- src/include/magma_calls.h | 3 ++- src/libnptm/algebraic.cpp | 4 ++-- src/libnptm/magma_calls.cpp | 16 ++++++++-------- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/include/algebraic.h b/src/include/algebraic.h index accdf818..f052bb2d 100644 --- a/src/include/algebraic.h +++ b/src/include/algebraic.h @@ -36,7 +36,8 @@ * \param ier: `int &` Reference to an integer variable for returning a result flag. * \param max_size: `np_int` The maximum expected size (required by some call-backs, * optional, defaults to 0). + * \param target_device: `int` ID of target GPU, if available (defaults to 0). */ -void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size=0); +void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size=0, int target_device=0); #endif diff --git a/src/include/magma_calls.h b/src/include/magma_calls.h index bf39c3a5..1002d351 100644 --- a/src/include/magma_calls.h +++ b/src/include/magma_calls.h @@ -31,7 +31,8 @@ * \param mat: Matrix of complex. The matrix to be inverted. * \param n: `np_int` The number of rows and columns of the [n x n] matrix. * \param jer: `int &` Reference to an integer return flag. + * \param device_id: `int` ID of the device for matrix inversion offloading. */ -void magma_zinvert(dcomplex **mat, np_int n, int &jer); +void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id=0); #endif diff --git a/src/libnptm/algebraic.cpp b/src/libnptm/algebraic.cpp index c942a1fd..c25ea0aa 100644 --- a/src/libnptm/algebraic.cpp +++ b/src/libnptm/algebraic.cpp @@ -44,10 +44,10 @@ extern void lucin(dcomplex **mat, np_int max_size, np_int size, int &ier); using namespace std; -void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size) { +void invert_matrix(dcomplex **mat, np_int size, int &ier, np_int max_size, int target_device) { ier = 0; #ifdef USE_MAGMA - magma_zinvert(mat, size, ier); + magma_zinvert(mat, size, ier, target_device); #elif defined USE_LAPACK zinvert(mat, size, ier); #else diff --git a/src/libnptm/magma_calls.cpp b/src/libnptm/magma_calls.cpp index d9875d9b..4cbae602 100644 --- a/src/libnptm/magma_calls.cpp +++ b/src/libnptm/magma_calls.cpp @@ -27,19 +27,19 @@ #include "../include/magma_calls.h" #endif -void magma_zinvert(dcomplex **mat, np_int n, int &jer) { +void magma_zinvert(dcomplex **mat, np_int n, int &jer, int device_id) { // magma_int_t result = magma_init(); magma_int_t err = MAGMA_SUCCESS; magma_queue_t queue = NULL; - magma_device_t dev = 0; + magma_device_t dev = (magma_device_t)device_id; magma_queue_create(dev, &queue); - magmaDoubleComplex *dwork; // dwork - workspace + magmaDoubleComplex *dwork; // workspace magma_int_t ldwork; // size of dwork - magma_int_t *piv , info; // piv - array of indices of inter - + magma_int_t *piv , info; // array of pivot indices magma_int_t m = (magma_int_t)n; // changed rows; a - mxm matrix - magma_int_t mm = m * m; // size of a, r, c - magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // a - mxm matrix on the host - magmaDoubleComplex *d_a; // d_a - mxm matrix a on the device + magma_int_t mm = m * m; // size of a + magmaDoubleComplex *a = (magmaDoubleComplex *)&(mat[0][0]); // pointer to first element on host + magmaDoubleComplex *d_a; // pointer to first element on device ldwork = m * magma_get_zgetri_nb(m); // optimal block size // allocate matrices err = magma_zmalloc(&d_a, mm); // device memory for a @@ -50,7 +50,7 @@ void magma_zinvert(dcomplex **mat, np_int n, int &jer) { magma_zgetrf_gpu(m, m, d_a, m, piv, &info); magma_zgetri_gpu(m, d_a, m, piv, dwork, ldwork, &info); - magma_zgetmatrix( m, m, d_a , m, a, m, queue); // copy d_a -> a + magma_zgetmatrix(m, m, d_a , m, a, m, queue); // copy d_a -> a delete[] piv; // free host memory magma_free(d_a); // free device memory magma_queue_destroy(queue); // destroy queue -- GitLab