diff --git a/plio/io/hcube.py b/plio/io/hcube.py index 5c50699cae7f924a5a4532ba3f35328d6dd84d1d..2e45197a532dce96b7df749cbdf848d71b07531b 100644 --- a/plio/io/hcube.py +++ b/plio/io/hcube.py @@ -2,6 +2,8 @@ import numpy as np import gdal from ..utils.indexing import _LocIndexer, _iLocIndexer +from libpyhat.transform.continuum import continuum_correction +from libpyhat.transform.continuum import polynomial, linear, regression class HCube(object): @@ -10,6 +12,11 @@ class HCube(object): to optionally add support for spectral labels, label based indexing, and lazy loading for reads. """ + def __init__(self, data = [], wavelengths = []): + if len(data) != 0: + self._data = data + if len(wavelengths) != 0: + self._wavelengths = wavelengths @property def wavelengths(self): @@ -24,6 +31,21 @@ class HCube(object): self._wavelengths = [] return self._wavelengths + @property + def data(self): + if not hasattr(self, '_data'): + try: + key = (slice(None, None, None), + slice(None, None, None), + slice(None, None, None)) + data = self._read(key) + except Exception as e: + print(e) + data = [] + self._data = data + + return self._data + @property def tolerance(self): return getattr(self, '_tolerance', 2) @@ -52,6 +74,104 @@ class HCube(object): def iloc(self): return _iLocIndexer(self) + def reduce(self, how = np.mean, axis = (1, 2)): + """ + Parameters + ---------- + how : function + Function to apply across along axises of the hcube + + axis : tuple + List of axis to apply a given function along + + Returns + ------- + new_hcube : Object + A new hcube object with the reduced data set + """ + res = how(self.data, axis = axis) + + new_hcube = HCube(res, self.wavelengths) + return new_hcube + + def continuum_correct(self, nodes, correction_nodes = np.array([]), correction = linear, + axis=0, adaptive=False, window=3, **kwargs): + """ + Parameters + ---------- + + nodes : list + A list of wavelengths for the continuum to be corrected along + + correction_nodes : list + A list of nodes to limit the correction between + + correction : function + Function specifying the type of correction to perform + along the continuum + + axis : int + Axis to apply the continuum correction on + + adaptive : boolean + ? + + window : int + ? + + Returns + ------- + + new_hcube : Object + A new hcube object with the corrected dataset + """ + + continuum_data = continuum_correction(self.data, self.wavelengths, nodes = nodes, + correction_nodes = correction_nodes, correction = correction, + axis = axis, adaptive = adaptive, + window = window, **kwargs) + + new_hcube = HCube(continuum_data[0], self.wavelengths) + return new_hcube + + + def clip_roi(self, x, y, band, tolerance=2): + """ + Parameters + ---------- + + x : tuple + Lower and upper bound along the x axis for clipping + + y : tuple + Lower and upper bound along the y axis for clipping + + band : tuple + Lower and upper band along the z axis for clipping + + tolerance : int + Tolerance given for trying to find wavelengths + between the upper and lower bound + + Returns + ------- + + new_hcube : Object + A new hcube object with the clipped dataset + """ + wavelength_clip = [] + for wavelength in self.wavelengths: + wavelength_upper = wavelength + tolerance + wavelength_lower = wavelength - tolerance + if wavelength_upper > band[0] and wavelength_lower < band[1]: + wavelength_clip.append(wavelength) + + key = (wavelength_clip, slice(*x), slice(*y)) + data_clip = _LocIndexer(self)[key] + + new_hcube = HCube(np.copy(data_clip), np.array(wavelength_clip)) + return new_hcube + def _read(self, key): ifnone = lambda a, b: b if a is None else a @@ -76,7 +196,10 @@ class HCube(object): elif isinstance(key[0], slice): # Given some slice iterate over the bands and get the bands and pixel space requested - return [self.read_array(i, pixels = pixels) for i in list(range(1, self.nbands + 1))[key[0]]] + arrs = [] + for band in list(list(range(1, self.nbands + 1))[key[0]]): + arrs.append(self.read_array(band, pixels = pixels)) + return np.stack(arrs) else: arrs = [] diff --git a/plio/io/io_crism.py b/plio/io/io_crism.py index a65657f02b526da01ec77d4d2ce550f02fe5c575..fb841077a02416f400cb9e0460c65f6aa57cbbe5 100644 --- a/plio/io/io_crism.py +++ b/plio/io/io_crism.py @@ -4,12 +4,12 @@ from .io_gdal import GeoDataset from .hcube import HCube try: - from libpysat.derived import crism - from libpysat.derived.utils import get_derived_funcs - libpysat_enabled = True + from libpyhat.derived import crism + from libpyhat.derived.utils import get_derived_funcs + libpyhat_enabled = True except: - print('No libpysat module. Unable to attach derived product functions') - libpysat_enabled = False + print('No libpyhat module. Unable to attach derived product functions') + libpyhat_enabled = False import gdal @@ -25,7 +25,7 @@ class Crism(GeoDataset, HCube): self.derived_funcs = {} - if libpysat_enabled: + if libpyhat_enabled: self.derived_funcs = get_derived_funcs(crism) def __getattr__(self, name): diff --git a/plio/io/io_moon_minerology_mapper.py b/plio/io/io_moon_minerology_mapper.py index d078f854bd52900030ace43b5c9507a65f34f076..3a4a7a439176f65ff95e175eedae46e10c050d17 100644 --- a/plio/io/io_moon_minerology_mapper.py +++ b/plio/io/io_moon_minerology_mapper.py @@ -4,12 +4,12 @@ from .io_gdal import GeoDataset from .hcube import HCube try: - from libpysat.derived import m3 - from libpysat.derived.utils import get_derived_funcs - libpysat_enabled = True + from libpyhat.derived import m3 + from libpyhat.derived.utils import get_derived_funcs + libpyhat_enabled = True except: - print('No libpysat module. Unable to attach derived product functions') - libpysat_enabled = False + print('No libpyhat module. Unable to attach derived product functions') + libpyhat_enabled = False import gdal @@ -25,7 +25,7 @@ class M3(GeoDataset, HCube): self.derived_funcs = {} - if libpysat_enabled: + if libpyhat_enabled: self.derived_funcs = get_derived_funcs(m3) def __getattr__(self, name): diff --git a/plio/utils/indexing.py b/plio/utils/indexing.py index 9dc7c93cea8d933f6dd7ba183a7cc843fce8a96e..1011d7482319582ff6cf3590efb6cc7e55934bc8 100644 --- a/plio/utils/indexing.py +++ b/plio/utils/indexing.py @@ -35,29 +35,35 @@ def expanded_indexer(key, ndim): class _LocIndexer(object): def __init__(self, data_array): self.data_array = data_array - + def __getitem__(self, key): # expand the indexer so we can handle Ellipsis key = expanded_indexer(key, 3) sl = key[0] ifnone = lambda a, b: b if a is None else a if isinstance(sl, slice): - sl = list(range(ifnone(sl.start, 0), self.data_array.nbands, ifnone(sl.step, 1))) + sl = list(range(ifnone(sl.start, 0), + ifnone(sl.stop, len(self.data_array.wavelengths)), + ifnone(sl.step, 1))) if isinstance(sl, (int, float)): idx = self._get_idx(sl) - else: + else: idx = [self._get_idx(s) for s in sl] key = (idx, key[1], key[2]) - return self.data_array._read(key) - + + if len(self.data_array.data) != 0: + return self.data_array.data[key] + + return self.data_array._read(key) + def _get_idx(self, value, tolerance=2): vals = np.abs(self.data_array.wavelengths-value) minidx = np.argmin(vals) if vals[minidx] >= tolerance: warnings.warn("Absolute difference between requested value and found values is {}".format(vals[minidx])) return minidx - + class _iLocIndexer(object): def __init__(self, data_array): self.data_array = data_array @@ -69,8 +75,12 @@ class _iLocIndexer(object): ifnone = lambda a, b: b if a is None else a if isinstance(sl, slice): sl = list(range(ifnone(sl.start, 0), - ifnone(sl.stop, self.data_array.nbands), + ifnone(sl.stop, len(self.data_array.wavelengths)), ifnone(sl.step, 1))) - + key = (key[0], key[1], key[2]) - return self.data_array._read(key) \ No newline at end of file + + if len(self.data_array.data) != 0: + return self.data_array.data[key] + + return self.data_array._read(key)