Fix the verification method before creating and updating the firewall rule

1.Verify the validity of the icmp port when the protocol is configured.
2.When updating the rule, the modified data should be verified.

Change-Id: I4c4b1cc5ff25b67e77669b721df4fdbb7d47515f
Closes-Bug: #1803499
This commit is contained in:
25643 2019-02-07 21:55:16 +08:00 committed by Chengqian Liu
parent 4a8687e3f3
commit 22aace21eb
2 changed files with 33 additions and 21 deletions

View File

@ -216,17 +216,20 @@ class FirewallPluginDb(common_db_mixin.CommonDbMixin):
except exc.NoResultFound:
raise f_exc.FirewallRuleNotFound(firewall_rule_id=id)
def _validate_fwr_protocol_parameters(self, fwr, fwr_db=None):
protocol = fwr.get('protocol', None)
if fwr_db and not protocol:
protocol = fwr_db.protocol
if protocol not in (nl_constants.PROTO_NAME_TCP,
nl_constants.PROTO_NAME_UDP):
if (fwr.get('source_port', None) or
fwr.get('destination_port', None)):
def _validate_fwr_protocol_parameters(self, fwr):
protocol = fwr['protocol']
source_port = fwr['source_port']
dest_port = fwr['destination_port']
if protocol and protocol not in (nl_constants.PROTO_NAME_TCP,
nl_constants.PROTO_NAME_UDP):
if source_port or dest_port:
raise f_exc.FirewallRuleInvalidICMPParameter(
param="Source, destination port")
if not protocol and (source_port or dest_port):
raise f_exc.FirewallRuleWithPortWithoutProtocolInvalid()
def _validate_fwr_src_dst_ip_version(self, fwr, fwr_db=None):
src_version = dst_version = None
if fwr.get('source_ip_address', None):
@ -451,9 +454,7 @@ class FirewallPluginDb(common_db_mixin.CommonDbMixin):
fwr = firewall_rule
self._validate_fwr_protocol_parameters(fwr)
self._validate_fwr_src_dst_ip_version(fwr)
if not fwr['protocol'] and (fwr['source_port'] or
fwr['destination_port']):
raise f_exc.FirewallRuleWithPortWithoutProtocolInvalid()
src_port_min, src_port_max = self._get_min_max_ports_from_range(
fwr['source_port'])
dst_port_min, dst_port_max = self._get_min_max_ports_from_range(
@ -481,8 +482,11 @@ class FirewallPluginDb(common_db_mixin.CommonDbMixin):
def update_firewall_rule(self, context, id, firewall_rule):
fwr = firewall_rule
fwr_db = self._get_firewall_rule(context, id)
self._validate_fwr_protocol_parameters(fwr, fwr_db=fwr_db)
self._validate_fwr_src_dst_ip_version(fwr, fwr_db=fwr_db)
fwr_db_updated = self._make_firewall_rule_dict(fwr_db)
fwr_db_updated.update(fwr)
self._validate_fwr_protocol_parameters(fwr_db_updated)
self._validate_fwr_src_dst_ip_version(fwr_db_updated)
if 'source_port' in fwr:
src_port_min, src_port_max = self._get_min_max_ports_from_range(
fwr['source_port'])
@ -496,14 +500,6 @@ class FirewallPluginDb(common_db_mixin.CommonDbMixin):
fwr['destination_port_range_max'] = dst_port_max
del fwr['destination_port']
with context.session.begin(subtransactions=True):
protocol = fwr.get('protocol', fwr_db['protocol'])
if not protocol:
sport = fwr.get('source_port_range_min',
fwr_db['source_port_range_min'])
dport = fwr.get('destination_port_range_min',
fwr_db['destination_port_range_min'])
if sport or dport:
raise f_exc.FirewallRuleWithPortWithoutProtocolInvalid()
fwr_db.update(fwr)
# if the rule on a policy, fix audited flag
fwp_ids = self.get_policies_with_rule(context, id)

View File

@ -697,6 +697,22 @@ class TestFirewallDBPluginV2(test_fwaas_plugin_v2.FirewallPluginV2TestCase):
res = req.get_response(self.ext_api)
self.assertEqual(400, res.status_int)
def test_update_firewall_rule_protocol_icmp(self):
with self.firewall_rule(source_port=10000) as fwr:
data = {'firewall_rule': {'protocol': 'icmp'}}
req = self.new_update_request('firewall_rules', data,
fwr['firewall_rule']['id'])
res = req.get_response(self.ext_api)
self.assertEqual(webob.exc.HTTPBadRequest.code, res.status_int)
def test_update_firewall_rule_protocol_none(self):
with self.firewall_rule(source_port=10000) as fwr:
data = {'firewall_rule': {'protocol': None}}
req = self.new_update_request('firewall_rules', data,
fwr['firewall_rule']['id'])
res = req.get_response(self.ext_api)
self.assertEqual(webob.exc.HTTPBadRequest.code, res.status_int)
def test_update_firewall_rule_with_policy_associated(self):
name = "new_firewall_rule1"
attrs = self._get_test_firewall_rule_attrs(name)