From d6087ee451c90f501a33f84f03434f9fed7effa6 Mon Sep 17 00:00:00 2001
From: acpaquette <acp263@nau.edu>
Date: Sat, 1 Dec 2018 13:40:50 -0700
Subject: [PATCH] PyHAT Updates (#87)

* Small updates to get the scripts working.

* Fixed pradius calculations, and made some small changes.

* Removed coord transforms and made body_fix func more generic.

* Updated reproj doc string.

* General refactor to socet scripts and clean up. Simplfied input for both scripts.

* Updated conf to use conda prefix

* Fixed up __getattr__ func

* Corrected attribute error text

* Uploaded notebooks for remote access

* Updated notebooks with complete footprint function

* More or less final notebook

* Forgot to tab this in under the first except block

* Updated version.

* Removed unnecessary notebooks

* Removed previously removed module

* Removed old import

* Fixes issue indexing on the hcube

* Updated hcube

* Updated io_m3 and io_crism pyhat dependency

* hcube and indexing update to handle clipping and other operations
---
 plio/io/hcube.py                     | 125 ++++++++++++++++++++++++++-
 plio/io/io_crism.py                  |  12 +--
 plio/io/io_moon_minerology_mapper.py |  12 +--
 plio/utils/indexing.py               |  28 ++++--
 4 files changed, 155 insertions(+), 22 deletions(-)

diff --git a/plio/io/hcube.py b/plio/io/hcube.py
index 5c50699..2e45197 100644
--- a/plio/io/hcube.py
+++ b/plio/io/hcube.py
@@ -2,6 +2,8 @@ import numpy as np
 import gdal
 
 from ..utils.indexing import _LocIndexer, _iLocIndexer
+from libpyhat.transform.continuum import continuum_correction
+from libpyhat.transform.continuum import polynomial, linear, regression
 
 
 class HCube(object):
@@ -10,6 +12,11 @@ class HCube(object):
     to optionally add support for spectral labels, label
     based indexing, and lazy loading for reads.
     """
+    def __init__(self, data = [], wavelengths = []):
+        if len(data) != 0:
+            self._data = data
+        if len(wavelengths) != 0:
+            self._wavelengths = wavelengths
 
     @property
     def wavelengths(self):
@@ -24,6 +31,21 @@ class HCube(object):
                 self._wavelengths = []
         return self._wavelengths
 
+    @property
+    def data(self):
+        if not hasattr(self, '_data'):
+            try:
+                key = (slice(None, None, None),
+                       slice(None, None, None),
+                       slice(None, None, None))
+                data = self._read(key)
+            except Exception as e:
+                print(e)
+                data = []
+            self._data = data
+
+        return self._data
+
     @property
     def tolerance(self):
         return getattr(self, '_tolerance', 2)
@@ -52,6 +74,104 @@ class HCube(object):
     def iloc(self):
         return _iLocIndexer(self)
 
+    def reduce(self, how = np.mean, axis = (1, 2)):
+        """
+        Parameters
+        ----------
+        how : function
+              Function to apply across along axises of the hcube
+
+        axis : tuple
+               List of axis to apply a given function along
+
+        Returns
+        -------
+        new_hcube : Object
+                    A new hcube object with the reduced data set
+        """
+        res = how(self.data, axis = axis)
+
+        new_hcube = HCube(res, self.wavelengths)
+        return new_hcube
+
+    def continuum_correct(self, nodes, correction_nodes = np.array([]), correction = linear,
+                          axis=0, adaptive=False, window=3, **kwargs):
+        """
+        Parameters
+        ----------
+
+        nodes : list
+                A list of wavelengths for the continuum to be corrected along
+
+        correction_nodes : list
+                           A list of nodes to limit the correction between
+
+        correction : function
+                     Function specifying the type of correction to perform
+                     along the continuum
+
+        axis : int
+               Axis to apply the continuum correction on
+
+        adaptive : boolean
+                   ?
+
+        window : int
+                 ?
+
+        Returns
+        -------
+
+        new_hcube : Object
+                    A new hcube object with the corrected dataset
+        """
+
+        continuum_data = continuum_correction(self.data, self.wavelengths, nodes = nodes,
+                                              correction_nodes = correction_nodes, correction = correction,
+                                              axis = axis, adaptive = adaptive,
+                                              window = window, **kwargs)
+
+        new_hcube = HCube(continuum_data[0], self.wavelengths)
+        return new_hcube
+
+
+    def clip_roi(self, x, y, band, tolerance=2):
+        """
+        Parameters
+        ----------
+
+        x : tuple
+            Lower and upper bound along the x axis for clipping
+
+        y : tuple
+            Lower and upper bound along the y axis for clipping
+
+        band : tuple
+               Lower and upper band along the z axis for clipping
+
+        tolerance : int
+                    Tolerance given for trying to find wavelengths
+                    between the upper and lower bound
+
+        Returns
+        -------
+
+        new_hcube : Object
+                    A new hcube object with the clipped dataset
+        """
+        wavelength_clip = []
+        for wavelength in self.wavelengths:
+            wavelength_upper = wavelength + tolerance
+            wavelength_lower = wavelength - tolerance
+            if wavelength_upper > band[0] and wavelength_lower < band[1]:
+                wavelength_clip.append(wavelength)
+
+        key = (wavelength_clip, slice(*x), slice(*y))
+        data_clip = _LocIndexer(self)[key]
+
+        new_hcube = HCube(np.copy(data_clip), np.array(wavelength_clip))
+        return new_hcube
+
     def _read(self, key):
         ifnone = lambda a, b: b if a is None else a
 
@@ -76,7 +196,10 @@ class HCube(object):
 
         elif isinstance(key[0], slice):
             # Given some slice iterate over the bands and get the bands and pixel space requested
-            return [self.read_array(i, pixels = pixels) for i in list(range(1, self.nbands + 1))[key[0]]]
+            arrs = []
+            for band in list(list(range(1, self.nbands + 1))[key[0]]):
+                arrs.append(self.read_array(band, pixels = pixels))
+            return np.stack(arrs)
 
         else:
             arrs = []
diff --git a/plio/io/io_crism.py b/plio/io/io_crism.py
index a65657f..fb84107 100644
--- a/plio/io/io_crism.py
+++ b/plio/io/io_crism.py
@@ -4,12 +4,12 @@ from .io_gdal import GeoDataset
 from .hcube import HCube
 
 try:
-    from libpysat.derived import crism
-    from libpysat.derived.utils import get_derived_funcs
-    libpysat_enabled = True
+    from libpyhat.derived import crism
+    from libpyhat.derived.utils import get_derived_funcs
+    libpyhat_enabled = True
 except:
-    print('No libpysat module. Unable to attach derived product functions')
-    libpysat_enabled = False
+    print('No libpyhat module. Unable to attach derived product functions')
+    libpyhat_enabled = False
 
 import gdal
 
@@ -25,7 +25,7 @@ class Crism(GeoDataset, HCube):
 
         self.derived_funcs = {}
 
-        if libpysat_enabled:
+        if libpyhat_enabled:
             self.derived_funcs = get_derived_funcs(crism)
 
     def __getattr__(self, name):
diff --git a/plio/io/io_moon_minerology_mapper.py b/plio/io/io_moon_minerology_mapper.py
index d078f85..3a4a7a4 100644
--- a/plio/io/io_moon_minerology_mapper.py
+++ b/plio/io/io_moon_minerology_mapper.py
@@ -4,12 +4,12 @@ from .io_gdal import GeoDataset
 from .hcube import HCube
 
 try:
-    from libpysat.derived import m3
-    from libpysat.derived.utils import get_derived_funcs
-    libpysat_enabled = True
+    from libpyhat.derived import m3
+    from libpyhat.derived.utils import get_derived_funcs
+    libpyhat_enabled = True
 except:
-    print('No libpysat module. Unable to attach derived product functions')
-    libpysat_enabled = False
+    print('No libpyhat module. Unable to attach derived product functions')
+    libpyhat_enabled = False
 
 import gdal
 
@@ -25,7 +25,7 @@ class M3(GeoDataset, HCube):
 
         self.derived_funcs = {}
 
-        if libpysat_enabled:
+        if libpyhat_enabled:
             self.derived_funcs = get_derived_funcs(m3)
 
     def __getattr__(self, name):
diff --git a/plio/utils/indexing.py b/plio/utils/indexing.py
index 9dc7c93..1011d74 100644
--- a/plio/utils/indexing.py
+++ b/plio/utils/indexing.py
@@ -35,29 +35,35 @@ def expanded_indexer(key, ndim):
 class _LocIndexer(object):
     def __init__(self, data_array):
         self.data_array = data_array
-    
+
     def __getitem__(self, key):
         # expand the indexer so we can handle Ellipsis
         key = expanded_indexer(key, 3)
         sl = key[0]
         ifnone = lambda a, b: b if a is None else a
         if isinstance(sl, slice):
-            sl = list(range(ifnone(sl.start, 0), self.data_array.nbands, ifnone(sl.step, 1)))
+            sl = list(range(ifnone(sl.start, 0),
+                            ifnone(sl.stop, len(self.data_array.wavelengths)),
+                            ifnone(sl.step, 1)))
 
         if isinstance(sl, (int, float)):
             idx = self._get_idx(sl)
-        else:    
+        else:
             idx = [self._get_idx(s) for s in sl]
         key = (idx, key[1], key[2])
-        return self.data_array._read(key)   
-    
+
+        if len(self.data_array.data) != 0:
+            return self.data_array.data[key]
+
+        return self.data_array._read(key)
+
     def _get_idx(self, value, tolerance=2):
         vals = np.abs(self.data_array.wavelengths-value)
         minidx = np.argmin(vals)
         if vals[minidx] >= tolerance:
             warnings.warn("Absolute difference between requested value and found values is {}".format(vals[minidx]))
         return minidx
-    
+
 class _iLocIndexer(object):
     def __init__(self, data_array):
         self.data_array = data_array
@@ -69,8 +75,12 @@ class _iLocIndexer(object):
         ifnone = lambda a, b: b if a is None else a
         if isinstance(sl, slice):
             sl = list(range(ifnone(sl.start, 0),
-                            ifnone(sl.stop, self.data_array.nbands),
+                            ifnone(sl.stop, len(self.data_array.wavelengths)),
                             ifnone(sl.step, 1)))
-        
+
         key = (key[0], key[1], key[2])
-        return self.data_array._read(key)
\ No newline at end of file
+
+        if len(self.data_array.data) != 0:
+            return self.data_array.data[key]
+
+        return self.data_array._read(key)
-- 
GitLab