diff --git a/keystone/tests/test_backend_memcache.py b/keystone/tests/test_backend_memcache.py index fb94d7a637..964d5b42a3 100644 --- a/keystone/tests/test_backend_memcache.py +++ b/keystone/tests/test_backend_memcache.py @@ -163,30 +163,51 @@ class MemcacheToken(tests.TestCase, test_backend.TokenTests): # get expired tokens as well as valid tokens. token_api.list_tokens() # will not return any expired tokens in the list. user_key = self.token_api.driver._prefix_user_id(user_id) - user_record = self.token_api.driver.client.get(user_key) - user_token_list = jsonutils.loads('[%s]' % user_record) - self.assertEquals(len(user_token_list), 2) - expired_token_ptk = self.token_api.driver._prefix_token_id( - expired_token_id) - expired_token = self.token_api.driver.client.get(expired_token_ptk) - expired_token['expires'] = (timeutils.utcnow() - expire_delta) - self.token_api.driver.client.set(expired_token_ptk, expired_token) + user_token_list = self.token_api.driver.client.get(user_key) + self.assertEqual(len(user_token_list), 2) + # user_token_list is a list of (token, expiry) tuples + expired_idx = [i[0] for i in user_token_list].index(expired_token_id) + # set the token as expired. + user_token_list[expired_idx] = (user_token_list[expired_idx][0], + timeutils.utcnow() - expire_delta) + self.token_api.driver.client.set(user_key, user_token_list) self.token_api.create_token(second_valid_token_id, second_valid_data) - user_record = self.token_api.driver.client.get(user_key) - user_token_list = jsonutils.loads('[%s]' % user_record) - self.assertEquals(len(user_token_list), 2) + user_token_list = self.token_api.driver.client.get(user_key) + self.assertEqual(len(user_token_list), 2) + + def test_convert_token_list_from_json(self): + token_list = ','.join(['"%s"' % uuid.uuid4().hex for x in xrange(5)]) + token_list_loaded = jsonutils.loads('[%s]' % token_list) + converted_list = self.token_api.driver._convert_user_index_from_json( + token_list, 'test-key') + for idx, item in enumerate(converted_list): + token_id, expiry = item + self.assertEqual(token_id, token_list_loaded[idx]) + self.assertIsInstance(expiry, datetime.datetime) + + def test_convert_token_list_from_json_non_string(self): + token_list = self.token_api.driver._convert_user_index_from_json( + None, 'test-key') + self.assertEqual([], token_list) + + def test_convert_token_list_from_json_invalid_json(self): + token_list = self.token_api.driver._convert_user_index_from_json( + 'invalid_json_list', 'test-key') + self.assertEqual([], token_list) def test_cas_failure(self): + expire_delta = datetime.timedelta(seconds=86400) self.token_api.driver.client.reject_cas = True token_id = uuid.uuid4().hex user_id = unicode(uuid.uuid4().hex) + token_data = {'expires': timeutils.utcnow() + expire_delta, + 'id': token_id} user_key = self.token_api.driver._prefix_user_id(user_id) - token_data = jsonutils.dumps(token_id) self.assertRaises( exception.UnexpectedError, self.token_api.driver._update_user_list_with_cas, - user_key, token_data) + user_key, token_id, token_data) def test_token_expire_timezone(self): diff --git a/keystone/token/backends/memcache.py b/keystone/token/backends/memcache.py index 2582c49c6b..a6fe82694e 100644 --- a/keystone/token/backends/memcache.py +++ b/keystone/token/backends/memcache.py @@ -16,6 +16,7 @@ from __future__ import absolute_import import copy +import datetime import memcache @@ -83,19 +84,50 @@ class Token(token.Driver): kwargs['time'] = expires_ts self.client.set(ptk, data_copy, **kwargs) if 'id' in data['user']: - token_data = jsonutils.dumps(token_id) user_id = data['user']['id'] user_key = self._prefix_user_id(user_id) # Append the new token_id to the token-index-list stored in the # user-key within memcache. - self._update_user_list_with_cas(user_key, token_data) + self._update_user_list_with_cas(user_key, token_id, data_copy) return copy.deepcopy(data_copy) - def _update_user_list_with_cas(self, user_key, token_id): + def _convert_user_index_from_json(self, token_list, user_key): + try: + # NOTE(morganfainberg): Try loading in the old format + # of the list. + token_list = jsonutils.loads('[%s]' % token_list) + + # NOTE(morganfainberg): Build a delta based upon the + # token TTL configured. Since we are using the old + # format index-list, we will create a "fake" expiration + # that should be further in the future than the actual + # expiry. To avoid locking up keystone trying to + # communicate to memcached, it is better to use a fake + # value. The logic that utilizes this list already + # knows how to handle the case of tokens that are + # no longer valid being included. + delta = datetime.timedelta( + seconds=CONF.token.expiration) + new_expiry = timeutils.normalize_time( + timeutils.utcnow()) + delta + + for idx, token_id in enumerate(token_list): + token_list[idx] = (token_id, new_expiry) + + except Exception: + # NOTE(morganfainberg): Catch any errors thrown here. There is + # nothing the admin or operator needs to do in this case, but + # it should be logged that there was an error and some action was + # taken to correct it + LOG.info(_('Unable to convert user-token-index to new format; ' + 'clearing user token index record "%s".'), user_key) + token_list = [] + return token_list + + def _update_user_list_with_cas(self, user_key, token_id, token_data): cas_retry = 0 max_cas_retry = CONF.memcache.max_compare_and_set_retry - current_time = timeutils.normalize_time( - timeutils.parse_isotime(timeutils.isotime())) + current_time = timeutils.normalize_time(timeutils.utcnow()) self.client.reset_cas() @@ -110,35 +142,30 @@ class Token(token.Driver): # case memcache is down or something horrible happens we don't # iterate forever trying to compare and set the new value. cas_retry += 1 - record = self.client.gets(user_key) + token_list = self.client.gets(user_key) filtered_list = [] - if record is not None: - token_list = jsonutils.loads('[%s]' % record) - for token_i in token_list: - ptk = self._prefix_token_id(token_i) - token_ref = self.client.get(ptk) - if not token_ref: - # skip tokens that do not exist in memcache + if token_list is not None: + if not isinstance(token_list, list): + token_list = self._convert_user_index_from_json(token_list, + user_key) + for token_i, expiry in token_list: + expires_at = timeutils.normalize_time(expiry) + if expires_at < current_time: + # skip tokens that are expired. continue - if 'expires' in token_ref: - expires_at = timeutils.normalize_time( - token_ref['expires']) - if expires_at < current_time: - # skip tokens that are expired. - continue - # Add the still valid token_id to the list. - filtered_list.append(jsonutils.dumps(token_i)) - # Add the new token_id to the list. - filtered_list.append(token_id) + filtered_list.append((token_i, expiry)) + # Add the new token_id and expiry. + filtered_list.append( + (token_id, timeutils.normalize_time(token_data['expires']))) # Use compare-and-set (cas) to set the new value for the # token-index-list for the user-key. Cas is used to prevent race # conditions from causing the loss of valid token ids from this # list. - if self.client.cas(user_key, ','.join(filtered_list)): + if self.client.cas(user_key, filtered_list): msg = _('Successful set of token-index-list for user-key ' '"%(user_key)s", #%(count)d records') LOG.debug(msg, {'user_key': user_key, @@ -182,9 +209,17 @@ class Token(token.Driver): consumer_id=None): tokens = [] user_key = self._prefix_user_id(user_id) - user_record = self.client.get(user_key) or "" - token_list = jsonutils.loads('[%s]' % user_record) - for token_id in token_list: + current_time = timeutils.normalize_time(timeutils.utcnow()) + token_list = self.client.get(user_key) or [] + if not isinstance(token_list, list): + # NOTE(morganfainberg): This is for compatibility for old-format + # token-lists that were a JSON string of just token_ids. This code + # will reference the underlying expires directly from the + # token_ref vs in this list, so setting to none just ensures the + # loop works as expected. + token_list = [(i, None) for i in + jsonutils.loads('[%s]' % token_list)] + for token_id, expiry in token_list: ptk = self._prefix_token_id(token_id) token_ref = self.client.get(ptk) if token_ref: @@ -208,6 +243,11 @@ class Token(token.Driver): except KeyError: continue + if (timeutils.normalize_time(token_ref['expires']) < + current_time): + # Skip expired tokens. + continue + tokens.append(token_id) return tokens