diff --git a/src/cluster/cluster.cpp b/src/cluster/cluster.cpp index 75e053cd61fd0b94303a6a4247ebabc26ede805e..cc309294071ec31e26f8ad89f2dce648a01197d9 100644 --- a/src/cluster/cluster.cpp +++ b/src/cluster/cluster.cpp @@ -33,8 +33,12 @@ #endif #ifdef USE_LAPACK +#ifdef USE_MKL +#include <mkl_lapacke.h> +#else +#include <lapacke.h> +#endif #ifndef INCLUDE_LAPACK_CALLS_H_ -#include "lapacke.h" #include "../include/lapack_calls.h" #endif #endif diff --git a/src/include/algebraic.h b/src/include/algebraic.h index 8c0cf1242a06aaa3d95885bd05cb02ef357ccca7..2ecc680f576093d84183cb005e6a8658737a450c 100644 --- a/src/include/algebraic.h +++ b/src/include/algebraic.h @@ -10,7 +10,13 @@ * legacy serial function implementation is used as a fall-back. */ -#ifndef lapack_int +#ifdef USE_LAPACK +#ifdef USE_MKL +#include <mkl_lapacke.h> +#else +#include <lapacke.h> +#endif +#else #define lapack_int int64_t #endif diff --git a/src/libnptm/algebraic.cpp b/src/libnptm/algebraic.cpp index a1c420315ee5535df368f4c9a6736e6f068ecd9d..9d62cccd58df9e49738df630d325ed3b87592d2b 100644 --- a/src/libnptm/algebraic.cpp +++ b/src/libnptm/algebraic.cpp @@ -9,10 +9,16 @@ #include "../include/algebraic.h" #endif +#ifdef USE_LAPACK +#ifdef USE_MKL +#include <mkl_lapacke.h> +#else +#include <lapacke.h> +#endif #ifndef INCLUDE_LAPACK_CALLS_H_ -#include "lapacke.h" #include "../include/lapack_calls.h" #endif +#endif // >>> FALL-BACK FUNCTIONS DECLARATION <<< // extern void lucin(std::complex<double> **mat, int64_t max_size, int64_t size, int &ier); diff --git a/src/libnptm/lapack_calls.cpp b/src/libnptm/lapack_calls.cpp index 8ec06e9899f7486488b21ac57f6e5ae46477f49d..f6bedff11f251c732f60f5c4d292a3b49cb2668c 100644 --- a/src/libnptm/lapack_calls.cpp +++ b/src/libnptm/lapack_calls.cpp @@ -1,16 +1,25 @@ #include <complex> #include <complex.h> +#ifdef USE_LAPACK +#ifdef USE_MKL +#include <mkl_lapacke.h> +#else +#include <lapacke.h> +#endif #ifndef INCLUDE_LAPACK_CALLS_H_ -#include "lapacke.h" #include "../include/lapack_calls.h" #endif +#endif -#ifdef USE_LAPACK void zinvert(std::complex<double> **mat, lapack_int n, int &jer) { +#ifdef USE_LAPACK jer = 0; __complex__ double *arr = new __complex__ double[n * n]; const __complex__ double uim = 1.0*I; +#ifdef USE_MKL + MKL_Complex16 *arr2 = (MKL_Complex16 *) arr; +#endif for (lapack_int i = 0; i < n; i++) { for (lapack_int j = 0; j < n; j++) { lapack_int idx = i + n * j; @@ -20,8 +29,13 @@ void zinvert(std::complex<double> **mat, lapack_int n, int &jer) { lapack_int* IPIV = new lapack_int[n](); +#ifdef USE_MKL + LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr2, n, IPIV); + LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr2, n, IPIV); +#else LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, arr, n, IPIV); LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, arr, n, IPIV); +#endif for (lapack_int i = 0; i < n; i++) { for (lapack_int j = 0; j < n; j++) { lapack_int idx = i + n * j; @@ -30,5 +44,5 @@ void zinvert(std::complex<double> **mat, lapack_int n, int &jer) { } delete[] IPIV; delete[] arr; -} #endif +} diff --git a/src/make.inc b/src/make.inc index 712b27f398f0d680638d07fe62c12ba24dab4c6d..2ba3d6b135da8c1627da9b0e5504c6c266cc8777 100644 --- a/src/make.inc +++ b/src/make.inc @@ -77,7 +77,7 @@ override CXXFLAGS=-O3 -ggdb -pg -coverage -I$(HDF5_INCLUDE) ifdef USE_LAPACK override CXXFLAGS+= -DUSE_LAPACK -DLAPACK_ILP64 ifdef USE_MKL -override CXXFLAGS+= -DMKL_ILP64 -I$(MKLROOT)/include +override CXXFLAGS+= -DMKL_ILP64 -DUSE_MKL -I$(MKLROOT)/include endif endif endif