From 71a6e5ef31a6ec0e83caeb97953afee1e6f6ba40 Mon Sep 17 00:00:00 2001 From: Kelvin Date: Fri, 15 Mar 2019 14:24:24 -0700 Subject: [PATCH] fixed loads test --- ale/__init__.py | 2 +- ale/drivers/__init__.py | 17 +++++++++++++++++ ale/drivers/base.py | 6 +++--- src/ale.cpp | 8 +++----- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/ale/__init__.py b/ale/__init__.py index f2d46ca..b649ee8 100644 --- a/ale/__init__.py +++ b/ale/__init__.py @@ -1,2 +1,2 @@ from . import drivers -from .drivers import load +from .drivers import load, loads diff --git a/ale/drivers/__init__.py b/ale/drivers/__init__.py index 8448214..56a9e09 100644 --- a/ale/drivers/__init__.py +++ b/ale/drivers/__init__.py @@ -13,6 +13,8 @@ from datetime import datetime, date from abc import ABC +import datetime + # dynamically load drivers __all__ = [os.path.splitext(os.path.basename(d))[0] for d in glob(os.path.join(os.path.dirname(__file__), '*_driver.py'))] __driver_modules__ = [importlib.import_module('.'+m, package='ale.drivers') for m in __all__] @@ -41,3 +43,18 @@ def load(label): print("Driver Failed:", e) traceback.print_exc() raise Exception('No Such Driver for Label') + + +def loads(label): + class JsonEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.int64): + return int(obj) + if isinstance(obj, datetime.datetime): + return obj.__str__() + return json.JSONEncoder.default(self, obj) + + res = load(label).to_dict() + return json.dumps(res, cls=JsonEncoder) diff --git a/ale/drivers/base.py b/ale/drivers/base.py index bff6c6d..6afb1fb 100644 --- a/ale/drivers/base.py +++ b/ale/drivers/base.py @@ -272,9 +272,9 @@ class Driver(): if isinstance(self._file, pvl.PVLModule): self._label = self._file try: - self._label = pvl.loads(self._file, strict=False) - except AttributeError: - self._label = pvl.load(self._file, strict=False) + self._label = pvl.loads(self._file) + except Exception: + self._label = pvl.load(self._file) except: raise ValueError("{} is not a valid label".format(self._file)) return self._label diff --git a/src/ale.cpp b/src/ale.cpp index 5cf73cd..e63db09 100644 --- a/src/ale.cpp +++ b/src/ale.cpp @@ -355,7 +355,7 @@ namespace ale { PyObject *pDict = PyModule_GetDict(pModule); // Get the add method from the dictionary. - PyObject *pFunc = PyDict_GetItemString(pDict, "load"); + PyObject *pFunc = PyDict_GetItemString(pDict, "loads"); if(!pFunc) { // import errors do not set a PyError flag, need to use a custom // error message instead. @@ -382,15 +382,13 @@ namespace ale { throw invalid_argument(getPyTraceback()); } - std::string cResult; - - // use PyObject_Str to ensure return is always a string PyObject *pResultStr = PyObject_Str(pResult); PyObject *temp_bytes = PyUnicode_AsUTF8String(pResultStr); // Owned reference + if(!temp_bytes){ throw invalid_argument(getPyTraceback()); } - + std::string cResult; char *temp_str = PyBytes_AS_STRING(temp_bytes); // Borrowed pointer cResult = temp_str; // copy into std::string -- GitLab