Merge "Added Unit tests for object Accelerator and Deployable"

This commit is contained in:
Zuul 2018-02-08 15:17:31 +00:00 committed by Gerrit Code Review
commit e6d198a60e
7 changed files with 631 additions and 0 deletions

View File

@ -15,6 +15,7 @@
"""Cyborg common internal object model"""
import netaddr
from oslo_utils import versionutils
from oslo_versionedobjects import base as object_base
@ -88,3 +89,57 @@ class CyborgObject(object_base.VersionedObject):
class CyborgObjectSerializer(object_base.VersionedObjectSerializer):
# Base class to use for object hydration
OBJ_BASE_CLASS = CyborgObject
CyborgObjectDictCompat = object_base.VersionedObjectDictCompat
class CyborgPersistentObject(object):
"""Mixin class for Persistent objects.
This adds the fields that we use in common for most persistent objects.
"""
fields = {
'created_at': object_fields.DateTimeField(nullable=True),
'updated_at': object_fields.DateTimeField(nullable=True),
'deleted_at': object_fields.DateTimeField(nullable=True),
'deleted': object_fields.BooleanField(default=False),
}
class ObjectListBase(object_base.ObjectListBase):
@classmethod
def _obj_primitive_key(cls, field):
return 'cyborg_object.%s' % field
@classmethod
def _obj_primitive_field(cls, primitive, field,
default=object_fields.UnspecifiedDefault):
key = cls._obj_primitive_key(field)
if default == object_fields.UnspecifiedDefault:
return primitive[key]
else:
return primitive.get(key, default)
def obj_to_primitive(obj):
"""Recursively turn an object into a python primitive.
A CyborgObject becomes a dict, and anything that implements ObjectListBase
becomes a list.
"""
if isinstance(obj, ObjectListBase):
return [obj_to_primitive(x) for x in obj]
elif isinstance(obj, CyborgObject):
result = {}
for key in obj.obj_fields:
if obj.obj_attr_is_set(key) or key in obj.obj_extra_fields:
result[key] = obj_to_primitive(getattr(obj, key))
return result
elif isinstance(obj, netaddr.IPAddress):
return str(obj)
elif isinstance(obj, netaddr.IPNetwork):
return str(obj)
else:
return obj

View File

@ -22,3 +22,9 @@ UUIDField = object_fields.UUIDField
StringField = object_fields.StringField
DateTimeField = object_fields.DateTimeField
BooleanField = object_fields.BooleanField
ObjectField = object_fields.ObjectField
ListOfObjectsField = object_fields.ListOfObjectsField
ListOfStringsField = object_fields.ListOfStringsField
IPAddressField = object_fields.IPAddressField
IPNetworkField = object_fields.IPNetworkField
UnspecifiedDefault = object_fields.UnspecifiedDefault

View File

@ -0,0 +1,66 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
def fake_db_accelerator(**updates):
db_accelerator = {
'id': 1,
'deleted': False,
'uuid': uuidutils.generate_uuid(),
'name': 'fake-name',
'description': 'fake-desc',
'project_id': 'fake-pid',
'user_id': 'fake-uid',
'device_type': 'fake-dtype',
'acc_type': 'fake-acc_type',
'acc_capability': 'fake-cap',
'vendor_id': 'fake-vid',
'product_id': 'fake-pid',
'remotable': 0
}
for name, field in objects.Accelerator.fields.items():
if name in db_accelerator:
continue
if field.nullable:
db_accelerator[name] = None
elif field.default != fields.UnspecifiedDefault:
db_accelerator[name] = field.default
else:
raise Exception('fake_db_accelerator needs help with %s' % name)
if updates:
db_accelerator.update(updates)
return db_accelerator
def fake_accelerator_obj(context, obj_accelerator_class=None, **updates):
if obj_accelerator_class is None:
obj_accelerator_class = objects.Accelerator
expected_attrs = updates.pop('expected_attrs', None)
acc = obj_instance_class._from_db_object(context,
obj_instance_class(),
fake_db_instance(**updates),
expected_attrs=expected_attrs)
acc.obj_reset_changes()
return acc

View File

@ -0,0 +1,69 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
def fake_db_deployable(**updates):
root_uuid = uuidutils.generate_uuid()
db_deployable = {
'id': 1,
'deleted': False,
'uuid': root_uuid,
'name': 'dp_name',
'parent_uuid': None,
'root_uuid': root_uuid,
'pcie_address': '00:7f:0b.2',
'host': 'host_name',
'board': 'KU115',
'vendor': 'Xilinx',
'version': '1.0',
'type': 'pf',
'assignable': True,
'instance_uuid': None,
'availability': 'Available'
}
for name, field in objects.Deployable.fields.items():
if name in db_deployable:
continue
if field.nullable:
db_deployable[name] = None
elif field.default != fields.UnspecifiedDefault:
db_deployable[name] = field.default
else:
raise Exception('fake_db_deployable needs help with %s' % name)
if updates:
db_deployable.update(updates)
return db_deployable
def fake_deployable_obj(context, obj_dpl_class=None, **updates):
if obj_dpl_class is None:
obj_dpl_class = objects.Deployable
expected_attrs = updates.pop('expected_attrs', None)
deploy = obj_dpl_class._from_db_object(context,
obj_dpl_class(),
fake_db_deployable(**updates),
expected_attrs=expected_attrs)
deploy.obj_reset_changes()
return deploy

View File

@ -0,0 +1,104 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import mock
import netaddr
from oslo_db import exception as db_exc
from oslo_serialization import jsonutils
from oslo_utils import timeutils
from oslo_context import context
from cyborg import db
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
class _TestAcceleratorObject(DbTestCase):
@property
def fake_accelerator(self):
db_acc = fake_accelerator.fake_db_accelerator(id=2)
return db_acc
@mock.patch.object(db.api.Connection, 'accelerator_create')
def test_create(self, mock_create):
mock_create.return_value = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**mock_create.return_value)
acc.create(self.context)
self.assertEqual(self.fake_accelerator['id'], acc.id)
@mock.patch.object(db.api.Connection, 'accelerator_get')
def test_get(self, mock_get):
mock_get.return_value = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**mock_get.return_value)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc['uuid'])
self.assertEqual(acc_get.uuid, acc.uuid)
@mock.patch.object(db.api.Connection, 'accelerator_update')
def test_save(self, mock_save):
mock_save.return_value = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**mock_save.return_value)
acc.create(self.context)
acc.name = 'test_save'
acc.save(self.context)
acc_get = objects.Accelerator.get(self.context, acc['uuid'])
self.assertEqual(acc_get.name, 'test_save')
@mock.patch.object(db.api.Connection, 'accelerator_delete')
def test_destroy(self, mock_destroy):
mock_destroy.return_value = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**mock_destroy.return_value)
acc.create(self.context)
self.assertEqual(self.fake_accelerator['id'], acc.id)
acc.destroy(self.context)
self.assertRaises(exception.AcceleratorNotFound,
objects.Accelerator.get, self.context,
acc['uuid'])
class TestAcceleratorObject(test_objects._LocalTest,
_TestAcceleratorObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
accelerator = fake_accelerator.fake_accelerator_obj(self.context)
fields_with_save_methods = [field for field in accelerator.fields
if hasattr(accelerator,
'_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(accelerator, '_save_%s' % field)
@mock.patch.object(accelerator, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
accelerator.obj_reset_changes(fields=[field])
accelerator._changed_fields.add(field)
self.assertRaises(expected_exception, accelerator.save)
accelerator.obj_reset_changes(fields=[field])
_test()

View File

@ -0,0 +1,103 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import mock
import netaddr
from oslo_db import exception as db_exc
from oslo_serialization import jsonutils
from oslo_utils import timeutils
from oslo_context import context
from cyborg import db
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_deployable
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
class _TestDeployableObject(DbTestCase):
@property
def fake_deployable(self):
db_deploy = fake_deployable.fake_db_deployable(id=2)
return db_deploy
@mock.patch.object(db.api.Connection, 'deployable_create')
def test_create(self, mock_create):
mock_create.return_value = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_create.return_value)
dpl.create(self.context)
self.assertEqual(self.fake_deployable['id'], dpl.id)
@mock.patch.object(db.api.Connection, 'deployable_get')
def test_get(self, mock_get):
mock_get.return_value = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_get.return_value)
dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl['uuid'])
self.assertEqual(dpl_get.uuid, dpl.uuid)
@mock.patch.object(db.api.Connection, 'deployable_update')
def test_save(self, mock_save):
mock_save.return_value = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_save.return_value)
dpl.create(self.context)
dpl.host = 'test_save'
dpl.save(self.context)
dpl_get = objects.Deployable.get(self.context, dpl['uuid'])
self.assertEqual(dpl_get.host, 'test_save')
@mock.patch.object(db.api.Connection, 'deployable_delete')
def test_destroy(self, mock_destroy):
mock_destroy.return_value = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_destroy.return_value)
dpl.create(self.context)
self.assertEqual(self.fake_deployable['id'], dpl.id)
dpl.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.Deployable.get, self.context,
dpl['uuid'])
class TestDeployableObject(test_objects._LocalTest,
_TestDeployableObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
deployable = fake_deployable.fake_deployable_obj(self.context)
fields_with_save_methods = [field for field in deployable.fields
if hasattr(deployable, '_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(deployable, '_save_%s' % field)
@mock.patch.object(deployable, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
deployable.obj_reset_changes(fields=[field])
deployable._changed_fields.add(field)
self.assertRaises(expected_exception, deployable.save)
deployable.obj_reset_changes(fields=[field])
_test()

View File

@ -0,0 +1,228 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import copy
import datetime
import inspect
import os
import pprint
import fixtures
import mock
from oslo_log import log
from oslo_utils import timeutils
from oslo_versionedobjects import base as ovo_base
from oslo_versionedobjects import exception as ovo_exc
from oslo_versionedobjects import fixture
import six
from oslo_context import context
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg.objects import fields
from cyborg import tests as test
LOG = log.getLogger(__name__)
class MyOwnedObject(base.CyborgPersistentObject, base.CyborgObject):
VERSION = '1.0'
fields = {'baz': fields.IntegerField()}
class MyObj(base.CyborgPersistentObject, base.CyborgObject,
base.CyborgObjectDictCompat):
VERSION = '1.6'
fields = {'foo': fields.IntegerField(default=1),
'bar': fields.StringField(),
'missing': fields.StringField(),
'readonly': fields.IntegerField(read_only=True),
'rel_object': fields.ObjectField('MyOwnedObject', nullable=True),
'rel_objects': fields.ListOfObjectsField('MyOwnedObject',
nullable=True),
'mutable_default': fields.ListOfStringsField(default=[]),
}
@staticmethod
def _from_db_object(context, obj, db_obj):
self = MyObj()
self.foo = db_obj['foo']
self.bar = db_obj['bar']
self.missing = db_obj['missing']
self.readonly = 1
self._context = context
return self
def obj_load_attr(self, attrname):
setattr(self, attrname, 'loaded!')
def query(cls, context):
obj = cls(context=context, foo=1, bar='bar')
obj.obj_reset_changes()
return obj
def marco(self):
return 'polo'
def _update_test(self):
self.bar = 'updated'
def save(self):
self.obj_reset_changes()
def refresh(self):
self.foo = 321
self.bar = 'refreshed'
self.obj_reset_changes()
def modify_save_modify(self):
self.bar = 'meow'
self.save()
self.foo = 42
self.rel_object = MyOwnedObject(baz=42)
def obj_make_compatible(self, primitive, target_version):
super(MyObj, self).obj_make_compatible(primitive, target_version)
# NOTE(danms): Simulate an older version that had a different
# format for the 'bar' attribute
if target_version == '1.1' and 'bar' in primitive:
primitive['bar'] = 'old%s' % primitive['bar']
class RandomMixInWithNoFields(object):
"""Used to test object inheritance using a mixin that has no fields."""
pass
@base.CyborgObjectRegistry.register_if(False)
class TestSubclassedObject(RandomMixInWithNoFields, MyObj):
fields = {'new_field': fields.StringField()}
class TestObjToPrimitive(test.base.TestCase):
def test_obj_to_primitive_list(self):
@base.CyborgObjectRegistry.register_if(False)
class MyObjElement(base.CyborgObject):
fields = {'foo': fields.IntegerField()}
def __init__(self, foo):
super(MyObjElement, self).__init__()
self.foo = foo
@base.CyborgObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.CyborgObject):
fields = {'objects': fields.ListOfObjectsField('MyObjElement')}
mylist = MyList()
mylist.objects = [MyObjElement(1), MyObjElement(2), MyObjElement(3)]
self.assertEqual([1, 2, 3],
[x['foo'] for x in base.obj_to_primitive(mylist)])
def test_obj_to_primitive_dict(self):
base.CyborgObjectRegistry.register(MyObj)
myobj = MyObj(foo=1, bar='foo')
self.assertEqual({'foo': 1, 'bar': 'foo'},
base.obj_to_primitive(myobj))
def test_obj_to_primitive_recursive(self):
base.CyborgObjectRegistry.register(MyObj)
class MyList(base.ObjectListBase, base.CyborgObject):
fields = {'objects': fields.ListOfObjectsField('MyObj')}
mylist = MyList(objects=[MyObj(), MyObj()])
for i, value in enumerate(mylist):
value.foo = i
self.assertEqual([{'foo': 0}, {'foo': 1}],
base.obj_to_primitive(mylist))
def test_obj_to_primitive_with_ip_addr(self):
@base.CyborgObjectRegistry.register_if(False)
class TestObject(base.CyborgObject):
fields = {'addr': fields.IPAddressField(),
'cidr': fields.IPNetworkField()}
obj = TestObject(addr='1.2.3.4', cidr='1.1.1.1/16')
self.assertEqual({'addr': '1.2.3.4', 'cidr': '1.1.1.1/16'},
base.obj_to_primitive(obj))
def compare_obj(test, obj, db_obj, subs=None, allow_missing=None,
comparators=None):
"""Compare a CyborgObject and a dict-like database object.
This automatically converts TZ-aware datetimes and iterates over
the fields of the object.
:param:test: The TestCase doing the comparison
:param:obj: The CyborgObject to examine
:param:db_obj: The dict-like database object to use as reference
:param:subs: A dict of objkey=dbkey field substitutions
:param:allow_missing: A list of fields that may not be in db_obj
:param:comparators: Map of comparator functions to use for certain fields
"""
if subs is None:
subs = {}
if allow_missing is None:
allow_missing = []
if comparators is None:
comparators = {}
for key in obj.fields:
if key in allow_missing and not obj.obj_attr_is_set(key):
continue
obj_val = getattr(obj, key)
db_key = subs.get(key, key)
db_val = db_obj[db_key]
if isinstance(obj_val, datetime.datetime):
obj_val = obj_val.replace(tzinfo=None)
if key in comparators:
comparator = comparators[key]
comparator(db_val, obj_val)
else:
test.assertEqual(db_val, obj_val)
class _BaseTestCase(test.base.TestCase):
def setUp(self):
super(_BaseTestCase, self).setUp()
self.user_id = 'fake-user'
self.project_id = 'fake-project'
self.context = context.RequestContext(self.user_id, self.project_id)
base.CyborgObjectRegistry.register(MyObj)
base.CyborgObjectRegistry.register(MyOwnedObject)
def compare_obj(self, obj, db_obj, subs=None, allow_missing=None,
comparators=None):
compare_obj(self, obj, db_obj, subs=subs, allow_missing=allow_missing,
comparators=comparators)
def str_comparator(self, expected, obj_val):
"""Compare an object field to a string in the db by performing
a simple coercion on the object field value.
"""
self.assertEqual(expected, str(obj_val))
class _LocalTest(_BaseTestCase):
def setUp(self):
super(_LocalTest, self).setUp()