diff --git a/oslo_context/context.py b/oslo_context/context.py index bd1a6c4..adad06d 100644 --- a/oslo_context/context.py +++ b/oslo_context/context.py @@ -53,7 +53,7 @@ class RequestContext(object): def __init__(self, auth_token=None, user=None, tenant=None, domain=None, user_domain=None, project_domain=None, is_admin=False, read_only=False, show_deleted=False, request_id=None, - resource_uuid=None, overwrite=True): + resource_uuid=None, overwrite=True, roles=None): """Initialize the RequestContext :param overwrite: Set to False to ensure that the greenthread local @@ -69,6 +69,7 @@ class RequestContext(object): self.read_only = read_only self.show_deleted = show_deleted self.resource_uuid = resource_uuid + self.roles = roles or [] if not request_id: request_id = generate_request_id() self.request_id = request_id @@ -99,6 +100,7 @@ class RequestContext(object): 'auth_token': self.auth_token, 'request_id': self.request_id, 'resource_uuid': self.resource_uuid, + 'roles': self.roles, 'user_identity': user_idt} def get_logging_values(self): @@ -143,6 +145,9 @@ class RequestContext(object): kwargs.setdefault('project_domain', environ.get('HTTP_X_PROJECT_DOMAIN_ID')) + roles = environ.get('HTTP_X_ROLES') + kwargs.setdefault('roles', roles.split(',') if roles else []) + return cls(**kwargs) diff --git a/oslo_context/tests/test_context.py b/oslo_context/tests/test_context.py index 4555da3..eff31c9 100644 --- a/oslo_context/tests/test_context.py +++ b/oslo_context/tests/test_context.py @@ -135,12 +135,14 @@ class ContextTest(test_base.BaseTestCase): project_id = uuid.uuid4().hex user_domain_id = uuid.uuid4().hex project_domain_id = uuid.uuid4().hex + roles = [uuid.uuid4().hex, uuid.uuid4().hex, uuid.uuid4().hex] environ = {'HTTP_X_AUTH_TOKEN': auth_token, 'HTTP_X_USER_ID': user_id, 'HTTP_X_PROJECT_ID': project_id, 'HTTP_X_USER_DOMAIN_ID': user_domain_id, - 'HTTP_X_PROJECT_DOMAIN_ID': project_domain_id} + 'HTTP_X_PROJECT_DOMAIN_ID': project_domain_id, + 'HTTP_X_ROLES': ','.join(roles)} ctx = context.RequestContext.from_environ(environ) @@ -149,6 +151,14 @@ class ContextTest(test_base.BaseTestCase): self.assertEqual(project_id, ctx.tenant) self.assertEqual(user_domain_id, ctx.user_domain) self.assertEqual(project_domain_id, ctx.project_domain) + self.assertEqual(roles, ctx.roles) + + def test_from_environ_no_roles(self): + ctx = context.RequestContext.from_environ(environ={}) + self.assertEqual([], ctx.roles) + + ctx = context.RequestContext.from_environ(environ={'HTTP_X_ROLES': ''}) + self.assertEqual([], ctx.roles) def test_from_function_and_args(self): ctx = context.RequestContext(user="user1") @@ -214,6 +224,7 @@ class ContextTest(test_base.BaseTestCase): self.assertIn('request_id', d) self.assertIn('resource_uuid', d) self.assertIn('user_identity', d) + self.assertIn('roles', d) self.assertEqual(auth_token, d['auth_token']) self.assertEqual(tenant, d['tenant']) @@ -228,6 +239,7 @@ class ContextTest(test_base.BaseTestCase): user_identity = "%s %s %s %s %s" % (user, tenant, domain, user_domain, project_domain) self.assertEqual(user_identity, d['user_identity']) + self.assertEqual([], d['roles']) def test_get_logging_values(self): auth_token = "token1"