Refresh token support

The refresh token may now be used to obtain a new access token if
needed. The refresh token can be used only once.

Sample request:
POST /v1/auth/token?grant_type=refresh_token&refresh_token=<the_token>

The response will be similar with the one if authroization code is used.

Introduced expiration for refresh tokens.
The migration deletes existing refresh tokens as they do not have a
valide expiration information set. Probably noone should notice that
because client is not using refresh tokens currently.

Change-Id: Ie0924888f66ca956caa43b04b8501e6fa8e9371e
This commit is contained in:
Nikita Konovalov 2014-05-20 17:31:29 +04:00
parent f6fef5f15d
commit 136ce74ded
9 changed files with 245 additions and 21 deletions

View File

@ -37,7 +37,10 @@ lock_path = $state_path/lock
# openid_url = https://login.launchpad.net/+openid
# Time in seconds before an access_token expires
# token_ttl = 3600
# access_token_ttl = 3600
# Time in seconds before an refresh_token expires
# refresh_token_ttl = 604800
# List paging configuration options.
# page_size_maximum = 500

View File

@ -27,9 +27,13 @@ CONF = cfg.CONF
LOG = logging.getLogger(__name__)
TOKEN_OPTS = [
cfg.IntOpt("token_ttl",
default=3600,
help="Time in seconds before an access_token expires")
cfg.IntOpt("access_token_ttl",
default=60 * 60, # One hour
help="Time in seconds before an access_token expires"),
cfg.IntOpt("refresh_token_ttl",
default=60 * 60 * 24 * 7, # One week
help="Time in seconds before an refresh_token expires")
]
CONF.register_opts(TOKEN_OPTS)
@ -165,12 +169,31 @@ class SkeletonValidator(RequestValidator):
return (grant_type == "authorization_code"
or grant_type == "refresh_token")
def _resolve_user_id(self, request):
# Try authorization code
code = request._params.get("code")
if code:
code_info = self.token_storage.get_authorization_code_info(code)
return code_info.user_id
# Try refresh token
refresh_token = request._params.get("refresh_token")
refresh_token_entry = self.token_storage.get_refresh_token_info(
refresh_token)
if refresh_token_entry:
return refresh_token_entry.user_id
return None
def save_bearer_token(self, token, request, *args, **kwargs):
"""Save all token information to the storage."""
code = request._params["code"]
code_info = self.token_storage.get_authorization_code_info(code)
user_id = code_info.user_id
user_id = self._resolve_user_id(request)
# If a refresh_token was used to obtain a new access_token, it should
# be removed.
self.invalidate_refresh_token(request)
self.token_storage.save_token(access_token=token["access_token"],
expires_in=token["expires_in"],
@ -202,13 +225,37 @@ class SkeletonValidator(RequestValidator):
"""
return ["user"]
def rotate_refresh_token(self, request):
"""The refresh token should be single use."""
return True
def validate_refresh_token(self, refresh_token, client, request, *args,
**kwargs):
"""Check that the refresh token exists in the db."""
return self.token_storage.check_refresh_token(refresh_token)
def invalidate_refresh_token(self, request):
"""Remove a used token from the storage."""
refresh_token = request._params.get("refresh_token")
# The request may have no token in parameters which means that the
# authorization code was used.
if not refresh_token:
return
self.token_storage.invalidate_refresh_token(refresh_token)
class OpenIdConnectServer(WebApplicationServer):
def __init__(self, request_validator):
token_ttl = CONF.token_ttl
super(OpenIdConnectServer, self).__init__(request_validator,
token_expires_in=token_ttl)
access_token_ttl = CONF.access_token_ttl
super(OpenIdConnectServer, self).__init__(
request_validator,
token_expires_in=access_token_ttl)
validator = SkeletonValidator()
SERVER = OpenIdConnectServer(validator)

View File

@ -15,10 +15,15 @@
import datetime
from oslo.config import cfg
from storyboard.api.auth.token_storage import storage
from storyboard.db.api import auth as auth_api
CONF = cfg.CONF
class DBTokenStorage(storage.StorageBase):
def save_authorization_code(self, authorization_code, user_id):
values = {
@ -47,9 +52,16 @@ class DBTokenStorage(storage.StorageBase):
"user_id": user_id
}
# Oauthlib does not provide a separate expiration time for a
# refresh_token so taking it from config directly.
refresh_expires_in = CONF.refresh_token_ttl
refresh_token_values = {
"refresh_token": refresh_token,
"user_id": user_id
"user_id": user_id,
"expires_in": refresh_expires_in,
"expires_at": datetime.datetime.now() + datetime.timedelta(
seconds=refresh_expires_in),
}
auth_api.access_token_save(access_token_values)
@ -72,3 +84,21 @@ class DBTokenStorage(storage.StorageBase):
def remove_token(self, access_token):
auth_api.access_token_delete(access_token)
def check_refresh_token(self, refresh_token):
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
if not refresh_token_entry:
return False
if datetime.datetime.now() > refresh_token_entry.expires_at:
auth_api.refresh_token_delete(refresh_token)
return False
return True
def get_refresh_token_info(self, refresh_token):
return auth_api.refresh_token_get(refresh_token)
def invalidate_refresh_token(self, refresh_token):
auth_api.refresh_token_delete(refresh_token)

View File

@ -75,6 +75,29 @@ class MemoryTokenStorage(storage.StorageBase):
def remove_token(self, token):
pass
def check_refresh_token(self, refresh_token):
for token_info in self.token_set:
if token_info.refresh_token == refresh_token:
return True
return False
def get_refresh_token_info(self, refresh_token):
for token_info in self.token_set:
if token_info.refresh_token == refresh_token:
return token_info
return None
def invalidate_refresh_token(self, refresh_token):
token_entry = None
for entry in self.token_set:
if entry.refresh_token == refresh_token:
token_entry = entry
break
self.token_set.remove(token_entry)
def save_authorization_code(self, authorization_code, user_id):
self.auth_code_set.add(AuthorizationCode(authorization_code, user_id))

View File

@ -111,6 +111,34 @@ class StorageBase(object):
"""
pass
@abc.abstractmethod
def check_refresh_token(self, refresh_token):
"""This method should say if a given token exists in the storage and
that it has not expired yet.
@param refresh_token: The token to be checked.
@return bool
"""
pass
@abc.abstractmethod
def get_refresh_token_info(self, refresh_token):
"""Get the Bearer token from the storage.
@param refresh_token: The token to get the information about.
@return object: The object should contain all fields associated with
the token (refresh_token, expires_in, user_id).
"""
pass
@abc.abstractmethod
def invalidate_refresh_token(self, refresh_token):
"""Remove a token from the storage.
@param refresh_token: A refresh token
"""
pass
STORAGE = None

View File

@ -70,28 +70,63 @@ class AuthController(rest.RestController):
return response
@pecan.expose()
def token(self):
"""Access token endpoint."""
def _access_token_by_code(self):
auth_code = request.params.get("code")
code_info = storage.get_storage()\
code_info = storage.get_storage() \
.get_authorization_code_info(auth_code)
headers, body, code = SERVER.create_token_response(
uri=request.url,
http_method=request.method,
body=request.body,
headers=request.headers)
response.headers = dict((str(k), str(v))
for k, v in headers.iteritems())
response.status_code = code
json_body = json.loads(body)
# Update a body with user_id only if a response is 2xx
if code / 100 == 2:
json_body.update({
'id_token': code_info.user_id
})
response.body = json.dumps(json_body)
return response
def _access_token_by_refresh_token(self):
refresh_token = request.params.get("refresh_token")
refresh_token_info = storage.get_storage().get_refresh_token_info(
refresh_token)
headers, body, code = SERVER.create_token_response(
uri=request.url,
http_method=request.method,
body=request.body,
headers=request.headers)
response.headers = dict((str(k), str(v))
for k, v in headers.iteritems())
response.status_code = code
json_body = json.loads(body)
json_body.update({
'id_token': code_info.user_id
})
# Update a body with user_id only if a response is 2xx
if code / 100 == 2:
json_body.update({
'id_token': refresh_token_info.user_id
})
response.body = json.dumps(json_body)
return response
@pecan.expose()
def token(self):
"""Token endpoint."""
grant_type = request.params.get("grant_type")
if grant_type == "authorization_code":
# Serve an access token having an authorization code
return self._access_token_by_code()
if grant_type == "refresh_token":
# Serve an access token having a refresh token
return self._access_token_by_refresh_token()

View File

@ -57,3 +57,10 @@ def refresh_token_get(refresh_token):
def refresh_token_save(values):
return api_base.entity_create(models.RefreshToken, values)
def refresh_token_delete(refresh_token):
del_token = refresh_token_get(refresh_token)
if del_token:
api_base.entity_hard_delete(models.RefreshToken, del_token.id)

View File

@ -0,0 +1,49 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""The refresh token should also have expiration fields.
Revision ID: 019
Revises: 018
Create Date: 2014-05-21 11:17:16.360987
"""
# revision identifiers, used by Alembic.
revision = '019'
down_revision = '018'
from alembic import op
import sqlalchemy as sa
def upgrade(active_plugins=None, options=None):
# Deleting old tokens because they don't have a valid expiration
# information.
bind = op.get_bind()
bind.execute(sa.delete(table='refreshtokens'))
op.add_column('refreshtokens', sa.Column('expires_at', sa.DateTime(),
nullable=False))
op.add_column('refreshtokens', sa.Column('expires_in', sa.Integer(),
nullable=False))
### end Alembic commands ###
def downgrade(active_plugins=None, options=None):
op.drop_column('refreshtokens', 'expires_in')
op.drop_column('refreshtokens', 'expires_at')
### end Alembic commands ###

View File

@ -223,6 +223,8 @@ class RefreshToken(Base):
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
refresh_token = Column(Unicode(100), nullable=False)
expires_in = Column(Integer, nullable=False)
expires_at = Column(DateTime, nullable=False)
def _story_build_summary_query():