/* Copyright (C) 2024   INAF - Osservatorio Astronomico di Cagliari

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   A copy of the GNU General Public License is distributed along with
   this program in the COPYING file. If not, see: <https://www.gnu.org/licenses/>.
 */

/*! \file test_zinvert.cpp
 *
 * \brief Test of matrix inversion process.
 */
#include <cstdio>

#ifndef INCLUDE_TYPES_H_
#include "../include/types.h"
#endif

#ifdef USE_MAGMA
#include "../include/magma_calls.h"
#define SUCCESS MAGMA_SUCCESS
#else
#define SUCCESS 0
#endif

#ifdef USE_MAGMA_SVD
#include "magma_operators.h"
#endif

#include "../include/algebraic.h"

dcomplex* matrix_mult(
		 dcomplex *a1, const magma_int_t r1, const magma_int_t c1,
		 dcomplex *a2, const magma_int_t r2, const magma_int_t c2
) {
  dcomplex *result = NULL;
  dcomplex value;
  if (c1 == r2) {
    result = new dcomplex[r1 * c2]();
    for (magma_int_t i = 0; i < r1; i++) {
      for (magma_int_t j = 0; j < c2; j++) {
	value = 0.0 + I * 0.0;
	for (magma_int_t k = 0; k < c1; k++) {
	  value += a1[i * c1 + k] * a2[k * c2 + j];
	}
	result[i * c2 + j] = value;
      }
    }
  }
  return result;
}

void print_matrix(dcomplex *a, const magma_int_t rows, const magma_int_t columns) {
  for (magma_int_t i = 0; i < rows; i++) {
    for (magma_int_t j = 0; j < columns; j++) {
      printf("\t(%.5le,%.5le)", real(a[i * columns + j]), imag(a[i * columns + j]));
    }
    printf("\n");
  }
}

int main() {
  int result = (int)SUCCESS;
#ifdef USE_MAGMA
  magma_init();
#endif
  const np_int n = 2;
  const np_int nn = n * n;
  const double tolerance = 1.0e-6;
  dcomplex **test_matrix;
  dcomplex *vec_matrix = new dcomplex[nn];
  dcomplex *old_matrix = new dcomplex[nn];
  dcomplex *prod_matrix;
  for (int ivi = 0; ivi < nn; ivi++) {
    vec_matrix[ivi] = 2.0 + ivi - I * (2 + ivi) / 2.0;
    old_matrix[ivi] = vec_matrix[ivi];
  }
  test_matrix = &vec_matrix;
#ifdef USE_MAGMA
  if (sizeof(np_int) != sizeof(magma_int_t)) {
    printf("ERROR: sizeof(np_int) = %ld; sizeof(magma_int_t) = %ld\n",
	   sizeof(np_int), sizeof(magma_int_t));
    result = 2;
  }
#endif
  if (result == 0) {
    dcomplex difference;
    invert_matrix(test_matrix, n, result, n);
    prod_matrix = matrix_mult(vec_matrix, n, n, old_matrix, n, n);
    for (int tvi = 0; tvi < n; tvi++) {
      for (int tvj = 0; tvj < n; tvj++) {
	difference = (tvi == tvj) ? prod_matrix[n * tvi + tvj] - 1.0 : prod_matrix[n * tvi + tvj];
	if (real(difference) < -tolerance || real(difference) > tolerance) result = 1;
	if (imag(difference) < -tolerance || imag(difference) > tolerance) result = 1;
      }
    }
    if (result != 0) {
      printf("ERROR: failed matrix inversion test!\n");
    }
    printf("INFO: original matrix:\n");
    print_matrix(old_matrix, n, n);
    printf("INFO: inverted matrix:\n");
    print_matrix(vec_matrix, n, n);
    printf("INFO: product matrix:\n");
    print_matrix(prod_matrix, n, n);
    delete[] prod_matrix;
  }
  delete[] vec_matrix;
  delete[] old_matrix;
#ifdef USE_MAGMA
  magma_finalize();
#endif
  return result;
}
