diff --git a/openstack/common/db/sqlalchemy/migration.py b/openstack/common/db/sqlalchemy/migration.py index bf9d150b7..04a6514fa 100644 --- a/openstack/common/db/sqlalchemy/migration.py +++ b/openstack/common/db/sqlalchemy/migration.py @@ -51,13 +51,9 @@ import sqlalchemy from sqlalchemy.schema import UniqueConstraint from openstack.common.db import exception -from openstack.common.db.sqlalchemy import session as db_session from openstack.common.gettextutils import _ -get_engine = db_session.get_engine - - def _get_unique_constraints(self, table): """Retrieve information about existing unique constraints of the table @@ -172,11 +168,12 @@ def patch_migrate(): sqlite.SQLiteConstraintGenerator) -def db_sync(abs_path, version=None, init_version=0): +def db_sync(engine, abs_path, version=None, init_version=0): """Upgrade or downgrade a database. Function runs the upgrade() or downgrade() functions in change scripts. + :param engine: SQLAlchemy engine instance for a given database :param abs_path: Absolute path to migrate repository. :param version: Database will upgrade/downgrade until this version. If None - database will update to the latest @@ -190,18 +187,23 @@ def db_sync(abs_path, version=None, init_version=0): raise exception.DbMigrationError( message=_("version should be an integer")) - current_version = db_version(abs_path, init_version) + current_version = db_version(engine, abs_path, init_version) repository = _find_migrate_repo(abs_path) - _db_schema_sanity_check() + _db_schema_sanity_check(engine) if version is None or version > current_version: - return versioning_api.upgrade(get_engine(), repository, version) + return versioning_api.upgrade(engine, repository, version) else: - return versioning_api.downgrade(get_engine(), repository, + return versioning_api.downgrade(engine, repository, version) -def _db_schema_sanity_check(): - engine = get_engine() +def _db_schema_sanity_check(engine): + """Ensure all database tables were created with required parameters. + + :param engine: SQLAlchemy engine instance for a given database + + """ + if engine.name == 'mysql': onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION ' 'from information_schema.TABLES ' @@ -216,23 +218,23 @@ def _db_schema_sanity_check(): ) % ','.join(table_names)) -def db_version(abs_path, init_version): +def db_version(engine, abs_path, init_version): """Show the current version of the repository. + :param engine: SQLAlchemy engine instance for a given database :param abs_path: Absolute path to migrate repository :param version: Initial database version """ repository = _find_migrate_repo(abs_path) try: - return versioning_api.db_version(get_engine(), repository) + return versioning_api.db_version(engine, repository) except versioning_exceptions.DatabaseNotControlledError: meta = sqlalchemy.MetaData() - engine = get_engine() meta.reflect(bind=engine) tables = meta.tables if len(tables) == 0 or 'alembic_version' in tables: db_version_control(abs_path, init_version) - return versioning_api.db_version(get_engine(), repository) + return versioning_api.db_version(engine, repository) else: raise exception.DbMigrationError( message=_( @@ -241,17 +243,18 @@ def db_version(abs_path, init_version): "manually.")) -def db_version_control(abs_path, version=None): +def db_version_control(engine, abs_path, version=None): """Mark a database as under this repository's version control. Once a database is under version control, schema changes should only be done via change scripts in this repository. + :param engine: SQLAlchemy engine instance for a given database :param abs_path: Absolute path to migrate repository :param version: Initial database version """ repository = _find_migrate_repo(abs_path) - versioning_api.version_control(get_engine(), repository, version) + versioning_api.version_control(engine, repository, version) return version diff --git a/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py b/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py index 59add6491..03b8a9897 100644 --- a/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py +++ b/openstack/common/db/sqlalchemy/migration_cli/ext_alembic.py @@ -40,6 +40,7 @@ class AlembicExtension(ext_base.MigrationExtensionBase): repo_path = migration_config.get('alembic_repo_path') if repo_path: self.config.set_main_option('script_location', repo_path) + self.db_url = migration_config['db_url'] def upgrade(self, version): return alembic.command.upgrade(self.config, version or 'head') @@ -50,7 +51,7 @@ class AlembicExtension(ext_base.MigrationExtensionBase): return alembic.command.downgrade(self.config, version) def version(self): - engine = db_session.get_engine() + engine = db_session.create_engine(self.db_url) with engine.connect() as conn: context = alembic_migration.MigrationContext.configure(conn) return context.get_current_revision() diff --git a/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py b/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py index fc7e27bcd..d5ae385c4 100644 --- a/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py +++ b/openstack/common/db/sqlalchemy/migration_cli/ext_migrate.py @@ -15,6 +15,7 @@ import os from openstack.common.db.sqlalchemy import migration from openstack.common.db.sqlalchemy.migration_cli import ext_base +from openstack.common.db.sqlalchemy import session as db_session from openstack.common.gettextutils import _ # noqa @@ -33,6 +34,8 @@ class MigrateExtension(ext_base.MigrationExtensionBase): def __init__(self, migration_config): self.repository = migration_config.get('migration_repo_path', '') self.init_version = migration_config.get('init_version', 0) + self.db_url = migration_config['db_url'] + self.engine = db_session.create_engine(self.db_url) @property def enabled(self): @@ -41,7 +44,7 @@ class MigrateExtension(ext_base.MigrationExtensionBase): def upgrade(self, version): version = None if version == 'head' else version return migration.db_sync( - self.repository, version, + self.engine, self.repository, version, init_version=self.init_version) def downgrade(self, version): @@ -51,7 +54,7 @@ class MigrateExtension(ext_base.MigrationExtensionBase): version = self.init_version version = int(version) return migration.db_sync( - self.repository, version, + self.engine, self.repository, version, init_version=self.init_version) except ValueError: LOG.error( @@ -63,4 +66,4 @@ class MigrateExtension(ext_base.MigrationExtensionBase): def version(self): return migration.db_version( - self.repository, init_version=self.init_version) + self.engine, self.repository, init_version=self.init_version) diff --git a/openstack/common/db/sqlalchemy/models.py b/openstack/common/db/sqlalchemy/models.py index 9d17de7ee..02712f9ae 100644 --- a/openstack/common/db/sqlalchemy/models.py +++ b/openstack/common/db/sqlalchemy/models.py @@ -26,7 +26,6 @@ from sqlalchemy import Column, Integer from sqlalchemy import DateTime from sqlalchemy.orm import object_mapper -from openstack.common.db.sqlalchemy import session as sa from openstack.common import timeutils @@ -34,10 +33,9 @@ class ModelBase(object): """Base class for models.""" __table_initialized__ = False - def save(self, session=None): + def save(self, session): """Save this object.""" - if not session: - session = sa.get_session() + # NOTE(boris-42): This part of code should be look like: # session.add(self) # session.flush() @@ -110,7 +108,7 @@ class SoftDeleteMixin(object): deleted_at = Column(DateTime) deleted = Column(Integer, default=0) - def soft_delete(self, session=None): + def soft_delete(self, session): """Mark this object as deleted.""" self.deleted = self.id self.deleted_at = timeutils.utcnow() diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py index a72969d3f..9447c5578 100644 --- a/openstack/common/db/sqlalchemy/session.py +++ b/openstack/common/db/sqlalchemy/session.py @@ -87,7 +87,7 @@ Recommended ways to use sessions within this framework: .. code:: python def create_many_foo(context, foos): - session = get_session() + session = sessionmaker() with session.begin(): for foo in foos: foo_ref = models.Foo() @@ -95,7 +95,7 @@ Recommended ways to use sessions within this framework: session.add(foo_ref) def update_bar(context, foo_id, newbar): - session = get_session() + session = sessionmaker() with session.begin(): foo_ref = (model_query(context, models.Foo, session). filter_by(id=foo_id). @@ -142,7 +142,7 @@ Recommended ways to use sessions within this framework: foo1 = models.Foo() foo2 = models.Foo() foo1.id = foo2.id = 1 - session = get_session() + session = sessionmaker() try: with session.begin(): session.add(foo1) @@ -168,7 +168,7 @@ Recommended ways to use sessions within this framework: .. code:: python def myfunc(foo): - session = get_session() + session = sessionmaker() with session.begin(): # do some database things bar = _private_func(foo, session) @@ -176,7 +176,7 @@ Recommended ways to use sessions within this framework: def _private_func(foo, session=None): if not session: - session = get_session() + session = sessionmaker() with session.begin(subtransaction=True): # do some other database things return bar @@ -240,7 +240,7 @@ Efficient use of soft deletes: def complex_soft_delete_with_synchronization_bar(session=None): if session is None: - session = get_session() + session = sessionmaker() with session.begin(subtransactions=True): count = (model_query(BarModel). find(some_condition). @@ -257,7 +257,7 @@ Efficient use of soft deletes: .. code:: python def soft_delete_bar_model(): - session = get_session() + session = sessionmaker() with session.begin(): bar_ref = model_query(BarModel).find(some_condition).first() # Work with bar_ref @@ -269,7 +269,7 @@ Efficient use of soft deletes: .. code:: python def soft_delete_multi_models(): - session = get_session() + session = sessionmaker() with session.begin(): query = (model_query(BarModel, session=session). find(some_condition)) @@ -408,11 +408,6 @@ CONF.register_opts(database_opts, 'database') LOG = logging.getLogger(__name__) -_ENGINE = None -_MAKER = None -_SLAVE_ENGINE = None -_SLAVE_MAKER = None - def set_defaults(sql_connection, sqlite_db, max_pool_size=None, max_overflow=None, pool_timeout=None): @@ -433,24 +428,6 @@ def set_defaults(sql_connection, sqlite_db, max_pool_size=None, pool_timeout=pool_timeout) -def cleanup(): - global _ENGINE, _MAKER - global _SLAVE_ENGINE, _SLAVE_MAKER - - if _MAKER: - _MAKER.close_all() - _MAKER = None - if _ENGINE: - _ENGINE.dispose() - _ENGINE = None - if _SLAVE_MAKER: - _SLAVE_MAKER.close_all() - _SLAVE_MAKER = None - if _SLAVE_ENGINE: - _SLAVE_ENGINE.dispose() - _SLAVE_ENGINE = None - - class SqliteForeignKeysListener(PoolListener): """Ensures that the foreign key constraints are enforced in SQLite. @@ -462,30 +439,6 @@ class SqliteForeignKeysListener(PoolListener): dbapi_con.execute('pragma foreign_keys=ON') -def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False, - slave_session=False, mysql_traditional_mode=False): - """Return a SQLAlchemy session.""" - global _MAKER - global _SLAVE_MAKER - maker = _MAKER - - if slave_session: - maker = _SLAVE_MAKER - - if maker is None: - engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session, - mysql_traditional_mode=mysql_traditional_mode) - maker = get_maker(engine, autocommit, expire_on_commit) - - if slave_session: - _SLAVE_MAKER = maker - else: - _MAKER = maker - - session = maker() - return session - - # note(boris-42): In current versions of DB backends unique constraint # violation messages follow the structure: # @@ -591,15 +544,19 @@ def _raise_if_deadlock_error(operational_error, engine_name): def _wrap_db_error(f): + #TODO(rpodolyaka): in a subsequent commit make this a class decorator to + # ensure it can only applied to Session subclasses instances (as we use + # Session instance bind attribute below) + @functools.wraps(f) - def _wrap(*args, **kwargs): + def _wrap(self, *args, **kwargs): try: - return f(*args, **kwargs) + return f(self, *args, **kwargs) except UnicodeEncodeError: raise exception.DBInvalidUnicodeParameter() except sqla_exc.OperationalError as e: - _raise_if_db_connection_lost(e, get_engine()) - _raise_if_deadlock_error(e, get_engine().name) + _raise_if_db_connection_lost(e, self.bind) + _raise_if_deadlock_error(e, self.bind.dialect.name) # NOTE(comstud): A lot of code is checking for OperationalError # so let's not wrap it for now. raise @@ -612,7 +569,7 @@ def _wrap_db_error(f): # instance_types) there are more than one unique constraint. This # means we should get names of columns, which values violate # unique constraint, from error message. - _raise_if_duplicate_entry_error(e, get_engine().name) + _raise_if_duplicate_entry_error(e, self.bind.dialect.name) raise exception.DBError(e) except Exception as e: LOG.exception(_('DB exception wrapped.')) @@ -620,29 +577,6 @@ def _wrap_db_error(f): return _wrap -def get_engine(sqlite_fk=False, slave_engine=False, - mysql_traditional_mode=False): - """Return a SQLAlchemy engine.""" - global _ENGINE - global _SLAVE_ENGINE - engine = _ENGINE - db_uri = CONF.database.connection - - if slave_engine: - engine = _SLAVE_ENGINE - db_uri = CONF.database.slave_connection - - if engine is None: - engine = create_engine(db_uri, sqlite_fk=sqlite_fk, - mysql_traditional_mode=mysql_traditional_mode) - if slave_engine: - _SLAVE_ENGINE = engine - else: - _ENGINE = engine - - return engine - - def _synchronous_switch_listener(dbapi_conn, connection_rec): """Switch sqlite connections to non-synchronous mode.""" dbapi_conn.execute("PRAGMA synchronous = OFF") @@ -902,3 +836,87 @@ def _assert_matching_drivers(): normal = sqlalchemy.engine.url.make_url(CONF.database.connection) slave = sqlalchemy.engine.url.make_url(CONF.database.slave_connection) assert normal.drivername == slave.drivername + + +class EngineFacade(object): + """A helper class for removing of global engine instances from oslo.db. + + As a library, oslo.db can't decide where to store/when to create engine + and sessionmaker instances, so this must be left for a target application. + + On the other hand, in order to simplify the adoption of oslo.db changes, + we'll provide a helper class, which creates engine and sessionmaker + on its instantiation and provides get_engine()/get_session() methods + that are compatible with corresponding utility functions that currently + exist in target projects, e.g. in Nova. + + engine/sessionmaker instances will still be global (and they are meant to + be global), but they will be stored in the app context, rather that in the + oslo.db context. + + Note: using of this helper is completely optional and you are encouraged to + integrate engine/sessionmaker instances into your apps any way you like + (e.g. one might want to bind a session to a request context). Two important + things to remember: + 1. An Engine instance is effectively a pool of DB connections, so it's + meant to be shared (and it's thread-safe). + 2. A Session instance is not meant to be shared and represents a DB + transactional context (i.e. it's not thread-safe). sessionmaker is + a factory of sessions. + + """ + + def __init__(self, sql_connection, + sqlite_fk=False, mysql_traditional_mode=False, + autocommit=True, expire_on_commit=False): + """Initialize engine and sessionmaker instances. + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param mysql_traditional_mode: enable traditional mode in MySQL + :type mysql_traditional_mode: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + super(EngineFacade, self).__init__() + + self._engine = create_engine( + sql_connection=sql_connection, + sqlite_fk=sqlite_fk, + mysql_traditional_mode=mysql_traditional_mode) + self._session_maker = get_maker( + engine=self._engine, + autocommit=autocommit, + expire_on_commit=expire_on_commit) + + def get_engine(self): + """Get the engine instance (note, that it's shared).""" + + return self._engine + + def get_session(self, **kwargs): + """Get a Session instance. + + If passed, keyword arguments values override the ones used when the + sessionmaker instance was created. + + :keyword autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :keyword expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + for arg in kwargs: + if arg not in ('autocommit', 'expire_on_commit'): + del kwargs[arg] + + return self._session_maker(**kwargs) diff --git a/openstack/common/db/sqlalchemy/test_base.py b/openstack/common/db/sqlalchemy/test_base.py index 62a773218..81ac59e61 100644 --- a/openstack/common/db/sqlalchemy/test_base.py +++ b/openstack/common/db/sqlalchemy/test_base.py @@ -18,7 +18,6 @@ import functools import os import fixtures -from oslo.config import cfg import six from openstack.common.db.sqlalchemy import session @@ -38,18 +37,17 @@ class DbFixture(fixtures.Fixture): def _get_uri(self): return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://') - def __init__(self): + def __init__(self, test): super(DbFixture, self).__init__() - self.conf = cfg.CONF - self.conf.import_opt('connection', - 'openstack.common.db.sqlalchemy.session', - group='database') + + self.test = test def setUp(self): super(DbFixture, self).setUp() - self.conf.set_default('connection', self._get_uri(), group='database') - self.addCleanup(self.conf.reset) + self.test.engine = session.create_engine(self._get_uri()) + self.test.sessionmaker = session.get_maker(self.test.engine) + self.addCleanup(self.test.engine.dispose) class DbTestCase(test.BaseTestCase): @@ -64,9 +62,7 @@ class DbTestCase(test.BaseTestCase): def setUp(self): super(DbTestCase, self).setUp() - self.useFixture(self.FIXTURE()) - - self.addCleanup(session.cleanup) + self.useFixture(self.FIXTURE(self)) ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql'] @@ -83,11 +79,10 @@ def backend_specific(*dialects): if not set(dialects).issubset(ALLOWED_DIALECTS): raise ValueError( "Please use allowed dialects: %s" % ALLOWED_DIALECTS) - engine = session.get_engine() - if engine.name not in dialects: + if self.engine.name not in dialects: msg = ('The test "%s" can be run ' 'only on %s. Current engine is %s.') - args = (f.__name__, ' '.join(dialects), engine.name) + args = (f.__name__, ' '.join(dialects), self.engine.name) self.skip(msg % args) else: return f(self) diff --git a/tests/unit/db/sqlalchemy/test_migrate.py b/tests/unit/db/sqlalchemy/test_migrate.py index 8ba5ddb83..a6e97008e 100644 --- a/tests/unit/db/sqlalchemy/test_migrate.py +++ b/tests/unit/db/sqlalchemy/test_migrate.py @@ -18,7 +18,6 @@ from migrate.changeset.databases import sqlite import sqlalchemy as sa from openstack.common.db.sqlalchemy import migration -from openstack.common.db.sqlalchemy import session from openstack.common.db.sqlalchemy import test_base @@ -44,7 +43,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): test_table = sa.Table( 'test_table', - sa.schema.MetaData(bind=session.get_engine()), + sa.schema.MetaData(bind=self.engine), sa.Column('a', sa.Integer), sa.Column('b', sa.String(10)), sa.Column('c', sa.Integer), @@ -58,7 +57,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): # we actually do in db migrations code self.reflected_table = sa.Table( 'test_table', - sa.schema.MetaData(bind=session.get_engine()), + sa.schema.MetaData(bind=self.engine), autoload=True ) diff --git a/tests/unit/db/sqlalchemy/test_migrate_cli.py b/tests/unit/db/sqlalchemy/test_migrate_cli.py index 5e0d5efd5..6911652c0 100644 --- a/tests/unit/db/sqlalchemy/test_migrate_cli.py +++ b/tests/unit/db/sqlalchemy/test_migrate_cli.py @@ -31,7 +31,8 @@ class MockWithCmp(mock.MagicMock): class TestAlembicExtension(test.BaseTestCase): def setUp(self): - self.migration_config = {'alembic_ini_path': '.'} + self.migration_config = {'alembic_ini_path': '.', + 'db_url': 'sqlite://'} self.alembic = ext_alembic.AlembicExtension(self.migration_config) super(TestAlembicExtension, self).setUp() @@ -90,7 +91,8 @@ class TestAlembicExtension(test.BaseTestCase): class TestMigrateExtension(test.BaseTestCase): def setUp(self): - self.migration_config = {'migration_repo_path': '.'} + self.migration_config = {'migration_repo_path': '.', + 'db_url': 'sqlite://'} self.migrate = ext_migrate.MigrateExtension(self.migration_config) super(TestMigrateExtension, self).setUp() @@ -105,40 +107,41 @@ class TestMigrateExtension(test.BaseTestCase): def test_upgrade_head(self, migration): self.migrate.upgrade('head') migration.db_sync.assert_called_once_with( - self.migrate.repository, None, init_version=0) + self.migrate.engine, self.migrate.repository, None, init_version=0) def test_upgrade_normal(self, migration): self.migrate.upgrade(111) migration.db_sync.assert_called_once_with( - self.migrate.repository, 111, init_version=0) + mock.ANY, self.migrate.repository, 111, init_version=0) def test_downgrade_init_version_from_base(self, migration): self.migrate.downgrade('base') migration.db_sync.assert_called_once_with( - self.migrate.repository, mock.ANY, + self.migrate.engine, self.migrate.repository, mock.ANY, init_version=mock.ANY) def test_downgrade_init_version_from_none(self, migration): self.migrate.downgrade(None) migration.db_sync.assert_called_once_with( - self.migrate.repository, mock.ANY, + self.migrate.engine, self.migrate.repository, mock.ANY, init_version=mock.ANY) def test_downgrade_normal(self, migration): self.migrate.downgrade(101) migration.db_sync.assert_called_once_with( - self.migrate.repository, 101, init_version=0) + self.migrate.engine, self.migrate.repository, 101, init_version=0) def test_version(self, migration): self.migrate.version() migration.db_version.assert_called_once_with( - self.migrate.repository, init_version=0) + self.migrate.engine, self.migrate.repository, init_version=0) def test_change_init_version(self, migration): self.migration_config['init_version'] = 101 migrate = ext_migrate.MigrateExtension(self.migration_config) migrate.downgrade(None) migration.db_sync.assert_called_once_with( + migrate.engine, self.migrate.repository, self.migration_config['init_version'], init_version=self.migration_config['init_version']) @@ -148,7 +151,8 @@ class TestMigrationManager(test.BaseTestCase): def setUp(self): self.migration_config = {'alembic_ini_path': '.', - 'migrate_repo_path': '.'} + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} self.migration_manager = manager.MigrationManager( self.migration_config) self.ext = mock.Mock() @@ -188,7 +192,8 @@ class TestMigrationRightOrder(test.BaseTestCase): def setUp(self): self.migration_config = {'alembic_ini_path': '.', - 'migrate_repo_path': '.'} + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} self.migration_manager = manager.MigrationManager( self.migration_config) self.first_ext = MockWithCmp() diff --git a/tests/unit/db/sqlalchemy/test_migration_common.py b/tests/unit/db/sqlalchemy/test_migration_common.py index c0648af15..3af0ec8e5 100644 --- a/tests/unit/db/sqlalchemy/test_migration_common.py +++ b/tests/unit/db/sqlalchemy/test_migration_common.py @@ -25,7 +25,6 @@ import sqlalchemy from openstack.common.db import exception as db_exception from openstack.common.db.sqlalchemy import migration -from openstack.common.db.sqlalchemy import session as db_session from openstack.common.db.sqlalchemy import test_base @@ -81,14 +80,15 @@ class TestMigrationCommon(test_base.DbTestCase): mock_find_repo.return_value = self.return_value version = migration.db_version_control( - self.path, self.test_version) + self.engine, self.path, self.test_version) self.assertEqual(version, self.test_version) mock_version_control.assert_called_once_with( - db_session.get_engine(), self.return_value, self.test_version) + self.engine, self.return_value, self.test_version) def test_db_version_return(self): - ret_val = migration.db_version(self.path, self.init_version) + ret_val = migration.db_version(self.engine, self.path, + self.init_version) self.assertEqual(ret_val, self.test_version) def test_db_version_raise_not_controlled_error_first(self): @@ -98,7 +98,8 @@ class TestMigrationCommon(test_base.DbTestCase): migrate_exception.DatabaseNotControlledError('oups'), self.test_version] - ret_val = migration.db_version(self.path, self.init_version) + ret_val = migration.db_version(self.engine, self.path, + self.init_version) self.assertEqual(ret_val, self.test_version) mock_ver.assert_called_once_with(self.path, self.init_version) @@ -112,7 +113,7 @@ class TestMigrationCommon(test_base.DbTestCase): self.assertRaises( db_exception.DbMigrationError, migration.db_version, - self.path, self.init_version) + self.engine, self.path, self.init_version) def test_db_sync_wrong_version(self): self.assertRaises( @@ -128,10 +129,11 @@ class TestMigrationCommon(test_base.DbTestCase): mock_find_repo.return_value = self.return_value self.mock_api_db_version.return_value = self.test_version - 1 - migration.db_sync(self.path, self.test_version, init_ver) + migration.db_sync(self.engine, self.path, self.test_version, + init_ver) mock_upgrade.assert_called_once_with( - db_session.get_engine(), self.return_value, self.test_version) + self.engine, self.return_value, self.test_version) def test_db_sync_downgrade(self): with contextlib.nested( @@ -142,10 +144,10 @@ class TestMigrationCommon(test_base.DbTestCase): mock_find_repo.return_value = self.return_value self.mock_api_db_version.return_value = self.test_version + 1 - migration.db_sync(self.path, self.test_version) + migration.db_sync(self.engine, self.path, self.test_version) mock_downgrade.assert_called_once_with( - db_session.get_engine(), self.return_value, self.test_version) + self.engine, self.return_value, self.test_version) def test_db_sync_sanity_called(self): with contextlib.nested( @@ -155,15 +157,15 @@ class TestMigrationCommon(test_base.DbTestCase): ) as (mock_find_repo, mock_sanity, mock_downgrade): mock_find_repo.return_value = self.return_value - migration.db_sync(self.path, self.test_version) + migration.db_sync(self.engine, self.path, self.test_version) mock_sanity.assert_called_once() def test_db_sanity_table_not_utf8(self): - with mock.patch.object(migration, 'get_engine') as mock_get_eng: - mock_eng = mock_get_eng.return_value + with mock.patch.object(self, 'engine') as mock_eng: type(mock_eng).name = mock.PropertyMock(return_value='mysql') mock_eng.execute.return_value = [['table_A', 'latin1'], ['table_B', 'latin1']] - self.assertRaises(ValueError, migration._db_schema_sanity_check) + self.assertRaises(ValueError, migration._db_schema_sanity_check, + mock_eng) diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py index 4ed5146b0..530e6513a 100644 --- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py +++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py @@ -129,7 +129,7 @@ class SessionErrorWrapperTestCase(test_base.DbTestCase): def setUp(self): super(SessionErrorWrapperTestCase, self).setUp() meta = MetaData() - meta.bind = session.get_engine() + meta.bind = self.engine test_table = Table(_TABLE_NAME, meta, Column('id', Integer, primary_key=True, nullable=False), @@ -143,16 +143,18 @@ class SessionErrorWrapperTestCase(test_base.DbTestCase): self.addCleanup(test_table.drop) def test_flush_wrapper(self): + _session = self.sessionmaker() + tbl = TmpTable() tbl.update({'foo': 10}) - tbl.save() + tbl.save(_session) tbl2 = TmpTable() tbl2.update({'foo': 10}) - self.assertRaises(db_exc.DBDuplicateEntry, tbl2.save) + self.assertRaises(db_exc.DBDuplicateEntry, tbl2.save, _session) def test_execute_wrapper(self): - _session = session.get_session() + _session = self.sessionmaker() with _session.begin(): for i in [10, 20]: tbl = TmpTable() @@ -180,7 +182,7 @@ class RegexpFilterTestCase(test_base.DbTestCase): def setUp(self): super(RegexpFilterTestCase, self).setUp() meta = MetaData() - meta.bind = session.get_engine() + meta.bind = self.engine test_table = Table(_REGEXP_TABLE_NAME, meta, Column('id', Integer, primary_key=True, nullable=False), @@ -189,7 +191,7 @@ class RegexpFilterTestCase(test_base.DbTestCase): self.addCleanup(test_table.drop) def _test_regexp_filter(self, regexp, expected): - _session = session.get_session() + _session = self.sessionmaker() with _session.begin(): for i in ['10', '20', u'♥']: tbl = RegexpTable() @@ -213,26 +215,6 @@ class RegexpFilterTestCase(test_base.DbTestCase): self._test_regexp_filter(u'♦', []) -class SlaveBackendTestCase(test.BaseTestCase): - - def test_slave_engine_nomatch(self): - default = session.CONF.database.connection - session.CONF.database.slave_connection = default - - e = session.get_engine() - slave_e = session.get_engine(slave_engine=True) - self.assertNotEqual(slave_e, e) - - def test_no_slave_engine_match(self): - slave_e = session.get_engine() - e = session.get_engine() - self.assertEqual(slave_e, e) - - def test_slave_backend_nomatch(self): - session.CONF.database.slave_connection = "mysql:///localhost" - self.assertRaises(AssertionError, session._assert_matching_drivers) - - class FakeDBAPIConnection(): def cursor(self): return FakeCursor() @@ -323,7 +305,8 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase): def setUp(self): super(MySQLTraditionalModeTestCase, self).setUp() - self.engine = session.get_engine(mysql_traditional_mode=True) + self.engine = session.create_engine(self.engine.url, + mysql_traditional_mode=True) self.connection = self.engine.connect() meta = MetaData() @@ -333,7 +316,6 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase): Column('bar', String(255))) self.test_table.create() - self.addCleanup(session.cleanup) self.addCleanup(self.test_table.drop) self.addCleanup(self.connection.close) @@ -341,3 +323,28 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase): with self.connection.begin(): self.assertRaises(DataError, self.connection.execute, self.test_table.insert(), bar='a' * 512) + + +class EngineFacadeTestCase(test.BaseTestCase): + def setUp(self): + super(EngineFacadeTestCase, self).setUp() + + self.facade = session.EngineFacade('sqlite://') + + def test_get_engine(self): + eng1 = self.facade.get_engine() + eng2 = self.facade.get_engine() + + self.assertIs(eng1, eng2) + + def test_get_session(self): + ses1 = self.facade.get_session() + ses2 = self.facade.get_session() + + self.assertIsNot(ses1, ses2) + + def test_get_session_arguments_override_default_settings(self): + ses = self.facade.get_session(autocommit=False, expire_on_commit=True) + + self.assertFalse(ses.autocommit) + self.assertTrue(ses.expire_on_commit)