From 6f83466307fb21aee5bb596974644d457ae1fa60 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 6 Feb 2018 18:12:38 -0800 Subject: [PATCH] Allow objects to opt in new engine facade New facade is enabled by setting new_facade = True for the object of interest. With new_facade on, all OVO actions will use the new reader / writer decorator to activate sessions. There are two new facade decorators added to OVO: db_context_reader and db_context_write that should be used instead of explicit autonested_transaction / reader.using / writer.using in OVO context. All neutron.objects.db.api helpers now receive OVO classes / objects instead of model classes, since they need to know which type of engine facade to use for which object. While it means we change signatures for those helper functions, they are not used anywhere outside neutron tree except vmware-nsx unit tests, and the latter pass anyway because the tests completely mock out them disregarding their signatures. This patch also adds several new OVO objects to be able to continue using neutron.objects.db.api helpers to persist models that previously didn't have corresponding OVO classes. Finally, the patch adds registration for missing options in neutron/tests/unit/extensions/test_qos_fip.py to be able to debug failures in those unit tests. Strictly speaking, this change doesn't belong to the patch, but I include it nevertheless to speed up merge in time close to release. There are several non-obvious changes included, specifically: - in neutron.objects.base, decorator() that refreshes / expunges models from the active session now opens a subtransaction for the whole span of call / refresh / expunge, so that we can safely refresh model regardless of whether caller opened another parent subtransaction (it was not the case for create_subnetpool in base db plugin code). - in neutron.db.l3_fip_qos, removed code that updates obj.db_model relationship directly after corresponding insertions for child policy binding model. This code is not needed because the only caller to the _process_extra_fip_qos_update method refetches latest state of floating ip OVO object anyway, and this code triggers several unit test failures. - unit tests checking that a single commit happens for get_object and get_objects are no longer valid for new facade objects that use reader decorator that doesn't commit but close. This change is as intended, so unit tests were tweaked to check close for new facade objects. Change-Id: I15ec238c18a464f977f7d1079605b82965052311 Related-Bug: #1746996 --- .../contributor/internals/objects_usage.rst | 37 ++++++++ neutron/db/l3_fip_qos.py | 14 +-- neutron/objects/agent.py | 4 +- neutron/objects/base.py | 94 ++++++++++-------- neutron/objects/db/api.py | 70 +++++++------- neutron/objects/network.py | 25 +++-- neutron/objects/ports.py | 27 ++++-- neutron/objects/qos/policy.py | 52 ++++++---- neutron/objects/qos/rule.py | 5 +- neutron/objects/quota.py | 3 +- neutron/objects/rbac_db.py | 42 ++++---- neutron/objects/securitygroup.py | 3 +- neutron/objects/stdattrs.py | 33 +++++++ neutron/objects/subnet.py | 4 +- neutron/objects/subnetpool.py | 5 +- neutron/objects/trunk.py | 5 +- neutron/tests/unit/extensions/test_qos_fip.py | 6 ++ neutron/tests/unit/objects/db/test_api.py | 95 +++++++++++-------- neutron/tests/unit/objects/qos/test_policy.py | 18 ++-- neutron/tests/unit/objects/test_base.py | 75 ++++++++------- neutron/tests/unit/objects/test_network.py | 2 + neutron/tests/unit/objects/test_objects.py | 4 + neutron/tests/unit/objects/test_ports.py | 10 ++ neutron/tests/unit/objects/test_rbac_db.py | 43 ++++++--- neutron/tests/unit/objects/test_subnet.py | 4 +- 25 files changed, 425 insertions(+), 255 deletions(-) create mode 100644 neutron/objects/stdattrs.py diff --git a/doc/source/contributor/internals/objects_usage.rst b/doc/source/contributor/internals/objects_usage.rst index d7b92dde351..f55357739f1 100644 --- a/doc/source/contributor/internals/objects_usage.rst +++ b/doc/source/contributor/internals/objects_usage.rst @@ -322,6 +322,43 @@ model, the nullable parameter is by default :code:`True`, while for OVO fields, the nullable is set to :code:`False`. Make sure you correctly map database column nullability properties to relevant object fields. +Database session activation +--------------------------- + +By default, all objects use old ``oslo.db`` engine facade. To enable the new +facade for a particular object, set ``new_facade`` class attribute to ``True``: + +.. code-block:: Python + + @obj_base.VersionedObjectRegistry.register + class ExampleObject(base.NeutronDbObject): + new_facade = True + +It will make all OVO actions - ``get_object``, ``update``, ``count`` etc. - to +use new ``reader.using`` or ``writer.using`` decorators to manage database +transactions. + +Whenever you need to open a new subtransaction in scope of OVO code, use the +following database session decorators: + +.. code-block:: Python + + @obj_base.VersionedObjectRegistry.register + class ExampleObject(base.NeutronDbObject): + + @classmethod + def get_object(cls, context, **kwargs): + with cls.db_context_reader(context): + super(ExampleObject, cls).get_object(context, **kwargs) + # fetch more data in the same transaction + + def create(self): + with self.db_context_writer(self.obj_context): + super(ExampleObject, self).create() + # apply more changes in the same transaction + +``db_context_reader`` and ``db_context_writer`` decorators abstract the choice +of engine facade used for particular object from action implementation. Synthetic fields ---------------- diff --git a/neutron/db/l3_fip_qos.py b/neutron/db/l3_fip_qos.py index ef19533b959..fad3204806a 100644 --- a/neutron/db/l3_fip_qos.py +++ b/neutron/db/l3_fip_qos.py @@ -44,8 +44,7 @@ class FloatingQoSDbMixin(object): def _create_fip_qos_db(self, context, fip_id, policy_id): policy = self._get_policy_obj(context, policy_id) policy.attach_floatingip(fip_id) - binding_db_obj = obj_db_api.get_object( - context, policy.fip_binding_model, fip_id=fip_id) + binding_db_obj = obj_db_api.get_object(policy, context, fip_id=fip_id) return binding_db_obj def _delete_fip_qos_db(self, context, fip_id, policy_id): @@ -73,14 +72,7 @@ class FloatingQoSDbMixin(object): self._delete_fip_qos_db(context, floatingip_obj['id'], old_qos_policy_id) - if floatingip_obj.db_obj.qos_policy_binding: - floatingip_obj.db_obj.qos_policy_binding['policy_id'] = ( - new_qos_policy_id) if not new_qos_policy_id: return - qos_policy_binding = self._create_fip_qos_db( - context, - floatingip_obj['id'], - new_qos_policy_id) - if not floatingip_obj.db_obj.qos_policy_binding: - floatingip_obj.db_obj.qos_policy_binding = qos_policy_binding + self._create_fip_qos_db( + context, floatingip_obj['id'], new_qos_policy_id) diff --git a/neutron/objects/agent.py b/neutron/objects/agent.py index 18ade02a1ec..73cebfea6a2 100644 --- a/neutron/objects/agent.py +++ b/neutron/objects/agent.py @@ -89,7 +89,7 @@ class Agent(base.NeutronDbObject): @classmethod def get_l3_agent_with_min_routers(cls, context, agent_ids): """Return l3 agent with the least number of routers.""" - with context.session.begin(subtransactions=True): + with cls.db_context_reader(context): query = context.session.query( agent_model.Agent, func.count( @@ -105,7 +105,7 @@ class Agent(base.NeutronDbObject): @classmethod def get_l3_agents_ordered_by_num_routers(cls, context, agent_ids): - with context.session.begin(subtransactions=True): + with cls.db_context_reader(context): query = (context.session.query(agent_model.Agent, func.count( rb_model.RouterL3AgentBinding.router_id) .label('count')). diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 844b11decbc..108089d9f89 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -81,7 +81,7 @@ class Pager(object): self.page_reverse = page_reverse self.marker = marker - def to_kwargs(self, context, model): + def to_kwargs(self, context, obj_cls): res = { attr: getattr(self, attr) for attr in ('sorts', 'limit', 'page_reverse') @@ -89,7 +89,7 @@ class Pager(object): } if self.marker and self.limit: res['marker_obj'] = obj_db_api.get_object( - context, model, id=self.marker) + obj_cls, context, id=self.marker) return res def __str__(self): @@ -310,16 +310,16 @@ def _detach_db_obj(func): @functools.wraps(func) def decorator(self, *args, **kwargs): synthetic_changed = bool(self._get_changed_synthetic_fields()) - res = func(self, *args, **kwargs) - # some relationship based fields may be changed since we - # captured the model, let's refresh it for the latest database - # state - if synthetic_changed: - # TODO(ihrachys) consider refreshing just changed attributes - self.obj_context.session.refresh(self.db_obj) - # detach the model so that consequent fetches don't reuse it - self.obj_context.session.expunge(self.db_obj) - return res + with self.db_context_writer(self.obj_context): + res = func(self, *args, **kwargs) + # some relationship based fields may be changed since we captured + # the model, let's refresh it for the latest database state + if synthetic_changed: + # TODO(ihrachys) consider refreshing just changed attributes + self.obj_context.session.refresh(self.db_obj) + # detach the model so that consequent fetches don't reuse it + self.obj_context.session.expunge(self.db_obj) + return res return decorator @@ -390,6 +390,12 @@ class NeutronDbObject(NeutronObject): # should be overridden for all persistent objects db_model = None + # should be overridden for all rbac aware objects + rbac_db_cls = None + + # whether to use new engine facade for the object + new_facade = False + primary_keys = ['id'] # 'unique_keys' is a list of unique keys that can be used with get_object @@ -512,6 +518,20 @@ class NeutronDbObject(NeutronObject): if is_attr_nullable: self[attrname] = None + @classmethod + def db_context_writer(cls, context): + """Return read-write session activation decorator.""" + if cls.new_facade: + return db_api.context_manager.writer.using(context) + return db_api.autonested_transaction(context.session) + + @classmethod + def db_context_reader(cls, context): + """Return read-only session activation decorator.""" + if cls.new_facade: + return db_api.context_manager.reader.using(context) + return db_api.autonested_transaction(context.session) + @classmethod def get_object(cls, context, **kwargs): """ @@ -529,11 +549,9 @@ class NeutronDbObject(NeutronObject): raise o_exc.NeutronPrimaryKeyMissing(object_class=cls, missing_keys=missing_keys) - with context.session.begin(subtransactions=True): + with cls.db_context_reader(context): db_obj = obj_db_api.get_object( - context, cls.db_model, - **cls.modify_fields_to_db(kwargs) - ) + cls, context, **cls.modify_fields_to_db(kwargs)) if db_obj: return cls._load_object(context, db_obj) @@ -553,11 +571,9 @@ class NeutronDbObject(NeutronObject): """ if validate_filters: cls.validate_filters(**kwargs) - with context.session.begin(subtransactions=True): + with cls.db_context_reader(context): db_objs = obj_db_api.get_objects( - context, cls.db_model, _pager=_pager, - **cls.modify_fields_to_db(kwargs) - ) + cls, context, _pager=_pager, **cls.modify_fields_to_db(kwargs)) return [cls._load_object(context, db_obj) for db_obj in db_objs] @classmethod @@ -582,9 +598,9 @@ class NeutronDbObject(NeutronObject): return super(NeutronDbObject, cls).update_object( context, values, validate_filters=False, **kwargs) else: - with db_api.autonested_transaction(context.session): + with cls.db_context_writer(context): db_obj = obj_db_api.update_object( - context, cls.db_model, + cls, context, cls.modify_fields_to_db(values), **cls.modify_fields_to_db(kwargs)) return cls._load_object(context, db_obj) @@ -604,15 +620,14 @@ class NeutronDbObject(NeutronObject): if validate_filters: cls.validate_filters(**kwargs) - # if we have standard attributes, we will need to fetch records to - # update revision numbers - if cls.has_standard_attributes(): - return super(NeutronDbObject, cls).update_objects( - context, values, validate_filters=False, **kwargs) - - with db_api.autonested_transaction(context.session): + with cls.db_context_writer(context): + # if we have standard attributes, we will need to fetch records to + # update revision numbers + if cls.has_standard_attributes(): + return super(NeutronDbObject, cls).update_objects( + context, values, validate_filters=False, **kwargs) return obj_db_api.update_objects( - context, cls.db_model, + cls, context, cls.modify_fields_to_db(values), **cls.modify_fields_to_db(kwargs)) @@ -629,9 +644,9 @@ class NeutronDbObject(NeutronObject): """ if validate_filters: cls.validate_filters(**kwargs) - with context.session.begin(subtransactions=True): + with cls.db_context_writer(context): return obj_db_api.delete_objects( - context, cls.db_model, **cls.modify_fields_to_db(kwargs)) + cls, context, **cls.modify_fields_to_db(kwargs)) @classmethod def is_accessible(cls, context, db_obj): @@ -748,11 +763,10 @@ class NeutronDbObject(NeutronObject): def create(self): fields = self._get_changed_persistent_fields() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): try: db_obj = obj_db_api.create_object( - self.obj_context, self.db_model, - self.modify_fields_to_db(fields)) + self, self.obj_context, self.modify_fields_to_db(fields)) except obj_exc.DBDuplicateEntry as db_exc: raise o_exc.NeutronDbObjectDuplicateEntry( object_class=self.__class__, db_exception=db_exc) @@ -786,16 +800,16 @@ class NeutronDbObject(NeutronObject): updates = self._get_changed_persistent_fields() updates = self._validate_changed_fields(updates) - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): db_obj = obj_db_api.update_object( - self.obj_context, self.db_model, + self, self.obj_context, self.modify_fields_to_db(updates), **self.modify_fields_to_db( self._get_composite_keys())) self.from_db_object(db_obj) def delete(self): - obj_db_api.delete_object(self.obj_context, self.db_model, + obj_db_api.delete_object(self, self.obj_context, **self.modify_fields_to_db( self._get_composite_keys())) self._captured_db_model = None @@ -814,7 +828,7 @@ class NeutronDbObject(NeutronObject): if validate_filters: cls.validate_filters(**kwargs) return obj_db_api.count( - context, cls.db_model, **cls.modify_fields_to_db(kwargs) + cls, context, **cls.modify_fields_to_db(kwargs) ) @classmethod @@ -832,5 +846,5 @@ class NeutronDbObject(NeutronObject): cls.validate_filters(**kwargs) # Succeed if at least a single object matches; no need to fetch more return bool(obj_db_api.get_object( - context, cls.db_model, **cls.modify_fields_to_db(kwargs)) + cls, context, **cls.modify_fields_to_db(kwargs)) ) diff --git a/neutron/objects/db/api.py b/neutron/objects/db/api.py index 069f5a03ef7..3053ebc300d 100644 --- a/neutron/objects/db/api.py +++ b/neutron/objects/db/api.py @@ -21,19 +21,20 @@ from neutron.objects import utils as obj_utils # Common database operation implementations -def _get_filter_query(context, model, **kwargs): - with context.session.begin(subtransactions=True): +def _get_filter_query(obj_cls, context, **kwargs): + with obj_cls.db_context_reader(context): filters = _kwargs_to_filters(**kwargs) - query = model_query.get_collection_query(context, model, filters) + query = model_query.get_collection_query( + context, obj_cls.db_model, filters) return query -def get_object(context, model, **kwargs): - return _get_filter_query(context, model, **kwargs).first() +def get_object(obj_cls, context, **kwargs): + return _get_filter_query(obj_cls, context, **kwargs).first() -def count(context, model, **kwargs): - return _get_filter_query(context, model, **kwargs).count() +def count(obj_cls, context, **kwargs): + return _get_filter_query(obj_cls, context, **kwargs).count() def _kwargs_to_filters(**kwargs): @@ -42,77 +43,80 @@ def _kwargs_to_filters(**kwargs): for k, v in kwargs.items()} -def get_objects(context, model, _pager=None, **kwargs): - with context.session.begin(subtransactions=True): +def get_objects(obj_cls, context, _pager=None, **kwargs): + with obj_cls.db_context_reader(context): filters = _kwargs_to_filters(**kwargs) return model_query.get_collection( - context, model, + context, obj_cls.db_model, dict_func=None, # return all the data filters=filters, - **(_pager.to_kwargs(context, model) if _pager else {})) + **(_pager.to_kwargs(context, obj_cls) if _pager else {})) -def create_object(context, model, values, populate_id=True): - with context.session.begin(subtransactions=True): - if populate_id and 'id' not in values and hasattr(model, 'id'): +def create_object(obj_cls, context, values, populate_id=True): + with obj_cls.db_context_writer(context): + if (populate_id and + 'id' not in values and + hasattr(obj_cls.db_model, 'id')): values['id'] = uuidutils.generate_uuid() - db_obj = model(**values) + db_obj = obj_cls.db_model(**values) context.session.add(db_obj) return db_obj -def _safe_get_object(context, model, **kwargs): - db_obj = get_object(context, model, **kwargs) +def _safe_get_object(obj_cls, context, **kwargs): + db_obj = get_object(obj_cls, context, **kwargs) if db_obj is None: key = ", ".join(['%s=%s' % (key, value) for (key, value) in kwargs.items()]) - raise n_exc.ObjectNotFound(id="%s(%s)" % (model.__name__, key)) + raise n_exc.ObjectNotFound( + id="%s(%s)" % (obj_cls.db_model.__name__, key)) return db_obj -def update_object(context, model, values, **kwargs): - with context.session.begin(subtransactions=True): - db_obj = _safe_get_object(context, model, **kwargs) +def update_object(obj_cls, context, values, **kwargs): + with obj_cls.db_context_writer(context): + db_obj = _safe_get_object(obj_cls, context, **kwargs) db_obj.update(values) db_obj.save(session=context.session) return db_obj -def delete_object(context, model, **kwargs): - with context.session.begin(subtransactions=True): - db_obj = _safe_get_object(context, model, **kwargs) +def delete_object(obj_cls, context, **kwargs): + with obj_cls.db_context_writer(context): + db_obj = _safe_get_object(obj_cls, context, **kwargs) context.session.delete(db_obj) -def update_objects(context, model, values, **kwargs): +def update_objects(obj_cls, context, values, **kwargs): '''Update matching objects, if any. Return number of updated objects. This function does not raise exceptions if nothing matches. - :param model: SQL model + :param obj_cls: Object class :param values: values to update in matching objects :param kwargs: multiple filters defined by key=value pairs :return: Number of entries updated ''' - with context.session.begin(subtransactions=True): + with obj_cls.db_context_writer(context): if not values: - return count(context, model, **kwargs) - q = _get_filter_query(context, model, **kwargs) + return count(obj_cls, context, **kwargs) + q = _get_filter_query(obj_cls, context, **kwargs) return q.update(values, synchronize_session=False) -def delete_objects(context, model, **kwargs): +def delete_objects(obj_cls, context, **kwargs): '''Delete matching objects, if any. Return number of deleted objects. This function does not raise exceptions if nothing matches. - :param model: SQL model + :param obj_cls: Object class :param kwargs: multiple filters defined by key=value pairs :return: Number of entries deleted ''' - with context.session.begin(subtransactions=True): - db_objs = get_objects(context, model, **kwargs) + with obj_cls.db_context_writer(context): + db_objs = get_objects(obj_cls, context, **kwargs) for db_obj in db_objs: context.session.delete(db_obj) return len(db_objs) diff --git a/neutron/objects/network.py b/neutron/objects/network.py index e78aac1ba83..e9eafbe5dca 100644 --- a/neutron/objects/network.py +++ b/neutron/objects/network.py @@ -16,7 +16,6 @@ from neutron_lib.api.definitions import availability_zone as az_def from neutron_lib.api.validators import availability_zone as az_validator from oslo_versionedobjects import fields as obj_fields -from neutron.db import api as db_api from neutron.db.models import dns as dns_models from neutron.db.models import external_net as ext_net_model from neutron.db.models import segment as segment_model @@ -32,6 +31,20 @@ from neutron.objects.qos import binding from neutron.objects import rbac_db +@base.NeutronObjectRegistry.register +class NetworkRBAC(base.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = rbac_db_models.NetworkRBAC + + fields = { + 'object_id': obj_fields.StringField(), + 'target_tenant': obj_fields.StringField(), + 'action': obj_fields.StringField(), + } + + @base.NeutronObjectRegistry.register class NetworkDhcpAgentBinding(base.NeutronDbObject): # Version 1.0: Initial version @@ -86,7 +99,7 @@ class NetworkSegment(base.NeutronDbObject): def create(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): hosts = self.hosts if hosts is None: hosts = [] @@ -96,7 +109,7 @@ class NetworkSegment(base.NeutronDbObject): def update(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(NetworkSegment, self).update() if 'hosts' in fields: self._attach_hosts(fields['hosts']) @@ -176,7 +189,7 @@ class Network(rbac_db.NeutronRbacObject): # Version 1.0: Initial version VERSION = '1.0' - rbac_db_model = rbac_db_models.NetworkRBAC + rbac_db_cls = NetworkRBAC db_model = models_v2.Network fields = { @@ -223,7 +236,7 @@ class Network(rbac_db.NeutronRbacObject): def create(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): dns_domain = self.dns_domain qos_policy_id = self.qos_policy_id super(Network, self).create() @@ -234,7 +247,7 @@ class Network(rbac_db.NeutronRbacObject): def update(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(Network, self).update() if 'dns_domain' in fields: self._set_dns_domain(fields['dns_domain']) diff --git a/neutron/objects/ports.py b/neutron/objects/ports.py index a62f3056acb..05a85a39fcf 100644 --- a/neutron/objects/ports.py +++ b/neutron/objects/ports.py @@ -18,7 +18,6 @@ from oslo_utils import versionutils from oslo_versionedobjects import fields as obj_fields from neutron.common import utils -from neutron.db import api as db_api from neutron.db.models import dns as dns_models from neutron.db.models import l3 from neutron.db.models import securitygroup as sg_models @@ -234,6 +233,22 @@ class PortDNS(base.NeutronDbObject): primitive.pop('dns_domain', None) +@base.NeutronObjectRegistry.register +class SecurityGroupPortBinding(base.NeutronDbObject): + + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = sg_models.SecurityGroupPortBinding + + fields = { + 'port_id': common_types.UUIDField(), + 'security_group_id': common_types.UUIDField(), + } + + primary_keys = ['port_id', 'security_group_id'] + + @base.NeutronObjectRegistry.register class Port(base.NeutronDbObject): # Version 1.0: Initial version @@ -318,7 +333,7 @@ class Port(base.NeutronDbObject): def create(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): sg_ids = self.security_group_ids if sg_ids is None: sg_ids = set() @@ -331,7 +346,7 @@ class Port(base.NeutronDbObject): def update(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(Port, self).update() if 'security_group_ids' in fields: self._attach_security_groups(fields['security_group_ids']) @@ -353,9 +368,7 @@ class Port(base.NeutronDbObject): # TODO(ihrachys): consider introducing an (internal) object for the # binding to decouple database operations a bit more obj_db_api.delete_objects( - self.obj_context, sg_models.SecurityGroupPortBinding, - port_id=self.id, - ) + SecurityGroupPortBinding, self.obj_context, port_id=self.id) if sg_ids: for sg_id in sg_ids: self._attach_security_group(sg_id) @@ -364,7 +377,7 @@ class Port(base.NeutronDbObject): def _attach_security_group(self, sg_id): obj_db_api.create_object( - self.obj_context, sg_models.SecurityGroupPortBinding, + SecurityGroupPortBinding, self.obj_context, {'port_id': self.id, 'security_group_id': sg_id} ) diff --git a/neutron/objects/qos/policy.py b/neutron/objects/qos/policy.py index 0826b43c327..86d03729dd1 100644 --- a/neutron/objects/qos/policy.py +++ b/neutron/objects/qos/policy.py @@ -22,11 +22,10 @@ from oslo_versionedobjects import exception from oslo_versionedobjects import fields as obj_fields from neutron.common import exceptions -from neutron.db import api as db_api from neutron.db.models import l3 from neutron.db import models_v2 from neutron.db.qos import models as qos_db_model -from neutron.db.rbac_db_models import QosPolicyRBAC +from neutron.db import rbac_db_models from neutron.objects import base as base_db from neutron.objects import common_types from neutron.objects.db import api as obj_db_api @@ -35,6 +34,20 @@ from neutron.objects.qos import rule as rule_obj_impl from neutron.objects import rbac_db +@base_db.NeutronObjectRegistry.register +class QosPolicyRBAC(base_db.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = rbac_db_models.QosPolicyRBAC + + fields = { + 'object_id': obj_fields.StringField(), + 'target_tenant': obj_fields.StringField(), + 'action': obj_fields.StringField(), + } + + @base_db.NeutronObjectRegistry.register class QosPolicy(rbac_db.NeutronRbacObject): # Version 1.0: Initial version @@ -48,13 +61,9 @@ class QosPolicy(rbac_db.NeutronRbacObject): VERSION = '1.7' # required by RbacNeutronMetaclass - rbac_db_model = QosPolicyRBAC + rbac_db_cls = QosPolicyRBAC db_model = qos_db_model.QosPolicy - port_binding_model = qos_db_model.QosPortPolicyBinding - network_binding_model = qos_db_model.QosNetworkPolicyBinding - fip_binding_model = qos_db_model.QosFIPPolicyBinding - fields = { 'id': common_types.UUIDField(), 'project_id': obj_fields.StringField(), @@ -82,7 +91,7 @@ class QosPolicy(rbac_db.NeutronRbacObject): return super(QosPolicy, self).obj_load_attr(attrname) def _reload_rules(self): - rules = rule_obj_impl.get_rules(self.obj_context, self.id) + rules = rule_obj_impl.get_rules(self, self.obj_context, self.id) setattr(self, 'rules', rules) self.obj_reset_changes(['rules']) @@ -121,7 +130,7 @@ class QosPolicy(rbac_db.NeutronRbacObject): # We want to get the policy regardless of its tenant id. We'll make # sure the tenant has permission to access the policy later on. admin_context = context.elevated() - with db_api.autonested_transaction(admin_context.session): + with cls.db_context_reader(admin_context): policy_obj = super(QosPolicy, cls).get_object(admin_context, **kwargs) if (not policy_obj or @@ -138,7 +147,7 @@ class QosPolicy(rbac_db.NeutronRbacObject): # We want to get the policy regardless of its tenant id. We'll make # sure the tenant has permission to access the policy later on. admin_context = context.elevated() - with db_api.autonested_transaction(admin_context.session): + with cls.db_context_reader(admin_context): objs = super(QosPolicy, cls).get_objects(admin_context, _pager, validate_filters, **kwargs) @@ -152,37 +161,38 @@ class QosPolicy(rbac_db.NeutronRbacObject): return result @classmethod - def _get_object_policy(cls, context, model, **kwargs): - with db_api.autonested_transaction(context.session): - binding_db_obj = obj_db_api.get_object(context, model, **kwargs) + def _get_object_policy(cls, context, binding_cls, **kwargs): + with cls.db_context_reader(context): + binding_db_obj = obj_db_api.get_object(binding_cls, context, + **kwargs) if binding_db_obj: return cls.get_object(context, id=binding_db_obj['policy_id']) @classmethod def get_network_policy(cls, context, network_id): - return cls._get_object_policy(context, cls.network_binding_model, + return cls._get_object_policy(context, binding.QosPolicyNetworkBinding, network_id=network_id) @classmethod def get_port_policy(cls, context, port_id): - return cls._get_object_policy(context, cls.port_binding_model, + return cls._get_object_policy(context, binding.QosPolicyPortBinding, port_id=port_id) @classmethod def get_fip_policy(cls, context, fip_id): - return cls._get_object_policy(context, cls.fip_binding_model, - fip_id=fip_id) + return cls._get_object_policy( + context, binding.QosPolicyFloatingIPBinding, fip_id=fip_id) # TODO(QoS): Consider extending base to trigger registered methods for us def create(self): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(QosPolicy, self).create() if self.is_default: self.set_default() self.obj_load_attr('rules') def update(self): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): if 'is_default' in self.obj_what_changed(): if self.is_default: self.set_default() @@ -191,7 +201,7 @@ class QosPolicy(rbac_db.NeutronRbacObject): super(QosPolicy, self).update() def delete(self): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): for object_type, obj_class in self.binding_models.items(): pager = base_db.Pager(limit=1) binding_obj = obj_class.get_objects(self.obj_context, @@ -322,7 +332,7 @@ class QosPolicy(rbac_db.NeutronRbacObject): fip = l3.FloatingIP qosfip = qos_db_model.QosFIPPolicyBinding bound_tenants = [] - with db_api.autonested_transaction(context.session): + with cls.db_context_reader(context): bound_tenants.extend(cls._get_bound_tenant_ids( context.session, qosnet, net, qosnet.network_id, policy_id)) bound_tenants.extend( diff --git a/neutron/objects/qos/rule.py b/neutron/objects/qos/rule.py index 21b88c6451a..e63eb710175 100644 --- a/neutron/objects/qos/rule.py +++ b/neutron/objects/qos/rule.py @@ -24,7 +24,6 @@ from oslo_versionedobjects import exception from oslo_versionedobjects import fields as obj_fields import six -from neutron.db import api as db_api from neutron.db.qos import models as qos_db_model from neutron.objects import base from neutron.objects import common_types @@ -32,9 +31,9 @@ from neutron.objects import common_types DSCP_MARK = 'dscp_mark' -def get_rules(context, qos_policy_id): +def get_rules(obj_cls, context, qos_policy_id): all_rules = [] - with db_api.autonested_transaction(context.session): + with obj_cls.db_context_reader(context): for rule_type in qos_consts.VALID_RULE_TYPES: rule_cls_name = 'Qos%sRule' % helpers.camelize(rule_type) rule_cls = getattr(sys.modules[__name__], rule_cls_name) diff --git a/neutron/objects/quota.py b/neutron/objects/quota.py index 518f8057a7f..668740b6661 100644 --- a/neutron/objects/quota.py +++ b/neutron/objects/quota.py @@ -16,7 +16,6 @@ from oslo_versionedobjects import fields as obj_fields import sqlalchemy as sa from sqlalchemy import sql -from neutron.db import api as db_api from neutron.db.quota import models from neutron.objects import base from neutron.objects import common_types @@ -60,7 +59,7 @@ class Reservation(base.NeutronDbObject): def create(self): deltas = self.resource_deltas - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(Reservation, self).create() if deltas: for delta in deltas: diff --git a/neutron/objects/rbac_db.py b/neutron/objects/rbac_db.py index 2c9ca5be3fb..f087cfbe014 100644 --- a/neutron/objects/rbac_db.py +++ b/neutron/objects/rbac_db.py @@ -25,7 +25,6 @@ from sqlalchemy import and_ from neutron._i18n import _ from neutron.common import exceptions as n_exc from neutron.db import _utils as db_utils -from neutron.db import api as db_api from neutron.db import rbac_db_mixin from neutron.db import rbac_db_models as models from neutron.extensions import rbac as ext_rbac @@ -37,7 +36,7 @@ from neutron.objects.db import api as obj_db_api class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, base.NeutronDbObject): - rbac_db_model = None + rbac_db_cls = None @classmethod @abc.abstractmethod @@ -65,9 +64,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, return False @staticmethod - def get_shared_with_tenant(context, rbac_db_model, obj_id, tenant_id): + def get_shared_with_tenant(context, rbac_db_cls, obj_id, tenant_id): # NOTE(korzen) This method enables to query within already started # session + rbac_db_model = rbac_db_cls.db_model return (db_utils.model_query(context, rbac_db_model).filter( and_(rbac_db_model.object_id == obj_id, rbac_db_model.action == models.ACCESS_SHARED, @@ -77,9 +77,8 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, @classmethod def is_shared_with_tenant(cls, context, obj_id, tenant_id): ctx = context.elevated() - rbac_db_model = cls.rbac_db_model - with ctx.session.begin(subtransactions=True): - return cls.get_shared_with_tenant(ctx, rbac_db_model, + with cls.db_context_reader(ctx): + return cls.get_shared_with_tenant(ctx, cls.rbac_db_cls, obj_id, tenant_id) @classmethod @@ -91,23 +90,24 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, @classmethod def _get_db_obj_rbac_entries(cls, context, rbac_obj_id, rbac_action): - rbac_db_model = cls.rbac_db_model + rbac_db_model = cls.rbac_db_cls.db_model return db_utils.model_query(context, rbac_db_model).filter( and_(rbac_db_model.object_id == rbac_obj_id, rbac_db_model.action == rbac_action)) @classmethod def _get_tenants_with_shared_access_to_db_obj(cls, context, obj_id): + rbac_db_model = cls.rbac_db_cls.db_model return set(itertools.chain.from_iterable(context.session.query( - cls.rbac_db_model.target_tenant).filter( - and_(cls.rbac_db_model.object_id == obj_id, - cls.rbac_db_model.action == models.ACCESS_SHARED, - cls.rbac_db_model.target_tenant != '*')))) + rbac_db_model.target_tenant).filter( + and_(rbac_db_model.object_id == obj_id, + rbac_db_model.action == models.ACCESS_SHARED, + rbac_db_model.target_tenant != '*')))) @classmethod def _validate_rbac_policy_delete(cls, context, obj_id, target_tenant): ctx_admin = context.elevated() - rb_model = cls.rbac_db_model + rb_model = cls.rbac_db_cls.db_model bound_tenant_ids = cls.get_bound_tenant_ids(ctx_admin, obj_id) db_obj_sharing_entries = cls._get_db_obj_rbac_entries( ctx_admin, obj_id, models.ACCESS_SHARED) @@ -146,7 +146,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, return target_tenant = policy['target_tenant'] db_obj = obj_db_api.get_object( - context.elevated(), cls.db_model, id=policy['object_id']) + cls, context.elevated(), id=policy['object_id']) if db_obj.tenant_id == target_tenant: return cls._validate_rbac_policy_delete(context=context, @@ -181,10 +181,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, # NeutronDbPluginV2.validate_network_rbac_policy_change(), those pieces # should be synced and contain the same bugs, until Network RBAC logic # (hopefully) melded with this one. - if object_type != cls.rbac_db_model.object_type: + if object_type != cls.rbac_db_cls.db_model.object_type: return db_obj = obj_db_api.get_object( - context.elevated(), cls.db_model, id=policy['object_id']) + cls, context.elevated(), id=policy['object_id']) if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE): if (not context.is_admin and db_obj['tenant_id'] != context.tenant_id): @@ -198,7 +198,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, object_type, policy, **kwargs) def attach_rbac(self, obj_id, tenant_id, target_tenant='*'): - obj_type = self.rbac_db_model.object_type + obj_type = self.rbac_db_cls.db_model.object_type rbac_policy = {'rbac_policy': {'object_id': obj_id, 'target_tenant': target_tenant, 'tenant_id': tenant_id, @@ -208,7 +208,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, def update_shared(self, is_shared_new, obj_id): admin_context = self.obj_context.elevated() - shared_prev = obj_db_api.get_object(admin_context, self.rbac_db_model, + shared_prev = obj_db_api.get_object(self.rbac_db_cls, admin_context, object_id=obj_id, target_tenant='*', action=models.ACCESS_SHARED) @@ -233,7 +233,7 @@ def _update_post(self, obj_changes): def _update_hook(self, update_orig): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): # NOTE(slaweq): copy of object changes is required to pass it later to # _update_post method because update() will reset all those changes obj_changes = self.obj_get_changes() @@ -247,7 +247,7 @@ def _create_post(self): def _create_hook(self, orig_create): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): orig_create(self) _create_post(self) @@ -305,8 +305,8 @@ class RbacNeutronMetaclass(type): def validate_existing_attrs(cls_name, dct): if 'shared' not in dct['fields']: raise KeyError(_('No shared key in %s fields') % cls_name) - if 'rbac_db_model' not in dct: - raise AttributeError(_('rbac_db_model not found in %s') % cls_name) + if 'rbac_db_cls' not in dct: + raise AttributeError(_('rbac_db_cls not found in %s') % cls_name) @staticmethod def get_replaced_method(orig_method, new_method): diff --git a/neutron/objects/securitygroup.py b/neutron/objects/securitygroup.py index 3e726852bdd..79cb5ca9c13 100644 --- a/neutron/objects/securitygroup.py +++ b/neutron/objects/securitygroup.py @@ -13,7 +13,6 @@ from oslo_versionedobjects import fields as obj_fields from neutron.common import utils -from neutron.db import api as db_api from neutron.db.models import securitygroup as sg_models from neutron.objects import base from neutron.objects import common_types @@ -47,7 +46,7 @@ class SecurityGroup(base.NeutronDbObject): def create(self): # save is_default before super() resets it to False is_default = self.is_default - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(SecurityGroup, self).create() if is_default: default_group = DefaultSecurityGroup( diff --git a/neutron/objects/stdattrs.py b/neutron/objects/stdattrs.py new file mode 100644 index 00000000000..240101ce125 --- /dev/null +++ b/neutron/objects/stdattrs.py @@ -0,0 +1,33 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_versionedobjects import fields as obj_fields + +from neutron.db import standard_attr +from neutron.objects import base +from neutron.objects.extensions import standardattributes as stdattr_obj + + +# TODO(ihrachys): add unit tests for the object +@base.NeutronObjectRegistry.register +class StandardAttribute(base.NeutronDbObject): + + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = standard_attr.StandardAttribute + + fields = { + 'id': obj_fields.IntegerField(), + 'resource_type': obj_fields.StringField(), + } + fields.update(stdattr_obj.STANDARD_ATTRIBUTES) diff --git a/neutron/objects/subnet.py b/neutron/objects/subnet.py index b21cbce281a..b018a05d349 100644 --- a/neutron/objects/subnet.py +++ b/neutron/objects/subnet.py @@ -17,9 +17,9 @@ from oslo_versionedobjects import fields as obj_fields from neutron.common import utils from neutron.db.models import subnet_service_type from neutron.db import models_v2 -from neutron.db import rbac_db_models from neutron.objects import base from neutron.objects import common_types +from neutron.objects import network from neutron.objects import rbac_db @@ -228,7 +228,7 @@ class Subnet(base.NeutronDbObject): # create), it should be rare case to load 'shared' by that method shared = (rbac_db.RbacNeutronDbObjectMixin. get_shared_with_tenant(self.obj_context.elevated(), - rbac_db_models.NetworkRBAC, + network.NetworkRBAC, self.network_id, self.project_id)) setattr(self, 'shared', shared) diff --git a/neutron/objects/subnetpool.py b/neutron/objects/subnetpool.py index f82e8e00a53..37d6876df00 100644 --- a/neutron/objects/subnetpool.py +++ b/neutron/objects/subnetpool.py @@ -16,7 +16,6 @@ import netaddr from oslo_versionedobjects import fields as obj_fields -from neutron.db import api as db_api from neutron.db import models_v2 as models from neutron.objects import base from neutron.objects import common_types @@ -70,7 +69,7 @@ class SubnetPool(base.NeutronDbObject): # TODO(ihrachys): Consider extending base to trigger registered methods def create(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): prefixes = self.prefixes super(SubnetPool, self).create() if 'prefixes' in fields: @@ -79,7 +78,7 @@ class SubnetPool(base.NeutronDbObject): # TODO(ihrachys): Consider extending base to trigger registered methods def update(self): fields = self.obj_get_changes() - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): super(SubnetPool, self).update() if 'prefixes' in fields: self._attach_prefixes(fields['prefixes']) diff --git a/neutron/objects/trunk.py b/neutron/objects/trunk.py index f1bfc9f1173..c753550af9d 100644 --- a/neutron/objects/trunk.py +++ b/neutron/objects/trunk.py @@ -19,7 +19,6 @@ from oslo_db import exception as o_db_exc from oslo_utils import versionutils from oslo_versionedobjects import fields as obj_fields -from neutron.db import api as db_api from neutron.objects import base from neutron.objects import common_types from neutron.services.trunk import exceptions as t_exc @@ -52,7 +51,7 @@ class SubPort(base.NeutronDbObject): return _dict def create(self): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): try: super(SubPort, self).create() except o_db_exc.DBReferenceError as ex: @@ -104,7 +103,7 @@ class Trunk(base.NeutronDbObject): synthetic_fields = ['sub_ports'] def create(self): - with db_api.autonested_transaction(self.obj_context.session): + with self.db_context_writer(self.obj_context): sub_ports = [] if self.obj_attr_is_set('sub_ports'): sub_ports = self.sub_ports diff --git a/neutron/tests/unit/extensions/test_qos_fip.py b/neutron/tests/unit/extensions/test_qos_fip.py index a67f4a5f026..14c05334e55 100644 --- a/neutron/tests/unit/extensions/test_qos_fip.py +++ b/neutron/tests/unit/extensions/test_qos_fip.py @@ -18,6 +18,7 @@ from oslo_config import cfg from oslo_utils import uuidutils from neutron.common import exceptions as n_exception +from neutron.conf.db import extraroute_db from neutron.db import l3_fip_qos from neutron.extensions import l3 from neutron.extensions import qos_fip @@ -211,9 +212,12 @@ class FloatingIPQoSDBIntTestCase(test_l3.L3BaseForIntTests, plugin = ('neutron.tests.unit.extensions.test_qos_fip.' 'TestFloatingIPQoSIntPlugin') service_plugins = {'qos': 'neutron.services.qos.qos_plugin.QoSPlugin'} + + extraroute_db.register_db_extraroute_opts() # for these tests we need to enable overlapping ips cfg.CONF.set_default('allow_overlapping_ips', True) cfg.CONF.set_default('max_routes', 3) + ext_mgr = FloatingIPQoSTestExtensionManager() super(test_l3.L3BaseForIntTests, self).setUp( plugin=plugin, @@ -236,9 +240,11 @@ class FloatingIPQoSDBSepTestCase(test_l3.L3BaseForSepTests, service_plugins = {'l3_plugin_name': l3_plugin, 'qos': 'neutron.services.qos.qos_plugin.QoSPlugin'} + extraroute_db.register_db_extraroute_opts() # for these tests we need to enable overlapping ips cfg.CONF.set_default('allow_overlapping_ips', True) cfg.CONF.set_default('max_routes', 3) + ext_mgr = FloatingIPQoSTestExtensionManager() super(test_l3.L3BaseForSepTests, self).setUp( plugin=plugin, diff --git a/neutron/tests/unit/objects/db/test_api.py b/neutron/tests/unit/objects/db/test_api.py index 5c7c8ec06ce..beef8bc1f1f 100644 --- a/neutron/tests/unit/objects/db/test_api.py +++ b/neutron/tests/unit/objects/db/test_api.py @@ -17,9 +17,9 @@ from neutron_lib import context from neutron_lib import exceptions as n_exc from neutron.db import _model_query as model_query -from neutron.db import models_v2 from neutron.objects import base from neutron.objects.db import api +from neutron.objects import network from neutron.objects import utils as obj_utils from neutron.tests import base as test_base from neutron.tests.unit import testlib_api @@ -28,6 +28,15 @@ from neutron.tests.unit import testlib_api PLUGIN_NAME = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2' +class FakeModel(object): + def __init__(self, *args, **kwargs): + pass + + +class FakeObj(base.NeutronDbObject): + db_model = FakeModel + + class GetObjectsTestCase(test_base.BaseTestCase): def setUp(self): @@ -38,7 +47,6 @@ class GetObjectsTestCase(test_base.BaseTestCase): def test_get_objects_pass_marker_obj_when_limit_and_marker_passed(self): ctxt = context.get_admin_context() - model = mock.sentinel.model marker = mock.sentinel.marker limit = mock.sentinel.limit pager = base.Pager(marker=marker, limit=limit) @@ -46,10 +54,10 @@ class GetObjectsTestCase(test_base.BaseTestCase): with mock.patch.object( model_query, 'get_collection') as get_collection: with mock.patch.object(api, 'get_object') as get_object: - api.get_objects(ctxt, model, _pager=pager) - get_object.assert_called_with(ctxt, model, id=marker) + api.get_objects(FakeObj, ctxt, _pager=pager) + get_object.assert_called_with(FakeObj, ctxt, id=marker) get_collection.assert_called_with( - ctxt, model, dict_func=None, + ctxt, FakeObj.db_model, dict_func=None, filters={}, limit=limit, marker_obj=get_object.return_value) @@ -58,15 +66,15 @@ class GetObjectsTestCase(test_base.BaseTestCase): class CreateObjectTestCase(test_base.BaseTestCase): def test_populate_id(self, populate_id=True): ctxt = context.get_admin_context() - model_cls = mock.Mock() values = {'x': 1, 'y': 2, 'z': 3} - with mock.patch.object(ctxt.__class__, 'session'): - api.create_object(ctxt, model_cls, values, - populate_id=populate_id) + with mock.patch.object(FakeObj, 'db_model') as db_model_mock: + with mock.patch.object(ctxt.__class__, 'session'): + api.create_object(FakeObj, ctxt, values, + populate_id=populate_id) expected = copy.copy(values) if populate_id: expected['id'] = mock.ANY - model_cls.assert_called_with(**expected) + db_model_mock.assert_called_with(**expected) def test_populate_id_False(self): self.test_populate_id(populate_id=False) @@ -82,90 +90,93 @@ class CRUDScenarioTestCase(testlib_api.SqlTestCase): # neutron.objects.db.api from core plugin instance self.setup_coreplugin(self.CORE_PLUGIN) # NOTE(ihrachys): nothing specific to networks in this test case, but - # we needed to pick some real model, so we picked the network. Any - # other model would work as well for our needs here. - self.model = models_v2.Network + # we needed to pick some real object, so we picked the network. Any + # other object would work as well for our needs here. + self.obj_cls = network.Network self.ctxt = context.get_admin_context() def test_get_object_with_None_value_in_filters(self): - obj = api.create_object(self.ctxt, self.model, {'name': 'foo'}) + obj = api.create_object(self.obj_cls, self.ctxt, {'name': 'foo'}) new_obj = api.get_object( - self.ctxt, self.model, name='foo', status=None) + self.obj_cls, self.ctxt, name='foo', status=None) self.assertEqual(obj, new_obj) def test_get_objects_with_None_value_in_filters(self): - obj = api.create_object(self.ctxt, self.model, {'name': 'foo'}) + obj = api.create_object(self.obj_cls, self.ctxt, {'name': 'foo'}) new_objs = api.get_objects( - self.ctxt, self.model, name='foo', status=None) + self.obj_cls, self.ctxt, name='foo', status=None) self.assertEqual(obj, new_objs[0]) def test_get_objects_with_string_matching_filters_contains(self): - obj1 = api.create_object(self.ctxt, self.model, {'name': 'obj_con_1'}) - obj2 = api.create_object(self.ctxt, self.model, {'name': 'obj_con_2'}) - obj3 = api.create_object(self.ctxt, self.model, {'name': 'obj_3'}) + obj1 = api.create_object( + self.obj_cls, self.ctxt, {'name': 'obj_con_1'}) + obj2 = api.create_object( + self.obj_cls, self.ctxt, {'name': 'obj_con_2'}) + obj3 = api.create_object( + self.obj_cls, self.ctxt, {'name': 'obj_3'}) objs = api.get_objects( - self.ctxt, self.model, name=obj_utils.StringContains('con')) + self.obj_cls, self.ctxt, name=obj_utils.StringContains('con')) self.assertEqual(2, len(objs)) self.assertIn(obj1, objs) self.assertIn(obj2, objs) self.assertNotIn(obj3, objs) def test_get_objects_with_string_matching_filters_starts(self): - obj1 = api.create_object(self.ctxt, self.model, {'name': 'pre_obj1'}) - obj2 = api.create_object(self.ctxt, self.model, {'name': 'pre_obj2'}) - obj3 = api.create_object(self.ctxt, self.model, {'name': 'obj_3'}) + obj1 = api.create_object(self.obj_cls, self.ctxt, {'name': 'pre_obj1'}) + obj2 = api.create_object(self.obj_cls, self.ctxt, {'name': 'pre_obj2'}) + obj3 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj_3'}) objs = api.get_objects( - self.ctxt, self.model, name=obj_utils.StringStarts('pre')) + self.obj_cls, self.ctxt, name=obj_utils.StringStarts('pre')) self.assertEqual(2, len(objs)) self.assertIn(obj1, objs) self.assertIn(obj2, objs) self.assertNotIn(obj3, objs) def test_get_objects_with_string_matching_filters_ends(self): - obj1 = api.create_object(self.ctxt, self.model, {'name': 'obj1_end'}) - obj2 = api.create_object(self.ctxt, self.model, {'name': 'obj2_end'}) - obj3 = api.create_object(self.ctxt, self.model, {'name': 'obj_3'}) + obj1 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj1_end'}) + obj2 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj2_end'}) + obj3 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj_3'}) objs = api.get_objects( - self.ctxt, self.model, name=obj_utils.StringEnds('end')) + self.obj_cls, self.ctxt, name=obj_utils.StringEnds('end')) self.assertEqual(2, len(objs)) self.assertIn(obj1, objs) self.assertIn(obj2, objs) self.assertNotIn(obj3, objs) def test_get_object_create_update_delete(self): - obj = api.create_object(self.ctxt, self.model, {'name': 'foo'}) + obj = api.create_object(self.obj_cls, self.ctxt, {'name': 'foo'}) - new_obj = api.get_object(self.ctxt, self.model, id=obj.id) + new_obj = api.get_object(self.obj_cls, self.ctxt, id=obj.id) self.assertEqual(obj, new_obj) obj = new_obj - api.update_object(self.ctxt, self.model, {'name': 'bar'}, id=obj.id) + api.update_object(self.obj_cls, self.ctxt, {'name': 'bar'}, id=obj.id) - new_obj = api.get_object(self.ctxt, self.model, id=obj.id) + new_obj = api.get_object(self.obj_cls, self.ctxt, id=obj.id) self.assertEqual(obj, new_obj) obj = new_obj - api.delete_object(self.ctxt, self.model, id=obj.id) + api.delete_object(self.obj_cls, self.ctxt, id=obj.id) - new_obj = api.get_object(self.ctxt, self.model, id=obj.id) + new_obj = api.get_object(self.obj_cls, self.ctxt, id=obj.id) self.assertIsNone(new_obj) # delete_object raises an exception on missing object self.assertRaises( n_exc.ObjectNotFound, - api.delete_object, self.ctxt, self.model, id=obj.id) + api.delete_object, self.obj_cls, self.ctxt, id=obj.id) # but delete_objects does not not - api.delete_objects(self.ctxt, self.model, id=obj.id) + api.delete_objects(self.obj_cls, self.ctxt, id=obj.id) def test_delete_objects_removes_all_matching_objects(self): # create some objects with identical description for i in range(10): api.create_object( - self.ctxt, self.model, + self.obj_cls, self.ctxt, {'name': 'foo%d' % i, 'description': 'bar'}) # create some more objects with a different description descriptions = set() @@ -173,16 +184,16 @@ class CRUDScenarioTestCase(testlib_api.SqlTestCase): desc = 'bar%d' % i descriptions.add(desc) api.create_object( - self.ctxt, self.model, + self.obj_cls, self.ctxt, {'name': 'foo%d' % i, 'description': desc}) # make sure that all objects are in the database - self.assertEqual(20, api.count(self.ctxt, self.model)) + self.assertEqual(20, api.count(self.obj_cls, self.ctxt)) # now delete just those with the 'bar' description - api.delete_objects(self.ctxt, self.model, description='bar') + api.delete_objects(self.obj_cls, self.ctxt, description='bar') # check that half of objects are gone, and remaining have expected # descriptions - objs = api.get_objects(self.ctxt, self.model) + objs = api.get_objects(self.obj_cls, self.ctxt) self.assertEqual(10, len(objs)) self.assertEqual( descriptions, diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index c3b42b0b979..7cc6447d1cd 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -18,9 +18,10 @@ from oslo_versionedobjects import exception import testtools from neutron.common import exceptions as n_exc -from neutron.db import models_v2 from neutron.objects.db import api as db_api from neutron.objects import network as net_obj +from neutron.objects import ports as port_obj +from neutron.objects.qos import binding from neutron.objects.qos import policy from neutron.objects.qos import rule from neutron.tests.unit.objects import test_base @@ -34,6 +35,9 @@ RULE_OBJ_CLS = { } +# TODO(ihrachys): add tests for QosPolicyRBAC + + class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): _test_class = policy.QosPolicy @@ -57,8 +61,8 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): self.model_map.update({ self._test_class.db_model: self.db_objs, - self._test_class.port_binding_model: [], - self._test_class.network_binding_model: [], + binding.QosPolicyPortBinding.db_model: [], + binding.QosPolicyNetworkBinding.db_model: [], rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules, rule.QosDscpMarkingRule.db_model: self.db_qos_dscp_rules, rule.QosMinimumBandwidthRule.db_model: @@ -73,7 +77,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): objs = self._test_class.get_objects(self.context) context_mock.assert_called_once_with() self.get_objects_mock.assert_any_call( - admin_context, self._test_class.db_model, _pager=None) + self._test_class, admin_context, _pager=None) self.assertItemsEqual( [test_base.get_obj_persistent_fields(obj) for obj in self.objs], [test_base.get_obj_persistent_fields(obj) for obj in objs]) @@ -95,7 +99,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): **self.valid_field_filter) context_mock.assert_called_once_with() get_objects_mock.assert_any_call( - admin_context, self._test_class.db_model, _pager=None, + self._test_class, admin_context, _pager=None, **self.valid_field_filter) self._check_equal(self.objs[0], objs[0]) @@ -110,7 +114,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): self._check_equal(self.objs[0], obj) context_mock.assert_called_once_with() get_object_mock.assert_called_once_with( - admin_context, self._test_class.db_model, id='fake_id') + self._test_class, admin_context, id='fake_id') def test_to_dict_makes_primitive_field_value(self): # is_shared_with_tenant requires DB @@ -237,7 +241,7 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, def test_attach_and_get_multiple_policy_ports(self): port1_id = self._port['id'] - port2 = db_api.create_object(self.context, models_v2.Port, + port2 = db_api.create_object(port_obj.Port, self.context, {'tenant_id': 'fake_tenant_id', 'name': 'test-port2', 'network_id': self._network_id, diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 129fb3b741f..79ec3094bcc 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -32,7 +32,6 @@ from oslo_versionedobjects import fields as obj_fields import testtools from neutron.db import _model_query as model_query -from neutron.db import standard_attr from neutron import objects from neutron.objects import agent from neutron.objects import base @@ -46,6 +45,7 @@ from neutron.objects.qos import policy as qos_policy from neutron.objects import rbac_db from neutron.objects import router from neutron.objects import securitygroup +from neutron.objects import stdattrs from neutron.objects import subnet from neutron.objects import utils as obj_utils from neutron.tests import base as test_base @@ -54,6 +54,7 @@ from neutron.tests.unit.db import test_db_base_plugin_v2 SQLALCHEMY_COMMIT = 'sqlalchemy.engine.Connection._commit_impl' +SQLALCHEMY_CLOSE = 'sqlalchemy.engine.Connection.close' OBJECTS_BASE_OBJ_FROM_PRIMITIVE = ('oslo_versionedobjects.base.' 'VersionedObject.obj_from_primitive') TIMESTAMP_FIELDS = ['created_at', 'updated_at', 'revision_number'] @@ -663,8 +664,8 @@ class _BaseObjectTestCase(object): def _is_test_class(cls, obj): return isinstance(obj, cls._test_class) - def fake_get_objects(self, context, model, **kwargs): - return self.model_map[model] + def fake_get_objects(self, obj_cls, context, **kwargs): + return self.model_map[obj_cls.db_model] def _get_object_synthetic_fields(self, objclass): return [field for field in objclass.synthetic_fields @@ -705,13 +706,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): # NOTE(ihrachys): for matters of basic object behaviour validation, # mock out rbac code accessing database. There are separate tests that # cover RBAC, per object type. - if getattr(self._test_class, 'rbac_db_model', None): - mock.patch.object( - rbac_db.RbacNeutronDbObjectMixin, - 'is_shared_with_tenant', return_value=False).start() - mock.patch.object( - rbac_db.RbacNeutronDbObjectMixin, - 'get_shared_with_tenant').start() + if self._test_class.rbac_db_cls is not None: + if getattr(self._test_class.rbac_db_cls, 'db_model', None): + mock.patch.object( + rbac_db.RbacNeutronDbObjectMixin, + 'is_shared_with_tenant', return_value=False).start() + mock.patch.object( + rbac_db.RbacNeutronDbObjectMixin, + 'get_shared_with_tenant').start() def fake_get_object(self, context, model, **kwargs): objs = self.model_map[model] @@ -719,8 +721,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): return None return [obj for obj in objs if obj['id'] == kwargs['id']][0] - def fake_get_objects(self, context, model, **kwargs): - return self.model_map[model] + def fake_get_objects(self, obj_cls, context, **kwargs): + return self.model_map[obj_cls.db_model] # TODO(ihrachys) document the intent of all common test cases in docstrings def test_get_object(self): @@ -734,7 +736,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.assertTrue(self._is_test_class(obj)) self._check_equal(self.objs[0], obj) get_object_mock.assert_called_once_with( - self.context, self._test_class.db_model, + self._test_class, self.context, **self._test_class.modify_fields_to_db(obj_keys)) def test_get_object_missing_object(self): @@ -773,7 +775,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.assertTrue(self._is_test_class(obj)) self._check_equal(self.objs[0], obj) get_object_mock.assert_called_once_with( - mock.ANY, self._test_class.db_model, + self._test_class, mock.ANY, **self._test_class.modify_fields_to_db(obj_keys)) def _get_synthetic_fields_get_objects_calls(self, db_objs): @@ -790,7 +792,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): } mock_calls.append( mock.call( - self.context, obj_class.db_model, + obj_class, self.context, _pager=self.pager_map[obj_class.obj_name()], **filter_kwargs)) return mock_calls @@ -805,7 +807,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): [get_obj_persistent_fields(obj) for obj in self.objs], [get_obj_persistent_fields(obj) for obj in objs]) get_objects_mock.assert_any_call( - self.context, self._test_class.db_model, + self._test_class, self.context, _pager=self.pager_map[self._test_class.obj_name()] ) @@ -952,7 +954,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): ) as delete_objects_mock: self.assertEqual(0, self._test_class.delete_objects(self.context)) delete_objects_mock.assert_any_call( - self.context, self._test_class.db_model) + self._test_class, self.context) def test_delete_objects_valid_fields(self): '''Test that a valid filter does not raise an error.''' @@ -1012,7 +1014,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj.create() self._check_equal(self.objs[0], obj) create_mock.assert_called_once_with( - self.context, self._test_class.db_model, + obj, self.context, self._test_class.modify_fields_to_db( get_obj_persistent_fields(self.objs[0]))) @@ -1126,7 +1128,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): update_mock.return_value[key] = value obj.update() update_mock.assert_called_once_with( - self.context, self._test_class.db_model, + obj, self.context, self._test_class.modify_fields_to_db(fields_to_update), **fixed_keys) @@ -1172,7 +1174,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj.delete() self._check_equal(self.objs[0], obj) delete_mock.assert_called_once_with( - self.context, self._test_class.db_model, + obj, self.context, **self._test_class.modify_fields_to_db(obj._get_composite_keys())) @mock.patch(OBJECTS_BASE_OBJ_FROM_PRIMITIVE) @@ -1229,7 +1231,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): pager = base.Pager() self._test_class.get_objects(self.context, _pager=pager) get_objects.assert_called_once_with( - mock.ANY, self._test_class.db_model, _pager=pager) + self._test_class, mock.ANY, _pager=pager) class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase): @@ -1545,9 +1547,8 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, 'revision_number': tools.get_random_integer() } return obj_db_api.create_object( - self.context, - standard_attr.StandardAttribute, attrs, - populate_id=False)['id'] + stdattrs.StandardAttribute, + self.context, attrs, populate_id=False)['id'] def _create_test_flavor_id(self): attrs = self.get_random_object_fields(obj_cls=flavor.Flavor) @@ -1673,19 +1674,27 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, obj.delete() self.assertEqual(1, mock_commit.call_count) - @mock.patch(SQLALCHEMY_COMMIT) - def test_get_objects_single_transaction(self, mock_commit): - self._test_class.get_objects(self.context) - self.assertEqual(1, mock_commit.call_count) + def _get_ro_txn_exit_func_name(self): + # for old engine facade, we didn't have distinction between r/o and r/w + # transactions and so we always call commit even for getters when the + # old facade is used + return ( + SQLALCHEMY_CLOSE + if self._test_class.new_facade else SQLALCHEMY_COMMIT) - @mock.patch(SQLALCHEMY_COMMIT) - def test_get_object_single_transaction(self, mock_commit): + def test_get_objects_single_transaction(self): + with mock.patch(self._get_ro_txn_exit_func_name()) as mock_exit: + self._test_class.get_objects(self.context) + self.assertEqual(1, mock_exit.call_count) + + def test_get_object_single_transaction(self): obj = self._make_object(self.obj_fields[0]) obj.create() - obj = self._test_class.get_object(self.context, - **obj._get_composite_keys()) - self.assertEqual(2, mock_commit.call_count) + with mock.patch(self._get_ro_txn_exit_func_name()) as mock_exit: + obj = self._test_class.get_object(self.context, + **obj._get_composite_keys()) + self.assertEqual(1, mock_exit.call_count) def test_get_objects_supports_extra_filtername(self): self.filtered_args = None diff --git a/neutron/tests/unit/objects/test_network.py b/neutron/tests/unit/objects/test_network.py index f86ad727f16..04ae46d3dd9 100644 --- a/neutron/tests/unit/objects/test_network.py +++ b/neutron/tests/unit/objects/test_network.py @@ -20,6 +20,8 @@ from neutron.tests.unit.objects import test_base as obj_test_base from neutron.tests.unit import testlib_api +# TODO(ihrachys): add tests for NetworkRBAC + class NetworkDhcpAgentBindingObjectIfaceTestCase( obj_test_base.BaseObjectIfaceTestCase): diff --git a/neutron/tests/unit/objects/test_objects.py b/neutron/tests/unit/objects/test_objects.py index 984cea9ebc7..826db0b962e 100644 --- a/neutron/tests/unit/objects/test_objects.py +++ b/neutron/tests/unit/objects/test_objects.py @@ -60,6 +60,7 @@ object_data = { 'NetworkDhcpAgentBinding': '1.0-6eeceb5fb4335cd65a305016deb41c68', 'NetworkDNSDomain': '1.0-420db7910294608534c1e2e30d6d8319', 'NetworkPortSecurity': '1.0-b30802391a87945ee9c07582b4ff95e3', + 'NetworkRBAC': '1.0-c8a67f39809c5a3c8c7f26f2f2c620b2', 'NetworkSegment': '1.0-57b7f2960971e3b95ded20cbc59244a8', 'Port': '1.1-5bf48d12a7bf7f5b7a319e8003b437a5', 'PortBinding': '1.0-3306deeaa6deb01e33af06777d48d578', @@ -72,6 +73,7 @@ object_data = { 'QosBandwidthLimitRule': '1.3-51b662b12a8d1dfa89288d826c6d26d3', 'QosDscpMarkingRule': '1.3-0313c6554b34fd10c753cb63d638256c', 'QosMinimumBandwidthRule': '1.3-314c3419f4799067cc31cc319080adff', + 'QosPolicyRBAC': '1.0-c8a67f39809c5a3c8c7f26f2f2c620b2', 'QosRuleType': '1.3-7286188edeb3a0386f9cf7979b9700fc', 'QosRuleTypeDriver': '1.0-7d8cb9f0ef661ac03700eae97118e3db', 'QosPolicy': '1.7-4adb0cde3102c10d8970ec9487fd7fe7', @@ -90,9 +92,11 @@ object_data = { 'RouterPort': '1.0-c8c8f499bcdd59186fcd83f323106908', 'RouterRoute': '1.0-07fc5337c801fb8c6ccfbcc5afb45907', 'SecurityGroup': '1.0-e26b90c409b31fd2e3c6fcec402ac0b9', + 'SecurityGroupPortBinding': '1.0-6879d5c0af80396ef5a72934b6a6ef20', 'SecurityGroupRule': '1.0-e9b8dace9d48b936c62ad40fe1f339d5', 'SegmentHostMapping': '1.0-521597cf82ead26217c3bd10738f00f0', 'ServiceProfile': '1.0-9beafc9e7d081b8258f3c5cb66ac5eed', + 'StandardAttribute': '1.0-617d4f46524c4ce734a6fc1cc0ac6a0b', 'Subnet': '1.0-927155c1fdd5a615cbcb981dda97bce4', 'SubnetPool': '1.0-a0e03895d1a6e7b9d4ab7b0ca13c3867', 'SubnetPoolPrefix': '1.0-13c15144135eb869faa4a76dc3ee3b6c', diff --git a/neutron/tests/unit/objects/test_ports.py b/neutron/tests/unit/objects/test_ports.py index c27e4db9365..371846f479f 100644 --- a/neutron/tests/unit/objects/test_ports.py +++ b/neutron/tests/unit/objects/test_ports.py @@ -24,6 +24,16 @@ from neutron.tests.unit.objects import test_base as obj_test_base from neutron.tests.unit import testlib_api +class SecurityGroupPortBindingIfaceObjTestCase( + obj_test_base.BaseObjectIfaceTestCase): + _test_class = ports.SecurityGroupPortBinding + + +class SecurityGroupPortBindingDbObjectTestCase( + obj_test_base.BaseDbObjectTestCase): + _test_class = ports.SecurityGroupPortBinding + + class BasePortBindingDbObjectTestCase(obj_test_base._BaseObjectTestCase, testlib_api.SqlTestCase): def setUp(self): diff --git a/neutron/tests/unit/objects/test_rbac_db.py b/neutron/tests/unit/objects/test_rbac_db.py index 6ce931b71c5..1ed405bc908 100644 --- a/neutron/tests/unit/objects/test_rbac_db.py +++ b/neutron/tests/unit/objects/test_rbac_db.py @@ -41,12 +41,25 @@ class FakeRbacModel(rbac_db_models.RBACColumns, model_base.BASEV2): return (rbac_db_models.ACCESS_SHARED,) +@base.NeutronObjectRegistry.register_if(False) +class FakeNeutronRbacObject(base.NeutronDbObject): + VERSION = '1.0' + + db_model = FakeRbacModel + + fields = { + 'object_id': obj_fields.StringField(), + 'target_tenant': obj_fields.StringField(), + 'action': obj_fields.StringField(), + } + + @base.NeutronObjectRegistry.register_if(False) class FakeNeutronDbObject(rbac_db.NeutronRbacObject): # Version 1.0: Initial version VERSION = '1.0' - rbac_db_model = FakeRbacModel + rbac_db_cls = FakeNeutronRbacObject db_model = FakeDbModel fields = { @@ -72,7 +85,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, super(RbacNeutronDbObjectTestCase, self).setUp() FakeNeutronDbObject.update_post = mock.Mock() - @mock.patch.object(_test_class, 'rbac_db_model') + @mock.patch.object(_test_class.rbac_db_cls, 'db_model') def test_get_tenants_with_shared_access_to_db_obj_return_tenant_ids( self, *mocks): ctx = mock.Mock() @@ -138,7 +151,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, context = mock.Mock(is_admin=True, tenant_id='db_obj_owner_id') self._rbac_policy_generate_change_events( resource=None, trigger='dummy_trigger', context=context, - object_type=self._test_class.rbac_db_model.object_type, + object_type=self._test_class.rbac_db_cls.db_model.object_type, policy={'object_id': 'fake_object_id'}, event_list=(events.BEFORE_CREATE, events.BEFORE_UPDATE)) @@ -154,7 +167,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, n_exc.InvalidInput, self._rbac_policy_generate_change_events, resource=mock.Mock(), trigger='dummy_trigger', context=context, - object_type=self._test_class.rbac_db_model.object_type, + object_type=self._test_class.rbac_db_cls.db_model.object_type, policy={'object_id': 'fake_object_id'}, event_list=(events.BEFORE_CREATE, events.BEFORE_UPDATE)) self.assertFalse(mock_validate_update.called) @@ -165,7 +178,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, self._test_class.validate_rbac_policy_delete( resource=mock.Mock(), event=events.BEFORE_DELETE, trigger='dummy_trigger', context=n_context.get_admin_context(), - object_type=self._test_class.rbac_db_model.object_type, + object_type=self._test_class.rbac_db_cls.db_model.object_type, policy=policy) mock_validate_delete.assert_not_called() @@ -205,7 +218,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, event=events.BEFORE_DELETE, trigger='dummy_trigger', context=context, - object_type=self._test_class.rbac_db_model.object_type, + object_type=self._test_class.rbac_db_cls.db_model.object_type, policy=policy) def test_validate_rbac_policy_delete_not_bound_tenant_success(self): @@ -247,7 +260,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, event=events.BEFORE_DELETE, trigger='dummy_trigger', context=context, - object_type=self._test_class.rbac_db_model.object_type, + object_type=self._test_class.rbac_db_cls.db_model.object_type, policy=policy) @mock.patch.object(_test_class, 'attach_rbac') @@ -257,10 +270,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, def test_update_shared_avoid_duplicate_update( self, mock_validate_delete, get_object_mock, attach_rbac_mock): obj_id = 'fake_obj_id' - self._test_class(mock.Mock()).update_shared(is_shared_new=True, - obj_id=obj_id) + obj = self._test_class(mock.Mock()) + obj.update_shared(is_shared_new=True, obj_id=obj_id) get_object_mock.assert_called_with( - mock.ANY, self._test_class.rbac_db_model, object_id=obj_id, + obj.rbac_db_cls, mock.ANY, object_id=obj_id, target_tenant='*', action=rbac_db_models.ACCESS_SHARED) self.assertFalse(mock_validate_delete.called) self.assertFalse(attach_rbac_mock.called) @@ -275,7 +288,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, test_neutron_obj = self._test_class(mock.Mock()) test_neutron_obj.update_shared(is_shared_new=True, obj_id=obj_id) get_object_mock.assert_called_with( - mock.ANY, self._test_class.rbac_db_model, object_id=obj_id, + test_neutron_obj.rbac_db_cls, mock.ANY, object_id=obj_id, target_tenant='*', action=rbac_db_models.ACCESS_SHARED) attach_rbac_mock.assert_called_with( @@ -292,10 +305,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, def test_update_shared_remove_wildcard_sharing( self, mock_validate_delete, get_object_mock, attach_rbac_mock): obj_id = 'fake_obj_id' - self._test_class(mock.Mock()).update_shared(is_shared_new=False, - obj_id=obj_id) + obj = self._test_class(mock.Mock()) + obj.update_shared(is_shared_new=False, obj_id=obj_id) get_object_mock.assert_called_with( - mock.ANY, self._test_class.rbac_db_model, object_id=obj_id, + obj.rbac_db_cls, mock.ANY, object_id=obj_id, target_tenant='*', action=rbac_db_models.ACCESS_SHARED) self.assertFalse(attach_rbac_mock.attach_rbac.called) @@ -313,4 +326,4 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, self.assertEqual(rbac_pol['target_tenant'], target_tenant) self.assertEqual(rbac_pol['action'], rbac_db_models.ACCESS_SHARED) self.assertEqual(rbac_pol['object_type'], - self._test_class.rbac_db_model.object_type) + self._test_class.rbac_db_cls.db_model.object_type) diff --git a/neutron/tests/unit/objects/test_subnet.py b/neutron/tests/unit/objects/test_subnet.py index 09d6d1671b6..2ace9cd6293 100644 --- a/neutron/tests/unit/objects/test_subnet.py +++ b/neutron/tests/unit/objects/test_subnet.py @@ -17,6 +17,7 @@ from oslo_utils import uuidutils from neutron.db import rbac_db_models from neutron.objects import base as obj_base from neutron.objects.db import api as obj_db_api +from neutron.objects import network as net_obj from neutron.objects import rbac_db from neutron.objects import subnet from neutron.tests.unit.objects import test_base as obj_test_base @@ -175,8 +176,7 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, 'target_tenant': '*', 'action': rbac_db_models.ACCESS_SHARED } - obj_db_api.create_object(self.context, rbac_db_models.NetworkRBAC, - attrs) + obj_db_api.create_object(net_obj.NetworkRBAC, self.context, attrs) def test_get_subnet_shared_true(self): network = self._create_test_network()