Cleanup token hashes generated by cache

We hash a token under multiple configurable algorithms for revocations
and caching and so must check all of them. This hash generation was
being done and returned by the cache check and reused in revocation
check. This is an unusual pattern and requires the cache object to have
knowledge of token types and how to hash them.

Change validate so that we generate the hash values in the main function
and pass that to the cache and revocation functions. This moves the
function to get the first available token id from the list to the main
file. This means the cache interface is much more normal with get and
set id functions and encapsulates the hash list function in the one
file.

Change-Id: I9194dbd052674f64122ff74329ce292a342512d3
This commit is contained in:
Jamie Lennox 2015-04-16 09:30:28 +10:00
parent de7a54efd3
commit 530b5cbe5c
3 changed files with 58 additions and 55 deletions

View File

@ -507,6 +507,7 @@ class AuthProtocol(object):
self._delay_auth_decision = self._conf_get('delay_auth_decision')
self._include_service_catalog = self._conf_get(
'include_service_catalog')
self._hash_algorithms = self._conf_get('hash_algorithms')
self._identity_server = self._create_identity_server()
@ -721,6 +722,41 @@ class AuthProtocol(object):
start_response('401 Unauthorized', resp.headers)
return resp.body
def _token_hashes(self, token):
"""Generate a list of hashes that the current token may be cached as.
With PKI tokens we have multiple hashing algorithms that we test with
revocations. This generates that whole list.
The first element of this list is the preferred algorithm and is what
new cache values should be saved as.
:param str token: The token being presented by a user.
:returns: list of str token hashes.
"""
if cms.is_asn1_token(token) or cms.is_pkiz(token):
return list(cms.cms_hash_token(token, mode=algo)
for algo in self._hash_algorithms)
else:
return [token]
def _cache_get_hashes(self, token_hashes):
"""Check if the token is cached already.
Functions takes a list of hashes that might be in the cache and matches
the first one that is present. If nothing is found in the cache it
returns None.
:returns: token data if found else None.
"""
for token in token_hashes:
cached = self._token_cache.get(token)
if cached:
return cached
def _validate_token(self, token, env):
"""Authenticate user token
@ -730,11 +766,12 @@ class AuthProtocol(object):
:raises exc.InvalidToken: if token is rejected
"""
token_id = None
token_hashes = None
try:
token_ids, cached = self._token_cache.get(token)
token_id = token_ids[0]
token_hashes = self._token_hashes(token)
cached = self._cache_get_hashes(token_hashes)
if cached:
# Token was retrieved from the cache. In this case, there's no
# need to check that the token is expired because the cache
@ -747,7 +784,7 @@ class AuthProtocol(object):
# A token stored in Memcached might have been revoked
# regardless of initial mechanism used to validate it,
# and needs to be checked.
self._revocations.check(token_ids)
self._revocations.check(token_hashes)
self._confirm_token_bind(data, env)
else:
verified = None
@ -755,9 +792,10 @@ class AuthProtocol(object):
# checked that it's not expired, and also put in the cache.
try:
if cms.is_pkiz(token):
verified = self._verify_pkiz_token(token, token_ids)
verified = self._verify_pkiz_token(token, token_hashes)
elif cms.is_asn1_token(token):
verified = self._verify_signed_token(token, token_ids)
verified = self._verify_signed_token(token,
token_hashes)
except exceptions.CertificateConfigError:
self._LOG.warn(_LW('Fetch certificate config failed, '
'fallback to online validation.'))
@ -775,7 +813,7 @@ class AuthProtocol(object):
# verify_token fails for expired tokens.
expires = _get_token_expiration(data)
self._confirm_token_bind(data, env)
self._token_cache.store(token_id, data, expires)
self._token_cache.store(token_hashes[0], data, expires)
return data
except (exceptions.ConnectionRefused, exceptions.RequestTimeout):
self._LOG.debug('Token validation failure.', exc_info=True)
@ -785,8 +823,8 @@ class AuthProtocol(object):
raise
except Exception:
self._LOG.debug('Token validation failure.', exc_info=True)
if token_id:
self._token_cache.store_invalid(token_id)
if token_hashes:
self._token_cache.store_invalid(token_hashes[0])
self._LOG.warn(_LW('Authorization failed for token'))
raise exc.InvalidToken(_('Token authorization failed'))
@ -1090,7 +1128,6 @@ class AuthProtocol(object):
cache_kwargs = dict(
cache_time=int(self._conf_get('token_cache_time')),
hash_algorithms=self._conf_get('hash_algorithms'),
env_cache_name=self._conf_get('cache'),
memcached_servers=self._conf_get('memcached_servers'),
use_advanced_pool=self._conf_get('memcache_use_advanced_pool'),

View File

@ -12,7 +12,6 @@
import contextlib
from keystoneclient.common import cms
from oslo_serialization import jsonutils
from oslo_utils import timeutils
import six
@ -97,7 +96,7 @@ class TokenCache(object):
_CACHE_KEY_TEMPLATE = 'tokens/%s'
_INVALID_INDICATOR = 'invalid'
def __init__(self, log, cache_time=None, hash_algorithms=None,
def __init__(self, log, cache_time=None,
env_cache_name=None, memcached_servers=None,
use_advanced_pool=False, memcache_pool_dead_retry=None,
memcache_pool_maxsize=None, memcache_pool_unused_timeout=None,
@ -105,7 +104,6 @@ class TokenCache(object):
memcache_pool_socket_timeout=None):
self._LOG = log
self._cache_time = cache_time
self._hash_algorithms = hash_algorithms
self._env_cache_name = env_cache_name
self._memcached_servers = memcached_servers
self._use_advanced_pool = use_advanced_pool
@ -150,38 +148,6 @@ class TokenCache(object):
self._initialized = True
def get(self, user_token):
"""Check if the token is cached already.
Returns a tuple. The first element is a list of token IDs, where the
first one is the preferred hash.
The second element is the token data from the cache if the token was
cached, otherwise ``None``.
:raises exc.InvalidToken: if the token is invalid
"""
if cms.is_asn1_token(user_token) or cms.is_pkiz(user_token):
# user_token is a PKI token that's not hashed.
token_hashes = list(cms.cms_hash_token(user_token, mode=algo)
for algo in self._hash_algorithms)
for token_hash in token_hashes:
cached = self._cache_get(token_hash)
if cached:
return (token_hashes, cached)
# The token wasn't found using any hash algorithm.
return (token_hashes, None)
# user_token is either a UUID token or a hashed PKI token.
token_id = user_token
cached = self._cache_get(token_id)
return ([token_id], cached)
def store(self, token_id, data, expires):
"""Put token data into the cache.
@ -249,7 +215,7 @@ class TokenCache(object):
# memory cache will handle serialization for us
return data
def _cache_get(self, token_id):
def get(self, token_id):
"""Return token information from cache.
If token is invalid raise exc.InvalidToken

View File

@ -537,7 +537,7 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest,
token_cache = self.middleware._token_cache
token_cache.initialize({})
token_cache._cache_store(token, data)
self.assertEqual(token_cache._cache_get(token), data[0])
self.assertEqual(token_cache.get(token), data[0])
@testtools.skipUnless(memcached_available(), 'memcached not available')
def test_sign_cache_data(self):
@ -554,7 +554,7 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest,
token_cache = self.middleware._token_cache
token_cache.initialize({})
token_cache._cache_store(token, data)
self.assertEqual(token_cache._cache_get(token), data[0])
self.assertEqual(token_cache.get(token), data[0])
@testtools.skipUnless(memcached_available(), 'memcached not available')
def test_no_memcache_protection(self):
@ -570,7 +570,7 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest,
token_cache = self.middleware._token_cache
token_cache.initialize({})
token_cache._cache_store(token, data)
self.assertEqual(token_cache._cache_get(token), data[0])
self.assertEqual(token_cache.get(token), data[0])
def test_assert_valid_memcache_protection_config(self):
# test missing memcache_secret_key
@ -1039,7 +1039,7 @@ class CommonAuthTokenMiddlewareTest(object):
def _get_cached_token(self, token, mode='md5'):
token_id = cms.cms_hash_token(token, mode=mode)
return self.middleware._token_cache._cache_get(token_id)
return self.middleware._token_cache.get(token_id)
def test_memcache(self):
req = webob.Request.blank('/')
@ -2054,7 +2054,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
some_time_later = timeutils.strtime(at=(self.now + self.delta))
expires = some_time_later
self.middleware._token_cache.store(token, data, expires)
self.assertEqual(self.middleware._token_cache._cache_get(token), data)
self.assertEqual(self.middleware._token_cache.get(token), data)
def test_cached_token_not_expired_with_old_style_nix_timestamp(self):
"""Ensure we cannot retrieve a token from the cache.
@ -2072,7 +2072,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
# Store a unix timestamp in the cache.
expires = calendar.timegm(some_time_later.timetuple())
token_cache.store(token, data, expires)
self.assertIsNone(token_cache._cache_get(token))
self.assertIsNone(token_cache.get(token))
def test_cached_token_expired(self):
token = 'mytoken'
@ -2082,7 +2082,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
some_time_earlier = timeutils.strtime(at=(self.now - self.delta))
expires = some_time_earlier
self.middleware._token_cache.store(token, data, expires)
self.assertThat(lambda: self.middleware._token_cache._cache_get(token),
self.assertThat(lambda: self.middleware._token_cache.get(token),
matchers.raises(exc.InvalidToken))
def test_cached_token_with_timezone_offset_not_expired(self):
@ -2094,7 +2094,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
some_time_later = self.now - timezone_offset + self.delta
expires = timeutils.strtime(some_time_later) + '-02:00'
self.middleware._token_cache.store(token, data, expires)
self.assertEqual(self.middleware._token_cache._cache_get(token), data)
self.assertEqual(self.middleware._token_cache.get(token), data)
def test_cached_token_with_timezone_offset_expired(self):
token = 'mytoken'
@ -2105,7 +2105,7 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest):
some_time_earlier = self.now - timezone_offset - self.delta
expires = timeutils.strtime(some_time_earlier) + '-02:00'
self.middleware._token_cache.store(token, data, expires)
self.assertThat(lambda: self.middleware._token_cache._cache_get(token),
self.assertThat(lambda: self.middleware._token_cache.get(token),
matchers.raises(exc.InvalidToken))