From ce69e7f8f65b175343d7035b3f084c91fd29fbf5 Mon Sep 17 00:00:00 2001 From: Roman Podoliaka Date: Thu, 23 Jan 2014 17:45:50 +0200 Subject: [PATCH] Don't store engine instances in oslo.db oslo.db is meant to be a collection of helper utilities for SQLAlchemy. In order to behave like a 'good' library, we must stop using global state. This patch ensures we don't store any engine instances globally in oslo.db. It's up to end applications to decide how to cope with engines, not oslo.db. Partial-Bug: #1263908 Co-authored-by: Victor Sergeyev Change-Id: I330467ec1317e1a13ff9f0237e2d8900d718e379 --- openstack/common/db/sqlalchemy/migration.py | 37 ++-- .../sqlalchemy/migration_cli/ext_alembic.py | 3 +- .../sqlalchemy/migration_cli/ext_migrate.py | 9 +- openstack/common/db/sqlalchemy/models.py | 8 +- openstack/common/db/sqlalchemy/session.py | 184 ++++++++++-------- openstack/common/db/sqlalchemy/test_base.py | 23 +-- tests/unit/db/sqlalchemy/test_migrate.py | 5 +- tests/unit/db/sqlalchemy/test_migrate_cli.py | 25 ++- .../db/sqlalchemy/test_migration_common.py | 30 +-- tests/unit/db/sqlalchemy/test_sqlalchemy.py | 63 +++--- 10 files changed, 209 insertions(+), 178 deletions(-) diff --git a/openstack/common/db/sqlalchemy/migration.py b/openstack/common/db/sqlalchemy/migration.py index bf9d150b7..04a6514fa 100644 --- a/openstack/common/db/sqlalchemy/migration.py +++ b/openstack/common/db/sqlalchemy/migration.py @@ -51,13 +51,9 @@ import sqlalchemy from sqlalchemy.schema import UniqueConstraint from openstack.common.db import exception -from openstack.common.db.sqlalchemy import session as db_session from openstack.common.gettextutils import _ -get_engine = db_session.get_engine - - def _get_unique_constraints(self, table): """Retrieve information about existing unique constraints of the table @@ -172,11 +168,12 @@ def patch_migrate(): sqlite.SQLiteConstraintGenerator) -def db_sync(abs_path, version=None, init_version=0): +def db_sync(engine, abs_path, version=None, init_version=0): """Upgrade or downgrade a database. Function runs the upgrade() or downgrade() functions in change scripts. + :param engine: SQLAlchemy engine instance for a given database :param abs_path: Absolute path to migrate repository. :param version: Database will upgrade/downgrade until this version. If None - database will update to the latest @@ -190,18 +187,23 @@ def db_sync(abs_path, version=None, init_version=0): raise exception.DbMigrationError( message=_("version should be an integer")) - current_version = db_version(abs_path, init_version) + current_version = db_version(engine, abs_path, init_version) repository = _find_migrate_repo(abs_path) - _db_schema_sanity_check() + _db_schema_sanity_check(engine) if version is None or version > current_version: - return versioning_api.upgrade(get_engine(), repository, version) + return versioning_api.upgrade(engine, repository, version) else: - return versioning_api.downgrade(get_engine(), repository, + return versioning_api.downgrade(engine, repository, version) -def _db_schema_sanity_check(): - engine = get_engine() +def _db_schema_sanity_check(engine): + """Ensure all database tables were created with required parameters. + + :param engine: SQLAlchemy engine instance for a given database + + """ + if engine.name == 'mysql': onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION ' 'from information_schema.TABLES ' @@ -216,23 +218,23 @@ def _db_schema_sanity_check(): ) % ','.join(table_names)) -def db_version(abs_path, init_version): +def db_version(engine, abs_path, init_version): """Show the current version of the repository. + :param engine: SQLAlchemy engine instance for a given database :param abs_path: Absolute path to migrate repository :param version: Initial database version """ repository = _find_migrate_repo(abs_path) try: - return versioning_api.db_version(get_engine(), repository) + return versioning_api.db_version(engine, repository) except versioning_exceptions.DatabaseNotControlledError: meta = sqlalchemy.MetaData() - engine = get_engine() meta.reflect(bind=engine) tables = meta.tables if len(tables) == 0 or 'alembic_version' in tables: db_version_control(abs_path, init_version) - return versioning_api.db_version(get_engine(), repository) + return versioning_api.db_version(engine, repository) else: raise exception.DbMigrationError( message=_( @@ -241,17 +243,18 @@ def db_version(abs_path, init_version): "manually.")) -def db_version_control(abs_path, version=None): +def db_version_control(engine, abs_path, version=None): """Mark a database as under this repository's version control. Once a database is under version control, schema changes should only be done via change scripts in this repository. + :param engine: SQLAlchemy engine instance for a given database :param abs_path: Absolute path to migrate repository :param version: Initial database version """ repository = _find_migrate_repo(abs_path) - versioning_api.version_control(get_engine(), repository, version) + versioning_api.version_control(engine, repository, version) return version diff --git a/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py b/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py index 59add6491..03b8a9897 100644 --- a/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py +++ b/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py @@ -40,6 +40,7 @@ class AlembicExtension(ext_base.MigrationExtensionBase): repo_path = migration_config.get('alembic_repo_path') if repo_path: self.config.set_main_option('script_location', repo_path) + self.db_url = migration_config['db_url'] def upgrade(self, version): return alembic.command.upgrade(self.config, version or 'head') @@ -50,7 +51,7 @@ class AlembicExtension(ext_base.MigrationExtensionBase): return alembic.command.downgrade(self.config, version) def version(self): - engine = db_session.get_engine() + engine = db_session.create_engine(self.db_url) with engine.connect() as conn: context = alembic_migration.MigrationContext.configure(conn) return context.get_current_revision() diff --git a/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py b/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py index fc7e27bcd..d5ae385c4 100644 --- a/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py +++ b/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py @@ -15,6 +15,7 @@ import os from openstack.common.db.sqlalchemy import migration from openstack.common.db.sqlalchemy.migration_cli import ext_base +from openstack.common.db.sqlalchemy import session as db_session from openstack.common.gettextutils import _ # noqa @@ -33,6 +34,8 @@ class MigrateExtension(ext_base.MigrationExtensionBase): def __init__(self, migration_config): self.repository = migration_config.get('migration_repo_path', '') self.init_version = migration_config.get('init_version', 0) + self.db_url = migration_config['db_url'] + self.engine = db_session.create_engine(self.db_url) @property def enabled(self): @@ -41,7 +44,7 @@ class MigrateExtension(ext_base.MigrationExtensionBase): def upgrade(self, version): version = None if version == 'head' else version return migration.db_sync( - self.repository, version, + self.engine, self.repository, version, init_version=self.init_version) def downgrade(self, version): @@ -51,7 +54,7 @@ class MigrateExtension(ext_base.MigrationExtensionBase): version = self.init_version version = int(version) return migration.db_sync( - self.repository, version, + self.engine, self.repository, version, init_version=self.init_version) except ValueError: LOG.error( @@ -63,4 +66,4 @@ class MigrateExtension(ext_base.MigrationExtensionBase): def version(self): return migration.db_version( - self.repository, init_version=self.init_version) + self.engine, self.repository, init_version=self.init_version) diff --git a/openstack/common/db/sqlalchemy/models.py b/openstack/common/db/sqlalchemy/models.py index 9d17de7ee..02712f9ae 100644 --- a/openstack/common/db/sqlalchemy/models.py +++ b/openstack/common/db/sqlalchemy/models.py @@ -26,7 +26,6 @@ from sqlalchemy import Column, Integer from sqlalchemy import DateTime from sqlalchemy.orm import object_mapper -from openstack.common.db.sqlalchemy import session as sa from openstack.common import timeutils @@ -34,10 +33,9 @@ class ModelBase(object): """Base class for models.""" __table_initialized__ = False - def save(self, session=None): + def save(self, session): """Save this object.""" - if not session: - session = sa.get_session() + # NOTE(boris-42): This part of code should be look like: # session.add(self) # session.flush() @@ -110,7 +108,7 @@ class SoftDeleteMixin(object): deleted_at = Column(DateTime) deleted = Column(Integer, default=0) - def soft_delete(self, session=None): + def soft_delete(self, session): """Mark this object as deleted.""" self.deleted = self.id self.deleted_at = timeutils.utcnow() diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py index a72969d3f..9447c5578 100644 --- a/openstack/common/db/sqlalchemy/session.py +++ b/openstack/common/db/sqlalchemy/session.py @@ -87,7 +87,7 @@ Recommended ways to use sessions within this framework: .. code:: python def create_many_foo(context, foos): - session = get_session() + session = sessionmaker() with session.begin(): for foo in foos: foo_ref = models.Foo() @@ -95,7 +95,7 @@ Recommended ways to use sessions within this framework: session.add(foo_ref) def update_bar(context, foo_id, newbar): - session = get_session() + session = sessionmaker() with session.begin(): foo_ref = (model_query(context, models.Foo, session). filter_by(id=foo_id). @@ -142,7 +142,7 @@ Recommended ways to use sessions within this framework: foo1 = models.Foo() foo2 = models.Foo() foo1.id = foo2.id = 1 - session = get_session() + session = sessionmaker() try: with session.begin(): session.add(foo1) @@ -168,7 +168,7 @@ Recommended ways to use sessions within this framework: .. code:: python def myfunc(foo): - session = get_session() + session = sessionmaker() with session.begin(): # do some database things bar = _private_func(foo, session) @@ -176,7 +176,7 @@ Recommended ways to use sessions within this framework: def _private_func(foo, session=None): if not session: - session = get_session() + session = sessionmaker() with session.begin(subtransaction=True): # do some other database things return bar @@ -240,7 +240,7 @@ Efficient use of soft deletes: def complex_soft_delete_with_synchronization_bar(session=None): if session is None: - session = get_session() + session = sessionmaker() with session.begin(subtransactions=True): count = (model_query(BarModel). find(some_condition). @@ -257,7 +257,7 @@ Efficient use of soft deletes: .. code:: python def soft_delete_bar_model(): - session = get_session() + session = sessionmaker() with session.begin(): bar_ref = model_query(BarModel).find(some_condition).first() # Work with bar_ref @@ -269,7 +269,7 @@ Efficient use of soft deletes: .. code:: python def soft_delete_multi_models(): - session = get_session() + session = sessionmaker() with session.begin(): query = (model_query(BarModel, session=session). find(some_condition)) @@ -408,11 +408,6 @@ CONF.register_opts(database_opts, 'database') LOG = logging.getLogger(__name__) -_ENGINE = None -_MAKER = None -_SLAVE_ENGINE = None -_SLAVE_MAKER = None - def set_defaults(sql_connection, sqlite_db, max_pool_size=None, max_overflow=None, pool_timeout=None): @@ -433,24 +428,6 @@ def set_defaults(sql_connection, sqlite_db, max_pool_size=None, pool_timeout=pool_timeout) -def cleanup(): - global _ENGINE, _MAKER - global _SLAVE_ENGINE, _SLAVE_MAKER - - if _MAKER: - _MAKER.close_all() - _MAKER = None - if _ENGINE: - _ENGINE.dispose() - _ENGINE = None - if _SLAVE_MAKER: - _SLAVE_MAKER.close_all() - _SLAVE_MAKER = None - if _SLAVE_ENGINE: - _SLAVE_ENGINE.dispose() - _SLAVE_ENGINE = None - - class SqliteForeignKeysListener(PoolListener): """Ensures that the foreign key constraints are enforced in SQLite. @@ -462,30 +439,6 @@ class SqliteForeignKeysListener(PoolListener): dbapi_con.execute('pragma foreign_keys=ON') -def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False, - slave_session=False, mysql_traditional_mode=False): - """Return a SQLAlchemy session.""" - global _MAKER - global _SLAVE_MAKER - maker = _MAKER - - if slave_session: - maker = _SLAVE_MAKER - - if maker is None: - engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session, - mysql_traditional_mode=mysql_traditional_mode) - maker = get_maker(engine, autocommit, expire_on_commit) - - if slave_session: - _SLAVE_MAKER = maker - else: - _MAKER = maker - - session = maker() - return session - - # note(boris-42): In current versions of DB backends unique constraint # violation messages follow the structure: # @@ -591,15 +544,19 @@ def _raise_if_deadlock_error(operational_error, engine_name): def _wrap_db_error(f): + #TODO(rpodolyaka): in a subsequent commit make this a class decorator to + # ensure it can only applied to Session subclasses instances (as we use + # Session instance bind attribute below) + @functools.wraps(f) - def _wrap(*args, **kwargs): + def _wrap(self, *args, **kwargs): try: - return f(*args, **kwargs) + return f(self, *args, **kwargs) except UnicodeEncodeError: raise exception.DBInvalidUnicodeParameter() except sqla_exc.OperationalError as e: - _raise_if_db_connection_lost(e, get_engine()) - _raise_if_deadlock_error(e, get_engine().name) + _raise_if_db_connection_lost(e, self.bind) + _raise_if_deadlock_error(e, self.bind.dialect.name) # NOTE(comstud): A lot of code is checking for OperationalError # so let's not wrap it for now. raise @@ -612,7 +569,7 @@ def _wrap_db_error(f): # instance_types) there are more than one unique constraint. This # means we should get names of columns, which values violate # unique constraint, from error message. - _raise_if_duplicate_entry_error(e, get_engine().name) + _raise_if_duplicate_entry_error(e, self.bind.dialect.name) raise exception.DBError(e) except Exception as e: LOG.exception(_('DB exception wrapped.')) @@ -620,29 +577,6 @@ def _wrap_db_error(f): return _wrap -def get_engine(sqlite_fk=False, slave_engine=False, - mysql_traditional_mode=False): - """Return a SQLAlchemy engine.""" - global _ENGINE - global _SLAVE_ENGINE - engine = _ENGINE - db_uri = CONF.database.connection - - if slave_engine: - engine = _SLAVE_ENGINE - db_uri = CONF.database.slave_connection - - if engine is None: - engine = create_engine(db_uri, sqlite_fk=sqlite_fk, - mysql_traditional_mode=mysql_traditional_mode) - if slave_engine: - _SLAVE_ENGINE = engine - else: - _ENGINE = engine - - return engine - - def _synchronous_switch_listener(dbapi_conn, connection_rec): """Switch sqlite connections to non-synchronous mode.""" dbapi_conn.execute("PRAGMA synchronous = OFF") @@ -902,3 +836,87 @@ def _assert_matching_drivers(): normal = sqlalchemy.engine.url.make_url(CONF.database.connection) slave = sqlalchemy.engine.url.make_url(CONF.database.slave_connection) assert normal.drivername == slave.drivername + + +class EngineFacade(object): + """A helper class for removing of global engine instances from oslo.db. + + As a library, oslo.db can't decide where to store/when to create engine + and sessionmaker instances, so this must be left for a target application. + + On the other hand, in order to simplify the adoption of oslo.db changes, + we'll provide a helper class, which creates engine and sessionmaker + on its instantiation and provides get_engine()/get_session() methods + that are compatible with corresponding utility functions that currently + exist in target projects, e.g. in Nova. + + engine/sessionmaker instances will still be global (and they are meant to + be global), but they will be stored in the app context, rather that in the + oslo.db context. + + Note: using of this helper is completely optional and you are encouraged to + integrate engine/sessionmaker instances into your apps any way you like + (e.g. one might want to bind a session to a request context). Two important + things to remember: + 1. An Engine instance is effectively a pool of DB connections, so it's + meant to be shared (and it's thread-safe). + 2. A Session instance is not meant to be shared and represents a DB + transactional context (i.e. it's not thread-safe). sessionmaker is + a factory of sessions. + + """ + + def __init__(self, sql_connection, + sqlite_fk=False, mysql_traditional_mode=False, + autocommit=True, expire_on_commit=False): + """Initialize engine and sessionmaker instances. + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param mysql_traditional_mode: enable traditional mode in MySQL + :type mysql_traditional_mode: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + super(EngineFacade, self).__init__() + + self._engine = create_engine( + sql_connection=sql_connection, + sqlite_fk=sqlite_fk, + mysql_traditional_mode=mysql_traditional_mode) + self._session_maker = get_maker( + engine=self._engine, + autocommit=autocommit, + expire_on_commit=expire_on_commit) + + def get_engine(self): + """Get the engine instance (note, that it's shared).""" + + return self._engine + + def get_session(self, **kwargs): + """Get a Session instance. + + If passed, keyword arguments values override the ones used when the + sessionmaker instance was created. + + :keyword autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :keyword expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + for arg in kwargs: + if arg not in ('autocommit', 'expire_on_commit'): + del kwargs[arg] + + return self._session_maker(**kwargs) diff --git a/openstack/common/db/sqlalchemy/test_base.py b/openstack/common/db/sqlalchemy/test_base.py index 62a773218..81ac59e61 100644 --- a/openstack/common/db/sqlalchemy/test_base.py +++ b/openstack/common/db/sqlalchemy/test_base.py @@ -18,7 +18,6 @@ import functools import os import fixtures -from oslo.config import cfg import six from openstack.common.db.sqlalchemy import session @@ -38,18 +37,17 @@ class DbFixture(fixtures.Fixture): def _get_uri(self): return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://') - def __init__(self): + def __init__(self, test): super(DbFixture, self).__init__() - self.conf = cfg.CONF - self.conf.import_opt('connection', - 'openstack.common.db.sqlalchemy.session', - group='database') + + self.test = test def setUp(self): super(DbFixture, self).setUp() - self.conf.set_default('connection', self._get_uri(), group='database') - self.addCleanup(self.conf.reset) + self.test.engine = session.create_engine(self._get_uri()) + self.test.sessionmaker = session.get_maker(self.test.engine) + self.addCleanup(self.test.engine.dispose) class DbTestCase(test.BaseTestCase): @@ -64,9 +62,7 @@ class DbTestCase(test.BaseTestCase): def setUp(self): super(DbTestCase, self).setUp() - self.useFixture(self.FIXTURE()) - - self.addCleanup(session.cleanup) + self.useFixture(self.FIXTURE(self)) ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql'] @@ -83,11 +79,10 @@ def backend_specific(*dialects): if not set(dialects).issubset(ALLOWED_DIALECTS): raise ValueError( "Please use allowed dialects: %s" % ALLOWED_DIALECTS) - engine = session.get_engine() - if engine.name not in dialects: + if self.engine.name not in dialects: msg = ('The test "%s" can be run ' 'only on %s. Current engine is %s.') - args = (f.__name__, ' '.join(dialects), engine.name) + args = (f.__name__, ' '.join(dialects), self.engine.name) self.skip(msg % args) else: return f(self) diff --git a/tests/unit/db/sqlalchemy/test_migrate.py b/tests/unit/db/sqlalchemy/test_migrate.py index 8ba5ddb83..a6e97008e 100644 --- a/tests/unit/db/sqlalchemy/test_migrate.py +++ b/tests/unit/db/sqlalchemy/test_migrate.py @@ -18,7 +18,6 @@ from migrate.changeset.databases import sqlite import sqlalchemy as sa from openstack.common.db.sqlalchemy import migration -from openstack.common.db.sqlalchemy import session from openstack.common.db.sqlalchemy import test_base @@ -44,7 +43,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): test_table = sa.Table( 'test_table', - sa.schema.MetaData(bind=session.get_engine()), + sa.schema.MetaData(bind=self.engine), sa.Column('a', sa.Integer), sa.Column('b', sa.String(10)), sa.Column('c', sa.Integer), @@ -58,7 +57,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): # we actually do in db migrations code self.reflected_table = sa.Table( 'test_table', - sa.schema.MetaData(bind=session.get_engine()), + sa.schema.MetaData(bind=self.engine), autoload=True ) diff --git a/tests/unit/db/sqlalchemy/test_migrate_cli.py b/tests/unit/db/sqlalchemy/test_migrate_cli.py index 5e0d5efd5..6911652c0 100644 --- a/tests/unit/db/sqlalchemy/test_migrate_cli.py +++ b/tests/unit/db/sqlalchemy/test_migrate_cli.py @@ -31,7 +31,8 @@ class MockWithCmp(mock.MagicMock): class TestAlembicExtension(test.BaseTestCase): def setUp(self): - self.migration_config = {'alembic_ini_path': '.'} + self.migration_config = {'alembic_ini_path': '.', + 'db_url': 'sqlite://'} self.alembic = ext_alembic.AlembicExtension(self.migration_config) super(TestAlembicExtension, self).setUp() @@ -90,7 +91,8 @@ class TestAlembicExtension(test.BaseTestCase): class TestMigrateExtension(test.BaseTestCase): def setUp(self): - self.migration_config = {'migration_repo_path': '.'} + self.migration_config = {'migration_repo_path': '.', + 'db_url': 'sqlite://'} self.migrate = ext_migrate.MigrateExtension(self.migration_config) super(TestMigrateExtension, self).setUp() @@ -105,40 +107,41 @@ class TestMigrateExtension(test.BaseTestCase): def test_upgrade_head(self, migration): self.migrate.upgrade('head') migration.db_sync.assert_called_once_with( - self.migrate.repository, None, init_version=0) + self.migrate.engine, self.migrate.repository, None, init_version=0) def test_upgrade_normal(self, migration): self.migrate.upgrade(111) migration.db_sync.assert_called_once_with( - self.migrate.repository, 111, init_version=0) + mock.ANY, self.migrate.repository, 111, init_version=0) def test_downgrade_init_version_from_base(self, migration): self.migrate.downgrade('base') migration.db_sync.assert_called_once_with( - self.migrate.repository, mock.ANY, + self.migrate.engine, self.migrate.repository, mock.ANY, init_version=mock.ANY) def test_downgrade_init_version_from_none(self, migration): self.migrate.downgrade(None) migration.db_sync.assert_called_once_with( - self.migrate.repository, mock.ANY, + self.migrate.engine, self.migrate.repository, mock.ANY, init_version=mock.ANY) def test_downgrade_normal(self, migration): self.migrate.downgrade(101) migration.db_sync.assert_called_once_with( - self.migrate.repository, 101, init_version=0) + self.migrate.engine, self.migrate.repository, 101, init_version=0) def test_version(self, migration): self.migrate.version() migration.db_version.assert_called_once_with( - self.migrate.repository, init_version=0) + self.migrate.engine, self.migrate.repository, init_version=0) def test_change_init_version(self, migration): self.migration_config['init_version'] = 101 migrate = ext_migrate.MigrateExtension(self.migration_config) migrate.downgrade(None) migration.db_sync.assert_called_once_with( + migrate.engine, self.migrate.repository, self.migration_config['init_version'], init_version=self.migration_config['init_version']) @@ -148,7 +151,8 @@ class TestMigrationManager(test.BaseTestCase): def setUp(self): self.migration_config = {'alembic_ini_path': '.', - 'migrate_repo_path': '.'} + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} self.migration_manager = manager.MigrationManager( self.migration_config) self.ext = mock.Mock() @@ -188,7 +192,8 @@ class TestMigrationRightOrder(test.BaseTestCase): def setUp(self): self.migration_config = {'alembic_ini_path': '.', - 'migrate_repo_path': '.'} + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} self.migration_manager = manager.MigrationManager( self.migration_config) self.first_ext = MockWithCmp() diff --git a/tests/unit/db/sqlalchemy/test_migration_common.py b/tests/unit/db/sqlalchemy/test_migration_common.py index c0648af15..3af0ec8e5 100644 --- a/tests/unit/db/sqlalchemy/test_migration_common.py +++ b/tests/unit/db/sqlalchemy/test_migration_common.py @@ -25,7 +25,6 @@ import sqlalchemy from openstack.common.db import exception as db_exception from openstack.common.db.sqlalchemy import migration -from openstack.common.db.sqlalchemy import session as db_session from openstack.common.db.sqlalchemy import test_base @@ -81,14 +80,15 @@ class TestMigrationCommon(test_base.DbTestCase): mock_find_repo.return_value = self.return_value version = migration.db_version_control( - self.path, self.test_version) + self.engine, self.path, self.test_version) self.assertEqual(version, self.test_version) mock_version_control.assert_called_once_with( - db_session.get_engine(), self.return_value, self.test_version) + self.engine, self.return_value, self.test_version) def test_db_version_return(self): - ret_val = migration.db_version(self.path, self.init_version) + ret_val = migration.db_version(self.engine, self.path, + self.init_version) self.assertEqual(ret_val, self.test_version) def test_db_version_raise_not_controlled_error_first(self): @@ -98,7 +98,8 @@ class TestMigrationCommon(test_base.DbTestCase): migrate_exception.DatabaseNotControlledError('oups'), self.test_version] - ret_val = migration.db_version(self.path, self.init_version) + ret_val = migration.db_version(self.engine, self.path, + self.init_version) self.assertEqual(ret_val, self.test_version) mock_ver.assert_called_once_with(self.path, self.init_version) @@ -112,7 +113,7 @@ class TestMigrationCommon(test_base.DbTestCase): self.assertRaises( db_exception.DbMigrationError, migration.db_version, - self.path, self.init_version) + self.engine, self.path, self.init_version) def test_db_sync_wrong_version(self): self.assertRaises( @@ -128,10 +129,11 @@ class TestMigrationCommon(test_base.DbTestCase): mock_find_repo.return_value = self.return_value self.mock_api_db_version.return_value = self.test_version - 1 - migration.db_sync(self.path, self.test_version, init_ver) + migration.db_sync(self.engine, self.path, self.test_version, + init_ver) mock_upgrade.assert_called_once_with( - db_session.get_engine(), self.return_value, self.test_version) + self.engine, self.return_value, self.test_version) def test_db_sync_downgrade(self): with contextlib.nested( @@ -142,10 +144,10 @@ class TestMigrationCommon(test_base.DbTestCase): mock_find_repo.return_value = self.return_value self.mock_api_db_version.return_value = self.test_version + 1 - migration.db_sync(self.path, self.test_version) + migration.db_sync(self.engine, self.path, self.test_version) mock_downgrade.assert_called_once_with( - db_session.get_engine(), self.return_value, self.test_version) + self.engine, self.return_value, self.test_version) def test_db_sync_sanity_called(self): with contextlib.nested( @@ -155,15 +157,15 @@ class TestMigrationCommon(test_base.DbTestCase): ) as (mock_find_repo, mock_sanity, mock_downgrade): mock_find_repo.return_value = self.return_value - migration.db_sync(self.path, self.test_version) + migration.db_sync(self.engine, self.path, self.test_version) mock_sanity.assert_called_once() def test_db_sanity_table_not_utf8(self): - with mock.patch.object(migration, 'get_engine') as mock_get_eng: - mock_eng = mock_get_eng.return_value + with mock.patch.object(self, 'engine') as mock_eng: type(mock_eng).name = mock.PropertyMock(return_value='mysql') mock_eng.execute.return_value = [['table_A', 'latin1'], ['table_B', 'latin1']] - self.assertRaises(ValueError, migration._db_schema_sanity_check) + self.assertRaises(ValueError, migration._db_schema_sanity_check, + mock_eng) diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py index 4ed5146b0..530e6513a 100644 --- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py +++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py @@ -129,7 +129,7 @@ class SessionErrorWrapperTestCase(test_base.DbTestCase): def setUp(self): super(SessionErrorWrapperTestCase, self).setUp() meta = MetaData() - meta.bind = session.get_engine() + meta.bind = self.engine test_table = Table(_TABLE_NAME, meta, Column('id', Integer, primary_key=True, nullable=False), @@ -143,16 +143,18 @@ class SessionErrorWrapperTestCase(test_base.DbTestCase): self.addCleanup(test_table.drop) def test_flush_wrapper(self): + _session = self.sessionmaker() + tbl = TmpTable() tbl.update({'foo': 10}) - tbl.save() + tbl.save(_session) tbl2 = TmpTable() tbl2.update({'foo': 10}) - self.assertRaises(db_exc.DBDuplicateEntry, tbl2.save) + self.assertRaises(db_exc.DBDuplicateEntry, tbl2.save, _session) def test_execute_wrapper(self): - _session = session.get_session() + _session = self.sessionmaker() with _session.begin(): for i in [10, 20]: tbl = TmpTable() @@ -180,7 +182,7 @@ class RegexpFilterTestCase(test_base.DbTestCase): def setUp(self): super(RegexpFilterTestCase, self).setUp() meta = MetaData() - meta.bind = session.get_engine() + meta.bind = self.engine test_table = Table(_REGEXP_TABLE_NAME, meta, Column('id', Integer, primary_key=True, nullable=False), @@ -189,7 +191,7 @@ class RegexpFilterTestCase(test_base.DbTestCase): self.addCleanup(test_table.drop) def _test_regexp_filter(self, regexp, expected): - _session = session.get_session() + _session = self.sessionmaker() with _session.begin(): for i in ['10', '20', u'♥']: tbl = RegexpTable() @@ -213,26 +215,6 @@ class RegexpFilterTestCase(test_base.DbTestCase): self._test_regexp_filter(u'♦', []) -class SlaveBackendTestCase(test.BaseTestCase): - - def test_slave_engine_nomatch(self): - default = session.CONF.database.connection - session.CONF.database.slave_connection = default - - e = session.get_engine() - slave_e = session.get_engine(slave_engine=True) - self.assertNotEqual(slave_e, e) - - def test_no_slave_engine_match(self): - slave_e = session.get_engine() - e = session.get_engine() - self.assertEqual(slave_e, e) - - def test_slave_backend_nomatch(self): - session.CONF.database.slave_connection = "mysql:///localhost" - self.assertRaises(AssertionError, session._assert_matching_drivers) - - class FakeDBAPIConnection(): def cursor(self): return FakeCursor() @@ -323,7 +305,8 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase): def setUp(self): super(MySQLTraditionalModeTestCase, self).setUp() - self.engine = session.get_engine(mysql_traditional_mode=True) + self.engine = session.create_engine(self.engine.url, + mysql_traditional_mode=True) self.connection = self.engine.connect() meta = MetaData() @@ -333,7 +316,6 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase): Column('bar', String(255))) self.test_table.create() - self.addCleanup(session.cleanup) self.addCleanup(self.test_table.drop) self.addCleanup(self.connection.close) @@ -341,3 +323,28 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase): with self.connection.begin(): self.assertRaises(DataError, self.connection.execute, self.test_table.insert(), bar='a' * 512) + + +class EngineFacadeTestCase(test.BaseTestCase): + def setUp(self): + super(EngineFacadeTestCase, self).setUp() + + self.facade = session.EngineFacade('sqlite://') + + def test_get_engine(self): + eng1 = self.facade.get_engine() + eng2 = self.facade.get_engine() + + self.assertIs(eng1, eng2) + + def test_get_session(self): + ses1 = self.facade.get_session() + ses2 = self.facade.get_session() + + self.assertIsNot(ses1, ses2) + + def test_get_session_arguments_override_default_settings(self): + ses = self.facade.get_session(autocommit=False, expire_on_commit=True) + + self.assertFalse(ses.autocommit) + self.assertTrue(ses.expire_on_commit)