Modified the Deployable Object

based on the new DB scheme

Change-Id: I35e09416f2f6b267b029376671cb6c255d40c737
This commit is contained in:
Li Liu 2019-02-21 23:45:05 -05:00
parent cef089f12a
commit e6028de00f
10 changed files with 602 additions and 108 deletions

View File

@ -103,15 +103,6 @@ class ConductorManager(object):
"""
return objects.Deployable.get(context, uuid)
def deployable_get_by_host(self, context, host):
"""Retrieve a deployable.
:param context: request context.
:param host: host on which the deployable is located.
:returns: requested deployable object.
"""
return objects.Deployable.get_by_host(context, host)
def deployable_list(self, context):
"""Retrieve a list of deployables.

View File

@ -136,16 +136,6 @@ class ConductorAPI(object):
cctxt = self.client.prepare(topic=self.topic)
return cctxt.call(context, 'deployable_get', uuid=uuid)
def deployable_get_by_host(self, context, host):
"""Signal to conductor service to get a deployable by host.
:param context: request context.
:param host: host on which the deployable is located.
:returns: requested deployable object.
"""
cctxt = self.client.prepare(topic=self.topic)
return cctxt.call(context, 'deployable_get_by_host', host=host)
def deployable_list(self, context):
"""Signal to conductor service to get a list of deployables.

View File

@ -110,10 +110,6 @@ class Connection(object):
def deployable_get(self, context, uuid):
"""Get requested deployable."""
@abc.abstractmethod
def deployable_get_by_host(self, context, host):
"""Get requested deployable by host."""
@abc.abstractmethod
def deployable_list(self, context):
"""Get requested list of deployables."""

View File

@ -532,12 +532,6 @@ class Connection(api.Connection):
except NoResultFound:
raise exception.DeployableNotFound(uuid=uuid)
def deployable_get_by_host(self, context, host):
query = model_query(
context,
models.Deployable).filter_by(host=host)
return query.all()
def deployable_list(self, context):
query = model_query(context, models.Deployable)
return query.all()
@ -572,7 +566,7 @@ class Connection(api.Connection):
with _session_for_write():
query = model_query(context, models.Deployable)
query = add_identity_filter(query, uuid)
query.update({'root_uuid': None})
query.update({'root_id': None})
count = query.delete()
if count != 1:
raise exception.DeployableNotFound(uuid=uuid)
@ -580,7 +574,7 @@ class Connection(api.Connection):
def deployable_get_by_filters_with_attributes(self, context,
filters):
exact_match_filter_names = ['uuid', 'name',
exact_match_filter_names = ['id', 'uuid', 'name',
'parent_id', 'root_id',
'num_accelerators', 'device_id']
attribute_filters = {}
@ -680,7 +674,7 @@ class Connection(api.Connection):
query_prefix = model_query(context, models.Deployable)
filters = copy.deepcopy(filters)
exact_match_filter_names = ['uuid', 'name',
exact_match_filter_names = ['id', 'uuid', 'name',
'parent_id', 'root_id',
'num_accelerators', 'device_id']

View File

@ -28,7 +28,7 @@ LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
# Version 1.0: Initial version
VERSION = '1.0'
VERSION = '2.0'
dbapi = dbapi.get_instance()
attributes_list = []
@ -36,35 +36,21 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
fields = {
'id': object_fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False),
'parent_id': object_fields.IntegerField(nullable=True),
# parent_id refers to the id of the deployable's parent node
'root_id': object_fields.IntegerField(nullable=True),
# root_id refers to the id of the deployable's root to for nested tree
'name': object_fields.StringField(nullable=False),
'parent_uuid': object_fields.UUIDField(nullable=True),
# parent_uuid refers to the id of the VF's parent node
'root_uuid': object_fields.UUIDField(nullable=True),
# root_uuid refers to the id of the VF's root which has to be a PF
'address': object_fields.StringField(nullable=False),
# if interface_type is pci(/mdev), address is the pci_address(/path)
'host': object_fields.StringField(nullable=False),
'board': object_fields.StringField(nullable=False),
# board refers to a specific acc board type, e.g P100 GPU card
'vendor': object_fields.StringField(nullable=False),
'version': object_fields.StringField(nullable=False),
'type': object_fields.StringField(nullable=False),
# type of deployable, e.g, pf/vf/*f
'interface_type': object_fields.StringField(nullable=False),
# interface to hypervisor(libvirt), e.g, pci/mdev...
'assignable': object_fields.BooleanField(nullable=False),
# identify if an accelerator is in use
'instance_uuid': object_fields.UUIDField(nullable=True),
# The id of the virtualized accelerator instance
'availability': object_fields.StringField(nullable=False),
# identify the state of acc, e.g released/claimed/...
'accelerator_id': object_fields.IntegerField(nullable=False)
# Foreign key constrain to reference accelerator table
# name of the deplyable
'num_accelerators': object_fields.IntegerField(nullable=False),
# number of accelerators spawned by this deplyable
'device_id': object_fields.IntegerField(nullable=False)
# Foreign key constrain to reference device table
}
def _get_parent_root_uuid(self):
obj_dep = Deployable.get(None, self.parent_uuid)
return obj_dep.root_uuid
def _get_parent_root_uuid(self, context):
obj_dep = Deployable.get_by_id(context, self.parent_id)
return obj_dep.root_id
def create(self, context):
"""Create a Deployable record in the DB."""
@ -72,40 +58,40 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
raise exception.ObjectActionError(action='create',
reason='uuid is required')
if not hasattr(self, 'parent_uuid') or self.parent_uuid is None:
self.root_uuid = self.uuid
if not hasattr(self, 'parent_id') or self.parent_id is None:
self.root_id = self.id
else:
self.root_uuid = self._get_parent_root_uuid()
self.root_uuid = self._get_parent_root_uuid(context)
values = self.obj_get_changes()
db_dep = self.dbapi.deployable_create(context, values)
self._from_db_object(self, db_dep)
self.obj_reset_changes()
del self.attributes_list[:]
@classmethod
def get(cls, context, uuid):
def get(cls, context, uuid, with_attribute_list=True):
"""Find a DB Deployable and return an Obj Deployable."""
db_dep = cls.dbapi.deployable_get(context, uuid)
obj_dep = cls._from_db_object(cls(context), db_dep)
# retrieve all the attrobutes for this deployable
query = {"deployable_id": obj_dep.id}
attr_get_list = Attribute.get_by_filter(context,
query)
obj_dep.attributes_list = attr_get_list
if with_attribute_list:
query = {"deployable_id": obj_dep.id}
attr_get_list = Attribute.get_by_filter(context,
query)
obj_dep.attributes_list = attr_get_list
obj_dep.obj_reset_changes()
return obj_dep
@classmethod
def get_by_host(cls, context, host):
"""Get a Deployable by host."""
db_deps = cls.dbapi.deployable_get_by_host(context, host)
obj_dpl_list = cls._from_db_object_list(db_deps, context)
for obj_dpl in obj_dpl_list:
query = {"deployable_id": obj_dpl.id}
attr_get_list = Attribute.get_by_filter(context,
query)
obj_dpl.attributes_list = attr_get_list
return obj_dpl_list
def get_by_id(cls, context, id):
"""Find a DB Deployable and return an Obj Deployable."""
dpl_query = {"id": id}
obj_dep = Deployable.get_by_filter(context, id)[0]
obj_dep.obj_reset_changes()
return obj_dep
@classmethod
def list(cls, context, filters={}):
@ -134,6 +120,7 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
"""Update a Deployable record in the DB."""
updates = self.obj_get_changes()
db_dep = self.dbapi.deployable_update(context, self.uuid, updates)
self.obj_reset_changes()
self._from_db_object(self, db_dep)
query = {"deployable_id": self.id}
attr_get_list = Attribute.get_by_filter(context,
@ -202,15 +189,16 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
return obj_dpl_list
@classmethod
def _from_db_object(cls, obj, db_obj):
@staticmethod
def _from_db_object(obj, db_obj):
"""Converts a deployable to a formal object.
:param obj: An object of the class.
:param db_obj: A DB model of the object
:return: The object of the class with the database entity added
"""
obj = base.CyborgObject._from_db_object(obj, db_obj)
for field in obj.fields:
obj[field] = db_obj[field]
obj.attributes_list = []
return obj

View File

@ -33,21 +33,15 @@ def get_test_accelerator(**kw):
def get_test_deployable(**kw):
return {
'uuid': kw.get('uuid', '10efe63d-dfea-4a37-ad94-4116fba5098'),
'deleted': False,
'id': kw.get('id', 1),
'uuid': kw.get('uuid', '10efe63d-dfea-4a37-ad94-4116fba5011'),
'parent_id': kw.get('parent_id', None),
'root_id': kw.get('root_id', 0),
'name': kw.get('name', 'name'),
'parent_uuid': kw.get('parent_uuid', None),
'address': kw.get('address', '00:7f:0b.2'),
'host': kw.get('host', 'host'),
'board': kw.get('board', 'KU115'),
'vendor': kw.get('vendor', 'Xilinx'),
'version': kw.get('version', '1.0'),
'type': kw.get('type', '1.0'),
'interface_type': 'pci',
'assignable': True,
'instance_uuid': None,
'availability': 'Available',
'accelerator_id': kw.get('accelerator_id', 1),
'num_accelerators': kw.get('num_accelerators', 4),
'device_id': kw.get('device_id', 0),
'created_at': kw.get('created_at', None),
'updated_at': kw.get('updated_at', None)
}

View File

@ -25,22 +25,12 @@ def fake_db_deployable(**updates):
root_uuid = uuidutils.generate_uuid()
db_deployable = {
'id': 1,
'deleted': False,
'uuid': root_uuid,
'name': 'dp_name',
'parent_uuid': None,
'root_uuid': root_uuid,
'address': '00:7f:0b.2',
'host': 'host_name',
'board': 'KU115',
'vendor': 'Xilinx',
'version': '1.0',
'type': 'pf',
'interface_type': 'pci',
'assignable': True,
'instance_uuid': None,
'availability': 'Available',
'accelerator_id': 1
'parent_id': None,
'root_id': 1,
'num_accelerators': 4,
'device_id': 0
}
for name, field in objects.Deployable.fields.items():

View File

@ -0,0 +1,59 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
def fake_db_device(**updates):
root_uuid = uuidutils.generate_uuid()
db_device = {
'id': 1,
'uuid': root_uuid,
'type': 'FPGA',
'vendor': "vendor",
'model': "model",
'std_board_info': "std_board_info",
'vendor_board_info': "vendor_board_info",
'hostname': "hostname"
}
for name, field in objects.Device.fields.items():
if name in db_device:
continue
if field.nullable:
db_device[name] = None
elif field.default != fields.UnspecifiedDefault:
db_device[name] = field.default
else:
raise Exception('fake_db_device needs help with %s' % name)
if updates:
db_device.update(updates)
return db_device
def fake_device_obj(context, obj_device_class=None, **updates):
if obj_device_class is None:
obj_device_class = objects.Device
device = obj_device_class._from_db_object(obj_device_class(),
fake_db_device(**updates))
device.obj_reset_changes()
return device

View File

@ -0,0 +1,265 @@
# Copyright 2019 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from testtools.matchers import HasLength
from cyborg import objects
from cyborg.tests.unit.db import base
from cyborg.tests.unit.db import utils
from cyborg.tests.unit import fake_device
from cyborg.tests.unit import fake_deployable
from cyborg.tests.unit import fake_attribute
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
from cyborg.common import exception
class _TestDeployableObject(DbTestCase):
@property
def fake_device(self):
db_device = fake_device.fake_db_device(id=1)
return db_device
@property
def fake_deployable(self):
db_deploy = fake_deployable.fake_db_deployable(id=1)
return db_deploy
@property
def fake_deployable2(self):
db_deploy = fake_deployable.fake_db_deployable(id=2)
return db_deploy
@property
def fake_attribute(self):
db_attr = fake_attribute.fake_db_attribute(id=2)
return db_attr
@property
def fake_attribute2(self):
db_attr = fake_attribute.fake_db_attribute(id=3)
return db_attr
@property
def fake_attribute3(self):
db_attr = fake_attribute.fake_db_attribute(id=4)
return db_attr
def test_create(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
self.assertEqual(db_dpl['uuid'], dpl.uuid)
def test_get(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(dpl_get.uuid, dpl.uuid)
def test_get_by_filter(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
query = {"uuid": dpl['uuid']}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(dpl_get_list[0].uuid, dpl.uuid)
def test_save(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
dpl.num_accelerators = 8
dpl.save(self.context)
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(dpl_get.num_accelerators, 8)
def test_destroy(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
self.assertEqual(db_dpl['uuid'], dpl.uuid)
dpl.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.Deployable.get, self.context,
dpl.uuid)
def test_add_attribute(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
db_attr = self.fake_attribute
dpl.add_attribute(self.context, db_attr['key'], db_attr['value'])
dpl.save(self.context)
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(len(dpl_get.attributes_list), 1)
def test_delete_attribute(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
db_attr = self.fake_attribute
dpl_get.add_attribute(self.context, db_attr['key'], db_attr['value'])
dpl_get.save(self.context)
dpl_get = objects.Deployable.get(self.context, dpl_get.uuid)
self.assertEqual(len(dpl_get.attributes_list), 1)
dpl_get.delete_attribute(self.context, dpl_get.attributes_list[0])
self.assertEqual(len(dpl_get.attributes_list), 0)
def test_get_by_filter_with_attributes(self):
db_device = self.fake_device
device = objects.Device(context=self.context,
**db_device)
device.create(self.context)
device_get = objects.Device.get(self.context, device.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.device_id = device_get.id
dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
db_dpl2 = self.fake_deployable2
dpl2 = objects.Deployable(context=self.context,
**db_dpl2)
dpl2.device_id = device_get.id
dpl2.create(self.context)
dpl2_get = objects.Deployable.get(self.context, dpl2.uuid)
db_attr = self.fake_attribute
db_attr2 = self.fake_attribute2
db_attr3 = self.fake_attribute3
dpl.add_attribute(self.context, 'attr_key', 'attr_val')
dpl.save(self.context)
dpl2.add_attribute(self.context, 'test_key', 'test_val')
dpl2.add_attribute(self.context, 'test_key3', 'test_val3')
dpl2.save(self.context)
query = {"attr_key": "attr_val"}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(len(dpl_get_list), 1)
self.assertEqual(dpl_get_list[0].uuid, dpl.uuid)
query = {"test_key": "test_val"}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(len(dpl_get_list), 1)
self.assertEqual(dpl_get_list[0].uuid, dpl2.uuid)
query = {"test_key": "test_val", "test_key3": "test_val3"}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(len(dpl_get_list), 1)
self.assertEqual(dpl_get_list[0].uuid, dpl2.uuid)
query = {"num_accelerators": 4, "test_key3": "test_val3"}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(len(dpl_get_list), 1)
self.assertEqual(dpl_get_list[0].uuid, dpl2.uuid)
class TestDeployableObject(test_objects._LocalTest,
_TestDeployableObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
deployable = fake_deployable.fake_deployable_obj(self.context)
fields_with_save_methods = [field for field in deployable.fields
if hasattr(deployable, '_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(deployable, '_save_%s' % field)
@mock.patch.object(deployable, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
deployable.obj_reset_changes(fields=[field])
deployable._changed_fields.add(field)
self.assertRaises(expected_exception, deployable.save)
deployable.obj_reset_changes(fields=[field])
_test()

View File

@ -0,0 +1,227 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import copy
import datetime
import inspect
import os
import fixtures
import mock
from oslo_log import log
from oslo_utils import timeutils
from oslo_versionedobjects import base as ovo_base
from oslo_versionedobjects import exception as ovo_exc
from oslo_versionedobjects import fixture
import six
from oslo_context import context
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg.objects import fields
from cyborg import tests as test
LOG = log.getLogger(__name__)
class MyOwnedObject(base.CyborgPersistentObject, base.CyborgObject):
VERSION = '1.0'
fields = {'baz': fields.IntegerField()}
class MyObj(base.CyborgPersistentObject, base.CyborgObject,
base.CyborgObjectDictCompat):
VERSION = '1.6'
fields = {'foo': fields.IntegerField(default=1),
'bar': fields.StringField(),
'missing': fields.StringField(),
'readonly': fields.IntegerField(read_only=True),
'rel_object': fields.ObjectField('MyOwnedObject', nullable=True),
'rel_objects': fields.ListOfObjectsField('MyOwnedObject',
nullable=True),
'mutable_default': fields.ListOfStringsField(default=[]),
}
@staticmethod
def _from_db_object(context, obj, db_obj):
self = MyObj()
self.foo = db_obj['foo']
self.bar = db_obj['bar']
self.missing = db_obj['missing']
self.readonly = 1
self._context = context
return self
def obj_load_attr(self, attrname):
setattr(self, attrname, 'loaded!')
def query(cls, context):
obj = cls(context=context, foo=1, bar='bar')
obj.obj_reset_changes()
return obj
def marco(self):
return 'polo'
def _update_test(self):
self.bar = 'updated'
def save(self):
self.obj_reset_changes()
def refresh(self):
self.foo = 321
self.bar = 'refreshed'
self.obj_reset_changes()
def modify_save_modify(self):
self.bar = 'meow'
self.save()
self.foo = 42
self.rel_object = MyOwnedObject(baz=42)
def obj_make_compatible(self, primitive, target_version):
super(MyObj, self).obj_make_compatible(primitive, target_version)
# NOTE(danms): Simulate an older version that had a different
# format for the 'bar' attribute
if target_version == '1.1' and 'bar' in primitive:
primitive['bar'] = 'old%s' % primitive['bar']
class RandomMixInWithNoFields(object):
"""Used to test object inheritance using a mixin that has no fields."""
pass
@base.CyborgObjectRegistry.register_if(False)
class TestSubclassedObject(RandomMixInWithNoFields, MyObj):
fields = {'new_field': fields.StringField()}
class TestObjToPrimitive(test.base.TestCase):
def test_obj_to_primitive_list(self):
@base.CyborgObjectRegistry.register_if(False)
class MyObjElement(base.CyborgObject):
fields = {'foo': fields.IntegerField()}
def __init__(self, foo):
super(MyObjElement, self).__init__()
self.foo = foo
@base.CyborgObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.CyborgObject):
fields = {'objects': fields.ListOfObjectsField('MyObjElement')}
mylist = MyList()
mylist.objects = [MyObjElement(1), MyObjElement(2), MyObjElement(3)]
self.assertEqual([1, 2, 3],
[x['foo'] for x in base.obj_to_primitive(mylist)])
def test_obj_to_primitive_dict(self):
base.CyborgObjectRegistry.register(MyObj)
myobj = MyObj(foo=1, bar='foo')
self.assertEqual({'foo': 1, 'bar': 'foo'},
base.obj_to_primitive(myobj))
def test_obj_to_primitive_recursive(self):
base.CyborgObjectRegistry.register(MyObj)
class MyList(base.ObjectListBase, base.CyborgObject):
fields = {'objects': fields.ListOfObjectsField('MyObj')}
mylist = MyList(objects=[MyObj(), MyObj()])
for i, value in enumerate(mylist):
value.foo = i
self.assertEqual([{'foo': 0}, {'foo': 1}],
base.obj_to_primitive(mylist))
def test_obj_to_primitive_with_ip_addr(self):
@base.CyborgObjectRegistry.register_if(False)
class TestObject(base.CyborgObject):
fields = {'addr': fields.IPAddressField(),
'cidr': fields.IPNetworkField()}
obj = TestObject(addr='1.2.3.4', cidr='1.1.1.1/16')
self.assertEqual({'addr': '1.2.3.4', 'cidr': '1.1.1.1/16'},
base.obj_to_primitive(obj))
def compare_obj(test, obj, db_obj, subs=None, allow_missing=None,
comparators=None):
"""Compare a CyborgObject and a dict-like database object.
This automatically converts TZ-aware datetimes and iterates over
the fields of the object.
:param:test: The TestCase doing the comparison
:param:obj: The CyborgObject to examine
:param:db_obj: The dict-like database object to use as reference
:param:subs: A dict of objkey=dbkey field substitutions
:param:allow_missing: A list of fields that may not be in db_obj
:param:comparators: Map of comparator functions to use for certain fields
"""
if subs is None:
subs = {}
if allow_missing is None:
allow_missing = []
if comparators is None:
comparators = {}
for key in obj.fields:
if key in allow_missing and not obj.obj_attr_is_set(key):
continue
obj_val = getattr(obj, key)
db_key = subs.get(key, key)
db_val = db_obj[db_key]
if isinstance(obj_val, datetime.datetime):
obj_val = obj_val.replace(tzinfo=None)
if key in comparators:
comparator = comparators[key]
comparator(db_val, obj_val)
else:
test.assertEqual(db_val, obj_val)
class _BaseTestCase(test.base.TestCase):
def setUp(self):
super(_BaseTestCase, self).setUp()
self.user_id = 'fake-user'
self.project_id = 'fake-project'
self.context = context.RequestContext(self.user_id, self.project_id)
base.CyborgObjectRegistry.register(MyObj)
base.CyborgObjectRegistry.register(MyOwnedObject)
def compare_obj(self, obj, db_obj, subs=None, allow_missing=None,
comparators=None):
compare_obj(self, obj, db_obj, subs=subs, allow_missing=allow_missing,
comparators=comparators)
def str_comparator(self, expected, obj_val):
"""Compare an object field to a string in the db by performing
a simple coercion on the object field value.
"""
self.assertEqual(expected, str(obj_val))
class _LocalTest(_BaseTestCase):
def setUp(self):
super(_LocalTest, self).setUp()