Support qos rules and fields parameters in GET requests

Previously we didn't load the rules into policy object. This patch adds
loading the rules and defines bandwidth_limit_rules as a policy
resource in a single transaction. As a part of moving towards usage of
single transaction, create() and update() of rule were modified
accordingly.
Finally, we support types in GET requests in this patch.

API tests will follow in different patch.

Change-Id: I25c72aae74469b687766754bbeb749dfd1b8867c
This commit is contained in:
Jakub Libosvar 2015-07-21 08:04:00 +00:00
parent 66520a4293
commit 7ed1d4f616
8 changed files with 240 additions and 42 deletions

View File

@ -29,6 +29,26 @@ from neutron.db import models_v2
LOG = logging.getLogger(__name__)
def filter_fields(f):
@functools.wraps(f)
def inner_filter(*args, **kwargs):
result = f(*args, **kwargs)
fields = kwargs.get('fields')
if not fields:
pos = f.func_code.co_varnames.index('fields')
try:
fields = args[pos]
except IndexError:
return result
do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
if isinstance(result, list):
return [do_filter(obj) for obj in result]
else:
return do_filter(result)
return inner_filter
class DbBasePluginCommon(common_db_mixin.CommonDbMixin):
"""Stores getters and helper methods for db_base_plugin_v2

View File

@ -61,7 +61,9 @@ RESOURCE_ATTRIBUTE_MAP = {
'convert_to': attr.convert_to_boolean},
'tenant_id': {'allow_post': True, 'allow_put': False,
'required_by_policy': True,
'is_visible': True}
'is_visible': True},
'bandwidth_limit_rules': {'allow_post': False, 'allow_put': False,
'is_visible': True},
},
'rule_types': {
'type': {'allow_post': False, 'allow_put': False,

View File

@ -75,12 +75,37 @@ class QosPolicy(base.NeutronObject):
setattr(self, attrname, rules)
self.obj_reset_changes([attrname])
def _load_rules(self):
for attr in self.rule_fields:
self.obj_load_attr(attr)
@classmethod
def get_by_id(cls, context, id):
with db_api.autonested_transaction(context.session):
policy_obj = super(QosPolicy, cls).get_by_id(context, id)
if policy_obj:
policy_obj._load_rules()
return policy_obj
# TODO(QoS): Test that all objects are fetched within one transaction
@classmethod
def get_objects(cls, context, **kwargs):
with db_api.autonested_transaction(context.session):
db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
objs = list()
for db_obj in db_objs:
obj = cls(context, **db_obj)
obj._load_rules()
objs.append(obj)
return objs
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
binding_db_obj = db_api.get_object(context, model, **kwargs)
# TODO(QoS): rethink handling missing binding case
if binding_db_obj:
return cls.get_by_id(context, binding_db_obj['policy_id'])
with db_api.autonested_transaction(context.session):
binding_db_obj = db_api.get_object(context, model, **kwargs)
# TODO(QoS): rethink handling missing binding case
if binding_db_obj:
return cls.get_by_id(context, binding_db_obj['policy_id'])
@classmethod
def get_network_policy(cls, context, network_id):
@ -92,6 +117,11 @@ class QosPolicy(base.NeutronObject):
return cls._get_object_policy(context, cls.port_binding_model,
port_id=port_id)
def create(self):
with db_api.autonested_transaction(self._context.session):
super(QosPolicy, self).create()
self._load_rules()
def attach_network(self, network_id):
qos_db_api.create_policy_network_binding(self._context,
policy_id=self.id,

View File

@ -96,7 +96,7 @@ class QosRule(base.NeutronObject):
obj.obj_reset_changes()
return obj
# TODO(QoS): create and update are not transactional safe
# TODO(QoS): Test that create is in single transaction
def create(self):
# TODO(QoS): enforce that type field value is bound to specific class
@ -104,18 +104,21 @@ class QosRule(base.NeutronObject):
# create base qos_rule
core_fields = self._get_changed_core_fields()
base_db_obj = db_api.create_object(
self._context, self.base_db_model, core_fields)
# create type specific qos_..._rule
addn_fields = self._get_changed_addn_fields()
self._copy_common_fields(core_fields, addn_fields)
addn_db_obj = db_api.create_object(
self._context, self.db_model, addn_fields)
with db_api.autonested_transaction(self._context.session):
base_db_obj = db_api.create_object(
self._context, self.base_db_model, core_fields)
# create type specific qos_..._rule
addn_fields = self._get_changed_addn_fields()
self._copy_common_fields(core_fields, addn_fields)
addn_db_obj = db_api.create_object(
self._context, self.db_model, addn_fields)
# merge two db objects into single neutron one
self.from_db_object(base_db_obj, addn_db_obj)
# TODO(QoS): Test that update is in single transaction
def update(self):
updated_db_objs = []
@ -123,16 +126,18 @@ class QosRule(base.NeutronObject):
# update base qos_rule, if needed
core_fields = self._get_changed_core_fields()
if core_fields:
base_db_obj = db_api.update_object(
self._context, self.base_db_model, self.id, core_fields)
updated_db_objs.append(base_db_obj)
addn_fields = self._get_changed_addn_fields()
if addn_fields:
addn_db_obj = db_api.update_object(
self._context, self.db_model, self.id, addn_fields)
updated_db_objs.append(addn_db_obj)
with db_api.autonested_transaction(self._context.session):
if core_fields:
base_db_obj = db_api.update_object(
self._context, self.base_db_model, self.id, core_fields)
updated_db_objs.append(base_db_obj)
addn_fields = self._get_changed_addn_fields()
if addn_fields:
addn_db_obj = db_api.update_object(
self._context, self.db_model, self.id, addn_fields)
updated_db_objs.append(addn_db_obj)
# update neutron object with values from both database objects
self.from_db_object(*updated_db_objs)

View File

@ -17,6 +17,7 @@ from neutron import manager
from neutron.api.rpc.callbacks import registry as rpc_registry
from neutron.api.rpc.callbacks import resources as rpc_resources
from neutron.db import db_base_plugin_common
from neutron.extensions import qos
from neutron.i18n import _LW
from neutron.objects.qos import policy as policy_object
@ -134,10 +135,11 @@ class QoSPlugin(qos.QoSPluginBase):
def _get_policy_obj(self, context, policy_id):
return policy_object.QosPolicy.get_by_id(context, policy_id)
@db_base_plugin_common.filter_fields
def get_policy(self, context, policy_id, fields=None):
#TODO(QoS): Support the fields parameter
return self._get_policy_obj(context, policy_id).to_dict()
@db_base_plugin_common.filter_fields
def get_policies(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
@ -174,12 +176,13 @@ class QoSPlugin(qos.QoSPluginBase):
rule.id = rule_id
rule.delete()
@db_base_plugin_common.filter_fields
def get_policy_bandwidth_limit_rule(self, context, rule_id,
policy_id, fields=None):
#TODO(QoS): Support the fields parameter
return rule_object.QosBandwidthLimitRule.get_by_id(context,
rule_id).to_dict()
@db_base_plugin_common.filter_fields
def get_policy_bandwidth_limit_rules(self, context, policy_id,
filters=None, fields=None,
sorts=None, limit=None,
@ -188,6 +191,7 @@ class QoSPlugin(qos.QoSPluginBase):
return [rule_obj.to_dict() for rule_obj in
rule_object.QosBandwidthLimitRule.get_objects(context)]
@db_base_plugin_common.filter_fields
def get_rule_types(self, context, filters=None, fields=None,
sorts=None, limit=None,
marker=None, page_reverse=False):

View File

@ -0,0 +1,64 @@
# Copyright (c) 2015 Red Hat, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from neutron.db import db_base_plugin_common
from neutron.tests import base
class FilterFieldsTestCase(base.BaseTestCase):
@db_base_plugin_common.filter_fields
def method_dict(self, fields=None):
return {'one': 1, 'two': 2, 'three': 3}
@db_base_plugin_common.filter_fields
def method_list(self, fields=None):
return [self.method_dict() for _ in range(3)]
@db_base_plugin_common.filter_fields
def method_multiple_arguments(self, not_used, fields=None,
also_not_used=None):
return {'one': 1, 'two': 2, 'three': 3}
def test_no_fields(self):
expected = {'one': 1, 'two': 2, 'three': 3}
observed = self.method_dict()
self.assertEqual(expected, observed)
def test_dict(self):
expected = {'two': 2}
observed = self.method_dict(['two'])
self.assertEqual(expected, observed)
def test_list(self):
expected = [{'two': 2}, {'two': 2}, {'two': 2}]
observed = self.method_list(['two'])
self.assertEqual(expected, observed)
def test_multiple_arguments_positional(self):
expected = {'two': 2}
observed = self.method_multiple_arguments(list(), ['two'])
self.assertEqual(expected, observed)
def test_multiple_arguments_positional_and_keywords(self):
expected = {'two': 2}
observed = self.method_multiple_arguments(fields=['two'],
not_used=None)
self.assertEqual(expected, observed)
def test_multiple_arguments_keyword(self):
expected = {'two': 2}
observed = self.method_multiple_arguments(list(), fields=['two'])
self.assertEqual(expected, observed)

View File

@ -10,6 +10,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.objects.qos import policy
@ -22,6 +24,50 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
_test_class = policy.QosPolicy
def setUp(self):
super(QosPolicyObjectTestCase, self).setUp()
self.db_qos_rules = [self.get_random_fields(rule.QosRule)
for _ in range(3)]
# Tie qos rules with policies
self.db_qos_rules[0]['qos_policy_id'] = self.db_objs[0]['id']
self.db_qos_rules[1]['qos_policy_id'] = self.db_objs[0]['id']
self.db_qos_rules[2]['qos_policy_id'] = self.db_objs[1]['id']
self.db_qos_bandwidth_rules = [
self.get_random_fields(rule.QosBandwidthLimitRule)
for _ in range(3)]
# Tie qos rules with qos bandwidth limit rules
for i, qos_rule in enumerate(self.db_qos_rules):
self.db_qos_bandwidth_rules[i]['id'] = qos_rule['id']
self.model_map = {
self._test_class.db_model: self.db_objs,
rule.QosRule.base_db_model: self.db_qos_rules,
rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules}
def fake_get_objects(self, context, model, qos_policy_id=None):
objs = self.model_map[model]
if model is rule.QosRule.base_db_model and qos_policy_id:
return [obj for obj in objs
if obj['qos_policy_id'] == qos_policy_id]
return objs
def fake_get_object(self, context, model, id):
objects = self.model_map[model]
return [obj for obj in objects if obj['id'] == id][0]
def test_get_objects(self):
with mock.patch.object(
db_api, 'get_objects',
side_effect=self.fake_get_objects),\
mock.patch.object(
db_api, 'get_object',
side_effect=self.fake_get_object):
objs = self._test_class.get_objects(self.context)
self._validate_objects(self.db_objs, objs)
class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):
@ -42,6 +88,19 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
policy_obj.create()
return policy_obj
def _create_test_policy_with_rule(self):
policy_obj = self._create_test_policy()
rule_fields = self.get_random_fields(
obj_cls=rule.QosBandwidthLimitRule)
rule_fields['qos_policy_id'] = policy_obj.id
rule_fields['tenant_id'] = policy_obj.tenant_id
rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
rule_obj.create()
return policy_obj, rule_obj
def _create_test_network(self):
# TODO(ihrachys): replace with network.create() once we get an object
# implementation for networks
@ -111,16 +170,22 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
self.assertIsNone(policy_obj)
def test_synthetic_rule_fields(self):
obj = policy.QosPolicy(self.context, **self.db_obj)
obj.create()
policy_obj, rule_obj = self._create_test_policy_with_rule()
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
self.assertEqual([rule_obj], policy_obj.bandwidth_limit_rules)
rule_fields = self.get_random_fields(
obj_cls=rule.QosBandwidthLimitRule)
rule_fields['qos_policy_id'] = obj.id
rule_fields['tenant_id'] = obj.tenant_id
def test_create_is_in_single_transaction(self):
obj = self._test_class(self.context, **self.db_obj)
with mock.patch('sqlalchemy.engine.'
'Transaction.commit') as mock_commit,\
mock.patch.object(obj._context.session, 'add'):
obj.create()
self.assertEqual(1, mock_commit.call_count)
rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
rule_obj.create()
def test_get_by_id_fetches_rules_non_lazily(self):
policy_obj, rule_obj = self._create_test_policy_with_rule()
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
obj = policy.QosPolicy.get_by_id(self.context, obj.id)
self.assertEqual([rule_obj], obj.bandwidth_limit_rules)
primitive = policy_obj.obj_to_primitive()
self.assertNotEqual([], (primitive['versioned_object.data']
['bandwidth_limit_rules']))

View File

@ -23,10 +23,15 @@ from neutron.objects import base
from neutron.tests import base as test_base
class FakeModel(object):
def __init__(self, *args, **kwargs):
pass
@obj_base.VersionedObjectRegistry.register
class FakeNeutronObject(base.NeutronObject):
db_model = 'fake_model'
db_model = FakeModel
fields = {
'id': obj_fields.UUIDField(),
@ -106,13 +111,16 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
with mock.patch.object(db_api, 'get_objects',
return_value=self.db_objs) as get_objects_mock:
objs = self._test_class.get_objects(self.context)
self.assertFalse(
filter(lambda obj: not self._is_test_class(obj), objs))
self.assertEqual(
sorted(self.db_objs),
sorted(get_obj_db_fields(obj) for obj in objs))
get_objects_mock.assert_called_once_with(
self.context, self._test_class.db_model)
self._validate_objects(self.db_objs, objs)
get_objects_mock.assert_called_once_with(
self.context, self._test_class.db_model)
def _validate_objects(self, expected, observed):
self.assertFalse(
filter(lambda obj: not self._is_test_class(obj), observed))
self.assertEqual(
sorted(expected),
sorted(get_obj_db_fields(obj) for obj in observed))
def _check_equal(self, obj, db_obj):
self.assertEqual(