neutron-classifier/neutron_classifier/tests/functional/test_api.py

191 lines
9.0 KiB
Python

# Copyright (c) 2018 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_utils import uuidutils
from neutron_lib.db import api as db_api
from neutron.tests.unit import testlib_api
from neutron_classifier.common import exceptions
from neutron_classifier.common import validators
from neutron_classifier.db.classification import\
TrafficClassificationGroupPlugin as cg_plugin
from neutron_classifier.objects import classifications
from neutron_classifier.services.classification.plugin import\
ClassificationPlugin as c_plugin
from neutron_classifier.tests import objects_base as obj_base
class ClassificationGroupApiTest(testlib_api.MySQLTestCaseMixin,
testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon):
def setUp(self):
super(ClassificationGroupApiTest, self).setUp()
self.test_plugin = cg_plugin()
def test_get_classification_group(self):
with db_api.CONTEXT_WRITER.using(self.ctx):
cg = self._create_test_cg('Test Group 0')
cg_dict = self.test_plugin._make_db_dict(cg)
fetch_cg = self.test_plugin.get_classification_group(self.ctx,
cg.id)
cg_dict['classification_group']['classifications'] =\
fetch_cg['classification_group']['classifications']
cg_dict['classification_group']['classification_groups'] = \
fetch_cg['classification_group']['classification_groups']
self.assertEqual(cg_dict, fetch_cg)
def test_get_classification_groups(self):
with db_api.CONTEXT_WRITER.using(self.ctx):
cg1 = self._create_test_cg('Test Group 1')
cg2 = self._create_test_cg('Test Group 2')
test_cgs = self.test_plugin._make_db_dicts([cg1, cg2])
cgs = self.test_plugin.get_classification_groups(self.ctx)
self.assertItemsEqual(test_cgs, cgs)
def test_create_classification_group(self):
with db_api.CONTEXT_WRITER.using(self.ctx):
tcp_class = classifications.TCPClassification
ipv4_class = classifications.IPV4Classification
cg2 = self._create_test_cg('Test Group 1')
tcp = self._create_test_classification('tcp', tcp_class)
ipv4 = self._create_test_classification('ipv4', ipv4_class)
cg_dict = {'classification_group':
{'name': 'Test Group 0',
'description': "Description of test group",
'project_id': uuidutils.generate_uuid(),
'operator': 'AND',
'classifications': [tcp.id, ipv4.id],
'classification_groups': [cg2.id]
}}
cg1 = self.test_plugin.create_classification_group(self.ctx,
cg_dict)
fetch_cg1 = classifications.ClassificationGroup.get_object(
self.ctx, id=cg1['id'])
mapped_cgs = classifications._get_mapped_classification_groups(
self.ctx, fetch_cg1)
mapped_cs = classifications._get_mapped_classifications(
self.ctx, fetch_cg1)
mapped_classification_groups = [cg.id for cg in mapped_cgs]
mapped_classifications = [c.id for c in mapped_cs]
self.assertEqual(cg1, cg_dict['classification_group'])
for cg in mapped_classification_groups:
self.assertIn(
cg,
cg_dict['classification_group']['classification_groups'])
for c in mapped_classifications:
self.assertIn(
c, cg_dict['classification_group']['classifications'])
def test_update_classification_group(self):
with db_api.CONTEXT_WRITER.using(self.ctx):
cg1 = self._create_test_cg('Test Group 0')
cg2 = self._create_test_cg('Test Group 1')
self.test_plugin.update_classification_group(
self.ctx, cg1.id,
{'classification_group': {'name': 'Test Group updated'}})
fetch_cg1 = classifications.ClassificationGroup.get_object(
self.ctx, id=cg1['id'])
self.assertRaises(
exceptions.InvalidUpdateRequest,
self.test_plugin.update_classification_group,
self.ctx, cg2.id,
{'classification_group': {'name': 'Test Group updated',
'operator': 'OR'}})
self.assertEqual(fetch_cg1.name, 'Test Group updated')
def test_delete_classification_group(self):
with db_api.CONTEXT_WRITER.using(self.ctx):
cg1 = self._create_test_cg('Test Group 0')
self.test_plugin.delete_classification_group(self.ctx, cg1.id)
fetch_cg1 = classifications.ClassificationGroup.get_object(
self.ctx, id=cg1['id'])
self.assertIsNone(fetch_cg1)
class ClassificationApiTest(testlib_api.MySQLTestCaseMixin,
testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon):
def setUp(self):
super(ClassificationApiTest, self).setUp()
self.test_clas_plugin = c_plugin()
def test_create_classification(self):
attrs = self.get_random_attrs(classifications.EthernetClassification)
c_type = 'ethernet'
attrs['c_type'] = c_type
attrs['definition'] = {}
for key in validators.type_validators[c_type].keys():
attrs['definition'][key] = attrs.pop(key, None)
c_attrs = {'classification': attrs}
with db_api.CONTEXT_WRITER.using(self.ctx):
c1 = self.test_clas_plugin.create_classification(self.ctx,
c_attrs)
fetch_c1 = classifications.EthernetClassification.get_object(
self.ctx, id=c1['id']
)
c_attrs['classification']['definition']['src_port'] = 'xyz'
self.assertRaises(exceptions.InvalidClassificationDefintion,
self.test_clas_plugin.create_classification,
self.ctx, c_attrs)
eth = c1.pop('definition', None)
for k, v in c1.items():
self.assertEqual(v, fetch_c1[k])
for x, y in eth.items():
self.assertEqual(y, fetch_c1[x])
def test_delete_classification(self):
tcp_class = classifications.TCPClassification
with db_api.CONTEXT_WRITER.using(self.ctx):
tcp = self._create_test_classification('tcp', tcp_class)
self.test_clas_plugin.delete_classification(self.ctx, tcp.id)
fetch_tcp = classifications.TCPClassification.get_object(
self.ctx, id=tcp.id)
self.assertIsNone(fetch_tcp)
def test_get_classification(self):
ipv4_class = classifications.IPV4Classification
with db_api.CONTEXT_WRITER.using(self.ctx):
ipv4 = self._create_test_classification('ipv4', ipv4_class)
fetch_ipv4 = self.test_clas_plugin.get_classification(self.ctx,
ipv4.id)
self.assertEqual(fetch_ipv4, self.test_clas_plugin.merge_header(ipv4))
def test_get_classifications(self):
with db_api.CONTEXT_WRITER.using(self.ctx):
c1 = self._create_test_classification(
'ipv6', classifications.IPV6Classification)
c2 = self._create_test_classification(
'udp', classifications.UDPClassification)
fetch_cs_udp = self.test_clas_plugin.get_classifications(
self.ctx, filters={'c_type': ['udp']})
fetch_cs_ipv6 = self.test_clas_plugin.get_classifications(
self.ctx, filters={'c_type': ['ipv6']})
c1_dict = self.test_clas_plugin.merge_header(c1)
c2_dict = self.test_clas_plugin.merge_header(c2)
self.assertIn(c1_dict, fetch_cs_ipv6)
self.assertIn(c2_dict, fetch_cs_udp)
def test_update_classification(self):
c1 = self._create_test_classification(
'ethernet', classifications.EthernetClassification)
updated_name = 'Test Updated Classification'
with db_api.CONTEXT_WRITER.using(self.ctx):
self.test_clas_plugin.update_classification(
self.ctx, c1.id, {'classification': {'name': updated_name}})
fetch_c1 = classifications.EthernetClassification.get_object(
self.ctx, id=c1.id)
self.assertEqual(fetch_c1.name, updated_name)