Merge "Full support of multiple hosts in transport url"
This commit is contained in:
commit
17375f41ce
|
@ -59,16 +59,17 @@ LOG = logging.getLogger(__name__)
|
|||
|
||||
class ConnectionPool(pool.Pool):
|
||||
"""Class that implements a Pool of Connections."""
|
||||
def __init__(self, conf, connection_cls):
|
||||
def __init__(self, conf, url, connection_cls):
|
||||
self.connection_cls = connection_cls
|
||||
self.conf = conf
|
||||
self.url = url
|
||||
super(ConnectionPool, self).__init__(self.conf.rpc_conn_pool_size)
|
||||
self.reply_proxy = None
|
||||
|
||||
# TODO(comstud): Timeout connections not used in a while
|
||||
def create(self):
|
||||
LOG.debug(_('Pool creating new connection'))
|
||||
return self.connection_cls(self.conf)
|
||||
return self.connection_cls(self.conf, self.url)
|
||||
|
||||
def empty(self):
|
||||
for item in self.iter_free():
|
||||
|
@ -82,18 +83,19 @@ class ConnectionPool(pool.Pool):
|
|||
# time code, it gets here via cleanup() and only appears in service.py
|
||||
# just before doing a sys.exit(), so cleanup() only happens once and
|
||||
# the leakage is not a problem.
|
||||
self.connection_cls.pool = None
|
||||
del self.connection_cls.pools[self.url]
|
||||
|
||||
|
||||
_pool_create_sem = threading.Lock()
|
||||
|
||||
|
||||
def get_connection_pool(conf, connection_cls):
|
||||
def get_connection_pool(conf, url, connection_cls):
|
||||
with _pool_create_sem:
|
||||
# Make sure only one thread tries to create the connection pool.
|
||||
if not connection_cls.pool:
|
||||
connection_cls.pool = ConnectionPool(conf, connection_cls)
|
||||
return connection_cls.pool
|
||||
if url not in connection_cls.pools:
|
||||
connection_cls.pools[url] = ConnectionPool(conf, url,
|
||||
connection_cls)
|
||||
return connection_cls.pools[url]
|
||||
|
||||
|
||||
class ConnectionContext(rpc_common.Connection):
|
||||
|
@ -108,17 +110,16 @@ class ConnectionContext(rpc_common.Connection):
|
|||
If possible the function makes sure to return a connection to the pool.
|
||||
"""
|
||||
|
||||
def __init__(self, conf, connection_pool, pooled=True, server_params=None):
|
||||
def __init__(self, conf, url, connection_pool, pooled=True):
|
||||
"""Create a new connection, or get one from the pool."""
|
||||
self.connection = None
|
||||
self.conf = conf
|
||||
self.url = url
|
||||
self.connection_pool = connection_pool
|
||||
if pooled:
|
||||
self.connection = connection_pool.get()
|
||||
else:
|
||||
self.connection = connection_pool.connection_cls(
|
||||
conf,
|
||||
server_params=server_params)
|
||||
self.connection = connection_pool.connection_cls(conf, url)
|
||||
self.pooled = pooled
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
@ -295,8 +295,6 @@ class AMQPDriverBase(base.BaseDriver):
|
|||
super(AMQPDriverBase, self).__init__(conf, url, default_exchange,
|
||||
allowed_remote_exmods)
|
||||
|
||||
self._server_params = self._server_params_from_url(self._url)
|
||||
|
||||
self._default_exchange = default_exchange
|
||||
|
||||
# FIXME(markmc): temp hack
|
||||
|
@ -310,35 +308,11 @@ class AMQPDriverBase(base.BaseDriver):
|
|||
self._reply_q_conn = None
|
||||
self._waiter = None
|
||||
|
||||
def _server_params_from_url(self, url):
|
||||
sp = {}
|
||||
|
||||
if url.virtual_host is not None:
|
||||
sp['virtual_host'] = url.virtual_host
|
||||
|
||||
if url.hosts:
|
||||
# FIXME(markmc): support multiple hosts
|
||||
host = url.hosts[0]
|
||||
|
||||
sp['hostname'] = host.hostname
|
||||
if host.port is not None:
|
||||
sp['port'] = host.port
|
||||
sp['username'] = host.username or ''
|
||||
sp['password'] = host.password or ''
|
||||
|
||||
return sp
|
||||
|
||||
def _get_connection(self, pooled=True):
|
||||
# FIXME(markmc): we don't yet have a connection pool for each
|
||||
# Transport instance, so we'll only use the pool with the
|
||||
# transport configuration from the config file
|
||||
server_params = self._server_params or None
|
||||
if server_params:
|
||||
pooled = False
|
||||
return rpc_amqp.ConnectionContext(self.conf,
|
||||
self._url,
|
||||
self._connection_pool,
|
||||
pooled=pooled,
|
||||
server_params=server_params)
|
||||
pooled=pooled)
|
||||
|
||||
def _get_reply_q(self):
|
||||
with self._reply_q_lock:
|
||||
|
|
|
@ -27,6 +27,7 @@ from oslo.messaging._drivers import amqpdriver
|
|||
from oslo.messaging._drivers import common as rpc_common
|
||||
from oslo.messaging.openstack.common import importutils
|
||||
from oslo.messaging.openstack.common import jsonutils
|
||||
from oslo.messaging.openstack.common import network_utils
|
||||
|
||||
# FIXME(markmc): remove this
|
||||
_ = lambda s: s
|
||||
|
@ -449,9 +450,9 @@ class NotifyPublisher(Publisher):
|
|||
class Connection(object):
|
||||
"""Connection object."""
|
||||
|
||||
pool = None
|
||||
pools = {}
|
||||
|
||||
def __init__(self, conf, server_params=None):
|
||||
def __init__(self, conf, url):
|
||||
if not qpid_messaging:
|
||||
raise ImportError("Failed to import qpid.messaging")
|
||||
|
||||
|
@ -460,35 +461,44 @@ class Connection(object):
|
|||
self.consumers = {}
|
||||
self.conf = conf
|
||||
|
||||
if server_params and 'hostname' in server_params:
|
||||
# NOTE(russellb) This enables support for cast_to_server.
|
||||
server_params['qpid_hosts'] = [
|
||||
'%s:%d' % (server_params['hostname'],
|
||||
server_params.get('port', 5672))
|
||||
]
|
||||
self.brokers_params = []
|
||||
if url.hosts:
|
||||
for host in url.hosts:
|
||||
params = {
|
||||
'username': host.username or '',
|
||||
'password': host.password or '',
|
||||
}
|
||||
if host.port is not None:
|
||||
params['host'] = '%s:%d' % (host.hostname, host.port)
|
||||
else:
|
||||
params['host'] = host.hostname
|
||||
self.brokers_params.append(params)
|
||||
else:
|
||||
# Old configuration format
|
||||
for adr in self.conf.qpid_hosts:
|
||||
hostname, port = network_utils.parse_host_port(
|
||||
adr, default_port=5672)
|
||||
|
||||
params = {
|
||||
'qpid_hosts': self.conf.qpid_hosts[:],
|
||||
'username': self.conf.qpid_username,
|
||||
'password': self.conf.qpid_password,
|
||||
}
|
||||
params.update(server_params or {})
|
||||
params = {
|
||||
'host': '%s:%d' % (hostname, port),
|
||||
'username': self.conf.qpid_username,
|
||||
'password': self.conf.qpid_password,
|
||||
}
|
||||
self.brokers_params.append(params)
|
||||
|
||||
random.shuffle(params['qpid_hosts'])
|
||||
self.brokers = itertools.cycle(params['qpid_hosts'])
|
||||
random.shuffle(self.brokers_params)
|
||||
self.brokers = itertools.cycle(self.brokers_params)
|
||||
|
||||
self.username = params['username']
|
||||
self.password = params['password']
|
||||
self.reconnect()
|
||||
|
||||
def connection_create(self, broker):
|
||||
# Create the connection - this does not open the connection
|
||||
self.connection = qpid_messaging.Connection(broker)
|
||||
self.connection = qpid_messaging.Connection(broker['host'])
|
||||
|
||||
# Check if flags are set and if so set them for the connection
|
||||
# before we call open
|
||||
self.connection.username = self.username
|
||||
self.connection.password = self.password
|
||||
self.connection.username = broker['username']
|
||||
self.connection.password = broker['password']
|
||||
|
||||
self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
|
||||
# Reconnection is done by self.reconnect()
|
||||
|
@ -520,14 +530,14 @@ class Connection(object):
|
|||
self.connection_create(broker)
|
||||
self.connection.open()
|
||||
except qpid_exceptions.MessagingError as e:
|
||||
msg_dict = dict(e=e, delay=delay)
|
||||
msg = _("Unable to connect to AMQP server: %(e)s. "
|
||||
"Sleeping %(delay)s seconds") % msg_dict
|
||||
msg_dict = dict(e=e, delay=delay, broker=broker['host'])
|
||||
msg = _("Unable to connect to AMQP server on %(broker)s: "
|
||||
"%(e)s. Sleeping %(delay)s seconds") % msg_dict
|
||||
LOG.error(msg)
|
||||
time.sleep(delay)
|
||||
delay = min(delay + 1, 5)
|
||||
else:
|
||||
LOG.info(_('Connected to AMQP server on %s'), broker)
|
||||
LOG.info(_('Connected to AMQP server on %s'), broker['host'])
|
||||
break
|
||||
|
||||
self.session = self.connection.session()
|
||||
|
@ -686,7 +696,7 @@ class QpidDriver(amqpdriver.AMQPDriverBase):
|
|||
conf.register_opts(qpid_opts)
|
||||
conf.register_opts(rpc_amqp.amqp_opts)
|
||||
|
||||
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
|
||||
connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection)
|
||||
|
||||
super(QpidDriver, self).__init__(conf, url,
|
||||
connection_pool,
|
||||
|
|
|
@ -421,9 +421,9 @@ class NotifyPublisher(TopicPublisher):
|
|||
class Connection(object):
|
||||
"""Connection object."""
|
||||
|
||||
pool = None
|
||||
pools = {}
|
||||
|
||||
def __init__(self, conf, server_params=None):
|
||||
def __init__(self, conf, url):
|
||||
self.consumers = []
|
||||
self.conf = conf
|
||||
self.max_retries = self.conf.rabbit_max_retries
|
||||
|
@ -436,39 +436,54 @@ class Connection(object):
|
|||
self.interval_max = 30
|
||||
self.memory_transport = False
|
||||
|
||||
if server_params is None:
|
||||
server_params = {}
|
||||
# Keys to translate from server_params to kombu params
|
||||
server_params_to_kombu_params = {'username': 'userid'}
|
||||
|
||||
ssl_params = self._fetch_ssl_params()
|
||||
params_list = []
|
||||
for adr in self.conf.rabbit_hosts:
|
||||
hostname, port = network_utils.parse_host_port(
|
||||
adr, default_port=self.conf.rabbit_port)
|
||||
|
||||
params = {
|
||||
'hostname': hostname,
|
||||
'port': port,
|
||||
'userid': self.conf.rabbit_userid,
|
||||
'password': self.conf.rabbit_password,
|
||||
'login_method': self.conf.rabbit_login_method,
|
||||
'virtual_host': self.conf.rabbit_virtual_host,
|
||||
}
|
||||
if url.virtual_host is not None:
|
||||
virtual_host = url.virtual_host
|
||||
else:
|
||||
virtual_host = self.conf.rabbit_virtual_host
|
||||
|
||||
for sp_key, value in six.iteritems(server_params):
|
||||
p_key = server_params_to_kombu_params.get(sp_key, sp_key)
|
||||
params[p_key] = value
|
||||
self.brokers_params = []
|
||||
if url.hosts:
|
||||
for host in url.hosts:
|
||||
params = {
|
||||
'hostname': host.hostname,
|
||||
'port': host.port or 5672,
|
||||
'userid': host.username or '',
|
||||
'password': host.password or '',
|
||||
'login_method': self.conf.rabbit_login_method,
|
||||
'virtual_host': virtual_host
|
||||
}
|
||||
if self.conf.fake_rabbit:
|
||||
params['transport'] = 'memory'
|
||||
if self.conf.rabbit_use_ssl:
|
||||
params['ssl'] = ssl_params
|
||||
|
||||
if self.conf.fake_rabbit:
|
||||
params['transport'] = 'memory'
|
||||
if self.conf.rabbit_use_ssl:
|
||||
params['ssl'] = ssl_params
|
||||
self.brokers_params.append(params)
|
||||
else:
|
||||
# Old configuration format
|
||||
for adr in self.conf.rabbit_hosts:
|
||||
hostname, port = network_utils.parse_host_port(
|
||||
adr, default_port=self.conf.rabbit_port)
|
||||
|
||||
params_list.append(params)
|
||||
params = {
|
||||
'hostname': hostname,
|
||||
'port': port,
|
||||
'userid': self.conf.rabbit_userid,
|
||||
'password': self.conf.rabbit_password,
|
||||
'login_method': self.conf.rabbit_login_method,
|
||||
'virtual_host': virtual_host
|
||||
}
|
||||
|
||||
random.shuffle(params_list)
|
||||
self.params_list = itertools.cycle(params_list)
|
||||
if self.conf.fake_rabbit:
|
||||
params['transport'] = 'memory'
|
||||
if self.conf.rabbit_use_ssl:
|
||||
params['ssl'] = ssl_params
|
||||
|
||||
self.brokers_params.append(params)
|
||||
|
||||
random.shuffle(self.brokers_params)
|
||||
self.brokers = itertools.cycle(self.brokers_params)
|
||||
|
||||
self.memory_transport = self.conf.fake_rabbit
|
||||
|
||||
|
@ -519,14 +534,14 @@ class Connection(object):
|
|||
# Return the extended behavior or just have the default behavior
|
||||
return ssl_params or True
|
||||
|
||||
def _connect(self, params):
|
||||
def _connect(self, broker):
|
||||
"""Connect to rabbit. Re-establish any queues that may have
|
||||
been declared before if we are reconnecting. Exceptions should
|
||||
be handled by the caller.
|
||||
"""
|
||||
if self.connection:
|
||||
LOG.info(_("Reconnecting to AMQP server on "
|
||||
"%(hostname)s:%(port)d") % params)
|
||||
"%(hostname)s:%(port)d") % broker)
|
||||
try:
|
||||
# XXX(nic): when reconnecting to a RabbitMQ cluster
|
||||
# with mirrored queues in use, the attempt to release the
|
||||
|
@ -545,7 +560,7 @@ class Connection(object):
|
|||
# Setting this in case the next statement fails, though
|
||||
# it shouldn't be doing any network operations, yet.
|
||||
self.connection = None
|
||||
self.connection = kombu.connection.BrokerConnection(**params)
|
||||
self.connection = kombu.connection.BrokerConnection(**broker)
|
||||
self.connection_errors = self.connection.connection_errors
|
||||
self.channel_errors = self.connection.channel_errors
|
||||
if self.memory_transport:
|
||||
|
@ -561,7 +576,7 @@ class Connection(object):
|
|||
for consumer in self.consumers:
|
||||
consumer.reconnect(self.channel)
|
||||
LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') %
|
||||
params)
|
||||
broker)
|
||||
|
||||
def reconnect(self):
|
||||
"""Handles reconnecting and re-establishing queues.
|
||||
|
@ -574,10 +589,10 @@ class Connection(object):
|
|||
|
||||
attempt = 0
|
||||
while True:
|
||||
params = six.next(self.params_list)
|
||||
broker = six.next(self.brokers)
|
||||
attempt += 1
|
||||
try:
|
||||
self._connect(params)
|
||||
self._connect(broker)
|
||||
return
|
||||
except IOError as e:
|
||||
pass
|
||||
|
@ -596,7 +611,7 @@ class Connection(object):
|
|||
log_info = {}
|
||||
log_info['err_str'] = e
|
||||
log_info['max_retries'] = self.max_retries
|
||||
log_info.update(params)
|
||||
log_info.update(broker)
|
||||
|
||||
if self.max_retries and attempt == self.max_retries:
|
||||
msg = _('Unable to connect to AMQP server on '
|
||||
|
@ -774,7 +789,7 @@ class RabbitDriver(amqpdriver.AMQPDriverBase):
|
|||
conf.register_opts(rabbit_opts)
|
||||
conf.register_opts(rpc_amqp.amqp_opts)
|
||||
|
||||
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
|
||||
connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection)
|
||||
|
||||
super(RabbitDriver, self).__init__(conf, url,
|
||||
connection_pool,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import operator
|
||||
import random
|
||||
import thread
|
||||
import threading
|
||||
|
@ -102,6 +103,65 @@ class _QpidBaseTestCase(test_utils.BaseTestCase):
|
|||
self.con_send.close()
|
||||
|
||||
|
||||
class TestQpidTransportURL(_QpidBaseTestCase):
|
||||
|
||||
scenarios = [
|
||||
('none', dict(url=None,
|
||||
expected=[dict(host='localhost:5672',
|
||||
username='',
|
||||
password='')])),
|
||||
('empty',
|
||||
dict(url='qpid:///',
|
||||
expected=[dict(host='localhost:5672',
|
||||
username='',
|
||||
password='')])),
|
||||
('localhost',
|
||||
dict(url='qpid://localhost/',
|
||||
expected=[dict(host='localhost',
|
||||
username='',
|
||||
password='')])),
|
||||
('no_creds',
|
||||
dict(url='qpid://host/',
|
||||
expected=[dict(host='host',
|
||||
username='',
|
||||
password='')])),
|
||||
('no_port',
|
||||
dict(url='qpid://user:password@host/',
|
||||
expected=[dict(host='host',
|
||||
username='user',
|
||||
password='password')])),
|
||||
('full_url',
|
||||
dict(url='qpid://user:password@host:10/',
|
||||
expected=[dict(host='host:10',
|
||||
username='user',
|
||||
password='password')])),
|
||||
('full_two_url',
|
||||
dict(url='qpid://user:password@host:10,'
|
||||
'user2:password2@host2:12/',
|
||||
expected=[dict(host='host:10',
|
||||
username='user',
|
||||
password='password'),
|
||||
dict(host='host2:12',
|
||||
username='user2',
|
||||
password='password2')
|
||||
]
|
||||
)),
|
||||
|
||||
]
|
||||
|
||||
@mock.patch.object(qpid_driver.Connection, 'reconnect')
|
||||
def test_transport_url(self, *args):
|
||||
transport = messaging.get_transport(self.conf, self.url)
|
||||
self.addCleanup(transport.cleanup)
|
||||
driver = transport._driver
|
||||
|
||||
brokers_params = driver._get_connection().brokers_params
|
||||
self.assertEqual(sorted(self.expected,
|
||||
key=operator.itemgetter('host')),
|
||||
sorted(brokers_params,
|
||||
key=operator.itemgetter('host')))
|
||||
|
||||
|
||||
class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
|
||||
"""Unit test cases to test invalid qpid topology version."""
|
||||
|
||||
|
@ -398,11 +458,12 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase):
|
|||
brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
|
||||
brokers_count = len(brokers)
|
||||
|
||||
self.messaging_conf.conf.qpid_hosts = brokers
|
||||
self.config(qpid_hosts=brokers)
|
||||
|
||||
with mock.patch('qpid.messaging.Connection') as conn_mock:
|
||||
# starting from the first broker in the list
|
||||
connection = qpid_driver.Connection(self.messaging_conf.conf)
|
||||
url = messaging.TransportURL.parse(self.conf, None)
|
||||
connection = qpid_driver.Connection(self.conf, url)
|
||||
|
||||
# reconnect will advance to the next broker, one broker per
|
||||
# attempt, and then wrap to the start of the list once the end is
|
||||
|
@ -412,7 +473,7 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase):
|
|||
|
||||
expected = []
|
||||
for broker in brokers:
|
||||
expected.extend([mock.call(broker),
|
||||
expected.extend([mock.call("%s:5672" % broker),
|
||||
mock.call().open(),
|
||||
mock.call().session(),
|
||||
mock.call().opened(),
|
||||
|
@ -601,6 +662,9 @@ class FakeQpidSession(object):
|
|||
key = slash_split[-1]
|
||||
return key.strip()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
_fake_session = FakeQpidSession()
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# under the License.
|
||||
|
||||
import datetime
|
||||
import operator
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
|
@ -46,74 +47,88 @@ class TestRabbitDriverLoad(test_utils.BaseTestCase):
|
|||
class TestRabbitTransportURL(test_utils.BaseTestCase):
|
||||
|
||||
scenarios = [
|
||||
('none', dict(url=None, expected=None)),
|
||||
('none', dict(url=None,
|
||||
expected=[dict(hostname='localhost',
|
||||
port=5672,
|
||||
userid='guest',
|
||||
password='guest',
|
||||
virtual_host='/')])),
|
||||
('empty',
|
||||
dict(url='rabbit:///',
|
||||
expected=dict(virtual_host=''))),
|
||||
expected=[dict(hostname='localhost',
|
||||
port=5672,
|
||||
userid='guest',
|
||||
password='guest',
|
||||
virtual_host='')])),
|
||||
('localhost',
|
||||
dict(url='rabbit://localhost/',
|
||||
expected=dict(hostname='localhost',
|
||||
username='',
|
||||
password='',
|
||||
virtual_host=''))),
|
||||
expected=[dict(hostname='localhost',
|
||||
port=5672,
|
||||
userid='',
|
||||
password='',
|
||||
virtual_host='')])),
|
||||
('virtual_host',
|
||||
dict(url='rabbit:///vhost',
|
||||
expected=dict(virtual_host='vhost'))),
|
||||
expected=[dict(hostname='localhost',
|
||||
port=5672,
|
||||
userid='guest',
|
||||
password='guest',
|
||||
virtual_host='vhost')])),
|
||||
('no_creds',
|
||||
dict(url='rabbit://host/virtual_host',
|
||||
expected=dict(hostname='host',
|
||||
username='',
|
||||
password='',
|
||||
virtual_host='virtual_host'))),
|
||||
expected=[dict(hostname='host',
|
||||
port=5672,
|
||||
userid='',
|
||||
password='',
|
||||
virtual_host='virtual_host')])),
|
||||
('no_port',
|
||||
dict(url='rabbit://user:password@host/virtual_host',
|
||||
expected=dict(hostname='host',
|
||||
username='user',
|
||||
password='password',
|
||||
virtual_host='virtual_host'))),
|
||||
expected=[dict(hostname='host',
|
||||
port=5672,
|
||||
userid='user',
|
||||
password='password',
|
||||
virtual_host='virtual_host')])),
|
||||
('full_url',
|
||||
dict(url='rabbit://user:password@host:10/virtual_host',
|
||||
expected=dict(hostname='host',
|
||||
port=10,
|
||||
username='user',
|
||||
password='password',
|
||||
virtual_host='virtual_host'))),
|
||||
expected=[dict(hostname='host',
|
||||
port=10,
|
||||
userid='user',
|
||||
password='password',
|
||||
virtual_host='virtual_host')])),
|
||||
('full_two_url',
|
||||
dict(url='rabbit://user:password@host:10,'
|
||||
'user2:password2@host2:12/virtual_host',
|
||||
expected=[dict(hostname='host',
|
||||
port=10,
|
||||
userid='user',
|
||||
password='password',
|
||||
virtual_host='virtual_host'),
|
||||
dict(hostname='host2',
|
||||
port=12,
|
||||
userid='user2',
|
||||
password='password2',
|
||||
virtual_host='virtual_host')
|
||||
]
|
||||
)),
|
||||
|
||||
]
|
||||
|
||||
def setUp(self):
|
||||
super(TestRabbitTransportURL, self).setUp()
|
||||
|
||||
self.messaging_conf.transport_driver = 'rabbit'
|
||||
def test_transport_url(self):
|
||||
self.messaging_conf.in_memory = True
|
||||
|
||||
self._server_params = []
|
||||
cnx_init = rabbit_driver.Connection.__init__
|
||||
transport = messaging.get_transport(self.conf, self.url)
|
||||
self.addCleanup(transport.cleanup)
|
||||
driver = transport._driver
|
||||
|
||||
def record_params(cnx, conf, server_params=None):
|
||||
self._server_params.append(server_params)
|
||||
return cnx_init(cnx, conf, server_params)
|
||||
brokers_params = driver._get_connection().brokers_params[:]
|
||||
brokers_params = [dict((k, v) for k, v in broker.items()
|
||||
if k not in ['transport', 'login_method'])
|
||||
for broker in brokers_params]
|
||||
|
||||
def dummy_send(cnx, topic, msg, timeout=None):
|
||||
pass
|
||||
|
||||
self.stubs.Set(rabbit_driver.Connection, '__init__', record_params)
|
||||
self.stubs.Set(rabbit_driver.Connection, 'topic_send', dummy_send)
|
||||
|
||||
self._driver = messaging.get_transport(self.conf, self.url)._driver
|
||||
self._target = messaging.Target(topic='testtopic')
|
||||
|
||||
def test_transport_url_listen(self):
|
||||
self._driver.listen(self._target)
|
||||
self.assertEqual(self.expected, self._server_params[0])
|
||||
|
||||
def test_transport_url_listen_for_notification(self):
|
||||
self._driver.listen_for_notifications(
|
||||
[(messaging.Target(topic='topic'), 'info')])
|
||||
self.assertEqual(self.expected, self._server_params[0])
|
||||
|
||||
def test_transport_url_send(self):
|
||||
self._driver.send(self._target, {}, {})
|
||||
self.assertEqual(self.expected, self._server_params[0])
|
||||
self.assertEqual(sorted(self.expected,
|
||||
key=operator.itemgetter('hostname')),
|
||||
sorted(brokers_params,
|
||||
key=operator.itemgetter('hostname')))
|
||||
|
||||
|
||||
class TestSendReceive(test_utils.BaseTestCase):
|
||||
|
@ -619,8 +634,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase):
|
|||
brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
|
||||
brokers_count = len(brokers)
|
||||
|
||||
self.conf.rabbit_hosts = brokers
|
||||
self.conf.rabbit_max_retries = 1
|
||||
self.config(rabbit_hosts=brokers,
|
||||
rabbit_max_retries=1)
|
||||
|
||||
hostname_sets = set()
|
||||
|
||||
|
@ -639,7 +654,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase):
|
|||
self.stubs.Set(rabbit_driver.Connection, '_connect', _connect)
|
||||
|
||||
# starting from the first broker in the list
|
||||
connection = rabbit_driver.Connection(self.conf)
|
||||
url = messaging.TransportURL.parse(self.conf, None)
|
||||
connection = rabbit_driver.Connection(self.conf, url)
|
||||
|
||||
# now that we have connection object, revert to the real 'connect'
|
||||
# implementation
|
||||
|
|
Loading…
Reference in New Issue