From 075568dc87d987e249afdef0324a1dace4e8a81f Mon Sep 17 00:00:00 2001 From: Alyson Deives Pereira Date: Thu, 28 Jul 2022 09:49:48 -0300 Subject: [PATCH] Add option for custom msgpack encoder/decoder Signed-off-by: Alyson Deives Pereira --- zerorpc/core.py | 8 ++++---- zerorpc/events.py | 20 +++++++++++--------- zerorpc/socket.py | 4 ++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/zerorpc/core.py b/zerorpc/core.py index ea89f36..90b1aae 100644 --- a/zerorpc/core.py +++ b/zerorpc/core.py @@ -278,8 +278,8 @@ class ClientBase(object): class Server(SocketBase, ServerBase): def __init__(self, methods=None, name=None, context=None, pool_size=1000, - heartbeat=5): - SocketBase.__init__(self, zmq.ROUTER, context) + heartbeat=5, encoder=None, decoder=None): + SocketBase.__init__(self, zmq.ROUTER, context, encoder, decoder) if methods is None: methods = self @@ -296,8 +296,8 @@ class Server(SocketBase, ServerBase): class Client(SocketBase, ClientBase): def __init__(self, connect_to=None, context=None, timeout=30, heartbeat=5, - passive_heartbeat=False): - SocketBase.__init__(self, zmq.DEALER, context=context) + passive_heartbeat=False, encoder=None, decoder=None): + SocketBase.__init__(self, zmq.DEALER, context=context, encoder=encoder, decoder=decoder) ClientBase.__init__(self, self._events, context, timeout, heartbeat, passive_heartbeat) if connect_to: diff --git a/zerorpc/events.py b/zerorpc/events.py index ce97ad6..295220e 100644 --- a/zerorpc/events.py +++ b/zerorpc/events.py @@ -201,14 +201,14 @@ class Event(object): def identity(self, v): self._identity = v - def pack(self): + def pack(self, encoder=None): payload = (self._header, self._name, self._args) - r = msgpack.Packer(use_bin_type=True).pack(payload) + r = msgpack.Packer(use_bin_type=True, default=encoder).pack(payload) return r @staticmethod - def unpack(blob): - unpacker = msgpack.Unpacker(raw=False) + def unpack(blob, decoder=None): + unpacker = msgpack.Unpacker(raw=False, object_hook=decoder) unpacker.feed(blob) unpacked_msg = unpacker.unpack() @@ -241,11 +241,13 @@ class Event(object): class Events(ChannelBase): - def __init__(self, zmq_socket_type, context=None): + def __init__(self, zmq_socket_type, context=None, encoder=None, decoder=None): self._debug = False self._zmq_socket_type = zmq_socket_type self._context = context or Context.get_instance() self._socket = self._context.socket(zmq_socket_type) + self._encoder = encoder + self._decoder = decoder if zmq_socket_type in (zmq.PUSH, zmq.PUB, zmq.DEALER, zmq.ROUTER): self._send = Sender(self._socket) @@ -342,11 +344,11 @@ class Events(ChannelBase): logger.debug('--> %s', event) if event.identity: parts = list(event.identity or list()) - parts.extend([b'', event.pack()]) + parts.extend([b'', event.pack(encoder=self._encoder)]) elif self._zmq_socket_type in (zmq.DEALER, zmq.ROUTER): - parts = (b'', event.pack()) + parts = (b'', event.pack(encoder=self._encoder)) else: - parts = (event.pack(),) + parts = (event.pack(encoder=self._encoder),) self._send(parts, timeout) def recv(self, timeout=None): @@ -360,7 +362,7 @@ class Events(ChannelBase): else: identity = None blob = parts[0] - event = Event.unpack(get_pyzmq_frame_buffer(blob)) + event = Event.unpack(get_pyzmq_frame_buffer(blob), decoder=self._decoder) event.identity = identity if self._debug: logger.debug('<-- %s', event) diff --git a/zerorpc/socket.py b/zerorpc/socket.py index 35cb7e4..274a6d4 100644 --- a/zerorpc/socket.py +++ b/zerorpc/socket.py @@ -29,9 +29,9 @@ from .events import Events class SocketBase(object): - def __init__(self, zmq_socket_type, context=None): + def __init__(self, zmq_socket_type, context=None, encoder=None, decoder=None): self._context = context or Context.get_instance() - self._events = Events(zmq_socket_type, context) + self._events = Events(zmq_socket_type, context, encoder, decoder) def close(self): self._events.close() -- 2.25.1