db: Only call private methods from public methods

This is the final preparation step for our removal of autocommit. This
makes our life easier since we can insist on only opening transactions
inside the public method.

On a related point, we remove one commit call from a private method and
keep another.

Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Change-Id: I75cf29a18b2179949ee4713f85ca52256d170c84
This commit is contained in:
Stephen Finucane 2023-09-14 12:02:41 +01:00 committed by Takashi Kajinami
parent 9ba72bbc18
commit 67070d4ed5
1 changed files with 165 additions and 83 deletions

View File

@ -34,7 +34,6 @@ from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import orm
from sqlalchemy.orm import aliased as orm_aliased
from heat.common import crypt
from heat.common import exception
@ -153,6 +152,10 @@ def _soft_delete_aware_query(context, *args, **kwargs):
def raw_template_get(context, template_id):
return _raw_template_get(context, template_id)
def _raw_template_get(context, template_id):
result = context.session.get(models.RawTemplate, template_id)
if not result:
@ -169,7 +172,7 @@ def raw_template_create(context, values):
def raw_template_update(context, template_id, values):
raw_template_ref = raw_template_get(context, template_id)
raw_template_ref = _raw_template_get(context, template_id)
# get only the changed values
values = dict((k, v) for k, v in values.items()
if getattr(raw_template_ref, k) != v)
@ -182,7 +185,7 @@ def raw_template_update(context, template_id, values):
def raw_template_delete(context, template_id):
try:
raw_template = raw_template_get(context, template_id)
raw_template = _raw_template_get(context, template_id)
except exception.NotFound:
# Ignore not found
return
@ -196,7 +199,7 @@ def raw_template_delete(context, template_id):
if context.session.query(models.RawTemplate).filter_by(
files_id=raw_tmpl_files_id).first() is None:
try:
raw_tmpl_files = raw_template_files_get(
raw_tmpl_files = _raw_template_files_get(
context, raw_tmpl_files_id)
except exception.NotFound:
# Ignore not found
@ -216,6 +219,10 @@ def raw_template_files_create(context, values):
def raw_template_files_get(context, files_id):
return _raw_template_files_get(context, files_id)
def _raw_template_files_get(context, files_id):
result = context.session.get(models.RawTemplateFiles, files_id)
if not result:
raise exception.NotFound(
@ -228,6 +235,10 @@ def raw_template_files_get(context, files_id):
def resource_create(context, values):
return _resource_create(context, values)
def _resource_create(context, values):
resource_ref = models.Resource()
resource_ref.update(values)
resource_ref.save(context.session)
@ -241,7 +252,7 @@ def resource_create_replacement(context,
atomic_key, expected_engine_id=None):
try:
with context.session.begin():
new_res = resource_create(context, new_res_values)
new_res = _resource_create(context, new_res_values)
update_data = {'replaced_by': new_res.id}
if not _try_resource_update(context,
existing_res_id, update_data,
@ -354,6 +365,12 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id):
def resource_get_all_by_physical_resource_id(context, physical_resource_id):
return _resource_get_all_by_physical_resource_id(
context, physical_resource_id,
)
def _resource_get_all_by_physical_resource_id(context, physical_resource_id):
results = (context.session.query(models.Resource)
.filter_by(physical_resource_id=physical_resource_id)
.all())
@ -365,8 +382,9 @@ def resource_get_all_by_physical_resource_id(context, physical_resource_id):
def resource_get_by_physical_resource_id(context, physical_resource_id):
results = resource_get_all_by_physical_resource_id(context,
physical_resource_id)
results = _resource_get_all_by_physical_resource_id(
context, physical_resource_id,
)
try:
return next(results)
except StopIteration:
@ -515,15 +533,17 @@ def resource_data_get(context, resource_id, key):
Decrypts resource data if necessary.
"""
result = resource_data_get_by_key(context,
resource_id,
key)
result = _resource_data_get_by_key(context, resource_id, key)
if result.redact:
return crypt.decrypt(result.decrypt_method, result.value)
return result.value
def resource_data_get_by_key(context, resource_id, key):
return _resource_data_get_by_key(context, resource_id, key)
def _resource_data_get_by_key(context, resource_id, key):
"""Looks up resource_data by resource_id and key.
Does not decrypt resource_data.
@ -544,7 +564,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
else:
method = ''
try:
current = resource_data_get_by_key(context, resource_id, key)
current = _resource_data_get_by_key(context, resource_id, key)
except exception.NotFound:
current = models.ResourceData()
current.key = key
@ -557,7 +577,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
def resource_data_delete(context, resource_id, key):
result = resource_data_get_by_key(context, resource_id, key)
result = _resource_data_get_by_key(context, resource_id, key)
with context.session.begin():
context.session.delete(result)
@ -566,6 +586,10 @@ def resource_data_delete(context, resource_id, key):
def resource_prop_data_create_or_update(context, values, rpd_id=None):
return _resource_prop_data_create_or_update(context, values, rpd_id=rpd_id)
def _resource_prop_data_create_or_update(context, values, rpd_id=None):
obj_ref = None
if rpd_id is not None:
obj_ref = context.session.query(
@ -578,7 +602,7 @@ def resource_prop_data_create_or_update(context, values, rpd_id=None):
def resource_prop_data_create(context, values):
return resource_prop_data_create_or_update(context, values)
return _resource_prop_data_create_or_update(context, values)
def resource_prop_data_get(context, resource_prop_data_id):
@ -605,6 +629,10 @@ def stack_get_by_name_and_owner_id(context, stack_name, owner_id):
def stack_get_by_name(context, stack_name):
return _stack_get_by_name(context, stack_name)
def _stack_get_by_name(context, stack_name):
query = _soft_delete_aware_query(
context, models.Stack
).filter(sqlalchemy.or_(
@ -615,6 +643,12 @@ def stack_get_by_name(context, stack_name):
def stack_get(context, stack_id, show_deleted=False, eager_load=True):
return _stack_get(
context, stack_id, show_deleted=show_deleted, eager_load=eager_load
)
def _stack_get(context, stack_id, show_deleted=False, eager_load=True):
options = []
if eager_load:
options.append(orm.joinedload(models.Stack.raw_template))
@ -653,6 +687,10 @@ def stack_get_status(context, stack_id):
def stack_get_all_by_owner_id(context, owner_id):
return _stack_get_all_by_owner_id(context, owner_id)
def _stack_get_all_by_owner_id(context, owner_id):
results = _soft_delete_aware_query(
context, models.Stack,
).filter_by(
@ -662,9 +700,13 @@ def stack_get_all_by_owner_id(context, owner_id):
def stack_get_all_by_root_owner_id(context, owner_id):
for stack in stack_get_all_by_owner_id(context, owner_id):
return _stack_get_all_by_root_owner_id(context, owner_id)
def _stack_get_all_by_root_owner_id(context, owner_id):
for stack in _stack_get_all_by_owner_id(context, owner_id):
yield stack
for ch_st in stack_get_all_by_root_owner_id(context, stack.id):
for ch_st in _stack_get_all_by_root_owner_id(context, stack.id):
yield ch_st
@ -722,7 +764,7 @@ def _query_stack_get_all(context, show_deleted=False,
query = query.options(orm.subqueryload(models.Stack.tags))
if tags:
for tag in tags:
tag_alias = orm_aliased(models.StackTag)
tag_alias = orm.aliased(models.StackTag)
query = query.join(tag_alias, models.Stack.tags)
query = query.filter(tag_alias.tag == tag)
@ -736,7 +778,7 @@ def _query_stack_get_all(context, show_deleted=False,
context, models.Stack, show_deleted=show_deleted
)
for tag in not_tags:
tag_alias = orm_aliased(models.StackTag)
tag_alias = orm.aliased(models.StackTag)
subquery = subquery.join(tag_alias, models.Stack.tags)
subquery = subquery.filter(tag_alias.tag == tag)
not_stack_ids = [s.id for s in subquery.all()]
@ -811,7 +853,7 @@ def stack_create(context, values):
# Even though we just created a stack with this name, we may not find
# it again because some unit tests create stacks with deleted_at set. Also
# some backup stacks may not be found, for reasons that are unclear.
earliest = stack_get_by_name(context, stack_name)
earliest = _stack_get_by_name(context, stack_name)
if earliest is not None and earliest.id != stack_ref.id:
with context.session.begin():
context.session.query(models.Stack).filter_by(
@ -841,7 +883,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None):
'expected traversal: %(trav)s',
{'id': stack_id, 'vals': str(values),
'trav': str(exp_trvsl)})
if not stack_get(context, stack_id, eager_load=False):
if not _stack_get(context, stack_id, eager_load=False):
raise exception.NotFound(
_('Attempt to update a stack with id: '
'%(id)s %(msg)s') % {
@ -852,7 +894,7 @@ def stack_update(context, stack_id, values, exp_trvsl=None):
def stack_delete(context, stack_id):
s = stack_get(context, stack_id, eager_load=False)
s = _stack_get(context, stack_id, eager_load=False)
if not s:
raise exception.NotFound(_('Attempt to delete a stack with id: '
'%(id)s %(msg)s') % {
@ -873,7 +915,14 @@ def stack_delete(context, stack_id):
_soft_delete(context, s)
def reset_stack_status(context, stack_id, stack=None):
def reset_stack_status(context, stack_id):
return _reset_stack_status(context, stack_id)
# NOTE(stephenfin): This method uses separate transactions to delete nested
# stacks, thus it's the only private method that is allowed to open a
# transaction (via 'context.session.begin')
def _reset_stack_status(context, stack_id, stack=None):
if stack is None:
stack = context.session.get(models.Stack, stack_id)
@ -901,7 +950,7 @@ def reset_stack_status(context, stack_id, stack=None):
query = context.session.query(models.Stack).filter_by(owner_id=stack_id)
for child in query:
reset_stack_status(context, child.id, child)
_reset_stack_status(context, child.id, child)
with context.session.begin():
if stack.status == 'IN_PROGRESS':
@ -915,7 +964,7 @@ def reset_stack_status(context, stack_id, stack=None):
def stack_tags_set(context, stack_id, tags):
with context.session.begin():
stack_tags_delete(context, stack_id)
_stack_tags_delete(context, stack_id)
result = []
for tag in tags:
stack_tag = models.StackTag()
@ -928,13 +977,21 @@ def stack_tags_set(context, stack_id, tags):
def stack_tags_delete(context, stack_id):
with transaction(context):
result = stack_tags_get(context, stack_id)
if result:
for tag in result:
context.session.delete(tag)
return _stack_tags_delete(context, stack_id)
def _stack_tags_delete(context, stack_id):
result = _stack_tags_get(context, stack_id)
if result:
for tag in result:
context.session.delete(tag)
def stack_tags_get(context, stack_id):
return _stack_tags_get(context, stack_id)
def _stack_tags_get(context, stack_id):
result = (context.session.query(models.StackTag)
.filter_by(stack_id=stack_id)
.all())
@ -1004,11 +1061,11 @@ def stack_lock_release(context, stack_id, engine_id):
def stack_get_root_id(context, stack_id):
s = stack_get(context, stack_id, eager_load=False)
s = _stack_get(context, stack_id, eager_load=False)
if not s:
return None
while s.owner_id:
s = stack_get(context, s.owner_id, eager_load=False)
s = _stack_get(context, s.owner_id, eager_load=False)
return s.id
@ -1152,6 +1209,10 @@ def _events_filter_and_page_query(context, query,
def event_count_all_by_stack(context, stack_id):
return _event_count_all_by_stack(context, stack_id)
def _event_count_all_by_stack(context, stack_id):
query = context.session.query(func.count(models.Event.id))
return query.filter_by(stack_id=stack_id).scalar()
@ -1212,47 +1273,46 @@ def _delete_event_rows(context, stack_id, limit):
# So we must manually supply the IN() values.
# pgsql SHOULD work with the pure DELETE/JOIN below but that must be
# confirmed via integration tests.
with context.session.begin():
query = context.session.query(models.Event).filter_by(
stack_id=stack_id,
)
query = query.order_by(models.Event.id).limit(limit)
id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()]
if not id_pairs:
return 0
(ids, rsrc_prop_ids) = zip(*id_pairs)
max_id = ids[-1]
# delete the events
retval = context.session.query(models.Event).filter(
models.Event.id <= max_id).filter(
models.Event.stack_id == stack_id).delete()
query = context.session.query(models.Event).filter_by(
stack_id=stack_id,
)
query = query.order_by(models.Event.id).limit(limit)
id_pairs = [(e.id, e.rsrc_prop_data_id) for e in query.all()]
if not id_pairs:
return 0
(ids, rsrc_prop_ids) = zip(*id_pairs)
max_id = ids[-1]
# delete the events
retval = context.session.query(models.Event).filter(
models.Event.id <= max_id).filter(
models.Event.stack_id == stack_id).delete()
# delete unreferenced resource_properties_data
def del_rpd(rpd_ids):
if not rpd_ids:
return
q_rpd = context.session.query(models.ResourcePropertiesData)
q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids))
q_rpd.delete(synchronize_session=False)
# delete unreferenced resource_properties_data
def del_rpd(rpd_ids):
if not rpd_ids:
return
q_rpd = context.session.query(models.ResourcePropertiesData)
q_rpd = q_rpd.filter(models.ResourcePropertiesData.id.in_(rpd_ids))
q_rpd.delete(synchronize_session=False)
if rsrc_prop_ids:
clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context,
stack_id)
clr_prop_ids.discard(None)
try:
del_rpd(clr_prop_ids)
except db_exception.DBReferenceError:
LOG.debug('Checking backup/stack pairs for RPD references')
found = False
for partner_stack_id in _all_backup_stack_ids(context,
stack_id):
found = True
clr_prop_ids -= _find_rpd_references(context,
partner_stack_id)
if not found:
LOG.debug('No backup/stack pairs found for %s', stack_id)
raise
del_rpd(clr_prop_ids)
if rsrc_prop_ids:
clr_prop_ids = set(rsrc_prop_ids) - _find_rpd_references(context,
stack_id)
clr_prop_ids.discard(None)
try:
del_rpd(clr_prop_ids)
except db_exception.DBReferenceError:
LOG.debug('Checking backup/stack pairs for RPD references')
found = False
for partner_stack_id in _all_backup_stack_ids(context,
stack_id):
found = True
clr_prop_ids -= _find_rpd_references(context,
partner_stack_id)
if not found:
LOG.debug('No backup/stack pairs found for %s', stack_id)
raise
del_rpd(clr_prop_ids)
return retval
@ -1263,13 +1323,18 @@ def event_create(context, values):
# only count events and purge on average
# 200.0/cfg.CONF.event_purge_batch_size percent of the time.
check = (2.0 / cfg.CONF.event_purge_batch_size) > random.uniform(0, 1)
if (check and
(event_count_all_by_stack(context, values['stack_id']) >=
cfg.CONF.max_events_per_stack)):
if (
check and _event_count_all_by_stack(
context, values['stack_id']
) >= cfg.CONF.max_events_per_stack
):
# prune
try:
_delete_event_rows(context, values['stack_id'],
cfg.CONF.event_purge_batch_size)
with context.session.begin():
_delete_event_rows(
context, values['stack_id'],
cfg.CONF.event_purge_batch_size,
)
except db_exception.DBError as exc:
LOG.error('Failed to purge events: %s', str(exc))
event_ref = models.Event()
@ -1289,6 +1354,10 @@ def software_config_create(context, values):
def software_config_get(context, config_id):
return _software_config_get(context, config_id)
def _software_config_get(context, config_id):
result = context.session.get(models.SoftwareConfig, config_id)
if (result is not None and context is not None and not context.is_admin and
result.tenant != context.tenant_id):
@ -1309,7 +1378,7 @@ def software_config_get_all(context, limit=None, marker=None):
def software_config_delete(context, config_id):
config = software_config_get(context, config_id)
config = _software_config_get(context, config_id)
# Query if the software config has been referenced by deployment.
result = context.session.query(models.SoftwareDeployment).filter_by(
config_id=config_id).first()
@ -1340,6 +1409,10 @@ def software_deployment_create(context, values):
def software_deployment_get(context, deployment_id):
return _software_deployment_get(context, deployment_id)
def _software_deployment_get(context, deployment_id):
result = context.session.get(models.SoftwareDeployment, deployment_id)
if (result is not None and context is not None and not context.is_admin and
context.tenant_id not in (result.tenant,
@ -1366,7 +1439,7 @@ def software_deployment_get_all(context, server_id=None):
def software_deployment_update(context, deployment_id, values):
deployment = software_deployment_get(context, deployment_id)
deployment = _software_deployment_get(context, deployment_id)
try:
update_and_save(context, deployment, values)
except db_exception.DBReferenceError:
@ -1377,7 +1450,7 @@ def software_deployment_update(context, deployment_id, values):
def software_deployment_delete(context, deployment_id):
deployment = software_deployment_get(context, deployment_id)
deployment = _software_deployment_get(context, deployment_id)
with context.session.begin():
context.session.delete(deployment)
@ -1393,6 +1466,10 @@ def snapshot_create(context, values):
def snapshot_get(context, snapshot_id):
return _snapshot_get(context, snapshot_id)
def _snapshot_get(context, snapshot_id):
result = context.session.get(models.Snapshot, snapshot_id)
if (result is not None and context is not None and
context.tenant_id != result.tenant):
@ -1405,7 +1482,7 @@ def snapshot_get(context, snapshot_id):
def snapshot_get_by_stack(context, snapshot_id, stack):
snapshot = snapshot_get(context, snapshot_id)
snapshot = _snapshot_get(context, snapshot_id)
if snapshot.stack_id != stack.id:
raise exception.SnapshotNotFound(snapshot=snapshot_id,
stack=stack.name)
@ -1414,14 +1491,14 @@ def snapshot_get_by_stack(context, snapshot_id, stack):
def snapshot_update(context, snapshot_id, values):
snapshot = snapshot_get(context, snapshot_id)
snapshot = _snapshot_get(context, snapshot_id)
snapshot.update(values)
snapshot.save(context.session)
return snapshot
def snapshot_delete(context, snapshot_id):
snapshot = snapshot_get(context, snapshot_id)
snapshot = _snapshot_get(context, snapshot_id)
with context.session.begin():
context.session.delete(snapshot)
@ -1442,7 +1519,7 @@ def service_create(context, values):
def service_update(context, service_id, values):
service = service_get(context, service_id)
service = _service_get(context, service_id)
values.update({'updated_at': timeutils.utcnow()})
service.update(values)
service.save(context.session)
@ -1450,7 +1527,7 @@ def service_update(context, service_id, values):
def service_delete(context, service_id, soft_delete=True):
service = service_get(context, service_id)
service = _service_get(context, service_id)
with context.session.begin():
if soft_delete:
_soft_delete(context, service)
@ -1459,6 +1536,10 @@ def service_delete(context, service_id, soft_delete=True):
def service_get(context, service_id):
return _service_get(context, service_id)
def _service_get(context, service_id):
result = context.session.get(models.Service, service_id)
if result is None:
raise exception.EntityNotFound(entity='Service', name=service_id)
@ -1466,8 +1547,9 @@ def service_get(context, service_id):
def service_get_all(context):
return (context.session.query(models.Service).
filter_by(deleted_at=None).all())
return context.session.query(models.Service).filter_by(
deleted_at=None,
).all()
def service_get_all_by_args(context, host, binary, hostname):