Skip to content
Snippets Groups Projects
Commit 6f0a25e3 authored by jlaura's avatar jlaura Committed by Jason R Laura
Browse files

Test updates for windows.

parent 08e6d5ad
No related branches found
No related tags found
No related merge requests found
...@@ -468,30 +468,39 @@ class GeoDataset(object): ...@@ -468,30 +468,39 @@ class GeoDataset(object):
def array_to_raster(array, file_name, projection=None, def array_to_raster(array, file_name, projection=None,
geotransform=None, outformat='GTiff', geotransform=None, outformat='GTiff',
ndv=None): ndv=None, bittype='GDT_Float64'):
""" """
Converts the given NumPy array to a raster format using the GeoDataset class. Converts the given NumPy array to a raster format using the GeoDataset class.
Parameters Parameters
---------- ----------
array : ndarray array : ndarray
The data to be written via GDAL
file_name : str file_name : str
The output file PATH (relative or absolute)
projection : projection : object
Default projection=None. A GDAL readable projection object, WKT string, PROJ4 string, etc.
Default: None
geotransform : object geotransform : object
Default geotransform=None. A six parameter geotransformation
Default:None.
outformat : str outformat : str
Default outformat='GTiff'. A GDAL supported output format
Default: 'GTiff'.
ndv : float ndv : float
The no data value for the given band. See no_data_value(). Default ndv=None. The no data value for the given band.
Default: None.
bittype : str
A GDAL supported bittype, e.g. GDT_Int32
Default: GDT_Float64
""" """
driver = gdal.GetDriverByName(outformat) driver = gdal.GetDriverByName(outformat)
try: try:
y, x, bands = array.shape y, x, bands = array.shape
...@@ -501,8 +510,7 @@ def array_to_raster(array, file_name, projection=None, ...@@ -501,8 +510,7 @@ def array_to_raster(array, file_name, projection=None,
y, x = array.shape y, x = array.shape
single = True single = True
#This is a crappy hard code to 32bit. dataset = driver.Create(file_name, x, y, bands, getattr(gdal, bittype))
dataset = driver.Create(file_name, x, y, bands, gdal.GDT_Float64)
if geotransform: if geotransform:
dataset.SetGeoTransform(geotransform) dataset.SetGeoTransform(geotransform)
......
...@@ -179,14 +179,14 @@ class TestWriter(unittest.TestCase): ...@@ -179,14 +179,14 @@ class TestWriter(unittest.TestCase):
dataset = io_gdal.GeoDataset('test.tif') dataset = io_gdal.GeoDataset('test.tif')
self.assertEqual(gt, dataset.geotransform) self.assertEqual(gt, dataset.geotransform)
def test_with_no_data_value(self): def test_with_no_data_value_nd(self):
no_data_value = 0.0 no_data_value = 0.0
#nd array
io_gdal.array_to_raster(self.ndarr, 'test.tif', ndv=no_data_value) io_gdal.array_to_raster(self.ndarr, 'test.tif', ndv=no_data_value)
dataset = io_gdal.GeoDataset('test.tif') dataset = io_gdal.GeoDataset('test.tif')
self.assertEqual(dataset.no_data_value, no_data_value) self.assertEqual(dataset.no_data_value, no_data_value)
#array def test_with_no_data_value(self):
no_data_value = 0.0
io_gdal.array_to_raster(self.arr, 'test.tif', ndv=no_data_value) io_gdal.array_to_raster(self.arr, 'test.tif', ndv=no_data_value)
dataset = io_gdal.GeoDataset('test.tif') dataset = io_gdal.GeoDataset('test.tif')
self.assertEqual(dataset.no_data_value, no_data_value) self.assertEqual(dataset.no_data_value, no_data_value)
......
...@@ -17,10 +17,6 @@ class TestHDF(unittest.TestCase): ...@@ -17,10 +17,6 @@ class TestHDF(unittest.TestCase):
cls.df = pd.DataFrame(cls.x[['bar', 'baz']], index=cls.x['index'], cls.df = pd.DataFrame(cls.x[['bar', 'baz']], index=cls.x['index'],
columns=['bar', 'baz']) columns=['bar', 'baz'])
@classmethod
def tearDownClass(cls):
os.remove('test_io_hdf.hdf')
def test_df_sarray(self): def test_df_sarray(self):
converted = self.hdf.df_to_sarray(self.df.reset_index()) converted = self.hdf.df_to_sarray(self.df.reset_index())
np.testing.assert_array_equal(converted, self.x) np.testing.assert_array_equal(converted, self.x)
...@@ -29,4 +25,8 @@ class TestHDF(unittest.TestCase): ...@@ -29,4 +25,8 @@ class TestHDF(unittest.TestCase):
converted = self.hdf.sarray_to_df(self.x) converted = self.hdf.sarray_to_df(self.x)
self.assertTrue((self.df == converted).all().all()) self.assertTrue((self.df == converted).all().all())
@classmethod
def tearDownClass(cls):
try:
os.remove('test_io_hdf.hdf')
except: pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment