Cleanup token hashes generated by cache
We hash a token under multiple configurable algorithms for revocations and caching and so must check all of them. This hash generation was being done and returned by the cache check and reused in revocation check. This is an unusual pattern and requires the cache object to have knowledge of token types and how to hash them. Change validate so that we generate the hash values in the main function and pass that to the cache and revocation functions. This moves the function to get the first available token id from the list to the main file. This means the cache interface is much more normal with get and set id functions and encapsulates the hash list function in the one file. Change-Id: I9194dbd052674f64122ff74329ce292a342512d3
This commit is contained in:
parent
de7a54efd3
commit
530b5cbe5c
|
@ -507,6 +507,7 @@ class AuthProtocol(object):
|
|||
self._delay_auth_decision = self._conf_get('delay_auth_decision')
|
||||
self._include_service_catalog = self._conf_get(
|
||||
'include_service_catalog')
|
||||
self._hash_algorithms = self._conf_get('hash_algorithms')
|
||||
|
||||
self._identity_server = self._create_identity_server()
|
||||
|
||||
|
@ -721,6 +722,41 @@ class AuthProtocol(object):
|
|||
start_response('401 Unauthorized', resp.headers)
|
||||
return resp.body
|
||||
|
||||
def _token_hashes(self, token):
|
||||
"""Generate a list of hashes that the current token may be cached as.
|
||||
|
||||
With PKI tokens we have multiple hashing algorithms that we test with
|
||||
revocations. This generates that whole list.
|
||||
|
||||
The first element of this list is the preferred algorithm and is what
|
||||
new cache values should be saved as.
|
||||
|
||||
:param str token: The token being presented by a user.
|
||||
|
||||
:returns: list of str token hashes.
|
||||
"""
|
||||
if cms.is_asn1_token(token) or cms.is_pkiz(token):
|
||||
return list(cms.cms_hash_token(token, mode=algo)
|
||||
for algo in self._hash_algorithms)
|
||||
else:
|
||||
return [token]
|
||||
|
||||
def _cache_get_hashes(self, token_hashes):
|
||||
"""Check if the token is cached already.
|
||||
|
||||
Functions takes a list of hashes that might be in the cache and matches
|
||||
the first one that is present. If nothing is found in the cache it
|
||||
returns None.
|
||||
|
||||
:returns: token data if found else None.
|
||||
"""
|
||||
|
||||
for token in token_hashes:
|
||||
cached = self._token_cache.get(token)
|
||||
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
def _validate_token(self, token, env):
|
||||
"""Authenticate user token
|
||||
|
||||
|
@ -730,11 +766,12 @@ class AuthProtocol(object):
|
|||
:raises exc.InvalidToken: if token is rejected
|
||||
|
||||
"""
|
||||
token_id = None
|
||||
token_hashes = None
|
||||
|
||||
try:
|
||||
token_ids, cached = self._token_cache.get(token)
|
||||
token_id = token_ids[0]
|
||||
token_hashes = self._token_hashes(token)
|
||||
cached = self._cache_get_hashes(token_hashes)
|
||||
|
||||
if cached:
|
||||
# Token was retrieved from the cache. In this case, there's no
|
||||
# need to check that the token is expired because the cache
|
||||
|
@ -747,7 +784,7 @@ class AuthProtocol(object):
|
|||
# A token stored in Memcached might have been revoked
|
||||
# regardless of initial mechanism used to validate it,
|
||||
# and needs to be checked.
|
||||
self._revocations.check(token_ids)
|
||||
self._revocations.check(token_hashes)
|
||||
self._confirm_token_bind(data, env)
|
||||
else:
|
||||
verified = None
|
||||
|
@ -755,9 +792,10 @@ class AuthProtocol(object):
|
|||
# checked that it's not expired, and also put in the cache.
|
||||
try:
|
||||
if cms.is_pkiz(token):
|
||||
verified = self._verify_pkiz_token(token, token_ids)
|
||||
verified = self._verify_pkiz_token(token, token_hashes)
|
||||
elif cms.is_asn1_token(token):
|
||||
verified = self._verify_signed_token(token, token_ids)
|
||||
verified = self._verify_signed_token(token,
|
||||
token_hashes)
|
||||
except exceptions.CertificateConfigError:
|
||||
self._LOG.warn(_LW('Fetch certificate config failed, '
|
||||
'fallback to online validation.'))
|
||||
|
@ -775,7 +813,7 @@ class AuthProtocol(object):
|
|||
# verify_token fails for expired tokens.
|
||||
expires = _get_token_expiration(data)
|
||||
self._confirm_token_bind(data, env)
|
||||
self._token_cache.store(token_id, data, expires)
|
||||
self._token_cache.store(token_hashes[0], data, expires)
|
||||
return data
|
||||
except (exceptions.ConnectionRefused, exceptions.RequestTimeout):
|
||||
self._LOG.debug('Token validation failure.', exc_info=True)
|
||||
|
@ -785,8 +823,8 @@ class AuthProtocol(object):
|
|||
raise
|
||||
except Exception:
|
||||
self._LOG.debug('Token validation failure.', exc_info=True)
|
||||
if token_id:
|
||||
self._token_cache.store_invalid(token_id)
|
||||
if token_hashes:
|
||||
self._token_cache.store_invalid(token_hashes[0])
|
||||
self._LOG.warn(_LW('Authorization failed for token'))
|
||||
raise exc.InvalidToken(_('Token authorization failed'))
|
||||
|
||||
|
@ -1090,7 +1128,6 @@ class AuthProtocol(object):
|
|||
|
||||
cache_kwargs = dict(
|
||||
cache_time=int(self._conf_get('token_cache_time')),
|
||||
hash_algorithms=self._conf_get('hash_algorithms'),
|
||||
env_cache_name=self._conf_get('cache'),
|
||||
memcached_servers=self._conf_get('memcached_servers'),
|
||||
use_advanced_pool=self._conf_get('memcache_use_advanced_pool'),
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
import contextlib
|
||||
|
||||
from keystoneclient.common import cms
|
||||
from oslo_serialization import jsonutils
|
||||
from oslo_utils import timeutils
|
||||
import six
|
||||
|
@ -97,7 +96,7 @@ class TokenCache(object):
|
|||
_CACHE_KEY_TEMPLATE = 'tokens/%s'
|
||||
_INVALID_INDICATOR = 'invalid'
|
||||
|
||||
def __init__(self, log, cache_time=None, hash_algorithms=None,
|
||||
def __init__(self, log, cache_time=None,
|
||||
env_cache_name=None, memcached_servers=None,
|
||||
use_advanced_pool=False, memcache_pool_dead_retry=None,
|
||||
memcache_pool_maxsize=None, memcache_pool_unused_timeout=None,
|
||||
|
@ -105,7 +104,6 @@ class TokenCache(object):
|
|||
memcache_pool_socket_timeout=None):
|
||||
self._LOG = log
|
||||
self._cache_time = cache_time
|
||||
self._hash_algorithms = hash_algorithms
|
||||
self._env_cache_name = env_cache_name
|
||||
self._memcached_servers = memcached_servers
|
||||
self._use_advanced_pool = use_advanced_pool
|
||||
|
@ -150,38 +148,6 @@ class TokenCache(object):
|
|||
|
||||
self._initialized = True
|
||||
|
||||
def get(self, user_token):
|
||||
"""Check if the token is cached already.
|
||||
|
||||
Returns a tuple. The first element is a list of token IDs, where the
|
||||
first one is the preferred hash.
|
||||
|
||||
The second element is the token data from the cache if the token was
|
||||
cached, otherwise ``None``.
|
||||
|
||||
:raises exc.InvalidToken: if the token is invalid
|
||||
|
||||
"""
|
||||
|
||||
if cms.is_asn1_token(user_token) or cms.is_pkiz(user_token):
|
||||
# user_token is a PKI token that's not hashed.
|
||||
|
||||
token_hashes = list(cms.cms_hash_token(user_token, mode=algo)
|
||||
for algo in self._hash_algorithms)
|
||||
|
||||
for token_hash in token_hashes:
|
||||
cached = self._cache_get(token_hash)
|
||||
if cached:
|
||||
return (token_hashes, cached)
|
||||
|
||||
# The token wasn't found using any hash algorithm.
|
||||
return (token_hashes, None)
|
||||
|
||||
# user_token is either a UUID token or a hashed PKI token.
|
||||
token_id = user_token
|
||||
cached = self._cache_get(token_id)
|
||||
return ([token_id], cached)
|
||||
|
||||
def store(self, token_id, data, expires):
|
||||
"""Put token data into the cache.
|
||||
|
||||
|
@ -249,7 +215,7 @@ class TokenCache(object):
|
|||
# memory cache will handle serialization for us
|
||||
return data
|
||||
|
||||
def _cache_get(self, token_id):
|
||||
def get(self, token_id):
|
||||
"""Return token information from cache.
|
||||
|
||||
If token is invalid raise exc.InvalidToken
|
||||
|
|
|
@ -537,7 +537,7 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest,
|
|||
token_cache = self.middleware._token_cache
|
||||
token_cache.initialize({})
|
||||
token_cache._cache_store(token, data)
|
||||
self.assertEqual(token_cache._cache_get(token), data[0])
|
||||
self.assertEqual(token_cache.get(token), data[0])
|
||||
|
||||
@testtools.skipUnless(memcached_available(), 'memcached not available')
|
||||
def test_sign_cache_data(self):
|
||||
|
@ -554,7 +554,7 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest,
|
|||
token_cache = self.middleware._token_cache
|
||||
token_cache.initialize({})
|
||||
token_cache._cache_store(token, data)
|
||||
self.assertEqual(token_cache._cache_get(token), data[0])
|
||||
self.assertEqual(token_cache.get(token), data[0])
|
||||
|
||||
@testtools.skipUnless(memcached_available(), 'memcached not available')
|
||||
def test_no_memcache_protection(self):
|
||||
|
@ -570,7 +570,7 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest,
|
|||
token_cache = self.middleware._token_cache
|
||||
token_cache.initialize({})
|
||||
token_cache._cache_store(token, data)
|
||||
self.assertEqual(token_cache._cache_get(token), data[0])
|
||||
self.assertEqual(token_cache.get(token), data[0])
|
||||
|
||||
def test_assert_valid_memcache_protection_config(self):
|
||||
# test missing memcache_secret_key
|
||||
|
@ -1039,7 +1039,7 @@ class CommonAuthTokenMiddlewareTest(object):
|
|||
|
||||
def _get_cached_token(self, token, mode='md5'):
|
||||
token_id = cms.cms_hash_token(token, mode=mode)
|
||||
return self.middleware._token_cache._cache_get(token_id)
|
||||
return self.middleware._token_cache.get(token_id)
|
||||
|
||||
def test_memcache(self):
|
||||
req = webob.Request.blank('/')
|
||||
|
@ -2054,7 +2054,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
|
|||
some_time_later = timeutils.strtime(at=(self.now + self.delta))
|
||||
expires = some_time_later
|
||||
self.middleware._token_cache.store(token, data, expires)
|
||||
self.assertEqual(self.middleware._token_cache._cache_get(token), data)
|
||||
self.assertEqual(self.middleware._token_cache.get(token), data)
|
||||
|
||||
def test_cached_token_not_expired_with_old_style_nix_timestamp(self):
|
||||
"""Ensure we cannot retrieve a token from the cache.
|
||||
|
@ -2072,7 +2072,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
|
|||
# Store a unix timestamp in the cache.
|
||||
expires = calendar.timegm(some_time_later.timetuple())
|
||||
token_cache.store(token, data, expires)
|
||||
self.assertIsNone(token_cache._cache_get(token))
|
||||
self.assertIsNone(token_cache.get(token))
|
||||
|
||||
def test_cached_token_expired(self):
|
||||
token = 'mytoken'
|
||||
|
@ -2082,7 +2082,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
|
|||
some_time_earlier = timeutils.strtime(at=(self.now - self.delta))
|
||||
expires = some_time_earlier
|
||||
self.middleware._token_cache.store(token, data, expires)
|
||||
self.assertThat(lambda: self.middleware._token_cache._cache_get(token),
|
||||
self.assertThat(lambda: self.middleware._token_cache.get(token),
|
||||
matchers.raises(exc.InvalidToken))
|
||||
|
||||
def test_cached_token_with_timezone_offset_not_expired(self):
|
||||
|
@ -2094,7 +2094,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
|
|||
some_time_later = self.now - timezone_offset + self.delta
|
||||
expires = timeutils.strtime(some_time_later) + '-02:00'
|
||||
self.middleware._token_cache.store(token, data, expires)
|
||||
self.assertEqual(self.middleware._token_cache._cache_get(token), data)
|
||||
self.assertEqual(self.middleware._token_cache.get(token), data)
|
||||
|
||||
def test_cached_token_with_timezone_offset_expired(self):
|
||||
token = 'mytoken'
|
||||
|
@ -2105,7 +2105,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
|
|||
some_time_earlier = self.now - timezone_offset - self.delta
|
||||
expires = timeutils.strtime(some_time_earlier) + '-02:00'
|
||||
self.middleware._token_cache.store(token, data, expires)
|
||||
self.assertThat(lambda: self.middleware._token_cache._cache_get(token),
|
||||
self.assertThat(lambda: self.middleware._token_cache.get(token),
|
||||
matchers.raises(exc.InvalidToken))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue