Merge "Implemented the Objects and APIs for vf/pf"

This commit is contained in:
Zuul 2018-04-19 18:00:34 +00:00 committed by Gerrit Code Review
commit ce0b6e5a95
18 changed files with 1235 additions and 26 deletions

View File

@ -147,6 +147,10 @@ class DeployableNotFound(NotFound):
_msg_fmt = _("Deployable %(uuid)s could not be found.") _msg_fmt = _("Deployable %(uuid)s could not be found.")
class InvalidDeployType(CyborgException):
_msg_fmt = _("Deployable have an invalid type")
class Conflict(CyborgException): class Conflict(CyborgException):
_msg_fmt = _('Conflict.') _msg_fmt = _('Conflict.')
code = http_client.CONFLICT code = http_client.CONFLICT
@ -180,3 +184,15 @@ class PlacementInventoryUpdateConflict(Conflict):
class ObjectActionError(CyborgException): class ObjectActionError(CyborgException):
_msg_fmt = _('Object action %(action)s failed because: %(reason)s') _msg_fmt = _('Object action %(action)s failed because: %(reason)s')
class AttributeNotFound(NotFound):
_msg_fmt = _("Attribute %(uuid)s could not be found.")
class AttributeInvalid(CyborgException):
_msg_fmt = _("Attribute is invalid")
class AttributeAlreadyExists(CyborgException):
_msg_fmt = _("Attribute with uuid %(uuid)s already exists.")

View File

@ -87,3 +87,31 @@ class Connection(object):
@abc.abstractmethod @abc.abstractmethod
def deployable_delete(self, context, uuid): def deployable_delete(self, context, uuid):
"""Delete a deployable.""" """Delete a deployable."""
@abc.abstractmethod
def deployable_get_by_filters(self, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, columns_to_join=None):
"""Get requested deployable by filters."""
# attributes
@abc.abstractmethod
def attribute_create(self, context, key, value):
"""Create a new attribute."""
@abc.abstractmethod
def attribute_get(self, context, uuid):
"""Get requested attribute."""
@abc.abstractmethod
def attribute_get_by_deployable_uuid(self, context, deployable_uuid):
"""Get requested deployable by deployable uuid."""
@abc.abstractmethod
def attribute_update(self, context, uuid, key, value):
"""Update an attribute's key value pair."""
@abc.abstractmethod
def attribute_delete(self, context, uuid):
"""Delete an attribute."""

View File

@ -71,6 +71,9 @@ def upgrade():
sa.Column('assignable', sa.Boolean(), nullable=False), sa.Column('assignable', sa.Boolean(), nullable=False),
sa.Column('instance_uuid', sa.String(length=36), nullable=True), sa.Column('instance_uuid', sa.String(length=36), nullable=True),
sa.Column('availability', sa.Text(), nullable=False), sa.Column('availability', sa.Text(), nullable=False),
sa.Column('accelerator_id', sa.Integer(),
sa.ForeignKey('accelerators.id', ondelete="CASCADE"),
nullable=False),
sa.PrimaryKeyConstraint('id'), sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('uuid', name='uniq_deployables0uuid'), sa.UniqueConstraint('uuid', name='uniq_deployables0uuid'),
sa.Index('deployables_parent_uuid_idx', 'parent_uuid'), sa.Index('deployables_parent_uuid_idx', 'parent_uuid'),
@ -78,3 +81,21 @@ def upgrade():
mysql_ENGINE='InnoDB', mysql_ENGINE='InnoDB',
mysql_DEFAULT_CHARSET='UTF8' mysql_DEFAULT_CHARSET='UTF8'
) )
op.create_table(
'attributes',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('uuid', sa.String(length=36), nullable=False),
sa.Column('deployable_id', sa.Integer(),
sa.ForeignKey('deployables.id', ondelete="CASCADE"),
nullable=False),
sa.Column('key', sa.Text(), nullable=False),
sa.Column('value', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('uuid', name='uniq_attributes0uuid'),
sa.Index('attributes_deployable_id_idx', 'deployable_id'),
mysql_ENGINE='InnoDB',
mysql_DEFAULT_CHARSET='UTF8'
)

View File

@ -16,6 +16,7 @@
"""SQLAlchemy storage backend.""" """SQLAlchemy storage backend."""
import threading import threading
import copy
from oslo_db import api as oslo_db_api from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc from oslo_db import exception as db_exc
@ -180,7 +181,8 @@ class Connection(api.Connection):
def deployable_create(self, context, values): def deployable_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
if values.get('id'):
values.pop('id', None)
deployable = models.Deployable() deployable = models.Deployable()
deployable.update(values) deployable.update(values)
@ -226,7 +228,8 @@ class Connection(api.Connection):
def _do_update_deployable(self, context, uuid, values): def _do_update_deployable(self, context, uuid, values):
with _session_for_write(): with _session_for_write():
query = model_query(context, models.Deployable) query = model_query(context, models.Deployable)
query = add_identity_filter(query, uuid) # query = add_identity_filter(query, uuid)
query = query.filter_by(uuid=uuid)
try: try:
ref = query.with_lockmode('update').one() ref = query.with_lockmode('update').one()
except NoResultFound: except NoResultFound:
@ -244,3 +247,193 @@ class Connection(api.Connection):
count = query.delete() count = query.delete()
if count != 1: if count != 1:
raise exception.DeployableNotFound(uuid=uuid) raise exception.DeployableNotFound(uuid=uuid)
def deployable_get_by_filters(self, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join_columns=None):
"""Return list of deployables matching all filters sorted by
the sort_key. See deployable_get_by_filters_sort for
more information.
"""
return self.deployable_get_by_filters_sort(context, filters,
limit=limit, marker=marker,
join_columns=join_columns,
sort_keys=[sort_key],
sort_dirs=[sort_dir])
def _exact_deployable_filter(self, query, filters, legal_keys):
"""Applies exact match filtering to a deployable query.
Returns the updated query. Modifies filters argument to remove
filters consumed.
:param query: query to apply filters to
:param filters: dictionary of filters; values that are lists,
tuples, sets, or frozensets cause an 'IN' test to
be performed, while exact matching ('==' operator)
is used for other values
:param legal_keys: list of keys to apply exact filtering to
"""
filter_dict = {}
model = models.Deployable
# Walk through all the keys
for key in legal_keys:
# Skip ones we're not filtering on
if key not in filters:
continue
# OK, filtering on this key; what value do we search for?
value = filters.pop(key)
if isinstance(value, (list, tuple, set, frozenset)):
if not value:
return None
# Looking for values in a list; apply to query directly
column_attr = getattr(model, key)
query = query.filter(column_attr.in_(value))
else:
filter_dict[key] = value
# Apply simple exact matches
if filter_dict:
query = query.filter(*[getattr(models.Deployable, k) == v
for k, v in filter_dict.items()])
return query
def deployable_get_by_filters_sort(self, context, filters, limit=None,
marker=None, join_columns=None,
sort_keys=None, sort_dirs=None):
"""Return deployables that match all filters sorted by the given
keys. Deleted deployables will be returned by default, unless
there's a filter that says otherwise.
"""
if limit == 0:
return []
sort_keys, sort_dirs = self.process_sort_params(sort_keys,
sort_dirs,
default_dir='desc')
query_prefix = model_query(context, models.Deployable)
filters = copy.deepcopy(filters)
exact_match_filter_names = ['uuid', 'name',
'parent_uuid', 'root_uuid',
'pcie_address', 'host',
'board', 'vendor', 'version',
'type', 'assignable', 'instance_uuid',
'availability', 'accelerator_id']
# Filter the query
query_prefix = self._exact_deployable_filter(query_prefix,
filters,
exact_match_filter_names)
if query_prefix is None:
return []
deployables = query_prefix.all()
return deployables
def attribute_create(self, context, key, value):
update_fields = {'key': key, 'value': value}
update_fields['uuid'] = uuidutils.generate_uuid()
attribute = models.Attribute()
attribute.update(update_fields)
with _session_for_write() as session:
try:
session.add(attribute)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.AttributeAlreadyExists(
uuid=update_fields['uuid'])
return attribute
def attribute_get(self, context, uuid):
query = model_query(
context,
models.Attribute).filter_by(uuid=uuid)
try:
return query.one()
except NoResultFound:
raise exception.AttributeNotFound(uuid=uuid)
def attribute_get_by_deployable_uuid(self, context, deployable_uuid):
query = model_query(
context,
models.Attribute).filter_by(deployable_uuid=deployable_uuid)
try:
return query.all()
except NoResultFound:
raise exception.AttributeNotFound(uuid=uuid)
def attribute_update(self, context, uuid, key, value):
return self._do_update_attribute(context, uuid, key, value)
@oslo_db_api.retry_on_deadlock
def _do_update_attribute(self, context, uuid, key, value):
update_fields = {'key': key, 'value': value}
with _session_for_write():
query = model_query(context, models.Attribute)
query = add_identity_filter(query, uuid)
try:
ref = query.with_lockmode('update').one()
except NoResultFound:
raise exception.AttributeNotFound(uuid=uuid)
ref.update(update_fields)
return ref
def attribute_delete(self, context, uuid):
with _session_for_write():
query = model_query(context, models.Attribute)
query = add_identity_filter(query, uuid)
count = query.delete()
if count != 1:
raise exception.AttributeNotFound(uuid=uuid)
def process_sort_params(self, sort_keys, sort_dirs,
default_keys=['created_at', 'id'],
default_dir='asc'):
# Determine direction to use for when adding default keys
if sort_dirs and len(sort_dirs) != 0:
default_dir_value = sort_dirs[0]
else:
default_dir_value = default_dir
# Create list of keys (do not modify the input list)
if sort_keys:
result_keys = list(sort_keys)
else:
result_keys = []
# If a list of directions is not provided,
# use the default sort direction for all provided keys
if sort_dirs:
result_dirs = []
# Verify sort direction
for sort_dir in sort_dirs:
if sort_dir not in ('asc', 'desc'):
msg = _("Unknown sort direction, must be 'desc' or 'asc'")
raise exception.InvalidInput(reason=msg)
result_dirs.append(sort_dir)
else:
result_dirs = [default_dir_value for _sort_key in result_keys]
# Ensure that the key and direction length match
while len(result_dirs) < len(result_keys):
result_dirs.append(default_dir_value)
# Unless more direction are specified, which is an error
if len(result_dirs) > len(result_keys):
msg = _("Sort direction size exceeds sort key size")
raise exception.InvalidInput(reason=msg)
# Ensure defaults are included
for key in default_keys:
if key not in result_keys:
result_keys.append(key)
result_dirs.append(default_dir_value)
return result_keys, result_dirs

View File

@ -20,6 +20,7 @@ from oslo_db.sqlalchemy import models
import six.moves.urllib.parse as urlparse import six.moves.urllib.parse as urlparse
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, Boolean, ForeignKey, Index from sqlalchemy import Column, String, Integer, Boolean, ForeignKey, Index
from sqlalchemy import Text
from sqlalchemy import schema from sqlalchemy import schema
from cyborg.common import paths from cyborg.common import paths
@ -82,6 +83,7 @@ class Deployable(Base):
schema.UniqueConstraint('uuid', name='uniq_deployables0uuid'), schema.UniqueConstraint('uuid', name='uniq_deployables0uuid'),
Index('deployables_parent_uuid_idx', 'parent_uuid'), Index('deployables_parent_uuid_idx', 'parent_uuid'),
Index('deployables_root_uuid_idx', 'root_uuid'), Index('deployables_root_uuid_idx', 'root_uuid'),
Index('deployables_accelerator_id_idx', 'accelerator_id'),
table_args() table_args()
) )
@ -101,3 +103,23 @@ class Deployable(Base):
assignable = Column(Boolean, nullable=False) assignable = Column(Boolean, nullable=False)
instance_uuid = Column(String(36), nullable=True) instance_uuid = Column(String(36), nullable=True)
availability = Column(String(255), nullable=False) availability = Column(String(255), nullable=False)
accelerator_id = Column(Integer,
ForeignKey('accelerators.id', ondelete="CASCADE"),
nullable=False)
class Attribute(Base):
__tablename__ = 'attributes'
__table_args__ = (
schema.UniqueConstraint('uuid', name='uniq_attributes0uuid'),
Index('attributes_deployable_id_idx', 'deployable_id'),
table_args()
)
id = Column(Integer, primary_key=True)
uuid = Column(String(36), nullable=False)
deployable_id = Column(Integer,
ForeignKey('deployables.id', ondelete="CASCADE"),
nullable=False)
key = Column(Text, nullable=False)
value = Column(Text, nullable=False)

View File

@ -32,6 +32,7 @@ class Accelerator(base.CyborgObject, object_base.VersionedObjectDictCompat):
dbapi = dbapi.get_instance() dbapi = dbapi.get_instance()
fields = { fields = {
'id': object_fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False), 'uuid': object_fields.UUIDField(nullable=False),
'name': object_fields.StringField(nullable=False), 'name': object_fields.StringField(nullable=False),
'description': object_fields.StringField(nullable=True), 'description': object_fields.StringField(nullable=True),

View File

@ -0,0 +1,84 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log as logging
from oslo_versionedobjects import base as object_base
from cyborg.common import exception
from cyborg.db import api as dbapi
from cyborg.objects import base
from cyborg.objects import fields as object_fields
LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat):
# Version 1.0: Initial version
VERSION = '1.0'
dbapi = dbapi.get_instance()
fields = {
'id': fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False),
'deployable_id': fields.IntegerField(nullable=False),
'key': object_fields.StringField(nullable=False),
'value': object_fields.StringField(nullable=False)
}
def create(self, context):
"""Create a attribute record in the DB."""
if self.deployable_id is None:
raise exception.AttributeInvalid()
values = self.obj_get_changes()
db_attr = self.dbapi.attribute_create(context,
self.key,
self.value)
self._from_db_object(self, db_attr)
@classmethod
def get(cls, context, uuid):
"""Find a DB Deployable and return an Obj Deployable."""
db_attr = cls.dbapi.attribute_get(context, uuid)
obj_attr = cls._from_db_object(cls(context), db_attr)
return obj_attr
@classmethod
def attribute_get_by_deployable_uuid(cls, context, deployable_uuid):
"""Get a Deployable by host."""
db_attr = cls.dbapi.attribute_get_by_deployable_uuid(context,
deployable_uuid)
return cls._from_db_object_list(db_attr, context)
def save(self, context):
"""Update a Deployable record in the DB."""
updates = self.obj_get_changes()
db_attr = self.dbapi.attribute_update(context,
self.uuid,
self.key,
self.value)
self._from_db_object(self, db_attr)
def destroy(self, context):
"""Delete a Deployable from the DB."""
self.dbapi.attribute_delete(context, self.uuid)
self.obj_reset_changes()
def set_key_value_pair(self, set_key, set_value):
self.key = set_key
self.value = set_value

View File

@ -143,3 +143,36 @@ def obj_to_primitive(obj):
return str(obj) return str(obj)
else: else:
return obj return obj
def obj_equal_prims(obj_1, obj_2, ignore=None):
"""Compare two primitives for equivalence ignoring some keys.
This operation tests the primitives of two objects for equivalence.
Object primitives may contain a list identifying fields that have been
changed - this is ignored in the comparison. The ignore parameter lists
any other keys to be ignored.
:param:obj1: The first object in the comparison
:param:obj2: The second object in the comparison
:param:ignore: A list of fields to ignore
:returns: True if the primitives are equal ignoring changes
and specified fields, otherwise False.
"""
def _strip(prim, keys):
if isinstance(prim, dict):
for k in keys:
prim.pop(k, None)
for v in prim.values():
_strip(v, keys)
if isinstance(prim, list):
for v in prim:
_strip(v, keys)
return prim
if ignore is not None:
keys = ['cyborg_object.changes'] + ignore
else:
keys = ['cyborg_object.changes']
prim_1 = _strip(obj_1.obj_to_primitive(), keys)
prim_2 = _strip(obj_2.obj_to_primitive(), keys)
return prim_1 == prim_2

View File

@ -31,8 +31,10 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
VERSION = '1.0' VERSION = '1.0'
dbapi = dbapi.get_instance() dbapi = dbapi.get_instance()
attributes_list = []
fields = { fields = {
'id': object_fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False), 'uuid': object_fields.UUIDField(nullable=False),
'name': object_fields.StringField(nullable=False), 'name': object_fields.StringField(nullable=False),
'parent_uuid': object_fields.UUIDField(nullable=True), 'parent_uuid': object_fields.UUIDField(nullable=True),
@ -53,6 +55,8 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
# The id of the virtualized accelerator instance # The id of the virtualized accelerator instance
'availability': object_fields.StringField(nullable=False), 'availability': object_fields.StringField(nullable=False),
# identify the state of acc, e.g released/claimed/... # identify the state of acc, e.g released/claimed/...
'accelerator_id': object_fields.IntegerField(nullable=False)
# Foreign key constrain to reference accelerator table
} }
def _get_parent_root_uuid(self): def _get_parent_root_uuid(self):
@ -71,6 +75,7 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
self.root_uuid = self._get_parent_root_uuid() self.root_uuid = self._get_parent_root_uuid()
values = self.obj_get_changes() values = self.obj_get_changes()
db_dep = self.dbapi.deployable_create(context, values) db_dep = self.dbapi.deployable_create(context, values)
self._from_db_object(self, db_dep) self._from_db_object(self, db_dep)
@ -103,3 +108,32 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
"""Delete a Deployable from the DB.""" """Delete a Deployable from the DB."""
self.dbapi.deployable_delete(context, self.uuid) self.dbapi.deployable_delete(context, self.uuid)
self.obj_reset_changes() self.obj_reset_changes()
def add_attribute(self, attribute):
"""add a attribute object to the attribute_list.
If the attribute already exists, it will update the value,
otherwise, the vf will be appended to the list
"""
if not isinstance(vf, VirtualFunction) or vf.type != 'vf':
raise exception.InvalidDeployType()
for exist_vf in self.virtual_function_list:
if base.obj_equal_prims(vf, exist_vf):
LOG.warning("The vf already exists")
return None
@classmethod
def get_by_filter(cls, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join=None):
obj_dpl_list = []
db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters,
sort_key=sort_key,
sort_dir=sort_dir,
limit=limit,
marker=marker,
join_columns=join)
for db_dpl in db_dpl_list:
obj_dpl = cls._from_db_object(cls(context), db_dpl)
obj_dpl_list.append(obj_dpl)
return obj_dpl_list

View File

@ -0,0 +1,137 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
from oslo_log import log as logging
from oslo_versionedobjects import base as object_base
from cyborg.common import exception
from cyborg.db import api as dbapi
from cyborg.objects import base
from cyborg.objects import fields as object_fields
from cyborg.objects.deployable import Deployable
from cyborg.objects.virtual_function import VirtualFunction
LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class PhysicalFunction(Deployable):
# Version 1.0: Initial version
VERSION = '1.0'
virtual_function_list = []
def create(self, context):
# To ensure the creating type is PF
if self.type != 'pf':
raise exception.InvalidDeployType()
super(PhysicalFunction, self).create(context)
def save(self, context):
"""In addition to save the pf, it should also save the
vfs associated with this pf
"""
# To ensure the saving type is PF
if self.type != 'pf':
raise exception.InvalidDeployType()
for exist_vf in self.virtual_function_list:
exist_vf.save(context)
super(PhysicalFunction, self).save(context)
def add_vf(self, vf):
"""add a vf object to the virtual_function_list.
If the vf already exists, it will ignore,
otherwise, the vf will be appended to the list
"""
if not isinstance(vf, VirtualFunction) or vf.type != 'vf':
raise exception.InvalidDeployType()
for exist_vf in self.virtual_function_list:
if base.obj_equal_prims(vf, exist_vf):
LOG.warning("The vf already exists")
return None
vf.parent_uuid = self.uuid
vf.root_uuid = self.root_uuid
vf_copy = copy.deepcopy(vf)
self.virtual_function_list.append(vf_copy)
def delete_vf(self, context, vf):
"""remove a vf from the virtual_function_list
if the vf does not exist, ignore it
"""
for idx, exist_vf in self.virtual_function_list:
if base.obj_equal_prims(vf, exist_vf):
removed_vf = self.virtual_function_list.pop(idx)
removed_vf.destroy(context)
return
LOG.warning("The removing vf does not exist!")
def destroy(self, context):
"""Delete a the pf from the DB."""
del self.virtual_function_list[:]
super(PhysicalFunction, self).destroy(context)
@classmethod
def get(cls, context, uuid):
"""Find a DB Physical Function and return an Obj Physical Function.
In addition, it will also finds all the Virtual Functions associated
with this Physical Function and place them in virtual_function_list
"""
db_pf = cls.dbapi.deployable_get(context, uuid)
obj_pf = cls._from_db_object(cls(context), db_pf)
pf_uuid = obj_pf.uuid
query = {"parent_uuid": pf_uuid, "type": "vf"}
db_vf_list = cls.dbapi.deployable_get_by_filters(context, query)
for db_vf in db_vf_list:
obj_vf = VirtualFunction.get(context, db_vf.uuid)
obj_pf.virtual_function_list.append(obj_vf)
return obj_pf
@classmethod
def get_by_filter(cls, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join=None):
obj_dpl_list = []
filters['type'] = 'pf'
db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters,
sort_key=sort_key,
sort_dir=sort_dir,
limit=limit,
marker=marker,
join_columns=join)
for db_dpl in db_dpl_list:
obj_dpl = cls._from_db_object(cls(context), db_dpl)
query = {"parent_uuid": obj_dpl.uuid}
vf_get_list = VirtualFunction.get_by_filter(context,
query)
obj_dpl.virtual_function_list = vf_get_list
obj_dpl_list.append(obj_dpl)
return obj_dpl_list
@classmethod
def _from_db_object(cls, obj, db_obj):
"""Converts a physical function to a formal object.
:param obj: An object of the class.
:param db_obj: A DB model of the object
:return: The object of the class with the database entity added
"""
obj = Deployable._from_db_object(obj, db_obj)
if cls is PhysicalFunction:
obj.virtual_function_list = []
return obj

View File

@ -0,0 +1,61 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log as logging
from oslo_versionedobjects import base as object_base
from cyborg.common import exception
from cyborg.db import api as dbapi
from cyborg.objects import base
from cyborg.objects import fields as object_fields
from cyborg.objects.deployable import Deployable
LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class VirtualFunction(Deployable):
# Version 1.0: Initial version
VERSION = '1.0'
def create(self, context):
# To ensure the creating type is VF
if self.type != 'vf':
raise exception.InvalidDeployType()
super(VirtualFunction, self).create(context)
def save(self, context):
# To ensure the saving type is VF
if self.type != 'vf':
raise exception.InvalidDeployType()
super(VirtualFunction, self).save(context)
@classmethod
def get_by_filter(cls, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join=None):
obj_dpl_list = []
filters['type'] = 'vf'
db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters,
sort_key=sort_key,
sort_dir=sort_dir,
limit=limit,
marker=marker,
join_columns=join)
for db_dpl in db_dpl_list:
obj_dpl = cls._from_db_object(cls(context), db_dpl)
obj_dpl_list.append(obj_dpl)
return obj_dpl_list

View File

@ -38,7 +38,8 @@ def fake_db_deployable(**updates):
'type': 'pf', 'type': 'pf',
'assignable': True, 'assignable': True,
'instance_uuid': None, 'instance_uuid': None,
'availability': 'Available' 'availability': 'Available',
'accelerator_id': 1
} }
for name, field in objects.Deployable.fields.items(): for name, field in objects.Deployable.fields.items():

View File

@ -0,0 +1,72 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
from cyborg.objects import physical_function
def fake_db_physical_function(**updates):
root_uuid = uuidutils.generate_uuid()
db_physical_function = {
'id': 1,
'deleted': False,
'uuid': root_uuid,
'name': 'dp_name',
'parent_uuid': None,
'root_uuid': root_uuid,
'pcie_address': '00:7f:0b.2',
'host': 'host_name',
'board': 'KU115',
'vendor': 'Xilinx',
'version': '1.0',
'type': 'pf',
'assignable': True,
'instance_uuid': None,
'availability': 'Available',
'accelerator_id': 1
}
for name, field in physical_function.PhysicalFunction.fields.items():
if name in db_physical_function:
continue
if field.nullable:
db_physical_function[name] = None
elif field.default != fields.UnspecifiedDefault:
db_physical_function[name] = field.default
else:
raise Exception('fake_db_physical_function needs help with %s'
% name)
if updates:
db_physical_function.update(updates)
return db_physical_function
def fake_physical_function_obj(context, obj_pf_class=None, **updates):
if obj_pf_class is None:
obj_pf_class = objects.VirtualFunction
expected_attrs = updates.pop('expected_attrs', None)
pf = obj_pf_class._from_db_object(context,
obj_pf_class(),
fake_db_physical_function(**updates),
expected_attrs=expected_attrs)
pf.obj_reset_changes()
return vf

View File

@ -0,0 +1,72 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
from cyborg.objects import virtual_function
def fake_db_virtual_function(**updates):
root_uuid = uuidutils.generate_uuid()
db_virtual_function = {
'id': 1,
'deleted': False,
'uuid': root_uuid,
'name': 'dp_name',
'parent_uuid': None,
'root_uuid': root_uuid,
'pcie_address': '00:7f:bb.2',
'host': 'host_name',
'board': 'KU115',
'vendor': 'Xilinx',
'version': '1.0',
'type': 'vf',
'assignable': True,
'instance_uuid': None,
'availability': 'Available',
'accelerator_id': 1
}
for name, field in virtual_function.VirtualFunction.fields.items():
if name in db_virtual_function:
continue
if field.nullable:
db_virtual_function[name] = None
elif field.default != fields.UnspecifiedDefault:
db_virtual_function[name] = field.default
else:
raise Exception('fake_db_virtual_function needs help with %s'
% name)
if updates:
db_virtual_function.update(updates)
return db_virtual_function
def fake_virtual_function_obj(context, obj_vf_class=None, **updates):
if obj_vf_class is None:
obj_vf_class = objects.VirtualFunction
expected_attrs = updates.pop('expected_attrs', None)
vf = obj_vf_class._from_db_object(context,
obj_vf_class(),
fake_db_virtual_function(**updates),
expected_attrs=expected_attrs)
vf.obj_reset_changes()
return vf

View File

@ -26,6 +26,7 @@ from cyborg.common import exception
from cyborg import objects from cyborg import objects
from cyborg.objects import base from cyborg.objects import base
from cyborg import tests as test from cyborg import tests as test
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit import fake_deployable from cyborg.tests.unit import fake_deployable
from cyborg.tests.unit.objects import test_objects from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase from cyborg.tests.unit.db.base import DbTestCase
@ -34,49 +35,95 @@ from cyborg.tests.unit.db.base import DbTestCase
class _TestDeployableObject(DbTestCase): class _TestDeployableObject(DbTestCase):
@property @property
def fake_deployable(self): def fake_deployable(self):
db_deploy = fake_deployable.fake_db_deployable(id=2) db_deploy = fake_deployable.fake_db_deployable(id=1)
return db_deploy return db_deploy
@mock.patch.object(db.api.Connection, 'deployable_create') @property
def test_create(self, mock_create): def fake_accelerator(self):
mock_create.return_value = self.fake_deployable db_acc = fake_accelerator.fake_db_accelerator(id=2)
return db_acc
def test_create(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context, dpl = objects.Deployable(context=self.context,
**mock_create.return_value) **db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context) dpl.create(self.context)
self.assertEqual(self.fake_deployable['id'], dpl.id) self.assertEqual(db_dpl['uuid'], dpl.uuid)
@mock.patch.object(db.api.Connection, 'deployable_get') def test_get(self):
def test_get(self, mock_get): db_acc = self.fake_accelerator
mock_get.return_value = self.fake_deployable acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context, dpl = objects.Deployable(context=self.context,
**mock_get.return_value) **db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context) dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl['uuid']) dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(dpl_get.uuid, dpl.uuid) self.assertEqual(dpl_get.uuid, dpl.uuid)
@mock.patch.object(db.api.Connection, 'deployable_update') def test_get_by_filter(self):
def test_save(self, mock_save): db_acc = self.fake_accelerator
mock_save.return_value = self.fake_deployable acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context, dpl = objects.Deployable(context=self.context,
**mock_save.return_value) **db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context)
query = {"uuid": dpl['uuid']}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(dpl_get_list[0].uuid, dpl.uuid)
def test_save(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context) dpl.create(self.context)
dpl.host = 'test_save' dpl.host = 'test_save'
dpl.save(self.context) dpl.save(self.context)
dpl_get = objects.Deployable.get(self.context, dpl['uuid']) dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(dpl_get.host, 'test_save') self.assertEqual(dpl_get.host, 'test_save')
@mock.patch.object(db.api.Connection, 'deployable_delete') def test_destroy(self):
def test_destroy(self, mock_destroy): db_acc = self.fake_accelerator
mock_destroy.return_value = self.fake_deployable acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context, dpl = objects.Deployable(context=self.context,
**mock_destroy.return_value) **db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context) dpl.create(self.context)
self.assertEqual(self.fake_deployable['id'], dpl.id) self.assertEqual(db_dpl['uuid'], dpl.uuid)
dpl.destroy(self.context) dpl.destroy(self.context)
self.assertRaises(exception.DeployableNotFound, self.assertRaises(exception.DeployableNotFound,
objects.Deployable.get, self.context, objects.Deployable.get, self.context,
dpl['uuid']) dpl.uuid)
class TestDeployableObject(test_objects._LocalTest, class TestDeployableObject(test_objects._LocalTest,

View File

@ -17,7 +17,6 @@ import copy
import datetime import datetime
import inspect import inspect
import os import os
import pprint
import fixtures import fixtures
import mock import mock

View File

@ -0,0 +1,186 @@
import mock
import netaddr
from oslo_db import exception as db_exc
from oslo_serialization import jsonutils
from oslo_utils import timeutils
from oslo_context import context
from cyborg import db
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_physical_function
from cyborg.tests.unit import fake_virtual_function
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
class _TestPhysicalFunctionObject(DbTestCase):
@property
def fake_physical_function(self):
db_pf = fake_physical_function.fake_db_physical_function(id=1)
return db_pf
@property
def fake_virtual_function(self):
db_vf = fake_virtual_function.fake_db_virtual_function(id=3)
return db_vf
@property
def fake_accelerator(self):
db_acc = fake_accelerator.fake_db_accelerator(id=2)
return db_acc
def test_create(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
self.assertEqual(db_pf['uuid'], pf.uuid)
def test_get(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(pf_get.uuid, pf.uuid)
def test_get_by_filter(self):
db_acc = self.fake_accelerator
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = pf_get.accelerator_id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
pf_get.add_vf(vf_get)
pf_get.save(self.context)
query = {"vendor": pf['vendor']}
pf_get_list = objects.PhysicalFunction.get_by_filter(self.context,
query)
self.assertEqual(len(pf_get_list), 1)
self.assertEqual(pf_get_list[0].uuid, pf.uuid)
self.assertEqual(objects.PhysicalFunction, type(pf_get_list[0]))
self.assertEqual(objects.VirtualFunction,
type(pf_get_list[0].virtual_function_list[0]))
self.assertEqual(pf_get_list[0].virtual_function_list[0].uuid,
vf.uuid)
def test_save(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
pf_get.host = 'test_save'
pf_get.save(self.context)
pf_get_2 = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(pf_get_2.host, 'test_save')
def test_destroy(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(db_pf['uuid'], pf_get.uuid)
pf_get.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.PhysicalFunction.get, self.context,
pf_get['uuid'])
def test_add_vf(self):
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = pf_get.accelerator_id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
pf_get.add_vf(vf_get)
pf_get.save(self.context)
pf_get_2 = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(db_vf['uuid'],
pf_get_2.virtual_function_list[0].uuid)
class TestPhysicalFunctionObject(test_objects._LocalTest,
_TestPhysicalFunctionObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
pf = fake_physical_function.physical_function_obj(self.context)
fields_with_save_methods = [field for field in pf.fields
if hasattr(pf, '_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(pf, '_save_%s' % field)
@mock.patch.object(pf, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
pf.obj_reset_changes(fields=[field])
pf._changed_fields.add(field)
self.assertRaises(expected_exception, pf.save)
pf.obj_reset_changes(fields=[field])
_test()

View File

@ -0,0 +1,202 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import mock
import netaddr
from oslo_db import exception as db_exc
from oslo_serialization import jsonutils
from oslo_utils import timeutils
from oslo_context import context
from cyborg import db
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_physical_function
from cyborg.tests.unit import fake_virtual_function
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
class _TestVirtualFunctionObject(DbTestCase):
@property
def fake_accelerator(self):
db_acc = fake_accelerator.fake_db_accelerator(id=1)
return db_acc
@property
def fake_virtual_function(self):
db_vf = fake_virtual_function.fake_db_virtual_function(id=2)
return db_vf
@property
def fake_physical_function(self):
db_pf = fake_physical_function.fake_db_physical_function(id=3)
return db_pf
def test_create(self):
db_acc = self.fake_accelerator
db_vf = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
self.assertEqual(db_vf['uuid'], vf.uuid)
def test_get(self):
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
self.assertEqual(vf_get.uuid, vf.uuid)
def test_get_by_filter(self):
db_acc = self.fake_accelerator
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = pf_get.accelerator_id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
pf_get.add_vf(vf_get)
pf_get.save(self.context)
query = {"vendor": pf_get['vendor']}
vf_get_list = objects.VirtualFunction.get_by_filter(self.context,
query)
self.assertEqual(len(vf_get_list), 1)
self.assertEqual(vf_get_list[0].uuid, vf.uuid)
self.assertEqual(objects.VirtualFunction, type(vf_get_list[0]))
self.assertEqual(1, 1)
def test_get_by_filter2(self):
db_acc = self.fake_accelerator
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
db_pf2 = self.fake_physical_function
db_vf2 = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
pf2 = objects.PhysicalFunction(context=self.context,
**db_pf2)
pf2.accelerator_id = acc_get.id
pf2.create(self.context)
pf_get2 = objects.PhysicalFunction.get(self.context, pf2.uuid)
query = {"uuid": pf2.uuid}
pf_get_list = objects.PhysicalFunction.get_by_filter(self.context,
query)
self.assertEqual(1, 1)
def test_save(self):
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
vf_get.host = 'test_save'
vf_get.save(self.context)
vf_get_2 = objects.VirtualFunction.get(self.context, vf.uuid)
self.assertEqual(vf_get_2.host, 'test_save')
def test_destroy(self):
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
self.assertEqual(db_vf['uuid'], vf_get.uuid)
vf_get.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.VirtualFunction.get, self.context,
vf_get['uuid'])
class TestVirtualFunctionObject(test_objects._LocalTest,
_TestVirtualFunctionObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
vf = fake_virtual_function.virtual_function_obj(self.context)
fields_with_save_methods = [field for field in vf.fields
if hasattr(vf, '_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(vf, '_save_%s' % field)
@mock.patch.object(vf, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
vf.obj_reset_changes(fields=[field])
vf._changed_fields.add(field)
self.assertRaises(expected_exception, vf.save)
vf.obj_reset_changes(fields=[field])
_test()