From e6028de00f7466c82a28329e4302afe46c682386 Mon Sep 17 00:00:00 2001 From: Li Liu Date: Thu, 21 Feb 2019 23:45:05 -0500 Subject: [PATCH] Modified the Deployable Object based on the new DB scheme Change-Id: I35e09416f2f6b267b029376671cb6c255d40c737 --- cyborg/conductor/manager.py | 9 - cyborg/conductor/rpcapi.py | 10 - cyborg/db/api.py | 4 - cyborg/db/sqlalchemy/api.py | 12 +- cyborg/objects/deployable.py | 84 +++--- cyborg/tests/unit/db/utils.py | 22 +- cyborg/tests/unit/fake_deployable.py | 18 +- cyborg/tests/unit/fake_device.py | 59 +++++ cyborg/tests/unit/objects/test_deployable.py | 265 +++++++++++++++++++ cyborg/tests/unit/objects/test_objects.py | 227 ++++++++++++++++ 10 files changed, 602 insertions(+), 108 deletions(-) create mode 100644 cyborg/tests/unit/fake_device.py create mode 100644 cyborg/tests/unit/objects/test_deployable.py create mode 100644 cyborg/tests/unit/objects/test_objects.py diff --git a/cyborg/conductor/manager.py b/cyborg/conductor/manager.py index 47172544..5606732c 100644 --- a/cyborg/conductor/manager.py +++ b/cyborg/conductor/manager.py @@ -103,15 +103,6 @@ class ConductorManager(object): """ return objects.Deployable.get(context, uuid) - def deployable_get_by_host(self, context, host): - """Retrieve a deployable. - - :param context: request context. - :param host: host on which the deployable is located. - :returns: requested deployable object. - """ - return objects.Deployable.get_by_host(context, host) - def deployable_list(self, context): """Retrieve a list of deployables. diff --git a/cyborg/conductor/rpcapi.py b/cyborg/conductor/rpcapi.py index bc663f53..866a9185 100644 --- a/cyborg/conductor/rpcapi.py +++ b/cyborg/conductor/rpcapi.py @@ -136,16 +136,6 @@ class ConductorAPI(object): cctxt = self.client.prepare(topic=self.topic) return cctxt.call(context, 'deployable_get', uuid=uuid) - def deployable_get_by_host(self, context, host): - """Signal to conductor service to get a deployable by host. - - :param context: request context. - :param host: host on which the deployable is located. - :returns: requested deployable object. - """ - cctxt = self.client.prepare(topic=self.topic) - return cctxt.call(context, 'deployable_get_by_host', host=host) - def deployable_list(self, context): """Signal to conductor service to get a list of deployables. diff --git a/cyborg/db/api.py b/cyborg/db/api.py index 4fb4e8ce..c6622a46 100644 --- a/cyborg/db/api.py +++ b/cyborg/db/api.py @@ -110,10 +110,6 @@ class Connection(object): def deployable_get(self, context, uuid): """Get requested deployable.""" - @abc.abstractmethod - def deployable_get_by_host(self, context, host): - """Get requested deployable by host.""" - @abc.abstractmethod def deployable_list(self, context): """Get requested list of deployables.""" diff --git a/cyborg/db/sqlalchemy/api.py b/cyborg/db/sqlalchemy/api.py index 9db96a7a..6075304b 100644 --- a/cyborg/db/sqlalchemy/api.py +++ b/cyborg/db/sqlalchemy/api.py @@ -532,12 +532,6 @@ class Connection(api.Connection): except NoResultFound: raise exception.DeployableNotFound(uuid=uuid) - def deployable_get_by_host(self, context, host): - query = model_query( - context, - models.Deployable).filter_by(host=host) - return query.all() - def deployable_list(self, context): query = model_query(context, models.Deployable) return query.all() @@ -572,7 +566,7 @@ class Connection(api.Connection): with _session_for_write(): query = model_query(context, models.Deployable) query = add_identity_filter(query, uuid) - query.update({'root_uuid': None}) + query.update({'root_id': None}) count = query.delete() if count != 1: raise exception.DeployableNotFound(uuid=uuid) @@ -580,7 +574,7 @@ class Connection(api.Connection): def deployable_get_by_filters_with_attributes(self, context, filters): - exact_match_filter_names = ['uuid', 'name', + exact_match_filter_names = ['id', 'uuid', 'name', 'parent_id', 'root_id', 'num_accelerators', 'device_id'] attribute_filters = {} @@ -680,7 +674,7 @@ class Connection(api.Connection): query_prefix = model_query(context, models.Deployable) filters = copy.deepcopy(filters) - exact_match_filter_names = ['uuid', 'name', + exact_match_filter_names = ['id', 'uuid', 'name', 'parent_id', 'root_id', 'num_accelerators', 'device_id'] diff --git a/cyborg/objects/deployable.py b/cyborg/objects/deployable.py index 3f3b3bbe..74bb236f 100644 --- a/cyborg/objects/deployable.py +++ b/cyborg/objects/deployable.py @@ -28,7 +28,7 @@ LOG = logging.getLogger(__name__) @base.CyborgObjectRegistry.register class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): # Version 1.0: Initial version - VERSION = '1.0' + VERSION = '2.0' dbapi = dbapi.get_instance() attributes_list = [] @@ -36,35 +36,21 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): fields = { 'id': object_fields.IntegerField(nullable=False), 'uuid': object_fields.UUIDField(nullable=False), + 'parent_id': object_fields.IntegerField(nullable=True), + # parent_id refers to the id of the deployable's parent node + 'root_id': object_fields.IntegerField(nullable=True), + # root_id refers to the id of the deployable's root to for nested tree 'name': object_fields.StringField(nullable=False), - 'parent_uuid': object_fields.UUIDField(nullable=True), - # parent_uuid refers to the id of the VF's parent node - 'root_uuid': object_fields.UUIDField(nullable=True), - # root_uuid refers to the id of the VF's root which has to be a PF - 'address': object_fields.StringField(nullable=False), - # if interface_type is pci(/mdev), address is the pci_address(/path) - 'host': object_fields.StringField(nullable=False), - 'board': object_fields.StringField(nullable=False), - # board refers to a specific acc board type, e.g P100 GPU card - 'vendor': object_fields.StringField(nullable=False), - 'version': object_fields.StringField(nullable=False), - 'type': object_fields.StringField(nullable=False), - # type of deployable, e.g, pf/vf/*f - 'interface_type': object_fields.StringField(nullable=False), - # interface to hypervisor(libvirt), e.g, pci/mdev... - 'assignable': object_fields.BooleanField(nullable=False), - # identify if an accelerator is in use - 'instance_uuid': object_fields.UUIDField(nullable=True), - # The id of the virtualized accelerator instance - 'availability': object_fields.StringField(nullable=False), - # identify the state of acc, e.g released/claimed/... - 'accelerator_id': object_fields.IntegerField(nullable=False) - # Foreign key constrain to reference accelerator table + # name of the deplyable + 'num_accelerators': object_fields.IntegerField(nullable=False), + # number of accelerators spawned by this deplyable + 'device_id': object_fields.IntegerField(nullable=False) + # Foreign key constrain to reference device table } - def _get_parent_root_uuid(self): - obj_dep = Deployable.get(None, self.parent_uuid) - return obj_dep.root_uuid + def _get_parent_root_uuid(self, context): + obj_dep = Deployable.get_by_id(context, self.parent_id) + return obj_dep.root_id def create(self, context): """Create a Deployable record in the DB.""" @@ -72,40 +58,40 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): raise exception.ObjectActionError(action='create', reason='uuid is required') - if not hasattr(self, 'parent_uuid') or self.parent_uuid is None: - self.root_uuid = self.uuid + if not hasattr(self, 'parent_id') or self.parent_id is None: + self.root_id = self.id else: - self.root_uuid = self._get_parent_root_uuid() + self.root_uuid = self._get_parent_root_uuid(context) values = self.obj_get_changes() db_dep = self.dbapi.deployable_create(context, values) self._from_db_object(self, db_dep) + self.obj_reset_changes() del self.attributes_list[:] @classmethod - def get(cls, context, uuid): + def get(cls, context, uuid, with_attribute_list=True): """Find a DB Deployable and return an Obj Deployable.""" db_dep = cls.dbapi.deployable_get(context, uuid) obj_dep = cls._from_db_object(cls(context), db_dep) # retrieve all the attrobutes for this deployable - query = {"deployable_id": obj_dep.id} - attr_get_list = Attribute.get_by_filter(context, - query) - obj_dep.attributes_list = attr_get_list + if with_attribute_list: + query = {"deployable_id": obj_dep.id} + attr_get_list = Attribute.get_by_filter(context, + query) + obj_dep.attributes_list = attr_get_list + + obj_dep.obj_reset_changes() return obj_dep @classmethod - def get_by_host(cls, context, host): - """Get a Deployable by host.""" - db_deps = cls.dbapi.deployable_get_by_host(context, host) - obj_dpl_list = cls._from_db_object_list(db_deps, context) - for obj_dpl in obj_dpl_list: - query = {"deployable_id": obj_dpl.id} - attr_get_list = Attribute.get_by_filter(context, - query) - obj_dpl.attributes_list = attr_get_list - return obj_dpl_list + def get_by_id(cls, context, id): + """Find a DB Deployable and return an Obj Deployable.""" + dpl_query = {"id": id} + obj_dep = Deployable.get_by_filter(context, id)[0] + obj_dep.obj_reset_changes() + return obj_dep @classmethod def list(cls, context, filters={}): @@ -134,6 +120,7 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): """Update a Deployable record in the DB.""" updates = self.obj_get_changes() db_dep = self.dbapi.deployable_update(context, self.uuid, updates) + self.obj_reset_changes() self._from_db_object(self, db_dep) query = {"deployable_id": self.id} attr_get_list = Attribute.get_by_filter(context, @@ -202,15 +189,16 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): return obj_dpl_list - @classmethod - def _from_db_object(cls, obj, db_obj): + @staticmethod + def _from_db_object(obj, db_obj): """Converts a deployable to a formal object. :param obj: An object of the class. :param db_obj: A DB model of the object :return: The object of the class with the database entity added """ - obj = base.CyborgObject._from_db_object(obj, db_obj) + for field in obj.fields: + obj[field] = db_obj[field] obj.attributes_list = [] return obj diff --git a/cyborg/tests/unit/db/utils.py b/cyborg/tests/unit/db/utils.py index 9eed2e8e..47d535aa 100644 --- a/cyborg/tests/unit/db/utils.py +++ b/cyborg/tests/unit/db/utils.py @@ -33,21 +33,15 @@ def get_test_accelerator(**kw): def get_test_deployable(**kw): return { - 'uuid': kw.get('uuid', '10efe63d-dfea-4a37-ad94-4116fba5098'), - 'deleted': False, + 'id': kw.get('id', 1), + 'uuid': kw.get('uuid', '10efe63d-dfea-4a37-ad94-4116fba5011'), + 'parent_id': kw.get('parent_id', None), + 'root_id': kw.get('root_id', 0), 'name': kw.get('name', 'name'), - 'parent_uuid': kw.get('parent_uuid', None), - 'address': kw.get('address', '00:7f:0b.2'), - 'host': kw.get('host', 'host'), - 'board': kw.get('board', 'KU115'), - 'vendor': kw.get('vendor', 'Xilinx'), - 'version': kw.get('version', '1.0'), - 'type': kw.get('type', '1.0'), - 'interface_type': 'pci', - 'assignable': True, - 'instance_uuid': None, - 'availability': 'Available', - 'accelerator_id': kw.get('accelerator_id', 1), + 'num_accelerators': kw.get('num_accelerators', 4), + 'device_id': kw.get('device_id', 0), + 'created_at': kw.get('created_at', None), + 'updated_at': kw.get('updated_at', None) } diff --git a/cyborg/tests/unit/fake_deployable.py b/cyborg/tests/unit/fake_deployable.py index 3352783a..b1ecb7f9 100644 --- a/cyborg/tests/unit/fake_deployable.py +++ b/cyborg/tests/unit/fake_deployable.py @@ -25,22 +25,12 @@ def fake_db_deployable(**updates): root_uuid = uuidutils.generate_uuid() db_deployable = { 'id': 1, - 'deleted': False, 'uuid': root_uuid, 'name': 'dp_name', - 'parent_uuid': None, - 'root_uuid': root_uuid, - 'address': '00:7f:0b.2', - 'host': 'host_name', - 'board': 'KU115', - 'vendor': 'Xilinx', - 'version': '1.0', - 'type': 'pf', - 'interface_type': 'pci', - 'assignable': True, - 'instance_uuid': None, - 'availability': 'Available', - 'accelerator_id': 1 + 'parent_id': None, + 'root_id': 1, + 'num_accelerators': 4, + 'device_id': 0 } for name, field in objects.Deployable.fields.items(): diff --git a/cyborg/tests/unit/fake_device.py b/cyborg/tests/unit/fake_device.py new file mode 100644 index 00000000..7a01d36b --- /dev/null +++ b/cyborg/tests/unit/fake_device.py @@ -0,0 +1,59 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime + +from oslo_serialization import jsonutils +from oslo_utils import uuidutils + +from cyborg import objects +from cyborg.objects import fields + + +def fake_db_device(**updates): + root_uuid = uuidutils.generate_uuid() + db_device = { + 'id': 1, + 'uuid': root_uuid, + 'type': 'FPGA', + 'vendor': "vendor", + 'model': "model", + 'std_board_info': "std_board_info", + 'vendor_board_info': "vendor_board_info", + 'hostname': "hostname" + } + + for name, field in objects.Device.fields.items(): + if name in db_device: + continue + if field.nullable: + db_device[name] = None + elif field.default != fields.UnspecifiedDefault: + db_device[name] = field.default + else: + raise Exception('fake_db_device needs help with %s' % name) + + if updates: + db_device.update(updates) + + return db_device + + +def fake_device_obj(context, obj_device_class=None, **updates): + if obj_device_class is None: + obj_device_class = objects.Device + device = obj_device_class._from_db_object(obj_device_class(), + fake_db_device(**updates)) + device.obj_reset_changes() + return device diff --git a/cyborg/tests/unit/objects/test_deployable.py b/cyborg/tests/unit/objects/test_deployable.py new file mode 100644 index 00000000..61d70aee --- /dev/null +++ b/cyborg/tests/unit/objects/test_deployable.py @@ -0,0 +1,265 @@ +# Copyright 2019 Huawei Technologies Co.,LTD. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from testtools.matchers import HasLength +from cyborg import objects +from cyborg.tests.unit.db import base +from cyborg.tests.unit.db import utils +from cyborg.tests.unit import fake_device +from cyborg.tests.unit import fake_deployable +from cyborg.tests.unit import fake_attribute +from cyborg.tests.unit.objects import test_objects +from cyborg.tests.unit.db.base import DbTestCase +from cyborg.common import exception + + +class _TestDeployableObject(DbTestCase): + + @property + def fake_device(self): + db_device = fake_device.fake_db_device(id=1) + return db_device + + @property + def fake_deployable(self): + db_deploy = fake_deployable.fake_db_deployable(id=1) + return db_deploy + + @property + def fake_deployable2(self): + db_deploy = fake_deployable.fake_db_deployable(id=2) + return db_deploy + + @property + def fake_attribute(self): + db_attr = fake_attribute.fake_db_attribute(id=2) + return db_attr + + @property + def fake_attribute2(self): + db_attr = fake_attribute.fake_db_attribute(id=3) + return db_attr + + @property + def fake_attribute3(self): + db_attr = fake_attribute.fake_db_attribute(id=4) + return db_attr + + def test_create(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + + self.assertEqual(db_dpl['uuid'], dpl.uuid) + + def test_get(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) + self.assertEqual(dpl_get.uuid, dpl.uuid) + + def test_get_by_filter(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + query = {"uuid": dpl['uuid']} + dpl_get_list = objects.Deployable.get_by_filter(self.context, query) + + self.assertEqual(dpl_get_list[0].uuid, dpl.uuid) + + def test_save(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + dpl.num_accelerators = 8 + dpl.save(self.context) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) + self.assertEqual(dpl_get.num_accelerators, 8) + + def test_destroy(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + self.assertEqual(db_dpl['uuid'], dpl.uuid) + dpl.destroy(self.context) + self.assertRaises(exception.DeployableNotFound, + objects.Deployable.get, self.context, + dpl.uuid) + + def test_add_attribute(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) + + db_attr = self.fake_attribute + + dpl.add_attribute(self.context, db_attr['key'], db_attr['value']) + dpl.save(self.context) + + dpl_get = objects.Deployable.get(self.context, dpl.uuid) + self.assertEqual(len(dpl_get.attributes_list), 1) + + def test_delete_attribute(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) + db_attr = self.fake_attribute + dpl_get.add_attribute(self.context, db_attr['key'], db_attr['value']) + dpl_get.save(self.context) + dpl_get = objects.Deployable.get(self.context, dpl_get.uuid) + self.assertEqual(len(dpl_get.attributes_list), 1) + + dpl_get.delete_attribute(self.context, dpl_get.attributes_list[0]) + self.assertEqual(len(dpl_get.attributes_list), 0) + + def test_get_by_filter_with_attributes(self): + db_device = self.fake_device + device = objects.Device(context=self.context, + **db_device) + device.create(self.context) + device_get = objects.Device.get(self.context, device.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.device_id = device_get.id + dpl.create(self.context) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) + + db_dpl2 = self.fake_deployable2 + dpl2 = objects.Deployable(context=self.context, + **db_dpl2) + dpl2.device_id = device_get.id + dpl2.create(self.context) + dpl2_get = objects.Deployable.get(self.context, dpl2.uuid) + + db_attr = self.fake_attribute + + db_attr2 = self.fake_attribute2 + + db_attr3 = self.fake_attribute3 + + dpl.add_attribute(self.context, 'attr_key', 'attr_val') + dpl.save(self.context) + + dpl2.add_attribute(self.context, 'test_key', 'test_val') + dpl2.add_attribute(self.context, 'test_key3', 'test_val3') + dpl2.save(self.context) + + query = {"attr_key": "attr_val"} + + dpl_get_list = objects.Deployable.get_by_filter(self.context, query) + self.assertEqual(len(dpl_get_list), 1) + self.assertEqual(dpl_get_list[0].uuid, dpl.uuid) + + query = {"test_key": "test_val"} + dpl_get_list = objects.Deployable.get_by_filter(self.context, query) + self.assertEqual(len(dpl_get_list), 1) + self.assertEqual(dpl_get_list[0].uuid, dpl2.uuid) + + query = {"test_key": "test_val", "test_key3": "test_val3"} + dpl_get_list = objects.Deployable.get_by_filter(self.context, query) + self.assertEqual(len(dpl_get_list), 1) + self.assertEqual(dpl_get_list[0].uuid, dpl2.uuid) + + query = {"num_accelerators": 4, "test_key3": "test_val3"} + dpl_get_list = objects.Deployable.get_by_filter(self.context, query) + self.assertEqual(len(dpl_get_list), 1) + self.assertEqual(dpl_get_list[0].uuid, dpl2.uuid) + + +class TestDeployableObject(test_objects._LocalTest, + _TestDeployableObject): + def _test_save_objectfield_fk_constraint_fails(self, foreign_key, + expected_exception): + + error = db_exc.DBReferenceError('table', 'constraint', foreign_key, + 'key_table') + # Prevent lazy-loading any fields, results in InstanceNotFound + deployable = fake_deployable.fake_deployable_obj(self.context) + fields_with_save_methods = [field for field in deployable.fields + if hasattr(deployable, '_save_%s' % field)] + for field in fields_with_save_methods: + @mock.patch.object(deployable, '_save_%s' % field) + @mock.patch.object(deployable, 'obj_attr_is_set') + def _test(mock_is_set, mock_save_field): + mock_is_set.return_value = True + mock_save_field.side_effect = error + deployable.obj_reset_changes(fields=[field]) + deployable._changed_fields.add(field) + self.assertRaises(expected_exception, deployable.save) + deployable.obj_reset_changes(fields=[field]) + _test() diff --git a/cyborg/tests/unit/objects/test_objects.py b/cyborg/tests/unit/objects/test_objects.py new file mode 100644 index 00000000..90d27d29 --- /dev/null +++ b/cyborg/tests/unit/objects/test_objects.py @@ -0,0 +1,227 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib +import copy +import datetime +import inspect +import os + +import fixtures +import mock +from oslo_log import log +from oslo_utils import timeutils +from oslo_versionedobjects import base as ovo_base +from oslo_versionedobjects import exception as ovo_exc +from oslo_versionedobjects import fixture +import six + +from oslo_context import context + +from cyborg.common import exception +from cyborg import objects +from cyborg.objects import base +from cyborg.objects import fields +from cyborg import tests as test + + +LOG = log.getLogger(__name__) + + +class MyOwnedObject(base.CyborgPersistentObject, base.CyborgObject): + VERSION = '1.0' + fields = {'baz': fields.IntegerField()} + + +class MyObj(base.CyborgPersistentObject, base.CyborgObject, + base.CyborgObjectDictCompat): + VERSION = '1.6' + fields = {'foo': fields.IntegerField(default=1), + 'bar': fields.StringField(), + 'missing': fields.StringField(), + 'readonly': fields.IntegerField(read_only=True), + 'rel_object': fields.ObjectField('MyOwnedObject', nullable=True), + 'rel_objects': fields.ListOfObjectsField('MyOwnedObject', + nullable=True), + 'mutable_default': fields.ListOfStringsField(default=[]), + } + + @staticmethod + def _from_db_object(context, obj, db_obj): + self = MyObj() + self.foo = db_obj['foo'] + self.bar = db_obj['bar'] + self.missing = db_obj['missing'] + self.readonly = 1 + self._context = context + return self + + def obj_load_attr(self, attrname): + setattr(self, attrname, 'loaded!') + + def query(cls, context): + obj = cls(context=context, foo=1, bar='bar') + obj.obj_reset_changes() + return obj + + def marco(self): + return 'polo' + + def _update_test(self): + self.bar = 'updated' + + def save(self): + self.obj_reset_changes() + + def refresh(self): + self.foo = 321 + self.bar = 'refreshed' + self.obj_reset_changes() + + def modify_save_modify(self): + self.bar = 'meow' + self.save() + self.foo = 42 + self.rel_object = MyOwnedObject(baz=42) + + def obj_make_compatible(self, primitive, target_version): + super(MyObj, self).obj_make_compatible(primitive, target_version) + # NOTE(danms): Simulate an older version that had a different + # format for the 'bar' attribute + if target_version == '1.1' and 'bar' in primitive: + primitive['bar'] = 'old%s' % primitive['bar'] + + +class RandomMixInWithNoFields(object): + """Used to test object inheritance using a mixin that has no fields.""" + pass + + +@base.CyborgObjectRegistry.register_if(False) +class TestSubclassedObject(RandomMixInWithNoFields, MyObj): + fields = {'new_field': fields.StringField()} + + +class TestObjToPrimitive(test.base.TestCase): + + def test_obj_to_primitive_list(self): + @base.CyborgObjectRegistry.register_if(False) + class MyObjElement(base.CyborgObject): + fields = {'foo': fields.IntegerField()} + + def __init__(self, foo): + super(MyObjElement, self).__init__() + self.foo = foo + + @base.CyborgObjectRegistry.register_if(False) + class MyList(base.ObjectListBase, base.CyborgObject): + fields = {'objects': fields.ListOfObjectsField('MyObjElement')} + + mylist = MyList() + mylist.objects = [MyObjElement(1), MyObjElement(2), MyObjElement(3)] + self.assertEqual([1, 2, 3], + [x['foo'] for x in base.obj_to_primitive(mylist)]) + + def test_obj_to_primitive_dict(self): + base.CyborgObjectRegistry.register(MyObj) + myobj = MyObj(foo=1, bar='foo') + self.assertEqual({'foo': 1, 'bar': 'foo'}, + base.obj_to_primitive(myobj)) + + def test_obj_to_primitive_recursive(self): + base.CyborgObjectRegistry.register(MyObj) + + class MyList(base.ObjectListBase, base.CyborgObject): + fields = {'objects': fields.ListOfObjectsField('MyObj')} + + mylist = MyList(objects=[MyObj(), MyObj()]) + for i, value in enumerate(mylist): + value.foo = i + self.assertEqual([{'foo': 0}, {'foo': 1}], + base.obj_to_primitive(mylist)) + + def test_obj_to_primitive_with_ip_addr(self): + @base.CyborgObjectRegistry.register_if(False) + class TestObject(base.CyborgObject): + fields = {'addr': fields.IPAddressField(), + 'cidr': fields.IPNetworkField()} + + obj = TestObject(addr='1.2.3.4', cidr='1.1.1.1/16') + self.assertEqual({'addr': '1.2.3.4', 'cidr': '1.1.1.1/16'}, + base.obj_to_primitive(obj)) + + +def compare_obj(test, obj, db_obj, subs=None, allow_missing=None, + comparators=None): + """Compare a CyborgObject and a dict-like database object. + + This automatically converts TZ-aware datetimes and iterates over + the fields of the object. + + :param:test: The TestCase doing the comparison + :param:obj: The CyborgObject to examine + :param:db_obj: The dict-like database object to use as reference + :param:subs: A dict of objkey=dbkey field substitutions + :param:allow_missing: A list of fields that may not be in db_obj + :param:comparators: Map of comparator functions to use for certain fields + """ + + if subs is None: + subs = {} + if allow_missing is None: + allow_missing = [] + if comparators is None: + comparators = {} + + for key in obj.fields: + if key in allow_missing and not obj.obj_attr_is_set(key): + continue + obj_val = getattr(obj, key) + db_key = subs.get(key, key) + db_val = db_obj[db_key] + if isinstance(obj_val, datetime.datetime): + obj_val = obj_val.replace(tzinfo=None) + + if key in comparators: + comparator = comparators[key] + comparator(db_val, obj_val) + else: + test.assertEqual(db_val, obj_val) + + +class _BaseTestCase(test.base.TestCase): + def setUp(self): + super(_BaseTestCase, self).setUp() + self.user_id = 'fake-user' + self.project_id = 'fake-project' + self.context = context.RequestContext(self.user_id, self.project_id) + + base.CyborgObjectRegistry.register(MyObj) + base.CyborgObjectRegistry.register(MyOwnedObject) + + def compare_obj(self, obj, db_obj, subs=None, allow_missing=None, + comparators=None): + compare_obj(self, obj, db_obj, subs=subs, allow_missing=allow_missing, + comparators=comparators) + + def str_comparator(self, expected, obj_val): + """Compare an object field to a string in the db by performing + a simple coercion on the object field value. + """ + self.assertEqual(expected, str(obj_val)) + + +class _LocalTest(_BaseTestCase): + def setUp(self): + super(_LocalTest, self).setUp()