From d058d28e2d75055801bf91e66f0b705eb65a6b0d Mon Sep 17 00:00:00 2001
From: Giovanni La Mura <giovanni.lamura@inaf.it>
Date: Tue, 6 Feb 2024 12:34:08 +0100
Subject: [PATCH] Fix transition matrix element output to binary format

---
 src/libnptm/TransitionMatrix.cpp | 70 +++++++++++++++++++++++---------
 1 file changed, 50 insertions(+), 20 deletions(-)

diff --git a/src/libnptm/TransitionMatrix.cpp b/src/libnptm/TransitionMatrix.cpp
index 29efb578..9db85027 100644
--- a/src/libnptm/TransitionMatrix.cpp
+++ b/src/libnptm/TransitionMatrix.cpp
@@ -110,22 +110,22 @@ TransitionMatrix* TransitionMatrix::from_hdf5(string file_name) {
     // This vector will be passed to the new object. DO NOT DELETE HERE!
     complex<double> *_elements;
     double _radius = 0.0;
-    status = hdf_file->read("/IS", "INT32", &_is);
-    status = hdf_file->read("/L_MAX", "INT32", &_lm);
-    status = hdf_file->read("/VK", "FLOAT64", &_vk);
-    status = hdf_file->read("/EXRI", "FLOAT64", &_exri);
+    status = hdf_file->read("IS", "INT32", &_is);
+    status = hdf_file->read("L_MAX", "INT32", &_lm);
+    status = hdf_file->read("VK", "FLOAT64", &_vk);
+    status = hdf_file->read("EXRI", "FLOAT64", &_exri);
     if (_is == 1111) {
       int num_elements = 2 * _lm;
       double *file_vector = new double[2 * num_elements]();
       hid_t file_id = hdf_file->get_file_id();
-      hid_t dset_id = H5Dopen2(file_id, "/ELEMENTS", H5P_DEFAULT);
+      hid_t dset_id = H5Dopen2(file_id, "ELEMENTS", H5P_DEFAULT);
       status = H5Dread(dset_id, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, file_vector);
       _elements = new complex<double>[num_elements]();
       for (int ei = 0; ei < num_elements; ei++) {
 	_elements[ei] = complex<double>(file_vector[2 * ei], file_vector[2 * ei + 1]);
       }
       status = H5Dclose(dset_id);
-      status = hdf_file->read("/RADIUS", "FLOAT64", &_radius);
+      status = hdf_file->read("RADIUS", "FLOAT64", &_radius);
       tm = new TransitionMatrix(_is, _lm, _vk, _exri, _elements, _radius);
       delete[] file_vector;
     } else if (_is == 1) {
@@ -133,14 +133,14 @@ TransitionMatrix* TransitionMatrix::from_hdf5(string file_name) {
       int num_elements = nlemt * nlemt;
       double *file_vector = new double[2 * num_elements]();
       hid_t file_id = hdf_file->get_file_id();
-      hid_t dset_id = H5Dopen2(file_id, "/ELEMENTS", H5P_DEFAULT);
+      hid_t dset_id = H5Dopen2(file_id, "ELEMENTS", H5P_DEFAULT);
       status = H5Dread(dset_id, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, file_vector);
       _elements = new complex<double>[num_elements]();
       for (int ei = 0; ei < num_elements; ei++) {
 	_elements[ei] = complex<double>(file_vector[2 * ei], file_vector[2 * ei + 1]);
       }
       status = H5Dclose(dset_id);
-      status = hdf_file->read("/RADIUS", "FLOAT64", &_radius);
+      status = hdf_file->read("RADIUS", "FLOAT64", &_radius);
       tm = new TransitionMatrix(_is, _lm, _vk, _exri, _elements, _radius);
       delete[] file_vector;
     }
@@ -215,33 +215,32 @@ void TransitionMatrix::write_hdf5(string file_name) {
     List<string> rec_type_list(1);
     List<void *> rec_ptr_list(1);
     string str_type, str_name;
-    rec_name_list.set(0, "/IS");
+    rec_name_list.set(0, "IS");
     rec_type_list.set(0, "INT32_(1)");
     rec_ptr_list.set(0, &is);
-    rec_name_list.append("/L_MAX");
+    rec_name_list.append("L_MAX");
     rec_type_list.append("INT32_(1)");
     rec_ptr_list.append(&l_max);
-    rec_name_list.append("/VK");
+    rec_name_list.append("VK");
     rec_type_list.append("FLOAT64_(1)");
     rec_ptr_list.append(&vk);
-    rec_name_list.append("/EXRI");
+    rec_name_list.append("EXRI");
     rec_type_list.append("FLOAT64_(1)");
     rec_ptr_list.append(&exri);
-    rec_name_list.append("/ELEMENTS");
+    rec_name_list.append("ELEMENTS");
     str_type = "FLOAT64_(" + to_string(shape[0]) + "," + to_string(2 * shape[1]) + ")";
     rec_type_list.append(str_type);
     // The (N x M) matrix of complex is converted to a (N x 2M) matrix of double,
     // where REAL(E_i,j) -> E_i,(2 j) and IMAG(E_i,j) -> E_i,(2 j + 1)
-    double *ptr_elements = new double[shape[0] * 2 * shape[1]]();
-    for (int ei = 0; ei < shape[0]; ei++) {
-      for (int ej = 0; ej < shape[1]; ej++) {
-	ptr_elements[shape[1] * ei + 2 * ej] = elements[shape[1] * ei + ej].real();
-	ptr_elements[shape[1] * ei + 2 * ej + 1] = elements[shape[1] * ei  + ej].imag();
-      }
+    int num_elements = 2 * shape[0] * shape[1];
+    double *ptr_elements = new double[num_elements]();
+    for (int ei = 0; ei < num_elements / 2; ei++) {
+      ptr_elements[2 * ei] = elements[ei].real();
+      ptr_elements[2 * ei + 1] = elements[ei].imag();
     }
     rec_ptr_list.append(ptr_elements);
     if (is == 1111) {
-      rec_name_list.append("/RADIUS");
+      rec_name_list.append("RADIUS");
       rec_type_list.append("FLOAT64_(1)");
       rec_ptr_list.append(&sphere_radius);
     }
@@ -299,3 +298,34 @@ void TransitionMatrix::write_legacy(string file_name) {
     printf("ERROR: could not open Transition Matrix file for writing.\n");
   }
 }
+
+bool TransitionMatrix::operator ==(TransitionMatrix &other) {
+  if (is != other.is) {
+    return false;
+  }
+  if (l_max != other.l_max) {
+    return false;
+  }
+  if (vk != other.vk) {
+    return false;
+  }
+  if (exri != other.exri) {
+    return false;
+  }
+  if (sphere_radius != other.sphere_radius) {
+    return false;
+  }
+  if (shape[0] != other.shape[0]) {
+    return false;
+  }
+  if (shape[1] != other.shape[1]) {
+    return false;
+  }
+  int num_elements = shape[0] * shape[1];
+  for (int ei = 0; ei < num_elements; ei++) {
+    if (elements[ei] != other.elements[ei]) {
+      return false;
+    }
+  }
+  return true;
+}
-- 
GitLab