From 760fb8df97307648433570d004670baa56283389 Mon Sep 17 00:00:00 2001
From: jlaura <jlaura@usgs.gov>
Date: Mon, 20 Apr 2020 09:57:21 -0700
Subject: [PATCH] Fixes 144 (#146)

* Fixes 144

* Supports read/write, warn on pointLog

* Adds tests and fixes bugs that the tests identified

* Updates for comments

* updates tests
---
 plio/io/io_controlnetwork.py            |  83 +++++++++-
 plio/io/tests/test_io_controlnetwork.py | 200 +++++++++++++-----------
 2 files changed, 187 insertions(+), 96 deletions(-)

diff --git a/plio/io/io_controlnetwork.py b/plio/io/io_controlnetwork.py
index f06a037..28a6cc2 100644
--- a/plio/io/io_controlnetwork.py
+++ b/plio/io/io_controlnetwork.py
@@ -1,4 +1,6 @@
+from enum import IntEnum
 from time import gmtime, strftime
+import warnings
 
 import pandas as pd
 import numpy as np
@@ -13,6 +15,7 @@ from plio.utils.utils import xstr, find_in_dict
 HEADERSTARTBYTE = 65536
 DEFAULTUSERNAME = 'None'
 
+
 def write_filelist(lst, path="fromlist.lis"):
     """
     Writes a filelist to a file so it can be used in ISIS3.
@@ -29,6 +32,73 @@ def write_filelist(lst, path="fromlist.lis"):
         handle.write('\n')
     return
 
+
+class MeasureMessageType(IntEnum):
+    """
+    An enum to mirror the ISIS3 MeasureLogData enum.
+    """
+    GoodnessOfFit = 2
+    MinimumPixelZScore = 3
+    MaximumPixelZScore = 4
+    PixelShift = 5
+    WholePixelCorrelation = 6
+    SubPixelCorrelation = 7 
+
+class MeasureLog():
+    
+    def __init__(self, messagetype, value):
+        """
+        A protobuf compliant measure log object.
+        
+        Parameters
+        ----------
+        messagetype : int or str
+                      Either the integer or string representation from the MeasureMessageType enum
+                      
+        value : int or float
+                The value to be stored in the message log
+        """
+        if isinstance(messagetype, int):
+            # by value
+            self.messagetype = MeasureMessageType(messagetype)
+        else:
+            # by name
+            self.messagetype = MeasureMessageType[messagetype]
+        
+        if not isinstance(value, (float, int)):
+            raise TypeError(f'{value} is not a numeric type')
+        self.value = value
+        
+    def __repr__(self):
+        return f'{self.messagetype.name}: {self.value}'
+        
+    def to_protobuf(self, version=2):
+        """
+        Return protobuf compliant measure log object representation
+        of this class.
+        
+        Returns
+        -------
+        log_message : obj
+                      MeasureLogData object suitable to append to a MeasureLog
+                      repeated field.
+        """
+        # I do not see a better way to get to the inner MeasureLogData obj than this
+        # imports were not working because it looks like these need to instantiate off
+        # an object
+        if version == 2:
+            log_message = cnf.ControlPointFileEntryV0002().Measure().MeasureLogData()
+        elif version == 5:
+            log_message = cnp5.ControlPointFileEntryV0005().Measure().MeasureLogData()
+        log_message.doubleDataValue = self.value
+        log_message.doubleDataType = self.messagetype
+        return log_message
+
+    @classmethod
+    def from_protobuf(cls, protobuf):
+        return cls(protobuf.doubleDataType, protobuf.doubleDataValue)
+
+
 class IsisControlNetwork(pd.DataFrame):
 
     # normal properties
@@ -171,7 +241,6 @@ class IsisStore(object):
             for s in pbuf_header.pointMessageSizes:
                 cp.ParseFromString(self._handle.read(s))
                 pt = [getattr(cp, i) for i in self.point_attrs if i != 'measures']
-
                 for measure in cp.measures:
                     meas = pt + [getattr(measure, j) for j in self.measure_attrs]
                     pts.append(meas)
@@ -211,6 +280,10 @@ class IsisStore(object):
         if 'aprioriline' in df.columns:
             df['aprioriline'] -= 0.5
             df['apriorisample'] -= 0.5
+
+        # Munge the MeasureLogData into Python objs
+        df['measureLog'] = df['measureLog'].apply(lambda x: [MeasureLog.from_protobuf(i) for i in x])
+        
         df.header = pvl_header
         return df
 
@@ -266,6 +339,10 @@ class IsisStore(object):
                 # Un-mangle common attribute names between points and measures
                 df_attr = self.point_field_map.get(attr, attr)
                 if df_attr in g.columns:
+                    if df_attr == 'pointLog':
+                        # Currently pointLog is not supported.
+                        warnings.warn('The pointLog field is currently unsupported. Any pointLog data will not be saved.')
+                        continue
                     # As per protobuf docs for assigning to a repeated field.
                     if df_attr == 'aprioriCovar' or df_attr == 'adjustedCovar':
                         arr = g.iloc[0][df_attr]
@@ -290,8 +367,10 @@ class IsisStore(object):
                     # Un-mangle common attribute names between points and measures
                     df_attr = self.measure_field_map.get(attr, attr)
                     if df_attr in g.columns:
+                        if df_attr == 'measureLog':
+                            [getattr(measure_spec, attr).extend([i.to_protobuf()]) for i in m[df_attr]]
                         # If field is repeated you must extend instead of assign
-                        if cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3:
+                        elif cnf._CONTROLPOINTFILEENTRYV0002_MEASURE.fields_by_name[attr].label == 3:
                             getattr(measure_spec, attr).extend(m[df_attr])
                         else:
                             setattr(measure_spec, attr, attrtype(m[df_attr]))
diff --git a/plio/io/tests/test_io_controlnetwork.py b/plio/io/tests/test_io_controlnetwork.py
index 8b628ce..791d308 100644
--- a/plio/io/tests/test_io_controlnetwork.py
+++ b/plio/io/tests/test_io_controlnetwork.py
@@ -31,97 +31,109 @@ def test_cnet_read(cnet_file):
         assert proto_field not in df.columns
         assert mangled_field in df.columns
 
-class TestWriteIsisControlNetwork(unittest.TestCase):
-
-    @classmethod
-    def setUpClass(cls):
-        cls.npts = 5
-        serial_times = {295: '1971-07-31T01:24:11.754',
-                        296: '1971-07-31T01:24:36.970'}
-        cls.serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
-        columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog']
-
-        data = []
-        for i in range(cls.npts):
-            data.append((i, 2, cls.serials[0], 2, 0, 0, 0, [], []))
-            data.append((i, 2, cls.serials[1], 2, 0, 0, 1, [], []))
-
-        df = pd.DataFrame(data, columns=columns)
-
-        cls.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
-        cls.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
-        io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon')
-
-        cls.header_message_size = 78
-        cls.point_start_byte = 65614 # 66949
-
-    def test_create_buffer_header(self):
-        npts = 5
-        serial_times = {295: '1971-07-31T01:24:11.754',
-                        296: '1971-07-31T01:24:36.970'}
-        serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
-        columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index']
-
-        data = []
-        for i in range(self.npts):
-            data.append((i, 2, serials[0], 2, 0, 0, 0))
-            data.append((i, 2, serials[1], 2, 0, 0, 1))
-
-        df = pd.DataFrame(data, columns=columns)
-
-        self.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
-        self.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
-        io_controlnetwork.to_isis(df, 'test.net', mode='wb', targetname='Moon')
-
-        self.header_message_size = 78
-        self.point_start_byte = 65614 # 66949
-
-        with open('test.net', 'rb') as f:
-            f.seek(io_controlnetwork.HEADERSTARTBYTE)
-            raw_header_message = f.read(self.header_message_size)
-            header_protocol = cnf.ControlNetFileHeaderV0002()
-            header_protocol.ParseFromString(raw_header_message)
-            #Non-repeating
-            #self.assertEqual('None', header_protocol.networkId)
-            self.assertEqual('Moon', header_protocol.targetName)
-            self.assertEqual(io_controlnetwork.DEFAULTUSERNAME,
-                             header_protocol.userName)
-            self.assertEqual(self.creation_date,
-                             header_protocol.created)
-            self.assertEqual('None', header_protocol.description)
-            self.assertEqual(self.modified_date, header_protocol.lastModified)
-            #Repeating
-            self.assertEqual([135] * self.npts, header_protocol.pointMessageSizes)
-
-    def test_create_point(self):
-
-        with open('test.net', 'rb') as f:
-            f.seek(self.point_start_byte)
-            for i, length in enumerate([135] * self.npts):
-                point_protocol = cnf.ControlPointFileEntryV0002()
-                raw_point = f.read(length)
-                point_protocol.ParseFromString(raw_point)
-                self.assertEqual(str(i), point_protocol.id)
-                self.assertEqual(2, point_protocol.type)
-                for m in point_protocol.measures:
-                    self.assertTrue(m.serialnumber in self.serials.values())
-                    self.assertEqual(2, m.type)
-
-    def test_create_pvl_header(self):
-        pvl_header = pvl.load('test.net')
-
-        npoints = find_in_dict(pvl_header, 'NumberOfPoints')
-        self.assertEqual(5, npoints)
-
-        mpoints = find_in_dict(pvl_header, 'NumberOfMeasures')
-        self.assertEqual(10, mpoints)
-
-        points_bytes = find_in_dict(pvl_header, 'PointsBytes')
-        self.assertEqual(675, points_bytes)
-
-        points_start_byte = find_in_dict(pvl_header, 'PointsStartByte')
-        self.assertEqual(self.point_start_byte, points_start_byte)
-
-    @classmethod
-    def tearDownClass(cls):
-        os.remove('test.net')
+@pytest.mark.parametrize('messagetype, value', [
+                         (2, 0.5),
+                         (3, 0.5),
+                         (4, -0.25),
+                         (5, 1e6),
+                         (6, 1),
+                         (7, -1e10),
+                         ('GoodnessOfFit', 0.5),
+                         ('MinimumPixelZScore', 0.25)
+])
+def test_MeasureLog(messagetype, value):
+    l = io_controlnetwork.MeasureLog(messagetype, value)
+    if isinstance(messagetype, int):
+        assert l.messagetype == io_controlnetwork.MeasureMessageType(messagetype)
+    elif isinstance(messagetype, str):
+        assert l.messagetype == io_controlnetwork.MeasureMessageType[messagetype]
+        
+    assert l.value == value
+    assert isinstance(l.to_protobuf, object)
+
+def test_log_error():
+    with pytest.raises(TypeError) as err:
+        io_controlnetwork.MeasureLog(2, 'foo')
+
+def test_to_protobuf():
+    value = 1.25
+    int_dtype = 2
+    log = io_controlnetwork.MeasureLog(int_dtype, value)
+    proto = log.to_protobuf()
+    assert proto.doubleDataType == int_dtype
+    assert proto.doubleDataValue == value
+
+@pytest.fixture
+def cnet_dataframe(tmpdir):
+    npts = 5
+    serial_times = {295: '1971-07-31T01:24:11.754',
+                    296: '1971-07-31T01:24:36.970'}
+    serials = {i:'APOLLO15/METRIC/{}'.format(j) for i, j in enumerate(serial_times.values())}
+    columns = ['id', 'pointType', 'serialnumber', 'measureType', 'sample', 'line', 'image_index', 'pointLog', 'measureLog']
+
+    data = []
+    for i in range(npts):
+        data.append((i, 2, serials[0], 2, 0, 0, 0, [], []))
+        data.append((i, 2, serials[1], 2, 0, 0, 1, [], [io_controlnetwork.MeasureLog(2, 0.5)]))
+
+    df = pd.DataFrame(data, columns=columns)
+
+    df.creation_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
+    df.modified_date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
+    io_controlnetwork.to_isis(df, tmpdir.join('test.net'), mode='wb', targetname='Moon')
+
+    df.header_message_size = 78
+    df.point_start_byte = 65614 # 66949
+    df.npts = npts
+    df.measure_size = 149  # Size of each measure in bytes
+    df.serials = serials
+    return df
+
+def test_create_buffer_header(cnet_dataframe, tmpdir):
+    with open(tmpdir.join('test.net'), 'rb') as f:
+        
+        f.seek(io_controlnetwork.HEADERSTARTBYTE)
+        raw_header_message = f.read(cnet_dataframe.header_message_size)
+        header_protocol = cnf.ControlNetFileHeaderV0002()
+        header_protocol.ParseFromString(raw_header_message)
+        #Non-repeating
+        #self.assertEqual('None', header_protocol.networkId)
+        assert 'Moon' == header_protocol.targetName
+        assert io_controlnetwork.DEFAULTUSERNAME == header_protocol.userName
+        assert cnet_dataframe.creation_date == header_protocol.created
+        assert 'None' == header_protocol.description
+        assert cnet_dataframe.modified_date == header_protocol.lastModified
+        #Repeating
+        assert [cnet_dataframe.measure_size] * cnet_dataframe.npts == header_protocol.pointMessageSizes
+
+def test_create_point(cnet_dataframe, tmpdir):
+    with open(tmpdir.join('test.net'), 'rb') as f:
+        f.seek(cnet_dataframe.point_start_byte)
+        for i, length in enumerate([cnet_dataframe.measure_size] * cnet_dataframe.npts):
+            point_protocol = cnf.ControlPointFileEntryV0002()
+            raw_point = f.read(length)
+            point_protocol.ParseFromString(raw_point)
+            assert str(i) == point_protocol.id
+            assert 2 == point_protocol.type
+            print(len(point_protocol.measures))
+            for i, m in enumerate(point_protocol.measures):
+                assert m.serialnumber in cnet_dataframe.serials.values()
+                assert 2 == m.type
+                assert len(m.log) == i
+
+def test_create_pvl_header(cnet_dataframe, tmpdir):
+    with open(tmpdir.join('test.net'), 'rb') as f:
+        pvl_header = pvl.load(f)
+
+    npoints = find_in_dict(pvl_header, 'NumberOfPoints')
+    assert 5 == npoints
+
+    mpoints = find_in_dict(pvl_header, 'NumberOfMeasures')
+    assert 10 == mpoints
+
+    points_bytes = find_in_dict(pvl_header, 'PointsBytes')
+    assert 745 == points_bytes
+
+    points_start_byte = find_in_dict(pvl_header, 'PointsStartByte')
+    assert cnet_dataframe.point_start_byte == points_start_byte
+
-- 
GitLab