diff --git a/etc/neutron.conf b/etc/neutron.conf index 13ff410c69a..99a3ca3761a 100644 --- a/etc/neutron.conf +++ b/etc/neutron.conf @@ -245,6 +245,9 @@ # # CIDR of the administrative network if HA mode is enabled # l3_ha_net_cidr = 169.254.192.0/18 +# +# Enable snat by default on external gateway when available +# enable_snat_by_default = True # =========== end of items for l3 extension ======= # =========== items for metadata proxy configuration ============== diff --git a/neutron/db/l3_gwmode_db.py b/neutron/db/l3_gwmode_db.py index 558adff15eb..8eabc9212a2 100644 --- a/neutron/db/l3_gwmode_db.py +++ b/neutron/db/l3_gwmode_db.py @@ -13,6 +13,7 @@ # under the License. # +from oslo_config import cfg from oslo_log import log as logging import sqlalchemy as sa from sqlalchemy import sql @@ -23,6 +24,12 @@ from neutron.extensions import l3 LOG = logging.getLogger(__name__) +OPTS = [ + cfg.BoolOpt('enable_snat_by_default', default=True, + help=_('Define the default value of enable_snat if not ' + 'provided in external_gateway_info.')) +] +cfg.CONF.register_opts(OPTS) EXTERNAL_GW_INFO = l3.EXTERNAL_GW_INFO # Modify the Router Data Model adding the enable_snat attribute @@ -55,10 +62,8 @@ class L3_NAT_dbonly_mixin(l3_db.L3_NAT_dbonly_mixin): # Load the router only if necessary if not router: router = self._get_router(context, router_id) - # if enable_snat is not specified then use the default value (True) - enable_snat = not info or info.get('enable_snat', True) with context.session.begin(subtransactions=True): - router.enable_snat = enable_snat + router.enable_snat = self._get_enable_snat(info) # Calls superclass, pass router db object for avoiding re-loading super(L3_NAT_dbonly_mixin, self)._update_router_gw_info( @@ -67,6 +72,13 @@ class L3_NAT_dbonly_mixin(l3_db.L3_NAT_dbonly_mixin): # method is overridden in child classes return router + @staticmethod + def _get_enable_snat(info): + if info and 'enable_snat' in info: + return info['enable_snat'] + # if enable_snat is not specified then use the default value + return cfg.CONF.enable_snat_by_default + def _build_routers_list(self, context, routers, gw_ports): for rtr in routers: gw_port_id = rtr['gw_port_id'] diff --git a/neutron/tests/unit/extensions/test_l3_ext_gw_mode.py b/neutron/tests/unit/extensions/test_l3_ext_gw_mode.py index cba57a6bf84..0d4f7fcc62d 100644 --- a/neutron/tests/unit/extensions/test_l3_ext_gw_mode.py +++ b/neutron/tests/unit/extensions/test_l3_ext_gw_mode.py @@ -16,6 +16,7 @@ import mock from oslo_config import cfg +import testscenarios from webob import exc from neutron.common import constants @@ -27,6 +28,7 @@ from neutron.db import models_v2 from neutron.extensions import l3 from neutron.extensions import l3_ext_gw_mode from neutron.openstack.common import uuidutils +from neutron.tests import base from neutron.tests.unit.db import test_db_base_plugin_v2 from neutron.tests.unit.extensions import test_l3 from neutron.tests.unit import testlib_api @@ -74,6 +76,33 @@ class TestDbSepPlugin(test_l3.TestL3NatServicePlugin, supported_extension_aliases = ["router", "ext-gw-mode"] +class TestGetEnableSnat(testscenarios.WithScenarios, base.BaseTestCase): + scenarios = [ + ('enabled', {'enable_snat_by_default': True}), + ('disabled', {'enable_snat_by_default': False})] + + def setUp(self): + super(TestGetEnableSnat, self).setUp() + self.config(enable_snat_by_default=self.enable_snat_by_default) + + def _test_get_enable_snat(self, expected, info): + observed = l3_gwmode_db.L3_NAT_dbonly_mixin._get_enable_snat(info) + self.assertEqual(expected, observed) + + def test_get_enable_snat_without_gw_info(self): + self._test_get_enable_snat(self.enable_snat_by_default, {}) + + def test_get_enable_snat_without_enable_snat(self): + info = {'network_id': _uuid()} + self._test_get_enable_snat(self.enable_snat_by_default, info) + + def test_get_enable_snat_with_snat_enabled(self): + self._test_get_enable_snat(True, {'enable_snat': True}) + + def test_get_enable_snat_with_snat_disabled(self): + self._test_get_enable_snat(False, {'enable_snat': False}) + + class TestL3GwModeMixin(testlib_api.SqlTestCase): def setUp(self):