From 760fb8df97307648433570d004670baa56283389 Mon Sep 17 00:00:00 2001 From: jlaura <jlaura@usgs.gov> Date: Mon, 20 Apr 2020 09:57:21 -0700 Subject: [PATCH] Fixes 144 (#146) * Fixes 144 * Supports read/write, warn on pointLog * Adds tests and fixes bugs that the tests identified * Updates for comments * updates tests --- plio/io/io_controlnetwork.py | 83 +++++++++- plio/io/tests/test_io_controlnetwork.py | 200 +++++++++++++----------- 2 files changed, 187 insertions(+), 96 deletions(-) diff --git a/plio/io/io_controlnetwork.py b/plio/io/io_controlnetwork.py index f06a037..28a6cc2 100644 --- a/plio/io/io_controlnetwork.py +++ b/plio/io/io_controlnetwork.py @@ -1,4 +1,6 @@ +from enum import IntEnum from time import gmtime, strftime +import warnings import pandas as pd import numpy as np @@ -13,6 +15,7 @@ from plio.utils.utils import xstr, find_in_dict HEADERSTARTBYTE = 65536 DEFAULTUSERNAME = 'None' + def write_filelist(lst, path="fromlist.lis"): """ Writes a filelist to a file so it can be used in ISIS3. @@ -29,6 +32,73 @@ def write_filelist(lst, path="fromlist.lis"): handle.write('\n') return + +class MeasureMessageType(IntEnum): + """ + An enum to mirror the ISIS3 MeasureLogData enum. + """ + GoodnessOfFit = 2 + MinimumPixelZScore = 3 + MaximumPixelZScore = 4 + PixelShift = 5 + WholePixelCorrelation = 6 + SubPixelCorrelation = 7 + +class MeasureLog(): + + def __init__(self, messagetype, value): + """ + A protobuf compliant measure log object. + + Parameters + ---------- + messagetype : int or str + Either the integer or string representation from the MeasureMessageType enum + + value : int or float + The value to be stored in the message log + """ + if isinstance(messagetype, int): + # by value + self.messagetype = MeasureMessageType(messagetype) + else: + # by name + self.messagetype = MeasureMessageType[messagetype] + + if not isinstance(value, (float, int)): + raise TypeError(f'{value} is not a numeric type') + self.value = value + + def __repr__(self): + return f'{self.messagetype.name}: {self.value}' + + def to_protobuf(self, version=2): + """ + Return protobuf compliant measure log object representation + of this class. + + Returns + ------- + log_message : obj + MeasureLogData object suitable to append to a MeasureLog + repeated field. + """ + # I do not see a better way to get to the inner MeasureLogData obj than this + # imports were not working because it looks like these need to instantiate off + # an object + if version == 2: + log_message = cnf.ControlPointFileEntryV0002().Measure().MeasureLogData() + elif version == 5: + log_message = cnp5.ControlPointFileEntryV0005().Measure().MeasureLogData() + log_message.doubleDataValue = self.value + log_message.doubleDataType = self.messagetype + return log_message + + @classmethod + def from_protobuf(cls, protobuf): + return cls(protobuf.doubleDataType, protobuf.doubleDataValue) + + class IsisControlNetwork(pd.DataFrame): # normal properties @@ -171,7 +241,6 @@ class IsisStore(object): for s in pbuf_header.pointMessageSizes: cp.ParseFromString(self._handle.read(s)) pt = [getattr(cp, i) for i in self.point_attrs if i != 'measures'] - for measure in cp.measures: meas = pt + [getattr(measure, j) for j in self.measure_attrs] pts.append(meas) @@ -211,6 +280,10 @@ class IsisStore(object): if 'aprioriline' in df.columns: df['aprioriline'] -= 0.5 df['apriorisample'] -= 0.5 + + # Munge the MeasureLogData into Python objs + df['measureLog'] = df['measureLog'].apply(lambda x: [MeasureLog.from_protobuf(i) for i in x]) + df.header = pvl_header return df @@ -266,6 +339,10 @@ class IsisStore(object): # Un-mangle common attribute names between points and measures df_attr = self.point_field_map.get(attr, attr) if df_attr in g.columns: + if df_attr == 'pointLog': + # Currently pointLog is not supported. + warnings.warn('The pointLog field is currently unsupported. Any pointLog data will not be saved.') + continue # As per protobuf docs for assigning to a repeated field. if df_attr == 'aprioriCovar' or df_attr == 'adjustedCovar': arr = g.iloc[0][df_attr] @@ -290,8 +367,10 @@ class IsisStore(object): # Un-mangle common attribute names between points and measures df_attr = self.measure_field_map.get(attr, attr) if df_attr in g.columns: + if df_attr == 'measureLog': + [getattr(measure_spec, attr).extend([i.to_protobuf()]) for i in m[df_attr]] # If field is repeated you must extend instead of assign - if cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3: + elif cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3: getattr(measure_spec, attr).extend(m[df_attr]) else: setattr(measure_spec, attr, attrtype(m[df_attr])) diff --git a/plio/io/tests/test_io_controlnetwork.py b/plio/io/tests/test_io_controlnetwork.py index 8b628ce..791d308 100644 --- a/plio/io/tests/test_io_controlnetwork.py +++ b/plio/io/tests/test_io_controlnetwork.py @@ -31,97 +31,109 @@ def test_cnet_read(cnet_file): assert proto_field not in df.columns assert mangled_field in df.columns -class TestWriteIsisControlNetwork(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.npts = 5 - serial_times = {295: '1971-07-31T01:24:11.754', - 296: '1971-07-31T01:24:36.970'} - cls.serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())} - columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog'] - - data = [] - for i in range(cls.npts): - data.append((i, 2, cls.serials[0], 2, 0, 0, 0, [], [])) - data.append((i, 2, cls.serials[1], 2, 0, 0, 1, [], [])) - - df = pd.DataFrame(data, columns=columns) - - cls.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime()) - cls.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime()) - io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon') - - cls.header_message_size = 78 - cls.point_start_byte = 65614 # 66949 - - def test_create_buffer_header(self): - npts = 5 - serial_times = {295: '1971-07-31T01:24:11.754', - 296: '1971-07-31T01:24:36.970'} - serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())} - columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index'] - - data = [] - for i in range(self.npts): - data.append((i, 2, serials[0], 2, 0, 0, 0)) - data.append((i, 2, serials[1], 2, 0, 0, 1)) - - df = pd.DataFrame(data, columns=columns) - - self.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime()) - self.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime()) - io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon') - - self.header_message_size = 78 - self.point_start_byte = 65614 # 66949 - - with open('test.net', 'rb') as f: - f.seek(io_controlnetwork.HEADERSTARTBYTE) - raw_header_message = f.read(self.header_message_size) - header_protocol = cnf.ControlNetFileHeaderV0002() - header_protocol.ParseFromString(raw_header_message) - #Non-repeating - #self.assertEqual('None', header_protocol.networkId) - self.assertEqual('Moon', header_protocol.targetName) - self.assertEqual(io_controlnetwork.DEFAULTUSERNAME, - header_protocol.userName) - self.assertEqual(self.creation_date, - header_protocol.created) - self.assertEqual('None', header_protocol.description) - self.assertEqual(self.modified_date, header_protocol.lastModified) - #Repeating - self.assertEqual([135] * self.npts, header_protocol.pointMessageSizes) - - def test_create_point(self): - - with open('test.net', 'rb') as f: - f.seek(self.point_start_byte) - for i, length in enumerate([135] * self.npts): - point_protocol = cnf.ControlPointFileEntryV0002() - raw_point = f.read(length) - point_protocol.ParseFromString(raw_point) - self.assertEqual(str(i), point_protocol.id) - self.assertEqual(2, point_protocol.type) - for m in point_protocol.measures: - self.assertTrue(m.serialnumber in self.serials.values()) - self.assertEqual(2, m.type) - - def test_create_pvl_header(self): - pvl_header = pvl.load('test.net') - - npoints = find_in_dict(pvl_header, 'NumberOfPoints') - self.assertEqual(5, npoints) - - mpoints = find_in_dict(pvl_header, 'NumberOfMeasures') - self.assertEqual(10, mpoints) - - points_bytes = find_in_dict(pvl_header, 'PointsBytes') - self.assertEqual(675, points_bytes) - - points_start_byte = find_in_dict(pvl_header, 'PointsStartByte') - self.assertEqual(self.point_start_byte, points_start_byte) - - @classmethod - def tearDownClass(cls): - os.remove('test.net') +@pytest.mark.parametrize('messagetype, value', [ + (2, 0.5), + (3, 0.5), + (4, -0.25), + (5, 1e6), + (6, 1), + (7, -1e10), + ('GoodnessOfFit', 0.5), + ('MinimumPixelZScore', 0.25) +]) +def test_MeasureLog(messagetype, value): + l = io_controlnetwork.MeasureLog(messagetype, value) + if isinstance(messagetype, int): + assert l.messagetype == io_controlnetwork.MeasureMessageType(messagetype) + elif isinstance(messagetype, str): + assert l.messagetype == io_controlnetwork.MeasureMessageType[messagetype] + + assert l.value == value + assert isinstance(l.to_protobuf, object) + +def test_log_error(): + with pytest.raises(TypeError) as err: + io_controlnetwork.MeasureLog(2, 'foo') + +def test_to_protobuf(): + value = 1.25 + int_dtype = 2 + log = io_controlnetwork.MeasureLog(int_dtype, value) + proto = log.to_protobuf() + assert proto.doubleDataType == int_dtype + assert proto.doubleDataValue == value + +@pytest.fixture +def cnet_dataframe(tmpdir): + npts = 5 + serial_times = {295: '1971-07-31T01:24:11.754', + 296: '1971-07-31T01:24:36.970'} + serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())} + columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog'] + + data = [] + for i in range(npts): + data.append((i, 2, serials[0], 2, 0, 0, 0, [], [])) + data.append((i, 2, serials[1], 2, 0, 0, 1, [], [io_controlnetwork.MeasureLog(2, 0.5)])) + + df = pd.DataFrame(data, columns=columns) + + df.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime()) + df.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime()) + io_controlnetwork.to_isis(df, tmpdir.join('test.net'), mode='wb', targetname='Moon') + + df.header_message_size = 78 + df.point_start_byte = 65614 # 66949 + df.npts = npts + df.measure_size = 149 # Size of each measure in bytes + df.serials = serials + return df + +def test_create_buffer_header(cnet_dataframe, tmpdir): + with open(tmpdir.join('test.net'), 'rb') as f: + + f.seek(io_controlnetwork.HEADERSTARTBYTE) + raw_header_message = f.read(cnet_dataframe.header_message_size) + header_protocol = cnf.ControlNetFileHeaderV0002() + header_protocol.ParseFromString(raw_header_message) + #Non-repeating + #self.assertEqual('None', header_protocol.networkId) + assert 'Moon' == header_protocol.targetName + assert io_controlnetwork.DEFAULTUSERNAME == header_protocol.userName + assert cnet_dataframe.creation_date == header_protocol.created + assert 'None' == header_protocol.description + assert cnet_dataframe.modified_date == header_protocol.lastModified + #Repeating + assert [cnet_dataframe.measure_size] * cnet_dataframe.npts == header_protocol.pointMessageSizes + +def test_create_point(cnet_dataframe, tmpdir): + with open(tmpdir.join('test.net'), 'rb') as f: + f.seek(cnet_dataframe.point_start_byte) + for i, length in enumerate([cnet_dataframe.measure_size] * cnet_dataframe.npts): + point_protocol = cnf.ControlPointFileEntryV0002() + raw_point = f.read(length) + point_protocol.ParseFromString(raw_point) + assert str(i) == point_protocol.id + assert 2 == point_protocol.type + print(len(point_protocol.measures)) + for i, m in enumerate(point_protocol.measures): + assert m.serialnumber in cnet_dataframe.serials.values() + assert 2 == m.type + assert len(m.log) == i + +def test_create_pvl_header(cnet_dataframe, tmpdir): + with open(tmpdir.join('test.net'), 'rb') as f: + pvl_header = pvl.load(f) + + npoints = find_in_dict(pvl_header, 'NumberOfPoints') + assert 5 == npoints + + mpoints = find_in_dict(pvl_header, 'NumberOfMeasures') + assert 10 == mpoints + + points_bytes = find_in_dict(pvl_header, 'PointsBytes') + assert 745 == points_bytes + + points_start_byte = find_in_dict(pvl_header, 'PointsStartByte') + assert cnet_dataframe.point_start_byte == points_start_byte + -- GitLab