diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_base.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_base.py index 3c232e365..6b8a70bd9 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_base.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_base.py @@ -12,11 +12,9 @@ # License for the specific language governing permissions and limitations # under the License. -import abc from concurrent import futures import logging -import oslo_messaging from oslo_messaging._drivers import common as rpc_common from oslo_messaging._drivers.zmq_driver.client.publishers \ import zmq_publisher_base @@ -41,27 +39,17 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase): super(DealerPublisherBase, self).__init__(sockets_manager, sender, receiver) - @staticmethod - def _check_pattern(request, supported_pattern): - if request.msg_type != supported_pattern: - raise zmq_publisher_base.UnsupportedSendPattern( - zmq_names.message_type_str(request.msg_type) - ) + def _check_received_data(self, reply_id, reply, request): + assert isinstance(reply, zmq_response.Reply), "Reply expected!" - @staticmethod - def _raise_timeout(request): - raise oslo_messaging.MessagingTimeout( - "Timeout %(tout)s seconds was reached for message %(msg_id)s" % - {"tout": request.timeout, "msg_id": request.message_id} - ) - - def _recv_reply(self, request): + def _recv_reply(self, request, socket): + self.receiver.register_socket(socket) reply_future = \ self.receiver.track_request(request)[zmq_names.REPLY_TYPE] try: - _, reply = reply_future.result(timeout=request.timeout) - assert isinstance(reply, zmq_response.Reply), "Reply expected!" + reply_id, reply = reply_future.result(timeout=request.timeout) + self._check_received_data(reply_id, reply, request) except AssertionError: LOG.error(_LE("Message format error in reply for %s"), request.message_id) @@ -77,30 +65,3 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase): ) else: return reply.reply_body - - def send_call(self, request): - self._check_pattern(request, zmq_names.CALL_TYPE) - - socket = self.connect_socket(request) - if not socket: - self._raise_timeout(request) - - self.sender.send(socket, request) - self.receiver.register_socket(socket) - return self._recv_reply(request) - - @abc.abstractmethod - def _send_non_blocking(self, request): - pass - - def send_cast(self, request): - self._check_pattern(request, zmq_names.CAST_TYPE) - self._send_non_blocking(request) - - def send_fanout(self, request): - self._check_pattern(request, zmq_names.CAST_FANOUT_TYPE) - self._send_non_blocking(request) - - def send_notify(self, request): - self._check_pattern(request, zmq_names.NOTIFY_TYPE) - self._send_non_blocking(request) diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_direct.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_direct.py index f6d30401c..9356ac623 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_direct.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_direct.py @@ -40,19 +40,19 @@ class DealerPublisherDirect(zmq_dealer_publisher_base.DealerPublisherBase): super(DealerPublisherDirect, self).__init__(conf, matchmaker, sender, receiver) - def connect_socket(self, request): + def _connect_socket(self, request): try: return self.sockets_manager.get_socket(request.target) except retrying.RetryError: return None - def _send_non_blocking(self, request): - socket = self.connect_socket(request) + def _send_request(self, request): + socket = self._connect_socket(request) if not socket: - return - + return None if request.msg_type in zmq_names.MULTISEND_TYPES: for _ in range(socket.connections_count()): self.sender.send(socket, request) else: self.sender.send(socket, request) + return socket diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_proxy.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_proxy.py index 1c29e5d7f..dce8c2ff0 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_proxy.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher_proxy.py @@ -55,15 +55,11 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase): return six.b(self.conf.oslo_messaging_zmq.rpc_zmq_host + "/" + str(uuid.uuid4())) - def connect_socket(self, request): - return self.socket - - def send_call(self, request): - request.routing_key = \ - self.routing_table.get_routable_host(request.target) - if request.routing_key is None: - self._raise_timeout(request) - return super(DealerPublisherProxy, self).send_call(request) + def _check_received_data(self, reply_id, reply, request): + super(DealerPublisherProxy, self)._check_received_data(reply_id, reply, + request) + assert reply_id == request.routing_key, \ + "Reply from recipient expected!" def _get_routing_keys(self, request): if request.msg_type in zmq_names.DIRECT_TYPES: @@ -74,16 +70,20 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase): if self.conf.oslo_messaging_zmq.use_pub_sub else \ self.routing_table.get_all_hosts(request.target) - def _send_non_blocking(self, request): - for routing_key in self._get_routing_keys(request): - if routing_key is None: - LOG.warning(_LW("Matchmaker contains no record for specified " - "target %(target)s. Dropping message %(id)s.") - % {"target": request.target, - "id": request.message_id}) - continue + def _send_request(self, request): + routing_keys = [routing_key + for routing_key in self._get_routing_keys(request) + if routing_key is not None] + if not routing_keys: + LOG.warning(_LW("Matchmaker contains no records for specified " + "target %(target)s. Dropping message %(msg_id)s.") + % {"target": request.target, + "msg_id": request.message_id}) + return None + for routing_key in routing_keys: request.routing_key = routing_key self.sender.send(self.socket, request) + return self.socket def cleanup(self): super(DealerPublisherProxy, self).cleanup() diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py index 8c6c100ed..c7c4cc8d4 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py @@ -17,8 +17,10 @@ import logging import six +import oslo_messaging from oslo_messaging._drivers import common as rpc_common from oslo_messaging._drivers.zmq_driver import zmq_async +from oslo_messaging._drivers.zmq_driver import zmq_names from oslo_messaging._i18n import _LE LOG = logging.getLogger(__name__) @@ -72,27 +74,48 @@ class PublisherBase(object): self.sender = sender self.receiver = receiver + @staticmethod + def _check_message_pattern(expected, actual): + if expected != actual: + raise UnsupportedSendPattern(zmq_names.message_type_str(actual)) + + @staticmethod + def _raise_timeout(request): + raise oslo_messaging.MessagingTimeout( + "Timeout %(tout)s seconds was reached for message %(msg_id)s" % + {"tout": request.timeout, "msg_id": request.message_id} + ) + @abc.abstractmethod - def connect_socket(self, request): - """Get connected socket ready for sending given request - or None otherwise (i.e. if connection can't be established). + def _send_request(self, request): + """Send the request and return a socket used for that. + Return value of None means some failure (e.g. connection + can't be established, etc). """ @abc.abstractmethod + def _recv_reply(self, request, socket): + """Wait for a reply via the socket used for sending the request.""" + def send_call(self, request): - pass + self._check_message_pattern(zmq_names.CALL_TYPE, request.msg_type) + socket = self._send_request(request) + if not socket: + raise self._raise_timeout(request) + return self._recv_reply(request, socket) - @abc.abstractmethod def send_cast(self, request): - pass + self._check_message_pattern(zmq_names.CAST_TYPE, request.msg_type) + self._send_request(request) - @abc.abstractmethod def send_fanout(self, request): - pass + self._check_message_pattern(zmq_names.CAST_FANOUT_TYPE, + request.msg_type) + self._send_request(request) - @abc.abstractmethod def send_notify(self, request): - pass + self._check_message_pattern(zmq_names.NOTIFY_TYPE, request.msg_type) + self._send_request(request) def cleanup(self): """Cleanup publisher. Close allocated connections.""" diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_ack_manager.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_ack_manager.py index 01bbc35a9..e1d9e8897 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_ack_manager.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_ack_manager.py @@ -61,9 +61,11 @@ class AckManagerProxy(AckManagerBase): ) def _wait_for_ack(self, ack_future): - request, socket = ack_future.args + request = ack_future.request retries = \ request.retry or self.conf.oslo_messaging_zmq.rpc_retry_attempts + if retries is None: + retries = -1 timeout = self.conf.oslo_messaging_zmq.rpc_ack_timeout_base done = False @@ -72,8 +74,9 @@ class AckManagerProxy(AckManagerBase): reply_id, response = ack_future.result(timeout=timeout) done = True assert response is None, "Ack expected!" - assert reply_id == request.routing_key, \ - "Ack from recipient expected!" + if reply_id is not None: + assert reply_id == request.routing_key, \ + "Ack from recipient expected!" except AssertionError: LOG.error(_LE("Message format error in ack for %s"), request.message_id) @@ -82,10 +85,10 @@ class AckManagerProxy(AckManagerBase): "for %(msg_id)s"), {"tout": timeout, "msg_id": request.message_id}) - if retries is None or retries != 0: - if retries is not None and retries > 0: + if retries != 0: + if retries > 0: retries -= 1 - self.sender.send(socket, request) + self.sender.send(ack_future.socket, request) timeout *= \ self.conf.oslo_messaging_zmq.rpc_ack_timeout_multiplier else: @@ -93,18 +96,35 @@ class AckManagerProxy(AckManagerBase): request.message_id) done = True - self.receiver.untrack_request(request) + if request.msg_type != zmq_names.CALL_TYPE: + self.receiver.untrack_request(request) - def _get_ack_future(self, request): - socket = self.publisher.connect_socket(request) + def _send_request_and_get_ack_future(self, request): + socket = self.publisher._send_request(request) + if not socket: + return None self.receiver.register_socket(socket) ack_future = self.receiver.track_request(request)[zmq_names.ACK_TYPE] - ack_future.args = request, socket + ack_future.request = request + ack_future.socket = socket return ack_future + def send_call(self, request): + ack_future = self._send_request_and_get_ack_future(request) + if not ack_future: + self.publisher._raise_timeout(request) + self._pool.submit(self._wait_for_ack, ack_future) + try: + return self.publisher._recv_reply(request, ack_future.socket) + finally: + if not ack_future.done(): + ack_future.set_result((None, None)) + def send_cast(self, request): - self.publisher.send_cast(request) - self._pool.submit(self._wait_for_ack, self._get_ack_future(request)) + ack_future = self._send_request_and_get_ack_future(request) + if not ack_future: + return + self._pool.submit(self._wait_for_ack, ack_future) def cleanup(self): self._pool.shutdown(wait=True) diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py index 9f7aeecf7..40c824e53 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py @@ -63,9 +63,12 @@ class ReceiverBase(object): a dict of futures for monitoring all types of responses. """ futures = {} + message_id = request.message_id for message_type in self.message_types: - future = futurist.Future() - self._set_future(request.message_id, message_type, future) + future = self._get_future(message_id, message_type) + if future is None: + future = futurist.Future() + self._set_future(message_id, message_type, future) futures[message_type] = future return futures @@ -92,7 +95,8 @@ class ReceiverBase(object): def _run_loop(self): data, socket = self._poller.poll( - timeout=self.conf.oslo_messaging_zmq.rpc_poll_timeout) + timeout=self.conf.oslo_messaging_zmq.rpc_poll_timeout + ) if data is None: return reply_id, message_type, message_id, response = data diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_routing_table.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_routing_table.py index d6f9b94ee..569826dd1 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_routing_table.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_routing_table.py @@ -76,9 +76,9 @@ class RoutingTable(object): key = str(target) if key not in self.routing_table: try: - self.routing_table[key] = (get_hosts_retry( - target, zmq_names.socket_type_str(zmq.DEALER)), - time.time()) + hosts = get_hosts_retry( + target, zmq_names.socket_type_str(zmq.DEALER)) + self.routing_table[key] = (hosts, time.time()) except retrying.RetryError: LOG.warning(_LW("Matchmaker contains no hosts for target %s") % key) diff --git a/oslo_messaging/_drivers/zmq_driver/matchmaker/zmq_matchmaker_redis.py b/oslo_messaging/_drivers/zmq_driver/matchmaker/zmq_matchmaker_redis.py index ff2036800..22ad912fc 100644 --- a/oslo_messaging/_drivers/zmq_driver/matchmaker/zmq_matchmaker_redis.py +++ b/oslo_messaging/_drivers/zmq_driver/matchmaker/zmq_matchmaker_redis.py @@ -239,8 +239,6 @@ class MatchmakerRedis(zmq_matchmaker_base.MatchmakerBase): @redis_connection_warn def get_hosts_fanout(self, target, listener_type): - LOG.debug("[Redis] get_hosts for target %s", target) - hosts = [] if target.topic and target.server: @@ -250,18 +248,19 @@ class MatchmakerRedis(zmq_matchmaker_base.MatchmakerBase): key = zmq_address.prefix_str(target.topic, listener_type) hosts.extend(self._get_hosts_by_key(key)) + LOG.debug("[Redis] get_hosts_fanout for target %(target)s: %(hosts)s", + {"target": target, "hosts": hosts}) + return hosts def get_hosts_fanout_retry(self, target, listener_type): return self._retry_method(target, listener_type, self.get_hosts_fanout) def _retry_method(self, target, listener_type, method): - conf = self.conf - @retry(retry_on_result=retry_if_empty, wrap_exception=True, - wait_fixed=conf.matchmaker_redis.wait_timeout, - stop_max_delay=conf.matchmaker_redis.check_timeout) + wait_fixed=self.conf.matchmaker_redis.wait_timeout, + stop_max_delay=self.conf.matchmaker_redis.check_timeout) def _get_hosts_retry(target, listener_type): return method(target, listener_type) return _get_hosts_retry(target, listener_type) diff --git a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py index 3e3e1107d..0ec03ceb3 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py +++ b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py @@ -40,7 +40,7 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): def __init__(self, conf, poller, server): self.ack_sender = zmq_senders.AckSenderProxy(conf) self.reply_sender = zmq_senders.ReplySenderProxy(conf) - self.received_messages = zmq_ttl_cache.TTLCache( + self.messages_cache = zmq_ttl_cache.TTLCache( ttl=conf.oslo_messaging_zmq.rpc_message_ttl ) self.sockets_manager = zmq_sockets_manager.SocketsManager( @@ -92,11 +92,11 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): message = zmq_incoming_message.ZmqIncomingMessage( context, message, reply_id, message_id, socket, - ack_sender, reply_sender + ack_sender, reply_sender, self.messages_cache ) - # drop duplicate message - if message_id in self.received_messages: + # drop a duplicate message + if message_id in self.messages_cache: LOG.warning( _LW("[%(host)s] Dropping duplicate %(msg_type)s " "message %(msg_id)s"), @@ -104,10 +104,16 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): "msg_type": zmq_names.message_type_str(msg_type), "msg_id": message_id} ) + # NOTE(gdavoian): send yet another ack for the non-CALL + # message, since the old one might be lost; + # for the CALL message also try to resend its reply + # (of course, if it was already obtained and cached). message.acknowledge() + if msg_type == zmq_names.CALL_TYPE: + message.reply_from_cache() return None - self.received_messages.add(message_id) + self.messages_cache.add(message_id) LOG.debug( "[%(host)s] Received %(msg_type)s message %(msg_id)s", {"host": self.host, @@ -124,7 +130,7 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): def cleanup(self): LOG.info(_LI("[%s] Destroy DEALER consumer"), self.host) - self.received_messages.cleanup() + self.messages_cache.cleanup() self.connection_updater.cleanup() super(DealerConsumer, self).cleanup() diff --git a/oslo_messaging/_drivers/zmq_driver/server/zmq_incoming_message.py b/oslo_messaging/_drivers/zmq_driver/server/zmq_incoming_message.py index d6ab57328..493e50940 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/zmq_incoming_message.py +++ b/oslo_messaging/_drivers/zmq_driver/server/zmq_incoming_message.py @@ -25,9 +25,14 @@ zmq = zmq_async.import_zmq() class ZmqIncomingMessage(base.RpcIncomingMessage): + """Base class for RPC-messages via ZMQ-driver. + Each message may send either acks/replies or just nothing + (if acks are disabled and replies are not supported). + """ def __init__(self, context, message, reply_id=None, message_id=None, - socket=None, ack_sender=None, reply_sender=None): + socket=None, ack_sender=None, reply_sender=None, + replies_cache=None): if ack_sender is not None or reply_sender is not None: assert socket is not None, "Valid socket expected!" @@ -41,6 +46,7 @@ class ZmqIncomingMessage(base.RpcIncomingMessage): self.socket = socket self.ack_sender = ack_sender self.reply_sender = reply_sender + self.replies_cache = replies_cache def acknowledge(self): if self.ack_sender is not None: @@ -57,6 +63,14 @@ class ZmqIncomingMessage(base.RpcIncomingMessage): reply_body=reply, failure=failure) self.reply_sender.send(self.socket, reply) + if self.replies_cache is not None: + self.replies_cache.add(self.message_id, reply) + + def reply_from_cache(self): + if self.reply_sender is not None and self.replies_cache is not None: + reply = self.replies_cache.get(self.message_id) + if reply is not None: + self.reply_sender.send(self.socket, reply) def requeue(self): """Requeue is not supported""" diff --git a/oslo_messaging/_drivers/zmq_driver/server/zmq_ttl_cache.py b/oslo_messaging/_drivers/zmq_driver/server/zmq_ttl_cache.py index 963d2d912..49edfbc37 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/zmq_ttl_cache.py +++ b/oslo_messaging/_drivers/zmq_driver/server/zmq_ttl_cache.py @@ -24,9 +24,11 @@ zmq = zmq_async.import_zmq() class TTLCache(object): + _UNDEFINED = object() + def __init__(self, ttl=None): self._lock = threading.Lock() - self._expiration_times = {} + self._cache = {} self._executor = None if not (ttl is None or isinstance(ttl, (int, float))): @@ -47,30 +49,31 @@ class TTLCache(object): def _is_expired(expiration_time, current_time): return expiration_time <= current_time - def add(self, item): + def add(self, key, value=None): with self._lock: - self._expiration_times[item] = time.time() + self._ttl + expiration_time = time.time() + self._ttl + self._cache[key] = (value, expiration_time) - def discard(self, item): + def get(self, key, default=None): with self._lock: - self._expiration_times.pop(item, None) - - def __contains__(self, item): - with self._lock: - expiration_time = self._expiration_times.get(item) - if expiration_time is None: - return False + data = self._cache.get(key) + if data is None: + return default + value, expiration_time = data if self._is_expired(expiration_time, time.time()): - self._expiration_times.pop(item) - return False - return True + del self._cache[key] + return default + return value + + def __contains__(self, key): + return self.get(key, self._UNDEFINED) is not self._UNDEFINED def _update_cache(self): with self._lock: current_time = time.time() - self._expiration_times = \ - {item: expiration_time for - item, expiration_time in six.iteritems(self._expiration_times) + self._cache = \ + {key: (value, expiration_time) for + key, (value, expiration_time) in six.iteritems(self._cache) if not self._is_expired(expiration_time, current_time)} time.sleep(self._ttl) diff --git a/oslo_messaging/_drivers/zmq_driver/zmq_options.py b/oslo_messaging/_drivers/zmq_driver/zmq_options.py index d59c9b0cd..f7150b156 100644 --- a/oslo_messaging/_drivers/zmq_driver/zmq_options.py +++ b/oslo_messaging/_drivers/zmq_driver/zmq_options.py @@ -117,10 +117,8 @@ zmq_opts = [ 'True means not keeping a queue when server side ' 'disconnects. False means to keep queue and messages ' 'even if server is disconnected, when the server ' - 'appears we send all accumulated messages to it.') -] + 'appears we send all accumulated messages to it.'), -zmq_ack_retry_opts = [ cfg.IntOpt('rpc_thread_pool_size', default=100, help='Maximum number of (green) threads to work concurrently.'), @@ -133,7 +131,7 @@ zmq_ack_retry_opts = [ help='Wait for message acknowledgements from receivers. ' 'This mechanism works only via proxy without PUB/SUB.'), - cfg.IntOpt('rpc_ack_timeout_base', default=10, + cfg.IntOpt('rpc_ack_timeout_base', default=15, help='Number of seconds to wait for an ack from a cast/call. ' 'After each retry attempt this timeout is multiplied by ' 'some specified multiplier.'), @@ -155,6 +153,5 @@ def register_opts(conf): opt_group = cfg.OptGroup(name='oslo_messaging_zmq', title='ZeroMQ driver options') conf.register_opts(zmq_opts, group=opt_group) - conf.register_opts(zmq_ack_retry_opts, group=opt_group) conf.register_opts(server._pool_opts) conf.register_opts(base.base_opts) diff --git a/oslo_messaging/tests/drivers/zmq/test_zmq_ack_manager.py b/oslo_messaging/tests/drivers/zmq/test_zmq_ack_manager.py index 05a230198..744b9ba2b 100644 --- a/oslo_messaging/tests/drivers/zmq/test_zmq_ack_manager.py +++ b/oslo_messaging/tests/drivers/zmq/test_zmq_ack_manager.py @@ -43,7 +43,7 @@ class TestZmqAckManager(test_utils.BaseTestCase): 'use_router_proxy': True, 'rpc_thread_pool_size': 1, 'rpc_use_acks': True, - 'rpc_ack_timeout_base': 3, + 'rpc_ack_timeout_base': 5, 'rpc_ack_timeout_multiplier': 1, 'rpc_retry_attempts': 2} self.config(group='oslo_messaging_zmq', **kwargs) @@ -104,8 +104,8 @@ class TestZmqAckManager(test_utils.BaseTestCase): result = self.driver.send( self.target, {}, self.message, wait_for_reply=False ) - self.ack_manager._pool.shutdown(wait=True) self.assertIsNone(result) + self.ack_manager._pool.shutdown(wait=True) self.assertTrue(self.listener._received.isSet()) self.assertEqual(self.message, self.listener.message.message) self.assertEqual(1, self.send.call_count) @@ -119,21 +119,20 @@ class TestZmqAckManager(test_utils.BaseTestCase): result = self.driver.send( self.target, {}, self.message, wait_for_reply=False ) - self.listener._received.wait(3) self.assertIsNone(result) + self.listener._received.wait(5) self.assertTrue(self.listener._received.isSet()) self.assertEqual(self.message, self.listener.message.message) self.assertEqual(1, self.send.call_count) self.assertEqual(1, lost_ack_mock.call_count) self.assertEqual(0, self.set_result.call_count) + self.listener._received.clear() with mock.patch.object( zmq_incoming_message.ZmqIncomingMessage, 'acknowledge', side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge, autospec=True ) as received_ack_mock: - self.listener._received.clear() self.ack_manager._pool.shutdown(wait=True) - self.listener._received.wait(3) self.assertFalse(self.listener._received.isSet()) self.assertEqual(2, self.send.call_count) self.assertEqual(1, received_ack_mock.call_count) @@ -146,15 +145,15 @@ class TestZmqAckManager(test_utils.BaseTestCase): result = self.driver.send( self.target, {}, self.message, wait_for_reply=False ) - self.listener._received.wait(3) self.assertIsNone(result) + self.listener._received.wait(5) self.assertTrue(self.listener._received.isSet()) self.assertEqual(self.message, self.listener.message.message) self.assertEqual(1, self.send.call_count) self.assertEqual(1, lost_ack_mock.call_count) self.assertEqual(0, self.set_result.call_count) self.listener._received.clear() - self.listener._received.wait(4.5) + self.listener._received.wait(7.5) self.assertFalse(self.listener._received.isSet()) self.assertEqual(2, self.send.call_count) self.assertEqual(2, lost_ack_mock.call_count) @@ -176,10 +175,53 @@ class TestZmqAckManager(test_utils.BaseTestCase): result = self.driver.send( self.target, {}, self.message, wait_for_reply=False ) - self.ack_manager._pool.shutdown(wait=True) self.assertIsNone(result) + self.ack_manager._pool.shutdown(wait=True) self.assertTrue(self.listener._received.isSet()) self.assertEqual(self.message, self.listener.message.message) self.assertEqual(3, self.send.call_count) self.assertEqual(3, lost_ack_mock.call_count) self.assertEqual(1, self.set_result.call_count) + + @mock.patch.object( + zmq_incoming_message.ZmqIncomingMessage, 'acknowledge', + side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge, + autospec=True + ) + @mock.patch.object( + zmq_incoming_message.ZmqIncomingMessage, 'reply', + side_effect=zmq_incoming_message.ZmqIncomingMessage.reply, + autospec=True + ) + def test_call_success_without_retries(self, received_reply_mock, + received_ack_mock): + self.listener.listen(self.target) + result = self.driver.send( + self.target, {}, self.message, wait_for_reply=True, timeout=10 + ) + self.assertIsNotNone(result) + self.ack_manager._pool.shutdown(wait=True) + self.assertTrue(self.listener._received.isSet()) + self.assertEqual(self.message, self.listener.message.message) + self.assertEqual(1, self.send.call_count) + self.assertEqual(1, received_ack_mock.call_count) + self.assertEqual(3, self.set_result.call_count) + received_reply_mock.assert_called_once_with(mock.ANY, reply=True) + + @mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, 'acknowledge') + @mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, 'reply') + def test_call_failure_exhausted_retries_and_timeout_error(self, + lost_reply_mock, + lost_ack_mock): + self.listener.listen(self.target) + self.assertRaises(oslo_messaging.MessagingTimeout, + self.driver.send, + self.target, {}, self.message, + wait_for_reply=True, timeout=20) + self.ack_manager._pool.shutdown(wait=True) + self.assertTrue(self.listener._received.isSet()) + self.assertEqual(self.message, self.listener.message.message) + self.assertEqual(3, self.send.call_count) + self.assertEqual(3, lost_ack_mock.call_count) + self.assertEqual(2, self.set_result.call_count) + lost_reply_mock.assert_called_once_with(reply=True) diff --git a/oslo_messaging/tests/drivers/zmq/test_zmq_ttl_cache.py b/oslo_messaging/tests/drivers/zmq/test_zmq_ttl_cache.py index fa2e2408e..60a5af0ac 100644 --- a/oslo_messaging/tests/drivers/zmq/test_zmq_ttl_cache.py +++ b/oslo_messaging/tests/drivers/zmq/test_zmq_ttl_cache.py @@ -35,6 +35,28 @@ class TestZmqTTLCache(test_utils.BaseTestCase): self.cache = zmq_ttl_cache.TTLCache(ttl=1) + self.addCleanup(lambda: self.cache.cleanup()) + + def _test_add_get(self): + self.cache.add('x', 'a') + + self.assertEqual(self.cache.get('x'), 'a') + self.assertEqual(self.cache.get('x', 'b'), 'a') + self.assertEqual(self.cache.get('y'), None) + self.assertEqual(self.cache.get('y', 'b'), 'b') + + time.sleep(1) + + self.assertEqual(self.cache.get('x'), None) + self.assertEqual(self.cache.get('x', 'b'), 'b') + + def test_add_get_with_executor(self): + self._test_add_get() + + def test_add_get_without_executor(self): + self.cache._executor.stop() + self._test_add_get() + def _test_in_operator(self): self.cache.add(1) @@ -67,15 +89,15 @@ class TestZmqTTLCache(test_utils.BaseTestCase): self.cache._executor.stop() self._test_in_operator() - def _is_expired(self, item): + def _is_expired(self, key): with self.cache._lock: - return self.cache._is_expired(self.cache._expiration_times[item], - time.time()) + _, expiration_time = self.cache._cache[key] + return self.cache._is_expired(expiration_time, time.time()) def test_executor(self): self.cache.add(1) - self.assertEqual([1], sorted(self.cache._expiration_times.keys())) + self.assertEqual([1], sorted(self.cache._cache.keys())) self.assertFalse(self._is_expired(1)) time.sleep(0.75) @@ -84,7 +106,7 @@ class TestZmqTTLCache(test_utils.BaseTestCase): self.cache.add(2) - self.assertEqual([1, 2], sorted(self.cache._expiration_times.keys())) + self.assertEqual([1, 2], sorted(self.cache._cache.keys())) self.assertFalse(self._is_expired(1)) self.assertFalse(self._is_expired(2)) @@ -95,12 +117,10 @@ class TestZmqTTLCache(test_utils.BaseTestCase): self.cache.add(3) if 1 in self.cache: - self.assertEqual([1, 2, 3], - sorted(self.cache._expiration_times.keys())) + self.assertEqual([1, 2, 3], sorted(self.cache._cache.keys())) self.assertTrue(self._is_expired(1)) else: - self.assertEqual([2, 3], - sorted(self.cache._expiration_times.keys())) + self.assertEqual([2, 3], sorted(self.cache._cache.keys())) self.assertFalse(self._is_expired(2)) self.assertFalse(self._is_expired(3)) @@ -108,9 +128,5 @@ class TestZmqTTLCache(test_utils.BaseTestCase): self.assertEqual(3, self.cache._update_cache.call_count) - self.assertEqual([3], sorted(self.cache._expiration_times.keys())) + self.assertEqual([3], sorted(self.cache._cache.keys())) self.assertFalse(self._is_expired(3)) - - def cleanUp(self): - self.cache.cleanup() - super(TestZmqTTLCache, self).cleanUp()