[zmq] Implement retries for unacknowledged CALLs

This patch tries to implement a mechanism of acknowledgements and
retries via proxy for CALL messages.

Change-Id: I33f7a14045740b5486f18d456f7219a6ab59d910
Closes-Bug: #1497306
Closes-Bug: #1503295
Depends-On: I83919382262b9f169becd09f5db465a01a0ccb78
This commit is contained in:
Gevorg Davoian 2016-08-10 13:23:51 +03:00
parent 4f8fcb332d
commit fab75e78c1
14 changed files with 237 additions and 152 deletions

View File

@ -12,11 +12,9 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
from concurrent import futures
import logging
import oslo_messaging
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver.client.publishers \
import zmq_publisher_base
@ -41,27 +39,17 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase):
super(DealerPublisherBase, self).__init__(sockets_manager, sender,
receiver)
@staticmethod
def _check_pattern(request, supported_pattern):
if request.msg_type != supported_pattern:
raise zmq_publisher_base.UnsupportedSendPattern(
zmq_names.message_type_str(request.msg_type)
)
def _check_received_data(self, reply_id, reply, request):
assert isinstance(reply, zmq_response.Reply), "Reply expected!"
@staticmethod
def _raise_timeout(request):
raise oslo_messaging.MessagingTimeout(
"Timeout %(tout)s seconds was reached for message %(msg_id)s" %
{"tout": request.timeout, "msg_id": request.message_id}
)
def _recv_reply(self, request):
def _recv_reply(self, request, socket):
self.receiver.register_socket(socket)
reply_future = \
self.receiver.track_request(request)[zmq_names.REPLY_TYPE]
try:
_, reply = reply_future.result(timeout=request.timeout)
assert isinstance(reply, zmq_response.Reply), "Reply expected!"
reply_id, reply = reply_future.result(timeout=request.timeout)
self._check_received_data(reply_id, reply, request)
except AssertionError:
LOG.error(_LE("Message format error in reply for %s"),
request.message_id)
@ -77,30 +65,3 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase):
)
else:
return reply.reply_body
def send_call(self, request):
self._check_pattern(request, zmq_names.CALL_TYPE)
socket = self.connect_socket(request)
if not socket:
self._raise_timeout(request)
self.sender.send(socket, request)
self.receiver.register_socket(socket)
return self._recv_reply(request)
@abc.abstractmethod
def _send_non_blocking(self, request):
pass
def send_cast(self, request):
self._check_pattern(request, zmq_names.CAST_TYPE)
self._send_non_blocking(request)
def send_fanout(self, request):
self._check_pattern(request, zmq_names.CAST_FANOUT_TYPE)
self._send_non_blocking(request)
def send_notify(self, request):
self._check_pattern(request, zmq_names.NOTIFY_TYPE)
self._send_non_blocking(request)

View File

@ -40,19 +40,19 @@ class DealerPublisherDirect(zmq_dealer_publisher_base.DealerPublisherBase):
super(DealerPublisherDirect, self).__init__(conf, matchmaker, sender,
receiver)
def connect_socket(self, request):
def _connect_socket(self, request):
try:
return self.sockets_manager.get_socket(request.target)
except retrying.RetryError:
return None
def _send_non_blocking(self, request):
socket = self.connect_socket(request)
def _send_request(self, request):
socket = self._connect_socket(request)
if not socket:
return
return None
if request.msg_type in zmq_names.MULTISEND_TYPES:
for _ in range(socket.connections_count()):
self.sender.send(socket, request)
else:
self.sender.send(socket, request)
return socket

View File

@ -55,15 +55,11 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase):
return six.b(self.conf.oslo_messaging_zmq.rpc_zmq_host + "/" +
str(uuid.uuid4()))
def connect_socket(self, request):
return self.socket
def send_call(self, request):
request.routing_key = \
self.routing_table.get_routable_host(request.target)
if request.routing_key is None:
self._raise_timeout(request)
return super(DealerPublisherProxy, self).send_call(request)
def _check_received_data(self, reply_id, reply, request):
super(DealerPublisherProxy, self)._check_received_data(reply_id, reply,
request)
assert reply_id == request.routing_key, \
"Reply from recipient expected!"
def _get_routing_keys(self, request):
if request.msg_type in zmq_names.DIRECT_TYPES:
@ -74,16 +70,20 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase):
if self.conf.oslo_messaging_zmq.use_pub_sub else \
self.routing_table.get_all_hosts(request.target)
def _send_non_blocking(self, request):
for routing_key in self._get_routing_keys(request):
if routing_key is None:
LOG.warning(_LW("Matchmaker contains no record for specified "
"target %(target)s. Dropping message %(id)s.")
% {"target": request.target,
"id": request.message_id})
continue
def _send_request(self, request):
routing_keys = [routing_key
for routing_key in self._get_routing_keys(request)
if routing_key is not None]
if not routing_keys:
LOG.warning(_LW("Matchmaker contains no records for specified "
"target %(target)s. Dropping message %(msg_id)s.")
% {"target": request.target,
"msg_id": request.message_id})
return None
for routing_key in routing_keys:
request.routing_key = routing_key
self.sender.send(self.socket, request)
return self.socket
def cleanup(self):
super(DealerPublisherProxy, self).cleanup()

View File

@ -17,8 +17,10 @@ import logging
import six
import oslo_messaging
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._i18n import _LE
LOG = logging.getLogger(__name__)
@ -72,27 +74,48 @@ class PublisherBase(object):
self.sender = sender
self.receiver = receiver
@staticmethod
def _check_message_pattern(expected, actual):
if expected != actual:
raise UnsupportedSendPattern(zmq_names.message_type_str(actual))
@staticmethod
def _raise_timeout(request):
raise oslo_messaging.MessagingTimeout(
"Timeout %(tout)s seconds was reached for message %(msg_id)s" %
{"tout": request.timeout, "msg_id": request.message_id}
)
@abc.abstractmethod
def connect_socket(self, request):
"""Get connected socket ready for sending given request
or None otherwise (i.e. if connection can't be established).
def _send_request(self, request):
"""Send the request and return a socket used for that.
Return value of None means some failure (e.g. connection
can't be established, etc).
"""
@abc.abstractmethod
def _recv_reply(self, request, socket):
"""Wait for a reply via the socket used for sending the request."""
def send_call(self, request):
pass
self._check_message_pattern(zmq_names.CALL_TYPE, request.msg_type)
socket = self._send_request(request)
if not socket:
raise self._raise_timeout(request)
return self._recv_reply(request, socket)
@abc.abstractmethod
def send_cast(self, request):
pass
self._check_message_pattern(zmq_names.CAST_TYPE, request.msg_type)
self._send_request(request)
@abc.abstractmethod
def send_fanout(self, request):
pass
self._check_message_pattern(zmq_names.CAST_FANOUT_TYPE,
request.msg_type)
self._send_request(request)
@abc.abstractmethod
def send_notify(self, request):
pass
self._check_message_pattern(zmq_names.NOTIFY_TYPE, request.msg_type)
self._send_request(request)
def cleanup(self):
"""Cleanup publisher. Close allocated connections."""

View File

@ -61,9 +61,11 @@ class AckManagerProxy(AckManagerBase):
)
def _wait_for_ack(self, ack_future):
request, socket = ack_future.args
request = ack_future.request
retries = \
request.retry or self.conf.oslo_messaging_zmq.rpc_retry_attempts
if retries is None:
retries = -1
timeout = self.conf.oslo_messaging_zmq.rpc_ack_timeout_base
done = False
@ -72,8 +74,9 @@ class AckManagerProxy(AckManagerBase):
reply_id, response = ack_future.result(timeout=timeout)
done = True
assert response is None, "Ack expected!"
assert reply_id == request.routing_key, \
"Ack from recipient expected!"
if reply_id is not None:
assert reply_id == request.routing_key, \
"Ack from recipient expected!"
except AssertionError:
LOG.error(_LE("Message format error in ack for %s"),
request.message_id)
@ -82,10 +85,10 @@ class AckManagerProxy(AckManagerBase):
"for %(msg_id)s"),
{"tout": timeout,
"msg_id": request.message_id})
if retries is None or retries != 0:
if retries is not None and retries > 0:
if retries != 0:
if retries > 0:
retries -= 1
self.sender.send(socket, request)
self.sender.send(ack_future.socket, request)
timeout *= \
self.conf.oslo_messaging_zmq.rpc_ack_timeout_multiplier
else:
@ -93,18 +96,35 @@ class AckManagerProxy(AckManagerBase):
request.message_id)
done = True
self.receiver.untrack_request(request)
if request.msg_type != zmq_names.CALL_TYPE:
self.receiver.untrack_request(request)
def _get_ack_future(self, request):
socket = self.publisher.connect_socket(request)
def _send_request_and_get_ack_future(self, request):
socket = self.publisher._send_request(request)
if not socket:
return None
self.receiver.register_socket(socket)
ack_future = self.receiver.track_request(request)[zmq_names.ACK_TYPE]
ack_future.args = request, socket
ack_future.request = request
ack_future.socket = socket
return ack_future
def send_call(self, request):
ack_future = self._send_request_and_get_ack_future(request)
if not ack_future:
self.publisher._raise_timeout(request)
self._pool.submit(self._wait_for_ack, ack_future)
try:
return self.publisher._recv_reply(request, ack_future.socket)
finally:
if not ack_future.done():
ack_future.set_result((None, None))
def send_cast(self, request):
self.publisher.send_cast(request)
self._pool.submit(self._wait_for_ack, self._get_ack_future(request))
ack_future = self._send_request_and_get_ack_future(request)
if not ack_future:
return
self._pool.submit(self._wait_for_ack, ack_future)
def cleanup(self):
self._pool.shutdown(wait=True)

View File

@ -63,9 +63,12 @@ class ReceiverBase(object):
a dict of futures for monitoring all types of responses.
"""
futures = {}
message_id = request.message_id
for message_type in self.message_types:
future = futurist.Future()
self._set_future(request.message_id, message_type, future)
future = self._get_future(message_id, message_type)
if future is None:
future = futurist.Future()
self._set_future(message_id, message_type, future)
futures[message_type] = future
return futures
@ -92,7 +95,8 @@ class ReceiverBase(object):
def _run_loop(self):
data, socket = self._poller.poll(
timeout=self.conf.oslo_messaging_zmq.rpc_poll_timeout)
timeout=self.conf.oslo_messaging_zmq.rpc_poll_timeout
)
if data is None:
return
reply_id, message_type, message_id, response = data

View File

@ -76,9 +76,9 @@ class RoutingTable(object):
key = str(target)
if key not in self.routing_table:
try:
self.routing_table[key] = (get_hosts_retry(
target, zmq_names.socket_type_str(zmq.DEALER)),
time.time())
hosts = get_hosts_retry(
target, zmq_names.socket_type_str(zmq.DEALER))
self.routing_table[key] = (hosts, time.time())
except retrying.RetryError:
LOG.warning(_LW("Matchmaker contains no hosts for target %s")
% key)

View File

@ -239,8 +239,6 @@ class MatchmakerRedis(zmq_matchmaker_base.MatchmakerBase):
@redis_connection_warn
def get_hosts_fanout(self, target, listener_type):
LOG.debug("[Redis] get_hosts for target %s", target)
hosts = []
if target.topic and target.server:
@ -250,18 +248,19 @@ class MatchmakerRedis(zmq_matchmaker_base.MatchmakerBase):
key = zmq_address.prefix_str(target.topic, listener_type)
hosts.extend(self._get_hosts_by_key(key))
LOG.debug("[Redis] get_hosts_fanout for target %(target)s: %(hosts)s",
{"target": target, "hosts": hosts})
return hosts
def get_hosts_fanout_retry(self, target, listener_type):
return self._retry_method(target, listener_type, self.get_hosts_fanout)
def _retry_method(self, target, listener_type, method):
conf = self.conf
@retry(retry_on_result=retry_if_empty,
wrap_exception=True,
wait_fixed=conf.matchmaker_redis.wait_timeout,
stop_max_delay=conf.matchmaker_redis.check_timeout)
wait_fixed=self.conf.matchmaker_redis.wait_timeout,
stop_max_delay=self.conf.matchmaker_redis.check_timeout)
def _get_hosts_retry(target, listener_type):
return method(target, listener_type)
return _get_hosts_retry(target, listener_type)

View File

@ -40,7 +40,7 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server):
self.ack_sender = zmq_senders.AckSenderProxy(conf)
self.reply_sender = zmq_senders.ReplySenderProxy(conf)
self.received_messages = zmq_ttl_cache.TTLCache(
self.messages_cache = zmq_ttl_cache.TTLCache(
ttl=conf.oslo_messaging_zmq.rpc_message_ttl
)
self.sockets_manager = zmq_sockets_manager.SocketsManager(
@ -92,11 +92,11 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
message = zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id, message_id, socket,
ack_sender, reply_sender
ack_sender, reply_sender, self.messages_cache
)
# drop duplicate message
if message_id in self.received_messages:
# drop a duplicate message
if message_id in self.messages_cache:
LOG.warning(
_LW("[%(host)s] Dropping duplicate %(msg_type)s "
"message %(msg_id)s"),
@ -104,10 +104,16 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id}
)
# NOTE(gdavoian): send yet another ack for the non-CALL
# message, since the old one might be lost;
# for the CALL message also try to resend its reply
# (of course, if it was already obtained and cached).
message.acknowledge()
if msg_type == zmq_names.CALL_TYPE:
message.reply_from_cache()
return None
self.received_messages.add(message_id)
self.messages_cache.add(message_id)
LOG.debug(
"[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
@ -124,7 +130,7 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
def cleanup(self):
LOG.info(_LI("[%s] Destroy DEALER consumer"), self.host)
self.received_messages.cleanup()
self.messages_cache.cleanup()
self.connection_updater.cleanup()
super(DealerConsumer, self).cleanup()

View File

@ -25,9 +25,14 @@ zmq = zmq_async.import_zmq()
class ZmqIncomingMessage(base.RpcIncomingMessage):
"""Base class for RPC-messages via ZMQ-driver.
Each message may send either acks/replies or just nothing
(if acks are disabled and replies are not supported).
"""
def __init__(self, context, message, reply_id=None, message_id=None,
socket=None, ack_sender=None, reply_sender=None):
socket=None, ack_sender=None, reply_sender=None,
replies_cache=None):
if ack_sender is not None or reply_sender is not None:
assert socket is not None, "Valid socket expected!"
@ -41,6 +46,7 @@ class ZmqIncomingMessage(base.RpcIncomingMessage):
self.socket = socket
self.ack_sender = ack_sender
self.reply_sender = reply_sender
self.replies_cache = replies_cache
def acknowledge(self):
if self.ack_sender is not None:
@ -57,6 +63,14 @@ class ZmqIncomingMessage(base.RpcIncomingMessage):
reply_body=reply,
failure=failure)
self.reply_sender.send(self.socket, reply)
if self.replies_cache is not None:
self.replies_cache.add(self.message_id, reply)
def reply_from_cache(self):
if self.reply_sender is not None and self.replies_cache is not None:
reply = self.replies_cache.get(self.message_id)
if reply is not None:
self.reply_sender.send(self.socket, reply)
def requeue(self):
"""Requeue is not supported"""

View File

@ -24,9 +24,11 @@ zmq = zmq_async.import_zmq()
class TTLCache(object):
_UNDEFINED = object()
def __init__(self, ttl=None):
self._lock = threading.Lock()
self._expiration_times = {}
self._cache = {}
self._executor = None
if not (ttl is None or isinstance(ttl, (int, float))):
@ -47,30 +49,31 @@ class TTLCache(object):
def _is_expired(expiration_time, current_time):
return expiration_time <= current_time
def add(self, item):
def add(self, key, value=None):
with self._lock:
self._expiration_times[item] = time.time() + self._ttl
expiration_time = time.time() + self._ttl
self._cache[key] = (value, expiration_time)
def discard(self, item):
def get(self, key, default=None):
with self._lock:
self._expiration_times.pop(item, None)
def __contains__(self, item):
with self._lock:
expiration_time = self._expiration_times.get(item)
if expiration_time is None:
return False
data = self._cache.get(key)
if data is None:
return default
value, expiration_time = data
if self._is_expired(expiration_time, time.time()):
self._expiration_times.pop(item)
return False
return True
del self._cache[key]
return default
return value
def __contains__(self, key):
return self.get(key, self._UNDEFINED) is not self._UNDEFINED
def _update_cache(self):
with self._lock:
current_time = time.time()
self._expiration_times = \
{item: expiration_time for
item, expiration_time in six.iteritems(self._expiration_times)
self._cache = \
{key: (value, expiration_time) for
key, (value, expiration_time) in six.iteritems(self._cache)
if not self._is_expired(expiration_time, current_time)}
time.sleep(self._ttl)

View File

@ -117,10 +117,8 @@ zmq_opts = [
'True means not keeping a queue when server side '
'disconnects. False means to keep queue and messages '
'even if server is disconnected, when the server '
'appears we send all accumulated messages to it.')
]
'appears we send all accumulated messages to it.'),
zmq_ack_retry_opts = [
cfg.IntOpt('rpc_thread_pool_size', default=100,
help='Maximum number of (green) threads to work concurrently.'),
@ -133,7 +131,7 @@ zmq_ack_retry_opts = [
help='Wait for message acknowledgements from receivers. '
'This mechanism works only via proxy without PUB/SUB.'),
cfg.IntOpt('rpc_ack_timeout_base', default=10,
cfg.IntOpt('rpc_ack_timeout_base', default=15,
help='Number of seconds to wait for an ack from a cast/call. '
'After each retry attempt this timeout is multiplied by '
'some specified multiplier.'),
@ -155,6 +153,5 @@ def register_opts(conf):
opt_group = cfg.OptGroup(name='oslo_messaging_zmq',
title='ZeroMQ driver options')
conf.register_opts(zmq_opts, group=opt_group)
conf.register_opts(zmq_ack_retry_opts, group=opt_group)
conf.register_opts(server._pool_opts)
conf.register_opts(base.base_opts)

View File

@ -43,7 +43,7 @@ class TestZmqAckManager(test_utils.BaseTestCase):
'use_router_proxy': True,
'rpc_thread_pool_size': 1,
'rpc_use_acks': True,
'rpc_ack_timeout_base': 3,
'rpc_ack_timeout_base': 5,
'rpc_ack_timeout_multiplier': 1,
'rpc_retry_attempts': 2}
self.config(group='oslo_messaging_zmq', **kwargs)
@ -104,8 +104,8 @@ class TestZmqAckManager(test_utils.BaseTestCase):
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.ack_manager._pool.shutdown(wait=True)
self.assertIsNone(result)
self.ack_manager._pool.shutdown(wait=True)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
@ -119,21 +119,20 @@ class TestZmqAckManager(test_utils.BaseTestCase):
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.listener._received.wait(3)
self.assertIsNone(result)
self.listener._received.wait(5)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
self.assertEqual(1, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count)
self.listener._received.clear()
with mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, 'acknowledge',
side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge,
autospec=True
) as received_ack_mock:
self.listener._received.clear()
self.ack_manager._pool.shutdown(wait=True)
self.listener._received.wait(3)
self.assertFalse(self.listener._received.isSet())
self.assertEqual(2, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count)
@ -146,15 +145,15 @@ class TestZmqAckManager(test_utils.BaseTestCase):
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.listener._received.wait(3)
self.assertIsNone(result)
self.listener._received.wait(5)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
self.assertEqual(1, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count)
self.listener._received.clear()
self.listener._received.wait(4.5)
self.listener._received.wait(7.5)
self.assertFalse(self.listener._received.isSet())
self.assertEqual(2, self.send.call_count)
self.assertEqual(2, lost_ack_mock.call_count)
@ -176,10 +175,53 @@ class TestZmqAckManager(test_utils.BaseTestCase):
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False
)
self.ack_manager._pool.shutdown(wait=True)
self.assertIsNone(result)
self.ack_manager._pool.shutdown(wait=True)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(3, self.send.call_count)
self.assertEqual(3, lost_ack_mock.call_count)
self.assertEqual(1, self.set_result.call_count)
@mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, 'acknowledge',
side_effect=zmq_incoming_message.ZmqIncomingMessage.acknowledge,
autospec=True
)
@mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, 'reply',
side_effect=zmq_incoming_message.ZmqIncomingMessage.reply,
autospec=True
)
def test_call_success_without_retries(self, received_reply_mock,
received_ack_mock):
self.listener.listen(self.target)
result = self.driver.send(
self.target, {}, self.message, wait_for_reply=True, timeout=10
)
self.assertIsNotNone(result)
self.ack_manager._pool.shutdown(wait=True)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(1, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count)
self.assertEqual(3, self.set_result.call_count)
received_reply_mock.assert_called_once_with(mock.ANY, reply=True)
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, 'acknowledge')
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, 'reply')
def test_call_failure_exhausted_retries_and_timeout_error(self,
lost_reply_mock,
lost_ack_mock):
self.listener.listen(self.target)
self.assertRaises(oslo_messaging.MessagingTimeout,
self.driver.send,
self.target, {}, self.message,
wait_for_reply=True, timeout=20)
self.ack_manager._pool.shutdown(wait=True)
self.assertTrue(self.listener._received.isSet())
self.assertEqual(self.message, self.listener.message.message)
self.assertEqual(3, self.send.call_count)
self.assertEqual(3, lost_ack_mock.call_count)
self.assertEqual(2, self.set_result.call_count)
lost_reply_mock.assert_called_once_with(reply=True)

View File

@ -35,6 +35,28 @@ class TestZmqTTLCache(test_utils.BaseTestCase):
self.cache = zmq_ttl_cache.TTLCache(ttl=1)
self.addCleanup(lambda: self.cache.cleanup())
def _test_add_get(self):
self.cache.add('x', 'a')
self.assertEqual(self.cache.get('x'), 'a')
self.assertEqual(self.cache.get('x', 'b'), 'a')
self.assertEqual(self.cache.get('y'), None)
self.assertEqual(self.cache.get('y', 'b'), 'b')
time.sleep(1)
self.assertEqual(self.cache.get('x'), None)
self.assertEqual(self.cache.get('x', 'b'), 'b')
def test_add_get_with_executor(self):
self._test_add_get()
def test_add_get_without_executor(self):
self.cache._executor.stop()
self._test_add_get()
def _test_in_operator(self):
self.cache.add(1)
@ -67,15 +89,15 @@ class TestZmqTTLCache(test_utils.BaseTestCase):
self.cache._executor.stop()
self._test_in_operator()
def _is_expired(self, item):
def _is_expired(self, key):
with self.cache._lock:
return self.cache._is_expired(self.cache._expiration_times[item],
time.time())
_, expiration_time = self.cache._cache[key]
return self.cache._is_expired(expiration_time, time.time())
def test_executor(self):
self.cache.add(1)
self.assertEqual([1], sorted(self.cache._expiration_times.keys()))
self.assertEqual([1], sorted(self.cache._cache.keys()))
self.assertFalse(self._is_expired(1))
time.sleep(0.75)
@ -84,7 +106,7 @@ class TestZmqTTLCache(test_utils.BaseTestCase):
self.cache.add(2)
self.assertEqual([1, 2], sorted(self.cache._expiration_times.keys()))
self.assertEqual([1, 2], sorted(self.cache._cache.keys()))
self.assertFalse(self._is_expired(1))
self.assertFalse(self._is_expired(2))
@ -95,12 +117,10 @@ class TestZmqTTLCache(test_utils.BaseTestCase):
self.cache.add(3)
if 1 in self.cache:
self.assertEqual([1, 2, 3],
sorted(self.cache._expiration_times.keys()))
self.assertEqual([1, 2, 3], sorted(self.cache._cache.keys()))
self.assertTrue(self._is_expired(1))
else:
self.assertEqual([2, 3],
sorted(self.cache._expiration_times.keys()))
self.assertEqual([2, 3], sorted(self.cache._cache.keys()))
self.assertFalse(self._is_expired(2))
self.assertFalse(self._is_expired(3))
@ -108,9 +128,5 @@ class TestZmqTTLCache(test_utils.BaseTestCase):
self.assertEqual(3, self.cache._update_cache.call_count)
self.assertEqual([3], sorted(self.cache._expiration_times.keys()))
self.assertEqual([3], sorted(self.cache._cache.keys()))
self.assertFalse(self._is_expired(3))
def cleanUp(self):
self.cache.cleanup()
super(TestZmqTTLCache, self).cleanUp()