Rework how queues get created in fake driver

Currently, if there are no servers listening on a topic then a message
to that topic just gets dropped by the fake driver.

This makes the tests intermittently fail if the server takes longer to
start.

Turn things on their head so that the client always creates the queues
on the exchange so that messages can get queued up even if there is no
server listening.

Now we also need to delete the "duplicate server on topic" test - it's
actually fine to have multiple servers listening on the one topic.
This commit is contained in:
Mark McLoughlin 2013-06-15 13:57:04 +01:00
parent 8bf3c862b3
commit 978d19c256
2 changed files with 73 additions and 82 deletions

View File

@ -17,6 +17,7 @@
import json
import Queue
import threading
import time
from oslo import messaging
@ -43,33 +44,58 @@ class FakeIncomingMessage(base.IncomingMessage):
class FakeListener(base.Listener):
def __init__(self, driver, target):
def __init__(self, driver, target, exchange):
super(FakeListener, self).__init__(driver, target)
self._queue = Queue.Queue()
self._reply_queue = Queue.Queue()
def _deliver_message(self, ctxt, message,
wait_for_reply=None, timeout=None):
self._queue.put((ctxt, message))
if wait_for_reply:
try:
return self._reply_queue.get(timeout=timeout)
except Queue.Empty:
# FIXME(markmc): timeout exception
return None
self._exchange = exchange
def _deliver_reply(self, reply=None, failure=None):
# FIXME: handle failure
self._reply_queue.put(reply)
if self._reply_q:
self._reply_q.put(reply)
def poll(self):
self._reply_q = None
while True:
# sleeping allows keyboard interrupts
try:
(ctxt, message) = self._queue.get(block=False)
(ctxt, message, reply_q) = self._exchange.poll(self.target)
if message is not None:
self._reply_q = reply_q
return FakeIncomingMessage(self, ctxt, message)
except Queue.Empty:
time.sleep(.05)
time.sleep(.05)
class FakeExchange(object):
def __init__(self, name):
self.name = name
self._queues_lock = threading.Lock()
self._topic_queues = {}
self._server_queues = {}
def _get_topic_queue(self, topic):
return self._topic_queues.setdefault(topic, [])
def _get_server_queue(self, topic, server):
return self._server_queues.setdefault((topic, server), [])
def deliver_message(self, topic, ctxt, message,
server=None, fanout=False, reply_q=None):
with self._queues_lock:
if fanout:
queues = [q for t, q in self._server_queues.items()
if t[0] == topic]
elif server is not None:
queues = [self._get_server_queue(topic, server)]
else:
queues = [self._get_topic_queue(topic)]
for queue in queues:
queue.append((ctxt, message, reply_q))
def poll(self, target):
with self._queues_lock:
queue = self._get_server_queue(target.topic, target.server)
if not queue:
queue = self._get_topic_queue(target.topic)
return queue.pop(0) if queue else (None, None, None)
class FakeDriver(base.BaseDriver):
@ -79,6 +105,7 @@ class FakeDriver(base.BaseDriver):
self._default_exchange = utils.exchange_from_url(url, default_exchange)
self._exchanges_lock = threading.Lock()
self._exchanges = {}
@staticmethod
@ -93,6 +120,10 @@ class FakeDriver(base.BaseDriver):
"""
json.dumps(message)
def _get_exchange(self, name):
while self._exchanges_lock:
return self._exchanges.setdefault(name, FakeExchange(name))
def send(self, target, ctxt, message,
wait_for_reply=None, timeout=None, envelope=False):
if not target.topic:
@ -104,55 +135,35 @@ class FakeDriver(base.BaseDriver):
self._check_serialize(message)
exchange = target.exchange or self._default_exchange
exchange = self._get_exchange(target.exchange or
self._default_exchange)
start_time = time.time()
while True:
topics = self._exchanges.get(exchange, {})
listeners = topics.get(target.topic, [])
if target.server:
listeners = [l for l in listeners
if l.target.server == target.server]
# FIXME(markmc): Need to create and pass a reply_queue
if listeners or not wait_for_reply:
break
reply_q = None
if wait_for_reply:
reply_q = Queue.Queue()
if timeout and (time.time() - start_time > timeout):
exchange.deliver_message(target.topic, ctxt, message,
server=target.server,
fanout=target.fanout,
reply_q=reply_q)
if wait_for_reply:
try:
return reply_q.get(timeout=timeout)
except Queue.Empty:
raise messaging.MessagingTimeout(
'No listeners found for topic %s' % target.topic)
'No reply on topic %s' % target.topic)
time.sleep(.05)
if target.fanout:
for listener in listeners:
ret = listener._deliver_message(ctxt, message)
if ret:
return ret
return
if not listeners:
# FIXME(markmc): timeout exception
return None
# FIXME(markmc): implement round-robin delivery
listener = listeners[0]
return listener._deliver_message(ctxt, message,
wait_for_reply, timeout)
return None
def listen(self, target):
if not (target.topic and target.server):
raise InvalidTarget('Topic and server are required to listen',
target)
exchange = target.exchange or self._default_exchange
topics = self._exchanges.setdefault(exchange, {})
exchange = self._get_exchange(target.exchange or
self._default_exchange)
if target.topic in topics:
raise InvalidTarget('Already listening on this topic', target)
listener = FakeListener(self, target)
listeners = topics.setdefault(target.topic, [])
listeners.append(listener)
return listener
return FakeListener(self, target, exchange)

View File

@ -67,7 +67,6 @@ class TestRPCServer(test_utils.BaseTestCase):
client = client.prepare(topic=topic)
client.cast({}, 'stop')
server_thread.join(timeout=30)
self.assertFalse(server_thread.isAlive())
def _setup_client(self, transport):
return messaging.RPCClient(transport,
@ -117,27 +116,6 @@ class TestRPCServer(test_utils.BaseTestCase):
else:
self.assertTrue(False)
def test_duplicate_target_topic(self):
transport = messaging.get_transport(self.conf, url='fake:')
server_thread = self._setup_server(transport, None, topic='testtopic')
server = messaging.get_rpc_server(transport,
messaging.Target(server='testserver',
topic='testtopic'),
[])
try:
server.start()
except Exception as ex:
self.assertTrue(isinstance(ex, messaging.ServerListenError), ex)
self.assertEquals(ex.target.server, 'testserver')
self.assertEquals(ex.target.topic, 'testtopic')
else:
self.assertTrue(False)
finally:
client = self._setup_client(transport)
self._stop_server(client, server_thread)
def test_unknown_executor(self):
transport = messaging.get_transport(self.conf, url='fake:')
@ -241,4 +219,6 @@ class TestRPCServer(test_utils.BaseTestCase):
self.assertTrue(thread2.isAlive())
self._stop_server(client, thread2, topic='topic2')
self.assertEquals(endpoint.pings, ['ds1', 'ds2'])
self.assertEquals(len(endpoint.pings), 2)
self.assertTrue('ds1' in endpoint.pings)
self.assertTrue('ds2' in endpoint.pings)