Support qos rules and fields parameters in GET requests
Previously we didn't load the rules into policy object. This patch adds loading the rules and defines bandwidth_limit_rules as a policy resource in a single transaction. As a part of moving towards usage of single transaction, create() and update() of rule were modified accordingly. Finally, we support types in GET requests in this patch. API tests will follow in different patch. Change-Id: I25c72aae74469b687766754bbeb749dfd1b8867c
This commit is contained in:
parent
66520a4293
commit
7ed1d4f616
|
@ -29,6 +29,26 @@ from neutron.db import models_v2
|
|||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_fields(f):
|
||||
@functools.wraps(f)
|
||||
def inner_filter(*args, **kwargs):
|
||||
result = f(*args, **kwargs)
|
||||
fields = kwargs.get('fields')
|
||||
if not fields:
|
||||
pos = f.func_code.co_varnames.index('fields')
|
||||
try:
|
||||
fields = args[pos]
|
||||
except IndexError:
|
||||
return result
|
||||
|
||||
do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
|
||||
if isinstance(result, list):
|
||||
return [do_filter(obj) for obj in result]
|
||||
else:
|
||||
return do_filter(result)
|
||||
return inner_filter
|
||||
|
||||
|
||||
class DbBasePluginCommon(common_db_mixin.CommonDbMixin):
|
||||
"""Stores getters and helper methods for db_base_plugin_v2
|
||||
|
||||
|
|
|
@ -61,7 +61,9 @@ RESOURCE_ATTRIBUTE_MAP = {
|
|||
'convert_to': attr.convert_to_boolean},
|
||||
'tenant_id': {'allow_post': True, 'allow_put': False,
|
||||
'required_by_policy': True,
|
||||
'is_visible': True}
|
||||
'is_visible': True},
|
||||
'bandwidth_limit_rules': {'allow_post': False, 'allow_put': False,
|
||||
'is_visible': True},
|
||||
},
|
||||
'rule_types': {
|
||||
'type': {'allow_post': False, 'allow_put': False,
|
||||
|
|
|
@ -75,12 +75,37 @@ class QosPolicy(base.NeutronObject):
|
|||
setattr(self, attrname, rules)
|
||||
self.obj_reset_changes([attrname])
|
||||
|
||||
def _load_rules(self):
|
||||
for attr in self.rule_fields:
|
||||
self.obj_load_attr(attr)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, context, id):
|
||||
with db_api.autonested_transaction(context.session):
|
||||
policy_obj = super(QosPolicy, cls).get_by_id(context, id)
|
||||
if policy_obj:
|
||||
policy_obj._load_rules()
|
||||
return policy_obj
|
||||
|
||||
# TODO(QoS): Test that all objects are fetched within one transaction
|
||||
@classmethod
|
||||
def get_objects(cls, context, **kwargs):
|
||||
with db_api.autonested_transaction(context.session):
|
||||
db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
|
||||
objs = list()
|
||||
for db_obj in db_objs:
|
||||
obj = cls(context, **db_obj)
|
||||
obj._load_rules()
|
||||
objs.append(obj)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
def _get_object_policy(cls, context, model, **kwargs):
|
||||
binding_db_obj = db_api.get_object(context, model, **kwargs)
|
||||
# TODO(QoS): rethink handling missing binding case
|
||||
if binding_db_obj:
|
||||
return cls.get_by_id(context, binding_db_obj['policy_id'])
|
||||
with db_api.autonested_transaction(context.session):
|
||||
binding_db_obj = db_api.get_object(context, model, **kwargs)
|
||||
# TODO(QoS): rethink handling missing binding case
|
||||
if binding_db_obj:
|
||||
return cls.get_by_id(context, binding_db_obj['policy_id'])
|
||||
|
||||
@classmethod
|
||||
def get_network_policy(cls, context, network_id):
|
||||
|
@ -92,6 +117,11 @@ class QosPolicy(base.NeutronObject):
|
|||
return cls._get_object_policy(context, cls.port_binding_model,
|
||||
port_id=port_id)
|
||||
|
||||
def create(self):
|
||||
with db_api.autonested_transaction(self._context.session):
|
||||
super(QosPolicy, self).create()
|
||||
self._load_rules()
|
||||
|
||||
def attach_network(self, network_id):
|
||||
qos_db_api.create_policy_network_binding(self._context,
|
||||
policy_id=self.id,
|
||||
|
|
|
@ -96,7 +96,7 @@ class QosRule(base.NeutronObject):
|
|||
obj.obj_reset_changes()
|
||||
return obj
|
||||
|
||||
# TODO(QoS): create and update are not transactional safe
|
||||
# TODO(QoS): Test that create is in single transaction
|
||||
def create(self):
|
||||
|
||||
# TODO(QoS): enforce that type field value is bound to specific class
|
||||
|
@ -104,18 +104,21 @@ class QosRule(base.NeutronObject):
|
|||
|
||||
# create base qos_rule
|
||||
core_fields = self._get_changed_core_fields()
|
||||
base_db_obj = db_api.create_object(
|
||||
self._context, self.base_db_model, core_fields)
|
||||
|
||||
# create type specific qos_..._rule
|
||||
addn_fields = self._get_changed_addn_fields()
|
||||
self._copy_common_fields(core_fields, addn_fields)
|
||||
addn_db_obj = db_api.create_object(
|
||||
self._context, self.db_model, addn_fields)
|
||||
with db_api.autonested_transaction(self._context.session):
|
||||
base_db_obj = db_api.create_object(
|
||||
self._context, self.base_db_model, core_fields)
|
||||
|
||||
# create type specific qos_..._rule
|
||||
addn_fields = self._get_changed_addn_fields()
|
||||
self._copy_common_fields(core_fields, addn_fields)
|
||||
addn_db_obj = db_api.create_object(
|
||||
self._context, self.db_model, addn_fields)
|
||||
|
||||
# merge two db objects into single neutron one
|
||||
self.from_db_object(base_db_obj, addn_db_obj)
|
||||
|
||||
# TODO(QoS): Test that update is in single transaction
|
||||
def update(self):
|
||||
updated_db_objs = []
|
||||
|
||||
|
@ -123,16 +126,18 @@ class QosRule(base.NeutronObject):
|
|||
|
||||
# update base qos_rule, if needed
|
||||
core_fields = self._get_changed_core_fields()
|
||||
if core_fields:
|
||||
base_db_obj = db_api.update_object(
|
||||
self._context, self.base_db_model, self.id, core_fields)
|
||||
updated_db_objs.append(base_db_obj)
|
||||
|
||||
addn_fields = self._get_changed_addn_fields()
|
||||
if addn_fields:
|
||||
addn_db_obj = db_api.update_object(
|
||||
self._context, self.db_model, self.id, addn_fields)
|
||||
updated_db_objs.append(addn_db_obj)
|
||||
with db_api.autonested_transaction(self._context.session):
|
||||
if core_fields:
|
||||
base_db_obj = db_api.update_object(
|
||||
self._context, self.base_db_model, self.id, core_fields)
|
||||
updated_db_objs.append(base_db_obj)
|
||||
|
||||
addn_fields = self._get_changed_addn_fields()
|
||||
if addn_fields:
|
||||
addn_db_obj = db_api.update_object(
|
||||
self._context, self.db_model, self.id, addn_fields)
|
||||
updated_db_objs.append(addn_db_obj)
|
||||
|
||||
# update neutron object with values from both database objects
|
||||
self.from_db_object(*updated_db_objs)
|
||||
|
|
|
@ -17,6 +17,7 @@ from neutron import manager
|
|||
|
||||
from neutron.api.rpc.callbacks import registry as rpc_registry
|
||||
from neutron.api.rpc.callbacks import resources as rpc_resources
|
||||
from neutron.db import db_base_plugin_common
|
||||
from neutron.extensions import qos
|
||||
from neutron.i18n import _LW
|
||||
from neutron.objects.qos import policy as policy_object
|
||||
|
@ -134,10 +135,11 @@ class QoSPlugin(qos.QoSPluginBase):
|
|||
def _get_policy_obj(self, context, policy_id):
|
||||
return policy_object.QosPolicy.get_by_id(context, policy_id)
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def get_policy(self, context, policy_id, fields=None):
|
||||
#TODO(QoS): Support the fields parameter
|
||||
return self._get_policy_obj(context, policy_id).to_dict()
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def get_policies(self, context, filters=None, fields=None,
|
||||
sorts=None, limit=None, marker=None,
|
||||
page_reverse=False):
|
||||
|
@ -174,12 +176,13 @@ class QoSPlugin(qos.QoSPluginBase):
|
|||
rule.id = rule_id
|
||||
rule.delete()
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def get_policy_bandwidth_limit_rule(self, context, rule_id,
|
||||
policy_id, fields=None):
|
||||
#TODO(QoS): Support the fields parameter
|
||||
return rule_object.QosBandwidthLimitRule.get_by_id(context,
|
||||
rule_id).to_dict()
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def get_policy_bandwidth_limit_rules(self, context, policy_id,
|
||||
filters=None, fields=None,
|
||||
sorts=None, limit=None,
|
||||
|
@ -188,6 +191,7 @@ class QoSPlugin(qos.QoSPluginBase):
|
|||
return [rule_obj.to_dict() for rule_obj in
|
||||
rule_object.QosBandwidthLimitRule.get_objects(context)]
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def get_rule_types(self, context, filters=None, fields=None,
|
||||
sorts=None, limit=None,
|
||||
marker=None, page_reverse=False):
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) 2015 Red Hat, Inc.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from neutron.db import db_base_plugin_common
|
||||
from neutron.tests import base
|
||||
|
||||
|
||||
class FilterFieldsTestCase(base.BaseTestCase):
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def method_dict(self, fields=None):
|
||||
return {'one': 1, 'two': 2, 'three': 3}
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def method_list(self, fields=None):
|
||||
return [self.method_dict() for _ in range(3)]
|
||||
|
||||
@db_base_plugin_common.filter_fields
|
||||
def method_multiple_arguments(self, not_used, fields=None,
|
||||
also_not_used=None):
|
||||
return {'one': 1, 'two': 2, 'three': 3}
|
||||
|
||||
def test_no_fields(self):
|
||||
expected = {'one': 1, 'two': 2, 'three': 3}
|
||||
observed = self.method_dict()
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
def test_dict(self):
|
||||
expected = {'two': 2}
|
||||
observed = self.method_dict(['two'])
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
def test_list(self):
|
||||
expected = [{'two': 2}, {'two': 2}, {'two': 2}]
|
||||
observed = self.method_list(['two'])
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
def test_multiple_arguments_positional(self):
|
||||
expected = {'two': 2}
|
||||
observed = self.method_multiple_arguments(list(), ['two'])
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
def test_multiple_arguments_positional_and_keywords(self):
|
||||
expected = {'two': 2}
|
||||
observed = self.method_multiple_arguments(fields=['two'],
|
||||
not_used=None)
|
||||
self.assertEqual(expected, observed)
|
||||
|
||||
def test_multiple_arguments_keyword(self):
|
||||
expected = {'two': 2}
|
||||
observed = self.method_multiple_arguments(list(), fields=['two'])
|
||||
self.assertEqual(expected, observed)
|
|
@ -10,6 +10,8 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from neutron.db import api as db_api
|
||||
from neutron.db import models_v2
|
||||
from neutron.objects.qos import policy
|
||||
|
@ -22,6 +24,50 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
|
|||
|
||||
_test_class = policy.QosPolicy
|
||||
|
||||
def setUp(self):
|
||||
super(QosPolicyObjectTestCase, self).setUp()
|
||||
self.db_qos_rules = [self.get_random_fields(rule.QosRule)
|
||||
for _ in range(3)]
|
||||
|
||||
# Tie qos rules with policies
|
||||
self.db_qos_rules[0]['qos_policy_id'] = self.db_objs[0]['id']
|
||||
self.db_qos_rules[1]['qos_policy_id'] = self.db_objs[0]['id']
|
||||
self.db_qos_rules[2]['qos_policy_id'] = self.db_objs[1]['id']
|
||||
|
||||
self.db_qos_bandwidth_rules = [
|
||||
self.get_random_fields(rule.QosBandwidthLimitRule)
|
||||
for _ in range(3)]
|
||||
|
||||
# Tie qos rules with qos bandwidth limit rules
|
||||
for i, qos_rule in enumerate(self.db_qos_rules):
|
||||
self.db_qos_bandwidth_rules[i]['id'] = qos_rule['id']
|
||||
|
||||
self.model_map = {
|
||||
self._test_class.db_model: self.db_objs,
|
||||
rule.QosRule.base_db_model: self.db_qos_rules,
|
||||
rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules}
|
||||
|
||||
def fake_get_objects(self, context, model, qos_policy_id=None):
|
||||
objs = self.model_map[model]
|
||||
if model is rule.QosRule.base_db_model and qos_policy_id:
|
||||
return [obj for obj in objs
|
||||
if obj['qos_policy_id'] == qos_policy_id]
|
||||
return objs
|
||||
|
||||
def fake_get_object(self, context, model, id):
|
||||
objects = self.model_map[model]
|
||||
return [obj for obj in objects if obj['id'] == id][0]
|
||||
|
||||
def test_get_objects(self):
|
||||
with mock.patch.object(
|
||||
db_api, 'get_objects',
|
||||
side_effect=self.fake_get_objects),\
|
||||
mock.patch.object(
|
||||
db_api, 'get_object',
|
||||
side_effect=self.fake_get_object):
|
||||
objs = self._test_class.get_objects(self.context)
|
||||
self._validate_objects(self.db_objs, objs)
|
||||
|
||||
|
||||
class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
|
||||
testlib_api.SqlTestCase):
|
||||
|
@ -42,6 +88,19 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
|
|||
policy_obj.create()
|
||||
return policy_obj
|
||||
|
||||
def _create_test_policy_with_rule(self):
|
||||
policy_obj = self._create_test_policy()
|
||||
|
||||
rule_fields = self.get_random_fields(
|
||||
obj_cls=rule.QosBandwidthLimitRule)
|
||||
rule_fields['qos_policy_id'] = policy_obj.id
|
||||
rule_fields['tenant_id'] = policy_obj.tenant_id
|
||||
|
||||
rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
|
||||
rule_obj.create()
|
||||
|
||||
return policy_obj, rule_obj
|
||||
|
||||
def _create_test_network(self):
|
||||
# TODO(ihrachys): replace with network.create() once we get an object
|
||||
# implementation for networks
|
||||
|
@ -111,16 +170,22 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
|
|||
self.assertIsNone(policy_obj)
|
||||
|
||||
def test_synthetic_rule_fields(self):
|
||||
obj = policy.QosPolicy(self.context, **self.db_obj)
|
||||
obj.create()
|
||||
policy_obj, rule_obj = self._create_test_policy_with_rule()
|
||||
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
|
||||
self.assertEqual([rule_obj], policy_obj.bandwidth_limit_rules)
|
||||
|
||||
rule_fields = self.get_random_fields(
|
||||
obj_cls=rule.QosBandwidthLimitRule)
|
||||
rule_fields['qos_policy_id'] = obj.id
|
||||
rule_fields['tenant_id'] = obj.tenant_id
|
||||
def test_create_is_in_single_transaction(self):
|
||||
obj = self._test_class(self.context, **self.db_obj)
|
||||
with mock.patch('sqlalchemy.engine.'
|
||||
'Transaction.commit') as mock_commit,\
|
||||
mock.patch.object(obj._context.session, 'add'):
|
||||
obj.create()
|
||||
self.assertEqual(1, mock_commit.call_count)
|
||||
|
||||
rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
|
||||
rule_obj.create()
|
||||
def test_get_by_id_fetches_rules_non_lazily(self):
|
||||
policy_obj, rule_obj = self._create_test_policy_with_rule()
|
||||
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
|
||||
|
||||
obj = policy.QosPolicy.get_by_id(self.context, obj.id)
|
||||
self.assertEqual([rule_obj], obj.bandwidth_limit_rules)
|
||||
primitive = policy_obj.obj_to_primitive()
|
||||
self.assertNotEqual([], (primitive['versioned_object.data']
|
||||
['bandwidth_limit_rules']))
|
||||
|
|
|
@ -23,10 +23,15 @@ from neutron.objects import base
|
|||
from neutron.tests import base as test_base
|
||||
|
||||
|
||||
class FakeModel(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@obj_base.VersionedObjectRegistry.register
|
||||
class FakeNeutronObject(base.NeutronObject):
|
||||
|
||||
db_model = 'fake_model'
|
||||
db_model = FakeModel
|
||||
|
||||
fields = {
|
||||
'id': obj_fields.UUIDField(),
|
||||
|
@ -106,13 +111,16 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
|||
with mock.patch.object(db_api, 'get_objects',
|
||||
return_value=self.db_objs) as get_objects_mock:
|
||||
objs = self._test_class.get_objects(self.context)
|
||||
self.assertFalse(
|
||||
filter(lambda obj: not self._is_test_class(obj), objs))
|
||||
self.assertEqual(
|
||||
sorted(self.db_objs),
|
||||
sorted(get_obj_db_fields(obj) for obj in objs))
|
||||
get_objects_mock.assert_called_once_with(
|
||||
self.context, self._test_class.db_model)
|
||||
self._validate_objects(self.db_objs, objs)
|
||||
get_objects_mock.assert_called_once_with(
|
||||
self.context, self._test_class.db_model)
|
||||
|
||||
def _validate_objects(self, expected, observed):
|
||||
self.assertFalse(
|
||||
filter(lambda obj: not self._is_test_class(obj), observed))
|
||||
self.assertEqual(
|
||||
sorted(expected),
|
||||
sorted(get_obj_db_fields(obj) for obj in observed))
|
||||
|
||||
def _check_equal(self, obj, db_obj):
|
||||
self.assertEqual(
|
||||
|
|
Loading…
Reference in New Issue