diff --git a/neutron/plugins/ml2/extensions/port_security.py b/neutron/plugins/ml2/extensions/port_security.py index aceec24a235..cb582f3b28f 100644 --- a/neutron/plugins/ml2/extensions/port_security.py +++ b/neutron/plugins/ml2/extensions/port_security.py @@ -38,8 +38,10 @@ class PortSecurityExtensionDriver(api.ExtensionDriver, def process_create_network(self, context, data, result): # Create the network extension attributes. - if psec.PORTSECURITY in data: - self._process_network_port_security_create(context, data, result) + if psec.PORTSECURITY not in data: + data[psec.PORTSECURITY] = (psec.EXTENDED_ATTRIBUTES_2_0['networks'] + [psec.PORTSECURITY]['default']) + self._process_network_port_security_create(context, data, result) def process_update_network(self, context, data, result): # Update the network extension attributes. @@ -63,7 +65,12 @@ class PortSecurityExtensionDriver(api.ExtensionDriver, self._extend_port_security_dict(result, db_data) def _extend_port_security_dict(self, response_data, db_data): - response_data[psec.PORTSECURITY] = ( + if db_data.get('port_security') is None: + response_data[psec.PORTSECURITY] = ( + psec.EXTENDED_ATTRIBUTES_2_0['networks'] + [psec.PORTSECURITY]['default']) + else: + response_data[psec.PORTSECURITY] = ( db_data['port_security'][psec.PORTSECURITY]) def _determine_port_security(self, context, port): diff --git a/neutron/tests/unit/extensions/test_portsecurity.py b/neutron/tests/unit/extensions/test_portsecurity.py index 42d0c340cca..76a269839ec 100644 --- a/neutron/tests/unit/extensions/test_portsecurity.py +++ b/neutron/tests/unit/extensions/test_portsecurity.py @@ -23,6 +23,7 @@ from neutron.db import securitygroups_db from neutron.extensions import portsecurity as psec from neutron.extensions import securitygroup as ext_sg from neutron import manager +from neutron.plugins.ml2.extensions import port_security from neutron.tests.unit.db import test_db_base_plugin_v2 from neutron.tests.unit.extensions import test_securitygroup @@ -399,3 +400,15 @@ class TestPortSecurity(PortSecurityDBTestCase): '', 'not_network_owner') res = req.get_response(self.api) self.assertEqual(res.status_int, exc.HTTPForbidden.code) + + def test_extend_port_dict_no_port_security(self): + """Test _extend_port_security_dict won't crash + if port_security item is None + """ + for db_data in ({'port_security': None, 'name': 'net1'}, {}): + response_data = {} + + driver = port_security.PortSecurityExtensionDriver() + driver._extend_port_security_dict(response_data, db_data) + + self.assertTrue(response_data[psec.PORTSECURITY]) diff --git a/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py b/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py index 6180ff10e86..7ec80f75aa5 100644 --- a/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py +++ b/neutron/tests/unit/plugins/ml2/test_ext_portsecurity.py @@ -13,6 +13,9 @@ # License for the specific language governing permissions and limitations # under the License. +from neutron import context +from neutron.extensions import portsecurity as psec +from neutron import manager from neutron.plugins.ml2 import config from neutron.tests.unit.extensions import test_portsecurity as test_psec from neutron.tests.unit.plugins.ml2 import test_plugin @@ -27,3 +30,22 @@ class PSExtDriverTestCase(test_plugin.Ml2PluginV2TestCase, self._extension_drivers, group='ml2') super(PSExtDriverTestCase, self).setUp() + + def test_create_net_port_security_default(self): + _core_plugin = manager.NeutronManager.get_plugin() + admin_ctx = context.get_admin_context() + _default_value = (psec.EXTENDED_ATTRIBUTES_2_0['networks'] + [psec.PORTSECURITY]['default']) + args = {'network': + {'name': 'test', + 'tenant_id': '', + 'shared': False, + 'admin_state_up': True, + 'status': 'ACTIVE'}} + try: + network = _core_plugin.create_network(admin_ctx, args) + _value = network[psec.PORTSECURITY] + finally: + if network: + _core_plugin.delete_network(admin_ctx, network['id']) + self.assertEqual(_default_value, _value)