diff --git a/plio/io/hcube.py b/plio/io/hcube.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1ed9fea846067aaa16a3c1be62d7d3f14df75f --- /dev/null +++ b/plio/io/hcube.py @@ -0,0 +1,80 @@ +import numpy as np +import gdal + +from ..utils.indexing import _LocIndexer, _iLocIndexer + + +class HCube(object): + """ + A Mixin class for use with the io_gdal.GeoDataset class + to optionally add support for spectral labels, label + based indexing, and lazy loading for reads. + """ + + @property + def wavelengths(self): + if not hasattr(self, '_wavelengths'): + try: + info = gdal.Info(self.file_name, format='json') + wavelengths = [float(j) for i, j in sorted(info['metadata'][''].items(), + key=lambda x: float(x[0].split('_')[-1]))] + self._original_wavelengths = wavelengths + self._wavelengths = np.round(wavelengths, self.tolerance) + except: + self._wavelengths = [] + return self._wavelengths + + @property + def tolerance(self): + return getattr(self, '_tolerance', 2) + + @tolerance.setter + def tolerance(self, val): + if isinstance(val, int): + self._tolerance = val + self._reindex() + else: + raise TypeError + + def _reindex(self): + if self._original_wavelengths is not None: + self._wavelengths = np.round(self._original_wavelengths, decimals=self.tolerance) + + def __getitem__(self, key): + i = _iLocIndexer(self) + return i[key] + + @property + def loc(self): + return _LocIndexer(self) + + @property + def iloc(self): + return _iLocIndexer(self) + + def _read(self, key): + ifnone = lambda a, b: b if a is None else a + + y = key[1] + x = key[2] + if isinstance(x, slice): + xstart = ifnone(x.start,0) + xstop = ifnone(x.stop,self.raster_size[0]) + xstep = xstop - xstart + else: + raise TypeError("Loc style access elements must be slices, e.g., [:] or [10:100]") + if isinstance(y, slice): + ystart = ifnone(y.start, 0) + ystop = ifnone(y.stop, self.raster_size[1]) + ystep = ystop - ystart + else: + raise TypeError("Loc style access elements must be slices, e.g., [:] or [10:100]") + + pixels = (xstart, ystart, xstep, ystep) + if isinstance(key[0], (int, np.integer)): + return self.read_array(band=int(key[0]+1), pixels=pixels) + else: + arrs = [] + for b in key[0]: + arrs.append(self.read_array(band=int(b+1), pixels=pixels)) + return np.stack(arrs) \ No newline at end of file diff --git a/plio/io/io_crism.py b/plio/io/io_crism.py new file mode 100644 index 0000000000000000000000000000000000000000..415ec472a92feb0993ec9a2c6de99f591282b678 --- /dev/null +++ b/plio/io/io_crism.py @@ -0,0 +1,32 @@ +import os +import numpy as np +from .io_gdal import GeoDataset +from .hcube import HCube +import gdal + + +class Crism(GeoDataset, HCube): + """ + An M3 specific reader with the spectral mixin. + """ + @property + def wavelengths(self): + if not hasattr(self, '_wavelengths'): + try: + info = gdal.Info(self.file_name, format='json') + wv = dict((k,v) for (k,v) in info['metadata'][''].items() if 'Band' in k) # Only get the 'Band_###' keys + wavelengths = [float(j.split(" ")[0]) for i, j in sorted(wv.items(), + key=lambda x: int(x[0].split('_')[-1]))] + self._original_wavelengths = wavelengths + self._wavelengths = np.round(wavelengths, self.tolerance) + except: + self._wavelengths = [] + return self._wavelengths + +def open(input_data): + if os.path.splitext(input_data)[-1] == 'hdr': + # GDAL wants the img, but many users aim at the .hdr + input_data = os.path.splitext(input_data)[:-1] + '.img' + ds = Crism(input_data) + + return ds \ No newline at end of file diff --git a/plio/io/io_moon_minerology_mapper.py b/plio/io/io_moon_minerology_mapper.py index 6a8ad1e3de5b240b82ccc1a4b1fae01624184d53..c28a7dc99da72d14ecc562466906903365b7299e 100644 --- a/plio/io/io_moon_minerology_mapper.py +++ b/plio/io/io_moon_minerology_mapper.py @@ -1,24 +1,38 @@ +import os import numpy as np from .io_gdal import GeoDataset +from .hcube import HCube +import gdal +class M3(GeoDataset, HCube): + """ + An M3 specific reader with the spectral mixin. + """ + @property + def wavelengths(self): + if not hasattr(self, '_wavelengths'): + try: + info = gdal.Info(self.file_name, format='json') + if 'Resize' in info['metadata']['']['Band_1']: + wavelengths = [float(j.split(' ')[-1].replace('(','').replace(')', '')) for\ + i,j in sorted(info['metadata'][''].items(), + key=lambda x: float(x[0].split('_')[-1]))] + # This is a geotiff translated from the PDS IMG + else: + # This is a PDS IMG + wavelengths = [float(j) for i, j in sorted(info['metadata'][''].items(), + key=lambda x: float(x[0].split('_')[-1]))] + self._original_wavelengths = wavelengths + self._wavelengths = np.round(wavelengths, self.tolerance) + except: + self._wavelengths = [] + return self._wavelengths + def open(input_data): - if input_data.split('.')[-1] == 'hdr': + if os.path.splitext(input_data)[-1] == 'hdr': # GDAL wants the img, but many users aim at the .hdr - input_data = input_data.split('.')[0] + '.img' - ds = GeoDataSet(input_data) - ref_array = ds.read_array() - metadata = ds.metadata - wv_array = metadatatoband(metadata) - return wv_array, ref_array, ds + input_data = os.path.splitext(input_data)[:-1] + '.img' + ds = M3(input_data) -def metadatatoband(metadata): - wv2band = [] - for k, v in metadata.items(): - try: - wv2band.append(float(v)) - except: - v = v.split(" ")[-1].split("(")[1].split(")")[0] - wv2band.append(float(v)) - wv2band.sort(key=int) - return np.asarray(wv2band) + return ds diff --git a/plio/utils/indexing.py b/plio/utils/indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc7c93cea8d933f6dd7ba183a7cc843fce8a96e --- /dev/null +++ b/plio/utils/indexing.py @@ -0,0 +1,76 @@ +import warnings +import numpy as np + +def is_dict_like(value): + return hasattr(value, 'keys') and hasattr(value, '__getitem__') + +def expanded_indexer(key, ndim): + """Given a key for indexing an ndarray, return an equivalent key which is a + tuple with length equal to the number of dimensions. + The expansion is done by replacing all `Ellipsis` items with the right + number of full slices and then padding the key with full slices so that it + reaches the appropriate dimensionality. + """ + if not isinstance(key, tuple): + # numpy treats non-tuple keys equivalent to tuples of length 1 + key = (key,) + new_key = [] + # handling Ellipsis right is a little tricky, see: + # http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing + found_ellipsis = False + for k in key: + if k is Ellipsis: + if not found_ellipsis: + new_key.extend((ndim + 1 - len(key)) * [slice(None)]) + found_ellipsis = True + else: + new_key.append(slice(None)) + else: + new_key.append(k) + if len(new_key) > ndim: + raise IndexError('too many indices') + new_key.extend((ndim - len(new_key)) * [slice(None)]) + return tuple(new_key) + +class _LocIndexer(object): + def __init__(self, data_array): + self.data_array = data_array + + def __getitem__(self, key): + # expand the indexer so we can handle Ellipsis + key = expanded_indexer(key, 3) + sl = key[0] + ifnone = lambda a, b: b if a is None else a + if isinstance(sl, slice): + sl = list(range(ifnone(sl.start, 0), self.data_array.nbands, ifnone(sl.step, 1))) + + if isinstance(sl, (int, float)): + idx = self._get_idx(sl) + else: + idx = [self._get_idx(s) for s in sl] + key = (idx, key[1], key[2]) + return self.data_array._read(key) + + def _get_idx(self, value, tolerance=2): + vals = np.abs(self.data_array.wavelengths-value) + minidx = np.argmin(vals) + if vals[minidx] >= tolerance: + warnings.warn("Absolute difference between requested value and found values is {}".format(vals[minidx])) + return minidx + +class _iLocIndexer(object): + def __init__(self, data_array): + self.data_array = data_array + + def __getitem__(self, key): + # expand the indexer so we can handle Ellipsis + key = expanded_indexer(key, 3) + sl = key[0] + ifnone = lambda a, b: b if a is None else a + if isinstance(sl, slice): + sl = list(range(ifnone(sl.start, 0), + ifnone(sl.stop, self.data_array.nbands), + ifnone(sl.step, 1))) + + key = (key[0], key[1], key[2]) + return self.data_array._read(key) \ No newline at end of file diff --git a/plio/utils/utils.py b/plio/utils/utils.py index 1cec0dfd11c353a1aeab1ef7a675f8aa66fb1814..2215f7558cafda549fd94ecc4013fcbfadec75e5 100644 --- a/plio/utils/utils.py +++ b/plio/utils/utils.py @@ -7,6 +7,18 @@ import shutil import tempfile import pandas as pd +import numpy as np + +def metadatatoband(metadata): + wv2band = [] + for k, v in metadata.items(): + try: + wv2band.append(float(v)) + except: + v = v.split(" ")[-1].split("(")[1].split(")")[0] + wv2band.append(float(v)) + wv2band.sort(key=int) + return np.asarray(wv2band) def create_dir(basedir=''): """