Added Unit tests for object Accelerator and Deployable
Change-Id: I9cf93306e7770172a95825d320de15e2d60c4d45
This commit is contained in:
parent
6ee7ed802f
commit
2ae19254bb
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Cyborg common internal object model"""
|
||||
|
||||
import netaddr
|
||||
from oslo_utils import versionutils
|
||||
from oslo_versionedobjects import base as object_base
|
||||
|
||||
|
@ -88,3 +89,57 @@ class CyborgObject(object_base.VersionedObject):
|
|||
class CyborgObjectSerializer(object_base.VersionedObjectSerializer):
|
||||
# Base class to use for object hydration
|
||||
OBJ_BASE_CLASS = CyborgObject
|
||||
|
||||
|
||||
CyborgObjectDictCompat = object_base.VersionedObjectDictCompat
|
||||
|
||||
|
||||
class CyborgPersistentObject(object):
|
||||
"""Mixin class for Persistent objects.
|
||||
|
||||
This adds the fields that we use in common for most persistent objects.
|
||||
"""
|
||||
fields = {
|
||||
'created_at': object_fields.DateTimeField(nullable=True),
|
||||
'updated_at': object_fields.DateTimeField(nullable=True),
|
||||
'deleted_at': object_fields.DateTimeField(nullable=True),
|
||||
'deleted': object_fields.BooleanField(default=False),
|
||||
}
|
||||
|
||||
|
||||
class ObjectListBase(object_base.ObjectListBase):
|
||||
|
||||
@classmethod
|
||||
def _obj_primitive_key(cls, field):
|
||||
return 'cyborg_object.%s' % field
|
||||
|
||||
@classmethod
|
||||
def _obj_primitive_field(cls, primitive, field,
|
||||
default=object_fields.UnspecifiedDefault):
|
||||
key = cls._obj_primitive_key(field)
|
||||
if default == object_fields.UnspecifiedDefault:
|
||||
return primitive[key]
|
||||
else:
|
||||
return primitive.get(key, default)
|
||||
|
||||
|
||||
def obj_to_primitive(obj):
|
||||
"""Recursively turn an object into a python primitive.
|
||||
|
||||
A CyborgObject becomes a dict, and anything that implements ObjectListBase
|
||||
becomes a list.
|
||||
"""
|
||||
if isinstance(obj, ObjectListBase):
|
||||
return [obj_to_primitive(x) for x in obj]
|
||||
elif isinstance(obj, CyborgObject):
|
||||
result = {}
|
||||
for key in obj.obj_fields:
|
||||
if obj.obj_attr_is_set(key) or key in obj.obj_extra_fields:
|
||||
result[key] = obj_to_primitive(getattr(obj, key))
|
||||
return result
|
||||
elif isinstance(obj, netaddr.IPAddress):
|
||||
return str(obj)
|
||||
elif isinstance(obj, netaddr.IPNetwork):
|
||||
return str(obj)
|
||||
else:
|
||||
return obj
|
||||
|
|
|
@ -22,3 +22,9 @@ UUIDField = object_fields.UUIDField
|
|||
StringField = object_fields.StringField
|
||||
DateTimeField = object_fields.DateTimeField
|
||||
BooleanField = object_fields.BooleanField
|
||||
ObjectField = object_fields.ObjectField
|
||||
ListOfObjectsField = object_fields.ListOfObjectsField
|
||||
ListOfStringsField = object_fields.ListOfStringsField
|
||||
IPAddressField = object_fields.IPAddressField
|
||||
IPNetworkField = object_fields.IPNetworkField
|
||||
UnspecifiedDefault = object_fields.UnspecifiedDefault
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2018 Huawei Technologies Co.,LTD.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
from oslo_serialization import jsonutils
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
from cyborg import objects
|
||||
from cyborg.objects import fields
|
||||
|
||||
|
||||
def fake_db_accelerator(**updates):
|
||||
db_accelerator = {
|
||||
'id': 1,
|
||||
'deleted': False,
|
||||
'uuid': uuidutils.generate_uuid(),
|
||||
'name': 'fake-name',
|
||||
'description': 'fake-desc',
|
||||
'project_id': 'fake-pid',
|
||||
'user_id': 'fake-uid',
|
||||
'device_type': 'fake-dtype',
|
||||
'acc_type': 'fake-acc_type',
|
||||
'acc_capability': 'fake-cap',
|
||||
'vendor_id': 'fake-vid',
|
||||
'product_id': 'fake-pid',
|
||||
'remotable': 0
|
||||
}
|
||||
|
||||
for name, field in objects.Accelerator.fields.items():
|
||||
if name in db_accelerator:
|
||||
continue
|
||||
if field.nullable:
|
||||
db_accelerator[name] = None
|
||||
elif field.default != fields.UnspecifiedDefault:
|
||||
db_accelerator[name] = field.default
|
||||
else:
|
||||
raise Exception('fake_db_accelerator needs help with %s' % name)
|
||||
|
||||
if updates:
|
||||
db_accelerator.update(updates)
|
||||
|
||||
return db_accelerator
|
||||
|
||||
|
||||
def fake_accelerator_obj(context, obj_accelerator_class=None, **updates):
|
||||
if obj_accelerator_class is None:
|
||||
obj_accelerator_class = objects.Accelerator
|
||||
expected_attrs = updates.pop('expected_attrs', None)
|
||||
acc = obj_instance_class._from_db_object(context,
|
||||
obj_instance_class(),
|
||||
fake_db_instance(**updates),
|
||||
expected_attrs=expected_attrs)
|
||||
acc.obj_reset_changes()
|
||||
return acc
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2018 Huawei Technologies Co.,LTD.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
from oslo_serialization import jsonutils
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
from cyborg import objects
|
||||
from cyborg.objects import fields
|
||||
|
||||
|
||||
def fake_db_deployable(**updates):
|
||||
root_uuid = uuidutils.generate_uuid()
|
||||
db_deployable = {
|
||||
'id': 1,
|
||||
'deleted': False,
|
||||
'uuid': root_uuid,
|
||||
'name': 'dp_name',
|
||||
'parent_uuid': None,
|
||||
'root_uuid': root_uuid,
|
||||
'pcie_address': '00:7f:0b.2',
|
||||
'host': 'host_name',
|
||||
'board': 'KU115',
|
||||
'vendor': 'Xilinx',
|
||||
'version': '1.0',
|
||||
'type': 'pf',
|
||||
'assignable': True,
|
||||
'instance_uuid': None,
|
||||
'availability': 'Available'
|
||||
}
|
||||
|
||||
for name, field in objects.Deployable.fields.items():
|
||||
if name in db_deployable:
|
||||
continue
|
||||
if field.nullable:
|
||||
db_deployable[name] = None
|
||||
elif field.default != fields.UnspecifiedDefault:
|
||||
db_deployable[name] = field.default
|
||||
else:
|
||||
raise Exception('fake_db_deployable needs help with %s' % name)
|
||||
|
||||
if updates:
|
||||
db_deployable.update(updates)
|
||||
|
||||
return db_deployable
|
||||
|
||||
|
||||
def fake_deployable_obj(context, obj_dpl_class=None, **updates):
|
||||
if obj_dpl_class is None:
|
||||
obj_dpl_class = objects.Deployable
|
||||
expected_attrs = updates.pop('expected_attrs', None)
|
||||
deploy = obj_dpl_class._from_db_object(context,
|
||||
obj_dpl_class(),
|
||||
fake_db_deployable(**updates),
|
||||
expected_attrs=expected_attrs)
|
||||
deploy.obj_reset_changes()
|
||||
return deploy
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2018 Huawei Technologies Co.,LTD.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
import mock
|
||||
import netaddr
|
||||
from oslo_db import exception as db_exc
|
||||
from oslo_serialization import jsonutils
|
||||
from oslo_utils import timeutils
|
||||
from oslo_context import context
|
||||
|
||||
from cyborg import db
|
||||
from cyborg.common import exception
|
||||
from cyborg import objects
|
||||
from cyborg.objects import base
|
||||
from cyborg import tests as test
|
||||
from cyborg.tests.unit import fake_accelerator
|
||||
from cyborg.tests.unit.objects import test_objects
|
||||
from cyborg.tests.unit.db.base import DbTestCase
|
||||
|
||||
|
||||
class _TestAcceleratorObject(DbTestCase):
|
||||
@property
|
||||
def fake_accelerator(self):
|
||||
db_acc = fake_accelerator.fake_db_accelerator(id=2)
|
||||
return db_acc
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'accelerator_create')
|
||||
def test_create(self, mock_create):
|
||||
mock_create.return_value = self.fake_accelerator
|
||||
acc = objects.Accelerator(context=self.context,
|
||||
**mock_create.return_value)
|
||||
acc.create(self.context)
|
||||
|
||||
self.assertEqual(self.fake_accelerator['id'], acc.id)
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'accelerator_get')
|
||||
def test_get(self, mock_get):
|
||||
mock_get.return_value = self.fake_accelerator
|
||||
acc = objects.Accelerator(context=self.context,
|
||||
**mock_get.return_value)
|
||||
acc.create(self.context)
|
||||
acc_get = objects.Accelerator.get(self.context, acc['uuid'])
|
||||
self.assertEqual(acc_get.uuid, acc.uuid)
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'accelerator_update')
|
||||
def test_save(self, mock_save):
|
||||
mock_save.return_value = self.fake_accelerator
|
||||
acc = objects.Accelerator(context=self.context,
|
||||
**mock_save.return_value)
|
||||
acc.create(self.context)
|
||||
acc.name = 'test_save'
|
||||
acc.save(self.context)
|
||||
acc_get = objects.Accelerator.get(self.context, acc['uuid'])
|
||||
self.assertEqual(acc_get.name, 'test_save')
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'accelerator_delete')
|
||||
def test_destroy(self, mock_destroy):
|
||||
mock_destroy.return_value = self.fake_accelerator
|
||||
acc = objects.Accelerator(context=self.context,
|
||||
**mock_destroy.return_value)
|
||||
acc.create(self.context)
|
||||
self.assertEqual(self.fake_accelerator['id'], acc.id)
|
||||
acc.destroy(self.context)
|
||||
self.assertRaises(exception.AcceleratorNotFound,
|
||||
objects.Accelerator.get, self.context,
|
||||
acc['uuid'])
|
||||
|
||||
|
||||
class TestAcceleratorObject(test_objects._LocalTest,
|
||||
_TestAcceleratorObject):
|
||||
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
|
||||
expected_exception):
|
||||
|
||||
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
|
||||
'key_table')
|
||||
# Prevent lazy-loading any fields, results in InstanceNotFound
|
||||
accelerator = fake_accelerator.fake_accelerator_obj(self.context)
|
||||
fields_with_save_methods = [field for field in accelerator.fields
|
||||
if hasattr(accelerator,
|
||||
'_save_%s' % field)]
|
||||
for field in fields_with_save_methods:
|
||||
@mock.patch.object(accelerator, '_save_%s' % field)
|
||||
@mock.patch.object(accelerator, 'obj_attr_is_set')
|
||||
def _test(mock_is_set, mock_save_field):
|
||||
mock_is_set.return_value = True
|
||||
mock_save_field.side_effect = error
|
||||
accelerator.obj_reset_changes(fields=[field])
|
||||
accelerator._changed_fields.add(field)
|
||||
self.assertRaises(expected_exception, accelerator.save)
|
||||
accelerator.obj_reset_changes(fields=[field])
|
||||
_test()
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2018 Huawei Technologies Co.,LTD.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
import mock
|
||||
import netaddr
|
||||
from oslo_db import exception as db_exc
|
||||
from oslo_serialization import jsonutils
|
||||
from oslo_utils import timeutils
|
||||
from oslo_context import context
|
||||
|
||||
from cyborg import db
|
||||
from cyborg.common import exception
|
||||
from cyborg import objects
|
||||
from cyborg.objects import base
|
||||
from cyborg import tests as test
|
||||
from cyborg.tests.unit import fake_deployable
|
||||
from cyborg.tests.unit.objects import test_objects
|
||||
from cyborg.tests.unit.db.base import DbTestCase
|
||||
|
||||
|
||||
class _TestDeployableObject(DbTestCase):
|
||||
@property
|
||||
def fake_deployable(self):
|
||||
db_deploy = fake_deployable.fake_db_deployable(id=2)
|
||||
return db_deploy
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'deployable_create')
|
||||
def test_create(self, mock_create):
|
||||
mock_create.return_value = self.fake_deployable
|
||||
dpl = objects.Deployable(context=self.context,
|
||||
**mock_create.return_value)
|
||||
dpl.create(self.context)
|
||||
|
||||
self.assertEqual(self.fake_deployable['id'], dpl.id)
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'deployable_get')
|
||||
def test_get(self, mock_get):
|
||||
mock_get.return_value = self.fake_deployable
|
||||
dpl = objects.Deployable(context=self.context,
|
||||
**mock_get.return_value)
|
||||
dpl.create(self.context)
|
||||
dpl_get = objects.Deployable.get(self.context, dpl['uuid'])
|
||||
self.assertEqual(dpl_get.uuid, dpl.uuid)
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'deployable_update')
|
||||
def test_save(self, mock_save):
|
||||
mock_save.return_value = self.fake_deployable
|
||||
dpl = objects.Deployable(context=self.context,
|
||||
**mock_save.return_value)
|
||||
dpl.create(self.context)
|
||||
dpl.host = 'test_save'
|
||||
dpl.save(self.context)
|
||||
dpl_get = objects.Deployable.get(self.context, dpl['uuid'])
|
||||
self.assertEqual(dpl_get.host, 'test_save')
|
||||
|
||||
@mock.patch.object(db.api.Connection, 'deployable_delete')
|
||||
def test_destroy(self, mock_destroy):
|
||||
mock_destroy.return_value = self.fake_deployable
|
||||
dpl = objects.Deployable(context=self.context,
|
||||
**mock_destroy.return_value)
|
||||
dpl.create(self.context)
|
||||
self.assertEqual(self.fake_deployable['id'], dpl.id)
|
||||
dpl.destroy(self.context)
|
||||
self.assertRaises(exception.DeployableNotFound,
|
||||
objects.Deployable.get, self.context,
|
||||
dpl['uuid'])
|
||||
|
||||
|
||||
class TestDeployableObject(test_objects._LocalTest,
|
||||
_TestDeployableObject):
|
||||
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
|
||||
expected_exception):
|
||||
|
||||
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
|
||||
'key_table')
|
||||
# Prevent lazy-loading any fields, results in InstanceNotFound
|
||||
deployable = fake_deployable.fake_deployable_obj(self.context)
|
||||
fields_with_save_methods = [field for field in deployable.fields
|
||||
if hasattr(deployable, '_save_%s' % field)]
|
||||
for field in fields_with_save_methods:
|
||||
@mock.patch.object(deployable, '_save_%s' % field)
|
||||
@mock.patch.object(deployable, 'obj_attr_is_set')
|
||||
def _test(mock_is_set, mock_save_field):
|
||||
mock_is_set.return_value = True
|
||||
mock_save_field.side_effect = error
|
||||
deployable.obj_reset_changes(fields=[field])
|
||||
deployable._changed_fields.add(field)
|
||||
self.assertRaises(expected_exception, deployable.save)
|
||||
deployable.obj_reset_changes(fields=[field])
|
||||
_test()
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2018 Huawei Technologies Co.,LTD.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import datetime
|
||||
import inspect
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import fixtures
|
||||
import mock
|
||||
from oslo_log import log
|
||||
from oslo_utils import timeutils
|
||||
from oslo_versionedobjects import base as ovo_base
|
||||
from oslo_versionedobjects import exception as ovo_exc
|
||||
from oslo_versionedobjects import fixture
|
||||
import six
|
||||
|
||||
from oslo_context import context
|
||||
|
||||
from cyborg.common import exception
|
||||
from cyborg import objects
|
||||
from cyborg.objects import base
|
||||
from cyborg.objects import fields
|
||||
from cyborg import tests as test
|
||||
|
||||
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
|
||||
class MyOwnedObject(base.CyborgPersistentObject, base.CyborgObject):
|
||||
VERSION = '1.0'
|
||||
fields = {'baz': fields.IntegerField()}
|
||||
|
||||
|
||||
class MyObj(base.CyborgPersistentObject, base.CyborgObject,
|
||||
base.CyborgObjectDictCompat):
|
||||
VERSION = '1.6'
|
||||
fields = {'foo': fields.IntegerField(default=1),
|
||||
'bar': fields.StringField(),
|
||||
'missing': fields.StringField(),
|
||||
'readonly': fields.IntegerField(read_only=True),
|
||||
'rel_object': fields.ObjectField('MyOwnedObject', nullable=True),
|
||||
'rel_objects': fields.ListOfObjectsField('MyOwnedObject',
|
||||
nullable=True),
|
||||
'mutable_default': fields.ListOfStringsField(default=[]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _from_db_object(context, obj, db_obj):
|
||||
self = MyObj()
|
||||
self.foo = db_obj['foo']
|
||||
self.bar = db_obj['bar']
|
||||
self.missing = db_obj['missing']
|
||||
self.readonly = 1
|
||||
self._context = context
|
||||
return self
|
||||
|
||||
def obj_load_attr(self, attrname):
|
||||
setattr(self, attrname, 'loaded!')
|
||||
|
||||
def query(cls, context):
|
||||
obj = cls(context=context, foo=1, bar='bar')
|
||||
obj.obj_reset_changes()
|
||||
return obj
|
||||
|
||||
def marco(self):
|
||||
return 'polo'
|
||||
|
||||
def _update_test(self):
|
||||
self.bar = 'updated'
|
||||
|
||||
def save(self):
|
||||
self.obj_reset_changes()
|
||||
|
||||
def refresh(self):
|
||||
self.foo = 321
|
||||
self.bar = 'refreshed'
|
||||
self.obj_reset_changes()
|
||||
|
||||
def modify_save_modify(self):
|
||||
self.bar = 'meow'
|
||||
self.save()
|
||||
self.foo = 42
|
||||
self.rel_object = MyOwnedObject(baz=42)
|
||||
|
||||
def obj_make_compatible(self, primitive, target_version):
|
||||
super(MyObj, self).obj_make_compatible(primitive, target_version)
|
||||
# NOTE(danms): Simulate an older version that had a different
|
||||
# format for the 'bar' attribute
|
||||
if target_version == '1.1' and 'bar' in primitive:
|
||||
primitive['bar'] = 'old%s' % primitive['bar']
|
||||
|
||||
|
||||
class RandomMixInWithNoFields(object):
|
||||
"""Used to test object inheritance using a mixin that has no fields."""
|
||||
pass
|
||||
|
||||
|
||||
@base.CyborgObjectRegistry.register_if(False)
|
||||
class TestSubclassedObject(RandomMixInWithNoFields, MyObj):
|
||||
fields = {'new_field': fields.StringField()}
|
||||
|
||||
|
||||
class TestObjToPrimitive(test.base.TestCase):
|
||||
|
||||
def test_obj_to_primitive_list(self):
|
||||
@base.CyborgObjectRegistry.register_if(False)
|
||||
class MyObjElement(base.CyborgObject):
|
||||
fields = {'foo': fields.IntegerField()}
|
||||
|
||||
def __init__(self, foo):
|
||||
super(MyObjElement, self).__init__()
|
||||
self.foo = foo
|
||||
|
||||
@base.CyborgObjectRegistry.register_if(False)
|
||||
class MyList(base.ObjectListBase, base.CyborgObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyObjElement')}
|
||||
|
||||
mylist = MyList()
|
||||
mylist.objects = [MyObjElement(1), MyObjElement(2), MyObjElement(3)]
|
||||
self.assertEqual([1, 2, 3],
|
||||
[x['foo'] for x in base.obj_to_primitive(mylist)])
|
||||
|
||||
def test_obj_to_primitive_dict(self):
|
||||
base.CyborgObjectRegistry.register(MyObj)
|
||||
myobj = MyObj(foo=1, bar='foo')
|
||||
self.assertEqual({'foo': 1, 'bar': 'foo'},
|
||||
base.obj_to_primitive(myobj))
|
||||
|
||||
def test_obj_to_primitive_recursive(self):
|
||||
base.CyborgObjectRegistry.register(MyObj)
|
||||
|
||||
class MyList(base.ObjectListBase, base.CyborgObject):
|
||||
fields = {'objects': fields.ListOfObjectsField('MyObj')}
|
||||
|
||||
mylist = MyList(objects=[MyObj(), MyObj()])
|
||||
for i, value in enumerate(mylist):
|
||||
value.foo = i
|
||||
self.assertEqual([{'foo': 0}, {'foo': 1}],
|
||||
base.obj_to_primitive(mylist))
|
||||
|
||||
def test_obj_to_primitive_with_ip_addr(self):
|
||||
@base.CyborgObjectRegistry.register_if(False)
|
||||
class TestObject(base.CyborgObject):
|
||||
fields = {'addr': fields.IPAddressField(),
|
||||
'cidr': fields.IPNetworkField()}
|
||||
|
||||
obj = TestObject(addr='1.2.3.4', cidr='1.1.1.1/16')
|
||||
self.assertEqual({'addr': '1.2.3.4', 'cidr': '1.1.1.1/16'},
|
||||
base.obj_to_primitive(obj))
|
||||
|
||||
|
||||
def compare_obj(test, obj, db_obj, subs=None, allow_missing=None,
|
||||
comparators=None):
|
||||
"""Compare a CyborgObject and a dict-like database object.
|
||||
|
||||
This automatically converts TZ-aware datetimes and iterates over
|
||||
the fields of the object.
|
||||
|
||||
:param:test: The TestCase doing the comparison
|
||||
:param:obj: The CyborgObject to examine
|
||||
:param:db_obj: The dict-like database object to use as reference
|
||||
:param:subs: A dict of objkey=dbkey field substitutions
|
||||
:param:allow_missing: A list of fields that may not be in db_obj
|
||||
:param:comparators: Map of comparator functions to use for certain fields
|
||||
"""
|
||||
|
||||
if subs is None:
|
||||
subs = {}
|
||||
if allow_missing is None:
|
||||
allow_missing = []
|
||||
if comparators is None:
|
||||
comparators = {}
|
||||
|
||||
for key in obj.fields:
|
||||
if key in allow_missing and not obj.obj_attr_is_set(key):
|
||||
continue
|
||||
obj_val = getattr(obj, key)
|
||||
db_key = subs.get(key, key)
|
||||
db_val = db_obj[db_key]
|
||||
if isinstance(obj_val, datetime.datetime):
|
||||
obj_val = obj_val.replace(tzinfo=None)
|
||||
|
||||
if key in comparators:
|
||||
comparator = comparators[key]
|
||||
comparator(db_val, obj_val)
|
||||
else:
|
||||
test.assertEqual(db_val, obj_val)
|
||||
|
||||
|
||||
class _BaseTestCase(test.base.TestCase):
|
||||
def setUp(self):
|
||||
super(_BaseTestCase, self).setUp()
|
||||
self.user_id = 'fake-user'
|
||||
self.project_id = 'fake-project'
|
||||
self.context = context.RequestContext(self.user_id, self.project_id)
|
||||
|
||||
base.CyborgObjectRegistry.register(MyObj)
|
||||
base.CyborgObjectRegistry.register(MyOwnedObject)
|
||||
|
||||
def compare_obj(self, obj, db_obj, subs=None, allow_missing=None,
|
||||
comparators=None):
|
||||
compare_obj(self, obj, db_obj, subs=subs, allow_missing=allow_missing,
|
||||
comparators=comparators)
|
||||
|
||||
def str_comparator(self, expected, obj_val):
|
||||
"""Compare an object field to a string in the db by performing
|
||||
a simple coercion on the object field value.
|
||||
"""
|
||||
self.assertEqual(expected, str(obj_val))
|
||||
|
||||
|
||||
class _LocalTest(_BaseTestCase):
|
||||
def setUp(self):
|
||||
super(_LocalTest, self).setUp()
|
Loading…
Reference in New Issue