db: Migrate "share snapshot", "share snapshot instance" APIs to enginefacade

Thankfully the APIs being migrated here were _mostly_ sharing sessions
already, so we can simply migrate from public (decorated) methods to
private methods with minimal fuss.

Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
Change-Id: Id1b555e48106662d15e8c50567a5f3acecf6a8f1
This commit is contained in:
Stephen Finucane 2023-11-01 09:42:24 +00:00
parent 13b4c31117
commit 685acf6013
2 changed files with 154 additions and 131 deletions

View File

@ -1684,65 +1684,70 @@ def _share_instance_update(context, share_instance_id, values):
return share_instance_ref return share_instance_ref
@context_manager.writer
def share_and_snapshot_instances_status_update( def share_and_snapshot_instances_status_update(
context, values, share_instance_ids=None, snapshot_instance_ids=None, context, values, share_instance_ids=None, snapshot_instance_ids=None,
current_expected_status=None): current_expected_status=None,
):
updated_share_instances = None updated_share_instances = None
updated_snapshot_instances = None updated_snapshot_instances = None
session = get_session()
with session.begin():
if current_expected_status and share_instance_ids:
filters = {'instance_ids': share_instance_ids}
share_instances = _share_instance_get_all(
context, filters=filters, session=session)
all_instances_are_compliant = all(
instance['status'] == current_expected_status
for instance in share_instances)
if not all_instances_are_compliant: if current_expected_status and share_instance_ids:
msg = _('At least one of the shares is not in the %(status)s ' filters = {'instance_ids': share_instance_ids}
'status.') % { share_instances = _share_instance_get_all(context, filters=filters)
'status': current_expected_status all_instances_are_compliant = all(
} instance['status'] == current_expected_status
raise exception.InvalidShareInstance(reason=msg) for instance in share_instances)
if current_expected_status and snapshot_instance_ids: if not all_instances_are_compliant:
filters = {'instance_ids': snapshot_instance_ids} msg = _('At least one of the shares is not in the %(status)s '
snapshot_instances = share_snapshot_instance_get_all_with_filters( 'status.') % {
context, filters, session=session) 'status': current_expected_status
all_snap_instances_are_compliant = all( }
snap_instance['status'] == current_expected_status raise exception.InvalidShareInstance(reason=msg)
for snap_instance in snapshot_instances)
if not all_snap_instances_are_compliant:
msg = _('At least one of the snapshots is not in the '
'%(status)s status.') % {
'status': current_expected_status
}
raise exception.InvalidShareSnapshotInstance(reason=msg)
if share_instance_ids: if current_expected_status and snapshot_instance_ids:
updated_share_instances = share_instance_status_update( filters = {'instance_ids': snapshot_instance_ids}
context, share_instance_ids, values, session=session) snapshot_instances = _share_snapshot_instance_get_all_with_filters(
context, filters,
)
all_snap_instances_are_compliant = all(
snap_instance['status'] == current_expected_status
for snap_instance in snapshot_instances)
if not all_snap_instances_are_compliant:
msg = _('At least one of the snapshots is not in the '
'%(status)s status.') % {
'status': current_expected_status
}
raise exception.InvalidShareSnapshotInstance(reason=msg)
if snapshot_instance_ids: if share_instance_ids:
updated_snapshot_instances = ( updated_share_instances = _share_instance_status_update(
share_snapshot_instances_status_update( context, share_instance_ids, values,
context, snapshot_instance_ids, values, session=session)) )
if snapshot_instance_ids:
updated_snapshot_instances = _share_snapshot_instances_status_update(
context, snapshot_instance_ids, values,
)
return updated_share_instances, updated_snapshot_instances return updated_share_instances, updated_snapshot_instances
@require_context @require_context
def share_instance_status_update( @context_manager.writer
context, share_instance_ids, values, session=None): def share_instance_status_update(context, share_instance_ids, values):
session = session or get_session() return _share_instance_status_update(context, share_instance_ids, values)
result = (
model_query( def _share_instance_status_update(context, share_instance_ids, values):
context, models.ShareInstance, read_deleted="no", result = model_query(
session=session).filter( context, models.ShareInstance, read_deleted="no",
models.ShareInstance.id.in_(share_instance_ids)).update( ).filter(
values, synchronize_session=False)) models.ShareInstance.id.in_(share_instance_ids)
).update(
values, synchronize_session=False,
)
return result return result
@ -1784,12 +1789,10 @@ def share_instance_get_all(context, filters=None, session=None):
return _share_instance_get_all(context, filters=filters) return _share_instance_get_all(context, filters=filters)
# TODO(stephenfin): Remove the 'session' argument once all callers have been
# converted
@require_admin_context @require_admin_context
def _share_instance_get_all(context, filters=None, session=None): def _share_instance_get_all(context, filters=None):
query = model_query( query = model_query(
context, models.ShareInstance, session=session, read_deleted="no", context, models.ShareInstance, read_deleted="no",
).options( ).options(
joinedload('export_locations'), joinedload('export_locations'),
) )
@ -3258,8 +3261,12 @@ def share_instance_access_update(context, access_id, instance_id, updates):
@require_context @require_context
def share_snapshot_instance_create(context, snapshot_id, values, session=None): @context_manager.writer
session = session or get_session() def share_snapshot_instance_create(context, snapshot_id, values):
return _share_snapshot_instance_create(context, snapshot_id, values)
def _share_snapshot_instance_create(context, snapshot_id, values):
values = copy.deepcopy(values) values = copy.deepcopy(values)
values['share_snapshot_metadata'] = _metadata_refs( values['share_snapshot_metadata'] = _metadata_refs(
values.get('metadata'), models.ShareSnapshotMetadata) values.get('metadata'), models.ShareSnapshotMetadata)
@ -3272,17 +3279,15 @@ def share_snapshot_instance_create(context, snapshot_id, values, session=None):
instance_ref = models.ShareSnapshotInstance() instance_ref = models.ShareSnapshotInstance()
instance_ref.update(values) instance_ref.update(values)
instance_ref.save(session=session) instance_ref.save(session=context.session)
return share_snapshot_instance_get(context, instance_ref['id'], return _share_snapshot_instance_get(context, instance_ref['id'])
session=session)
@require_context @require_context
@context_manager.writer
def share_snapshot_instance_update(context, instance_id, values): def share_snapshot_instance_update(context, instance_id, values):
session = get_session() instance_ref = _share_snapshot_instance_get(context, instance_id)
instance_ref = share_snapshot_instance_get(context, instance_id,
session=session)
_change_size_to_instance_size(values) _change_size_to_instance_size(values)
# NOTE(u_glide): Ignore updates to custom properties # NOTE(u_glide): Ignore updates to custom properties
@ -3291,7 +3296,7 @@ def share_snapshot_instance_update(context, instance_id, values):
values.pop(extra_key) values.pop(extra_key)
instance_ref.update(values) instance_ref.update(values)
instance_ref.save(session=session) instance_ref.save(session=context.session)
return instance_ref return instance_ref
@ -3302,7 +3307,7 @@ def share_snapshot_instance_delete(context, snapshot_instance_id,
with session.begin(): with session.begin():
snapshot_instance_ref = share_snapshot_instance_get( snapshot_instance_ref = _share_snapshot_instance_get(
context, snapshot_instance_id, session=session) context, snapshot_instance_id, session=session)
access_rules = share_snapshot_access_get_all_for_snapshot_instance( access_rules = share_snapshot_access_get_all_for_snapshot_instance(
@ -3316,7 +3321,7 @@ def share_snapshot_instance_delete(context, snapshot_instance_id,
snapshot_instance_ref.soft_delete( snapshot_instance_ref.soft_delete(
session=session, update_status=True) session=session, update_status=True)
snapshot = share_snapshot_get( snapshot = _share_snapshot_get(
context, snapshot_instance_ref['snapshot_id'], session=session) context, snapshot_instance_ref['snapshot_id'], session=session)
if len(snapshot.instances) == 0: if len(snapshot.instances) == 0:
session.query(models.ShareSnapshotMetadata).filter_by( session.query(models.ShareSnapshotMetadata).filter_by(
@ -3325,11 +3330,18 @@ def share_snapshot_instance_delete(context, snapshot_instance_id,
@require_context @require_context
def share_snapshot_instance_get(context, snapshot_instance_id, session=None, @context_manager.reader
def share_snapshot_instance_get(context, snapshot_instance_id,
with_share_data=False): with_share_data=False):
return _share_snapshot_instance_get(
context, snapshot_instance_id, with_share_data=with_share_data,
)
session = session or get_session()
# TODO(stephenfin): Remove the 'session' argument once all callers have been
# converted
def _share_snapshot_instance_get(context, snapshot_instance_id, session=None,
with_share_data=False):
result = _share_snapshot_instance_get_with_filters( result = _share_snapshot_instance_get_with_filters(
context, instance_ids=[snapshot_instance_id], session=session).first() context, instance_ids=[snapshot_instance_id], session=session).first()
@ -3344,13 +3356,22 @@ def share_snapshot_instance_get(context, snapshot_instance_id, session=None,
@require_context @require_context
def share_snapshot_instance_get_all_with_filters(context, search_filters, @context_manager.reader
with_share_data=False, def share_snapshot_instance_get_all_with_filters(
session=None): context, search_filters, with_share_data=False,
):
"""Get snapshot instances filtered by known attrs, ignore unknown attrs. """Get snapshot instances filtered by known attrs, ignore unknown attrs.
All filters accept list/tuples to filter on, along with simple values. All filters accept list/tuples to filter on, along with simple values.
""" """
return _share_snapshot_instance_get_all_with_filters(
context, search_filters, with_share_data=with_share_data,
)
def _share_snapshot_instance_get_all_with_filters(
context, search_filters, with_share_data=False,
):
def listify(values): def listify(values):
if values: if values:
if not isinstance(values, (list, tuple, set)): if not isinstance(values, (list, tuple, set)):
@ -3358,17 +3379,17 @@ def share_snapshot_instance_get_all_with_filters(context, search_filters,
else: else:
return values return values
session = session or get_session()
_known_filters = ('instance_ids', 'snapshot_ids', 'share_instance_ids', _known_filters = ('instance_ids', 'snapshot_ids', 'share_instance_ids',
'statuses') 'statuses')
filters = {k: listify(search_filters.get(k)) for k in _known_filters} filters = {k: listify(search_filters.get(k)) for k in _known_filters}
result = _share_snapshot_instance_get_with_filters( result = _share_snapshot_instance_get_with_filters(
context, session=session, **filters).all() context, **filters,
).all()
if with_share_data: if with_share_data:
result = _set_share_snapshot_instance_data(context, result, session) result = _set_share_snapshot_instance_data(context, result)
return result return result
@ -3400,7 +3421,11 @@ def _share_snapshot_instance_get_with_filters(context, instance_ids=None,
return query return query
def _set_share_snapshot_instance_data(context, snapshot_instances, session): # TODO(stephenfin): Remove the 'session' argument once all callers have been
# converted
def _set_share_snapshot_instance_data(
context, snapshot_instances, session=None,
):
if snapshot_instances and not isinstance(snapshot_instances, list): if snapshot_instances and not isinstance(snapshot_instances, list):
snapshot_instances = [snapshot_instances] snapshot_instances = [snapshot_instances]
@ -3417,6 +3442,7 @@ def _set_share_snapshot_instance_data(context, snapshot_instances, session):
@require_context @require_context
@context_manager.writer
def share_snapshot_create(context, create_values, def share_snapshot_create(context, create_values,
create_snapshot_instance=True): create_snapshot_instance=True):
values = copy.deepcopy(create_values) values = copy.deepcopy(create_values)
@ -3430,28 +3456,23 @@ def share_snapshot_create(context, create_values,
) )
snapshot_ref.update(snapshot_values) snapshot_ref.update(snapshot_values)
session = get_session() share_ref = _share_get(
with session.begin(): context,
share_ref = _share_get( snapshot_values.get('share_id'),
)
snapshot_instance_values.update(
{'share_instance_id': share_ref.instance.id}
)
snapshot_ref.save(session=context.session)
if create_snapshot_instance:
_share_snapshot_instance_create(
context, context,
snapshot_values.get('share_id'), snapshot_ref['id'],
session=session, snapshot_instance_values,
) )
snapshot_instance_values.update( return _share_snapshot_get(context, snapshot_values['id'])
{'share_instance_id': share_ref.instance.id}
)
snapshot_ref.save(session=session)
if create_snapshot_instance:
share_snapshot_instance_create(
context,
snapshot_ref['id'],
snapshot_instance_values,
session=session
)
return share_snapshot_get(
context, snapshot_values['id'], session=session)
@require_admin_context @require_admin_context
@ -3477,12 +3498,13 @@ def _snapshot_data_get_for_project(
@require_context @require_context
def share_snapshot_get(context, snapshot_id, project_only=True, session=None): @context_manager.reader
return _share_snapshot_get( def share_snapshot_get(context, snapshot_id, project_only=True):
context, snapshot_id, project_only=project_only, session=session, return _share_snapshot_get(context, snapshot_id, project_only=project_only)
)
# TODO(stephenfin): Remove the 'session' argument once all callers have been
# converted
def _share_snapshot_get(context, snapshot_id, project_only=True, session=None): def _share_snapshot_get(context, snapshot_id, project_only=True, session=None):
result = (model_query(context, models.ShareSnapshot, session=session, result = (model_query(context, models.ShareSnapshot, session=session,
project_only=project_only). project_only=project_only).
@ -3601,6 +3623,7 @@ def _share_snapshot_get_all_with_filters(context, project_id=None,
@require_admin_context @require_admin_context
@context_manager.reader
def share_snapshot_get_all(context, filters=None, limit=None, offset=None, def share_snapshot_get_all(context, filters=None, limit=None, offset=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
return _share_snapshot_get_all_with_filters( return _share_snapshot_get_all_with_filters(
@ -3609,6 +3632,7 @@ def share_snapshot_get_all(context, filters=None, limit=None, offset=None,
@require_admin_context @require_admin_context
@context_manager.reader
def share_snapshot_get_all_with_count(context, filters=None, limit=None, def share_snapshot_get_all_with_count(context, filters=None, limit=None,
offset=None, sort_key=None, offset=None, sort_key=None,
sort_dir=None): sort_dir=None):
@ -3620,6 +3644,7 @@ def share_snapshot_get_all_with_count(context, filters=None, limit=None,
@require_context @require_context
@context_manager.reader
def share_snapshot_get_all_by_project(context, project_id, filters=None, def share_snapshot_get_all_by_project(context, project_id, filters=None,
limit=None, offset=None, limit=None, offset=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
@ -3630,6 +3655,7 @@ def share_snapshot_get_all_by_project(context, project_id, filters=None,
@require_context @require_context
@context_manager.reader
def share_snapshot_get_all_by_project_with_count(context, project_id, def share_snapshot_get_all_by_project_with_count(context, project_id,
filters=None, limit=None, filters=None, limit=None,
offset=None, sort_key=None, offset=None, sort_key=None,
@ -3643,6 +3669,7 @@ def share_snapshot_get_all_by_project_with_count(context, project_id,
@require_context @require_context
@context_manager.reader
def share_snapshot_get_all_for_share( def share_snapshot_get_all_for_share(
context, share_id, filters=None, sort_key=None, sort_dir=None, context, share_id, filters=None, sort_key=None, sort_dir=None,
): ):
@ -3661,8 +3688,8 @@ def _share_snapshot_get_all_for_share(
@require_context @require_context
@context_manager.reader
def share_snapshot_get_latest_for_share(context, share_id): def share_snapshot_get_latest_for_share(context, share_id):
snapshots = _share_snapshot_get_all_with_filters( snapshots = _share_snapshot_get_all_with_filters(
context, share_id=share_id, sort_key='created_at', sort_dir='desc') context, share_id=share_id, sort_key='created_at', sort_dir='desc')
return snapshots[0] if snapshots else None return snapshots[0] if snapshots else None
@ -3670,17 +3697,13 @@ def share_snapshot_get_latest_for_share(context, share_id):
@require_context @require_context
@oslo_db_api.wrap_db_retry(max_retries=5, retry_on_deadlock=True) @oslo_db_api.wrap_db_retry(max_retries=5, retry_on_deadlock=True)
@context_manager.writer
def share_snapshot_update(context, snapshot_id, values): def share_snapshot_update(context, snapshot_id, values):
session = get_session() return _share_snapshot_update(context, snapshot_id, values)
with session.begin():
return _share_snapshot_update(context, snapshot_id, values,
session=session)
# TODO(stephenfin): Remove the 'session' argument once all callers have been def _share_snapshot_update(context, snapshot_id, values):
# converted snapshot_ref = _share_snapshot_get(context, snapshot_id)
def _share_snapshot_update(context, snapshot_id, values, session=None):
snapshot_ref = _share_snapshot_get(context, snapshot_id, session=session)
instance_values, snapshot_values = ( instance_values, snapshot_values = (
_extract_snapshot_instance_values(values) _extract_snapshot_instance_values(values)
@ -3688,26 +3711,34 @@ def _share_snapshot_update(context, snapshot_id, values, session=None):
if snapshot_values: if snapshot_values:
snapshot_ref.update(snapshot_values) snapshot_ref.update(snapshot_values)
snapshot_ref.save(session=session or context.session) snapshot_ref.save(session=context.session)
if instance_values: if instance_values:
snapshot_ref.instance.update(instance_values) snapshot_ref.instance.update(instance_values)
snapshot_ref.instance.save(session=session or context.session) snapshot_ref.instance.save(session=context.session)
return snapshot_ref return snapshot_ref
@require_context @require_context
@context_manager.writer
def share_snapshot_instances_status_update( def share_snapshot_instances_status_update(
context, snapshot_instance_ids, values, session=None): context, snapshot_instance_ids, values,
session = session or get_session() ):
return _share_snapshot_instances_status_update(
context, snapshot_instance_ids, values,
)
result = (
model_query( def _share_snapshot_instances_status_update(
context, models.ShareSnapshotInstance, context, snapshot_instance_ids, values,
read_deleted="no", session=session).filter( ):
models.ShareSnapshotInstance.id.in_(snapshot_instance_ids) result = model_query(
).update(values, synchronize_session=False)) context, models.ShareSnapshotInstance,
read_deleted="no",
).filter(
models.ShareSnapshotInstance.id.in_(snapshot_instance_ids)
).update(values, synchronize_session=False)
return result return result
@ -3842,8 +3873,8 @@ def share_snapshot_access_create(context, values):
access_ref.update(values) access_ref.update(values)
access_ref.save(session=session) access_ref.save(session=session)
snapshot = share_snapshot_get(context, values['share_snapshot_id'], snapshot = _share_snapshot_get(context, values['share_snapshot_id'],
session=session) session=session)
for instance in snapshot.instances: for instance in snapshot.instances:
vals = { vals = {
@ -4059,7 +4090,7 @@ def _share_snapshot_instance_export_locations_get_query(context, session,
@require_context @require_context
def share_snapshot_export_locations_get(context, snapshot_id): def share_snapshot_export_locations_get(context, snapshot_id):
session = get_session() session = get_session()
snapshot = share_snapshot_get(context, snapshot_id, session=session) snapshot = _share_snapshot_get(context, snapshot_id, session=session)
ins_ids = [ins['id'] for ins in snapshot.instances] ins_ids = [ins['id'] for ins in snapshot.instances]
export_locations = _share_snapshot_instance_export_locations_get_query( export_locations = _share_snapshot_instance_export_locations_get_query(
context, session, {}).filter( context, session, {}).filter(

View File

@ -5097,7 +5097,6 @@ class ShareResourcesAPITestCase(test.TestCase):
share_instance = db_utils.create_share_instance( share_instance = db_utils.create_share_instance(
status=constants.STATUS_AVAILABLE, share_id='fake') status=constants.STATUS_AVAILABLE, share_id='fake')
share_instance_ids = [share_instance['id']] share_instance_ids = [share_instance['id']]
fake_session = db_api.get_session()
snap_instances = [ snap_instances = [
db_utils.create_snapshot_instance( db_utils.create_snapshot_instance(
'fake_snapshot_id_1', status=constants.STATUS_CREATING, 'fake_snapshot_id_1', status=constants.STATUS_CREATING,
@ -5108,24 +5107,21 @@ class ShareResourcesAPITestCase(test.TestCase):
values = {'status': constants.STATUS_AVAILABLE} values = {'status': constants.STATUS_AVAILABLE}
mock_update_share_instances = self.mock_object( mock_update_share_instances = self.mock_object(
db_api, 'share_instance_status_update', db_api, '_share_instance_status_update',
mock.Mock(return_value=[share_instance])) mock.Mock(return_value=[share_instance]))
mock_update_snap_instances = self.mock_object( mock_update_snap_instances = self.mock_object(
db_api, 'share_snapshot_instances_status_update', db_api, '_share_snapshot_instances_status_update',
mock.Mock(return_value=snap_instances)) mock.Mock(return_value=snap_instances))
mock_get_session = self.mock_object(
db_api, 'get_session', mock.Mock(return_value=fake_session))
updated_share_instances, updated_snap_instances = ( updated_share_instances, updated_snap_instances = (
db_api.share_and_snapshot_instances_status_update( db_api.share_and_snapshot_instances_status_update(
self.context, values, share_instance_ids=share_instance_ids, self.context, values, share_instance_ids=share_instance_ids,
snapshot_instance_ids=snapshot_instance_ids)) snapshot_instance_ids=snapshot_instance_ids))
mock_get_session.assert_called()
mock_update_share_instances.assert_called_once_with( mock_update_share_instances.assert_called_once_with(
self.context, share_instance_ids, values, session=fake_session) self.context, share_instance_ids, values)
mock_update_snap_instances.assert_called_once_with( mock_update_snap_instances.assert_called_once_with(
self.context, snapshot_instance_ids, values, session=fake_session) self.context, snapshot_instance_ids, values)
self.assertEqual(updated_share_instances, [share_instance]) self.assertEqual(updated_share_instances, [share_instance])
self.assertEqual(updated_snap_instances, snap_instances) self.assertEqual(updated_snap_instances, snap_instances)
@ -5152,15 +5148,12 @@ class ShareResourcesAPITestCase(test.TestCase):
share_instance_ids = [share_instance['id']] share_instance_ids = [share_instance['id']]
snap_instance_ids = [share_snapshot_instance['id']] snap_instance_ids = [share_snapshot_instance['id']]
values = {'status': constants.STATUS_AVAILABLE} values = {'status': constants.STATUS_AVAILABLE}
fake_session = db_api.get_session()
mock_get_session = self.mock_object(
db_api, 'get_session', mock.Mock(return_value=fake_session))
mock_instances_get_all = self.mock_object( mock_instances_get_all = self.mock_object(
db_api, '_share_instance_get_all', db_api, '_share_instance_get_all',
mock.Mock(return_value=[share_instance])) mock.Mock(return_value=[share_instance]))
mock_snap_instances_get_all = self.mock_object( mock_snap_instances_get_all = self.mock_object(
db_api, 'share_snapshot_instance_get_all_with_filters', db_api, '_share_snapshot_instance_get_all_with_filters',
mock.Mock(return_value=[share_snapshot_instance])) mock.Mock(return_value=[share_snapshot_instance]))
self.assertRaises(expected_exc, self.assertRaises(expected_exc,
@ -5171,14 +5164,13 @@ class ShareResourcesAPITestCase(test.TestCase):
snapshot_instance_ids=snap_instance_ids, snapshot_instance_ids=snap_instance_ids,
current_expected_status=constants.STATUS_AVAILABLE) current_expected_status=constants.STATUS_AVAILABLE)
mock_get_session.assert_called()
mock_instances_get_all.assert_called_once_with( mock_instances_get_all.assert_called_once_with(
self.context, filters={'instance_ids': share_instance_ids}, self.context, filters={'instance_ids': share_instance_ids},
session=fake_session) )
if snap_instance_status == constants.STATUS_ERROR: if snap_instance_status == constants.STATUS_ERROR:
mock_snap_instances_get_all.assert_called_once_with( mock_snap_instances_get_all.assert_called_once_with(
self.context, {'instance_ids': snap_instance_ids}, self.context, {'instance_ids': snap_instance_ids},
session=fake_session) )
@ddt.ddt @ddt.ddt