Use oslo.db sessions

Keystone was using its own keystone.common.sql module for
database access. oslo-incubator's db.sqlalchemy module provides
the same or better functionality, so use that instead.

DocImpact
- The options that were in the [sql] section are deprecated and
  replaced by options in the [database] section. There are
  also several new options in this section. If database
  configuration is described for another project that uses
  oslo-incubator's db.sqlalchemy module the docs can be shared.

Part of bp use-common-oslo-db-code

Change-Id: I25b717d9616e9d31316441ae3671d2f86229c2bf
This commit is contained in:
Brant Knudson 2013-12-15 10:00:41 -06:00
parent 1a96f961e3
commit 44ceda2816
23 changed files with 179 additions and 500 deletions

View File

@ -43,6 +43,7 @@ from keystone.openstack.common import gettextutils
gettextutils.install('keystone')
from keystone.common import environment
from keystone.common import sql
from keystone.common import utils
from keystone import config
from keystone.openstack.common import importutils
@ -101,6 +102,8 @@ if __name__ == '__main__':
if os.path.exists(dev_conf):
config_files = [dev_conf]
sql.initialize()
CONF(project='keystone',
version=pbr.version.VersionInfo('keystone').version_string(),
default_config_files=config_files)

View File

@ -44,6 +44,12 @@
# similar to max_param_size, but provides an exception for token values
# max_token_size = 8192
# the filename to use with sqlite
# sqlite_db = keystone.db
# If true, use synchronous mode for sqlite
# sqlite_synchronous = True
# === Logging Options ===
# Print debugging output
# (includes plaintext request logging, potentially including passwords)
@ -145,11 +151,48 @@
[sql]
# The SQLAlchemy connection string used to connect to the database
# DEPRECATED: use connection in the [database] section instead.
# connection = sqlite:///keystone.db
# the timeout before idle sql connections are reaped
# DEPRECATED: use idle_timeout in the [database] section instead.
# idle_timeout = 200
[database]
# The SQLAlchemy connection string used to connect to the database
# connection = sqlite:///keystone.db
# The SQLAlchemy connection string used to connect to the slave database
# Note that Keystone does not use this option.
# slave_connection =
# timeout before idle sql connections are reaped
# idle_timeout = 3600
# Minimum number of SQL connections to keep open in a pool
# min_pool_size = 1
# Maximum number of SQL connections to keep open in a pool
# max_pool_size =
# maximum db connection retries during startup. (setting -1 implies an infinite retry count)
# max_retries = 10
# interval between retries of opening a sql connection
# retry_interval = 10
# If set, use this value for max_overflow with sqlalchemy
# max_overflow =
# Verbosity of SQL debugging information. 0=None, 100=Everything
# connection_debug = 0
# Add python stack traces to SQL as comment strings
# connection_trace = False
# If set, use this value for pool_timeout with sqlalchemy
# pool_timeout =
[identity]
# driver = keystone.identity.backends.sql.Identity

View File

@ -27,12 +27,16 @@ from keystone.openstack.common import gettextutils
gettextutils.install('keystone')
from keystone.common import environment
from keystone.common import sql
from keystone import config
from keystone.openstack.common import log
from keystone import service
CONF = config.CONF
sql.initialize()
CONF(project='keystone')
config.setup_logging()

View File

@ -24,6 +24,7 @@ from oslo.config import cfg
import pbr.version
from keystone.common import openssl
from keystone.common import sql
from keystone.common.sql import migration
from keystone.common import utils
from keystone import config
@ -210,6 +211,9 @@ command_opt = cfg.SubCommandOpt('command',
def main(argv=None, config_files=None):
CONF.register_cli_opt(command_opt)
sql.initialize()
CONF(args=argv[1:],
project='keystone',
version=pbr.version.VersionInfo('keystone').version_string(),

View File

@ -125,10 +125,6 @@ FILE_OPTIONS = {
cfg.StrOpt('cert_subject',
default=('/C=US/ST=Unset/L=Unset/O=Unset/'
'CN=www.example.com'))],
'sql': [
cfg.StrOpt('connection', secret=True,
default='sqlite:///keystone.db'),
cfg.IntOpt('idle_timeout', default=200)],
'assignment': [
# assignment has no default for backward compatibility reasons.
# If assignment driver is not specified, the identity driver chooses

View File

@ -14,32 +14,25 @@
# License for the specific language governing permissions and limitations
# under the License.
"""SQL backends for the various services."""
"""SQL backends for the various services.
Before using this module, call initialize(). This has to be done before
CONF() because it sets up configuration options.
"""
import contextlib
import functools
import sqlalchemy as sql
import sqlalchemy.engine.url
from sqlalchemy.exc import DisconnectionError
from sqlalchemy.ext import declarative
import sqlalchemy.orm
from sqlalchemy.orm.attributes import flag_modified, InstrumentedAttribute
import sqlalchemy.pool
from sqlalchemy import types as sql_types
from keystone import config
from keystone import exception
from keystone.openstack.common.db import exception as db_exception
from keystone.openstack.common.db.sqlalchemy import models
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common import jsonutils
from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
CONF = config.CONF
# maintain a single engine reference for sqlalchemy engine
GLOBAL_ENGINE = None
GLOBAL_ENGINE_CALLBACKS = set()
ModelBase = declarative.declarative_base()
@ -63,6 +56,14 @@ joinedload = sql.orm.joinedload
flag_modified = flag_modified
def initialize():
"""Initialize the module."""
session.set_defaults(
sql_connection="sqlite:///keystone.db",
sqlite_db="keystone.db")
def initialize_decorator(init):
"""Ensure that the length of string field do not exceed the limit.
@ -96,51 +97,6 @@ def initialize_decorator(init):
ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
def set_global_engine(engine):
"""Set the global engine.
This sets the current global engine, which is returned by
Base.get_engine(allow_global_engine=True).
When the global engine is changed, all of the callbacks registered via
register_global_engine_callback since the last time set_global_engine was
changed are called. The callback functions are invoked with no arguments.
"""
global GLOBAL_ENGINE
global GLOBAL_ENGINE_CALLBACKS
if engine is GLOBAL_ENGINE:
# It's the same engine so nothing to do.
return
GLOBAL_ENGINE = engine
cbs = GLOBAL_ENGINE_CALLBACKS
GLOBAL_ENGINE_CALLBACKS = set()
for cb in cbs:
try:
cb()
except Exception:
LOG.exception(_("Global engine callback raised."))
# Just logging the exception so can process other callbacks.
def register_global_engine_callback(cb_fn):
"""Register a function to be called when the global engine is set.
Note that the callback will be called only once or not at all, so to get
called each time the global engine is changed the function must be
re-registered.
"""
global GLOBAL_ENGINE_CALLBACKS
GLOBAL_ENGINE_CALLBACKS.add(cb_fn)
# Special Fields
class JsonBlob(sql_types.TypeDecorator):
@ -188,59 +144,9 @@ class DictBase(models.ModelBase):
return getattr(self, key)
def mysql_on_checkout(dbapi_conn, connection_rec, connection_proxy):
"""Ensures that MySQL connections checked out of the pool are alive.
Borrowed from:
http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
Error codes caught:
* 2006 MySQL server has gone away
* 2013 Lost connection to MySQL server during query
* 2014 Commands out of sync; you can't run this command now
* 2045 Can't open shared memory; no answer from server (%lu)
* 2055 Lost connection to MySQL server at '%s', system error: %d
from http://dev.mysql.com/doc/refman/5.6/en/error-messages-client.html
"""
try:
dbapi_conn.cursor().execute('select 1')
except dbapi_conn.OperationalError as e:
if e.args[0] in (2006, 2013, 2014, 2045, 2055):
LOG.warn(_('MySQL server has gone away: %s'), e)
raise DisconnectionError("Database server went away")
else:
raise
def db2_on_checkout(engine, dbapi_conn, connection_rec, connection_proxy):
"""Ensures that DB2 connections checked out of the pool are alive."""
cursor = dbapi_conn.cursor()
try:
cursor.execute('select 1 from (values (1)) AS t1')
except Exception as e:
is_disconnect = engine.dialect.is_disconnect(e, dbapi_conn, cursor)
if is_disconnect:
LOG.warn(_('DB2 server has gone away: %s'), e)
raise DisconnectionError("Database server went away")
else:
raise
# Backends
class Base(object):
_engine = None
_sessionmaker = None
def get_session(self, autocommit=True, expire_on_commit=False):
"""Return a SQLAlchemy session."""
if not self._engine:
self._engine = self.get_engine()
self._sessionmaker = self.get_sessionmaker(self._engine)
register_global_engine_callback(self.clear_engine)
return self._sessionmaker(autocommit=autocommit,
expire_on_commit=expire_on_commit)
get_session = session.get_session
@contextlib.contextmanager
def transaction(self, expire_on_commit=False):
@ -249,74 +155,23 @@ class Base(object):
with session.begin():
yield session
def get_engine(self, allow_global_engine=True):
"""Return a SQLAlchemy engine.
If allow_global_engine is True and an in-memory sqlite connection
string is provided by CONF, all backends will share a global sqlalchemy
engine.
"""
def new_engine():
connection_dict = sql.engine.url.make_url(CONF.sql.connection)
engine_config = {
'convert_unicode': True,
'echo': CONF.debug and CONF.verbose,
'pool_recycle': CONF.sql.idle_timeout,
}
if 'sqlite' in connection_dict.drivername:
engine_config['poolclass'] = sqlalchemy.pool.StaticPool
engine = sql.create_engine(CONF.sql.connection, **engine_config)
if engine.name == 'mysql':
sql.event.listen(engine, 'checkout', mysql_on_checkout)
elif engine.name == 'ibm_db_sa':
callback = functools.partial(db2_on_checkout, engine)
sql.event.listen(engine, 'checkout', callback)
return engine
if not allow_global_engine:
return new_engine()
if GLOBAL_ENGINE:
return GLOBAL_ENGINE
engine = new_engine()
# auto-build the db to support wsgi server w/ in-memory backend
if CONF.sql.connection == 'sqlite://':
ModelBase.metadata.create_all(bind=engine)
set_global_engine(engine)
return engine
def get_sessionmaker(self, engine, autocommit=True,
expire_on_commit=False):
"""Return a SQLAlchemy sessionmaker using the given engine."""
return sqlalchemy.orm.sessionmaker(
bind=engine,
autocommit=autocommit,
expire_on_commit=expire_on_commit)
def clear_engine(self):
self._engine = None
self._sessionmaker = None
def handle_conflicts(conflict_type='object'):
"""Converts IntegrityError into HTTP 409 Conflict."""
"""Converts select sqlalchemy exceptions into HTTP 409 Conflict."""
def decorator(method):
@functools.wraps(method)
def wrapper(*args, **kwargs):
try:
return method(*args, **kwargs)
except (IntegrityError, OperationalError) as e:
raise exception.Conflict(type=conflict_type,
details=str(e.orig))
except db_exception.DBDuplicateEntry as e:
raise exception.Conflict(type=conflict_type, details=str(e))
except db_exception.DBError as e:
# TODO(blk-u): inspecting inner_exception breaks encapsulation;
# oslo.db should provide exception we need.
if isinstance(e.inner_exception, IntegrityError):
raise exception.Conflict(type=conflict_type,
details=str(e))
raise
return wrapper
return decorator

View File

@ -41,11 +41,11 @@ except ImportError:
def migrate_repository(version, current_version, repo_path):
if version is None or version > current_version:
result = versioning_api.upgrade(CONF.sql.connection,
result = versioning_api.upgrade(CONF.database.connection,
repo_path, version)
else:
result = versioning_api.downgrade(
CONF.sql.connection, repo_path, version)
CONF.database.connection, repo_path, version)
return result
@ -65,7 +65,7 @@ def db_version(repo_path=None):
if repo_path is None:
repo_path = find_migrate_repo()
try:
return versioning_api.db_version(CONF.sql.connection, repo_path)
return versioning_api.db_version(CONF.database.connection, repo_path)
except versioning_exceptions.DatabaseNotControlledError:
return db_version_control(0)
@ -73,7 +73,8 @@ def db_version(repo_path=None):
def db_version_control(version=None, repo_path=None):
if repo_path is None:
repo_path = find_migrate_repo()
versioning_api.version_control(CONF.sql.connection, repo_path, version)
versioning_api.version_control(CONF.database.connection, repo_path,
version)
return version

View File

@ -1,4 +1,4 @@
#Used for running the Migrate tests against a live DB2 Server
#See _sql_livetest.py
[sql]
[database]
connection = ibm_db_sa://keystone:keystone@/staktest?charset=utf8

View File

@ -1,4 +1,4 @@
[sql]
[database]
connection = sqlite://
#For a file based sqlite use
#connection = sqlite:////tmp/keystone.db

View File

@ -1,4 +1,4 @@
[sql]
[database]
connection = sqlite://
#For a file based sqlite use
#connection = sqlite:////tmp/keystone.db

View File

@ -1,4 +1,4 @@
#Used for running the Migrate tests against a live Mysql Server
#See _sql_livetest.py
[sql]
[database]
connection = mysql://keystone:keystone@localhost/keystone_test?charset=utf8

View File

@ -1,4 +1,4 @@
#Used for running the Migrate tests against a live Postgresql Server
#See _sql_livetest.py
[sql]
[database]
connection = postgresql://keystone:keystone@localhost/keystone_test?client_encoding=utf8

View File

@ -1,4 +1,4 @@
[sql]
[database]
connection = sqlite://
#For a file based sqlite use
#connection = sqlite:////tmp/keystone.db

View File

@ -1,2 +1,2 @@
[sql]
[database]
connection = sqlite:///keystone/tests/tmp/test.db

View File

@ -54,6 +54,7 @@ from keystone.common import utils
from keystone.common import wsgi
from keystone import config
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common import log
from keystone.openstack.common import timeutils
from keystone import service
@ -94,6 +95,17 @@ class dirs:
return os.path.join(TMPDIR, *p)
# keystone.common.sql.initialize() for testing.
def _initialize_sql_session():
db_file = dirs.tmp('test.db')
session.set_defaults(
sql_connection="sqlite:///" + db_file,
sqlite_db=db_file)
_initialize_sql_session()
def checkout_vendor(repo, rev):
# TODO(termie): this function is a good target for some optimizations :PERF
name = repo.split('/')[-1]
@ -165,7 +177,7 @@ def remove_generated_paste_config(extension_name):
def teardown_database():
sql.core.set_global_engine(None)
session.cleanup()
def skip_if_cache_disabled(*sections):

View File

@ -0,0 +1,8 @@
# Options in this file are deprecated. See test_config.
[sql]
# These options were deprecated in Icehouse with the switch to oslo's
# db.sqlalchemy.
connection = sqlite://deprecated
idle_timeout = 54321

View File

@ -0,0 +1,15 @@
# Options in this file are deprecated. See test_config.
[sql]
# These options were deprecated in Icehouse with the switch to oslo's
# db.sqlalchemy.
connection = sqlite://deprecated
idle_timeout = 54321
[database]
# These are the new options from the [sql] section.
connection = sqlite://new
idle_timeout = 65432

View File

@ -27,6 +27,7 @@ from keystone.common import sql
from keystone import config
from keystone import exception
from keystone import identity
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common.fixture import moxstubout
from keystone import tests
from keystone.tests import default_fixtures
@ -1004,7 +1005,7 @@ class LdapIdentitySqlAssignment(sql.Base, tests.TestCase, BaseLDAPIdentity):
self.clear_database()
self.load_backends()
cache.configure_cache_region(cache.REGION)
self.engine = self.get_engine()
self.engine = session.get_engine()
sql.ModelBase.metadata.create_all(bind=self.engine)
self.load_fixtures(default_fixtures)
#defaulted by the data load
@ -1012,8 +1013,7 @@ class LdapIdentitySqlAssignment(sql.Base, tests.TestCase, BaseLDAPIdentity):
def tearDown(self):
sql.ModelBase.metadata.drop_all(bind=self.engine)
self.engine.dispose()
sql.set_global_engine(None)
session.cleanup()
super(LdapIdentitySqlAssignment, self).tearDown()
def test_domain_crud(self):
@ -1055,7 +1055,7 @@ class MultiLDAPandSQLIdentity(sql.Base, tests.TestCase, BaseLDAPIdentity):
self._set_config()
self.load_backends()
self.engine = self.get_engine()
self.engine = session.get_engine()
sql.ModelBase.metadata.create_all(bind=self.engine)
self._setup_domain_test_data()
@ -1081,8 +1081,7 @@ class MultiLDAPandSQLIdentity(sql.Base, tests.TestCase, BaseLDAPIdentity):
'identity',
domain_specific_drivers_enabled=self.orig_config_domains_enabled)
sql.ModelBase.metadata.drop_all(bind=self.engine)
self.engine.dispose()
sql.set_global_engine(None)
session.cleanup()
def _set_config(self):
self.config([tests.dirs.etc('keystone.conf.sample'),

View File

@ -22,6 +22,7 @@ from keystone.common import sql
from keystone import config
from keystone import exception
from keystone.identity.backends import sql as identity_sql
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common.fixture import moxstubout
from keystone import tests
from keystone.tests import default_fixtures
@ -45,7 +46,7 @@ class SqlTests(tests.TestCase, sql.Base):
# create tables and keep an engine reference for cleanup.
# this must be done after the models are loaded by the managers.
self.engine = self.get_engine()
self.engine = session.get_engine()
sql.ModelBase.metadata.create_all(bind=self.engine)
# populate the engine with tables & fixtures
@ -55,8 +56,7 @@ class SqlTests(tests.TestCase, sql.Base):
def tearDown(self):
sql.ModelBase.metadata.drop_all(bind=self.engine)
self.engine.dispose()
sql.set_global_engine(None)
session.cleanup()
super(SqlTests, self).tearDown()

View File

@ -40,3 +40,37 @@ class ConfigTestCase(tests.TestCase):
CONF.auth.password)
self.assertEqual('keystone.auth.plugins.token.Token',
CONF.auth.token)
class DeprecatedTestCase(tests.TestCase):
"""Test using the original (deprecated) name for renamed options."""
def setUp(self):
super(DeprecatedTestCase, self).setUp()
self.config([tests.dirs.etc('keystone.conf.sample'),
tests.dirs.tests('test_overrides.conf'),
tests.dirs.tests('deprecated.conf'), ])
def test_sql(self):
# Options in [sql] were moved to [database] in Icehouse for the change
# to use oslo-incubator's db.sqlalchemy.sessions.
self.assertEqual(CONF.database.connection, 'sqlite://deprecated')
self.assertEqual(CONF.database.idle_timeout, 54321)
class DeprecatedOverrideTestCase(tests.TestCase):
"""Test using the deprecated AND new name for renamed options."""
def setUp(self):
super(DeprecatedOverrideTestCase, self).setUp()
self.config([tests.dirs.etc('keystone.conf.sample'),
tests.dirs.tests('test_overrides.conf'),
tests.dirs.tests('deprecated_override.conf'), ])
def test_sql(self):
# Options in [sql] were moved to [database] in Icehouse for the change
# to use oslo-incubator's db.sqlalchemy.sessions.
self.assertEqual(CONF.database.connection, 'sqlite://new')
self.assertEqual(CONF.database.idle_timeout, 65432)

View File

@ -21,6 +21,7 @@ from keystoneclient.contrib.ec2 import utils as ec2_utils
from keystone.common import sql
from keystone import config
from keystone.openstack.common.db.sqlalchemy import session
from keystone import tests
from keystone.tests import test_keystoneclient
@ -36,7 +37,7 @@ class KcMasterSqlTestCase(test_keystoneclient.KcMasterTestCase, sql.Base):
tests.dirs.tests('backend_sql.conf')])
self.load_backends()
self.engine = self.get_engine()
self.engine = session.get_engine()
sql.ModelBase.metadata.create_all(bind=self.engine)
def setUp(self):
@ -45,8 +46,7 @@ class KcMasterSqlTestCase(test_keystoneclient.KcMasterTestCase, sql.Base):
def tearDown(self):
sql.ModelBase.metadata.drop_all(bind=self.engine)
self.engine.dispose()
sql.set_global_engine(None)
session.cleanup()
super(KcMasterSqlTestCase, self).tearDown()
def test_endpoint_crud(self):

View File

@ -1,292 +0,0 @@
# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from sqlalchemy.exc import DisconnectionError
from keystone.common import sql
from keystone import tests
class CallbackMonitor:
def __init__(self, expect_called=True, raise_=False):
self.expect_called = expect_called
self.called = False
self._complete = False
self._raise = raise_
def call_this(self):
if self._complete:
return
if not self.expect_called:
raise Exception("Did not expect callback.")
if self.called:
raise Exception("Callback already called.")
self.called = True
if self._raise:
raise Exception("When called, raises.")
def check(self):
if self.expect_called:
if not self.called:
raise Exception("Expected function to be called.")
self._complete = True
class TestGlobalEngine(tests.TestCase):
def tearDown(self):
sql.set_global_engine(None)
super(TestGlobalEngine, self).tearDown()
def test_notify_on_set(self):
# If call sql.set_global_engine(), notify callbacks get called.
cb_mon = CallbackMonitor()
sql.register_global_engine_callback(cb_mon.call_this)
fake_engine = object()
sql.set_global_engine(fake_engine)
cb_mon.check()
def test_multi_notify(self):
# You can also set multiple notify callbacks and they each get called.
cb_mon1 = CallbackMonitor()
cb_mon2 = CallbackMonitor()
sql.register_global_engine_callback(cb_mon1.call_this)
sql.register_global_engine_callback(cb_mon2.call_this)
fake_engine = object()
sql.set_global_engine(fake_engine)
cb_mon1.check()
cb_mon2.check()
def test_notify_once(self):
# After a callback is called, it's not called again if set global
# engine again.
cb_mon = CallbackMonitor()
sql.register_global_engine_callback(cb_mon.call_this)
fake_engine = object()
sql.set_global_engine(fake_engine)
fake_engine = object()
# Note that cb_mon.call_this would raise if it's called again.
sql.set_global_engine(fake_engine)
cb_mon.check()
def test_set_same_engine(self):
# If you set the global engine to the same engine, callbacks don't get
# called.
fake_engine = object()
sql.set_global_engine(fake_engine)
cb_mon = CallbackMonitor(expect_called=False)
sql.register_global_engine_callback(cb_mon.call_this)
# Note that cb_mon.call_this would raise if it's called.
sql.set_global_engine(fake_engine)
cb_mon.check()
def test_notify_register_same(self):
# If you register the same callback twice, only gets called once.
cb_mon = CallbackMonitor()
sql.register_global_engine_callback(cb_mon.call_this)
sql.register_global_engine_callback(cb_mon.call_this)
fake_engine = object()
# Note that cb_mon.call_this would raise if it's called twice.
sql.set_global_engine(fake_engine)
cb_mon.check()
def test_callback_throws(self):
# If a callback function raises,
# a) the caller doesn't know about it,
# b) other callbacks are still called
cb_mon1 = CallbackMonitor(raise_=True)
cb_mon2 = CallbackMonitor()
sql.register_global_engine_callback(cb_mon1.call_this)
sql.register_global_engine_callback(cb_mon2.call_this)
fake_engine = object()
sql.set_global_engine(fake_engine)
cb_mon1.check()
cb_mon2.check()
class TestBase(tests.TestCase):
def tearDown(self):
sql.set_global_engine(None)
super(TestBase, self).tearDown()
def test_get_engine_global(self):
# If call get_engine() twice, get the same global engine.
base = sql.Base()
engine1 = base.get_engine()
self.assertIsNotNone(engine1)
engine2 = base.get_engine()
self.assertIs(engine1, engine2)
def test_get_engine_not_global(self):
# If call get_engine() twice, once with allow_global_engine=True
# and once with allow_global_engine=False, get different engines.
base = sql.Base()
engine1 = base.get_engine()
engine2 = base.get_engine(allow_global_engine=False)
self.assertIsNot(engine1, engine2)
def test_get_session(self):
# autocommit and expire_on_commit flags to get_session() are passed on
# to the session created.
base = sql.Base()
session = base.get_session(autocommit=False, expire_on_commit=True)
self.assertFalse(session.autocommit)
self.assertTrue(session.expire_on_commit)
def test_get_session_invalidated(self):
# If clear the global engine, a new engine is used for get_session().
base = sql.Base()
session1 = base.get_session()
sql.set_global_engine(None)
session2 = base.get_session()
self.assertIsNot(session1.bind, session2.bind)
class FakeDbapiConn(object):
"""Simulates the dbapi_conn passed to mysql_on_checkout."""
class OperationalError(Exception):
pass
class Cursor(object):
def __init__(self, failwith=None):
self._failwith = failwith
def execute(self, sql):
if self._failwith:
raise self._failwith
def __init__(self, failwith=None):
self._cursor = self.Cursor(failwith=failwith)
def cursor(self):
return self._cursor
class TestMysqlCheckoutHandler(tests.TestCase):
def _do_on_checkout(self, failwith=None):
dbapi_conn = FakeDbapiConn(failwith=failwith)
connection_rec = None
connection_proxy = None
sql.mysql_on_checkout(dbapi_conn, connection_rec, connection_proxy)
def test_checkout_success(self):
# If call mysql_on_checkout and query doesn't raise anything, then no
# problems
# If this doesn't raise then the test is successful.
self._do_on_checkout()
def test_disconnected(self):
# If call mysql_on_checkout and query raises OperationalError with
# specific errors, then raises DisconnectionError.
# mysql_on_checkout should look for 2006 among others.
disconnected_exception = FakeDbapiConn.OperationalError(2006)
self.assertRaises(DisconnectionError,
self._do_on_checkout,
failwith=disconnected_exception)
def test_error(self):
# If call mysql_on_checkout and query raises an exception that doesn't
# indicate disconnected, then the original error is raised.
# mysql_on_checkout doesn't look for 2056
other_exception = FakeDbapiConn.OperationalError(2056)
self.assertRaises(FakeDbapiConn.OperationalError,
self._do_on_checkout,
failwith=other_exception)
class TestDb2CheckoutHandler(tests.TestCase):
class FakeEngine(object):
class Dialect():
DISCONNECT_EXCEPTION = Exception()
@classmethod
def is_disconnect(cls, e, *args):
return (e is cls.DISCONNECT_EXCEPTION)
dialect = Dialect()
def _do_on_checkout(self, failwith=None):
engine = self.FakeEngine()
dbapi_conn = FakeDbapiConn(failwith=failwith)
connection_rec = None
connection_proxy = None
sql.db2_on_checkout(engine, dbapi_conn, connection_rec,
connection_proxy)
def test_checkout_success(self):
# If call db2_on_checkout and query doesn't raise anything, then no
# problems
# If this doesn't raise then the test is successful.
self._do_on_checkout()
def test_disconnected(self):
# If call db2_on_checkout and query raises exception that engine
# dialect says is a disconnect problem, then raises DisconnectionError.
disconnected_exception = self.FakeEngine.Dialect.DISCONNECT_EXCEPTION
self.assertRaises(DisconnectionError,
self._do_on_checkout,
failwith=disconnected_exception)
def test_error(self):
# If call db2_on_checkout and query raises an exception that engine
# dialect says is not a disconnect problem, then the original error is
# raised.
# fake engine dialect doesn't look for this exception.
class OtherException(Exception):
pass
other_exception = OtherException()
self.assertRaises(OtherException,
self._do_on_checkout,
failwith=other_exception)

View File

@ -38,12 +38,12 @@ import uuid
from migrate.versioning import api as versioning_api
import sqlalchemy
from keystone.common import sql
from keystone.common.sql import migration
from keystone.common import utils
from keystone import config
from keystone import credential
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import session
from keystone import tests
from keystone.tests import default_fixtures
@ -72,13 +72,10 @@ class SqlMigrateBase(tests.TestCase):
super(SqlMigrateBase, self).setUp()
self.config(self.config_files())
self.base = sql.Base()
# create and share a single sqlalchemy engine for testing
self.engine = self.base.get_engine(allow_global_engine=False)
sql.core.set_global_engine(self.engine)
self.Session = self.base.get_sessionmaker(engine=self.engine,
autocommit=False)
self.engine = session.get_engine()
self.Session = session.get_maker(self.engine, autocommit=False)
self.initialize_sql()
self.repo_path = migration.find_migrate_repo(self.repo_package())
@ -95,7 +92,7 @@ class SqlMigrateBase(tests.TestCase):
autoload=True)
self.downgrade(0)
table.drop(self.engine, checkfirst=True)
sql.core.set_global_engine(None)
session.cleanup()
super(SqlMigrateBase, self).tearDown()
def select_table(self, name):