From 1640c3c5668c563353f3f9a93fb5aa51ba2510f8 Mon Sep 17 00:00:00 2001
From: Giovanni La Mura <giovanni.lamura@inaf.it>
Date: Thu, 25 Jan 2024 19:00:50 +0100
Subject: [PATCH] Implement HDF5 I/O of transition matrix for SPHERE

---
 src/include/TransitionMatrix.h   |  44 ++++++-
 src/include/file_io.h            |  17 +++
 src/libnptm/TransitionMatrix.cpp | 205 ++++++++++++++++++++++++++++++-
 src/libnptm/file_io.cpp          |  44 ++++++-
 src/sphere/sphere.cpp            |   6 +-
 5 files changed, 307 insertions(+), 9 deletions(-)

diff --git a/src/include/TransitionMatrix.h b/src/include/TransitionMatrix.h
index 3e08852e..155b627f 100644
--- a/src/include/TransitionMatrix.h
+++ b/src/include/TransitionMatrix.h
@@ -47,6 +47,22 @@ class TransitionMatrix {
   //! Matrix shape
   int *shape;
 
+  /*! \brief Build transition matrix from a HDF5 binary input file.
+   *
+   * \param file_name: `string` Name of the binary configuration data file.
+   * \return config: `TransitionMatrix *` Pointer to object containing the
+   * transition matrix data.
+   */
+  static TransitionMatrix *from_hdf5(std::string file_name);
+  
+  /*! \brief Build transition matrix from a legacy binary input file.
+   *
+   * \param file_name: `string` Name of the binary configuration data file.
+   * \return config: `TransitionMatrix *` Pointer to object containing the
+   * transition matrix data.
+   */
+  static TransitionMatrix *from_legacy(std::string file_name);
+
   /*! \brief Write the Transition Matrix to HDF5 binary output.
    *
    * \param file_name: `string` Name of the binary configuration data file.
@@ -59,6 +75,20 @@ class TransitionMatrix {
    */
   void write_legacy(std::string file_name);
  public:
+  /*! \brief Default Transition Matrix instance constructor.
+   *
+   * \param _is: `int` Matrix type identifier
+   * \param _lm: `int` Maximum field expansion order.
+   * \param _vk: `double`
+   * \param _exri: `double`
+   * \param _elements: `complex<double> *` Vectorized elements of the matrix.
+   * \param _radius: `double` Radius for the single sphere case (defaults to 0.0).
+   */
+  TransitionMatrix(
+		   int _is, int _lm, double _vk, double _exri, std::complex<double> *_elements,
+		   double _radius=0.0
+  );
+
   /*! \brief Transition Matrix instance constructor for single sphere.
    *
    * \param _lm: `int` Maximum field expansion order.
@@ -71,7 +101,7 @@ class TransitionMatrix {
   TransitionMatrix(
 		   int _lm, double _vk, double _exri, std::complex<double> **_rmi,
 		   std::complex<double> **_rei, double _sphere_radius
-		   );
+  );
 
   /*! \brief Transition Matrix instance constructor for a cluster of spheres.
    *
@@ -81,12 +111,22 @@ class TransitionMatrix {
    * \param _exri: `double`
    * \param _am0m: Matrix of complex.
    */
-  TransitionMatrix(int _nlemt, int _lm, double _vk, double _exri, std::complex<double> **am0m);
+  TransitionMatrix(int _nlemt, int _lm, double _vk, double _exri, std::complex<double> **_am0m);
 
   /*! \brief Transition Matrix instance destroyer.
    */
   ~TransitionMatrix();
   
+  /*! \brief Build transition matrix from binary input file.
+   *
+   * \param file_name: `string` Name of the binary configuration data file.
+   * \param mode: `string` Binary encoding. Can be one of ["LEGACY", "HDF5"]. Optional
+   * (default is "LEGACY").
+   * \return config: `TransitionMatrix *` Pointer to object containing the transition
+   * matrix data.
+   */
+  static TransitionMatrix* from_binary(std::string file_name, std::string mode="LEGACY");
+
   /*! \brief Write the Transition Matrix to a binary file.
    *
    * \param file_name: `string` Name of the file to be written.
diff --git a/src/include/file_io.h b/src/include/file_io.h
index 0a0aad2a..27ef8129 100644
--- a/src/include/file_io.h
+++ b/src/include/file_io.h
@@ -120,6 +120,23 @@ class HDFFile {
    */
   bool is_open() { return file_open_flag; }
   
+  /*! \brief Read data from attached file.
+   *
+   * \param dataset_name: `string` Name of the dataset to read from.
+   * \param data_type: `string` Memory data type identifier.
+   * \param buffer: `hid_t` Starting address of the memory sector to store the data.
+   * \param mem_space_id: `hid_t` Memory data space identifier (defaults to `H5S_ALL`).
+   * \param file_space_id: `hid_t` File space identifier (defaults to `H5S_ALL`).
+   * \param dapl_id: `hid_t` Data access property list identifier (defaults to `H5P_DEFAULT`).
+   * \param dxpl_id: `hid_t` Data transfer property list identifier (defaults to `H5P_DEFAULT`).
+   * \return status: `herr_t` Exit status of the operation.
+   */
+  herr_t read(
+	       std::string dataset_name, std::string data_type, void *buffer,
+	       hid_t mem_space_id=H5S_ALL, hid_t file_space_id=H5S_ALL,
+	       hid_t dapl_id=H5P_DEFAULT, hid_t dxpl_id=H5P_DEFAULT
+  );
+
   /*! \brief Write data to attached file.
    *
    * \param dataset_name: `string` Name of the dataset to write to.
diff --git a/src/libnptm/TransitionMatrix.cpp b/src/libnptm/TransitionMatrix.cpp
index a9a08cf7..f66daa5d 100644
--- a/src/libnptm/TransitionMatrix.cpp
+++ b/src/libnptm/TransitionMatrix.cpp
@@ -5,11 +5,20 @@
 #include <complex>
 #include <exception>
 #include <fstream>
+#include <hdf5.h>
+
+#ifndef INCLUDE_LIST_H_
+#include "../include/List.h"
+#endif
 
 #ifndef INCLUDE_TRANSITIONMATRIX_H_
 #include "../include/TransitionMatrix.h"
 #endif
 
+#ifndef INCLUDE_FILE_IO_H_
+#include "../include/file_io.h"
+#endif
+
 using namespace std;
 
 TransitionMatrix::~TransitionMatrix() {
@@ -17,6 +26,27 @@ TransitionMatrix::~TransitionMatrix() {
   if (shape != NULL) delete[] shape;
 }
 
+TransitionMatrix::TransitionMatrix(
+				   int _is, int _lm, double _vk, double _exri, std::complex<double> *_elements,
+				   double _radius
+) {
+  is = _is;
+  l_max = _lm;
+  vk = _vk;
+  exri = _exri;
+  elements = _elements;
+  sphere_radius = _radius;
+  shape = new int[2]();
+  if (is == 1111) {
+    shape[0] = l_max;
+    shape[1] = 2;
+  } else if (is == 1) {
+    const int nlemt = 2 * l_max * (l_max + 2);
+    shape[0] = nlemt;
+    shape[1] = nlemt;
+  }
+}
+
 TransitionMatrix::TransitionMatrix(
 				   int _lm, double _vk, double _exri, complex<double> **_rmi,
 				   complex<double> **_rei, double _sphere_radius
@@ -38,7 +68,7 @@ TransitionMatrix::TransitionMatrix(
 
 TransitionMatrix::TransitionMatrix(
 				   int _nlemt, int _lm, double _vk, double _exri,
-				   std::complex<double> **am0m
+				   std::complex<double> **_am0m
 ) {
   is = 1;
   shape = new int[2];
@@ -50,8 +80,122 @@ TransitionMatrix::TransitionMatrix(
   sphere_radius = 0.0;
   elements = new complex<double>[_nlemt * _nlemt]();
   for (int ei = 0; ei < _nlemt; ei++) {
-    for (int ej = 0; ej < _nlemt; ej++) elements[_nlemt * ei + ej] = am0m[ei][ej];
+    for (int ej = 0; ej < _nlemt; ej++) elements[_nlemt * ei + ej] = _am0m[ei][ej];
+  }
+}
+
+TransitionMatrix* TransitionMatrix::from_binary(string file_name, string mode) {
+  TransitionMatrix *tm = NULL;
+  if (mode.compare("LEGACY") == 0) {
+    tm = TransitionMatrix::from_legacy(file_name);
+  } else if (mode.compare("HDF5") == 0) {
+    tm = TransitionMatrix::from_hdf5(file_name);
+  } else {
+    string message = "Unknown format mode: \"" + mode + "\"";
+    throw UnrecognizedFormatException(message);
   }
+  return tm;
+}
+
+TransitionMatrix* TransitionMatrix::from_hdf5(string file_name) {
+  TransitionMatrix *tm = NULL;
+  unsigned int flags = H5F_ACC_RDONLY;
+  HDFFile *hdf_file = new HDFFile(file_name, flags);
+  herr_t status = hdf_file->get_status();
+  if (status == 0) {
+    int _is;
+    int _lm;
+    double _vk;
+    double _exri;
+    // This vector will be passed to the new object. DO NOT DELETE HERE!
+    complex<double> *_elements;
+    double _radius = 0.0;
+    status = hdf_file->read("/IS", "INT32", &_is);
+    status = hdf_file->read("/L_MAX", "INT32", &_lm);
+    status = hdf_file->read("/VK", "FLOAT64", &_vk);
+    status = hdf_file->read("/EXRI", "FLOAT64", &_exri);
+    if (_is == 1111) {
+      int num_elements = 2 * _lm;
+      double *file_vector = new double[2 * num_elements]();
+      hid_t file_id = hdf_file->get_file_id();
+      hid_t dset_id = H5Dopen2(file_id, "/ELEMENTS", H5P_DEFAULT);
+      status = H5Dread(dset_id, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, file_vector);
+      _elements = new complex<double>[num_elements]();
+      for (int ei = 0; ei < num_elements; ei++) {
+	_elements[ei] = complex<double>(file_vector[2 * ei], file_vector[2 * ei + 1]);
+      }
+      status = H5Dclose(dset_id);
+      status = hdf_file->read("/RADIUS", "FLOAT64", &_radius);
+      tm = new TransitionMatrix(_is, _lm, _vk, _exri, _elements, _radius);
+      delete[] file_vector;
+    } else if (_is == 1) {
+      int nlemt = 2 * _lm * (_lm + 2);
+      int num_elements = nlemt * nlemt;
+      double *file_vector = new double[2 * num_elements]();
+      hid_t file_id = hdf_file->get_file_id();
+      hid_t dset_id = H5Dopen2(file_id, "/ELEMENTS", H5P_DEFAULT);
+      status = H5Dread(dset_id, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, file_vector);
+      _elements = new complex<double>[num_elements]();
+      for (int ei = 0; ei < num_elements; ei++) {
+	_elements[ei] = complex<double>(file_vector[2 * ei], file_vector[2 * ei + 1]);
+      }
+      status = H5Dclose(dset_id);
+      status = hdf_file->read("/RADIUS", "FLOAT64", &_radius);
+      tm = new TransitionMatrix(_is, _lm, _vk, _exri, _elements, _radius);
+      delete[] file_vector;
+    }
+    status = hdf_file->close();
+  } else {
+    printf("ERROR: could not open file \"%s\"\n", file_name.c_str());
+  }
+  return tm;
+}
+
+TransitionMatrix* TransitionMatrix::from_legacy(string file_name) {
+  fstream ttms;
+  TransitionMatrix *tm = NULL;
+  ttms.open(file_name, ios::binary | ios::in);
+  if (ttms.is_open()) {
+    int num_elements = 0;
+    int _is;
+    int _lm;
+    double _vk;
+    double _exri;
+    // This vector will be passed to the new object. DO NOT DELETE HERE!
+    complex<double> *_elements;
+    double _radius = 0.0;
+    ttms.read(reinterpret_cast<char *>(&_is), sizeof(int));
+    ttms.read(reinterpret_cast<char *>(&_lm), sizeof(int));
+    ttms.read(reinterpret_cast<char *>(&_vk), sizeof(double));
+    ttms.read(reinterpret_cast<char *>(&_exri), sizeof(double));
+    if (_is == 1111) {
+      num_elements = _lm * 2;
+      _elements = new complex<double>[num_elements]();
+      for (int ei = 0; ei < num_elements; ei++) {
+	double vreal, vimag;
+	ttms.read(reinterpret_cast<char *>(&vreal), sizeof(double));
+	ttms.read(reinterpret_cast<char *>(&vimag), sizeof(double));
+	_elements[ei] = complex<double>(vreal, vimag);
+      }
+      double _radius;
+      ttms.read(reinterpret_cast<char *>(&_radius), sizeof(double));
+      tm = new TransitionMatrix(_is, _lm, _vk, _exri, _elements, _radius);
+    } else if (_is == 1) {
+      int nlemt = 2 * _lm * (_lm + 2);
+      num_elements = nlemt * nlemt;
+      _elements = new complex<double>[num_elements]();
+      for (int ei = 0; ei < num_elements; ei++) {
+	double vreal, vimag;
+	ttms.read(reinterpret_cast<char *>(&vreal), sizeof(double));
+	ttms.read(reinterpret_cast<char *>(&vimag), sizeof(double));
+	_elements[ei] = complex<double>(vreal, vimag);
+      }
+      tm = new TransitionMatrix(_is, _lm, _vk, _exri, _elements);
+    }
+  } else {
+    printf("ERROR: could not open file \"%s\"\n", file_name.c_str());
+  }
+  return tm;
 }
 
 void TransitionMatrix::write_binary(string file_name, string mode) {
@@ -66,8 +210,61 @@ void TransitionMatrix::write_binary(string file_name, string mode) {
 }
 
 void TransitionMatrix::write_hdf5(string file_name) {
-  // TODO: needs implementation.
-  return;
+  if (is == 1 || is == 1111) {
+    List<string> rec_name_list(1);
+    List<string> rec_type_list(1);
+    List<void *> rec_ptr_list(1);
+    string str_type, str_name;
+    rec_name_list.set(0, "/IS");
+    rec_type_list.set(0, "INT32_(1)");
+    rec_ptr_list.set(0, &is);
+    rec_name_list.append("/L_MAX");
+    rec_type_list.append("INT32_(1)");
+    rec_ptr_list.append(&l_max);
+    rec_name_list.append("/VK");
+    rec_type_list.append("FLOAT64_(1)");
+    rec_ptr_list.append(&vk);
+    rec_name_list.append("/EXRI");
+    rec_type_list.append("FLOAT64_(1)");
+    rec_ptr_list.append(&exri);
+    rec_name_list.append("/ELEMENTS");
+    str_type = "FLOAT64_(" + to_string(shape[0]) + "," + to_string(2 * shape[1]) + ")";
+    rec_type_list.append(str_type);
+    // The (N x M) matrix of complex is converted to a (N x 2M) matrix of double,
+    // where REAL(E_i,j) -> E_i,(2 j) and IMAG(E_i,j) -> E_i,(2 j + 1)
+    double *ptr_elements = new double[shape[0] * 2 * shape[1]]();
+    for (int ei = 0; ei < shape[0]; ei++) {
+      for (int ej = 0; ej < shape[1]; ej++) {
+	ptr_elements[ei * (2 * ej)] = elements[shape[1] * ei + ej].real();
+	ptr_elements[ei * (2 * ej) + 1] = elements[shape[1] * ei  + ej].imag();
+      }
+    }
+    rec_ptr_list.append(ptr_elements);
+    if (is == 1111) {
+      rec_name_list.append("/RADIUS");
+      rec_type_list.append("FLOAT64_(1)");
+      rec_ptr_list.append(&sphere_radius);
+    }
+
+    string *rec_names = rec_name_list.to_array();
+    string *rec_types = rec_type_list.to_array();
+    void **rec_pointers = rec_ptr_list.to_array();
+    const int rec_num = rec_name_list.length();
+    FileSchema schema(rec_num, rec_types, rec_names);
+    HDFFile *hdf_file = HDFFile::from_schema(schema, file_name, H5F_ACC_TRUNC);
+    for (int ri = 0; ri < rec_num; ri++)
+      hdf_file->write(rec_names[ri], rec_types[ri], rec_pointers[ri]);
+    hdf_file->close();
+    
+    delete[] ptr_elements;
+    delete[] rec_names;
+    delete[] rec_types;
+    delete[] rec_pointers;
+    delete hdf_file;
+  } else {
+    string message = "Unrecognized matrix data.";
+    throw UnrecognizedFormatException(message);
+  }
 }
 
 void TransitionMatrix::write_legacy(string file_name) {
diff --git a/src/libnptm/file_io.cpp b/src/libnptm/file_io.cpp
index a5df8a6a..0c0b0c17 100644
--- a/src/libnptm/file_io.cpp
+++ b/src/libnptm/file_io.cpp
@@ -47,7 +47,10 @@ string* FileSchema::get_record_types() {
 
 HDFFile::HDFFile(string name, unsigned int flags, hid_t fcpl_id, hid_t fapl_id) {
   file_name = name;
-  file_id = H5Fcreate(name.c_str(), flags, fcpl_id, fapl_id);
+  if (flags == H5F_ACC_EXCL || flags == H5F_ACC_TRUNC)
+    file_id = H5Fcreate(name.c_str(), flags, fcpl_id, fapl_id);
+  else if (flags == H5F_ACC_RDONLY || flags == H5F_ACC_RDWR)
+    file_id = H5Fopen(name.c_str(), flags, fapl_id);
   id_list = new List<hid_t>(1);
   id_list->set(0, file_id);
   if (file_id != H5I_INVALID_HID) file_open_flag = true;
@@ -132,6 +135,45 @@ HDFFile* HDFFile::from_schema(
   return hdf_file;
 }
 
+herr_t HDFFile::read(
+		      string dataset_name, string data_type, void *buffer,
+		      hid_t mem_space_id, hid_t file_space_id, hid_t dapl_id,
+		      hid_t dxpl_id
+) {
+  string known_types[] = {"INT32", "FLOAT64"};
+  regex re;
+  smatch m;
+  bool found_type = false;
+  int type_index = 0;
+  while (!found_type) {
+    re = regex(known_types[type_index++]);
+    found_type = regex_search(data_type, m, re);
+    if (type_index == 2) break;
+  }
+  if (found_type) {
+    hid_t dataset_id = H5Dopen2(file_id, dataset_name.c_str(), dapl_id);
+    hid_t mem_type_id;
+    switch (type_index) {
+    case 1:
+      mem_type_id = H5T_NATIVE_INT; break;
+    case 2:
+      mem_type_id = H5T_NATIVE_DOUBLE; break;
+    default:
+      throw runtime_error("Unrecognized data type \"" + data_type + "\"");
+    }
+    if (dataset_id != H5I_INVALID_HID) {
+      status = H5Dread(dataset_id, mem_type_id, mem_space_id, file_space_id, dxpl_id, buffer);
+      if (status == 0) status = H5Dclose(dataset_id);
+      else status = (herr_t)-2;
+    } else {
+      status = (herr_t)-1;
+    }
+  } else {
+    throw runtime_error("Unrecognized data type \"" + data_type + "\"");
+  }
+  return status;
+}
+
 herr_t HDFFile::write(
 		      string dataset_name, string data_type, const void *buffer,
 		      hid_t mem_space_id, hid_t file_space_id, hid_t dapl_id,
diff --git a/src/sphere/sphere.cpp b/src/sphere/sphere.cpp
index 129b2f37..8c6ac483 100644
--- a/src/sphere/sphere.cpp
+++ b/src/sphere/sphere.cpp
@@ -292,8 +292,10 @@ void sphere(string config_file, string data_file, string output_path) {
 	if (sconf->idfc >= 0 and nsph == 1 and jxi == gconf->jwtm) {
 	  // This is the condition that writes the transition matrix to output.
 	  TransitionMatrix ttms(gconf->l_max, vk, exri, c1->rmi, c1->rei, sconf->radii_of_spheres[0]);
-	  string ttms_name = output_path + "/c_TTMS";
-	  ttms.write_binary(ttms_name, "LEGACY");
+	  string ttms_name = output_path + "/c_TTMS.hd5";
+	  ttms.write_binary(ttms_name, "HDF5");
+	  ttms_name = output_path + "/c_TTMS";
+	  ttms.write_binary(ttms_name);
 	}
 	double cs0 = 0.25 * vk * vk * vk / half_pi;
 	//printf("DEBUG: cs0 = %lE\n", cs0);
-- 
GitLab