From 53b9d741a8f5a2fa6cbaea9e23a3679a6af235a4 Mon Sep 17 00:00:00 2001 From: Mehdi Abaakouk Date: Fri, 7 Mar 2014 10:46:17 +0100 Subject: [PATCH] Full support of multiple hosts in transport url This patch add the support of multiple hosts in transport url for rabbit and qpid drivers. And also fix the amqp connection pool management to allow to have one pool by transport. Implements blueprint multiple-hosts-support-in-url Co-Authored-By: Ala Rezmerita Change-Id: I5aff24d292b67a7b65e33e7083e245efbbe82024 --- oslo/messaging/_drivers/amqp.py | 23 ++--- oslo/messaging/_drivers/amqpdriver.py | 30 +----- oslo/messaging/_drivers/impl_qpid.py | 62 +++++++------ oslo/messaging/_drivers/impl_rabbit.py | 89 ++++++++++-------- tests/test_qpid.py | 70 +++++++++++++- tests/test_rabbit.py | 122 ++++++++++++++----------- 6 files changed, 238 insertions(+), 158 deletions(-) diff --git a/oslo/messaging/_drivers/amqp.py b/oslo/messaging/_drivers/amqp.py index 74b671baf..b7ec5945a 100644 --- a/oslo/messaging/_drivers/amqp.py +++ b/oslo/messaging/_drivers/amqp.py @@ -59,16 +59,17 @@ LOG = logging.getLogger(__name__) class ConnectionPool(pool.Pool): """Class that implements a Pool of Connections.""" - def __init__(self, conf, connection_cls): + def __init__(self, conf, url, connection_cls): self.connection_cls = connection_cls self.conf = conf + self.url = url super(ConnectionPool, self).__init__(self.conf.rpc_conn_pool_size) self.reply_proxy = None # TODO(comstud): Timeout connections not used in a while def create(self): LOG.debug(_('Pool creating new connection')) - return self.connection_cls(self.conf) + return self.connection_cls(self.conf, self.url) def empty(self): for item in self.iter_free(): @@ -82,18 +83,19 @@ class ConnectionPool(pool.Pool): # time code, it gets here via cleanup() and only appears in service.py # just before doing a sys.exit(), so cleanup() only happens once and # the leakage is not a problem. - self.connection_cls.pool = None + del self.connection_cls.pools[self.url] _pool_create_sem = threading.Lock() -def get_connection_pool(conf, connection_cls): +def get_connection_pool(conf, url, connection_cls): with _pool_create_sem: # Make sure only one thread tries to create the connection pool. - if not connection_cls.pool: - connection_cls.pool = ConnectionPool(conf, connection_cls) - return connection_cls.pool + if url not in connection_cls.pools: + connection_cls.pools[url] = ConnectionPool(conf, url, + connection_cls) + return connection_cls.pools[url] class ConnectionContext(rpc_common.Connection): @@ -108,17 +110,16 @@ class ConnectionContext(rpc_common.Connection): If possible the function makes sure to return a connection to the pool. """ - def __init__(self, conf, connection_pool, pooled=True, server_params=None): + def __init__(self, conf, url, connection_pool, pooled=True): """Create a new connection, or get one from the pool.""" self.connection = None self.conf = conf + self.url = url self.connection_pool = connection_pool if pooled: self.connection = connection_pool.get() else: - self.connection = connection_pool.connection_cls( - conf, - server_params=server_params) + self.connection = connection_pool.connection_cls(conf, url) self.pooled = pooled def __enter__(self): diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 3a2303d39..fedbb7c4c 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -295,8 +295,6 @@ class AMQPDriverBase(base.BaseDriver): super(AMQPDriverBase, self).__init__(conf, url, default_exchange, allowed_remote_exmods) - self._server_params = self._server_params_from_url(self._url) - self._default_exchange = default_exchange # FIXME(markmc): temp hack @@ -310,35 +308,11 @@ class AMQPDriverBase(base.BaseDriver): self._reply_q_conn = None self._waiter = None - def _server_params_from_url(self, url): - sp = {} - - if url.virtual_host is not None: - sp['virtual_host'] = url.virtual_host - - if url.hosts: - # FIXME(markmc): support multiple hosts - host = url.hosts[0] - - sp['hostname'] = host.hostname - if host.port is not None: - sp['port'] = host.port - sp['username'] = host.username or '' - sp['password'] = host.password or '' - - return sp - def _get_connection(self, pooled=True): - # FIXME(markmc): we don't yet have a connection pool for each - # Transport instance, so we'll only use the pool with the - # transport configuration from the config file - server_params = self._server_params or None - if server_params: - pooled = False return rpc_amqp.ConnectionContext(self.conf, + self._url, self._connection_pool, - pooled=pooled, - server_params=server_params) + pooled=pooled) def _get_reply_q(self): with self._reply_q_lock: diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index 39f8136c1..0c169ae77 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -27,6 +27,7 @@ from oslo.messaging._drivers import amqpdriver from oslo.messaging._drivers import common as rpc_common from oslo.messaging.openstack.common import importutils from oslo.messaging.openstack.common import jsonutils +from oslo.messaging.openstack.common import network_utils # FIXME(markmc): remove this _ = lambda s: s @@ -449,9 +450,9 @@ class NotifyPublisher(Publisher): class Connection(object): """Connection object.""" - pool = None + pools = {} - def __init__(self, conf, server_params=None): + def __init__(self, conf, url): if not qpid_messaging: raise ImportError("Failed to import qpid.messaging") @@ -460,35 +461,44 @@ class Connection(object): self.consumers = {} self.conf = conf - if server_params and 'hostname' in server_params: - # NOTE(russellb) This enables support for cast_to_server. - server_params['qpid_hosts'] = [ - '%s:%d' % (server_params['hostname'], - server_params.get('port', 5672)) - ] + self.brokers_params = [] + if url.hosts: + for host in url.hosts: + params = { + 'username': host.username or '', + 'password': host.password or '', + } + if host.port is not None: + params['host'] = '%s:%d' % (host.hostname, host.port) + else: + params['host'] = host.hostname + self.brokers_params.append(params) + else: + # Old configuration format + for adr in self.conf.qpid_hosts: + hostname, port = network_utils.parse_host_port( + adr, default_port=5672) - params = { - 'qpid_hosts': self.conf.qpid_hosts[:], - 'username': self.conf.qpid_username, - 'password': self.conf.qpid_password, - } - params.update(server_params or {}) + params = { + 'host': '%s:%d' % (hostname, port), + 'username': self.conf.qpid_username, + 'password': self.conf.qpid_password, + } + self.brokers_params.append(params) - random.shuffle(params['qpid_hosts']) - self.brokers = itertools.cycle(params['qpid_hosts']) + random.shuffle(self.brokers_params) + self.brokers = itertools.cycle(self.brokers_params) - self.username = params['username'] - self.password = params['password'] self.reconnect() def connection_create(self, broker): # Create the connection - this does not open the connection - self.connection = qpid_messaging.Connection(broker) + self.connection = qpid_messaging.Connection(broker['host']) # Check if flags are set and if so set them for the connection # before we call open - self.connection.username = self.username - self.connection.password = self.password + self.connection.username = broker['username'] + self.connection.password = broker['password'] self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms # Reconnection is done by self.reconnect() @@ -520,14 +530,14 @@ class Connection(object): self.connection_create(broker) self.connection.open() except qpid_exceptions.MessagingError as e: - msg_dict = dict(e=e, delay=delay) - msg = _("Unable to connect to AMQP server: %(e)s. " - "Sleeping %(delay)s seconds") % msg_dict + msg_dict = dict(e=e, delay=delay, broker=broker['host']) + msg = _("Unable to connect to AMQP server on %(broker)s: " + "%(e)s. Sleeping %(delay)s seconds") % msg_dict LOG.error(msg) time.sleep(delay) delay = min(delay + 1, 5) else: - LOG.info(_('Connected to AMQP server on %s'), broker) + LOG.info(_('Connected to AMQP server on %s'), broker['host']) break self.session = self.connection.session() @@ -687,7 +697,7 @@ class QpidDriver(amqpdriver.AMQPDriverBase): conf.register_opts(qpid_opts) conf.register_opts(rpc_amqp.amqp_opts) - connection_pool = rpc_amqp.get_connection_pool(conf, Connection) + connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection) super(QpidDriver, self).__init__(conf, url, connection_pool, diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index b2cf460b9..7c5d1b3fa 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -421,9 +421,9 @@ class NotifyPublisher(TopicPublisher): class Connection(object): """Connection object.""" - pool = None + pools = {} - def __init__(self, conf, server_params=None): + def __init__(self, conf, url): self.consumers = [] self.conf = conf self.max_retries = self.conf.rabbit_max_retries @@ -436,39 +436,54 @@ class Connection(object): self.interval_max = 30 self.memory_transport = False - if server_params is None: - server_params = {} - # Keys to translate from server_params to kombu params - server_params_to_kombu_params = {'username': 'userid'} - ssl_params = self._fetch_ssl_params() - params_list = [] - for adr in self.conf.rabbit_hosts: - hostname, port = network_utils.parse_host_port( - adr, default_port=self.conf.rabbit_port) - params = { - 'hostname': hostname, - 'port': port, - 'userid': self.conf.rabbit_userid, - 'password': self.conf.rabbit_password, - 'login_method': self.conf.rabbit_login_method, - 'virtual_host': self.conf.rabbit_virtual_host, - } + if url.virtual_host is not None: + virtual_host = url.virtual_host + else: + virtual_host = self.conf.rabbit_virtual_host - for sp_key, value in six.iteritems(server_params): - p_key = server_params_to_kombu_params.get(sp_key, sp_key) - params[p_key] = value + self.brokers_params = [] + if url.hosts: + for host in url.hosts: + params = { + 'hostname': host.hostname, + 'port': host.port or 5672, + 'userid': host.username or '', + 'password': host.password or '', + 'login_method': self.conf.rabbit_login_method, + 'virtual_host': virtual_host + } + if self.conf.fake_rabbit: + params['transport'] = 'memory' + if self.conf.rabbit_use_ssl: + params['ssl'] = ssl_params - if self.conf.fake_rabbit: - params['transport'] = 'memory' - if self.conf.rabbit_use_ssl: - params['ssl'] = ssl_params + self.brokers_params.append(params) + else: + # Old configuration format + for adr in self.conf.rabbit_hosts: + hostname, port = network_utils.parse_host_port( + adr, default_port=self.conf.rabbit_port) - params_list.append(params) + params = { + 'hostname': hostname, + 'port': port, + 'userid': self.conf.rabbit_userid, + 'password': self.conf.rabbit_password, + 'login_method': self.conf.rabbit_login_method, + 'virtual_host': virtual_host + } - random.shuffle(params_list) - self.params_list = itertools.cycle(params_list) + if self.conf.fake_rabbit: + params['transport'] = 'memory' + if self.conf.rabbit_use_ssl: + params['ssl'] = ssl_params + + self.brokers_params.append(params) + + random.shuffle(self.brokers_params) + self.brokers = itertools.cycle(self.brokers_params) self.memory_transport = self.conf.fake_rabbit @@ -519,14 +534,14 @@ class Connection(object): # Return the extended behavior or just have the default behavior return ssl_params or True - def _connect(self, params): + def _connect(self, broker): """Connect to rabbit. Re-establish any queues that may have been declared before if we are reconnecting. Exceptions should be handled by the caller. """ if self.connection: LOG.info(_("Reconnecting to AMQP server on " - "%(hostname)s:%(port)d") % params) + "%(hostname)s:%(port)d") % broker) try: # XXX(nic): when reconnecting to a RabbitMQ cluster # with mirrored queues in use, the attempt to release the @@ -545,7 +560,7 @@ class Connection(object): # Setting this in case the next statement fails, though # it shouldn't be doing any network operations, yet. self.connection = None - self.connection = kombu.connection.BrokerConnection(**params) + self.connection = kombu.connection.BrokerConnection(**broker) self.connection_errors = self.connection.connection_errors self.channel_errors = self.connection.channel_errors if self.memory_transport: @@ -561,7 +576,7 @@ class Connection(object): for consumer in self.consumers: consumer.reconnect(self.channel) LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') % - params) + broker) def reconnect(self): """Handles reconnecting and re-establishing queues. @@ -574,10 +589,10 @@ class Connection(object): attempt = 0 while True: - params = six.next(self.params_list) + broker = six.next(self.brokers) attempt += 1 try: - self._connect(params) + self._connect(broker) return except IOError as e: pass @@ -596,7 +611,7 @@ class Connection(object): log_info = {} log_info['err_str'] = str(e) log_info['max_retries'] = self.max_retries - log_info.update(params) + log_info.update(broker) if self.max_retries and attempt == self.max_retries: msg = _('Unable to connect to AMQP server on ' @@ -775,7 +790,7 @@ class RabbitDriver(amqpdriver.AMQPDriverBase): conf.register_opts(rabbit_opts) conf.register_opts(rpc_amqp.amqp_opts) - connection_pool = rpc_amqp.get_connection_pool(conf, Connection) + connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection) super(RabbitDriver, self).__init__(conf, url, connection_pool, diff --git a/tests/test_qpid.py b/tests/test_qpid.py index 419d9dd15..23145518a 100644 --- a/tests/test_qpid.py +++ b/tests/test_qpid.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import operator import random import thread import threading @@ -102,6 +103,65 @@ class _QpidBaseTestCase(test_utils.BaseTestCase): self.con_send.close() +class TestQpidTransportURL(_QpidBaseTestCase): + + scenarios = [ + ('none', dict(url=None, + expected=[dict(host='localhost:5672', + username='', + password='')])), + ('empty', + dict(url='qpid:///', + expected=[dict(host='localhost:5672', + username='', + password='')])), + ('localhost', + dict(url='qpid://localhost/', + expected=[dict(host='localhost', + username='', + password='')])), + ('no_creds', + dict(url='qpid://host/', + expected=[dict(host='host', + username='', + password='')])), + ('no_port', + dict(url='qpid://user:password@host/', + expected=[dict(host='host', + username='user', + password='password')])), + ('full_url', + dict(url='qpid://user:password@host:10/', + expected=[dict(host='host:10', + username='user', + password='password')])), + ('full_two_url', + dict(url='qpid://user:password@host:10,' + 'user2:password2@host2:12/', + expected=[dict(host='host:10', + username='user', + password='password'), + dict(host='host2:12', + username='user2', + password='password2') + ] + )), + + ] + + @mock.patch.object(qpid_driver.Connection, 'reconnect') + def test_transport_url(self, *args): + transport = messaging.get_transport(self.conf, self.url) + self.addCleanup(transport.cleanup) + driver = transport._driver + + brokers_params = driver._get_connection().brokers_params + self.assertEqual(sorted(self.expected, + key=operator.itemgetter('host')), + sorted(brokers_params, + key=operator.itemgetter('host'))) + + class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): """Unit test cases to test invalid qpid topology version.""" @@ -398,11 +458,12 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase): brokers = ['host1', 'host2', 'host3', 'host4', 'host5'] brokers_count = len(brokers) - self.messaging_conf.conf.qpid_hosts = brokers + self.config(qpid_hosts=brokers) with mock.patch('qpid.messaging.Connection') as conn_mock: # starting from the first broker in the list - connection = qpid_driver.Connection(self.messaging_conf.conf) + url = messaging.TransportURL.parse(self.conf, None) + connection = qpid_driver.Connection(self.conf, url) # reconnect will advance to the next broker, one broker per # attempt, and then wrap to the start of the list once the end is @@ -412,7 +473,7 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase): expected = [] for broker in brokers: - expected.extend([mock.call(broker), + expected.extend([mock.call("%s:5672" % broker), mock.call().open(), mock.call().session(), mock.call().opened(), @@ -601,6 +662,9 @@ class FakeQpidSession(object): key = slash_split[-1] return key.strip() + def close(self): + pass + _fake_session = FakeQpidSession() diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index f08d57e38..d42a0f507 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -13,6 +13,7 @@ # under the License. import datetime +import operator import sys import threading import uuid @@ -46,74 +47,88 @@ class TestRabbitDriverLoad(test_utils.BaseTestCase): class TestRabbitTransportURL(test_utils.BaseTestCase): scenarios = [ - ('none', dict(url=None, expected=None)), + ('none', dict(url=None, + expected=[dict(hostname='localhost', + port=5672, + userid='guest', + password='guest', + virtual_host='/')])), ('empty', dict(url='rabbit:///', - expected=dict(virtual_host=''))), + expected=[dict(hostname='localhost', + port=5672, + userid='guest', + password='guest', + virtual_host='')])), ('localhost', dict(url='rabbit://localhost/', - expected=dict(hostname='localhost', - username='', - password='', - virtual_host=''))), + expected=[dict(hostname='localhost', + port=5672, + userid='', + password='', + virtual_host='')])), ('virtual_host', dict(url='rabbit:///vhost', - expected=dict(virtual_host='vhost'))), + expected=[dict(hostname='localhost', + port=5672, + userid='guest', + password='guest', + virtual_host='vhost')])), ('no_creds', dict(url='rabbit://host/virtual_host', - expected=dict(hostname='host', - username='', - password='', - virtual_host='virtual_host'))), + expected=[dict(hostname='host', + port=5672, + userid='', + password='', + virtual_host='virtual_host')])), ('no_port', dict(url='rabbit://user:password@host/virtual_host', - expected=dict(hostname='host', - username='user', - password='password', - virtual_host='virtual_host'))), + expected=[dict(hostname='host', + port=5672, + userid='user', + password='password', + virtual_host='virtual_host')])), ('full_url', dict(url='rabbit://user:password@host:10/virtual_host', - expected=dict(hostname='host', - port=10, - username='user', - password='password', - virtual_host='virtual_host'))), + expected=[dict(hostname='host', + port=10, + userid='user', + password='password', + virtual_host='virtual_host')])), + ('full_two_url', + dict(url='rabbit://user:password@host:10,' + 'user2:password2@host2:12/virtual_host', + expected=[dict(hostname='host', + port=10, + userid='user', + password='password', + virtual_host='virtual_host'), + dict(hostname='host2', + port=12, + userid='user2', + password='password2', + virtual_host='virtual_host') + ] + )), + ] - def setUp(self): - super(TestRabbitTransportURL, self).setUp() - - self.messaging_conf.transport_driver = 'rabbit' + def test_transport_url(self): self.messaging_conf.in_memory = True - self._server_params = [] - cnx_init = rabbit_driver.Connection.__init__ + transport = messaging.get_transport(self.conf, self.url) + self.addCleanup(transport.cleanup) + driver = transport._driver - def record_params(cnx, conf, server_params=None): - self._server_params.append(server_params) - return cnx_init(cnx, conf, server_params) + brokers_params = driver._get_connection().brokers_params[:] + brokers_params = [dict((k, v) for k, v in broker.items() + if k not in ['transport', 'login_method']) + for broker in brokers_params] - def dummy_send(cnx, topic, msg, timeout=None): - pass - - self.stubs.Set(rabbit_driver.Connection, '__init__', record_params) - self.stubs.Set(rabbit_driver.Connection, 'topic_send', dummy_send) - - self._driver = messaging.get_transport(self.conf, self.url)._driver - self._target = messaging.Target(topic='testtopic') - - def test_transport_url_listen(self): - self._driver.listen(self._target) - self.assertEqual(self.expected, self._server_params[0]) - - def test_transport_url_listen_for_notification(self): - self._driver.listen_for_notifications( - [(messaging.Target(topic='topic'), 'info')]) - self.assertEqual(self.expected, self._server_params[0]) - - def test_transport_url_send(self): - self._driver.send(self._target, {}, {}) - self.assertEqual(self.expected, self._server_params[0]) + self.assertEqual(sorted(self.expected, + key=operator.itemgetter('hostname')), + sorted(brokers_params, + key=operator.itemgetter('hostname'))) class TestSendReceive(test_utils.BaseTestCase): @@ -619,8 +634,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase): brokers = ['host1', 'host2', 'host3', 'host4', 'host5'] brokers_count = len(brokers) - self.conf.rabbit_hosts = brokers - self.conf.rabbit_max_retries = 1 + self.config(rabbit_hosts=brokers, + rabbit_max_retries=1) hostname_sets = set() @@ -639,7 +654,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase): self.stubs.Set(rabbit_driver.Connection, '_connect', _connect) # starting from the first broker in the list - connection = rabbit_driver.Connection(self.conf) + url = messaging.TransportURL.parse(self.conf, None) + connection = rabbit_driver.Connection(self.conf, url) # now that we have connection object, revert to the real 'connect' # implementation