Merge "Implemented the Objects and APIs for vf/pf"

This commit is contained in:
Zuul 2018-04-19 18:00:34 +00:00 committed by Gerrit Code Review
commit ce0b6e5a95
18 changed files with 1235 additions and 26 deletions

View File

@ -147,6 +147,10 @@ class DeployableNotFound(NotFound):
_msg_fmt = _("Deployable %(uuid)s could not be found.")
class InvalidDeployType(CyborgException):
_msg_fmt = _("Deployable have an invalid type")
class Conflict(CyborgException):
_msg_fmt = _('Conflict.')
code = http_client.CONFLICT
@ -180,3 +184,15 @@ class PlacementInventoryUpdateConflict(Conflict):
class ObjectActionError(CyborgException):
_msg_fmt = _('Object action %(action)s failed because: %(reason)s')
class AttributeNotFound(NotFound):
_msg_fmt = _("Attribute %(uuid)s could not be found.")
class AttributeInvalid(CyborgException):
_msg_fmt = _("Attribute is invalid")
class AttributeAlreadyExists(CyborgException):
_msg_fmt = _("Attribute with uuid %(uuid)s already exists.")

View File

@ -87,3 +87,31 @@ class Connection(object):
@abc.abstractmethod
def deployable_delete(self, context, uuid):
"""Delete a deployable."""
@abc.abstractmethod
def deployable_get_by_filters(self, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, columns_to_join=None):
"""Get requested deployable by filters."""
# attributes
@abc.abstractmethod
def attribute_create(self, context, key, value):
"""Create a new attribute."""
@abc.abstractmethod
def attribute_get(self, context, uuid):
"""Get requested attribute."""
@abc.abstractmethod
def attribute_get_by_deployable_uuid(self, context, deployable_uuid):
"""Get requested deployable by deployable uuid."""
@abc.abstractmethod
def attribute_update(self, context, uuid, key, value):
"""Update an attribute's key value pair."""
@abc.abstractmethod
def attribute_delete(self, context, uuid):
"""Delete an attribute."""

View File

@ -71,6 +71,9 @@ def upgrade():
sa.Column('assignable', sa.Boolean(), nullable=False),
sa.Column('instance_uuid', sa.String(length=36), nullable=True),
sa.Column('availability', sa.Text(), nullable=False),
sa.Column('accelerator_id', sa.Integer(),
sa.ForeignKey('accelerators.id', ondelete="CASCADE"),
nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('uuid', name='uniq_deployables0uuid'),
sa.Index('deployables_parent_uuid_idx', 'parent_uuid'),
@ -78,3 +81,21 @@ def upgrade():
mysql_ENGINE='InnoDB',
mysql_DEFAULT_CHARSET='UTF8'
)
op.create_table(
'attributes',
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('uuid', sa.String(length=36), nullable=False),
sa.Column('deployable_id', sa.Integer(),
sa.ForeignKey('deployables.id', ondelete="CASCADE"),
nullable=False),
sa.Column('key', sa.Text(), nullable=False),
sa.Column('value', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('uuid', name='uniq_attributes0uuid'),
sa.Index('attributes_deployable_id_idx', 'deployable_id'),
mysql_ENGINE='InnoDB',
mysql_DEFAULT_CHARSET='UTF8'
)

View File

@ -16,6 +16,7 @@
"""SQLAlchemy storage backend."""
import threading
import copy
from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc
@ -180,7 +181,8 @@ class Connection(api.Connection):
def deployable_create(self, context, values):
if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid()
if values.get('id'):
values.pop('id', None)
deployable = models.Deployable()
deployable.update(values)
@ -226,7 +228,8 @@ class Connection(api.Connection):
def _do_update_deployable(self, context, uuid, values):
with _session_for_write():
query = model_query(context, models.Deployable)
query = add_identity_filter(query, uuid)
# query = add_identity_filter(query, uuid)
query = query.filter_by(uuid=uuid)
try:
ref = query.with_lockmode('update').one()
except NoResultFound:
@ -244,3 +247,193 @@ class Connection(api.Connection):
count = query.delete()
if count != 1:
raise exception.DeployableNotFound(uuid=uuid)
def deployable_get_by_filters(self, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join_columns=None):
"""Return list of deployables matching all filters sorted by
the sort_key. See deployable_get_by_filters_sort for
more information.
"""
return self.deployable_get_by_filters_sort(context, filters,
limit=limit, marker=marker,
join_columns=join_columns,
sort_keys=[sort_key],
sort_dirs=[sort_dir])
def _exact_deployable_filter(self, query, filters, legal_keys):
"""Applies exact match filtering to a deployable query.
Returns the updated query. Modifies filters argument to remove
filters consumed.
:param query: query to apply filters to
:param filters: dictionary of filters; values that are lists,
tuples, sets, or frozensets cause an 'IN' test to
be performed, while exact matching ('==' operator)
is used for other values
:param legal_keys: list of keys to apply exact filtering to
"""
filter_dict = {}
model = models.Deployable
# Walk through all the keys
for key in legal_keys:
# Skip ones we're not filtering on
if key not in filters:
continue
# OK, filtering on this key; what value do we search for?
value = filters.pop(key)
if isinstance(value, (list, tuple, set, frozenset)):
if not value:
return None
# Looking for values in a list; apply to query directly
column_attr = getattr(model, key)
query = query.filter(column_attr.in_(value))
else:
filter_dict[key] = value
# Apply simple exact matches
if filter_dict:
query = query.filter(*[getattr(models.Deployable, k) == v
for k, v in filter_dict.items()])
return query
def deployable_get_by_filters_sort(self, context, filters, limit=None,
marker=None, join_columns=None,
sort_keys=None, sort_dirs=None):
"""Return deployables that match all filters sorted by the given
keys. Deleted deployables will be returned by default, unless
there's a filter that says otherwise.
"""
if limit == 0:
return []
sort_keys, sort_dirs = self.process_sort_params(sort_keys,
sort_dirs,
default_dir='desc')
query_prefix = model_query(context, models.Deployable)
filters = copy.deepcopy(filters)
exact_match_filter_names = ['uuid', 'name',
'parent_uuid', 'root_uuid',
'pcie_address', 'host',
'board', 'vendor', 'version',
'type', 'assignable', 'instance_uuid',
'availability', 'accelerator_id']
# Filter the query
query_prefix = self._exact_deployable_filter(query_prefix,
filters,
exact_match_filter_names)
if query_prefix is None:
return []
deployables = query_prefix.all()
return deployables
def attribute_create(self, context, key, value):
update_fields = {'key': key, 'value': value}
update_fields['uuid'] = uuidutils.generate_uuid()
attribute = models.Attribute()
attribute.update(update_fields)
with _session_for_write() as session:
try:
session.add(attribute)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.AttributeAlreadyExists(
uuid=update_fields['uuid'])
return attribute
def attribute_get(self, context, uuid):
query = model_query(
context,
models.Attribute).filter_by(uuid=uuid)
try:
return query.one()
except NoResultFound:
raise exception.AttributeNotFound(uuid=uuid)
def attribute_get_by_deployable_uuid(self, context, deployable_uuid):
query = model_query(
context,
models.Attribute).filter_by(deployable_uuid=deployable_uuid)
try:
return query.all()
except NoResultFound:
raise exception.AttributeNotFound(uuid=uuid)
def attribute_update(self, context, uuid, key, value):
return self._do_update_attribute(context, uuid, key, value)
@oslo_db_api.retry_on_deadlock
def _do_update_attribute(self, context, uuid, key, value):
update_fields = {'key': key, 'value': value}
with _session_for_write():
query = model_query(context, models.Attribute)
query = add_identity_filter(query, uuid)
try:
ref = query.with_lockmode('update').one()
except NoResultFound:
raise exception.AttributeNotFound(uuid=uuid)
ref.update(update_fields)
return ref
def attribute_delete(self, context, uuid):
with _session_for_write():
query = model_query(context, models.Attribute)
query = add_identity_filter(query, uuid)
count = query.delete()
if count != 1:
raise exception.AttributeNotFound(uuid=uuid)
def process_sort_params(self, sort_keys, sort_dirs,
default_keys=['created_at', 'id'],
default_dir='asc'):
# Determine direction to use for when adding default keys
if sort_dirs and len(sort_dirs) != 0:
default_dir_value = sort_dirs[0]
else:
default_dir_value = default_dir
# Create list of keys (do not modify the input list)
if sort_keys:
result_keys = list(sort_keys)
else:
result_keys = []
# If a list of directions is not provided,
# use the default sort direction for all provided keys
if sort_dirs:
result_dirs = []
# Verify sort direction
for sort_dir in sort_dirs:
if sort_dir not in ('asc', 'desc'):
msg = _("Unknown sort direction, must be 'desc' or 'asc'")
raise exception.InvalidInput(reason=msg)
result_dirs.append(sort_dir)
else:
result_dirs = [default_dir_value for _sort_key in result_keys]
# Ensure that the key and direction length match
while len(result_dirs) < len(result_keys):
result_dirs.append(default_dir_value)
# Unless more direction are specified, which is an error
if len(result_dirs) > len(result_keys):
msg = _("Sort direction size exceeds sort key size")
raise exception.InvalidInput(reason=msg)
# Ensure defaults are included
for key in default_keys:
if key not in result_keys:
result_keys.append(key)
result_dirs.append(default_dir_value)
return result_keys, result_dirs

View File

@ -20,6 +20,7 @@ from oslo_db.sqlalchemy import models
import six.moves.urllib.parse as urlparse
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, Boolean, ForeignKey, Index
from sqlalchemy import Text
from sqlalchemy import schema
from cyborg.common import paths
@ -82,6 +83,7 @@ class Deployable(Base):
schema.UniqueConstraint('uuid', name='uniq_deployables0uuid'),
Index('deployables_parent_uuid_idx', 'parent_uuid'),
Index('deployables_root_uuid_idx', 'root_uuid'),
Index('deployables_accelerator_id_idx', 'accelerator_id'),
table_args()
)
@ -101,3 +103,23 @@ class Deployable(Base):
assignable = Column(Boolean, nullable=False)
instance_uuid = Column(String(36), nullable=True)
availability = Column(String(255), nullable=False)
accelerator_id = Column(Integer,
ForeignKey('accelerators.id', ondelete="CASCADE"),
nullable=False)
class Attribute(Base):
__tablename__ = 'attributes'
__table_args__ = (
schema.UniqueConstraint('uuid', name='uniq_attributes0uuid'),
Index('attributes_deployable_id_idx', 'deployable_id'),
table_args()
)
id = Column(Integer, primary_key=True)
uuid = Column(String(36), nullable=False)
deployable_id = Column(Integer,
ForeignKey('deployables.id', ondelete="CASCADE"),
nullable=False)
key = Column(Text, nullable=False)
value = Column(Text, nullable=False)

View File

@ -32,6 +32,7 @@ class Accelerator(base.CyborgObject, object_base.VersionedObjectDictCompat):
dbapi = dbapi.get_instance()
fields = {
'id': object_fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False),
'name': object_fields.StringField(nullable=False),
'description': object_fields.StringField(nullable=True),

View File

@ -0,0 +1,84 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log as logging
from oslo_versionedobjects import base as object_base
from cyborg.common import exception
from cyborg.db import api as dbapi
from cyborg.objects import base
from cyborg.objects import fields as object_fields
LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat):
# Version 1.0: Initial version
VERSION = '1.0'
dbapi = dbapi.get_instance()
fields = {
'id': fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False),
'deployable_id': fields.IntegerField(nullable=False),
'key': object_fields.StringField(nullable=False),
'value': object_fields.StringField(nullable=False)
}
def create(self, context):
"""Create a attribute record in the DB."""
if self.deployable_id is None:
raise exception.AttributeInvalid()
values = self.obj_get_changes()
db_attr = self.dbapi.attribute_create(context,
self.key,
self.value)
self._from_db_object(self, db_attr)
@classmethod
def get(cls, context, uuid):
"""Find a DB Deployable and return an Obj Deployable."""
db_attr = cls.dbapi.attribute_get(context, uuid)
obj_attr = cls._from_db_object(cls(context), db_attr)
return obj_attr
@classmethod
def attribute_get_by_deployable_uuid(cls, context, deployable_uuid):
"""Get a Deployable by host."""
db_attr = cls.dbapi.attribute_get_by_deployable_uuid(context,
deployable_uuid)
return cls._from_db_object_list(db_attr, context)
def save(self, context):
"""Update a Deployable record in the DB."""
updates = self.obj_get_changes()
db_attr = self.dbapi.attribute_update(context,
self.uuid,
self.key,
self.value)
self._from_db_object(self, db_attr)
def destroy(self, context):
"""Delete a Deployable from the DB."""
self.dbapi.attribute_delete(context, self.uuid)
self.obj_reset_changes()
def set_key_value_pair(self, set_key, set_value):
self.key = set_key
self.value = set_value

View File

@ -143,3 +143,36 @@ def obj_to_primitive(obj):
return str(obj)
else:
return obj
def obj_equal_prims(obj_1, obj_2, ignore=None):
"""Compare two primitives for equivalence ignoring some keys.
This operation tests the primitives of two objects for equivalence.
Object primitives may contain a list identifying fields that have been
changed - this is ignored in the comparison. The ignore parameter lists
any other keys to be ignored.
:param:obj1: The first object in the comparison
:param:obj2: The second object in the comparison
:param:ignore: A list of fields to ignore
:returns: True if the primitives are equal ignoring changes
and specified fields, otherwise False.
"""
def _strip(prim, keys):
if isinstance(prim, dict):
for k in keys:
prim.pop(k, None)
for v in prim.values():
_strip(v, keys)
if isinstance(prim, list):
for v in prim:
_strip(v, keys)
return prim
if ignore is not None:
keys = ['cyborg_object.changes'] + ignore
else:
keys = ['cyborg_object.changes']
prim_1 = _strip(obj_1.obj_to_primitive(), keys)
prim_2 = _strip(obj_2.obj_to_primitive(), keys)
return prim_1 == prim_2

View File

@ -31,8 +31,10 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
VERSION = '1.0'
dbapi = dbapi.get_instance()
attributes_list = []
fields = {
'id': object_fields.IntegerField(nullable=False),
'uuid': object_fields.UUIDField(nullable=False),
'name': object_fields.StringField(nullable=False),
'parent_uuid': object_fields.UUIDField(nullable=True),
@ -53,6 +55,8 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
# The id of the virtualized accelerator instance
'availability': object_fields.StringField(nullable=False),
# identify the state of acc, e.g released/claimed/...
'accelerator_id': object_fields.IntegerField(nullable=False)
# Foreign key constrain to reference accelerator table
}
def _get_parent_root_uuid(self):
@ -71,6 +75,7 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
self.root_uuid = self._get_parent_root_uuid()
values = self.obj_get_changes()
db_dep = self.dbapi.deployable_create(context, values)
self._from_db_object(self, db_dep)
@ -103,3 +108,32 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat):
"""Delete a Deployable from the DB."""
self.dbapi.deployable_delete(context, self.uuid)
self.obj_reset_changes()
def add_attribute(self, attribute):
"""add a attribute object to the attribute_list.
If the attribute already exists, it will update the value,
otherwise, the vf will be appended to the list
"""
if not isinstance(vf, VirtualFunction) or vf.type != 'vf':
raise exception.InvalidDeployType()
for exist_vf in self.virtual_function_list:
if base.obj_equal_prims(vf, exist_vf):
LOG.warning("The vf already exists")
return None
@classmethod
def get_by_filter(cls, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join=None):
obj_dpl_list = []
db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters,
sort_key=sort_key,
sort_dir=sort_dir,
limit=limit,
marker=marker,
join_columns=join)
for db_dpl in db_dpl_list:
obj_dpl = cls._from_db_object(cls(context), db_dpl)
obj_dpl_list.append(obj_dpl)
return obj_dpl_list

View File

@ -0,0 +1,137 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
from oslo_log import log as logging
from oslo_versionedobjects import base as object_base
from cyborg.common import exception
from cyborg.db import api as dbapi
from cyborg.objects import base
from cyborg.objects import fields as object_fields
from cyborg.objects.deployable import Deployable
from cyborg.objects.virtual_function import VirtualFunction
LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class PhysicalFunction(Deployable):
# Version 1.0: Initial version
VERSION = '1.0'
virtual_function_list = []
def create(self, context):
# To ensure the creating type is PF
if self.type != 'pf':
raise exception.InvalidDeployType()
super(PhysicalFunction, self).create(context)
def save(self, context):
"""In addition to save the pf, it should also save the
vfs associated with this pf
"""
# To ensure the saving type is PF
if self.type != 'pf':
raise exception.InvalidDeployType()
for exist_vf in self.virtual_function_list:
exist_vf.save(context)
super(PhysicalFunction, self).save(context)
def add_vf(self, vf):
"""add a vf object to the virtual_function_list.
If the vf already exists, it will ignore,
otherwise, the vf will be appended to the list
"""
if not isinstance(vf, VirtualFunction) or vf.type != 'vf':
raise exception.InvalidDeployType()
for exist_vf in self.virtual_function_list:
if base.obj_equal_prims(vf, exist_vf):
LOG.warning("The vf already exists")
return None
vf.parent_uuid = self.uuid
vf.root_uuid = self.root_uuid
vf_copy = copy.deepcopy(vf)
self.virtual_function_list.append(vf_copy)
def delete_vf(self, context, vf):
"""remove a vf from the virtual_function_list
if the vf does not exist, ignore it
"""
for idx, exist_vf in self.virtual_function_list:
if base.obj_equal_prims(vf, exist_vf):
removed_vf = self.virtual_function_list.pop(idx)
removed_vf.destroy(context)
return
LOG.warning("The removing vf does not exist!")
def destroy(self, context):
"""Delete a the pf from the DB."""
del self.virtual_function_list[:]
super(PhysicalFunction, self).destroy(context)
@classmethod
def get(cls, context, uuid):
"""Find a DB Physical Function and return an Obj Physical Function.
In addition, it will also finds all the Virtual Functions associated
with this Physical Function and place them in virtual_function_list
"""
db_pf = cls.dbapi.deployable_get(context, uuid)
obj_pf = cls._from_db_object(cls(context), db_pf)
pf_uuid = obj_pf.uuid
query = {"parent_uuid": pf_uuid, "type": "vf"}
db_vf_list = cls.dbapi.deployable_get_by_filters(context, query)
for db_vf in db_vf_list:
obj_vf = VirtualFunction.get(context, db_vf.uuid)
obj_pf.virtual_function_list.append(obj_vf)
return obj_pf
@classmethod
def get_by_filter(cls, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join=None):
obj_dpl_list = []
filters['type'] = 'pf'
db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters,
sort_key=sort_key,
sort_dir=sort_dir,
limit=limit,
marker=marker,
join_columns=join)
for db_dpl in db_dpl_list:
obj_dpl = cls._from_db_object(cls(context), db_dpl)
query = {"parent_uuid": obj_dpl.uuid}
vf_get_list = VirtualFunction.get_by_filter(context,
query)
obj_dpl.virtual_function_list = vf_get_list
obj_dpl_list.append(obj_dpl)
return obj_dpl_list
@classmethod
def _from_db_object(cls, obj, db_obj):
"""Converts a physical function to a formal object.
:param obj: An object of the class.
:param db_obj: A DB model of the object
:return: The object of the class with the database entity added
"""
obj = Deployable._from_db_object(obj, db_obj)
if cls is PhysicalFunction:
obj.virtual_function_list = []
return obj

View File

@ -0,0 +1,61 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log as logging
from oslo_versionedobjects import base as object_base
from cyborg.common import exception
from cyborg.db import api as dbapi
from cyborg.objects import base
from cyborg.objects import fields as object_fields
from cyborg.objects.deployable import Deployable
LOG = logging.getLogger(__name__)
@base.CyborgObjectRegistry.register
class VirtualFunction(Deployable):
# Version 1.0: Initial version
VERSION = '1.0'
def create(self, context):
# To ensure the creating type is VF
if self.type != 'vf':
raise exception.InvalidDeployType()
super(VirtualFunction, self).create(context)
def save(self, context):
# To ensure the saving type is VF
if self.type != 'vf':
raise exception.InvalidDeployType()
super(VirtualFunction, self).save(context)
@classmethod
def get_by_filter(cls, context,
filters, sort_key='created_at',
sort_dir='desc', limit=None,
marker=None, join=None):
obj_dpl_list = []
filters['type'] = 'vf'
db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters,
sort_key=sort_key,
sort_dir=sort_dir,
limit=limit,
marker=marker,
join_columns=join)
for db_dpl in db_dpl_list:
obj_dpl = cls._from_db_object(cls(context), db_dpl)
obj_dpl_list.append(obj_dpl)
return obj_dpl_list

View File

@ -38,7 +38,8 @@ def fake_db_deployable(**updates):
'type': 'pf',
'assignable': True,
'instance_uuid': None,
'availability': 'Available'
'availability': 'Available',
'accelerator_id': 1
}
for name, field in objects.Deployable.fields.items():

View File

@ -0,0 +1,72 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
from cyborg.objects import physical_function
def fake_db_physical_function(**updates):
root_uuid = uuidutils.generate_uuid()
db_physical_function = {
'id': 1,
'deleted': False,
'uuid': root_uuid,
'name': 'dp_name',
'parent_uuid': None,
'root_uuid': root_uuid,
'pcie_address': '00:7f:0b.2',
'host': 'host_name',
'board': 'KU115',
'vendor': 'Xilinx',
'version': '1.0',
'type': 'pf',
'assignable': True,
'instance_uuid': None,
'availability': 'Available',
'accelerator_id': 1
}
for name, field in physical_function.PhysicalFunction.fields.items():
if name in db_physical_function:
continue
if field.nullable:
db_physical_function[name] = None
elif field.default != fields.UnspecifiedDefault:
db_physical_function[name] = field.default
else:
raise Exception('fake_db_physical_function needs help with %s'
% name)
if updates:
db_physical_function.update(updates)
return db_physical_function
def fake_physical_function_obj(context, obj_pf_class=None, **updates):
if obj_pf_class is None:
obj_pf_class = objects.VirtualFunction
expected_attrs = updates.pop('expected_attrs', None)
pf = obj_pf_class._from_db_object(context,
obj_pf_class(),
fake_db_physical_function(**updates),
expected_attrs=expected_attrs)
pf.obj_reset_changes()
return vf

View File

@ -0,0 +1,72 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from oslo_serialization import jsonutils
from oslo_utils import uuidutils
from cyborg import objects
from cyborg.objects import fields
from cyborg.objects import virtual_function
def fake_db_virtual_function(**updates):
root_uuid = uuidutils.generate_uuid()
db_virtual_function = {
'id': 1,
'deleted': False,
'uuid': root_uuid,
'name': 'dp_name',
'parent_uuid': None,
'root_uuid': root_uuid,
'pcie_address': '00:7f:bb.2',
'host': 'host_name',
'board': 'KU115',
'vendor': 'Xilinx',
'version': '1.0',
'type': 'vf',
'assignable': True,
'instance_uuid': None,
'availability': 'Available',
'accelerator_id': 1
}
for name, field in virtual_function.VirtualFunction.fields.items():
if name in db_virtual_function:
continue
if field.nullable:
db_virtual_function[name] = None
elif field.default != fields.UnspecifiedDefault:
db_virtual_function[name] = field.default
else:
raise Exception('fake_db_virtual_function needs help with %s'
% name)
if updates:
db_virtual_function.update(updates)
return db_virtual_function
def fake_virtual_function_obj(context, obj_vf_class=None, **updates):
if obj_vf_class is None:
obj_vf_class = objects.VirtualFunction
expected_attrs = updates.pop('expected_attrs', None)
vf = obj_vf_class._from_db_object(context,
obj_vf_class(),
fake_db_virtual_function(**updates),
expected_attrs=expected_attrs)
vf.obj_reset_changes()
return vf

View File

@ -26,6 +26,7 @@ from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit import fake_deployable
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
@ -34,49 +35,95 @@ from cyborg.tests.unit.db.base import DbTestCase
class _TestDeployableObject(DbTestCase):
@property
def fake_deployable(self):
db_deploy = fake_deployable.fake_db_deployable(id=2)
db_deploy = fake_deployable.fake_db_deployable(id=1)
return db_deploy
@mock.patch.object(db.api.Connection, 'deployable_create')
def test_create(self, mock_create):
mock_create.return_value = self.fake_deployable
@property
def fake_accelerator(self):
db_acc = fake_accelerator.fake_db_accelerator(id=2)
return db_acc
def test_create(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_create.return_value)
**db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context)
self.assertEqual(self.fake_deployable['id'], dpl.id)
self.assertEqual(db_dpl['uuid'], dpl.uuid)
@mock.patch.object(db.api.Connection, 'deployable_get')
def test_get(self, mock_get):
mock_get.return_value = self.fake_deployable
def test_get(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_get.return_value)
**db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context)
dpl_get = objects.Deployable.get(self.context, dpl['uuid'])
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(dpl_get.uuid, dpl.uuid)
@mock.patch.object(db.api.Connection, 'deployable_update')
def test_save(self, mock_save):
mock_save.return_value = self.fake_deployable
def test_get_by_filter(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_save.return_value)
**db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context)
query = {"uuid": dpl['uuid']}
dpl_get_list = objects.Deployable.get_by_filter(self.context, query)
self.assertEqual(dpl_get_list[0].uuid, dpl.uuid)
def test_save(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context)
dpl.host = 'test_save'
dpl.save(self.context)
dpl_get = objects.Deployable.get(self.context, dpl['uuid'])
dpl_get = objects.Deployable.get(self.context, dpl.uuid)
self.assertEqual(dpl_get.host, 'test_save')
@mock.patch.object(db.api.Connection, 'deployable_delete')
def test_destroy(self, mock_destroy):
mock_destroy.return_value = self.fake_deployable
def test_destroy(self):
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
db_dpl = self.fake_deployable
dpl = objects.Deployable(context=self.context,
**mock_destroy.return_value)
**db_dpl)
dpl.accelerator_id = acc_get.id
dpl.create(self.context)
self.assertEqual(self.fake_deployable['id'], dpl.id)
self.assertEqual(db_dpl['uuid'], dpl.uuid)
dpl.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.Deployable.get, self.context,
dpl['uuid'])
dpl.uuid)
class TestDeployableObject(test_objects._LocalTest,

View File

@ -17,7 +17,6 @@ import copy
import datetime
import inspect
import os
import pprint
import fixtures
import mock

View File

@ -0,0 +1,186 @@
import mock
import netaddr
from oslo_db import exception as db_exc
from oslo_serialization import jsonutils
from oslo_utils import timeutils
from oslo_context import context
from cyborg import db
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_physical_function
from cyborg.tests.unit import fake_virtual_function
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
class _TestPhysicalFunctionObject(DbTestCase):
@property
def fake_physical_function(self):
db_pf = fake_physical_function.fake_db_physical_function(id=1)
return db_pf
@property
def fake_virtual_function(self):
db_vf = fake_virtual_function.fake_db_virtual_function(id=3)
return db_vf
@property
def fake_accelerator(self):
db_acc = fake_accelerator.fake_db_accelerator(id=2)
return db_acc
def test_create(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
self.assertEqual(db_pf['uuid'], pf.uuid)
def test_get(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(pf_get.uuid, pf.uuid)
def test_get_by_filter(self):
db_acc = self.fake_accelerator
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = pf_get.accelerator_id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
pf_get.add_vf(vf_get)
pf_get.save(self.context)
query = {"vendor": pf['vendor']}
pf_get_list = objects.PhysicalFunction.get_by_filter(self.context,
query)
self.assertEqual(len(pf_get_list), 1)
self.assertEqual(pf_get_list[0].uuid, pf.uuid)
self.assertEqual(objects.PhysicalFunction, type(pf_get_list[0]))
self.assertEqual(objects.VirtualFunction,
type(pf_get_list[0].virtual_function_list[0]))
self.assertEqual(pf_get_list[0].virtual_function_list[0].uuid,
vf.uuid)
def test_save(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
pf_get.host = 'test_save'
pf_get.save(self.context)
pf_get_2 = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(pf_get_2.host, 'test_save')
def test_destroy(self):
db_pf = self.fake_physical_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(db_pf['uuid'], pf_get.uuid)
pf_get.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.PhysicalFunction.get, self.context,
pf_get['uuid'])
def test_add_vf(self):
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = pf_get.accelerator_id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
pf_get.add_vf(vf_get)
pf_get.save(self.context)
pf_get_2 = objects.PhysicalFunction.get(self.context, pf.uuid)
self.assertEqual(db_vf['uuid'],
pf_get_2.virtual_function_list[0].uuid)
class TestPhysicalFunctionObject(test_objects._LocalTest,
_TestPhysicalFunctionObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
pf = fake_physical_function.physical_function_obj(self.context)
fields_with_save_methods = [field for field in pf.fields
if hasattr(pf, '_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(pf, '_save_%s' % field)
@mock.patch.object(pf, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
pf.obj_reset_changes(fields=[field])
pf._changed_fields.add(field)
self.assertRaises(expected_exception, pf.save)
pf.obj_reset_changes(fields=[field])
_test()

View File

@ -0,0 +1,202 @@
# Copyright 2018 Huawei Technologies Co.,LTD.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import mock
import netaddr
from oslo_db import exception as db_exc
from oslo_serialization import jsonutils
from oslo_utils import timeutils
from oslo_context import context
from cyborg import db
from cyborg.common import exception
from cyborg import objects
from cyborg.objects import base
from cyborg import tests as test
from cyborg.tests.unit import fake_physical_function
from cyborg.tests.unit import fake_virtual_function
from cyborg.tests.unit import fake_accelerator
from cyborg.tests.unit.objects import test_objects
from cyborg.tests.unit.db.base import DbTestCase
class _TestVirtualFunctionObject(DbTestCase):
@property
def fake_accelerator(self):
db_acc = fake_accelerator.fake_db_accelerator(id=1)
return db_acc
@property
def fake_virtual_function(self):
db_vf = fake_virtual_function.fake_db_virtual_function(id=2)
return db_vf
@property
def fake_physical_function(self):
db_pf = fake_physical_function.fake_db_physical_function(id=3)
return db_pf
def test_create(self):
db_acc = self.fake_accelerator
db_vf = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
self.assertEqual(db_vf['uuid'], vf.uuid)
def test_get(self):
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
self.assertEqual(vf_get.uuid, vf.uuid)
def test_get_by_filter(self):
db_acc = self.fake_accelerator
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = pf_get.accelerator_id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
pf_get.add_vf(vf_get)
pf_get.save(self.context)
query = {"vendor": pf_get['vendor']}
vf_get_list = objects.VirtualFunction.get_by_filter(self.context,
query)
self.assertEqual(len(vf_get_list), 1)
self.assertEqual(vf_get_list[0].uuid, vf.uuid)
self.assertEqual(objects.VirtualFunction, type(vf_get_list[0]))
self.assertEqual(1, 1)
def test_get_by_filter2(self):
db_acc = self.fake_accelerator
db_pf = self.fake_physical_function
db_vf = self.fake_virtual_function
db_pf2 = self.fake_physical_function
db_vf2 = self.fake_virtual_function
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
pf = objects.PhysicalFunction(context=self.context,
**db_pf)
pf.accelerator_id = acc_get.id
pf.create(self.context)
pf_get = objects.PhysicalFunction.get(self.context, pf.uuid)
pf2 = objects.PhysicalFunction(context=self.context,
**db_pf2)
pf2.accelerator_id = acc_get.id
pf2.create(self.context)
pf_get2 = objects.PhysicalFunction.get(self.context, pf2.uuid)
query = {"uuid": pf2.uuid}
pf_get_list = objects.PhysicalFunction.get_by_filter(self.context,
query)
self.assertEqual(1, 1)
def test_save(self):
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
vf_get.host = 'test_save'
vf_get.save(self.context)
vf_get_2 = objects.VirtualFunction.get(self.context, vf.uuid)
self.assertEqual(vf_get_2.host, 'test_save')
def test_destroy(self):
db_vf = self.fake_virtual_function
db_acc = self.fake_accelerator
acc = objects.Accelerator(context=self.context,
**db_acc)
acc.create(self.context)
acc_get = objects.Accelerator.get(self.context, acc.uuid)
vf = objects.VirtualFunction(context=self.context,
**db_vf)
vf.accelerator_id = acc_get.id
vf.create(self.context)
vf_get = objects.VirtualFunction.get(self.context, vf.uuid)
self.assertEqual(db_vf['uuid'], vf_get.uuid)
vf_get.destroy(self.context)
self.assertRaises(exception.DeployableNotFound,
objects.VirtualFunction.get, self.context,
vf_get['uuid'])
class TestVirtualFunctionObject(test_objects._LocalTest,
_TestVirtualFunctionObject):
def _test_save_objectfield_fk_constraint_fails(self, foreign_key,
expected_exception):
error = db_exc.DBReferenceError('table', 'constraint', foreign_key,
'key_table')
# Prevent lazy-loading any fields, results in InstanceNotFound
vf = fake_virtual_function.virtual_function_obj(self.context)
fields_with_save_methods = [field for field in vf.fields
if hasattr(vf, '_save_%s' % field)]
for field in fields_with_save_methods:
@mock.patch.object(vf, '_save_%s' % field)
@mock.patch.object(vf, 'obj_attr_is_set')
def _test(mock_is_set, mock_save_field):
mock_is_set.return_value = True
mock_save_field.side_effect = error
vf.obj_reset_changes(fields=[field])
vf._changed_fields.add(field)
self.assertRaises(expected_exception, vf.save)
vf.obj_reset_changes(fields=[field])
_test()