From 7063585c60205fe031e1c74289d88886705cfb57 Mon Sep 17 00:00:00 2001 From: Doug Hellmann Date: Fri, 12 Dec 2014 14:23:13 -0500 Subject: [PATCH] Move files out of the namespace package Move the public API out of oslo.db to oslo_db. Retain the ability to import from the old namespace package for backwards compatibility for this release cycle. Blueprint: drop-namespace-packages Change-Id: Ie96b482b9fbcb1d85203ad35bb65c1f43e912a44 --- .testr.conf | 2 +- doc/source/api/api.rst | 4 +- doc/source/api/concurrency.rst | 4 +- doc/source/api/exception.rst | 4 +- doc/source/api/options.rst | 4 +- doc/source/api/sqlalchemy/index.rst | 2 +- doc/source/api/sqlalchemy/migration.rst | 4 +- doc/source/api/sqlalchemy/models.rst | 4 +- doc/source/api/sqlalchemy/provision.rst | 4 +- doc/source/api/sqlalchemy/session.rst | 4 +- doc/source/api/sqlalchemy/test_base.rst | 4 +- doc/source/api/sqlalchemy/test_migrations.rst | 4 +- doc/source/api/sqlalchemy/utils.rst | 4 +- oslo/db/__init__.py | 26 + oslo/db/api.py | 216 +--- oslo/db/concurrency.py | 68 +- oslo/db/exception.py | 160 +-- oslo/db/options.py | 229 +--- oslo/db/sqlalchemy/compat/__init__.py | 23 +- oslo/db/sqlalchemy/compat/utils.py | 17 +- oslo/db/sqlalchemy/exc_filters.py | 349 +----- oslo/db/sqlalchemy/migration.py | 165 +-- oslo/db/sqlalchemy/migration_cli/__init__.py | 18 + oslo/db/sqlalchemy/models.py | 115 +- oslo/db/sqlalchemy/provision.py | 494 +------- oslo/db/sqlalchemy/session.py | 834 +------------ oslo/db/sqlalchemy/test_base.py | 132 +- oslo/db/sqlalchemy/test_migrations.py | 600 +-------- oslo/db/sqlalchemy/utils.py | 999 +-------------- {tests => oslo_db}/__init__.py | 0 {oslo/db => oslo_db}/_i18n.py | 0 oslo_db/api.py | 229 ++++ oslo_db/concurrency.py | 81 ++ oslo_db/exception.py | 173 +++ oslo_db/options.py | 220 ++++ {tests => oslo_db}/sqlalchemy/__init__.py | 0 oslo_db/sqlalchemy/compat/__init__.py | 30 + .../sqlalchemy/compat/engine_connect.py | 2 +- .../sqlalchemy/compat/handle_error.py | 2 +- oslo_db/sqlalchemy/compat/utils.py | 26 + oslo_db/sqlalchemy/exc_filters.py | 358 ++++++ oslo_db/sqlalchemy/migration.py | 160 +++ .../sqlalchemy/migration_cli/README.rst | 0 oslo_db/sqlalchemy/migration_cli/__init__.py | 0 .../sqlalchemy/migration_cli/ext_alembic.py | 4 +- .../sqlalchemy/migration_cli/ext_base.py | 0 .../sqlalchemy/migration_cli/ext_migrate.py | 8 +- .../sqlalchemy/migration_cli/manager.py | 0 oslo_db/sqlalchemy/models.py | 128 ++ oslo_db/sqlalchemy/provision.py | 507 ++++++++ oslo_db/sqlalchemy/session.py | 847 +++++++++++++ oslo_db/sqlalchemy/test_base.py | 127 ++ oslo_db/sqlalchemy/test_migrations.py | 613 +++++++++ oslo_db/sqlalchemy/utils.py | 1012 +++++++++++++++ oslo_db/tests/__init__.py | 0 {tests => oslo_db/tests}/base.py | 0 oslo_db/tests/old_import_api/__init__.py | 0 oslo_db/tests/old_import_api/base.py | 53 + .../old_import_api/sqlalchemy/__init__.py | 0 .../sqlalchemy/test_engine_connect.py | 68 + .../sqlalchemy/test_exc_filters.py | 833 +++++++++++++ .../sqlalchemy/test_handle_error.py | 194 +++ .../sqlalchemy/test_migrate_cli.py | 2 +- .../sqlalchemy/test_migration_common.py | 174 +++ .../sqlalchemy/test_migrations.py | 0 .../old_import_api}/sqlalchemy/test_models.py | 0 .../sqlalchemy/test_options.py | 2 +- .../sqlalchemy/test_sqlalchemy.py | 554 +++++++++ .../old_import_api/sqlalchemy/test_utils.py | 1093 +++++++++++++++++ .../tests/old_import_api}/test_api.py | 6 +- .../tests/old_import_api}/test_concurrency.py | 8 +- oslo_db/tests/old_import_api/test_warning.py | 61 + .../tests/old_import_api}/utils.py | 0 oslo_db/tests/sqlalchemy/__init__.py | 0 .../tests}/sqlalchemy/test_engine_connect.py | 4 +- .../tests}/sqlalchemy/test_exc_filters.py | 12 +- .../tests}/sqlalchemy/test_handle_error.py | 8 +- oslo_db/tests/sqlalchemy/test_migrate_cli.py | 222 ++++ .../sqlalchemy/test_migration_common.py | 8 +- oslo_db/tests/sqlalchemy/test_migrations.py | 309 +++++ oslo_db/tests/sqlalchemy/test_models.py | 146 +++ oslo_db/tests/sqlalchemy/test_options.py | 127 ++ .../tests}/sqlalchemy/test_sqlalchemy.py | 40 +- .../tests}/sqlalchemy/test_utils.py | 18 +- oslo_db/tests/test_api.py | 177 +++ oslo_db/tests/test_concurrency.py | 108 ++ oslo_db/tests/utils.py | 40 + setup.cfg | 9 +- tools/run_cross_tests.sh | 5 + tox.ini | 2 +- 90 files changed, 8865 insertions(+), 4438 deletions(-) rename {tests => oslo_db}/__init__.py (100%) rename {oslo/db => oslo_db}/_i18n.py (100%) create mode 100644 oslo_db/api.py create mode 100644 oslo_db/concurrency.py create mode 100644 oslo_db/exception.py create mode 100644 oslo_db/options.py rename {tests => oslo_db}/sqlalchemy/__init__.py (100%) create mode 100644 oslo_db/sqlalchemy/compat/__init__.py rename {oslo/db => oslo_db}/sqlalchemy/compat/engine_connect.py (97%) rename {oslo/db => oslo_db}/sqlalchemy/compat/handle_error.py (99%) create mode 100644 oslo_db/sqlalchemy/compat/utils.py create mode 100644 oslo_db/sqlalchemy/exc_filters.py create mode 100644 oslo_db/sqlalchemy/migration.py rename {oslo/db => oslo_db}/sqlalchemy/migration_cli/README.rst (100%) create mode 100644 oslo_db/sqlalchemy/migration_cli/__init__.py rename {oslo/db => oslo_db}/sqlalchemy/migration_cli/ext_alembic.py (96%) rename {oslo/db => oslo_db}/sqlalchemy/migration_cli/ext_base.py (100%) rename {oslo/db => oslo_db}/sqlalchemy/migration_cli/ext_migrate.py (92%) rename {oslo/db => oslo_db}/sqlalchemy/migration_cli/manager.py (100%) create mode 100644 oslo_db/sqlalchemy/models.py create mode 100644 oslo_db/sqlalchemy/provision.py create mode 100644 oslo_db/sqlalchemy/session.py create mode 100644 oslo_db/sqlalchemy/test_base.py create mode 100644 oslo_db/sqlalchemy/test_migrations.py create mode 100644 oslo_db/sqlalchemy/utils.py create mode 100644 oslo_db/tests/__init__.py rename {tests => oslo_db/tests}/base.py (100%) create mode 100644 oslo_db/tests/old_import_api/__init__.py create mode 100644 oslo_db/tests/old_import_api/base.py create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/__init__.py create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/test_engine_connect.py create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/test_exc_filters.py create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/test_handle_error.py rename {tests => oslo_db/tests/old_import_api}/sqlalchemy/test_migrate_cli.py (99%) create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/test_migration_common.py rename {tests => oslo_db/tests/old_import_api}/sqlalchemy/test_migrations.py (100%) rename {tests => oslo_db/tests/old_import_api}/sqlalchemy/test_models.py (100%) rename {tests => oslo_db/tests/old_import_api}/sqlalchemy/test_options.py (98%) create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/test_sqlalchemy.py create mode 100644 oslo_db/tests/old_import_api/sqlalchemy/test_utils.py rename {tests => oslo_db/tests/old_import_api}/test_api.py (97%) rename {tests => oslo_db/tests/old_import_api}/test_concurrency.py (95%) create mode 100644 oslo_db/tests/old_import_api/test_warning.py rename {tests => oslo_db/tests/old_import_api}/utils.py (100%) create mode 100644 oslo_db/tests/sqlalchemy/__init__.py rename {tests => oslo_db/tests}/sqlalchemy/test_engine_connect.py (93%) rename {tests => oslo_db/tests}/sqlalchemy/test_exc_filters.py (99%) rename {tests => oslo_db/tests}/sqlalchemy/test_handle_error.py (97%) create mode 100644 oslo_db/tests/sqlalchemy/test_migrate_cli.py rename {tests => oslo_db/tests}/sqlalchemy/test_migration_common.py (98%) create mode 100644 oslo_db/tests/sqlalchemy/test_migrations.py create mode 100644 oslo_db/tests/sqlalchemy/test_models.py create mode 100644 oslo_db/tests/sqlalchemy/test_options.py rename {tests => oslo_db/tests}/sqlalchemy/test_sqlalchemy.py (94%) rename {tests => oslo_db/tests}/sqlalchemy/test_utils.py (99%) create mode 100644 oslo_db/tests/test_api.py create mode 100644 oslo_db/tests/test_concurrency.py create mode 100644 oslo_db/tests/utils.py diff --git a/.testr.conf b/.testr.conf index 35d9ba43..c9be815f 100644 --- a/.testr.conf +++ b/.testr.conf @@ -2,6 +2,6 @@ test_command=OS_STDOUT_CAPTURE=${OS_STDOUT_CAPTURE:-1} \ OS_STDERR_CAPTURE=${OS_STDERR_CAPTURE:-1} \ OS_TEST_TIMEOUT=${OS_TEST_TIMEOUT:-60} \ - ${PYTHON:-python} -m subunit.run discover -t ./ ./tests $LISTOPT $IDOPTION + ${PYTHON:-python} -m subunit.run discover -t ./ ./oslo_db/tests $LISTOPT $IDOPTION test_id_option=--load-list $IDFILE test_list_option=--list diff --git a/doc/source/api/api.rst b/doc/source/api/api.rst index 1591cc1c..d58cbaff 100644 --- a/doc/source/api/api.rst +++ b/doc/source/api/api.rst @@ -1,8 +1,8 @@ ============= - oslo.db.api + oslo_db.api ============= -.. automodule:: oslo.db.api +.. automodule:: oslo_db.api :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/concurrency.rst b/doc/source/api/concurrency.rst index 527883a5..d5775e75 100644 --- a/doc/source/api/concurrency.rst +++ b/doc/source/api/concurrency.rst @@ -1,8 +1,8 @@ ===================== - oslo.db.concurrency + oslo_db.concurrency ===================== -.. automodule:: oslo.db.concurrency +.. automodule:: oslo_db.concurrency :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/exception.rst b/doc/source/api/exception.rst index 12b0bf04..329525bf 100644 --- a/doc/source/api/exception.rst +++ b/doc/source/api/exception.rst @@ -1,8 +1,8 @@ =================== - oslo.db.exception + oslo_db.exception =================== -.. automodule:: oslo.db.exception +.. automodule:: oslo_db.exception :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/options.rst b/doc/source/api/options.rst index 40d5a833..4216cda0 100644 --- a/doc/source/api/options.rst +++ b/doc/source/api/options.rst @@ -1,8 +1,8 @@ ================= - oslo.db.options + oslo_db.options ================= -.. automodule:: oslo.db.options +.. automodule:: oslo_db.options :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/index.rst b/doc/source/api/sqlalchemy/index.rst index 62fc2b2a..5e2496e6 100644 --- a/doc/source/api/sqlalchemy/index.rst +++ b/doc/source/api/sqlalchemy/index.rst @@ -1,5 +1,5 @@ ==================== - oslo.db.sqlalchemy + oslo_db.sqlalchemy ==================== .. toctree:: diff --git a/doc/source/api/sqlalchemy/migration.rst b/doc/source/api/sqlalchemy/migration.rst index 2355cbc3..6c7ee469 100644 --- a/doc/source/api/sqlalchemy/migration.rst +++ b/doc/source/api/sqlalchemy/migration.rst @@ -1,8 +1,8 @@ ============================== - oslo.db.sqlalchemy.migration + oslo_db.sqlalchemy.migration ============================== -.. automodule:: oslo.db.sqlalchemy.migration +.. automodule:: oslo_db.sqlalchemy.migration :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/models.rst b/doc/source/api/sqlalchemy/models.rst index b9023204..5c6f4619 100644 --- a/doc/source/api/sqlalchemy/models.rst +++ b/doc/source/api/sqlalchemy/models.rst @@ -1,8 +1,8 @@ =========================== - oslo.db.sqlalchemy.models + oslo_db.sqlalchemy.models =========================== -.. automodule:: oslo.db.sqlalchemy.models +.. automodule:: oslo_db.sqlalchemy.models :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/provision.rst b/doc/source/api/sqlalchemy/provision.rst index 7d003d7d..0ae5ff1d 100644 --- a/doc/source/api/sqlalchemy/provision.rst +++ b/doc/source/api/sqlalchemy/provision.rst @@ -1,8 +1,8 @@ ============================== - oslo.db.sqlalchemy.provision + oslo_db.sqlalchemy.provision ============================== -.. automodule:: oslo.db.sqlalchemy.provision +.. automodule:: oslo_db.sqlalchemy.provision :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/session.rst b/doc/source/api/sqlalchemy/session.rst index 14e6f5ac..b6635bdc 100644 --- a/doc/source/api/sqlalchemy/session.rst +++ b/doc/source/api/sqlalchemy/session.rst @@ -1,8 +1,8 @@ ============================ - oslo.db.sqlalchemy.session + oslo_db.sqlalchemy.session ============================ -.. automodule:: oslo.db.sqlalchemy.session +.. automodule:: oslo_db.sqlalchemy.session :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/test_base.rst b/doc/source/api/sqlalchemy/test_base.rst index 77941705..0eb6878b 100644 --- a/doc/source/api/sqlalchemy/test_base.rst +++ b/doc/source/api/sqlalchemy/test_base.rst @@ -1,8 +1,8 @@ ============================== - oslo.db.sqlalchemy.test_base + oslo_db.sqlalchemy.test_base ============================== -.. automodule:: oslo.db.sqlalchemy.test_base +.. automodule:: oslo_db.sqlalchemy.test_base :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/test_migrations.rst b/doc/source/api/sqlalchemy/test_migrations.rst index 4b3c81c7..afb24ffa 100644 --- a/doc/source/api/sqlalchemy/test_migrations.rst +++ b/doc/source/api/sqlalchemy/test_migrations.rst @@ -1,8 +1,8 @@ ==================================== - oslo.db.sqlalchemy.test_migrations + oslo_db.sqlalchemy.test_migrations ==================================== -.. automodule:: oslo.db.sqlalchemy.test_migrations +.. automodule:: oslo_db.sqlalchemy.test_migrations :members: :undoc-members: :show-inheritance: diff --git a/doc/source/api/sqlalchemy/utils.rst b/doc/source/api/sqlalchemy/utils.rst index cccc93f4..e6576364 100644 --- a/doc/source/api/sqlalchemy/utils.rst +++ b/doc/source/api/sqlalchemy/utils.rst @@ -1,8 +1,8 @@ ========================== - oslo.db.sqlalchemy.utils + oslo_db.sqlalchemy.utils ========================== -.. automodule:: oslo.db.sqlalchemy.utils +.. automodule:: oslo_db.sqlalchemy.utils :members: :undoc-members: :show-inheritance: diff --git a/oslo/db/__init__.py b/oslo/db/__init__.py index e69de29b..73e54f3d 100644 --- a/oslo/db/__init__.py +++ b/oslo/db/__init__.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import warnings + + +def deprecated(): + new_name = __name__.replace('.', '_') + warnings.warn( + ('The oslo namespace package is deprecated. Please use %s instead.' % + new_name), + DeprecationWarning, + stacklevel=3, + ) + + +deprecated() diff --git a/oslo/db/api.py b/oslo/db/api.py index 906e88d8..c0453b52 100644 --- a/oslo/db/api.py +++ b/oslo/db/api.py @@ -1,4 +1,3 @@ -# Copyright (c) 2013 Rackspace Hosting # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,217 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -""" -================================= -Multiple DB API backend support. -================================= - -A DB backend module should implement a method named 'get_backend' which -takes no arguments. The method can return any object that implements DB -API methods. -""" - -import logging -import threading -import time - -from oslo.utils import importutils -import six - -from oslo.db._i18n import _LE -from oslo.db import exception -from oslo.db import options - - -LOG = logging.getLogger(__name__) - - -def safe_for_db_retry(f): - """Indicate api method as safe for re-connection to database. - - Database connection retries will be enabled for the decorated api method. - Database connection failure can have many causes, which can be temporary. - In such cases retry may increase the likelihood of connection. - - Usage:: - - @safe_for_db_retry - def api_method(self): - self.engine.connect() - - - :param f: database api method. - :type f: function. - """ - f.__dict__['enable_retry'] = True - return f - - -class wrap_db_retry(object): - """Decorator class. Retry db.api methods, if DBConnectionError() raised. - - Retry decorated db.api methods. If we enabled `use_db_reconnect` - in config, this decorator will be applied to all db.api functions, - marked with @safe_for_db_retry decorator. - Decorator catches DBConnectionError() and retries function in a - loop until it succeeds, or until maximum retries count will be reached. - - Keyword arguments: - - :param retry_interval: seconds between transaction retries - :type retry_interval: int - - :param max_retries: max number of retries before an error is raised - :type max_retries: int - - :param inc_retry_interval: determine increase retry interval or not - :type inc_retry_interval: bool - - :param max_retry_interval: max interval value between retries - :type max_retry_interval: int - """ - - def __init__(self, retry_interval, max_retries, inc_retry_interval, - max_retry_interval): - super(wrap_db_retry, self).__init__() - - self.retry_interval = retry_interval - self.max_retries = max_retries - self.inc_retry_interval = inc_retry_interval - self.max_retry_interval = max_retry_interval - - def __call__(self, f): - @six.wraps(f) - def wrapper(*args, **kwargs): - next_interval = self.retry_interval - remaining = self.max_retries - - while True: - try: - return f(*args, **kwargs) - except exception.DBConnectionError as e: - if remaining == 0: - LOG.exception(_LE('DB exceeded retry limit.')) - raise exception.DBError(e) - if remaining != -1: - remaining -= 1 - LOG.exception(_LE('DB connection error.')) - # NOTE(vsergeyev): We are using patched time module, so - # this effectively yields the execution - # context to another green thread. - time.sleep(next_interval) - if self.inc_retry_interval: - next_interval = min( - next_interval * 2, - self.max_retry_interval - ) - return wrapper - - -class DBAPI(object): - """Initialize the chosen DB API backend. - - After initialization API methods is available as normal attributes of - ``DBAPI`` subclass. Database API methods are supposed to be called as - DBAPI instance methods. - - :param backend_name: name of the backend to load - :type backend_name: str - - :param backend_mapping: backend name -> module/class to load mapping - :type backend_mapping: dict - :default backend_mapping: None - - :param lazy: load the DB backend lazily on the first DB API method call - :type lazy: bool - :default lazy: False - - :keyword use_db_reconnect: retry DB transactions on disconnect or not - :type use_db_reconnect: bool - - :keyword retry_interval: seconds between transaction retries - :type retry_interval: int - - :keyword inc_retry_interval: increase retry interval or not - :type inc_retry_interval: bool - - :keyword max_retry_interval: max interval value between retries - :type max_retry_interval: int - - :keyword max_retries: max number of retries before an error is raised - :type max_retries: int - """ - - def __init__(self, backend_name, backend_mapping=None, lazy=False, - **kwargs): - - self._backend = None - self._backend_name = backend_name - self._backend_mapping = backend_mapping or {} - self._lock = threading.Lock() - - if not lazy: - self._load_backend() - - self.use_db_reconnect = kwargs.get('use_db_reconnect', False) - self.retry_interval = kwargs.get('retry_interval', 1) - self.inc_retry_interval = kwargs.get('inc_retry_interval', True) - self.max_retry_interval = kwargs.get('max_retry_interval', 10) - self.max_retries = kwargs.get('max_retries', 20) - - def _load_backend(self): - with self._lock: - if not self._backend: - # Import the untranslated name if we don't have a mapping - backend_path = self._backend_mapping.get(self._backend_name, - self._backend_name) - LOG.debug('Loading backend %(name)r from %(path)r', - {'name': self._backend_name, - 'path': backend_path}) - backend_mod = importutils.import_module(backend_path) - self._backend = backend_mod.get_backend() - - def __getattr__(self, key): - if not self._backend: - self._load_backend() - - attr = getattr(self._backend, key) - if not hasattr(attr, '__call__'): - return attr - # NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry - # DB API methods, decorated with @safe_for_db_retry - # on disconnect. - if self.use_db_reconnect and hasattr(attr, 'enable_retry'): - attr = wrap_db_retry( - retry_interval=self.retry_interval, - max_retries=self.max_retries, - inc_retry_interval=self.inc_retry_interval, - max_retry_interval=self.max_retry_interval)(attr) - - return attr - - @classmethod - def from_config(cls, conf, backend_mapping=None, lazy=False): - """Initialize DBAPI instance given a config instance. - - :param conf: oslo.config config instance - :type conf: oslo.config.cfg.ConfigOpts - - :param backend_mapping: backend name -> module/class to load mapping - :type backend_mapping: dict - - :param lazy: load the DB backend lazily on the first DB API method call - :type lazy: bool - - """ - - conf.register_opts(options.database_opts, 'database') - - return cls(backend_name=conf.database.backend, - backend_mapping=backend_mapping, - lazy=lazy, - use_db_reconnect=conf.database.use_db_reconnect, - retry_interval=conf.database.db_retry_interval, - inc_retry_interval=conf.database.db_inc_retry_interval, - max_retry_interval=conf.database.db_max_retry_interval, - max_retries=conf.database.db_max_retries) +from oslo_db.api import * # noqa diff --git a/oslo/db/concurrency.py b/oslo/db/concurrency.py index c97690f3..a59f58ab 100644 --- a/oslo/db/concurrency.py +++ b/oslo/db/concurrency.py @@ -1,4 +1,3 @@ -# Copyright 2014 Mirantis.inc # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,69 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -import copy -import logging -import threading - -from oslo.config import cfg - -from oslo.db._i18n import _LE -from oslo.db import api - - -LOG = logging.getLogger(__name__) - -tpool_opts = [ - cfg.BoolOpt('use_tpool', - default=False, - deprecated_name='dbapi_use_tpool', - deprecated_group='DEFAULT', - help='Enable the experimental use of thread pooling for ' - 'all DB API calls'), -] - - -class TpoolDbapiWrapper(object): - """DB API wrapper class. - - This wraps the oslo DB API with an option to be able to use eventlet's - thread pooling. Since the CONF variable may not be loaded at the time - this class is instantiated, we must look at it on the first DB API call. - """ - - def __init__(self, conf, backend_mapping): - self._db_api = None - self._backend_mapping = backend_mapping - self._conf = conf - self._conf.register_opts(tpool_opts, 'database') - self._lock = threading.Lock() - - @property - def _api(self): - if not self._db_api: - with self._lock: - if not self._db_api: - db_api = api.DBAPI.from_config( - conf=self._conf, backend_mapping=self._backend_mapping) - if self._conf.database.use_tpool: - try: - from eventlet import tpool - except ImportError: - LOG.exception(_LE("'eventlet' is required for " - "TpoolDbapiWrapper.")) - raise - self._db_api = tpool.Proxy(db_api) - else: - self._db_api = db_api - return self._db_api - - def __getattr__(self, key): - return getattr(self._api, key) - - -def list_opts(): - """Returns a list of oslo.config options available in this module. - - :returns: a list of (group_name, opts) tuples - """ - return [('database', copy.deepcopy(tpool_opts))] +from oslo_db.concurrency import * # noqa diff --git a/oslo/db/exception.py b/oslo/db/exception.py index a96ad767..d29a4b18 100644 --- a/oslo/db/exception.py +++ b/oslo/db/exception.py @@ -1,5 +1,3 @@ -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -14,160 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -"""DB related custom exceptions. - -Custom exceptions intended to determine the causes of specific database -errors. This module provides more generic exceptions than the database-specific -driver libraries, and so users of oslo.db can catch these no matter which -database the application is using. Most of the exceptions are wrappers. Wrapper -exceptions take an original exception as positional argument and keep it for -purposes of deeper debug. - -Example:: - - try: - statement(arg) - except sqlalchemy.exc.OperationalError as e: - raise DBDuplicateEntry(e) - - -This is useful to determine more specific error cases further at execution, -when you need to add some extra information to an error message. Wrapper -exceptions takes care about original error message displaying to not to loose -low level cause of an error. All the database api exceptions wrapped into -the specific exceptions provided belove. - - -Please use only database related custom exceptions with database manipulations -with `try/except` statement. This is required for consistent handling of -database errors. -""" - -import six - -from oslo.db._i18n import _ - - -class DBError(Exception): - - """Base exception for all custom database exceptions. - - :kwarg inner_exception: an original exception which was wrapped with - DBError or its subclasses. - """ - - def __init__(self, inner_exception=None): - self.inner_exception = inner_exception - super(DBError, self).__init__(six.text_type(inner_exception)) - - -class DBDuplicateEntry(DBError): - """Duplicate entry at unique column error. - - Raised when made an attempt to write to a unique column the same entry as - existing one. :attr: `columns` available on an instance of the exception - and could be used at error handling:: - - try: - instance_type_ref.save() - except DBDuplicateEntry as e: - if 'colname' in e.columns: - # Handle error. - - :kwarg columns: a list of unique columns have been attempted to write a - duplicate entry. - :type columns: list - :kwarg value: a value which has been attempted to write. The value will - be None, if we can't extract it for a particular database backend. Only - MySQL and PostgreSQL 9.x are supported right now. - """ - def __init__(self, columns=None, inner_exception=None, value=None): - self.columns = columns or [] - self.value = value - super(DBDuplicateEntry, self).__init__(inner_exception) - - -class DBReferenceError(DBError): - """Foreign key violation error. - - :param table: a table name in which the reference is directed. - :type table: str - :param constraint: a problematic constraint name. - :type constraint: str - :param key: a broken reference key name. - :type key: str - :param key_table: a table name which contains the key. - :type key_table: str - """ - - def __init__(self, table, constraint, key, key_table, - inner_exception=None): - self.table = table - self.constraint = constraint - self.key = key - self.key_table = key_table - super(DBReferenceError, self).__init__(inner_exception) - - -class DBDeadlock(DBError): - - """Database dead lock error. - - Deadlock is a situation that occurs when two or more different database - sessions have some data locked, and each database session requests a lock - on the data that another, different, session has already locked. - """ - - def __init__(self, inner_exception=None): - super(DBDeadlock, self).__init__(inner_exception) - - -class DBInvalidUnicodeParameter(Exception): - - """Database unicode error. - - Raised when unicode parameter is passed to a database - without encoding directive. - """ - - message = _("Invalid Parameter: " - "Encoding directive wasn't provided.") - - -class DbMigrationError(DBError): - - """Wrapped migration specific exception. - - Raised when migrations couldn't be completed successfully. - """ - - def __init__(self, message=None): - super(DbMigrationError, self).__init__(message) - - -class DBConnectionError(DBError): - - """Wrapped connection specific exception. - - Raised when database connection is failed. - """ - - pass - - -class InvalidSortKey(Exception): - """A sort key destined for database query usage is invalid.""" - - message = _("Sort key supplied was not valid.") - - -class ColumnError(Exception): - """Error raised when no column or an invalid column is found.""" - - -class BackendNotAvailable(Exception): - """Error raised when a particular database backend is not available - - within a test suite. - - """ +from oslo_db.exception import * # noqa diff --git a/oslo/db/options.py b/oslo/db/options.py index b8550644..20884af0 100644 --- a/oslo/db/options.py +++ b/oslo/db/options.py @@ -1,220 +1,15 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at +# All Rights Reserved. # -# http://www.apache.org/licenses/LICENSE-2.0 +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. -import copy - -from oslo.config import cfg - - -database_opts = [ - cfg.StrOpt('sqlite_db', - deprecated_group='DEFAULT', - default='oslo.sqlite', - help='The file name to use with SQLite.'), - cfg.BoolOpt('sqlite_synchronous', - deprecated_group='DEFAULT', - default=True, - help='If True, SQLite uses synchronous mode.'), - cfg.StrOpt('backend', - default='sqlalchemy', - deprecated_name='db_backend', - deprecated_group='DEFAULT', - help='The back end to use for the database.'), - cfg.StrOpt('connection', - help='The SQLAlchemy connection string to use to connect to ' - 'the database.', - secret=True, - deprecated_opts=[cfg.DeprecatedOpt('sql_connection', - group='DEFAULT'), - cfg.DeprecatedOpt('sql_connection', - group='DATABASE'), - cfg.DeprecatedOpt('connection', - group='sql'), ]), - cfg.StrOpt('slave_connection', - secret=True, - help='The SQLAlchemy connection string to use to connect to the' - ' slave database.'), - cfg.StrOpt('mysql_sql_mode', - default='TRADITIONAL', - help='The SQL mode to be used for MySQL sessions. ' - 'This option, including the default, overrides any ' - 'server-set SQL mode. To use whatever SQL mode ' - 'is set by the server configuration, ' - 'set this to no value. Example: mysql_sql_mode='), - cfg.IntOpt('idle_timeout', - default=3600, - deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', - group='DEFAULT'), - cfg.DeprecatedOpt('sql_idle_timeout', - group='DATABASE'), - cfg.DeprecatedOpt('idle_timeout', - group='sql')], - help='Timeout before idle SQL connections are reaped.'), - cfg.IntOpt('min_pool_size', - default=1, - deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', - group='DEFAULT'), - cfg.DeprecatedOpt('sql_min_pool_size', - group='DATABASE')], - help='Minimum number of SQL connections to keep open in a ' - 'pool.'), - cfg.IntOpt('max_pool_size', - deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', - group='DEFAULT'), - cfg.DeprecatedOpt('sql_max_pool_size', - group='DATABASE')], - help='Maximum number of SQL connections to keep open in a ' - 'pool.'), - cfg.IntOpt('max_retries', - default=10, - deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', - group='DEFAULT'), - cfg.DeprecatedOpt('sql_max_retries', - group='DATABASE')], - help='Maximum number of database connection retries ' - 'during startup. Set to -1 to specify an infinite ' - 'retry count.'), - cfg.IntOpt('retry_interval', - default=10, - deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', - group='DEFAULT'), - cfg.DeprecatedOpt('reconnect_interval', - group='DATABASE')], - help='Interval between retries of opening a SQL connection.'), - cfg.IntOpt('max_overflow', - deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', - group='DEFAULT'), - cfg.DeprecatedOpt('sqlalchemy_max_overflow', - group='DATABASE')], - help='If set, use this value for max_overflow with ' - 'SQLAlchemy.'), - cfg.IntOpt('connection_debug', - default=0, - deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', - group='DEFAULT')], - help='Verbosity of SQL debugging information: 0=None, ' - '100=Everything.'), - cfg.BoolOpt('connection_trace', - default=False, - deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', - group='DEFAULT')], - help='Add Python stack traces to SQL as comment strings.'), - cfg.IntOpt('pool_timeout', - deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', - group='DATABASE')], - help='If set, use this value for pool_timeout with ' - 'SQLAlchemy.'), - cfg.BoolOpt('use_db_reconnect', - default=False, - help='Enable the experimental use of database reconnect ' - 'on connection lost.'), - cfg.IntOpt('db_retry_interval', - default=1, - help='Seconds between database connection retries.'), - cfg.BoolOpt('db_inc_retry_interval', - default=True, - help='If True, increases the interval between database ' - 'connection retries up to db_max_retry_interval.'), - cfg.IntOpt('db_max_retry_interval', - default=10, - help='If db_inc_retry_interval is set, the ' - 'maximum seconds between database connection retries.'), - cfg.IntOpt('db_max_retries', - default=20, - help='Maximum database connection retries before error is ' - 'raised. Set to -1 to specify an infinite retry ' - 'count.'), -] - - -def set_defaults(conf, connection=None, sqlite_db=None, - max_pool_size=None, max_overflow=None, - pool_timeout=None): - """Set defaults for configuration variables. - - Overrides default options values. - - :param conf: Config instance specified to set default options in it. Using - of instances instead of a global config object prevents conflicts between - options declaration. - :type conf: oslo.config.cfg.ConfigOpts instance. - - :keyword connection: SQL connection string. - Valid SQLite URL forms are: - * sqlite:///:memory: (or, sqlite://) - * sqlite:///relative/path/to/file.db - * sqlite:////absolute/path/to/file.db - :type connection: str - - :keyword sqlite_db: path to SQLite database file. - :type sqlite_db: str - - :keyword max_pool_size: maximum connections pool size. The size of the pool - to be maintained, defaults to 5, will be used if value of the parameter is - `None`. This is the largest number of connections that will be kept - persistently in the pool. Note that the pool begins with no connections; - once this number of connections is requested, that number of connections - will remain. - :type max_pool_size: int - :default max_pool_size: None - - :keyword max_overflow: The maximum overflow size of the pool. When the - number of checked-out connections reaches the size set in pool_size, - additional connections will be returned up to this limit. When those - additional connections are returned to the pool, they are disconnected and - discarded. It follows then that the total number of simultaneous - connections the pool will allow is pool_size + max_overflow, and the total - number of "sleeping" connections the pool will allow is pool_size. - max_overflow can be set to -1 to indicate no overflow limit; no limit will - be placed on the total number of concurrent connections. Defaults to 10, - will be used if value of the parameter in `None`. - :type max_overflow: int - :default max_overflow: None - - :keyword pool_timeout: The number of seconds to wait before giving up on - returning a connection. Defaults to 30, will be used if value of the - parameter is `None`. - :type pool_timeout: int - :default pool_timeout: None - """ - - conf.register_opts(database_opts, group='database') - - if connection is not None: - conf.set_default('connection', connection, group='database') - if sqlite_db is not None: - conf.set_default('sqlite_db', sqlite_db, group='database') - if max_pool_size is not None: - conf.set_default('max_pool_size', max_pool_size, group='database') - if max_overflow is not None: - conf.set_default('max_overflow', max_overflow, group='database') - if pool_timeout is not None: - conf.set_default('pool_timeout', pool_timeout, group='database') - - -def list_opts(): - """Returns a list of oslo.config options available in the library. - - The returned list includes all oslo.config options which may be registered - at runtime by the library. - - Each element of the list is a tuple. The first element is the name of the - group under which the list of elements in the second element will be - registered. A group name of None corresponds to the [DEFAULT] group in - config files. - - The purpose of this is to allow tools like the Oslo sample config file - generator to discover the options exposed to users by this library. - - :returns: a list of (group_name, opts) tuples - """ - return [('database', copy.deepcopy(database_opts))] +from oslo_db.options import * # noqa diff --git a/oslo/db/sqlalchemy/compat/__init__.py b/oslo/db/sqlalchemy/compat/__init__.py index 2ffe2933..86436ec4 100644 --- a/oslo/db/sqlalchemy/compat/__init__.py +++ b/oslo/db/sqlalchemy/compat/__init__.py @@ -1,3 +1,5 @@ +# All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -9,22 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -"""compatiblity extensions for SQLAlchemy versions. -Elements within this module provide SQLAlchemy features that have been -added at some point but for which oslo.db provides a compatible versions -for previous SQLAlchemy versions. - -""" -from oslo.db.sqlalchemy.compat import engine_connect as _e_conn -from oslo.db.sqlalchemy.compat import handle_error as _h_err - -# trying to get: "from oslo.db.sqlalchemy import compat; compat.handle_error" -# flake8 won't let me import handle_error directly -engine_connect = _e_conn.engine_connect -handle_error = _h_err.handle_error -handle_connect_context = _h_err.handle_connect_context - -__all__ = [ - 'engine_connect', 'handle_error', - 'handle_connect_context'] +from oslo_db.sqlalchemy.compat import engine_connect # noqa +from oslo_db.sqlalchemy.compat import handle_error # noqa +from oslo_db.sqlalchemy.compat import utils # noqa diff --git a/oslo/db/sqlalchemy/compat/utils.py b/oslo/db/sqlalchemy/compat/utils.py index fa6c3e77..b910a16f 100644 --- a/oslo/db/sqlalchemy/compat/utils.py +++ b/oslo/db/sqlalchemy/compat/utils.py @@ -1,3 +1,5 @@ +# All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -9,18 +11,5 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -import re -import sqlalchemy - - -_SQLA_VERSION = tuple( - int(num) if re.match(r'^\d+$', num) else num - for num in sqlalchemy.__version__.split(".") -) - -sqla_100 = _SQLA_VERSION >= (1, 0, 0) -sqla_097 = _SQLA_VERSION >= (0, 9, 7) -sqla_094 = _SQLA_VERSION >= (0, 9, 4) -sqla_090 = _SQLA_VERSION >= (0, 9, 0) -sqla_08 = _SQLA_VERSION >= (0, 8) +from oslo_db.sqlalchemy.compat.utils import * # noqa diff --git a/oslo/db/sqlalchemy/exc_filters.py b/oslo/db/sqlalchemy/exc_filters.py index f283f08f..24b9cec0 100644 --- a/oslo/db/sqlalchemy/exc_filters.py +++ b/oslo/db/sqlalchemy/exc_filters.py @@ -1,3 +1,5 @@ +# All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -9,350 +11,5 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -"""Define exception redefinitions for SQLAlchemy DBAPI exceptions.""" -import collections -import logging -import re - -from sqlalchemy import exc as sqla_exc - -from oslo.db._i18n import _LE -from oslo.db import exception -from oslo.db.sqlalchemy import compat - - -LOG = logging.getLogger(__name__) - - -_registry = collections.defaultdict( - lambda: collections.defaultdict( - list - ) -) - - -def filters(dbname, exception_type, regex): - """Mark a function as receiving a filtered exception. - - :param dbname: string database name, e.g. 'mysql' - :param exception_type: a SQLAlchemy database exception class, which - extends from :class:`sqlalchemy.exc.DBAPIError`. - :param regex: a string, or a tuple of strings, that will be processed - as matching regular expressions. - - """ - def _receive(fn): - _registry[dbname][exception_type].extend( - (fn, re.compile(reg)) - for reg in - ((regex,) if not isinstance(regex, tuple) else regex) - ) - return fn - return _receive - - -# NOTE(zzzeek) - for Postgresql, catch both OperationalError, as the -# actual error is -# psycopg2.extensions.TransactionRollbackError(OperationalError), -# as well as sqlalchemy.exc.DBAPIError, as SQLAlchemy will reraise it -# as this until issue #3075 is fixed. -@filters("mysql", sqla_exc.OperationalError, r"^.*\b1213\b.*Deadlock found.*") -@filters("mysql", sqla_exc.OperationalError, - r"^.*\b1205\b.*Lock wait timeout exceeded.*") -@filters("mysql", sqla_exc.InternalError, r"^.*\b1213\b.*Deadlock found.*") -@filters("postgresql", sqla_exc.OperationalError, r"^.*deadlock detected.*") -@filters("postgresql", sqla_exc.DBAPIError, r"^.*deadlock detected.*") -@filters("ibm_db_sa", sqla_exc.DBAPIError, r"^.*SQL0911N.*") -def _deadlock_error(operational_error, match, engine_name, is_disconnect): - """Filter for MySQL or Postgresql deadlock error. - - NOTE(comstud): In current versions of DB backends, Deadlock violation - messages follow the structure: - - mysql+mysqldb: - (OperationalError) (1213, 'Deadlock found when trying to get lock; try ' - 'restarting transaction') - - mysql+mysqlconnector: - (InternalError) 1213 (40001): Deadlock found when trying to get lock; try - restarting transaction - - postgresql: - (TransactionRollbackError) deadlock detected - - - ibm_db_sa: - SQL0911N The current transaction has been rolled back because of a - deadlock or timeout - - """ - raise exception.DBDeadlock(operational_error) - - -@filters("mysql", sqla_exc.IntegrityError, - r"^.*\b1062\b.*Duplicate entry '(?P[^']+)'" - r" for key '(?P[^']+)'.*$") -# NOTE(pkholkin): the first regex is suitable only for PostgreSQL 9.x versions -# the second regex is suitable for PostgreSQL 8.x versions -@filters("postgresql", sqla_exc.IntegrityError, - (r'^.*duplicate\s+key.*"(?P[^"]+)"\s*\n.*' - r'Key\s+\((?P.*)\)=\((?P.*)\)\s+already\s+exists.*$', - r"^.*duplicate\s+key.*\"(?P[^\"]+)\"\s*\n.*$")) -def _default_dupe_key_error(integrity_error, match, engine_name, - is_disconnect): - """Filter for MySQL or Postgresql duplicate key error. - - note(boris-42): In current versions of DB backends unique constraint - violation messages follow the structure: - - postgres: - 1 column - (IntegrityError) duplicate key value violates unique - constraint "users_c1_key" - N columns - (IntegrityError) duplicate key value violates unique - constraint "name_of_our_constraint" - - mysql+mysqldb: - 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key - 'c1'") - N columns - (IntegrityError) (1062, "Duplicate entry 'values joined - with -' for key 'name_of_our_constraint'") - - mysql+mysqlconnector: - 1 column - (IntegrityError) 1062 (23000): Duplicate entry 'value_of_c1' for - key 'c1' - N columns - (IntegrityError) 1062 (23000): Duplicate entry 'values - joined with -' for key 'name_of_our_constraint' - - - - """ - - columns = match.group('columns') - - # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2" - # where `t` it is table name and columns `c1`, `c2` - # are in UniqueConstraint. - uniqbase = "uniq_" - if not columns.startswith(uniqbase): - if engine_name == "postgresql": - columns = [columns[columns.index("_") + 1:columns.rindex("_")]] - else: - columns = [columns] - else: - columns = columns[len(uniqbase):].split("0")[1:] - - value = match.groupdict().get('value') - - raise exception.DBDuplicateEntry(columns, integrity_error, value) - - -@filters("sqlite", sqla_exc.IntegrityError, - (r"^.*columns?(?P[^)]+)(is|are)\s+not\s+unique$", - r"^.*UNIQUE\s+constraint\s+failed:\s+(?P.+)$", - r"^.*PRIMARY\s+KEY\s+must\s+be\s+unique.*$")) -def _sqlite_dupe_key_error(integrity_error, match, engine_name, is_disconnect): - """Filter for SQLite duplicate key error. - - note(boris-42): In current versions of DB backends unique constraint - violation messages follow the structure: - - sqlite: - 1 column - (IntegrityError) column c1 is not unique - N columns - (IntegrityError) column c1, c2, ..., N are not unique - - sqlite since 3.7.16: - 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1 - N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2 - - sqlite since 3.8.2: - (IntegrityError) PRIMARY KEY must be unique - - """ - columns = [] - # NOTE(ochuprykov): We can get here by last filter in which there are no - # groups. Trying to access the substring that matched by - # the group will lead to IndexError. In this case just - # pass empty list to exception.DBDuplicateEntry - try: - columns = match.group('columns') - columns = [c.split('.')[-1] for c in columns.strip().split(", ")] - except IndexError: - pass - - raise exception.DBDuplicateEntry(columns, integrity_error) - - -@filters("sqlite", sqla_exc.IntegrityError, - r"(?i).*foreign key constraint failed") -@filters("postgresql", sqla_exc.IntegrityError, - r".*on table \"(?P[^\"]+)\" violates " - "foreign key constraint \"(?P[^\"]+)\"\s*\n" - "DETAIL: Key \((?P.+)\)=\(.+\) " - "is not present in table " - "\"(?P[^\"]+)\".") -@filters("mysql", sqla_exc.IntegrityError, - r".* 'Cannot add or update a child row: " - 'a foreign key constraint fails \([`"].+[`"]\.[`"](?P
.+)[`"], ' - 'CONSTRAINT [`"](?P.+)[`"] FOREIGN KEY ' - '\([`"](?P.+)[`"]\) REFERENCES [`"](?P.+)[`"] ') -def _foreign_key_error(integrity_error, match, engine_name, is_disconnect): - """Filter for foreign key errors.""" - - try: - table = match.group("table") - except IndexError: - table = None - try: - constraint = match.group("constraint") - except IndexError: - constraint = None - try: - key = match.group("key") - except IndexError: - key = None - try: - key_table = match.group("key_table") - except IndexError: - key_table = None - - raise exception.DBReferenceError(table, constraint, key, key_table, - integrity_error) - - -@filters("ibm_db_sa", sqla_exc.IntegrityError, r"^.*SQL0803N.*$") -def _db2_dupe_key_error(integrity_error, match, engine_name, is_disconnect): - """Filter for DB2 duplicate key errors. - - N columns - (IntegrityError) SQL0803N One or more values in the INSERT - statement, UPDATE statement, or foreign key update caused by a - DELETE statement are not valid because the primary key, unique - constraint or unique index identified by "2" constrains table - "NOVA.KEY_PAIRS" from having duplicate values for the index - key. - - """ - - # NOTE(mriedem): The ibm_db_sa integrity error message doesn't provide the - # columns so we have to omit that from the DBDuplicateEntry error. - raise exception.DBDuplicateEntry([], integrity_error) - - -@filters("mysql", sqla_exc.DBAPIError, r".*\b1146\b") -def _raise_mysql_table_doesnt_exist_asis( - error, match, engine_name, is_disconnect): - """Raise MySQL error 1146 as is. - - Raise MySQL error 1146 as is, so that it does not conflict with - the MySQL dialect's checking a table not existing. - """ - - raise error - - -@filters("*", sqla_exc.OperationalError, r".*") -def _raise_operational_errors_directly_filter(operational_error, - match, engine_name, - is_disconnect): - """Filter for all remaining OperationalError classes and apply. - - Filter for all remaining OperationalError classes and apply - special rules. - """ - if is_disconnect: - # operational errors that represent disconnect - # should be wrapped - raise exception.DBConnectionError(operational_error) - else: - # NOTE(comstud): A lot of code is checking for OperationalError - # so let's not wrap it for now. - raise operational_error - - -@filters("mysql", sqla_exc.OperationalError, r".*\(.*(?:2002|2003|2006|2013)") -@filters("ibm_db_sa", sqla_exc.OperationalError, r".*(?:30081)") -def _is_db_connection_error(operational_error, match, engine_name, - is_disconnect): - """Detect the exception as indicating a recoverable error on connect.""" - raise exception.DBConnectionError(operational_error) - - -@filters("*", sqla_exc.DBAPIError, r".*") -def _raise_for_remaining_DBAPIError(error, match, engine_name, is_disconnect): - """Filter for remaining DBAPIErrors. - - Filter for remaining DBAPIErrors and wrap if they represent - a disconnect error. - """ - if is_disconnect: - raise exception.DBConnectionError(error) - else: - LOG.exception( - _LE('DBAPIError exception wrapped from %s') % error) - raise exception.DBError(error) - - -@filters('*', UnicodeEncodeError, r".*") -def _raise_for_unicode_encode(error, match, engine_name, is_disconnect): - raise exception.DBInvalidUnicodeParameter() - - -@filters("*", Exception, r".*") -def _raise_for_all_others(error, match, engine_name, is_disconnect): - LOG.exception(_LE('DB exception wrapped.')) - raise exception.DBError(error) - - -def handler(context): - """Iterate through available filters and invoke those which match. - - The first one which raises wins. The order in which the filters - are attempted is sorted by specificity - dialect name or "*", - exception class per method resolution order (``__mro__``). - Method resolution order is used so that filter rules indicating a - more specific exception class are attempted first. - - """ - def _dialect_registries(engine): - if engine.dialect.name in _registry: - yield _registry[engine.dialect.name] - if '*' in _registry: - yield _registry['*'] - - for per_dialect in _dialect_registries(context.engine): - for exc in ( - context.sqlalchemy_exception, - context.original_exception): - for super_ in exc.__class__.__mro__: - if super_ in per_dialect: - regexp_reg = per_dialect[super_] - for fn, regexp in regexp_reg: - match = regexp.match(exc.args[0]) - if match: - try: - fn( - exc, - match, - context.engine.dialect.name, - context.is_disconnect) - except exception.DBConnectionError: - context.is_disconnect = True - raise - - -def register_engine(engine): - compat.handle_error(engine, handler) - - -def handle_connect_error(engine): - """Handle connect error. - - Provide a special context that will allow on-connect errors - to be treated within the filtering context. - - This routine is dependent on SQLAlchemy version, as version 1.0.0 - provides this functionality natively. - - """ - with compat.handle_connect_context(handler, engine): - return engine.connect() +from oslo_db.sqlalchemy.exc_filters import * # noqa diff --git a/oslo/db/sqlalchemy/migration.py b/oslo/db/sqlalchemy/migration.py index f1cecdd8..a8d7cee8 100644 --- a/oslo/db/sqlalchemy/migration.py +++ b/oslo/db/sqlalchemy/migration.py @@ -1,160 +1,15 @@ -# coding=utf-8 - -# Copyright (c) 2013 OpenStack Foundation # All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# Base on code in migrate/changeset/databases/sqlite.py which is under -# the following license: -# -# The MIT License -# -# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. -import os - -from migrate import exceptions as versioning_exceptions -from migrate.versioning import api as versioning_api -from migrate.versioning.repository import Repository -import sqlalchemy - -from oslo.db._i18n import _ -from oslo.db import exception - - -def db_sync(engine, abs_path, version=None, init_version=0, sanity_check=True): - """Upgrade or downgrade a database. - - Function runs the upgrade() or downgrade() functions in change scripts. - - :param engine: SQLAlchemy engine instance for a given database - :param abs_path: Absolute path to migrate repository. - :param version: Database will upgrade/downgrade until this version. - If None - database will update to the latest - available version. - :param init_version: Initial database version - :param sanity_check: Require schema sanity checking for all tables - """ - - if version is not None: - try: - version = int(version) - except ValueError: - raise exception.DbMigrationError( - message=_("version should be an integer")) - - current_version = db_version(engine, abs_path, init_version) - repository = _find_migrate_repo(abs_path) - if sanity_check: - _db_schema_sanity_check(engine) - if version is None or version > current_version: - return versioning_api.upgrade(engine, repository, version) - else: - return versioning_api.downgrade(engine, repository, - version) - - -def _db_schema_sanity_check(engine): - """Ensure all database tables were created with required parameters. - - :param engine: SQLAlchemy engine instance for a given database - - """ - - if engine.name == 'mysql': - onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION ' - 'from information_schema.TABLES ' - 'where TABLE_SCHEMA=%s and ' - 'TABLE_COLLATION NOT LIKE \'%%utf8%%\'') - - # NOTE(morganfainberg): exclude the sqlalchemy-migrate and alembic - # versioning tables from the tables we need to verify utf8 status on. - # Non-standard table names are not supported. - EXCLUDED_TABLES = ['migrate_version', 'alembic_version'] - - table_names = [res[0] for res in - engine.execute(onlyutf8_sql, engine.url.database) if - res[0].lower() not in EXCLUDED_TABLES] - - if len(table_names) > 0: - raise ValueError(_('Tables "%s" have non utf8 collation, ' - 'please make sure all tables are CHARSET=utf8' - ) % ','.join(table_names)) - - -def db_version(engine, abs_path, init_version): - """Show the current version of the repository. - - :param engine: SQLAlchemy engine instance for a given database - :param abs_path: Absolute path to migrate repository - :param version: Initial database version - """ - repository = _find_migrate_repo(abs_path) - try: - return versioning_api.db_version(engine, repository) - except versioning_exceptions.DatabaseNotControlledError: - meta = sqlalchemy.MetaData() - meta.reflect(bind=engine) - tables = meta.tables - if len(tables) == 0 or 'alembic_version' in tables: - db_version_control(engine, abs_path, version=init_version) - return versioning_api.db_version(engine, repository) - else: - raise exception.DbMigrationError( - message=_( - "The database is not under version control, but has " - "tables. Please stamp the current version of the schema " - "manually.")) - - -def db_version_control(engine, abs_path, version=None): - """Mark a database as under this repository's version control. - - Once a database is under version control, schema changes should - only be done via change scripts in this repository. - - :param engine: SQLAlchemy engine instance for a given database - :param abs_path: Absolute path to migrate repository - :param version: Initial database version - """ - repository = _find_migrate_repo(abs_path) - versioning_api.version_control(engine, repository, version) - return version - - -def _find_migrate_repo(abs_path): - """Get the project's change script repository - - :param abs_path: Absolute path to migrate repository - """ - if not os.path.exists(abs_path): - raise exception.DbMigrationError("Path %s not found" % abs_path) - return Repository(abs_path) +from oslo_db.sqlalchemy.migration import * # noqa diff --git a/oslo/db/sqlalchemy/migration_cli/__init__.py b/oslo/db/sqlalchemy/migration_cli/__init__.py index e69de29b..79972c3c 100644 --- a/oslo/db/sqlalchemy/migration_cli/__init__.py +++ b/oslo/db/sqlalchemy/migration_cli/__init__.py @@ -0,0 +1,18 @@ +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_db.sqlalchemy.migration_cli import ext_alembic # noqa +from oslo_db.sqlalchemy.migration_cli import ext_base # noqa +from oslo_db.sqlalchemy.migration_cli import ext_migrate # noqa +from oslo_db.sqlalchemy.migration_cli import manager # noqa diff --git a/oslo/db/sqlalchemy/models.py b/oslo/db/sqlalchemy/models.py index 818c1b40..6bcb8221 100644 --- a/oslo/db/sqlalchemy/models.py +++ b/oslo/db/sqlalchemy/models.py @@ -1,8 +1,3 @@ -# Copyright (c) 2011 X.commerce, a business unit of eBay Inc. -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# Copyright 2011 Piston Cloud Computing, Inc. -# Copyright 2012 Cloudscaling Group, Inc. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -16,113 +11,5 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -""" -SQLAlchemy models. -""" -import six - -from oslo.utils import timeutils -from sqlalchemy import Column, Integer -from sqlalchemy import DateTime -from sqlalchemy.orm import object_mapper - - -class ModelBase(six.Iterator): - """Base class for models.""" - __table_initialized__ = False - - def save(self, session): - """Save this object.""" - - # NOTE(boris-42): This part of code should be look like: - # session.add(self) - # session.flush() - # But there is a bug in sqlalchemy and eventlet that - # raises NoneType exception if there is no running - # transaction and rollback is called. As long as - # sqlalchemy has this bug we have to create transaction - # explicitly. - with session.begin(subtransactions=True): - session.add(self) - session.flush() - - def __setitem__(self, key, value): - setattr(self, key, value) - - def __getitem__(self, key): - return getattr(self, key) - - def __contains__(self, key): - return hasattr(self, key) - - def get(self, key, default=None): - return getattr(self, key, default) - - @property - def _extra_keys(self): - """Specifies custom fields - - Subclasses can override this property to return a list - of custom fields that should be included in their dict - representation. - - For reference check tests/db/sqlalchemy/test_models.py - """ - return [] - - def __iter__(self): - columns = list(dict(object_mapper(self).columns).keys()) - # NOTE(russellb): Allow models to specify other keys that can be looked - # up, beyond the actual db columns. An example would be the 'name' - # property for an Instance. - columns.extend(self._extra_keys) - - return ModelIterator(self, iter(columns)) - - def update(self, values): - """Make the model object behave like a dict.""" - for k, v in six.iteritems(values): - setattr(self, k, v) - - def iteritems(self): - """Make the model object behave like a dict. - - Includes attributes from joins. - """ - local = dict(self) - joined = dict([(k, v) for k, v in six.iteritems(self.__dict__) - if not k[0] == '_']) - local.update(joined) - return six.iteritems(local) - - -class ModelIterator(ModelBase, six.Iterator): - - def __init__(self, model, columns): - self.model = model - self.i = columns - - def __iter__(self): - return self - - # In Python 3, __next__() has replaced next(). - def __next__(self): - n = six.advance_iterator(self.i) - return n, getattr(self.model, n) - - -class TimestampMixin(object): - created_at = Column(DateTime, default=lambda: timeutils.utcnow()) - updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow()) - - -class SoftDeleteMixin(object): - deleted_at = Column(DateTime) - deleted = Column(Integer, default=0) - - def soft_delete(self, session): - """Mark this object as deleted.""" - self.deleted = self.id - self.deleted_at = timeutils.utcnow() - self.save(session=session) +from oslo_db.sqlalchemy.models import * # noqa diff --git a/oslo/db/sqlalchemy/provision.py b/oslo/db/sqlalchemy/provision.py index 260cb5c8..c091b023 100644 --- a/oslo/db/sqlalchemy/provision.py +++ b/oslo/db/sqlalchemy/provision.py @@ -1,4 +1,3 @@ -# Copyright 2013 Mirantis.inc # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,495 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -"""Provision test environment for specific DB backends""" - -import abc -import argparse -import logging -import os -import random -import re -import string - -import six -from six import moves -import sqlalchemy -from sqlalchemy.engine import url as sa_url - -from oslo.db._i18n import _LI -from oslo.db import exception -from oslo.db.sqlalchemy import session -from oslo.db.sqlalchemy import utils - -LOG = logging.getLogger(__name__) - - -class ProvisionedDatabase(object): - """Represent a single database node that can be used for testing in - - a serialized fashion. - - ``ProvisionedDatabase`` includes features for full lifecycle management - of a node, in a way that is context-specific. Depending on how the - test environment runs, ``ProvisionedDatabase`` should know if it needs - to create and drop databases or if it is making use of a database that - is maintained by an external process. - - """ - - def __init__(self, database_type): - self.backend = Backend.backend_for_database_type(database_type) - self.db_token = _random_ident() - - self.backend.create_named_database(self.db_token) - self.engine = self.backend.provisioned_engine(self.db_token) - - def dispose(self): - self.engine.dispose() - self.backend.drop_named_database(self.db_token) - - -class Backend(object): - """Represent a particular database backend that may be provisionable. - - The ``Backend`` object maintains a database type (e.g. database without - specific driver type, such as "sqlite", "postgresql", etc.), - a target URL, a base ``Engine`` for that URL object that can be used - to provision databases and a ``BackendImpl`` which knows how to perform - operations against this type of ``Engine``. - - """ - - backends_by_database_type = {} - - def __init__(self, database_type, url): - self.database_type = database_type - self.url = url - self.verified = False - self.engine = None - self.impl = BackendImpl.impl(database_type) - Backend.backends_by_database_type[database_type] = self - - @classmethod - def backend_for_database_type(cls, database_type): - """Return and verify the ``Backend`` for the given database type. - - Creates the engine if it does not already exist and raises - ``BackendNotAvailable`` if it cannot be produced. - - :return: a base ``Engine`` that allows provisioning of databases. - - :raises: ``BackendNotAvailable``, if an engine for this backend - cannot be produced. - - """ - try: - backend = cls.backends_by_database_type[database_type] - except KeyError: - raise exception.BackendNotAvailable(database_type) - else: - return backend._verify() - - @classmethod - def all_viable_backends(cls): - """Return an iterator of all ``Backend`` objects that are present - - and provisionable. - - """ - - for backend in cls.backends_by_database_type.values(): - try: - yield backend._verify() - except exception.BackendNotAvailable: - pass - - def _verify(self): - """Verify that this ``Backend`` is available and provisionable. - - :return: this ``Backend`` - - :raises: ``BackendNotAvailable`` if the backend is not available. - - """ - - if not self.verified: - try: - eng = self._ensure_backend_available(self.url) - except exception.BackendNotAvailable: - raise - else: - self.engine = eng - finally: - self.verified = True - if self.engine is None: - raise exception.BackendNotAvailable(self.database_type) - return self - - @classmethod - def _ensure_backend_available(cls, url): - url = sa_url.make_url(str(url)) - try: - eng = sqlalchemy.create_engine(url) - except ImportError as i_e: - # SQLAlchemy performs an "import" of the DBAPI module - # within create_engine(). So if ibm_db_sa, cx_oracle etc. - # isn't installed, we get an ImportError here. - LOG.info( - _LI("The %(dbapi)s backend is unavailable: %(err)s"), - dict(dbapi=url.drivername, err=i_e)) - raise exception.BackendNotAvailable("No DBAPI installed") - else: - try: - conn = eng.connect() - except sqlalchemy.exc.DBAPIError as d_e: - # upon connect, SQLAlchemy calls dbapi.connect(). This - # usually raises OperationalError and should always at - # least raise a SQLAlchemy-wrapped DBAPI Error. - LOG.info( - _LI("The %(dbapi)s backend is unavailable: %(err)s"), - dict(dbapi=url.drivername, err=d_e) - ) - raise exception.BackendNotAvailable("Could not connect") - else: - conn.close() - return eng - - def create_named_database(self, ident): - """Create a database with the given name.""" - - self.impl.create_named_database(self.engine, ident) - - def drop_named_database(self, ident, conditional=False): - """Drop a database with the given name.""" - - self.impl.drop_named_database( - self.engine, ident, - conditional=conditional) - - def database_exists(self, ident): - """Return True if a database of the given name exists.""" - - return self.impl.database_exists(self.engine, ident) - - def provisioned_engine(self, ident): - """Given the URL of a particular database backend and the string - - name of a particular 'database' within that backend, return - an Engine instance whose connections will refer directly to the - named database. - - For hostname-based URLs, this typically involves switching just the - 'database' portion of the URL with the given name and creating - an engine. - - For URLs that instead deal with DSNs, the rules may be more custom; - for example, the engine may need to connect to the root URL and - then emit a command to switch to the named database. - - """ - return self.impl.provisioned_engine(self.url, ident) - - @classmethod - def _setup(cls): - """Initial startup feature will scan the environment for configured - - URLs and place them into the list of URLs we will use for provisioning. - - This searches through OS_TEST_DBAPI_ADMIN_CONNECTION for URLs. If - not present, we set up URLs based on the "opportunstic" convention, - e.g. username+password = "openstack_citest". - - The provisioning system will then use or discard these URLs as they - are requested, based on whether or not the target database is actually - found to be available. - - """ - configured_urls = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION', None) - if configured_urls: - configured_urls = configured_urls.split(";") - else: - configured_urls = [ - impl.create_opportunistic_driver_url() - for impl in BackendImpl.all_impls() - ] - - for url_str in configured_urls: - url = sa_url.make_url(url_str) - m = re.match(r'([^+]+?)(?:\+(.+))?$', url.drivername) - database_type, drivertype = m.group(1, 2) - Backend(database_type, url) - - -@six.add_metaclass(abc.ABCMeta) -class BackendImpl(object): - """Provide database-specific implementations of key provisioning - - functions. - - ``BackendImpl`` is owned by a ``Backend`` instance which delegates - to it for all database-specific features. - - """ - - @classmethod - def all_impls(cls): - """Return an iterator of all possible BackendImpl objects. - - These are BackendImpls that are implemented, but not - necessarily provisionable. - - """ - for database_type in cls.impl.reg: - if database_type == '*': - continue - yield BackendImpl.impl(database_type) - - @utils.dispatch_for_dialect("*") - def impl(drivername): - """Return a ``BackendImpl`` instance corresponding to the - - given driver name. - - This is a dispatched method which will refer to the constructor - of implementing subclasses. - - """ - raise NotImplementedError( - "No provision impl available for driver: %s" % drivername) - - def __init__(self, drivername): - self.drivername = drivername - - @abc.abstractmethod - def create_opportunistic_driver_url(self): - """Produce a string url known as the 'opportunistic' URL. - - This URL is one that corresponds to an established Openstack - convention for a pre-established database login, which, when - detected as available in the local environment, is automatically - used as a test platform for a specific type of driver. - - """ - - @abc.abstractmethod - def create_named_database(self, engine, ident): - """Create a database with the given name.""" - - @abc.abstractmethod - def drop_named_database(self, engine, ident, conditional=False): - """Drop a database with the given name.""" - - def provisioned_engine(self, base_url, ident): - """Return a provisioned engine. - - Given the URL of a particular database backend and the string - name of a particular 'database' within that backend, return - an Engine instance whose connections will refer directly to the - named database. - - For hostname-based URLs, this typically involves switching just the - 'database' portion of the URL with the given name and creating - an engine. - - For URLs that instead deal with DSNs, the rules may be more custom; - for example, the engine may need to connect to the root URL and - then emit a command to switch to the named database. - - """ - - url = sa_url.make_url(str(base_url)) - url.database = ident - return session.create_engine( - url, - logging_name="%s@%s" % (self.drivername, ident)) - - -@BackendImpl.impl.dispatch_for("mysql") -class MySQLBackendImpl(BackendImpl): - def create_opportunistic_driver_url(self): - return "mysql://openstack_citest:openstack_citest@localhost/" - - def create_named_database(self, engine, ident): - with engine.connect() as conn: - conn.execute("CREATE DATABASE %s" % ident) - - def drop_named_database(self, engine, ident, conditional=False): - with engine.connect() as conn: - if not conditional or self.database_exists(conn, ident): - conn.execute("DROP DATABASE %s" % ident) - - def database_exists(self, engine, ident): - return bool(engine.scalar("SHOW DATABASES LIKE '%s'" % ident)) - - -@BackendImpl.impl.dispatch_for("sqlite") -class SQLiteBackendImpl(BackendImpl): - def create_opportunistic_driver_url(self): - return "sqlite://" - - def create_named_database(self, engine, ident): - url = self._provisioned_database_url(engine.url, ident) - eng = sqlalchemy.create_engine(url) - eng.connect().close() - - def provisioned_engine(self, base_url, ident): - return session.create_engine( - self._provisioned_database_url(base_url, ident)) - - def drop_named_database(self, engine, ident, conditional=False): - url = self._provisioned_database_url(engine.url, ident) - filename = url.database - if filename and (not conditional or os.access(filename, os.F_OK)): - os.remove(filename) - - def database_exists(self, engine, ident): - url = self._provisioned_database_url(engine.url, ident) - filename = url.database - return not filename or os.access(filename, os.F_OK) - - def _provisioned_database_url(self, base_url, ident): - if base_url.database: - return sa_url.make_url("sqlite:////tmp/%s.db" % ident) - else: - return base_url - - -@BackendImpl.impl.dispatch_for("postgresql") -class PostgresqlBackendImpl(BackendImpl): - def create_opportunistic_driver_url(self): - return "postgresql://openstack_citest:openstack_citest"\ - "@localhost/postgres" - - def create_named_database(self, engine, ident): - with engine.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: - conn.execute("CREATE DATABASE %s" % ident) - - def drop_named_database(self, engine, ident, conditional=False): - with engine.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: - self._close_out_database_users(conn, ident) - if conditional: - conn.execute("DROP DATABASE IF EXISTS %s" % ident) - else: - conn.execute("DROP DATABASE %s" % ident) - - def database_exists(self, engine, ident): - return bool( - engine.scalar( - sqlalchemy.text( - "select datname from pg_database " - "where datname=:name"), name=ident) - ) - - def _close_out_database_users(self, conn, ident): - """Attempt to guarantee a database can be dropped. - - Optional feature which guarantees no connections with our - username are attached to the DB we're going to drop. - - This method has caveats; for one, the 'pid' column was named - 'procpid' prior to Postgresql 9.2. But more critically, - prior to 9.2 this operation required superuser permissions, - even if the connections we're closing are under the same username - as us. In more recent versions this restriction has been - lifted for same-user connections. - - """ - if conn.dialect.server_version_info >= (9, 2): - conn.execute( - sqlalchemy.text( - "select pg_terminate_backend(pid) " - "from pg_stat_activity " - "where usename=current_user and " - "pid != pg_backend_pid() " - "and datname=:dname" - ), dname=ident) - - -def _random_ident(): - return ''.join( - random.choice(string.ascii_lowercase) - for i in moves.range(10)) - - -def _echo_cmd(args): - idents = [_random_ident() for i in moves.range(args.instances_count)] - print("\n".join(idents)) - - -def _create_cmd(args): - idents = [_random_ident() for i in moves.range(args.instances_count)] - - for backend in Backend.all_viable_backends(): - for ident in idents: - backend.create_named_database(ident) - - print("\n".join(idents)) - - -def _drop_cmd(args): - for backend in Backend.all_viable_backends(): - for ident in args.instances: - backend.drop_named_database(ident, args.conditional) - -Backend._setup() - - -def main(argv=None): - """Command line interface to create/drop databases. - - ::create: Create test database with random names. - ::drop: Drop database created by previous command. - ::echo: create random names and display them; don't create. - """ - parser = argparse.ArgumentParser( - description='Controller to handle database creation and dropping' - ' commands.', - epilog='Typically called by the test runner, e.g. shell script, ' - 'testr runner via .testr.conf, or other system.') - subparsers = parser.add_subparsers( - help='Subcommands to manipulate temporary test databases.') - - create = subparsers.add_parser( - 'create', - help='Create temporary test databases.') - create.set_defaults(which=_create_cmd) - create.add_argument( - 'instances_count', - type=int, - help='Number of databases to create.') - - drop = subparsers.add_parser( - 'drop', - help='Drop temporary test databases.') - drop.set_defaults(which=_drop_cmd) - drop.add_argument( - 'instances', - nargs='+', - help='List of databases uri to be dropped.') - drop.add_argument( - '--conditional', - action="store_true", - help="Check if database exists first before dropping" - ) - - echo = subparsers.add_parser( - 'echo', - help="Create random database names and display only." - ) - echo.set_defaults(which=_echo_cmd) - echo.add_argument( - 'instances_count', - type=int, - help='Number of identifiers to create.') - - args = parser.parse_args(argv) - - cmd = args.which - cmd(args) - - -if __name__ == "__main__": - main() +from oslo_db.sqlalchemy.provision import * # noqa diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py index da877f1a..9a9fcf37 100644 --- a/oslo/db/sqlalchemy/session.py +++ b/oslo/db/sqlalchemy/session.py @@ -1,5 +1,3 @@ -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -14,834 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -"""Session Handling for SQLAlchemy backend. - -Recommended ways to use sessions within this framework: - -* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``. - `model_query()` will implicitly use a session when called without one - supplied. This is the ideal situation because it will allow queries - to be automatically retried if the database connection is interrupted. - - .. note:: Automatic retry will be enabled in a future patch. - - It is generally fine to issue several queries in a row like this. Even though - they may be run in separate transactions and/or separate sessions, each one - will see the data from the prior calls. If needed, undo- or rollback-like - functionality should be handled at a logical level. For an example, look at - the code around quotas and `reservation_rollback()`. - - Examples: - - .. code-block:: python - - def get_foo(context, foo): - return (model_query(context, models.Foo). - filter_by(foo=foo). - first()) - - def update_foo(context, id, newfoo): - (model_query(context, models.Foo). - filter_by(id=id). - update({'foo': newfoo})) - - def create_foo(context, values): - foo_ref = models.Foo() - foo_ref.update(values) - foo_ref.save() - return foo_ref - - -* Within the scope of a single method, keep all the reads and writes within - the context managed by a single session. In this way, the session's - `__exit__` handler will take care of calling `flush()` and `commit()` for - you. If using this approach, you should not explicitly call `flush()` or - `commit()`. Any error within the context of the session will cause the - session to emit a `ROLLBACK`. Database errors like `IntegrityError` will be - raised in `session`'s `__exit__` handler, and any try/except within the - context managed by `session` will not be triggered. And catching other - non-database errors in the session will not trigger the ROLLBACK, so - exception handlers should always be outside the session, unless the - developer wants to do a partial commit on purpose. If the connection is - dropped before this is possible, the database will implicitly roll back the - transaction. - - .. note:: Statements in the session scope will not be automatically retried. - - If you create models within the session, they need to be added, but you - do not need to call `model.save()`: - - .. code-block:: python - - def create_many_foo(context, foos): - session = sessionmaker() - with session.begin(): - for foo in foos: - foo_ref = models.Foo() - foo_ref.update(foo) - session.add(foo_ref) - - def update_bar(context, foo_id, newbar): - session = sessionmaker() - with session.begin(): - foo_ref = (model_query(context, models.Foo, session). - filter_by(id=foo_id). - first()) - (model_query(context, models.Bar, session). - filter_by(id=foo_ref['bar_id']). - update({'bar': newbar})) - - .. note:: `update_bar` is a trivially simple example of using - ``with session.begin``. Whereas `create_many_foo` is a good example of - when a transaction is needed, it is always best to use as few queries as - possible. - - The two queries in `update_bar` can be better expressed using a single query - which avoids the need for an explicit transaction. It can be expressed like - so: - - .. code-block:: python - - def update_bar(context, foo_id, newbar): - subq = (model_query(context, models.Foo.id). - filter_by(id=foo_id). - limit(1). - subquery()) - (model_query(context, models.Bar). - filter_by(id=subq.as_scalar()). - update({'bar': newbar})) - - For reference, this emits approximately the following SQL statement: - - .. code-block:: sql - - UPDATE bar SET bar = ${newbar} - WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1); - - .. note:: `create_duplicate_foo` is a trivially simple example of catching an - exception while using ``with session.begin``. Here create two duplicate - instances with same primary key, must catch the exception out of context - managed by a single session: - - .. code-block:: python - - def create_duplicate_foo(context): - foo1 = models.Foo() - foo2 = models.Foo() - foo1.id = foo2.id = 1 - session = sessionmaker() - try: - with session.begin(): - session.add(foo1) - session.add(foo2) - except exception.DBDuplicateEntry as e: - handle_error(e) - -* Passing an active session between methods. Sessions should only be passed - to private methods. The private method must use a subtransaction; otherwise - SQLAlchemy will throw an error when you call `session.begin()` on an existing - transaction. Public methods should not accept a session parameter and should - not be involved in sessions within the caller's scope. - - Note that this incurs more overhead in SQLAlchemy than the above means - due to nesting transactions, and it is not possible to implicitly retry - failed database operations when using this approach. - - This also makes code somewhat more difficult to read and debug, because a - single database transaction spans more than one method. Error handling - becomes less clear in this situation. When this is needed for code clarity, - it should be clearly documented. - - .. code-block:: python - - def myfunc(foo): - session = sessionmaker() - with session.begin(): - # do some database things - bar = _private_func(foo, session) - return bar - - def _private_func(foo, session=None): - if not session: - session = sessionmaker() - with session.begin(subtransaction=True): - # do some other database things - return bar - - -There are some things which it is best to avoid: - -* Don't keep a transaction open any longer than necessary. - - This means that your ``with session.begin()`` block should be as short - as possible, while still containing all the related calls for that - transaction. - -* Avoid ``with_lockmode('UPDATE')`` when possible. - - In MySQL/InnoDB, when a ``SELECT ... FOR UPDATE`` query does not match - any rows, it will take a gap-lock. This is a form of write-lock on the - "gap" where no rows exist, and prevents any other writes to that space. - This can effectively prevent any INSERT into a table by locking the gap - at the end of the index. Similar problems will occur if the SELECT FOR UPDATE - has an overly broad WHERE clause, or doesn't properly use an index. - - One idea proposed at ODS Fall '12 was to use a normal SELECT to test the - number of rows matching a query, and if only one row is returned, - then issue the SELECT FOR UPDATE. - - The better long-term solution is to use - ``INSERT .. ON DUPLICATE KEY UPDATE``. - However, this can not be done until the "deleted" columns are removed and - proper UNIQUE constraints are added to the tables. - - -Enabling soft deletes: - -* To use/enable soft-deletes, the `SoftDeleteMixin` must be added - to your model class. For example: - - .. code-block:: python - - class NovaBase(models.SoftDeleteMixin, models.ModelBase): - pass - - -Efficient use of soft deletes: - -* There are two possible ways to mark a record as deleted: - `model.soft_delete()` and `query.soft_delete()`. - - The `model.soft_delete()` method works with a single already-fetched entry. - `query.soft_delete()` makes only one db request for all entries that - correspond to the query. - -* In almost all cases you should use `query.soft_delete()`. Some examples: - - .. code-block:: python - - def soft_delete_bar(): - count = model_query(BarModel).find(some_condition).soft_delete() - if count == 0: - raise Exception("0 entries were soft deleted") - - def complex_soft_delete_with_synchronization_bar(session=None): - if session is None: - session = sessionmaker() - with session.begin(subtransactions=True): - count = (model_query(BarModel). - find(some_condition). - soft_delete(synchronize_session=True)) - # Here synchronize_session is required, because we - # don't know what is going on in outer session. - if count == 0: - raise Exception("0 entries were soft deleted") - -* There is only one situation where `model.soft_delete()` is appropriate: when - you fetch a single record, work with it, and mark it as deleted in the same - transaction. - - .. code-block:: python - - def soft_delete_bar_model(): - session = sessionmaker() - with session.begin(): - bar_ref = model_query(BarModel).find(some_condition).first() - # Work with bar_ref - bar_ref.soft_delete(session=session) - - However, if you need to work with all entries that correspond to query and - then soft delete them you should use the `query.soft_delete()` method: - - .. code-block:: python - - def soft_delete_multi_models(): - session = sessionmaker() - with session.begin(): - query = (model_query(BarModel, session=session). - find(some_condition)) - model_refs = query.all() - # Work with model_refs - query.soft_delete(synchronize_session=False) - # synchronize_session=False should be set if there is no outer - # session and these entries are not used after this. - - When working with many rows, it is very important to use query.soft_delete, - which issues a single query. Using `model.soft_delete()`, as in the following - example, is very inefficient. - - .. code-block:: python - - for bar_ref in bar_refs: - bar_ref.soft_delete(session=session) - # This will produce count(bar_refs) db requests. - -""" - -import itertools -import logging -import re -import time - -from oslo.utils import timeutils -import six -import sqlalchemy.orm -from sqlalchemy import pool -from sqlalchemy.sql.expression import literal_column -from sqlalchemy.sql.expression import select - -from oslo.db._i18n import _LW -from oslo.db import exception -from oslo.db import options -from oslo.db.sqlalchemy import compat -from oslo.db.sqlalchemy import exc_filters -from oslo.db.sqlalchemy import utils - -LOG = logging.getLogger(__name__) - - -def _thread_yield(dbapi_con, con_record): - """Ensure other greenthreads get a chance to be executed. - - If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will - execute instead of time.sleep(0). - Force a context switch. With common database backends (eg MySQLdb and - sqlite), there is no implicit yield caused by network I/O since they are - implemented by C libraries that eventlet cannot monkey patch. - """ - time.sleep(0) - - -def _connect_ping_listener(connection, branch): - """Ping the server at connection startup. - - Ping the server at transaction begin and transparently reconnect - if a disconnect exception occurs. - """ - if branch: - return - - # turn off "close with result". This can also be accomplished - # by branching the connection, however just setting the flag is - # more performant and also doesn't get involved with some - # connection-invalidation awkardness that occurs (see - # https://bitbucket.org/zzzeek/sqlalchemy/issue/3215/) - save_should_close_with_result = connection.should_close_with_result - connection.should_close_with_result = False - try: - # run a SELECT 1. use a core select() so that - # any details like that needed by Oracle, DB2 etc. are handled. - connection.scalar(select([1])) - except exception.DBConnectionError: - # catch DBConnectionError, which is raised by the filter - # system. - # disconnect detected. The connection is now - # "invalid", but the pool should be ready to return - # new connections assuming they are good now. - # run the select again to re-validate the Connection. - connection.scalar(select([1])) - finally: - connection.should_close_with_result = save_should_close_with_result - - -def _setup_logging(connection_debug=0): - """setup_logging function maps SQL debug level to Python log level. - - Connection_debug is a verbosity of SQL debugging information. - 0=None(default value), - 1=Processed only messages with WARNING level or higher - 50=Processed only messages with INFO level or higher - 100=Processed only messages with DEBUG level - """ - if connection_debug >= 0: - logger = logging.getLogger('sqlalchemy.engine') - if connection_debug >= 100: - logger.setLevel(logging.DEBUG) - elif connection_debug >= 50: - logger.setLevel(logging.INFO) - else: - logger.setLevel(logging.WARNING) - - -def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, - idle_timeout=3600, - connection_debug=0, max_pool_size=None, max_overflow=None, - pool_timeout=None, sqlite_synchronous=True, - connection_trace=False, max_retries=10, retry_interval=10, - thread_checkin=True, logging_name=None): - """Return a new SQLAlchemy engine.""" - - url = sqlalchemy.engine.url.make_url(sql_connection) - - engine_args = { - "pool_recycle": idle_timeout, - 'convert_unicode': True, - 'connect_args': {}, - 'logging_name': logging_name - } - - _setup_logging(connection_debug) - - _init_connection_args( - url, engine_args, - sqlite_fk=sqlite_fk, - max_pool_size=max_pool_size, - max_overflow=max_overflow, - pool_timeout=pool_timeout - ) - - engine = sqlalchemy.create_engine(url, **engine_args) - - _init_events( - engine, - mysql_sql_mode=mysql_sql_mode, - sqlite_synchronous=sqlite_synchronous, - sqlite_fk=sqlite_fk, - thread_checkin=thread_checkin, - connection_trace=connection_trace - ) - - # register alternate exception handler - exc_filters.register_engine(engine) - - # register engine connect handler - compat.engine_connect(engine, _connect_ping_listener) - - # initial connect + test - _test_connection(engine, max_retries, retry_interval) - - return engine - - -@utils.dispatch_for_dialect('*', multiple=True) -def _init_connection_args( - url, engine_args, - max_pool_size=None, max_overflow=None, pool_timeout=None, **kw): - - pool_class = url.get_dialect().get_pool_class(url) - if issubclass(pool_class, pool.QueuePool): - if max_pool_size is not None: - engine_args['pool_size'] = max_pool_size - if max_overflow is not None: - engine_args['max_overflow'] = max_overflow - if pool_timeout is not None: - engine_args['pool_timeout'] = pool_timeout - - -@_init_connection_args.dispatch_for("sqlite") -def _init_connection_args(url, engine_args, **kw): - pool_class = url.get_dialect().get_pool_class(url) - # singletonthreadpool is used for :memory: connections; - # replace it with StaticPool. - if issubclass(pool_class, pool.SingletonThreadPool): - engine_args["poolclass"] = pool.StaticPool - engine_args['connect_args']['check_same_thread'] = False - - -@_init_connection_args.dispatch_for("postgresql") -def _init_connection_args(url, engine_args, **kw): - if 'client_encoding' not in url.query: - # Set encoding using engine_args instead of connect_args since - # it's supported for PostgreSQL 8.*. More details at: - # http://docs.sqlalchemy.org/en/rel_0_9/dialects/postgresql.html - engine_args['client_encoding'] = 'utf8' - - -@_init_connection_args.dispatch_for("mysql") -def _init_connection_args(url, engine_args, **kw): - if 'charset' not in url.query: - engine_args['connect_args']['charset'] = 'utf8' - - -@_init_connection_args.dispatch_for("mysql+mysqlconnector") -def _init_connection_args(url, engine_args, **kw): - # mysqlconnector engine (<1.0) incorrectly defaults to - # raise_on_warnings=True - # https://bitbucket.org/zzzeek/sqlalchemy/issue/2515 - if 'raise_on_warnings' not in url.query: - engine_args['connect_args']['raise_on_warnings'] = False - - -@_init_connection_args.dispatch_for("mysql+mysqldb") -@_init_connection_args.dispatch_for("mysql+oursql") -def _init_connection_args(url, engine_args, **kw): - # Those drivers require use_unicode=0 to avoid performance drop due - # to internal usage of Python unicode objects in the driver - # http://docs.sqlalchemy.org/en/rel_0_9/dialects/mysql.html - if 'use_unicode' not in url.query: - engine_args['connect_args']['use_unicode'] = 0 - - -@utils.dispatch_for_dialect('*', multiple=True) -def _init_events(engine, thread_checkin=True, connection_trace=False, **kw): - """Set up event listeners for all database backends.""" - - if connection_trace: - _add_trace_comments(engine) - - if thread_checkin: - sqlalchemy.event.listen(engine, 'checkin', _thread_yield) - - -@_init_events.dispatch_for("mysql") -def _init_events(engine, mysql_sql_mode=None, **kw): - """Set up event listeners for MySQL.""" - - if mysql_sql_mode is not None: - @sqlalchemy.event.listens_for(engine, "connect") - def _set_session_sql_mode(dbapi_con, connection_rec): - cursor = dbapi_con.cursor() - cursor.execute("SET SESSION sql_mode = %s", [mysql_sql_mode]) - - @sqlalchemy.event.listens_for(engine, "first_connect") - def _check_effective_sql_mode(dbapi_con, connection_rec): - if mysql_sql_mode is not None: - _set_session_sql_mode(dbapi_con, connection_rec) - - cursor = dbapi_con.cursor() - cursor.execute("SHOW VARIABLES LIKE 'sql_mode'") - realmode = cursor.fetchone() - - if realmode is None: - LOG.warning(_LW('Unable to detect effective SQL mode')) - else: - realmode = realmode[1] - LOG.debug('MySQL server mode set to %s', realmode) - if 'TRADITIONAL' not in realmode.upper() and \ - 'STRICT_ALL_TABLES' not in realmode.upper(): - LOG.warning( - _LW( - "MySQL SQL mode is '%s', " - "consider enabling TRADITIONAL or STRICT_ALL_TABLES"), - realmode) - - -@_init_events.dispatch_for("sqlite") -def _init_events(engine, sqlite_synchronous=True, sqlite_fk=False, **kw): - """Set up event listeners for SQLite. - - This includes several settings made on connections as they are - created, as well as transactional control extensions. - - """ - - def regexp(expr, item): - reg = re.compile(expr) - return reg.search(six.text_type(item)) is not None - - @sqlalchemy.event.listens_for(engine, "connect") - def _sqlite_connect_events(dbapi_con, con_record): - - # Add REGEXP functionality on SQLite connections - dbapi_con.create_function('regexp', 2, regexp) - - if not sqlite_synchronous: - # Switch sqlite connections to non-synchronous mode - dbapi_con.execute("PRAGMA synchronous = OFF") - - # Disable pysqlite's emitting of the BEGIN statement entirely. - # Also stops it from emitting COMMIT before any DDL. - # below, we emit BEGIN ourselves. - # see http://docs.sqlalchemy.org/en/rel_0_9/dialects/\ - # sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_con.isolation_level = None - - if sqlite_fk: - # Ensures that the foreign key constraints are enforced in SQLite. - dbapi_con.execute('pragma foreign_keys=ON') - - @sqlalchemy.event.listens_for(engine, "begin") - def _sqlite_emit_begin(conn): - # emit our own BEGIN, checking for existing - # transactional state - if 'in_transaction' not in conn.info: - conn.execute("BEGIN") - conn.info['in_transaction'] = True - - @sqlalchemy.event.listens_for(engine, "rollback") - @sqlalchemy.event.listens_for(engine, "commit") - def _sqlite_end_transaction(conn): - # remove transactional marker - conn.info.pop('in_transaction', None) - - -def _test_connection(engine, max_retries, retry_interval): - if max_retries == -1: - attempts = itertools.count() - else: - attempts = six.moves.range(max_retries) - # See: http://legacy.python.org/dev/peps/pep-3110/#semantic-changes for - # why we are not using 'de' directly (it can be removed from the local - # scope). - de_ref = None - for attempt in attempts: - try: - return exc_filters.handle_connect_error(engine) - except exception.DBConnectionError as de: - msg = _LW('SQL connection failed. %s attempts left.') - LOG.warning(msg, max_retries - attempt) - time.sleep(retry_interval) - de_ref = de - else: - if de_ref is not None: - six.reraise(type(de_ref), de_ref) - - -class Query(sqlalchemy.orm.query.Query): - """Subclass of sqlalchemy.query with soft_delete() method.""" - def soft_delete(self, synchronize_session='evaluate'): - return self.update({'deleted': literal_column('id'), - 'updated_at': literal_column('updated_at'), - 'deleted_at': timeutils.utcnow()}, - synchronize_session=synchronize_session) - - -class Session(sqlalchemy.orm.session.Session): - """Custom Session class to avoid SqlAlchemy Session monkey patching.""" - - -def get_maker(engine, autocommit=True, expire_on_commit=False): - """Return a SQLAlchemy sessionmaker using the given engine.""" - return sqlalchemy.orm.sessionmaker(bind=engine, - class_=Session, - autocommit=autocommit, - expire_on_commit=expire_on_commit, - query_cls=Query) - - -def _add_trace_comments(engine): - """Add trace comments. - - Augment statements with a trace of the immediate calling code - for a given statement. - """ - - import os - import sys - import traceback - target_paths = set([ - os.path.dirname(sys.modules['oslo.db'].__file__), - os.path.dirname(sys.modules['sqlalchemy'].__file__) - ]) - - @sqlalchemy.event.listens_for(engine, "before_cursor_execute", retval=True) - def before_cursor_execute(conn, cursor, statement, parameters, context, - executemany): - - # NOTE(zzzeek) - if different steps per DB dialect are desirable - # here, switch out on engine.name for now. - stack = traceback.extract_stack() - our_line = None - for idx, (filename, line, method, function) in enumerate(stack): - for tgt in target_paths: - if filename.startswith(tgt): - our_line = idx - break - if our_line: - break - - if our_line: - trace = "; ".join( - "File: %s (%s) %s" % ( - line[0], line[1], line[2] - ) - # include three lines of context. - for line in stack[our_line - 3:our_line] - - ) - statement = "%s -- %s" % (statement, trace) - - return statement, parameters - - -class EngineFacade(object): - """A helper class for removing of global engine instances from oslo.db. - - As a library, oslo.db can't decide where to store/when to create engine - and sessionmaker instances, so this must be left for a target application. - - On the other hand, in order to simplify the adoption of oslo.db changes, - we'll provide a helper class, which creates engine and sessionmaker - on its instantiation and provides get_engine()/get_session() methods - that are compatible with corresponding utility functions that currently - exist in target projects, e.g. in Nova. - - engine/sessionmaker instances will still be global (and they are meant to - be global), but they will be stored in the app context, rather that in the - oslo.db context. - - Note: using of this helper is completely optional and you are encouraged to - integrate engine/sessionmaker instances into your apps any way you like - (e.g. one might want to bind a session to a request context). Two important - things to remember: - - 1. An Engine instance is effectively a pool of DB connections, so it's - meant to be shared (and it's thread-safe). - 2. A Session instance is not meant to be shared and represents a DB - transactional context (i.e. it's not thread-safe). sessionmaker is - a factory of sessions. - - """ - - def __init__(self, sql_connection, slave_connection=None, - sqlite_fk=False, autocommit=True, - expire_on_commit=False, **kwargs): - """Initialize engine and sessionmaker instances. - - :param sql_connection: the connection string for the database to use - :type sql_connection: string - - :param slave_connection: the connection string for the 'slave' database - to use. If not provided, the master database - will be used for all operations. Note: this - is meant to be used for offloading of read - operations to asynchronously replicated slaves - to reduce the load on the master database. - :type slave_connection: string - - :param sqlite_fk: enable foreign keys in SQLite - :type sqlite_fk: bool - - :param autocommit: use autocommit mode for created Session instances - :type autocommit: bool - - :param expire_on_commit: expire session objects on commit - :type expire_on_commit: bool - - Keyword arguments: - - :keyword mysql_sql_mode: the SQL mode to be used for MySQL sessions. - (defaults to TRADITIONAL) - :keyword idle_timeout: timeout before idle sql connections are reaped - (defaults to 3600) - :keyword connection_debug: verbosity of SQL debugging information. - -1=Off, 0=None, 100=Everything (defaults - to 0) - :keyword max_pool_size: maximum number of SQL connections to keep open - in a pool (defaults to SQLAlchemy settings) - :keyword max_overflow: if set, use this value for max_overflow with - sqlalchemy (defaults to SQLAlchemy settings) - :keyword pool_timeout: if set, use this value for pool_timeout with - sqlalchemy (defaults to SQLAlchemy settings) - :keyword sqlite_synchronous: if True, SQLite uses synchronous mode - (defaults to True) - :keyword connection_trace: add python stack traces to SQL as comment - strings (defaults to False) - :keyword max_retries: maximum db connection retries during startup. - (setting -1 implies an infinite retry count) - (defaults to 10) - :keyword retry_interval: interval between retries of opening a sql - connection (defaults to 10) - :keyword thread_checkin: boolean that indicates that between each - engine checkin event a sleep(0) will occur to - allow other greenthreads to run (defaults to - True) - """ - - super(EngineFacade, self).__init__() - - engine_kwargs = { - 'sqlite_fk': sqlite_fk, - 'mysql_sql_mode': kwargs.get('mysql_sql_mode', 'TRADITIONAL'), - 'idle_timeout': kwargs.get('idle_timeout', 3600), - 'connection_debug': kwargs.get('connection_debug', 0), - 'max_pool_size': kwargs.get('max_pool_size'), - 'max_overflow': kwargs.get('max_overflow'), - 'pool_timeout': kwargs.get('pool_timeout'), - 'sqlite_synchronous': kwargs.get('sqlite_synchronous', True), - 'connection_trace': kwargs.get('connection_trace', False), - 'max_retries': kwargs.get('max_retries', 10), - 'retry_interval': kwargs.get('retry_interval', 10), - 'thread_checkin': kwargs.get('thread_checkin', True) - } - maker_kwargs = { - 'autocommit': autocommit, - 'expire_on_commit': expire_on_commit - } - - self._engine = create_engine(sql_connection=sql_connection, - **engine_kwargs) - self._session_maker = get_maker(engine=self._engine, - **maker_kwargs) - if slave_connection: - self._slave_engine = create_engine(sql_connection=slave_connection, - **engine_kwargs) - self._slave_session_maker = get_maker(engine=self._slave_engine, - **maker_kwargs) - else: - self._slave_engine = None - self._slave_session_maker = None - - def get_engine(self, use_slave=False): - """Get the engine instance (note, that it's shared). - - :param use_slave: if possible, use 'slave' database for this engine. - If the connection string for the slave database - wasn't provided, 'master' engine will be returned. - (defaults to False) - :type use_slave: bool - - """ - - if use_slave and self._slave_engine: - return self._slave_engine - - return self._engine - - def get_session(self, use_slave=False, **kwargs): - """Get a Session instance. - - :param use_slave: if possible, use 'slave' database connection for - this session. If the connection string for the - slave database wasn't provided, a session bound - to the 'master' engine will be returned. - (defaults to False) - :type use_slave: bool - - Keyword arugments will be passed to a sessionmaker instance as is (if - passed, they will override the ones used when the sessionmaker instance - was created). See SQLAlchemy Session docs for details. - - """ - - if use_slave and self._slave_session_maker: - return self._slave_session_maker(**kwargs) - - return self._session_maker(**kwargs) - - @classmethod - def from_config(cls, conf, - sqlite_fk=False, autocommit=True, expire_on_commit=False): - """Initialize EngineFacade using oslo.config config instance options. - - :param conf: oslo.config config instance - :type conf: oslo.config.cfg.ConfigOpts - - :param sqlite_fk: enable foreign keys in SQLite - :type sqlite_fk: bool - - :param autocommit: use autocommit mode for created Session instances - :type autocommit: bool - - :param expire_on_commit: expire session objects on commit - :type expire_on_commit: bool - - """ - - conf.register_opts(options.database_opts, 'database') - - return cls(sql_connection=conf.database.connection, - slave_connection=conf.database.slave_connection, - sqlite_fk=sqlite_fk, - autocommit=autocommit, - expire_on_commit=expire_on_commit, - mysql_sql_mode=conf.database.mysql_sql_mode, - idle_timeout=conf.database.idle_timeout, - connection_debug=conf.database.connection_debug, - max_pool_size=conf.database.max_pool_size, - max_overflow=conf.database.max_overflow, - pool_timeout=conf.database.pool_timeout, - sqlite_synchronous=conf.database.sqlite_synchronous, - connection_trace=conf.database.connection_trace, - max_retries=conf.database.max_retries, - retry_interval=conf.database.retry_interval) +from oslo_db.sqlalchemy.session import * # noqa diff --git a/oslo/db/sqlalchemy/test_base.py b/oslo/db/sqlalchemy/test_base.py index d483fadf..57e73a81 100644 --- a/oslo/db/sqlalchemy/test_base.py +++ b/oslo/db/sqlalchemy/test_base.py @@ -1,127 +1,15 @@ -# Copyright (c) 2013 OpenStack Foundation # All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. -import fixtures - -try: - from oslotest import base as test_base -except ImportError: - raise NameError('Oslotest is not installed. Please add oslotest in your' - ' test-requirements') - - -import six - -from oslo.db import exception -from oslo.db.sqlalchemy import provision -from oslo.db.sqlalchemy import session -from oslo.db.sqlalchemy import utils - - -class DbFixture(fixtures.Fixture): - """Basic database fixture. - - Allows to run tests on various db backends, such as SQLite, MySQL and - PostgreSQL. By default use sqlite backend. To override default backend - uri set env variable OS_TEST_DBAPI_CONNECTION with database admin - credentials for specific backend. - """ - - DRIVER = "sqlite" - - # these names are deprecated, and are not used by DbFixture. - # they are here for backwards compatibility with test suites that - # are referring to them directly. - DBNAME = PASSWORD = USERNAME = 'openstack_citest' - - def __init__(self, test): - super(DbFixture, self).__init__() - - self.test = test - - def setUp(self): - super(DbFixture, self).setUp() - - try: - self.provision = provision.ProvisionedDatabase(self.DRIVER) - self.addCleanup(self.provision.dispose) - except exception.BackendNotAvailable: - msg = '%s backend is not available.' % self.DRIVER - return self.test.skip(msg) - else: - self.test.engine = self.provision.engine - self.addCleanup(setattr, self.test, 'engine', None) - self.test.sessionmaker = session.get_maker(self.test.engine) - self.addCleanup(setattr, self.test, 'sessionmaker', None) - - -class DbTestCase(test_base.BaseTestCase): - """Base class for testing of DB code. - - Using `DbFixture`. Intended to be the main database test case to use all - the tests on a given backend with user defined uri. Backend specific - tests should be decorated with `backend_specific` decorator. - """ - - FIXTURE = DbFixture - - def setUp(self): - super(DbTestCase, self).setUp() - self.useFixture(self.FIXTURE(self)) - - -class OpportunisticTestCase(DbTestCase): - """Placeholder for backwards compatibility.""" - -ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql'] - - -def backend_specific(*dialects): - """Decorator to skip backend specific tests on inappropriate engines. - - ::dialects: list of dialects names under which the test will be launched. - """ - def wrap(f): - @six.wraps(f) - def ins_wrap(self): - if not set(dialects).issubset(ALLOWED_DIALECTS): - raise ValueError( - "Please use allowed dialects: %s" % ALLOWED_DIALECTS) - if self.engine.name not in dialects: - msg = ('The test "%s" can be run ' - 'only on %s. Current engine is %s.') - args = (utils.get_callable_name(f), ' '.join(dialects), - self.engine.name) - self.skip(msg % args) - else: - return f(self) - return ins_wrap - return wrap - - -class MySQLOpportunisticFixture(DbFixture): - DRIVER = 'mysql' - - -class PostgreSQLOpportunisticFixture(DbFixture): - DRIVER = 'postgresql' - - -class MySQLOpportunisticTestCase(OpportunisticTestCase): - FIXTURE = MySQLOpportunisticFixture - - -class PostgreSQLOpportunisticTestCase(OpportunisticTestCase): - FIXTURE = PostgreSQLOpportunisticFixture +from oslo_db.sqlalchemy.test_base import * # noqa diff --git a/oslo/db/sqlalchemy/test_migrations.py b/oslo/db/sqlalchemy/test_migrations.py index 4d9146aa..bfbf0a8a 100644 --- a/oslo/db/sqlalchemy/test_migrations.py +++ b/oslo/db/sqlalchemy/test_migrations.py @@ -1,5 +1,3 @@ -# Copyright 2010-2011 OpenStack Foundation -# Copyright 2012-2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -14,600 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -import abc -import collections -import logging -import pprint - -import alembic -import alembic.autogenerate -import alembic.migration -import pkg_resources as pkg -import six -import sqlalchemy -from sqlalchemy.engine import reflection -import sqlalchemy.exc -from sqlalchemy import schema -import sqlalchemy.sql.expression as expr -import sqlalchemy.types as types - -from oslo.db._i18n import _LE -from oslo.db import exception as exc -from oslo.db.sqlalchemy import utils - -LOG = logging.getLogger(__name__) - - -@six.add_metaclass(abc.ABCMeta) -class WalkVersionsMixin(object): - """Test mixin to check upgrade and downgrade ability of migration. - - This is only suitable for testing of migrate_ migration scripts. An - abstract class mixin. `INIT_VERSION`, `REPOSITORY` and `migration_api` - attributes must be implemented in subclasses. - - .. _auxiliary-dynamic-methods: Auxiliary Methods - - Auxiliary Methods: - - `migrate_up` and `migrate_down` instance methods of the class can be - used with auxiliary methods named `_pre_upgrade_`, - `_check_`, `_post_downgrade_`. The methods - intended to check applied changes for correctness of data operations. - This methods should be implemented for every particular revision - which you want to check with data. Implementation recommendations for - `_pre_upgrade_`, `_check_`, - `_post_downgrade_` implementation: - - * `_pre_upgrade_`: provide a data appropriate to - a next revision. Should be used an id of revision which - going to be applied. - - * `_check_`: Insert, select, delete operations - with newly applied changes. The data provided by - `_pre_upgrade_` will be used. - - * `_post_downgrade_`: check for absence - (inability to use) changes provided by reverted revision. - - Execution order of auxiliary methods when revision is upgrading: - - `_pre_upgrade_###` => `upgrade` => `_check_###` - - Execution order of auxiliary methods when revision is downgrading: - - `downgrade` => `_post_downgrade_###` - - .. _migrate: https://sqlalchemy-migrate.readthedocs.org/en/latest/ - - """ - - @abc.abstractproperty - def INIT_VERSION(self): - """Initial version of a migration repository. - - Can be different from 0, if a migrations were squashed. - - :rtype: int - """ - pass - - @abc.abstractproperty - def REPOSITORY(self): - """Allows basic manipulation with migration repository. - - :returns: `migrate.versioning.repository.Repository` subclass. - """ - pass - - @abc.abstractproperty - def migration_api(self): - """Provides API for upgrading, downgrading and version manipulations. - - :returns: `migrate.api` or overloaded analog. - """ - pass - - @abc.abstractproperty - def migrate_engine(self): - """Provides engine instance. - - Should be the same instance as used when migrations are applied. In - most cases, the `engine` attribute provided by the test class in a - `setUp` method will work. - - Example of implementation: - - def migrate_engine(self): - return self.engine - - :returns: sqlalchemy engine instance - """ - pass - - def _walk_versions(self, snake_walk=False, downgrade=True): - """Check if migration upgrades and downgrades successfully. - - DEPRECATED: this function is deprecated and will be removed from - oslo.db in a few releases. Please use walk_versions() method instead. - """ - self.walk_versions(snake_walk, downgrade) - - def _migrate_down(self, version, with_data=False): - """Migrate down to a previous version of the db. - - DEPRECATED: this function is deprecated and will be removed from - oslo.db in a few releases. Please use migrate_down() method instead. - """ - return self.migrate_down(version, with_data) - - def _migrate_up(self, version, with_data=False): - """Migrate up to a new version of the db. - - DEPRECATED: this function is deprecated and will be removed from - oslo.db in a few releases. Please use migrate_up() method instead. - """ - self.migrate_up(version, with_data) - - def walk_versions(self, snake_walk=False, downgrade=True): - """Check if migration upgrades and downgrades successfully. - - Determine the latest version script from the repo, then - upgrade from 1 through to the latest, with no data - in the databases. This just checks that the schema itself - upgrades successfully. - - `walk_versions` calls `migrate_up` and `migrate_down` with - `with_data` argument to check changes with data, but these methods - can be called without any extra check outside of `walk_versions` - method. - - :param snake_walk: enables checking that each individual migration can - be upgraded/downgraded by itself. - - If we have ordered migrations 123abc, 456def, 789ghi and we run - upgrading with the `snake_walk` argument set to `True`, the - migrations will be applied in the following order: - - `123abc => 456def => 123abc => - 456def => 789ghi => 456def => 789ghi` - - :type snake_walk: bool - :param downgrade: Check downgrade behavior if True. - :type downgrade: bool - """ - - # Place the database under version control - self.migration_api.version_control(self.migrate_engine, - self.REPOSITORY, - self.INIT_VERSION) - self.assertEqual(self.INIT_VERSION, - self.migration_api.db_version(self.migrate_engine, - self.REPOSITORY)) - - LOG.debug('latest version is %s', self.REPOSITORY.latest) - versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) - - for version in versions: - # upgrade -> downgrade -> upgrade - self.migrate_up(version, with_data=True) - if snake_walk: - downgraded = self.migrate_down(version - 1, with_data=True) - if downgraded: - self.migrate_up(version) - - if downgrade: - # Now walk it back down to 0 from the latest, testing - # the downgrade paths. - for version in reversed(versions): - # downgrade -> upgrade -> downgrade - downgraded = self.migrate_down(version - 1) - - if snake_walk and downgraded: - self.migrate_up(version) - self.migrate_down(version - 1) - - def migrate_down(self, version, with_data=False): - """Migrate down to a previous version of the db. - - :param version: id of revision to downgrade. - :type version: str - :keyword with_data: Whether to verify the absence of changes from - migration(s) being downgraded, see - :ref:`auxiliary-dynamic-methods `. - :type with_data: Bool - """ - - try: - self.migration_api.downgrade(self.migrate_engine, - self.REPOSITORY, version) - except NotImplementedError: - # NOTE(sirp): some migrations, namely release-level - # migrations, don't support a downgrade. - return False - - self.assertEqual(version, self.migration_api.db_version( - self.migrate_engine, self.REPOSITORY)) - - # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target' - # version). So if we have any downgrade checks, they need to be run for - # the previous (higher numbered) migration. - if with_data: - post_downgrade = getattr( - self, "_post_downgrade_%03d" % (version + 1), None) - if post_downgrade: - post_downgrade(self.migrate_engine) - - return True - - def migrate_up(self, version, with_data=False): - """Migrate up to a new version of the db. - - :param version: id of revision to upgrade. - :type version: str - :keyword with_data: Whether to verify the applied changes with data, - see :ref:`auxiliary-dynamic-methods `. - :type with_data: Bool - """ - # NOTE(sdague): try block is here because it's impossible to debug - # where a failed data migration happens otherwise - try: - if with_data: - data = None - pre_upgrade = getattr( - self, "_pre_upgrade_%03d" % version, None) - if pre_upgrade: - data = pre_upgrade(self.migrate_engine) - - self.migration_api.upgrade(self.migrate_engine, - self.REPOSITORY, version) - self.assertEqual(version, - self.migration_api.db_version(self.migrate_engine, - self.REPOSITORY)) - if with_data: - check = getattr(self, "_check_%03d" % version, None) - if check: - check(self.migrate_engine, data) - except exc.DbMigrationError: - msg = _LE("Failed to migrate to version %(ver)s on engine %(eng)s") - LOG.error(msg, {"ver": version, "eng": self.migrate_engine}) - raise - - -@six.add_metaclass(abc.ABCMeta) -class ModelsMigrationsSync(object): - """A helper class for comparison of DB migration scripts and models. - - It's intended to be inherited by test cases in target projects. They have - to provide implementations for methods used internally in the test (as - we have no way to implement them here). - - test_model_sync() will run migration scripts for the engine provided and - then compare the given metadata to the one reflected from the database. - The difference between MODELS and MIGRATION scripts will be printed and - the test will fail, if the difference is not empty. The return value is - really a list of actions, that should be performed in order to make the - current database schema state (i.e. migration scripts) consistent with - models definitions. It's left up to developers to analyze the output and - decide whether the models definitions or the migration scripts should be - modified to make them consistent. - - Output:: - - [( - 'add_table', - description of the table from models - ), - ( - 'remove_table', - description of the table from database - ), - ( - 'add_column', - schema, - table name, - column description from models - ), - ( - 'remove_column', - schema, - table name, - column description from database - ), - ( - 'add_index', - description of the index from models - ), - ( - 'remove_index', - description of the index from database - ), - ( - 'add_constraint', - description of constraint from models - ), - ( - 'remove_constraint, - description of constraint from database - ), - ( - 'modify_nullable', - schema, - table name, - column name, - { - 'existing_type': type of the column from database, - 'existing_server_default': default value from database - }, - nullable from database, - nullable from models - ), - ( - 'modify_type', - schema, - table name, - column name, - { - 'existing_nullable': database nullable, - 'existing_server_default': default value from database - }, - database column type, - type of the column from models - ), - ( - 'modify_default', - schema, - table name, - column name, - { - 'existing_nullable': database nullable, - 'existing_type': type of the column from database - }, - connection column default value, - default from models - )] - - Method include_object() can be overridden to exclude some tables from - comparison (e.g. migrate_repo). - - """ - - @abc.abstractmethod - def db_sync(self, engine): - """Run migration scripts with the given engine instance. - - This method must be implemented in subclasses and run migration scripts - for a DB the given engine is connected to. - - """ - - @abc.abstractmethod - def get_engine(self): - """Return the engine instance to be used when running tests. - - This method must be implemented in subclasses and return an engine - instance to be used when running tests. - - """ - - @abc.abstractmethod - def get_metadata(self): - """Return the metadata instance to be used for schema comparison. - - This method must be implemented in subclasses and return the metadata - instance attached to the BASE model. - - """ - - def include_object(self, object_, name, type_, reflected, compare_to): - """Return True for objects that should be compared. - - :param object_: a SchemaItem object such as a Table or Column object - :param name: the name of the object - :param type_: a string describing the type of object (e.g. "table") - :param reflected: True if the given object was produced based on - table reflection, False if it's from a local - MetaData object - :param compare_to: the object being compared against, if available, - else None - - """ - - return True - - def compare_type(self, ctxt, insp_col, meta_col, insp_type, meta_type): - """Return True if types are different, False if not. - - Return None to allow the default implementation to compare these types. - - :param ctxt: alembic MigrationContext instance - :param insp_col: reflected column - :param meta_col: column from model - :param insp_type: reflected column type - :param meta_type: column type from model - - """ - - # some backends (e.g. mysql) don't provide native boolean type - BOOLEAN_METADATA = (types.BOOLEAN, types.Boolean) - BOOLEAN_SQL = BOOLEAN_METADATA + (types.INTEGER, types.Integer) - - if issubclass(type(meta_type), BOOLEAN_METADATA): - return not issubclass(type(insp_type), BOOLEAN_SQL) - - return None # tells alembic to use the default comparison method - - def compare_server_default(self, ctxt, ins_col, meta_col, - insp_def, meta_def, rendered_meta_def): - """Compare default values between model and db table. - - Return True if the defaults are different, False if not, or None to - allow the default implementation to compare these defaults. - - :param ctxt: alembic MigrationContext instance - :param insp_col: reflected column - :param meta_col: column from model - :param insp_def: reflected column default value - :param meta_def: column default value from model - :param rendered_meta_def: rendered column default value (from model) - - """ - return self._compare_server_default(ctxt.bind, meta_col, insp_def, - meta_def) - - @utils.DialectFunctionDispatcher.dispatch_for_dialect("*") - def _compare_server_default(bind, meta_col, insp_def, meta_def): - pass - - @_compare_server_default.dispatch_for('mysql') - def _compare_server_default(bind, meta_col, insp_def, meta_def): - if isinstance(meta_col.type, sqlalchemy.Boolean): - if meta_def is None or insp_def is None: - return meta_def != insp_def - return not ( - isinstance(meta_def.arg, expr.True_) and insp_def == "'1'" or - isinstance(meta_def.arg, expr.False_) and insp_def == "'0'" - ) - - if isinstance(meta_col.type, sqlalchemy.Integer): - if meta_def is None or insp_def is None: - return meta_def != insp_def - return meta_def.arg != insp_def.split("'")[1] - - @_compare_server_default.dispatch_for('postgresql') - def _compare_server_default(bind, meta_col, insp_def, meta_def): - if isinstance(meta_col.type, sqlalchemy.Enum): - if meta_def is None or insp_def is None: - return meta_def != insp_def - return insp_def != "'%s'::%s" % (meta_def.arg, meta_col.type.name) - elif isinstance(meta_col.type, sqlalchemy.String): - if meta_def is None or insp_def is None: - return meta_def != insp_def - return insp_def != "'%s'::character varying" % meta_def.arg - - def _cleanup(self): - engine = self.get_engine() - with engine.begin() as conn: - inspector = reflection.Inspector.from_engine(engine) - metadata = schema.MetaData() - tbs = [] - all_fks = [] - - for table_name in inspector.get_table_names(): - fks = [] - for fk in inspector.get_foreign_keys(table_name): - if not fk['name']: - continue - fks.append( - schema.ForeignKeyConstraint((), (), name=fk['name']) - ) - table = schema.Table(table_name, metadata, *fks) - tbs.append(table) - all_fks.extend(fks) - - for fkc in all_fks: - conn.execute(schema.DropConstraint(fkc)) - - for table in tbs: - conn.execute(schema.DropTable(table)) - - FKInfo = collections.namedtuple('fk_info', ['constrained_columns', - 'referred_table', - 'referred_columns']) - - def check_foreign_keys(self, metadata, bind): - """Compare foreign keys between model and db table. - - :returns: a list that contains information about: - - * should be a new key added or removed existing, - * name of that key, - * source table, - * referred table, - * constrained columns, - * referred columns - - Output:: - - [('drop_key', - 'testtbl_fk_check_fkey', - 'testtbl', - fk_info(constrained_columns=(u'fk_check',), - referred_table=u'table', - referred_columns=(u'fk_check',)))] - - """ - - diff = [] - insp = sqlalchemy.engine.reflection.Inspector.from_engine(bind) - # Get all tables from db - db_tables = insp.get_table_names() - # Get all tables from models - model_tables = metadata.tables - for table in db_tables: - if table not in model_tables: - continue - # Get all necessary information about key of current table from db - fk_db = dict((self._get_fk_info_from_db(i), i['name']) - for i in insp.get_foreign_keys(table)) - fk_db_set = set(fk_db.keys()) - # Get all necessary information about key of current table from - # models - fk_models = dict((self._get_fk_info_from_model(fk), fk) - for fk in model_tables[table].foreign_keys) - fk_models_set = set(fk_models.keys()) - for key in (fk_db_set - fk_models_set): - diff.append(('drop_key', fk_db[key], table, key)) - LOG.info(("Detected removed foreign key %(fk)r on " - "table %(table)r"), {'fk': fk_db[key], - 'table': table}) - for key in (fk_models_set - fk_db_set): - diff.append(('add_key', fk_models[key], table, key)) - LOG.info(( - "Detected added foreign key for column %(fk)r on table " - "%(table)r"), {'fk': fk_models[key].column.name, - 'table': table}) - return diff - - def _get_fk_info_from_db(self, fk): - return self.FKInfo(tuple(fk['constrained_columns']), - fk['referred_table'], - tuple(fk['referred_columns'])) - - def _get_fk_info_from_model(self, fk): - return self.FKInfo((fk.parent.name,), fk.column.table.name, - (fk.column.name,)) - - def test_models_sync(self): - # recent versions of sqlalchemy and alembic are needed for running of - # this test, but we already have them in requirements - try: - pkg.require('sqlalchemy>=0.8.4', 'alembic>=0.6.2') - except (pkg.VersionConflict, pkg.DistributionNotFound) as e: - self.skipTest('sqlalchemy>=0.8.4 and alembic>=0.6.3 are required' - ' for running of this test: %s' % e) - - # drop all tables after a test run - self.addCleanup(self._cleanup) - - # run migration scripts - self.db_sync(self.get_engine()) - - with self.get_engine().connect() as conn: - opts = { - 'include_object': self.include_object, - 'compare_type': self.compare_type, - 'compare_server_default': self.compare_server_default, - } - mc = alembic.migration.MigrationContext.configure(conn, opts=opts) - - # compare schemas and fail with diff, if it's not empty - diff1 = alembic.autogenerate.compare_metadata(mc, - self.get_metadata()) - diff2 = self.check_foreign_keys(self.get_metadata(), - self.get_engine()) - diff = diff1 + diff2 - if diff: - msg = pprint.pformat(diff, indent=2, width=20) - self.fail( - "Models and migration scripts aren't in sync:\n%s" % msg) +from oslo_db.sqlalchemy.test_migrations import * # noqa diff --git a/oslo/db/sqlalchemy/utils.py b/oslo/db/sqlalchemy/utils.py index dd891b6f..a51cccdf 100644 --- a/oslo/db/sqlalchemy/utils.py +++ b/oslo/db/sqlalchemy/utils.py @@ -1,7 +1,3 @@ -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# Copyright 2010-2011 OpenStack Foundation. -# Copyright 2012 Justin Santa Barbara # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -16,997 +12,4 @@ # License for the specific language governing permissions and limitations # under the License. -import collections -import logging -import re - -from oslo.utils import timeutils -import six -import sqlalchemy -from sqlalchemy import Boolean -from sqlalchemy import CheckConstraint -from sqlalchemy import Column -from sqlalchemy.engine import Connectable -from sqlalchemy.engine import reflection -from sqlalchemy.engine import url as sa_url -from sqlalchemy.ext.compiler import compiles -from sqlalchemy import func -from sqlalchemy import Index -from sqlalchemy import Integer -from sqlalchemy import MetaData -from sqlalchemy.sql.expression import literal_column -from sqlalchemy.sql.expression import UpdateBase -from sqlalchemy.sql import text -from sqlalchemy import String -from sqlalchemy import Table -from sqlalchemy.types import NullType - -from oslo.db import exception -from oslo.db._i18n import _, _LI, _LW -from oslo.db.sqlalchemy import models - -# NOTE(ochuprykov): Add references for backwards compatibility -InvalidSortKey = exception.InvalidSortKey -ColumnError = exception.ColumnError - -LOG = logging.getLogger(__name__) - -_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") - - -def get_callable_name(function): - # TODO(harlowja): Replace this once - # it is possible to use https://review.openstack.org/#/c/122495/ which is - # a more complete and expansive module that does a similar thing... - try: - method_self = six.get_method_self(function) - except AttributeError: - method_self = None - if method_self is not None: - if isinstance(method_self, six.class_types): - im_class = method_self - else: - im_class = type(method_self) - try: - parts = (im_class.__module__, function.__qualname__) - except AttributeError: - parts = (im_class.__module__, im_class.__name__, function.__name__) - else: - try: - parts = (function.__module__, function.__qualname__) - except AttributeError: - parts = (function.__module__, function.__name__) - return '.'.join(parts) - - -def sanitize_db_url(url): - match = _DBURL_REGEX.match(url) - if match: - return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):]) - return url - - -# copy from glance/db/sqlalchemy/api.py -def paginate_query(query, model, limit, sort_keys, marker=None, - sort_dir=None, sort_dirs=None): - """Returns a query with sorting / pagination criteria added. - - Pagination works by requiring a unique sort_key, specified by sort_keys. - (If sort_keys is not unique, then we risk looping through values.) - We use the last row in the previous page as the 'marker' for pagination. - So we must return values that follow the passed marker in the order. - With a single-valued sort_key, this would be easy: sort_key > X. - With a compound-values sort_key, (k1, k2, k3) we must do this to repeat - the lexicographical ordering: - (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3) - - We also have to cope with different sort_directions. - - Typically, the id of the last row is used as the client-facing pagination - marker, then the actual marker object must be fetched from the db and - passed in to us as marker. - - :param query: the query object to which we should add paging/sorting - :param model: the ORM model class - :param limit: maximum number of items to return - :param sort_keys: array of attributes by which results should be sorted - :param marker: the last item of the previous page; we returns the next - results after this value. - :param sort_dir: direction in which results should be sorted (asc, desc) - :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys - - :rtype: sqlalchemy.orm.query.Query - :return: The query with sorting/pagination added. - """ - - if 'id' not in sort_keys: - # TODO(justinsb): If this ever gives a false-positive, check - # the actual primary key, rather than assuming its id - LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?')) - - assert(not (sort_dir and sort_dirs)) - - # Default the sort direction to ascending - if sort_dirs is None and sort_dir is None: - sort_dir = 'asc' - - # Ensure a per-column sort direction - if sort_dirs is None: - sort_dirs = [sort_dir for _sort_key in sort_keys] - - assert(len(sort_dirs) == len(sort_keys)) - - # Add sorting - for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): - try: - sort_dir_func = { - 'asc': sqlalchemy.asc, - 'desc': sqlalchemy.desc, - }[current_sort_dir] - except KeyError: - raise ValueError(_("Unknown sort direction, " - "must be 'desc' or 'asc'")) - try: - sort_key_attr = getattr(model, current_sort_key) - except AttributeError: - raise exception.InvalidSortKey() - query = query.order_by(sort_dir_func(sort_key_attr)) - - # Add pagination - if marker is not None: - marker_values = [] - for sort_key in sort_keys: - v = getattr(marker, sort_key) - marker_values.append(v) - - # Build up an array of sort criteria as in the docstring - criteria_list = [] - for i in range(len(sort_keys)): - crit_attrs = [] - for j in range(i): - model_attr = getattr(model, sort_keys[j]) - crit_attrs.append((model_attr == marker_values[j])) - - model_attr = getattr(model, sort_keys[i]) - if sort_dirs[i] == 'desc': - crit_attrs.append((model_attr < marker_values[i])) - else: - crit_attrs.append((model_attr > marker_values[i])) - - criteria = sqlalchemy.sql.and_(*crit_attrs) - criteria_list.append(criteria) - - f = sqlalchemy.sql.or_(*criteria_list) - query = query.filter(f) - - if limit is not None: - query = query.limit(limit) - - return query - - -def _read_deleted_filter(query, db_model, deleted): - if 'deleted' not in db_model.__table__.columns: - raise ValueError(_("There is no `deleted` column in `%s` table. " - "Project doesn't use soft-deleted feature.") - % db_model.__name__) - - default_deleted_value = db_model.__table__.c.deleted.default.arg - if deleted: - query = query.filter(db_model.deleted != default_deleted_value) - else: - query = query.filter(db_model.deleted == default_deleted_value) - return query - - -def _project_filter(query, db_model, project_id): - if 'project_id' not in db_model.__table__.columns: - raise ValueError(_("There is no `project_id` column in `%s` table.") - % db_model.__name__) - - if isinstance(project_id, (list, tuple, set)): - query = query.filter(db_model.project_id.in_(project_id)) - else: - query = query.filter(db_model.project_id == project_id) - - return query - - -def model_query(model, session, args=None, **kwargs): - """Query helper for db.sqlalchemy api methods. - - This accounts for `deleted` and `project_id` fields. - - :param model: Model to query. Must be a subclass of ModelBase. - :type model: models.ModelBase - - :param session: The session to use. - :type session: sqlalchemy.orm.session.Session - - :param args: Arguments to query. If None - model is used. - :type args: tuple - - Keyword arguments: - - :keyword project_id: If present, allows filtering by project_id(s). - Can be either a project_id value, or an iterable of - project_id values, or None. If an iterable is passed, - only rows whose project_id column value is on the - `project_id` list will be returned. If None is passed, - only rows which are not bound to any project, will be - returned. - :type project_id: iterable, - model.__table__.columns.project_id.type, - None type - - :keyword deleted: If present, allows filtering by deleted field. - If True is passed, only deleted entries will be - returned, if False - only existing entries. - :type deleted: bool - - - Usage: - - .. code-block:: python - - from oslo.db.sqlalchemy import utils - - - def get_instance_by_uuid(uuid): - session = get_session() - with session.begin() - return (utils.model_query(models.Instance, session=session) - .filter(models.Instance.uuid == uuid) - .first()) - - def get_nodes_stat(): - data = (Node.id, Node.cpu, Node.ram, Node.hdd) - - session = get_session() - with session.begin() - return utils.model_query(Node, session=session, args=data).all() - - Also you can create your own helper, based on ``utils.model_query()``. - For example, it can be useful if you plan to use ``project_id`` and - ``deleted`` parameters from project's ``context`` - - .. code-block:: python - - from oslo.db.sqlalchemy import utils - - - def _model_query(context, model, session=None, args=None, - project_id=None, project_only=False, - read_deleted=None): - - # We suppose, that functions ``_get_project_id()`` and - # ``_get_deleted()`` should handle passed parameters and - # context object (for example, decide, if we need to restrict a user - # to query his own entries by project_id or only allow admin to read - # deleted entries). For return values, we expect to get - # ``project_id`` and ``deleted``, which are suitable for the - # ``model_query()`` signature. - kwargs = {} - if project_id is not None: - kwargs['project_id'] = _get_project_id(context, project_id, - project_only) - if read_deleted is not None: - kwargs['deleted'] = _get_deleted_dict(context, read_deleted) - session = session or get_session() - - with session.begin(): - return utils.model_query(model, session=session, - args=args, **kwargs) - - def get_instance_by_uuid(context, uuid): - return (_model_query(context, models.Instance, read_deleted='yes') - .filter(models.Instance.uuid == uuid) - .first()) - - def get_nodes_data(context, project_id, project_only='allow_none'): - data = (Node.id, Node.cpu, Node.ram, Node.hdd) - - return (_model_query(context, Node, args=data, project_id=project_id, - project_only=project_only) - .all()) - - """ - - if not issubclass(model, models.ModelBase): - raise TypeError(_("model should be a subclass of ModelBase")) - - query = session.query(model) if not args else session.query(*args) - if 'deleted' in kwargs: - query = _read_deleted_filter(query, model, kwargs['deleted']) - if 'project_id' in kwargs: - query = _project_filter(query, model, kwargs['project_id']) - - return query - - -def get_table(engine, name): - """Returns an sqlalchemy table dynamically from db. - - Needed because the models don't work for us in migrations - as models will be far out of sync with the current data. - - .. warning:: - - Do not use this method when creating ForeignKeys in database migrations - because sqlalchemy needs the same MetaData object to hold information - about the parent table and the reference table in the ForeignKey. This - method uses a unique MetaData object per table object so it won't work - with ForeignKey creation. - """ - metadata = MetaData() - metadata.bind = engine - return Table(name, metadata, autoload=True) - - -class InsertFromSelect(UpdateBase): - """Form the base for `INSERT INTO table (SELECT ... )` statement.""" - def __init__(self, table, select): - self.table = table - self.select = select - - -@compiles(InsertFromSelect) -def visit_insert_from_select(element, compiler, **kw): - """Form the `INSERT INTO table (SELECT ... )` statement.""" - return "INSERT INTO %s %s" % ( - compiler.process(element.table, asfrom=True), - compiler.process(element.select)) - - -def _get_not_supported_column(col_name_col_instance, column_name): - try: - column = col_name_col_instance[column_name] - except KeyError: - msg = _("Please specify column %s in col_name_col_instance " - "param. It is required because column has unsupported " - "type by SQLite.") - raise exception.ColumnError(msg % column_name) - - if not isinstance(column, Column): - msg = _("col_name_col_instance param has wrong type of " - "column instance for column %s It should be instance " - "of sqlalchemy.Column.") - raise exception.ColumnError(msg % column_name) - return column - - -def drop_old_duplicate_entries_from_table(migrate_engine, table_name, - use_soft_delete, *uc_column_names): - """Drop all old rows having the same values for columns in uc_columns. - - This method drop (or mark ad `deleted` if use_soft_delete is True) old - duplicate rows form table with name `table_name`. - - :param migrate_engine: Sqlalchemy engine - :param table_name: Table with duplicates - :param use_soft_delete: If True - values will be marked as `deleted`, - if False - values will be removed from table - :param uc_column_names: Unique constraint columns - """ - meta = MetaData() - meta.bind = migrate_engine - - table = Table(table_name, meta, autoload=True) - columns_for_group_by = [table.c[name] for name in uc_column_names] - - columns_for_select = [func.max(table.c.id)] - columns_for_select.extend(columns_for_group_by) - - duplicated_rows_select = sqlalchemy.sql.select( - columns_for_select, group_by=columns_for_group_by, - having=func.count(table.c.id) > 1) - - for row in migrate_engine.execute(duplicated_rows_select).fetchall(): - # NOTE(boris-42): Do not remove row that has the biggest ID. - delete_condition = table.c.id != row[0] - is_none = None # workaround for pyflakes - delete_condition &= table.c.deleted_at == is_none - for name in uc_column_names: - delete_condition &= table.c[name] == row[name] - - rows_to_delete_select = sqlalchemy.sql.select( - [table.c.id]).where(delete_condition) - for row in migrate_engine.execute(rows_to_delete_select).fetchall(): - LOG.info(_LI("Deleting duplicated row with id: %(id)s from table: " - "%(table)s"), dict(id=row[0], table=table_name)) - - if use_soft_delete: - delete_statement = table.update().\ - where(delete_condition).\ - values({ - 'deleted': literal_column('id'), - 'updated_at': literal_column('updated_at'), - 'deleted_at': timeutils.utcnow() - }) - else: - delete_statement = table.delete().where(delete_condition) - migrate_engine.execute(delete_statement) - - -def _get_default_deleted_value(table): - if isinstance(table.c.id.type, Integer): - return 0 - if isinstance(table.c.id.type, String): - return "" - raise exception.ColumnError(_("Unsupported id columns type")) - - -def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): - table = get_table(migrate_engine, table_name) - - insp = reflection.Inspector.from_engine(migrate_engine) - real_indexes = insp.get_indexes(table_name) - existing_index_names = dict( - [(index['name'], index['column_names']) for index in real_indexes]) - - # NOTE(boris-42): Restore indexes on `deleted` column - for index in indexes: - if 'deleted' not in index['column_names']: - continue - name = index['name'] - if name in existing_index_names: - column_names = [table.c[c] for c in existing_index_names[name]] - old_index = Index(name, *column_names, unique=index["unique"]) - old_index.drop(migrate_engine) - - column_names = [table.c[c] for c in index['column_names']] - new_index = Index(index["name"], *column_names, unique=index["unique"]) - new_index.create(migrate_engine) - - -def change_deleted_column_type_to_boolean(migrate_engine, table_name, - **col_name_col_instance): - if migrate_engine.name == "sqlite": - return _change_deleted_column_type_to_boolean_sqlite( - migrate_engine, table_name, **col_name_col_instance) - insp = reflection.Inspector.from_engine(migrate_engine) - indexes = insp.get_indexes(table_name) - - table = get_table(migrate_engine, table_name) - - old_deleted = Column('old_deleted', Boolean, default=False) - old_deleted.create(table, populate_default=False) - - table.update().\ - where(table.c.deleted == table.c.id).\ - values(old_deleted=True).\ - execute() - - table.c.deleted.drop() - table.c.old_deleted.alter(name="deleted") - - _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) - - -def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name, - **col_name_col_instance): - insp = reflection.Inspector.from_engine(migrate_engine) - table = get_table(migrate_engine, table_name) - - columns = [] - for column in table.columns: - column_copy = None - if column.name != "deleted": - if isinstance(column.type, NullType): - column_copy = _get_not_supported_column(col_name_col_instance, - column.name) - else: - column_copy = column.copy() - else: - column_copy = Column('deleted', Boolean, default=0) - columns.append(column_copy) - - constraints = [constraint.copy() for constraint in table.constraints] - - meta = table.metadata - new_table = Table(table_name + "__tmp__", meta, - *(columns + constraints)) - new_table.create() - - indexes = [] - for index in insp.get_indexes(table_name): - column_names = [new_table.c[c] for c in index['column_names']] - indexes.append(Index(index["name"], *column_names, - unique=index["unique"])) - - c_select = [] - for c in table.c: - if c.name != "deleted": - c_select.append(c) - else: - c_select.append(table.c.deleted == table.c.id) - - ins = InsertFromSelect(new_table, sqlalchemy.sql.select(c_select)) - migrate_engine.execute(ins) - - table.drop() - for index in indexes: - index.create(migrate_engine) - - new_table.rename(table_name) - new_table.update().\ - where(new_table.c.deleted == new_table.c.id).\ - values(deleted=True).\ - execute() - - -def change_deleted_column_type_to_id_type(migrate_engine, table_name, - **col_name_col_instance): - if migrate_engine.name == "sqlite": - return _change_deleted_column_type_to_id_type_sqlite( - migrate_engine, table_name, **col_name_col_instance) - insp = reflection.Inspector.from_engine(migrate_engine) - indexes = insp.get_indexes(table_name) - - table = get_table(migrate_engine, table_name) - - new_deleted = Column('new_deleted', table.c.id.type, - default=_get_default_deleted_value(table)) - new_deleted.create(table, populate_default=True) - - deleted = True # workaround for pyflakes - table.update().\ - where(table.c.deleted == deleted).\ - values(new_deleted=table.c.id).\ - execute() - table.c.deleted.drop() - table.c.new_deleted.alter(name="deleted") - - _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) - - -def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name, - **col_name_col_instance): - # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check - # constraints in sqlite DB and our `deleted` column has - # 2 check constraints. So there is only one way to remove - # these constraints: - # 1) Create new table with the same columns, constraints - # and indexes. (except deleted column). - # 2) Copy all data from old to new table. - # 3) Drop old table. - # 4) Rename new table to old table name. - insp = reflection.Inspector.from_engine(migrate_engine) - meta = MetaData(bind=migrate_engine) - table = Table(table_name, meta, autoload=True) - default_deleted_value = _get_default_deleted_value(table) - - columns = [] - for column in table.columns: - column_copy = None - if column.name != "deleted": - if isinstance(column.type, NullType): - column_copy = _get_not_supported_column(col_name_col_instance, - column.name) - else: - column_copy = column.copy() - else: - column_copy = Column('deleted', table.c.id.type, - default=default_deleted_value) - columns.append(column_copy) - - def is_deleted_column_constraint(constraint): - # NOTE(boris-42): There is no other way to check is CheckConstraint - # associated with deleted column. - if not isinstance(constraint, CheckConstraint): - return False - sqltext = str(constraint.sqltext) - # NOTE(I159): in order to omit the CHECK constraint corresponding - # to `deleted` column we have to test these patterns which may - # vary depending on the SQLAlchemy version used. - constraint_markers = ( - "deleted in (0, 1)", - "deleted IN (:deleted_1, :deleted_2)", - "deleted IN (:param_1, :param_2)" - ) - return any(sqltext.endswith(marker) for marker in constraint_markers) - - constraints = [] - for constraint in table.constraints: - if not is_deleted_column_constraint(constraint): - constraints.append(constraint.copy()) - - new_table = Table(table_name + "__tmp__", meta, - *(columns + constraints)) - new_table.create() - - indexes = [] - for index in insp.get_indexes(table_name): - column_names = [new_table.c[c] for c in index['column_names']] - indexes.append(Index(index["name"], *column_names, - unique=index["unique"])) - - ins = InsertFromSelect(new_table, table.select()) - migrate_engine.execute(ins) - - table.drop() - for index in indexes: - index.create(migrate_engine) - - new_table.rename(table_name) - deleted = True # workaround for pyflakes - new_table.update().\ - where(new_table.c.deleted == deleted).\ - values(deleted=new_table.c.id).\ - execute() - - # NOTE(boris-42): Fix value of deleted column: False -> "" or 0. - deleted = False # workaround for pyflakes - new_table.update().\ - where(new_table.c.deleted == deleted).\ - values(deleted=default_deleted_value).\ - execute() - - -def get_connect_string(backend, database, user=None, passwd=None, - host='localhost'): - """Get database connection - - Try to get a connection with a very specific set of values, if we get - these then we'll run the tests, otherwise they are skipped - - DEPRECATED: this function is deprecated and will be removed from oslo.db - in a few releases. Please use the provisioning system for dealing - with URLs and database provisioning. - - """ - args = {'backend': backend, - 'user': user, - 'passwd': passwd, - 'host': host, - 'database': database} - if backend == 'sqlite': - template = '%(backend)s:///%(database)s' - else: - template = "%(backend)s://%(user)s:%(passwd)s@%(host)s/%(database)s" - return template % args - - -def is_backend_avail(backend, database, user=None, passwd=None): - """Return True if the given backend is available. - - - DEPRECATED: this function is deprecated and will be removed from oslo.db - in a few releases. Please use the provisioning system to access - databases based on backend availability. - - """ - from oslo.db.sqlalchemy import provision - - connect_uri = get_connect_string(backend=backend, - database=database, - user=user, - passwd=passwd) - try: - eng = provision.Backend._ensure_backend_available(connect_uri) - eng.dispose() - except exception.BackendNotAvailable: - return False - else: - return True - - -def get_db_connection_info(conn_pieces): - database = conn_pieces.path.strip('/') - loc_pieces = conn_pieces.netloc.split('@') - host = loc_pieces[1] - - auth_pieces = loc_pieces[0].split(':') - user = auth_pieces[0] - password = "" - if len(auth_pieces) > 1: - password = auth_pieces[1].strip() - - return (user, password, database, host) - - -def index_exists(migrate_engine, table_name, index_name): - """Check if given index exists. - - :param migrate_engine: sqlalchemy engine - :param table_name: name of the table - :param index_name: name of the index - """ - inspector = reflection.Inspector.from_engine(migrate_engine) - indexes = inspector.get_indexes(table_name) - index_names = [index['name'] for index in indexes] - return index_name in index_names - - -def add_index(migrate_engine, table_name, index_name, idx_columns): - """Create an index for given columns. - - :param migrate_engine: sqlalchemy engine - :param table_name: name of the table - :param index_name: name of the index - :param idx_columns: tuple with names of columns that will be indexed - """ - table = get_table(migrate_engine, table_name) - if not index_exists(migrate_engine, table_name, index_name): - index = Index( - index_name, *[getattr(table.c, col) for col in idx_columns] - ) - index.create() - else: - raise ValueError("Index '%s' already exists!" % index_name) - - -def drop_index(migrate_engine, table_name, index_name): - """Drop index with given name. - - :param migrate_engine: sqlalchemy engine - :param table_name: name of the table - :param index_name: name of the index - """ - table = get_table(migrate_engine, table_name) - for index in table.indexes: - if index.name == index_name: - index.drop() - break - else: - raise ValueError("Index '%s' not found!" % index_name) - - -def change_index_columns(migrate_engine, table_name, index_name, new_columns): - """Change set of columns that are indexed by given index. - - :param migrate_engine: sqlalchemy engine - :param table_name: name of the table - :param index_name: name of the index - :param new_columns: tuple with names of columns that will be indexed - """ - drop_index(migrate_engine, table_name, index_name) - add_index(migrate_engine, table_name, index_name, new_columns) - - -def column_exists(engine, table_name, column): - """Check if table has given column. - - :param engine: sqlalchemy engine - :param table_name: name of the table - :param column: name of the colmn - """ - t = get_table(engine, table_name) - return column in t.c - - -class DialectFunctionDispatcher(object): - @classmethod - def dispatch_for_dialect(cls, expr, multiple=False): - """Provide dialect-specific functionality within distinct functions. - - e.g.:: - - @dispatch_for_dialect("*") - def set_special_option(engine): - pass - - @set_special_option.dispatch_for("sqlite") - def set_sqlite_special_option(engine): - return engine.execute("sqlite thing") - - @set_special_option.dispatch_for("mysql+mysqldb") - def set_mysqldb_special_option(engine): - return engine.execute("mysqldb thing") - - After the above registration, the ``set_special_option()`` function - is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``, - URL string, or ``sqlalchemy.engine.URL`` object:: - - eng = create_engine('...') - result = set_special_option(eng) - - The filter system supports two modes, "multiple" and "single". - The default is "single", and requires that one and only one function - match for a given backend. In this mode, the function may also - have a return value, which will be returned by the top level - call. - - "multiple" mode, on the other hand, does not support return - arguments, but allows for any number of matching functions, where - each function will be called:: - - # the initial call sets this up as a "multiple" dispatcher - @dispatch_for_dialect("*", multiple=True) - def set_options(engine): - # set options that apply to *all* engines - - @set_options.dispatch_for("postgresql") - def set_postgresql_options(engine): - # set options that apply to all Postgresql engines - - @set_options.dispatch_for("postgresql+psycopg2") - def set_postgresql_psycopg2_options(engine): - # set options that apply only to "postgresql+psycopg2" - - @set_options.dispatch_for("*+pyodbc") - def set_pyodbc_options(engine): - # set options that apply to all pyodbc backends - - Note that in both modes, any number of additional arguments can be - accepted by member functions. For example, to populate a dictionary of - options, it may be passed in:: - - @dispatch_for_dialect("*", multiple=True) - def set_engine_options(url, opts): - pass - - @set_engine_options.dispatch_for("mysql+mysqldb") - def _mysql_set_default_charset_to_utf8(url, opts): - opts.setdefault('charset', 'utf-8') - - @set_engine_options.dispatch_for("sqlite") - def _set_sqlite_in_memory_check_same_thread(url, opts): - if url.database in (None, 'memory'): - opts['check_same_thread'] = False - - opts = {} - set_engine_options(url, opts) - - The driver specifiers are of the form: - ``[+]``. That is, database name or "*", - followed by an optional ``+`` sign with driver or "*". Omitting - the driver name implies all drivers for that database. - - """ - if multiple: - cls = DialectMultiFunctionDispatcher - else: - cls = DialectSingleFunctionDispatcher - return cls().dispatch_for(expr) - - _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$') - - def dispatch_for(self, expr): - def decorate(fn): - dbname, driver = self._parse_dispatch(expr) - if fn is self: - fn = fn._last - self._last = fn - self._register(expr, dbname, driver, fn) - return self - return decorate - - def _parse_dispatch(self, text): - m = self._db_plus_driver_reg.match(text) - if not m: - raise ValueError("Couldn't parse database[+driver]: %r" % text) - return m.group(1) or '*', m.group(2) or '*' - - def __call__(self, *arg, **kw): - target = arg[0] - return self._dispatch_on( - self._url_from_target(target), target, arg, kw) - - def _url_from_target(self, target): - if isinstance(target, Connectable): - return target.engine.url - elif isinstance(target, six.string_types): - if "://" not in target: - target_url = sa_url.make_url("%s://" % target) - else: - target_url = sa_url.make_url(target) - return target_url - elif isinstance(target, sa_url.URL): - return target - else: - raise ValueError("Invalid target type: %r" % target) - - def dispatch_on_drivername(self, drivername): - """Return a sub-dispatcher for the given drivername. - - This provides a means of calling a different function, such as the - "*" function, for a given target object that normally refers - to a sub-function. - - """ - dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2) - - def go(*arg, **kw): - return self._dispatch_on_db_driver(dbname, "*", arg, kw) - - return go - - def _dispatch_on(self, url, target, arg, kw): - dbname, driver = self._db_plus_driver_reg.match( - url.drivername).group(1, 2) - if not driver: - driver = url.get_dialect().driver - - return self._dispatch_on_db_driver(dbname, driver, arg, kw) - - def _invoke_fn(self, fn, arg, kw): - return fn(*arg, **kw) - - -class DialectSingleFunctionDispatcher(DialectFunctionDispatcher): - def __init__(self): - self.reg = collections.defaultdict(dict) - - def _register(self, expr, dbname, driver, fn): - fn_dict = self.reg[dbname] - if driver in fn_dict: - raise TypeError("Multiple functions for expression %r" % expr) - fn_dict[driver] = fn - - def _matches(self, dbname, driver): - for db in (dbname, '*'): - subdict = self.reg[db] - for drv in (driver, '*'): - if drv in subdict: - return subdict[drv] - else: - raise ValueError( - "No default function found for driver: %r" % - ("%s+%s" % (dbname, driver))) - - def _dispatch_on_db_driver(self, dbname, driver, arg, kw): - fn = self._matches(dbname, driver) - return self._invoke_fn(fn, arg, kw) - - -class DialectMultiFunctionDispatcher(DialectFunctionDispatcher): - def __init__(self): - self.reg = collections.defaultdict( - lambda: collections.defaultdict(list)) - - def _register(self, expr, dbname, driver, fn): - self.reg[dbname][driver].append(fn) - - def _matches(self, dbname, driver): - if driver != '*': - drivers = (driver, '*') - else: - drivers = ('*', ) - - for db in (dbname, '*'): - subdict = self.reg[db] - for drv in drivers: - for fn in subdict[drv]: - yield fn - - def _dispatch_on_db_driver(self, dbname, driver, arg, kw): - for fn in self._matches(dbname, driver): - if self._invoke_fn(fn, arg, kw) is not None: - raise TypeError( - "Return value not allowed for " - "multiple filtered function") - -dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect - - -def get_non_innodb_tables(connectable, skip_tables=('migrate_version', - 'alembic_version')): - """Get a list of tables which don't use InnoDB storage engine. - - :param connectable: a SQLAlchemy Engine or a Connection instance - :param skip_tables: a list of tables which might have a different - storage engine - """ - - query_str = """ - SELECT table_name - FROM information_schema.tables - WHERE table_schema = :database AND - engine != 'InnoDB' - """ - - params = {} - if skip_tables: - params = dict( - ('skip_%s' % i, table_name) - for i, table_name in enumerate(skip_tables) - ) - - placeholders = ', '.join(':' + p for p in params) - query_str += ' AND table_name NOT IN (%s)' % placeholders - - params['database'] = connectable.engine.url.database - query = text(query_str) - noninnodb = connectable.execute(query, **params) - return [i[0] for i in noninnodb] +from oslo_db.sqlalchemy.utils import * # noqa diff --git a/tests/__init__.py b/oslo_db/__init__.py similarity index 100% rename from tests/__init__.py rename to oslo_db/__init__.py diff --git a/oslo/db/_i18n.py b/oslo_db/_i18n.py similarity index 100% rename from oslo/db/_i18n.py rename to oslo_db/_i18n.py diff --git a/oslo_db/api.py b/oslo_db/api.py new file mode 100644 index 00000000..e673b372 --- /dev/null +++ b/oslo_db/api.py @@ -0,0 +1,229 @@ +# Copyright (c) 2013 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +================================= +Multiple DB API backend support. +================================= + +A DB backend module should implement a method named 'get_backend' which +takes no arguments. The method can return any object that implements DB +API methods. +""" + +import logging +import threading +import time + +from oslo.utils import importutils +import six + +from oslo_db._i18n import _LE +from oslo_db import exception +from oslo_db import options + + +LOG = logging.getLogger(__name__) + + +def safe_for_db_retry(f): + """Indicate api method as safe for re-connection to database. + + Database connection retries will be enabled for the decorated api method. + Database connection failure can have many causes, which can be temporary. + In such cases retry may increase the likelihood of connection. + + Usage:: + + @safe_for_db_retry + def api_method(self): + self.engine.connect() + + + :param f: database api method. + :type f: function. + """ + f.__dict__['enable_retry'] = True + return f + + +class wrap_db_retry(object): + """Decorator class. Retry db.api methods, if DBConnectionError() raised. + + Retry decorated db.api methods. If we enabled `use_db_reconnect` + in config, this decorator will be applied to all db.api functions, + marked with @safe_for_db_retry decorator. + Decorator catches DBConnectionError() and retries function in a + loop until it succeeds, or until maximum retries count will be reached. + + Keyword arguments: + + :param retry_interval: seconds between transaction retries + :type retry_interval: int + + :param max_retries: max number of retries before an error is raised + :type max_retries: int + + :param inc_retry_interval: determine increase retry interval or not + :type inc_retry_interval: bool + + :param max_retry_interval: max interval value between retries + :type max_retry_interval: int + """ + + def __init__(self, retry_interval, max_retries, inc_retry_interval, + max_retry_interval): + super(wrap_db_retry, self).__init__() + + self.retry_interval = retry_interval + self.max_retries = max_retries + self.inc_retry_interval = inc_retry_interval + self.max_retry_interval = max_retry_interval + + def __call__(self, f): + @six.wraps(f) + def wrapper(*args, **kwargs): + next_interval = self.retry_interval + remaining = self.max_retries + + while True: + try: + return f(*args, **kwargs) + except exception.DBConnectionError as e: + if remaining == 0: + LOG.exception(_LE('DB exceeded retry limit.')) + raise exception.DBError(e) + if remaining != -1: + remaining -= 1 + LOG.exception(_LE('DB connection error.')) + # NOTE(vsergeyev): We are using patched time module, so + # this effectively yields the execution + # context to another green thread. + time.sleep(next_interval) + if self.inc_retry_interval: + next_interval = min( + next_interval * 2, + self.max_retry_interval + ) + return wrapper + + +class DBAPI(object): + """Initialize the chosen DB API backend. + + After initialization API methods is available as normal attributes of + ``DBAPI`` subclass. Database API methods are supposed to be called as + DBAPI instance methods. + + :param backend_name: name of the backend to load + :type backend_name: str + + :param backend_mapping: backend name -> module/class to load mapping + :type backend_mapping: dict + :default backend_mapping: None + + :param lazy: load the DB backend lazily on the first DB API method call + :type lazy: bool + :default lazy: False + + :keyword use_db_reconnect: retry DB transactions on disconnect or not + :type use_db_reconnect: bool + + :keyword retry_interval: seconds between transaction retries + :type retry_interval: int + + :keyword inc_retry_interval: increase retry interval or not + :type inc_retry_interval: bool + + :keyword max_retry_interval: max interval value between retries + :type max_retry_interval: int + + :keyword max_retries: max number of retries before an error is raised + :type max_retries: int + """ + + def __init__(self, backend_name, backend_mapping=None, lazy=False, + **kwargs): + + self._backend = None + self._backend_name = backend_name + self._backend_mapping = backend_mapping or {} + self._lock = threading.Lock() + + if not lazy: + self._load_backend() + + self.use_db_reconnect = kwargs.get('use_db_reconnect', False) + self.retry_interval = kwargs.get('retry_interval', 1) + self.inc_retry_interval = kwargs.get('inc_retry_interval', True) + self.max_retry_interval = kwargs.get('max_retry_interval', 10) + self.max_retries = kwargs.get('max_retries', 20) + + def _load_backend(self): + with self._lock: + if not self._backend: + # Import the untranslated name if we don't have a mapping + backend_path = self._backend_mapping.get(self._backend_name, + self._backend_name) + LOG.debug('Loading backend %(name)r from %(path)r', + {'name': self._backend_name, + 'path': backend_path}) + backend_mod = importutils.import_module(backend_path) + self._backend = backend_mod.get_backend() + + def __getattr__(self, key): + if not self._backend: + self._load_backend() + + attr = getattr(self._backend, key) + if not hasattr(attr, '__call__'): + return attr + # NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry + # DB API methods, decorated with @safe_for_db_retry + # on disconnect. + if self.use_db_reconnect and hasattr(attr, 'enable_retry'): + attr = wrap_db_retry( + retry_interval=self.retry_interval, + max_retries=self.max_retries, + inc_retry_interval=self.inc_retry_interval, + max_retry_interval=self.max_retry_interval)(attr) + + return attr + + @classmethod + def from_config(cls, conf, backend_mapping=None, lazy=False): + """Initialize DBAPI instance given a config instance. + + :param conf: oslo.config config instance + :type conf: oslo.config.cfg.ConfigOpts + + :param backend_mapping: backend name -> module/class to load mapping + :type backend_mapping: dict + + :param lazy: load the DB backend lazily on the first DB API method call + :type lazy: bool + + """ + + conf.register_opts(options.database_opts, 'database') + + return cls(backend_name=conf.database.backend, + backend_mapping=backend_mapping, + lazy=lazy, + use_db_reconnect=conf.database.use_db_reconnect, + retry_interval=conf.database.db_retry_interval, + inc_retry_interval=conf.database.db_inc_retry_interval, + max_retry_interval=conf.database.db_max_retry_interval, + max_retries=conf.database.db_max_retries) diff --git a/oslo_db/concurrency.py b/oslo_db/concurrency.py new file mode 100644 index 00000000..2c596230 --- /dev/null +++ b/oslo_db/concurrency.py @@ -0,0 +1,81 @@ +# Copyright 2014 Mirantis.inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import logging +import threading + +from oslo.config import cfg + +from oslo_db._i18n import _LE +from oslo_db import api + + +LOG = logging.getLogger(__name__) + +tpool_opts = [ + cfg.BoolOpt('use_tpool', + default=False, + deprecated_name='dbapi_use_tpool', + deprecated_group='DEFAULT', + help='Enable the experimental use of thread pooling for ' + 'all DB API calls'), +] + + +class TpoolDbapiWrapper(object): + """DB API wrapper class. + + This wraps the oslo DB API with an option to be able to use eventlet's + thread pooling. Since the CONF variable may not be loaded at the time + this class is instantiated, we must look at it on the first DB API call. + """ + + def __init__(self, conf, backend_mapping): + self._db_api = None + self._backend_mapping = backend_mapping + self._conf = conf + self._conf.register_opts(tpool_opts, 'database') + self._lock = threading.Lock() + + @property + def _api(self): + if not self._db_api: + with self._lock: + if not self._db_api: + db_api = api.DBAPI.from_config( + conf=self._conf, backend_mapping=self._backend_mapping) + if self._conf.database.use_tpool: + try: + from eventlet import tpool + except ImportError: + LOG.exception(_LE("'eventlet' is required for " + "TpoolDbapiWrapper.")) + raise + self._db_api = tpool.Proxy(db_api) + else: + self._db_api = db_api + return self._db_api + + def __getattr__(self, key): + return getattr(self._api, key) + + +def list_opts(): + """Returns a list of oslo.config options available in this module. + + :returns: a list of (group_name, opts) tuples + """ + return [('database', copy.deepcopy(tpool_opts))] diff --git a/oslo_db/exception.py b/oslo_db/exception.py new file mode 100644 index 00000000..5de7f1e1 --- /dev/null +++ b/oslo_db/exception.py @@ -0,0 +1,173 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""DB related custom exceptions. + +Custom exceptions intended to determine the causes of specific database +errors. This module provides more generic exceptions than the database-specific +driver libraries, and so users of oslo.db can catch these no matter which +database the application is using. Most of the exceptions are wrappers. Wrapper +exceptions take an original exception as positional argument and keep it for +purposes of deeper debug. + +Example:: + + try: + statement(arg) + except sqlalchemy.exc.OperationalError as e: + raise DBDuplicateEntry(e) + + +This is useful to determine more specific error cases further at execution, +when you need to add some extra information to an error message. Wrapper +exceptions takes care about original error message displaying to not to loose +low level cause of an error. All the database api exceptions wrapped into +the specific exceptions provided belove. + + +Please use only database related custom exceptions with database manipulations +with `try/except` statement. This is required for consistent handling of +database errors. +""" + +import six + +from oslo_db._i18n import _ + + +class DBError(Exception): + + """Base exception for all custom database exceptions. + + :kwarg inner_exception: an original exception which was wrapped with + DBError or its subclasses. + """ + + def __init__(self, inner_exception=None): + self.inner_exception = inner_exception + super(DBError, self).__init__(six.text_type(inner_exception)) + + +class DBDuplicateEntry(DBError): + """Duplicate entry at unique column error. + + Raised when made an attempt to write to a unique column the same entry as + existing one. :attr: `columns` available on an instance of the exception + and could be used at error handling:: + + try: + instance_type_ref.save() + except DBDuplicateEntry as e: + if 'colname' in e.columns: + # Handle error. + + :kwarg columns: a list of unique columns have been attempted to write a + duplicate entry. + :type columns: list + :kwarg value: a value which has been attempted to write. The value will + be None, if we can't extract it for a particular database backend. Only + MySQL and PostgreSQL 9.x are supported right now. + """ + def __init__(self, columns=None, inner_exception=None, value=None): + self.columns = columns or [] + self.value = value + super(DBDuplicateEntry, self).__init__(inner_exception) + + +class DBReferenceError(DBError): + """Foreign key violation error. + + :param table: a table name in which the reference is directed. + :type table: str + :param constraint: a problematic constraint name. + :type constraint: str + :param key: a broken reference key name. + :type key: str + :param key_table: a table name which contains the key. + :type key_table: str + """ + + def __init__(self, table, constraint, key, key_table, + inner_exception=None): + self.table = table + self.constraint = constraint + self.key = key + self.key_table = key_table + super(DBReferenceError, self).__init__(inner_exception) + + +class DBDeadlock(DBError): + + """Database dead lock error. + + Deadlock is a situation that occurs when two or more different database + sessions have some data locked, and each database session requests a lock + on the data that another, different, session has already locked. + """ + + def __init__(self, inner_exception=None): + super(DBDeadlock, self).__init__(inner_exception) + + +class DBInvalidUnicodeParameter(Exception): + + """Database unicode error. + + Raised when unicode parameter is passed to a database + without encoding directive. + """ + + message = _("Invalid Parameter: " + "Encoding directive wasn't provided.") + + +class DbMigrationError(DBError): + + """Wrapped migration specific exception. + + Raised when migrations couldn't be completed successfully. + """ + + def __init__(self, message=None): + super(DbMigrationError, self).__init__(message) + + +class DBConnectionError(DBError): + + """Wrapped connection specific exception. + + Raised when database connection is failed. + """ + + pass + + +class InvalidSortKey(Exception): + """A sort key destined for database query usage is invalid.""" + + message = _("Sort key supplied was not valid.") + + +class ColumnError(Exception): + """Error raised when no column or an invalid column is found.""" + + +class BackendNotAvailable(Exception): + """Error raised when a particular database backend is not available + + within a test suite. + + """ diff --git a/oslo_db/options.py b/oslo_db/options.py new file mode 100644 index 00000000..b8550644 --- /dev/null +++ b/oslo_db/options.py @@ -0,0 +1,220 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy + +from oslo.config import cfg + + +database_opts = [ + cfg.StrOpt('sqlite_db', + deprecated_group='DEFAULT', + default='oslo.sqlite', + help='The file name to use with SQLite.'), + cfg.BoolOpt('sqlite_synchronous', + deprecated_group='DEFAULT', + default=True, + help='If True, SQLite uses synchronous mode.'), + cfg.StrOpt('backend', + default='sqlalchemy', + deprecated_name='db_backend', + deprecated_group='DEFAULT', + help='The back end to use for the database.'), + cfg.StrOpt('connection', + help='The SQLAlchemy connection string to use to connect to ' + 'the database.', + secret=True, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_connection', + group='DATABASE'), + cfg.DeprecatedOpt('connection', + group='sql'), ]), + cfg.StrOpt('slave_connection', + secret=True, + help='The SQLAlchemy connection string to use to connect to the' + ' slave database.'), + cfg.StrOpt('mysql_sql_mode', + default='TRADITIONAL', + help='The SQL mode to be used for MySQL sessions. ' + 'This option, including the default, overrides any ' + 'server-set SQL mode. To use whatever SQL mode ' + 'is set by the server configuration, ' + 'set this to no value. Example: mysql_sql_mode='), + cfg.IntOpt('idle_timeout', + default=3600, + deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_idle_timeout', + group='DATABASE'), + cfg.DeprecatedOpt('idle_timeout', + group='sql')], + help='Timeout before idle SQL connections are reaped.'), + cfg.IntOpt('min_pool_size', + default=1, + deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_min_pool_size', + group='DATABASE')], + help='Minimum number of SQL connections to keep open in a ' + 'pool.'), + cfg.IntOpt('max_pool_size', + deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_pool_size', + group='DATABASE')], + help='Maximum number of SQL connections to keep open in a ' + 'pool.'), + cfg.IntOpt('max_retries', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_retries', + group='DATABASE')], + help='Maximum number of database connection retries ' + 'during startup. Set to -1 to specify an infinite ' + 'retry count.'), + cfg.IntOpt('retry_interval', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', + group='DEFAULT'), + cfg.DeprecatedOpt('reconnect_interval', + group='DATABASE')], + help='Interval between retries of opening a SQL connection.'), + cfg.IntOpt('max_overflow', + deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', + group='DEFAULT'), + cfg.DeprecatedOpt('sqlalchemy_max_overflow', + group='DATABASE')], + help='If set, use this value for max_overflow with ' + 'SQLAlchemy.'), + cfg.IntOpt('connection_debug', + default=0, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', + group='DEFAULT')], + help='Verbosity of SQL debugging information: 0=None, ' + '100=Everything.'), + cfg.BoolOpt('connection_trace', + default=False, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', + group='DEFAULT')], + help='Add Python stack traces to SQL as comment strings.'), + cfg.IntOpt('pool_timeout', + deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', + group='DATABASE')], + help='If set, use this value for pool_timeout with ' + 'SQLAlchemy.'), + cfg.BoolOpt('use_db_reconnect', + default=False, + help='Enable the experimental use of database reconnect ' + 'on connection lost.'), + cfg.IntOpt('db_retry_interval', + default=1, + help='Seconds between database connection retries.'), + cfg.BoolOpt('db_inc_retry_interval', + default=True, + help='If True, increases the interval between database ' + 'connection retries up to db_max_retry_interval.'), + cfg.IntOpt('db_max_retry_interval', + default=10, + help='If db_inc_retry_interval is set, the ' + 'maximum seconds between database connection retries.'), + cfg.IntOpt('db_max_retries', + default=20, + help='Maximum database connection retries before error is ' + 'raised. Set to -1 to specify an infinite retry ' + 'count.'), +] + + +def set_defaults(conf, connection=None, sqlite_db=None, + max_pool_size=None, max_overflow=None, + pool_timeout=None): + """Set defaults for configuration variables. + + Overrides default options values. + + :param conf: Config instance specified to set default options in it. Using + of instances instead of a global config object prevents conflicts between + options declaration. + :type conf: oslo.config.cfg.ConfigOpts instance. + + :keyword connection: SQL connection string. + Valid SQLite URL forms are: + * sqlite:///:memory: (or, sqlite://) + * sqlite:///relative/path/to/file.db + * sqlite:////absolute/path/to/file.db + :type connection: str + + :keyword sqlite_db: path to SQLite database file. + :type sqlite_db: str + + :keyword max_pool_size: maximum connections pool size. The size of the pool + to be maintained, defaults to 5, will be used if value of the parameter is + `None`. This is the largest number of connections that will be kept + persistently in the pool. Note that the pool begins with no connections; + once this number of connections is requested, that number of connections + will remain. + :type max_pool_size: int + :default max_pool_size: None + + :keyword max_overflow: The maximum overflow size of the pool. When the + number of checked-out connections reaches the size set in pool_size, + additional connections will be returned up to this limit. When those + additional connections are returned to the pool, they are disconnected and + discarded. It follows then that the total number of simultaneous + connections the pool will allow is pool_size + max_overflow, and the total + number of "sleeping" connections the pool will allow is pool_size. + max_overflow can be set to -1 to indicate no overflow limit; no limit will + be placed on the total number of concurrent connections. Defaults to 10, + will be used if value of the parameter in `None`. + :type max_overflow: int + :default max_overflow: None + + :keyword pool_timeout: The number of seconds to wait before giving up on + returning a connection. Defaults to 30, will be used if value of the + parameter is `None`. + :type pool_timeout: int + :default pool_timeout: None + """ + + conf.register_opts(database_opts, group='database') + + if connection is not None: + conf.set_default('connection', connection, group='database') + if sqlite_db is not None: + conf.set_default('sqlite_db', sqlite_db, group='database') + if max_pool_size is not None: + conf.set_default('max_pool_size', max_pool_size, group='database') + if max_overflow is not None: + conf.set_default('max_overflow', max_overflow, group='database') + if pool_timeout is not None: + conf.set_default('pool_timeout', pool_timeout, group='database') + + +def list_opts(): + """Returns a list of oslo.config options available in the library. + + The returned list includes all oslo.config options which may be registered + at runtime by the library. + + Each element of the list is a tuple. The first element is the name of the + group under which the list of elements in the second element will be + registered. A group name of None corresponds to the [DEFAULT] group in + config files. + + The purpose of this is to allow tools like the Oslo sample config file + generator to discover the options exposed to users by this library. + + :returns: a list of (group_name, opts) tuples + """ + return [('database', copy.deepcopy(database_opts))] diff --git a/tests/sqlalchemy/__init__.py b/oslo_db/sqlalchemy/__init__.py similarity index 100% rename from tests/sqlalchemy/__init__.py rename to oslo_db/sqlalchemy/__init__.py diff --git a/oslo_db/sqlalchemy/compat/__init__.py b/oslo_db/sqlalchemy/compat/__init__.py new file mode 100644 index 00000000..b49d5c41 --- /dev/null +++ b/oslo_db/sqlalchemy/compat/__init__.py @@ -0,0 +1,30 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""compatiblity extensions for SQLAlchemy versions. + +Elements within this module provide SQLAlchemy features that have been +added at some point but for which oslo.db provides a compatible versions +for previous SQLAlchemy versions. + +""" +from oslo_db.sqlalchemy.compat import engine_connect as _e_conn +from oslo_db.sqlalchemy.compat import handle_error as _h_err + +# trying to get: "from oslo_db.sqlalchemy import compat; compat.handle_error" +# flake8 won't let me import handle_error directly +engine_connect = _e_conn.engine_connect +handle_error = _h_err.handle_error +handle_connect_context = _h_err.handle_connect_context + +__all__ = [ + 'engine_connect', 'handle_error', + 'handle_connect_context'] diff --git a/oslo/db/sqlalchemy/compat/engine_connect.py b/oslo_db/sqlalchemy/compat/engine_connect.py similarity index 97% rename from oslo/db/sqlalchemy/compat/engine_connect.py rename to oslo_db/sqlalchemy/compat/engine_connect.py index d64d4624..6b50fc68 100644 --- a/oslo/db/sqlalchemy/compat/engine_connect.py +++ b/oslo_db/sqlalchemy/compat/engine_connect.py @@ -20,7 +20,7 @@ http://docs.sqlalchemy.org/en/rel_0_9/core/events.html. from sqlalchemy.engine import Engine from sqlalchemy import event -from oslo.db.sqlalchemy.compat import utils +from oslo_db.sqlalchemy.compat import utils def engine_connect(engine, listener): diff --git a/oslo/db/sqlalchemy/compat/handle_error.py b/oslo_db/sqlalchemy/compat/handle_error.py similarity index 99% rename from oslo/db/sqlalchemy/compat/handle_error.py rename to oslo_db/sqlalchemy/compat/handle_error.py index 1537480e..7e476a0f 100644 --- a/oslo/db/sqlalchemy/compat/handle_error.py +++ b/oslo_db/sqlalchemy/compat/handle_error.py @@ -24,7 +24,7 @@ from sqlalchemy.engine import Engine from sqlalchemy import event from sqlalchemy import exc as sqla_exc -from oslo.db.sqlalchemy.compat import utils +from oslo_db.sqlalchemy.compat import utils def handle_error(engine, listener): diff --git a/oslo_db/sqlalchemy/compat/utils.py b/oslo_db/sqlalchemy/compat/utils.py new file mode 100644 index 00000000..fa6c3e77 --- /dev/null +++ b/oslo_db/sqlalchemy/compat/utils.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import re + +import sqlalchemy + + +_SQLA_VERSION = tuple( + int(num) if re.match(r'^\d+$', num) else num + for num in sqlalchemy.__version__.split(".") +) + +sqla_100 = _SQLA_VERSION >= (1, 0, 0) +sqla_097 = _SQLA_VERSION >= (0, 9, 7) +sqla_094 = _SQLA_VERSION >= (0, 9, 4) +sqla_090 = _SQLA_VERSION >= (0, 9, 0) +sqla_08 = _SQLA_VERSION >= (0, 8) diff --git a/oslo_db/sqlalchemy/exc_filters.py b/oslo_db/sqlalchemy/exc_filters.py new file mode 100644 index 00000000..efdbb2f7 --- /dev/null +++ b/oslo_db/sqlalchemy/exc_filters.py @@ -0,0 +1,358 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Define exception redefinitions for SQLAlchemy DBAPI exceptions.""" + +import collections +import logging +import re + +from sqlalchemy import exc as sqla_exc + +from oslo_db._i18n import _LE +from oslo_db import exception +from oslo_db.sqlalchemy import compat + + +LOG = logging.getLogger(__name__) + + +_registry = collections.defaultdict( + lambda: collections.defaultdict( + list + ) +) + + +def filters(dbname, exception_type, regex): + """Mark a function as receiving a filtered exception. + + :param dbname: string database name, e.g. 'mysql' + :param exception_type: a SQLAlchemy database exception class, which + extends from :class:`sqlalchemy.exc.DBAPIError`. + :param regex: a string, or a tuple of strings, that will be processed + as matching regular expressions. + + """ + def _receive(fn): + _registry[dbname][exception_type].extend( + (fn, re.compile(reg)) + for reg in + ((regex,) if not isinstance(regex, tuple) else regex) + ) + return fn + return _receive + + +# NOTE(zzzeek) - for Postgresql, catch both OperationalError, as the +# actual error is +# psycopg2.extensions.TransactionRollbackError(OperationalError), +# as well as sqlalchemy.exc.DBAPIError, as SQLAlchemy will reraise it +# as this until issue #3075 is fixed. +@filters("mysql", sqla_exc.OperationalError, r"^.*\b1213\b.*Deadlock found.*") +@filters("mysql", sqla_exc.OperationalError, + r"^.*\b1205\b.*Lock wait timeout exceeded.*") +@filters("mysql", sqla_exc.InternalError, r"^.*\b1213\b.*Deadlock found.*") +@filters("postgresql", sqla_exc.OperationalError, r"^.*deadlock detected.*") +@filters("postgresql", sqla_exc.DBAPIError, r"^.*deadlock detected.*") +@filters("ibm_db_sa", sqla_exc.DBAPIError, r"^.*SQL0911N.*") +def _deadlock_error(operational_error, match, engine_name, is_disconnect): + """Filter for MySQL or Postgresql deadlock error. + + NOTE(comstud): In current versions of DB backends, Deadlock violation + messages follow the structure: + + mysql+mysqldb: + (OperationalError) (1213, 'Deadlock found when trying to get lock; try ' + 'restarting transaction') + + mysql+mysqlconnector: + (InternalError) 1213 (40001): Deadlock found when trying to get lock; try + restarting transaction + + postgresql: + (TransactionRollbackError) deadlock detected + + + ibm_db_sa: + SQL0911N The current transaction has been rolled back because of a + deadlock or timeout + + """ + raise exception.DBDeadlock(operational_error) + + +@filters("mysql", sqla_exc.IntegrityError, + r"^.*\b1062\b.*Duplicate entry '(?P[^']+)'" + r" for key '(?P[^']+)'.*$") +# NOTE(pkholkin): the first regex is suitable only for PostgreSQL 9.x versions +# the second regex is suitable for PostgreSQL 8.x versions +@filters("postgresql", sqla_exc.IntegrityError, + (r'^.*duplicate\s+key.*"(?P[^"]+)"\s*\n.*' + r'Key\s+\((?P.*)\)=\((?P.*)\)\s+already\s+exists.*$', + r"^.*duplicate\s+key.*\"(?P[^\"]+)\"\s*\n.*$")) +def _default_dupe_key_error(integrity_error, match, engine_name, + is_disconnect): + """Filter for MySQL or Postgresql duplicate key error. + + note(boris-42): In current versions of DB backends unique constraint + violation messages follow the structure: + + postgres: + 1 column - (IntegrityError) duplicate key value violates unique + constraint "users_c1_key" + N columns - (IntegrityError) duplicate key value violates unique + constraint "name_of_our_constraint" + + mysql+mysqldb: + 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key + 'c1'") + N columns - (IntegrityError) (1062, "Duplicate entry 'values joined + with -' for key 'name_of_our_constraint'") + + mysql+mysqlconnector: + 1 column - (IntegrityError) 1062 (23000): Duplicate entry 'value_of_c1' for + key 'c1' + N columns - (IntegrityError) 1062 (23000): Duplicate entry 'values + joined with -' for key 'name_of_our_constraint' + + + + """ + + columns = match.group('columns') + + # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2" + # where `t` it is table name and columns `c1`, `c2` + # are in UniqueConstraint. + uniqbase = "uniq_" + if not columns.startswith(uniqbase): + if engine_name == "postgresql": + columns = [columns[columns.index("_") + 1:columns.rindex("_")]] + else: + columns = [columns] + else: + columns = columns[len(uniqbase):].split("0")[1:] + + value = match.groupdict().get('value') + + raise exception.DBDuplicateEntry(columns, integrity_error, value) + + +@filters("sqlite", sqla_exc.IntegrityError, + (r"^.*columns?(?P[^)]+)(is|are)\s+not\s+unique$", + r"^.*UNIQUE\s+constraint\s+failed:\s+(?P.+)$", + r"^.*PRIMARY\s+KEY\s+must\s+be\s+unique.*$")) +def _sqlite_dupe_key_error(integrity_error, match, engine_name, is_disconnect): + """Filter for SQLite duplicate key error. + + note(boris-42): In current versions of DB backends unique constraint + violation messages follow the structure: + + sqlite: + 1 column - (IntegrityError) column c1 is not unique + N columns - (IntegrityError) column c1, c2, ..., N are not unique + + sqlite since 3.7.16: + 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1 + N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2 + + sqlite since 3.8.2: + (IntegrityError) PRIMARY KEY must be unique + + """ + columns = [] + # NOTE(ochuprykov): We can get here by last filter in which there are no + # groups. Trying to access the substring that matched by + # the group will lead to IndexError. In this case just + # pass empty list to exception.DBDuplicateEntry + try: + columns = match.group('columns') + columns = [c.split('.')[-1] for c in columns.strip().split(", ")] + except IndexError: + pass + + raise exception.DBDuplicateEntry(columns, integrity_error) + + +@filters("sqlite", sqla_exc.IntegrityError, + r"(?i).*foreign key constraint failed") +@filters("postgresql", sqla_exc.IntegrityError, + r".*on table \"(?P
[^\"]+)\" violates " + "foreign key constraint \"(?P[^\"]+)\"\s*\n" + "DETAIL: Key \((?P.+)\)=\(.+\) " + "is not present in table " + "\"(?P[^\"]+)\".") +@filters("mysql", sqla_exc.IntegrityError, + r".* 'Cannot add or update a child row: " + 'a foreign key constraint fails \([`"].+[`"]\.[`"](?P
.+)[`"], ' + 'CONSTRAINT [`"](?P.+)[`"] FOREIGN KEY ' + '\([`"](?P.+)[`"]\) REFERENCES [`"](?P.+)[`"] ') +def _foreign_key_error(integrity_error, match, engine_name, is_disconnect): + """Filter for foreign key errors.""" + + try: + table = match.group("table") + except IndexError: + table = None + try: + constraint = match.group("constraint") + except IndexError: + constraint = None + try: + key = match.group("key") + except IndexError: + key = None + try: + key_table = match.group("key_table") + except IndexError: + key_table = None + + raise exception.DBReferenceError(table, constraint, key, key_table, + integrity_error) + + +@filters("ibm_db_sa", sqla_exc.IntegrityError, r"^.*SQL0803N.*$") +def _db2_dupe_key_error(integrity_error, match, engine_name, is_disconnect): + """Filter for DB2 duplicate key errors. + + N columns - (IntegrityError) SQL0803N One or more values in the INSERT + statement, UPDATE statement, or foreign key update caused by a + DELETE statement are not valid because the primary key, unique + constraint or unique index identified by "2" constrains table + "NOVA.KEY_PAIRS" from having duplicate values for the index + key. + + """ + + # NOTE(mriedem): The ibm_db_sa integrity error message doesn't provide the + # columns so we have to omit that from the DBDuplicateEntry error. + raise exception.DBDuplicateEntry([], integrity_error) + + +@filters("mysql", sqla_exc.DBAPIError, r".*\b1146\b") +def _raise_mysql_table_doesnt_exist_asis( + error, match, engine_name, is_disconnect): + """Raise MySQL error 1146 as is. + + Raise MySQL error 1146 as is, so that it does not conflict with + the MySQL dialect's checking a table not existing. + """ + + raise error + + +@filters("*", sqla_exc.OperationalError, r".*") +def _raise_operational_errors_directly_filter(operational_error, + match, engine_name, + is_disconnect): + """Filter for all remaining OperationalError classes and apply. + + Filter for all remaining OperationalError classes and apply + special rules. + """ + if is_disconnect: + # operational errors that represent disconnect + # should be wrapped + raise exception.DBConnectionError(operational_error) + else: + # NOTE(comstud): A lot of code is checking for OperationalError + # so let's not wrap it for now. + raise operational_error + + +@filters("mysql", sqla_exc.OperationalError, r".*\(.*(?:2002|2003|2006|2013)") +@filters("ibm_db_sa", sqla_exc.OperationalError, r".*(?:30081)") +def _is_db_connection_error(operational_error, match, engine_name, + is_disconnect): + """Detect the exception as indicating a recoverable error on connect.""" + raise exception.DBConnectionError(operational_error) + + +@filters("*", sqla_exc.DBAPIError, r".*") +def _raise_for_remaining_DBAPIError(error, match, engine_name, is_disconnect): + """Filter for remaining DBAPIErrors. + + Filter for remaining DBAPIErrors and wrap if they represent + a disconnect error. + """ + if is_disconnect: + raise exception.DBConnectionError(error) + else: + LOG.exception( + _LE('DBAPIError exception wrapped from %s') % error) + raise exception.DBError(error) + + +@filters('*', UnicodeEncodeError, r".*") +def _raise_for_unicode_encode(error, match, engine_name, is_disconnect): + raise exception.DBInvalidUnicodeParameter() + + +@filters("*", Exception, r".*") +def _raise_for_all_others(error, match, engine_name, is_disconnect): + LOG.exception(_LE('DB exception wrapped.')) + raise exception.DBError(error) + + +def handler(context): + """Iterate through available filters and invoke those which match. + + The first one which raises wins. The order in which the filters + are attempted is sorted by specificity - dialect name or "*", + exception class per method resolution order (``__mro__``). + Method resolution order is used so that filter rules indicating a + more specific exception class are attempted first. + + """ + def _dialect_registries(engine): + if engine.dialect.name in _registry: + yield _registry[engine.dialect.name] + if '*' in _registry: + yield _registry['*'] + + for per_dialect in _dialect_registries(context.engine): + for exc in ( + context.sqlalchemy_exception, + context.original_exception): + for super_ in exc.__class__.__mro__: + if super_ in per_dialect: + regexp_reg = per_dialect[super_] + for fn, regexp in regexp_reg: + match = regexp.match(exc.args[0]) + if match: + try: + fn( + exc, + match, + context.engine.dialect.name, + context.is_disconnect) + except exception.DBConnectionError: + context.is_disconnect = True + raise + + +def register_engine(engine): + compat.handle_error(engine, handler) + + +def handle_connect_error(engine): + """Handle connect error. + + Provide a special context that will allow on-connect errors + to be treated within the filtering context. + + This routine is dependent on SQLAlchemy version, as version 1.0.0 + provides this functionality natively. + + """ + with compat.handle_connect_context(handler, engine): + return engine.connect() diff --git a/oslo_db/sqlalchemy/migration.py b/oslo_db/sqlalchemy/migration.py new file mode 100644 index 00000000..308ce60c --- /dev/null +++ b/oslo_db/sqlalchemy/migration.py @@ -0,0 +1,160 @@ +# coding=utf-8 + +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# Base on code in migrate/changeset/databases/sqlite.py which is under +# the following license: +# +# The MIT License +# +# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +from migrate import exceptions as versioning_exceptions +from migrate.versioning import api as versioning_api +from migrate.versioning.repository import Repository +import sqlalchemy + +from oslo_db._i18n import _ +from oslo_db import exception + + +def db_sync(engine, abs_path, version=None, init_version=0, sanity_check=True): + """Upgrade or downgrade a database. + + Function runs the upgrade() or downgrade() functions in change scripts. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository. + :param version: Database will upgrade/downgrade until this version. + If None - database will update to the latest + available version. + :param init_version: Initial database version + :param sanity_check: Require schema sanity checking for all tables + """ + + if version is not None: + try: + version = int(version) + except ValueError: + raise exception.DbMigrationError( + message=_("version should be an integer")) + + current_version = db_version(engine, abs_path, init_version) + repository = _find_migrate_repo(abs_path) + if sanity_check: + _db_schema_sanity_check(engine) + if version is None or version > current_version: + return versioning_api.upgrade(engine, repository, version) + else: + return versioning_api.downgrade(engine, repository, + version) + + +def _db_schema_sanity_check(engine): + """Ensure all database tables were created with required parameters. + + :param engine: SQLAlchemy engine instance for a given database + + """ + + if engine.name == 'mysql': + onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION ' + 'from information_schema.TABLES ' + 'where TABLE_SCHEMA=%s and ' + 'TABLE_COLLATION NOT LIKE \'%%utf8%%\'') + + # NOTE(morganfainberg): exclude the sqlalchemy-migrate and alembic + # versioning tables from the tables we need to verify utf8 status on. + # Non-standard table names are not supported. + EXCLUDED_TABLES = ['migrate_version', 'alembic_version'] + + table_names = [res[0] for res in + engine.execute(onlyutf8_sql, engine.url.database) if + res[0].lower() not in EXCLUDED_TABLES] + + if len(table_names) > 0: + raise ValueError(_('Tables "%s" have non utf8 collation, ' + 'please make sure all tables are CHARSET=utf8' + ) % ','.join(table_names)) + + +def db_version(engine, abs_path, init_version): + """Show the current version of the repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + try: + return versioning_api.db_version(engine, repository) + except versioning_exceptions.DatabaseNotControlledError: + meta = sqlalchemy.MetaData() + meta.reflect(bind=engine) + tables = meta.tables + if len(tables) == 0 or 'alembic_version' in tables: + db_version_control(engine, abs_path, version=init_version) + return versioning_api.db_version(engine, repository) + else: + raise exception.DbMigrationError( + message=_( + "The database is not under version control, but has " + "tables. Please stamp the current version of the schema " + "manually.")) + + +def db_version_control(engine, abs_path, version=None): + """Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + versioning_api.version_control(engine, repository, version) + return version + + +def _find_migrate_repo(abs_path): + """Get the project's change script repository + + :param abs_path: Absolute path to migrate repository + """ + if not os.path.exists(abs_path): + raise exception.DbMigrationError("Path %s not found" % abs_path) + return Repository(abs_path) diff --git a/oslo/db/sqlalchemy/migration_cli/README.rst b/oslo_db/sqlalchemy/migration_cli/README.rst similarity index 100% rename from oslo/db/sqlalchemy/migration_cli/README.rst rename to oslo_db/sqlalchemy/migration_cli/README.rst diff --git a/oslo_db/sqlalchemy/migration_cli/__init__.py b/oslo_db/sqlalchemy/migration_cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/oslo/db/sqlalchemy/migration_cli/ext_alembic.py b/oslo_db/sqlalchemy/migration_cli/ext_alembic.py similarity index 96% rename from oslo/db/sqlalchemy/migration_cli/ext_alembic.py rename to oslo_db/sqlalchemy/migration_cli/ext_alembic.py index be3d8c63..243ae47e 100644 --- a/oslo/db/sqlalchemy/migration_cli/ext_alembic.py +++ b/oslo_db/sqlalchemy/migration_cli/ext_alembic.py @@ -16,8 +16,8 @@ import alembic from alembic import config as alembic_config import alembic.migration as alembic_migration -from oslo.db.sqlalchemy.migration_cli import ext_base -from oslo.db.sqlalchemy import session as db_session +from oslo_db.sqlalchemy.migration_cli import ext_base +from oslo_db.sqlalchemy import session as db_session class AlembicExtension(ext_base.MigrationExtensionBase): diff --git a/oslo/db/sqlalchemy/migration_cli/ext_base.py b/oslo_db/sqlalchemy/migration_cli/ext_base.py similarity index 100% rename from oslo/db/sqlalchemy/migration_cli/ext_base.py rename to oslo_db/sqlalchemy/migration_cli/ext_base.py diff --git a/oslo/db/sqlalchemy/migration_cli/ext_migrate.py b/oslo_db/sqlalchemy/migration_cli/ext_migrate.py similarity index 92% rename from oslo/db/sqlalchemy/migration_cli/ext_migrate.py rename to oslo_db/sqlalchemy/migration_cli/ext_migrate.py index 758fe609..e31ee3d8 100644 --- a/oslo/db/sqlalchemy/migration_cli/ext_migrate.py +++ b/oslo_db/sqlalchemy/migration_cli/ext_migrate.py @@ -13,10 +13,10 @@ import logging import os -from oslo.db._i18n import _LE -from oslo.db.sqlalchemy import migration -from oslo.db.sqlalchemy.migration_cli import ext_base -from oslo.db.sqlalchemy import session as db_session +from oslo_db._i18n import _LE +from oslo_db.sqlalchemy import migration +from oslo_db.sqlalchemy.migration_cli import ext_base +from oslo_db.sqlalchemy import session as db_session LOG = logging.getLogger(__name__) diff --git a/oslo/db/sqlalchemy/migration_cli/manager.py b/oslo_db/sqlalchemy/migration_cli/manager.py similarity index 100% rename from oslo/db/sqlalchemy/migration_cli/manager.py rename to oslo_db/sqlalchemy/migration_cli/manager.py diff --git a/oslo_db/sqlalchemy/models.py b/oslo_db/sqlalchemy/models.py new file mode 100644 index 00000000..818c1b40 --- /dev/null +++ b/oslo_db/sqlalchemy/models.py @@ -0,0 +1,128 @@ +# Copyright (c) 2011 X.commerce, a business unit of eBay Inc. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2012 Cloudscaling Group, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +SQLAlchemy models. +""" + +import six + +from oslo.utils import timeutils +from sqlalchemy import Column, Integer +from sqlalchemy import DateTime +from sqlalchemy.orm import object_mapper + + +class ModelBase(six.Iterator): + """Base class for models.""" + __table_initialized__ = False + + def save(self, session): + """Save this object.""" + + # NOTE(boris-42): This part of code should be look like: + # session.add(self) + # session.flush() + # But there is a bug in sqlalchemy and eventlet that + # raises NoneType exception if there is no running + # transaction and rollback is called. As long as + # sqlalchemy has this bug we have to create transaction + # explicitly. + with session.begin(subtransactions=True): + session.add(self) + session.flush() + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key) + + def __contains__(self, key): + return hasattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + @property + def _extra_keys(self): + """Specifies custom fields + + Subclasses can override this property to return a list + of custom fields that should be included in their dict + representation. + + For reference check tests/db/sqlalchemy/test_models.py + """ + return [] + + def __iter__(self): + columns = list(dict(object_mapper(self).columns).keys()) + # NOTE(russellb): Allow models to specify other keys that can be looked + # up, beyond the actual db columns. An example would be the 'name' + # property for an Instance. + columns.extend(self._extra_keys) + + return ModelIterator(self, iter(columns)) + + def update(self, values): + """Make the model object behave like a dict.""" + for k, v in six.iteritems(values): + setattr(self, k, v) + + def iteritems(self): + """Make the model object behave like a dict. + + Includes attributes from joins. + """ + local = dict(self) + joined = dict([(k, v) for k, v in six.iteritems(self.__dict__) + if not k[0] == '_']) + local.update(joined) + return six.iteritems(local) + + +class ModelIterator(ModelBase, six.Iterator): + + def __init__(self, model, columns): + self.model = model + self.i = columns + + def __iter__(self): + return self + + # In Python 3, __next__() has replaced next(). + def __next__(self): + n = six.advance_iterator(self.i) + return n, getattr(self.model, n) + + +class TimestampMixin(object): + created_at = Column(DateTime, default=lambda: timeutils.utcnow()) + updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow()) + + +class SoftDeleteMixin(object): + deleted_at = Column(DateTime) + deleted = Column(Integer, default=0) + + def soft_delete(self, session): + """Mark this object as deleted.""" + self.deleted = self.id + self.deleted_at = timeutils.utcnow() + self.save(session=session) diff --git a/oslo_db/sqlalchemy/provision.py b/oslo_db/sqlalchemy/provision.py new file mode 100644 index 00000000..4f74bc65 --- /dev/null +++ b/oslo_db/sqlalchemy/provision.py @@ -0,0 +1,507 @@ +# Copyright 2013 Mirantis.inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provision test environment for specific DB backends""" + +import abc +import argparse +import logging +import os +import random +import re +import string + +import six +from six import moves +import sqlalchemy +from sqlalchemy.engine import url as sa_url + +from oslo_db._i18n import _LI +from oslo_db import exception +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import utils + +LOG = logging.getLogger(__name__) + + +class ProvisionedDatabase(object): + """Represent a single database node that can be used for testing in + + a serialized fashion. + + ``ProvisionedDatabase`` includes features for full lifecycle management + of a node, in a way that is context-specific. Depending on how the + test environment runs, ``ProvisionedDatabase`` should know if it needs + to create and drop databases or if it is making use of a database that + is maintained by an external process. + + """ + + def __init__(self, database_type): + self.backend = Backend.backend_for_database_type(database_type) + self.db_token = _random_ident() + + self.backend.create_named_database(self.db_token) + self.engine = self.backend.provisioned_engine(self.db_token) + + def dispose(self): + self.engine.dispose() + self.backend.drop_named_database(self.db_token) + + +class Backend(object): + """Represent a particular database backend that may be provisionable. + + The ``Backend`` object maintains a database type (e.g. database without + specific driver type, such as "sqlite", "postgresql", etc.), + a target URL, a base ``Engine`` for that URL object that can be used + to provision databases and a ``BackendImpl`` which knows how to perform + operations against this type of ``Engine``. + + """ + + backends_by_database_type = {} + + def __init__(self, database_type, url): + self.database_type = database_type + self.url = url + self.verified = False + self.engine = None + self.impl = BackendImpl.impl(database_type) + Backend.backends_by_database_type[database_type] = self + + @classmethod + def backend_for_database_type(cls, database_type): + """Return and verify the ``Backend`` for the given database type. + + Creates the engine if it does not already exist and raises + ``BackendNotAvailable`` if it cannot be produced. + + :return: a base ``Engine`` that allows provisioning of databases. + + :raises: ``BackendNotAvailable``, if an engine for this backend + cannot be produced. + + """ + try: + backend = cls.backends_by_database_type[database_type] + except KeyError: + raise exception.BackendNotAvailable(database_type) + else: + return backend._verify() + + @classmethod + def all_viable_backends(cls): + """Return an iterator of all ``Backend`` objects that are present + + and provisionable. + + """ + + for backend in cls.backends_by_database_type.values(): + try: + yield backend._verify() + except exception.BackendNotAvailable: + pass + + def _verify(self): + """Verify that this ``Backend`` is available and provisionable. + + :return: this ``Backend`` + + :raises: ``BackendNotAvailable`` if the backend is not available. + + """ + + if not self.verified: + try: + eng = self._ensure_backend_available(self.url) + except exception.BackendNotAvailable: + raise + else: + self.engine = eng + finally: + self.verified = True + if self.engine is None: + raise exception.BackendNotAvailable(self.database_type) + return self + + @classmethod + def _ensure_backend_available(cls, url): + url = sa_url.make_url(str(url)) + try: + eng = sqlalchemy.create_engine(url) + except ImportError as i_e: + # SQLAlchemy performs an "import" of the DBAPI module + # within create_engine(). So if ibm_db_sa, cx_oracle etc. + # isn't installed, we get an ImportError here. + LOG.info( + _LI("The %(dbapi)s backend is unavailable: %(err)s"), + dict(dbapi=url.drivername, err=i_e)) + raise exception.BackendNotAvailable("No DBAPI installed") + else: + try: + conn = eng.connect() + except sqlalchemy.exc.DBAPIError as d_e: + # upon connect, SQLAlchemy calls dbapi.connect(). This + # usually raises OperationalError and should always at + # least raise a SQLAlchemy-wrapped DBAPI Error. + LOG.info( + _LI("The %(dbapi)s backend is unavailable: %(err)s"), + dict(dbapi=url.drivername, err=d_e) + ) + raise exception.BackendNotAvailable("Could not connect") + else: + conn.close() + return eng + + def create_named_database(self, ident): + """Create a database with the given name.""" + + self.impl.create_named_database(self.engine, ident) + + def drop_named_database(self, ident, conditional=False): + """Drop a database with the given name.""" + + self.impl.drop_named_database( + self.engine, ident, + conditional=conditional) + + def database_exists(self, ident): + """Return True if a database of the given name exists.""" + + return self.impl.database_exists(self.engine, ident) + + def provisioned_engine(self, ident): + """Given the URL of a particular database backend and the string + + name of a particular 'database' within that backend, return + an Engine instance whose connections will refer directly to the + named database. + + For hostname-based URLs, this typically involves switching just the + 'database' portion of the URL with the given name and creating + an engine. + + For URLs that instead deal with DSNs, the rules may be more custom; + for example, the engine may need to connect to the root URL and + then emit a command to switch to the named database. + + """ + return self.impl.provisioned_engine(self.url, ident) + + @classmethod + def _setup(cls): + """Initial startup feature will scan the environment for configured + + URLs and place them into the list of URLs we will use for provisioning. + + This searches through OS_TEST_DBAPI_ADMIN_CONNECTION for URLs. If + not present, we set up URLs based on the "opportunstic" convention, + e.g. username+password = "openstack_citest". + + The provisioning system will then use or discard these URLs as they + are requested, based on whether or not the target database is actually + found to be available. + + """ + configured_urls = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION', None) + if configured_urls: + configured_urls = configured_urls.split(";") + else: + configured_urls = [ + impl.create_opportunistic_driver_url() + for impl in BackendImpl.all_impls() + ] + + for url_str in configured_urls: + url = sa_url.make_url(url_str) + m = re.match(r'([^+]+?)(?:\+(.+))?$', url.drivername) + database_type, drivertype = m.group(1, 2) + Backend(database_type, url) + + +@six.add_metaclass(abc.ABCMeta) +class BackendImpl(object): + """Provide database-specific implementations of key provisioning + + functions. + + ``BackendImpl`` is owned by a ``Backend`` instance which delegates + to it for all database-specific features. + + """ + + @classmethod + def all_impls(cls): + """Return an iterator of all possible BackendImpl objects. + + These are BackendImpls that are implemented, but not + necessarily provisionable. + + """ + for database_type in cls.impl.reg: + if database_type == '*': + continue + yield BackendImpl.impl(database_type) + + @utils.dispatch_for_dialect("*") + def impl(drivername): + """Return a ``BackendImpl`` instance corresponding to the + + given driver name. + + This is a dispatched method which will refer to the constructor + of implementing subclasses. + + """ + raise NotImplementedError( + "No provision impl available for driver: %s" % drivername) + + def __init__(self, drivername): + self.drivername = drivername + + @abc.abstractmethod + def create_opportunistic_driver_url(self): + """Produce a string url known as the 'opportunistic' URL. + + This URL is one that corresponds to an established Openstack + convention for a pre-established database login, which, when + detected as available in the local environment, is automatically + used as a test platform for a specific type of driver. + + """ + + @abc.abstractmethod + def create_named_database(self, engine, ident): + """Create a database with the given name.""" + + @abc.abstractmethod + def drop_named_database(self, engine, ident, conditional=False): + """Drop a database with the given name.""" + + def provisioned_engine(self, base_url, ident): + """Return a provisioned engine. + + Given the URL of a particular database backend and the string + name of a particular 'database' within that backend, return + an Engine instance whose connections will refer directly to the + named database. + + For hostname-based URLs, this typically involves switching just the + 'database' portion of the URL with the given name and creating + an engine. + + For URLs that instead deal with DSNs, the rules may be more custom; + for example, the engine may need to connect to the root URL and + then emit a command to switch to the named database. + + """ + + url = sa_url.make_url(str(base_url)) + url.database = ident + return session.create_engine( + url, + logging_name="%s@%s" % (self.drivername, ident)) + + +@BackendImpl.impl.dispatch_for("mysql") +class MySQLBackendImpl(BackendImpl): + def create_opportunistic_driver_url(self): + return "mysql://openstack_citest:openstack_citest@localhost/" + + def create_named_database(self, engine, ident): + with engine.connect() as conn: + conn.execute("CREATE DATABASE %s" % ident) + + def drop_named_database(self, engine, ident, conditional=False): + with engine.connect() as conn: + if not conditional or self.database_exists(conn, ident): + conn.execute("DROP DATABASE %s" % ident) + + def database_exists(self, engine, ident): + return bool(engine.scalar("SHOW DATABASES LIKE '%s'" % ident)) + + +@BackendImpl.impl.dispatch_for("sqlite") +class SQLiteBackendImpl(BackendImpl): + def create_opportunistic_driver_url(self): + return "sqlite://" + + def create_named_database(self, engine, ident): + url = self._provisioned_database_url(engine.url, ident) + eng = sqlalchemy.create_engine(url) + eng.connect().close() + + def provisioned_engine(self, base_url, ident): + return session.create_engine( + self._provisioned_database_url(base_url, ident)) + + def drop_named_database(self, engine, ident, conditional=False): + url = self._provisioned_database_url(engine.url, ident) + filename = url.database + if filename and (not conditional or os.access(filename, os.F_OK)): + os.remove(filename) + + def database_exists(self, engine, ident): + url = self._provisioned_database_url(engine.url, ident) + filename = url.database + return not filename or os.access(filename, os.F_OK) + + def _provisioned_database_url(self, base_url, ident): + if base_url.database: + return sa_url.make_url("sqlite:////tmp/%s.db" % ident) + else: + return base_url + + +@BackendImpl.impl.dispatch_for("postgresql") +class PostgresqlBackendImpl(BackendImpl): + def create_opportunistic_driver_url(self): + return "postgresql://openstack_citest:openstack_citest"\ + "@localhost/postgres" + + def create_named_database(self, engine, ident): + with engine.connect().execution_options( + isolation_level="AUTOCOMMIT") as conn: + conn.execute("CREATE DATABASE %s" % ident) + + def drop_named_database(self, engine, ident, conditional=False): + with engine.connect().execution_options( + isolation_level="AUTOCOMMIT") as conn: + self._close_out_database_users(conn, ident) + if conditional: + conn.execute("DROP DATABASE IF EXISTS %s" % ident) + else: + conn.execute("DROP DATABASE %s" % ident) + + def database_exists(self, engine, ident): + return bool( + engine.scalar( + sqlalchemy.text( + "select datname from pg_database " + "where datname=:name"), name=ident) + ) + + def _close_out_database_users(self, conn, ident): + """Attempt to guarantee a database can be dropped. + + Optional feature which guarantees no connections with our + username are attached to the DB we're going to drop. + + This method has caveats; for one, the 'pid' column was named + 'procpid' prior to Postgresql 9.2. But more critically, + prior to 9.2 this operation required superuser permissions, + even if the connections we're closing are under the same username + as us. In more recent versions this restriction has been + lifted for same-user connections. + + """ + if conn.dialect.server_version_info >= (9, 2): + conn.execute( + sqlalchemy.text( + "select pg_terminate_backend(pid) " + "from pg_stat_activity " + "where usename=current_user and " + "pid != pg_backend_pid() " + "and datname=:dname" + ), dname=ident) + + +def _random_ident(): + return ''.join( + random.choice(string.ascii_lowercase) + for i in moves.range(10)) + + +def _echo_cmd(args): + idents = [_random_ident() for i in moves.range(args.instances_count)] + print("\n".join(idents)) + + +def _create_cmd(args): + idents = [_random_ident() for i in moves.range(args.instances_count)] + + for backend in Backend.all_viable_backends(): + for ident in idents: + backend.create_named_database(ident) + + print("\n".join(idents)) + + +def _drop_cmd(args): + for backend in Backend.all_viable_backends(): + for ident in args.instances: + backend.drop_named_database(ident, args.conditional) + +Backend._setup() + + +def main(argv=None): + """Command line interface to create/drop databases. + + ::create: Create test database with random names. + ::drop: Drop database created by previous command. + ::echo: create random names and display them; don't create. + """ + parser = argparse.ArgumentParser( + description='Controller to handle database creation and dropping' + ' commands.', + epilog='Typically called by the test runner, e.g. shell script, ' + 'testr runner via .testr.conf, or other system.') + subparsers = parser.add_subparsers( + help='Subcommands to manipulate temporary test databases.') + + create = subparsers.add_parser( + 'create', + help='Create temporary test databases.') + create.set_defaults(which=_create_cmd) + create.add_argument( + 'instances_count', + type=int, + help='Number of databases to create.') + + drop = subparsers.add_parser( + 'drop', + help='Drop temporary test databases.') + drop.set_defaults(which=_drop_cmd) + drop.add_argument( + 'instances', + nargs='+', + help='List of databases uri to be dropped.') + drop.add_argument( + '--conditional', + action="store_true", + help="Check if database exists first before dropping" + ) + + echo = subparsers.add_parser( + 'echo', + help="Create random database names and display only." + ) + echo.set_defaults(which=_echo_cmd) + echo.add_argument( + 'instances_count', + type=int, + help='Number of identifiers to create.') + + args = parser.parse_args(argv) + + cmd = args.which + cmd(args) + + +if __name__ == "__main__": + main() diff --git a/oslo_db/sqlalchemy/session.py b/oslo_db/sqlalchemy/session.py new file mode 100644 index 00000000..24bf31d4 --- /dev/null +++ b/oslo_db/sqlalchemy/session.py @@ -0,0 +1,847 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Session Handling for SQLAlchemy backend. + +Recommended ways to use sessions within this framework: + +* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``. + `model_query()` will implicitly use a session when called without one + supplied. This is the ideal situation because it will allow queries + to be automatically retried if the database connection is interrupted. + + .. note:: Automatic retry will be enabled in a future patch. + + It is generally fine to issue several queries in a row like this. Even though + they may be run in separate transactions and/or separate sessions, each one + will see the data from the prior calls. If needed, undo- or rollback-like + functionality should be handled at a logical level. For an example, look at + the code around quotas and `reservation_rollback()`. + + Examples: + + .. code-block:: python + + def get_foo(context, foo): + return (model_query(context, models.Foo). + filter_by(foo=foo). + first()) + + def update_foo(context, id, newfoo): + (model_query(context, models.Foo). + filter_by(id=id). + update({'foo': newfoo})) + + def create_foo(context, values): + foo_ref = models.Foo() + foo_ref.update(values) + foo_ref.save() + return foo_ref + + +* Within the scope of a single method, keep all the reads and writes within + the context managed by a single session. In this way, the session's + `__exit__` handler will take care of calling `flush()` and `commit()` for + you. If using this approach, you should not explicitly call `flush()` or + `commit()`. Any error within the context of the session will cause the + session to emit a `ROLLBACK`. Database errors like `IntegrityError` will be + raised in `session`'s `__exit__` handler, and any try/except within the + context managed by `session` will not be triggered. And catching other + non-database errors in the session will not trigger the ROLLBACK, so + exception handlers should always be outside the session, unless the + developer wants to do a partial commit on purpose. If the connection is + dropped before this is possible, the database will implicitly roll back the + transaction. + + .. note:: Statements in the session scope will not be automatically retried. + + If you create models within the session, they need to be added, but you + do not need to call `model.save()`: + + .. code-block:: python + + def create_many_foo(context, foos): + session = sessionmaker() + with session.begin(): + for foo in foos: + foo_ref = models.Foo() + foo_ref.update(foo) + session.add(foo_ref) + + def update_bar(context, foo_id, newbar): + session = sessionmaker() + with session.begin(): + foo_ref = (model_query(context, models.Foo, session). + filter_by(id=foo_id). + first()) + (model_query(context, models.Bar, session). + filter_by(id=foo_ref['bar_id']). + update({'bar': newbar})) + + .. note:: `update_bar` is a trivially simple example of using + ``with session.begin``. Whereas `create_many_foo` is a good example of + when a transaction is needed, it is always best to use as few queries as + possible. + + The two queries in `update_bar` can be better expressed using a single query + which avoids the need for an explicit transaction. It can be expressed like + so: + + .. code-block:: python + + def update_bar(context, foo_id, newbar): + subq = (model_query(context, models.Foo.id). + filter_by(id=foo_id). + limit(1). + subquery()) + (model_query(context, models.Bar). + filter_by(id=subq.as_scalar()). + update({'bar': newbar})) + + For reference, this emits approximately the following SQL statement: + + .. code-block:: sql + + UPDATE bar SET bar = ${newbar} + WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1); + + .. note:: `create_duplicate_foo` is a trivially simple example of catching an + exception while using ``with session.begin``. Here create two duplicate + instances with same primary key, must catch the exception out of context + managed by a single session: + + .. code-block:: python + + def create_duplicate_foo(context): + foo1 = models.Foo() + foo2 = models.Foo() + foo1.id = foo2.id = 1 + session = sessionmaker() + try: + with session.begin(): + session.add(foo1) + session.add(foo2) + except exception.DBDuplicateEntry as e: + handle_error(e) + +* Passing an active session between methods. Sessions should only be passed + to private methods. The private method must use a subtransaction; otherwise + SQLAlchemy will throw an error when you call `session.begin()` on an existing + transaction. Public methods should not accept a session parameter and should + not be involved in sessions within the caller's scope. + + Note that this incurs more overhead in SQLAlchemy than the above means + due to nesting transactions, and it is not possible to implicitly retry + failed database operations when using this approach. + + This also makes code somewhat more difficult to read and debug, because a + single database transaction spans more than one method. Error handling + becomes less clear in this situation. When this is needed for code clarity, + it should be clearly documented. + + .. code-block:: python + + def myfunc(foo): + session = sessionmaker() + with session.begin(): + # do some database things + bar = _private_func(foo, session) + return bar + + def _private_func(foo, session=None): + if not session: + session = sessionmaker() + with session.begin(subtransaction=True): + # do some other database things + return bar + + +There are some things which it is best to avoid: + +* Don't keep a transaction open any longer than necessary. + + This means that your ``with session.begin()`` block should be as short + as possible, while still containing all the related calls for that + transaction. + +* Avoid ``with_lockmode('UPDATE')`` when possible. + + In MySQL/InnoDB, when a ``SELECT ... FOR UPDATE`` query does not match + any rows, it will take a gap-lock. This is a form of write-lock on the + "gap" where no rows exist, and prevents any other writes to that space. + This can effectively prevent any INSERT into a table by locking the gap + at the end of the index. Similar problems will occur if the SELECT FOR UPDATE + has an overly broad WHERE clause, or doesn't properly use an index. + + One idea proposed at ODS Fall '12 was to use a normal SELECT to test the + number of rows matching a query, and if only one row is returned, + then issue the SELECT FOR UPDATE. + + The better long-term solution is to use + ``INSERT .. ON DUPLICATE KEY UPDATE``. + However, this can not be done until the "deleted" columns are removed and + proper UNIQUE constraints are added to the tables. + + +Enabling soft deletes: + +* To use/enable soft-deletes, the `SoftDeleteMixin` must be added + to your model class. For example: + + .. code-block:: python + + class NovaBase(models.SoftDeleteMixin, models.ModelBase): + pass + + +Efficient use of soft deletes: + +* There are two possible ways to mark a record as deleted: + `model.soft_delete()` and `query.soft_delete()`. + + The `model.soft_delete()` method works with a single already-fetched entry. + `query.soft_delete()` makes only one db request for all entries that + correspond to the query. + +* In almost all cases you should use `query.soft_delete()`. Some examples: + + .. code-block:: python + + def soft_delete_bar(): + count = model_query(BarModel).find(some_condition).soft_delete() + if count == 0: + raise Exception("0 entries were soft deleted") + + def complex_soft_delete_with_synchronization_bar(session=None): + if session is None: + session = sessionmaker() + with session.begin(subtransactions=True): + count = (model_query(BarModel). + find(some_condition). + soft_delete(synchronize_session=True)) + # Here synchronize_session is required, because we + # don't know what is going on in outer session. + if count == 0: + raise Exception("0 entries were soft deleted") + +* There is only one situation where `model.soft_delete()` is appropriate: when + you fetch a single record, work with it, and mark it as deleted in the same + transaction. + + .. code-block:: python + + def soft_delete_bar_model(): + session = sessionmaker() + with session.begin(): + bar_ref = model_query(BarModel).find(some_condition).first() + # Work with bar_ref + bar_ref.soft_delete(session=session) + + However, if you need to work with all entries that correspond to query and + then soft delete them you should use the `query.soft_delete()` method: + + .. code-block:: python + + def soft_delete_multi_models(): + session = sessionmaker() + with session.begin(): + query = (model_query(BarModel, session=session). + find(some_condition)) + model_refs = query.all() + # Work with model_refs + query.soft_delete(synchronize_session=False) + # synchronize_session=False should be set if there is no outer + # session and these entries are not used after this. + + When working with many rows, it is very important to use query.soft_delete, + which issues a single query. Using `model.soft_delete()`, as in the following + example, is very inefficient. + + .. code-block:: python + + for bar_ref in bar_refs: + bar_ref.soft_delete(session=session) + # This will produce count(bar_refs) db requests. + +""" + +import itertools +import logging +import re +import time + +from oslo.utils import timeutils +import six +import sqlalchemy.orm +from sqlalchemy import pool +from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import select + +from oslo_db._i18n import _LW +from oslo_db import exception +from oslo_db import options +from oslo_db.sqlalchemy import compat +from oslo_db.sqlalchemy import exc_filters +from oslo_db.sqlalchemy import utils + +LOG = logging.getLogger(__name__) + + +def _thread_yield(dbapi_con, con_record): + """Ensure other greenthreads get a chance to be executed. + + If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will + execute instead of time.sleep(0). + Force a context switch. With common database backends (eg MySQLdb and + sqlite), there is no implicit yield caused by network I/O since they are + implemented by C libraries that eventlet cannot monkey patch. + """ + time.sleep(0) + + +def _connect_ping_listener(connection, branch): + """Ping the server at connection startup. + + Ping the server at transaction begin and transparently reconnect + if a disconnect exception occurs. + """ + if branch: + return + + # turn off "close with result". This can also be accomplished + # by branching the connection, however just setting the flag is + # more performant and also doesn't get involved with some + # connection-invalidation awkardness that occurs (see + # https://bitbucket.org/zzzeek/sqlalchemy/issue/3215/) + save_should_close_with_result = connection.should_close_with_result + connection.should_close_with_result = False + try: + # run a SELECT 1. use a core select() so that + # any details like that needed by Oracle, DB2 etc. are handled. + connection.scalar(select([1])) + except exception.DBConnectionError: + # catch DBConnectionError, which is raised by the filter + # system. + # disconnect detected. The connection is now + # "invalid", but the pool should be ready to return + # new connections assuming they are good now. + # run the select again to re-validate the Connection. + connection.scalar(select([1])) + finally: + connection.should_close_with_result = save_should_close_with_result + + +def _setup_logging(connection_debug=0): + """setup_logging function maps SQL debug level to Python log level. + + Connection_debug is a verbosity of SQL debugging information. + 0=None(default value), + 1=Processed only messages with WARNING level or higher + 50=Processed only messages with INFO level or higher + 100=Processed only messages with DEBUG level + """ + if connection_debug >= 0: + logger = logging.getLogger('sqlalchemy.engine') + if connection_debug >= 100: + logger.setLevel(logging.DEBUG) + elif connection_debug >= 50: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + + +def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, + idle_timeout=3600, + connection_debug=0, max_pool_size=None, max_overflow=None, + pool_timeout=None, sqlite_synchronous=True, + connection_trace=False, max_retries=10, retry_interval=10, + thread_checkin=True, logging_name=None): + """Return a new SQLAlchemy engine.""" + + url = sqlalchemy.engine.url.make_url(sql_connection) + + engine_args = { + "pool_recycle": idle_timeout, + 'convert_unicode': True, + 'connect_args': {}, + 'logging_name': logging_name + } + + _setup_logging(connection_debug) + + _init_connection_args( + url, engine_args, + sqlite_fk=sqlite_fk, + max_pool_size=max_pool_size, + max_overflow=max_overflow, + pool_timeout=pool_timeout + ) + + engine = sqlalchemy.create_engine(url, **engine_args) + + _init_events( + engine, + mysql_sql_mode=mysql_sql_mode, + sqlite_synchronous=sqlite_synchronous, + sqlite_fk=sqlite_fk, + thread_checkin=thread_checkin, + connection_trace=connection_trace + ) + + # register alternate exception handler + exc_filters.register_engine(engine) + + # register engine connect handler + compat.engine_connect(engine, _connect_ping_listener) + + # initial connect + test + _test_connection(engine, max_retries, retry_interval) + + return engine + + +@utils.dispatch_for_dialect('*', multiple=True) +def _init_connection_args( + url, engine_args, + max_pool_size=None, max_overflow=None, pool_timeout=None, **kw): + + pool_class = url.get_dialect().get_pool_class(url) + if issubclass(pool_class, pool.QueuePool): + if max_pool_size is not None: + engine_args['pool_size'] = max_pool_size + if max_overflow is not None: + engine_args['max_overflow'] = max_overflow + if pool_timeout is not None: + engine_args['pool_timeout'] = pool_timeout + + +@_init_connection_args.dispatch_for("sqlite") +def _init_connection_args(url, engine_args, **kw): + pool_class = url.get_dialect().get_pool_class(url) + # singletonthreadpool is used for :memory: connections; + # replace it with StaticPool. + if issubclass(pool_class, pool.SingletonThreadPool): + engine_args["poolclass"] = pool.StaticPool + engine_args['connect_args']['check_same_thread'] = False + + +@_init_connection_args.dispatch_for("postgresql") +def _init_connection_args(url, engine_args, **kw): + if 'client_encoding' not in url.query: + # Set encoding using engine_args instead of connect_args since + # it's supported for PostgreSQL 8.*. More details at: + # http://docs.sqlalchemy.org/en/rel_0_9/dialects/postgresql.html + engine_args['client_encoding'] = 'utf8' + + +@_init_connection_args.dispatch_for("mysql") +def _init_connection_args(url, engine_args, **kw): + if 'charset' not in url.query: + engine_args['connect_args']['charset'] = 'utf8' + + +@_init_connection_args.dispatch_for("mysql+mysqlconnector") +def _init_connection_args(url, engine_args, **kw): + # mysqlconnector engine (<1.0) incorrectly defaults to + # raise_on_warnings=True + # https://bitbucket.org/zzzeek/sqlalchemy/issue/2515 + if 'raise_on_warnings' not in url.query: + engine_args['connect_args']['raise_on_warnings'] = False + + +@_init_connection_args.dispatch_for("mysql+mysqldb") +@_init_connection_args.dispatch_for("mysql+oursql") +def _init_connection_args(url, engine_args, **kw): + # Those drivers require use_unicode=0 to avoid performance drop due + # to internal usage of Python unicode objects in the driver + # http://docs.sqlalchemy.org/en/rel_0_9/dialects/mysql.html + if 'use_unicode' not in url.query: + engine_args['connect_args']['use_unicode'] = 0 + + +@utils.dispatch_for_dialect('*', multiple=True) +def _init_events(engine, thread_checkin=True, connection_trace=False, **kw): + """Set up event listeners for all database backends.""" + + if connection_trace: + _add_trace_comments(engine) + + if thread_checkin: + sqlalchemy.event.listen(engine, 'checkin', _thread_yield) + + +@_init_events.dispatch_for("mysql") +def _init_events(engine, mysql_sql_mode=None, **kw): + """Set up event listeners for MySQL.""" + + if mysql_sql_mode is not None: + @sqlalchemy.event.listens_for(engine, "connect") + def _set_session_sql_mode(dbapi_con, connection_rec): + cursor = dbapi_con.cursor() + cursor.execute("SET SESSION sql_mode = %s", [mysql_sql_mode]) + + @sqlalchemy.event.listens_for(engine, "first_connect") + def _check_effective_sql_mode(dbapi_con, connection_rec): + if mysql_sql_mode is not None: + _set_session_sql_mode(dbapi_con, connection_rec) + + cursor = dbapi_con.cursor() + cursor.execute("SHOW VARIABLES LIKE 'sql_mode'") + realmode = cursor.fetchone() + + if realmode is None: + LOG.warning(_LW('Unable to detect effective SQL mode')) + else: + realmode = realmode[1] + LOG.debug('MySQL server mode set to %s', realmode) + if 'TRADITIONAL' not in realmode.upper() and \ + 'STRICT_ALL_TABLES' not in realmode.upper(): + LOG.warning( + _LW( + "MySQL SQL mode is '%s', " + "consider enabling TRADITIONAL or STRICT_ALL_TABLES"), + realmode) + + +@_init_events.dispatch_for("sqlite") +def _init_events(engine, sqlite_synchronous=True, sqlite_fk=False, **kw): + """Set up event listeners for SQLite. + + This includes several settings made on connections as they are + created, as well as transactional control extensions. + + """ + + def regexp(expr, item): + reg = re.compile(expr) + return reg.search(six.text_type(item)) is not None + + @sqlalchemy.event.listens_for(engine, "connect") + def _sqlite_connect_events(dbapi_con, con_record): + + # Add REGEXP functionality on SQLite connections + dbapi_con.create_function('regexp', 2, regexp) + + if not sqlite_synchronous: + # Switch sqlite connections to non-synchronous mode + dbapi_con.execute("PRAGMA synchronous = OFF") + + # Disable pysqlite's emitting of the BEGIN statement entirely. + # Also stops it from emitting COMMIT before any DDL. + # below, we emit BEGIN ourselves. + # see http://docs.sqlalchemy.org/en/rel_0_9/dialects/\ + # sqlite.html#serializable-isolation-savepoints-transactional-ddl + dbapi_con.isolation_level = None + + if sqlite_fk: + # Ensures that the foreign key constraints are enforced in SQLite. + dbapi_con.execute('pragma foreign_keys=ON') + + @sqlalchemy.event.listens_for(engine, "begin") + def _sqlite_emit_begin(conn): + # emit our own BEGIN, checking for existing + # transactional state + if 'in_transaction' not in conn.info: + conn.execute("BEGIN") + conn.info['in_transaction'] = True + + @sqlalchemy.event.listens_for(engine, "rollback") + @sqlalchemy.event.listens_for(engine, "commit") + def _sqlite_end_transaction(conn): + # remove transactional marker + conn.info.pop('in_transaction', None) + + +def _test_connection(engine, max_retries, retry_interval): + if max_retries == -1: + attempts = itertools.count() + else: + attempts = six.moves.range(max_retries) + # See: http://legacy.python.org/dev/peps/pep-3110/#semantic-changes for + # why we are not using 'de' directly (it can be removed from the local + # scope). + de_ref = None + for attempt in attempts: + try: + return exc_filters.handle_connect_error(engine) + except exception.DBConnectionError as de: + msg = _LW('SQL connection failed. %s attempts left.') + LOG.warning(msg, max_retries - attempt) + time.sleep(retry_interval) + de_ref = de + else: + if de_ref is not None: + six.reraise(type(de_ref), de_ref) + + +class Query(sqlalchemy.orm.query.Query): + """Subclass of sqlalchemy.query with soft_delete() method.""" + def soft_delete(self, synchronize_session='evaluate'): + return self.update({'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow()}, + synchronize_session=synchronize_session) + + +class Session(sqlalchemy.orm.session.Session): + """Custom Session class to avoid SqlAlchemy Session monkey patching.""" + + +def get_maker(engine, autocommit=True, expire_on_commit=False): + """Return a SQLAlchemy sessionmaker using the given engine.""" + return sqlalchemy.orm.sessionmaker(bind=engine, + class_=Session, + autocommit=autocommit, + expire_on_commit=expire_on_commit, + query_cls=Query) + + +def _add_trace_comments(engine): + """Add trace comments. + + Augment statements with a trace of the immediate calling code + for a given statement. + """ + + import os + import sys + import traceback + target_paths = set([ + os.path.dirname(sys.modules['oslo_db'].__file__), + os.path.dirname(sys.modules['sqlalchemy'].__file__) + ]) + + @sqlalchemy.event.listens_for(engine, "before_cursor_execute", retval=True) + def before_cursor_execute(conn, cursor, statement, parameters, context, + executemany): + + # NOTE(zzzeek) - if different steps per DB dialect are desirable + # here, switch out on engine.name for now. + stack = traceback.extract_stack() + our_line = None + for idx, (filename, line, method, function) in enumerate(stack): + for tgt in target_paths: + if filename.startswith(tgt): + our_line = idx + break + if our_line: + break + + if our_line: + trace = "; ".join( + "File: %s (%s) %s" % ( + line[0], line[1], line[2] + ) + # include three lines of context. + for line in stack[our_line - 3:our_line] + + ) + statement = "%s -- %s" % (statement, trace) + + return statement, parameters + + +class EngineFacade(object): + """A helper class for removing of global engine instances from oslo.db. + + As a library, oslo.db can't decide where to store/when to create engine + and sessionmaker instances, so this must be left for a target application. + + On the other hand, in order to simplify the adoption of oslo.db changes, + we'll provide a helper class, which creates engine and sessionmaker + on its instantiation and provides get_engine()/get_session() methods + that are compatible with corresponding utility functions that currently + exist in target projects, e.g. in Nova. + + engine/sessionmaker instances will still be global (and they are meant to + be global), but they will be stored in the app context, rather that in the + oslo.db context. + + Note: using of this helper is completely optional and you are encouraged to + integrate engine/sessionmaker instances into your apps any way you like + (e.g. one might want to bind a session to a request context). Two important + things to remember: + + 1. An Engine instance is effectively a pool of DB connections, so it's + meant to be shared (and it's thread-safe). + 2. A Session instance is not meant to be shared and represents a DB + transactional context (i.e. it's not thread-safe). sessionmaker is + a factory of sessions. + + """ + + def __init__(self, sql_connection, slave_connection=None, + sqlite_fk=False, autocommit=True, + expire_on_commit=False, **kwargs): + """Initialize engine and sessionmaker instances. + + :param sql_connection: the connection string for the database to use + :type sql_connection: string + + :param slave_connection: the connection string for the 'slave' database + to use. If not provided, the master database + will be used for all operations. Note: this + is meant to be used for offloading of read + operations to asynchronously replicated slaves + to reduce the load on the master database. + :type slave_connection: string + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + Keyword arguments: + + :keyword mysql_sql_mode: the SQL mode to be used for MySQL sessions. + (defaults to TRADITIONAL) + :keyword idle_timeout: timeout before idle sql connections are reaped + (defaults to 3600) + :keyword connection_debug: verbosity of SQL debugging information. + -1=Off, 0=None, 100=Everything (defaults + to 0) + :keyword max_pool_size: maximum number of SQL connections to keep open + in a pool (defaults to SQLAlchemy settings) + :keyword max_overflow: if set, use this value for max_overflow with + sqlalchemy (defaults to SQLAlchemy settings) + :keyword pool_timeout: if set, use this value for pool_timeout with + sqlalchemy (defaults to SQLAlchemy settings) + :keyword sqlite_synchronous: if True, SQLite uses synchronous mode + (defaults to True) + :keyword connection_trace: add python stack traces to SQL as comment + strings (defaults to False) + :keyword max_retries: maximum db connection retries during startup. + (setting -1 implies an infinite retry count) + (defaults to 10) + :keyword retry_interval: interval between retries of opening a sql + connection (defaults to 10) + :keyword thread_checkin: boolean that indicates that between each + engine checkin event a sleep(0) will occur to + allow other greenthreads to run (defaults to + True) + """ + + super(EngineFacade, self).__init__() + + engine_kwargs = { + 'sqlite_fk': sqlite_fk, + 'mysql_sql_mode': kwargs.get('mysql_sql_mode', 'TRADITIONAL'), + 'idle_timeout': kwargs.get('idle_timeout', 3600), + 'connection_debug': kwargs.get('connection_debug', 0), + 'max_pool_size': kwargs.get('max_pool_size'), + 'max_overflow': kwargs.get('max_overflow'), + 'pool_timeout': kwargs.get('pool_timeout'), + 'sqlite_synchronous': kwargs.get('sqlite_synchronous', True), + 'connection_trace': kwargs.get('connection_trace', False), + 'max_retries': kwargs.get('max_retries', 10), + 'retry_interval': kwargs.get('retry_interval', 10), + 'thread_checkin': kwargs.get('thread_checkin', True) + } + maker_kwargs = { + 'autocommit': autocommit, + 'expire_on_commit': expire_on_commit + } + + self._engine = create_engine(sql_connection=sql_connection, + **engine_kwargs) + self._session_maker = get_maker(engine=self._engine, + **maker_kwargs) + if slave_connection: + self._slave_engine = create_engine(sql_connection=slave_connection, + **engine_kwargs) + self._slave_session_maker = get_maker(engine=self._slave_engine, + **maker_kwargs) + else: + self._slave_engine = None + self._slave_session_maker = None + + def get_engine(self, use_slave=False): + """Get the engine instance (note, that it's shared). + + :param use_slave: if possible, use 'slave' database for this engine. + If the connection string for the slave database + wasn't provided, 'master' engine will be returned. + (defaults to False) + :type use_slave: bool + + """ + + if use_slave and self._slave_engine: + return self._slave_engine + + return self._engine + + def get_session(self, use_slave=False, **kwargs): + """Get a Session instance. + + :param use_slave: if possible, use 'slave' database connection for + this session. If the connection string for the + slave database wasn't provided, a session bound + to the 'master' engine will be returned. + (defaults to False) + :type use_slave: bool + + Keyword arugments will be passed to a sessionmaker instance as is (if + passed, they will override the ones used when the sessionmaker instance + was created). See SQLAlchemy Session docs for details. + + """ + + if use_slave and self._slave_session_maker: + return self._slave_session_maker(**kwargs) + + return self._session_maker(**kwargs) + + @classmethod + def from_config(cls, conf, + sqlite_fk=False, autocommit=True, expire_on_commit=False): + """Initialize EngineFacade using oslo.config config instance options. + + :param conf: oslo.config config instance + :type conf: oslo.config.cfg.ConfigOpts + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + conf.register_opts(options.database_opts, 'database') + + return cls(sql_connection=conf.database.connection, + slave_connection=conf.database.slave_connection, + sqlite_fk=sqlite_fk, + autocommit=autocommit, + expire_on_commit=expire_on_commit, + mysql_sql_mode=conf.database.mysql_sql_mode, + idle_timeout=conf.database.idle_timeout, + connection_debug=conf.database.connection_debug, + max_pool_size=conf.database.max_pool_size, + max_overflow=conf.database.max_overflow, + pool_timeout=conf.database.pool_timeout, + sqlite_synchronous=conf.database.sqlite_synchronous, + connection_trace=conf.database.connection_trace, + max_retries=conf.database.max_retries, + retry_interval=conf.database.retry_interval) diff --git a/oslo_db/sqlalchemy/test_base.py b/oslo_db/sqlalchemy/test_base.py new file mode 100644 index 00000000..aaff621c --- /dev/null +++ b/oslo_db/sqlalchemy/test_base.py @@ -0,0 +1,127 @@ +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import fixtures + +try: + from oslotest import base as test_base +except ImportError: + raise NameError('Oslotest is not installed. Please add oslotest in your' + ' test-requirements') + + +import six + +from oslo_db import exception +from oslo_db.sqlalchemy import provision +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import utils + + +class DbFixture(fixtures.Fixture): + """Basic database fixture. + + Allows to run tests on various db backends, such as SQLite, MySQL and + PostgreSQL. By default use sqlite backend. To override default backend + uri set env variable OS_TEST_DBAPI_CONNECTION with database admin + credentials for specific backend. + """ + + DRIVER = "sqlite" + + # these names are deprecated, and are not used by DbFixture. + # they are here for backwards compatibility with test suites that + # are referring to them directly. + DBNAME = PASSWORD = USERNAME = 'openstack_citest' + + def __init__(self, test): + super(DbFixture, self).__init__() + + self.test = test + + def setUp(self): + super(DbFixture, self).setUp() + + try: + self.provision = provision.ProvisionedDatabase(self.DRIVER) + self.addCleanup(self.provision.dispose) + except exception.BackendNotAvailable: + msg = '%s backend is not available.' % self.DRIVER + return self.test.skip(msg) + else: + self.test.engine = self.provision.engine + self.addCleanup(setattr, self.test, 'engine', None) + self.test.sessionmaker = session.get_maker(self.test.engine) + self.addCleanup(setattr, self.test, 'sessionmaker', None) + + +class DbTestCase(test_base.BaseTestCase): + """Base class for testing of DB code. + + Using `DbFixture`. Intended to be the main database test case to use all + the tests on a given backend with user defined uri. Backend specific + tests should be decorated with `backend_specific` decorator. + """ + + FIXTURE = DbFixture + + def setUp(self): + super(DbTestCase, self).setUp() + self.useFixture(self.FIXTURE(self)) + + +class OpportunisticTestCase(DbTestCase): + """Placeholder for backwards compatibility.""" + +ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql'] + + +def backend_specific(*dialects): + """Decorator to skip backend specific tests on inappropriate engines. + + ::dialects: list of dialects names under which the test will be launched. + """ + def wrap(f): + @six.wraps(f) + def ins_wrap(self): + if not set(dialects).issubset(ALLOWED_DIALECTS): + raise ValueError( + "Please use allowed dialects: %s" % ALLOWED_DIALECTS) + if self.engine.name not in dialects: + msg = ('The test "%s" can be run ' + 'only on %s. Current engine is %s.') + args = (utils.get_callable_name(f), ' '.join(dialects), + self.engine.name) + self.skip(msg % args) + else: + return f(self) + return ins_wrap + return wrap + + +class MySQLOpportunisticFixture(DbFixture): + DRIVER = 'mysql' + + +class PostgreSQLOpportunisticFixture(DbFixture): + DRIVER = 'postgresql' + + +class MySQLOpportunisticTestCase(OpportunisticTestCase): + FIXTURE = MySQLOpportunisticFixture + + +class PostgreSQLOpportunisticTestCase(OpportunisticTestCase): + FIXTURE = PostgreSQLOpportunisticFixture diff --git a/oslo_db/sqlalchemy/test_migrations.py b/oslo_db/sqlalchemy/test_migrations.py new file mode 100644 index 00000000..6c21275d --- /dev/null +++ b/oslo_db/sqlalchemy/test_migrations.py @@ -0,0 +1,613 @@ +# Copyright 2010-2011 OpenStack Foundation +# Copyright 2012-2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc +import collections +import logging +import pprint + +import alembic +import alembic.autogenerate +import alembic.migration +import pkg_resources as pkg +import six +import sqlalchemy +from sqlalchemy.engine import reflection +import sqlalchemy.exc +from sqlalchemy import schema +import sqlalchemy.sql.expression as expr +import sqlalchemy.types as types + +from oslo_db._i18n import _LE +from oslo_db import exception as exc +from oslo_db.sqlalchemy import utils + +LOG = logging.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class WalkVersionsMixin(object): + """Test mixin to check upgrade and downgrade ability of migration. + + This is only suitable for testing of migrate_ migration scripts. An + abstract class mixin. `INIT_VERSION`, `REPOSITORY` and `migration_api` + attributes must be implemented in subclasses. + + .. _auxiliary-dynamic-methods: Auxiliary Methods + + Auxiliary Methods: + + `migrate_up` and `migrate_down` instance methods of the class can be + used with auxiliary methods named `_pre_upgrade_`, + `_check_`, `_post_downgrade_`. The methods + intended to check applied changes for correctness of data operations. + This methods should be implemented for every particular revision + which you want to check with data. Implementation recommendations for + `_pre_upgrade_`, `_check_`, + `_post_downgrade_` implementation: + + * `_pre_upgrade_`: provide a data appropriate to + a next revision. Should be used an id of revision which + going to be applied. + + * `_check_`: Insert, select, delete operations + with newly applied changes. The data provided by + `_pre_upgrade_` will be used. + + * `_post_downgrade_`: check for absence + (inability to use) changes provided by reverted revision. + + Execution order of auxiliary methods when revision is upgrading: + + `_pre_upgrade_###` => `upgrade` => `_check_###` + + Execution order of auxiliary methods when revision is downgrading: + + `downgrade` => `_post_downgrade_###` + + .. _migrate: https://sqlalchemy-migrate.readthedocs.org/en/latest/ + + """ + + @abc.abstractproperty + def INIT_VERSION(self): + """Initial version of a migration repository. + + Can be different from 0, if a migrations were squashed. + + :rtype: int + """ + pass + + @abc.abstractproperty + def REPOSITORY(self): + """Allows basic manipulation with migration repository. + + :returns: `migrate.versioning.repository.Repository` subclass. + """ + pass + + @abc.abstractproperty + def migration_api(self): + """Provides API for upgrading, downgrading and version manipulations. + + :returns: `migrate.api` or overloaded analog. + """ + pass + + @abc.abstractproperty + def migrate_engine(self): + """Provides engine instance. + + Should be the same instance as used when migrations are applied. In + most cases, the `engine` attribute provided by the test class in a + `setUp` method will work. + + Example of implementation: + + def migrate_engine(self): + return self.engine + + :returns: sqlalchemy engine instance + """ + pass + + def _walk_versions(self, snake_walk=False, downgrade=True): + """Check if migration upgrades and downgrades successfully. + + DEPRECATED: this function is deprecated and will be removed from + oslo.db in a few releases. Please use walk_versions() method instead. + """ + self.walk_versions(snake_walk, downgrade) + + def _migrate_down(self, version, with_data=False): + """Migrate down to a previous version of the db. + + DEPRECATED: this function is deprecated and will be removed from + oslo.db in a few releases. Please use migrate_down() method instead. + """ + return self.migrate_down(version, with_data) + + def _migrate_up(self, version, with_data=False): + """Migrate up to a new version of the db. + + DEPRECATED: this function is deprecated and will be removed from + oslo.db in a few releases. Please use migrate_up() method instead. + """ + self.migrate_up(version, with_data) + + def walk_versions(self, snake_walk=False, downgrade=True): + """Check if migration upgrades and downgrades successfully. + + Determine the latest version script from the repo, then + upgrade from 1 through to the latest, with no data + in the databases. This just checks that the schema itself + upgrades successfully. + + `walk_versions` calls `migrate_up` and `migrate_down` with + `with_data` argument to check changes with data, but these methods + can be called without any extra check outside of `walk_versions` + method. + + :param snake_walk: enables checking that each individual migration can + be upgraded/downgraded by itself. + + If we have ordered migrations 123abc, 456def, 789ghi and we run + upgrading with the `snake_walk` argument set to `True`, the + migrations will be applied in the following order: + + `123abc => 456def => 123abc => + 456def => 789ghi => 456def => 789ghi` + + :type snake_walk: bool + :param downgrade: Check downgrade behavior if True. + :type downgrade: bool + """ + + # Place the database under version control + self.migration_api.version_control(self.migrate_engine, + self.REPOSITORY, + self.INIT_VERSION) + self.assertEqual(self.INIT_VERSION, + self.migration_api.db_version(self.migrate_engine, + self.REPOSITORY)) + + LOG.debug('latest version is %s', self.REPOSITORY.latest) + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + for version in versions: + # upgrade -> downgrade -> upgrade + self.migrate_up(version, with_data=True) + if snake_walk: + downgraded = self.migrate_down(version - 1, with_data=True) + if downgraded: + self.migrate_up(version) + + if downgrade: + # Now walk it back down to 0 from the latest, testing + # the downgrade paths. + for version in reversed(versions): + # downgrade -> upgrade -> downgrade + downgraded = self.migrate_down(version - 1) + + if snake_walk and downgraded: + self.migrate_up(version) + self.migrate_down(version - 1) + + def migrate_down(self, version, with_data=False): + """Migrate down to a previous version of the db. + + :param version: id of revision to downgrade. + :type version: str + :keyword with_data: Whether to verify the absence of changes from + migration(s) being downgraded, see + :ref:`auxiliary-dynamic-methods `. + :type with_data: Bool + """ + + try: + self.migration_api.downgrade(self.migrate_engine, + self.REPOSITORY, version) + except NotImplementedError: + # NOTE(sirp): some migrations, namely release-level + # migrations, don't support a downgrade. + return False + + self.assertEqual(version, self.migration_api.db_version( + self.migrate_engine, self.REPOSITORY)) + + # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target' + # version). So if we have any downgrade checks, they need to be run for + # the previous (higher numbered) migration. + if with_data: + post_downgrade = getattr( + self, "_post_downgrade_%03d" % (version + 1), None) + if post_downgrade: + post_downgrade(self.migrate_engine) + + return True + + def migrate_up(self, version, with_data=False): + """Migrate up to a new version of the db. + + :param version: id of revision to upgrade. + :type version: str + :keyword with_data: Whether to verify the applied changes with data, + see :ref:`auxiliary-dynamic-methods `. + :type with_data: Bool + """ + # NOTE(sdague): try block is here because it's impossible to debug + # where a failed data migration happens otherwise + try: + if with_data: + data = None + pre_upgrade = getattr( + self, "_pre_upgrade_%03d" % version, None) + if pre_upgrade: + data = pre_upgrade(self.migrate_engine) + + self.migration_api.upgrade(self.migrate_engine, + self.REPOSITORY, version) + self.assertEqual(version, + self.migration_api.db_version(self.migrate_engine, + self.REPOSITORY)) + if with_data: + check = getattr(self, "_check_%03d" % version, None) + if check: + check(self.migrate_engine, data) + except exc.DbMigrationError: + msg = _LE("Failed to migrate to version %(ver)s on engine %(eng)s") + LOG.error(msg, {"ver": version, "eng": self.migrate_engine}) + raise + + +@six.add_metaclass(abc.ABCMeta) +class ModelsMigrationsSync(object): + """A helper class for comparison of DB migration scripts and models. + + It's intended to be inherited by test cases in target projects. They have + to provide implementations for methods used internally in the test (as + we have no way to implement them here). + + test_model_sync() will run migration scripts for the engine provided and + then compare the given metadata to the one reflected from the database. + The difference between MODELS and MIGRATION scripts will be printed and + the test will fail, if the difference is not empty. The return value is + really a list of actions, that should be performed in order to make the + current database schema state (i.e. migration scripts) consistent with + models definitions. It's left up to developers to analyze the output and + decide whether the models definitions or the migration scripts should be + modified to make them consistent. + + Output:: + + [( + 'add_table', + description of the table from models + ), + ( + 'remove_table', + description of the table from database + ), + ( + 'add_column', + schema, + table name, + column description from models + ), + ( + 'remove_column', + schema, + table name, + column description from database + ), + ( + 'add_index', + description of the index from models + ), + ( + 'remove_index', + description of the index from database + ), + ( + 'add_constraint', + description of constraint from models + ), + ( + 'remove_constraint, + description of constraint from database + ), + ( + 'modify_nullable', + schema, + table name, + column name, + { + 'existing_type': type of the column from database, + 'existing_server_default': default value from database + }, + nullable from database, + nullable from models + ), + ( + 'modify_type', + schema, + table name, + column name, + { + 'existing_nullable': database nullable, + 'existing_server_default': default value from database + }, + database column type, + type of the column from models + ), + ( + 'modify_default', + schema, + table name, + column name, + { + 'existing_nullable': database nullable, + 'existing_type': type of the column from database + }, + connection column default value, + default from models + )] + + Method include_object() can be overridden to exclude some tables from + comparison (e.g. migrate_repo). + + """ + + @abc.abstractmethod + def db_sync(self, engine): + """Run migration scripts with the given engine instance. + + This method must be implemented in subclasses and run migration scripts + for a DB the given engine is connected to. + + """ + + @abc.abstractmethod + def get_engine(self): + """Return the engine instance to be used when running tests. + + This method must be implemented in subclasses and return an engine + instance to be used when running tests. + + """ + + @abc.abstractmethod + def get_metadata(self): + """Return the metadata instance to be used for schema comparison. + + This method must be implemented in subclasses and return the metadata + instance attached to the BASE model. + + """ + + def include_object(self, object_, name, type_, reflected, compare_to): + """Return True for objects that should be compared. + + :param object_: a SchemaItem object such as a Table or Column object + :param name: the name of the object + :param type_: a string describing the type of object (e.g. "table") + :param reflected: True if the given object was produced based on + table reflection, False if it's from a local + MetaData object + :param compare_to: the object being compared against, if available, + else None + + """ + + return True + + def compare_type(self, ctxt, insp_col, meta_col, insp_type, meta_type): + """Return True if types are different, False if not. + + Return None to allow the default implementation to compare these types. + + :param ctxt: alembic MigrationContext instance + :param insp_col: reflected column + :param meta_col: column from model + :param insp_type: reflected column type + :param meta_type: column type from model + + """ + + # some backends (e.g. mysql) don't provide native boolean type + BOOLEAN_METADATA = (types.BOOLEAN, types.Boolean) + BOOLEAN_SQL = BOOLEAN_METADATA + (types.INTEGER, types.Integer) + + if issubclass(type(meta_type), BOOLEAN_METADATA): + return not issubclass(type(insp_type), BOOLEAN_SQL) + + return None # tells alembic to use the default comparison method + + def compare_server_default(self, ctxt, ins_col, meta_col, + insp_def, meta_def, rendered_meta_def): + """Compare default values between model and db table. + + Return True if the defaults are different, False if not, or None to + allow the default implementation to compare these defaults. + + :param ctxt: alembic MigrationContext instance + :param insp_col: reflected column + :param meta_col: column from model + :param insp_def: reflected column default value + :param meta_def: column default value from model + :param rendered_meta_def: rendered column default value (from model) + + """ + return self._compare_server_default(ctxt.bind, meta_col, insp_def, + meta_def) + + @utils.DialectFunctionDispatcher.dispatch_for_dialect("*") + def _compare_server_default(bind, meta_col, insp_def, meta_def): + pass + + @_compare_server_default.dispatch_for('mysql') + def _compare_server_default(bind, meta_col, insp_def, meta_def): + if isinstance(meta_col.type, sqlalchemy.Boolean): + if meta_def is None or insp_def is None: + return meta_def != insp_def + return not ( + isinstance(meta_def.arg, expr.True_) and insp_def == "'1'" or + isinstance(meta_def.arg, expr.False_) and insp_def == "'0'" + ) + + if isinstance(meta_col.type, sqlalchemy.Integer): + if meta_def is None or insp_def is None: + return meta_def != insp_def + return meta_def.arg != insp_def.split("'")[1] + + @_compare_server_default.dispatch_for('postgresql') + def _compare_server_default(bind, meta_col, insp_def, meta_def): + if isinstance(meta_col.type, sqlalchemy.Enum): + if meta_def is None or insp_def is None: + return meta_def != insp_def + return insp_def != "'%s'::%s" % (meta_def.arg, meta_col.type.name) + elif isinstance(meta_col.type, sqlalchemy.String): + if meta_def is None or insp_def is None: + return meta_def != insp_def + return insp_def != "'%s'::character varying" % meta_def.arg + + def _cleanup(self): + engine = self.get_engine() + with engine.begin() as conn: + inspector = reflection.Inspector.from_engine(engine) + metadata = schema.MetaData() + tbs = [] + all_fks = [] + + for table_name in inspector.get_table_names(): + fks = [] + for fk in inspector.get_foreign_keys(table_name): + if not fk['name']: + continue + fks.append( + schema.ForeignKeyConstraint((), (), name=fk['name']) + ) + table = schema.Table(table_name, metadata, *fks) + tbs.append(table) + all_fks.extend(fks) + + for fkc in all_fks: + conn.execute(schema.DropConstraint(fkc)) + + for table in tbs: + conn.execute(schema.DropTable(table)) + + FKInfo = collections.namedtuple('fk_info', ['constrained_columns', + 'referred_table', + 'referred_columns']) + + def check_foreign_keys(self, metadata, bind): + """Compare foreign keys between model and db table. + + :returns: a list that contains information about: + + * should be a new key added or removed existing, + * name of that key, + * source table, + * referred table, + * constrained columns, + * referred columns + + Output:: + + [('drop_key', + 'testtbl_fk_check_fkey', + 'testtbl', + fk_info(constrained_columns=(u'fk_check',), + referred_table=u'table', + referred_columns=(u'fk_check',)))] + + """ + + diff = [] + insp = sqlalchemy.engine.reflection.Inspector.from_engine(bind) + # Get all tables from db + db_tables = insp.get_table_names() + # Get all tables from models + model_tables = metadata.tables + for table in db_tables: + if table not in model_tables: + continue + # Get all necessary information about key of current table from db + fk_db = dict((self._get_fk_info_from_db(i), i['name']) + for i in insp.get_foreign_keys(table)) + fk_db_set = set(fk_db.keys()) + # Get all necessary information about key of current table from + # models + fk_models = dict((self._get_fk_info_from_model(fk), fk) + for fk in model_tables[table].foreign_keys) + fk_models_set = set(fk_models.keys()) + for key in (fk_db_set - fk_models_set): + diff.append(('drop_key', fk_db[key], table, key)) + LOG.info(("Detected removed foreign key %(fk)r on " + "table %(table)r"), {'fk': fk_db[key], + 'table': table}) + for key in (fk_models_set - fk_db_set): + diff.append(('add_key', fk_models[key], table, key)) + LOG.info(( + "Detected added foreign key for column %(fk)r on table " + "%(table)r"), {'fk': fk_models[key].column.name, + 'table': table}) + return diff + + def _get_fk_info_from_db(self, fk): + return self.FKInfo(tuple(fk['constrained_columns']), + fk['referred_table'], + tuple(fk['referred_columns'])) + + def _get_fk_info_from_model(self, fk): + return self.FKInfo((fk.parent.name,), fk.column.table.name, + (fk.column.name,)) + + def test_models_sync(self): + # recent versions of sqlalchemy and alembic are needed for running of + # this test, but we already have them in requirements + try: + pkg.require('sqlalchemy>=0.8.4', 'alembic>=0.6.2') + except (pkg.VersionConflict, pkg.DistributionNotFound) as e: + self.skipTest('sqlalchemy>=0.8.4 and alembic>=0.6.3 are required' + ' for running of this test: %s' % e) + + # drop all tables after a test run + self.addCleanup(self._cleanup) + + # run migration scripts + self.db_sync(self.get_engine()) + + with self.get_engine().connect() as conn: + opts = { + 'include_object': self.include_object, + 'compare_type': self.compare_type, + 'compare_server_default': self.compare_server_default, + } + mc = alembic.migration.MigrationContext.configure(conn, opts=opts) + + # compare schemas and fail with diff, if it's not empty + diff1 = alembic.autogenerate.compare_metadata(mc, + self.get_metadata()) + diff2 = self.check_foreign_keys(self.get_metadata(), + self.get_engine()) + diff = diff1 + diff2 + if diff: + msg = pprint.pformat(diff, indent=2, width=20) + self.fail( + "Models and migration scripts aren't in sync:\n%s" % msg) diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py new file mode 100644 index 00000000..6a66bb90 --- /dev/null +++ b/oslo_db/sqlalchemy/utils.py @@ -0,0 +1,1012 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010-2011 OpenStack Foundation. +# Copyright 2012 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import logging +import re + +from oslo.utils import timeutils +import six +import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy.engine import Connectable +from sqlalchemy.engine import reflection +from sqlalchemy.engine import url as sa_url +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import UpdateBase +from sqlalchemy.sql import text +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.types import NullType + +from oslo_db import exception +from oslo_db._i18n import _, _LI, _LW +from oslo_db.sqlalchemy import models + +# NOTE(ochuprykov): Add references for backwards compatibility +InvalidSortKey = exception.InvalidSortKey +ColumnError = exception.ColumnError + +LOG = logging.getLogger(__name__) + +_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") + + +def get_callable_name(function): + # TODO(harlowja): Replace this once + # it is possible to use https://review.openstack.org/#/c/122495/ which is + # a more complete and expansive module that does a similar thing... + try: + method_self = six.get_method_self(function) + except AttributeError: + method_self = None + if method_self is not None: + if isinstance(method_self, six.class_types): + im_class = method_self + else: + im_class = type(method_self) + try: + parts = (im_class.__module__, function.__qualname__) + except AttributeError: + parts = (im_class.__module__, im_class.__name__, function.__name__) + else: + try: + parts = (function.__module__, function.__qualname__) + except AttributeError: + parts = (function.__module__, function.__name__) + return '.'.join(parts) + + +def sanitize_db_url(url): + match = _DBURL_REGEX.match(url) + if match: + return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):]) + return url + + +# copy from glance/db/sqlalchemy/api.py +def paginate_query(query, model, limit, sort_keys, marker=None, + sort_dir=None, sort_dirs=None): + """Returns a query with sorting / pagination criteria added. + + Pagination works by requiring a unique sort_key, specified by sort_keys. + (If sort_keys is not unique, then we risk looping through values.) + We use the last row in the previous page as the 'marker' for pagination. + So we must return values that follow the passed marker in the order. + With a single-valued sort_key, this would be easy: sort_key > X. + With a compound-values sort_key, (k1, k2, k3) we must do this to repeat + the lexicographical ordering: + (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3) + + We also have to cope with different sort_directions. + + Typically, the id of the last row is used as the client-facing pagination + marker, then the actual marker object must be fetched from the db and + passed in to us as marker. + + :param query: the query object to which we should add paging/sorting + :param model: the ORM model class + :param limit: maximum number of items to return + :param sort_keys: array of attributes by which results should be sorted + :param marker: the last item of the previous page; we returns the next + results after this value. + :param sort_dir: direction in which results should be sorted (asc, desc) + :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys + + :rtype: sqlalchemy.orm.query.Query + :return: The query with sorting/pagination added. + """ + + if 'id' not in sort_keys: + # TODO(justinsb): If this ever gives a false-positive, check + # the actual primary key, rather than assuming its id + LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?')) + + assert(not (sort_dir and sort_dirs)) + + # Default the sort direction to ascending + if sort_dirs is None and sort_dir is None: + sort_dir = 'asc' + + # Ensure a per-column sort direction + if sort_dirs is None: + sort_dirs = [sort_dir for _sort_key in sort_keys] + + assert(len(sort_dirs) == len(sort_keys)) + + # Add sorting + for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): + try: + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + except KeyError: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) + try: + sort_key_attr = getattr(model, current_sort_key) + except AttributeError: + raise exception.InvalidSortKey() + query = query.order_by(sort_dir_func(sort_key_attr)) + + # Add pagination + if marker is not None: + marker_values = [] + for sort_key in sort_keys: + v = getattr(marker, sort_key) + marker_values.append(v) + + # Build up an array of sort criteria as in the docstring + criteria_list = [] + for i in range(len(sort_keys)): + crit_attrs = [] + for j in range(i): + model_attr = getattr(model, sort_keys[j]) + crit_attrs.append((model_attr == marker_values[j])) + + model_attr = getattr(model, sort_keys[i]) + if sort_dirs[i] == 'desc': + crit_attrs.append((model_attr < marker_values[i])) + else: + crit_attrs.append((model_attr > marker_values[i])) + + criteria = sqlalchemy.sql.and_(*crit_attrs) + criteria_list.append(criteria) + + f = sqlalchemy.sql.or_(*criteria_list) + query = query.filter(f) + + if limit is not None: + query = query.limit(limit) + + return query + + +def _read_deleted_filter(query, db_model, deleted): + if 'deleted' not in db_model.__table__.columns: + raise ValueError(_("There is no `deleted` column in `%s` table. " + "Project doesn't use soft-deleted feature.") + % db_model.__name__) + + default_deleted_value = db_model.__table__.c.deleted.default.arg + if deleted: + query = query.filter(db_model.deleted != default_deleted_value) + else: + query = query.filter(db_model.deleted == default_deleted_value) + return query + + +def _project_filter(query, db_model, project_id): + if 'project_id' not in db_model.__table__.columns: + raise ValueError(_("There is no `project_id` column in `%s` table.") + % db_model.__name__) + + if isinstance(project_id, (list, tuple, set)): + query = query.filter(db_model.project_id.in_(project_id)) + else: + query = query.filter(db_model.project_id == project_id) + + return query + + +def model_query(model, session, args=None, **kwargs): + """Query helper for db.sqlalchemy api methods. + + This accounts for `deleted` and `project_id` fields. + + :param model: Model to query. Must be a subclass of ModelBase. + :type model: models.ModelBase + + :param session: The session to use. + :type session: sqlalchemy.orm.session.Session + + :param args: Arguments to query. If None - model is used. + :type args: tuple + + Keyword arguments: + + :keyword project_id: If present, allows filtering by project_id(s). + Can be either a project_id value, or an iterable of + project_id values, or None. If an iterable is passed, + only rows whose project_id column value is on the + `project_id` list will be returned. If None is passed, + only rows which are not bound to any project, will be + returned. + :type project_id: iterable, + model.__table__.columns.project_id.type, + None type + + :keyword deleted: If present, allows filtering by deleted field. + If True is passed, only deleted entries will be + returned, if False - only existing entries. + :type deleted: bool + + + Usage: + + .. code-block:: python + + from oslo_db.sqlalchemy import utils + + + def get_instance_by_uuid(uuid): + session = get_session() + with session.begin() + return (utils.model_query(models.Instance, session=session) + .filter(models.Instance.uuid == uuid) + .first()) + + def get_nodes_stat(): + data = (Node.id, Node.cpu, Node.ram, Node.hdd) + + session = get_session() + with session.begin() + return utils.model_query(Node, session=session, args=data).all() + + Also you can create your own helper, based on ``utils.model_query()``. + For example, it can be useful if you plan to use ``project_id`` and + ``deleted`` parameters from project's ``context`` + + .. code-block:: python + + from oslo_db.sqlalchemy import utils + + + def _model_query(context, model, session=None, args=None, + project_id=None, project_only=False, + read_deleted=None): + + # We suppose, that functions ``_get_project_id()`` and + # ``_get_deleted()`` should handle passed parameters and + # context object (for example, decide, if we need to restrict a user + # to query his own entries by project_id or only allow admin to read + # deleted entries). For return values, we expect to get + # ``project_id`` and ``deleted``, which are suitable for the + # ``model_query()`` signature. + kwargs = {} + if project_id is not None: + kwargs['project_id'] = _get_project_id(context, project_id, + project_only) + if read_deleted is not None: + kwargs['deleted'] = _get_deleted_dict(context, read_deleted) + session = session or get_session() + + with session.begin(): + return utils.model_query(model, session=session, + args=args, **kwargs) + + def get_instance_by_uuid(context, uuid): + return (_model_query(context, models.Instance, read_deleted='yes') + .filter(models.Instance.uuid == uuid) + .first()) + + def get_nodes_data(context, project_id, project_only='allow_none'): + data = (Node.id, Node.cpu, Node.ram, Node.hdd) + + return (_model_query(context, Node, args=data, project_id=project_id, + project_only=project_only) + .all()) + + """ + + if not issubclass(model, models.ModelBase): + raise TypeError(_("model should be a subclass of ModelBase")) + + query = session.query(model) if not args else session.query(*args) + if 'deleted' in kwargs: + query = _read_deleted_filter(query, model, kwargs['deleted']) + if 'project_id' in kwargs: + query = _project_filter(query, model, kwargs['project_id']) + + return query + + +def get_table(engine, name): + """Returns an sqlalchemy table dynamically from db. + + Needed because the models don't work for us in migrations + as models will be far out of sync with the current data. + + .. warning:: + + Do not use this method when creating ForeignKeys in database migrations + because sqlalchemy needs the same MetaData object to hold information + about the parent table and the reference table in the ForeignKey. This + method uses a unique MetaData object per table object so it won't work + with ForeignKey creation. + """ + metadata = MetaData() + metadata.bind = engine + return Table(name, metadata, autoload=True) + + +class InsertFromSelect(UpdateBase): + """Form the base for `INSERT INTO table (SELECT ... )` statement.""" + def __init__(self, table, select): + self.table = table + self.select = select + + +@compiles(InsertFromSelect) +def visit_insert_from_select(element, compiler, **kw): + """Form the `INSERT INTO table (SELECT ... )` statement.""" + return "INSERT INTO %s %s" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select)) + + +def _get_not_supported_column(col_name_col_instance, column_name): + try: + column = col_name_col_instance[column_name] + except KeyError: + msg = _("Please specify column %s in col_name_col_instance " + "param. It is required because column has unsupported " + "type by SQLite.") + raise exception.ColumnError(msg % column_name) + + if not isinstance(column, Column): + msg = _("col_name_col_instance param has wrong type of " + "column instance for column %s It should be instance " + "of sqlalchemy.Column.") + raise exception.ColumnError(msg % column_name) + return column + + +def drop_old_duplicate_entries_from_table(migrate_engine, table_name, + use_soft_delete, *uc_column_names): + """Drop all old rows having the same values for columns in uc_columns. + + This method drop (or mark ad `deleted` if use_soft_delete is True) old + duplicate rows form table with name `table_name`. + + :param migrate_engine: Sqlalchemy engine + :param table_name: Table with duplicates + :param use_soft_delete: If True - values will be marked as `deleted`, + if False - values will be removed from table + :param uc_column_names: Unique constraint columns + """ + meta = MetaData() + meta.bind = migrate_engine + + table = Table(table_name, meta, autoload=True) + columns_for_group_by = [table.c[name] for name in uc_column_names] + + columns_for_select = [func.max(table.c.id)] + columns_for_select.extend(columns_for_group_by) + + duplicated_rows_select = sqlalchemy.sql.select( + columns_for_select, group_by=columns_for_group_by, + having=func.count(table.c.id) > 1) + + for row in migrate_engine.execute(duplicated_rows_select).fetchall(): + # NOTE(boris-42): Do not remove row that has the biggest ID. + delete_condition = table.c.id != row[0] + is_none = None # workaround for pyflakes + delete_condition &= table.c.deleted_at == is_none + for name in uc_column_names: + delete_condition &= table.c[name] == row[name] + + rows_to_delete_select = sqlalchemy.sql.select( + [table.c.id]).where(delete_condition) + for row in migrate_engine.execute(rows_to_delete_select).fetchall(): + LOG.info(_LI("Deleting duplicated row with id: %(id)s from table: " + "%(table)s"), dict(id=row[0], table=table_name)) + + if use_soft_delete: + delete_statement = table.update().\ + where(delete_condition).\ + values({ + 'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow() + }) + else: + delete_statement = table.delete().where(delete_condition) + migrate_engine.execute(delete_statement) + + +def _get_default_deleted_value(table): + if isinstance(table.c.id.type, Integer): + return 0 + if isinstance(table.c.id.type, String): + return "" + raise exception.ColumnError(_("Unsupported id columns type")) + + +def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): + table = get_table(migrate_engine, table_name) + + insp = reflection.Inspector.from_engine(migrate_engine) + real_indexes = insp.get_indexes(table_name) + existing_index_names = dict( + [(index['name'], index['column_names']) for index in real_indexes]) + + # NOTE(boris-42): Restore indexes on `deleted` column + for index in indexes: + if 'deleted' not in index['column_names']: + continue + name = index['name'] + if name in existing_index_names: + column_names = [table.c[c] for c in existing_index_names[name]] + old_index = Index(name, *column_names, unique=index["unique"]) + old_index.drop(migrate_engine) + + column_names = [table.c[c] for c in index['column_names']] + new_index = Index(index["name"], *column_names, unique=index["unique"]) + new_index.create(migrate_engine) + + +def change_deleted_column_type_to_boolean(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_boolean_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + old_deleted = Column('old_deleted', Boolean, default=False) + old_deleted.create(table, populate_default=False) + + table.update().\ + where(table.c.deleted == table.c.id).\ + values(old_deleted=True).\ + execute() + + table.c.deleted.drop() + table.c.old_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name, + **col_name_col_instance): + insp = reflection.Inspector.from_engine(migrate_engine) + table = get_table(migrate_engine, table_name) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', Boolean, default=0) + columns.append(column_copy) + + constraints = [constraint.copy() for constraint in table.constraints] + + meta = table.metadata + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + c_select = [] + for c in table.c: + if c.name != "deleted": + c_select.append(c) + else: + c_select.append(table.c.deleted == table.c.id) + + ins = InsertFromSelect(new_table, sqlalchemy.sql.select(c_select)) + migrate_engine.execute(ins) + + table.drop() + for index in indexes: + index.create(migrate_engine) + + new_table.rename(table_name) + new_table.update().\ + where(new_table.c.deleted == new_table.c.id).\ + values(deleted=True).\ + execute() + + +def change_deleted_column_type_to_id_type(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_id_type_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + new_deleted = Column('new_deleted', table.c.id.type, + default=_get_default_deleted_value(table)) + new_deleted.create(table, populate_default=True) + + deleted = True # workaround for pyflakes + table.update().\ + where(table.c.deleted == deleted).\ + values(new_deleted=table.c.id).\ + execute() + table.c.deleted.drop() + table.c.new_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name, + **col_name_col_instance): + # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check + # constraints in sqlite DB and our `deleted` column has + # 2 check constraints. So there is only one way to remove + # these constraints: + # 1) Create new table with the same columns, constraints + # and indexes. (except deleted column). + # 2) Copy all data from old to new table. + # 3) Drop old table. + # 4) Rename new table to old table name. + insp = reflection.Inspector.from_engine(migrate_engine) + meta = MetaData(bind=migrate_engine) + table = Table(table_name, meta, autoload=True) + default_deleted_value = _get_default_deleted_value(table) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', table.c.id.type, + default=default_deleted_value) + columns.append(column_copy) + + def is_deleted_column_constraint(constraint): + # NOTE(boris-42): There is no other way to check is CheckConstraint + # associated with deleted column. + if not isinstance(constraint, CheckConstraint): + return False + sqltext = str(constraint.sqltext) + # NOTE(I159): in order to omit the CHECK constraint corresponding + # to `deleted` column we have to test these patterns which may + # vary depending on the SQLAlchemy version used. + constraint_markers = ( + "deleted in (0, 1)", + "deleted IN (:deleted_1, :deleted_2)", + "deleted IN (:param_1, :param_2)" + ) + return any(sqltext.endswith(marker) for marker in constraint_markers) + + constraints = [] + for constraint in table.constraints: + if not is_deleted_column_constraint(constraint): + constraints.append(constraint.copy()) + + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + ins = InsertFromSelect(new_table, table.select()) + migrate_engine.execute(ins) + + table.drop() + for index in indexes: + index.create(migrate_engine) + + new_table.rename(table_name) + deleted = True # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=new_table.c.id).\ + execute() + + # NOTE(boris-42): Fix value of deleted column: False -> "" or 0. + deleted = False # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=default_deleted_value).\ + execute() + + +def get_connect_string(backend, database, user=None, passwd=None, + host='localhost'): + """Get database connection + + Try to get a connection with a very specific set of values, if we get + these then we'll run the tests, otherwise they are skipped + + DEPRECATED: this function is deprecated and will be removed from oslo.db + in a few releases. Please use the provisioning system for dealing + with URLs and database provisioning. + + """ + args = {'backend': backend, + 'user': user, + 'passwd': passwd, + 'host': host, + 'database': database} + if backend == 'sqlite': + template = '%(backend)s:///%(database)s' + else: + template = "%(backend)s://%(user)s:%(passwd)s@%(host)s/%(database)s" + return template % args + + +def is_backend_avail(backend, database, user=None, passwd=None): + """Return True if the given backend is available. + + + DEPRECATED: this function is deprecated and will be removed from oslo.db + in a few releases. Please use the provisioning system to access + databases based on backend availability. + + """ + from oslo_db.sqlalchemy import provision + + connect_uri = get_connect_string(backend=backend, + database=database, + user=user, + passwd=passwd) + try: + eng = provision.Backend._ensure_backend_available(connect_uri) + eng.dispose() + except exception.BackendNotAvailable: + return False + else: + return True + + +def get_db_connection_info(conn_pieces): + database = conn_pieces.path.strip('/') + loc_pieces = conn_pieces.netloc.split('@') + host = loc_pieces[1] + + auth_pieces = loc_pieces[0].split(':') + user = auth_pieces[0] + password = "" + if len(auth_pieces) > 1: + password = auth_pieces[1].strip() + + return (user, password, database, host) + + +def index_exists(migrate_engine, table_name, index_name): + """Check if given index exists. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of the table + :param index_name: name of the index + """ + inspector = reflection.Inspector.from_engine(migrate_engine) + indexes = inspector.get_indexes(table_name) + index_names = [index['name'] for index in indexes] + return index_name in index_names + + +def add_index(migrate_engine, table_name, index_name, idx_columns): + """Create an index for given columns. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of the table + :param index_name: name of the index + :param idx_columns: tuple with names of columns that will be indexed + """ + table = get_table(migrate_engine, table_name) + if not index_exists(migrate_engine, table_name, index_name): + index = Index( + index_name, *[getattr(table.c, col) for col in idx_columns] + ) + index.create() + else: + raise ValueError("Index '%s' already exists!" % index_name) + + +def drop_index(migrate_engine, table_name, index_name): + """Drop index with given name. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of the table + :param index_name: name of the index + """ + table = get_table(migrate_engine, table_name) + for index in table.indexes: + if index.name == index_name: + index.drop() + break + else: + raise ValueError("Index '%s' not found!" % index_name) + + +def change_index_columns(migrate_engine, table_name, index_name, new_columns): + """Change set of columns that are indexed by given index. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of the table + :param index_name: name of the index + :param new_columns: tuple with names of columns that will be indexed + """ + drop_index(migrate_engine, table_name, index_name) + add_index(migrate_engine, table_name, index_name, new_columns) + + +def column_exists(engine, table_name, column): + """Check if table has given column. + + :param engine: sqlalchemy engine + :param table_name: name of the table + :param column: name of the colmn + """ + t = get_table(engine, table_name) + return column in t.c + + +class DialectFunctionDispatcher(object): + @classmethod + def dispatch_for_dialect(cls, expr, multiple=False): + """Provide dialect-specific functionality within distinct functions. + + e.g.:: + + @dispatch_for_dialect("*") + def set_special_option(engine): + pass + + @set_special_option.dispatch_for("sqlite") + def set_sqlite_special_option(engine): + return engine.execute("sqlite thing") + + @set_special_option.dispatch_for("mysql+mysqldb") + def set_mysqldb_special_option(engine): + return engine.execute("mysqldb thing") + + After the above registration, the ``set_special_option()`` function + is now a dispatcher, given a SQLAlchemy ``Engine``, ``Connection``, + URL string, or ``sqlalchemy.engine.URL`` object:: + + eng = create_engine('...') + result = set_special_option(eng) + + The filter system supports two modes, "multiple" and "single". + The default is "single", and requires that one and only one function + match for a given backend. In this mode, the function may also + have a return value, which will be returned by the top level + call. + + "multiple" mode, on the other hand, does not support return + arguments, but allows for any number of matching functions, where + each function will be called:: + + # the initial call sets this up as a "multiple" dispatcher + @dispatch_for_dialect("*", multiple=True) + def set_options(engine): + # set options that apply to *all* engines + + @set_options.dispatch_for("postgresql") + def set_postgresql_options(engine): + # set options that apply to all Postgresql engines + + @set_options.dispatch_for("postgresql+psycopg2") + def set_postgresql_psycopg2_options(engine): + # set options that apply only to "postgresql+psycopg2" + + @set_options.dispatch_for("*+pyodbc") + def set_pyodbc_options(engine): + # set options that apply to all pyodbc backends + + Note that in both modes, any number of additional arguments can be + accepted by member functions. For example, to populate a dictionary of + options, it may be passed in:: + + @dispatch_for_dialect("*", multiple=True) + def set_engine_options(url, opts): + pass + + @set_engine_options.dispatch_for("mysql+mysqldb") + def _mysql_set_default_charset_to_utf8(url, opts): + opts.setdefault('charset', 'utf-8') + + @set_engine_options.dispatch_for("sqlite") + def _set_sqlite_in_memory_check_same_thread(url, opts): + if url.database in (None, 'memory'): + opts['check_same_thread'] = False + + opts = {} + set_engine_options(url, opts) + + The driver specifiers are of the form: + ``[+]``. That is, database name or "*", + followed by an optional ``+`` sign with driver or "*". Omitting + the driver name implies all drivers for that database. + + """ + if multiple: + cls = DialectMultiFunctionDispatcher + else: + cls = DialectSingleFunctionDispatcher + return cls().dispatch_for(expr) + + _db_plus_driver_reg = re.compile(r'([^+]+?)(?:\+(.+))?$') + + def dispatch_for(self, expr): + def decorate(fn): + dbname, driver = self._parse_dispatch(expr) + if fn is self: + fn = fn._last + self._last = fn + self._register(expr, dbname, driver, fn) + return self + return decorate + + def _parse_dispatch(self, text): + m = self._db_plus_driver_reg.match(text) + if not m: + raise ValueError("Couldn't parse database[+driver]: %r" % text) + return m.group(1) or '*', m.group(2) or '*' + + def __call__(self, *arg, **kw): + target = arg[0] + return self._dispatch_on( + self._url_from_target(target), target, arg, kw) + + def _url_from_target(self, target): + if isinstance(target, Connectable): + return target.engine.url + elif isinstance(target, six.string_types): + if "://" not in target: + target_url = sa_url.make_url("%s://" % target) + else: + target_url = sa_url.make_url(target) + return target_url + elif isinstance(target, sa_url.URL): + return target + else: + raise ValueError("Invalid target type: %r" % target) + + def dispatch_on_drivername(self, drivername): + """Return a sub-dispatcher for the given drivername. + + This provides a means of calling a different function, such as the + "*" function, for a given target object that normally refers + to a sub-function. + + """ + dbname, driver = self._db_plus_driver_reg.match(drivername).group(1, 2) + + def go(*arg, **kw): + return self._dispatch_on_db_driver(dbname, "*", arg, kw) + + return go + + def _dispatch_on(self, url, target, arg, kw): + dbname, driver = self._db_plus_driver_reg.match( + url.drivername).group(1, 2) + if not driver: + driver = url.get_dialect().driver + + return self._dispatch_on_db_driver(dbname, driver, arg, kw) + + def _invoke_fn(self, fn, arg, kw): + return fn(*arg, **kw) + + +class DialectSingleFunctionDispatcher(DialectFunctionDispatcher): + def __init__(self): + self.reg = collections.defaultdict(dict) + + def _register(self, expr, dbname, driver, fn): + fn_dict = self.reg[dbname] + if driver in fn_dict: + raise TypeError("Multiple functions for expression %r" % expr) + fn_dict[driver] = fn + + def _matches(self, dbname, driver): + for db in (dbname, '*'): + subdict = self.reg[db] + for drv in (driver, '*'): + if drv in subdict: + return subdict[drv] + else: + raise ValueError( + "No default function found for driver: %r" % + ("%s+%s" % (dbname, driver))) + + def _dispatch_on_db_driver(self, dbname, driver, arg, kw): + fn = self._matches(dbname, driver) + return self._invoke_fn(fn, arg, kw) + + +class DialectMultiFunctionDispatcher(DialectFunctionDispatcher): + def __init__(self): + self.reg = collections.defaultdict( + lambda: collections.defaultdict(list)) + + def _register(self, expr, dbname, driver, fn): + self.reg[dbname][driver].append(fn) + + def _matches(self, dbname, driver): + if driver != '*': + drivers = (driver, '*') + else: + drivers = ('*', ) + + for db in (dbname, '*'): + subdict = self.reg[db] + for drv in drivers: + for fn in subdict[drv]: + yield fn + + def _dispatch_on_db_driver(self, dbname, driver, arg, kw): + for fn in self._matches(dbname, driver): + if self._invoke_fn(fn, arg, kw) is not None: + raise TypeError( + "Return value not allowed for " + "multiple filtered function") + +dispatch_for_dialect = DialectFunctionDispatcher.dispatch_for_dialect + + +def get_non_innodb_tables(connectable, skip_tables=('migrate_version', + 'alembic_version')): + """Get a list of tables which don't use InnoDB storage engine. + + :param connectable: a SQLAlchemy Engine or a Connection instance + :param skip_tables: a list of tables which might have a different + storage engine + """ + + query_str = """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = :database AND + engine != 'InnoDB' + """ + + params = {} + if skip_tables: + params = dict( + ('skip_%s' % i, table_name) + for i, table_name in enumerate(skip_tables) + ) + + placeholders = ', '.join(':' + p for p in params) + query_str += ' AND table_name NOT IN (%s)' % placeholders + + params['database'] = connectable.engine.url.database + query = text(query_str) + noninnodb = connectable.execute(query, **params) + return [i[0] for i in noninnodb] diff --git a/oslo_db/tests/__init__.py b/oslo_db/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/base.py b/oslo_db/tests/base.py similarity index 100% rename from tests/base.py rename to oslo_db/tests/base.py diff --git a/oslo_db/tests/old_import_api/__init__.py b/oslo_db/tests/old_import_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/oslo_db/tests/old_import_api/base.py b/oslo_db/tests/old_import_api/base.py new file mode 100644 index 00000000..69e6a802 --- /dev/null +++ b/oslo_db/tests/old_import_api/base.py @@ -0,0 +1,53 @@ +# Copyright 2010-2011 OpenStack Foundation +# Copyright (c) 2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os + +import fixtures +import testtools + +_TRUE_VALUES = ('true', '1', 'yes') + +# FIXME(dhellmann) Update this to use oslo.test library + + +class TestCase(testtools.TestCase): + + """Test case base class for all unit tests.""" + + def setUp(self): + """Run before each test method to initialize test environment.""" + + super(TestCase, self).setUp() + test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0) + try: + test_timeout = int(test_timeout) + except ValueError: + # If timeout value is invalid do not set a timeout. + test_timeout = 0 + if test_timeout > 0: + self.useFixture(fixtures.Timeout(test_timeout, gentle=True)) + + self.useFixture(fixtures.NestedTempfile()) + self.useFixture(fixtures.TempHomeDir()) + + if os.environ.get('OS_STDOUT_CAPTURE') in _TRUE_VALUES: + stdout = self.useFixture(fixtures.StringStream('stdout')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout)) + if os.environ.get('OS_STDERR_CAPTURE') in _TRUE_VALUES: + stderr = self.useFixture(fixtures.StringStream('stderr')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr)) + + self.log_fixture = self.useFixture(fixtures.FakeLogger()) diff --git a/oslo_db/tests/old_import_api/sqlalchemy/__init__.py b/oslo_db/tests/old_import_api/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_engine_connect.py b/oslo_db/tests/old_import_api/sqlalchemy/test_engine_connect.py new file mode 100644 index 00000000..54e359fd --- /dev/null +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_engine_connect.py @@ -0,0 +1,68 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test the compatibility layer for the engine_connect() event. + +This event is added as of SQLAlchemy 0.9.0; oslo.db provides a compatibility +layer for prior SQLAlchemy versions. + +""" + +import mock +from oslotest import base as test_base +import sqlalchemy as sqla + +from oslo.db.sqlalchemy import compat + + +class EngineConnectTest(test_base.BaseTestCase): + + def setUp(self): + super(EngineConnectTest, self).setUp() + + self.engine = engine = sqla.create_engine("sqlite://") + self.addCleanup(engine.dispose) + + def test_connect_event(self): + engine = self.engine + + listener = mock.Mock() + compat.engine_connect(engine, listener) + + conn = engine.connect() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False)] + ) + + conn.close() + + conn2 = engine.connect() + conn2.close() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False), mock.call(conn2, False)] + ) + + def test_branch(self): + engine = self.engine + + listener = mock.Mock() + compat.engine_connect(engine, listener) + + conn = engine.connect() + branched = conn.connect() + conn.close() + self.assertEqual( + listener.mock_calls, + [mock.call(conn, False), mock.call(branched, True)] + ) diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_exc_filters.py b/oslo_db/tests/old_import_api/sqlalchemy/test_exc_filters.py new file mode 100644 index 00000000..4d4609a3 --- /dev/null +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_exc_filters.py @@ -0,0 +1,833 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test exception filters applied to engines.""" + +import contextlib +import itertools + +import mock +from oslotest import base as oslo_test_base +import six +import sqlalchemy as sqla +from sqlalchemy.orm import mapper + +from oslo.db import exception +from oslo.db.sqlalchemy import compat +from oslo.db.sqlalchemy import exc_filters +from oslo.db.sqlalchemy import test_base +from oslo_db.sqlalchemy import session as private_session +from oslo_db.tests.old_import_api import utils as test_utils + +_TABLE_NAME = '__tmp__test__tmp__' + + +class _SQLAExceptionMatcher(object): + def assertInnerException( + self, + matched, exception_type, message, sql=None, params=None): + + exc = matched.inner_exception + self.assertSQLAException(exc, exception_type, message, sql, params) + + def assertSQLAException( + self, + exc, exception_type, message, sql=None, params=None): + if isinstance(exception_type, (type, tuple)): + self.assertTrue(issubclass(exc.__class__, exception_type)) + else: + self.assertEqual(exc.__class__.__name__, exception_type) + self.assertEqual(str(exc.orig).lower(), message.lower()) + if sql is not None: + self.assertEqual(exc.statement, sql) + if params is not None: + self.assertEqual(exc.params, params) + + +class TestsExceptionFilter(_SQLAExceptionMatcher, oslo_test_base.BaseTestCase): + + class Error(Exception): + """DBAPI base error. + + This exception and subclasses are used in a mock context + within these tests. + + """ + + class OperationalError(Error): + pass + + class InterfaceError(Error): + pass + + class InternalError(Error): + pass + + class IntegrityError(Error): + pass + + class ProgrammingError(Error): + pass + + class TransactionRollbackError(OperationalError): + """Special psycopg2-only error class. + + SQLAlchemy has an issue with this per issue #3075: + + https://bitbucket.org/zzzeek/sqlalchemy/issue/3075/ + + """ + + def setUp(self): + super(TestsExceptionFilter, self).setUp() + self.engine = sqla.create_engine("sqlite://") + exc_filters.register_engine(self.engine) + self.engine.connect().close() # initialize + + @contextlib.contextmanager + def _dbapi_fixture(self, dialect_name): + engine = self.engine + with test_utils.nested( + mock.patch.object(engine.dialect.dbapi, + "Error", + self.Error), + mock.patch.object(engine.dialect, "name", dialect_name), + ): + yield + + @contextlib.contextmanager + def _fixture(self, dialect_name, exception, is_disconnect=False): + + def do_execute(self, cursor, statement, parameters, **kw): + raise exception + + engine = self.engine + + # ensure the engine has done its initial checks against the + # DB as we are going to be removing its ability to execute a + # statement + self.engine.connect().close() + + with test_utils.nested( + mock.patch.object(engine.dialect, "do_execute", do_execute), + # replace the whole DBAPI rather than patching "Error" + # as some DBAPIs might not be patchable (?) + mock.patch.object(engine.dialect, + "dbapi", + mock.Mock(Error=self.Error)), + mock.patch.object(engine.dialect, "name", dialect_name), + mock.patch.object(engine.dialect, + "is_disconnect", + lambda *args: is_disconnect) + ): + yield + + def _run_test(self, dialect_name, statement, raises, expected, + is_disconnect=False, params=()): + with self._fixture(dialect_name, raises, is_disconnect=is_disconnect): + with self.engine.connect() as conn: + matched = self.assertRaises( + expected, conn.execute, statement, params + ) + return matched + + +class TestFallthroughsAndNonDBAPI(TestsExceptionFilter): + + def test_generic_dbapi(self): + matched = self._run_test( + "mysql", "select you_made_a_programming_error", + self.ProgrammingError("Error 123, you made a mistake"), + exception.DBError + ) + self.assertInnerException( + matched, + "ProgrammingError", + "Error 123, you made a mistake", + 'select you_made_a_programming_error', ()) + + def test_generic_dbapi_disconnect(self): + matched = self._run_test( + "mysql", "select the_db_disconnected", + self.InterfaceError("connection lost"), + exception.DBConnectionError, + is_disconnect=True + ) + self.assertInnerException( + matched, + "InterfaceError", "connection lost", + "select the_db_disconnected", ()), + + def test_operational_dbapi_disconnect(self): + matched = self._run_test( + "mysql", "select the_db_disconnected", + self.OperationalError("connection lost"), + exception.DBConnectionError, + is_disconnect=True + ) + self.assertInnerException( + matched, + "OperationalError", "connection lost", + "select the_db_disconnected", ()), + + def test_operational_error_asis(self): + """Test operational errors. + + test that SQLAlchemy OperationalErrors that aren't disconnects + are passed through without wrapping. + """ + + matched = self._run_test( + "mysql", "select some_operational_error", + self.OperationalError("some op error"), + sqla.exc.OperationalError + ) + self.assertSQLAException( + matched, + "OperationalError", "some op error" + ) + + def test_unicode_encode(self): + # intentionally generate a UnicodeEncodeError, as its + # constructor is quite complicated and seems to be non-public + # or at least not documented anywhere. + uee_ref = None + try: + six.u('\u2435').encode('ascii') + except UnicodeEncodeError as uee: + # Python3.x added new scoping rules here (sadly) + # http://legacy.python.org/dev/peps/pep-3110/#semantic-changes + uee_ref = uee + + self._run_test( + "postgresql", six.u('select \u2435'), + uee_ref, + exception.DBInvalidUnicodeParameter + ) + + def test_garden_variety(self): + matched = self._run_test( + "mysql", "select some_thing_that_breaks", + AttributeError("mysqldb has an attribute error"), + exception.DBError + ) + self.assertEqual("mysqldb has an attribute error", matched.args[0]) + + +class TestReferenceErrorSQLite(_SQLAExceptionMatcher, test_base.DbTestCase): + + def setUp(self): + super(TestReferenceErrorSQLite, self).setUp() + + meta = sqla.MetaData(bind=self.engine) + + table_1 = sqla.Table( + "resource_foo", meta, + sqla.Column("id", sqla.Integer, primary_key=True), + sqla.Column("foo", sqla.Integer), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) + table_1.create() + + self.table_2 = sqla.Table( + "resource_entity", meta, + sqla.Column("id", sqla.Integer, primary_key=True), + sqla.Column("foo_id", sqla.Integer, + sqla.ForeignKey("resource_foo.id", name="foo_fkey")), + mysql_engine='InnoDB', + mysql_charset='utf8', + ) + self.table_2.create() + + def test_raise(self): + self.engine.execute("PRAGMA foreign_keys = ON;") + + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert({'id': 1, 'foo_id': 2}) + ) + + self.assertInnerException( + matched, + "IntegrityError", + "FOREIGN KEY constraint failed", + 'INSERT INTO resource_entity (id, foo_id) VALUES (?, ?)', + (1, 2) + ) + + self.assertIsNone(matched.table) + self.assertIsNone(matched.constraint) + self.assertIsNone(matched.key) + self.assertIsNone(matched.key_table) + + +class TestReferenceErrorPostgreSQL(TestReferenceErrorSQLite, + test_base.PostgreSQLOpportunisticTestCase): + def test_raise(self): + params = {'id': 1, 'foo_id': 2} + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert(params) + ) + self.assertInnerException( + matched, + "IntegrityError", + "insert or update on table \"resource_entity\" " + "violates foreign key constraint \"foo_fkey\"\nDETAIL: Key " + "(foo_id)=(2) is not present in table \"resource_foo\".\n", + "INSERT INTO resource_entity (id, foo_id) VALUES (%(id)s, " + "%(foo_id)s)", + params, + ) + + self.assertEqual("resource_entity", matched.table) + self.assertEqual("foo_fkey", matched.constraint) + self.assertEqual("foo_id", matched.key) + self.assertEqual("resource_foo", matched.key_table) + + +class TestReferenceErrorMySQL(TestReferenceErrorSQLite, + test_base.MySQLOpportunisticTestCase): + def test_raise(self): + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert({'id': 1, 'foo_id': 2}) + ) + + self.assertInnerException( + matched, + "IntegrityError", + "(1452, 'Cannot add or update a child row: a " + "foreign key constraint fails (`{0}`.`resource_entity`, " + "CONSTRAINT `foo_fkey` FOREIGN KEY (`foo_id`) REFERENCES " + "`resource_foo` (`id`))')".format(self.engine.url.database), + "INSERT INTO resource_entity (id, foo_id) VALUES (%s, %s)", + (1, 2) + ) + self.assertEqual("resource_entity", matched.table) + self.assertEqual("foo_fkey", matched.constraint) + self.assertEqual("foo_id", matched.key) + self.assertEqual("resource_foo", matched.key_table) + + def test_raise_ansi_quotes(self): + self.engine.execute("SET SESSION sql_mode = 'ANSI';") + matched = self.assertRaises( + exception.DBReferenceError, + self.engine.execute, + self.table_2.insert({'id': 1, 'foo_id': 2}) + ) + + self.assertInnerException( + matched, + "IntegrityError", + '(1452, \'Cannot add or update a child row: a ' + 'foreign key constraint fails ("{0}"."resource_entity", ' + 'CONSTRAINT "foo_fkey" FOREIGN KEY ("foo_id") REFERENCES ' + '"resource_foo" ("id"))\')'.format(self.engine.url.database), + "INSERT INTO resource_entity (id, foo_id) VALUES (%s, %s)", + (1, 2) + ) + self.assertEqual("resource_entity", matched.table) + self.assertEqual("foo_fkey", matched.constraint) + self.assertEqual("foo_id", matched.key) + self.assertEqual("resource_foo", matched.key_table) + + +class TestDuplicate(TestsExceptionFilter): + + def _run_dupe_constraint_test(self, dialect_name, message, + expected_columns=['a', 'b'], + expected_value=None): + matched = self._run_test( + dialect_name, "insert into table some_values", + self.IntegrityError(message), + exception.DBDuplicateEntry + ) + self.assertEqual(expected_columns, matched.columns) + self.assertEqual(expected_value, matched.value) + + def _not_dupe_constraint_test(self, dialect_name, statement, message, + expected_cls): + matched = self._run_test( + dialect_name, statement, + self.IntegrityError(message), + expected_cls + ) + self.assertInnerException( + matched, + "IntegrityError", + str(self.IntegrityError(message)), + statement + ) + + def test_sqlite(self): + self._run_dupe_constraint_test("sqlite", 'column a, b are not unique') + + def test_sqlite_3_7_16_or_3_8_2_and_higher(self): + self._run_dupe_constraint_test( + "sqlite", + 'UNIQUE constraint failed: tbl.a, tbl.b') + + def test_sqlite_dupe_primary_key(self): + self._run_dupe_constraint_test( + "sqlite", + "PRIMARY KEY must be unique 'insert into t values(10)'", + expected_columns=[]) + + def test_mysql_mysqldb(self): + self._run_dupe_constraint_test( + "mysql", + '(1062, "Duplicate entry ' + '\'2-3\' for key \'uniq_tbl0a0b\'")', expected_value='2-3') + + def test_mysql_mysqlconnector(self): + self._run_dupe_constraint_test( + "mysql", + '1062 (23000): Duplicate entry ' + '\'2-3\' for key \'uniq_tbl0a0b\'")', expected_value='2-3') + + def test_postgresql(self): + self._run_dupe_constraint_test( + 'postgresql', + 'duplicate key value violates unique constraint' + '"uniq_tbl0a0b"' + '\nDETAIL: Key (a, b)=(2, 3) already exists.\n', + expected_value='2, 3' + ) + + def test_mysql_single(self): + self._run_dupe_constraint_test( + "mysql", + "1062 (23000): Duplicate entry '2' for key 'b'", + expected_columns=['b'], + expected_value='2' + ) + + def test_postgresql_single(self): + self._run_dupe_constraint_test( + 'postgresql', + 'duplicate key value violates unique constraint "uniq_tbl0b"\n' + 'DETAIL: Key (b)=(2) already exists.\n', + expected_columns=['b'], + expected_value='2' + ) + + def test_unsupported_backend(self): + self._not_dupe_constraint_test( + "nonexistent", "insert into table some_values", + self.IntegrityError("constraint violation"), + exception.DBError + ) + + def test_ibm_db_sa(self): + self._run_dupe_constraint_test( + 'ibm_db_sa', + 'SQL0803N One or more values in the INSERT statement, UPDATE ' + 'statement, or foreign key update caused by a DELETE statement are' + ' not valid because the primary key, unique constraint or unique ' + 'index identified by "2" constrains table "NOVA.KEY_PAIRS" from ' + 'having duplicate values for the index key.', + expected_columns=[] + ) + + def test_ibm_db_sa_notadupe(self): + self._not_dupe_constraint_test( + 'ibm_db_sa', + 'ALTER TABLE instance_types ADD CONSTRAINT ' + 'uniq_name_x_deleted UNIQUE (name, deleted)', + 'SQL0542N The column named "NAME" cannot be a column of a ' + 'primary key or unique key constraint because it can contain null ' + 'values.', + exception.DBError + ) + + +class TestDeadlock(TestsExceptionFilter): + statement = ('SELECT quota_usages.created_at AS ' + 'quota_usages_created_at FROM quota_usages ' + 'WHERE quota_usages.project_id = %(project_id_1)s ' + 'AND quota_usages.deleted = %(deleted_1)s FOR UPDATE') + params = { + 'project_id_1': '8891d4478bbf48ad992f050cdf55e9b5', + 'deleted_1': 0 + } + + def _run_deadlock_detect_test( + self, dialect_name, message, + orig_exception_cls=TestsExceptionFilter.OperationalError): + self._run_test( + dialect_name, self.statement, + orig_exception_cls(message), + exception.DBDeadlock, + params=self.params + ) + + def _not_deadlock_test( + self, dialect_name, message, + expected_cls, expected_dbapi_cls, + orig_exception_cls=TestsExceptionFilter.OperationalError): + + matched = self._run_test( + dialect_name, self.statement, + orig_exception_cls(message), + expected_cls, + params=self.params + ) + + if isinstance(matched, exception.DBError): + matched = matched.inner_exception + + self.assertEqual(matched.orig.__class__.__name__, expected_dbapi_cls) + + def test_mysql_mysqldb_deadlock(self): + self._run_deadlock_detect_test( + "mysql", + "(1213, 'Deadlock found when trying " + "to get lock; try restarting " + "transaction')" + ) + + def test_mysql_mysqldb_galera_deadlock(self): + self._run_deadlock_detect_test( + "mysql", + "(1205, 'Lock wait timeout exceeded; " + "try restarting transaction')" + ) + + def test_mysql_mysqlconnector_deadlock(self): + self._run_deadlock_detect_test( + "mysql", + "1213 (40001): Deadlock found when trying to get lock; try " + "restarting transaction", + orig_exception_cls=self.InternalError + ) + + def test_mysql_not_deadlock(self): + self._not_deadlock_test( + "mysql", + "(1005, 'some other error')", + sqla.exc.OperationalError, # note OperationalErrors are sent thru + "OperationalError", + ) + + def test_postgresql_deadlock(self): + self._run_deadlock_detect_test( + "postgresql", + "deadlock detected", + orig_exception_cls=self.TransactionRollbackError + ) + + def test_postgresql_not_deadlock(self): + self._not_deadlock_test( + "postgresql", + 'relation "fake" does not exist', + # can be either depending on #3075 + (exception.DBError, sqla.exc.OperationalError), + "TransactionRollbackError", + orig_exception_cls=self.TransactionRollbackError + ) + + def test_ibm_db_sa_deadlock(self): + self._run_deadlock_detect_test( + "ibm_db_sa", + "SQL0911N The current transaction has been " + "rolled back because of a deadlock or timeout", + # use the lowest class b.c. I don't know what actual error + # class DB2's driver would raise for this + orig_exception_cls=self.Error + ) + + def test_ibm_db_sa_not_deadlock(self): + self._not_deadlock_test( + "ibm_db_sa", + "SQL01234B Some other error.", + exception.DBError, + "Error", + orig_exception_cls=self.Error + ) + + +class IntegrationTest(test_base.DbTestCase): + """Test an actual error-raising round trips against the database.""" + + def setUp(self): + super(IntegrationTest, self).setUp() + meta = sqla.MetaData() + self.test_table = sqla.Table( + _TABLE_NAME, meta, + sqla.Column('id', sqla.Integer, + primary_key=True, nullable=False), + sqla.Column('counter', sqla.Integer, + nullable=False), + sqla.UniqueConstraint('counter', + name='uniq_counter')) + self.test_table.create(self.engine) + self.addCleanup(self.test_table.drop, self.engine) + + class Foo(object): + def __init__(self, counter): + self.counter = counter + mapper(Foo, self.test_table) + self.Foo = Foo + + def test_flush_wrapper_duplicate_entry(self): + """test a duplicate entry exception.""" + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=1) + _session.add(foo) + self.assertRaises(exception.DBDuplicateEntry, _session.flush) + + def test_autoflush_wrapper_duplicate_entry(self): + """Test a duplicate entry exception raised. + + test a duplicate entry exception raised via query.all()-> autoflush + """ + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=1) + _session.add(foo) + self.assertTrue(_session.autoflush) + self.assertRaises(exception.DBDuplicateEntry, + _session.query(self.Foo).all) + + def test_flush_wrapper_plain_integrity_error(self): + """test a plain integrity error wrapped as DBError.""" + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=None) + _session.add(foo) + self.assertRaises(exception.DBError, _session.flush) + + def test_flush_wrapper_operational_error(self): + """test an operational error from flush() raised as-is.""" + + _session = self.sessionmaker() + + with _session.begin(): + foo = self.Foo(counter=1) + _session.add(foo) + + _session.begin() + self.addCleanup(_session.rollback) + foo = self.Foo(counter=sqla.func.imfake(123)) + _session.add(foo) + matched = self.assertRaises(sqla.exc.OperationalError, _session.flush) + self.assertTrue("no such function" in str(matched)) + + def test_query_wrapper_operational_error(self): + """test an operational error from query.all() raised as-is.""" + + _session = self.sessionmaker() + + _session.begin() + self.addCleanup(_session.rollback) + q = _session.query(self.Foo).filter( + self.Foo.counter == sqla.func.imfake(123)) + matched = self.assertRaises(sqla.exc.OperationalError, q.all) + self.assertTrue("no such function" in str(matched)) + + +class TestDBDisconnected(TestsExceptionFilter): + + @contextlib.contextmanager + def _fixture( + self, + dialect_name, exception, num_disconnects, is_disconnect=True): + engine = self.engine + + compat.engine_connect(engine, private_session._connect_ping_listener) + + real_do_execute = engine.dialect.do_execute + counter = itertools.count(1) + + def fake_do_execute(self, *arg, **kw): + if next(counter) > num_disconnects: + return real_do_execute(self, *arg, **kw) + else: + raise exception + + with self._dbapi_fixture(dialect_name): + with test_utils.nested( + mock.patch.object(engine.dialect, + "do_execute", + fake_do_execute), + mock.patch.object(engine.dialect, + "is_disconnect", + mock.Mock(return_value=is_disconnect)) + ): + yield + + def _test_ping_listener_disconnected( + self, dialect_name, exc_obj, is_disconnect=True): + with self._fixture(dialect_name, exc_obj, 1, is_disconnect): + conn = self.engine.connect() + with conn.begin(): + self.assertEqual(conn.scalar(sqla.select([1])), 1) + self.assertFalse(conn.closed) + self.assertFalse(conn.invalidated) + self.assertTrue(conn.in_transaction()) + + with self._fixture(dialect_name, exc_obj, 2, is_disconnect): + self.assertRaises( + exception.DBConnectionError, + self.engine.connect + ) + + # test implicit execution + with self._fixture(dialect_name, exc_obj, 1): + self.assertEqual(self.engine.scalar(sqla.select([1])), 1) + + def test_mysql_ping_listener_disconnected(self): + for code in [2006, 2013, 2014, 2045, 2055]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code) + ) + + def test_mysql_ping_listener_disconnected_regex_only(self): + # intentionally set the is_disconnect flag to False + # in the "sqlalchemy" layer to make sure the regexp + # on _is_db_connection_error is catching + for code in [2002, 2003, 2006, 2013]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code), + is_disconnect=False + ) + + def test_db2_ping_listener_disconnected(self): + self._test_ping_listener_disconnected( + "ibm_db_sa", + self.OperationalError( + 'SQL30081N: DB2 Server connection is no longer active') + ) + + def test_db2_ping_listener_disconnected_regex_only(self): + self._test_ping_listener_disconnected( + "ibm_db_sa", + self.OperationalError( + 'SQL30081N: DB2 Server connection is no longer active'), + is_disconnect=False + ) + + +class TestDBConnectRetry(TestsExceptionFilter): + + def _run_test(self, dialect_name, exception, count, retries): + counter = itertools.count() + + engine = self.engine + + # empty out the connection pool + engine.dispose() + + connect_fn = engine.dialect.connect + + def cant_connect(*arg, **kw): + if next(counter) < count: + raise exception + else: + return connect_fn(*arg, **kw) + + with self._dbapi_fixture(dialect_name): + with mock.patch.object(engine.dialect, "connect", cant_connect): + return private_session._test_connection(engine, retries, .01) + + def test_connect_no_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 0 + ) + # didnt connect because nothing was tried + self.assertIsNone(conn) + + def test_connect_inifinite_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, -1 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_connect_retry_past_failure(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 3 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_connect_retry_not_candidate_exception(self): + self.assertRaises( + sqla.exc.OperationalError, # remember, we pass OperationalErrors + # through at the moment :) + self._run_test, + "mysql", + self.OperationalError("Error: (2015) I can't connect period"), + 2, 3 + ) + + def test_connect_retry_stops_infailure(self): + self.assertRaises( + exception.DBConnectionError, + self._run_test, + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 3, 2 + ) + + def test_db2_error_positive(self): + conn = self._run_test( + "ibm_db_sa", + self.OperationalError("blah blah -30081 blah blah"), + 2, -1 + ) + # conn is good + self.assertEqual(conn.scalar(sqla.select([1])), 1) + + def test_db2_error_negative(self): + self.assertRaises( + sqla.exc.OperationalError, + self._run_test, + "ibm_db_sa", + self.OperationalError("blah blah -39981 blah blah"), + 2, 3 + ) diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_handle_error.py b/oslo_db/tests/old_import_api/sqlalchemy/test_handle_error.py new file mode 100644 index 00000000..fed029a4 --- /dev/null +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_handle_error.py @@ -0,0 +1,194 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Test the compatibility layer for the handle_error() event. + +This event is added as of SQLAlchemy 0.9.7; oslo.db provides a compatibility +layer for prior SQLAlchemy versions. + +""" + +import mock +from oslotest import base as test_base +import sqlalchemy as sqla +from sqlalchemy.sql import column +from sqlalchemy.sql import literal +from sqlalchemy.sql import select +from sqlalchemy.types import Integer +from sqlalchemy.types import TypeDecorator + +from oslo.db.sqlalchemy import compat +from oslo.db.sqlalchemy.compat import utils +from oslo_db.tests.old_import_api import utils as test_utils + + +class MyException(Exception): + pass + + +class ExceptionReraiseTest(test_base.BaseTestCase): + + def setUp(self): + super(ExceptionReraiseTest, self).setUp() + + self.engine = engine = sqla.create_engine("sqlite://") + self.addCleanup(engine.dispose) + + def _fixture(self): + engine = self.engine + + def err(context): + if "ERROR ONE" in str(context.statement): + raise MyException("my exception") + compat.handle_error(engine, err) + + def test_exception_event_altered(self): + self._fixture() + + with mock.patch.object(self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception") as patched: + + matchee = self.assertRaises( + MyException, + self.engine.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + ) + self.assertEqual(1, patched.call_count) + self.assertEqual("my exception", matchee.args[0]) + + def test_exception_event_non_altered(self): + self._fixture() + + with mock.patch.object(self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception") as patched: + + self.assertRaises( + sqla.exc.DBAPIError, + self.engine.execute, "SELECT 'ERROR TWO' FROM I_DONT_EXIST" + ) + self.assertEqual(1, patched.call_count) + + def test_is_disconnect_not_interrupted(self): + self._fixture() + + with test_utils.nested( + mock.patch.object( + self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception" + ), + mock.patch.object( + self.engine.dialect, "is_disconnect", + lambda *args: True + ) + ) as (handle_dbapi_exception, is_disconnect): + with self.engine.connect() as conn: + self.assertRaises( + MyException, + conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + ) + self.assertEqual(1, handle_dbapi_exception.call_count) + self.assertTrue(conn.invalidated) + + def test_no_is_disconnect_not_invalidated(self): + self._fixture() + + with test_utils.nested( + mock.patch.object( + self.engine.dialect.execution_ctx_cls, + "handle_dbapi_exception" + ), + mock.patch.object( + self.engine.dialect, "is_disconnect", + lambda *args: False + ) + ) as (handle_dbapi_exception, is_disconnect): + with self.engine.connect() as conn: + self.assertRaises( + MyException, + conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + ) + self.assertEqual(1, handle_dbapi_exception.call_count) + self.assertFalse(conn.invalidated) + + def test_exception_event_ad_hoc_context(self): + engine = self.engine + + nope = MyException("nope") + + class MyType(TypeDecorator): + impl = Integer + + def process_bind_param(self, value, dialect): + raise nope + + listener = mock.Mock(return_value=None) + compat.handle_error(engine, listener) + + self.assertRaises( + sqla.exc.StatementError, + engine.execute, + select([1]).where(column('foo') == literal('bar', MyType)) + ) + + ctx = listener.mock_calls[0][1][0] + self.assertTrue(ctx.statement.startswith("SELECT 1 ")) + self.assertIs(ctx.is_disconnect, False) + self.assertIs(ctx.original_exception, nope) + + def _test_alter_disconnect(self, orig_error, evt_value): + engine = self.engine + + def evt(ctx): + ctx.is_disconnect = evt_value + compat.handle_error(engine, evt) + + # if we are under sqla 0.9.7, and we are expecting to take + # an "is disconnect" exception and make it not a disconnect, + # that isn't supported b.c. the wrapped handler has already + # done the invalidation. + expect_failure = not utils.sqla_097 and orig_error and not evt_value + + with mock.patch.object(engine.dialect, + "is_disconnect", + mock.Mock(return_value=orig_error)): + + with engine.connect() as c: + conn_rec = c.connection._connection_record + try: + c.execute("SELECT x FROM nonexistent") + assert False + except sqla.exc.StatementError as st: + self.assertFalse(expect_failure) + + # check the exception's invalidation flag + self.assertEqual(st.connection_invalidated, evt_value) + + # check the Connection object's invalidation flag + self.assertEqual(c.invalidated, evt_value) + + # this is the ConnectionRecord object; it's invalidated + # when its .connection member is None + self.assertEqual(conn_rec.connection is None, evt_value) + + except NotImplementedError as ne: + self.assertTrue(expect_failure) + self.assertEqual( + str(ne), + "Can't reset 'disconnect' status of exception once it " + "is set with this version of SQLAlchemy") + + def test_alter_disconnect_to_true(self): + self._test_alter_disconnect(False, True) + self._test_alter_disconnect(True, True) + + def test_alter_disconnect_to_false(self): + self._test_alter_disconnect(True, False) + self._test_alter_disconnect(False, False) diff --git a/tests/sqlalchemy/test_migrate_cli.py b/oslo_db/tests/old_import_api/sqlalchemy/test_migrate_cli.py similarity index 99% rename from tests/sqlalchemy/test_migrate_cli.py rename to oslo_db/tests/old_import_api/sqlalchemy/test_migrate_cli.py index d7e7a5c5..135d44e3 100644 --- a/tests/sqlalchemy/test_migrate_cli.py +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_migrate_cli.py @@ -28,7 +28,7 @@ class MockWithCmp(mock.MagicMock): self.__lt__ = lambda self, other: self.order < other.order -@mock.patch(('oslo.db.sqlalchemy.migration_cli.' +@mock.patch(('oslo_db.sqlalchemy.migration_cli.' 'ext_alembic.alembic.command')) class TestAlembicExtension(test_base.BaseTestCase): diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_migration_common.py b/oslo_db/tests/old_import_api/sqlalchemy/test_migration_common.py new file mode 100644 index 00000000..98ae46e9 --- /dev/null +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_migration_common.py @@ -0,0 +1,174 @@ +# Copyright 2013 Mirantis Inc. +# All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +import os +import tempfile + +from migrate import exceptions as migrate_exception +from migrate.versioning import api as versioning_api +import mock +import sqlalchemy + +from oslo.db import exception as db_exception +from oslo.db.sqlalchemy import migration +from oslo.db.sqlalchemy import test_base +from oslo_db.sqlalchemy import migration as private_migration +from oslo_db.tests.old_import_api import utils as test_utils + + +class TestMigrationCommon(test_base.DbTestCase): + def setUp(self): + super(TestMigrationCommon, self).setUp() + + migration._REPOSITORY = None + self.path = tempfile.mkdtemp('test_migration') + self.path1 = tempfile.mkdtemp('test_migration') + self.return_value = '/home/openstack/migrations' + self.return_value1 = '/home/extension/migrations' + self.init_version = 1 + self.test_version = 123 + + self.patcher_repo = mock.patch.object(private_migration, 'Repository') + self.repository = self.patcher_repo.start() + self.repository.side_effect = [self.return_value, self.return_value1] + + self.mock_api_db = mock.patch.object(versioning_api, 'db_version') + self.mock_api_db_version = self.mock_api_db.start() + self.mock_api_db_version.return_value = self.test_version + + def tearDown(self): + os.rmdir(self.path) + self.mock_api_db.stop() + self.patcher_repo.stop() + super(TestMigrationCommon, self).tearDown() + + def test_db_version_control(self): + with test_utils.nested( + mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'), + mock.patch.object(versioning_api, 'version_control'), + ) as (mock_find_repo, mock_version_control): + mock_find_repo.return_value = self.return_value + + version = migration.db_version_control( + self.engine, self.path, self.test_version) + + self.assertEqual(version, self.test_version) + mock_version_control.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_version_return(self): + ret_val = migration.db_version(self.engine, self.path, + self.init_version) + self.assertEqual(ret_val, self.test_version) + + def test_db_version_raise_not_controlled_error_first(self): + patcher = mock.patch.object(private_migration, 'db_version_control') + with patcher as mock_ver: + + self.mock_api_db_version.side_effect = [ + migrate_exception.DatabaseNotControlledError('oups'), + self.test_version] + + ret_val = migration.db_version(self.engine, self.path, + self.init_version) + self.assertEqual(ret_val, self.test_version) + mock_ver.assert_called_once_with(self.engine, self.path, + version=self.init_version) + + def test_db_version_raise_not_controlled_error_tables(self): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = \ + migrate_exception.DatabaseNotControlledError('oups') + my_meta = mock.MagicMock() + my_meta.tables = {'a': 1, 'b': 2} + mock_meta.return_value = my_meta + + self.assertRaises( + db_exception.DbMigrationError, migration.db_version, + self.engine, self.path, self.init_version) + + @mock.patch.object(versioning_api, 'version_control') + def test_db_version_raise_not_controlled_error_no_tables(self, mock_vc): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = ( + migrate_exception.DatabaseNotControlledError('oups'), + self.init_version) + my_meta = mock.MagicMock() + my_meta.tables = {} + mock_meta.return_value = my_meta + migration.db_version(self.engine, self.path, self.init_version) + + mock_vc.assert_called_once_with(self.engine, self.return_value1, + self.init_version) + + def test_db_sync_wrong_version(self): + self.assertRaises(db_exception.DbMigrationError, + migration.db_sync, self.engine, self.path, 'foo') + + def test_db_sync_upgrade(self): + init_ver = 55 + with test_utils.nested( + mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'), + mock.patch.object(versioning_api, 'upgrade') + ) as (mock_find_repo, mock_upgrade): + + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version - 1 + + migration.db_sync(self.engine, self.path, self.test_version, + init_ver) + + mock_upgrade.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_sync_downgrade(self): + with test_utils.nested( + mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_downgrade): + + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version + 1 + + migration.db_sync(self.engine, self.path, self.test_version) + + mock_downgrade.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_sync_sanity_called(self): + with test_utils.nested( + mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'), + mock.patch('oslo_db.sqlalchemy.migration._db_schema_sanity_check'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_sanity, mock_downgrade): + + mock_find_repo.return_value = self.return_value + migration.db_sync(self.engine, self.path, self.test_version) + + mock_sanity.assert_called_once_with(self.engine) + + def test_db_sync_sanity_skipped(self): + with test_utils.nested( + mock.patch('oslo_db.sqlalchemy.migration._find_migrate_repo'), + mock.patch('oslo_db.sqlalchemy.migration._db_schema_sanity_check'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_sanity, mock_downgrade): + + mock_find_repo.return_value = self.return_value + migration.db_sync(self.engine, self.path, self.test_version, + sanity_check=False) + + self.assertFalse(mock_sanity.called) diff --git a/tests/sqlalchemy/test_migrations.py b/oslo_db/tests/old_import_api/sqlalchemy/test_migrations.py similarity index 100% rename from tests/sqlalchemy/test_migrations.py rename to oslo_db/tests/old_import_api/sqlalchemy/test_migrations.py diff --git a/tests/sqlalchemy/test_models.py b/oslo_db/tests/old_import_api/sqlalchemy/test_models.py similarity index 100% rename from tests/sqlalchemy/test_models.py rename to oslo_db/tests/old_import_api/sqlalchemy/test_models.py diff --git a/tests/sqlalchemy/test_options.py b/oslo_db/tests/old_import_api/sqlalchemy/test_options.py similarity index 98% rename from tests/sqlalchemy/test_options.py rename to oslo_db/tests/old_import_api/sqlalchemy/test_options.py index 51bf470b..df0da997 100644 --- a/tests/sqlalchemy/test_options.py +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_options.py @@ -15,7 +15,7 @@ from oslo.config import cfg from oslo.config import fixture as config from oslo.db import options -from tests import utils as test_utils +from oslo_db.tests.old_import_api import utils as test_utils class DbApiOptionsTestCase(test_utils.BaseTestCase): diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_sqlalchemy.py b/oslo_db/tests/old_import_api/sqlalchemy/test_sqlalchemy.py new file mode 100644 index 00000000..8d45cd46 --- /dev/null +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_sqlalchemy.py @@ -0,0 +1,554 @@ +# coding=utf-8 + +# Copyright (c) 2012 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Unit tests for SQLAlchemy specific code.""" + +import logging + +import fixtures +import mock +from oslo.config import cfg +from oslotest import base as oslo_test +import sqlalchemy +from sqlalchemy import Column, MetaData, Table +from sqlalchemy import Integer, String +from sqlalchemy.ext.declarative import declarative_base + +from oslo.db import exception +from oslo.db.sqlalchemy import models +from oslo.db.sqlalchemy import session +from oslo.db.sqlalchemy import test_base +from oslo_db import options as db_options +from oslo_db.sqlalchemy import session as private_session + + +BASE = declarative_base() +_TABLE_NAME = '__tmp__test__tmp__' + +_REGEXP_TABLE_NAME = _TABLE_NAME + "regexp" + + +class RegexpTable(BASE, models.ModelBase): + __tablename__ = _REGEXP_TABLE_NAME + id = Column(Integer, primary_key=True) + bar = Column(String(255)) + + +class RegexpFilterTestCase(test_base.DbTestCase): + + def setUp(self): + super(RegexpFilterTestCase, self).setUp() + meta = MetaData() + meta.bind = self.engine + test_table = Table(_REGEXP_TABLE_NAME, meta, + Column('id', Integer, primary_key=True, + nullable=False), + Column('bar', String(255))) + test_table.create() + self.addCleanup(test_table.drop) + + def _test_regexp_filter(self, regexp, expected): + _session = self.sessionmaker() + with _session.begin(): + for i in ['10', '20', u'♥']: + tbl = RegexpTable() + tbl.update({'bar': i}) + tbl.save(session=_session) + + regexp_op = RegexpTable.bar.op('REGEXP')(regexp) + result = _session.query(RegexpTable).filter(regexp_op).all() + self.assertEqual([r.bar for r in result], expected) + + def test_regexp_filter(self): + self._test_regexp_filter('10', ['10']) + + def test_regexp_filter_nomatch(self): + self._test_regexp_filter('11', []) + + def test_regexp_filter_unicode(self): + self._test_regexp_filter(u'♥', [u'♥']) + + def test_regexp_filter_unicode_nomatch(self): + self._test_regexp_filter(u'♦', []) + + +class SQLiteSavepointTest(test_base.DbTestCase): + def setUp(self): + super(SQLiteSavepointTest, self).setUp() + meta = MetaData() + self.test_table = Table( + "test_table", meta, + Column('id', Integer, primary_key=True), + Column('data', String(10))) + self.test_table.create(self.engine) + self.addCleanup(self.test_table.drop, self.engine) + + def test_plain_transaction(self): + conn = self.engine.connect() + trans = conn.begin() + conn.execute( + self.test_table.insert(), + {'data': 'data 1'} + ) + self.assertEqual( + [(1, 'data 1')], + self.engine.execute( + self.test_table.select(). + order_by(self.test_table.c.id) + ).fetchall() + ) + trans.rollback() + self.assertEqual( + 0, + self.engine.scalar(self.test_table.count()) + ) + + def test_savepoint_middle(self): + with self.engine.begin() as conn: + conn.execute( + self.test_table.insert(), + {'data': 'data 1'} + ) + + savepoint = conn.begin_nested() + conn.execute( + self.test_table.insert(), + {'data': 'data 2'} + ) + savepoint.rollback() + + conn.execute( + self.test_table.insert(), + {'data': 'data 3'} + ) + + self.assertEqual( + [(1, 'data 1'), (2, 'data 3')], + self.engine.execute( + self.test_table.select(). + order_by(self.test_table.c.id) + ).fetchall() + ) + + def test_savepoint_beginning(self): + with self.engine.begin() as conn: + savepoint = conn.begin_nested() + conn.execute( + self.test_table.insert(), + {'data': 'data 1'} + ) + savepoint.rollback() + + conn.execute( + self.test_table.insert(), + {'data': 'data 2'} + ) + + self.assertEqual( + [(1, 'data 2')], + self.engine.execute( + self.test_table.select(). + order_by(self.test_table.c.id) + ).fetchall() + ) + + +class FakeDBAPIConnection(): + def cursor(self): + return FakeCursor() + + +class FakeCursor(): + def execute(self, sql): + pass + + +class FakeConnectionProxy(): + pass + + +class FakeConnectionRec(): + pass + + +class OperationalError(Exception): + pass + + +class ProgrammingError(Exception): + pass + + +class FakeDB2Engine(object): + + class Dialect(): + + def is_disconnect(self, e, *args): + expected_error = ('SQL30081N: DB2 Server connection is no longer ' + 'active') + return (str(e) == expected_error) + + dialect = Dialect() + name = 'ibm_db_sa' + + def dispose(self): + pass + + +class MySQLModeTestCase(test_base.MySQLOpportunisticTestCase): + + def __init__(self, *args, **kwargs): + super(MySQLModeTestCase, self).__init__(*args, **kwargs) + # By default, run in empty SQL mode. + # Subclasses override this with specific modes. + self.mysql_mode = '' + + def setUp(self): + super(MySQLModeTestCase, self).setUp() + + self.engine = session.create_engine(self.engine.url, + mysql_sql_mode=self.mysql_mode) + self.connection = self.engine.connect() + + meta = MetaData() + meta.bind = self.engine + self.test_table = Table(_TABLE_NAME + "mode", meta, + Column('id', Integer, primary_key=True), + Column('bar', String(255))) + self.test_table.create() + + self.addCleanup(self.test_table.drop) + self.addCleanup(self.connection.close) + + def _test_string_too_long(self, value): + with self.connection.begin(): + self.connection.execute(self.test_table.insert(), + bar=value) + result = self.connection.execute(self.test_table.select()) + return result.fetchone()['bar'] + + def test_string_too_long(self): + value = 'a' * 512 + # String is too long. + # With no SQL mode set, this gets truncated. + self.assertNotEqual(value, + self._test_string_too_long(value)) + + +class MySQLStrictAllTablesModeTestCase(MySQLModeTestCase): + "Test data integrity enforcement in MySQL STRICT_ALL_TABLES mode." + + def __init__(self, *args, **kwargs): + super(MySQLStrictAllTablesModeTestCase, self).__init__(*args, **kwargs) + self.mysql_mode = 'STRICT_ALL_TABLES' + + def test_string_too_long(self): + value = 'a' * 512 + # String is too long. + # With STRICT_ALL_TABLES or TRADITIONAL mode set, this is an error. + self.assertRaises(exception.DBError, + self._test_string_too_long, value) + + +class MySQLTraditionalModeTestCase(MySQLStrictAllTablesModeTestCase): + """Test data integrity enforcement in MySQL TRADITIONAL mode. + + Since TRADITIONAL includes STRICT_ALL_TABLES, this inherits all + STRICT_ALL_TABLES mode tests. + """ + + def __init__(self, *args, **kwargs): + super(MySQLTraditionalModeTestCase, self).__init__(*args, **kwargs) + self.mysql_mode = 'TRADITIONAL' + + +class EngineFacadeTestCase(oslo_test.BaseTestCase): + def setUp(self): + super(EngineFacadeTestCase, self).setUp() + + self.facade = session.EngineFacade('sqlite://') + + def test_get_engine(self): + eng1 = self.facade.get_engine() + eng2 = self.facade.get_engine() + + self.assertIs(eng1, eng2) + + def test_get_session(self): + ses1 = self.facade.get_session() + ses2 = self.facade.get_session() + + self.assertIsNot(ses1, ses2) + + def test_get_session_arguments_override_default_settings(self): + ses = self.facade.get_session(autocommit=False, expire_on_commit=True) + + self.assertFalse(ses.autocommit) + self.assertTrue(ses.expire_on_commit) + + @mock.patch('oslo_db.sqlalchemy.session.get_maker') + @mock.patch('oslo_db.sqlalchemy.session.create_engine') + def test_creation_from_config(self, create_engine, get_maker): + conf = cfg.ConfigOpts() + conf.register_opts(db_options.database_opts, group='database') + + overrides = { + 'connection': 'sqlite:///:memory:', + 'slave_connection': None, + 'connection_debug': 100, + 'max_pool_size': 10, + 'mysql_sql_mode': 'TRADITIONAL', + } + for optname, optvalue in overrides.items(): + conf.set_override(optname, optvalue, group='database') + + session.EngineFacade.from_config(conf, + autocommit=False, + expire_on_commit=True) + + create_engine.assert_called_once_with( + sql_connection='sqlite:///:memory:', + connection_debug=100, + max_pool_size=10, + mysql_sql_mode='TRADITIONAL', + sqlite_fk=False, + idle_timeout=mock.ANY, + retry_interval=mock.ANY, + max_retries=mock.ANY, + max_overflow=mock.ANY, + connection_trace=mock.ANY, + sqlite_synchronous=mock.ANY, + pool_timeout=mock.ANY, + thread_checkin=mock.ANY, + ) + get_maker.assert_called_once_with(engine=create_engine(), + autocommit=False, + expire_on_commit=True) + + def test_slave_connection(self): + paths = self.create_tempfiles([('db.master', ''), ('db.slave', '')], + ext='') + master_path = 'sqlite:///' + paths[0] + slave_path = 'sqlite:///' + paths[1] + + facade = session.EngineFacade( + sql_connection=master_path, + slave_connection=slave_path + ) + + master = facade.get_engine() + self.assertEqual(master_path, str(master.url)) + slave = facade.get_engine(use_slave=True) + self.assertEqual(slave_path, str(slave.url)) + + master_session = facade.get_session() + self.assertEqual(master_path, str(master_session.bind.url)) + slave_session = facade.get_session(use_slave=True) + self.assertEqual(slave_path, str(slave_session.bind.url)) + + def test_slave_connection_string_not_provided(self): + master_path = 'sqlite:///' + self.create_tempfiles( + [('db.master', '')], ext='')[0] + + facade = session.EngineFacade(sql_connection=master_path) + + master = facade.get_engine() + slave = facade.get_engine(use_slave=True) + self.assertIs(master, slave) + self.assertEqual(master_path, str(master.url)) + + master_session = facade.get_session() + self.assertEqual(master_path, str(master_session.bind.url)) + slave_session = facade.get_session(use_slave=True) + self.assertEqual(master_path, str(slave_session.bind.url)) + + +class SQLiteConnectTest(oslo_test.BaseTestCase): + + def _fixture(self, **kw): + return session.create_engine("sqlite://", **kw) + + def test_sqlite_fk_listener(self): + engine = self._fixture(sqlite_fk=True) + self.assertEqual( + engine.scalar("pragma foreign_keys"), + 1 + ) + + engine = self._fixture(sqlite_fk=False) + + self.assertEqual( + engine.scalar("pragma foreign_keys"), + 0 + ) + + def test_sqlite_synchronous_listener(self): + engine = self._fixture() + + # "The default setting is synchronous=FULL." (e.g. 2) + # http://www.sqlite.org/pragma.html#pragma_synchronous + self.assertEqual( + engine.scalar("pragma synchronous"), + 2 + ) + + engine = self._fixture(sqlite_synchronous=False) + + self.assertEqual( + engine.scalar("pragma synchronous"), + 0 + ) + + +class MysqlConnectTest(test_base.MySQLOpportunisticTestCase): + + def _fixture(self, sql_mode): + return session.create_engine(self.engine.url, mysql_sql_mode=sql_mode) + + def _assert_sql_mode(self, engine, sql_mode_present, sql_mode_non_present): + mode = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + self.assertTrue( + sql_mode_present in mode + ) + if sql_mode_non_present: + self.assertTrue( + sql_mode_non_present not in mode + ) + + def test_set_mode_traditional(self): + engine = self._fixture(sql_mode='TRADITIONAL') + self._assert_sql_mode(engine, "TRADITIONAL", "ANSI") + + def test_set_mode_ansi(self): + engine = self._fixture(sql_mode='ANSI') + self._assert_sql_mode(engine, "ANSI", "TRADITIONAL") + + def test_set_mode_no_mode(self): + # If _mysql_set_mode_callback is called with sql_mode=None, then + # the SQL mode is NOT set on the connection. + + expected = self.engine.execute( + "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + + engine = self._fixture(sql_mode=None) + self._assert_sql_mode(engine, expected, None) + + def test_fail_detect_mode(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in no row, then + # we get a log indicating can't detect the mode. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + + mysql_conn = self.engine.raw_connection() + self.addCleanup(mysql_conn.close) + mysql_conn.detach() + mysql_cursor = mysql_conn.cursor() + + def execute(statement, parameters=()): + if "SHOW VARIABLES LIKE 'sql_mode'" in statement: + statement = "SHOW VARIABLES LIKE 'i_dont_exist'" + return mysql_cursor.execute(statement, parameters) + + test_engine = sqlalchemy.create_engine(self.engine.url, + _initialize=False) + + with mock.patch.object( + test_engine.pool, '_creator', + mock.Mock( + return_value=mock.Mock( + cursor=mock.Mock( + return_value=mock.Mock( + execute=execute, + fetchone=mysql_cursor.fetchone, + fetchall=mysql_cursor.fetchall + ) + ) + ) + ) + ): + private_session._init_events.dispatch_on_drivername("mysql")( + test_engine + ) + + test_engine.raw_connection() + self.assertIn('Unable to detect effective SQL mode', + log.output) + + def test_logs_real_mode(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value, then + # we get a log with the value. + + log = self.useFixture(fixtures.FakeLogger(level=logging.DEBUG)) + + engine = self._fixture(sql_mode='TRADITIONAL') + + actual_mode = engine.execute( + "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + + self.assertIn('MySQL server mode set to %s' % actual_mode, + log.output) + + def test_warning_when_not_traditional(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that doesn't + # include 'TRADITIONAL', then a warning is logged. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='ANSI') + + self.assertIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", + log.output) + + def test_no_warning_when_traditional(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes + # 'TRADITIONAL', then no warning is logged. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='TRADITIONAL') + + self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", + log.output) + + def test_no_warning_when_strict_all_tables(self): + # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes + # 'STRICT_ALL_TABLES', then no warning is logged. + + log = self.useFixture(fixtures.FakeLogger(level=logging.WARN)) + self._fixture(sql_mode='TRADITIONAL') + + self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES", + log.output) + + +# NOTE(dhellmann): This test no longer works as written. The code in +# oslo_db.sqlalchemy.session filters out lines from modules under +# oslo_db, and now this test is under oslo_db, so the test filename +# does not appear in the context for the error message. LP #1405376 + +# class PatchStacktraceTest(test_base.DbTestCase): + +# def test_trace(self): +# engine = self.engine +# private_session._add_trace_comments(engine) +# conn = engine.connect() +# with mock.patch.object(engine.dialect, "do_execute") as mock_exec: + +# conn.execute("select * from table") + +# call = mock_exec.mock_calls[0] + +# # we're the caller, see that we're in there +# self.assertTrue("tests/sqlalchemy/test_sqlalchemy.py" in call[1][1]) diff --git a/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py b/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py new file mode 100644 index 00000000..6e865c6c --- /dev/null +++ b/oslo_db/tests/old_import_api/sqlalchemy/test_utils.py @@ -0,0 +1,1093 @@ +# Copyright (c) 2013 Boris Pavlovic (boris@pavlovic.me). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import uuid + +import fixtures +import mock +from oslotest import base as test_base +from oslotest import moxstubout +import six +from six.moves.urllib import parse +import sqlalchemy +from sqlalchemy.dialects import mysql +from sqlalchemy import Boolean, Index, Integer, DateTime, String, SmallInteger +from sqlalchemy import MetaData, Table, Column, ForeignKey +from sqlalchemy.engine import reflection +from sqlalchemy.engine import url as sa_url +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import select +from sqlalchemy.types import UserDefinedType, NullType + +from oslo.db import exception +from oslo.db.sqlalchemy import models +from oslo.db.sqlalchemy import provision +from oslo.db.sqlalchemy import session +from oslo.db.sqlalchemy import test_base as db_test_base +from oslo.db.sqlalchemy import utils +from oslo_db.sqlalchemy import utils as private_utils +from oslo_db.tests.old_import_api import utils as test_utils + + +SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.'))) + + +class TestSanitizeDbUrl(test_base.BaseTestCase): + + def test_url_with_cred(self): + db_url = 'myproto://johndoe:secret@localhost/myschema' + expected = 'myproto://****:****@localhost/myschema' + actual = utils.sanitize_db_url(db_url) + self.assertEqual(expected, actual) + + def test_url_with_no_cred(self): + db_url = 'sqlite:///mysqlitefile' + actual = utils.sanitize_db_url(db_url) + self.assertEqual(db_url, actual) + + +class CustomType(UserDefinedType): + """Dummy column type for testing unsupported types.""" + def get_col_spec(self): + return "CustomType" + + +class FakeModel(object): + def __init__(self, values): + self.values = values + + def __getattr__(self, name): + try: + value = self.values[name] + except KeyError: + raise AttributeError(name) + return value + + def __getitem__(self, key): + if key in self.values: + return self.values[key] + else: + raise NotImplementedError() + + def __repr__(self): + return '' % self.values + + +class TestPaginateQuery(test_base.BaseTestCase): + def setUp(self): + super(TestPaginateQuery, self).setUp() + mox_fixture = self.useFixture(moxstubout.MoxStubout()) + self.mox = mox_fixture.mox + self.query = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(sqlalchemy, 'asc') + self.mox.StubOutWithMock(sqlalchemy, 'desc') + self.marker = FakeModel({ + 'user_id': 'user', + 'project_id': 'p', + 'snapshot_id': 's', + }) + self.model = FakeModel({ + 'user_id': 'user', + 'project_id': 'project', + 'snapshot_id': 'snapshot', + }) + + def test_paginate_query_no_pagination_no_sort_dirs(self): + sqlalchemy.asc('user').AndReturn('asc_3') + self.query.order_by('asc_3').AndReturn(self.query) + sqlalchemy.asc('project').AndReturn('asc_2') + self.query.order_by('asc_2').AndReturn(self.query) + sqlalchemy.asc('snapshot').AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id', 'snapshot_id']) + + def test_paginate_query_no_pagination(self): + sqlalchemy.asc('user').AndReturn('asc') + self.query.order_by('asc').AndReturn(self.query) + sqlalchemy.desc('project').AndReturn('desc') + self.query.order_by('desc').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id'], + sort_dirs=['asc', 'desc']) + + def test_paginate_query_attribute_error(self): + sqlalchemy.asc('user').AndReturn('asc') + self.query.order_by('asc').AndReturn(self.query) + self.mox.ReplayAll() + self.assertRaises(exception.InvalidSortKey, + utils.paginate_query, self.query, + self.model, 5, ['user_id', 'non-existent key']) + + def test_paginate_query_assertion_error(self): + self.mox.ReplayAll() + self.assertRaises(AssertionError, + utils.paginate_query, self.query, + self.model, 5, ['user_id'], + marker=self.marker, + sort_dir='asc', sort_dirs=['asc']) + + def test_paginate_query_assertion_error_2(self): + self.mox.ReplayAll() + self.assertRaises(AssertionError, + utils.paginate_query, self.query, + self.model, 5, ['user_id'], + marker=self.marker, + sort_dir=None, sort_dirs=['asc', 'desk']) + + def test_paginate_query(self): + sqlalchemy.asc('user').AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + sqlalchemy.desc('project').AndReturn('desc_1') + self.query.order_by('desc_1').AndReturn(self.query) + self.mox.StubOutWithMock(sqlalchemy.sql, 'and_') + sqlalchemy.sql.and_(False).AndReturn('some_crit') + sqlalchemy.sql.and_(True, False).AndReturn('another_crit') + self.mox.StubOutWithMock(sqlalchemy.sql, 'or_') + sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f') + self.query.filter('some_f').AndReturn(self.query) + self.query.limit(5).AndReturn(self.query) + self.mox.ReplayAll() + utils.paginate_query(self.query, self.model, 5, + ['user_id', 'project_id'], + marker=self.marker, + sort_dirs=['asc', 'desc']) + + def test_paginate_query_value_error(self): + sqlalchemy.asc('user').AndReturn('asc_1') + self.query.order_by('asc_1').AndReturn(self.query) + self.mox.ReplayAll() + self.assertRaises(ValueError, utils.paginate_query, + self.query, self.model, 5, ['user_id', 'project_id'], + marker=self.marker, sort_dirs=['asc', 'mixed']) + + +class TestMigrationUtils(db_test_base.DbTestCase): + + """Class for testing utils that are used in db migrations.""" + + def setUp(self): + super(TestMigrationUtils, self).setUp() + self.meta = MetaData(bind=self.engine) + self.conn = self.engine.connect() + self.addCleanup(self.meta.drop_all) + self.addCleanup(self.conn.close) + + def _populate_db_for_drop_duplicate_entries(self, engine, meta, + table_name): + values = [ + {'id': 11, 'a': 3, 'b': 10, 'c': 'abcdef'}, + {'id': 12, 'a': 5, 'b': 10, 'c': 'abcdef'}, + {'id': 13, 'a': 6, 'b': 10, 'c': 'abcdef'}, + {'id': 14, 'a': 7, 'b': 10, 'c': 'abcdef'}, + {'id': 21, 'a': 1, 'b': 20, 'c': 'aa'}, + {'id': 31, 'a': 1, 'b': 20, 'c': 'bb'}, + {'id': 41, 'a': 1, 'b': 30, 'c': 'aef'}, + {'id': 42, 'a': 2, 'b': 30, 'c': 'aef'}, + {'id': 43, 'a': 3, 'b': 30, 'c': 'aef'} + ] + + test_table = Table(table_name, meta, + Column('id', Integer, primary_key=True, + nullable=False), + Column('a', Integer), + Column('b', Integer), + Column('c', String(255)), + Column('deleted', Integer, default=0), + Column('deleted_at', DateTime), + Column('updated_at', DateTime)) + + test_table.create() + engine.execute(test_table.insert(), values) + return test_table, values + + def test_drop_old_duplicate_entries_from_table(self): + table_name = "__test_tmp_table__" + + test_table, values = self._populate_db_for_drop_duplicate_entries( + self.engine, self.meta, table_name) + utils.drop_old_duplicate_entries_from_table( + self.engine, table_name, False, 'b', 'c') + + uniq_values = set() + expected_ids = [] + for value in sorted(values, key=lambda x: x['id'], reverse=True): + uniq_value = (('b', value['b']), ('c', value['c'])) + if uniq_value in uniq_values: + continue + uniq_values.add(uniq_value) + expected_ids.append(value['id']) + + real_ids = [row[0] for row in + self.engine.execute(select([test_table.c.id])).fetchall()] + + self.assertEqual(len(real_ids), len(expected_ids)) + for id_ in expected_ids: + self.assertTrue(id_ in real_ids) + + def test_drop_dup_entries_in_file_conn(self): + table_name = "__test_tmp_table__" + tmp_db_file = self.create_tempfiles([['name', '']], ext='.sql')[0] + in_file_engine = session.EngineFacade( + 'sqlite:///%s' % tmp_db_file).get_engine() + meta = MetaData() + meta.bind = in_file_engine + test_table, values = self._populate_db_for_drop_duplicate_entries( + in_file_engine, meta, table_name) + utils.drop_old_duplicate_entries_from_table( + in_file_engine, table_name, False, 'b', 'c') + + def test_drop_old_duplicate_entries_from_table_soft_delete(self): + table_name = "__test_tmp_table__" + + table, values = self._populate_db_for_drop_duplicate_entries( + self.engine, self.meta, table_name) + utils.drop_old_duplicate_entries_from_table(self.engine, table_name, + True, 'b', 'c') + uniq_values = set() + expected_values = [] + soft_deleted_values = [] + + for value in sorted(values, key=lambda x: x['id'], reverse=True): + uniq_value = (('b', value['b']), ('c', value['c'])) + if uniq_value in uniq_values: + soft_deleted_values.append(value) + continue + uniq_values.add(uniq_value) + expected_values.append(value) + + base_select = table.select() + + rows_select = base_select.where(table.c.deleted != table.c.id) + row_ids = [row['id'] for row in + self.engine.execute(rows_select).fetchall()] + self.assertEqual(len(row_ids), len(expected_values)) + for value in expected_values: + self.assertTrue(value['id'] in row_ids) + + deleted_rows_select = base_select.where( + table.c.deleted == table.c.id) + deleted_rows_ids = [row['id'] for row in + self.engine.execute( + deleted_rows_select).fetchall()] + self.assertEqual(len(deleted_rows_ids), + len(values) - len(row_ids)) + for value in soft_deleted_values: + self.assertTrue(value['id'] in deleted_rows_ids) + + def test_change_deleted_column_type_does_not_drop_index(self): + table_name = 'abc' + + indexes = { + 'idx_a_deleted': ['a', 'deleted'], + 'idx_b_deleted': ['b', 'deleted'], + 'idx_a': ['a'] + } + + index_instances = [Index(name, *columns) + for name, columns in six.iteritems(indexes)] + + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('a', String(255)), + Column('b', String(255)), + Column('deleted', Boolean), + *index_instances) + table.create() + utils.change_deleted_column_type_to_id_type(self.engine, table_name) + utils.change_deleted_column_type_to_boolean(self.engine, table_name) + + insp = reflection.Inspector.from_engine(self.engine) + real_indexes = insp.get_indexes(table_name) + self.assertEqual(len(real_indexes), 3) + for index in real_indexes: + name = index['name'] + self.assertIn(name, indexes) + self.assertEqual(set(index['column_names']), + set(indexes[name])) + + def test_change_deleted_column_type_to_id_type_integer(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Boolean)) + table.create() + utils.change_deleted_column_type_to_id_type(self.engine, table_name) + + table = utils.get_table(self.engine, table_name) + self.assertTrue(isinstance(table.c.deleted.type, Integer)) + + def test_change_deleted_column_type_to_id_type_string(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', String(255), primary_key=True), + Column('deleted', Boolean)) + table.create() + utils.change_deleted_column_type_to_id_type(self.engine, table_name) + + table = utils.get_table(self.engine, table_name) + self.assertTrue(isinstance(table.c.deleted.type, String)) + + @db_test_base.backend_specific('sqlite') + def test_change_deleted_column_type_to_id_type_custom(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('foo', CustomType), + Column('deleted', Boolean)) + table.create() + + # reflection of custom types has been fixed upstream + if SA_VERSION < (0, 9, 0): + self.assertRaises(exception.ColumnError, + utils.change_deleted_column_type_to_id_type, + self.engine, table_name) + + fooColumn = Column('foo', CustomType()) + utils.change_deleted_column_type_to_id_type(self.engine, table_name, + foo=fooColumn) + + table = utils.get_table(self.engine, table_name) + # NOTE(boris-42): There is no way to check has foo type CustomType. + # but sqlalchemy will set it to NullType. This has + # been fixed upstream in recent SA versions + if SA_VERSION < (0, 9, 0): + self.assertTrue(isinstance(table.c.foo.type, NullType)) + self.assertTrue(isinstance(table.c.deleted.type, Integer)) + + def test_change_deleted_column_type_to_boolean(self): + expected_types = {'mysql': mysql.TINYINT, + 'ibm_db_sa': SmallInteger} + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Integer)) + table.create() + + utils.change_deleted_column_type_to_boolean(self.engine, table_name) + + table = utils.get_table(self.engine, table_name) + self.assertIsInstance(table.c.deleted.type, + expected_types.get(self.engine.name, Boolean)) + + def test_change_deleted_column_type_to_boolean_with_fc(self): + expected_types = {'mysql': mysql.TINYINT, + 'ibm_db_sa': SmallInteger} + table_name_1 = 'abc' + table_name_2 = 'bcd' + + table_1 = Table(table_name_1, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Integer)) + table_1.create() + + table_2 = Table(table_name_2, self.meta, + Column('id', Integer, primary_key=True), + Column('foreign_id', Integer, + ForeignKey('%s.id' % table_name_1)), + Column('deleted', Integer)) + table_2.create() + + utils.change_deleted_column_type_to_boolean(self.engine, table_name_2) + + table = utils.get_table(self.engine, table_name_2) + self.assertIsInstance(table.c.deleted.type, + expected_types.get(self.engine.name, Boolean)) + + @db_test_base.backend_specific('sqlite') + def test_change_deleted_column_type_to_boolean_type_custom(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('foo', CustomType), + Column('deleted', Integer)) + table.create() + + # reflection of custom types has been fixed upstream + if SA_VERSION < (0, 9, 0): + self.assertRaises(exception.ColumnError, + utils.change_deleted_column_type_to_boolean, + self.engine, table_name) + + fooColumn = Column('foo', CustomType()) + utils.change_deleted_column_type_to_boolean(self.engine, table_name, + foo=fooColumn) + + table = utils.get_table(self.engine, table_name) + # NOTE(boris-42): There is no way to check has foo type CustomType. + # but sqlalchemy will set it to NullType. This has + # been fixed upstream in recent SA versions + if SA_VERSION < (0, 9, 0): + self.assertTrue(isinstance(table.c.foo.type, NullType)) + self.assertTrue(isinstance(table.c.deleted.type, Boolean)) + + @db_test_base.backend_specific('sqlite') + def test_change_deleted_column_type_sqlite_drops_check_constraint(self): + table_name = 'abc' + table = Table(table_name, self.meta, + Column('id', Integer, primary_key=True), + Column('deleted', Boolean)) + table.create() + + private_utils._change_deleted_column_type_to_id_type_sqlite( + self.engine, + table_name, + ) + table = Table(table_name, self.meta, autoload=True) + # NOTE(I159): if the CHECK constraint has been dropped (expected + # behavior), any integer value can be inserted, otherwise only 1 or 0. + self.engine.execute(table.insert({'deleted': 10})) + + def test_insert_from_select(self): + insert_table_name = "__test_insert_to_table__" + select_table_name = "__test_select_from_table__" + uuidstrs = [] + for unused in range(10): + uuidstrs.append(uuid.uuid4().hex) + insert_table = Table( + insert_table_name, self.meta, + Column('id', Integer, primary_key=True, + nullable=False, autoincrement=True), + Column('uuid', String(36), nullable=False)) + select_table = Table( + select_table_name, self.meta, + Column('id', Integer, primary_key=True, + nullable=False, autoincrement=True), + Column('uuid', String(36), nullable=False)) + + insert_table.create() + select_table.create() + # Add 10 rows to select_table + for uuidstr in uuidstrs: + ins_stmt = select_table.insert().values(uuid=uuidstr) + self.conn.execute(ins_stmt) + + # Select 4 rows in one chunk from select_table + column = select_table.c.id + query_insert = select([select_table], + select_table.c.id < 5).order_by(column) + insert_statement = utils.InsertFromSelect(insert_table, + query_insert) + result_insert = self.conn.execute(insert_statement) + # Verify we insert 4 rows + self.assertEqual(result_insert.rowcount, 4) + + query_all = select([insert_table]).where( + insert_table.c.uuid.in_(uuidstrs)) + rows = self.conn.execute(query_all).fetchall() + # Verify we really have 4 rows in insert_table + self.assertEqual(len(rows), 4) + + +class PostgesqlTestMigrations(TestMigrationUtils, + db_test_base.PostgreSQLOpportunisticTestCase): + + """Test migrations on PostgreSQL.""" + pass + + +class MySQLTestMigrations(TestMigrationUtils, + db_test_base.MySQLOpportunisticTestCase): + + """Test migrations on MySQL.""" + pass + + +class TestConnectionUtils(test_utils.BaseTestCase): + + def setUp(self): + super(TestConnectionUtils, self).setUp() + + self.full_credentials = {'backend': 'postgresql', + 'database': 'test', + 'user': 'dude', + 'passwd': 'pass'} + + self.connect_string = 'postgresql://dude:pass@localhost/test' + + def test_connect_string(self): + connect_string = utils.get_connect_string(**self.full_credentials) + self.assertEqual(connect_string, self.connect_string) + + def test_connect_string_sqlite(self): + sqlite_credentials = {'backend': 'sqlite', 'database': 'test.db'} + connect_string = utils.get_connect_string(**sqlite_credentials) + self.assertEqual(connect_string, 'sqlite:///test.db') + + def test_is_backend_avail(self): + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + fake_connection = self.mox.CreateMockAnything() + fake_connection.close() + sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection) + self.mox.ReplayAll() + + self.assertTrue(utils.is_backend_avail(**self.full_credentials)) + + def test_is_backend_unavail(self): + log = self.useFixture(fixtures.FakeLogger()) + err = OperationalError("Can't connect to database", None, None) + error_msg = "The postgresql backend is unavailable: %s\n" % err + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + sqlalchemy.engine.base.Engine.connect().AndRaise(err) + self.mox.ReplayAll() + self.assertFalse(utils.is_backend_avail(**self.full_credentials)) + self.assertEqual(error_msg, log.output) + + def test_ensure_backend_available(self): + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + fake_connection = self.mox.CreateMockAnything() + fake_connection.close() + sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection) + self.mox.ReplayAll() + + eng = provision.Backend._ensure_backend_available(self.connect_string) + self.assertIsInstance(eng, sqlalchemy.engine.base.Engine) + self.assertEqual(self.connect_string, str(eng.url)) + + def test_ensure_backend_available_no_connection_raises(self): + log = self.useFixture(fixtures.FakeLogger()) + err = OperationalError("Can't connect to database", None, None) + self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect') + sqlalchemy.engine.base.Engine.connect().AndRaise(err) + self.mox.ReplayAll() + + exc = self.assertRaises( + exception.BackendNotAvailable, + provision.Backend._ensure_backend_available, self.connect_string + ) + self.assertEqual("Could not connect", str(exc)) + self.assertEqual( + "The postgresql backend is unavailable: %s" % err, + log.output.strip()) + + def test_ensure_backend_available_no_dbapi_raises(self): + log = self.useFixture(fixtures.FakeLogger()) + self.mox.StubOutWithMock(sqlalchemy, 'create_engine') + sqlalchemy.create_engine( + sa_url.make_url(self.connect_string)).AndRaise( + ImportError("Can't import DBAPI module foobar")) + self.mox.ReplayAll() + + exc = self.assertRaises( + exception.BackendNotAvailable, + provision.Backend._ensure_backend_available, self.connect_string + ) + self.assertEqual("No DBAPI installed", str(exc)) + self.assertEqual( + "The postgresql backend is unavailable: Can't import " + "DBAPI module foobar", log.output.strip()) + + def test_get_db_connection_info(self): + conn_pieces = parse.urlparse(self.connect_string) + self.assertEqual(utils.get_db_connection_info(conn_pieces), + ('dude', 'pass', 'test', 'localhost')) + + def test_connect_string_host(self): + self.full_credentials['host'] = 'myhost' + connect_string = utils.get_connect_string(**self.full_credentials) + self.assertEqual(connect_string, 'postgresql://dude:pass@myhost/test') + + +class MyModelSoftDeletedProjectId(declarative_base(), models.ModelBase, + models.SoftDeleteMixin): + __tablename__ = 'soft_deleted_project_id_test_model' + id = Column(Integer, primary_key=True) + project_id = Column(Integer) + + +class MyModel(declarative_base(), models.ModelBase): + __tablename__ = 'test_model' + id = Column(Integer, primary_key=True) + + +class MyModelSoftDeleted(declarative_base(), models.ModelBase, + models.SoftDeleteMixin): + __tablename__ = 'soft_deleted_test_model' + id = Column(Integer, primary_key=True) + + +class TestModelQuery(test_base.BaseTestCase): + + def setUp(self): + super(TestModelQuery, self).setUp() + + self.session = mock.MagicMock() + self.session.query.return_value = self.session.query + self.session.query.filter.return_value = self.session.query + + def test_wrong_model(self): + self.assertRaises(TypeError, utils.model_query, + FakeModel, session=self.session) + + def test_no_soft_deleted(self): + self.assertRaises(ValueError, utils.model_query, + MyModel, session=self.session, deleted=True) + + def test_deleted_false(self): + mock_query = utils.model_query( + MyModelSoftDeleted, session=self.session, deleted=False) + + deleted_filter = mock_query.filter.call_args[0][0] + self.assertEqual(str(deleted_filter), + 'soft_deleted_test_model.deleted = :deleted_1') + self.assertEqual(deleted_filter.right.value, + MyModelSoftDeleted.__mapper__.c.deleted.default.arg) + + def test_deleted_true(self): + mock_query = utils.model_query( + MyModelSoftDeleted, session=self.session, deleted=True) + + deleted_filter = mock_query.filter.call_args[0][0] + self.assertEqual(str(deleted_filter), + 'soft_deleted_test_model.deleted != :deleted_1') + self.assertEqual(deleted_filter.right.value, + MyModelSoftDeleted.__mapper__.c.deleted.default.arg) + + @mock.patch('oslo_db.sqlalchemy.utils._read_deleted_filter') + def test_no_deleted_value(self, _read_deleted_filter): + utils.model_query(MyModelSoftDeleted, session=self.session) + self.assertEqual(_read_deleted_filter.call_count, 0) + + def test_project_filter(self): + project_id = 10 + + mock_query = utils.model_query( + MyModelSoftDeletedProjectId, session=self.session, + project_only=True, project_id=project_id) + + deleted_filter = mock_query.filter.call_args[0][0] + self.assertEqual( + str(deleted_filter), + 'soft_deleted_project_id_test_model.project_id = :project_id_1') + self.assertEqual(deleted_filter.right.value, project_id) + + def test_project_filter_wrong_model(self): + self.assertRaises(ValueError, utils.model_query, + MyModelSoftDeleted, session=self.session, + project_id=10) + + def test_project_filter_allow_none(self): + mock_query = utils.model_query( + MyModelSoftDeletedProjectId, + session=self.session, project_id=(10, None)) + + self.assertEqual( + str(mock_query.filter.call_args[0][0]), + 'soft_deleted_project_id_test_model.project_id' + ' IN (:project_id_1, NULL)' + ) + + def test_model_query_common(self): + utils.model_query(MyModel, args=(MyModel.id,), session=self.session) + self.session.query.assert_called_with(MyModel.id) + + +class TestUtils(db_test_base.DbTestCase): + def setUp(self): + super(TestUtils, self).setUp() + meta = MetaData(bind=self.engine) + self.test_table = Table( + 'test_table', + meta, + Column('a', Integer), + Column('b', Integer) + ) + self.test_table.create() + self.addCleanup(meta.drop_all) + + def test_index_exists(self): + self.assertFalse(utils.index_exists(self.engine, 'test_table', + 'new_index')) + Index('new_index', self.test_table.c.a).create(self.engine) + self.assertTrue(utils.index_exists(self.engine, 'test_table', + 'new_index')) + + def test_add_index(self): + self.assertFalse(utils.index_exists(self.engine, 'test_table', + 'new_index')) + utils.add_index(self.engine, 'test_table', 'new_index', ('a',)) + self.assertTrue(utils.index_exists(self.engine, 'test_table', + 'new_index')) + + def test_add_existing_index(self): + Index('new_index', self.test_table.c.a).create(self.engine) + self.assertRaises(ValueError, utils.add_index, self.engine, + 'test_table', 'new_index', ('a',)) + + def test_drop_index(self): + Index('new_index', self.test_table.c.a).create(self.engine) + utils.drop_index(self.engine, 'test_table', 'new_index') + self.assertFalse(utils.index_exists(self.engine, 'test_table', + 'new_index')) + + def test_drop_unexisting_index(self): + self.assertRaises(ValueError, utils.drop_index, self.engine, + 'test_table', 'new_index') + + @mock.patch('oslo_db.sqlalchemy.utils.drop_index') + @mock.patch('oslo_db.sqlalchemy.utils.add_index') + def test_change_index_columns(self, add_index, drop_index): + utils.change_index_columns(self.engine, 'test_table', 'a_index', + ('a',)) + drop_index.assert_called_once_with(self.engine, 'test_table', + 'a_index') + add_index.assert_called_once_with(self.engine, 'test_table', + 'a_index', ('a',)) + + def test_column_exists(self): + for col in ['a', 'b']: + self.assertTrue(utils.column_exists(self.engine, 'test_table', + col)) + self.assertFalse(utils.column_exists(self.engine, 'test_table', + 'fake_column')) + + +class TestUtilsMysqlOpportunistically( + TestUtils, db_test_base.MySQLOpportunisticTestCase): + pass + + +class TestUtilsPostgresqlOpportunistically( + TestUtils, db_test_base.PostgreSQLOpportunisticTestCase): + pass + + +class TestDialectFunctionDispatcher(test_base.BaseTestCase): + def _single_fixture(self): + callable_fn = mock.Mock() + + dispatcher = orig = utils.dispatch_for_dialect("*")( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql_mysqldb) + dispatcher = dispatcher.dispatch_for("postgresql")( + callable_fn.postgresql) + + self.assertTrue(dispatcher is orig) + + return dispatcher, callable_fn + + def _multiple_fixture(self): + callable_fn = mock.Mock() + + for targ in [ + callable_fn.default, + callable_fn.sqlite, + callable_fn.mysql_mysqldb, + callable_fn.postgresql, + callable_fn.postgresql_psycopg2, + callable_fn.pyodbc + ]: + targ.return_value = None + + dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql_mysqldb) + dispatcher = dispatcher.dispatch_for("postgresql+*")( + callable_fn.postgresql) + dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")( + callable_fn.postgresql_psycopg2) + dispatcher = dispatcher.dispatch_for("*+pyodbc")( + callable_fn.pyodbc) + + self.assertTrue(dispatcher is orig) + + return dispatcher, callable_fn + + def test_single(self): + + dispatcher, callable_fn = self._single_fixture() + dispatcher("sqlite://", 1) + dispatcher("postgresql+psycopg2://u:p@h/t", 2) + dispatcher("mysql://u:p@h/t", 3) + dispatcher("mysql+mysqlconnector://u:p@h/t", 4) + + self.assertEqual( + [ + mock.call.sqlite('sqlite://', 1), + mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2), + mock.call.mysql_mysqldb("mysql://u:p@h/t", 3), + mock.call.default("mysql+mysqlconnector://u:p@h/t", 4) + ], + callable_fn.mock_calls) + + def test_single_kwarg(self): + dispatcher, callable_fn = self._single_fixture() + dispatcher("sqlite://", foo='bar') + dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y') + + self.assertEqual( + [ + mock.call.sqlite('sqlite://', foo='bar'), + mock.call.postgresql( + "postgresql+psycopg2://u:p@h/t", + 1, x='y'), + ], + callable_fn.mock_calls) + + def test_dispatch_on_target(self): + callable_fn = mock.Mock() + + @utils.dispatch_for_dialect("*") + def default_fn(url, x, y): + callable_fn.default(url, x, y) + + @default_fn.dispatch_for("sqlite") + def sqlite_fn(url, x, y): + callable_fn.sqlite(url, x, y) + default_fn.dispatch_on_drivername("*")(url, x, y) + + default_fn("sqlite://", 4, 5) + self.assertEqual( + [ + mock.call.sqlite("sqlite://", 4, 5), + mock.call.default("sqlite://", 4, 5) + ], + callable_fn.mock_calls + ) + + def test_single_no_dispatcher(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql) + exc = self.assertRaises( + ValueError, + dispatcher, "postgresql://s:t@localhost/test" + ) + self.assertEqual( + "No default function found for driver: 'postgresql+psycopg2'", + str(exc) + ) + + def test_multiple_no_dispatcher(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)( + callable_fn.sqlite) + dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql) + dispatcher("postgresql://s:t@localhost/test") + self.assertEqual( + [], callable_fn.mock_calls + ) + + def test_multiple_no_driver(self): + callable_fn = mock.Mock( + default=mock.Mock(return_value=None), + sqlite=mock.Mock(return_value=None) + ) + + dispatcher = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")( + callable_fn.sqlite) + + dispatcher.dispatch_on_drivername("sqlite")("foo") + self.assertEqual( + [mock.call.sqlite("foo"), mock.call.default("foo")], + callable_fn.mock_calls + ) + + def test_multiple_nesting(self): + callable_fn = mock.Mock( + default=mock.Mock(return_value=None), + mysql=mock.Mock(return_value=None) + ) + + dispatcher = utils.dispatch_for_dialect("*", multiple=True)( + callable_fn.default) + + dispatcher = dispatcher.dispatch_for("mysql+mysqlconnector")( + dispatcher.dispatch_for("mysql+mysqldb")( + callable_fn.mysql + ) + ) + + mysqldb_url = sqlalchemy.engine.url.make_url("mysql+mysqldb://") + mysqlconnector_url = sqlalchemy.engine.url.make_url( + "mysql+mysqlconnector://") + sqlite_url = sqlalchemy.engine.url.make_url("sqlite://") + + dispatcher(mysqldb_url, 1) + dispatcher(mysqlconnector_url, 2) + dispatcher(sqlite_url, 3) + + self.assertEqual( + [ + mock.call.mysql(mysqldb_url, 1), + mock.call.default(mysqldb_url, 1), + mock.call.mysql(mysqlconnector_url, 2), + mock.call.default(mysqlconnector_url, 2), + mock.call.default(sqlite_url, 3) + ], + callable_fn.mock_calls + ) + + def test_single_retval(self): + dispatcher, callable_fn = self._single_fixture() + callable_fn.mysql_mysqldb.return_value = 5 + + self.assertEqual( + dispatcher("mysql://u:p@h/t", 3), 5 + ) + + def test_engine(self): + eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db") + dispatcher, callable_fn = self._single_fixture() + + dispatcher(eng) + self.assertEqual( + [mock.call.sqlite(eng)], + callable_fn.mock_calls + ) + + def test_url(self): + url = sqlalchemy.engine.url.make_url( + "mysql+mysqldb://scott:tiger@localhost/test") + dispatcher, callable_fn = self._single_fixture() + + dispatcher(url, 15) + self.assertEqual( + [mock.call.mysql_mysqldb(url, 15)], + callable_fn.mock_calls + ) + + def test_invalid_target(self): + dispatcher, callable_fn = self._single_fixture() + + exc = self.assertRaises( + ValueError, + dispatcher, 20 + ) + self.assertEqual("Invalid target type: 20", str(exc)) + + def test_invalid_dispatch(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default) + + exc = self.assertRaises( + ValueError, + dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc + ) + self.assertEqual( + "Couldn't parse database[+driver]: '+pyodbc'", + str(exc) + ) + + def test_single_only_one_target(self): + callable_fn = mock.Mock() + + dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default) + dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite) + + exc = self.assertRaises( + TypeError, + dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2 + ) + self.assertEqual( + "Multiple functions for expression 'sqlite'", str(exc) + ) + + def test_multiple(self): + dispatcher, callable_fn = self._multiple_fixture() + + dispatcher("postgresql+pyodbc://", 1) + dispatcher("mysql://", 2) + dispatcher("ibm_db_sa+db2://", 3) + dispatcher("postgresql+psycopg2://", 4) + + # TODO(zzzeek): there is a deterministic order here, but we might + # want to tweak it, or maybe provide options. default first? + # most specific first? is *+pyodbc or postgresql+* more specific? + self.assertEqual( + [ + mock.call.postgresql('postgresql+pyodbc://', 1), + mock.call.pyodbc('postgresql+pyodbc://', 1), + mock.call.default('postgresql+pyodbc://', 1), + mock.call.mysql_mysqldb('mysql://', 2), + mock.call.default('mysql://', 2), + mock.call.default('ibm_db_sa+db2://', 3), + mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4), + mock.call.postgresql('postgresql+psycopg2://', 4), + mock.call.default('postgresql+psycopg2://', 4), + ], + callable_fn.mock_calls + ) + + def test_multiple_no_return_value(self): + dispatcher, callable_fn = self._multiple_fixture() + callable_fn.sqlite.return_value = 5 + + exc = self.assertRaises( + TypeError, + dispatcher, "sqlite://" + ) + self.assertEqual( + "Return value not allowed for multiple filtered function", + str(exc) + ) + + +class TestGetInnoDBTables(db_test_base.MySQLOpportunisticTestCase): + + def test_all_tables_use_innodb(self): + self.engine.execute("CREATE TABLE customers " + "(a INT, b CHAR (20), INDEX (a)) ENGINE=InnoDB") + self.assertEqual([], utils.get_non_innodb_tables(self.engine)) + + def test_all_tables_use_innodb_false(self): + self.engine.execute("CREATE TABLE employee " + "(i INT) ENGINE=MEMORY") + self.assertEqual(['employee'], + utils.get_non_innodb_tables(self.engine)) + + def test_skip_tables_use_default_value(self): + self.engine.execute("CREATE TABLE migrate_version " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables(self.engine)) + + def test_skip_tables_use_passed_value(self): + self.engine.execute("CREATE TABLE some_table " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables( + self.engine, skip_tables=('some_table',))) + + def test_skip_tables_use_empty_list(self): + self.engine.execute("CREATE TABLE some_table_3 " + "(i INT) ENGINE=MEMORY") + self.assertEqual(['some_table_3'], + utils.get_non_innodb_tables( + self.engine, skip_tables=())) + + def test_skip_tables_use_several_values(self): + self.engine.execute("CREATE TABLE some_table_1 " + "(i INT) ENGINE=MEMORY") + self.engine.execute("CREATE TABLE some_table_2 " + "(i INT) ENGINE=MEMORY") + self.assertEqual([], + utils.get_non_innodb_tables( + self.engine, + skip_tables=('some_table_1', 'some_table_2'))) diff --git a/tests/test_api.py b/oslo_db/tests/old_import_api/test_api.py similarity index 97% rename from tests/test_api.py rename to oslo_db/tests/old_import_api/test_api.py index 2168cd5a..aa69d556 100644 --- a/tests/test_api.py +++ b/oslo_db/tests/old_import_api/test_api.py @@ -21,7 +21,7 @@ from oslo.utils import importutils from oslo.db import api from oslo.db import exception -from tests import utils as test_utils +from oslo_db.tests.old_import_api import utils as test_utils sqla = importutils.try_import('sqlalchemy') if not sqla: @@ -66,7 +66,7 @@ class DBAPI(object): class DBAPITestCase(test_utils.BaseTestCase): def test_dbapi_full_path_module_method(self): - dbapi = api.DBAPI('tests.test_api') + dbapi = api.DBAPI('oslo_db.tests.test_api') result = dbapi.api_class_call1(1, 2, kwarg1='meow') expected = ((1, 2), {'kwarg1': 'meow'}) self.assertEqual(expected, result) @@ -75,7 +75,7 @@ class DBAPITestCase(test_utils.BaseTestCase): self.assertRaises(ImportError, api.DBAPI, 'tests.unit.db.not_existent') def test_dbapi_lazy_loading(self): - dbapi = api.DBAPI('tests.test_api', lazy=True) + dbapi = api.DBAPI('oslo_db.tests.test_api', lazy=True) self.assertIsNone(dbapi._backend) dbapi.api_class_call1(1, 'abc') diff --git a/tests/test_concurrency.py b/oslo_db/tests/old_import_api/test_concurrency.py similarity index 95% rename from tests/test_concurrency.py rename to oslo_db/tests/old_import_api/test_concurrency.py index a53ea804..eea07eda 100644 --- a/tests/test_concurrency.py +++ b/oslo_db/tests/old_import_api/test_concurrency.py @@ -18,7 +18,7 @@ import sys import mock from oslo.db import concurrency -from tests import utils as test_utils +from oslo_db.tests.old_import_api import utils as test_utils FAKE_BACKEND_MAPPING = {'sqlalchemy': 'fake.db.sqlalchemy.api'} @@ -47,7 +47,7 @@ class TpoolDbapiWrapperTestCase(test_utils.BaseTestCase): sys.modules['eventlet'] = self.eventlet self.addCleanup(sys.modules.pop, 'eventlet', None) - @mock.patch('oslo.db.api.DBAPI') + @mock.patch('oslo_db.api.DBAPI') def test_db_api_common(self, mock_db_api): # test context: # CONF.database.use_tpool == False @@ -73,7 +73,7 @@ class TpoolDbapiWrapperTestCase(test_utils.BaseTestCase): self.assertFalse(self.eventlet.tpool.Proxy.called) self.assertEqual(1, mock_db_api.from_config.call_count) - @mock.patch('oslo.db.api.DBAPI') + @mock.patch('oslo_db.api.DBAPI') def test_db_api_config_change(self, mock_db_api): # test context: # CONF.database.use_tpool == True @@ -94,7 +94,7 @@ class TpoolDbapiWrapperTestCase(test_utils.BaseTestCase): self.eventlet.tpool.Proxy.assert_called_once_with(fake_db_api) self.assertEqual(self.db_api._db_api, self.proxy) - @mock.patch('oslo.db.api.DBAPI') + @mock.patch('oslo_db.api.DBAPI') def test_db_api_without_installed_eventlet(self, mock_db_api): # test context: # CONF.database.use_tpool == True diff --git a/oslo_db/tests/old_import_api/test_warning.py b/oslo_db/tests/old_import_api/test_warning.py new file mode 100644 index 00000000..b8abdad2 --- /dev/null +++ b/oslo_db/tests/old_import_api/test_warning.py @@ -0,0 +1,61 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import imp +import os +import warnings + +import mock +from oslotest import base as test_base +import six + + +class DeprecationWarningTest(test_base.BaseTestCase): + + @mock.patch('warnings.warn') + def test_warning(self, mock_warn): + import oslo.db + imp.reload(oslo.db) + self.assertTrue(mock_warn.called) + args = mock_warn.call_args + self.assertIn('oslo_db', args[0][0]) + self.assertIn('deprecated', args[0][0]) + self.assertTrue(issubclass(args[0][1], DeprecationWarning)) + + def test_real_warning(self): + with warnings.catch_warnings(record=True) as warning_msgs: + warnings.resetwarnings() + warnings.simplefilter('always', DeprecationWarning) + import oslo.db + + # Use a separate function to get the stack level correct + # so we know the message points back to this file. This + # corresponds to an import or reload, which isn't working + # inside the test under Python 3.3. That may be due to a + # difference in the import implementation not triggering + # warnings properly when the module is reloaded, or + # because the warnings module is mostly implemented in C + # and something isn't cleanly resetting the global state + # used to track whether a warning needs to be + # emitted. Whatever the cause, we definitely see the + # warnings.warn() being invoked on a reload (see the test + # above) and warnings are reported on the console when we + # run the tests. A simpler test script run outside of + # testr does correctly report the warnings. + def foo(): + oslo.db.deprecated() + + foo() + self.assertEqual(1, len(warning_msgs)) + msg = warning_msgs[0] + self.assertIn('oslo_db', six.text_type(msg.message)) + self.assertEqual('test_warning.py', os.path.basename(msg.filename)) diff --git a/tests/utils.py b/oslo_db/tests/old_import_api/utils.py similarity index 100% rename from tests/utils.py rename to oslo_db/tests/old_import_api/utils.py diff --git a/oslo_db/tests/sqlalchemy/__init__.py b/oslo_db/tests/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sqlalchemy/test_engine_connect.py b/oslo_db/tests/sqlalchemy/test_engine_connect.py similarity index 93% rename from tests/sqlalchemy/test_engine_connect.py rename to oslo_db/tests/sqlalchemy/test_engine_connect.py index 80d17cb3..c75511f6 100644 --- a/tests/sqlalchemy/test_engine_connect.py +++ b/oslo_db/tests/sqlalchemy/test_engine_connect.py @@ -12,7 +12,7 @@ """Test the compatibility layer for the engine_connect() event. -This event is added as of SQLAlchemy 0.9.0; oslo.db provides a compatibility +This event is added as of SQLAlchemy 0.9.0; oslo_db provides a compatibility layer for prior SQLAlchemy versions. """ @@ -21,7 +21,7 @@ import mock from oslotest import base as test_base import sqlalchemy as sqla -from oslo.db.sqlalchemy.compat import engine_connect +from oslo_db.sqlalchemy.compat import engine_connect class EngineConnectTest(test_base.BaseTestCase): diff --git a/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py similarity index 99% rename from tests/sqlalchemy/test_exc_filters.py rename to oslo_db/tests/sqlalchemy/test_exc_filters.py index a3f91a6c..157a183e 100644 --- a/tests/sqlalchemy/test_exc_filters.py +++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py @@ -21,12 +21,12 @@ import six import sqlalchemy as sqla from sqlalchemy.orm import mapper -from oslo.db import exception -from oslo.db.sqlalchemy import compat -from oslo.db.sqlalchemy import exc_filters -from oslo.db.sqlalchemy import session -from oslo.db.sqlalchemy import test_base -from tests import utils as test_utils +from oslo_db import exception +from oslo_db.sqlalchemy import compat +from oslo_db.sqlalchemy import exc_filters +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import test_base +from oslo_db.tests import utils as test_utils _TABLE_NAME = '__tmp__test__tmp__' diff --git a/tests/sqlalchemy/test_handle_error.py b/oslo_db/tests/sqlalchemy/test_handle_error.py similarity index 97% rename from tests/sqlalchemy/test_handle_error.py rename to oslo_db/tests/sqlalchemy/test_handle_error.py index 14269a38..83322ef9 100644 --- a/tests/sqlalchemy/test_handle_error.py +++ b/oslo_db/tests/sqlalchemy/test_handle_error.py @@ -12,7 +12,7 @@ """Test the compatibility layer for the handle_error() event. -This event is added as of SQLAlchemy 0.9.7; oslo.db provides a compatibility +This event is added as of SQLAlchemy 0.9.7; oslo_db provides a compatibility layer for prior SQLAlchemy versions. """ @@ -26,9 +26,9 @@ from sqlalchemy.sql import select from sqlalchemy.types import Integer from sqlalchemy.types import TypeDecorator -from oslo.db.sqlalchemy.compat import handle_error -from oslo.db.sqlalchemy.compat import utils -from tests import utils as test_utils +from oslo_db.sqlalchemy.compat import handle_error +from oslo_db.sqlalchemy.compat import utils +from oslo_db.tests import utils as test_utils class MyException(Exception): diff --git a/oslo_db/tests/sqlalchemy/test_migrate_cli.py b/oslo_db/tests/sqlalchemy/test_migrate_cli.py new file mode 100644 index 00000000..c1ab53c7 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_migrate_cli.py @@ -0,0 +1,222 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock +from oslotest import base as test_base + +from oslo_db.sqlalchemy.migration_cli import ext_alembic +from oslo_db.sqlalchemy.migration_cli import ext_migrate +from oslo_db.sqlalchemy.migration_cli import manager + + +class MockWithCmp(mock.MagicMock): + + order = 0 + + def __init__(self, *args, **kwargs): + super(MockWithCmp, self).__init__(*args, **kwargs) + + self.__lt__ = lambda self, other: self.order < other.order + + +@mock.patch(('oslo_db.sqlalchemy.migration_cli.' + 'ext_alembic.alembic.command')) +class TestAlembicExtension(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'alembic_ini_path': '.', + 'db_url': 'sqlite://'} + self.alembic = ext_alembic.AlembicExtension(self.migration_config) + super(TestAlembicExtension, self).setUp() + + def test_check_enabled_true(self, command): + """Check enabled returns True + + Verifies that enabled returns True on non empty + alembic_ini_path conf variable + """ + self.assertTrue(self.alembic.enabled) + + def test_check_enabled_false(self, command): + """Check enabled returns False + + Verifies enabled returns False on empty alembic_ini_path variable + """ + self.migration_config['alembic_ini_path'] = '' + alembic = ext_alembic.AlembicExtension(self.migration_config) + self.assertFalse(alembic.enabled) + + def test_upgrade_none(self, command): + self.alembic.upgrade(None) + command.upgrade.assert_called_once_with(self.alembic.config, 'head') + + def test_upgrade_normal(self, command): + self.alembic.upgrade('131daa') + command.upgrade.assert_called_once_with(self.alembic.config, '131daa') + + def test_downgrade_none(self, command): + self.alembic.downgrade(None) + command.downgrade.assert_called_once_with(self.alembic.config, 'base') + + def test_downgrade_int(self, command): + self.alembic.downgrade(111) + command.downgrade.assert_called_once_with(self.alembic.config, 'base') + + def test_downgrade_normal(self, command): + self.alembic.downgrade('131daa') + command.downgrade.assert_called_once_with( + self.alembic.config, '131daa') + + def test_revision(self, command): + self.alembic.revision(message='test', autogenerate=True) + command.revision.assert_called_once_with( + self.alembic.config, message='test', autogenerate=True) + + def test_stamp(self, command): + self.alembic.stamp('stamp') + command.stamp.assert_called_once_with( + self.alembic.config, revision='stamp') + + def test_version(self, command): + version = self.alembic.version() + self.assertIsNone(version) + + +@mock.patch(('oslo_db.sqlalchemy.migration_cli.' + 'ext_migrate.migration')) +class TestMigrateExtension(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'migration_repo_path': '.', + 'db_url': 'sqlite://'} + self.migrate = ext_migrate.MigrateExtension(self.migration_config) + super(TestMigrateExtension, self).setUp() + + def test_check_enabled_true(self, migration): + self.assertTrue(self.migrate.enabled) + + def test_check_enabled_false(self, migration): + self.migration_config['migration_repo_path'] = '' + migrate = ext_migrate.MigrateExtension(self.migration_config) + self.assertFalse(migrate.enabled) + + def test_upgrade_head(self, migration): + self.migrate.upgrade('head') + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, None, init_version=0) + + def test_upgrade_normal(self, migration): + self.migrate.upgrade(111) + migration.db_sync.assert_called_once_with( + mock.ANY, self.migrate.repository, 111, init_version=0) + + def test_downgrade_init_version_from_base(self, migration): + self.migrate.downgrade('base') + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, mock.ANY, + init_version=mock.ANY) + + def test_downgrade_init_version_from_none(self, migration): + self.migrate.downgrade(None) + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, mock.ANY, + init_version=mock.ANY) + + def test_downgrade_normal(self, migration): + self.migrate.downgrade(101) + migration.db_sync.assert_called_once_with( + self.migrate.engine, self.migrate.repository, 101, init_version=0) + + def test_version(self, migration): + self.migrate.version() + migration.db_version.assert_called_once_with( + self.migrate.engine, self.migrate.repository, init_version=0) + + def test_change_init_version(self, migration): + self.migration_config['init_version'] = 101 + migrate = ext_migrate.MigrateExtension(self.migration_config) + migrate.downgrade(None) + migration.db_sync.assert_called_once_with( + migrate.engine, + self.migrate.repository, + self.migration_config['init_version'], + init_version=self.migration_config['init_version']) + + +class TestMigrationManager(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'alembic_ini_path': '.', + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} + self.migration_manager = manager.MigrationManager( + self.migration_config) + self.ext = mock.Mock() + self.migration_manager._manager.extensions = [self.ext] + super(TestMigrationManager, self).setUp() + + def test_manager_update(self): + self.migration_manager.upgrade('head') + self.ext.obj.upgrade.assert_called_once_with('head') + + def test_manager_update_revision_none(self): + self.migration_manager.upgrade(None) + self.ext.obj.upgrade.assert_called_once_with(None) + + def test_downgrade_normal_revision(self): + self.migration_manager.downgrade('111abcd') + self.ext.obj.downgrade.assert_called_once_with('111abcd') + + def test_version(self): + self.migration_manager.version() + self.ext.obj.version.assert_called_once_with() + + def test_revision_message_autogenerate(self): + self.migration_manager.revision('test', True) + self.ext.obj.revision.assert_called_once_with('test', True) + + def test_revision_only_message(self): + self.migration_manager.revision('test', False) + self.ext.obj.revision.assert_called_once_with('test', False) + + def test_stamp(self): + self.migration_manager.stamp('stamp') + self.ext.obj.stamp.assert_called_once_with('stamp') + + +class TestMigrationRightOrder(test_base.BaseTestCase): + + def setUp(self): + self.migration_config = {'alembic_ini_path': '.', + 'migrate_repo_path': '.', + 'db_url': 'sqlite://'} + self.migration_manager = manager.MigrationManager( + self.migration_config) + self.first_ext = MockWithCmp() + self.first_ext.obj.order = 1 + self.first_ext.obj.upgrade.return_value = 100 + self.first_ext.obj.downgrade.return_value = 0 + self.second_ext = MockWithCmp() + self.second_ext.obj.order = 2 + self.second_ext.obj.upgrade.return_value = 200 + self.second_ext.obj.downgrade.return_value = 100 + self.migration_manager._manager.extensions = [self.first_ext, + self.second_ext] + super(TestMigrationRightOrder, self).setUp() + + def test_upgrade_right_order(self): + results = self.migration_manager.upgrade(None) + self.assertEqual(results, [100, 200]) + + def test_downgrade_right_order(self): + results = self.migration_manager.downgrade(None) + self.assertEqual(results, [100, 0]) diff --git a/tests/sqlalchemy/test_migration_common.py b/oslo_db/tests/sqlalchemy/test_migration_common.py similarity index 98% rename from tests/sqlalchemy/test_migration_common.py rename to oslo_db/tests/sqlalchemy/test_migration_common.py index 86f5d3f6..95efab13 100644 --- a/tests/sqlalchemy/test_migration_common.py +++ b/oslo_db/tests/sqlalchemy/test_migration_common.py @@ -22,10 +22,10 @@ from migrate.versioning import api as versioning_api import mock import sqlalchemy -from oslo.db import exception as db_exception -from oslo.db.sqlalchemy import migration -from oslo.db.sqlalchemy import test_base -from tests import utils as test_utils +from oslo_db import exception as db_exception +from oslo_db.sqlalchemy import migration +from oslo_db.sqlalchemy import test_base +from oslo_db.tests import utils as test_utils class TestMigrationCommon(test_base.DbTestCase): diff --git a/oslo_db/tests/sqlalchemy/test_migrations.py b/oslo_db/tests/sqlalchemy/test_migrations.py new file mode 100644 index 00000000..a372d8b9 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_migrations.py @@ -0,0 +1,309 @@ +# Copyright 2010-2011 OpenStack Foundation +# Copyright 2012-2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import fixtures +import mock +from oslotest import base as test +import six +import sqlalchemy as sa +import sqlalchemy.ext.declarative as sa_decl + +from oslo_db import exception as exc +from oslo_db.sqlalchemy import test_base +from oslo_db.sqlalchemy import test_migrations as migrate + + +class TestWalkVersions(test.BaseTestCase, migrate.WalkVersionsMixin): + migration_api = mock.MagicMock() + REPOSITORY = mock.MagicMock() + engine = mock.MagicMock() + INIT_VERSION = 4 + + @property + def migrate_engine(self): + return self.engine + + def test_migrate_up(self): + self.migration_api.db_version.return_value = 141 + + self.migrate_up(141) + + self.migration_api.upgrade.assert_called_with( + self.engine, self.REPOSITORY, 141) + self.migration_api.db_version.assert_called_with( + self.engine, self.REPOSITORY) + + def test_migrate_up_fail(self): + version = 141 + self.migration_api.db_version.return_value = version + expected_output = (u"Failed to migrate to version %(version)s on " + "engine %(engine)s\n" % + {'version': version, 'engine': self.engine}) + + with mock.patch.object(self.migration_api, + 'upgrade', + side_effect=exc.DbMigrationError): + log = self.useFixture(fixtures.FakeLogger()) + self.assertRaises(exc.DbMigrationError, self.migrate_up, version) + self.assertEqual(expected_output, log.output) + + def test_migrate_up_with_data(self): + test_value = {"a": 1, "b": 2} + self.migration_api.db_version.return_value = 141 + self._pre_upgrade_141 = mock.MagicMock() + self._pre_upgrade_141.return_value = test_value + self._check_141 = mock.MagicMock() + + self.migrate_up(141, True) + + self._pre_upgrade_141.assert_called_with(self.engine) + self._check_141.assert_called_with(self.engine, test_value) + + def test_migrate_down(self): + self.migration_api.db_version.return_value = 42 + + self.assertTrue(self.migrate_down(42)) + self.migration_api.db_version.assert_called_with( + self.engine, self.REPOSITORY) + + def test_migrate_down_not_implemented(self): + with mock.patch.object(self.migration_api, + 'downgrade', + side_effect=NotImplementedError): + self.assertFalse(self.migrate_down(self.engine, 42)) + + def test_migrate_down_with_data(self): + self._post_downgrade_043 = mock.MagicMock() + self.migration_api.db_version.return_value = 42 + + self.migrate_down(42, True) + + self._post_downgrade_043.assert_called_with(self.engine) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_all_default(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions() + + self.migration_api.version_control.assert_called_with( + self.engine, self.REPOSITORY, self.INIT_VERSION) + self.migration_api.db_version.assert_called_with( + self.engine, self.REPOSITORY) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + upgraded = [mock.call(v, with_data=True) + for v in versions] + self.assertEqual(self.migrate_up.call_args_list, upgraded) + + downgraded = [mock.call(v - 1) for v in reversed(versions)] + self.assertEqual(self.migrate_down.call_args_list, downgraded) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_all_true(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions(snake_walk=True, downgrade=True) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + upgraded = [] + for v in versions: + upgraded.append(mock.call(v, with_data=True)) + upgraded.append(mock.call(v)) + upgraded.extend([mock.call(v) for v in reversed(versions)]) + self.assertEqual(upgraded, self.migrate_up.call_args_list) + + downgraded_1 = [mock.call(v - 1, with_data=True) for v in versions] + downgraded_2 = [] + for v in reversed(versions): + downgraded_2.append(mock.call(v - 1)) + downgraded_2.append(mock.call(v - 1)) + downgraded = downgraded_1 + downgraded_2 + self.assertEqual(self.migrate_down.call_args_list, downgraded) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_true_false(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions(snake_walk=True, downgrade=False) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + upgraded = [] + for v in versions: + upgraded.append(mock.call(v, with_data=True)) + upgraded.append(mock.call(v)) + self.assertEqual(upgraded, self.migrate_up.call_args_list) + + downgraded = [mock.call(v - 1, with_data=True) for v in versions] + self.assertEqual(self.migrate_down.call_args_list, downgraded) + + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_up') + @mock.patch.object(migrate.WalkVersionsMixin, 'migrate_down') + def test_walk_versions_all_false(self, migrate_up, migrate_down): + self.REPOSITORY.latest = 20 + self.migration_api.db_version.return_value = self.INIT_VERSION + + self.walk_versions(snake_walk=False, downgrade=False) + + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + upgraded = [mock.call(v, with_data=True) for v in versions] + self.assertEqual(upgraded, self.migrate_up.call_args_list) + + +class ModelsMigrationSyncMixin(test.BaseTestCase): + + def setUp(self): + super(ModelsMigrationSyncMixin, self).setUp() + + self.metadata = sa.MetaData() + self.metadata_migrations = sa.MetaData() + + sa.Table( + 'testtbl', self.metadata_migrations, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('spam', sa.String(10), nullable=False), + sa.Column('eggs', sa.DateTime), + sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.true()), + sa.Column('bool_wo_default', sa.Boolean), + sa.Column('bar', sa.Numeric(10, 5)), + sa.Column('defaulttest', sa.Integer, server_default='5'), + sa.Column('defaulttest2', sa.String(8), server_default=''), + sa.Column('defaulttest3', sa.String(5), server_default="test"), + sa.Column('defaulttest4', sa.Enum('first', 'second', + name='testenum'), + server_default="first"), + sa.Column('fk_check', sa.String(36), nullable=False), + sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'), + ) + + BASE = sa_decl.declarative_base(metadata=self.metadata) + + class TestModel(BASE): + __tablename__ = 'testtbl' + __table_args__ = ( + sa.UniqueConstraint('spam', 'eggs', name='uniq_cons'), + ) + + id = sa.Column('id', sa.Integer, primary_key=True) + spam = sa.Column('spam', sa.String(10), nullable=False) + eggs = sa.Column('eggs', sa.DateTime) + foo = sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.true()) + fk_check = sa.Column('fk_check', sa.String(36), nullable=False) + bool_wo_default = sa.Column('bool_wo_default', sa.Boolean) + defaulttest = sa.Column('defaulttest', + sa.Integer, server_default='5') + defaulttest2 = sa.Column('defaulttest2', sa.String(8), + server_default='') + defaulttest3 = sa.Column('defaulttest3', sa.String(5), + server_default="test") + defaulttest4 = sa.Column('defaulttest4', sa.Enum('first', 'second', + name='testenum'), + server_default="first") + bar = sa.Column('bar', sa.Numeric(10, 5)) + + class ModelThatShouldNotBeCompared(BASE): + __tablename__ = 'testtbl2' + + id = sa.Column('id', sa.Integer, primary_key=True) + spam = sa.Column('spam', sa.String(10), nullable=False) + + def get_metadata(self): + return self.metadata + + def get_engine(self): + return self.engine + + def db_sync(self, engine): + self.metadata_migrations.create_all(bind=engine) + + def include_object(self, object_, name, type_, reflected, compare_to): + if type_ == 'table': + return name == 'testtbl' + else: + return True + + def _test_models_not_sync(self): + self.metadata_migrations.clear() + sa.Table( + 'table', self.metadata_migrations, + sa.Column('fk_check', sa.String(36), nullable=False), + sa.PrimaryKeyConstraint('fk_check'), + mysql_engine='InnoDB' + ) + sa.Table( + 'testtbl', self.metadata_migrations, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('spam', sa.String(8), nullable=True), + sa.Column('eggs', sa.DateTime), + sa.Column('foo', sa.Boolean, + server_default=sa.sql.expression.false()), + sa.Column('bool_wo_default', sa.Boolean, unique=True), + sa.Column('bar', sa.BigInteger), + sa.Column('defaulttest', sa.Integer, server_default='7'), + sa.Column('defaulttest2', sa.String(8), server_default=''), + sa.Column('defaulttest3', sa.String(5), server_default="fake"), + sa.Column('defaulttest4', + sa.Enum('first', 'second', name='testenum'), + server_default="first"), + sa.Column('fk_check', sa.String(36), nullable=False), + sa.UniqueConstraint('spam', 'foo', name='uniq_cons'), + sa.ForeignKeyConstraint(['fk_check'], ['table.fk_check']), + mysql_engine='InnoDB' + ) + + msg = six.text_type(self.assertRaises(AssertionError, + self.test_models_sync)) + # NOTE(I159): Check mentioning of the table and columns. + # The log is invalid json, so we can't parse it and check it for + # full compliance. We have no guarantee of the log items ordering, + # so we can't use regexp. + self.assertTrue(msg.startswith( + 'Models and migration scripts aren\'t in sync:')) + self.assertIn('testtbl', msg) + self.assertIn('spam', msg) + self.assertIn('eggs', msg) # test that the unique constraint is added + self.assertIn('foo', msg) + self.assertIn('bar', msg) + self.assertIn('bool_wo_default', msg) + self.assertIn('defaulttest', msg) + self.assertIn('defaulttest3', msg) + self.assertIn('drop_key', msg) + + +class ModelsMigrationsSyncMysql(ModelsMigrationSyncMixin, + migrate.ModelsMigrationsSync, + test_base.MySQLOpportunisticTestCase): + + def test_models_not_sync(self): + self._test_models_not_sync() + + +class ModelsMigrationsSyncPsql(ModelsMigrationSyncMixin, + migrate.ModelsMigrationsSync, + test_base.PostgreSQLOpportunisticTestCase): + + def test_models_not_sync(self): + self._test_models_not_sync() diff --git a/oslo_db/tests/sqlalchemy/test_models.py b/oslo_db/tests/sqlalchemy/test_models.py new file mode 100644 index 00000000..4a45576f --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_models.py @@ -0,0 +1,146 @@ +# Copyright 2012 Cloudscaling Group, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections + +from oslotest import base as oslo_test +from sqlalchemy import Column +from sqlalchemy import Integer, String +from sqlalchemy.ext.declarative import declarative_base + +from oslo_db.sqlalchemy import models +from oslo_db.sqlalchemy import test_base + + +BASE = declarative_base() + + +class ModelBaseTest(test_base.DbTestCase): + def setUp(self): + super(ModelBaseTest, self).setUp() + self.mb = models.ModelBase() + self.ekm = ExtraKeysModel() + + def test_modelbase_has_dict_methods(self): + dict_methods = ('__getitem__', + '__setitem__', + '__contains__', + 'get', + 'update', + 'save', + 'iteritems') + for method in dict_methods: + self.assertTrue(hasattr(models.ModelBase, method), + "Method %s() is not found" % method) + + def test_modelbase_is_iterable(self): + self.assertTrue(issubclass(models.ModelBase, collections.Iterable)) + + def test_modelbase_set(self): + self.mb['world'] = 'hello' + self.assertEqual(self.mb['world'], 'hello') + + def test_modelbase_update(self): + h = {'a': '1', 'b': '2'} + self.mb.update(h) + for key in h.keys(): + self.assertEqual(self.mb[key], h[key]) + + def test_modelbase_contains(self): + mb = models.ModelBase() + h = {'a': '1', 'b': '2'} + mb.update(h) + for key in h.keys(): + # Test 'in' syntax (instead of using .assertIn) + self.assertTrue(key in mb) + + self.assertFalse('non-existent-key' in mb) + + def test_modelbase_iteritems(self): + h = {'a': '1', 'b': '2'} + expected = { + 'id': None, + 'smth': None, + 'name': 'NAME', + 'a': '1', + 'b': '2', + } + self.ekm.update(h) + self.assertEqual(dict(self.ekm.iteritems()), expected) + + def test_modelbase_iter(self): + expected = { + 'id': None, + 'smth': None, + 'name': 'NAME', + } + i = iter(self.ekm) + found_items = 0 + while True: + r = next(i, None) + if r is None: + break + self.assertEqual(expected[r[0]], r[1]) + found_items += 1 + + self.assertEqual(len(expected), found_items) + + def test_modelbase_several_iters(self): + mb = ExtraKeysModel() + it1 = iter(mb) + it2 = iter(mb) + + self.assertFalse(it1 is it2) + self.assertEqual(dict(it1), dict(mb)) + self.assertEqual(dict(it2), dict(mb)) + + def test_extra_keys_empty(self): + """Test verifies that by default extra_keys return empty list.""" + self.assertEqual(self.mb._extra_keys, []) + + def test_extra_keys_defined(self): + """Property _extra_keys will return list with attributes names.""" + self.assertEqual(self.ekm._extra_keys, ['name']) + + def test_model_with_extra_keys(self): + data = dict(self.ekm) + self.assertEqual(data, {'smth': None, + 'id': None, + 'name': 'NAME'}) + + +class ExtraKeysModel(BASE, models.ModelBase): + __tablename__ = 'test_model' + + id = Column(Integer, primary_key=True) + smth = Column(String(255)) + + @property + def name(self): + return 'NAME' + + @property + def _extra_keys(self): + return ['name'] + + +class TimestampMixinTest(oslo_test.BaseTestCase): + + def test_timestampmixin_attr(self): + methods = ('created_at', + 'updated_at') + for method in methods: + self.assertTrue(hasattr(models.TimestampMixin, method), + "Method %s() is not found" % method) diff --git a/oslo_db/tests/sqlalchemy/test_options.py b/oslo_db/tests/sqlalchemy/test_options.py new file mode 100644 index 00000000..22a6e4f3 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_options.py @@ -0,0 +1,127 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg +from oslo.config import fixture as config + +from oslo_db import options +from oslo_db.tests import utils as test_utils + + +class DbApiOptionsTestCase(test_utils.BaseTestCase): + def setUp(self): + super(DbApiOptionsTestCase, self).setUp() + + config_fixture = self.useFixture(config.Config()) + self.conf = config_fixture.conf + self.conf.register_opts(options.database_opts, group='database') + self.config = config_fixture.config + + def test_deprecated_session_parameters(self): + path = self.create_tempfiles([["tmp", b"""[DEFAULT] +sql_connection=x://y.z +sql_min_pool_size=10 +sql_max_pool_size=20 +sql_max_retries=30 +sql_retry_interval=40 +sql_max_overflow=50 +sql_connection_debug=60 +sql_connection_trace=True +"""]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'x://y.z') + self.assertEqual(self.conf.database.min_pool_size, 10) + self.assertEqual(self.conf.database.max_pool_size, 20) + self.assertEqual(self.conf.database.max_retries, 30) + self.assertEqual(self.conf.database.retry_interval, 40) + self.assertEqual(self.conf.database.max_overflow, 50) + self.assertEqual(self.conf.database.connection_debug, 60) + self.assertEqual(self.conf.database.connection_trace, True) + + def test_session_parameters(self): + path = self.create_tempfiles([["tmp", b"""[database] +connection=x://y.z +min_pool_size=10 +max_pool_size=20 +max_retries=30 +retry_interval=40 +max_overflow=50 +connection_debug=60 +connection_trace=True +pool_timeout=7 +"""]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'x://y.z') + self.assertEqual(self.conf.database.min_pool_size, 10) + self.assertEqual(self.conf.database.max_pool_size, 20) + self.assertEqual(self.conf.database.max_retries, 30) + self.assertEqual(self.conf.database.retry_interval, 40) + self.assertEqual(self.conf.database.max_overflow, 50) + self.assertEqual(self.conf.database.connection_debug, 60) + self.assertEqual(self.conf.database.connection_trace, True) + self.assertEqual(self.conf.database.pool_timeout, 7) + + def test_dbapi_database_deprecated_parameters(self): + path = self.create_tempfiles([['tmp', b'[DATABASE]\n' + b'sql_connection=fake_connection\n' + b'sql_idle_timeout=100\n' + b'sql_min_pool_size=99\n' + b'sql_max_pool_size=199\n' + b'sql_max_retries=22\n' + b'reconnect_interval=17\n' + b'sqlalchemy_max_overflow=101\n' + b'sqlalchemy_pool_timeout=5\n' + ]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'fake_connection') + self.assertEqual(self.conf.database.idle_timeout, 100) + self.assertEqual(self.conf.database.min_pool_size, 99) + self.assertEqual(self.conf.database.max_pool_size, 199) + self.assertEqual(self.conf.database.max_retries, 22) + self.assertEqual(self.conf.database.retry_interval, 17) + self.assertEqual(self.conf.database.max_overflow, 101) + self.assertEqual(self.conf.database.pool_timeout, 5) + + def test_dbapi_database_deprecated_parameters_sql(self): + path = self.create_tempfiles([['tmp', b'[sql]\n' + b'connection=test_sql_connection\n' + b'idle_timeout=99\n' + ]])[0] + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.connection, 'test_sql_connection') + self.assertEqual(self.conf.database.idle_timeout, 99) + + def test_deprecated_dbapi_parameters(self): + path = self.create_tempfiles([['tmp', b'[DEFAULT]\n' + b'db_backend=test_123\n' + ]])[0] + + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.backend, 'test_123') + + def test_dbapi_parameters(self): + path = self.create_tempfiles([['tmp', b'[database]\n' + b'backend=test_123\n' + ]])[0] + + self.conf(['--config-file', path]) + self.assertEqual(self.conf.database.backend, 'test_123') + + def test_set_defaults(self): + conf = cfg.ConfigOpts() + + options.set_defaults(conf, + connection='sqlite:///:memory:') + + self.assertTrue(len(conf.database.items()) > 1) + self.assertEqual('sqlite:///:memory:', conf.database.connection) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py similarity index 94% rename from tests/sqlalchemy/test_sqlalchemy.py rename to oslo_db/tests/sqlalchemy/test_sqlalchemy.py index 1c463994..84cbbcf6 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py @@ -29,11 +29,11 @@ from sqlalchemy.engine import url from sqlalchemy import Integer, String from sqlalchemy.ext.declarative import declarative_base -from oslo.db import exception -from oslo.db import options as db_options -from oslo.db.sqlalchemy import models -from oslo.db.sqlalchemy import session -from oslo.db.sqlalchemy import test_base +from oslo_db import exception +from oslo_db import options as db_options +from oslo_db.sqlalchemy import models +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import test_base BASE = declarative_base() @@ -300,8 +300,8 @@ class EngineFacadeTestCase(oslo_test.BaseTestCase): self.assertFalse(ses.autocommit) self.assertTrue(ses.expire_on_commit) - @mock.patch('oslo.db.sqlalchemy.session.get_maker') - @mock.patch('oslo.db.sqlalchemy.session.create_engine') + @mock.patch('oslo_db.sqlalchemy.session.get_maker') + @mock.patch('oslo_db.sqlalchemy.session.create_engine') def test_creation_from_config(self, create_engine, get_maker): conf = cfg.ConfigOpts() conf.register_opts(db_options.database_opts, group='database') @@ -633,17 +633,23 @@ class CreateEngineTest(oslo_test.BaseTestCase): ) -class PatchStacktraceTest(test_base.DbTestCase): +# NOTE(dhellmann): This test no longer works as written. The code in +# oslo_db.sqlalchemy.session filters out lines from modules under +# oslo_db, and now this test is under oslo_db, so the test filename +# does not appear in the context for the error message. LP #1405376 - def test_trace(self): - engine = self.engine - session._add_trace_comments(engine) - conn = engine.connect() - with mock.patch.object(engine.dialect, "do_execute") as mock_exec: +# class PatchStacktraceTest(test_base.DbTestCase): - conn.execute("select * from table") +# def test_trace(self): +# engine = self.engine +# session._add_trace_comments(engine) +# conn = engine.connect() +# with mock.patch.object(engine.dialect, "do_execute") as mock_exec: - call = mock_exec.mock_calls[0] +# conn.execute("select * from table") - # we're the caller, see that we're in there - self.assertTrue("tests/sqlalchemy/test_sqlalchemy.py" in call[1][1]) +# call = mock_exec.mock_calls[0] + +# # we're the caller, see that we're in there +# self.assertIn("oslo_db/tests/sqlalchemy/test_sqlalchemy.py", +# call[1][1]) diff --git a/tests/sqlalchemy/test_utils.py b/oslo_db/tests/sqlalchemy/test_utils.py similarity index 99% rename from tests/sqlalchemy/test_utils.py rename to oslo_db/tests/sqlalchemy/test_utils.py index 78fef857..509cf48f 100644 --- a/tests/sqlalchemy/test_utils.py +++ b/oslo_db/tests/sqlalchemy/test_utils.py @@ -32,13 +32,13 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import select from sqlalchemy.types import UserDefinedType, NullType -from oslo.db import exception -from oslo.db.sqlalchemy import models -from oslo.db.sqlalchemy import provision -from oslo.db.sqlalchemy import session -from oslo.db.sqlalchemy import test_base as db_test_base -from oslo.db.sqlalchemy import utils -from tests import utils as test_utils +from oslo_db import exception +from oslo_db.sqlalchemy import models +from oslo_db.sqlalchemy import provision +from oslo_db.sqlalchemy import session +from oslo_db.sqlalchemy import test_base as db_test_base +from oslo_db.sqlalchemy import utils +from oslo_db.tests import utils as test_utils SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.'))) @@ -738,8 +738,8 @@ class TestUtils(db_test_base.DbTestCase): self.assertRaises(ValueError, utils.drop_index, self.engine, 'test_table', 'new_index') - @mock.patch('oslo.db.sqlalchemy.utils.drop_index') - @mock.patch('oslo.db.sqlalchemy.utils.add_index') + @mock.patch('oslo_db.sqlalchemy.utils.drop_index') + @mock.patch('oslo_db.sqlalchemy.utils.add_index') def test_change_index_columns(self, add_index, drop_index): utils.change_index_columns(self.engine, 'test_table', 'a_index', ('a',)) diff --git a/oslo_db/tests/test_api.py b/oslo_db/tests/test_api.py new file mode 100644 index 00000000..5874a019 --- /dev/null +++ b/oslo_db/tests/test_api.py @@ -0,0 +1,177 @@ +# Copyright (c) 2013 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Unit tests for DB API.""" + +import mock +from oslo.config import cfg +from oslo.utils import importutils + +from oslo_db import api +from oslo_db import exception +from oslo_db.tests import utils as test_utils + +sqla = importutils.try_import('sqlalchemy') +if not sqla: + raise ImportError("Unable to import module 'sqlalchemy'.") + + +def get_backend(): + return DBAPI() + + +class DBAPI(object): + def _api_raise(self, *args, **kwargs): + """Simulate raising a database-has-gone-away error + + This method creates a fake OperationalError with an ID matching + a valid MySQL "database has gone away" situation. It also decrements + the error_counter so that we can artificially keep track of + how many times this function is called by the wrapper. When + error_counter reaches zero, this function returns True, simulating + the database becoming available again and the query succeeding. + """ + + if self.error_counter > 0: + self.error_counter -= 1 + orig = sqla.exc.DBAPIError(False, False, False) + orig.args = [2006, 'Test raise operational error'] + e = exception.DBConnectionError(orig) + raise e + else: + return True + + def api_raise_default(self, *args, **kwargs): + return self._api_raise(*args, **kwargs) + + @api.safe_for_db_retry + def api_raise_enable_retry(self, *args, **kwargs): + return self._api_raise(*args, **kwargs) + + def api_class_call1(_self, *args, **kwargs): + return args, kwargs + + +class DBAPITestCase(test_utils.BaseTestCase): + def test_dbapi_full_path_module_method(self): + dbapi = api.DBAPI('oslo_db.tests.test_api') + result = dbapi.api_class_call1(1, 2, kwarg1='meow') + expected = ((1, 2), {'kwarg1': 'meow'}) + self.assertEqual(expected, result) + + def test_dbapi_unknown_invalid_backend(self): + self.assertRaises(ImportError, api.DBAPI, 'tests.unit.db.not_existent') + + def test_dbapi_lazy_loading(self): + dbapi = api.DBAPI('oslo_db.tests.test_api', lazy=True) + + self.assertIsNone(dbapi._backend) + dbapi.api_class_call1(1, 'abc') + self.assertIsNotNone(dbapi._backend) + + def test_dbapi_from_config(self): + conf = cfg.ConfigOpts() + + dbapi = api.DBAPI.from_config(conf, + backend_mapping={'sqlalchemy': __name__}) + self.assertIsNotNone(dbapi._backend) + + +class DBReconnectTestCase(DBAPITestCase): + def setUp(self): + super(DBReconnectTestCase, self).setUp() + + self.test_db_api = DBAPI() + patcher = mock.patch(__name__ + '.get_backend', + return_value=self.test_db_api) + patcher.start() + self.addCleanup(patcher.stop) + + def test_raise_connection_error(self): + self.dbapi = api.DBAPI('sqlalchemy', {'sqlalchemy': __name__}) + + self.test_db_api.error_counter = 5 + self.assertRaises(exception.DBConnectionError, self.dbapi._api_raise) + + def test_raise_connection_error_decorated(self): + self.dbapi = api.DBAPI('sqlalchemy', {'sqlalchemy': __name__}) + + self.test_db_api.error_counter = 5 + self.assertRaises(exception.DBConnectionError, + self.dbapi.api_raise_enable_retry) + self.assertEqual(4, self.test_db_api.error_counter, 'Unexpected retry') + + def test_raise_connection_error_enabled(self): + self.dbapi = api.DBAPI('sqlalchemy', + {'sqlalchemy': __name__}, + use_db_reconnect=True) + + self.test_db_api.error_counter = 5 + self.assertRaises(exception.DBConnectionError, + self.dbapi.api_raise_default) + self.assertEqual(4, self.test_db_api.error_counter, 'Unexpected retry') + + def test_retry_one(self): + self.dbapi = api.DBAPI('sqlalchemy', + {'sqlalchemy': __name__}, + use_db_reconnect=True, + retry_interval=1) + + try: + func = self.dbapi.api_raise_enable_retry + self.test_db_api.error_counter = 1 + self.assertTrue(func(), 'Single retry did not succeed.') + except Exception: + self.fail('Single retry raised an un-wrapped error.') + + self.assertEqual( + 0, self.test_db_api.error_counter, + 'Counter not decremented, retry logic probably failed.') + + def test_retry_two(self): + self.dbapi = api.DBAPI('sqlalchemy', + {'sqlalchemy': __name__}, + use_db_reconnect=True, + retry_interval=1, + inc_retry_interval=False) + + try: + func = self.dbapi.api_raise_enable_retry + self.test_db_api.error_counter = 2 + self.assertTrue(func(), 'Multiple retry did not succeed.') + except Exception: + self.fail('Multiple retry raised an un-wrapped error.') + + self.assertEqual( + 0, self.test_db_api.error_counter, + 'Counter not decremented, retry logic probably failed.') + + def test_retry_until_failure(self): + self.dbapi = api.DBAPI('sqlalchemy', + {'sqlalchemy': __name__}, + use_db_reconnect=True, + retry_interval=1, + inc_retry_interval=False, + max_retries=3) + + func = self.dbapi.api_raise_enable_retry + self.test_db_api.error_counter = 5 + self.assertRaises( + exception.DBError, func, + 'Retry of permanent failure did not throw DBError exception.') + + self.assertNotEqual( + 0, self.test_db_api.error_counter, + 'Retry did not stop after sql_max_retries iterations.') diff --git a/oslo_db/tests/test_concurrency.py b/oslo_db/tests/test_concurrency.py new file mode 100644 index 00000000..769f5224 --- /dev/null +++ b/oslo_db/tests/test_concurrency.py @@ -0,0 +1,108 @@ +# Copyright 2014 Mirantis.inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sys + +import mock + +from oslo_db import concurrency +from oslo_db.tests import utils as test_utils + +FAKE_BACKEND_MAPPING = {'sqlalchemy': 'fake.db.sqlalchemy.api'} + + +class TpoolDbapiWrapperTestCase(test_utils.BaseTestCase): + + def setUp(self): + super(TpoolDbapiWrapperTestCase, self).setUp() + self.db_api = concurrency.TpoolDbapiWrapper( + conf=self.conf, backend_mapping=FAKE_BACKEND_MAPPING) + + # NOTE(akurilin): We are not going to add `eventlet` to `oslo_db` in + # requirements (`requirements.txt` and `test-requirements.txt`) due to + # the following reasons: + # - supporting of eventlet's thread pooling is totally optional; + # - we don't need to test `tpool.Proxy` functionality itself, + # because it's a tool from the third party library; + # - `eventlet` would prevent us from running unit tests on Python 3.x + # versions, because it doesn't support them yet. + # + # As we don't test `tpool.Proxy`, we can safely mock it in tests. + + self.proxy = mock.MagicMock() + self.eventlet = mock.MagicMock() + self.eventlet.tpool.Proxy.return_value = self.proxy + sys.modules['eventlet'] = self.eventlet + self.addCleanup(sys.modules.pop, 'eventlet', None) + + @mock.patch('oslo_db.api.DBAPI') + def test_db_api_common(self, mock_db_api): + # test context: + # CONF.database.use_tpool == False + # eventlet is installed + # expected result: + # TpoolDbapiWrapper should wrap DBAPI + + fake_db_api = mock.MagicMock() + mock_db_api.from_config.return_value = fake_db_api + + # get access to some db-api method + self.db_api.fake_call_1 + + mock_db_api.from_config.assert_called_once_with( + conf=self.conf, backend_mapping=FAKE_BACKEND_MAPPING) + self.assertEqual(self.db_api._db_api, fake_db_api) + self.assertFalse(self.eventlet.tpool.Proxy.called) + + # get access to other db-api method to be sure that api didn't changed + self.db_api.fake_call_2 + + self.assertEqual(self.db_api._db_api, fake_db_api) + self.assertFalse(self.eventlet.tpool.Proxy.called) + self.assertEqual(1, mock_db_api.from_config.call_count) + + @mock.patch('oslo_db.api.DBAPI') + def test_db_api_config_change(self, mock_db_api): + # test context: + # CONF.database.use_tpool == True + # eventlet is installed + # expected result: + # TpoolDbapiWrapper should wrap tpool proxy + + fake_db_api = mock.MagicMock() + mock_db_api.from_config.return_value = fake_db_api + self.conf.set_override('use_tpool', True, group='database') + + # get access to some db-api method + self.db_api.fake_call + + # CONF.database.use_tpool is True, so we get tpool proxy in this case + mock_db_api.from_config.assert_called_once_with( + conf=self.conf, backend_mapping=FAKE_BACKEND_MAPPING) + self.eventlet.tpool.Proxy.assert_called_once_with(fake_db_api) + self.assertEqual(self.db_api._db_api, self.proxy) + + @mock.patch('oslo_db.api.DBAPI') + def test_db_api_without_installed_eventlet(self, mock_db_api): + # test context: + # CONF.database.use_tpool == True + # eventlet is not installed + # expected result: + # raise ImportError + + self.conf.set_override('use_tpool', True, group='database') + sys.modules['eventlet'] = None + + self.assertRaises(ImportError, getattr, self.db_api, 'fake') diff --git a/oslo_db/tests/utils.py b/oslo_db/tests/utils.py new file mode 100644 index 00000000..44eb1aeb --- /dev/null +++ b/oslo_db/tests/utils.py @@ -0,0 +1,40 @@ +# Copyright 2010-2011 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib + +from oslo.config import cfg +from oslotest import base as test_base +from oslotest import moxstubout +import six + + +if six.PY3: + @contextlib.contextmanager + def nested(*contexts): + with contextlib.ExitStack() as stack: + yield [stack.enter_context(c) for c in contexts] +else: + nested = contextlib.nested + + +class BaseTestCase(test_base.BaseTestCase): + def setUp(self, conf=cfg.CONF): + super(BaseTestCase, self).setUp() + moxfixture = self.useFixture(moxstubout.MoxStubout()) + self.mox = moxfixture.mox + self.stubs = moxfixture.stubs + self.conf = conf + self.addCleanup(self.conf.reset) diff --git a/setup.cfg b/setup.cfg index 119e2a80..c4ca694b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,17 +21,18 @@ classifier = packages = oslo oslo.db + oslo_db namespace_packages = oslo [entry_points] oslo.config.opts = - oslo.db = oslo.db.options:list_opts - oslo.db.concurrency = oslo.db.concurrency:list_opts + oslo.db = oslo_db.options:list_opts + oslo.db.concurrency = oslo_db.concurrency:list_opts oslo.db.migration = - alembic = oslo.db.sqlalchemy.migration_cli.ext_alembic:AlembicExtension - migrate = oslo.db.sqlalchemy.migration_cli.ext_migrate:MigrateExtension + alembic = oslo_db.sqlalchemy.migration_cli.ext_alembic:AlembicExtension + migrate = oslo_db.sqlalchemy.migration_cli.ext_migrate:MigrateExtension [build_sphinx] source-dir = doc/source diff --git a/tools/run_cross_tests.sh b/tools/run_cross_tests.sh index 5e7bc118..ec2b1c79 100755 --- a/tools/run_cross_tests.sh +++ b/tools/run_cross_tests.sh @@ -36,6 +36,11 @@ tox_envbin=$project_dir/.tox/$venv/bin our_name=$(python setup.py --name) +# Build the egg-info, including the source file list, +# so we install all of the files, even if the package +# list or name has changed. +python setup.py egg_info + # Replace the pip-installed package with the version in our source # tree. Look to see if we are already installed before trying to # uninstall ourselves, to avoid failures from packages that do not use us diff --git a/tox.ini b/tox.ini index e1fa3e0e..79111163 100644 --- a/tox.ini +++ b/tox.ini @@ -59,4 +59,4 @@ exclude=.venv,.git,.tox,dist,doc,*openstack/common*,*lib/python*,*egg,build [hacking] import_exceptions = - oslo.db._i18n + oslo_db._i18n