diff --git a/keystone/contrib/oauth1/backends/sql.py b/keystone/contrib/oauth1/backends/sql.py index 7713e6e0b7..653c32858d 100644 --- a/keystone/contrib/oauth1/backends/sql.py +++ b/keystone/contrib/oauth1/backends/sql.py @@ -22,6 +22,7 @@ from keystone.common import sql from keystone.common.sql import migration from keystone.contrib.oauth1 import core from keystone import exception +from keystone.openstack.common import jsonutils from keystone.openstack.common import timeutils @@ -38,13 +39,13 @@ class RequestToken(sql.ModelBase, sql.DictBase): __tablename__ = 'request_token' attributes = ['id', 'request_secret', 'verifier', 'authorizing_user_id', 'requested_project_id', - 'requested_roles', 'consumer_id', 'expires_at'] + 'role_ids', 'consumer_id', 'expires_at'] id = sql.Column(sql.String(64), primary_key=True, nullable=False) request_secret = sql.Column(sql.String(64), nullable=False) verifier = sql.Column(sql.String(64), nullable=True) authorizing_user_id = sql.Column(sql.String(64), nullable=True) requested_project_id = sql.Column(sql.String(64), nullable=False) - requested_roles = sql.Column(sql.Text(), nullable=False) + role_ids = sql.Column(sql.Text(), nullable=True) consumer_id = sql.Column(sql.String(64), sql.ForeignKey('consumer.id'), nullable=False, index=True) expires_at = sql.Column(sql.String(64), nullable=True) @@ -60,14 +61,14 @@ class RequestToken(sql.ModelBase, sql.DictBase): class AccessToken(sql.ModelBase, sql.DictBase): __tablename__ = 'access_token' attributes = ['id', 'access_secret', 'authorizing_user_id', - 'project_id', 'requested_roles', 'consumer_id', + 'project_id', 'role_ids', 'consumer_id', 'expires_at'] id = sql.Column(sql.String(64), primary_key=True, nullable=False) access_secret = sql.Column(sql.String(64), nullable=False) authorizing_user_id = sql.Column(sql.String(64), nullable=False, index=True) project_id = sql.Column(sql.String(64), nullable=False) - requested_roles = sql.Column(sql.Text(), nullable=False) + role_ids = sql.Column(sql.Text(), nullable=False) consumer_id = sql.Column(sql.String(64), sql.ForeignKey('consumer.id'), nullable=False) expires_at = sql.Column(sql.String(64), nullable=True) @@ -164,8 +165,7 @@ class OAuth1(sql.Base): session.flush() return core.filter_consumer(consumer_ref.to_dict()) - def create_request_token(self, consumer_id, roles, - project_id, token_duration): + def create_request_token(self, consumer_id, project_id, token_duration): expiry_date = None if token_duration: now = timeutils.utcnow() @@ -179,7 +179,7 @@ class OAuth1(sql.Base): ref['verifier'] = None ref['authorizing_user_id'] = None ref['requested_project_id'] = project_id - ref['requested_roles'] = roles + ref['role_ids'] = None ref['consumer_id'] = consumer_id ref['expires_at'] = expiry_date session = self.get_session() @@ -200,17 +200,20 @@ class OAuth1(sql.Base): token_ref = self._get_request_token(session, request_token_id) return token_ref.to_dict() - def authorize_request_token(self, request_token_id, user_id): + def authorize_request_token(self, request_token_id, user_id, + role_ids): session = self.get_session() with session.begin(): token_ref = self._get_request_token(session, request_token_id) token_dict = token_ref.to_dict() token_dict['authorizing_user_id'] = user_id token_dict['verifier'] = str(random.randint(1000, 9999)) + token_dict['role_ids'] = jsonutils.dumps(role_ids) new_token = RequestToken.from_dict(token_dict) for attr in RequestToken.attributes: - if (attr == 'authorizing_user_id' or attr == 'verifier'): + if (attr == 'authorizing_user_id' or attr == 'verifier' + or attr == 'role_ids'): setattr(token_ref, attr, getattr(new_token, attr)) session.flush() @@ -235,7 +238,7 @@ class OAuth1(sql.Base): ref['access_secret'] = uuid.uuid4().hex ref['authorizing_user_id'] = token_dict['authorizing_user_id'] ref['project_id'] = token_dict['requested_project_id'] - ref['requested_roles'] = token_dict['requested_roles'] + ref['role_ids'] = token_dict['role_ids'] ref['consumer_id'] = token_dict['consumer_id'] ref['expires_at'] = expiry_date token_ref = AccessToken.from_dict(ref) diff --git a/keystone/contrib/oauth1/controllers.py b/keystone/contrib/oauth1/controllers.py index 6ab9b870e1..ae1ccd4b2e 100644 --- a/keystone/contrib/oauth1/controllers.py +++ b/keystone/contrib/oauth1/controllers.py @@ -95,8 +95,8 @@ class AccessTokenCrudV3(controller.V3Controller): formatted_entity = entity.copy() access_token_id = formatted_entity['id'] user_id = "" - if 'requested_roles' in entity: - formatted_entity.pop('requested_roles') + if 'role_ids' in entity: + formatted_entity.pop('role_ids') if 'access_secret' in entity: formatted_entity.pop('access_secret') if 'authorizing_user_id' in entity: @@ -112,7 +112,7 @@ class AccessTokenCrudV3(controller.V3Controller): return formatted_entity -@dependency.requires('oauth_api') +@dependency.requires('oauth_api', 'assignment_api') class AccessTokenRolesV3(controller.V3Controller): collection_name = 'roles' member_name = 'role' @@ -121,30 +121,30 @@ class AccessTokenRolesV3(controller.V3Controller): access_token = self.oauth_api.get_access_token(access_token_id) if access_token['authorizing_user_id'] != user_id: raise exception.NotFound() - roles = access_token['requested_roles'] - roles_refs = jsonutils.loads(roles) - formatted_refs = ([self._format_role_entity(x) for x in roles_refs]) - return AccessTokenRolesV3.wrap_collection(context, formatted_refs) + authed_role_ids = access_token['role_ids'] + authed_role_ids = jsonutils.loads(authed_role_ids) + refs = ([self._format_role_entity(x) for x in authed_role_ids]) + return AccessTokenRolesV3.wrap_collection(context, refs) def get_access_token_role(self, context, user_id, access_token_id, role_id): access_token = self.oauth_api.get_access_token(access_token_id) if access_token['authorizing_user_id'] != user_id: raise exception.Unauthorized(_('User IDs do not match')) - roles = access_token['requested_roles'] - roles_dict = jsonutils.loads(roles) - for role in roles_dict: - if role['id'] == role_id: - role = self._format_role_entity(role) + authed_role_ids = access_token['role_ids'] + authed_role_ids = jsonutils.loads(authed_role_ids) + for authed_role_id in authed_role_ids: + if authed_role_id == role_id: + role = self._format_role_entity(role_id) return AccessTokenRolesV3.wrap_member(context, role) raise exception.RoleNotFound(_('Could not find role')) - def _format_role_entity(self, entity): - - formatted_entity = entity.copy() - if 'description' in entity: + def _format_role_entity(self, role_id): + role = self.assignment_api.get_role(role_id) + formatted_entity = role.copy() + if 'description' in role: formatted_entity.pop('description') - if 'enabled' in entity: + if 'enabled' in role: formatted_entity.pop('enabled') return formatted_entity @@ -159,19 +159,14 @@ class OAuthControllerV3(controller.V3Controller): headers = context['headers'] oauth_headers = oauth1.get_oauth_headers(headers) consumer_id = oauth_headers.get('oauth_consumer_key') - requested_role_ids = headers.get('Requested-Role-Ids') requested_project_id = headers.get('Requested-Project-Id') if not consumer_id: raise exception.ValidationError( attribute='oauth_consumer_key', target='request') - if not requested_role_ids: - raise exception.ValidationError( - attribute='requested_role_ids', target='request') if not requested_project_id: raise exception.ValidationError( attribute='requested_project_id', target='request') - req_role_ids = requested_role_ids.split(',') consumer_ref = self.oauth_api.get_consumer_with_secret(consumer_id) consumer = oauth1.Consumer(key=consumer_ref['id'], secret=consumer_ref['secret']) @@ -182,8 +177,7 @@ class OAuthControllerV3(controller.V3Controller): http_url=url, headers=context['headers'], query_string=context['query_string'], - parameters={'requested_role_ids': requested_role_ids, - 'requested_project_id': requested_project_id}) + parameters={'requested_project_id': requested_project_id}) oauth_server = oauth1.Server() oauth_server.add_signature_method(oauth1.SignatureMethod_HMAC_SHA1()) params = oauth_server.verify_request(oauth_request, @@ -195,27 +189,8 @@ class OAuthControllerV3(controller.V3Controller): msg = _('Non-oauth parameter - project, do not match') raise exception.Unauthorized(message=msg) - roles_params = params['requested_role_ids'] - roles_params_list = roles_params.split(',') - if roles_params_list != req_role_ids: - msg = _('Non-oauth parameter - roles, do not match') - raise exception.Unauthorized(message=msg) - - req_role_list = list() - all_roles = self.identity_api.list_roles() - for role in all_roles: - for req_role in req_role_ids: - if role['id'] == req_role: - req_role_list.append(role) - - if len(req_role_list) == 0: - msg = _('could not find matching roles for provided role ids') - raise exception.Unauthorized(message=msg) - - json_roles = jsonutils.dumps(req_role_list) request_token_duration = CONF.oauth1.request_token_duration token_ref = self.oauth_api.create_request_token(consumer_id, - json_roles, requested_project_id, request_token_duration) @@ -320,7 +295,7 @@ class OAuthControllerV3(controller.V3Controller): return response - def authorize(self, context, request_token_id): + def authorize(self, context, request_token_id, roles): """An authenticated user is going to authorize a request token. As a security precaution, the requested roles must match those in @@ -339,24 +314,26 @@ class OAuthControllerV3(controller.V3Controller): if now > expires: raise exception.Unauthorized(_('Request token is expired')) - req_roles = req_token['requested_roles'] - req_roles_list = jsonutils.loads(req_roles) - - req_set = set() - for x in req_roles_list: - req_set.add(x['id']) + # put the roles in a set for easy comparison + authed_roles = set() + for role in roles: + authed_roles.add(role['id']) # verify the authorizing user has the roles user_token = self.token_api.get_token(context['token_id']) - credentials = user_token['metadata'].copy() - user_roles = credentials.get('roles') user_id = user_token['user'].get('id') + project_id = req_token['requested_project_id'] + user_roles = self.assignment_api.get_roles_for_user_and_project( + user_id, project_id) cred_set = set(user_roles) - if not cred_set.issuperset(req_set): + if not cred_set.issuperset(authed_roles): msg = _('authorizing user does not have role required') raise exception.Unauthorized(message=msg) + # create list of just the id's for the backend + role_list = list(authed_roles) + # verify the user has the project too req_project_id = req_token['requested_project_id'] user_projects = self.assignment_api.list_user_projects(user_id) @@ -371,7 +348,7 @@ class OAuthControllerV3(controller.V3Controller): # finally authorize the token authed_token = self.oauth_api.authorize_request_token( - request_token_id, user_id) + request_token_id, user_id, role_list) to_return = {'token': {'oauth_verifier': authed_token['verifier']}} return to_return diff --git a/keystone/contrib/oauth1/core.py b/keystone/contrib/oauth1/core.py index 03e8707c01..32dfc32713 100644 --- a/keystone/contrib/oauth1/core.py +++ b/keystone/contrib/oauth1/core.py @@ -225,14 +225,12 @@ class Driver(object): """ raise exception.NotImplemented() - def create_request_token(self, consumer_id, requested_roles, - requested_project, request_token_duration): + def create_request_token(self, consumer_id, requested_project, + request_token_duration): """Create request token. :param consumer_id: the id of the consumer :type consumer_id: string - :param requested_roles: requested roles - :type requested_roles: string :param requested_project_id: requested project id :type requested_project_id: string :param request_token_duration: duration of request token @@ -262,13 +260,15 @@ class Driver(object): """ raise exception.NotImplemented() - def authorize_request_token(self, request_id, user_id): + def authorize_request_token(self, request_id, user_id, role_ids): """Authorize request token. :param request_id: the id of the request token, to be authorized :type request_id: string :param user_id: the id of the authorizing user :type user_id: string + :param role_ids: list of role ids to authorize + :type role_ids: list returns: verifier """ diff --git a/keystone/contrib/oauth1/migrate_repo/versions/004_request_token_roles_nullable.py b/keystone/contrib/oauth1/migrate_repo/versions/004_request_token_roles_nullable.py new file mode 100644 index 0000000000..aec13b8d1f --- /dev/null +++ b/keystone/contrib/oauth1/migrate_repo/versions/004_request_token_roles_nullable.py @@ -0,0 +1,36 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sqlalchemy as sql + + +def upgrade(migrate_engine): + meta = sql.MetaData() + meta.bind = migrate_engine + request_token_table = sql.Table('request_token', meta, autoload=True) + request_token_table.c.requested_roles.alter(name="role_ids", nullable=True) + access_token_table = sql.Table('access_token', meta, autoload=True) + access_token_table.c.requested_roles.alter(name="role_ids") + + +def downgrade(migrate_engine): + meta = sql.MetaData() + meta.bind = migrate_engine + request_token_table = sql.Table('request_token', meta, autoload=True) + request_token_table.c.role_ids.alter(name="requested_roles", + nullable=False) + access_token_table = sql.Table('access_token', meta, autoload=True) + access_token_table.c.role_ids.alter(name="requested_roles") diff --git a/keystone/tests/test_v3_oauth1.py b/keystone/tests/test_v3_oauth1.py index 6404f92e29..8d0c2a4fe1 100644 --- a/keystone/tests/test_v3_oauth1.py +++ b/keystone/tests/test_v3_oauth1.py @@ -66,9 +66,8 @@ class OAuth1Tests(test_v3.RestfulTestCase): token=token, **kw) - def _create_request_token(self, consumer, role, project_id): - params = {'requested_role_ids': role, - 'requested_project_id': project_id} + def _create_request_token(self, consumer, project_id): + params = {'requested_project_id': project_id} headers = {'Content-Type': 'application/json'} url = '/OS-OAUTH1/request_token' oreq = self._oauth_request( @@ -220,7 +219,6 @@ class OAuthFlowTests(OAuth1Tests): self.assertIsNotNone(self.consumer.key) url, headers = self._create_request_token(self.consumer, - self.role_id, self.project_id) content = self.post(url, headers=headers) credentials = urlparse.parse_qs(content.result) @@ -230,7 +228,8 @@ class OAuthFlowTests(OAuth1Tests): self.assertIsNotNone(self.request_token.key) url = self._authorize_request_token(request_key) - resp = self.put(url, expected_status=200) + body = {'roles': [{'id': self.role_id}]} + resp = self.put(url, body=body, expected_status=200) self.verifier = resp.result['token']['oauth_verifier'] self.request_token.set_verifier(self.verifier) @@ -446,7 +445,6 @@ class MaliciousOAuth1Tests(OAuth1Tests): consumer_id = consumer.get('id') consumer = oauth1.Consumer(consumer_id, "bad_secret") url, headers = self._create_request_token(consumer, - self.role_id, self.project_id) self.post(url, headers=headers, expected_status=500) @@ -456,11 +454,11 @@ class MaliciousOAuth1Tests(OAuth1Tests): consumer_secret = consumer.get('secret') consumer = oauth1.Consumer(consumer_id, consumer_secret) url, headers = self._create_request_token(consumer, - self.role_id, self.project_id) self.post(url, headers=headers) url = self._authorize_request_token("bad_key") - self.put(url, expected_status=404) + body = {'roles': [{'id': self.role_id}]} + self.put(url, body=body, expected_status=404) def test_bad_verifier(self): consumer = self._create_single_consumer() @@ -469,7 +467,6 @@ class MaliciousOAuth1Tests(OAuth1Tests): consumer = oauth1.Consumer(consumer_id, consumer_secret) url, headers = self._create_request_token(consumer, - self.role_id, self.project_id) content = self.post(url, headers=headers) credentials = urlparse.parse_qs(content.result) @@ -478,7 +475,8 @@ class MaliciousOAuth1Tests(OAuth1Tests): request_token = oauth1.Token(request_key, request_secret) url = self._authorize_request_token(request_key) - resp = self.put(url, expected_status=200) + body = {'roles': [{'id': self.role_id}]} + resp = self.put(url, body=body, expected_status=200) verifier = resp.result['token']['oauth_verifier'] self.assertIsNotNone(verifier) @@ -487,17 +485,6 @@ class MaliciousOAuth1Tests(OAuth1Tests): request_token) self.post(url, headers=headers, expected_status=401) - def test_bad_requested_roles(self): - consumer = self._create_single_consumer() - consumer_id = consumer.get('id') - consumer_secret = consumer.get('secret') - consumer = oauth1.Consumer(consumer_id, consumer_secret) - - url, headers = self._create_request_token(consumer, - "bad_role", - self.project_id) - self.post(url, headers=headers, expected_status=401) - def test_bad_authorizing_roles(self): consumer = self._create_single_consumer() consumer_id = consumer.get('id') @@ -505,7 +492,6 @@ class MaliciousOAuth1Tests(OAuth1Tests): consumer = oauth1.Consumer(consumer_id, consumer_secret) url, headers = self._create_request_token(consumer, - self.role_id, self.project_id) content = self.post(url, headers=headers) credentials = urlparse.parse_qs(content.result) @@ -515,7 +501,9 @@ class MaliciousOAuth1Tests(OAuth1Tests): self.project_id, self.role_id) url = self._authorize_request_token(request_key) - self.admin_request(path=url, method='PUT', expected_status=404) + body = {'roles': [{'id': self.role_id}]} + self.admin_request(path=url, method='PUT', + body=body, expected_status=404) def test_expired_authorizing_request_token(self): CONF.oauth1.request_token_duration = -1 @@ -527,7 +515,6 @@ class MaliciousOAuth1Tests(OAuth1Tests): self.assertIsNotNone(self.consumer.key) url, headers = self._create_request_token(self.consumer, - self.role_id, self.project_id) content = self.post(url, headers=headers) credentials = urlparse.parse_qs(content.result) @@ -537,7 +524,8 @@ class MaliciousOAuth1Tests(OAuth1Tests): self.assertIsNotNone(self.request_token.key) url = self._authorize_request_token(request_key) - self.put(url, expected_status=401) + body = {'roles': [{'id': self.role_id}]} + self.put(url, body=body, expected_status=401) def test_expired_creating_keystone_token(self): CONF.oauth1.access_token_duration = -1 @@ -548,7 +536,6 @@ class MaliciousOAuth1Tests(OAuth1Tests): self.assertIsNotNone(self.consumer.key) url, headers = self._create_request_token(self.consumer, - self.role_id, self.project_id) content = self.post(url, headers=headers) credentials = urlparse.parse_qs(content.result) @@ -558,7 +545,8 @@ class MaliciousOAuth1Tests(OAuth1Tests): self.assertIsNotNone(self.request_token.key) url = self._authorize_request_token(request_key) - resp = self.put(url, expected_status=200) + body = {'roles': [{'id': self.role_id}]} + resp = self.put(url, body=body, expected_status=200) self.verifier = resp.result['token']['oauth_verifier'] self.request_token.set_verifier(self.verifier) diff --git a/keystone/token/providers/uuid.py b/keystone/token/providers/uuid.py index 262e10e4ac..d849eefd33 100644 --- a/keystone/token/providers/uuid.py +++ b/keystone/token/providers/uuid.py @@ -213,7 +213,15 @@ class V3TokenDataHelper(object): return if access_token: - token_data['roles'] = json.loads(access_token['requested_roles']) + filtered_roles = [] + authed_role_ids = json.loads(access_token['role_ids']) + all_roles = self.identity_api.list_roles() + for role in all_roles: + for authed_role in authed_role_ids: + if authed_role == role['id']: + filtered_roles.append({'id': role['id'], + 'name': role['name']}) + token_data['roles'] = filtered_roles return if CONF.trust.enabled and trust: