diff --git a/neutron/api/rpc/handlers/securitygroups_rpc.py b/neutron/api/rpc/handlers/securitygroups_rpc.py index 20269abd665..752843a2776 100644 --- a/neutron/api/rpc/handlers/securitygroups_rpc.py +++ b/neutron/api/rpc/handlers/securitygroups_rpc.py @@ -438,6 +438,10 @@ class SecurityGroupServerAPIShim(sg_rpc_base.SecurityGroupInfoAPIMixin): for sg_id in p['security_group_ids'])) return [(sg_id, ) for sg_id in sg_ids] - def _is_security_group_stateful(self, context, sg_id): - sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id) - return sg.stateful + def _get_sgs_stateful_flag(self, context, sg_ids): + sgs_stateful = {} + for sg_id in sg_ids: + sg = self.rcache.get_resource_by_id(resources.SECURITYGROUP, sg_id) + sgs_stateful[sg_id] = sg.stateful + + return sgs_stateful diff --git a/neutron/db/securitygroups_rpc_base.py b/neutron/db/securitygroups_rpc_base.py index 3f32e913c24..c4c8f16102d 100644 --- a/neutron/db/securitygroups_rpc_base.py +++ b/neutron/db/securitygroups_rpc_base.py @@ -215,12 +215,10 @@ class SecurityGroupInfoAPIMixin(object): # this set will be serialized into a list by rpc code remote_address_group_info[remote_ag_id][ethertype] = set() direction = rule_in_db['direction'] - stateful = self._is_security_group_stateful(context, - security_group_id) rule_dict = { 'direction': direction, 'ethertype': ethertype, - 'stateful': stateful} + } for key in ('protocol', 'port_range_min', 'port_range_max', 'remote_ip_prefix', 'remote_group_id', @@ -238,6 +236,13 @@ class SecurityGroupInfoAPIMixin(object): if rule_dict not in sg_info['security_groups'][security_group_id]: sg_info['security_groups'][security_group_id].append( rule_dict) + + # Populate the security group "stateful" flag in the SGs list of rules. + for sg_id, stateful in self._get_sgs_stateful_flag( + context, sg_info['security_groups'].keys()).items(): + for rule in sg_info['security_groups'][sg_id]: + rule['stateful'] = stateful + # Update the security groups info if they don't have any rules sg_ids = self._select_sg_ids_for_ports(context, ports) for (sg_id, ) in sg_ids: @@ -431,13 +436,13 @@ class SecurityGroupInfoAPIMixin(object): """ raise NotImplementedError() - def _is_security_group_stateful(self, context, sg_id): - """Return whether the security group is stateful or not. + def _get_sgs_stateful_flag(self, context, sg_id): + """Return the security groups stateful flag. - Return True if the security group associated with the given ID - is stateful, else False. + Returns a dictionary with the SG ID as key and the stateful flag: + {sg_1: True, sg_2: False, ...} """ - return True + raise NotImplementedError() class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin, @@ -534,5 +539,5 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin, return ips_by_group @db_api.retry_if_session_inactive() - def _is_security_group_stateful(self, context, sg_id): - return sg_obj.SecurityGroup.get_sg_by_id(context, sg_id).stateful + def _get_sgs_stateful_flag(self, context, sg_ids): + return sg_obj.SecurityGroup.get_sgs_stateful_flag(context, sg_ids) diff --git a/neutron/objects/securitygroup.py b/neutron/objects/securitygroup.py index cd59481910b..4ce210d5fb4 100644 --- a/neutron/objects/securitygroup.py +++ b/neutron/objects/securitygroup.py @@ -130,6 +130,13 @@ class SecurityGroup(rbac_db.NeutronRbacObject): security_group_ids=[obj_id]) return {port.project_id for port in port_objs} + @classmethod + @db_api.CONTEXT_READER + def get_sgs_stateful_flag(cls, context, sg_ids): + query = context.session.query(cls.db_model.id, cls.db_model.stateful) + query = query.filter(cls.db_model.id.in_(sg_ids)) + return dict(query.all()) + @base.NeutronObjectRegistry.register class DefaultSecurityGroup(base.NeutronDbObject): diff --git a/neutron/tests/unit/objects/test_securitygroup.py b/neutron/tests/unit/objects/test_securitygroup.py index 77a575a6b2b..066608f8b82 100644 --- a/neutron/tests/unit/objects/test_securitygroup.py +++ b/neutron/tests/unit/objects/test_securitygroup.py @@ -219,6 +219,22 @@ class SecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase, self.assertEqual(len(sg_obj.rules), 0) self.assertIsNone(listed_objs[0].rules) + def test_get_sgs_stateful_flag(self): + for obj in self.objs: + obj.create() + + sg_ids = tuple(sg.id for sg in self.objs) + sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag( + self.context, sg_ids) + for sg_id, stateful in sgs_stateful.items(): + for obj in (obj for obj in self.objs if obj.id == sg_id): + self.assertEqual(obj.stateful, stateful) + + sg_ids = sg_ids + ('random_id_not_present', ) + sgs_stateful = securitygroup.SecurityGroup.get_sgs_stateful_flag( + self.context, sg_ids) + self.assertEqual(len(self.objs), len(sgs_stateful)) + class DefaultSecurityGroupIfaceObjTestCase(test_base.BaseObjectIfaceTestCase):