Filter by owner SGs when retrieving the SG rules
Retrieving the SG rules now is used the admin context. This allows to get all possible rules, independently of the user calling. The filters passed and the RBAC policies filter those results, returning only: - The SG rules belonging to the user. - The SG rules belonging to a SG owned by the user. However, if the SG list is too long, the query can take a lot of time. Instead of this, the filtering is done in the DB query. If no filters are passed to "get_security_group_rules" and the context is not the admin context, only the rules specified in the first paragraph will be retrieved. Because overwriting the method "get_objects" is too complex, an intermediate query is done to retrieve the SG rule IDs. Those IDs will be used as a filter in the "get_objects" call. Closes-Bug: #1863201 Change-Id: I25d3da929f8d0b6ee15d7b90ec59b9d58a4ae6a5
This commit is contained in:
parent
8ba44d6720
commit
d874c46bff
|
@ -719,6 +719,11 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase,
|
|||
pager = base_obj.Pager(
|
||||
sorts=sorts, marker=marker, limit=limit, page_reverse=page_reverse)
|
||||
|
||||
if not filters and context.project_id and not context.is_admin:
|
||||
rule_ids = sg_obj.SecurityGroupRule.get_security_group_rule_ids(
|
||||
context.project_id)
|
||||
filters = {'id': rule_ids}
|
||||
|
||||
# NOTE(slaweq): use admin context here to be able to get all rules
|
||||
# which fits filters' criteria. Later in policy engine rules will be
|
||||
# filtered and only those which are allowed according to policy will
|
||||
|
|
|
@ -10,10 +10,12 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from neutron_lib import context as context_lib
|
||||
from neutron_lib.objects import common_types
|
||||
from neutron_lib.utils import net as net_utils
|
||||
from oslo_utils import versionutils
|
||||
from oslo_versionedobjects import fields as obj_fields
|
||||
from sqlalchemy import or_
|
||||
|
||||
from neutron.db.models import securitygroup as sg_models
|
||||
from neutron.db import rbac_db_models
|
||||
|
@ -155,3 +157,21 @@ class SecurityGroupRule(base.NeutronDbObject):
|
|||
fields['remote_ip_prefix'] = (
|
||||
net_utils.AuthenticIPNetwork(fields['remote_ip_prefix']))
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def get_security_group_rule_ids(cls, project_id):
|
||||
"""Retrieve all SG rules related to this project_id
|
||||
|
||||
This method returns the SG rule IDs that meet these conditions:
|
||||
- The rule belongs to this project_id
|
||||
- The rule belongs to a security group that belongs to the project_id
|
||||
"""
|
||||
context = context_lib.get_admin_context()
|
||||
query = context.session.query(cls.db_model.id)
|
||||
query = query.join(
|
||||
SecurityGroup.db_model,
|
||||
cls.db_model.security_group_id == SecurityGroup.db_model.id)
|
||||
clauses = or_(SecurityGroup.db_model.project_id == project_id,
|
||||
cls.db_model.project_id == project_id)
|
||||
rule_ids = query.filter(clauses).all()
|
||||
return [rule_id[0] for rule_id in rule_ids]
|
||||
|
|
|
@ -479,3 +479,60 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
|
|||
{'port_range_min': None,
|
||||
'port_range_max': 200,
|
||||
'protocol': constants.PROTO_NAME_VRRP})
|
||||
|
||||
def _create_environment(self):
|
||||
self.sg = copy.deepcopy(FAKE_SECGROUP)
|
||||
self.user_ctx = context.Context(user_id='user1', tenant_id='tenant_1',
|
||||
is_admin=False, overwrite=False)
|
||||
self.admin_ctx = context.Context(user_id='user2', tenant_id='tenant_2',
|
||||
is_admin=True, overwrite=False)
|
||||
self.sg_user = self.mixin.create_security_group(
|
||||
self.user_ctx, {'security_group': {'name': 'name',
|
||||
'tenant_id': 'tenant_1',
|
||||
'description': 'fake'}})
|
||||
|
||||
def test_get_security_group_rules(self):
|
||||
self._create_environment()
|
||||
rules_before = self.mixin.get_security_group_rules(self.user_ctx)
|
||||
|
||||
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
|
||||
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
|
||||
rule['security_group_rule']['tenant_id'] = 'tenant_2'
|
||||
self.mixin.create_security_group_rule(self.admin_ctx, rule)
|
||||
|
||||
rules_after = self.mixin.get_security_group_rules(self.user_ctx)
|
||||
self.assertEqual(len(rules_before) + 1, len(rules_after))
|
||||
for rule in (rule for rule in rules_after if rule not in rules_before):
|
||||
self.assertEqual('tenant_2', rule['tenant_id'])
|
||||
|
||||
def test_get_security_group_rules_filters_passed(self):
|
||||
self._create_environment()
|
||||
filters = {'security_group_id': self.sg_user['id']}
|
||||
rules_before = self.mixin.get_security_group_rules(self.user_ctx,
|
||||
filters=filters)
|
||||
|
||||
default_sg = self.mixin.get_security_groups(
|
||||
self.user_ctx, filters={'name': 'default'})[0]
|
||||
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
|
||||
rule['security_group_rule']['security_group_id'] = default_sg['id']
|
||||
rule['security_group_rule']['tenant_id'] = 'tenant_1'
|
||||
self.mixin.create_security_group_rule(self.user_ctx, rule)
|
||||
|
||||
rules_after = self.mixin.get_security_group_rules(self.user_ctx,
|
||||
filters=filters)
|
||||
self.assertEqual(rules_before, rules_after)
|
||||
|
||||
def test_get_security_group_rules_admin_context(self):
|
||||
self._create_environment()
|
||||
rules_before = self.mixin.get_security_group_rules(self.ctx)
|
||||
|
||||
rule = copy.deepcopy(FAKE_SECGROUP_RULE)
|
||||
rule['security_group_rule']['security_group_id'] = self.sg_user['id']
|
||||
rule['security_group_rule']['tenant_id'] = 'tenant_1'
|
||||
self.mixin.create_security_group_rule(self.user_ctx, rule)
|
||||
|
||||
rules_after = self.mixin.get_security_group_rules(self.ctx)
|
||||
self.assertEqual(len(rules_before) + 1, len(rules_after))
|
||||
for rule in (rule for rule in rules_after if rule not in rules_before):
|
||||
self.assertEqual('tenant_1', rule['tenant_id'])
|
||||
self.assertEqual(self.sg_user['id'], rule['security_group_id'])
|
||||
|
|
|
@ -1626,8 +1626,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
|
|||
self._router.create()
|
||||
return self._router['id']
|
||||
|
||||
def _create_test_security_group_id(self):
|
||||
def _create_test_security_group_id(self, fields=None):
|
||||
sg_fields = self.get_random_object_fields(securitygroup.SecurityGroup)
|
||||
fields = fields or {}
|
||||
for field, value in ((f, v) for (f, v) in fields.items() if
|
||||
f in sg_fields):
|
||||
sg_fields[field] = value
|
||||
_securitygroup = securitygroup.SecurityGroup(
|
||||
self.context, **sg_fields)
|
||||
_securitygroup.create()
|
||||
|
|
|
@ -10,8 +10,12 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import random
|
||||
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
from neutron.objects import securitygroup
|
||||
from neutron.tests.unit.objects import test_base
|
||||
from neutron.tests.unit.objects import test_rbac
|
||||
|
@ -213,3 +217,38 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
|
|||
'remote_group_id':
|
||||
lambda: self._create_test_security_group_id()
|
||||
})
|
||||
|
||||
def test_get_security_group_rule_ids(self):
|
||||
"""Retrieve the SG rules associated to a project (see method desc.)
|
||||
|
||||
SG1 (PROJECT1) SG2 (PROJECT2)
|
||||
rule1a (PROJECT1) rule2a (PROJECT1)
|
||||
rule1b (PROJECT2) rule2b (PROJECT2)
|
||||
|
||||
query PROJECT1: rule1a, rule1b, rule2a
|
||||
query PROJECT2: rule1b, rule2a, rule2b
|
||||
"""
|
||||
projects = [uuidutils.generate_uuid(), uuidutils.generate_uuid()]
|
||||
sgs = [
|
||||
self._create_test_security_group_id({'project_id': projects[0]}),
|
||||
self._create_test_security_group_id({'project_id': projects[1]})]
|
||||
|
||||
rules_per_project = collections.defaultdict(list)
|
||||
rules_per_sg = collections.defaultdict(list)
|
||||
for project, sg in itertools.product(projects, sgs):
|
||||
sgrule_fields = self.get_random_object_fields(
|
||||
securitygroup.SecurityGroupRule)
|
||||
sgrule_fields['project_id'] = project
|
||||
sgrule_fields['security_group_id'] = sg
|
||||
rule = securitygroup.SecurityGroupRule(self.context,
|
||||
**sgrule_fields)
|
||||
rule.create()
|
||||
rules_per_project[project].append(rule.id)
|
||||
rules_per_sg[sg].append(rule.id)
|
||||
|
||||
for idx in range(2):
|
||||
rule_ids = securitygroup.SecurityGroupRule.\
|
||||
get_security_group_rule_ids(projects[idx])
|
||||
rule_ids_ref = set(rules_per_project[projects[idx]])
|
||||
rule_ids_ref.update(set(rules_per_sg[sgs[idx]]))
|
||||
self.assertEqual(rule_ids_ref, set(rule_ids))
|
||||
|
|
Loading…
Reference in New Issue