From 7dd24b46e521ffb49515d6f8d7c2b63aa500a417 Mon Sep 17 00:00:00 2001 From: Giovanni La Mura Date: Wed, 9 Jul 2025 17:22:58 +0200 Subject: [PATCH] Allocate global work-space directly on device, if using offload --- src/trapping/cfrfme.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/trapping/cfrfme.cpp b/src/trapping/cfrfme.cpp index 51ace8e1..fbc9a935 100644 --- a/src/trapping/cfrfme.cpp +++ b/src/trapping/cfrfme.cpp @@ -401,16 +401,25 @@ void frfme(string data_file, string output_path) { int size_vkzm = nkv * nkv; const dcomplex *vec_tt1_wk = tt1->wk; int size_tt1_wk = nkv * nkv * nlmmt; - dcomplex *global_vec_w = new dcomplex[nkv * nkv * (jlml - jlmf + 1)](); - dcomplex **global_w = new dcomplex*[nkv * (jlml - jlmf + 1)]; int size_global_vec_w = nkv * nkv * (jlml - jlmf + 1); int size_global_w = nkv * (jlml - jlmf + 1); + int device_id = 0; + // Work-space pointers for simultaneous threads + dcomplex *global_vec_w; + dcomplex **global_w; #ifdef USE_TARGET_OFFLOAD + // Device-only work-space allocation + device_id = omp_get_default_device(); + global_vec_w = (dcomplex *)omp_target_alloc(size_global_vec_w * sizeof(dcomplex), device_id); + global_w = (dcomplex **)omp_target_alloc(size_global_w * sizeof(dcomplex), device_id); #pragma omp target teams distribute parallel for simd map(tofrom: vec_wsum[0:size_wsum]) \ map(to:vec_vkzm[0:size_vkzm], vkv[0:nkv], vec_tt1_wk[0:size_tt1_wk], _xv[0:nxv], _yv[0:nyv], _zv[0:nzv]) \ - map(alloc:global_w[0:size_global_w], global_vec_w[0:size_global_vec_w]) \ + map(to: global_vec_w, global_w) \ firstprivate(jlmf, jlml, nkv, nlmmt, nrvc, nxv, nyv, nzv, frsh, uim, delks) #else + // Fall-back host work-space allocation + global_vec_w = = new dcomplex[size_global_vec_w](); + global_w = new dcomplex*[size_global_w]; #pragma omp parallel for simd #endif for (int j80 = jlmf-1; j80 < jlml; j80++) { @@ -457,8 +466,13 @@ void frfme(string data_file, string output_path) { vec_wsum[((j80) * nrvc) + ixyz] = sumy * delks; } // ixyz loop } // j80 loop +#ifdef USE_TARGET_OFFLOAD + omp_target_free(global_w, device_id); + omp_target_free(global_vec_w, device_id); +#else delete[] global_w; delete[] global_vec_w; +#endif // USE_TARGET_OFFLOAD #ifdef USE_NVTX nvtxRangePop(); nvtxRangePush("Closing operations"); -- GitLab