From 8840637c09d65485085bd085f5e7bdd7aa1b09f3 Mon Sep 17 00:00:00 2001
From: acpaquette <acpaquette@usgs.gov>
Date: Fri, 18 Nov 2022 18:18:46 -0800
Subject: [PATCH] ALE Load Driver Selection (#509)

* Initial stab at exclusively selecting drivers based on component classes

* Updated failing tests and added new test

* Expose new exclusion flags through the C++ load/loads functions

* Addresses PR feedback and adds a few more tests for the PyInterface

* Exposed driver exclusion in isd_generate.py
---
 ale/drivers/__init__.py            | 50 +++++++++++++++++++++++-------
 ale/isd_generate.py                | 24 ++++++++++----
 include/ale/Load.h                 | 12 +++++--
 src/Load.cpp                       | 23 +++++++++++---
 tests/ctests/LoadTests.cpp         | 21 +++++++++++++
 tests/pytests/test_isd_generate.py |  4 +--
 tests/pytests/test_load.py         | 22 ++++++++-----
 7 files changed, 122 insertions(+), 34 deletions(-)

diff --git a/ale/drivers/__init__.py b/ale/drivers/__init__.py
index 4c50d9b..e084b62 100644
--- a/ale/drivers/__init__.py
+++ b/ale/drivers/__init__.py
@@ -1,23 +1,24 @@
 import pvl
-import zlib
 
 import importlib
 import inspect
-import itertools
-from itertools import chain
+from itertools import chain, compress
 import os
 from glob import glob
 import json
 import numpy as np
 import datetime
-from datetime import datetime, date
+from datetime import datetime
 import traceback
-from collections import OrderedDict
 
 from ale.formatters.usgscsm_formatter import to_usgscsm
 from ale.formatters.isis_formatter import to_isis
 from ale.formatters.formatter import to_isd
 from ale.base.data_isis import IsisSpice
+from ale.base.label_isis import IsisLabel
+from ale.base.label_pds3 import Pds3Label
+from ale.base.data_naif import NaifSpice
+
 
 from abc import ABC
 
@@ -53,15 +54,25 @@ class AleJsonEncoder(json.JSONEncoder):
             return obj.isoformat()
         return json.JSONEncoder.default(self, obj)
 
-def load(label, props={}, formatter='ale', verbose=False):
+def load(label, props={}, formatter='ale', verbose=False, only_isis_spice=False, only_naif_spice=False):
     """
-    Attempt to load a given label from all possible drivers.
+    Attempt to load a given label from possible drivers.
 
     This function opens up the label file and attempts to produce an ISD in the
     format specified using the supplied properties. Drivers are tried sequentially
     until an ISD is successfully created. Drivers that use external ephemeris
     data are tested before drivers that use attached ephemeris data.
 
+    Using the only_* flags will limit the drivers used to construct ISDs. If you
+    are not sure what input data you have, just leave the only_* parameters as False.
+    Leaving/Setting all only_* parameters to False should satisfy most situations.
+
+    Only parameters explained and there uses:
+    * ``only_isis_spice=True`` Used for spiceinit'd ISIS cubes, used, for example, 
+    when one has updated the ephemeris information on an ISIS cube.
+    * ``only_naif_spice=True`` Used for example, when one has a data product or 
+    an ISIS cube, but not yet obtained ephemeris information.
+
     Parameters
     ----------
     label : str
@@ -82,15 +93,31 @@ def load(label, props={}, formatter='ale', verbose=False):
               If True, displays debug output specifying which drivers were
               attempted and why they failed.
 
+    only_isis_spice : bool
+                      Explicitly searches for drivers constructed from the IsisSpice
+                      component class
+
+    only_naif_spice : bool
+                      Explicitly searches for drivers constructed from the NaifSpice
+                      component class
+
     Returns
     -------
     dict
          The ISD as a dictionary
     """
+    print("Banana", only_isis_spice, only_naif_spice)
     if isinstance(formatter, str):
         formatter = __formatters__[formatter]
-
-    drivers = chain.from_iterable(inspect.getmembers(dmod, lambda x: inspect.isclass(x) and "_driver" in x.__module__) for dmod in __driver_modules__)
+    
+    driver_mask = [only_isis_spice, only_naif_spice]
+    class_list = [IsisSpice, NaifSpice]
+    class_list = list(compress(class_list, driver_mask))
+    # predicat logic: make sure x is a class, who contains the word "driver" (clipper_drivers) and 
+    # the componenet classes 
+    predicat = lambda x: inspect.isclass(x) and "_driver" in x.__module__ and [i for i in class_list if i in inspect.getmro(x)] == class_list
+    driver_list = [inspect.getmembers(dmod, predicat) for dmod in __driver_modules__]
+    drivers = chain.from_iterable(driver_list)
     drivers = sort_drivers([d[1] for d in drivers])
 
     if verbose:
@@ -137,7 +164,7 @@ def load(label, props={}, formatter='ale', verbose=False):
                 traceback.print_exc()
     raise Exception('No Such Driver for Label')
 
-def loads(label, props='', formatter='ale', indent = 2, verbose=False):
+def loads(label, props='', formatter='ale', indent = 2, verbose=False, only_isis_spice=False, only_naif_spice=False):
     """
     Attempt to load a given label from all possible drivers.
 
@@ -160,7 +187,8 @@ def loads(label, props='', formatter='ale', indent = 2, verbose=False):
     --------
     load
     """
-    res = load(label, props, formatter, verbose=verbose)
+    print(only_isis_spice, only_naif_spice)
+    res = load(label, props, formatter, verbose, only_isis_spice, only_naif_spice)
     return json.dumps(res, indent=indent, cls=AleJsonEncoder)
 
 def parse_label(label, grammar=pvl.grammar.PVLGrammar()):
diff --git a/ale/isd_generate.py b/ale/isd_generate.py
index 9e63273..b04307d 100755
--- a/ale/isd_generate.py
+++ b/ale/isd_generate.py
@@ -57,6 +57,16 @@ def main():
         action="store_true",
         help="Display information as program runs."
     )
+    parser.add_argument(
+        "-i", "--only_isis_spice",
+        action="store_true",
+        help="Only use drivers that read from spiceinit'd ISIS cubes"
+    )
+    parser.add_argument(
+        "-n", "--only_naif_spice",
+        action="store_true",
+        help="Only use drivers that generate fresh spice data"
+    )
     parser.add_argument(
         '--version',
         action='version',
@@ -87,7 +97,7 @@ def main():
 
     if len(args.input) == 1:
         try:
-            file_to_isd(args.input[0], args.out, kernels=k, log_level=log_level)
+            file_to_isd(args.input[0], args.out, kernels=k, log_level=log_level, only_isis_spice=args.only_isis_spice, only_naif_spice=args.only_naif_spice)
         except Exception as err:
             # Seriously, this just throws a generic Exception?
             sys.exit(f"File {args.input[0]}: {err}")
@@ -97,7 +107,7 @@ def main():
         ) as executor:
             futures = {
                 executor.submit(
-                    file_to_isd, f, **{"kernels": k, "log_level": log_level}
+                    file_to_isd, f, **{"kernels": k, "log_level": log_level, "only_isis_spice": args.only_isis_spice, "only_naif_spice": args.only_naif_spice}
                 ): f for f in args.input
             }
             for f in concurrent.futures.as_completed(futures):
@@ -115,7 +125,9 @@ def file_to_isd(
     file: os.PathLike,
     out: os.PathLike = None,
     kernels: list = None,
-    log_level=logging.WARNING
+    log_level=logging.WARNING,
+    only_isis_spice=False,
+    only_naif_spice=False
 ):
     """
     Returns nothing, but acts as a thin wrapper to take the *file* and generate
@@ -139,11 +151,11 @@ def file_to_isd(
     logger.setLevel(log_level)
 
     logger.info(f"Reading: {file}")
+    props = {}
     if kernels is not None:
         kernels = [str(PurePath(p)) for p in kernels]
-        usgscsm_str = ale.loads(file, props={'kernels': kernels}, verbose=log_level>=logging.INFO)
-    else:
-        usgscsm_str = ale.loads(file, verbose=log_level>=logging.INFO)
+        props["kernels"] = kernels
+    usgscsm_str = ale.loads(file, props=props, verbose=log_level>=logging.INFO, only_isis_spice=only_isis_spice, only_naif_spice=only_naif_spice)
 
     logger.info(f"Writing: {isd_file}")
     isd_file.write_text(usgscsm_str)
diff --git a/include/ale/Load.h b/include/ale/Load.h
index 7485740..b529b0d 100644
--- a/include/ale/Load.h
+++ b/include/ale/Load.h
@@ -23,10 +23,14 @@ namespace ale {
    * @param verbose A flag to output what the load function is attempting to do.
    *                If set to true, information about the drivers load attempts
    *                to use will be output to standard out.
+   * @param onlyIsisSpice A flag the forces the load function to only use IsisSpice
+   *                      drivers
+   * @param onlyNaifSpice A flag the forces the load function to only use NaifSpice
+   *                      drivers
    *
    * @returns A string containing a JSON formatted ISD for the image.
    */
-  std::string loads(std::string filename, std::string props="", std::string formatter="usgscsm", int indent = 2, bool verbose=true);
+  std::string loads(std::string filename, std::string props="", std::string formatter="ale", int indent=2, bool verbose=true, bool onlyIsisSpice=false, bool onlyNaifSpice=false);
 
   /**
    * Load all of the metadata for an image into a JSON ISD.
@@ -44,10 +48,14 @@ namespace ale {
    * @param verbose A flag to output what the load function is attempting to do.
    *                If set to true, information about the drivers load attempts
    *                to use will be output to standard out.
+   * @param onlyIsisSpice A flag the forces the load function to only use IsisSpice
+   *                      drivers
+   * @param onlyIsisSpice A flag the forces the load function to only use NaifSpice
+   *                      drivers
    *
    * @returns A string containing a JSON formatted ISD for the image.
    */
-  nlohmann::json load(std::string filename, std::string props="", std::string formatter="usgscsm", bool verbose=true);
+  nlohmann::json load(std::string filename, std::string props="", std::string formatter="ale", bool verbose=true, bool onlyIsisSpice=false, bool onlyNaifSpice=false);
 }
 
 #endif // ALE_H
diff --git a/src/Load.cpp b/src/Load.cpp
index c1dc71d..3aeb48d 100644
--- a/src/Load.cpp
+++ b/src/Load.cpp
@@ -59,7 +59,7 @@ namespace ale {
     return "";
   }
 
-  std::string loads(std::string filename, std::string props, std::string formatter, int indent, bool verbose) {
+  std::string loads(std::string filename, std::string props, std::string formatter, int indent, bool verbose, bool onlyIsisSpice, bool onlyNaifSpice) {
     static bool first_run = true;
 
     if(first_run) {
@@ -88,7 +88,7 @@ namespace ale {
 
 
     // Create a Python tuple to hold the arguments to the method.
-    PyObject *pArgs = PyTuple_New(5);
+    PyObject *pArgs = PyTuple_New(7);
     if(!pArgs) {
       throw runtime_error(getPyTraceback());
     }
@@ -111,12 +111,25 @@ namespace ale {
     Py_INCREF(pIntIndent); // take ownership of reference
 
     PyObject *pBoolVerbose = Py_False;
-    if (!verbose) {
+    if (verbose == true) {
       pBoolVerbose = Py_True;
     }
     PyTuple_SetItem(pArgs, 4, pBoolVerbose);
     Py_INCREF(pBoolVerbose); // take ownership of reference
 
+    PyObject *pBoolOnlyIsisSpice = Py_False;
+    if (onlyIsisSpice == true) {
+      pBoolOnlyIsisSpice = Py_True;
+    }
+    PyTuple_SetItem(pArgs, 5, pBoolOnlyIsisSpice);
+    Py_INCREF(pBoolOnlyIsisSpice); // take ownership of reference
+
+    PyObject *pBoolOnlyNaifSpice = Py_False;
+    if (onlyNaifSpice == true) {
+      pBoolOnlyNaifSpice = Py_True;
+    }
+    PyTuple_SetItem(pArgs, 6, pBoolOnlyNaifSpice);
+    Py_INCREF(pBoolOnlyNaifSpice); // take ownership of reference
 
     // Call the function with the arguments.
     PyObject* pResult = PyObject_CallObject(pFunc, pArgs);
@@ -149,8 +162,8 @@ namespace ale {
     return cResult;
   }
 
-  json load(std::string filename, std::string props, std::string formatter, bool verbose) {
-    std::string jsonstr = loads(filename, props, formatter, verbose);
+  json load(std::string filename, std::string props, std::string formatter, bool verbose, bool onlyIsisSpice, bool onlyNaifSpice) {
+    std::string jsonstr = loads(filename, props, formatter, 0, verbose, onlyIsisSpice, onlyNaifSpice);
     return json::parse(jsonstr);
   }
 }
diff --git a/tests/ctests/LoadTests.cpp b/tests/ctests/LoadTests.cpp
index a5f0c69..bcd556e 100644
--- a/tests/ctests/LoadTests.cpp
+++ b/tests/ctests/LoadTests.cpp
@@ -1,10 +1,15 @@
 #include "gtest/gtest.h"
+#include "gmock/gmock.h"
 
 #include "ale/Load.h"
 
+#include <nlohmann/json.hpp>
+
 #include <stdexcept>
 
+using json = nlohmann::json;
 using namespace std;
+using ::testing::HasSubstr;
 
 TEST(PyInterfaceTest, LoadInvalidLabel) {
   std::string label = "Not a Real Label";
@@ -16,3 +21,19 @@ TEST(PyInterfaceTest, LoadValidLabel) {
   std::string label = "../pytests/data/EN1072174528M/EN1072174528M_spiceinit.lbl";
   ale::load(label, "", "isis");
 }
+
+TEST(PyInterfaceTest, LoadValidLabelOnlyIsisSpice) {
+  std::string label = "../pytests/data/EN1072174528M/EN1072174528M_spiceinit.lbl";
+  ale::load(label, "", "isis", false, true, false);
+}
+
+TEST(PyInterfaceTest, LoadValidLabelOnlyNaifSpice) {
+  std::string label = "../pytests/data/EN1072174528M/EN1072174528M_spiceinit.lbl";
+  try {
+    ale::load(label, "", "isis", false, false, true);
+    FAIL() << "Should not have been able to generate an ISD" << endl;
+  }
+  catch (exception &e) {
+    EXPECT_THAT(e.what(), HasSubstr("No Valid instrument found for label."));
+  }
+}
\ No newline at end of file
diff --git a/tests/pytests/test_isd_generate.py b/tests/pytests/test_isd_generate.py
index fb1c8c4..2826eb2 100644
--- a/tests/pytests/test_isd_generate.py
+++ b/tests/pytests/test_isd_generate.py
@@ -28,7 +28,7 @@ class TestFile(unittest.TestCase):
             cube_str = "dummy.cub"
             isdg.file_to_isd(cube_str)
             self.assertEqual(
-                m_loads.call_args_list, [call(cube_str, verbose=True)]
+                m_loads.call_args_list, [call(cube_str, props={}, verbose=True, only_isis_spice=False, only_naif_spice=False)]
             )
             self.assertEqual(
                 m_path_wt.call_args_list, [call(json_text)]
@@ -41,7 +41,7 @@ class TestFile(unittest.TestCase):
             isdg.file_to_isd(cube_str, out=out_str, kernels=kernel_val)
             self.assertEqual(
                 m_loads.call_args_list,
-                [call(cube_str, props={'kernels': kernel_val}, verbose=True)]
+                [call(cube_str, props={'kernels': kernel_val}, verbose=True, only_isis_spice=False, only_naif_spice=False)]
             )
             self.assertEqual(
                 m_path_wt.call_args_list, [call(json_text)]
diff --git a/tests/pytests/test_load.py b/tests/pytests/test_load.py
index 431ae2f..20776a9 100644
--- a/tests/pytests/test_load.py
+++ b/tests/pytests/test_load.py
@@ -24,16 +24,22 @@ def test_priority(tmpdir, monkeypatch):
     sorted_drivers = sort_drivers(drivers)
     assert all([IsisSpice in klass.__bases__ for klass in sorted_drivers[2:]])
 
-def test_mess_load(mess_kernels):
-    updated_kernels = mess_kernels
+@pytest.mark.parametrize(("class_truth, return_val"), [({"only_isis_spice": False,  "only_naif_spice": False}, True), 
+                                                       ({"only_isis_spice": True,  "only_naif_spice": False}, False)])
+def test_mess_load(class_truth, return_val, mess_kernels):
     label_file = get_image_label('EN1072174528M')
 
-    usgscsm_isd_str = ale.loads(label_file, props={'kernels': updated_kernels}, formatter='usgscsm')
-    usgscsm_isd_obj = json.loads(usgscsm_isd_str)
-
-    assert usgscsm_isd_obj['name_platform'] == 'MESSENGER'
-    assert usgscsm_isd_obj['name_sensor'] == 'MERCURY DUAL IMAGING SYSTEM NARROW ANGLE CAMERA'
-    assert usgscsm_isd_obj['name_model'] == 'USGS_ASTRO_FRAME_SENSOR_MODEL'
+    try:
+        usgscsm_isd_str = ale.loads(label_file, {'kernels': mess_kernels}, 'usgscsm', False, **class_truth)
+        usgscsm_isd_obj = json.loads(usgscsm_isd_str)
+
+        assert return_val is True
+        assert usgscsm_isd_obj['name_platform'] == 'MESSENGER'
+        assert usgscsm_isd_obj['name_sensor'] == 'MERCURY DUAL IMAGING SYSTEM NARROW ANGLE CAMERA'
+        assert usgscsm_isd_obj['name_model'] == 'USGS_ASTRO_FRAME_SENSOR_MODEL'
+    except Exception as load_failure:
+        assert str(load_failure) == "No Such Driver for Label"
+        assert return_val is False
 
 def test_load_invalid_label():
     with pytest.raises(Exception):
-- 
GitLab