From 67070d4ed5128928b8451856d2841af4e5a54263 Mon Sep 17 00:00:00 2001 From: Stephen Finucane Date: Thu, 14 Sep 2023 12:02:41 +0100 Subject: [PATCH] db: Only call private methods from public methods This is the final preparation step for our removal of autocommit. This makes our life easier since we can insist on only opening transactions inside the public method. On a related point, we remove one commit call from a private method and keep another. Signed-off-by: Stephen Finucane Change-Id: I75cf29a18b2179949ee4713f85ca52256d170c84 --- heat/db/api.py | 248 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 165 insertions(+), 83 deletions(-) diff --git a/heat/db/api.py b/heat/db/api.py index 06d868a935..21879b6c9d 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -34,7 +34,6 @@ from sqlalchemy import and_ from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import orm -from sqlalchemy.orm import aliased as orm_aliased from heat.common import crypt from heat.common import exception @@ -153,6 +152,10 @@ def _soft_delete_aware_query(context, *args, **kwargs): def raw_template_get(context, template_id): + return _raw_template_get(context, template_id) + + +def _raw_template_get(context, template_id): result = context.session.get(models.RawTemplate, template_id) if not result: @@ -169,7 +172,7 @@ def raw_template_create(context, values): def raw_template_update(context, template_id, values): - raw_template_ref = raw_template_get(context, template_id) + raw_template_ref = _raw_template_get(context, template_id) # get only the changed values values = dict((k, v) for k, v in values.items() if getattr(raw_template_ref, k) != v) @@ -182,7 +185,7 @@ def raw_template_update(context, template_id, values): def raw_template_delete(context, template_id): try: - raw_template = raw_template_get(context, template_id) + raw_template = _raw_template_get(context, template_id) except exception.NotFound: # Ignore not found return @@ -196,7 +199,7 @@ def raw_template_delete(context, template_id): if context.session.query(models.RawTemplate).filter_by( files_id=raw_tmpl_files_id).first() is None: try: - raw_tmpl_files = raw_template_files_get( + raw_tmpl_files = _raw_template_files_get( context, raw_tmpl_files_id) except exception.NotFound: # Ignore not found @@ -216,6 +219,10 @@ def raw_template_files_create(context, values): def raw_template_files_get(context, files_id): + return _raw_template_files_get(context, files_id) + + +def _raw_template_files_get(context, files_id): result = context.session.get(models.RawTemplateFiles, files_id) if not result: raise exception.NotFound( @@ -228,6 +235,10 @@ def raw_template_files_get(context, files_id): def resource_create(context, values): + return _resource_create(context, values) + + +def _resource_create(context, values): resource_ref = models.Resource() resource_ref.update(values) resource_ref.save(context.session) @@ -241,7 +252,7 @@ def resource_create_replacement(context, atomic_key, expected_engine_id=None): try: with context.session.begin(): - new_res = resource_create(context, new_res_values) + new_res = _resource_create(context, new_res_values) update_data = {'replaced_by': new_res.id} if not _try_resource_update(context, existing_res_id, update_data, @@ -354,6 +365,12 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id): def resource_get_all_by_physical_resource_id(context, physical_resource_id): + return _resource_get_all_by_physical_resource_id( + context, physical_resource_id, + ) + + +def _resource_get_all_by_physical_resource_id(context, physical_resource_id): results = (context.session.query(models.Resource) .filter_by(physical_resource_id=physical_resource_id) .all()) @@ -365,8 +382,9 @@ def resource_get_all_by_physical_resource_id(context, physical_resource_id): def resource_get_by_physical_resource_id(context, physical_resource_id): - results = resource_get_all_by_physical_resource_id(context, - physical_resource_id) + results = _resource_get_all_by_physical_resource_id( + context, physical_resource_id, + ) try: return next(results) except StopIteration: @@ -515,15 +533,17 @@ def resource_data_get(context, resource_id, key): Decrypts resource data if necessary. """ - result = resource_data_get_by_key(context, - resource_id, - key) + result = _resource_data_get_by_key(context, resource_id, key) if result.redact: return crypt.decrypt(result.decrypt_method, result.value) return result.value def resource_data_get_by_key(context, resource_id, key): + return _resource_data_get_by_key(context, resource_id, key) + + +def _resource_data_get_by_key(context, resource_id, key): """Looks up resource_data by resource_id and key. Does not decrypt resource_data. @@ -544,7 +564,7 @@ def resource_data_set(context, resource_id, key, value, redact=False): else: method = '' try: - current = resource_data_get_by_key(context, resource_id, key) + current = _resource_data_get_by_key(context, resource_id, key) except exception.NotFound: current = models.ResourceData() current.key = key @@ -557,7 +577,7 @@ def resource_data_set(context, resource_id, key, value, redact=False): def resource_data_delete(context, resource_id, key): - result = resource_data_get_by_key(context, resource_id, key) + result = _resource_data_get_by_key(context, resource_id, key) with context.session.begin(): context.session.delete(result) @@ -566,6 +586,10 @@ def resource_data_delete(context, resource_id, key): def resource_prop_data_create_or_update(context, values, rpd_id=None): + return _resource_prop_data_create_or_update(context, values, rpd_id=rpd_id) + + +def _resource_prop_data_create_or_update(context, values, rpd_id=None): obj_ref = None if rpd_id is not None: obj_ref = context.session.query( @@ -578,7 +602,7 @@ def resource_prop_data_create_or_update(context, values, rpd_id=None): def resource_prop_data_create(context, values): - return resource_prop_data_create_or_update(context, values) + return _resource_prop_data_create_or_update(context, values) def resource_prop_data_get(context, resource_prop_data_id): @@ -605,6 +629,10 @@ def stack_get_by_name_and_owner_id(context, stack_name, owner_id): def stack_get_by_name(context, stack_name): + return _stack_get_by_name(context, stack_name) + + +def _stack_get_by_name(context, stack_name): query = _soft_delete_aware_query( context, models.Stack ).filter(sqlalchemy.or_( @@ -615,6 +643,12 @@ def stack_get_by_name(context, stack_name): def stack_get(context, stack_id, show_deleted=False, eager_load=True): + return _stack_get( + context, stack_id, show_deleted=show_deleted, eager_load=eager_load + ) + + +def _stack_get(context, stack_id, show_deleted=False, eager_load=True): options = [] if eager_load: options.append(orm.joinedload(models.Stack.raw_template)) @@ -653,6 +687,10 @@ def stack_get_status(context, stack_id): def stack_get_all_by_owner_id(context, owner_id): + return _stack_get_all_by_owner_id(context, owner_id) + + +def _stack_get_all_by_owner_id(context, owner_id): results = _soft_delete_aware_query( context, models.Stack, ).filter_by( @@ -662,9 +700,13 @@ def stack_get_all_by_owner_id(context, owner_id): def stack_get_all_by_root_owner_id(context, owner_id): - for stack in stack_get_all_by_owner_id(context, owner_id): + return _stack_get_all_by_root_owner_id(context, owner_id) + + +def _stack_get_all_by_root_owner_id(context, owner_id): + for stack in _stack_get_all_by_owner_id(context, owner_id): yield stack - for ch_st in stack_get_all_by_root_owner_id(context, stack.id): + for ch_st in _stack_get_all_by_root_owner_id(context, stack.id): yield ch_st @@ -722,7 +764,7 @@ def _query_stack_get_all(context, show_deleted=False, query = query.options(orm.subqueryload(models.Stack.tags)) if tags: for tag in tags: - tag_alias = orm_aliased(models.StackTag) + tag_alias = orm.aliased(models.StackTag) query = query.join(tag_alias, models.Stack.tags) query = query.filter(tag_alias.tag == tag) @@ -736,7 +778,7 @@ def _query_stack_get_all(context, show_deleted=False, context, models.Stack, show_deleted=show_deleted ) for tag in not_tags: - tag_alias = orm_aliased(models.StackTag) + tag_alias = orm.aliased(models.StackTag) subquery = subquery.join(tag_alias, models.Stack.tags) subquery = subquery.filter(tag_alias.tag == tag) not_stack_ids = [s.id for s in subquery.all()] @@ -811,7 +853,7 @@ def stack_create(context, values): # Even though we just created a stack with this name, we may not find # it again because some unit tests create stacks with deleted_at set. Also # some backup stacks may not be found, for reasons that are unclear. - earliest = stack_get_by_name(context, stack_name) + earliest = _stack_get_by_name(context, stack_name) if earliest is not None and earliest.id != stack_ref.id: with context.session.begin(): context.session.query(models.Stack).filter_by( @@ -841,7 +883,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None): 'expected traversal: %(trav)s', {'id': stack_id, 'vals': str(values), 'trav': str(exp_trvsl)}) - if not stack_get(context, stack_id, eager_load=False): + if not _stack_get(context, stack_id, eager_load=False): raise exception.NotFound( _('Attempt to update a stack with id: ' '%(id)s %(msg)s') % { @@ -852,7 +894,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None): def stack_delete(context, stack_id): - s = stack_get(context, stack_id, eager_load=False) + s = _stack_get(context, stack_id, eager_load=False) if not s: raise exception.NotFound(_('Attempt to delete a stack with id: ' '%(id)s %(msg)s') % { @@ -873,7 +915,14 @@ def stack_delete(context, stack_id): _soft_delete(context, s) -def reset_stack_status(context, stack_id, stack=None): +def reset_stack_status(context, stack_id): + return _reset_stack_status(context, stack_id) + + +# NOTE(stephenfin): This method uses separate transactions to delete nested +# stacks, thus it's the only private method that is allowed to open a +# transaction (via 'context.session.begin') +def _reset_stack_status(context, stack_id, stack=None): if stack is None: stack = context.session.get(models.Stack, stack_id) @@ -901,7 +950,7 @@ def reset_stack_status(context, stack_id, stack=None): query = context.session.query(models.Stack).filter_by(owner_id=stack_id) for child in query: - reset_stack_status(context, child.id, child) + _reset_stack_status(context, child.id, child) with context.session.begin(): if stack.status == 'IN_PROGRESS': @@ -915,7 +964,7 @@ def reset_stack_status(context, stack_id, stack=None): def stack_tags_set(context, stack_id, tags): with context.session.begin(): - stack_tags_delete(context, stack_id) + _stack_tags_delete(context, stack_id) result = [] for tag in tags: stack_tag = models.StackTag() @@ -928,13 +977,21 @@ def stack_tags_set(context, stack_id, tags): def stack_tags_delete(context, stack_id): with transaction(context): - result = stack_tags_get(context, stack_id) - if result: - for tag in result: - context.session.delete(tag) + return _stack_tags_delete(context, stack_id) + + +def _stack_tags_delete(context, stack_id): + result = _stack_tags_get(context, stack_id) + if result: + for tag in result: + context.session.delete(tag) def stack_tags_get(context, stack_id): + return _stack_tags_get(context, stack_id) + + +def _stack_tags_get(context, stack_id): result = (context.session.query(models.StackTag) .filter_by(stack_id=stack_id) .all()) @@ -1004,11 +1061,11 @@ def stack_lock_release(context, stack_id, engine_id): def stack_get_root_id(context, stack_id): - s = stack_get(context, stack_id, eager_load=False) + s = _stack_get(context, stack_id, eager_load=False) if not s: return None while s.owner_id: - s = stack_get(context, s.owner_id, eager_load=False) + s = _stack_get(context, s.owner_id, eager_load=False) return s.id @@ -1152,6 +1209,10 @@ def _events_filter_and_page_query(context, query, def event_count_all_by_stack(context, stack_id): + return _event_count_all_by_stack(context, stack_id) + + +def _event_count_all_by_stack(context, stack_id): query = context.session.query(func.count(models.Event.id)) return query.filter_by(stack_id=stack_id).scalar() @@ -1212,47 +1273,46 @@ def _delete_event_rows(context, stack_id, limit): # So we must manually supply the IN() values. # pgsql SHOULD work with the pure DELETE/JOIN below but that must be # confirmed via integration tests. - with context.session.begin(): - query = context.session.query(models.Event).filter_by( - stack_id=stack_id, - ) - query = query.order_by(models.Event.id).limit(limit) - id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()] - if not id_pairs: - return 0 - (ids, rsrc_prop_ids) = zip(*id_pairs) - max_id = ids[-1] - # delete the events - retval = context.session.query(models.Event).filter( - models.Event.id <= max_id).filter( - models.Event.stack_id == stack_id).delete() + query = context.session.query(models.Event).filter_by( + stack_id=stack_id, + ) + query = query.order_by(models.Event.id).limit(limit) + id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()] + if not id_pairs: + return 0 + (ids, rsrc_prop_ids) = zip(*id_pairs) + max_id = ids[-1] + # delete the events + retval = context.session.query(models.Event).filter( + models.Event.id <= max_id).filter( + models.Event.stack_id == stack_id).delete() - # delete unreferenced resource_properties_data - def del_rpd(rpd_ids): - if not rpd_ids: - return - q_rpd = context.session.query(models.ResourcePropertiesData) - q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids)) - q_rpd.delete(synchronize_session=False) + # delete unreferenced resource_properties_data + def del_rpd(rpd_ids): + if not rpd_ids: + return + q_rpd = context.session.query(models.ResourcePropertiesData) + q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids)) + q_rpd.delete(synchronize_session=False) - if rsrc_prop_ids: - clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context, - stack_id) - clr_prop_ids.discard(None) - try: - del_rpd(clr_prop_ids) - except db_exception.DBReferenceError: - LOG.debug('Checking backup/stack pairs for RPD references') - found = False - for partner_stack_id in _all_backup_stack_ids(context, - stack_id): - found = True - clr_prop_ids -= _find_rpd_references(context, - partner_stack_id) - if not found: - LOG.debug('No backup/stack pairs found for %s', stack_id) - raise - del_rpd(clr_prop_ids) + if rsrc_prop_ids: + clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context, + stack_id) + clr_prop_ids.discard(None) + try: + del_rpd(clr_prop_ids) + except db_exception.DBReferenceError: + LOG.debug('Checking backup/stack pairs for RPD references') + found = False + for partner_stack_id in _all_backup_stack_ids(context, + stack_id): + found = True + clr_prop_ids -= _find_rpd_references(context, + partner_stack_id) + if not found: + LOG.debug('No backup/stack pairs found for %s', stack_id) + raise + del_rpd(clr_prop_ids) return retval @@ -1263,13 +1323,18 @@ def event_create(context, values): # only count events and purge on average # 200.0/cfg.CONF.event_purge_batch_size percent of the time. check = (2.0 / cfg.CONF.event_purge_batch_size) > random.uniform(0, 1) - if (check and - (event_count_all_by_stack(context, values['stack_id']) >= - cfg.CONF.max_events_per_stack)): + if ( + check and _event_count_all_by_stack( + context, values['stack_id'] + ) >= cfg.CONF.max_events_per_stack + ): # prune try: - _delete_event_rows(context, values['stack_id'], - cfg.CONF.event_purge_batch_size) + with context.session.begin(): + _delete_event_rows( + context, values['stack_id'], + cfg.CONF.event_purge_batch_size, + ) except db_exception.DBError as exc: LOG.error('Failed to purge events: %s', str(exc)) event_ref = models.Event() @@ -1289,6 +1354,10 @@ def software_config_create(context, values): def software_config_get(context, config_id): + return _software_config_get(context, config_id) + + +def _software_config_get(context, config_id): result = context.session.get(models.SoftwareConfig, config_id) if (result is not None and context is not None and not context.is_admin and result.tenant != context.tenant_id): @@ -1309,7 +1378,7 @@ def software_config_get_all(context, limit=None, marker=None): def software_config_delete(context, config_id): - config = software_config_get(context, config_id) + config = _software_config_get(context, config_id) # Query if the software config has been referenced by deployment. result = context.session.query(models.SoftwareDeployment).filter_by( config_id=config_id).first() @@ -1340,6 +1409,10 @@ def software_deployment_create(context, values): def software_deployment_get(context, deployment_id): + return _software_deployment_get(context, deployment_id) + + +def _software_deployment_get(context, deployment_id): result = context.session.get(models.SoftwareDeployment, deployment_id) if (result is not None and context is not None and not context.is_admin and context.tenant_id not in (result.tenant, @@ -1366,7 +1439,7 @@ def software_deployment_get_all(context, server_id=None): def software_deployment_update(context, deployment_id, values): - deployment = software_deployment_get(context, deployment_id) + deployment = _software_deployment_get(context, deployment_id) try: update_and_save(context, deployment, values) except db_exception.DBReferenceError: @@ -1377,7 +1450,7 @@ def software_deployment_update(context, deployment_id, values): def software_deployment_delete(context, deployment_id): - deployment = software_deployment_get(context, deployment_id) + deployment = _software_deployment_get(context, deployment_id) with context.session.begin(): context.session.delete(deployment) @@ -1393,6 +1466,10 @@ def snapshot_create(context, values): def snapshot_get(context, snapshot_id): + return _snapshot_get(context, snapshot_id) + + +def _snapshot_get(context, snapshot_id): result = context.session.get(models.Snapshot, snapshot_id) if (result is not None and context is not None and context.tenant_id != result.tenant): @@ -1405,7 +1482,7 @@ def snapshot_get(context, snapshot_id): def snapshot_get_by_stack(context, snapshot_id, stack): - snapshot = snapshot_get(context, snapshot_id) + snapshot = _snapshot_get(context, snapshot_id) if snapshot.stack_id != stack.id: raise exception.SnapshotNotFound(snapshot=snapshot_id, stack=stack.name) @@ -1414,14 +1491,14 @@ def snapshot_get_by_stack(context, snapshot_id, stack): def snapshot_update(context, snapshot_id, values): - snapshot = snapshot_get(context, snapshot_id) + snapshot = _snapshot_get(context, snapshot_id) snapshot.update(values) snapshot.save(context.session) return snapshot def snapshot_delete(context, snapshot_id): - snapshot = snapshot_get(context, snapshot_id) + snapshot = _snapshot_get(context, snapshot_id) with context.session.begin(): context.session.delete(snapshot) @@ -1442,7 +1519,7 @@ def service_create(context, values): def service_update(context, service_id, values): - service = service_get(context, service_id) + service = _service_get(context, service_id) values.update({'updated_at': timeutils.utcnow()}) service.update(values) service.save(context.session) @@ -1450,7 +1527,7 @@ def service_update(context, service_id, values): def service_delete(context, service_id, soft_delete=True): - service = service_get(context, service_id) + service = _service_get(context, service_id) with context.session.begin(): if soft_delete: _soft_delete(context, service) @@ -1459,6 +1536,10 @@ def service_delete(context, service_id, soft_delete=True): def service_get(context, service_id): + return _service_get(context, service_id) + + +def _service_get(context, service_id): result = context.session.get(models.Service, service_id) if result is None: raise exception.EntityNotFound(entity='Service', name=service_id) @@ -1466,8 +1547,9 @@ def service_get(context, service_id): def service_get_all(context): - return (context.session.query(models.Service). - filter_by(deleted_at=None).all()) + return context.session.query(models.Service).filter_by( + deleted_at=None, + ).all() def service_get_all_by_args(context, host, binary, hostname):