diff --git a/neutron_fwaas/db/firewall/v2/firewall_db_v2.py b/neutron_fwaas/db/firewall/v2/firewall_db_v2.py index c4cfaf599..ff39994f8 100644 --- a/neutron_fwaas/db/firewall/v2/firewall_db_v2.py +++ b/neutron_fwaas/db/firewall/v2/firewall_db_v2.py @@ -140,8 +140,10 @@ class Firewall_db_mixin_v2(fw_ext.Firewallv2PluginBase, base_db.CommonDbMixin): except exc.NoResultFound: raise fw_ext.FirewallRuleNotFound(firewall_rule_id=id) - def _validate_fwr_protocol_parameters(self, fwr): + def _validate_fwr_protocol_parameters(self, fwr, fwr_db=None): protocol = fwr.get('protocol', None) + if fwr_db and not protocol: + protocol = fwr_db.protocol if protocol not in (nl_constants.PROTO_NAME_TCP, nl_constants.PROTO_NAME_UDP): if (fwr.get('source_port', None) or @@ -353,9 +355,9 @@ class Firewall_db_mixin_v2(fw_ext.Firewallv2PluginBase, base_db.CommonDbMixin): def update_firewall_rule(self, context, id, firewall_rule): LOG.debug("update_firewall_rule() called") fwr = firewall_rule['firewall_rule'] - self._validate_fwr_protocol_parameters(fwr) - self._validate_fwr_src_dst_ip_version(fwr) fwr_db = self._get_firewall_rule(context, id) + self._validate_fwr_protocol_parameters(fwr, fwr_db=fwr_db) + self._validate_fwr_src_dst_ip_version(fwr) if 'source_port' in fwr: src_port_min, src_port_max = self._get_min_max_ports_from_range( fwr['source_port']) diff --git a/neutron_fwaas/tests/unit/db/firewall/v2/test_firewall_db_v2.py b/neutron_fwaas/tests/unit/db/firewall/v2/test_firewall_db_v2.py index 1ec4629d8..fc865fc09 100644 --- a/neutron_fwaas/tests/unit/db/firewall/v2/test_firewall_db_v2.py +++ b/neutron_fwaas/tests/unit/db/firewall/v2/test_firewall_db_v2.py @@ -951,6 +951,15 @@ class TestFirewallDBPluginV2(FirewallPluginV2DbTestCase): res = req.get_response(self.ext_api) self.assertEqual(400, res.status_int) + with self.firewall_rule(source_port=None, + destination_port=None, + protocol='icmp') as fwr: + data = {'firewall_rule': {'destination_port': 80}} + req = self.new_update_request('firewall_rules', data, + fwr['firewall_rule']['id']) + res = req.get_response(self.ext_api) + self.assertEqual(400, res.status_int) + def test_update_firewall_rule_with_policy_associated(self): name = "new_firewall_rule1" attrs = self._get_test_firewall_rule_attrs(name)