Skip to content
Snippets Groups Projects
Commit 71a6e5ef authored by Kelvin's avatar Kelvin
Browse files

fixed loads test

parent 90bf0233
No related branches found
No related tags found
No related merge requests found
from . import drivers from . import drivers
from .drivers import load from .drivers import load, loads
...@@ -13,6 +13,8 @@ from datetime import datetime, date ...@@ -13,6 +13,8 @@ from datetime import datetime, date
from abc import ABC from abc import ABC
import datetime
# dynamically load drivers # dynamically load drivers
__all__ = [os.path.splitext(os.path.basename(d))[0] for d in glob(os.path.join(os.path.dirname(__file__), '*_driver.py'))] __all__ = [os.path.splitext(os.path.basename(d))[0] for d in glob(os.path.join(os.path.dirname(__file__), '*_driver.py'))]
__driver_modules__ = [importlib.import_module('.'+m, package='ale.drivers') for m in __all__] __driver_modules__ = [importlib.import_module('.'+m, package='ale.drivers') for m in __all__]
...@@ -41,3 +43,18 @@ def load(label): ...@@ -41,3 +43,18 @@ def load(label):
print("Driver Failed:", e) print("Driver Failed:", e)
traceback.print_exc() traceback.print_exc()
raise Exception('No Such Driver for Label') raise Exception('No Such Driver for Label')
def loads(label):
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.int64):
return int(obj)
if isinstance(obj, datetime.datetime):
return obj.__str__()
return json.JSONEncoder.default(self, obj)
res = load(label).to_dict()
return json.dumps(res, cls=JsonEncoder)
...@@ -272,9 +272,9 @@ class Driver(): ...@@ -272,9 +272,9 @@ class Driver():
if isinstance(self._file, pvl.PVLModule): if isinstance(self._file, pvl.PVLModule):
self._label = self._file self._label = self._file
try: try:
self._label = pvl.loads(self._file, strict=False) self._label = pvl.loads(self._file)
except AttributeError: except Exception:
self._label = pvl.load(self._file, strict=False) self._label = pvl.load(self._file)
except: except:
raise ValueError("{} is not a valid label".format(self._file)) raise ValueError("{} is not a valid label".format(self._file))
return self._label return self._label
......
...@@ -355,7 +355,7 @@ namespace ale { ...@@ -355,7 +355,7 @@ namespace ale {
PyObject *pDict = PyModule_GetDict(pModule); PyObject *pDict = PyModule_GetDict(pModule);
// Get the add method from the dictionary. // Get the add method from the dictionary.
PyObject *pFunc = PyDict_GetItemString(pDict, "load"); PyObject *pFunc = PyDict_GetItemString(pDict, "loads");
if(!pFunc) { if(!pFunc) {
// import errors do not set a PyError flag, need to use a custom // import errors do not set a PyError flag, need to use a custom
// error message instead. // error message instead.
...@@ -382,15 +382,13 @@ namespace ale { ...@@ -382,15 +382,13 @@ namespace ale {
throw invalid_argument(getPyTraceback()); throw invalid_argument(getPyTraceback());
} }
std::string cResult;
// use PyObject_Str to ensure return is always a string
PyObject *pResultStr = PyObject_Str(pResult); PyObject *pResultStr = PyObject_Str(pResult);
PyObject *temp_bytes = PyUnicode_AsUTF8String(pResultStr); // Owned reference PyObject *temp_bytes = PyUnicode_AsUTF8String(pResultStr); // Owned reference
if(!temp_bytes){ if(!temp_bytes){
throw invalid_argument(getPyTraceback()); throw invalid_argument(getPyTraceback());
} }
std::string cResult;
char *temp_str = PyBytes_AS_STRING(temp_bytes); // Borrowed pointer char *temp_str = PyBytes_AS_STRING(temp_bytes); // Borrowed pointer
cResult = temp_str; // copy into std::string cResult = temp_str; // copy into std::string
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment