From 326bc6cdb0f55a26dff80e99493ad518d32fb977 Mon Sep 17 00:00:00 2001 From: Ofer Ben-Yacov Date: Tue, 24 Jan 2017 16:52:42 +0200 Subject: [PATCH] Delete class tree from agent --- wan_qos/agent/tc_driver.py | 8 +-- wan_qos/agent/tc_manager.py | 18 +++++++ wan_qos/common/api.py | 6 +++ wan_qos/db/wan_qos_db.py | 97 +++++++++++++++++++++++++++++------ wan_qos/services/plugin.py | 6 ++- wan_qos/tests/unit/test_db.py | 69 +++++++++++++++++++++++-- 6 files changed, 178 insertions(+), 26 deletions(-) diff --git a/wan_qos/agent/tc_driver.py b/wan_qos/agent/tc_driver.py index 7ecae83..fddfbd5 100644 --- a/wan_qos/agent/tc_driver.py +++ b/wan_qos/agent/tc_driver.py @@ -66,11 +66,11 @@ class TcDriver(agent_api.AgentInterface): tc_dict['command'] = 'change' self._create_or_update_class(tc_dict) - def remove_traffic_class(self, tc_dict): - self._delete_filter(tc_dict) - cmd = 'sudo tc class del dev %s classid %s:%s' % ( + def remove_traffic_class(self, tc_dict, with_filter=False): + if with_filter: + self._delete_filter(tc_dict) + cmd = 'sudo tc class del dev %s classid 1:%s' % ( self.ports[tc_dict['port_side']], - tc_dict['parent'], tc_dict['child'] ) check_call(cmd, shell=True) diff --git a/wan_qos/agent/tc_manager.py b/wan_qos/agent/tc_manager.py index b2f9745..8f118a3 100644 --- a/wan_qos/agent/tc_manager.py +++ b/wan_qos/agent/tc_manager.py @@ -103,3 +103,21 @@ class TcAgentManager(manager.Manager): def _create_wtc_class(self, class_dict): self.agent.create_traffic_class(class_dict) + + def delete_wtc_class(self, context, wtc_class_tree): + for child in wtc_class_tree['child_list']: + self.delete_wtc_class(context, child) + self._delete_wtc_class(wtc_class_tree) + + def _delete_wtc_class(self, wtc_class): + tc_dict = { + 'parent': wtc_class['parent_class_ext_id'], + 'child': wtc_class['class_ext_id'] + } + + if wtc_class['direction'] == 'in' or wtc_class['direction'] == 'both': + tc_dict['port_side'] = 'lan_port' + self.agent.remove_traffic_class(tc_dict) + if wtc_class['direction'] == 'out' or wtc_class['direction'] == 'both': + tc_dict['port_side'] = 'wan_port' + self.agent.remove_traffic_class(tc_dict) diff --git a/wan_qos/common/api.py b/wan_qos/common/api.py index 0783706..31b4fed 100644 --- a/wan_qos/common/api.py +++ b/wan_qos/common/api.py @@ -57,3 +57,9 @@ class TcAgentApi(object): return cctxt.call(context, 'create_wtc_class', wtc_class_dict=wtc_class_dict) + + def delete_wtc_class(self, context, wtc_class_tree): + cctxt = self.client.prepare() + return cctxt.call(context, + 'delete_wtc_class', + wtc_class_tree=wtc_class_tree) diff --git a/wan_qos/db/wan_qos_db.py b/wan_qos/db/wan_qos_db.py index 2cb76a1..1db8a97 100644 --- a/wan_qos/db/wan_qos_db.py +++ b/wan_qos/db/wan_qos_db.py @@ -113,7 +113,7 @@ class WanTcDb(object): class_ext_id=self.get_last_class_ext_id(context) ) - if wtc_class['parent'] != '': + if 'parent' in wtc_class and wtc_class['parent'] != '': parent = wtc_class['parent'] parent_class = self.get_class_by_id(context, parent) if not parent_class: @@ -149,15 +149,26 @@ class WanTcDb(object): if wtc_class: return self._class_to_dict(wtc_class) - def get_all_classes(self, context): - wtc_classes_db = context.session.query(models.WanTcClass).filter( - models.WanTcClass.id != 'root').all() - wtc_classes = [] - for wtc_class in wtc_classes_db: - wtc_classes.append(self._class_to_dict(wtc_class)) - return wtc_classes + def get_all_classes(self, context, filters=None, + fields=None, + sorts=None, limit=None, marker=None, + page_reverse=False): + marker_obj = self._get_marker_obj( + context, 'wan_tc_class', limit, marker) + all_classes = self._get_collection(context, models.WanTcClass, + self._class_to_dict, + filters=filters, fields=fields, + sorts=sorts, limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) + if not filters: + for wtc_class in all_classes: + if wtc_class['id'] == 'root': + all_classes.remove(wtc_class) + break + return all_classes - def _class_to_dict(self, wtc_class): + def _class_to_dict(self, wtc_class, fields=None): class_dict = { 'id': wtc_class.id, @@ -171,7 +182,7 @@ class WanTcDb(object): return class_dict - def _device_to_dict(self, device): + def _device_to_dict(self, device, fields=None): device_dict = { 'id': device.id, 'host': device.host, @@ -200,14 +211,10 @@ class WanTcDb(object): if device: return self._device_to_dict(device) - def get_class_tree(self): + def get_class_tree(self, start_from_id='root'): context = ctx.get_admin_context() - wtc_classes = self._get_root_classes(context) - return wtc_classes - - def _get_root_classes(self, context): root_class_db = context.session.query(models.WanTcClass).filter_by( - id='root').first() + id=start_from_id).first() root_class = self._class_to_dict(root_class_db) self._get_child_classes(context, root_class) return root_class @@ -220,3 +227,61 @@ class WanTcDb(object): child_class = self._class_to_dict(child_class_db) parent_class['child_list'].append(child_class) self._get_child_classes(context, child_class) + + def _get_collection(self, context, model, dict_func, filters=None, + fields=None, sorts=None, limit=None, marker_obj=None, + page_reverse=False): + """Get collection object based on query for resources.""" + query = self._get_collection_query(context, model, filters=filters, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) + items = [dict_func(c, fields) for c in query] + if limit and page_reverse: + items.reverse() + return items + + def _get_collection_query(self, context, model, filters=None, + sorts=None, limit=None, marker_obj=None, + page_reverse=False): + """Get collection query for the models.""" + collection = self._model_query(context, model) + collection = self._apply_filters_to_query(collection, model, filters) + return collection + + def _get_marker_obj(self, context, resource, limit, marker): + """Get marker object for the resource.""" + if limit and marker: + return getattr(self, '_get_%s' % resource)(context, marker) + return None + + def _fields(self, resource, fields): + """Get fields for the resource for get query.""" + if fields: + return dict(((key, item) for key, item in resource.items() + if key in fields)) + return resource + + def _model_query(self, context, model): + """Query model based on filter.""" + query = context.session.query(model) + query_filter = None + if not context.is_admin and hasattr(model, 'tenant_id'): + if hasattr(model, 'shared'): + query_filter = ((model.tenant_id == context.tenant_id) | + (model.shared == sa.true())) + else: + query_filter = (model.tenant_id == context.tenant_id) + if query_filter is not None: + query = query.filter(query_filter) + return query + + def _apply_filters_to_query(self, query, model, filters): + """Apply filters to query for the models.""" + if filters: + for key, value in filters.items(): + column = getattr(model, key, None) + if column: + query = query.filter(column.in_(value)) + return query diff --git a/wan_qos/services/plugin.py b/wan_qos/services/plugin.py index 811867e..d8d6e9a 100644 --- a/wan_qos/services/plugin.py +++ b/wan_qos/services/plugin.py @@ -109,11 +109,15 @@ class WanQosPlugin(wanqos.WanQosPluginBase, return wtc_class_db def delete_wan_tc_class(self, context, id): + LOG.debug('Got request to delete class id: %s' % id) + class_tree = self.db.get_class_tree(id) self.db.delete_wtc_class(context, id) + self.agent_rpc.delete_wtc_class(context, class_tree) def get_wan_tc_classs(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False): - return self.db.get_all_classes(context) + return self.db.get_all_classes(context, filters, fields, sorts, limit, + marker, page_reverse) @staticmethod def _get_tenant_id_for_create(self, context, resource): diff --git a/wan_qos/tests/unit/test_db.py b/wan_qos/tests/unit/test_db.py index 7bf76e1..eb5992f 100644 --- a/wan_qos/tests/unit/test_db.py +++ b/wan_qos/tests/unit/test_db.py @@ -7,6 +7,7 @@ from oslo_config import cfg from wan_qos.db import wan_qos_db from wan_qos.services import plugin + class TestTcDb(testlib_api.SqlTestCase): def setUp(self): super(TestTcDb, self).setUp() @@ -22,15 +23,40 @@ class TestTcDb(testlib_api.SqlTestCase): assert wtc_class_db is not None + def test_get_class_by_id(self): + + class_db_1 = self._add_class(None, 'both', '1mbit', '2mbit') + class_db_2 = self._add_class(class_db_1['id'], 'both', '2mbit', + '3mbit') + class_db_3 = self._add_class(class_db_2['id'], 'both', '3mbit', + '4mbit') + + class_by_id = self.db.get_class_by_id(self.context, class_db_1['id']) + print (class_by_id) + def test_get_class_tree(self): - class_db = self._add_class(None, 'both', '1mbit', '2mbit') - class_db = self._add_class(class_db['id'], 'both', '2mbit', '3mbit') - class_db = self._add_class(class_db['id'], 'both', '3mbit', '4mbit') + class_db_1 = self._add_class(None, 'both', '1mbit', '2mbit') + class_db_2 = self._add_class(class_db_1['id'], 'both', '2mbit', + '3mbit') + class_db_3 = self._add_class(class_db_2['id'], 'both', '3mbit', + '4mbit') class_tree = self.db.get_class_tree() assert class_tree is not None - print class_tree + print (class_tree) + + class_tree = self.db.get_class_tree(class_db_1['id']) + assert class_tree is not None + print (class_tree) + + class_tree = self.db.get_class_tree(class_db_2['id']) + assert class_tree is not None + print (class_tree) + + class_tree = self.db.get_class_tree(class_db_3['id']) + assert class_tree is not None + print (class_tree) def test_get_classes(self): self.test_add_class() @@ -64,7 +90,40 @@ class TestPlugin(testlib_api.SqlTestCase): } } - wan_tc = self.plugin.create_wan_tc_class(ctx.get_admin_context(), wtc_class) + wan_tc = self.plugin.create_wan_tc_class(ctx.get_admin_context(), + wtc_class) assert wan_tc is not None print (wan_tc) + + def test_get_class_by_id(self): + + class_db_1 = self._add_class(None, 'both', '1mbit', '2mbit') + class_db_2 = self._add_class(class_db_1['id'], 'both', '2mbit', + '3mbit') + class_db_3 = self._add_class(class_db_2['id'], 'both', '3mbit', + '4mbit') + + tc_class = self.plugin.get_wan_tc_class(ctx.get_admin_context(), + class_db_1['id']) + + print(tc_class) + filters = {'id': [class_db_1['id']]} + tc_classes = self.plugin.get_wan_tc_classs(ctx.get_admin_context()) + + print(tc_classes) + + + def _add_class(self, parent, direction, min, max): + wtc_class = { + 'direction': direction, + } + if min: + wtc_class['min'] = min + if parent: + wtc_class['parent'] = parent + if max: + wtc_class['max'] = max + + return self.plugin.db.create_wan_tc_class(ctx.get_admin_context(), + wtc_class)