limit special handling to bytes in user_options

uploaded form data can be bytes, which we base64-encode

don't persist any other unsupported data types, persist None instead
This commit is contained in:
Min RK
2019-03-07 15:30:00 +01:00
parent 4183d45ab3
commit 82c889861d

View File

@@ -3,7 +3,6 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import enum import enum
import json import json
import pickle
from base64 import decodebytes from base64 import decodebytes
from base64 import encodebytes from base64 import encodebytes
from datetime import datetime from datetime import datetime
@@ -59,30 +58,29 @@ class JSONDict(TypeDecorator):
impl = Text impl = Text
def _fallback_pickle(self, obj): def _json_default(self, obj):
"""encode unrecognized objects with pickle""" """encode non-jsonable objects as JSON
try:
pickle_bytes = pickle.dumps(obj, 3) Currently only bytes are supported
except Exception as e:
"""
if not isinstance(obj, bytes):
app_log.warning( app_log.warning(
"Failed to serialize unpickleable data (%s), will persist None.", e "Non-jsonable data in user_options: %r; will persist None.", type(obj)
) )
return None return None
return { return {"__jupyterhub_bytes__": True, "data": encodebytes(obj).decode('ascii')}
"__jupyterhub_pickle__": True,
"data": encodebytes(pickle_bytes).decode('ascii'),
}
def _object_hook(self, dct): def _object_hook(self, dct):
"""decode pickle-packed objects""" """decode non-json objects packed by _json_default"""
if dct.get("__jupyterhub_pickle__", False): if dct.get("__jupyterhub_bytes__", False):
return pickle.loads(decodebytes(dct['data'].encode('ascii'))) return decodebytes(dct['data'].encode('ascii'))
return dct return dct
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
if value is not None: if value is not None:
value = json.dumps(value, default=self._fallback_pickle) value = json.dumps(value, default=self._json_default)
return value return value
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):