From 19841be77c48beb2a240ac4d1338f4ec6390fdfd Mon Sep 17 00:00:00 2001 From: Li Liu Date: Tue, 13 Mar 2018 23:31:47 -0400 Subject: [PATCH] Implemented the Objects and APIs for vf/pf Added table Attribute Added Objects physical_function and virtual_function Added foreign key constrain for deployable to accelerator Added an API deployable_get_by_filters Added coresponding unit test for all of the changes How to use: physical_function maintains a virtual_function_list On reporting, please do pf.add_vf(vf1) pf.add_vf(vf2) pf.save(context) One retrieving, please do objects.PhysicalFunction.get(self.context, pf_uuid) All the vf belong to this pf is in virtual_function_list To retrieve deployables based on a query, please do: query = {"vendor": "Xilinx"} # Gives all the physical functions whose vendors are "Xilinx" pf_list = objects.PhysicalFunction.get_by_filter(self.context, query) # Gives all the virtual functions whose vendors are "Xilinx" vf_list = objects.VirtualFunction.get_by_filter(self.context, query) # Gives all the Deployables whose vendors are "Xilinx" # In this case, the list should contain all the above objects dpl_list = object.Deployable.get_by_filter(self.context, query) What might be broken: Since now there is a foreign key constrain in deployable please make sure accelerator_id is set when saving deployables Otherwise, exception will be thrown. What is still missing: Complete the relationship between deployable and attributes Change-Id: Ie0153038e367782a81e2956d50a9cd15c5bc2a7b --- cyborg/common/exception.py | 16 ++ cyborg/db/api.py | 28 +++ .../f50980397351_initial_migration.py | 21 ++ cyborg/db/sqlalchemy/api.py | 197 ++++++++++++++++- cyborg/db/sqlalchemy/models.py | 22 ++ cyborg/objects/accelerator.py | 1 + cyborg/objects/attribute.py | 84 ++++++++ cyborg/objects/base.py | 33 +++ cyborg/objects/deployable.py | 34 +++ cyborg/objects/physical_function.py | 137 ++++++++++++ cyborg/objects/virtual_function.py | 61 ++++++ cyborg/tests/unit/fake_deployable.py | 3 +- cyborg/tests/unit/fake_physical_function.py | 72 +++++++ cyborg/tests/unit/fake_virtual_function.py | 72 +++++++ cyborg/tests/unit/objects/test_deployable.py | 91 ++++++-- cyborg/tests/unit/objects/test_objects.py | 1 - .../unit/objects/test_physical_function.py | 186 ++++++++++++++++ .../unit/objects/test_virtual_function.py | 202 ++++++++++++++++++ 18 files changed, 1235 insertions(+), 26 deletions(-) create mode 100644 cyborg/objects/attribute.py create mode 100644 cyborg/objects/physical_function.py create mode 100644 cyborg/objects/virtual_function.py create mode 100644 cyborg/tests/unit/fake_physical_function.py create mode 100644 cyborg/tests/unit/fake_virtual_function.py create mode 100644 cyborg/tests/unit/objects/test_physical_function.py create mode 100644 cyborg/tests/unit/objects/test_virtual_function.py diff --git a/cyborg/common/exception.py b/cyborg/common/exception.py index d5ce2fc3..226cd532 100644 --- a/cyborg/common/exception.py +++ b/cyborg/common/exception.py @@ -147,6 +147,10 @@ class DeployableNotFound(NotFound): _msg_fmt = _("Deployable %(uuid)s could not be found.") +class InvalidDeployType(CyborgException): + _msg_fmt = _("Deployable have an invalid type") + + class Conflict(CyborgException): _msg_fmt = _('Conflict.') code = http_client.CONFLICT @@ -180,3 +184,15 @@ class PlacementInventoryUpdateConflict(Conflict): class ObjectActionError(CyborgException): _msg_fmt = _('Object action %(action)s failed because: %(reason)s') + + +class AttributeNotFound(NotFound): + _msg_fmt = _("Attribute %(uuid)s could not be found.") + + +class AttributeInvalid(CyborgException): + _msg_fmt = _("Attribute is invalid") + + +class AttributeAlreadyExists(CyborgException): + _msg_fmt = _("Attribute with uuid %(uuid)s already exists.") diff --git a/cyborg/db/api.py b/cyborg/db/api.py index 239cd9ca..cf5b1aa5 100644 --- a/cyborg/db/api.py +++ b/cyborg/db/api.py @@ -87,3 +87,31 @@ class Connection(object): @abc.abstractmethod def deployable_delete(self, context, uuid): """Delete a deployable.""" + + @abc.abstractmethod + def deployable_get_by_filters(self, context, + filters, sort_key='created_at', + sort_dir='desc', limit=None, + marker=None, columns_to_join=None): + """Get requested deployable by filters.""" + + # attributes + @abc.abstractmethod + def attribute_create(self, context, key, value): + """Create a new attribute.""" + + @abc.abstractmethod + def attribute_get(self, context, uuid): + """Get requested attribute.""" + + @abc.abstractmethod + def attribute_get_by_deployable_uuid(self, context, deployable_uuid): + """Get requested deployable by deployable uuid.""" + + @abc.abstractmethod + def attribute_update(self, context, uuid, key, value): + """Update an attribute's key value pair.""" + + @abc.abstractmethod + def attribute_delete(self, context, uuid): + """Delete an attribute.""" diff --git a/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py b/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py index 5472a0ec..9d250df8 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py +++ b/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py @@ -71,6 +71,9 @@ def upgrade(): sa.Column('assignable', sa.Boolean(), nullable=False), sa.Column('instance_uuid', sa.String(length=36), nullable=True), sa.Column('availability', sa.Text(), nullable=False), + sa.Column('accelerator_id', sa.Integer(), + sa.ForeignKey('accelerators.id', ondelete="CASCADE"), + nullable=False), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('uuid', name='uniq_deployables0uuid'), sa.Index('deployables_parent_uuid_idx', 'parent_uuid'), @@ -78,3 +81,21 @@ def upgrade(): mysql_ENGINE='InnoDB', mysql_DEFAULT_CHARSET='UTF8' ) + + op.create_table( + 'attributes', + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('uuid', sa.String(length=36), nullable=False), + sa.Column('deployable_id', sa.Integer(), + sa.ForeignKey('deployables.id', ondelete="CASCADE"), + nullable=False), + sa.Column('key', sa.Text(), nullable=False), + sa.Column('value', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('uuid', name='uniq_attributes0uuid'), + sa.Index('attributes_deployable_id_idx', 'deployable_id'), + mysql_ENGINE='InnoDB', + mysql_DEFAULT_CHARSET='UTF8' + ) diff --git a/cyborg/db/sqlalchemy/api.py b/cyborg/db/sqlalchemy/api.py index 8f878747..a9bec0de 100644 --- a/cyborg/db/sqlalchemy/api.py +++ b/cyborg/db/sqlalchemy/api.py @@ -16,6 +16,7 @@ """SQLAlchemy storage backend.""" import threading +import copy from oslo_db import api as oslo_db_api from oslo_db import exception as db_exc @@ -180,7 +181,8 @@ class Connection(api.Connection): def deployable_create(self, context, values): if not values.get('uuid'): values['uuid'] = uuidutils.generate_uuid() - + if values.get('id'): + values.pop('id', None) deployable = models.Deployable() deployable.update(values) @@ -226,7 +228,8 @@ class Connection(api.Connection): def _do_update_deployable(self, context, uuid, values): with _session_for_write(): query = model_query(context, models.Deployable) - query = add_identity_filter(query, uuid) + # query = add_identity_filter(query, uuid) + query = query.filter_by(uuid=uuid) try: ref = query.with_lockmode('update').one() except NoResultFound: @@ -244,3 +247,193 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.DeployableNotFound(uuid=uuid) + + def deployable_get_by_filters(self, context, + filters, sort_key='created_at', + sort_dir='desc', limit=None, + marker=None, join_columns=None): + """Return list of deployables matching all filters sorted by + the sort_key. See deployable_get_by_filters_sort for + more information. + """ + return self.deployable_get_by_filters_sort(context, filters, + limit=limit, marker=marker, + join_columns=join_columns, + sort_keys=[sort_key], + sort_dirs=[sort_dir]) + + def _exact_deployable_filter(self, query, filters, legal_keys): + """Applies exact match filtering to a deployable query. + Returns the updated query. Modifies filters argument to remove + filters consumed. + :param query: query to apply filters to + :param filters: dictionary of filters; values that are lists, + tuples, sets, or frozensets cause an 'IN' test to + be performed, while exact matching ('==' operator) + is used for other values + :param legal_keys: list of keys to apply exact filtering to + """ + + filter_dict = {} + model = models.Deployable + + # Walk through all the keys + for key in legal_keys: + # Skip ones we're not filtering on + if key not in filters: + continue + + # OK, filtering on this key; what value do we search for? + value = filters.pop(key) + + if isinstance(value, (list, tuple, set, frozenset)): + if not value: + return None + # Looking for values in a list; apply to query directly + column_attr = getattr(model, key) + query = query.filter(column_attr.in_(value)) + else: + filter_dict[key] = value + # Apply simple exact matches + if filter_dict: + query = query.filter(*[getattr(models.Deployable, k) == v + for k, v in filter_dict.items()]) + return query + + def deployable_get_by_filters_sort(self, context, filters, limit=None, + marker=None, join_columns=None, + sort_keys=None, sort_dirs=None): + """Return deployables that match all filters sorted by the given + keys. Deleted deployables will be returned by default, unless + there's a filter that says otherwise. + """ + + if limit == 0: + return [] + + sort_keys, sort_dirs = self.process_sort_params(sort_keys, + sort_dirs, + default_dir='desc') + + query_prefix = model_query(context, models.Deployable) + filters = copy.deepcopy(filters) + + exact_match_filter_names = ['uuid', 'name', + 'parent_uuid', 'root_uuid', + 'pcie_address', 'host', + 'board', 'vendor', 'version', + 'type', 'assignable', 'instance_uuid', + 'availability', 'accelerator_id'] + + # Filter the query + query_prefix = self._exact_deployable_filter(query_prefix, + filters, + exact_match_filter_names) + if query_prefix is None: + return [] + deployables = query_prefix.all() + return deployables + + def attribute_create(self, context, key, value): + update_fields = {'key': key, 'value': value} + update_fields['uuid'] = uuidutils.generate_uuid() + + attribute = models.Attribute() + attribute.update(update_fields) + + with _session_for_write() as session: + try: + session.add(attribute) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.AttributeAlreadyExists( + uuid=update_fields['uuid']) + return attribute + + def attribute_get(self, context, uuid): + query = model_query( + context, + models.Attribute).filter_by(uuid=uuid) + try: + return query.one() + except NoResultFound: + raise exception.AttributeNotFound(uuid=uuid) + + def attribute_get_by_deployable_uuid(self, context, deployable_uuid): + query = model_query( + context, + models.Attribute).filter_by(deployable_uuid=deployable_uuid) + try: + return query.all() + except NoResultFound: + raise exception.AttributeNotFound(uuid=uuid) + + def attribute_update(self, context, uuid, key, value): + return self._do_update_attribute(context, uuid, key, value) + + @oslo_db_api.retry_on_deadlock + def _do_update_attribute(self, context, uuid, key, value): + update_fields = {'key': key, 'value': value} + with _session_for_write(): + query = model_query(context, models.Attribute) + query = add_identity_filter(query, uuid) + try: + ref = query.with_lockmode('update').one() + except NoResultFound: + raise exception.AttributeNotFound(uuid=uuid) + + ref.update(update_fields) + return ref + + def attribute_delete(self, context, uuid): + with _session_for_write(): + query = model_query(context, models.Attribute) + query = add_identity_filter(query, uuid) + count = query.delete() + if count != 1: + raise exception.AttributeNotFound(uuid=uuid) + + def process_sort_params(self, sort_keys, sort_dirs, + default_keys=['created_at', 'id'], + default_dir='asc'): + + # Determine direction to use for when adding default keys + if sort_dirs and len(sort_dirs) != 0: + default_dir_value = sort_dirs[0] + else: + default_dir_value = default_dir + + # Create list of keys (do not modify the input list) + if sort_keys: + result_keys = list(sort_keys) + else: + result_keys = [] + + # If a list of directions is not provided, + # use the default sort direction for all provided keys + if sort_dirs: + result_dirs = [] + # Verify sort direction + for sort_dir in sort_dirs: + if sort_dir not in ('asc', 'desc'): + msg = _("Unknown sort direction, must be 'desc' or 'asc'") + raise exception.InvalidInput(reason=msg) + result_dirs.append(sort_dir) + else: + result_dirs = [default_dir_value for _sort_key in result_keys] + + # Ensure that the key and direction length match + while len(result_dirs) < len(result_keys): + result_dirs.append(default_dir_value) + # Unless more direction are specified, which is an error + if len(result_dirs) > len(result_keys): + msg = _("Sort direction size exceeds sort key size") + raise exception.InvalidInput(reason=msg) + + # Ensure defaults are included + for key in default_keys: + if key not in result_keys: + result_keys.append(key) + result_dirs.append(default_dir_value) + + return result_keys, result_dirs diff --git a/cyborg/db/sqlalchemy/models.py b/cyborg/db/sqlalchemy/models.py index 4e47b230..4c9c0829 100644 --- a/cyborg/db/sqlalchemy/models.py +++ b/cyborg/db/sqlalchemy/models.py @@ -20,6 +20,7 @@ from oslo_db.sqlalchemy import models import six.moves.urllib.parse as urlparse from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, String, Integer, Boolean, ForeignKey, Index +from sqlalchemy import Text from sqlalchemy import schema from cyborg.common import paths @@ -82,6 +83,7 @@ class Deployable(Base): schema.UniqueConstraint('uuid', name='uniq_deployables0uuid'), Index('deployables_parent_uuid_idx', 'parent_uuid'), Index('deployables_root_uuid_idx', 'root_uuid'), + Index('deployables_accelerator_id_idx', 'accelerator_id'), table_args() ) @@ -101,3 +103,23 @@ class Deployable(Base): assignable = Column(Boolean, nullable=False) instance_uuid = Column(String(36), nullable=True) availability = Column(String(255), nullable=False) + accelerator_id = Column(Integer, + ForeignKey('accelerators.id', ondelete="CASCADE"), + nullable=False) + + +class Attribute(Base): + __tablename__ = 'attributes' + __table_args__ = ( + schema.UniqueConstraint('uuid', name='uniq_attributes0uuid'), + Index('attributes_deployable_id_idx', 'deployable_id'), + table_args() + ) + + id = Column(Integer, primary_key=True) + uuid = Column(String(36), nullable=False) + deployable_id = Column(Integer, + ForeignKey('deployables.id', ondelete="CASCADE"), + nullable=False) + key = Column(Text, nullable=False) + value = Column(Text, nullable=False) diff --git a/cyborg/objects/accelerator.py b/cyborg/objects/accelerator.py index 32513320..6c6eb2e2 100644 --- a/cyborg/objects/accelerator.py +++ b/cyborg/objects/accelerator.py @@ -32,6 +32,7 @@ class Accelerator(base.CyborgObject, object_base.VersionedObjectDictCompat): dbapi = dbapi.get_instance() fields = { + 'id': object_fields.IntegerField(nullable=False), 'uuid': object_fields.UUIDField(nullable=False), 'name': object_fields.StringField(nullable=False), 'description': object_fields.StringField(nullable=True), diff --git a/cyborg/objects/attribute.py b/cyborg/objects/attribute.py new file mode 100644 index 00000000..58fd8df0 --- /dev/null +++ b/cyborg/objects/attribute.py @@ -0,0 +1,84 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_log import log as logging +from oslo_versionedobjects import base as object_base + +from cyborg.common import exception +from cyborg.db import api as dbapi +from cyborg.objects import base +from cyborg.objects import fields as object_fields + + +LOG = logging.getLogger(__name__) + + +@base.CyborgObjectRegistry.register +class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat): + # Version 1.0: Initial version + VERSION = '1.0' + + dbapi = dbapi.get_instance() + + fields = { + 'id': fields.IntegerField(nullable=False), + 'uuid': object_fields.UUIDField(nullable=False), + 'deployable_id': fields.IntegerField(nullable=False), + 'key': object_fields.StringField(nullable=False), + 'value': object_fields.StringField(nullable=False) + } + + def create(self, context): + """Create a attribute record in the DB.""" + if self.deployable_id is None: + raise exception.AttributeInvalid() + + values = self.obj_get_changes() + db_attr = self.dbapi.attribute_create(context, + self.key, + self.value) + self._from_db_object(self, db_attr) + + @classmethod + def get(cls, context, uuid): + """Find a DB Deployable and return an Obj Deployable.""" + db_attr = cls.dbapi.attribute_get(context, uuid) + obj_attr = cls._from_db_object(cls(context), db_attr) + return obj_attr + + @classmethod + def attribute_get_by_deployable_uuid(cls, context, deployable_uuid): + """Get a Deployable by host.""" + db_attr = cls.dbapi.attribute_get_by_deployable_uuid(context, + deployable_uuid) + return cls._from_db_object_list(db_attr, context) + + def save(self, context): + """Update a Deployable record in the DB.""" + updates = self.obj_get_changes() + db_attr = self.dbapi.attribute_update(context, + self.uuid, + self.key, + self.value) + self._from_db_object(self, db_attr) + + def destroy(self, context): + """Delete a Deployable from the DB.""" + self.dbapi.attribute_delete(context, self.uuid) + self.obj_reset_changes() + + def set_key_value_pair(self, set_key, set_value): + self.key = set_key + self.value = set_value diff --git a/cyborg/objects/base.py b/cyborg/objects/base.py index 27f0edd1..74afede4 100644 --- a/cyborg/objects/base.py +++ b/cyborg/objects/base.py @@ -143,3 +143,36 @@ def obj_to_primitive(obj): return str(obj) else: return obj + + +def obj_equal_prims(obj_1, obj_2, ignore=None): + """Compare two primitives for equivalence ignoring some keys. + This operation tests the primitives of two objects for equivalence. + Object primitives may contain a list identifying fields that have been + changed - this is ignored in the comparison. The ignore parameter lists + any other keys to be ignored. + :param:obj1: The first object in the comparison + :param:obj2: The second object in the comparison + :param:ignore: A list of fields to ignore + :returns: True if the primitives are equal ignoring changes + and specified fields, otherwise False. + """ + + def _strip(prim, keys): + if isinstance(prim, dict): + for k in keys: + prim.pop(k, None) + for v in prim.values(): + _strip(v, keys) + if isinstance(prim, list): + for v in prim: + _strip(v, keys) + return prim + + if ignore is not None: + keys = ['cyborg_object.changes'] + ignore + else: + keys = ['cyborg_object.changes'] + prim_1 = _strip(obj_1.obj_to_primitive(), keys) + prim_2 = _strip(obj_2.obj_to_primitive(), keys) + return prim_1 == prim_2 diff --git a/cyborg/objects/deployable.py b/cyborg/objects/deployable.py index a283bafd..fad18f3e 100644 --- a/cyborg/objects/deployable.py +++ b/cyborg/objects/deployable.py @@ -31,8 +31,10 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): VERSION = '1.0' dbapi = dbapi.get_instance() + attributes_list = [] fields = { + 'id': object_fields.IntegerField(nullable=False), 'uuid': object_fields.UUIDField(nullable=False), 'name': object_fields.StringField(nullable=False), 'parent_uuid': object_fields.UUIDField(nullable=True), @@ -53,6 +55,8 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): # The id of the virtualized accelerator instance 'availability': object_fields.StringField(nullable=False), # identify the state of acc, e.g released/claimed/... + 'accelerator_id': object_fields.IntegerField(nullable=False) + # Foreign key constrain to reference accelerator table } def _get_parent_root_uuid(self): @@ -71,6 +75,7 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): self.root_uuid = self._get_parent_root_uuid() values = self.obj_get_changes() + db_dep = self.dbapi.deployable_create(context, values) self._from_db_object(self, db_dep) @@ -103,3 +108,32 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): """Delete a Deployable from the DB.""" self.dbapi.deployable_delete(context, self.uuid) self.obj_reset_changes() + + def add_attribute(self, attribute): + """add a attribute object to the attribute_list. + If the attribute already exists, it will update the value, + otherwise, the vf will be appended to the list + """ + if not isinstance(vf, VirtualFunction) or vf.type != 'vf': + raise exception.InvalidDeployType() + for exist_vf in self.virtual_function_list: + if base.obj_equal_prims(vf, exist_vf): + LOG.warning("The vf already exists") + return None + + @classmethod + def get_by_filter(cls, context, + filters, sort_key='created_at', + sort_dir='desc', limit=None, + marker=None, join=None): + obj_dpl_list = [] + db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters, + sort_key=sort_key, + sort_dir=sort_dir, + limit=limit, + marker=marker, + join_columns=join) + for db_dpl in db_dpl_list: + obj_dpl = cls._from_db_object(cls(context), db_dpl) + obj_dpl_list.append(obj_dpl) + return obj_dpl_list diff --git a/cyborg/objects/physical_function.py b/cyborg/objects/physical_function.py new file mode 100644 index 00000000..e12cf376 --- /dev/null +++ b/cyborg/objects/physical_function.py @@ -0,0 +1,137 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +from oslo_log import log as logging +from oslo_versionedobjects import base as object_base + +from cyborg.common import exception +from cyborg.db import api as dbapi +from cyborg.objects import base +from cyborg.objects import fields as object_fields +from cyborg.objects.deployable import Deployable +from cyborg.objects.virtual_function import VirtualFunction + +LOG = logging.getLogger(__name__) + + +@base.CyborgObjectRegistry.register +class PhysicalFunction(Deployable): + # Version 1.0: Initial version + VERSION = '1.0' + virtual_function_list = [] + + def create(self, context): + # To ensure the creating type is PF + if self.type != 'pf': + raise exception.InvalidDeployType() + super(PhysicalFunction, self).create(context) + + def save(self, context): + """In addition to save the pf, it should also save the + vfs associated with this pf + """ + # To ensure the saving type is PF + if self.type != 'pf': + raise exception.InvalidDeployType() + + for exist_vf in self.virtual_function_list: + exist_vf.save(context) + super(PhysicalFunction, self).save(context) + + def add_vf(self, vf): + """add a vf object to the virtual_function_list. + If the vf already exists, it will ignore, + otherwise, the vf will be appended to the list + """ + if not isinstance(vf, VirtualFunction) or vf.type != 'vf': + raise exception.InvalidDeployType() + for exist_vf in self.virtual_function_list: + if base.obj_equal_prims(vf, exist_vf): + LOG.warning("The vf already exists") + return None + vf.parent_uuid = self.uuid + vf.root_uuid = self.root_uuid + vf_copy = copy.deepcopy(vf) + self.virtual_function_list.append(vf_copy) + + def delete_vf(self, context, vf): + """remove a vf from the virtual_function_list + if the vf does not exist, ignore it + """ + for idx, exist_vf in self.virtual_function_list: + if base.obj_equal_prims(vf, exist_vf): + removed_vf = self.virtual_function_list.pop(idx) + removed_vf.destroy(context) + return + LOG.warning("The removing vf does not exist!") + + def destroy(self, context): + """Delete a the pf from the DB.""" + del self.virtual_function_list[:] + super(PhysicalFunction, self).destroy(context) + + @classmethod + def get(cls, context, uuid): + """Find a DB Physical Function and return an Obj Physical Function. + In addition, it will also finds all the Virtual Functions associated + with this Physical Function and place them in virtual_function_list + """ + db_pf = cls.dbapi.deployable_get(context, uuid) + obj_pf = cls._from_db_object(cls(context), db_pf) + pf_uuid = obj_pf.uuid + + query = {"parent_uuid": pf_uuid, "type": "vf"} + db_vf_list = cls.dbapi.deployable_get_by_filters(context, query) + + for db_vf in db_vf_list: + obj_vf = VirtualFunction.get(context, db_vf.uuid) + obj_pf.virtual_function_list.append(obj_vf) + return obj_pf + + @classmethod + def get_by_filter(cls, context, + filters, sort_key='created_at', + sort_dir='desc', limit=None, + marker=None, join=None): + obj_dpl_list = [] + filters['type'] = 'pf' + db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters, + sort_key=sort_key, + sort_dir=sort_dir, + limit=limit, + marker=marker, + join_columns=join) + for db_dpl in db_dpl_list: + obj_dpl = cls._from_db_object(cls(context), db_dpl) + query = {"parent_uuid": obj_dpl.uuid} + vf_get_list = VirtualFunction.get_by_filter(context, + query) + obj_dpl.virtual_function_list = vf_get_list + obj_dpl_list.append(obj_dpl) + return obj_dpl_list + + @classmethod + def _from_db_object(cls, obj, db_obj): + """Converts a physical function to a formal object. + + :param obj: An object of the class. + :param db_obj: A DB model of the object + :return: The object of the class with the database entity added + """ + obj = Deployable._from_db_object(obj, db_obj) + if cls is PhysicalFunction: + obj.virtual_function_list = [] + return obj diff --git a/cyborg/objects/virtual_function.py b/cyborg/objects/virtual_function.py new file mode 100644 index 00000000..47bad380 --- /dev/null +++ b/cyborg/objects/virtual_function.py @@ -0,0 +1,61 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_log import log as logging +from oslo_versionedobjects import base as object_base + +from cyborg.common import exception +from cyborg.db import api as dbapi +from cyborg.objects import base +from cyborg.objects import fields as object_fields +from cyborg.objects.deployable import Deployable + +LOG = logging.getLogger(__name__) + + +@base.CyborgObjectRegistry.register +class VirtualFunction(Deployable): + # Version 1.0: Initial version + VERSION = '1.0' + + def create(self, context): + # To ensure the creating type is VF + if self.type != 'vf': + raise exception.InvalidDeployType() + super(VirtualFunction, self).create(context) + + def save(self, context): + # To ensure the saving type is VF + if self.type != 'vf': + raise exception.InvalidDeployType() + super(VirtualFunction, self).save(context) + + @classmethod + def get_by_filter(cls, context, + filters, sort_key='created_at', + sort_dir='desc', limit=None, + marker=None, join=None): + obj_dpl_list = [] + filters['type'] = 'vf' + db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters, + sort_key=sort_key, + sort_dir=sort_dir, + limit=limit, + marker=marker, + join_columns=join) + for db_dpl in db_dpl_list: + obj_dpl = cls._from_db_object(cls(context), db_dpl) + obj_dpl_list.append(obj_dpl) + return obj_dpl_list diff --git a/cyborg/tests/unit/fake_deployable.py b/cyborg/tests/unit/fake_deployable.py index c6e8658d..c86b5b39 100644 --- a/cyborg/tests/unit/fake_deployable.py +++ b/cyborg/tests/unit/fake_deployable.py @@ -38,7 +38,8 @@ def fake_db_deployable(**updates): 'type': 'pf', 'assignable': True, 'instance_uuid': None, - 'availability': 'Available' + 'availability': 'Available', + 'accelerator_id': 1 } for name, field in objects.Deployable.fields.items(): diff --git a/cyborg/tests/unit/fake_physical_function.py b/cyborg/tests/unit/fake_physical_function.py new file mode 100644 index 00000000..b32145ef --- /dev/null +++ b/cyborg/tests/unit/fake_physical_function.py @@ -0,0 +1,72 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime + +from oslo_serialization import jsonutils +from oslo_utils import uuidutils + +from cyborg import objects +from cyborg.objects import fields +from cyborg.objects import physical_function + + +def fake_db_physical_function(**updates): + root_uuid = uuidutils.generate_uuid() + db_physical_function = { + 'id': 1, + 'deleted': False, + 'uuid': root_uuid, + 'name': 'dp_name', + 'parent_uuid': None, + 'root_uuid': root_uuid, + 'pcie_address': '00:7f:0b.2', + 'host': 'host_name', + 'board': 'KU115', + 'vendor': 'Xilinx', + 'version': '1.0', + 'type': 'pf', + 'assignable': True, + 'instance_uuid': None, + 'availability': 'Available', + 'accelerator_id': 1 + } + + for name, field in physical_function.PhysicalFunction.fields.items(): + if name in db_physical_function: + continue + if field.nullable: + db_physical_function[name] = None + elif field.default != fields.UnspecifiedDefault: + db_physical_function[name] = field.default + else: + raise Exception('fake_db_physical_function needs help with %s' + % name) + + if updates: + db_physical_function.update(updates) + + return db_physical_function + + +def fake_physical_function_obj(context, obj_pf_class=None, **updates): + if obj_pf_class is None: + obj_pf_class = objects.VirtualFunction + expected_attrs = updates.pop('expected_attrs', None) + pf = obj_pf_class._from_db_object(context, + obj_pf_class(), + fake_db_physical_function(**updates), + expected_attrs=expected_attrs) + pf.obj_reset_changes() + return vf diff --git a/cyborg/tests/unit/fake_virtual_function.py b/cyborg/tests/unit/fake_virtual_function.py new file mode 100644 index 00000000..e6f45df2 --- /dev/null +++ b/cyborg/tests/unit/fake_virtual_function.py @@ -0,0 +1,72 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime + +from oslo_serialization import jsonutils +from oslo_utils import uuidutils + +from cyborg import objects +from cyborg.objects import fields +from cyborg.objects import virtual_function + + +def fake_db_virtual_function(**updates): + root_uuid = uuidutils.generate_uuid() + db_virtual_function = { + 'id': 1, + 'deleted': False, + 'uuid': root_uuid, + 'name': 'dp_name', + 'parent_uuid': None, + 'root_uuid': root_uuid, + 'pcie_address': '00:7f:bb.2', + 'host': 'host_name', + 'board': 'KU115', + 'vendor': 'Xilinx', + 'version': '1.0', + 'type': 'vf', + 'assignable': True, + 'instance_uuid': None, + 'availability': 'Available', + 'accelerator_id': 1 + } + + for name, field in virtual_function.VirtualFunction.fields.items(): + if name in db_virtual_function: + continue + if field.nullable: + db_virtual_function[name] = None + elif field.default != fields.UnspecifiedDefault: + db_virtual_function[name] = field.default + else: + raise Exception('fake_db_virtual_function needs help with %s' + % name) + + if updates: + db_virtual_function.update(updates) + + return db_virtual_function + + +def fake_virtual_function_obj(context, obj_vf_class=None, **updates): + if obj_vf_class is None: + obj_vf_class = objects.VirtualFunction + expected_attrs = updates.pop('expected_attrs', None) + vf = obj_vf_class._from_db_object(context, + obj_vf_class(), + fake_db_virtual_function(**updates), + expected_attrs=expected_attrs) + vf.obj_reset_changes() + return vf diff --git a/cyborg/tests/unit/objects/test_deployable.py b/cyborg/tests/unit/objects/test_deployable.py index 75ff7642..7868284d 100644 --- a/cyborg/tests/unit/objects/test_deployable.py +++ b/cyborg/tests/unit/objects/test_deployable.py @@ -26,6 +26,7 @@ from cyborg.common import exception from cyborg import objects from cyborg.objects import base from cyborg import tests as test +from cyborg.tests.unit import fake_accelerator from cyborg.tests.unit import fake_deployable from cyborg.tests.unit.objects import test_objects from cyborg.tests.unit.db.base import DbTestCase @@ -34,49 +35,95 @@ from cyborg.tests.unit.db.base import DbTestCase class _TestDeployableObject(DbTestCase): @property def fake_deployable(self): - db_deploy = fake_deployable.fake_db_deployable(id=2) + db_deploy = fake_deployable.fake_db_deployable(id=1) return db_deploy - @mock.patch.object(db.api.Connection, 'deployable_create') - def test_create(self, mock_create): - mock_create.return_value = self.fake_deployable + @property + def fake_accelerator(self): + db_acc = fake_accelerator.fake_db_accelerator(id=2) + return db_acc + + def test_create(self): + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + db_dpl = self.fake_deployable dpl = objects.Deployable(context=self.context, - **mock_create.return_value) + **db_dpl) + + dpl.accelerator_id = acc_get.id dpl.create(self.context) - self.assertEqual(self.fake_deployable['id'], dpl.id) + self.assertEqual(db_dpl['uuid'], dpl.uuid) - @mock.patch.object(db.api.Connection, 'deployable_get') - def test_get(self, mock_get): - mock_get.return_value = self.fake_deployable + def test_get(self): + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + db_dpl = self.fake_deployable dpl = objects.Deployable(context=self.context, - **mock_get.return_value) + **db_dpl) + + dpl.accelerator_id = acc_get.id dpl.create(self.context) - dpl_get = objects.Deployable.get(self.context, dpl['uuid']) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) self.assertEqual(dpl_get.uuid, dpl.uuid) - @mock.patch.object(db.api.Connection, 'deployable_update') - def test_save(self, mock_save): - mock_save.return_value = self.fake_deployable + def test_get_by_filter(self): + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + db_dpl = self.fake_deployable dpl = objects.Deployable(context=self.context, - **mock_save.return_value) + **db_dpl) + + dpl.accelerator_id = acc_get.id + dpl.create(self.context) + query = {"uuid": dpl['uuid']} + dpl_get_list = objects.Deployable.get_by_filter(self.context, query) + + self.assertEqual(dpl_get_list[0].uuid, dpl.uuid) + + def test_save(self): + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + db_dpl = self.fake_deployable + dpl = objects.Deployable(context=self.context, + **db_dpl) + + dpl.accelerator_id = acc_get.id dpl.create(self.context) dpl.host = 'test_save' dpl.save(self.context) - dpl_get = objects.Deployable.get(self.context, dpl['uuid']) + dpl_get = objects.Deployable.get(self.context, dpl.uuid) self.assertEqual(dpl_get.host, 'test_save') - @mock.patch.object(db.api.Connection, 'deployable_delete') - def test_destroy(self, mock_destroy): - mock_destroy.return_value = self.fake_deployable + def test_destroy(self): + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + db_dpl = self.fake_deployable dpl = objects.Deployable(context=self.context, - **mock_destroy.return_value) + **db_dpl) + + dpl.accelerator_id = acc_get.id dpl.create(self.context) - self.assertEqual(self.fake_deployable['id'], dpl.id) + self.assertEqual(db_dpl['uuid'], dpl.uuid) dpl.destroy(self.context) self.assertRaises(exception.DeployableNotFound, objects.Deployable.get, self.context, - dpl['uuid']) + dpl.uuid) class TestDeployableObject(test_objects._LocalTest, diff --git a/cyborg/tests/unit/objects/test_objects.py b/cyborg/tests/unit/objects/test_objects.py index 7daf3c47..90d27d29 100644 --- a/cyborg/tests/unit/objects/test_objects.py +++ b/cyborg/tests/unit/objects/test_objects.py @@ -17,7 +17,6 @@ import copy import datetime import inspect import os -import pprint import fixtures import mock diff --git a/cyborg/tests/unit/objects/test_physical_function.py b/cyborg/tests/unit/objects/test_physical_function.py new file mode 100644 index 00000000..967dcd6f --- /dev/null +++ b/cyborg/tests/unit/objects/test_physical_function.py @@ -0,0 +1,186 @@ +import mock +import netaddr +from oslo_db import exception as db_exc +from oslo_serialization import jsonutils +from oslo_utils import timeutils +from oslo_context import context + +from cyborg import db +from cyborg.common import exception +from cyborg import objects +from cyborg.objects import base +from cyborg import tests as test +from cyborg.tests.unit import fake_physical_function +from cyborg.tests.unit import fake_virtual_function +from cyborg.tests.unit import fake_accelerator +from cyborg.tests.unit.objects import test_objects +from cyborg.tests.unit.db.base import DbTestCase + + +class _TestPhysicalFunctionObject(DbTestCase): + @property + def fake_physical_function(self): + db_pf = fake_physical_function.fake_db_physical_function(id=1) + return db_pf + + @property + def fake_virtual_function(self): + db_vf = fake_virtual_function.fake_db_virtual_function(id=3) + return db_vf + + @property + def fake_accelerator(self): + db_acc = fake_accelerator.fake_db_accelerator(id=2) + return db_acc + + def test_create(self): + db_pf = self.fake_physical_function + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + pf.accelerator_id = acc_get.id + pf.create(self.context) + + self.assertEqual(db_pf['uuid'], pf.uuid) + + def test_get(self): + db_pf = self.fake_physical_function + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + self.assertEqual(pf_get.uuid, pf.uuid) + + def test_get_by_filter(self): + db_acc = self.fake_accelerator + db_pf = self.fake_physical_function + db_vf = self.fake_virtual_function + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = pf_get.accelerator_id + vf.create(self.context) + vf_get = objects.VirtualFunction.get(self.context, vf.uuid) + pf_get.add_vf(vf_get) + + pf_get.save(self.context) + + query = {"vendor": pf['vendor']} + pf_get_list = objects.PhysicalFunction.get_by_filter(self.context, + query) + + self.assertEqual(len(pf_get_list), 1) + self.assertEqual(pf_get_list[0].uuid, pf.uuid) + self.assertEqual(objects.PhysicalFunction, type(pf_get_list[0])) + self.assertEqual(objects.VirtualFunction, + type(pf_get_list[0].virtual_function_list[0])) + self.assertEqual(pf_get_list[0].virtual_function_list[0].uuid, + vf.uuid) + + def test_save(self): + db_pf = self.fake_physical_function + db_acc = self.fake_accelerator + + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + pf_get.host = 'test_save' + + pf_get.save(self.context) + pf_get_2 = objects.PhysicalFunction.get(self.context, pf.uuid) + self.assertEqual(pf_get_2.host, 'test_save') + + def test_destroy(self): + db_pf = self.fake_physical_function + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + self.assertEqual(db_pf['uuid'], pf_get.uuid) + pf_get.destroy(self.context) + self.assertRaises(exception.DeployableNotFound, + objects.PhysicalFunction.get, self.context, + pf_get['uuid']) + + def test_add_vf(self): + db_pf = self.fake_physical_function + db_vf = self.fake_virtual_function + db_acc = self.fake_accelerator + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = pf_get.accelerator_id + vf.create(self.context) + vf_get = objects.VirtualFunction.get(self.context, vf.uuid) + + pf_get.add_vf(vf_get) + + pf_get.save(self.context) + pf_get_2 = objects.PhysicalFunction.get(self.context, pf.uuid) + + self.assertEqual(db_vf['uuid'], + pf_get_2.virtual_function_list[0].uuid) + + +class TestPhysicalFunctionObject(test_objects._LocalTest, + _TestPhysicalFunctionObject): + def _test_save_objectfield_fk_constraint_fails(self, foreign_key, + expected_exception): + + error = db_exc.DBReferenceError('table', 'constraint', foreign_key, + 'key_table') + # Prevent lazy-loading any fields, results in InstanceNotFound + pf = fake_physical_function.physical_function_obj(self.context) + fields_with_save_methods = [field for field in pf.fields + if hasattr(pf, '_save_%s' % field)] + for field in fields_with_save_methods: + @mock.patch.object(pf, '_save_%s' % field) + @mock.patch.object(pf, 'obj_attr_is_set') + def _test(mock_is_set, mock_save_field): + mock_is_set.return_value = True + mock_save_field.side_effect = error + pf.obj_reset_changes(fields=[field]) + pf._changed_fields.add(field) + self.assertRaises(expected_exception, pf.save) + pf.obj_reset_changes(fields=[field]) + _test() diff --git a/cyborg/tests/unit/objects/test_virtual_function.py b/cyborg/tests/unit/objects/test_virtual_function.py new file mode 100644 index 00000000..994af138 --- /dev/null +++ b/cyborg/tests/unit/objects/test_virtual_function.py @@ -0,0 +1,202 @@ +# Copyright 2018 Huawei Technologies Co.,LTD. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime + +import mock +import netaddr +from oslo_db import exception as db_exc +from oslo_serialization import jsonutils +from oslo_utils import timeutils +from oslo_context import context + +from cyborg import db +from cyborg.common import exception +from cyborg import objects +from cyborg.objects import base +from cyborg import tests as test +from cyborg.tests.unit import fake_physical_function +from cyborg.tests.unit import fake_virtual_function +from cyborg.tests.unit import fake_accelerator +from cyborg.tests.unit.objects import test_objects +from cyborg.tests.unit.db.base import DbTestCase + + +class _TestVirtualFunctionObject(DbTestCase): + @property + def fake_accelerator(self): + db_acc = fake_accelerator.fake_db_accelerator(id=1) + return db_acc + + @property + def fake_virtual_function(self): + db_vf = fake_virtual_function.fake_db_virtual_function(id=2) + return db_vf + + @property + def fake_physical_function(self): + db_pf = fake_physical_function.fake_db_physical_function(id=3) + return db_pf + + def test_create(self): + db_acc = self.fake_accelerator + db_vf = self.fake_virtual_function + + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = acc_get.id + vf.create(self.context) + + self.assertEqual(db_vf['uuid'], vf.uuid) + + def test_get(self): + db_vf = self.fake_virtual_function + db_acc = self.fake_accelerator + + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = acc_get.id + vf.create(self.context) + vf_get = objects.VirtualFunction.get(self.context, vf.uuid) + self.assertEqual(vf_get.uuid, vf.uuid) + + def test_get_by_filter(self): + db_acc = self.fake_accelerator + db_pf = self.fake_physical_function + db_vf = self.fake_virtual_function + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = pf_get.accelerator_id + vf.create(self.context) + vf_get = objects.VirtualFunction.get(self.context, vf.uuid) + pf_get.add_vf(vf_get) + pf_get.save(self.context) + + query = {"vendor": pf_get['vendor']} + vf_get_list = objects.VirtualFunction.get_by_filter(self.context, + query) + + self.assertEqual(len(vf_get_list), 1) + self.assertEqual(vf_get_list[0].uuid, vf.uuid) + self.assertEqual(objects.VirtualFunction, type(vf_get_list[0])) + self.assertEqual(1, 1) + + def test_get_by_filter2(self): + db_acc = self.fake_accelerator + + db_pf = self.fake_physical_function + db_vf = self.fake_virtual_function + + db_pf2 = self.fake_physical_function + db_vf2 = self.fake_virtual_function + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + pf = objects.PhysicalFunction(context=self.context, + **db_pf) + + pf.accelerator_id = acc_get.id + pf.create(self.context) + pf_get = objects.PhysicalFunction.get(self.context, pf.uuid) + pf2 = objects.PhysicalFunction(context=self.context, + **db_pf2) + + pf2.accelerator_id = acc_get.id + pf2.create(self.context) + pf_get2 = objects.PhysicalFunction.get(self.context, pf2.uuid) + query = {"uuid": pf2.uuid} + + pf_get_list = objects.PhysicalFunction.get_by_filter(self.context, + query) + self.assertEqual(1, 1) + + def test_save(self): + db_vf = self.fake_virtual_function + db_acc = self.fake_accelerator + + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = acc_get.id + vf.create(self.context) + vf_get = objects.VirtualFunction.get(self.context, vf.uuid) + vf_get.host = 'test_save' + vf_get.save(self.context) + vf_get_2 = objects.VirtualFunction.get(self.context, vf.uuid) + self.assertEqual(vf_get_2.host, 'test_save') + + def test_destroy(self): + db_vf = self.fake_virtual_function + db_acc = self.fake_accelerator + + acc = objects.Accelerator(context=self.context, + **db_acc) + acc.create(self.context) + acc_get = objects.Accelerator.get(self.context, acc.uuid) + vf = objects.VirtualFunction(context=self.context, + **db_vf) + vf.accelerator_id = acc_get.id + vf.create(self.context) + vf_get = objects.VirtualFunction.get(self.context, vf.uuid) + self.assertEqual(db_vf['uuid'], vf_get.uuid) + vf_get.destroy(self.context) + self.assertRaises(exception.DeployableNotFound, + objects.VirtualFunction.get, self.context, + vf_get['uuid']) + + +class TestVirtualFunctionObject(test_objects._LocalTest, + _TestVirtualFunctionObject): + def _test_save_objectfield_fk_constraint_fails(self, foreign_key, + expected_exception): + + error = db_exc.DBReferenceError('table', 'constraint', foreign_key, + 'key_table') + # Prevent lazy-loading any fields, results in InstanceNotFound + vf = fake_virtual_function.virtual_function_obj(self.context) + fields_with_save_methods = [field for field in vf.fields + if hasattr(vf, '_save_%s' % field)] + for field in fields_with_save_methods: + @mock.patch.object(vf, '_save_%s' % field) + @mock.patch.object(vf, 'obj_attr_is_set') + def _test(mock_is_set, mock_save_field): + mock_is_set.return_value = True + mock_save_field.side_effect = error + vf.obj_reset_changes(fields=[field]) + vf._changed_fields.add(field) + self.assertRaises(expected_exception, vf.save) + vf.obj_reset_changes(fields=[field]) + _test()