Refresh token support
The refresh token may now be used to obtain a new access token if needed. The refresh token can be used only once. Sample request: POST /v1/auth/token?grant_type=refresh_token&refresh_token=<the_token> The response will be similar with the one if authroization code is used. Introduced expiration for refresh tokens. The migration deletes existing refresh tokens as they do not have a valide expiration information set. Probably noone should notice that because client is not using refresh tokens currently. Change-Id: Ie0924888f66ca956caa43b04b8501e6fa8e9371e
This commit is contained in:
parent
f6fef5f15d
commit
136ce74ded
|
@ -37,7 +37,10 @@ lock_path = $state_path/lock
|
|||
# openid_url = https://login.launchpad.net/+openid
|
||||
|
||||
# Time in seconds before an access_token expires
|
||||
# token_ttl = 3600
|
||||
# access_token_ttl = 3600
|
||||
|
||||
# Time in seconds before an refresh_token expires
|
||||
# refresh_token_ttl = 604800
|
||||
|
||||
# List paging configuration options.
|
||||
# page_size_maximum = 500
|
||||
|
|
|
@ -27,9 +27,13 @@ CONF = cfg.CONF
|
|||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TOKEN_OPTS = [
|
||||
cfg.IntOpt("token_ttl",
|
||||
default=3600,
|
||||
help="Time in seconds before an access_token expires")
|
||||
cfg.IntOpt("access_token_ttl",
|
||||
default=60 * 60, # One hour
|
||||
help="Time in seconds before an access_token expires"),
|
||||
|
||||
cfg.IntOpt("refresh_token_ttl",
|
||||
default=60 * 60 * 24 * 7, # One week
|
||||
help="Time in seconds before an refresh_token expires")
|
||||
]
|
||||
|
||||
CONF.register_opts(TOKEN_OPTS)
|
||||
|
@ -165,12 +169,31 @@ class SkeletonValidator(RequestValidator):
|
|||
return (grant_type == "authorization_code"
|
||||
or grant_type == "refresh_token")
|
||||
|
||||
def _resolve_user_id(self, request):
|
||||
|
||||
# Try authorization code
|
||||
code = request._params.get("code")
|
||||
if code:
|
||||
code_info = self.token_storage.get_authorization_code_info(code)
|
||||
return code_info.user_id
|
||||
|
||||
# Try refresh token
|
||||
refresh_token = request._params.get("refresh_token")
|
||||
refresh_token_entry = self.token_storage.get_refresh_token_info(
|
||||
refresh_token)
|
||||
if refresh_token_entry:
|
||||
return refresh_token_entry.user_id
|
||||
|
||||
return None
|
||||
|
||||
def save_bearer_token(self, token, request, *args, **kwargs):
|
||||
"""Save all token information to the storage."""
|
||||
|
||||
code = request._params["code"]
|
||||
code_info = self.token_storage.get_authorization_code_info(code)
|
||||
user_id = code_info.user_id
|
||||
user_id = self._resolve_user_id(request)
|
||||
|
||||
# If a refresh_token was used to obtain a new access_token, it should
|
||||
# be removed.
|
||||
self.invalidate_refresh_token(request)
|
||||
|
||||
self.token_storage.save_token(access_token=token["access_token"],
|
||||
expires_in=token["expires_in"],
|
||||
|
@ -202,13 +225,37 @@ class SkeletonValidator(RequestValidator):
|
|||
"""
|
||||
return ["user"]
|
||||
|
||||
def rotate_refresh_token(self, request):
|
||||
"""The refresh token should be single use."""
|
||||
|
||||
return True
|
||||
|
||||
def validate_refresh_token(self, refresh_token, client, request, *args,
|
||||
**kwargs):
|
||||
"""Check that the refresh token exists in the db."""
|
||||
|
||||
return self.token_storage.check_refresh_token(refresh_token)
|
||||
|
||||
def invalidate_refresh_token(self, request):
|
||||
"""Remove a used token from the storage."""
|
||||
|
||||
refresh_token = request._params.get("refresh_token")
|
||||
|
||||
# The request may have no token in parameters which means that the
|
||||
# authorization code was used.
|
||||
if not refresh_token:
|
||||
return
|
||||
|
||||
self.token_storage.invalidate_refresh_token(refresh_token)
|
||||
|
||||
|
||||
class OpenIdConnectServer(WebApplicationServer):
|
||||
|
||||
def __init__(self, request_validator):
|
||||
token_ttl = CONF.token_ttl
|
||||
super(OpenIdConnectServer, self).__init__(request_validator,
|
||||
token_expires_in=token_ttl)
|
||||
access_token_ttl = CONF.access_token_ttl
|
||||
super(OpenIdConnectServer, self).__init__(
|
||||
request_validator,
|
||||
token_expires_in=access_token_ttl)
|
||||
|
||||
validator = SkeletonValidator()
|
||||
SERVER = OpenIdConnectServer(validator)
|
||||
|
|
|
@ -15,10 +15,15 @@
|
|||
|
||||
import datetime
|
||||
|
||||
from oslo.config import cfg
|
||||
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.db.api import auth as auth_api
|
||||
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
||||
|
||||
class DBTokenStorage(storage.StorageBase):
|
||||
def save_authorization_code(self, authorization_code, user_id):
|
||||
values = {
|
||||
|
@ -47,9 +52,16 @@ class DBTokenStorage(storage.StorageBase):
|
|||
"user_id": user_id
|
||||
}
|
||||
|
||||
# Oauthlib does not provide a separate expiration time for a
|
||||
# refresh_token so taking it from config directly.
|
||||
refresh_expires_in = CONF.refresh_token_ttl
|
||||
|
||||
refresh_token_values = {
|
||||
"refresh_token": refresh_token,
|
||||
"user_id": user_id
|
||||
"user_id": user_id,
|
||||
"expires_in": refresh_expires_in,
|
||||
"expires_at": datetime.datetime.now() + datetime.timedelta(
|
||||
seconds=refresh_expires_in),
|
||||
}
|
||||
|
||||
auth_api.access_token_save(access_token_values)
|
||||
|
@ -72,3 +84,21 @@ class DBTokenStorage(storage.StorageBase):
|
|||
|
||||
def remove_token(self, access_token):
|
||||
auth_api.access_token_delete(access_token)
|
||||
|
||||
def check_refresh_token(self, refresh_token):
|
||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
if not refresh_token_entry:
|
||||
return False
|
||||
|
||||
if datetime.datetime.now() > refresh_token_entry.expires_at:
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_refresh_token_info(self, refresh_token):
|
||||
return auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
def invalidate_refresh_token(self, refresh_token):
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
|
|
|
@ -75,6 +75,29 @@ class MemoryTokenStorage(storage.StorageBase):
|
|||
def remove_token(self, token):
|
||||
pass
|
||||
|
||||
def check_refresh_token(self, refresh_token):
|
||||
for token_info in self.token_set:
|
||||
if token_info.refresh_token == refresh_token:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_refresh_token_info(self, refresh_token):
|
||||
for token_info in self.token_set:
|
||||
if token_info.refresh_token == refresh_token:
|
||||
return token_info
|
||||
|
||||
return None
|
||||
|
||||
def invalidate_refresh_token(self, refresh_token):
|
||||
token_entry = None
|
||||
for entry in self.token_set:
|
||||
if entry.refresh_token == refresh_token:
|
||||
token_entry = entry
|
||||
break
|
||||
|
||||
self.token_set.remove(token_entry)
|
||||
|
||||
def save_authorization_code(self, authorization_code, user_id):
|
||||
self.auth_code_set.add(AuthorizationCode(authorization_code, user_id))
|
||||
|
||||
|
|
|
@ -111,6 +111,34 @@ class StorageBase(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_refresh_token(self, refresh_token):
|
||||
"""This method should say if a given token exists in the storage and
|
||||
that it has not expired yet.
|
||||
|
||||
@param refresh_token: The token to be checked.
|
||||
@return bool
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_refresh_token_info(self, refresh_token):
|
||||
"""Get the Bearer token from the storage.
|
||||
|
||||
@param refresh_token: The token to get the information about.
|
||||
@return object: The object should contain all fields associated with
|
||||
the token (refresh_token, expires_in, user_id).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def invalidate_refresh_token(self, refresh_token):
|
||||
"""Remove a token from the storage.
|
||||
|
||||
@param refresh_token: A refresh token
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
STORAGE = None
|
||||
|
||||
|
|
|
@ -70,28 +70,63 @@ class AuthController(rest.RestController):
|
|||
|
||||
return response
|
||||
|
||||
@pecan.expose()
|
||||
def token(self):
|
||||
"""Access token endpoint."""
|
||||
|
||||
def _access_token_by_code(self):
|
||||
auth_code = request.params.get("code")
|
||||
code_info = storage.get_storage()\
|
||||
code_info = storage.get_storage() \
|
||||
.get_authorization_code_info(auth_code)
|
||||
headers, body, code = SERVER.create_token_response(
|
||||
uri=request.url,
|
||||
http_method=request.method,
|
||||
body=request.body,
|
||||
headers=request.headers)
|
||||
response.headers = dict((str(k), str(v))
|
||||
for k, v in headers.iteritems())
|
||||
response.status_code = code
|
||||
json_body = json.loads(body)
|
||||
|
||||
# Update a body with user_id only if a response is 2xx
|
||||
if code / 100 == 2:
|
||||
json_body.update({
|
||||
'id_token': code_info.user_id
|
||||
})
|
||||
|
||||
response.body = json.dumps(json_body)
|
||||
return response
|
||||
|
||||
def _access_token_by_refresh_token(self):
|
||||
refresh_token = request.params.get("refresh_token")
|
||||
refresh_token_info = storage.get_storage().get_refresh_token_info(
|
||||
refresh_token)
|
||||
|
||||
headers, body, code = SERVER.create_token_response(
|
||||
uri=request.url,
|
||||
http_method=request.method,
|
||||
body=request.body,
|
||||
headers=request.headers)
|
||||
|
||||
response.headers = dict((str(k), str(v))
|
||||
for k, v in headers.iteritems())
|
||||
response.status_code = code
|
||||
|
||||
json_body = json.loads(body)
|
||||
json_body.update({
|
||||
'id_token': code_info.user_id
|
||||
})
|
||||
|
||||
# Update a body with user_id only if a response is 2xx
|
||||
if code / 100 == 2:
|
||||
json_body.update({
|
||||
'id_token': refresh_token_info.user_id
|
||||
})
|
||||
|
||||
response.body = json.dumps(json_body)
|
||||
|
||||
return response
|
||||
|
||||
@pecan.expose()
|
||||
def token(self):
|
||||
"""Token endpoint."""
|
||||
|
||||
grant_type = request.params.get("grant_type")
|
||||
if grant_type == "authorization_code":
|
||||
# Serve an access token having an authorization code
|
||||
return self._access_token_by_code()
|
||||
|
||||
if grant_type == "refresh_token":
|
||||
# Serve an access token having a refresh token
|
||||
return self._access_token_by_refresh_token()
|
||||
|
|
|
@ -57,3 +57,10 @@ def refresh_token_get(refresh_token):
|
|||
|
||||
def refresh_token_save(values):
|
||||
return api_base.entity_create(models.RefreshToken, values)
|
||||
|
||||
|
||||
def refresh_token_delete(refresh_token):
|
||||
del_token = refresh_token_get(refresh_token)
|
||||
|
||||
if del_token:
|
||||
api_base.entity_hard_delete(models.RefreshToken, del_token.id)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
"""The refresh token should also have expiration fields.
|
||||
|
||||
Revision ID: 019
|
||||
Revises: 018
|
||||
Create Date: 2014-05-21 11:17:16.360987
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '019'
|
||||
down_revision = '018'
|
||||
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade(active_plugins=None, options=None):
|
||||
|
||||
# Deleting old tokens because they don't have a valid expiration
|
||||
# information.
|
||||
bind = op.get_bind()
|
||||
bind.execute(sa.delete(table='refreshtokens'))
|
||||
|
||||
op.add_column('refreshtokens', sa.Column('expires_at', sa.DateTime(),
|
||||
nullable=False))
|
||||
op.add_column('refreshtokens', sa.Column('expires_in', sa.Integer(),
|
||||
nullable=False))
|
||||
### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade(active_plugins=None, options=None):
|
||||
|
||||
op.drop_column('refreshtokens', 'expires_in')
|
||||
op.drop_column('refreshtokens', 'expires_at')
|
||||
### end Alembic commands ###
|
|
@ -223,6 +223,8 @@ class RefreshToken(Base):
|
|||
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
refresh_token = Column(Unicode(100), nullable=False)
|
||||
expires_in = Column(Integer, nullable=False)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
|
||||
def _story_build_summary_query():
|
||||
|
|
Loading…
Reference in New Issue