[zmq] Implement retries for unacknowledged CASTs

This patch tries to implement a mechanism of acknowledgements and
retries via proxy for CAST messages.

Change-Id: I83919382262b9f169becd09f5db465a01a0ccb78
Partial-Bug: #1497306
Closes-Bug: #1515269
This commit is contained in:
Gevorg Davoian 2016-07-06 11:49:01 +03:00
parent a0336c8aa1
commit 20a07e7f48
28 changed files with 860 additions and 168 deletions

View File

@ -16,12 +16,11 @@ import abc
from concurrent import futures
import logging
import retrying
import oslo_messaging
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver.client.publishers \
import zmq_publisher_base
from oslo_messaging._drivers.zmq_driver.client import zmq_response
from oslo_messaging._drivers.zmq_driver.client import zmq_sockets_manager
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
@ -56,15 +55,13 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase):
{"tout": request.timeout, "msg_id": request.message_id}
)
@abc.abstractmethod
def _connect_socket(self, request):
pass
def _recv_reply(self, request):
reply_future, = self.receiver.track_request(request)
reply_future = \
self.receiver.track_request(request)[zmq_names.REPLY_TYPE]
try:
_, reply = reply_future.result(timeout=request.timeout)
assert isinstance(reply, zmq_response.Reply), "Reply expected!"
except AssertionError:
LOG.error(_LE("Message format error in reply for %s"),
request.message_id)
@ -84,9 +81,8 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase):
def send_call(self, request):
self._check_pattern(request, zmq_names.CALL_TYPE)
try:
socket = self._connect_socket(request)
except retrying.RetryError:
socket = self.connect_socket(request)
if not socket:
self._raise_timeout(request)
self.sender.send(socket, request)

View File

@ -33,17 +33,22 @@ class DealerPublisherDirect(zmq_dealer_publisher_base.DealerPublisherBase):
def __init__(self, conf, matchmaker):
sender = zmq_senders.RequestSenderDirect(conf)
receiver = zmq_receivers.ReplyReceiverDirect(conf)
if conf.oslo_messaging_zmq.rpc_use_acks:
receiver = zmq_receivers.AckAndReplyReceiverDirect(conf)
else:
receiver = zmq_receivers.ReplyReceiverDirect(conf)
super(DealerPublisherDirect, self).__init__(conf, matchmaker, sender,
receiver)
def _connect_socket(self, request):
return self.sockets_manager.get_socket(request.target)
def connect_socket(self, request):
try:
return self.sockets_manager.get_socket(request.target)
except retrying.RetryError:
return None
def _send_non_blocking(self, request):
try:
socket = self._connect_socket(request)
except retrying.RetryError:
socket = self.connect_socket(request)
if not socket:
return
if request.msg_type in zmq_names.MULTISEND_TYPES:

View File

@ -36,7 +36,10 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase):
def __init__(self, conf, matchmaker):
sender = zmq_senders.RequestSenderProxy(conf)
receiver = zmq_receivers.ReplyReceiverProxy(conf)
if conf.oslo_messaging_zmq.rpc_use_acks:
receiver = zmq_receivers.AckAndReplyReceiverProxy(conf)
else:
receiver = zmq_receivers.ReplyReceiverProxy(conf)
super(DealerPublisherProxy, self).__init__(conf, matchmaker, sender,
receiver)
self.socket = self.sockets_manager.get_socket_to_publishers()
@ -45,7 +48,7 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase):
self.connection_updater = \
PublisherConnectionUpdater(self.conf, self.matchmaker, self.socket)
def _connect_socket(self, request):
def connect_socket(self, request):
return self.socket
def send_call(self, request):

View File

@ -72,6 +72,12 @@ class PublisherBase(object):
self.sender = sender
self.receiver = receiver
@abc.abstractmethod
def connect_socket(self, request):
"""Get connected socket ready for sending given request
or None otherwise (i.e. if connection can't be established).
"""
@abc.abstractmethod
def send_call(self, request):
pass

View File

@ -0,0 +1,111 @@
# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from concurrent import futures
import logging
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._i18n import _LE, _LW
LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq()
class AckManagerBase(object):
def __init__(self, publisher):
self.publisher = publisher
self.conf = publisher.conf
self.sender = publisher.sender
self.receiver = publisher.receiver
def send_call(self, request):
return self.publisher.send_call(request)
def send_cast(self, request):
self.publisher.send_cast(request)
def send_fanout(self, request):
self.publisher.send_fanout(request)
def send_notify(self, request):
self.publisher.send_notify(request)
def cleanup(self):
self.publisher.cleanup()
class AckManagerDirect(AckManagerBase):
pass
class AckManagerProxy(AckManagerBase):
def __init__(self, publisher):
super(AckManagerProxy, self).__init__(publisher)
self._pool = zmq_async.get_pool(
size=self.conf.oslo_messaging_zmq.rpc_thread_pool_size
)
def _wait_for_ack(self, ack_future):
request, socket = ack_future.args
retries = \
request.retry or self.conf.oslo_messaging_zmq.rpc_retry_attempts
timeout = self.conf.oslo_messaging_zmq.rpc_ack_timeout_base
done = False
while not done:
try:
reply_id, response = ack_future.result(timeout=timeout)
done = True
assert response is None, "Ack expected!"
assert reply_id == request.routing_key, \
"Ack from recipient expected!"
except AssertionError:
LOG.error(_LE("Message format error in ack for %s"),
request.message_id)
except futures.TimeoutError:
LOG.warning(_LW("No ack received within %(tout)s seconds "
"for %(msg_id)s"),
{"tout": timeout,
"msg_id": request.message_id})
if retries is None or retries != 0:
if retries is not None and retries > 0:
retries -= 1
self.sender.send(socket, request)
timeout *= \
self.conf.oslo_messaging_zmq.rpc_ack_timeout_multiplier
else:
LOG.warning(_LW("Exhausted number of retries for %s"),
request.message_id)
done = True
self.receiver.untrack_request(request)
def _get_ack_future(self, request):
socket = self.publisher.connect_socket(request)
self.receiver.register_socket(socket)
ack_future = self.receiver.track_request(request)[zmq_names.ACK_TYPE]
ack_future.args = request, socket
return ack_future
def send_cast(self, request):
self.publisher.send_cast(request)
self._pool.submit(self._wait_for_ack, self._get_ack_future(request))
def cleanup(self):
self._pool.shutdown(wait=True)
super(AckManagerProxy, self).cleanup()

View File

@ -18,6 +18,7 @@ from oslo_messaging._drivers.zmq_driver.client.publishers.dealer \
import zmq_dealer_publisher_direct
from oslo_messaging._drivers.zmq_driver.client.publishers.dealer \
import zmq_dealer_publisher_proxy
from oslo_messaging._drivers.zmq_driver.client import zmq_ack_manager
from oslo_messaging._drivers.zmq_driver.client import zmq_client_base
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
@ -43,11 +44,17 @@ class ZmqClientMixDirectPubSub(zmq_client_base.ZmqClientBase):
conf.oslo_messaging_zmq.use_pub_sub:
raise WrongClientException()
publisher_direct = \
zmq_dealer_publisher_direct.DealerPublisherDirect(conf, matchmaker)
publisher_direct = self.create_publisher(
conf, matchmaker,
zmq_dealer_publisher_direct.DealerPublisherDirect,
zmq_ack_manager.AckManagerDirect
)
publisher_proxy = \
zmq_dealer_publisher_proxy.DealerPublisherProxy(conf, matchmaker)
publisher_proxy = self.create_publisher(
conf, matchmaker,
zmq_dealer_publisher_proxy.DealerPublisherProxy,
zmq_ack_manager.AckManagerProxy
)
super(ZmqClientMixDirectPubSub, self).__init__(
conf, matchmaker, allowed_remote_exmods,
@ -73,8 +80,11 @@ class ZmqClientDirect(zmq_client_base.ZmqClientBase):
conf.oslo_messaging_zmq.use_router_proxy:
raise WrongClientException()
publisher = \
zmq_dealer_publisher_direct.DealerPublisherDirect(conf, matchmaker)
publisher = self.create_publisher(
conf, matchmaker,
zmq_dealer_publisher_direct.DealerPublisherDirect,
zmq_ack_manager.AckManagerDirect
)
super(ZmqClientDirect, self).__init__(
conf, matchmaker, allowed_remote_exmods,
@ -97,8 +107,11 @@ class ZmqClientProxy(zmq_client_base.ZmqClientBase):
if not conf.oslo_messaging_zmq.use_router_proxy:
raise WrongClientException()
publisher = \
zmq_dealer_publisher_proxy.DealerPublisherProxy(conf, matchmaker)
publisher = self.create_publisher(
conf, matchmaker,
zmq_dealer_publisher_proxy.DealerPublisherProxy,
zmq_ack_manager.AckManagerProxy
)
super(ZmqClientProxy, self).__init__(
conf, matchmaker, allowed_remote_exmods,

View File

@ -37,6 +37,13 @@ class ZmqClientBase(object):
self.notify_publisher = publishers.get(zmq_names.NOTIFY_TYPE,
publishers["default"])
@staticmethod
def create_publisher(conf, matchmaker, publisher_cls, ack_manager_cls):
publisher = publisher_cls(conf, matchmaker)
if conf.oslo_messaging_zmq.rpc_use_acks:
publisher = ack_manager_cls(publisher)
return publisher
def send_call(self, target, context, message, timeout=None, retry=None):
request = zmq_request.CallRequest(
target, context=context, message=message, retry=retry,

View File

@ -42,12 +42,16 @@ class ReceiverBase(object):
@abc.abstractproperty
def message_types(self):
"""A list of supported incoming response types."""
"""A set of supported incoming response types."""
def register_socket(self, socket):
"""Register a socket for receiving data."""
self._poller.register(socket, recv_method=self.recv_response)
def unregister_socket(self, socket):
"""Unregister a socket from receiving data."""
self._poller.unregister(socket)
@abc.abstractmethod
def recv_response(self, socket):
"""Receive a response and return a tuple of the form
@ -56,13 +60,13 @@ class ReceiverBase(object):
def track_request(self, request):
"""Track a request via already registered sockets and return
a list of futures for monitoring all types of responses.
a dict of futures for monitoring all types of responses.
"""
futures = []
futures = {}
for message_type in self.message_types:
future = futurist.Future()
self._set_future(request.message_id, message_type, future)
futures.append(future)
futures[message_type] = future
return futures
def untrack_request(self, request):
@ -102,14 +106,9 @@ class ReceiverBase(object):
future.set_result((reply_id, response))
class AckReceiver(ReceiverBase):
message_types = (zmq_names.ACK_TYPE,)
class ReplyReceiver(ReceiverBase):
message_types = (zmq_names.REPLY_TYPE,)
message_types = {zmq_names.REPLY_TYPE}
class ReplyReceiverProxy(ReplyReceiver):
@ -121,11 +120,12 @@ class ReplyReceiverProxy(ReplyReceiver):
assert reply_id is not None, "Reply ID expected!"
message_type = int(socket.recv())
assert message_type == zmq_names.REPLY_TYPE, "Reply expected!"
message_id = socket.recv()
raw_reply = socket.recv_loaded()
assert isinstance(raw_reply, dict), "Dict expected!"
reply = zmq_response.Response(**raw_reply)
LOG.debug("Received reply for %s", message_id)
message_id = socket.recv_string()
reply_body, failure = socket.recv_loaded()
reply = zmq_response.Reply(
message_id=message_id, reply_id=reply_id,
reply_body=reply_body, failure=failure
)
return reply_id, message_type, message_id, reply
@ -136,11 +136,45 @@ class ReplyReceiverDirect(ReplyReceiver):
assert empty == b'', "Empty expected!"
raw_reply = socket.recv_loaded()
assert isinstance(raw_reply, dict), "Dict expected!"
reply = zmq_response.Response(**raw_reply)
LOG.debug("Received reply for %s", reply.message_id)
reply = zmq_response.Reply(**raw_reply)
return reply.reply_id, reply.msg_type, reply.message_id, reply
class AckAndReplyReceiver(ReceiverBase):
message_types = (zmq_names.ACK_TYPE, zmq_names.REPLY_TYPE)
message_types = {zmq_names.ACK_TYPE, zmq_names.REPLY_TYPE}
class AckAndReplyReceiverProxy(AckAndReplyReceiver):
def recv_response(self, socket):
empty = socket.recv()
assert empty == b'', "Empty expected!"
reply_id = socket.recv()
assert reply_id is not None, "Reply ID expected!"
message_type = int(socket.recv())
assert message_type in (zmq_names.ACK_TYPE, zmq_names.REPLY_TYPE), \
"Ack or reply expected!"
message_id = socket.recv_string()
if message_type == zmq_names.REPLY_TYPE:
reply_body, failure = socket.recv_loaded()
reply = zmq_response.Reply(
message_id=message_id, reply_id=reply_id,
reply_body=reply_body, failure=failure
)
response = reply
else:
response = None
return reply_id, message_type, message_id, response
class AckAndReplyReceiverDirect(AckAndReplyReceiver):
def recv_response(self, socket):
# acks are not supported yet
empty = socket.recv()
assert empty == b'', "Empty expected!"
raw_reply = socket.recv_loaded()
assert isinstance(raw_reply, dict), "Dict expected!"
reply = zmq_response.Reply(**raw_reply)
return reply.reply_id, reply.msg_type, reply.message_id, reply

View File

@ -12,23 +12,24 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
import six
from oslo_messaging._drivers.zmq_driver import zmq_names
@six.add_metaclass(abc.ABCMeta)
class Response(object):
def __init__(self, msg_type=None, message_id=None,
reply_id=None, reply_body=None, failure=None):
def __init__(self, message_id=None, reply_id=None):
self._msg_type = msg_type
self._message_id = message_id
self._reply_id = reply_id
self._reply_body = reply_body
self._failure = failure
@property
@abc.abstractproperty
def msg_type(self):
return self._msg_type
pass
@property
def message_id(self):
@ -38,6 +39,29 @@ class Response(object):
def reply_id(self):
return self._reply_id
def to_dict(self):
return {zmq_names.FIELD_MSG_ID: self._message_id,
zmq_names.FIELD_REPLY_ID: self._reply_id}
def __str__(self):
return str(self.to_dict())
class Ack(Response):
msg_type = zmq_names.ACK_TYPE
class Reply(Response):
msg_type = zmq_names.REPLY_TYPE
def __init__(self, message_id=None, reply_id=None, reply_body=None,
failure=None):
super(Reply, self).__init__(message_id, reply_id)
self._reply_body = reply_body
self._failure = failure
@property
def reply_body(self):
return self._reply_body
@ -47,11 +71,7 @@ class Response(object):
return self._failure
def to_dict(self):
return {zmq_names.FIELD_MSG_TYPE: self._msg_type,
zmq_names.FIELD_MSG_ID: self._message_id,
zmq_names.FIELD_REPLY_ID: self._reply_id,
zmq_names.FIELD_REPLY_BODY: self._reply_body,
zmq_names.FIELD_FAILURE: self._failure}
def __str__(self):
return str(self.to_dict())
dict_ = super(Reply, self).to_dict()
dict_.update({zmq_names.FIELD_REPLY_BODY: self._reply_body,
zmq_names.FIELD_FAILURE: self._failure})
return dict_

View File

@ -35,12 +35,12 @@ class RoutingTable(object):
def get_all_hosts(self, target):
self._update_routing_table(target)
return list(self.routable_hosts.get(str(target)) or [])
return list(self.routable_hosts.get(str(target), []))
def get_routable_host(self, target):
self._update_routing_table(target)
hosts_for_target = self.routable_hosts[str(target)]
host = hosts_for_target.pop(0)
host = hosts_for_target.pop()
if not hosts_for_target:
self._renew_routable_hosts(target)
return host

View File

@ -41,19 +41,22 @@ class RequestSender(SenderBase):
pass
class AckSender(SenderBase):
pass
class ReplySender(SenderBase):
pass
class RequestSenderProxy(RequestSender):
class RequestSenderProxy(SenderBase):
def send(self, socket, request):
socket.send(b'', zmq.SNDMORE)
socket.send(six.b(str(request.msg_type)), zmq.SNDMORE)
socket.send(six.b(request.routing_key), zmq.SNDMORE)
socket.send(six.b(request.message_id), zmq.SNDMORE)
socket.send_dumped(request.context, zmq.SNDMORE)
socket.send_dumped(request.message)
socket.send(request.routing_key, zmq.SNDMORE)
socket.send_string(request.message_id, zmq.SNDMORE)
socket.send_dumped([request.context, request.message])
LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s message "
"%(msg_id)s to target %(target)s",
@ -63,28 +66,46 @@ class RequestSenderProxy(RequestSender):
"target": request.target})
class ReplySenderProxy(ReplySender):
class AckSenderProxy(AckSender):
def send(self, socket, ack):
assert ack.msg_type == zmq_names.ACK_TYPE, "Ack expected!"
socket.send(b'', zmq.SNDMORE)
socket.send(six.b(str(ack.msg_type)), zmq.SNDMORE)
socket.send(ack.reply_id, zmq.SNDMORE)
socket.send_string(ack.message_id)
LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s for %(msg_id)s",
{"addr": list(socket.connections),
"msg_type": zmq_names.message_type_str(ack.msg_type),
"msg_id": ack.message_id})
class ReplySenderProxy(SenderBase):
def send(self, socket, reply):
LOG.debug("Replying to %s", reply.message_id)
assert reply.msg_type == zmq_names.REPLY_TYPE, "Reply expected!"
socket.send(b'', zmq.SNDMORE)
socket.send(six.b(str(reply.msg_type)), zmq.SNDMORE)
socket.send(reply.reply_id, zmq.SNDMORE)
socket.send(reply.message_id, zmq.SNDMORE)
socket.send_dumped(reply.to_dict())
socket.send_string(reply.message_id, zmq.SNDMORE)
socket.send_dumped([reply.reply_body, reply.failure])
LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s for %(msg_id)s",
{"addr": list(socket.connections),
"msg_type": zmq_names.message_type_str(reply.msg_type),
"msg_id": reply.message_id})
class RequestSenderDirect(RequestSender):
class RequestSenderDirect(SenderBase):
def send(self, socket, request):
socket.send(b'', zmq.SNDMORE)
socket.send(six.b(str(request.msg_type)), zmq.SNDMORE)
socket.send_string(request.message_id, zmq.SNDMORE)
socket.send_dumped(request.context, zmq.SNDMORE)
socket.send_dumped(request.message)
socket.send_dumped([request.context, request.message])
LOG.debug("Sending %(msg_type)s message %(msg_id)s to "
"target %(target)s",
@ -93,13 +114,27 @@ class RequestSenderDirect(RequestSender):
"target": request.target})
class ReplySenderDirect(ReplySender):
class AckSenderDirect(AckSender):
def send(self, socket, ack):
assert ack.msg_type == zmq_names.ACK_TYPE, "Ack expected!"
# not implemented yet
LOG.debug("Sending %(msg_type)s for %(msg_id)s",
{"msg_type": zmq_names.message_type_str(ack.msg_type),
"msg_id": ack.message_id})
class ReplySenderDirect(SenderBase):
def send(self, socket, reply):
LOG.debug("Replying to %s", reply.message_id)
assert reply.msg_type == zmq_names.REPLY_TYPE, "Reply expected!"
socket.send(reply.reply_id, zmq.SNDMORE)
socket.send(b'', zmq.SNDMORE)
socket.send_dumped(reply.to_dict())
LOG.debug("Sending %(msg_type)s for %(msg_id)s",
{"msg_type": zmq_names.message_type_str(reply.msg_type),
"msg_id": reply.message_id})

View File

@ -28,7 +28,7 @@ class MatchMakerBase(object):
self.url = kwargs.get('url')
@abc.abstractmethod
def register_publisher(self, hostname):
def register_publisher(self, hostname, expire=-1):
"""Register publisher on nameserver.
This works for PUB-SUB only
@ -36,6 +36,8 @@ class MatchMakerBase(object):
:param hostname: host for the topic in "host:port" format
host for back-chatter in "host:port" format
:type hostname: tuple
:param expire: record expiration timeout
:type expire: int
"""
@abc.abstractmethod
@ -57,13 +59,15 @@ class MatchMakerBase(object):
"""
@abc.abstractmethod
def register_router(self, hostname):
def register_router(self, hostname, expire=-1):
"""Register router on the nameserver.
This works for ROUTER proxy only
:param hostname: host for the topic in "host:port" format
:type hostname: string
:type hostname: str
:param expire: record expiration timeout
:type expire: int
"""
@abc.abstractmethod
@ -73,7 +77,7 @@ class MatchMakerBase(object):
This works for ROUTER proxy only
:param hostname: host for the topic in "host:port" format
:type hostname: string
:type hostname: str
"""
@abc.abstractmethod
@ -92,10 +96,10 @@ class MatchMakerBase(object):
:param target: the target for host
:type target: Target
:param hostname: host for the topic in "host:port" format
:type hostname: String
:param listener_type: Listener socket type ROUTER, SUB etc.
:type listener_type: String
:param expire: Record expiration timeout
:type hostname: str
:param listener_type: listener socket type ROUTER, SUB etc.
:type listener_type: str
:param expire: record expiration timeout
:type expire: int
"""
@ -106,9 +110,9 @@ class MatchMakerBase(object):
:param target: the target for host
:type target: Target
:param hostname: host for the topic in "host:port" format
:type hostname: String
:param listener_type: Listener socket type ROUTER, SUB etc.
:type listener_type: String
:type hostname: str
:param listener_type: listener socket type ROUTER, SUB etc.
:type listener_type: str
"""
@abc.abstractmethod
@ -117,6 +121,8 @@ class MatchMakerBase(object):
:param target: the default target for invocations
:type target: Target
:param listener_type: listener socket type ROUTER, SUB etc.
:type listener_type: str
:returns: a list of "hostname:port" hosts
"""
@ -130,7 +136,7 @@ class DummyMatchMaker(MatchMakerBase):
self._publishers = set()
self._routers = set()
def register_publisher(self, hostname):
def register_publisher(self, hostname, expire=-1):
if hostname not in self._publishers:
self._publishers.add(hostname)
@ -141,7 +147,7 @@ class DummyMatchMaker(MatchMakerBase):
def get_publishers(self):
return list(self._publishers)
def register_router(self, hostname):
def register_router(self, hostname, expire=-1):
if hostname not in self._routers:
self._routers.add(hostname)

View File

@ -31,6 +31,11 @@ class GreenPoller(zmq_poller.ZmqPoller):
self.thread_by_socket[socket] = self.green_pool.spawn(
self._socket_receive, socket, recv_method)
def unregister(self, socket):
thread = self.thread_by_socket.pop(socket, None)
if thread:
thread.kill()
def _socket_receive(self, socket, recv_method=None):
while True:
if recv_method:

View File

@ -37,6 +37,10 @@ class ThreadingPoller(zmq_poller.ZmqPoller):
self.recv_methods[socket] = recv_method
self.poller.register(socket, zmq.POLLIN)
def unregister(self, socket):
self.recv_methods.pop(socket, None)
self.poller.unregister(socket)
def poll(self, timeout=None):
if timeout is not None and timeout > 0:
timeout *= 1000 # convert seconds to milliseconds

View File

@ -20,10 +20,11 @@ from oslo_messaging._drivers.zmq_driver.client import zmq_sockets_manager
from oslo_messaging._drivers.zmq_driver.server.consumers \
import zmq_consumer_base
from oslo_messaging._drivers.zmq_driver.server import zmq_incoming_message
from oslo_messaging._drivers.zmq_driver.server import zmq_ttl_cache
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_updater
from oslo_messaging._i18n import _LE, _LI
from oslo_messaging._i18n import _LE, _LI, _LW
LOG = logging.getLogger(__name__)
@ -33,7 +34,11 @@ zmq = zmq_async.import_zmq()
class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server):
self.sender = zmq_senders.ReplySenderProxy(conf)
self.ack_sender = zmq_senders.AckSenderProxy(conf)
self.reply_sender = zmq_senders.ReplySenderProxy(conf)
self.received_messages = zmq_ttl_cache.TTLCache(
ttl=conf.oslo_messaging_zmq.rpc_message_ttl
)
self.sockets_manager = zmq_sockets_manager.SocketsManager(
conf, server.matchmaker, zmq.ROUTER, zmq.DEALER)
self.host = None
@ -53,34 +58,63 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
LOG.error(_LE("Failed connecting to ROUTER socket %(e)s") % e)
raise rpc_common.RPCException(str(e))
def _receive_request(self, socket):
empty = socket.recv()
assert empty == b'', 'Bad format: empty delimiter expected'
reply_id = socket.recv()
msg_type = int(socket.recv())
message_id = socket.recv_string()
context, message = socket.recv_loaded()
return reply_id, msg_type, message_id, context, message
def receive_message(self, socket):
try:
empty = socket.recv()
assert empty == b'', 'Bad format: empty delimiter expected'
reply_id = socket.recv()
message_type = int(socket.recv())
message_id = socket.recv()
context = socket.recv_loaded()
message = socket.recv_loaded()
LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
"msg_type": zmq_names.message_type_str(message_type),
"msg_id": message_id})
if message_type == zmq_names.CALL_TYPE:
return zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id, message_id, socket, self.sender
reply_id, msg_type, message_id, context, message = \
self._receive_request(socket)
if msg_type == zmq_names.CALL_TYPE or \
msg_type in zmq_names.NON_BLOCKING_TYPES:
ack_sender = self.ack_sender \
if self.conf.oslo_messaging_zmq.rpc_use_acks else None
reply_sender = self.reply_sender \
if msg_type == zmq_names.CALL_TYPE else None
message = zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id, message_id, socket,
ack_sender, reply_sender
)
elif message_type in zmq_names.NON_BLOCKING_TYPES:
return zmq_incoming_message.ZmqIncomingMessage(context,
message)
# drop duplicate message
if message_id in self.received_messages:
LOG.warning(
_LW("[%(host)s] Dropping duplicate %(msg_type)s "
"message %(msg_id)s"),
{"host": self.host,
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id}
)
message.acknowledge()
return None
self.received_messages.add(message_id)
LOG.debug(
"[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id}
)
return message
else:
LOG.error(_LE("Unknown message type: %s"),
zmq_names.message_type_str(message_type))
zmq_names.message_type_str(msg_type))
except (zmq.ZMQError, AssertionError, ValueError) as e:
LOG.error(_LE("Receiving message failure: %s"), str(e))
def cleanup(self):
LOG.info(_LI("[%s] Destroy DEALER consumer"), self.host)
self.received_messages.cleanup()
self.connection_updater.cleanup()
super(DealerConsumer, self).cleanup()

View File

@ -30,7 +30,8 @@ zmq = zmq_async.import_zmq()
class RouterConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server):
self.sender = zmq_senders.ReplySenderDirect(conf)
self.ack_sender = zmq_senders.AckSenderDirect(conf)
self.reply_sender = zmq_senders.ReplySenderDirect(conf)
super(RouterConsumer, self).__init__(conf, poller, server, zmq.ROUTER)
LOG.info(_LI("[%s] Run ROUTER consumer"), self.host)
@ -40,26 +41,29 @@ class RouterConsumer(zmq_consumer_base.SingleSocketConsumer):
assert empty == b'', 'Bad format: empty delimiter expected'
msg_type = int(socket.recv())
message_id = socket.recv_string()
context = socket.recv_loaded()
message = socket.recv_loaded()
context, message = socket.recv_loaded()
return reply_id, msg_type, message_id, context, message
def receive_message(self, socket):
try:
reply_id, msg_type, message_id, context, message = \
self._receive_request(socket)
LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id})
if msg_type == zmq_names.CALL_TYPE:
if msg_type == zmq_names.CALL_TYPE or \
msg_type in zmq_names.NON_BLOCKING_TYPES:
ack_sender = self.ack_sender \
if self.conf.oslo_messaging_zmq.rpc_use_acks else None
reply_sender = self.reply_sender \
if msg_type == zmq_names.CALL_TYPE else None
return zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id, message_id, socket, self.sender
context, message, reply_id, message_id, socket,
ack_sender, reply_sender
)
elif msg_type in zmq_names.NON_BLOCKING_TYPES:
return zmq_incoming_message.ZmqIncomingMessage(context,
message)
else:
LOG.error(_LE("Unknown message type: %s"),
zmq_names.message_type_str(msg_type))

View File

@ -63,8 +63,7 @@ class SubConsumer(zmq_consumer_base.ConsumerBase):
def _receive_request(socket):
topic_filter = socket.recv()
message_id = socket.recv()
context = socket.recv_loaded()
message = socket.recv_loaded()
context, message = socket.recv_loaded()
LOG.debug("Received %(topic_filter)s topic message %(id)s",
{'id': message_id, 'topic_filter': topic_filter})
return context, message

View File

@ -12,14 +12,12 @@
# License for the specific language governing permissions and limitations
# under the License.
import logging
from oslo_messaging._drivers import base
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver.client import zmq_response
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
LOG = logging.getLogger(__name__)
@ -29,9 +27,9 @@ zmq = zmq_async.import_zmq()
class ZmqIncomingMessage(base.RpcIncomingMessage):
def __init__(self, context, message, reply_id=None, message_id=None,
socket=None, sender=None):
socket=None, ack_sender=None, reply_sender=None):
if sender is not None:
if ack_sender is not None or reply_sender is not None:
assert socket is not None, "Valid socket expected!"
assert message_id is not None, "Valid message ID expected!"
assert reply_id is not None, "Valid reply ID expected!"
@ -41,21 +39,24 @@ class ZmqIncomingMessage(base.RpcIncomingMessage):
self.reply_id = reply_id
self.message_id = message_id
self.socket = socket
self.sender = sender
self.ack_sender = ack_sender
self.reply_sender = reply_sender
def acknowledge(self):
"""Not sending acknowledge"""
if self.ack_sender is not None:
ack = zmq_response.Ack(message_id=self.message_id,
reply_id=self.reply_id)
self.ack_sender.send(self.socket, ack)
def reply(self, reply=None, failure=None):
if self.sender is not None:
if self.reply_sender is not None:
if failure is not None:
failure = rpc_common.serialize_remote_exception(failure)
reply = zmq_response.Response(msg_type=zmq_names.REPLY_TYPE,
message_id=self.message_id,
reply_id=self.reply_id,
reply_body=reply,
failure=failure)
self.sender.send(self.socket, reply)
reply = zmq_response.Reply(message_id=self.message_id,
reply_id=self.reply_id,
reply_body=reply,
failure=failure)
self.reply_sender.send(self.socket, reply)
def requeue(self):
"""Requeue is not supported"""

View File

@ -0,0 +1,79 @@
# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import threading
import time
import six
from oslo_messaging._drivers.zmq_driver import zmq_async
zmq = zmq_async.import_zmq()
class TTLCache(object):
def __init__(self, ttl=None):
self._lock = threading.Lock()
self._expiration_times = {}
self._executor = None
if not (ttl is None or isinstance(ttl, (int, float))):
raise ValueError('ttl must be None or a number')
# no (i.e. infinite) ttl
if ttl is None or ttl <= 0:
ttl = float('inf')
else:
self._executor = zmq_async.get_executor(self._update_cache)
self._ttl = ttl
if self._executor:
self._executor.execute()
@staticmethod
def _is_expired(expiration_time, current_time):
return expiration_time <= current_time
def add(self, item):
with self._lock:
self._expiration_times[item] = time.time() + self._ttl
def discard(self, item):
with self._lock:
self._expiration_times.pop(item, None)
def __contains__(self, item):
with self._lock:
expiration_time = self._expiration_times.get(item)
if expiration_time is None:
return False
if self._is_expired(expiration_time, time.time()):
self._expiration_times.pop(item)
return False
return True
def _update_cache(self):
with self._lock:
current_time = time.time()
self._expiration_times = \
{item: expiration_time for
item, expiration_time in six.iteritems(self._expiration_times)
if not self._is_expired(expiration_time, current_time)}
time.sleep(self._ttl)
def cleanup(self):
if self._executor:
self._executor.stop()

View File

@ -42,6 +42,15 @@ def get_executor(method):
return threading_poller.ThreadingExecutor(method)
def get_pool(size):
import futurist
if eventletutils.is_monkey_patched('thread'):
return futurist.GreenThreadPoolExecutor(size)
return futurist.ThreadPoolExecutor(size)
def get_queue():
if eventletutils.is_monkey_patched('thread'):
import eventlet

View File

@ -17,7 +17,6 @@ from oslo_messaging._drivers.zmq_driver import zmq_async
zmq = zmq_async.import_zmq()
FIELD_MSG_TYPE = 'msg_type'
FIELD_MSG_ID = 'message_id'
FIELD_REPLY_ID = 'reply_id'
FIELD_REPLY_BODY = 'reply_body'

View File

@ -113,10 +113,41 @@ zmq_opts = [
'serializing/deserializing outgoing/incoming messages')
]
zmq_ack_retry_opts = [
cfg.IntOpt('rpc_thread_pool_size', default=100,
help='Maximum number of (green) threads to work concurrently.'),
cfg.IntOpt('rpc_message_ttl', default=300,
help='Expiration timeout in seconds of a sent/received message '
'after which it is not tracked anymore by a '
'client/server.'),
cfg.BoolOpt('rpc_use_acks', default=True,
help='Wait for message acknowledgements from receivers. '
'This mechanism works only via proxy without PUB/SUB.'),
cfg.IntOpt('rpc_ack_timeout_base', default=10,
help='Number of seconds to wait for an ack from a cast/call. '
'After each retry attempt this timeout is multiplied by '
'some specified multiplier.'),
cfg.IntOpt('rpc_ack_timeout_multiplier', default=2,
help='Number to multiply base ack timeout by after each retry '
'attempt.'),
cfg.IntOpt('rpc_retry_attempts', default=3,
help='Default number of message sending attempts in case '
'of any problems occurred: positive value N means '
'at most N retries, 0 means no retries, None or -1 '
'(or any other negative values) mean to retry forever. '
'This option is used only if acknowledgments are enabled.')
]
def register_opts(conf):
opt_group = cfg.OptGroup(name='oslo_messaging_zmq',
title='ZeroMQ driver options')
conf.register_opts(zmq_opts, group=opt_group)
conf.register_opts(zmq_ack_retry_opts, group=opt_group)
conf.register_opts(server._pool_opts)
conf.register_opts(base.base_opts)

View File

@ -62,6 +62,13 @@ class ZmqPoller(object):
Should return received message object
:type recv_method: callable
"""
@abc.abstractmethod
def unregister(self, socket):
"""Unregister socket from poll
:param socket: Socket to unsubscribe from polling
:type socket: zmq.Socket
"""
@abc.abstractmethod
def poll(self, timeout=None):

View File

@ -86,8 +86,7 @@ class TestPubSub(zmq_common.ZmqBaseTestCase):
zmq_address.target_to_subscribe_filter(target),
b"message",
b"0000-0000",
self.dumps(context),
self.dumps(message)])
self.dumps([context, message])])
def _check_listener(self, listener):
listener._received.wait(timeout=5)

View File

@ -0,0 +1,185 @@
# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
import testtools
import oslo_messaging
from oslo_messaging._drivers.zmq_driver.client import zmq_receivers
from oslo_messaging._drivers.zmq_driver.client import zmq_senders
from oslo_messaging._drivers.zmq_driver.proxy import zmq_proxy
from oslo_messaging._drivers.zmq_driver.proxy import zmq_queue_proxy
from oslo_messaging._drivers.zmq_driver.server import zmq_incoming_message
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_options
from oslo_messaging.tests.drivers.zmq import zmq_common
from oslo_messaging.tests import utils as test_utils
zmq = zmq_async.import_zmq()
class TestZmqAckManager(test_utils.BaseTestCase):
@testtools.skipIf(zmq is None, "zmq not available")
def setUp(self):
super(TestZmqAckManager, self).setUp()
self.messaging_conf.transport_driver = 'zmq'
zmq_options.register_opts(self.conf)
# set config opts
kwargs = {'rpc_zmq_matchmaker': 'dummy',
'use_pub_sub': False,
'use_router_proxy': True,
'rpc_thread_pool_size': 1,
'rpc_use_acks': True,
'rpc_ack_timeout_base': 3,
'rpc_ack_timeout_multiplier': 1,
'rpc_retry_attempts': 2}
self.config(group='oslo_messaging_zmq', **kwargs)
self.conf.register_opts(zmq_proxy.zmq_proxy_opts,
group='zmq_proxy_opts')
# mock set_result method of futures
self.set_result_patcher = mock.patch.object(
zmq_receivers.futurist.Future, 'set_result',
side_effect=zmq_receivers.futurist.Future.set_result, autospec=True
)
self.set_result = self.set_result_patcher.start()
# mock send method of senders
self.send_patcher = mock.patch.object(
zmq_senders.RequestSenderProxy, 'send',
side_effect=zmq_senders.RequestSenderProxy.send, autospec=True
)
self.send = self.send_patcher.start()
# get driver
transport = oslo_messaging.get_transport(self.conf)
self.driver = transport._driver
# get ack manager
self.ack_manager = self.driver.client.get().publishers['default']
# prepare and launch proxy
self.proxy = zmq_proxy.ZmqProxy(self.conf,
zmq_queue_proxy.UniversalQueueProxy)
vars(self.driver.matchmaker).update(vars(self.proxy.matchmaker))
self.executor = zmq_async.get_executor(self.proxy.run)
self.executor.execute()
# create listener
self.listener = zmq_common.TestServerListener(self.driver)
# create target and message
self.target = oslo_messaging.Target(topic='topic', server='server')
self.message = {'method': 'xyz', 'args': {'x': 1, 'y': 2, 'z': 3}}
self.addCleanup(
zmq_common.StopRpc(
self, [('listener', 'stop'), ('executor', 'stop'),
('proxy', 'close'), ('driver', 'cleanup'),
('send_patcher', 'stop'),
('set_result_patcher', 'stop')]
)
)
@mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, 'acknowledge',
side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge,
autospec=True
)
def test_cast_success_without_retries(self, received_ack_mock):
self.listener.listen(self.target)
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.ack_manager._pool.shutdown(wait=True)
self.assertIsNone(result)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count)
self.assertEqual(2, self.set_result.call_count)
def test_cast_success_with_one_retry(self):
self.listener.listen(self.target)
with mock.patch.object(zmq_incoming_message.ZmqIncomingMessage,
'acknowledge') as lost_ack_mock:
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.listener._received.wait(3)
self.assertIsNone(result)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
self.assertEqual(1, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count)
with mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, 'acknowledge',
side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge,
autospec=True
) as received_ack_mock:
self.listener._received.clear()
self.ack_manager._pool.shutdown(wait=True)
self.listener._received.wait(3)
self.assertFalse(self.listener._received.isSet())
self.assertEqual(2, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count)
self.assertEqual(2, self.set_result.call_count)
def test_cast_success_with_two_retries(self):
self.listener.listen(self.target)
with mock.patch.object(zmq_incoming_message.ZmqIncomingMessage,
'acknowledge') as lost_ack_mock:
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.listener._received.wait(3)
self.assertIsNone(result)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
self.assertEqual(1, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count)
self.listener._received.clear()
self.listener._received.wait(4.5)
self.assertFalse(self.listener._received.isSet())
self.assertEqual(2, self.send.call_count)
self.assertEqual(2, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count)
with mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, 'acknowledge',
side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge,
autospec=True
) as received_ack_mock:
self.ack_manager._pool.shutdown(wait=True)
self.assertFalse(self.listener._received.isSet())
self.assertEqual(3, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count)
self.assertEqual(2, self.set_result.call_count)
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, 'acknowledge')
def test_cast_failure_exhausted_retries(self, lost_ack_mock):
self.listener.listen(self.target)
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.ack_manager._pool.shutdown(wait=True)
self.assertIsNone(result)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(3, self.send.call_count)
self.assertEqual(3, lost_ack_mock.call_count)
self.assertEqual(1, self.set_result.call_count)

View File

@ -57,37 +57,16 @@ class TestGetPoller(test_utils.BaseTestCase):
def test_when_eventlet_is_available_then_return_GreenPoller(self):
zmq_async.eventletutils.is_monkey_patched = lambda _: True
actual = zmq_async.get_poller()
poller = zmq_async.get_poller()
self.assertTrue(isinstance(actual, green_poller.GreenPoller))
self.assertTrue(isinstance(poller, green_poller.GreenPoller))
def test_when_eventlet_is_unavailable_then_return_ThreadingPoller(self):
zmq_async.eventletutils.is_monkey_patched = lambda _: False
actual = zmq_async.get_poller()
poller = zmq_async.get_poller()
self.assertTrue(isinstance(actual, threading_poller.ThreadingPoller))
class TestGetReplyPoller(test_utils.BaseTestCase):
@testtools.skipIf(zmq is None, "zmq not available")
def setUp(self):
super(TestGetReplyPoller, self).setUp()
def test_when_eventlet_is_available_then_return_HoldReplyPoller(self):
zmq_async.eventletutils.is_monkey_patched = lambda _: True
actual = zmq_async.get_poller()
self.assertTrue(isinstance(actual, green_poller.GreenPoller))
def test_when_eventlet_is_unavailable_then_return_ThreadingPoller(self):
zmq_async.eventletutils.is_monkey_patched = lambda _: False
actual = zmq_async.get_poller()
self.assertTrue(isinstance(actual, threading_poller.ThreadingPoller))
self.assertTrue(isinstance(poller, threading_poller.ThreadingPoller))
class TestGetExecutor(test_utils.BaseTestCase):

View File

@ -0,0 +1,116 @@
# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import time
from oslo_messaging._drivers.zmq_driver.server import zmq_ttl_cache
from oslo_messaging.tests import utils as test_utils
class TestZmqTTLCache(test_utils.BaseTestCase):
def setUp(self):
super(TestZmqTTLCache, self).setUp()
def call_count_decorator(unbound_method):
def wrapper(self, *args, **kwargs):
wrapper.call_count += 1
return unbound_method(self, *args, **kwargs)
wrapper.call_count = 0
return wrapper
zmq_ttl_cache.TTLCache._update_cache = \
call_count_decorator(zmq_ttl_cache.TTLCache._update_cache)
self.cache = zmq_ttl_cache.TTLCache(ttl=1)
def _test_in_operator(self):
self.cache.add(1)
self.assertTrue(1 in self.cache)
time.sleep(0.5)
self.cache.add(2)
self.assertTrue(1 in self.cache)
self.assertTrue(2 in self.cache)
time.sleep(0.75)
self.cache.add(3)
self.assertFalse(1 in self.cache)
self.assertTrue(2 in self.cache)
self.assertTrue(3 in self.cache)
time.sleep(0.5)
self.assertFalse(2 in self.cache)
self.assertTrue(3 in self.cache)
def test_in_operator_with_executor(self):
self._test_in_operator()
def test_in_operator_without_executor(self):
self.cache._executor.stop()
self._test_in_operator()
def _is_expired(self, item):
with self.cache._lock:
return self.cache._is_expired(self.cache._expiration_times[item],
time.time())
def test_executor(self):
self.cache.add(1)
self.assertEqual([1], sorted(self.cache._expiration_times.keys()))
self.assertFalse(self._is_expired(1))
time.sleep(0.75)
self.assertEqual(1, self.cache._update_cache.call_count)
self.cache.add(2)
self.assertEqual([1, 2], sorted(self.cache._expiration_times.keys()))
self.assertFalse(self._is_expired(1))
self.assertFalse(self._is_expired(2))
time.sleep(0.75)
self.assertEqual(2, self.cache._update_cache.call_count)
self.cache.add(3)
if 1 in self.cache:
self.assertEqual([1, 2, 3],
sorted(self.cache._expiration_times.keys()))
self.assertTrue(self._is_expired(1))
else:
self.assertEqual([2, 3],
sorted(self.cache._expiration_times.keys()))
self.assertFalse(self._is_expired(2))
self.assertFalse(self._is_expired(3))
time.sleep(0.75)
self.assertEqual(3, self.cache._update_cache.call_count)
self.assertEqual([3], sorted(self.cache._expiration_times.keys()))
self.assertFalse(self._is_expired(3))
def cleanUp(self):
self.cache.cleanup()
super(TestZmqTTLCache, self).cleanUp()

View File

@ -91,15 +91,20 @@ class ZmqBaseTestCase(test_utils.BaseTestCase):
self.listener = TestServerListener(self.driver)
self.addCleanup(StopRpc(self.__dict__))
self.addCleanup(
StopRpc(self, [('listener', 'stop'), ('driver', 'cleanup')])
)
class StopRpc(object):
def __init__(self, attrs):
self.attrs = attrs
def __init__(self, obj, attrs_and_stops):
self.obj = obj
self.attrs_and_stops = attrs_and_stops
def __call__(self):
if self.attrs['driver']:
self.attrs['driver'].cleanup()
if self.attrs['listener']:
self.attrs['listener'].stop()
for attr, stop in self.attrs_and_stops:
if hasattr(self.obj, attr):
obj_attr = getattr(self.obj, attr)
if hasattr(obj_attr, stop):
obj_attr_stop = getattr(obj_attr, stop)
obj_attr_stop()