diff --git a/neutron_classifier/common/constants.py b/neutron_classifier/common/constants.py index 6b9ceda..b6a2848 100644 --- a/neutron_classifier/common/constants.py +++ b/neutron_classifier/common/constants.py @@ -22,3 +22,11 @@ PROTOCOLS = ['tcp', 'udp', 'icmp', 'icmpv6'] ENCAPSULATION_TYPES = ['vxlan', 'gre'] NEUTRON_SERVICES = ['neutron-fwaas', 'networking-sfc', 'security-group'] + +DIRECTIONS = ['INGRESS', 'EGRESS', 'BIDIRECTIONAL'] + +ETHERTYPE_IPV4 = 0x0800 +ETHERTYPE_IPV6 = 0x86DD + +SECURITYGROUP_ETHERTYPE_IPV4 = 'IPv4' +SECURITYGROUP_ETHERTYPE_IPV6 = 'IPv6' diff --git a/neutron_classifier/db/api.py b/neutron_classifier/db/api.py index 59318bf..dd48a70 100644 --- a/neutron_classifier/db/api.py +++ b/neutron_classifier/db/api.py @@ -12,62 +12,113 @@ # License for the specific language governing permissions and limitations # under the License. +from neutron_classifier.common import constants from neutron_classifier.db import models -def get_classifier_chain(): - pass +def security_group_ethertype_to_ethertype_value(ethertype): + if ethertype == constants.SECURITYGROUP_ETHERTYPE_IPV6: + return constants.ETHERTYPE_IPV6 + else: + return constants.ETHERTYPE_IPV4 -def create_classifier_chain(context, classifier_group, classifier): - chain = models.ClassifierChainEntry() - chain.sequence = 1 - chain.classifier = classifier - chain.classifier_group = classifier_group - context.session.add(chain) +def ethertype_value_to_security_group_ethertype(ethertype): + if ethertype == constants.ETHERTYPE_IPV6: + return constants.SECURITYGROUP_ETHERTYPE_IPV6 + else: + return constants.SECURITYGROUP_ETHERTYPE_IPV4 + + +def get_classifier_group(context, classifier_group_id): + return context.session.query(models.ClassifierGroup).get( + classifier_group_id) + + +def create_classifier_chain(classifier_group, classifiers, + incremeting_sequence=False): + if incremeting_sequence: + seq = 0 + + for classifier in classifiers: + ce = models.ClassifierChainEntry(classifier_group=classifier_group, + classifier=classifier) + if incremeting_sequence: + ce.sequence = seq + classifier_group.classifier_chain.append(ce) + + +def convert_security_group_to_classifier(context, security_group): + cgroup = models.ClassifierGroup() + cgroup.service = 'security-group' + for rule in security_group['security_group_rules']: + convert_security_group_rule_to_classifier(context, rule, cgroup) + context.session.add(cgroup) context.session.commit() - return chain + return cgroup -def convert_security_group_rule_to_classifier(context, security_group_rule): - # TODO(sc68cal) Pass in the classifier group - group = models.ClassifierGroup() - group.service = 'security-group' - +def convert_security_group_rule_to_classifier(context, sgr, group): # Pull the source from the SG rule cl1 = models.IpClassifier() - cl1.source_ip_prefix = security_group_rule['remote_ip_prefix'] + cl1.source_ip_prefix = sgr['remote_ip_prefix'] # Ports - cl2 = models.TransportClassifier() - cl2.destination_port_range_min = security_group_rule['port_range_min'] - cl2.destination_port_range_max = security_group_rule['port_range_max'] + cl2 = models.TransportClassifier( + dst_port_range_min=sgr['port_range_min'], + dst_port_range_max=sgr['port_range_max']) - chain1 = models.ClassifierChainEntry() - chain1.classifier_group = group - chain1.classifier = cl1 - chain1.sequence = 1 + # Direction + cl3 = models.DirectionClassifier( + direction=sgr['direction']) - chain2 = models.ClassifierChainEntry() - chain2.classifier_group = group - chain2.classifier = cl2 - # Security Group classifiers might not need to be nested or have sequences? - chain2.sequence = 1 - context.session.add(group) - context.session.add(cl1) - context.session.add(cl2) - context.session.add(chain1) - context.session.add(chain2) - context.session.commit() - return group + # Ethertype + cl4 = models.EthernetClassifier() + cl4.ethertype = security_group_ethertype_to_ethertype_value( + sgr['ethertype']) + + if cl4.ethertype == constants.ETHERTYPE_IPV6: + cl5 = models.Ipv6Classifier() + cl5.next_header = sgr['protocol'] + else: + cl5 = models.Ipv4Classifier() + cl5.protocol = sgr['protocol'] + + classifiers = [cl1, cl2, cl3, cl4, cl5] + create_classifier_chain(group, classifiers) def convert_firewall_rule_to_classifier(context, firewall_rule): pass -def convert_classifier_chain_to_security_group(context, chain_id): - pass +def convert_classifier_group_to_security_group(context, classifier_group_id): + sg_dict = {} + cg = get_classifier_group(context, classifier_group_id) + for classifier in [link.classifier for link in cg.classifier_chain]: + classifier_type = type(classifier) + if classifier_type is models.TransportClassifier: + sg_dict['port_range_min'] = classifier.destination_port_range_min + sg_dict['port_range_max'] = classifier.destination_port_range_max + continue + if classifier_type is models.IpClassifier: + sg_dict['remote_ip_prefix'] = classifier.source_ip_prefix + continue + if classifier_type is models.DirectionClassifier: + sg_dict['direction'] = classifier.direction + continue + if classifier_type is models.EthernetClassifier: + sg_dict['ethertype'] = ethertype_value_to_security_group_ethertype( + classifier.ethertype) + continue + if classifier_type is models.Ipv4Classifier: + sg_dict['protocol'] = classifier.protcol + continue + if classifier_type is models.Ipv6Classifier: + sg_dict['protocol'] = classifier.next_header + continue + + return sg_dict def convert_classifier_to_firewall_policy(context, chain_id): diff --git a/neutron_classifier/db/models.py b/neutron_classifier/db/models.py index 90b6e4b..0aaf7fa 100644 --- a/neutron_classifier/db/models.py +++ b/neutron_classifier/db/models.py @@ -67,6 +67,24 @@ class ClassifierChainEntry(Base, HasId): sequence = sa.Column(sa.Integer) classifier_group = orm.relationship(ClassifierGroup) + def __init__(self, classifier_group=None, classifier=None, sequence=None): + super(ClassifierChainEntry, self).__init__() + self.classifier = classifier + self.classifier_group = classifier_group + self.sequence = sequence + + +class DirectionClassifier(Classifier): + __tablename__ = 'direction_classifiers' + __mapper_args__ = {'polymorphic_identity': 'directionclassifier'} + id = sa.Column(sa.String(36), sa.ForeignKey('classifiers.id'), + primary_key=True) + direction = sa.Column(sa.Enum(*constants.DIRECTIONS)) + + def __init__(self, direction=None): + super(DirectionClassifier, self).__init__() + self.direction = direction + class EncapsulationClassifier(Classifier): __tablename__ = 'encapsulation_classifiers' @@ -135,6 +153,12 @@ class TransportClassifier(Classifier): destination_port_range_max = sa.Column(sa.Integer) destination_port_range_min = sa.Column(sa.Integer) + def __init__(self, dst_port_range_min=None, + dst_port_range_max=None): + super(TransportClassifier, self).__init__() + self.destination_port_range_min = dst_port_range_min + self.destination_port_range_max = dst_port_range_max + class VlanClassifier(Classifier): __tablename__ = 'vlan_classifiers' diff --git a/neutron_classifier/tests/test_db_api.py b/neutron_classifier/tests/test_db_api.py index 11952bf..03f701c 100644 --- a/neutron_classifier/tests/test_db_api.py +++ b/neutron_classifier/tests/test_db_api.py @@ -21,6 +21,16 @@ from oslotest import base CREATED = False +FAKE_SG_RULE = {'direction': 'INGRESS', 'protocol': 'tcp', 'ethertype': 'IPv6', + 'tenant_id': 'fake_tenant', 'port_range_min': 80, + 'port_range_max': 80, 'remote_ip_prefix': 'fddf:cb3b:bc4::/48', + } + +FAKE_SG = {'name': 'fake security group', + 'tenant_id': uuidutils.generate_uuid(), + 'description': 'this is fake', + 'security_group_rules': [FAKE_SG_RULE]} + class ClassifierTestContext(object): "Classifier Database Context." @@ -43,33 +53,34 @@ class DbApiTestCase(base.BaseTestCase): global CREATED CREATED = True + def _create_classifier_group(self, service): + cg = models.ClassifierGroup() + cg.tenant_id = uuidutils.generate_uuid() + cg.name = 'test classifier' + cg.description = 'ensure all data inserted correctly' + cg.service = service + return cg + def test_create_classifier_chain(self): - # TODO(sc68cal) Make this not hacky, and make it pass a session - # in a context - fake_tenant = uuidutils.generate_uuid() - a = models.ClassifierGroup() - a.tenant_id = fake_tenant - a.name = 'test classifier' - a.description = 'ensure all data inserted correctly' - a.service = 'neutron-fwaas' - b = models.IpClassifier() - b.destination_ip_prefix = 'fd70:fbb6:449e::/48' - b.source_ip_prefix = 'fddf:cb3b:bc4::/48' - result = api.create_classifier_chain(self.context, a, b) - self.assertIsNotNone(result) + cg = self._create_classifier_group('neutron-fwaas') + ipc = models.IpClassifier() + ipc.destination_ip_prefix = 'fd70:fbb6:449e::/48' + ipc.source_ip_prefix = 'fddf:cb3b:bc4::/48' + api.create_classifier_chain(cg, [ipc]) + self.assertGreater(len(cg.classifier_chain), 0) def test_convert_security_group_rule_to_classifier(self): - sg_rule = {'direction': 'INGRESS', - 'protocol': 'tcp', - 'ethertype': 6, - 'tenant_id': 'fake_tenant', - 'port_range_min': 80, - 'port_range_max': 80, - 'remote_ip_prefix': 'fddf:cb3b:bc4::/48', - } - result = api.convert_security_group_rule_to_classifier(self.context, - sg_rule) - self.assertIsNotNone(result) + # TODO(sc68cal) make this not call session.commit directly + cg = self._create_classifier_group('security-group') + api.convert_security_group_rule_to_classifier(self.context, + FAKE_SG_RULE, cg) + # Save to the database + self.context.session.add(cg) + self.context.session.commit() + + # Refresh the classifier group from the DB + cg = api.get_classifier_group(self.context, cg.id) + self.assertGreater(len(cg.classifier_chain), 0) def test_convert_firewall_rule_to_classifier(self): firewall_rule = {'protocol': 'foo', @@ -88,10 +99,17 @@ class DbApiTestCase(base.BaseTestCase): pass def test_convert_security_group_to_classifier_chain(self): - pass + result = api.convert_security_group_to_classifier(self.context, + FAKE_SG) + self.assertIsNotNone(result) def test_convert_classifier_chain_to_security_group(self): - pass + classifier_id = api.convert_security_group_to_classifier( + self.context, FAKE_SG).id + result = api.convert_classifier_group_to_security_group(self.context, + classifier_id) + result['tenant_id'] = FAKE_SG_RULE['tenant_id'] + self.assertEqual(FAKE_SG_RULE, result) def test_convert_classifier_chain_to_firewall_policy(self): pass