Merge "Make the msgpackutils handlers more extendable"
This commit is contained in:
commit
6216d7b610
@ -13,6 +13,7 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
@ -33,60 +34,104 @@ else:
|
|||||||
_PY26 = False
|
_PY26 = False
|
||||||
|
|
||||||
|
|
||||||
def _serialize_datetime(dt):
|
class HandlerRegistry(object):
|
||||||
dct = {
|
|
||||||
'day': dt.day,
|
|
||||||
'month': dt.month,
|
|
||||||
'year': dt.year,
|
|
||||||
'hour': dt.hour,
|
|
||||||
'minute': dt.minute,
|
|
||||||
'second': dt.second,
|
|
||||||
'microsecond': dt.microsecond,
|
|
||||||
}
|
|
||||||
if dt.tzinfo:
|
|
||||||
dct['tz'] = dt.tzinfo.tzname(None)
|
|
||||||
return dumps(dct)
|
|
||||||
|
|
||||||
|
|
||||||
def _deserialize_datetime(blob):
|
|
||||||
dct = loads(blob)
|
|
||||||
dt = datetime.datetime(day=dct['day'],
|
|
||||||
month=dct['month'],
|
|
||||||
year=dct['year'],
|
|
||||||
hour=dct['hour'],
|
|
||||||
minute=dct['minute'],
|
|
||||||
second=dct['second'],
|
|
||||||
microsecond=dct['microsecond'])
|
|
||||||
if 'tz' in dct:
|
|
||||||
tzinfo = timezone(dct['tz'])
|
|
||||||
dt = tzinfo.localize(dt)
|
|
||||||
return dt
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_date(d):
|
|
||||||
dct = {
|
|
||||||
'year': d.year,
|
|
||||||
'month': d.month,
|
|
||||||
'day': d.day,
|
|
||||||
}
|
|
||||||
return dumps(dct)
|
|
||||||
|
|
||||||
|
|
||||||
def _deserialize_date(blob):
|
|
||||||
dct = loads(blob)
|
|
||||||
return datetime.date(year=dct['year'],
|
|
||||||
month=dct['month'],
|
|
||||||
day=dct['day'])
|
|
||||||
|
|
||||||
|
|
||||||
def _serializer(obj):
|
|
||||||
# Applications can assign 0 to 127 to store
|
# Applications can assign 0 to 127 to store
|
||||||
# application-specific type information...
|
# application-specific type information...
|
||||||
if isinstance(obj, uuid.UUID):
|
min_value = 0
|
||||||
return msgpack.ExtType(0, six.text_type(obj.hex).encode('ascii'))
|
max_value = 127
|
||||||
if isinstance(obj, datetime.datetime):
|
|
||||||
return msgpack.ExtType(1, _serialize_datetime(obj))
|
def __init__(self):
|
||||||
if type(obj) == itertools.count:
|
self._handlers = {}
|
||||||
|
self.frozen = False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return six.itervalues(self._handlers)
|
||||||
|
|
||||||
|
def register(self, handler):
|
||||||
|
if self.frozen:
|
||||||
|
raise ValueError("Frozen handler registry can't be modified")
|
||||||
|
ident = handler.identity
|
||||||
|
if ident < self.min_value:
|
||||||
|
raise ValueError("Handler '%s' identity must be greater"
|
||||||
|
" or equal to %s" % (handler, self.min_value))
|
||||||
|
if ident > self.max_value:
|
||||||
|
raise ValueError("Handler '%s' identity must be less than"
|
||||||
|
" or equal to %s" % (handler, self.max_value))
|
||||||
|
if ident in self._handlers:
|
||||||
|
raise ValueError("Already registered handler with"
|
||||||
|
" identity %s: %s" % (ident,
|
||||||
|
self._handlers[ident]))
|
||||||
|
else:
|
||||||
|
self._handlers[ident] = handler
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._handlers)
|
||||||
|
|
||||||
|
def get(self, identity):
|
||||||
|
return self._handlers.get(identity, None)
|
||||||
|
|
||||||
|
def match(self, obj):
|
||||||
|
for handler in six.itervalues(self._handlers):
|
||||||
|
if isinstance(obj, handler.handles):
|
||||||
|
return handler
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDHandler(object):
|
||||||
|
identity = 0
|
||||||
|
handles = (uuid.UUID,)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize(obj):
|
||||||
|
return six.text_type(obj.hex).encode('ascii')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(data):
|
||||||
|
return uuid.UUID(hex=six.text_type(data, encoding='ascii'))
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeHandler(object):
|
||||||
|
identity = 1
|
||||||
|
handles = (datetime.datetime,)
|
||||||
|
|
||||||
|
def __init__(self, registry):
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def serialize(self, dt):
|
||||||
|
dct = {
|
||||||
|
'day': dt.day,
|
||||||
|
'month': dt.month,
|
||||||
|
'year': dt.year,
|
||||||
|
'hour': dt.hour,
|
||||||
|
'minute': dt.minute,
|
||||||
|
'second': dt.second,
|
||||||
|
'microsecond': dt.microsecond,
|
||||||
|
}
|
||||||
|
if dt.tzinfo:
|
||||||
|
dct['tz'] = dt.tzinfo.tzname(None)
|
||||||
|
return dumps(dct, registry=self._registry)
|
||||||
|
|
||||||
|
def deserialize(self, blob):
|
||||||
|
dct = loads(blob, registry=self._registry)
|
||||||
|
dt = datetime.datetime(day=dct['day'],
|
||||||
|
month=dct['month'],
|
||||||
|
year=dct['year'],
|
||||||
|
hour=dct['hour'],
|
||||||
|
minute=dct['minute'],
|
||||||
|
second=dct['second'],
|
||||||
|
microsecond=dct['microsecond'])
|
||||||
|
if 'tz' in dct:
|
||||||
|
tzinfo = timezone(dct['tz'])
|
||||||
|
dt = tzinfo.localize(dt)
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
|
class CountHandler(object):
|
||||||
|
identity = 2
|
||||||
|
handles = (itertools.count,)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize(obj):
|
||||||
# FIXME(harlowja): figure out a better way to avoid hacking into
|
# FIXME(harlowja): figure out a better way to avoid hacking into
|
||||||
# the string representation of count to get at the right numbers...
|
# the string representation of count to get at the right numbers...
|
||||||
obj = six.text_type(obj)
|
obj = six.text_type(obj)
|
||||||
@ -99,71 +144,161 @@ def _serializer(obj):
|
|||||||
else:
|
else:
|
||||||
start = int(pieces[0])
|
start = int(pieces[0])
|
||||||
step = int(pieces[1])
|
step = int(pieces[1])
|
||||||
return msgpack.ExtType(2, msgpack.packb([start, step]))
|
return msgpack.packb([start, step])
|
||||||
if netaddr and isinstance(obj, netaddr.IPAddress):
|
|
||||||
return msgpack.ExtType(3, msgpack.packb(obj.value))
|
|
||||||
if isinstance(obj, (set, frozenset)):
|
|
||||||
value = dumps(list(obj))
|
|
||||||
if isinstance(obj, set):
|
|
||||||
ident = 4
|
|
||||||
else:
|
|
||||||
ident = 5
|
|
||||||
return msgpack.ExtType(ident, value)
|
|
||||||
if isinstance(obj, xmlrpclib.DateTime):
|
|
||||||
dt = datetime.datetime(*tuple(obj.timetuple())[:6])
|
|
||||||
return msgpack.ExtType(6, _serialize_datetime(dt))
|
|
||||||
if isinstance(obj, datetime.date):
|
|
||||||
return msgpack.ExtType(7, _serialize_date(obj))
|
|
||||||
raise TypeError("Unknown type: %r" % (obj,))
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _unserializer(code, data):
|
def deserialize(data):
|
||||||
if code == 0:
|
|
||||||
return uuid.UUID(hex=six.text_type(data, encoding='ascii'))
|
|
||||||
if code == 1:
|
|
||||||
return _deserialize_datetime(data)
|
|
||||||
if code == 2:
|
|
||||||
value = msgpack.unpackb(data)
|
value = msgpack.unpackb(data)
|
||||||
|
start, step = value
|
||||||
if not _PY26:
|
if not _PY26:
|
||||||
return itertools.count(value[0], value[1])
|
return itertools.count(start, step)
|
||||||
else:
|
else:
|
||||||
return itertools.count(value[0])
|
if step != 1:
|
||||||
if netaddr and code == 3:
|
raise ValueError("Python 2.6.x does not support steps"
|
||||||
value = msgpack.unpackb(data)
|
" that are not equal to one")
|
||||||
return netaddr.IPAddress(value)
|
return itertools.count(start)
|
||||||
if code in (4, 5):
|
|
||||||
value = loads(data)
|
|
||||||
if code == 4:
|
if netaddr is not None:
|
||||||
return set(value)
|
class NetAddrIPHandler(object):
|
||||||
else:
|
identity = 3
|
||||||
return frozenset(value)
|
handles = (netaddr.IPAddress,)
|
||||||
if code == 6:
|
|
||||||
dt = _deserialize_datetime(data)
|
@staticmethod
|
||||||
|
def serialize(obj):
|
||||||
|
return msgpack.packb(obj.value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(data):
|
||||||
|
return netaddr.IPAddress(msgpack.unpackb(data))
|
||||||
|
else:
|
||||||
|
NetAddrIPHandler = None
|
||||||
|
|
||||||
|
|
||||||
|
class SetHandler(object):
|
||||||
|
identity = 4
|
||||||
|
handles = (set,)
|
||||||
|
|
||||||
|
def __init__(self, registry):
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def serialize(self, obj):
|
||||||
|
return dumps(list(obj), registry=self._registry)
|
||||||
|
|
||||||
|
def deserialize(self, data):
|
||||||
|
return self.handles[0](loads(data, registry=self._registry))
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenSetHandler(SetHandler):
|
||||||
|
identity = 5
|
||||||
|
handles = (frozenset,)
|
||||||
|
|
||||||
|
|
||||||
|
class XMLRPCDateTimeHandler(object):
|
||||||
|
handles = (xmlrpclib.DateTime,)
|
||||||
|
identity = 6
|
||||||
|
|
||||||
|
def __init__(self, registry):
|
||||||
|
self._handler = DateTimeHandler(registry)
|
||||||
|
|
||||||
|
def serialize(self, obj):
|
||||||
|
dt = datetime.datetime(*tuple(obj.timetuple())[:6])
|
||||||
|
return self._handler.serialize(dt)
|
||||||
|
|
||||||
|
def deserialize(self, blob):
|
||||||
|
dt = self._handler.deserialize(blob)
|
||||||
return xmlrpclib.DateTime(dt.timetuple())
|
return xmlrpclib.DateTime(dt.timetuple())
|
||||||
if code == 7:
|
|
||||||
return _deserialize_date(data)
|
|
||||||
return msgpack.ExtType(code, data)
|
|
||||||
|
|
||||||
|
|
||||||
def load(fp):
|
class DateHandler(object):
|
||||||
|
identity = 7
|
||||||
|
handles = (datetime.date,)
|
||||||
|
|
||||||
|
def __init__(self, registry):
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def serialize(self, d):
|
||||||
|
dct = {
|
||||||
|
'year': d.year,
|
||||||
|
'month': d.month,
|
||||||
|
'day': d.day,
|
||||||
|
}
|
||||||
|
return dumps(dct, registry=self._registry)
|
||||||
|
|
||||||
|
def deserialize(self, blob):
|
||||||
|
dct = loads(blob, registry=self._registry)
|
||||||
|
return datetime.date(year=dct['year'],
|
||||||
|
month=dct['month'],
|
||||||
|
day=dct['day'])
|
||||||
|
|
||||||
|
|
||||||
|
def _serializer(registry, obj):
|
||||||
|
handler = registry.match(obj)
|
||||||
|
if handler is None:
|
||||||
|
raise TypeError("No serialization handler registered"
|
||||||
|
" for type '%s'" % (type(obj).__name__))
|
||||||
|
return msgpack.ExtType(handler.identity, handler.serialize(obj))
|
||||||
|
|
||||||
|
|
||||||
|
def _unserializer(registry, code, data):
|
||||||
|
handler = registry.get(code)
|
||||||
|
if handler is None:
|
||||||
|
return msgpack.ExtType(code, data)
|
||||||
|
else:
|
||||||
|
return handler.deserialize(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_default_registry():
|
||||||
|
registry = HandlerRegistry()
|
||||||
|
registry.register(DateTimeHandler(registry))
|
||||||
|
registry.register(DateHandler(registry))
|
||||||
|
registry.register(UUIDHandler())
|
||||||
|
registry.register(CountHandler())
|
||||||
|
registry.register(SetHandler(registry))
|
||||||
|
registry.register(FrozenSetHandler(registry))
|
||||||
|
if netaddr is not None:
|
||||||
|
registry.register(NetAddrIPHandler())
|
||||||
|
registry.register(XMLRPCDateTimeHandler(registry))
|
||||||
|
registry.frozen = True
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
#: Default, read-only/frozen registry that will be used when none is provided.
|
||||||
|
default_registry = _create_default_registry()
|
||||||
|
|
||||||
|
|
||||||
|
def load(fp, registry=None):
|
||||||
"""Deserialize ``fp`` into a Python object."""
|
"""Deserialize ``fp`` into a Python object."""
|
||||||
|
if registry is None:
|
||||||
|
registry = default_registry
|
||||||
# NOTE(harlowja): the reason we can't use the more native msgpack functions
|
# NOTE(harlowja): the reason we can't use the more native msgpack functions
|
||||||
# here is that the unpack() function (oddly) doesn't seem to take a
|
# here is that the unpack() function (oddly) doesn't seem to take a
|
||||||
# 'ext_hook' parameter..
|
# 'ext_hook' parameter..
|
||||||
return msgpack.Unpacker(fp, ext_hook=_unserializer,
|
ext_hook = functools.partial(_unserializer, registry)
|
||||||
encoding='utf-8').unpack()
|
return msgpack.Unpacker(fp, ext_hook=ext_hook, encoding='utf-8').unpack()
|
||||||
|
|
||||||
|
|
||||||
def dump(obj, fp):
|
def dump(obj, fp, registry=None):
|
||||||
"""Serialize ``obj`` as a messagepack formatted stream to ``fp``."""
|
"""Serialize ``obj`` as a messagepack formatted stream to ``fp``."""
|
||||||
return msgpack.pack(obj, fp, default=_serializer, use_bin_type=True)
|
if registry is None:
|
||||||
|
registry = default_registry
|
||||||
|
return msgpack.pack(obj, fp,
|
||||||
|
default=functools.partial(_serializer, registry),
|
||||||
|
use_bin_type=True)
|
||||||
|
|
||||||
|
|
||||||
def dumps(obj):
|
def dumps(obj, registry=None):
|
||||||
"""Serialize ``obj`` to a messagepack formatted ``str``."""
|
"""Serialize ``obj`` to a messagepack formatted ``str``."""
|
||||||
return msgpack.packb(obj, default=_serializer, use_bin_type=True)
|
if registry is None:
|
||||||
|
registry = default_registry
|
||||||
|
return msgpack.packb(obj,
|
||||||
|
default=functools.partial(_serializer, registry),
|
||||||
|
use_bin_type=True)
|
||||||
|
|
||||||
|
|
||||||
def loads(s):
|
def loads(s, registry=None):
|
||||||
"""Deserialize ``s`` messagepack ``str`` into a Python object."""
|
"""Deserialize ``s`` messagepack ``str`` into a Python object."""
|
||||||
return msgpack.unpackb(s, ext_hook=_unserializer, encoding='utf-8')
|
if registry is None:
|
||||||
|
registry = default_registry
|
||||||
|
ext_hook = functools.partial(_unserializer, registry)
|
||||||
|
return msgpack.unpackb(s, ext_hook=ext_hook, encoding='utf-8')
|
||||||
|
Loading…
Reference in New Issue
Block a user