Skip to content
Snippets Groups Projects
Commit d6087ee4 authored by acpaquette's avatar acpaquette Committed by jlaura
Browse files

PyHAT Updates (#87)

* Small updates to get the scripts working.

* Fixed pradius calculations, and made some small changes.

* Removed coord transforms and made body_fix func more generic.

* Updated reproj doc string.

* General refactor to socet scripts and clean up. Simplfied input for both scripts.

* Updated conf to use conda prefix

* Fixed up __getattr__ func

* Corrected attribute error text

* Uploaded notebooks for remote access

* Updated notebooks with complete footprint function

* More or less final notebook

* Forgot to tab this in under the first except block

* Updated version.

* Removed unnecessary notebooks

* Removed previously removed module

* Removed old import

* Fixes issue indexing on the hcube

* Updated hcube

* Updated io_m3 and io_crism pyhat dependency

* hcube and indexing update to handle clipping and other operations
parent d4f55f50
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,8 @@ import numpy as np
import gdal
from ..utils.indexing import _LocIndexer, _iLocIndexer
from libpyhat.transform.continuum import continuum_correction
from libpyhat.transform.continuum import polynomial, linear, regression
class HCube(object):
......@@ -10,6 +12,11 @@ class HCube(object):
to optionally add support for spectral labels, label
based indexing, and lazy loading for reads.
"""
def __init__(self, data = [], wavelengths = []):
if len(data) != 0:
self._data = data
if len(wavelengths) != 0:
self._wavelengths = wavelengths
@property
def wavelengths(self):
......@@ -24,6 +31,21 @@ class HCube(object):
self._wavelengths = []
return self._wavelengths
@property
def data(self):
if not hasattr(self, '_data'):
try:
key = (slice(None, None, None),
slice(None, None, None),
slice(None, None, None))
data = self._read(key)
except Exception as e:
print(e)
data = []
self._data = data
return self._data
@property
def tolerance(self):
return getattr(self, '_tolerance', 2)
......@@ -52,6 +74,104 @@ class HCube(object):
def iloc(self):
return _iLocIndexer(self)
def reduce(self, how = np.mean, axis = (1, 2)):
"""
Parameters
----------
how : function
Function to apply across along axises of the hcube
axis : tuple
List of axis to apply a given function along
Returns
-------
new_hcube : Object
A new hcube object with the reduced data set
"""
res = how(self.data, axis = axis)
new_hcube = HCube(res, self.wavelengths)
return new_hcube
def continuum_correct(self, nodes, correction_nodes = np.array([]), correction = linear,
axis=0, adaptive=False, window=3, **kwargs):
"""
Parameters
----------
nodes : list
A list of wavelengths for the continuum to be corrected along
correction_nodes : list
A list of nodes to limit the correction between
correction : function
Function specifying the type of correction to perform
along the continuum
axis : int
Axis to apply the continuum correction on
adaptive : boolean
?
window : int
?
Returns
-------
new_hcube : Object
A new hcube object with the corrected dataset
"""
continuum_data = continuum_correction(self.data, self.wavelengths, nodes = nodes,
correction_nodes = correction_nodes, correction = correction,
axis = axis, adaptive = adaptive,
window = window, **kwargs)
new_hcube = HCube(continuum_data[0], self.wavelengths)
return new_hcube
def clip_roi(self, x, y, band, tolerance=2):
"""
Parameters
----------
x : tuple
Lower and upper bound along the x axis for clipping
y : tuple
Lower and upper bound along the y axis for clipping
band : tuple
Lower and upper band along the z axis for clipping
tolerance : int
Tolerance given for trying to find wavelengths
between the upper and lower bound
Returns
-------
new_hcube : Object
A new hcube object with the clipped dataset
"""
wavelength_clip = []
for wavelength in self.wavelengths:
wavelength_upper = wavelength + tolerance
wavelength_lower = wavelength - tolerance
if wavelength_upper > band[0] and wavelength_lower < band[1]:
wavelength_clip.append(wavelength)
key = (wavelength_clip, slice(*x), slice(*y))
data_clip = _LocIndexer(self)[key]
new_hcube = HCube(np.copy(data_clip), np.array(wavelength_clip))
return new_hcube
def _read(self, key):
ifnone = lambda a, b: b if a is None else a
......@@ -76,7 +196,10 @@ class HCube(object):
elif isinstance(key[0], slice):
# Given some slice iterate over the bands and get the bands and pixel space requested
return [self.read_array(i, pixels = pixels) for i in list(range(1, self.nbands + 1))[key[0]]]
arrs = []
for band in list(list(range(1, self.nbands + 1))[key[0]]):
arrs.append(self.read_array(band, pixels = pixels))
return np.stack(arrs)
else:
arrs = []
......
......@@ -4,12 +4,12 @@ from .io_gdal import GeoDataset
from .hcube import HCube
try:
from libpysat.derived import crism
from libpysat.derived.utils import get_derived_funcs
libpysat_enabled = True
from libpyhat.derived import crism
from libpyhat.derived.utils import get_derived_funcs
libpyhat_enabled = True
except:
print('No libpysat module. Unable to attach derived product functions')
libpysat_enabled = False
print('No libpyhat module. Unable to attach derived product functions')
libpyhat_enabled = False
import gdal
......@@ -25,7 +25,7 @@ class Crism(GeoDataset, HCube):
self.derived_funcs = {}
if libpysat_enabled:
if libpyhat_enabled:
self.derived_funcs = get_derived_funcs(crism)
def __getattr__(self, name):
......
......@@ -4,12 +4,12 @@ from .io_gdal import GeoDataset
from .hcube import HCube
try:
from libpysat.derived import m3
from libpysat.derived.utils import get_derived_funcs
libpysat_enabled = True
from libpyhat.derived import m3
from libpyhat.derived.utils import get_derived_funcs
libpyhat_enabled = True
except:
print('No libpysat module. Unable to attach derived product functions')
libpysat_enabled = False
print('No libpyhat module. Unable to attach derived product functions')
libpyhat_enabled = False
import gdal
......@@ -25,7 +25,7 @@ class M3(GeoDataset, HCube):
self.derived_funcs = {}
if libpysat_enabled:
if libpyhat_enabled:
self.derived_funcs = get_derived_funcs(m3)
def __getattr__(self, name):
......
......@@ -35,29 +35,35 @@ def expanded_indexer(key, ndim):
class _LocIndexer(object):
def __init__(self, data_array):
self.data_array = data_array
def __getitem__(self, key):
# expand the indexer so we can handle Ellipsis
key = expanded_indexer(key, 3)
sl = key[0]
ifnone = lambda a, b: b if a is None else a
if isinstance(sl, slice):
sl = list(range(ifnone(sl.start, 0), self.data_array.nbands, ifnone(sl.step, 1)))
sl = list(range(ifnone(sl.start, 0),
ifnone(sl.stop, len(self.data_array.wavelengths)),
ifnone(sl.step, 1)))
if isinstance(sl, (int, float)):
idx = self._get_idx(sl)
else:
else:
idx = [self._get_idx(s) for s in sl]
key = (idx, key[1], key[2])
return self.data_array._read(key)
if len(self.data_array.data) != 0:
return self.data_array.data[key]
return self.data_array._read(key)
def _get_idx(self, value, tolerance=2):
vals = np.abs(self.data_array.wavelengths-value)
minidx = np.argmin(vals)
if vals[minidx] >= tolerance:
warnings.warn("Absolute difference between requested value and found values is {}".format(vals[minidx]))
return minidx
class _iLocIndexer(object):
def __init__(self, data_array):
self.data_array = data_array
......@@ -69,8 +75,12 @@ class _iLocIndexer(object):
ifnone = lambda a, b: b if a is None else a
if isinstance(sl, slice):
sl = list(range(ifnone(sl.start, 0),
ifnone(sl.stop, self.data_array.nbands),
ifnone(sl.stop, len(self.data_array.wavelengths)),
ifnone(sl.step, 1)))
key = (key[0], key[1], key[2])
return self.data_array._read(key)
\ No newline at end of file
if len(self.data_array.data) != 0:
return self.data_array.data[key]
return self.data_array._read(key)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment