From d4b8259b82369b0ec9374457eb2c5e4d74c8e5a7 Mon Sep 17 00:00:00 2001
From: Giovanni La Mura <giovanni.lamura@inaf.it>
Date: Wed, 15 May 2024 12:02:08 +0200
Subject: [PATCH] Use CUDA calls to test for devices

---
 src/cluster/cluster.cpp | 36 +++++++++++++++++++++++++++++++++---
 src/make.inc            |  2 +-
 2 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/src/cluster/cluster.cpp b/src/cluster/cluster.cpp
index a5664ef7..e31bcbd4 100644
--- a/src/cluster/cluster.cpp
+++ b/src/cluster/cluster.cpp
@@ -30,6 +30,9 @@
 #include <mpi.h>
 #endif
 #endif
+#ifdef USE_MAGMA
+#include <cuda_runtime.h>
+#endif
 
 #ifndef INCLUDE_TYPES_H_
 #include "../include/types.h"
@@ -88,8 +91,25 @@ void cluster(const string& config_file, const string& data_file, const string& o
   Logger *time_logger = new Logger(LOG_DEBG, timing_file);
   Logger *logger = new Logger(LOG_DEBG);
 #ifdef USE_MAGMA
+  int device_count;
+  cudaGetDeviceCount(&device_count);
+  logger->log("DEBUG: Proc-" + to_string(mpidata->rank) + " found " + to_string(device_count) + " CUDA devices.\n", LOG_DEBG);
   logger->log("INFO: Process " + to_string(mpidata->rank) + " initializes MAGMA.\n");
-  magma_init();
+  magma_device_t *devices = new magma_device_t[device_count];
+  cudaSetValidDevices(devices, device_count);
+  magma_int_t num_devices;
+  magma_getdevices(devices, device_count, &num_devices);
+  logger->log("DEBUG: Proc-" + to_string(mpidata->rank) + " found " + to_string(num_devices) + " MAGMA devices.\n", LOG_DEBG);
+  magma_int_t magma_result = magma_init();
+  if (magma_result != MAGMA_SUCCESS) {
+    logger->err("ERROR: Process " + to_string(mpidata->rank) + " failed to initilize MAGMA.\n");
+    logger->err("PROC-" + to_string(mpidata->rank) + ": MAGMA error code " + to_string(magma_result) + "\n");
+    fclose(timing_file);
+    delete[] devices;
+    delete time_logger;
+    delete logger;
+    return;
+  }
 #endif
   // the following only happens on MPI process 0
   if (mpidata->rank == 0) {
@@ -101,7 +121,11 @@ void cluster(const string& config_file, const string& data_file, const string& o
       logger->err("\nERROR: failed to open scatterer configuration file.\n");
       string message = "FILE: " + string(ex.what()) + "\n";
       logger->err(message);
-      exit(1);
+      fclose(timing_file);
+      delete[] devices;
+      delete time_logger;
+      delete logger;
+      return;
     }
     sconf->write_formatted(output_path + "/c_OEDFB");
     sconf->write_binary(output_path + "/c_TEDF");
@@ -114,7 +138,11 @@ void cluster(const string& config_file, const string& data_file, const string& o
       string message = "FILE: " + string(ex.what()) + "\n";
       logger->err(message);
       if (sconf) delete sconf;
-      exit(1);
+      fclose(timing_file);
+      delete[] devices;
+      delete time_logger;
+      delete logger;
+      return;
     }
     logger->log(" done.\n", LOG_INFO);
     int s_nsph = sconf->number_of_spheres;
@@ -217,6 +245,7 @@ void cluster(const string& config_file, const string& data_file, const string& o
 	  tppoan.close();
 	  fclose(timing_file);
 	  fclose(output);
+	  delete[] devices;
 	  delete p_scattering_angles;
 	  delete cid;
 	  delete logger;
@@ -526,6 +555,7 @@ void cluster(const string& config_file, const string& data_file, const string& o
       }
     }
     // Clean memory
+    delete[] devices;
     delete cid;
     delete p_scattering_angles;
     delete sconf;
diff --git a/src/make.inc b/src/make.inc
index d7c5dd4d..3fadc9a8 100644
--- a/src/make.inc
+++ b/src/make.inc
@@ -81,7 +81,7 @@ endif
 
 # define (outside) USE_MAGMA for magma support
 ifdef USE_MAGMA
-MAGMA_LDFLAGS= -lmagma
+MAGMA_LDFLAGS= -lmagma -lcudart
 #the next endif is for USE_MAGMA
 endif
 
-- 
GitLab