From 045cddcea2ecefccecbb40d4249b915c3f1faae3 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Sun, 21 Sep 2014 13:20:35 -0700 Subject: [PATCH] Add an optional advanced pool of memcached clients This patchset adds an advanced eventlet safe pool of memcache clients. This allows the deployer to configure auth_token middleware to utilize the new pool by simply setting 'memcache_use_advanced_pool' to true. Optional tunables for the memcache pool have also been added. Co-Authored-By: Morgan Fainberg Closes-bug: #1332058 Closes-bug: #1360446 Change-Id: I08082b46ce692cf4df449d48dac94718f1e98a6c --- keystonemiddleware/_memcache_pool.py | 182 ++++++++++++++++++ keystonemiddleware/auth_token.py | 105 +++++++++- .../tests/test_connection_pool.py | 118 ++++++++++++ keystonemiddleware/tests/test_opts.py | 6 + 4 files changed, 407 insertions(+), 4 deletions(-) create mode 100644 keystonemiddleware/_memcache_pool.py create mode 100644 keystonemiddleware/tests/test_connection_pool.py diff --git a/keystonemiddleware/_memcache_pool.py b/keystonemiddleware/_memcache_pool.py new file mode 100644 index 00000000..5d2de54e --- /dev/null +++ b/keystonemiddleware/_memcache_pool.py @@ -0,0 +1,182 @@ +# Copyright 2014 Mirantis Inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Thread-safe connection pool for python-memcached.""" + +# NOTE(yorik-sar): this file is copied between keystone and keystonemiddleware +# and should be kept in sync until we can use external library for this. + +import collections +import contextlib +import itertools +import logging +import time + +from six.moves import queue + + +_PoolItem = collections.namedtuple('_PoolItem', ['ttl', 'connection']) + + +class ConnectionGetTimeoutException(Exception): + pass + + +class ConnectionPool(queue.Queue): + """Base connection pool class + + This class implements the basic connection pool logic as an abstract base + class. + """ + def __init__(self, maxsize, unused_timeout, conn_get_timeout=None): + """Initialize the connection pool. + + :param maxsize: maximum number of client connections for the pool + :type maxsize: int + :param unused_timeout: idle time to live for unused clients (in + seconds). If a client connection object has been + in the pool and idle for longer than the + unused_timeout, it will be reaped. This is to + ensure resources are released as utilization + goes down. + :type unused_timeout: int + :param conn_get_timeout: maximum time in seconds to wait for a + connection. If set to `None` timeout is + indefinite. + :type conn_get_timeout: int + """ + queue.Queue.__init__(self, maxsize) + self._unused_timeout = unused_timeout + self._connection_get_timeout = conn_get_timeout + self._acquired = 0 + self._LOG = logging.getLogger(__name__) + + def _create_connection(self): + raise NotImplementedError + + def _destroy_connection(self, conn): + raise NotImplementedError + + @contextlib.contextmanager + def acquire(self): + try: + conn = self.get(timeout=self._connection_get_timeout) + except queue.Empty: + self._LOG.critical('Unable to get a connection from pool id ' + '%(id)s after %(seconds)s seconds.', + {'id': id(self), + 'seconds': self._connection_get_timeout}) + raise ConnectionGetTimeoutException() + try: + yield conn + finally: + self.put(conn) + + def _qsize(self): + return self.maxsize - self._acquired + + if not hasattr(queue.Queue, '_qsize'): + qsize = _qsize + + def _get(self): + if self.queue: + conn = self.queue.pop().connection + else: + conn = self._create_connection() + self._acquired += 1 + return conn + + def _put(self, conn): + self.queue.append(_PoolItem( + ttl=time.time() + self._unused_timeout, + connection=conn, + )) + self._acquired -= 1 + # Drop all expired connections from the right end of the queue + now = time.time() + while self.queue and self.queue[0].ttl < now: + conn = self.queue.popleft().connection + self._destroy_connection(conn) + + +class MemcacheClientPool(ConnectionPool): + def __init__(self, urls, arguments, **kwargs): + ConnectionPool.__init__(self, **kwargs) + self._urls = urls + self._arguments = arguments + # NOTE(morganfainberg): The host objects expect an int for the + # deaduntil value. Initialize this at 0 for each host with 0 indicating + # the host is not dead. + self._hosts_deaduntil = [0] * len(urls) + + # NOTE(morganfainberg): Lazy import to allow middleware to work with + # python 3k even if memcache will not due to python 3k + # incompatibilities within the python-memcache library. + global memcache + import memcache + + # This 'class' is taken from http://stackoverflow.com/a/22520633/238308 + # Don't inherit client from threading.local so that we can reuse + # clients in different threads + MemcacheClient = type('_MemcacheClient', (object,), + dict(memcache.Client.__dict__)) + + self._memcache_client_class = MemcacheClient + + def _create_connection(self): + return self._memcache_client_class(self._urls, **self._arguments) + + def _destroy_connection(self, conn): + conn.disconnect_all() + + def _get(self): + conn = ConnectionPool._get(self) + try: + # Propagate host state known to us to this client's list + now = time.time() + for deaduntil, host in zip(self._hosts_deaduntil, conn.servers): + if deaduntil > now and host.deaduntil <= now: + host.mark_dead('propagating death mark from the pool') + host.deaduntil = deaduntil + except Exception: + # We need to be sure that connection doesn't leak from the pool. + # This code runs before we enter context manager's try-finally + # block, so we need to explicitly release it here + ConnectionPool._put(self, conn) + raise + return conn + + def _put(self, conn): + try: + # If this client found that one of the hosts is dead, mark it as + # such in our internal list + now = time.time() + for i, deaduntil, host in zip(itertools.count(), + self._hosts_deaduntil, + conn.servers): + # Do nothing if we already know this host is dead + if deaduntil <= now: + if host.deaduntil > now: + self._hosts_deaduntil[i] = host.deaduntil + else: + self._hosts_deaduntil[i] = 0 + # If all hosts are dead we should forget that they're dead. This + # way we won't get completely shut off until dead_retry seconds + # pass, but will be checking servers as frequent as we can (over + # way smaller socket_timeout) + if all(deaduntil > now for deaduntil in self._hosts_deaduntil): + self._hosts_deaduntil[:] = [0] * len(self._hosts_deaduntil) + finally: + ConnectionPool._put(self, conn) diff --git a/keystonemiddleware/auth_token.py b/keystonemiddleware/auth_token.py index 07a35827..ba2f1d2a 100644 --- a/keystonemiddleware/auth_token.py +++ b/keystonemiddleware/auth_token.py @@ -322,6 +322,31 @@ _OPTS = [ secret=True, help='(optional, mandatory if memcache_security_strategy is' ' defined) this string is used for key derivation.'), + cfg.IntOpt('memcache_pool_dead_retry', + default=5 * 60, + help='(optional) number of seconds memcached server is' + ' considered dead before it is tried again.'), + cfg.IntOpt('memcache_pool_maxsize', + default=10, + help='(optional) max total number of open connections to' + ' every memcached server.'), + cfg.IntOpt('memcache_pool_socket_timeout', + default=3, + help='(optional) socket timeout in seconds for communicating ' + 'with a memcache server.'), + cfg.IntOpt('memcache_pool_unused_timeout', + default=60, + help='(optional) number of seconds a connection to memcached' + ' is held unused in the pool before it is closed.'), + cfg.IntOpt('memcache_pool_conn_get_timeout', + default=10, + help='(optional) number of seconds that an operation will wait ' + 'to get a memcache client connection from the pool.'), + cfg.BoolOpt('memcache_use_advanced_pool', + default=False, + help='(optional) use the advanced (eventlet safe) memcache ' + 'client pool. The advanced pool will only work under ' + 'python 2.x.'), cfg.BoolOpt('include_service_catalog', default=True, help='(optional) indicate whether to set the X-Service-Catalog' @@ -1245,7 +1270,17 @@ class AuthProtocol(object): env_cache_name=self._conf_get('cache'), memcached_servers=self._conf_get('memcached_servers'), memcache_security_strategy=self._memcache_security_strategy, - memcache_secret_key=self._conf_get('memcache_secret_key')) + memcache_secret_key=self._conf_get('memcache_secret_key'), + use_advanced_pool=self._conf_get('memcache_use_advanced_pool'), + memcache_pool_dead_retry=self._conf_get( + 'memcache_pool_dead_retry'), + memcache_pool_maxsize=self._conf_get('memcache_pool_maxsize'), + memcache_pool_unused_timeout=self._conf_get( + 'memcache_pool_unused_timeout'), + memcache_pool_conn_get_timeout=self._conf_get( + 'memcache_pool_conn_get_timeout'), + memcache_pool_socket_timeout=self._conf_get( + 'memcache_pool_socket_timeout')) return token_cache @@ -1276,6 +1311,34 @@ class _CachePool(list): self.append(c) +class _MemcacheClientPool(object): + """An advanced memcached client pool that is eventlet safe.""" + def __init__(self, memcache_servers, memcache_dead_retry=None, + memcache_pool_maxsize=None, memcache_pool_unused_timeout=None, + memcache_pool_conn_get_timeout=None, + memcache_pool_socket_timeout=None): + # NOTE(morganfainberg): import here to avoid hard dependency on + # python-memcache library. + global _memcache_pool + from keystonemiddleware import _memcache_pool + + self._pool = _memcache_pool.MemcacheClientPool( + memcache_servers, + arguments={ + 'dead_retry': memcache_dead_retry, + 'socket_timeout': memcache_pool_socket_timeout, + }, + maxsize=memcache_pool_maxsize, + unused_timeout=memcache_pool_unused_timeout, + conn_get_timeout=memcache_pool_conn_get_timeout, + ) + + @contextlib.contextmanager + def reserve(self): + with self._pool.get() as client: + yield client + + class _IdentityServer(object): """Operations on the Identity API server. @@ -1542,12 +1605,22 @@ class _TokenCache(object): def __init__(self, log, cache_time=None, hash_algorithms=None, env_cache_name=None, memcached_servers=None, - memcache_security_strategy=None, memcache_secret_key=None): + memcache_security_strategy=None, memcache_secret_key=None, + use_advanced_pool=False, memcache_pool_dead_retry=None, + memcache_pool_maxsize=None, memcache_pool_unused_timeout=None, + memcache_pool_conn_get_timeout=None, + memcache_pool_socket_timeout=None): self._LOG = log self._cache_time = cache_time self._hash_algorithms = hash_algorithms self._env_cache_name = env_cache_name self._memcached_servers = memcached_servers + self._use_advanced_pool = use_advanced_pool + self._memcache_pool_dead_retry = memcache_pool_dead_retry, + self._memcache_pool_maxsize = memcache_pool_maxsize, + self._memcache_pool_unused_timeout = memcache_pool_unused_timeout + self._memcache_pool_conn_get_timeout = memcache_pool_conn_get_timeout + self._memcache_pool_socket_timeout = memcache_pool_socket_timeout # memcache value treatment, ENCRYPT or MAC self._memcache_security_strategy = memcache_security_strategy @@ -1561,12 +1634,36 @@ class _TokenCache(object): self._assert_valid_memcache_protection_config() + def _get_cache_pool(self, cache, memcache_servers, use_advanced_pool=False, + memcache_dead_retry=None, memcache_pool_maxsize=None, + memcache_pool_unused_timeout=None, + memcache_pool_conn_get_timeout=None, + memcache_pool_socket_timeout=None): + if use_advanced_pool is True and memcache_servers and cache is None: + return _MemcacheClientPool( + memcache_servers, + memcache_dead_retry=memcache_dead_retry, + memcache_pool_maxsize=memcache_pool_maxsize, + memcache_pool_unused_timeout=memcache_pool_unused_timeout, + memcache_pool_conn_get_timeout=memcache_pool_conn_get_timeout, + memcache_pool_socket_timeout=memcache_pool_socket_timeout) + else: + return _CachePool(cache, memcache_servers) + def initialize(self, env): if self._initialized: return - self._cache_pool = _CachePool(env.get(self._env_cache_name), - self._memcached_servers) + self._cache_pool = self._get_cache_pool( + env.get(self._env_cache_name), + self._memcached_servers, + use_advanced_pool=self._use_advanced_pool, + memcache_dead_retry=self._memcache_pool_dead_retry, + memcache_pool_maxsize=self._memcache_pool_maxsize, + memcache_pool_unused_timeout=self._memcache_pool_unused_timeout, + memcache_pool_conn_get_timeout=self._memcache_pool_conn_get_timeout + ) + self._initialized = True def get(self, user_token): diff --git a/keystonemiddleware/tests/test_connection_pool.py b/keystonemiddleware/tests/test_connection_pool.py new file mode 100644 index 00000000..c5152764 --- /dev/null +++ b/keystonemiddleware/tests/test_connection_pool.py @@ -0,0 +1,118 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import time + +import mock +from six.moves import queue +import testtools +from testtools import matchers + +from keystonemiddleware import _memcache_pool +from keystonemiddleware.tests import utils + + +class _TestConnectionPool(_memcache_pool.ConnectionPool): + destroyed_value = 'destroyed' + + def _create_connection(self): + return mock.MagicMock() + + def _destroy_connection(self, conn): + conn(self.destroyed_value) + + +class TestConnectionPool(utils.TestCase): + def setUp(self): + super(TestConnectionPool, self).setUp() + self.unused_timeout = 10 + self.maxsize = 2 + self.connection_pool = _TestConnectionPool( + maxsize=self.maxsize, + unused_timeout=self.unused_timeout) + + def test_get_context_manager(self): + self.assertThat(self.connection_pool.queue, matchers.HasLength(0)) + with self.connection_pool.acquire() as conn: + self.assertEqual(1, self.connection_pool._acquired) + self.assertEqual(0, self.connection_pool._acquired) + self.assertThat(self.connection_pool.queue, matchers.HasLength(1)) + self.assertEqual(conn, self.connection_pool.queue[0].connection) + + def test_cleanup_pool(self): + self.test_get_context_manager() + newtime = time.time() + self.unused_timeout * 2 + non_expired_connection = _memcache_pool._PoolItem( + ttl=(newtime * 2), + connection=mock.MagicMock()) + self.connection_pool.queue.append(non_expired_connection) + self.assertThat(self.connection_pool.queue, matchers.HasLength(2)) + with mock.patch.object(time, 'time', return_value=newtime): + conn = self.connection_pool.queue[0].connection + with self.connection_pool.acquire(): + pass + conn.assert_has_calls( + [mock.call(self.connection_pool.destroyed_value)]) + self.assertThat(self.connection_pool.queue, matchers.HasLength(1)) + self.assertEqual(0, non_expired_connection.connection.call_count) + + def test_acquire_conn_exception_returns_acquired_count(self): + class TestException(Exception): + pass + + with mock.patch.object(_TestConnectionPool, '_create_connection', + side_effect=TestException): + with testtools.ExpectedException(TestException): + with self.connection_pool.acquire(): + pass + self.assertThat(self.connection_pool.queue, + matchers.HasLength(0)) + self.assertEqual(0, self.connection_pool._acquired) + + def test_connection_pool_limits_maximum_connections(self): + # NOTE(morganfainberg): To ensure we don't lockup tests until the + # job limit, explicitly call .get_nowait() and .put_nowait() in this + # case. + conn1 = self.connection_pool.get_nowait() + conn2 = self.connection_pool.get_nowait() + + # Use a nowait version to raise an Empty exception indicating we would + # not get another connection until one is placed back into the queue. + self.assertRaises(queue.Empty, self.connection_pool.get_nowait) + + # Place the connections back into the pool. + self.connection_pool.put_nowait(conn1) + self.connection_pool.put_nowait(conn2) + + # Make sure we can get a connection out of the pool again. + self.connection_pool.get_nowait() + + def test_connection_pool_maximum_connection_get_timeout(self): + connection_pool = _TestConnectionPool( + maxsize=1, + unused_timeout=self.unused_timeout, + conn_get_timeout=0) + + def _acquire_connection(): + with connection_pool.acquire(): + pass + + # Make sure we've consumed the only available connection from the pool + conn = connection_pool.get_nowait() + + self.assertRaises(_memcache_pool.ConnectionGetTimeoutException, + _acquire_connection) + + # Put the connection back and ensure we can acquire the connection + # after it is available. + connection_pool.put_nowait(conn) + _acquire_connection() diff --git a/keystonemiddleware/tests/test_opts.py b/keystonemiddleware/tests/test_opts.py index d6839b21..eeeb84fe 100644 --- a/keystonemiddleware/tests/test_opts.py +++ b/keystonemiddleware/tests/test_opts.py @@ -53,6 +53,12 @@ class OptsTestCase(utils.TestCase): 'revocation_cache_time', 'memcache_security_strategy', 'memcache_secret_key', + 'memcache_use_advanced_pool', + 'memcache_pool_dead_retry', + 'memcache_pool_maxsize', + 'memcache_pool_unused_timeout', + 'memcache_pool_conn_get_timeout', + 'memcache_pool_socket_timeout', 'include_service_catalog', 'enforce_token_bind', 'check_revocations_for_cached',