Merge "Full support of multiple hosts in transport url"

This commit is contained in:
Jenkins 2014-05-05 15:46:34 +00:00 committed by Gerrit Code Review
commit 17375f41ce
6 changed files with 238 additions and 158 deletions

View File

@ -59,16 +59,17 @@ LOG = logging.getLogger(__name__)
class ConnectionPool(pool.Pool):
"""Class that implements a Pool of Connections."""
def __init__(self, conf, connection_cls):
def __init__(self, conf, url, connection_cls):
self.connection_cls = connection_cls
self.conf = conf
self.url = url
super(ConnectionPool, self).__init__(self.conf.rpc_conn_pool_size)
self.reply_proxy = None
# TODO(comstud): Timeout connections not used in a while
def create(self):
LOG.debug(_('Pool creating new connection'))
return self.connection_cls(self.conf)
return self.connection_cls(self.conf, self.url)
def empty(self):
for item in self.iter_free():
@ -82,18 +83,19 @@ class ConnectionPool(pool.Pool):
# time code, it gets here via cleanup() and only appears in service.py
# just before doing a sys.exit(), so cleanup() only happens once and
# the leakage is not a problem.
self.connection_cls.pool = None
del self.connection_cls.pools[self.url]
_pool_create_sem = threading.Lock()
def get_connection_pool(conf, connection_cls):
def get_connection_pool(conf, url, connection_cls):
with _pool_create_sem:
# Make sure only one thread tries to create the connection pool.
if not connection_cls.pool:
connection_cls.pool = ConnectionPool(conf, connection_cls)
return connection_cls.pool
if url not in connection_cls.pools:
connection_cls.pools[url] = ConnectionPool(conf, url,
connection_cls)
return connection_cls.pools[url]
class ConnectionContext(rpc_common.Connection):
@ -108,17 +110,16 @@ class ConnectionContext(rpc_common.Connection):
If possible the function makes sure to return a connection to the pool.
"""
def __init__(self, conf, connection_pool, pooled=True, server_params=None):
def __init__(self, conf, url, connection_pool, pooled=True):
"""Create a new connection, or get one from the pool."""
self.connection = None
self.conf = conf
self.url = url
self.connection_pool = connection_pool
if pooled:
self.connection = connection_pool.get()
else:
self.connection = connection_pool.connection_cls(
conf,
server_params=server_params)
self.connection = connection_pool.connection_cls(conf, url)
self.pooled = pooled
def __enter__(self):

View File

@ -295,8 +295,6 @@ class AMQPDriverBase(base.BaseDriver):
super(AMQPDriverBase, self).__init__(conf, url, default_exchange,
allowed_remote_exmods)
self._server_params = self._server_params_from_url(self._url)
self._default_exchange = default_exchange
# FIXME(markmc): temp hack
@ -310,35 +308,11 @@ class AMQPDriverBase(base.BaseDriver):
self._reply_q_conn = None
self._waiter = None
def _server_params_from_url(self, url):
sp = {}
if url.virtual_host is not None:
sp['virtual_host'] = url.virtual_host
if url.hosts:
# FIXME(markmc): support multiple hosts
host = url.hosts[0]
sp['hostname'] = host.hostname
if host.port is not None:
sp['port'] = host.port
sp['username'] = host.username or ''
sp['password'] = host.password or ''
return sp
def _get_connection(self, pooled=True):
# FIXME(markmc): we don't yet have a connection pool for each
# Transport instance, so we'll only use the pool with the
# transport configuration from the config file
server_params = self._server_params or None
if server_params:
pooled = False
return rpc_amqp.ConnectionContext(self.conf,
self._url,
self._connection_pool,
pooled=pooled,
server_params=server_params)
pooled=pooled)
def _get_reply_q(self):
with self._reply_q_lock:

View File

@ -27,6 +27,7 @@ from oslo.messaging._drivers import amqpdriver
from oslo.messaging._drivers import common as rpc_common
from oslo.messaging.openstack.common import importutils
from oslo.messaging.openstack.common import jsonutils
from oslo.messaging.openstack.common import network_utils
# FIXME(markmc): remove this
_ = lambda s: s
@ -449,9 +450,9 @@ class NotifyPublisher(Publisher):
class Connection(object):
"""Connection object."""
pool = None
pools = {}
def __init__(self, conf, server_params=None):
def __init__(self, conf, url):
if not qpid_messaging:
raise ImportError("Failed to import qpid.messaging")
@ -460,35 +461,44 @@ class Connection(object):
self.consumers = {}
self.conf = conf
if server_params and 'hostname' in server_params:
# NOTE(russellb) This enables support for cast_to_server.
server_params['qpid_hosts'] = [
'%s:%d' % (server_params['hostname'],
server_params.get('port', 5672))
]
self.brokers_params = []
if url.hosts:
for host in url.hosts:
params = {
'username': host.username or '',
'password': host.password or '',
}
if host.port is not None:
params['host'] = '%s:%d' % (host.hostname, host.port)
else:
params['host'] = host.hostname
self.brokers_params.append(params)
else:
# Old configuration format
for adr in self.conf.qpid_hosts:
hostname, port = network_utils.parse_host_port(
adr, default_port=5672)
params = {
'qpid_hosts': self.conf.qpid_hosts[:],
'username': self.conf.qpid_username,
'password': self.conf.qpid_password,
}
params.update(server_params or {})
params = {
'host': '%s:%d' % (hostname, port),
'username': self.conf.qpid_username,
'password': self.conf.qpid_password,
}
self.brokers_params.append(params)
random.shuffle(params['qpid_hosts'])
self.brokers = itertools.cycle(params['qpid_hosts'])
random.shuffle(self.brokers_params)
self.brokers = itertools.cycle(self.brokers_params)
self.username = params['username']
self.password = params['password']
self.reconnect()
def connection_create(self, broker):
# Create the connection - this does not open the connection
self.connection = qpid_messaging.Connection(broker)
self.connection = qpid_messaging.Connection(broker['host'])
# Check if flags are set and if so set them for the connection
# before we call open
self.connection.username = self.username
self.connection.password = self.password
self.connection.username = broker['username']
self.connection.password = broker['password']
self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
# Reconnection is done by self.reconnect()
@ -520,14 +530,14 @@ class Connection(object):
self.connection_create(broker)
self.connection.open()
except qpid_exceptions.MessagingError as e:
msg_dict = dict(e=e, delay=delay)
msg = _("Unable to connect to AMQP server: %(e)s. "
"Sleeping %(delay)s seconds") % msg_dict
msg_dict = dict(e=e, delay=delay, broker=broker['host'])
msg = _("Unable to connect to AMQP server on %(broker)s: "
"%(e)s. Sleeping %(delay)s seconds") % msg_dict
LOG.error(msg)
time.sleep(delay)
delay = min(delay + 1, 5)
else:
LOG.info(_('Connected to AMQP server on %s'), broker)
LOG.info(_('Connected to AMQP server on %s'), broker['host'])
break
self.session = self.connection.session()
@ -686,7 +696,7 @@ class QpidDriver(amqpdriver.AMQPDriverBase):
conf.register_opts(qpid_opts)
conf.register_opts(rpc_amqp.amqp_opts)
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection)
super(QpidDriver, self).__init__(conf, url,
connection_pool,

View File

@ -421,9 +421,9 @@ class NotifyPublisher(TopicPublisher):
class Connection(object):
"""Connection object."""
pool = None
pools = {}
def __init__(self, conf, server_params=None):
def __init__(self, conf, url):
self.consumers = []
self.conf = conf
self.max_retries = self.conf.rabbit_max_retries
@ -436,39 +436,54 @@ class Connection(object):
self.interval_max = 30
self.memory_transport = False
if server_params is None:
server_params = {}
# Keys to translate from server_params to kombu params
server_params_to_kombu_params = {'username': 'userid'}
ssl_params = self._fetch_ssl_params()
params_list = []
for adr in self.conf.rabbit_hosts:
hostname, port = network_utils.parse_host_port(
adr, default_port=self.conf.rabbit_port)
params = {
'hostname': hostname,
'port': port,
'userid': self.conf.rabbit_userid,
'password': self.conf.rabbit_password,
'login_method': self.conf.rabbit_login_method,
'virtual_host': self.conf.rabbit_virtual_host,
}
if url.virtual_host is not None:
virtual_host = url.virtual_host
else:
virtual_host = self.conf.rabbit_virtual_host
for sp_key, value in six.iteritems(server_params):
p_key = server_params_to_kombu_params.get(sp_key, sp_key)
params[p_key] = value
self.brokers_params = []
if url.hosts:
for host in url.hosts:
params = {
'hostname': host.hostname,
'port': host.port or 5672,
'userid': host.username or '',
'password': host.password or '',
'login_method': self.conf.rabbit_login_method,
'virtual_host': virtual_host
}
if self.conf.fake_rabbit:
params['transport'] = 'memory'
if self.conf.rabbit_use_ssl:
params['ssl'] = ssl_params
if self.conf.fake_rabbit:
params['transport'] = 'memory'
if self.conf.rabbit_use_ssl:
params['ssl'] = ssl_params
self.brokers_params.append(params)
else:
# Old configuration format
for adr in self.conf.rabbit_hosts:
hostname, port = network_utils.parse_host_port(
adr, default_port=self.conf.rabbit_port)
params_list.append(params)
params = {
'hostname': hostname,
'port': port,
'userid': self.conf.rabbit_userid,
'password': self.conf.rabbit_password,
'login_method': self.conf.rabbit_login_method,
'virtual_host': virtual_host
}
random.shuffle(params_list)
self.params_list = itertools.cycle(params_list)
if self.conf.fake_rabbit:
params['transport'] = 'memory'
if self.conf.rabbit_use_ssl:
params['ssl'] = ssl_params
self.brokers_params.append(params)
random.shuffle(self.brokers_params)
self.brokers = itertools.cycle(self.brokers_params)
self.memory_transport = self.conf.fake_rabbit
@ -519,14 +534,14 @@ class Connection(object):
# Return the extended behavior or just have the default behavior
return ssl_params or True
def _connect(self, params):
def _connect(self, broker):
"""Connect to rabbit. Re-establish any queues that may have
been declared before if we are reconnecting. Exceptions should
be handled by the caller.
"""
if self.connection:
LOG.info(_("Reconnecting to AMQP server on "
"%(hostname)s:%(port)d") % params)
"%(hostname)s:%(port)d") % broker)
try:
# XXX(nic): when reconnecting to a RabbitMQ cluster
# with mirrored queues in use, the attempt to release the
@ -545,7 +560,7 @@ class Connection(object):
# Setting this in case the next statement fails, though
# it shouldn't be doing any network operations, yet.
self.connection = None
self.connection = kombu.connection.BrokerConnection(**params)
self.connection = kombu.connection.BrokerConnection(**broker)
self.connection_errors = self.connection.connection_errors
self.channel_errors = self.connection.channel_errors
if self.memory_transport:
@ -561,7 +576,7 @@ class Connection(object):
for consumer in self.consumers:
consumer.reconnect(self.channel)
LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') %
params)
broker)
def reconnect(self):
"""Handles reconnecting and re-establishing queues.
@ -574,10 +589,10 @@ class Connection(object):
attempt = 0
while True:
params = six.next(self.params_list)
broker = six.next(self.brokers)
attempt += 1
try:
self._connect(params)
self._connect(broker)
return
except IOError as e:
pass
@ -596,7 +611,7 @@ class Connection(object):
log_info = {}
log_info['err_str'] = e
log_info['max_retries'] = self.max_retries
log_info.update(params)
log_info.update(broker)
if self.max_retries and attempt == self.max_retries:
msg = _('Unable to connect to AMQP server on '
@ -774,7 +789,7 @@ class RabbitDriver(amqpdriver.AMQPDriverBase):
conf.register_opts(rabbit_opts)
conf.register_opts(rpc_amqp.amqp_opts)
connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection)
super(RabbitDriver, self).__init__(conf, url,
connection_pool,

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import operator
import random
import thread
import threading
@ -102,6 +103,65 @@ class _QpidBaseTestCase(test_utils.BaseTestCase):
self.con_send.close()
class TestQpidTransportURL(_QpidBaseTestCase):
scenarios = [
('none', dict(url=None,
expected=[dict(host='localhost:5672',
username='',
password='')])),
('empty',
dict(url='qpid:///',
expected=[dict(host='localhost:5672',
username='',
password='')])),
('localhost',
dict(url='qpid://localhost/',
expected=[dict(host='localhost',
username='',
password='')])),
('no_creds',
dict(url='qpid://host/',
expected=[dict(host='host',
username='',
password='')])),
('no_port',
dict(url='qpid://user:password@host/',
expected=[dict(host='host',
username='user',
password='password')])),
('full_url',
dict(url='qpid://user:password@host:10/',
expected=[dict(host='host:10',
username='user',
password='password')])),
('full_two_url',
dict(url='qpid://user:password@host:10,'
'user2:password2@host2:12/',
expected=[dict(host='host:10',
username='user',
password='password'),
dict(host='host2:12',
username='user2',
password='password2')
]
)),
]
@mock.patch.object(qpid_driver.Connection, 'reconnect')
def test_transport_url(self, *args):
transport = messaging.get_transport(self.conf, self.url)
self.addCleanup(transport.cleanup)
driver = transport._driver
brokers_params = driver._get_connection().brokers_params
self.assertEqual(sorted(self.expected,
key=operator.itemgetter('host')),
sorted(brokers_params,
key=operator.itemgetter('host')))
class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
"""Unit test cases to test invalid qpid topology version."""
@ -398,11 +458,12 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase):
brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
brokers_count = len(brokers)
self.messaging_conf.conf.qpid_hosts = brokers
self.config(qpid_hosts=brokers)
with mock.patch('qpid.messaging.Connection') as conn_mock:
# starting from the first broker in the list
connection = qpid_driver.Connection(self.messaging_conf.conf)
url = messaging.TransportURL.parse(self.conf, None)
connection = qpid_driver.Connection(self.conf, url)
# reconnect will advance to the next broker, one broker per
# attempt, and then wrap to the start of the list once the end is
@ -412,7 +473,7 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase):
expected = []
for broker in brokers:
expected.extend([mock.call(broker),
expected.extend([mock.call("%s:5672" % broker),
mock.call().open(),
mock.call().session(),
mock.call().opened(),
@ -601,6 +662,9 @@ class FakeQpidSession(object):
key = slash_split[-1]
return key.strip()
def close(self):
pass
_fake_session = FakeQpidSession()

View File

@ -13,6 +13,7 @@
# under the License.
import datetime
import operator
import sys
import threading
import uuid
@ -46,74 +47,88 @@ class TestRabbitDriverLoad(test_utils.BaseTestCase):
class TestRabbitTransportURL(test_utils.BaseTestCase):
scenarios = [
('none', dict(url=None, expected=None)),
('none', dict(url=None,
expected=[dict(hostname='localhost',
port=5672,
userid='guest',
password='guest',
virtual_host='/')])),
('empty',
dict(url='rabbit:///',
expected=dict(virtual_host=''))),
expected=[dict(hostname='localhost',
port=5672,
userid='guest',
password='guest',
virtual_host='')])),
('localhost',
dict(url='rabbit://localhost/',
expected=dict(hostname='localhost',
username='',
password='',
virtual_host=''))),
expected=[dict(hostname='localhost',
port=5672,
userid='',
password='',
virtual_host='')])),
('virtual_host',
dict(url='rabbit:///vhost',
expected=dict(virtual_host='vhost'))),
expected=[dict(hostname='localhost',
port=5672,
userid='guest',
password='guest',
virtual_host='vhost')])),
('no_creds',
dict(url='rabbit://host/virtual_host',
expected=dict(hostname='host',
username='',
password='',
virtual_host='virtual_host'))),
expected=[dict(hostname='host',
port=5672,
userid='',
password='',
virtual_host='virtual_host')])),
('no_port',
dict(url='rabbit://user:password@host/virtual_host',
expected=dict(hostname='host',
username='user',
password='password',
virtual_host='virtual_host'))),
expected=[dict(hostname='host',
port=5672,
userid='user',
password='password',
virtual_host='virtual_host')])),
('full_url',
dict(url='rabbit://user:password@host:10/virtual_host',
expected=dict(hostname='host',
port=10,
username='user',
password='password',
virtual_host='virtual_host'))),
expected=[dict(hostname='host',
port=10,
userid='user',
password='password',
virtual_host='virtual_host')])),
('full_two_url',
dict(url='rabbit://user:password@host:10,'
'user2:password2@host2:12/virtual_host',
expected=[dict(hostname='host',
port=10,
userid='user',
password='password',
virtual_host='virtual_host'),
dict(hostname='host2',
port=12,
userid='user2',
password='password2',
virtual_host='virtual_host')
]
)),
]
def setUp(self):
super(TestRabbitTransportURL, self).setUp()
self.messaging_conf.transport_driver = 'rabbit'
def test_transport_url(self):
self.messaging_conf.in_memory = True
self._server_params = []
cnx_init = rabbit_driver.Connection.__init__
transport = messaging.get_transport(self.conf, self.url)
self.addCleanup(transport.cleanup)
driver = transport._driver
def record_params(cnx, conf, server_params=None):
self._server_params.append(server_params)
return cnx_init(cnx, conf, server_params)
brokers_params = driver._get_connection().brokers_params[:]
brokers_params = [dict((k, v) for k, v in broker.items()
if k not in ['transport', 'login_method'])
for broker in brokers_params]
def dummy_send(cnx, topic, msg, timeout=None):
pass
self.stubs.Set(rabbit_driver.Connection, '__init__', record_params)
self.stubs.Set(rabbit_driver.Connection, 'topic_send', dummy_send)
self._driver = messaging.get_transport(self.conf, self.url)._driver
self._target = messaging.Target(topic='testtopic')
def test_transport_url_listen(self):
self._driver.listen(self._target)
self.assertEqual(self.expected, self._server_params[0])
def test_transport_url_listen_for_notification(self):
self._driver.listen_for_notifications(
[(messaging.Target(topic='topic'), 'info')])
self.assertEqual(self.expected, self._server_params[0])
def test_transport_url_send(self):
self._driver.send(self._target, {}, {})
self.assertEqual(self.expected, self._server_params[0])
self.assertEqual(sorted(self.expected,
key=operator.itemgetter('hostname')),
sorted(brokers_params,
key=operator.itemgetter('hostname')))
class TestSendReceive(test_utils.BaseTestCase):
@ -619,8 +634,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase):
brokers = ['host1', 'host2', 'host3', 'host4', 'host5']
brokers_count = len(brokers)
self.conf.rabbit_hosts = brokers
self.conf.rabbit_max_retries = 1
self.config(rabbit_hosts=brokers,
rabbit_max_retries=1)
hostname_sets = set()
@ -639,7 +654,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase):
self.stubs.Set(rabbit_driver.Connection, '_connect', _connect)
# starting from the first broker in the list
connection = rabbit_driver.Connection(self.conf)
url = messaging.TransportURL.parse(self.conf, None)
connection = rabbit_driver.Connection(self.conf, url)
# now that we have connection object, revert to the real 'connect'
# implementation