From d6087ee451c90f501a33f84f03434f9fed7effa6 Mon Sep 17 00:00:00 2001 From: acpaquette <acp263@nau.edu> Date: Sat, 1 Dec 2018 13:40:50 -0700 Subject: [PATCH] PyHAT Updates (#87) * Small updates to get the scripts working. * Fixed pradius calculations, and made some small changes. * Removed coord transforms and made body_fix func more generic. * Updated reproj doc string. * General refactor to socet scripts and clean up. Simplfied input for both scripts. * Updated conf to use conda prefix * Fixed up __getattr__ func * Corrected attribute error text * Uploaded notebooks for remote access * Updated notebooks with complete footprint function * More or less final notebook * Forgot to tab this in under the first except block * Updated version. * Removed unnecessary notebooks * Removed previously removed module * Removed old import * Fixes issue indexing on the hcube * Updated hcube * Updated io_m3 and io_crism pyhat dependency * hcube and indexing update to handle clipping and other operations --- plio/io/hcube.py | 125 ++++++++++++++++++++++++++- plio/io/io_crism.py | 12 +-- plio/io/io_moon_minerology_mapper.py | 12 +-- plio/utils/indexing.py | 28 ++++-- 4 files changed, 155 insertions(+), 22 deletions(-) diff --git a/plio/io/hcube.py b/plio/io/hcube.py index 5c50699..2e45197 100644 --- a/plio/io/hcube.py +++ b/plio/io/hcube.py @@ -2,6 +2,8 @@ import numpy as np import gdal from ..utils.indexing import _LocIndexer, _iLocIndexer +from libpyhat.transform.continuum import continuum_correction +from libpyhat.transform.continuum import polynomial, linear, regression class HCube(object): @@ -10,6 +12,11 @@ class HCube(object): to optionally add support for spectral labels, label based indexing, and lazy loading for reads. """ + def __init__(self, data = [], wavelengths = []): + if len(data) != 0: + self._data = data + if len(wavelengths) != 0: + self._wavelengths = wavelengths @property def wavelengths(self): @@ -24,6 +31,21 @@ class HCube(object): self._wavelengths = [] return self._wavelengths + @property + def data(self): + if not hasattr(self, '_data'): + try: + key = (slice(None, None, None), + slice(None, None, None), + slice(None, None, None)) + data = self._read(key) + except Exception as e: + print(e) + data = [] + self._data = data + + return self._data + @property def tolerance(self): return getattr(self, '_tolerance', 2) @@ -52,6 +74,104 @@ class HCube(object): def iloc(self): return _iLocIndexer(self) + def reduce(self, how = np.mean, axis = (1, 2)): + """ + Parameters + ---------- + how : function + Function to apply across along axises of the hcube + + axis : tuple + List of axis to apply a given function along + + Returns + ------- + new_hcube : Object + A new hcube object with the reduced data set + """ + res = how(self.data, axis = axis) + + new_hcube = HCube(res, self.wavelengths) + return new_hcube + + def continuum_correct(self, nodes, correction_nodes = np.array([]), correction = linear, + axis=0, adaptive=False, window=3, **kwargs): + """ + Parameters + ---------- + + nodes : list + A list of wavelengths for the continuum to be corrected along + + correction_nodes : list + A list of nodes to limit the correction between + + correction : function + Function specifying the type of correction to perform + along the continuum + + axis : int + Axis to apply the continuum correction on + + adaptive : boolean + ? + + window : int + ? + + Returns + ------- + + new_hcube : Object + A new hcube object with the corrected dataset + """ + + continuum_data = continuum_correction(self.data, self.wavelengths, nodes = nodes, + correction_nodes = correction_nodes, correction = correction, + axis = axis, adaptive = adaptive, + window = window, **kwargs) + + new_hcube = HCube(continuum_data[0], self.wavelengths) + return new_hcube + + + def clip_roi(self, x, y, band, tolerance=2): + """ + Parameters + ---------- + + x : tuple + Lower and upper bound along the x axis for clipping + + y : tuple + Lower and upper bound along the y axis for clipping + + band : tuple + Lower and upper band along the z axis for clipping + + tolerance : int + Tolerance given for trying to find wavelengths + between the upper and lower bound + + Returns + ------- + + new_hcube : Object + A new hcube object with the clipped dataset + """ + wavelength_clip = [] + for wavelength in self.wavelengths: + wavelength_upper = wavelength + tolerance + wavelength_lower = wavelength - tolerance + if wavelength_upper > band[0] and wavelength_lower < band[1]: + wavelength_clip.append(wavelength) + + key = (wavelength_clip, slice(*x), slice(*y)) + data_clip = _LocIndexer(self)[key] + + new_hcube = HCube(np.copy(data_clip), np.array(wavelength_clip)) + return new_hcube + def _read(self, key): ifnone = lambda a, b: b if a is None else a @@ -76,7 +196,10 @@ class HCube(object): elif isinstance(key[0], slice): # Given some slice iterate over the bands and get the bands and pixel space requested - return [self.read_array(i, pixels = pixels) for i in list(range(1, self.nbands + 1))[key[0]]] + arrs = [] + for band in list(list(range(1, self.nbands + 1))[key[0]]): + arrs.append(self.read_array(band, pixels = pixels)) + return np.stack(arrs) else: arrs = [] diff --git a/plio/io/io_crism.py b/plio/io/io_crism.py index a65657f..fb84107 100644 --- a/plio/io/io_crism.py +++ b/plio/io/io_crism.py @@ -4,12 +4,12 @@ from .io_gdal import GeoDataset from .hcube import HCube try: - from libpysat.derived import crism - from libpysat.derived.utils import get_derived_funcs - libpysat_enabled = True + from libpyhat.derived import crism + from libpyhat.derived.utils import get_derived_funcs + libpyhat_enabled = True except: - print('No libpysat module. Unable to attach derived product functions') - libpysat_enabled = False + print('No libpyhat module. Unable to attach derived product functions') + libpyhat_enabled = False import gdal @@ -25,7 +25,7 @@ class Crism(GeoDataset, HCube): self.derived_funcs = {} - if libpysat_enabled: + if libpyhat_enabled: self.derived_funcs = get_derived_funcs(crism) def __getattr__(self, name): diff --git a/plio/io/io_moon_minerology_mapper.py b/plio/io/io_moon_minerology_mapper.py index d078f85..3a4a7a4 100644 --- a/plio/io/io_moon_minerology_mapper.py +++ b/plio/io/io_moon_minerology_mapper.py @@ -4,12 +4,12 @@ from .io_gdal import GeoDataset from .hcube import HCube try: - from libpysat.derived import m3 - from libpysat.derived.utils import get_derived_funcs - libpysat_enabled = True + from libpyhat.derived import m3 + from libpyhat.derived.utils import get_derived_funcs + libpyhat_enabled = True except: - print('No libpysat module. Unable to attach derived product functions') - libpysat_enabled = False + print('No libpyhat module. Unable to attach derived product functions') + libpyhat_enabled = False import gdal @@ -25,7 +25,7 @@ class M3(GeoDataset, HCube): self.derived_funcs = {} - if libpysat_enabled: + if libpyhat_enabled: self.derived_funcs = get_derived_funcs(m3) def __getattr__(self, name): diff --git a/plio/utils/indexing.py b/plio/utils/indexing.py index 9dc7c93..1011d74 100644 --- a/plio/utils/indexing.py +++ b/plio/utils/indexing.py @@ -35,29 +35,35 @@ def expanded_indexer(key, ndim): class _LocIndexer(object): def __init__(self, data_array): self.data_array = data_array - + def __getitem__(self, key): # expand the indexer so we can handle Ellipsis key = expanded_indexer(key, 3) sl = key[0] ifnone = lambda a, b: b if a is None else a if isinstance(sl, slice): - sl = list(range(ifnone(sl.start, 0), self.data_array.nbands, ifnone(sl.step, 1))) + sl = list(range(ifnone(sl.start, 0), + ifnone(sl.stop, len(self.data_array.wavelengths)), + ifnone(sl.step, 1))) if isinstance(sl, (int, float)): idx = self._get_idx(sl) - else: + else: idx = [self._get_idx(s) for s in sl] key = (idx, key[1], key[2]) - return self.data_array._read(key) - + + if len(self.data_array.data) != 0: + return self.data_array.data[key] + + return self.data_array._read(key) + def _get_idx(self, value, tolerance=2): vals = np.abs(self.data_array.wavelengths-value) minidx = np.argmin(vals) if vals[minidx] >= tolerance: warnings.warn("Absolute difference between requested value and found values is {}".format(vals[minidx])) return minidx - + class _iLocIndexer(object): def __init__(self, data_array): self.data_array = data_array @@ -69,8 +75,12 @@ class _iLocIndexer(object): ifnone = lambda a, b: b if a is None else a if isinstance(sl, slice): sl = list(range(ifnone(sl.start, 0), - ifnone(sl.stop, self.data_array.nbands), + ifnone(sl.stop, len(self.data_array.wavelengths)), ifnone(sl.step, 1))) - + key = (key[0], key[1], key[2]) - return self.data_array._read(key) \ No newline at end of file + + if len(self.data_array.data) != 0: + return self.data_array.data[key] + + return self.data_array._read(key) -- GitLab