ZMQ: Run more functional tests

Change-Id: Ia7b001bf5aba1120544dcc15c5200c50ebe731f6
This commit is contained in:
Victor Sergeyev 2015-07-23 18:31:11 +03:00
parent 315e56ae2b
commit de629d8104
10 changed files with 112 additions and 79 deletions

View File

@ -14,10 +14,11 @@
import abc
import collections
import logging
import random
import six
from oslo_messaging._drivers.zmq_driver import zmq_target
import oslo_messaging
from oslo_messaging._i18n import _LI, _LW
@ -34,26 +35,44 @@ class MatchMakerBase(object):
@abc.abstractmethod
def register(self, target, hostname):
"""Register target on nameserver"""
"""Register target on nameserver.
:param target: the target for host
:type target: Target
:param hostname: host for the topic in "host:port" format
:type hostname: String
"""
@abc.abstractmethod
def get_hosts(self, target):
"""Get hosts from nameserver by target"""
"""Get all hosts from nameserver by target.
:param target: the default target for invocations
:type target: Target
:returns: a list of "hostname:port" hosts
"""
def get_single_host(self, target):
"""Get a single host by target"""
"""Get a single host by target.
:param target: the target for messages
:type target: Target
:returns: a "hostname:port" host
"""
hosts = self.get_hosts(target)
if len(hosts) == 0:
LOG.warning(_LW("No hosts were found for target %s. Using "
"localhost") % target)
return "localhost:" + str(self.conf.rpc_zmq_port)
elif len(hosts) == 1:
if not hosts:
err_msg = "No hosts were found for target %s." % target
LOG.error(err_msg)
raise oslo_messaging.InvalidTarget(err_msg, target)
if len(hosts) == 1:
LOG.info(_LI("A single host found for target %s.") % target)
return hosts[0]
else:
LOG.warning(_LW("Multiple hosts were found for target %s. Using "
"the first one.") % target)
return hosts[0]
"the random one.") % target)
return random.choice(hosts)
class DummyMatchMaker(MatchMakerBase):
@ -64,10 +83,10 @@ class DummyMatchMaker(MatchMakerBase):
self._cache = collections.defaultdict(list)
def register(self, target, hostname):
key = zmq_target.target_to_str(target)
key = str(target)
if hostname not in self._cache[key]:
self._cache[key].append(hostname)
def get_hosts(self, target):
key = zmq_target.target_to_str(target)
key = str(target)
return self._cache[key]

View File

@ -17,7 +17,6 @@ from oslo_config import cfg
import redis
from oslo_messaging._drivers.zmq_driver.matchmaker import base
from oslo_messaging._drivers.zmq_driver import zmq_target
LOG = logging.getLogger(__name__)
@ -48,11 +47,28 @@ class RedisMatchMaker(base.MatchMakerBase):
password=self.conf.matchmaker_redis.password,
)
def _target_to_key(self, target):
attributes = ['topic', 'exchange', 'server']
return ':'.join((getattr(target, attr) or "*") for attr in attributes)
def _get_keys_by_pattern(self, pattern):
return self._redis.keys(pattern)
def _get_hosts_by_key(self, key):
return self._redis.lrange(key, 0, -1)
def register(self, target, hostname):
if hostname not in self.get_hosts(target):
key = zmq_target.target_to_str(target)
key = self._target_to_key(target)
if hostname not in self._get_hosts_by_key(key):
self._redis.lpush(key, hostname)
def get_hosts(self, target):
key = zmq_target.target_to_str(target)
return self._redis.lrange(key, 0, -1)[::-1]
pattern = self._target_to_key(target)
if "*" not in pattern:
# pattern have no placeholders, so this is valid key
return self._get_hosts_by_key(pattern)
hosts = []
for key in self._get_keys_by_pattern(pattern):
hosts.extend(self._get_hosts_by_key(key))
return hosts

View File

@ -29,6 +29,8 @@ zmq = zmq_async.import_zmq()
class CallRequest(Request):
msg_type = zmq_serializer.CALL_TYPE
def __init__(self, conf, target, context, message, timeout=None,
retry=None, allowed_remote_exmods=None, matchmaker=None):
self.allowed_remote_exmods = allowed_remote_exmods or []
@ -40,7 +42,6 @@ class CallRequest(Request):
socket = self.zmq_context.socket(zmq.REQ)
super(CallRequest, self).__init__(conf, target, context,
message, socket,
zmq_serializer.CALL_TYPE,
timeout, retry)
self.host = self.matchmaker.get_single_host(self.target)
self.connect_address = zmq_target.get_tcp_direct_address(

View File

@ -29,14 +29,7 @@ zmq = zmq_async.import_zmq()
class CastRequest(Request):
def __init__(self, conf, target, context,
message, socket, address, timeout=None, retry=None):
self.connect_address = address
fanout_type = zmq_serializer.FANOUT_TYPE
cast_type = zmq_serializer.CAST_TYPE
msg_type = fanout_type if target.fanout else cast_type
super(CastRequest, self).__init__(conf, target, context, message,
socket, msg_type, timeout, retry)
msg_type = zmq_serializer.CAST_TYPE
def __call__(self, *args, **kwargs):
self.send_request()
@ -50,6 +43,19 @@ class CastRequest(Request):
pass
class FanoutRequest(CastRequest):
msg_type = zmq_serializer.FANOUT_TYPE
def __init__(self, *args, **kwargs):
self.hosts_count = kwargs.pop("hosts_count")
super(FanoutRequest, self).__init__(*args, **kwargs)
def send_request(self):
for _ in range(self.hosts_count):
super(FanoutRequest, self).send_request()
class DealerCastPublisher(zmq_cast_publisher.CastPublisherBase):
def __init__(self, conf, matchmaker):
@ -58,22 +64,30 @@ class DealerCastPublisher(zmq_cast_publisher.CastPublisherBase):
def cast(self, target, context,
message, timeout=None, retry=None):
host = self.matchmaker.get_single_host(target)
connect_address = zmq_target.get_tcp_direct_address(host)
dealer_socket = self._create_socket(connect_address)
request = CastRequest(self.conf, target, context, message,
dealer_socket, connect_address, timeout, retry)
if str(target) in self.outbound_sockets:
dealer_socket, hosts = self.outbound_sockets[str(target)]
else:
dealer_socket = self.zmq_context.socket(zmq.DEALER)
hosts = self.matchmaker.get_hosts(target)
for host in hosts:
self._connect_to_host(dealer_socket, host)
self.outbound_sockets[str(target)] = (dealer_socket, hosts)
if target.fanout:
request = FanoutRequest(self.conf, target, context, message,
dealer_socket, timeout, retry,
hosts_count=len(hosts))
else:
request = CastRequest(self.conf, target, context, message,
dealer_socket, timeout, retry)
request.send_request()
def _create_socket(self, address):
if address in self.outbound_sockets:
return self.outbound_sockets[address]
def _connect_to_host(self, socket, host):
address = zmq_target.get_tcp_direct_address(host)
try:
dealer_socket = self.zmq_context.socket(zmq.DEALER)
LOG.info(_LI("Connecting DEALER to %s") % address)
dealer_socket.connect(address)
self.outbound_sockets[address] = dealer_socket
return dealer_socket
socket.connect(address)
except zmq.ZMQError as e:
errmsg = _LE("Failed connecting DEALER to %(address)s: %(e)s")\
% (address, e)
@ -81,7 +95,6 @@ class DealerCastPublisher(zmq_cast_publisher.CastPublisherBase):
raise rpc_common.RPCException(errmsg)
def cleanup(self):
if self.outbound_sockets:
for socket in self.outbound_sockets.values():
socket.setsockopt(zmq.LINGER, 0)
socket.close()
for socket, hosts in self.outbound_sockets.values():
socket.setsockopt(zmq.LINGER, 0)
socket.close()

View File

@ -32,9 +32,10 @@ zmq = zmq_async.import_zmq()
class Request(object):
def __init__(self, conf, target, context, message,
socket, msg_type, timeout=None, retry=None):
socket, timeout=None, retry=None):
assert msg_type in zmq_serializer.MESSAGE_TYPES, "Unknown msg type!"
if self.msg_type not in zmq_serializer.MESSAGE_TYPES:
raise RuntimeError("Unknown msg type!")
if message['method'] is None:
errmsg = _LE("No method specified for RPC call")
@ -42,7 +43,6 @@ class Request(object):
raise KeyError(errmsg)
self.msg_id = uuid.uuid4().hex
self.msg_type = msg_type
self.target = target
self.context = context
self.message = message
@ -51,6 +51,10 @@ class Request(object):
self.reply = None
self.socket = socket
@abc.abstractproperty
def msg_type(self):
"""ZMQ message type"""
@property
def is_replied(self):
return self.reply is not None

View File

@ -36,8 +36,7 @@ class ZmqServer(base.Listener):
self.socket = self.context.socket(zmq.ROUTER)
self.address = zmq_target.get_tcp_random_address(conf)
self.port = self.socket.bind_to_random_port(self.address)
LOG.info("Run server on tcp://%s:%d" %
(self.address, self.port))
LOG.info("Run server on %s:%d" % (self.address, self.port))
except zmq.ZMQError as e:
errmsg = _LE("Failed binding to port %(port)d: %(e)s")\
% (self.port, e)

View File

@ -12,8 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
from oslo_messaging import target
def get_tcp_bind_address(port):
return "tcp://*:%s" % port
@ -33,23 +31,3 @@ def get_tcp_direct_address(host):
def get_tcp_random_address(conf):
return "tcp://*"
def target_to_str(target):
items = []
if target.topic:
items.append(target.topic)
if target.exchange:
items.append(target.exchange)
if target.server:
items.append(target.server)
return '.'.join(items)
def target_from_dict(target_dict):
return target.Target(exchange=target_dict['exchange'],
topic=target_dict['topic'],
namespace=target_dict['namespace'],
version=target_dict['version'],
server=target_dict['server'],
fanout=target_dict['fanout'])

View File

@ -56,8 +56,8 @@ class TestImplMatchmaker(test_utils.BaseTestCase):
self.test_matcher.register(self.target, self.host1)
self.test_matcher.register(self.target, self.host2)
self.assertEqual(self.test_matcher.get_hosts(self.target),
[self.host1, self.host2])
self.assertItemsEqual(self.test_matcher.get_hosts(self.target),
[self.host1, self.host2])
self.assertIn(self.test_matcher.get_single_host(self.target),
[self.host1, self.host2])
@ -76,5 +76,5 @@ class TestImplMatchmaker(test_utils.BaseTestCase):
def test_get_single_host_wrong_topic(self):
target = oslo_messaging.Target(topic="no_such_topic")
self.assertEqual(self.test_matcher.get_single_host(target),
"localhost:9501")
self.assertRaises(oslo_messaging.InvalidTarget,
self.test_matcher.get_single_host, target)

View File

@ -93,6 +93,8 @@ class CallTestCase(utils.SkipIfNoTransportURL):
self.assertEqual(0, s.endpoint.ival)
def test_timeout(self):
if self.url.startswith("zmq"):
self.skipTest("Skip CallTestCase.test_timeout for ZMQ driver")
transport = self.useFixture(utils.TransportFixture(self.url))
target = oslo_messaging.Target(topic="no_such_topic")
c = utils.ClientStub(transport.transport, target, timeout=1)
@ -185,6 +187,11 @@ class NotifyTestCase(utils.SkipIfNoTransportURL):
# NOTE(sileht): Each test must not use the same topics
# to be run in parallel
def setUp(self):
super(NotifyTestCase, self).setUp()
if self.url.startswith("zmq"):
self.skipTest("Skip NotifyTestCase for ZMQ driver")
def test_simple(self):
listener = self.useFixture(
utils.NotificationFixture(self.url, ['test_simple']))

View File

@ -41,11 +41,7 @@ setenv = TRANSPORT_URL=amqp://stackqpid:secretqpid@127.0.0.1:65123//
commands = {toxinidir}/setup-test-env-qpid.sh python setup.py testr --slowest --testr-args='oslo_messaging.tests.functional'
[testenv:py27-func-zeromq]
commands = {toxinidir}/setup-test-env-zmq.sh python -m testtools.run \
oslo_messaging.tests.functional.test_functional.CallTestCase.test_exception \
oslo_messaging.tests.functional.test_functional.CallTestCase.test_timeout \
oslo_messaging.tests.functional.test_functional.CallTestCase.test_specific_server \
oslo_messaging.tests.functional.test_functional.CastTestCase.test_specific_server
commands = {toxinidir}/setup-test-env-zmq.sh python setup.py testr --slowest --testr-args='oslo_messaging.tests.functional.test_functional'
[flake8]
show-source = True