Full support of multiple hosts in transport url

This patch add the support of multiple hosts in transport url for rabbit and
qpid drivers. And also fix the amqp connection pool management to allow
to have one pool by transport.

Implements blueprint multiple-hosts-support-in-url

Co-Authored-By: Ala Rezmerita <ala.rezmerita@cloudwatt.com>
Change-Id: I5aff24d292b67a7b65e33e7083e245efbbe82024
This commit is contained in:
Mehdi Abaakouk 2014-03-07 10:46:17 +01:00
parent 06ab616d8f
commit 53b9d741a8
6 changed files with 238 additions and 158 deletions

View File

@ -59,16 +59,17 @@ LOG = logging.getLogger(__name__)
class ConnectionPool(pool.Pool):
"""Class that implements a Pool of Connections."""
def __init__(self, conf, connection_cls):
def __init__(self, conf, url, connection_cls):
self.connection_cls = connection_cls
self.conf = conf
self.url = url
super(ConnectionPool, self).__init__(self.conf.rpc_conn_pool_size)
self.reply_proxy = None
# TODO(comstud): Timeout connections not used in a while
def create(self):
LOG.debug(_('Pool creating new connection'))
return self.connection_cls(self.conf)
return self.connection_cls(self.conf, self.url)
def empty(self):
for item in self.iter_free():
@ -82,18 +83,19 @@ class ConnectionPool(pool.Pool):
# time code, it gets here via cleanup() and only appears in service.py
# just before doing a sys.exit(), so cleanup() only happens once and
# the leakage is not a problem.
self.connection_cls.pool = None
del self.connection_cls.pools[self.url]
_pool_create_sem = threading.Lock()
def get_connection_pool(conf, connection_cls):
def get_connection_pool(conf, url, connection_cls):
with _pool_create_sem:
# Make sure only one thread tries to create the connection pool.
if not connection_cls.pool:
connection_cls.pool = ConnectionPool(conf, connection_cls)
return connection_cls.pool
if url not in connection_cls.pools:
connection_cls.pools[url] = ConnectionPool(conf, url,
connection_cls)
return connection_cls.pools[url]
class ConnectionContext(rpc_common.Connection):
@ -108,17 +110,16 @@ class ConnectionContext(rpc_common.Connection):
If possible the function makes sure to return a connection to the pool.
"""
def __init__(self, conf, connection_pool, pooled=True, server_params=None):
def __init__(self, conf, url, connection_pool, pooled=True):
"""Create a new connection, or get one from the pool."""
self.connection = None
self.conf = conf
self.url = url
self.connection_pool = connection_pool
if pooled:
self.connection = connection_pool.get()
else:
self.connection = connection_pool.connection_cls(
conf,
server_params=server_params)
self.connection = connection_pool.connection_cls(conf, url)
self.pooled = pooled
def __enter__(self):

View File

@ -295,8 +295,6 @@ class AMQPDriverBase(base.BaseDriver):
super(AMQPDriverBase, self).__init__(conf, url, default_exchange,
allowed_remote_exmods)
self._server_params = self._server_params_from_url(self._url)
self._default_exchange = default_exchange
# FIXME(markmc): temp hack
@ -310,35 +308,11 @@ class AMQPDriverBase(base.BaseDriver):
self._reply_q_conn = None
self._waiter = None
def _server_params_from_url(self, url):
sp = {}
if url.virtual_host is not None:
sp['virtual_host'] = url.virtual_host
if url.hosts:
# FIXME(markmc): support multiple hosts
host = url.hosts[0]
sp['hostname'] = host.hostname
if host.port is not None:
sp['port'] = host.port
sp['username'] = host.username or ''
sp['password'] = host.password or ''
return sp
def _get_connection(self, pooled=True):
# FIXME(markmc): we don't yet have a connection pool for each
# Transport instance, so we'll only use the pool with the
# transport configuration from the config file
server_params = self._server_params or None
if server_params:
pooled = False
return rpc_amqp.ConnectionContext(self.conf,
self._url,
self._connection_pool,
pooled=pooled,
server_params=server_params)
pooled=pooled)
def _get_reply_q(self):
with self._reply_q_lock:

View File

@ -27,6 +27,7 @@ from oslo.messaging._drivers import amqpdriver
from oslo.messaging._drivers import common as rpc_common
from oslo.messaging.openstack.common import importutils
from oslo.messaging.openstack.common import jsonutils
from oslo.messaging.openstack.common import network_utils
# FIXME(markmc): remove this
_ = lambda s: s
@ -449,9 +450,9 @@ class NotifyPublisher(Publisher):
class Connection(object):
"""Connection object."""
pool = None
pools = {}
def __init__(self, conf, server_params=None):
def __init__(self, conf, url):
if not qpid_messaging:
raise ImportError("Failed to import qpid.messaging")
@ -460,35 +461,44 @@ class Connection(object):
self.consumers = {}
self.conf = conf
if server_params and 'hostname' in server_params:
# NOTE(russellb) This enables support for cast_to_server.
server_params['qpid_hosts'] = [
'%s:%d' % (server_params['hostname'],
server_params.get('port', 5672))
]
self.brokers_params = []
if url.hosts:
for host in url.hosts:
params = {
'username': host.username or '',
'password': host.password or '',
}
if host.port is not None:
params['host'] = '%s:%d' % (host.hostname, host.port)
else:
params['host'] = host.hostname
self.brokers_params.append(params)
else:
# Old configuration format
for adr in self.conf.qpid_hosts:
hostname, port = network_utils.parse_host_port(
adr, default_port=5672)
params = {
'qpid_hosts': self.conf.qpid_hosts[:],
'username': self.conf.qpid_username,
'password': self.conf.qpid_password,
}
params.update(server_params or {})
params = {
'host': '%s:%d' % (hostname, port),
'username': self.conf.qpid_username,
'password': self.conf.qpid_password,
}
self.brokers_params.append(params)
random.shuffle(params['qpid_hosts'])
self.brokers = itertools.cycle(params['qpid_hosts'])
random.shuffle(self.brokers_params)
self.brokers = itertools.cycle(self.brokers_params)
self.username = params['username']
self.password = params['password']
self.reconnect()
def connection_create(self, broker):
# Create the connection - this does not open the connection
self.connection = qpid_messaging.Connection(broker)
self.connection = qpid_messaging.Connection(broker['host'])
# Check if flags are set and if so set them for the connection
# before we call open
self.connection.username = self.username
self.connection.password = self.password
self.connection.username = broker['username']
self.connection.password = broker['password']
self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
# Reconnection is done by self.reconnect()
@ -520,14 +530,14 @@ class Connection(object):
self.connection_create(broker)
self.connection.open()
except qpid_exceptions.MessagingError as e:
msg_dict = dict(e=e, delay=delay)
msg = _("Unable to connect to AMQP server: %(e)s. "
"Sleeping %(delay)s seconds") % msg_dict
msg_dict = dict(e=e, delay=delay, broker=broker['host'])
msg = _("Unable to connect to AMQP server on %(broker)s: "
"%(e)s. Sleeping %(delay)s seconds") % msg_dict
LOG.error(msg)
time.sleep(delay)
delay = min(delay + 1, 5)
else:
LOG.info(_('Connected to AMQP server on %s'), broker)
LOG.info(_('Connected to AMQP server on %s'), broker['host'])
break
self.session = self.connection.session()
@ -687,7 +697,7 @@ class QpidDriver(amqpdriver.AMQPDriverBase):
conf.register_opts(qpid_opts)
conf.register_opts(rpc_amqp.amqp_opts)
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection)
super(QpidDriver, self).__init__(conf, url,
connection_pool,

View File

@ -421,9 +421,9 @@ class NotifyPublisher(TopicPublisher):
class Connection(object):
"""Connection object."""
pool = None
pools = {}
def __init__(self, conf, server_params=None):
def __init__(self, conf, url):
self.consumers = []
self.conf = conf
self.max_retries = self.conf.rabbit_max_retries
@ -436,39 +436,54 @@ class Connection(object):
self.interval_max = 30
self.memory_transport = False
if server_params is None:
server_params = {}
# Keys to translate from server_params to kombu params
server_params_to_kombu_params = {'username': 'userid'}
ssl_params = self._fetch_ssl_params()
params_list = []
for adr in self.conf.rabbit_hosts:
hostname, port = network_utils.parse_host_port(
adr, default_port=self.conf.rabbit_port)
params = {
'hostname': hostname,
'port': port,
'userid': self.conf.rabbit_userid,
'password': self.conf.rabbit_password,
'login_method': self.conf.rabbit_login_method,
'virtual_host': self.conf.rabbit_virtual_host,
}
if url.virtual_host is not None:
virtual_host = url.virtual_host
else:
virtual_host = self.conf.rabbit_virtual_host
for sp_key, value in six.iteritems(server_params):
p_key = server_params_to_kombu_params.get(sp_key, sp_key)
params[p_key] = value
self.brokers_params = []
if url.hosts:
for host in url.hosts:
params = {
'hostname': host.hostname,
'port': host.port or 5672,
'userid': host.username or '',
'password': host.password or '',
'login_method': self.conf.rabbit_login_method,
'virtual_host': virtual_host
}
if self.conf.fake_rabbit:
params['transport'] = 'memory'
if self.conf.rabbit_use_ssl:
params['ssl'] = ssl_params
if self.conf.fake_rabbit:
params['transport'] = 'memory'
if self.conf.rabbit_use_ssl:
params['ssl'] = ssl_params
self.brokers_params.append(params)
else:
# Old configuration format
for adr in self.conf.rabbit_hosts:
hostname, port = network_utils.parse_host_port(
adr, default_port=self.conf.rabbit_port)
params_list.append(params)
params = {
'hostname': hostname,
'port': port,
'userid': self.conf.rabbit_userid,
'password': self.conf.rabbit_password,
'login_method': self.conf.rabbit_login_method,
'virtual_host': virtual_host
}
random.shuffle(params_list)
self.params_list = itertools.cycle(params_list)
if self.conf.fake_rabbit:
params['transport'] = 'memory'
if self.conf.rabbit_use_ssl:
params['ssl'] = ssl_params
self.brokers_params.append(params)
random.shuffle(self.brokers_params)
self.brokers = itertools.cycle(self.brokers_params)
self.memory_transport = self.conf.fake_rabbit
@ -519,14 +534,14 @@ class Connection(object):
# Return the extended behavior or just have the default behavior
return ssl_params or True
def _connect(self, params):
def _connect(self, broker):
"""Connect to rabbit. Re-establish any queues that may have
been declared before if we are reconnecting. Exceptions should
be handled by the caller.
"""
if self.connection:
LOG.info(_("Reconnecting to AMQP server on "
"%(hostname)s:%(port)d") % params)
"%(hostname)s:%(port)d") % broker)
try:
# XXX(nic): when reconnecting to a RabbitMQ cluster
# with mirrored queues in use, the attempt to release the
@ -545,7 +560,7 @@ class Connection(object):
# Setting this in case the next statement fails, though
# it shouldn't be doing any network operations, yet.
self.connection = None
self.connection = kombu.connection.BrokerConnection(**params)
self.connection = kombu.connection.BrokerConnection(**broker)
self.connection_errors = self.connection.connection_errors
self.channel_errors = self.connection.channel_errors
if self.memory_transport:
@ -561,7 +576,7 @@ class Connection(object):
for consumer in self.consumers:
consumer.reconnect(self.channel)
LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') %
params)
broker)
def reconnect(self):
"""Handles reconnecting and re-establishing queues.
@ -574,10 +589,10 @@ class Connection(object):
attempt = 0
while True:
params = six.next(self.params_list)
broker = six.next(self.brokers)
attempt += 1
try:
self._connect(params)
self._connect(broker)
return
except IOError as e:
pass
@ -596,7 +611,7 @@ class Connection(object):
log_info = {}
log_info['err_str'] = str(e)
log_info['max_retries'] = self.max_retries
log_info.update(params)
log_info.update(broker)
if self.max_retries and attempt == self.max_retries:
msg = _('Unable to connect to AMQP server on '
@ -775,7 +790,7 @@ class RabbitDriver(amqpdriver.AMQPDriverBase):
conf.register_opts(rabbit_opts)
conf.register_opts(rpc_amqp.amqp_opts)
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection)
super(RabbitDriver, self).__init__(conf, url,
connection_pool,

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import operator
import random
import thread
import threading
@ -102,6 +103,65 @@ class _QpidBaseTestCase(test_utils.BaseTestCase):
self.con_send.close()
class TestQpidTransportURL(_QpidBaseTestCase):
scenarios = [
('none', dict(url=None,
expected=[dict(host='localhost:5672',
username='',
password='')])),
('empty',
dict(url='qpid:///',
expected=[dict(host='localhost:5672',
username='',
password='')])),
('localhost',
dict(url='qpid://localhost/',
expected=[dict(host='localhost',
username='',
password='')])),
('no_creds',
dict(url='qpid://host/',
expected=[dict(host='host',
username='',
password='')])),
('no_port',
dict(url='qpid://user:password@host/',
expected=[dict(host='host',
username='user',
password='password')])),
('full_url',
dict(url='qpid://user:password@host:10/',
expected=[dict(host='host:10',
username='user',
password='password')])),
('full_two_url',
dict(url='qpid://user:password@host:10,'
'user2:password2@host2:12/',
expected=[dict(host='host:10',
username='user',
password='password'),
dict(host='host2:12',
username='user2',
password='password2')
]
)),
]
@mock.patch.object(qpid_driver.Connection, 'reconnect')
def test_transport_url(self, *args):
transport = messaging.get_transport(self.conf, self.url)
self.addCleanup(transport.cleanup)
driver = transport._driver
brokers_params = driver._get_connection().brokers_params
self.assertEqual(sorted(self.expected,
key=operator.itemgetter('host')),
sorted(brokers_params,
key=operator.itemgetter('host')))
class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
"""Unit test cases to test invalid qpid topology version."""
@ -398,11 +458,12 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase):
brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
brokers_count = len(brokers)
self.messaging_conf.conf.qpid_hosts = brokers
self.config(qpid_hosts=brokers)
with mock.patch('qpid.messaging.Connection') as conn_mock:
# starting from the first broker in the list
connection = qpid_driver.Connection(self.messaging_conf.conf)
url = messaging.TransportURL.parse(self.conf, None)
connection = qpid_driver.Connection(self.conf, url)
# reconnect will advance to the next broker, one broker per
# attempt, and then wrap to the start of the list once the end is
@ -412,7 +473,7 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase):
expected = []
for broker in brokers:
expected.extend([mock.call(broker),
expected.extend([mock.call("%s:5672" % broker),
mock.call().open(),
mock.call().session(),
mock.call().opened(),
@ -601,6 +662,9 @@ class FakeQpidSession(object):
key = slash_split[-1]
return key.strip()
def close(self):
pass
_fake_session = FakeQpidSession()

View File

@ -13,6 +13,7 @@
# under the License.
import datetime
import operator
import sys
import threading
import uuid
@ -46,74 +47,88 @@ class TestRabbitDriverLoad(test_utils.BaseTestCase):
class TestRabbitTransportURL(test_utils.BaseTestCase):
scenarios = [
('none', dict(url=None, expected=None)),
('none', dict(url=None,
expected=[dict(hostname='localhost',
port=5672,
userid='guest',
password='guest',
virtual_host='/')])),
('empty',
dict(url='rabbit:///',
expected=dict(virtual_host=''))),
expected=[dict(hostname='localhost',
port=5672,
userid='guest',
password='guest',
virtual_host='')])),
('localhost',
dict(url='rabbit://localhost/',
expected=dict(hostname='localhost',
username='',
password='',
virtual_host=''))),
expected=[dict(hostname='localhost',
port=5672,
userid='',
password='',
virtual_host='')])),
('virtual_host',
dict(url='rabbit:///vhost',
expected=dict(virtual_host='vhost'))),
expected=[dict(hostname='localhost',
port=5672,
userid='guest',
password='guest',
virtual_host='vhost')])),
('no_creds',
dict(url='rabbit://host/virtual_host',
expected=dict(hostname='host',
username='',
password='',
virtual_host='virtual_host'))),
expected=[dict(hostname='host',
port=5672,
userid='',
password='',
virtual_host='virtual_host')])),
('no_port',
dict(url='rabbit://user:password@host/virtual_host',
expected=dict(hostname='host',
username='user',
password='password',
virtual_host='virtual_host'))),
expected=[dict(hostname='host',
port=5672,
userid='user',
password='password',
virtual_host='virtual_host')])),
('full_url',
dict(url='rabbit://user:password@host:10/virtual_host',
expected=dict(hostname='host',
port=10,
username='user',
password='password',
virtual_host='virtual_host'))),
expected=[dict(hostname='host',
port=10,
userid='user',
password='password',
virtual_host='virtual_host')])),
('full_two_url',
dict(url='rabbit://user:password@host:10,'
'user2:password2@host2:12/virtual_host',
expected=[dict(hostname='host',
port=10,
userid='user',
password='password',
virtual_host='virtual_host'),
dict(hostname='host2',
port=12,
userid='user2',
password='password2',
virtual_host='virtual_host')
]
)),
]
def setUp(self):
super(TestRabbitTransportURL, self).setUp()
self.messaging_conf.transport_driver = 'rabbit'
def test_transport_url(self):
self.messaging_conf.in_memory = True
self._server_params = []
cnx_init = rabbit_driver.Connection.__init__
transport = messaging.get_transport(self.conf, self.url)
self.addCleanup(transport.cleanup)
driver = transport._driver
def record_params(cnx, conf, server_params=None):
self._server_params.append(server_params)
return cnx_init(cnx, conf, server_params)
brokers_params = driver._get_connection().brokers_params[:]
brokers_params = [dict((k, v) for k, v in broker.items()
if k not in ['transport', 'login_method'])
for broker in brokers_params]
def dummy_send(cnx, topic, msg, timeout=None):
pass
self.stubs.Set(rabbit_driver.Connection, '__init__', record_params)
self.stubs.Set(rabbit_driver.Connection, 'topic_send', dummy_send)
self._driver = messaging.get_transport(self.conf, self.url)._driver
self._target = messaging.Target(topic='testtopic')
def test_transport_url_listen(self):
self._driver.listen(self._target)
self.assertEqual(self.expected, self._server_params[0])
def test_transport_url_listen_for_notification(self):
self._driver.listen_for_notifications(
[(messaging.Target(topic='topic'), 'info')])
self.assertEqual(self.expected, self._server_params[0])
def test_transport_url_send(self):
self._driver.send(self._target, {}, {})
self.assertEqual(self.expected, self._server_params[0])
self.assertEqual(sorted(self.expected,
key=operator.itemgetter('hostname')),
sorted(brokers_params,
key=operator.itemgetter('hostname')))
class TestSendReceive(test_utils.BaseTestCase):
@ -619,8 +634,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase):
brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
brokers_count = len(brokers)
self.conf.rabbit_hosts = brokers
self.conf.rabbit_max_retries = 1
self.config(rabbit_hosts=brokers,
rabbit_max_retries=1)
hostname_sets = set()
@ -639,7 +654,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase):
self.stubs.Set(rabbit_driver.Connection, '_connect', _connect)
# starting from the first broker in the list
connection = rabbit_driver.Connection(self.conf)
url = messaging.TransportURL.parse(self.conf, None)
connection = rabbit_driver.Connection(self.conf, url)
# now that we have connection object, revert to the real 'connect'
# implementation