From 30c98dd590a1093ecebab23c111b66f7fbc4d426 Mon Sep 17 00:00:00 2001
From: Giovanni La Mura <giovanni.lamura@inaf.it>
Date: Fri, 11 Apr 2025 13:06:46 +0200
Subject: [PATCH] Enable OBJ model export with group-based color scaling

---
 src/scripts/model_maker.py | 68 ++++++++++++++++++++++++--------------
 1 file changed, 43 insertions(+), 25 deletions(-)

diff --git a/src/scripts/model_maker.py b/src/scripts/model_maker.py
index 555e322..2354daf 100755
--- a/src/scripts/model_maker.py
+++ b/src/scripts/model_maker.py
@@ -36,7 +36,7 @@ except ModuleNotFoundError as ex:
     print("WARNING: pyvista not found!")
     allow_3d = False
 
-from pathlib import PurePath
+from pathlib import Path
 from sys import argv
 
 ## \brief Main execution code
@@ -69,7 +69,7 @@ def interpolate_constants(sconf):
     for i in range(sconf['configurations']):
         for j in range(sconf['nshl'][i]):
             file_idx = sconf['dielec_id'][i][j]
-            dielec_path = PurePath(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
+            dielec_path = Path(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
             file_name = str(dielec_path)
             dielec_file = open(file_name, 'r')
             wavelengths = []
@@ -149,7 +149,7 @@ def load_model(model_file):
             make_3d = False
         # Create the sconf dict
         sconf = {
-            'out_file': PurePath(
+            'out_file': Path(
                 model['input_settings']['input_folder'],
                 model['input_settings']['spheres_file']
             )
@@ -316,7 +316,7 @@ def load_model(model_file):
             print("ERROR: %s is not a recognized polarization state."%str_polar)
             return (None, None)
         gconf = {
-            'out_file': PurePath(
+            'out_file': Path(
                 model['input_settings']['input_folder'],
                 model['input_settings']['geometry_file']
             )
@@ -404,7 +404,7 @@ def match_grid(sconf):
             layers += 1
         for j in range(layers):
             file_idx = sconf['dielec_id'][i][j]
-            dielec_path = PurePath(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
+            dielec_path = Path(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
             file_name = str(dielec_path)
             dielec_file = open(file_name, 'r')
             wavelengths = []
@@ -784,6 +784,9 @@ def write_legacy_sconf(conf):
 #  \param geometry: `dict` Geometry configuration dictionary (gets modified)
 #  \param max_rad: `float` Maximum allowed radial extension of the aggregate
 def write_obj(scatterer, geometry, max_rad):
+    out_dir = scatterer['out_file'].absolute().parent
+    out_model_path = Path(out_dir, "model.obj")
+    out_material_path = Path(out_dir, "model.mtl")
     color_strings = [
         "1.0 1.0 1.0\n", # white
         "1.0 0.0 0.0\n", # red
@@ -793,9 +796,9 @@ def write_obj(scatterer, geometry, max_rad):
     color_names = [
         "white", "red", "blue", "green"
     ]
-    mtl_file = open("model.mtl", "w")
+    mtl_file = open(str(out_material_path), "w")
     for mi in range(len(color_strings)):
-        mtl_line = "newmtl mtl{0:d}\n".format(mi)
+        mtl_line = "newmtl "  + color_names[mi] + "\n"
         mtl_file.write(mtl_line)
         color_line = color_strings[mi]
         mtl_file.write("   Ka " + color_line)
@@ -808,29 +811,44 @@ def write_obj(scatterer, geometry, max_rad):
     pl = pv.Plotter()
     for si in range(scatterer['nsph']):
         sph_type_index = scatterer['vec_types'][si]
-        color_by_name = color_names[sph_type_index]
+        # color_index = 1 + (sph_type_index % (len(color_strings) - 1))
+        # color_by_name = color_names[sph_type_index]
         radius = scatterer['ros'][sph_type_index - 1] / max_rad
         x = geometry['vec_sph_x'][si] / max_rad
         y = geometry['vec_sph_y'][si] / max_rad
         z = geometry['vec_sph_z'][si] / max_rad
         mesh = pv.Sphere(radius, (x, y, z))
-        mesh.save("tmp_mesh.obj")
-        pl.add_mesh(mesh) #, color=color_by_name)
-        mesh_name = "sphere_{0:04d}.obj".format(si)
-        in_obj_file = open("tmp_mesh.obj", "r")
-        out_obj_file = open(mesh_name, "w")
-        in_line = in_obj_file.readline()
-        out_obj_file.write(in_line)
-        out_obj_file.write("mtllib model.mtl\n")
-        out_obj_file.write("usemtl mtl{0:d}\n".format(sph_type_index))
-        while (in_line != ""):
-            in_line = in_obj_file.readline()
-            out_obj_file.write(in_line)
-        in_obj_file.close()
-        out_obj_file.close()
-    pl.export_obj("model.obj")
-    os.remove("tmp_mesh.obj")
-    
+        pl.add_mesh(mesh, color=None)
+    pl.export_obj(str(Path(str(out_dir), "TMP_MODEL.obj")))
+    tmp_model_file = open(str(Path(str(out_dir), "TMP_MODEL.obj")), "r")
+    out_model_file = open(str(Path(str(out_dir), "model.obj")), "w")
+    sph_index = 0
+    sph_type_index = 0
+    old_sph_type_index = 0
+    str_line = tmp_model_file.readline()
+    while (str_line != ""):
+        if (str_line.startswith("mtllib")):
+            str_line = "mtllib model.mtl\n"
+        elif (str_line.startswith("g ")):
+            sph_index += 1
+            sph_type_index = scatterer['vec_types'][sph_index - 1]
+            if (sph_type_index == old_sph_type_index):
+                str_line = tmp_model_file.readline()
+                str_line = tmp_model_file.readline()
+            else:
+                old_sph_type_index = sph_type_index
+                color_index = sph_type_index % (len(color_names) - 1)
+                str_line = "g grp{0:04d}\n".format(sph_type_index)
+                out_model_file.write(str_line)
+                str_line = tmp_model_file.readline()
+                str_line = "usemtl {0:s}\n".format(color_names[color_index])
+        out_model_file.write(str_line)
+        str_line = tmp_model_file.readline()
+    out_model_file.close()
+    tmp_model_file.close()
+    os.remove(str(Path(str(out_dir), "TMP_MODEL.obj")))
+    os.remove(str(Path(str(out_dir), "TMP_MODEL.mtl")))
+
 ## \brief Exit code (0 for success)
 exit_code = main()
 exit(exit_code)
-- 
GitLab