Don't store engine instances in oslo.db

oslo.db is meant to be a collection of helper utilities for
SQLAlchemy. In order to behave like a 'good' library, we must
stop using global state.

This patch ensures we don't store any engine instances globally in
oslo.db. It's up to end applications to decide how to cope with
engines, not oslo.db.

Partial-Bug: #1263908

Co-authored-by: Victor Sergeyev <vsergeyev@mirantis.com>

Change-Id: I330467ec1317e1a13ff9f0237e2d8900d718e379
This commit is contained in:
Roman Podoliaka 2014-01-23 17:45:50 +02:00
parent a01f79c305
commit ce69e7f8f6
10 changed files with 209 additions and 178 deletions

View File

@ -51,13 +51,9 @@ import sqlalchemy
from sqlalchemy.schema import UniqueConstraint
from openstack.common.db import exception
from openstack.common.db.sqlalchemy import session as db_session
from openstack.common.gettextutils import _
get_engine = db_session.get_engine
def _get_unique_constraints(self, table):
"""Retrieve information about existing unique constraints of the table
@ -172,11 +168,12 @@ def patch_migrate():
sqlite.SQLiteConstraintGenerator)
def db_sync(abs_path, version=None, init_version=0):
def db_sync(engine, abs_path, version=None, init_version=0):
"""Upgrade or downgrade a database.
Function runs the upgrade() or downgrade() functions in change scripts.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository.
:param version: Database will upgrade/downgrade until this version.
If None - database will update to the latest
@ -190,18 +187,23 @@ def db_sync(abs_path, version=None, init_version=0):
raise exception.DbMigrationError(
message=_("version should be an integer"))
current_version = db_version(abs_path, init_version)
current_version = db_version(engine, abs_path, init_version)
repository = _find_migrate_repo(abs_path)
_db_schema_sanity_check()
_db_schema_sanity_check(engine)
if version is None or version > current_version:
return versioning_api.upgrade(get_engine(), repository, version)
return versioning_api.upgrade(engine, repository, version)
else:
return versioning_api.downgrade(get_engine(), repository,
return versioning_api.downgrade(engine, repository,
version)
def _db_schema_sanity_check():
engine = get_engine()
def _db_schema_sanity_check(engine):
"""Ensure all database tables were created with required parameters.
:param engine: SQLAlchemy engine instance for a given database
"""
if engine.name == 'mysql':
onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION '
'from information_schema.TABLES '
@ -216,23 +218,23 @@ def _db_schema_sanity_check():
) % ','.join(table_names))
def db_version(abs_path, init_version):
def db_version(engine, abs_path, init_version):
"""Show the current version of the repository.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
try:
return versioning_api.db_version(get_engine(), repository)
return versioning_api.db_version(engine, repository)
except versioning_exceptions.DatabaseNotControlledError:
meta = sqlalchemy.MetaData()
engine = get_engine()
meta.reflect(bind=engine)
tables = meta.tables
if len(tables) == 0 or 'alembic_version' in tables:
db_version_control(abs_path, init_version)
return versioning_api.db_version(get_engine(), repository)
return versioning_api.db_version(engine, repository)
else:
raise exception.DbMigrationError(
message=_(
@ -241,17 +243,18 @@ def db_version(abs_path, init_version):
"manually."))
def db_version_control(abs_path, version=None):
def db_version_control(engine, abs_path, version=None):
"""Mark a database as under this repository's version control.
Once a database is under version control, schema changes should
only be done via change scripts in this repository.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
versioning_api.version_control(get_engine(), repository, version)
versioning_api.version_control(engine, repository, version)
return version

View File

@ -40,6 +40,7 @@ class AlembicExtension(ext_base.MigrationExtensionBase):
repo_path = migration_config.get('alembic_repo_path')
if repo_path:
self.config.set_main_option('script_location', repo_path)
self.db_url = migration_config['db_url']
def upgrade(self, version):
return alembic.command.upgrade(self.config, version or 'head')
@ -50,7 +51,7 @@ class AlembicExtension(ext_base.MigrationExtensionBase):
return alembic.command.downgrade(self.config, version)
def version(self):
engine = db_session.get_engine()
engine = db_session.create_engine(self.db_url)
with engine.connect() as conn:
context = alembic_migration.MigrationContext.configure(conn)
return context.get_current_revision()

View File

@ -15,6 +15,7 @@ import os
from openstack.common.db.sqlalchemy import migration
from openstack.common.db.sqlalchemy.migration_cli import ext_base
from openstack.common.db.sqlalchemy import session as db_session
from openstack.common.gettextutils import _ # noqa
@ -33,6 +34,8 @@ class MigrateExtension(ext_base.MigrationExtensionBase):
def __init__(self, migration_config):
self.repository = migration_config.get('migration_repo_path', '')
self.init_version = migration_config.get('init_version', 0)
self.db_url = migration_config['db_url']
self.engine = db_session.create_engine(self.db_url)
@property
def enabled(self):
@ -41,7 +44,7 @@ class MigrateExtension(ext_base.MigrationExtensionBase):
def upgrade(self, version):
version = None if version == 'head' else version
return migration.db_sync(
self.repository, version,
self.engine, self.repository, version,
init_version=self.init_version)
def downgrade(self, version):
@ -51,7 +54,7 @@ class MigrateExtension(ext_base.MigrationExtensionBase):
version = self.init_version
version = int(version)
return migration.db_sync(
self.repository, version,
self.engine, self.repository, version,
init_version=self.init_version)
except ValueError:
LOG.error(
@ -63,4 +66,4 @@ class MigrateExtension(ext_base.MigrationExtensionBase):
def version(self):
return migration.db_version(
self.repository, init_version=self.init_version)
self.engine, self.repository, init_version=self.init_version)

View File

@ -26,7 +26,6 @@ from sqlalchemy import Column, Integer
from sqlalchemy import DateTime
from sqlalchemy.orm import object_mapper
from openstack.common.db.sqlalchemy import session as sa
from openstack.common import timeutils
@ -34,10 +33,9 @@ class ModelBase(object):
"""Base class for models."""
__table_initialized__ = False
def save(self, session=None):
def save(self, session):
"""Save this object."""
if not session:
session = sa.get_session()
# NOTE(boris-42): This part of code should be look like:
# session.add(self)
# session.flush()
@ -110,7 +108,7 @@ class SoftDeleteMixin(object):
deleted_at = Column(DateTime)
deleted = Column(Integer, default=0)
def soft_delete(self, session=None):
def soft_delete(self, session):
"""Mark this object as deleted."""
self.deleted = self.id
self.deleted_at = timeutils.utcnow()

View File

@ -87,7 +87,7 @@ Recommended ways to use sessions within this framework:
.. code:: python
def create_many_foo(context, foos):
session = get_session()
session = sessionmaker()
with session.begin():
for foo in foos:
foo_ref = models.Foo()
@ -95,7 +95,7 @@ Recommended ways to use sessions within this framework:
session.add(foo_ref)
def update_bar(context, foo_id, newbar):
session = get_session()
session = sessionmaker()
with session.begin():
foo_ref = (model_query(context, models.Foo, session).
filter_by(id=foo_id).
@ -142,7 +142,7 @@ Recommended ways to use sessions within this framework:
foo1 = models.Foo()
foo2 = models.Foo()
foo1.id = foo2.id = 1
session = get_session()
session = sessionmaker()
try:
with session.begin():
session.add(foo1)
@ -168,7 +168,7 @@ Recommended ways to use sessions within this framework:
.. code:: python
def myfunc(foo):
session = get_session()
session = sessionmaker()
with session.begin():
# do some database things
bar = _private_func(foo, session)
@ -176,7 +176,7 @@ Recommended ways to use sessions within this framework:
def _private_func(foo, session=None):
if not session:
session = get_session()
session = sessionmaker()
with session.begin(subtransaction=True):
# do some other database things
return bar
@ -240,7 +240,7 @@ Efficient use of soft deletes:
def complex_soft_delete_with_synchronization_bar(session=None):
if session is None:
session = get_session()
session = sessionmaker()
with session.begin(subtransactions=True):
count = (model_query(BarModel).
find(some_condition).
@ -257,7 +257,7 @@ Efficient use of soft deletes:
.. code:: python
def soft_delete_bar_model():
session = get_session()
session = sessionmaker()
with session.begin():
bar_ref = model_query(BarModel).find(some_condition).first()
# Work with bar_ref
@ -269,7 +269,7 @@ Efficient use of soft deletes:
.. code:: python
def soft_delete_multi_models():
session = get_session()
session = sessionmaker()
with session.begin():
query = (model_query(BarModel, session=session).
find(some_condition))
@ -408,11 +408,6 @@ CONF.register_opts(database_opts, 'database')
LOG = logging.getLogger(__name__)
_ENGINE = None
_MAKER = None
_SLAVE_ENGINE = None
_SLAVE_MAKER = None
def set_defaults(sql_connection, sqlite_db, max_pool_size=None,
max_overflow=None, pool_timeout=None):
@ -433,24 +428,6 @@ def set_defaults(sql_connection, sqlite_db, max_pool_size=None,
pool_timeout=pool_timeout)
def cleanup():
global _ENGINE, _MAKER
global _SLAVE_ENGINE, _SLAVE_MAKER
if _MAKER:
_MAKER.close_all()
_MAKER = None
if _ENGINE:
_ENGINE.dispose()
_ENGINE = None
if _SLAVE_MAKER:
_SLAVE_MAKER.close_all()
_SLAVE_MAKER = None
if _SLAVE_ENGINE:
_SLAVE_ENGINE.dispose()
_SLAVE_ENGINE = None
class SqliteForeignKeysListener(PoolListener):
"""Ensures that the foreign key constraints are enforced in SQLite.
@ -462,30 +439,6 @@ class SqliteForeignKeysListener(PoolListener):
dbapi_con.execute('pragma foreign_keys=ON')
def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False,
slave_session=False, mysql_traditional_mode=False):
"""Return a SQLAlchemy session."""
global _MAKER
global _SLAVE_MAKER
maker = _MAKER
if slave_session:
maker = _SLAVE_MAKER
if maker is None:
engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session,
mysql_traditional_mode=mysql_traditional_mode)
maker = get_maker(engine, autocommit, expire_on_commit)
if slave_session:
_SLAVE_MAKER = maker
else:
_MAKER = maker
session = maker()
return session
# note(boris-42): In current versions of DB backends unique constraint
# violation messages follow the structure:
#
@ -591,15 +544,19 @@ def _raise_if_deadlock_error(operational_error, engine_name):
def _wrap_db_error(f):
#TODO(rpodolyaka): in a subsequent commit make this a class decorator to
# ensure it can only applied to Session subclasses instances (as we use
# Session instance bind attribute below)
@functools.wraps(f)
def _wrap(*args, **kwargs):
def _wrap(self, *args, **kwargs):
try:
return f(*args, **kwargs)
return f(self, *args, **kwargs)
except UnicodeEncodeError:
raise exception.DBInvalidUnicodeParameter()
except sqla_exc.OperationalError as e:
_raise_if_db_connection_lost(e, get_engine())
_raise_if_deadlock_error(e, get_engine().name)
_raise_if_db_connection_lost(e, self.bind)
_raise_if_deadlock_error(e, self.bind.dialect.name)
# NOTE(comstud): A lot of code is checking for OperationalError
# so let's not wrap it for now.
raise
@ -612,7 +569,7 @@ def _wrap_db_error(f):
# instance_types) there are more than one unique constraint. This
# means we should get names of columns, which values violate
# unique constraint, from error message.
_raise_if_duplicate_entry_error(e, get_engine().name)
_raise_if_duplicate_entry_error(e, self.bind.dialect.name)
raise exception.DBError(e)
except Exception as e:
LOG.exception(_('DB exception wrapped.'))
@ -620,29 +577,6 @@ def _wrap_db_error(f):
return _wrap
def get_engine(sqlite_fk=False, slave_engine=False,
mysql_traditional_mode=False):
"""Return a SQLAlchemy engine."""
global _ENGINE
global _SLAVE_ENGINE
engine = _ENGINE
db_uri = CONF.database.connection
if slave_engine:
engine = _SLAVE_ENGINE
db_uri = CONF.database.slave_connection
if engine is None:
engine = create_engine(db_uri, sqlite_fk=sqlite_fk,
mysql_traditional_mode=mysql_traditional_mode)
if slave_engine:
_SLAVE_ENGINE = engine
else:
_ENGINE = engine
return engine
def _synchronous_switch_listener(dbapi_conn, connection_rec):
"""Switch sqlite connections to non-synchronous mode."""
dbapi_conn.execute("PRAGMA synchronous = OFF")
@ -902,3 +836,87 @@ def _assert_matching_drivers():
normal = sqlalchemy.engine.url.make_url(CONF.database.connection)
slave = sqlalchemy.engine.url.make_url(CONF.database.slave_connection)
assert normal.drivername == slave.drivername
class EngineFacade(object):
"""A helper class for removing of global engine instances from oslo.db.
As a library, oslo.db can't decide where to store/when to create engine
and sessionmaker instances, so this must be left for a target application.
On the other hand, in order to simplify the adoption of oslo.db changes,
we'll provide a helper class, which creates engine and sessionmaker
on its instantiation and provides get_engine()/get_session() methods
that are compatible with corresponding utility functions that currently
exist in target projects, e.g. in Nova.
engine/sessionmaker instances will still be global (and they are meant to
be global), but they will be stored in the app context, rather that in the
oslo.db context.
Note: using of this helper is completely optional and you are encouraged to
integrate engine/sessionmaker instances into your apps any way you like
(e.g. one might want to bind a session to a request context). Two important
things to remember:
1. An Engine instance is effectively a pool of DB connections, so it's
meant to be shared (and it's thread-safe).
2. A Session instance is not meant to be shared and represents a DB
transactional context (i.e. it's not thread-safe). sessionmaker is
a factory of sessions.
"""
def __init__(self, sql_connection,
sqlite_fk=False, mysql_traditional_mode=False,
autocommit=True, expire_on_commit=False):
"""Initialize engine and sessionmaker instances.
:param sqlite_fk: enable foreign keys in SQLite
:type sqlite_fk: bool
:param mysql_traditional_mode: enable traditional mode in MySQL
:type mysql_traditional_mode: bool
:param autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:param expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
"""
super(EngineFacade, self).__init__()
self._engine = create_engine(
sql_connection=sql_connection,
sqlite_fk=sqlite_fk,
mysql_traditional_mode=mysql_traditional_mode)
self._session_maker = get_maker(
engine=self._engine,
autocommit=autocommit,
expire_on_commit=expire_on_commit)
def get_engine(self):
"""Get the engine instance (note, that it's shared)."""
return self._engine
def get_session(self, **kwargs):
"""Get a Session instance.
If passed, keyword arguments values override the ones used when the
sessionmaker instance was created.
:keyword autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:keyword expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
"""
for arg in kwargs:
if arg not in ('autocommit', 'expire_on_commit'):
del kwargs[arg]
return self._session_maker(**kwargs)

View File

@ -18,7 +18,6 @@ import functools
import os
import fixtures
from oslo.config import cfg
import six
from openstack.common.db.sqlalchemy import session
@ -38,18 +37,17 @@ class DbFixture(fixtures.Fixture):
def _get_uri(self):
return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://')
def __init__(self):
def __init__(self, test):
super(DbFixture, self).__init__()
self.conf = cfg.CONF
self.conf.import_opt('connection',
'openstack.common.db.sqlalchemy.session',
group='database')
self.test = test
def setUp(self):
super(DbFixture, self).setUp()
self.conf.set_default('connection', self._get_uri(), group='database')
self.addCleanup(self.conf.reset)
self.test.engine = session.create_engine(self._get_uri())
self.test.sessionmaker = session.get_maker(self.test.engine)
self.addCleanup(self.test.engine.dispose)
class DbTestCase(test.BaseTestCase):
@ -64,9 +62,7 @@ class DbTestCase(test.BaseTestCase):
def setUp(self):
super(DbTestCase, self).setUp()
self.useFixture(self.FIXTURE())
self.addCleanup(session.cleanup)
self.useFixture(self.FIXTURE(self))
ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql']
@ -83,11 +79,10 @@ def backend_specific(*dialects):
if not set(dialects).issubset(ALLOWED_DIALECTS):
raise ValueError(
"Please use allowed dialects: %s" % ALLOWED_DIALECTS)
engine = session.get_engine()
if engine.name not in dialects:
if self.engine.name not in dialects:
msg = ('The test "%s" can be run '
'only on %s. Current engine is %s.')
args = (f.__name__, ' '.join(dialects), engine.name)
args = (f.__name__, ' '.join(dialects), self.engine.name)
self.skip(msg % args)
else:
return f(self)

View File

@ -18,7 +18,6 @@ from migrate.changeset.databases import sqlite
import sqlalchemy as sa
from openstack.common.db.sqlalchemy import migration
from openstack.common.db.sqlalchemy import session
from openstack.common.db.sqlalchemy import test_base
@ -44,7 +43,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase):
test_table = sa.Table(
'test_table',
sa.schema.MetaData(bind=session.get_engine()),
sa.schema.MetaData(bind=self.engine),
sa.Column('a', sa.Integer),
sa.Column('b', sa.String(10)),
sa.Column('c', sa.Integer),
@ -58,7 +57,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase):
# we actually do in db migrations code
self.reflected_table = sa.Table(
'test_table',
sa.schema.MetaData(bind=session.get_engine()),
sa.schema.MetaData(bind=self.engine),
autoload=True
)

View File

@ -31,7 +31,8 @@ class MockWithCmp(mock.MagicMock):
class TestAlembicExtension(test.BaseTestCase):
def setUp(self):
self.migration_config = {'alembic_ini_path': '.'}
self.migration_config = {'alembic_ini_path': '.',
'db_url': 'sqlite://'}
self.alembic = ext_alembic.AlembicExtension(self.migration_config)
super(TestAlembicExtension, self).setUp()
@ -90,7 +91,8 @@ class TestAlembicExtension(test.BaseTestCase):
class TestMigrateExtension(test.BaseTestCase):
def setUp(self):
self.migration_config = {'migration_repo_path': '.'}
self.migration_config = {'migration_repo_path': '.',
'db_url': 'sqlite://'}
self.migrate = ext_migrate.MigrateExtension(self.migration_config)
super(TestMigrateExtension, self).setUp()
@ -105,40 +107,41 @@ class TestMigrateExtension(test.BaseTestCase):
def test_upgrade_head(self, migration):
self.migrate.upgrade('head')
migration.db_sync.assert_called_once_with(
self.migrate.repository, None, init_version=0)
self.migrate.engine, self.migrate.repository, None, init_version=0)
def test_upgrade_normal(self, migration):
self.migrate.upgrade(111)
migration.db_sync.assert_called_once_with(
self.migrate.repository, 111, init_version=0)
mock.ANY, self.migrate.repository, 111, init_version=0)
def test_downgrade_init_version_from_base(self, migration):
self.migrate.downgrade('base')
migration.db_sync.assert_called_once_with(
self.migrate.repository, mock.ANY,
self.migrate.engine, self.migrate.repository, mock.ANY,
init_version=mock.ANY)
def test_downgrade_init_version_from_none(self, migration):
self.migrate.downgrade(None)
migration.db_sync.assert_called_once_with(
self.migrate.repository, mock.ANY,
self.migrate.engine, self.migrate.repository, mock.ANY,
init_version=mock.ANY)
def test_downgrade_normal(self, migration):
self.migrate.downgrade(101)
migration.db_sync.assert_called_once_with(
self.migrate.repository, 101, init_version=0)
self.migrate.engine, self.migrate.repository, 101, init_version=0)
def test_version(self, migration):
self.migrate.version()
migration.db_version.assert_called_once_with(
self.migrate.repository, init_version=0)
self.migrate.engine, self.migrate.repository, init_version=0)
def test_change_init_version(self, migration):
self.migration_config['init_version'] = 101
migrate = ext_migrate.MigrateExtension(self.migration_config)
migrate.downgrade(None)
migration.db_sync.assert_called_once_with(
migrate.engine,
self.migrate.repository,
self.migration_config['init_version'],
init_version=self.migration_config['init_version'])
@ -148,7 +151,8 @@ class TestMigrationManager(test.BaseTestCase):
def setUp(self):
self.migration_config = {'alembic_ini_path': '.',
'migrate_repo_path': '.'}
'migrate_repo_path': '.',
'db_url': 'sqlite://'}
self.migration_manager = manager.MigrationManager(
self.migration_config)
self.ext = mock.Mock()
@ -188,7 +192,8 @@ class TestMigrationRightOrder(test.BaseTestCase):
def setUp(self):
self.migration_config = {'alembic_ini_path': '.',
'migrate_repo_path': '.'}
'migrate_repo_path': '.',
'db_url': 'sqlite://'}
self.migration_manager = manager.MigrationManager(
self.migration_config)
self.first_ext = MockWithCmp()

View File

@ -25,7 +25,6 @@ import sqlalchemy
from openstack.common.db import exception as db_exception
from openstack.common.db.sqlalchemy import migration
from openstack.common.db.sqlalchemy import session as db_session
from openstack.common.db.sqlalchemy import test_base
@ -81,14 +80,15 @@ class TestMigrationCommon(test_base.DbTestCase):
mock_find_repo.return_value = self.return_value
version = migration.db_version_control(
self.path, self.test_version)
self.engine, self.path, self.test_version)
self.assertEqual(version, self.test_version)
mock_version_control.assert_called_once_with(
db_session.get_engine(), self.return_value, self.test_version)
self.engine, self.return_value, self.test_version)
def test_db_version_return(self):
ret_val = migration.db_version(self.path, self.init_version)
ret_val = migration.db_version(self.engine, self.path,
self.init_version)
self.assertEqual(ret_val, self.test_version)
def test_db_version_raise_not_controlled_error_first(self):
@ -98,7 +98,8 @@ class TestMigrationCommon(test_base.DbTestCase):
migrate_exception.DatabaseNotControlledError('oups'),
self.test_version]
ret_val = migration.db_version(self.path, self.init_version)
ret_val = migration.db_version(self.engine, self.path,
self.init_version)
self.assertEqual(ret_val, self.test_version)
mock_ver.assert_called_once_with(self.path, self.init_version)
@ -112,7 +113,7 @@ class TestMigrationCommon(test_base.DbTestCase):
self.assertRaises(
db_exception.DbMigrationError, migration.db_version,
self.path, self.init_version)
self.engine, self.path, self.init_version)
def test_db_sync_wrong_version(self):
self.assertRaises(
@ -128,10 +129,11 @@ class TestMigrationCommon(test_base.DbTestCase):
mock_find_repo.return_value = self.return_value
self.mock_api_db_version.return_value = self.test_version - 1
migration.db_sync(self.path, self.test_version, init_ver)
migration.db_sync(self.engine, self.path, self.test_version,
init_ver)
mock_upgrade.assert_called_once_with(
db_session.get_engine(), self.return_value, self.test_version)
self.engine, self.return_value, self.test_version)
def test_db_sync_downgrade(self):
with contextlib.nested(
@ -142,10 +144,10 @@ class TestMigrationCommon(test_base.DbTestCase):
mock_find_repo.return_value = self.return_value
self.mock_api_db_version.return_value = self.test_version + 1
migration.db_sync(self.path, self.test_version)
migration.db_sync(self.engine, self.path, self.test_version)
mock_downgrade.assert_called_once_with(
db_session.get_engine(), self.return_value, self.test_version)
self.engine, self.return_value, self.test_version)
def test_db_sync_sanity_called(self):
with contextlib.nested(
@ -155,15 +157,15 @@ class TestMigrationCommon(test_base.DbTestCase):
) as (mock_find_repo, mock_sanity, mock_downgrade):
mock_find_repo.return_value = self.return_value
migration.db_sync(self.path, self.test_version)
migration.db_sync(self.engine, self.path, self.test_version)
mock_sanity.assert_called_once()
def test_db_sanity_table_not_utf8(self):
with mock.patch.object(migration, 'get_engine') as mock_get_eng:
mock_eng = mock_get_eng.return_value
with mock.patch.object(self, 'engine') as mock_eng:
type(mock_eng).name = mock.PropertyMock(return_value='mysql')
mock_eng.execute.return_value = [['table_A', 'latin1'],
['table_B', 'latin1']]
self.assertRaises(ValueError, migration._db_schema_sanity_check)
self.assertRaises(ValueError, migration._db_schema_sanity_check,
mock_eng)

View File

@ -129,7 +129,7 @@ class SessionErrorWrapperTestCase(test_base.DbTestCase):
def setUp(self):
super(SessionErrorWrapperTestCase, self).setUp()
meta = MetaData()
meta.bind = session.get_engine()
meta.bind = self.engine
test_table = Table(_TABLE_NAME, meta,
Column('id', Integer, primary_key=True,
nullable=False),
@ -143,16 +143,18 @@ class SessionErrorWrapperTestCase(test_base.DbTestCase):
self.addCleanup(test_table.drop)
def test_flush_wrapper(self):
_session = self.sessionmaker()
tbl = TmpTable()
tbl.update({'foo': 10})
tbl.save()
tbl.save(_session)
tbl2 = TmpTable()
tbl2.update({'foo': 10})
self.assertRaises(db_exc.DBDuplicateEntry, tbl2.save)
self.assertRaises(db_exc.DBDuplicateEntry, tbl2.save, _session)
def test_execute_wrapper(self):
_session = session.get_session()
_session = self.sessionmaker()
with _session.begin():
for i in [10, 20]:
tbl = TmpTable()
@ -180,7 +182,7 @@ class RegexpFilterTestCase(test_base.DbTestCase):
def setUp(self):
super(RegexpFilterTestCase, self).setUp()
meta = MetaData()
meta.bind = session.get_engine()
meta.bind = self.engine
test_table = Table(_REGEXP_TABLE_NAME, meta,
Column('id', Integer, primary_key=True,
nullable=False),
@ -189,7 +191,7 @@ class RegexpFilterTestCase(test_base.DbTestCase):
self.addCleanup(test_table.drop)
def _test_regexp_filter(self, regexp, expected):
_session = session.get_session()
_session = self.sessionmaker()
with _session.begin():
for i in ['10', '20', u'']:
tbl = RegexpTable()
@ -213,26 +215,6 @@ class RegexpFilterTestCase(test_base.DbTestCase):
self._test_regexp_filter(u'', [])
class SlaveBackendTestCase(test.BaseTestCase):
def test_slave_engine_nomatch(self):
default = session.CONF.database.connection
session.CONF.database.slave_connection = default
e = session.get_engine()
slave_e = session.get_engine(slave_engine=True)
self.assertNotEqual(slave_e, e)
def test_no_slave_engine_match(self):
slave_e = session.get_engine()
e = session.get_engine()
self.assertEqual(slave_e, e)
def test_slave_backend_nomatch(self):
session.CONF.database.slave_connection = "mysql:///localhost"
self.assertRaises(AssertionError, session._assert_matching_drivers)
class FakeDBAPIConnection():
def cursor(self):
return FakeCursor()
@ -323,7 +305,8 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase):
def setUp(self):
super(MySQLTraditionalModeTestCase, self).setUp()
self.engine = session.get_engine(mysql_traditional_mode=True)
self.engine = session.create_engine(self.engine.url,
mysql_traditional_mode=True)
self.connection = self.engine.connect()
meta = MetaData()
@ -333,7 +316,6 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase):
Column('bar', String(255)))
self.test_table.create()
self.addCleanup(session.cleanup)
self.addCleanup(self.test_table.drop)
self.addCleanup(self.connection.close)
@ -341,3 +323,28 @@ class MySQLTraditionalModeTestCase(test_base.MySQLOpportunisticTestCase):
with self.connection.begin():
self.assertRaises(DataError, self.connection.execute,
self.test_table.insert(), bar='a' * 512)
class EngineFacadeTestCase(test.BaseTestCase):
def setUp(self):
super(EngineFacadeTestCase, self).setUp()
self.facade = session.EngineFacade('sqlite://')
def test_get_engine(self):
eng1 = self.facade.get_engine()
eng2 = self.facade.get_engine()
self.assertIs(eng1, eng2)
def test_get_session(self):
ses1 = self.facade.get_session()
ses2 = self.facade.get_session()
self.assertIsNot(ses1, ses2)
def test_get_session_arguments_override_default_settings(self):
ses = self.facade.get_session(autocommit=False, expire_on_commit=True)
self.assertFalse(ses.autocommit)
self.assertTrue(ses.expire_on_commit)