From 52e208145ad759dc4669103b300c24c1e3a17d1e Mon Sep 17 00:00:00 2001
From: Kelvin Rodriguez <krodriguez@usgs.gov>
Date: Mon, 3 Apr 2023 11:25:50 -0700
Subject: [PATCH] added projection code (#524)

* added projection code

* fixed test

* added geotransform

* addded indentity matrix

* added to generic ISD

* adding getters

* addressed comment

* more comment
---
 ale/base/base.py                              | 61 +++++++++++++++++++
 ale/formatters/formatter.py                   |  2 +
 ale/formatters/usgscsm_formatter.py           |  3 +
 environment.yml                               |  1 +
 include/ale/Util.h                            |  3 +
 recipe/meta.yaml                              |  4 +-
 src/Util.cpp                                  | 25 ++++++++
 tests/ctests/IsdTests.cpp                     | 31 ++++++++++
 .../B10_013341_1010_XN_79S172W_isis3.lbl      | 19 ++++++
 tests/pytests/test_load.py                    |  2 -
 tests/pytests/test_mex_drivers.py             |  4 +-
 tests/pytests/test_usgscsm_formatter.py       | 19 ++++++
 12 files changed, 170 insertions(+), 4 deletions(-)

diff --git a/ale/base/base.py b/ale/base/base.py
index bf6ad36..1df9c7e 100644
--- a/ale/base/base.py
+++ b/ale/base/base.py
@@ -1,6 +1,9 @@
 import pvl
 import json
 
+import tempfile
+import os 
+
 class Driver():
     """
     Base class for all Drivers.
@@ -323,3 +326,61 @@ class Driver():
     @property
     def short_mission_name(self):
         return self.__module__.split('.')[-1].split('_')[0]
+
+    @property 
+    def projection(self):
+        if not hasattr(self, "_projection"): 
+            try: 
+              from osgeo import gdal 
+            except: 
+                self._projection = ""
+                return self._projection
+
+            if isinstance(self._file, pvl.PVLModule):
+                # save it to a temp folder
+                with tempfile.NamedTemporaryFile() as tmp:
+                    tmp.write(pvl.dumps(self._file)) 
+
+                    geodata = gdal.Open(tempfile.name)
+                    self._projection = geodata.GetSpatialRef()
+            else: 
+                # should be a path
+                if not os.path.exists(self._file): 
+                    self._projection = "" 
+                else: 
+                    geodata = gdal.Open(self._file)
+                    self._projection = geodata.GetSpatialRef()
+
+            # is None if not projected
+            if self._projection: 
+                self._projection = self._projection.ExportToProj4()
+            else: 
+                self._projection = "" 
+        return self._projection
+    
+
+    @property 
+    def geotransform(self):
+        if not hasattr(self, "_geotransform"): 
+            try: 
+              from osgeo import gdal 
+            except: 
+                self._geotransform = (0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
+                return self._geotransform
+
+            if isinstance(self._file, pvl.PVLModule):
+                # save it to a temp folder
+                with tempfile.NamedTemporaryFile() as tmp:
+                    tmp.write(pvl.dumps(self._file)) 
+
+                    geodata = gdal.Open(tempfile.name)
+                    self._geotransform = geodata.GetGeoTransform()
+            else: 
+                # should be a path
+                if not os.path.exists(self._file): 
+                    self._geotransform = (0.0, 1.0, 0.0, 0.0, 0.0, 1.0) 
+                else: 
+                    geodata = gdal.Open(self._file)
+                    self._geotransform = geodata.GetGeoTransform()
+                
+        return self._geotransform
\ No newline at end of file
diff --git a/ale/formatters/formatter.py b/ale/formatters/formatter.py
index b33287b..23968d6 100644
--- a/ale/formatters/formatter.py
+++ b/ale/formatters/formatter.py
@@ -196,6 +196,8 @@ def to_isd(driver):
 
     meta_data['sun_position'] = sun_position
 
+    meta_data["projection"] = driver.projection 
+    meta_data["geotransform"] = driver.geotransform 
 
     # check that there is a valid sensor model name
     if 'name_model' not in meta_data:
diff --git a/ale/formatters/usgscsm_formatter.py b/ale/formatters/usgscsm_formatter.py
index d7fd7db..44438fd 100644
--- a/ale/formatters/usgscsm_formatter.py
+++ b/ale/formatters/usgscsm_formatter.py
@@ -51,6 +51,9 @@ def to_usgscsm(driver):
         'unit' : 'm'
     }
 
+    isd_data["projection"] = driver.projection
+    isd_data["geotransform"] = driver.geotransform
+
     # shared isd keywords for Framer and Linescanner
     if isinstance(driver, LineScanner) or isinstance(driver, Framer):
         # exterior orientation for just Framer and LineScanner
diff --git a/environment.yml b/environment.yml
index 7b84927..b4c1110 100644
--- a/environment.yml
+++ b/environment.yml
@@ -7,6 +7,7 @@ dependencies:
   - cmake>=3.15
   - pytest
   - eigen
+  - gdal 
   - jupyter
   - nlohmann_json
   - numpy
diff --git a/include/ale/Util.h b/include/ale/Util.h
index d27244f..dbc50be 100644
--- a/include/ale/Util.h
+++ b/include/ale/Util.h
@@ -34,6 +34,8 @@ namespace ale {
   std::string getPlatformName(nlohmann::json isd);
   std::string getLogFile(nlohmann::json isd);
   std::string getIsisCameraVersion(nlohmann::json isd);
+  std::string getProjection(nlohmann::json isd);
+  
   int getTotalLines(nlohmann::json isd);
   int getTotalSamples(nlohmann::json isd);
   double getStartingTime(nlohmann::json isd);
@@ -45,6 +47,7 @@ namespace ale {
   double getFocalLengthUncertainty(nlohmann::json isd);
   std::vector<double> getFocal2PixelLines(nlohmann::json isd);
   std::vector<double> getFocal2PixelSamples(nlohmann::json isd);
+  std::vector<double> getGeoTransform(nlohmann::json isd);
   double getDetectorCenterLine(nlohmann::json isd);
   double getDetectorCenterSample(nlohmann::json isd);
   double getDetectorStartingLine(nlohmann::json isd);
diff --git a/recipe/meta.yaml b/recipe/meta.yaml
index c96b750..7ca81bb 100644
--- a/recipe/meta.yaml
+++ b/recipe/meta.yaml
@@ -31,7 +31,9 @@ requirements:
     - scipy >=1.4.0
     - spiceypy >=4.0.1
     - pyyaml
-
+  run_contrained: 
+    - gdal 
+    
 test:
   imports:
     - ale
diff --git a/src/Util.cpp b/src/Util.cpp
index 297c46c..25fbd40 100644
--- a/src/Util.cpp
+++ b/src/Util.cpp
@@ -37,6 +37,31 @@ std::string getImageId(json isd) {
   return id;
 }
 
+
+std::vector<double> getGeoTransform(json isd) {
+  std::vector<double> transform = {};
+  try {
+    transform = isd.at("geotransform").get<std::vector<double>>();
+  } catch (std::exception &e) {
+    std::string originalError = e.what();
+    std::string msg = "Could not parse the geo_transform. ERROR: \n" + originalError;
+    throw std::runtime_error(msg);
+  }
+  return transform;
+}
+
+
+std::string getProjection(json isd) {
+  std::string projection_string = "";
+  try {
+    projection_string = isd.at("projection");
+  } catch (...) {
+    throw std::runtime_error("Could not parse the projection string.");
+  }
+  return projection_string;
+}
+
+
 std::string getSensorName(json isd) {
   std::string name = "";
   try {
diff --git a/tests/ctests/IsdTests.cpp b/tests/ctests/IsdTests.cpp
index 23fe477..fb6a34d 100644
--- a/tests/ctests/IsdTests.cpp
+++ b/tests/ctests/IsdTests.cpp
@@ -77,6 +77,22 @@ TEST(Isd, LogFile) {
   EXPECT_STREQ(ale::getLogFile(j).c_str(), "fake/path");
 }
 
+
+TEST(Isd, Projection) {
+  nlohmann::json proj;
+  proj["projection"] = "totally a proj4 string";
+  proj["geotransform"] = {0.0, 1.0, 0.0, 0.0, 0.0, 1.0};
+
+  std::vector<double> coeffs = ale::getGeoTransform(proj);
+  std::string proj4str = ale::getProjection(proj);
+  EXPECT_EQ(proj4str, "totally a proj4 string");
+  ASSERT_EQ(coeffs.size(), 6);
+  EXPECT_DOUBLE_EQ(coeffs[1], 1.0);
+  EXPECT_DOUBLE_EQ(coeffs[5], 1.0);
+  EXPECT_DOUBLE_EQ(coeffs[0], 0.0);
+}
+
+
 TEST(Isd, TransverseDistortion) {
   nlohmann::json trans;
   trans["optical_distortion"]["transverse"]["x"] = {1};
@@ -571,6 +587,21 @@ TEST(Isd, BadStartingDetector) {
   }
 }
 
+
+TEST(Isd, BadProjection) {
+  nlohmann::json proj;
+  proj["projection"] = NULL;
+
+  try {
+    std::string proj4str = ale::getProjection(proj);
+    FAIL() << "Expected exception to be thrown"; 
+  }
+  catch(std::exception &e) { 
+     EXPECT_EQ(std::string(e.what()), "Could not parse the projection string.");
+  }
+}
+
+
 TEST(Isd, BadFocal2Pixel) {
   std::string bad_json_str("{}");
   try {
diff --git a/tests/pytests/data/B10_013341_1010_XN_79S172W/B10_013341_1010_XN_79S172W_isis3.lbl b/tests/pytests/data/B10_013341_1010_XN_79S172W/B10_013341_1010_XN_79S172W_isis3.lbl
index 1115475..504171f 100644
--- a/tests/pytests/data/B10_013341_1010_XN_79S172W/B10_013341_1010_XN_79S172W_isis3.lbl
+++ b/tests/pytests/data/B10_013341_1010_XN_79S172W/B10_013341_1010_XN_79S172W_isis3.lbl
@@ -69,6 +69,25 @@ Object = IsisCube
     Source                    = isis
   End_Group
 
+  Group = Mapping
+    ProjectionName     = Sinusoidal
+    CenterLongitude    = 148.36859083039
+    TargetName         = MARS
+    EquatorialRadius   = 3396190.0 <meters>
+    PolarRadius        = 3376200.0 <meters>
+    LatitudeType       = Planetocentric
+    LongitudeDirection = PositiveEast
+    LongitudeDomain    = 360
+    MinimumLatitude    = 63.636322793577
+    MaximumLatitude    = 87.296295823424
+    MinimumLongitude   = 139.6658284858
+    MaximumLongitude   = 157.07135317498
+    UpperLeftCornerX   = -219771.1526456 <meters>
+    UpperLeftCornerY   = 5175537.8728989 <meters>
+    PixelResolution    = 1455.4380969907 <meters/pixel>
+    Scale              = 40.726361118253 <pixels/degree>
+  End_Group
+  
   Group = AlphaCube
     AlphaSamples        = 5000
     AlphaLines          = 24576
diff --git a/tests/pytests/test_load.py b/tests/pytests/test_load.py
index 20776a9..8c7e291 100644
--- a/tests/pytests/test_load.py
+++ b/tests/pytests/test_load.py
@@ -73,12 +73,10 @@ def test_load_mes_from_metakernels(tmpdir, monkeypatch, mess_kernels):
         mk_file.write(mk_str)
 
     usgscsm_isd_obj = ale.load(label_file, verbose=True)
-
     assert usgscsm_isd_obj['name_platform'] == 'MESSENGER'
     assert usgscsm_isd_obj['name_sensor'] == 'MERCURY DUAL IMAGING SYSTEM NARROW ANGLE CAMERA'
     assert usgscsm_isd_obj['name_model'] == 'USGS_ASTRO_FRAME_SENSOR_MODEL'
 
-
 def test_load_mes_with_no_metakernels(tmpdir, monkeypatch, mess_kernels):
     monkeypatch.setenv('ALESPICEROOT', str(tmpdir))
 
diff --git a/tests/pytests/test_mex_drivers.py b/tests/pytests/test_mex_drivers.py
index 650d9b4..ec5035c 100644
--- a/tests/pytests/test_mex_drivers.py
+++ b/tests/pytests/test_mex_drivers.py
@@ -202,7 +202,9 @@ def usgscsm_compare_dict():
               "t0_ephemeris": -101.83713859319687,
               "dt_ephemeris": 40.734855437278746,
               "t0_quaternion": -101.83713859319687,
-              "dt_quaternion": 40.734855437278746
+              "dt_quaternion": 40.734855437278746,
+              "projection" : "",
+              "geotransform" : (0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
               },
 
         "isis" :
diff --git a/tests/pytests/test_usgscsm_formatter.py b/tests/pytests/test_usgscsm_formatter.py
index 9a9ae6d..70ecb64 100644
--- a/tests/pytests/test_usgscsm_formatter.py
+++ b/tests/pytests/test_usgscsm_formatter.py
@@ -9,6 +9,8 @@ from ale.transformation import FrameChain
 from ale.base.data_naif import NaifSpice
 from ale.rotation import ConstantRotation, TimeDependentRotation
 
+from conftest import get_image_label
+
 class TestDriver(Driver, NaifSpice):
     """
     Test Driver implementation with dummy values
@@ -342,3 +344,20 @@ def test_line_scan_sun_position(test_line_scan_driver):
     assert sun_position_obj['positions'] == [[0, 1, 2], [3, 4, 5]]
     assert sun_position_obj['velocities'] == [[0, -1, -2], [-3, -4, -5]]
     assert sun_position_obj['unit'] == 'm'
+
+def test_no_projection(test_frame_driver):
+    isd = usgscsm_formatter.to_usgscsm(test_frame_driver)
+    # isn't using real projection so it should be None
+    assert isd['projection'] == None
+
+def test_isis_projection():
+    isd = usgscsm_formatter.to_usgscsm(TestLineScanner(get_image_label('B10_013341_1010_XN_79S172W', "isis3")))
+    assert isd["projection"] == "+proj=sinu +lon_0=148.36859083039 +x_0=0 +y_0=0 +R=3396190 +units=m +no_defs"
+
+
+def test_isis_geotransform():
+    isd = usgscsm_formatter.to_usgscsm(TestLineScanner(get_image_label('B10_013341_1010_XN_79S172W', "isis3")))
+    expected = (-219771.1526456, 1455.4380969907, 0.0, 5175537.8728989, 0.0, -1455.4380969907)
+    for value, truth in zip(isd["geotransform"], expected):
+        pytest.approx(value, truth)
+
-- 
GitLab