diff --git a/heat/engine/clients/os/keystone/fake_keystoneclient.py b/heat/engine/clients/os/keystone/fake_keystoneclient.py index 524d09a8ed..9715c952f0 100644 --- a/heat/engine/clients/os/keystone/fake_keystoneclient.py +++ b/heat/engine/clients/os/keystone/fake_keystoneclient.py @@ -91,6 +91,13 @@ class FakeKeystoneClient(object): trust_id='atrust', trustor_user_id=self.user_id) + def regenerate_trust_context(self): + return context.RequestContext(username=self.username, + password=self.password, + is_admin=False, + trust_id='atrust', + trustor_user_id=self.user_id) + def delete_trust(self, trust_id): pass diff --git a/heat/engine/clients/os/keystone/heat_keystoneclient.py b/heat/engine/clients/os/keystone/heat_keystoneclient.py index d99c5fab45..9754491a96 100644 --- a/heat/engine/clients/os/keystone/heat_keystoneclient.py +++ b/heat/engine/clients/os/keystone/heat_keystoneclient.py @@ -188,19 +188,7 @@ class KsClientWrapper(object): return client - def create_trust_context(self): - """Create a trust using the trustor identity in the current context. - - The trust is created with the trustee as the heat service user. - - If the current context already contains a trust_id, we do nothing - and return the current context. - - Returns a context containing the new trust_id. - """ - if self.context.trust_id: - return self.context - + def _create_trust_context(self, trustor_user_id, trustor_proj_id): # We need the service admin user ID (not name), as the trustor user # can't lookup the ID in keystoneclient unless they're admin # workaround this by getting the user_id from admin_client @@ -211,9 +199,6 @@ class KsClientWrapper(object): LOG.error("Domain admin client authentication failed") raise exception.AuthorizationFailure() - trustor_user_id = self.context.auth_plugin.get_user_id(self.session) - trustor_proj_id = self.context.auth_plugin.get_project_id(self.session) - role_kw = {} # inherit the roles of the trustor, unless set trusts_delegated_roles if cfg.CONF.trusts_delegated_roles: @@ -245,6 +230,23 @@ class KsClientWrapper(object): trust_context.trustor_user_id = trustor_user_id return trust_context + def create_trust_context(self): + """Create a trust using the trustor identity in the current context. + + The trust is created with the trustee as the heat service user. + + If the current context already contains a trust_id, we do nothing + and return the current context. + + Returns a context containing the new trust_id. + """ + if self.context.trust_id: + return self.context + + trustor_user_id = self.context.auth_plugin.get_user_id(self.session) + trustor_proj_id = self.context.auth_plugin.get_project_id(self.session) + return self._create_trust_context(trustor_user_id, trustor_proj_id) + def delete_trust(self, trust_id): """Delete the specified trust.""" try: @@ -252,6 +254,23 @@ class KsClientWrapper(object): except (ks_exception.NotFound, ks_exception.Unauthorized): pass + def regenerate_trust_context(self): + """Regenerate a trust using the trustor identity of current user_id. + + The trust is created with the trustee as the heat service user. + + Returns a context containing the new trust_id. + """ + old_trust_id = self.context.trust_id + trustor_user_id = self.context.auth_plugin.get_user_id(self.session) + trustor_proj_id = self.context.auth_plugin.get_project_id(self.session) + trust_context = self._create_trust_context(trustor_user_id, + trustor_proj_id) + + if old_trust_id: + self.delete_trust(old_trust_id) + return trust_context + def _get_username(self, username): if(len(username) > 255): LOG.warning("Truncating the username %s to the last 255 " diff --git a/heat/engine/service.py b/heat/engine/service.py index f3e6571d66..d16699f3e7 100644 --- a/heat/engine/service.py +++ b/heat/engine/service.py @@ -1022,9 +1022,11 @@ class EngineService(service.ServiceBase): LOG.info('Updating stack %s', db_stack.name) if cfg.CONF.reauthentication_auth_method == 'trusts': current_stack = parser.Stack.load( - cnxt, stack=db_stack, use_stored_context=True) + cnxt, stack=db_stack, use_stored_context=True, + check_refresh_cred=True) else: - current_stack = parser.Stack.load(cnxt, stack=db_stack) + current_stack = parser.Stack.load(cnxt, stack=db_stack, + check_refresh_cred=True) self.resource_enforcer.enforce_stack(current_stack, is_registered_policy=True) diff --git a/heat/engine/stack.py b/heat/engine/stack.py index be3323b8bc..aa65b92e8f 100644 --- a/heat/engine/stack.py +++ b/heat/engine/stack.py @@ -126,7 +126,7 @@ class Stack(collections.Mapping): nested_depth=0, strict_validate=True, convergence=False, current_traversal=None, tags=None, prev_raw_template_id=None, current_deps=None, cache_data=None, - deleted_time=None, converge=False): + deleted_time=None, converge=False, refresh_cred=False): """Initialise the Stack. @@ -188,6 +188,9 @@ class Stack(collections.Mapping): self.thread_group_mgr = None self.converge = converge + # This flag is use to check whether credential needs to refresh or not + self.refresh_cred = refresh_cred + # strict_validate can be used to disable value validation # in the resource properties schema, this is useful when # performing validation when properties reference attributes @@ -541,10 +544,35 @@ class Stack(collections.Mapping): {'res': str(res), 'err': str(exc)}) + @classmethod + def _check_refresh_cred(cls, context, stack): + if stack.user_creds_id: + creds_obj = ucreds_object.UserCreds.get_by_id( + context, stack.user_creds_id) + creds = creds_obj.obj_to_primitive()["versioned_object.data"] + stored_context = common_context.StoredContext.from_dict(creds) + + if cfg.CONF.deferred_auth_method == 'trusts': + old_trustor_proj_id = stored_context.tenant_id + old_trustor_user_id = stored_context.trustor_user_id + + trustor_user_id = context.auth_plugin.get_user_id( + context.clients.client('keystone').session) + trustor_proj_id = context.auth_plugin.get_project_id( + context.clients.client('keystone').session) + return False if ( + old_trustor_user_id == trustor_user_id) and ( + old_trustor_proj_id == trustor_proj_id + ) else True + + # Should not raise error or allow refresh credential when we can't find + # user_creds_id in stack + return False + @classmethod def load(cls, context, stack_id=None, stack=None, show_deleted=True, use_stored_context=False, force_reload=False, cache_data=None, - load_template=True): + load_template=True, check_refresh_cred=False): """Retrieve a Stack from the database.""" if stack is None: stack = stack_object.Stack.get_by_id( @@ -555,13 +583,22 @@ class Stack(collections.Mapping): message = _('No stack exists with id "%s"') % str(stack_id) raise exception.NotFound(message) + refresh_cred = False + if check_refresh_cred and ( + cfg.CONF.deferred_auth_method == 'trusts' + ): + if cls._check_refresh_cred(context, stack): + use_stored_context = False + refresh_cred = True + if force_reload: stack.refresh() return cls._from_db(context, stack, use_stored_context=use_stored_context, cache_data=cache_data, - load_template=load_template) + load_template=load_template, + refresh_cred=refresh_cred) @classmethod def load_all(cls, context, limit=None, marker=None, sort_keys=None, @@ -595,7 +632,7 @@ class Stack(collections.Mapping): @classmethod def _from_db(cls, context, stack, use_stored_context=False, cache_data=None, - load_template=True): + load_template=True, refresh_cred=False): if load_template: template = tmpl.Template.load( context, stack.raw_template_id, stack.raw_template) @@ -619,7 +656,8 @@ class Stack(collections.Mapping): prev_raw_template_id=stack.prev_raw_template_id, current_deps=stack.current_deps, cache_data=cache_data, nested_depth=stack.nested_depth, - deleted_time=stack.deleted_at) + deleted_time=stack.deleted_at, + refresh_cred=refresh_cred) def get_kwargs_for_cloning(self, keep_status=False, only_db=False, keep_tags=False): @@ -687,6 +725,17 @@ class Stack(collections.Mapping): s['raw_template_id'] = self.t.id if self.id is not None: + if self.refresh_cred: + keystone = self.clients.client('keystone') + trust_ctx = keystone.regenerate_trust_context() + new_creds = ucreds_object.UserCreds.create(trust_ctx) + s['user_creds_id'] = new_creds.id + + self._delete_user_cred(raise_keystone_exception=True) + + self.user_creds_id = new_creds.id + self.refresh_cred = False + if exp_trvsl is None and not ignore_traversal_check: exp_trvsl = self.current_traversal @@ -1840,11 +1889,10 @@ class Stack(collections.Mapping): LOG.exception("Failed to retrieve user_creds") return None - def _delete_credentials(self, stack_status, reason, abandon): + def _delete_user_cred(self, stack_status=None, reason=None, + raise_keystone_exception=False): # Cleanup stored user_creds so they aren't accessible via # the soft-deleted stack which remains in the DB - # The stack_status and reason passed in are current values, which - # may get rewritten and returned from this method if self.user_creds_id: user_creds = self._try_get_user_creds() # If we created a trust, delete it @@ -1874,6 +1922,8 @@ class Stack(collections.Mapping): # Without this, they would need to issue # an additional stack-delete LOG.exception("Error deleting trust") + if raise_keystone_exception: + raise # Delete the stored credentials try: @@ -1883,13 +1933,18 @@ class Stack(collections.Mapping): LOG.info("Tried to delete user_creds that do not exist " "(stack=%(stack)s user_creds_id=%(uc)s)", {'stack': self.id, 'uc': self.user_creds_id}) + self.user_creds_id = None + return stack_status, reason - try: - self.user_creds_id = None - self.store() - except exception.NotFound: - LOG.info("Tried to store a stack that does not exist %s", - self.id) + def _delete_credentials(self, stack_status, reason, abandon): + # The stack_status and reason passed in are current values, which + # may get rewritten and returned from this method + stack_status, reason = self._delete_user_cred(stack_status, reason) + try: + self.store() + except exception.NotFound: + LOG.info("Tried to store a stack that does not exist %s", + self.id) # If the stack has a domain project, delete it if self.stack_user_project_id and not abandon: diff --git a/heat/tests/clients/test_heat_client.py b/heat/tests/clients/test_heat_client.py index 912fbdf75f..dabd004361 100644 --- a/heat/tests/clients/test_heat_client.py +++ b/heat/tests/clients/test_heat_client.py @@ -521,6 +521,65 @@ class KeystoneClientTest(common.HeatTestCase): self.assertRaises(exception.AuthorizationFailure, heat_keystoneclient.KeystoneClient, ctx) + def test_regenerate_trust_context_with_no_exist_trust_id(self): + + """Test regenerate_trust_context.""" + + class MockTrust(object): + id = 'dtrust123' + + mock_ks_auth, mock_auth_ref = self._stubs_auth(user_id='5678', + project_id='42', + stub_trust_context=True, + stub_admin_auth=True) + + cfg.CONF.set_override('deferred_auth_method', 'trusts') + + trustor_roles = ['heat_stack_owner', 'admin', '__member__'] + trustee_roles = trustor_roles + mock_auth_ref.user_id = '5678' + mock_auth_ref.project_id = '42' + + self.mock_ks_v3_client.trusts.create.return_value = MockTrust() + + ctx = utils.dummy_context(roles=trustor_roles) + ctx.trust_id = None + heat_ks_client = heat_keystoneclient.KeystoneClient(ctx) + trust_context = heat_ks_client.regenerate_trust_context() + self.assertEqual('dtrust123', trust_context.trust_id) + self.assertEqual('5678', trust_context.trustor_user_id) + ks_loading.load_auth_from_conf_options.assert_called_once_with( + cfg.CONF, 'trustee', trust_id=None) + self.mock_ks_v3_client.trusts.create.assert_called_once_with( + trustor_user='5678', + trustee_user='1234', + project='42', + impersonation=True, + allow_redelegation=False, + role_names=trustee_roles) + self.assertEqual(0, self.mock_ks_v3_client.trusts.delete.call_count) + + def test_regenerate_trust_context_with_exist_trust_id(self): + + """Test regenerate_trust_context.""" + + self._stubs_auth(method='trust') + cfg.CONF.set_override('deferred_auth_method', 'trusts') + + ctx = utils.dummy_context() + ctx.trust_id = 'atrust123' + ctx.trustor_user_id = 'trustor_user_id' + + class MockTrust(object): + id = 'dtrust123' + + self.mock_ks_v3_client.trusts.create.return_value = MockTrust() + heat_ks_client = heat_keystoneclient.KeystoneClient(ctx) + trust_context = heat_ks_client.regenerate_trust_context() + self.assertEqual('dtrust123', trust_context.trust_id) + self.mock_ks_v3_client.trusts.delete.assert_called_once_with( + ctx.trust_id) + def test_create_trust_context_trust_id(self): """Test create_trust_context with existing trust_id.""" diff --git a/heat/tests/convergence/test_converge.py b/heat/tests/convergence/test_converge.py index 081a8d43c2..8e1b104802 100644 --- a/heat/tests/convergence/test_converge.py +++ b/heat/tests/convergence/test_converge.py @@ -11,13 +11,15 @@ # License for the specific language governing permissions and limitations # under the License. +from oslo_config import cfg + +from heat.common import context from heat.engine import resource from heat.tests import common from heat.tests.convergence.framework import fake_resource from heat.tests.convergence.framework import processes from heat.tests.convergence.framework import scenario from heat.tests.convergence.framework import testutils -from oslo_config import cfg class ScenarioTest(common.HeatTestCase): @@ -27,6 +29,7 @@ class ScenarioTest(common.HeatTestCase): def setUp(self): super(ScenarioTest, self).setUp() + self.patchobject(context, 'StoredContext') resource._register_class('OS::Heat::TestResource', fake_resource.TestResource) self.procs = processes.Processes() diff --git a/heat/tests/engine/service/test_stack_action.py b/heat/tests/engine/service/test_stack_action.py index 48a704b676..74a1d96e58 100644 --- a/heat/tests/engine/service/test_stack_action.py +++ b/heat/tests/engine/service/test_stack_action.py @@ -159,6 +159,7 @@ class StackServiceUpdateActionsNotSupportedTest(common.HeatTestCase): self.ctx, old_stack.identifier(), template, params, None, {}) self.assertEqual(exception.NotSupported, ex.exc_info[0]) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) old_stack.delete() diff --git a/heat/tests/engine/service/test_stack_events.py b/heat/tests/engine/service/test_stack_events.py index a1da41582e..ef1907c47b 100644 --- a/heat/tests/engine/service/test_stack_events.py +++ b/heat/tests/engine/service/test_stack_events.py @@ -16,6 +16,7 @@ from unittest import mock from oslo_config import cfg from oslo_messaging import conffixture +from heat.common import context from heat.engine import resource as res from heat.engine.resources.aws.ec2 import instance as instances from heat.engine import service @@ -32,6 +33,7 @@ class StackEventTest(common.HeatTestCase): def setUp(self): super(StackEventTest, self).setUp() + self.patchobject(context, 'StoredContext') self.ctx = utils.dummy_context(tenant_id='stack_event_test_tenant') self.eng = service.EngineService('a-host', 'a-topic') diff --git a/heat/tests/engine/service/test_stack_update.py b/heat/tests/engine/service/test_stack_update.py index 17aaabb8b7..f1df832af1 100644 --- a/heat/tests/engine/service/test_stack_update.py +++ b/heat/tests/engine/service/test_stack_update.py @@ -18,6 +18,7 @@ from oslo_config import cfg from oslo_messaging import conffixture from oslo_messaging.rpc import dispatcher +from heat.common import context from heat.common import environment_util as env_util from heat.common import exception from heat.common import messaging @@ -45,6 +46,7 @@ class ServiceStackUpdateTest(common.HeatTestCase): def setUp(self): super(ServiceStackUpdateTest, self).setUp() self.useFixture(conffixture.ConfFixture(cfg.CONF)) + self.patchobject(context, 'StoredContext') self.ctx = utils.dummy_context() self.man = service.EngineService('a-host', 'a-topic') self.man.thread_group_mgr = tools.DummyThreadGroupManager() @@ -103,7 +105,8 @@ class ServiceStackUpdateTest(common.HeatTestCase): username='test_username', converge=True ) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) mock_validate.assert_called_once_with() def _test_stack_update_with_environment_files(self, stack_name, @@ -222,7 +225,8 @@ class ServiceStackUpdateTest(common.HeatTestCase): username='test_username', converge=False ) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) mock_validate.assert_called_once_with() def test_stack_update_existing_parameters(self): @@ -555,7 +559,8 @@ resources: mock_validate.assert_called_once_with() mock_tmpl.assert_called_once_with(template, files=None) mock_env.assert_called_once_with(params) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) mock_stack.assert_called_once_with( self.ctx, stk.name, stk.t, convergence=False, @@ -703,7 +708,8 @@ resources: username='test_username', converge=False ) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) mock_validate.assert_called_once_with() def test_stack_update_stack_id_equal(self): @@ -750,9 +756,11 @@ resources: old_stack['A'].properties['Foo']) self.assertEqual(create_stack['A'].id, old_stack['A'].id) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) def test_stack_update_exceeds_resource_limit(self): + self.patchobject(context, 'StoredContext') stack_name = 'test_stack_update_exceeds_resource_limit' params = {} tpl = {'HeatTemplateFormatVersion': '2012-12-12', @@ -822,7 +830,8 @@ resources: username='test_username', converge=False ) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) mock_validate.assert_called_once_with() def test_stack_update_nonexist(self): @@ -886,7 +895,8 @@ resources: user_creds_id=u'1', username='test_username', converge=False ) - mock_load.assert_called_once_with(self.ctx, stack=s) + mock_load.assert_called_once_with(self.ctx, stack=s, + check_refresh_cred=True) def test_stack_update_existing_template(self): '''Update a stack using the same template.''' diff --git a/heat/tests/test_engine_service.py b/heat/tests/test_engine_service.py index aa6b62818a..875d44dcd2 100644 --- a/heat/tests/test_engine_service.py +++ b/heat/tests/test_engine_service.py @@ -289,7 +289,7 @@ class StackConvergenceServiceCreateUpdateTest(common.HeatTestCase): self.assertIsInstance(result, dict) self.assertTrue(result['stack_id']) parser.Stack.load.assert_called_once_with( - self.ctx, stack=mock.ANY) + self.ctx, stack=mock.ANY, check_refresh_cred=True) templatem.Template.assert_called_once_with(template, files=None) environment.Environment.assert_called_once_with(params) diff --git a/heat/tests/test_stack.py b/heat/tests/test_stack.py index 0e8a0b4e44..790d0f4cae 100644 --- a/heat/tests/test_stack.py +++ b/heat/tests/test_stack.py @@ -478,7 +478,7 @@ class StackTest(common.HeatTestCase): prev_raw_template_id=None, current_deps=None, cache_data=None, nested_depth=0, - deleted_time=None) + deleted_time=None, refresh_cred=False) template.Template.load.assert_called_once_with( self.ctx, stk.raw_template_id, stk.raw_template) @@ -1630,6 +1630,31 @@ class StackTest(common.HeatTestCase): saved_stack = stack.Stack.load(self.ctx, stack_id=stack_ownee.id) self.assertEqual(self.stack.id, saved_stack.owner_id) + def _test_load_with_refresh_cred(self, refresh=True): + cfg.CONF.set_override('deferred_auth_method', 'trusts') + self.patchobject(self.ctx.auth_plugin, 'get_user_id', + return_value='old_trustor_user_id') + self.patchobject(self.ctx.auth_plugin, 'get_project_id', + return_value='test_tenant_id') + + old_context = utils.dummy_context() + old_context.trust_id = 'atrust123' + old_context.trustor_user_id = ( + 'trustor_user_id' if refresh else 'old_trustor_user_id') + m_sc = self.patchobject(context, 'StoredContext') + m_sc.from_dict.return_value = old_context + self.stack = stack.Stack(self.ctx, 'test_regenerate_trust', self.tmpl) + self.stack.store() + load_stack = stack.Stack.load(self.ctx, stack_id=self.stack.id, + check_refresh_cred=True) + self.assertEqual(refresh, load_stack.refresh_cred) + + def test_load_with_refresh_cred(self): + self._test_load_with_refresh_cred() + + def test_load_with_no_refresh_cred(self): + self._test_load_with_refresh_cred(refresh=False) + def test_requires_deferred_auth(self): tmpl = {'HeatTemplateFormatVersion': '2012-12-12', 'Resources': {'AResource': {'Type': 'GenericResourceType'}, diff --git a/heat/tests/test_stack_update.py b/heat/tests/test_stack_update.py index 71e97c9814..a466cef4f3 100644 --- a/heat/tests/test_stack_update.py +++ b/heat/tests/test_stack_update.py @@ -18,6 +18,7 @@ from unittest import mock from heat.common import exception from heat.common import template_format from heat.db.sqlalchemy import api as db_api +from heat.engine.clients.os.keystone import fake_keystoneclient from heat.engine import environment from heat.engine import resource from heat.engine import rsrc_defn @@ -72,6 +73,37 @@ class StackUpdateTest(common.HeatTestCase): self.assertRaises(exception.NotFound, db_api.raw_template_get, self.ctx, raw_template_id) + def test_update_with_refresh_creds(self): + tmpl = {'HeatTemplateFormatVersion': '2012-12-12', + 'Resources': {'AResource': {'Type': 'GenericResourceType'}}} + + self.stack = stack.Stack(self.ctx, 'update_test_stack', + template.Template(tmpl)) + self.stack.store() + self.stack.create() + self.assertEqual((stack.Stack.CREATE, stack.Stack.COMPLETE), + self.stack.state) + + tmpl2 = {'HeatTemplateFormatVersion': '2012-12-12', + 'Resources': { + 'AResource': {'Type': 'GenericResourceType'}, + 'BResource': {'Type': 'GenericResourceType'}}} + updated_stack = stack.Stack(self.ctx, 'updated_stack', + template.Template(tmpl2)) + old_user_creds_id = self.stack.user_creds_id + self.stack.refresh_cred = True + + self.stack.context.user_id = '5678' + + mock_del_trust = self.patchobject( + fake_keystoneclient.FakeKeystoneClient, 'delete_trust') + + self.stack.update(updated_stack) + self.assertEqual((stack.Stack.UPDATE, stack.Stack.COMPLETE), + self.stack.state) + self.assertEqual(1, mock_del_trust.call_count) + self.assertNotEqual(self.stack.user_creds_id, old_user_creds_id) + def test_update_remove(self): tmpl = {'HeatTemplateFormatVersion': '2012-12-12', 'Resources': {