diff --git a/oslo_serialization/msgpackutils.py b/oslo_serialization/msgpackutils.py index 5a7f038..d1e3082 100644 --- a/oslo_serialization/msgpackutils.py +++ b/oslo_serialization/msgpackutils.py @@ -13,6 +13,7 @@ # under the License. import datetime +import functools import itertools import sys import uuid @@ -33,60 +34,104 @@ else: _PY26 = False -def _serialize_datetime(dt): - dct = { - 'day': dt.day, - 'month': dt.month, - 'year': dt.year, - 'hour': dt.hour, - 'minute': dt.minute, - 'second': dt.second, - 'microsecond': dt.microsecond, - } - if dt.tzinfo: - dct['tz'] = dt.tzinfo.tzname(None) - return dumps(dct) - - -def _deserialize_datetime(blob): - dct = loads(blob) - dt = datetime.datetime(day=dct['day'], - month=dct['month'], - year=dct['year'], - hour=dct['hour'], - minute=dct['minute'], - second=dct['second'], - microsecond=dct['microsecond']) - if 'tz' in dct: - tzinfo = timezone(dct['tz']) - dt = tzinfo.localize(dt) - return dt - - -def _serialize_date(d): - dct = { - 'year': d.year, - 'month': d.month, - 'day': d.day, - } - return dumps(dct) - - -def _deserialize_date(blob): - dct = loads(blob) - return datetime.date(year=dct['year'], - month=dct['month'], - day=dct['day']) - - -def _serializer(obj): +class HandlerRegistry(object): # Applications can assign 0 to 127 to store # application-specific type information... - if isinstance(obj, uuid.UUID): - return msgpack.ExtType(0, six.text_type(obj.hex).encode('ascii')) - if isinstance(obj, datetime.datetime): - return msgpack.ExtType(1, _serialize_datetime(obj)) - if type(obj) == itertools.count: + min_value = 0 + max_value = 127 + + def __init__(self): + self._handlers = {} + self.frozen = False + + def __iter__(self): + return six.itervalues(self._handlers) + + def register(self, handler): + if self.frozen: + raise ValueError("Frozen handler registry can't be modified") + ident = handler.identity + if ident < self.min_value: + raise ValueError("Handler '%s' identity must be greater" + " or equal to %s" % (handler, self.min_value)) + if ident > self.max_value: + raise ValueError("Handler '%s' identity must be less than" + " or equal to %s" % (handler, self.max_value)) + if ident in self._handlers: + raise ValueError("Already registered handler with" + " identity %s: %s" % (ident, + self._handlers[ident])) + else: + self._handlers[ident] = handler + + def __len__(self): + return len(self._handlers) + + def get(self, identity): + return self._handlers.get(identity, None) + + def match(self, obj): + for handler in six.itervalues(self._handlers): + if isinstance(obj, handler.handles): + return handler + return None + + +class UUIDHandler(object): + identity = 0 + handles = (uuid.UUID,) + + @staticmethod + def serialize(obj): + return six.text_type(obj.hex).encode('ascii') + + @staticmethod + def deserialize(data): + return uuid.UUID(hex=six.text_type(data, encoding='ascii')) + + +class DateTimeHandler(object): + identity = 1 + handles = (datetime.datetime,) + + def __init__(self, registry): + self._registry = registry + + def serialize(self, dt): + dct = { + 'day': dt.day, + 'month': dt.month, + 'year': dt.year, + 'hour': dt.hour, + 'minute': dt.minute, + 'second': dt.second, + 'microsecond': dt.microsecond, + } + if dt.tzinfo: + dct['tz'] = dt.tzinfo.tzname(None) + return dumps(dct, registry=self._registry) + + def deserialize(self, blob): + dct = loads(blob, registry=self._registry) + dt = datetime.datetime(day=dct['day'], + month=dct['month'], + year=dct['year'], + hour=dct['hour'], + minute=dct['minute'], + second=dct['second'], + microsecond=dct['microsecond']) + if 'tz' in dct: + tzinfo = timezone(dct['tz']) + dt = tzinfo.localize(dt) + return dt + + +class CountHandler(object): + identity = 2 + handles = (itertools.count,) + + @staticmethod + def serialize(obj): # FIXME(harlowja): figure out a better way to avoid hacking into # the string representation of count to get at the right numbers... obj = six.text_type(obj) @@ -99,71 +144,161 @@ def _serializer(obj): else: start = int(pieces[0]) step = int(pieces[1]) - return msgpack.ExtType(2, msgpack.packb([start, step])) - if netaddr and isinstance(obj, netaddr.IPAddress): - return msgpack.ExtType(3, msgpack.packb(obj.value)) - if isinstance(obj, (set, frozenset)): - value = dumps(list(obj)) - if isinstance(obj, set): - ident = 4 - else: - ident = 5 - return msgpack.ExtType(ident, value) - if isinstance(obj, xmlrpclib.DateTime): - dt = datetime.datetime(*tuple(obj.timetuple())[:6]) - return msgpack.ExtType(6, _serialize_datetime(dt)) - if isinstance(obj, datetime.date): - return msgpack.ExtType(7, _serialize_date(obj)) - raise TypeError("Unknown type: %r" % (obj,)) + return msgpack.packb([start, step]) - -def _unserializer(code, data): - if code == 0: - return uuid.UUID(hex=six.text_type(data, encoding='ascii')) - if code == 1: - return _deserialize_datetime(data) - if code == 2: + @staticmethod + def deserialize(data): value = msgpack.unpackb(data) + start, step = value if not _PY26: - return itertools.count(value[0], value[1]) + return itertools.count(start, step) else: - return itertools.count(value[0]) - if netaddr and code == 3: - value = msgpack.unpackb(data) - return netaddr.IPAddress(value) - if code in (4, 5): - value = loads(data) - if code == 4: - return set(value) - else: - return frozenset(value) - if code == 6: - dt = _deserialize_datetime(data) + if step != 1: + raise ValueError("Python 2.6.x does not support steps" + " that are not equal to one") + return itertools.count(start) + + +if netaddr is not None: + class NetAddrIPHandler(object): + identity = 3 + handles = (netaddr.IPAddress,) + + @staticmethod + def serialize(obj): + return msgpack.packb(obj.value) + + @staticmethod + def deserialize(data): + return netaddr.IPAddress(msgpack.unpackb(data)) +else: + NetAddrIPHandler = None + + +class SetHandler(object): + identity = 4 + handles = (set,) + + def __init__(self, registry): + self._registry = registry + + def serialize(self, obj): + return dumps(list(obj), registry=self._registry) + + def deserialize(self, data): + return self.handles[0](loads(data, registry=self._registry)) + + +class FrozenSetHandler(SetHandler): + identity = 5 + handles = (frozenset,) + + +class XMLRPCDateTimeHandler(object): + handles = (xmlrpclib.DateTime,) + identity = 6 + + def __init__(self, registry): + self._handler = DateTimeHandler(registry) + + def serialize(self, obj): + dt = datetime.datetime(*tuple(obj.timetuple())[:6]) + return self._handler.serialize(dt) + + def deserialize(self, blob): + dt = self._handler.deserialize(blob) return xmlrpclib.DateTime(dt.timetuple()) - if code == 7: - return _deserialize_date(data) - return msgpack.ExtType(code, data) -def load(fp): +class DateHandler(object): + identity = 7 + handles = (datetime.date,) + + def __init__(self, registry): + self._registry = registry + + def serialize(self, d): + dct = { + 'year': d.year, + 'month': d.month, + 'day': d.day, + } + return dumps(dct, registry=self._registry) + + def deserialize(self, blob): + dct = loads(blob, registry=self._registry) + return datetime.date(year=dct['year'], + month=dct['month'], + day=dct['day']) + + +def _serializer(registry, obj): + handler = registry.match(obj) + if handler is None: + raise TypeError("No serialization handler registered" + " for type '%s'" % (type(obj).__name__)) + return msgpack.ExtType(handler.identity, handler.serialize(obj)) + + +def _unserializer(registry, code, data): + handler = registry.get(code) + if handler is None: + return msgpack.ExtType(code, data) + else: + return handler.deserialize(data) + + +def _create_default_registry(): + registry = HandlerRegistry() + registry.register(DateTimeHandler(registry)) + registry.register(DateHandler(registry)) + registry.register(UUIDHandler()) + registry.register(CountHandler()) + registry.register(SetHandler(registry)) + registry.register(FrozenSetHandler(registry)) + if netaddr is not None: + registry.register(NetAddrIPHandler()) + registry.register(XMLRPCDateTimeHandler(registry)) + registry.frozen = True + return registry + + +#: Default, read-only/frozen registry that will be used when none is provided. +default_registry = _create_default_registry() + + +def load(fp, registry=None): """Deserialize ``fp`` into a Python object.""" + if registry is None: + registry = default_registry # NOTE(harlowja): the reason we can't use the more native msgpack functions # here is that the unpack() function (oddly) doesn't seem to take a # 'ext_hook' parameter.. - return msgpack.Unpacker(fp, ext_hook=_unserializer, - encoding='utf-8').unpack() + ext_hook = functools.partial(_unserializer, registry) + return msgpack.Unpacker(fp, ext_hook=ext_hook, encoding='utf-8').unpack() -def dump(obj, fp): +def dump(obj, fp, registry=None): """Serialize ``obj`` as a messagepack formatted stream to ``fp``.""" - return msgpack.pack(obj, fp, default=_serializer, use_bin_type=True) + if registry is None: + registry = default_registry + return msgpack.pack(obj, fp, + default=functools.partial(_serializer, registry), + use_bin_type=True) -def dumps(obj): +def dumps(obj, registry=None): """Serialize ``obj`` to a messagepack formatted ``str``.""" - return msgpack.packb(obj, default=_serializer, use_bin_type=True) + if registry is None: + registry = default_registry + return msgpack.packb(obj, + default=functools.partial(_serializer, registry), + use_bin_type=True) -def loads(s): +def loads(s, registry=None): """Deserialize ``s`` messagepack ``str`` into a Python object.""" - return msgpack.unpackb(s, ext_hook=_unserializer, encoding='utf-8') + if registry is None: + registry = default_registry + ext_hook = functools.partial(_unserializer, registry) + return msgpack.unpackb(s, ext_hook=ext_hook, encoding='utf-8')