From 685acf601310eda98aa5c25970f6179f2b024358 Mon Sep 17 00:00:00 2001 From: Stephen Finucane Date: Wed, 1 Nov 2023 09:42:24 +0000 Subject: [PATCH] db: Migrate "share snapshot", "share snapshot instance" APIs to enginefacade Thankfully the APIs being migrated here were _mostly_ sharing sessions already, so we can simply migrate from public (decorated) methods to private methods with minimal fuss. Signed-off-by: Stephen Finucane Change-Id: Id1b555e48106662d15e8c50567a5f3acecf6a8f1 --- manila/db/sqlalchemy/api.py | 263 ++++++++++++++----------- manila/tests/db/sqlalchemy/test_api.py | 22 +-- 2 files changed, 154 insertions(+), 131 deletions(-) diff --git a/manila/db/sqlalchemy/api.py b/manila/db/sqlalchemy/api.py index 9e78c4fef4..a16c3bbeb4 100644 --- a/manila/db/sqlalchemy/api.py +++ b/manila/db/sqlalchemy/api.py @@ -1684,65 +1684,70 @@ def _share_instance_update(context, share_instance_id, values): return share_instance_ref +@context_manager.writer def share_and_snapshot_instances_status_update( - context, values, share_instance_ids=None, snapshot_instance_ids=None, - current_expected_status=None): + context, values, share_instance_ids=None, snapshot_instance_ids=None, + current_expected_status=None, +): updated_share_instances = None updated_snapshot_instances = None - session = get_session() - with session.begin(): - if current_expected_status and share_instance_ids: - filters = {'instance_ids': share_instance_ids} - share_instances = _share_instance_get_all( - context, filters=filters, session=session) - all_instances_are_compliant = all( - instance['status'] == current_expected_status - for instance in share_instances) - if not all_instances_are_compliant: - msg = _('At least one of the shares is not in the %(status)s ' - 'status.') % { - 'status': current_expected_status - } - raise exception.InvalidShareInstance(reason=msg) + if current_expected_status and share_instance_ids: + filters = {'instance_ids': share_instance_ids} + share_instances = _share_instance_get_all(context, filters=filters) + all_instances_are_compliant = all( + instance['status'] == current_expected_status + for instance in share_instances) - if current_expected_status and snapshot_instance_ids: - filters = {'instance_ids': snapshot_instance_ids} - snapshot_instances = share_snapshot_instance_get_all_with_filters( - context, filters, session=session) - all_snap_instances_are_compliant = all( - snap_instance['status'] == current_expected_status - for snap_instance in snapshot_instances) - if not all_snap_instances_are_compliant: - msg = _('At least one of the snapshots is not in the ' - '%(status)s status.') % { - 'status': current_expected_status - } - raise exception.InvalidShareSnapshotInstance(reason=msg) + if not all_instances_are_compliant: + msg = _('At least one of the shares is not in the %(status)s ' + 'status.') % { + 'status': current_expected_status + } + raise exception.InvalidShareInstance(reason=msg) - if share_instance_ids: - updated_share_instances = share_instance_status_update( - context, share_instance_ids, values, session=session) + if current_expected_status and snapshot_instance_ids: + filters = {'instance_ids': snapshot_instance_ids} + snapshot_instances = _share_snapshot_instance_get_all_with_filters( + context, filters, + ) + all_snap_instances_are_compliant = all( + snap_instance['status'] == current_expected_status + for snap_instance in snapshot_instances) + if not all_snap_instances_are_compliant: + msg = _('At least one of the snapshots is not in the ' + '%(status)s status.') % { + 'status': current_expected_status + } + raise exception.InvalidShareSnapshotInstance(reason=msg) - if snapshot_instance_ids: - updated_snapshot_instances = ( - share_snapshot_instances_status_update( - context, snapshot_instance_ids, values, session=session)) + if share_instance_ids: + updated_share_instances = _share_instance_status_update( + context, share_instance_ids, values, + ) + + if snapshot_instance_ids: + updated_snapshot_instances = _share_snapshot_instances_status_update( + context, snapshot_instance_ids, values, + ) return updated_share_instances, updated_snapshot_instances @require_context -def share_instance_status_update( - context, share_instance_ids, values, session=None): - session = session or get_session() +@context_manager.writer +def share_instance_status_update(context, share_instance_ids, values): + return _share_instance_status_update(context, share_instance_ids, values) - result = ( - model_query( - context, models.ShareInstance, read_deleted="no", - session=session).filter( - models.ShareInstance.id.in_(share_instance_ids)).update( - values, synchronize_session=False)) + +def _share_instance_status_update(context, share_instance_ids, values): + result = model_query( + context, models.ShareInstance, read_deleted="no", + ).filter( + models.ShareInstance.id.in_(share_instance_ids) + ).update( + values, synchronize_session=False, + ) return result @@ -1784,12 +1789,10 @@ def share_instance_get_all(context, filters=None, session=None): return _share_instance_get_all(context, filters=filters) -# TODO(stephenfin): Remove the 'session' argument once all callers have been -# converted @require_admin_context -def _share_instance_get_all(context, filters=None, session=None): +def _share_instance_get_all(context, filters=None): query = model_query( - context, models.ShareInstance, session=session, read_deleted="no", + context, models.ShareInstance, read_deleted="no", ).options( joinedload('export_locations'), ) @@ -3258,8 +3261,12 @@ def share_instance_access_update(context, access_id, instance_id, updates): @require_context -def share_snapshot_instance_create(context, snapshot_id, values, session=None): - session = session or get_session() +@context_manager.writer +def share_snapshot_instance_create(context, snapshot_id, values): + return _share_snapshot_instance_create(context, snapshot_id, values) + + +def _share_snapshot_instance_create(context, snapshot_id, values): values = copy.deepcopy(values) values['share_snapshot_metadata'] = _metadata_refs( values.get('metadata'), models.ShareSnapshotMetadata) @@ -3272,17 +3279,15 @@ def share_snapshot_instance_create(context, snapshot_id, values, session=None): instance_ref = models.ShareSnapshotInstance() instance_ref.update(values) - instance_ref.save(session=session) + instance_ref.save(session=context.session) - return share_snapshot_instance_get(context, instance_ref['id'], - session=session) + return _share_snapshot_instance_get(context, instance_ref['id']) @require_context +@context_manager.writer def share_snapshot_instance_update(context, instance_id, values): - session = get_session() - instance_ref = share_snapshot_instance_get(context, instance_id, - session=session) + instance_ref = _share_snapshot_instance_get(context, instance_id) _change_size_to_instance_size(values) # NOTE(u_glide): Ignore updates to custom properties @@ -3291,7 +3296,7 @@ def share_snapshot_instance_update(context, instance_id, values): values.pop(extra_key) instance_ref.update(values) - instance_ref.save(session=session) + instance_ref.save(session=context.session) return instance_ref @@ -3302,7 +3307,7 @@ def share_snapshot_instance_delete(context, snapshot_instance_id, with session.begin(): - snapshot_instance_ref = share_snapshot_instance_get( + snapshot_instance_ref = _share_snapshot_instance_get( context, snapshot_instance_id, session=session) access_rules = share_snapshot_access_get_all_for_snapshot_instance( @@ -3316,7 +3321,7 @@ def share_snapshot_instance_delete(context, snapshot_instance_id, snapshot_instance_ref.soft_delete( session=session, update_status=True) - snapshot = share_snapshot_get( + snapshot = _share_snapshot_get( context, snapshot_instance_ref['snapshot_id'], session=session) if len(snapshot.instances) == 0: session.query(models.ShareSnapshotMetadata).filter_by( @@ -3325,11 +3330,18 @@ def share_snapshot_instance_delete(context, snapshot_instance_id, @require_context -def share_snapshot_instance_get(context, snapshot_instance_id, session=None, +@context_manager.reader +def share_snapshot_instance_get(context, snapshot_instance_id, with_share_data=False): + return _share_snapshot_instance_get( + context, snapshot_instance_id, with_share_data=with_share_data, + ) - session = session or get_session() +# TODO(stephenfin): Remove the 'session' argument once all callers have been +# converted +def _share_snapshot_instance_get(context, snapshot_instance_id, session=None, + with_share_data=False): result = _share_snapshot_instance_get_with_filters( context, instance_ids=[snapshot_instance_id], session=session).first() @@ -3344,13 +3356,22 @@ def share_snapshot_instance_get(context, snapshot_instance_id, session=None, @require_context -def share_snapshot_instance_get_all_with_filters(context, search_filters, - with_share_data=False, - session=None): +@context_manager.reader +def share_snapshot_instance_get_all_with_filters( + context, search_filters, with_share_data=False, +): """Get snapshot instances filtered by known attrs, ignore unknown attrs. All filters accept list/tuples to filter on, along with simple values. """ + return _share_snapshot_instance_get_all_with_filters( + context, search_filters, with_share_data=with_share_data, + ) + + +def _share_snapshot_instance_get_all_with_filters( + context, search_filters, with_share_data=False, +): def listify(values): if values: if not isinstance(values, (list, tuple, set)): @@ -3358,17 +3379,17 @@ def share_snapshot_instance_get_all_with_filters(context, search_filters, else: return values - session = session or get_session() _known_filters = ('instance_ids', 'snapshot_ids', 'share_instance_ids', 'statuses') filters = {k: listify(search_filters.get(k)) for k in _known_filters} result = _share_snapshot_instance_get_with_filters( - context, session=session, **filters).all() + context, **filters, + ).all() if with_share_data: - result = _set_share_snapshot_instance_data(context, result, session) + result = _set_share_snapshot_instance_data(context, result) return result @@ -3400,7 +3421,11 @@ def _share_snapshot_instance_get_with_filters(context, instance_ids=None, return query -def _set_share_snapshot_instance_data(context, snapshot_instances, session): +# TODO(stephenfin): Remove the 'session' argument once all callers have been +# converted +def _set_share_snapshot_instance_data( + context, snapshot_instances, session=None, +): if snapshot_instances and not isinstance(snapshot_instances, list): snapshot_instances = [snapshot_instances] @@ -3417,6 +3442,7 @@ def _set_share_snapshot_instance_data(context, snapshot_instances, session): @require_context +@context_manager.writer def share_snapshot_create(context, create_values, create_snapshot_instance=True): values = copy.deepcopy(create_values) @@ -3430,28 +3456,23 @@ def share_snapshot_create(context, create_values, ) snapshot_ref.update(snapshot_values) - session = get_session() - with session.begin(): - share_ref = _share_get( + share_ref = _share_get( + context, + snapshot_values.get('share_id'), + ) + snapshot_instance_values.update( + {'share_instance_id': share_ref.instance.id} + ) + + snapshot_ref.save(session=context.session) + + if create_snapshot_instance: + _share_snapshot_instance_create( context, - snapshot_values.get('share_id'), - session=session, + snapshot_ref['id'], + snapshot_instance_values, ) - snapshot_instance_values.update( - {'share_instance_id': share_ref.instance.id} - ) - - snapshot_ref.save(session=session) - - if create_snapshot_instance: - share_snapshot_instance_create( - context, - snapshot_ref['id'], - snapshot_instance_values, - session=session - ) - return share_snapshot_get( - context, snapshot_values['id'], session=session) + return _share_snapshot_get(context, snapshot_values['id']) @require_admin_context @@ -3477,12 +3498,13 @@ def _snapshot_data_get_for_project( @require_context -def share_snapshot_get(context, snapshot_id, project_only=True, session=None): - return _share_snapshot_get( - context, snapshot_id, project_only=project_only, session=session, - ) +@context_manager.reader +def share_snapshot_get(context, snapshot_id, project_only=True): + return _share_snapshot_get(context, snapshot_id, project_only=project_only) +# TODO(stephenfin): Remove the 'session' argument once all callers have been +# converted def _share_snapshot_get(context, snapshot_id, project_only=True, session=None): result = (model_query(context, models.ShareSnapshot, session=session, project_only=project_only). @@ -3601,6 +3623,7 @@ def _share_snapshot_get_all_with_filters(context, project_id=None, @require_admin_context +@context_manager.reader def share_snapshot_get_all(context, filters=None, limit=None, offset=None, sort_key=None, sort_dir=None): return _share_snapshot_get_all_with_filters( @@ -3609,6 +3632,7 @@ def share_snapshot_get_all(context, filters=None, limit=None, offset=None, @require_admin_context +@context_manager.reader def share_snapshot_get_all_with_count(context, filters=None, limit=None, offset=None, sort_key=None, sort_dir=None): @@ -3620,6 +3644,7 @@ def share_snapshot_get_all_with_count(context, filters=None, limit=None, @require_context +@context_manager.reader def share_snapshot_get_all_by_project(context, project_id, filters=None, limit=None, offset=None, sort_key=None, sort_dir=None): @@ -3630,6 +3655,7 @@ def share_snapshot_get_all_by_project(context, project_id, filters=None, @require_context +@context_manager.reader def share_snapshot_get_all_by_project_with_count(context, project_id, filters=None, limit=None, offset=None, sort_key=None, @@ -3643,6 +3669,7 @@ def share_snapshot_get_all_by_project_with_count(context, project_id, @require_context +@context_manager.reader def share_snapshot_get_all_for_share( context, share_id, filters=None, sort_key=None, sort_dir=None, ): @@ -3661,8 +3688,8 @@ def _share_snapshot_get_all_for_share( @require_context +@context_manager.reader def share_snapshot_get_latest_for_share(context, share_id): - snapshots = _share_snapshot_get_all_with_filters( context, share_id=share_id, sort_key='created_at', sort_dir='desc') return snapshots[0] if snapshots else None @@ -3670,17 +3697,13 @@ def share_snapshot_get_latest_for_share(context, share_id): @require_context @oslo_db_api.wrap_db_retry(max_retries=5, retry_on_deadlock=True) +@context_manager.writer def share_snapshot_update(context, snapshot_id, values): - session = get_session() - with session.begin(): - return _share_snapshot_update(context, snapshot_id, values, - session=session) + return _share_snapshot_update(context, snapshot_id, values) -# TODO(stephenfin): Remove the 'session' argument once all callers have been -# converted -def _share_snapshot_update(context, snapshot_id, values, session=None): - snapshot_ref = _share_snapshot_get(context, snapshot_id, session=session) +def _share_snapshot_update(context, snapshot_id, values): + snapshot_ref = _share_snapshot_get(context, snapshot_id) instance_values, snapshot_values = ( _extract_snapshot_instance_values(values) @@ -3688,26 +3711,34 @@ def _share_snapshot_update(context, snapshot_id, values, session=None): if snapshot_values: snapshot_ref.update(snapshot_values) - snapshot_ref.save(session=session or context.session) + snapshot_ref.save(session=context.session) if instance_values: snapshot_ref.instance.update(instance_values) - snapshot_ref.instance.save(session=session or context.session) + snapshot_ref.instance.save(session=context.session) return snapshot_ref @require_context +@context_manager.writer def share_snapshot_instances_status_update( - context, snapshot_instance_ids, values, session=None): - session = session or get_session() + context, snapshot_instance_ids, values, +): + return _share_snapshot_instances_status_update( + context, snapshot_instance_ids, values, + ) - result = ( - model_query( - context, models.ShareSnapshotInstance, - read_deleted="no", session=session).filter( - models.ShareSnapshotInstance.id.in_(snapshot_instance_ids) - ).update(values, synchronize_session=False)) + +def _share_snapshot_instances_status_update( + context, snapshot_instance_ids, values, +): + result = model_query( + context, models.ShareSnapshotInstance, + read_deleted="no", + ).filter( + models.ShareSnapshotInstance.id.in_(snapshot_instance_ids) + ).update(values, synchronize_session=False) return result @@ -3842,8 +3873,8 @@ def share_snapshot_access_create(context, values): access_ref.update(values) access_ref.save(session=session) - snapshot = share_snapshot_get(context, values['share_snapshot_id'], - session=session) + snapshot = _share_snapshot_get(context, values['share_snapshot_id'], + session=session) for instance in snapshot.instances: vals = { @@ -4059,7 +4090,7 @@ def _share_snapshot_instance_export_locations_get_query(context, session, @require_context def share_snapshot_export_locations_get(context, snapshot_id): session = get_session() - snapshot = share_snapshot_get(context, snapshot_id, session=session) + snapshot = _share_snapshot_get(context, snapshot_id, session=session) ins_ids = [ins['id'] for ins in snapshot.instances] export_locations = _share_snapshot_instance_export_locations_get_query( context, session, {}).filter( diff --git a/manila/tests/db/sqlalchemy/test_api.py b/manila/tests/db/sqlalchemy/test_api.py index d7f0ad41be..84dad1a296 100644 --- a/manila/tests/db/sqlalchemy/test_api.py +++ b/manila/tests/db/sqlalchemy/test_api.py @@ -5097,7 +5097,6 @@ class ShareResourcesAPITestCase(test.TestCase): share_instance = db_utils.create_share_instance( status=constants.STATUS_AVAILABLE, share_id='fake') share_instance_ids = [share_instance['id']] - fake_session = db_api.get_session() snap_instances = [ db_utils.create_snapshot_instance( 'fake_snapshot_id_1', status=constants.STATUS_CREATING, @@ -5108,24 +5107,21 @@ class ShareResourcesAPITestCase(test.TestCase): values = {'status': constants.STATUS_AVAILABLE} mock_update_share_instances = self.mock_object( - db_api, 'share_instance_status_update', + db_api, '_share_instance_status_update', mock.Mock(return_value=[share_instance])) mock_update_snap_instances = self.mock_object( - db_api, 'share_snapshot_instances_status_update', + db_api, '_share_snapshot_instances_status_update', mock.Mock(return_value=snap_instances)) - mock_get_session = self.mock_object( - db_api, 'get_session', mock.Mock(return_value=fake_session)) updated_share_instances, updated_snap_instances = ( db_api.share_and_snapshot_instances_status_update( self.context, values, share_instance_ids=share_instance_ids, snapshot_instance_ids=snapshot_instance_ids)) - mock_get_session.assert_called() mock_update_share_instances.assert_called_once_with( - self.context, share_instance_ids, values, session=fake_session) + self.context, share_instance_ids, values) mock_update_snap_instances.assert_called_once_with( - self.context, snapshot_instance_ids, values, session=fake_session) + self.context, snapshot_instance_ids, values) self.assertEqual(updated_share_instances, [share_instance]) self.assertEqual(updated_snap_instances, snap_instances) @@ -5152,15 +5148,12 @@ class ShareResourcesAPITestCase(test.TestCase): share_instance_ids = [share_instance['id']] snap_instance_ids = [share_snapshot_instance['id']] values = {'status': constants.STATUS_AVAILABLE} - fake_session = db_api.get_session() - mock_get_session = self.mock_object( - db_api, 'get_session', mock.Mock(return_value=fake_session)) mock_instances_get_all = self.mock_object( db_api, '_share_instance_get_all', mock.Mock(return_value=[share_instance])) mock_snap_instances_get_all = self.mock_object( - db_api, 'share_snapshot_instance_get_all_with_filters', + db_api, '_share_snapshot_instance_get_all_with_filters', mock.Mock(return_value=[share_snapshot_instance])) self.assertRaises(expected_exc, @@ -5171,14 +5164,13 @@ class ShareResourcesAPITestCase(test.TestCase): snapshot_instance_ids=snap_instance_ids, current_expected_status=constants.STATUS_AVAILABLE) - mock_get_session.assert_called() mock_instances_get_all.assert_called_once_with( self.context, filters={'instance_ids': share_instance_ids}, - session=fake_session) + ) if snap_instance_status == constants.STATUS_ERROR: mock_snap_instances_get_all.assert_called_once_with( self.context, {'instance_ids': snap_instance_ids}, - session=fake_session) + ) @ddt.ddt