diff --git a/nova/compute/api.py b/nova/compute/api.py index 6d3290ba626a..4a281d6f5950 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -3412,9 +3412,11 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase): group = {'name': name, 'description': description} + columns_to_join = ['rules.grantee_group'] group_ref = self.db.security_group_update(context, - security_group['id'], - group) + security_group['id'], + group, + columns_to_join=columns_to_join) return group_ref def get(self, context, name=None, id=None, map_exception=False): diff --git a/nova/db/api.py b/nova/db/api.py index d4b10353259b..ef1d71406659 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -1231,9 +1231,11 @@ def security_group_create(context, values): return IMPL.security_group_create(context, values) -def security_group_update(context, security_group_id, values): +def security_group_update(context, security_group_id, values, + columns_to_join=None): """Update a security group.""" - return IMPL.security_group_update(context, security_group_id, values) + return IMPL.security_group_update(context, security_group_id, values, + columns_to_join=columns_to_join) def security_group_ensure_default(context): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index e5d4d5fa8285..996e7b615b36 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -3694,13 +3694,16 @@ def security_group_create(context, values): @require_context -def security_group_update(context, security_group_id, values): +def security_group_update(context, security_group_id, values, + columns_to_join=None): session = get_session() with session.begin(): - security_group_ref = model_query(context, models.SecurityGroup, - session=session).\ - filter_by(id=security_group_id).\ - first() + query = model_query(context, models.SecurityGroup, + session=session).filter_by(id=security_group_id) + if columns_to_join: + for column in columns_to_join: + query = query.options(joinedload_all(column)) + security_group_ref = query.first() if not security_group_ref: raise exception.SecurityGroupNotFound( diff --git a/nova/tests/api/openstack/compute/contrib/test_security_groups.py b/nova/tests/api/openstack/compute/contrib/test_security_groups.py index 08bccbde16b4..e5543b5d5cc1 100644 --- a/nova/tests/api/openstack/compute/contrib/test_security_groups.py +++ b/nova/tests/api/openstack/compute/contrib/test_security_groups.py @@ -422,10 +422,11 @@ class TestSecurityGroups(test.TestCase): self.assertEquals(sg['id'], group_id) return security_group_db(sg) - def return_update_security_group(context, group_id, values): - self.assertEquals(sg_update['id'], group_id) - self.assertEquals(sg_update['name'], values['name']) - self.assertEquals(sg_update['description'], values['description']) + def return_update_security_group(context, group_id, values, + columns_to_join=None): + self.assertEqual(sg_update['id'], group_id) + self.assertEqual(sg_update['name'], values['name']) + self.assertEqual(sg_update['description'], values['description']) return security_group_db(sg_update) self.stubs.Set(nova.db, 'security_group_update', diff --git a/nova/tests/db/test_db_api.py b/nova/tests/db/test_db_api.py index 8ca2005cfe24..5627eeb1cafc 100644 --- a/nova/tests/db/test_db_api.py +++ b/nova/tests/db/test_db_api.py @@ -1247,11 +1247,14 @@ class SecurityGroupTestCase(test.TestCase, ModelsObjectComparatorMixin): 'user_id': 'fake_user1', 'project_id': 'fake_proj1', } + updated_group = db.security_group_update(self.ctxt, - security_group['id'], - new_values) + security_group['id'], + new_values, + columns_to_join=['rules.grantee_group']) for key, value in new_values.iteritems(): self.assertEqual(updated_group[key], value) + self.assertEqual(updated_group['rules'], []) def test_security_group_update_to_duplicate(self): security_group1 = self._create_security_group(