diff --git a/designate/exceptions.py b/designate/exceptions.py index b018c5b6b..7d800c13d 100644 --- a/designate/exceptions.py +++ b/designate/exceptions.py @@ -87,6 +87,10 @@ class NetworkEndpointNotFound(BadRequest): error_code = 403 +class MarkerNotFound(BadRequest): + error_type = 'marker_not_found' + + class InvalidOperation(BadRequest): error_code = 400 error_type = 'invalid_operation' diff --git a/designate/openstack/common/db/__init__.py b/designate/openstack/common/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/designate/openstack/common/db/api.py b/designate/openstack/common/db/api.py new file mode 100644 index 000000000..e8ca92dab --- /dev/null +++ b/designate/openstack/common/db/api.py @@ -0,0 +1,147 @@ +# Copyright (c) 2013 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Multiple DB API backend support. + +A DB backend module should implement a method named 'get_backend' which +takes no arguments. The method can return any object that implements DB +API methods. +""" + +import functools +import logging +import time + +from designate.openstack.common.db import exception +from designate.openstack.common.gettextutils import _ # noqa +from designate.openstack.common import importutils + + +LOG = logging.getLogger(__name__) + + +def safe_for_db_retry(f): + """Enable db-retry for decorated function, if config option enabled.""" + f.__dict__['enable_retry'] = True + return f + + +class wrap_db_retry(object): + """Retry db.api methods, if DBConnectionError() raised + + Retry decorated db.api methods. If we enabled `use_db_reconnect` + in config, this decorator will be applied to all db.api functions, + marked with @safe_for_db_retry decorator. + Decorator catchs DBConnectionError() and retries function in a + loop until it succeeds, or until maximum retries count will be reached. + """ + + def __init__(self, retry_interval, max_retries, inc_retry_interval, + max_retry_interval): + super(wrap_db_retry, self).__init__() + + self.retry_interval = retry_interval + self.max_retries = max_retries + self.inc_retry_interval = inc_retry_interval + self.max_retry_interval = max_retry_interval + + def __call__(self, f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + next_interval = self.retry_interval + remaining = self.max_retries + + while True: + try: + return f(*args, **kwargs) + except exception.DBConnectionError as e: + if remaining == 0: + LOG.exception(_('DB exceeded retry limit.')) + raise exception.DBError(e) + if remaining != -1: + remaining -= 1 + LOG.exception(_('DB connection error.')) + # NOTE(vsergeyev): We are using patched time module, so + # this effectively yields the execution + # context to another green thread. + time.sleep(next_interval) + if self.inc_retry_interval: + next_interval = min( + next_interval * 2, + self.max_retry_interval + ) + return wrapper + + +class DBAPI(object): + def __init__(self, backend_name, backend_mapping=None, **kwargs): + """Initialize the choosen DB API backend. + + :param backend_name: name of the backend to load + :type backend_name: str + + :param backend_mapping: backend name -> module/class to load mapping + :type backend_mapping: dict + + Keyword arguments: + + :keyword use_db_reconnect: retry DB transactions on disconnect or not + :type use_db_reconnect: bool + + :keyword retry_interval: seconds between transaction retries + :type retry_interval: int + + :keyword inc_retry_interval: increase retry interval or not + :type inc_retry_interval: bool + + :keyword max_retry_interval: max interval value between retries + :type max_retry_interval: int + + :keyword max_retries: max number of retries before an error is raised + :type max_retries: int + + """ + + if backend_mapping is None: + backend_mapping = {} + + # Import the untranslated name if we don't have a + # mapping. + backend_path = backend_mapping.get(backend_name, backend_name) + backend_mod = importutils.import_module(backend_path) + self.__backend = backend_mod.get_backend() + + self.use_db_reconnect = kwargs.get('use_db_reconnect', False) + self.retry_interval = kwargs.get('retry_interval', 1) + self.inc_retry_interval = kwargs.get('inc_retry_interval', True) + self.max_retry_interval = kwargs.get('max_retry_interval', 10) + self.max_retries = kwargs.get('max_retries', 20) + + def __getattr__(self, key): + attr = getattr(self.__backend, key) + + if not hasattr(attr, '__call__'): + return attr + # NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry + # DB API methods, decorated with @safe_for_db_retry + # on disconnect. + if self.use_db_reconnect and hasattr(attr, 'enable_retry'): + attr = wrap_db_retry( + retry_interval=self.retry_interval, + max_retries=self.max_retries, + inc_retry_interval=self.inc_retry_interval, + max_retry_interval=self.max_retry_interval)(attr) + + return attr diff --git a/designate/openstack/common/db/exception.py b/designate/openstack/common/db/exception.py new file mode 100644 index 000000000..fd3ba2464 --- /dev/null +++ b/designate/openstack/common/db/exception.py @@ -0,0 +1,56 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""DB related custom exceptions.""" + +import six + +from designate.openstack.common.gettextutils import _ + + +class DBError(Exception): + """Wraps an implementation specific exception.""" + def __init__(self, inner_exception=None): + self.inner_exception = inner_exception + super(DBError, self).__init__(six.text_type(inner_exception)) + + +class DBDuplicateEntry(DBError): + """Wraps an implementation specific exception.""" + def __init__(self, columns=[], inner_exception=None): + self.columns = columns + super(DBDuplicateEntry, self).__init__(inner_exception) + + +class DBDeadlock(DBError): + def __init__(self, inner_exception=None): + super(DBDeadlock, self).__init__(inner_exception) + + +class DBInvalidUnicodeParameter(Exception): + message = _("Invalid Parameter: " + "Unicode is not supported by the current database.") + + +class DbMigrationError(DBError): + """Wraps migration specific exception.""" + def __init__(self, message=None): + super(DbMigrationError, self).__init__(message) + + +class DBConnectionError(DBError): + """Wraps connection specific exception.""" + pass diff --git a/designate/openstack/common/db/options.py b/designate/openstack/common/db/options.py new file mode 100644 index 000000000..6e15a0f5c --- /dev/null +++ b/designate/openstack/common/db/options.py @@ -0,0 +1,177 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import os + +from oslo.config import cfg + + +sqlite_db_opts = [ + cfg.StrOpt('sqlite_db', + default='designate.sqlite', + help='The file name to use with SQLite'), + cfg.BoolOpt('sqlite_synchronous', + default=True, + help='If True, SQLite uses synchronous mode'), +] + +database_opts = [ + cfg.StrOpt('backend', + default='sqlalchemy', + deprecated_name='db_backend', + deprecated_group='DEFAULT', + help='The backend to use for db'), + cfg.StrOpt('connection', + default='sqlite:///' + + os.path.abspath(os.path.join(os.path.dirname(__file__), + '../', '$sqlite_db')), + help='The SQLAlchemy connection string used to connect to the ' + 'database', + secret=True, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_connection', + group='DATABASE'), + cfg.DeprecatedOpt('connection', + group='sql'), ]), + cfg.IntOpt('idle_timeout', + default=3600, + deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_idle_timeout', + group='DATABASE'), + cfg.DeprecatedOpt('idle_timeout', + group='sql')], + help='Timeout before idle sql connections are reaped'), + cfg.IntOpt('min_pool_size', + default=1, + deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_min_pool_size', + group='DATABASE')], + help='Minimum number of SQL connections to keep open in a ' + 'pool'), + cfg.IntOpt('max_pool_size', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_pool_size', + group='DATABASE')], + help='Maximum number of SQL connections to keep open in a ' + 'pool'), + cfg.IntOpt('max_retries', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_retries', + group='DATABASE')], + help='Maximum db connection retries during startup. ' + '(setting -1 implies an infinite retry count)'), + cfg.IntOpt('retry_interval', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', + group='DEFAULT'), + cfg.DeprecatedOpt('reconnect_interval', + group='DATABASE')], + help='Interval between retries of opening a sql connection'), + cfg.IntOpt('max_overflow', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', + group='DEFAULT'), + cfg.DeprecatedOpt('sqlalchemy_max_overflow', + group='DATABASE')], + help='If set, use this value for max_overflow with sqlalchemy'), + cfg.IntOpt('connection_debug', + default=0, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', + group='DEFAULT')], + help='Verbosity of SQL debugging information. 0=None, ' + '100=Everything'), + cfg.BoolOpt('connection_trace', + default=False, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', + group='DEFAULT')], + help='Add python stack traces to SQL as comment strings'), + cfg.IntOpt('pool_timeout', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', + group='DATABASE')], + help='If set, use this value for pool_timeout with sqlalchemy'), + cfg.BoolOpt('use_db_reconnect', + default=False, + help='Enable the experimental use of database reconnect ' + 'on connection lost'), + cfg.IntOpt('db_retry_interval', + default=1, + help='seconds between db connection retries'), + cfg.BoolOpt('db_inc_retry_interval', + default=True, + help='Whether to increase interval between db connection ' + 'retries, up to db_max_retry_interval'), + cfg.IntOpt('db_max_retry_interval', + default=10, + help='max seconds between db connection retries, if ' + 'db_inc_retry_interval is enabled'), + cfg.IntOpt('db_max_retries', + default=20, + help='maximum db connection retries before error is raised. ' + '(setting -1 implies an infinite retry count)'), +] + +CONF = cfg.CONF +CONF.register_opts(sqlite_db_opts) +CONF.register_opts(database_opts, 'database') + + +def set_defaults(sql_connection, sqlite_db, max_pool_size=None, + max_overflow=None, pool_timeout=None): + """Set defaults for configuration variables.""" + cfg.set_defaults(database_opts, + connection=sql_connection) + cfg.set_defaults(sqlite_db_opts, + sqlite_db=sqlite_db) + # Update the QueuePool defaults + if max_pool_size is not None: + cfg.set_defaults(database_opts, + max_pool_size=max_pool_size) + if max_overflow is not None: + cfg.set_defaults(database_opts, + max_overflow=max_overflow) + if pool_timeout is not None: + cfg.set_defaults(database_opts, + pool_timeout=pool_timeout) + + +_opts = [ + (None, sqlite_db_opts), + ('database', database_opts), +] + + +def list_opts(): + """Returns a list of oslo.config options available in the library. + + The returned list includes all oslo.config options which may be registered + at runtime by the library. + + Each element of the list is a tuple. The first element is the name of the + group under which the list of elements in the second element will be + registered. A group name of None corresponds to the [DEFAULT] group in + config files. + + The purpose of this is to allow tools like the Oslo sample config file + generator to discover the options exposed to users by this library. + + :returns: a list of (group_name, opts) tuples + """ + return [(g, copy.deepcopy(o)) for g, o in _opts] diff --git a/designate/openstack/common/db/sqlalchemy/__init__.py b/designate/openstack/common/db/sqlalchemy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/designate/openstack/common/db/sqlalchemy/migration.py b/designate/openstack/common/db/sqlalchemy/migration.py new file mode 100644 index 000000000..c88fa6146 --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/migration.py @@ -0,0 +1,268 @@ +# coding: utf-8 +# +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# Base on code in migrate/changeset/databases/sqlite.py which is under +# the following license: +# +# The MIT License +# +# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import re + +from migrate.changeset import ansisql +from migrate.changeset.databases import sqlite +from migrate import exceptions as versioning_exceptions +from migrate.versioning import api as versioning_api +from migrate.versioning.repository import Repository +import sqlalchemy +from sqlalchemy.schema import UniqueConstraint + +from designate.openstack.common.db import exception +from designate.openstack.common.gettextutils import _ + + +def _get_unique_constraints(self, table): + """Retrieve information about existing unique constraints of the table + + This feature is needed for _recreate_table() to work properly. + Unfortunately, it's not available in sqlalchemy 0.7.x/0.8.x. + + """ + + data = table.metadata.bind.execute( + """SELECT sql + FROM sqlite_master + WHERE + type='table' AND + name=:table_name""", + table_name=table.name + ).fetchone()[0] + + UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)" + return [ + UniqueConstraint( + *[getattr(table.columns, c.strip(' "')) for c in cols.split(",")], + name=name + ) + for name, cols in re.findall(UNIQUE_PATTERN, data) + ] + + +def _recreate_table(self, table, column=None, delta=None, omit_uniques=None): + """Recreate the table properly + + Unlike the corresponding original method of sqlalchemy-migrate this one + doesn't drop existing unique constraints when creating a new one. + + """ + + table_name = self.preparer.format_table(table) + + # we remove all indexes so as not to have + # problems during copy and re-create + for index in table.indexes: + index.drop() + + # reflect existing unique constraints + for uc in self._get_unique_constraints(table): + table.append_constraint(uc) + # omit given unique constraints when creating a new table if required + table.constraints = set([ + cons for cons in table.constraints + if omit_uniques is None or cons.name not in omit_uniques + ]) + + self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name) + self.execute() + + insertion_string = self._modify_table(table, column, delta) + + table.create(bind=self.connection) + self.append(insertion_string % {'table_name': table_name}) + self.execute() + self.append('DROP TABLE migration_tmp') + self.execute() + + +def _visit_migrate_unique_constraint(self, *p, **k): + """Drop the given unique constraint + + The corresponding original method of sqlalchemy-migrate just + raises NotImplemented error + + """ + + self.recreate_table(p[0].table, omit_uniques=[p[0].name]) + + +def patch_migrate(): + """A workaround for SQLite's inability to alter things + + SQLite abilities to alter tables are very limited (please read + http://www.sqlite.org/lang_altertable.html for more details). + E. g. one can't drop a column or a constraint in SQLite. The + workaround for this is to recreate the original table omitting + the corresponding constraint (or column). + + sqlalchemy-migrate library has recreate_table() method that + implements this workaround, but it does it wrong: + + - information about unique constraints of a table + is not retrieved. So if you have a table with one + unique constraint and a migration adding another one + you will end up with a table that has only the + latter unique constraint, and the former will be lost + + - dropping of unique constraints is not supported at all + + The proper way to fix this is to provide a pull-request to + sqlalchemy-migrate, but the project seems to be dead. So we + can go on with monkey-patching of the lib at least for now. + + """ + + # this patch is needed to ensure that recreate_table() doesn't drop + # existing unique constraints of the table when creating a new one + helper_cls = sqlite.SQLiteHelper + helper_cls.recreate_table = _recreate_table + helper_cls._get_unique_constraints = _get_unique_constraints + + # this patch is needed to be able to drop existing unique constraints + constraint_cls = sqlite.SQLiteConstraintDropper + constraint_cls.visit_migrate_unique_constraint = \ + _visit_migrate_unique_constraint + constraint_cls.__bases__ = (ansisql.ANSIColumnDropper, + sqlite.SQLiteConstraintGenerator) + + +def db_sync(engine, abs_path, version=None, init_version=0): + """Upgrade or downgrade a database. + + Function runs the upgrade() or downgrade() functions in change scripts. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository. + :param version: Database will upgrade/downgrade until this version. + If None - database will update to the latest + available version. + :param init_version: Initial database version + """ + if version is not None: + try: + version = int(version) + except ValueError: + raise exception.DbMigrationError( + message=_("version should be an integer")) + + current_version = db_version(engine, abs_path, init_version) + repository = _find_migrate_repo(abs_path) + _db_schema_sanity_check(engine) + if version is None or version > current_version: + return versioning_api.upgrade(engine, repository, version) + else: + return versioning_api.downgrade(engine, repository, + version) + + +def _db_schema_sanity_check(engine): + """Ensure all database tables were created with required parameters. + + :param engine: SQLAlchemy engine instance for a given database + + """ + + if engine.name == 'mysql': + onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION ' + 'from information_schema.TABLES ' + 'where TABLE_SCHEMA=%s and ' + 'TABLE_COLLATION NOT LIKE "%%utf8%%"') + + table_names = [res[0] for res in engine.execute(onlyutf8_sql, + engine.url.database)] + if len(table_names) > 0: + raise ValueError(_('Tables "%s" have non utf8 collation, ' + 'please make sure all tables are CHARSET=utf8' + ) % ','.join(table_names)) + + +def db_version(engine, abs_path, init_version): + """Show the current version of the repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + try: + return versioning_api.db_version(engine, repository) + except versioning_exceptions.DatabaseNotControlledError: + meta = sqlalchemy.MetaData() + meta.reflect(bind=engine) + tables = meta.tables + if len(tables) == 0 or 'alembic_version' in tables: + db_version_control(abs_path, init_version) + return versioning_api.db_version(engine, repository) + else: + raise exception.DbMigrationError( + message=_( + "The database is not under version control, but has " + "tables. Please stamp the current version of the schema " + "manually.")) + + +def db_version_control(engine, abs_path, version=None): + """Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + versioning_api.version_control(engine, repository, version) + return version + + +def _find_migrate_repo(abs_path): + """Get the project's change script repository + + :param abs_path: Absolute path to migrate repository + """ + if not os.path.exists(abs_path): + raise exception.DbMigrationError("Path %s not found" % abs_path) + return Repository(abs_path) diff --git a/designate/openstack/common/db/sqlalchemy/models.py b/designate/openstack/common/db/sqlalchemy/models.py new file mode 100644 index 000000000..01296da49 --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/models.py @@ -0,0 +1,115 @@ +# Copyright (c) 2011 X.commerce, a business unit of eBay Inc. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2012 Cloudscaling Group, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +SQLAlchemy models. +""" + +import six + +from sqlalchemy import Column, Integer +from sqlalchemy import DateTime +from sqlalchemy.orm import object_mapper + +from designate.openstack.common import timeutils + + +class ModelBase(object): + """Base class for models.""" + __table_initialized__ = False + + def save(self, session): + """Save this object.""" + + # NOTE(boris-42): This part of code should be look like: + # session.add(self) + # session.flush() + # But there is a bug in sqlalchemy and eventlet that + # raises NoneType exception if there is no running + # transaction and rollback is called. As long as + # sqlalchemy has this bug we have to create transaction + # explicitly. + with session.begin(subtransactions=True): + session.add(self) + session.flush() + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + @property + def _extra_keys(self): + """Specifies custom fields + + Subclasses can override this property to return a list + of custom fields that should be included in their dict + representation. + + For reference check tests/db/sqlalchemy/test_models.py + """ + return [] + + def __iter__(self): + columns = dict(object_mapper(self).columns).keys() + # NOTE(russellb): Allow models to specify other keys that can be looked + # up, beyond the actual db columns. An example would be the 'name' + # property for an Instance. + columns.extend(self._extra_keys) + self._i = iter(columns) + return self + + def next(self): + n = six.advance_iterator(self._i) + return n, getattr(self, n) + + def update(self, values): + """Make the model object behave like a dict.""" + for k, v in six.iteritems(values): + setattr(self, k, v) + + def iteritems(self): + """Make the model object behave like a dict. + + Includes attributes from joins. + """ + local = dict(self) + joined = dict([(k, v) for k, v in six.iteritems(self.__dict__) + if not k[0] == '_']) + local.update(joined) + return six.iteritems(local) + + +class TimestampMixin(object): + created_at = Column(DateTime, default=lambda: timeutils.utcnow()) + updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow()) + + +class SoftDeleteMixin(object): + deleted_at = Column(DateTime) + deleted = Column(Integer, default=0) + + def soft_delete(self, session): + """Mark this object as deleted.""" + self.deleted = self.id + self.deleted_at = timeutils.utcnow() + self.save(session=session) diff --git a/designate/openstack/common/db/sqlalchemy/provision.py b/designate/openstack/common/db/sqlalchemy/provision.py new file mode 100644 index 000000000..47582ec88 --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/provision.py @@ -0,0 +1,187 @@ +# Copyright 2013 Mirantis.inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provision test environment for specific DB backends""" + +import argparse +import os +import random +import string + +from six import moves +import sqlalchemy + +from designate.openstack.common.db import exception as exc + + +SQL_CONNECTION = os.getenv('OS_TEST_DBAPI_ADMIN_CONNECTION', 'sqlite://') + + +def _gen_credentials(*names): + """Generate credentials.""" + auth_dict = {} + for name in names: + val = ''.join(random.choice(string.ascii_lowercase) + for i in moves.range(10)) + auth_dict[name] = val + return auth_dict + + +def _get_engine(uri=SQL_CONNECTION): + """Engine creation + + By default the uri is SQL_CONNECTION which is admin credentials. + Call the function without arguments to get admin connection. Admin + connection required to create temporary user and database for each + particular test. Otherwise use existing connection to recreate connection + to the temporary database. + """ + return sqlalchemy.create_engine(uri, poolclass=sqlalchemy.pool.NullPool) + + +def _execute_sql(engine, sql, driver): + """Initialize connection, execute sql query and close it.""" + try: + with engine.connect() as conn: + if driver == 'postgresql': + conn.connection.set_isolation_level(0) + for s in sql: + conn.execute(s) + except sqlalchemy.exc.OperationalError: + msg = ('%s does not match database admin ' + 'credentials or database does not exist.') + raise exc.DBConnectionError(msg % SQL_CONNECTION) + + +def create_database(engine): + """Provide temporary user and database for each particular test.""" + driver = engine.name + + auth = _gen_credentials('database', 'user', 'passwd') + + sqls = { + 'mysql': [ + "drop database if exists %(database)s;", + "grant all on %(database)s.* to '%(user)s'@'localhost'" + " identified by '%(passwd)s';", + "create database %(database)s;", + ], + 'postgresql': [ + "drop database if exists %(database)s;", + "drop user if exists %(user)s;", + "create user %(user)s with password '%(passwd)s';", + "create database %(database)s owner %(user)s;", + ] + } + + if driver == 'sqlite': + return 'sqlite:////tmp/%s' % auth['database'] + + try: + sql_rows = sqls[driver] + except KeyError: + raise ValueError('Unsupported RDBMS %s' % driver) + sql_query = map(lambda x: x % auth, sql_rows) + + _execute_sql(engine, sql_query, driver) + + params = auth.copy() + params['backend'] = driver + return "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" % params + + +def drop_database(engine, current_uri): + """Drop temporary database and user after each particular test.""" + engine = _get_engine(current_uri) + admin_engine = _get_engine() + driver = engine.name + auth = {'database': engine.url.database, 'user': engine.url.username} + + if driver == 'sqlite': + try: + os.remove(auth['database']) + except OSError: + pass + return + + sqls = { + 'mysql': [ + "drop database if exists %(database)s;", + "drop user '%(user)s'@'localhost';", + ], + 'postgresql': [ + "drop database if exists %(database)s;", + "drop user if exists %(user)s;", + ] + } + + try: + sql_rows = sqls[driver] + except KeyError: + raise ValueError('Unsupported RDBMS %s' % driver) + sql_query = map(lambda x: x % auth, sql_rows) + + _execute_sql(admin_engine, sql_query, driver) + + +def main(): + """Controller to handle commands + + ::create: Create test user and database with random names. + ::drop: Drop user and database created by previous command. + """ + parser = argparse.ArgumentParser( + description='Controller to handle database creation and dropping' + ' commands.', + epilog='Under normal circumstances is not used directly.' + ' Used in .testr.conf to automate test database creation' + ' and dropping processes.') + subparsers = parser.add_subparsers( + help='Subcommands to manipulate temporary test databases.') + + create = subparsers.add_parser( + 'create', + help='Create temporary test ' + 'databases and users.') + create.set_defaults(which='create') + create.add_argument( + 'instances_count', + type=int, + help='Number of databases to create.') + + drop = subparsers.add_parser( + 'drop', + help='Drop temporary test databases and users.') + drop.set_defaults(which='drop') + drop.add_argument( + 'instances', + nargs='+', + help='List of databases uri to be dropped.') + + args = parser.parse_args() + + engine = _get_engine() + which = args.which + + if which == "create": + for i in range(int(args.instances_count)): + print(create_database(engine)) + elif which == "drop": + for db in args.instances: + drop_database(engine, db) + + +if __name__ == "__main__": + main() diff --git a/designate/openstack/common/db/sqlalchemy/session.py b/designate/openstack/common/db/sqlalchemy/session.py new file mode 100644 index 000000000..faaf7eb4a --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/session.py @@ -0,0 +1,809 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Session Handling for SQLAlchemy backend. + +Recommended ways to use sessions within this framework: + +* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``. + `model_query()` will implicitly use a session when called without one + supplied. This is the ideal situation because it will allow queries + to be automatically retried if the database connection is interrupted. + + .. note:: Automatic retry will be enabled in a future patch. + + It is generally fine to issue several queries in a row like this. Even though + they may be run in separate transactions and/or separate sessions, each one + will see the data from the prior calls. If needed, undo- or rollback-like + functionality should be handled at a logical level. For an example, look at + the code around quotas and `reservation_rollback()`. + + Examples: + + .. code:: python + + def get_foo(context, foo): + return (model_query(context, models.Foo). + filter_by(foo=foo). + first()) + + def update_foo(context, id, newfoo): + (model_query(context, models.Foo). + filter_by(id=id). + update({'foo': newfoo})) + + def create_foo(context, values): + foo_ref = models.Foo() + foo_ref.update(values) + foo_ref.save() + return foo_ref + + +* Within the scope of a single method, keep all the reads and writes within + the context managed by a single session. In this way, the session's + `__exit__` handler will take care of calling `flush()` and `commit()` for + you. If using this approach, you should not explicitly call `flush()` or + `commit()`. Any error within the context of the session will cause the + session to emit a `ROLLBACK`. Database errors like `IntegrityError` will be + raised in `session`'s `__exit__` handler, and any try/except within the + context managed by `session` will not be triggered. And catching other + non-database errors in the session will not trigger the ROLLBACK, so + exception handlers should always be outside the session, unless the + developer wants to do a partial commit on purpose. If the connection is + dropped before this is possible, the database will implicitly roll back the + transaction. + + .. note:: Statements in the session scope will not be automatically retried. + + If you create models within the session, they need to be added, but you + do not need to call `model.save()`: + + .. code:: python + + def create_many_foo(context, foos): + session = sessionmaker() + with session.begin(): + for foo in foos: + foo_ref = models.Foo() + foo_ref.update(foo) + session.add(foo_ref) + + def update_bar(context, foo_id, newbar): + session = sessionmaker() + with session.begin(): + foo_ref = (model_query(context, models.Foo, session). + filter_by(id=foo_id). + first()) + (model_query(context, models.Bar, session). + filter_by(id=foo_ref['bar_id']). + update({'bar': newbar})) + + .. note:: `update_bar` is a trivially simple example of using + ``with session.begin``. Whereas `create_many_foo` is a good example of + when a transaction is needed, it is always best to use as few queries as + possible. + + The two queries in `update_bar` can be better expressed using a single query + which avoids the need for an explicit transaction. It can be expressed like + so: + + .. code:: python + + def update_bar(context, foo_id, newbar): + subq = (model_query(context, models.Foo.id). + filter_by(id=foo_id). + limit(1). + subquery()) + (model_query(context, models.Bar). + filter_by(id=subq.as_scalar()). + update({'bar': newbar})) + + For reference, this emits approximately the following SQL statement: + + .. code:: sql + + UPDATE bar SET bar = ${newbar} + WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1); + + .. note:: `create_duplicate_foo` is a trivially simple example of catching an + exception while using ``with session.begin``. Here create two duplicate + instances with same primary key, must catch the exception out of context + managed by a single session: + + .. code:: python + + def create_duplicate_foo(context): + foo1 = models.Foo() + foo2 = models.Foo() + foo1.id = foo2.id = 1 + session = sessionmaker() + try: + with session.begin(): + session.add(foo1) + session.add(foo2) + except exception.DBDuplicateEntry as e: + handle_error(e) + +* Passing an active session between methods. Sessions should only be passed + to private methods. The private method must use a subtransaction; otherwise + SQLAlchemy will throw an error when you call `session.begin()` on an existing + transaction. Public methods should not accept a session parameter and should + not be involved in sessions within the caller's scope. + + Note that this incurs more overhead in SQLAlchemy than the above means + due to nesting transactions, and it is not possible to implicitly retry + failed database operations when using this approach. + + This also makes code somewhat more difficult to read and debug, because a + single database transaction spans more than one method. Error handling + becomes less clear in this situation. When this is needed for code clarity, + it should be clearly documented. + + .. code:: python + + def myfunc(foo): + session = sessionmaker() + with session.begin(): + # do some database things + bar = _private_func(foo, session) + return bar + + def _private_func(foo, session=None): + if not session: + session = sessionmaker() + with session.begin(subtransaction=True): + # do some other database things + return bar + + +There are some things which it is best to avoid: + +* Don't keep a transaction open any longer than necessary. + + This means that your ``with session.begin()`` block should be as short + as possible, while still containing all the related calls for that + transaction. + +* Avoid ``with_lockmode('UPDATE')`` when possible. + + In MySQL/InnoDB, when a ``SELECT ... FOR UPDATE`` query does not match + any rows, it will take a gap-lock. This is a form of write-lock on the + "gap" where no rows exist, and prevents any other writes to that space. + This can effectively prevent any INSERT into a table by locking the gap + at the end of the index. Similar problems will occur if the SELECT FOR UPDATE + has an overly broad WHERE clause, or doesn't properly use an index. + + One idea proposed at ODS Fall '12 was to use a normal SELECT to test the + number of rows matching a query, and if only one row is returned, + then issue the SELECT FOR UPDATE. + + The better long-term solution is to use + ``INSERT .. ON DUPLICATE KEY UPDATE``. + However, this can not be done until the "deleted" columns are removed and + proper UNIQUE constraints are added to the tables. + + +Enabling soft deletes: + +* To use/enable soft-deletes, the `SoftDeleteMixin` must be added + to your model class. For example: + + .. code:: python + + class NovaBase(models.SoftDeleteMixin, models.ModelBase): + pass + + +Efficient use of soft deletes: + +* There are two possible ways to mark a record as deleted: + `model.soft_delete()` and `query.soft_delete()`. + + The `model.soft_delete()` method works with a single already-fetched entry. + `query.soft_delete()` makes only one db request for all entries that + correspond to the query. + +* In almost all cases you should use `query.soft_delete()`. Some examples: + + .. code:: python + + def soft_delete_bar(): + count = model_query(BarModel).find(some_condition).soft_delete() + if count == 0: + raise Exception("0 entries were soft deleted") + + def complex_soft_delete_with_synchronization_bar(session=None): + if session is None: + session = sessionmaker() + with session.begin(subtransactions=True): + count = (model_query(BarModel). + find(some_condition). + soft_delete(synchronize_session=True)) + # Here synchronize_session is required, because we + # don't know what is going on in outer session. + if count == 0: + raise Exception("0 entries were soft deleted") + +* There is only one situation where `model.soft_delete()` is appropriate: when + you fetch a single record, work with it, and mark it as deleted in the same + transaction. + + .. code:: python + + def soft_delete_bar_model(): + session = sessionmaker() + with session.begin(): + bar_ref = model_query(BarModel).find(some_condition).first() + # Work with bar_ref + bar_ref.soft_delete(session=session) + + However, if you need to work with all entries that correspond to query and + then soft delete them you should use the `query.soft_delete()` method: + + .. code:: python + + def soft_delete_multi_models(): + session = sessionmaker() + with session.begin(): + query = (model_query(BarModel, session=session). + find(some_condition)) + model_refs = query.all() + # Work with model_refs + query.soft_delete(synchronize_session=False) + # synchronize_session=False should be set if there is no outer + # session and these entries are not used after this. + + When working with many rows, it is very important to use query.soft_delete, + which issues a single query. Using `model.soft_delete()`, as in the following + example, is very inefficient. + + .. code:: python + + for bar_ref in bar_refs: + bar_ref.soft_delete(session=session) + # This will produce count(bar_refs) db requests. + +""" + +import functools +import logging +import re +import time + +import six +from sqlalchemy import exc as sqla_exc +from sqlalchemy.interfaces import PoolListener +import sqlalchemy.orm +from sqlalchemy.pool import NullPool, StaticPool +from sqlalchemy.sql.expression import literal_column + +from designate.openstack.common.db import exception +from designate.openstack.common.gettextutils import _ +from designate.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + + +class SqliteForeignKeysListener(PoolListener): + """Ensures that the foreign key constraints are enforced in SQLite. + + The foreign key constraints are disabled by default in SQLite, + so the foreign key constraints will be enabled here for every + database connection + """ + def connect(self, dbapi_con, con_record): + dbapi_con.execute('pragma foreign_keys=ON') + + +# note(boris-42): In current versions of DB backends unique constraint +# violation messages follow the structure: +# +# sqlite: +# 1 column - (IntegrityError) column c1 is not unique +# N columns - (IntegrityError) column c1, c2, ..., N are not unique +# +# sqlite since 3.7.16: +# 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1 +# +# N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2 +# +# postgres: +# 1 column - (IntegrityError) duplicate key value violates unique +# constraint "users_c1_key" +# N columns - (IntegrityError) duplicate key value violates unique +# constraint "name_of_our_constraint" +# +# mysql: +# 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key +# 'c1'") +# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined +# with -' for key 'name_of_our_constraint'") +_DUP_KEY_RE_DB = { + "sqlite": (re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"), + re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(.+)$")), + "postgresql": (re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),), + "mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),) +} + + +def _raise_if_duplicate_entry_error(integrity_error, engine_name): + """Raise exception if two entries are duplicated. + + In this function will be raised DBDuplicateEntry exception if integrity + error wrap unique constraint violation. + """ + + def get_columns_from_uniq_cons_or_name(columns): + # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2" + # where `t` it is table name and columns `c1`, `c2` + # are in UniqueConstraint. + uniqbase = "uniq_" + if not columns.startswith(uniqbase): + if engine_name == "postgresql": + return [columns[columns.index("_") + 1:columns.rindex("_")]] + return [columns] + return columns[len(uniqbase):].split("0")[1:] + + if engine_name not in ["mysql", "sqlite", "postgresql"]: + return + + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. + for pattern in _DUP_KEY_RE_DB[engine_name]: + match = pattern.match(integrity_error.message) + if match: + break + else: + return + + columns = match.group(1) + + if engine_name == "sqlite": + columns = [c.split('.')[-1] for c in columns.strip().split(", ")] + else: + columns = get_columns_from_uniq_cons_or_name(columns) + raise exception.DBDuplicateEntry(columns, integrity_error) + + +# NOTE(comstud): In current versions of DB backends, Deadlock violation +# messages follow the structure: +# +# mysql: +# (OperationalError) (1213, 'Deadlock found when trying to get lock; try ' +# 'restarting transaction') +_DEADLOCK_RE_DB = { + "mysql": re.compile(r"^.*\(1213, 'Deadlock.*") +} + + +def _raise_if_deadlock_error(operational_error, engine_name): + """Raise exception on deadlock condition. + + Raise DBDeadlock exception if OperationalError contains a Deadlock + condition. + """ + re = _DEADLOCK_RE_DB.get(engine_name) + if re is None: + return + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. + m = re.match(operational_error.message) + if not m: + return + raise exception.DBDeadlock(operational_error) + + +def _wrap_db_error(f): + #TODO(rpodolyaka): in a subsequent commit make this a class decorator to + # ensure it can only applied to Session subclasses instances (as we use + # Session instance bind attribute below) + + @functools.wraps(f) + def _wrap(self, *args, **kwargs): + try: + return f(self, *args, **kwargs) + except UnicodeEncodeError: + raise exception.DBInvalidUnicodeParameter() + except sqla_exc.OperationalError as e: + _raise_if_db_connection_lost(e, self.bind) + _raise_if_deadlock_error(e, self.bind.dialect.name) + # NOTE(comstud): A lot of code is checking for OperationalError + # so let's not wrap it for now. + raise + # note(boris-42): We should catch unique constraint violation and + # wrap it by our own DBDuplicateEntry exception. Unique constraint + # violation is wrapped by IntegrityError. + except sqla_exc.IntegrityError as e: + # note(boris-42): SqlAlchemy doesn't unify errors from different + # DBs so we must do this. Also in some tables (for example + # instance_types) there are more than one unique constraint. This + # means we should get names of columns, which values violate + # unique constraint, from error message. + _raise_if_duplicate_entry_error(e, self.bind.dialect.name) + raise exception.DBError(e) + except Exception as e: + LOG.exception(_('DB exception wrapped.')) + raise exception.DBError(e) + return _wrap + + +def _synchronous_switch_listener(dbapi_conn, connection_rec): + """Switch sqlite connections to non-synchronous mode.""" + dbapi_conn.execute("PRAGMA synchronous = OFF") + + +def _add_regexp_listener(dbapi_con, con_record): + """Add REGEXP function to sqlite connections.""" + + def regexp(expr, item): + reg = re.compile(expr) + return reg.search(six.text_type(item)) is not None + dbapi_con.create_function('regexp', 2, regexp) + + +def _thread_yield(dbapi_con, con_record): + """Ensure other greenthreads get a chance to be executed. + + If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will + execute instead of time.sleep(0). + Force a context switch. With common database backends (eg MySQLdb and + sqlite), there is no implicit yield caused by network I/O since they are + implemented by C libraries that eventlet cannot monkey patch. + """ + time.sleep(0) + + +def _ping_listener(engine, dbapi_conn, connection_rec, connection_proxy): + """Ensures that MySQL and DB2 connections are alive. + + Borrowed from: + http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f + """ + cursor = dbapi_conn.cursor() + try: + ping_sql = 'select 1' + if engine.name == 'ibm_db_sa': + # DB2 requires a table expression + ping_sql = 'select 1 from (values (1)) AS t1' + cursor.execute(ping_sql) + except Exception as ex: + if engine.dialect.is_disconnect(ex, dbapi_conn, cursor): + msg = _('Database server has gone away: %s') % ex + LOG.warning(msg) + raise sqla_exc.DisconnectionError(msg) + else: + raise + + +def _set_mode_traditional(dbapi_con, connection_rec, connection_proxy): + """Set engine mode to 'traditional'. + + Required to prevent silent truncates at insert or update operations + under MySQL. By default MySQL truncates inserted string if it longer + than a declared field just with warning. That is fraught with data + corruption. + """ + dbapi_con.cursor().execute("SET SESSION sql_mode = TRADITIONAL;") + + +def _is_db_connection_error(args): + """Return True if error in connecting to db.""" + # NOTE(adam_g): This is currently MySQL specific and needs to be extended + # to support Postgres and others. + # For the db2, the error code is -30081 since the db2 is still not ready + conn_err_codes = ('2002', '2003', '2006', '2013', '-30081') + for err_code in conn_err_codes: + if args.find(err_code) != -1: + return True + return False + + +def _raise_if_db_connection_lost(error, engine): + # NOTE(vsergeyev): Function is_disconnect(e, connection, cursor) + # requires connection and cursor in incoming parameters, + # but we have no possibility to create connection if DB + # is not available, so in such case reconnect fails. + # But is_disconnect() ignores these parameters, so it + # makes sense to pass to function None as placeholder + # instead of connection and cursor. + if engine.dialect.is_disconnect(error, None, None): + raise exception.DBConnectionError(error) + + +def create_engine(sql_connection, sqlite_fk=False, + mysql_traditional_mode=False, idle_timeout=3600, + connection_debug=0, max_pool_size=None, max_overflow=None, + pool_timeout=None, sqlite_synchronous=True, + connection_trace=False, max_retries=10, retry_interval=10): + """Return a new SQLAlchemy engine.""" + + connection_dict = sqlalchemy.engine.url.make_url(sql_connection) + + engine_args = { + "pool_recycle": idle_timeout, + "echo": False, + 'convert_unicode': True, + } + + # Map our SQL debug level to SQLAlchemy's options + if connection_debug >= 100: + engine_args['echo'] = 'debug' + elif connection_debug >= 50: + engine_args['echo'] = True + + if "sqlite" in connection_dict.drivername: + if sqlite_fk: + engine_args["listeners"] = [SqliteForeignKeysListener()] + engine_args["poolclass"] = NullPool + + if sql_connection == "sqlite://": + engine_args["poolclass"] = StaticPool + engine_args["connect_args"] = {'check_same_thread': False} + else: + if max_pool_size is not None: + engine_args['pool_size'] = max_pool_size + if max_overflow is not None: + engine_args['max_overflow'] = max_overflow + if pool_timeout is not None: + engine_args['pool_timeout'] = pool_timeout + + engine = sqlalchemy.create_engine(sql_connection, **engine_args) + + sqlalchemy.event.listen(engine, 'checkin', _thread_yield) + + if engine.name in ['mysql', 'ibm_db_sa']: + callback = functools.partial(_ping_listener, engine) + sqlalchemy.event.listen(engine, 'checkout', callback) + if engine.name == 'mysql': + if mysql_traditional_mode: + sqlalchemy.event.listen(engine, 'checkout', + _set_mode_traditional) + else: + LOG.warning(_("This application has not enabled MySQL " + "traditional mode, which means silent " + "data corruption may occur. " + "Please encourage the application " + "developers to enable this mode.")) + elif 'sqlite' in connection_dict.drivername: + if not sqlite_synchronous: + sqlalchemy.event.listen(engine, 'connect', + _synchronous_switch_listener) + sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener) + + if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb': + _patch_mysqldb_with_stacktrace_comments() + + try: + engine.connect() + except sqla_exc.OperationalError as e: + if not _is_db_connection_error(e.args[0]): + raise + + remaining = max_retries + if remaining == -1: + remaining = 'infinite' + while True: + msg = _('SQL connection failed. %s attempts left.') + LOG.warning(msg % remaining) + if remaining != 'infinite': + remaining -= 1 + time.sleep(retry_interval) + try: + engine.connect() + break + except sqla_exc.OperationalError as e: + if (remaining != 'infinite' and remaining == 0) or \ + not _is_db_connection_error(e.args[0]): + raise + return engine + + +class Query(sqlalchemy.orm.query.Query): + """Subclass of sqlalchemy.query with soft_delete() method.""" + def soft_delete(self, synchronize_session='evaluate'): + return self.update({'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow()}, + synchronize_session=synchronize_session) + + +class Session(sqlalchemy.orm.session.Session): + """Custom Session class to avoid SqlAlchemy Session monkey patching.""" + @_wrap_db_error + def query(self, *args, **kwargs): + return super(Session, self).query(*args, **kwargs) + + @_wrap_db_error + def flush(self, *args, **kwargs): + return super(Session, self).flush(*args, **kwargs) + + @_wrap_db_error + def execute(self, *args, **kwargs): + return super(Session, self).execute(*args, **kwargs) + + +def get_maker(engine, autocommit=True, expire_on_commit=False): + """Return a SQLAlchemy sessionmaker using the given engine.""" + return sqlalchemy.orm.sessionmaker(bind=engine, + class_=Session, + autocommit=autocommit, + expire_on_commit=expire_on_commit, + query_cls=Query) + + +def _patch_mysqldb_with_stacktrace_comments(): + """Adds current stack trace as a comment in queries. + + Patches MySQLdb.cursors.BaseCursor._do_query. + """ + import MySQLdb.cursors + import traceback + + old_mysql_do_query = MySQLdb.cursors.BaseCursor._do_query + + def _do_query(self, q): + stack = '' + for filename, line, method, function in traceback.extract_stack(): + # exclude various common things from trace + if filename.endswith('session.py') and method == '_do_query': + continue + if filename.endswith('api.py') and method == 'wrapper': + continue + if filename.endswith('utils.py') and method == '_inner': + continue + if filename.endswith('exception.py') and method == '_wrap': + continue + # db/api is just a wrapper around db/sqlalchemy/api + if filename.endswith('db/api.py'): + continue + # only trace inside designate + index = filename.rfind('designate') + if index == -1: + continue + stack += "File:%s:%s Method:%s() Line:%s | " \ + % (filename[index:], line, method, function) + + # strip trailing " | " from stack + if stack: + stack = stack[:-3] + qq = "%s /* %s */" % (q, stack) + else: + qq = q + old_mysql_do_query(self, qq) + + setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query) + + +class EngineFacade(object): + """A helper class for removing of global engine instances from designate.db. + + As a library, designate.db can't decide where to store/when to create engine + and sessionmaker instances, so this must be left for a target application. + + On the other hand, in order to simplify the adoption of designate.db changes, + we'll provide a helper class, which creates engine and sessionmaker + on its instantiation and provides get_engine()/get_session() methods + that are compatible with corresponding utility functions that currently + exist in target projects, e.g. in Nova. + + engine/sessionmaker instances will still be global (and they are meant to + be global), but they will be stored in the app context, rather that in the + designate.db context. + + Note: using of this helper is completely optional and you are encouraged to + integrate engine/sessionmaker instances into your apps any way you like + (e.g. one might want to bind a session to a request context). Two important + things to remember: + 1. An Engine instance is effectively a pool of DB connections, so it's + meant to be shared (and it's thread-safe). + 2. A Session instance is not meant to be shared and represents a DB + transactional context (i.e. it's not thread-safe). sessionmaker is + a factory of sessions. + + """ + + def __init__(self, sql_connection, + sqlite_fk=False, mysql_traditional_mode=False, + autocommit=True, expire_on_commit=False, **kwargs): + """Initialize engine and sessionmaker instances. + + :param sqlite_fk: enable foreign keys in SQLite + :type sqlite_fk: bool + + :param mysql_traditional_mode: enable traditional mode in MySQL + :type mysql_traditional_mode: bool + + :param autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :param expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + Keyword arguments: + + :keyword idle_timeout: timeout before idle sql connections are reaped + (defaults to 3600) + :keyword connection_debug: verbosity of SQL debugging information. + 0=None, 100=Everything (defaults to 0) + :keyword max_pool_size: maximum number of SQL connections to keep open + in a pool (defaults to SQLAlchemy settings) + :keyword max_overflow: if set, use this value for max_overflow with + sqlalchemy (defaults to SQLAlchemy settings) + :keyword pool_timeout: if set, use this value for pool_timeout with + sqlalchemy (defaults to SQLAlchemy settings) + :keyword sqlite_synchronous: if True, SQLite uses synchronous mode + (defaults to True) + :keyword connection_trace: add python stack traces to SQL as comment + strings (defaults to False) + :keyword max_retries: maximum db connection retries during startup. + (setting -1 implies an infinite retry count) + (defaults to 10) + :keyword retry_interval: interval between retries of opening a sql + connection (defaults to 10) + + """ + + super(EngineFacade, self).__init__() + + self._engine = create_engine( + sql_connection=sql_connection, + sqlite_fk=sqlite_fk, + mysql_traditional_mode=mysql_traditional_mode, + idle_timeout=kwargs.get('idle_timeout', 3600), + connection_debug=kwargs.get('connection_debug', 0), + max_pool_size=kwargs.get('max_pool_size', None), + max_overflow=kwargs.get('max_overflow', None), + pool_timeout=kwargs.get('pool_timeout', None), + sqlite_synchronous=kwargs.get('sqlite_synchronous', True), + connection_trace=kwargs.get('connection_trace', False), + max_retries=kwargs.get('max_retries', 10), + retry_interval=kwargs.get('retry_interval', 10)) + self._session_maker = get_maker( + engine=self._engine, + autocommit=autocommit, + expire_on_commit=expire_on_commit) + + def get_engine(self): + """Get the engine instance (note, that it's shared).""" + + return self._engine + + def get_session(self, **kwargs): + """Get a Session instance. + + If passed, keyword arguments values override the ones used when the + sessionmaker instance was created. + + :keyword autocommit: use autocommit mode for created Session instances + :type autocommit: bool + + :keyword expire_on_commit: expire session objects on commit + :type expire_on_commit: bool + + """ + + for arg in kwargs: + if arg not in ('autocommit', 'expire_on_commit'): + del kwargs[arg] + + return self._session_maker(**kwargs) diff --git a/designate/openstack/common/db/sqlalchemy/test_base.py b/designate/openstack/common/db/sqlalchemy/test_base.py new file mode 100644 index 000000000..a97875d00 --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/test_base.py @@ -0,0 +1,149 @@ +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc +import functools +import os + +import fixtures +import six + +from designate.openstack.common.db.sqlalchemy import session +from designate.openstack.common.db.sqlalchemy import utils +from designate.openstack.common import test + + +class DbFixture(fixtures.Fixture): + """Basic database fixture. + + Allows to run tests on various db backends, such as SQLite, MySQL and + PostgreSQL. By default use sqlite backend. To override default backend + uri set env variable OS_TEST_DBAPI_CONNECTION with database admin + credentials for specific backend. + """ + + def _get_uri(self): + return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://') + + def __init__(self, test): + super(DbFixture, self).__init__() + + self.test = test + + def setUp(self): + super(DbFixture, self).setUp() + + self.test.engine = session.create_engine(self._get_uri()) + self.test.sessionmaker = session.get_maker(self.test.engine) + self.addCleanup(self.test.engine.dispose) + + +class DbTestCase(test.BaseTestCase): + """Base class for testing of DB code. + + Using `DbFixture`. Intended to be the main database test case to use all + the tests on a given backend with user defined uri. Backend specific + tests should be decorated with `backend_specific` decorator. + """ + + FIXTURE = DbFixture + + def setUp(self): + super(DbTestCase, self).setUp() + self.useFixture(self.FIXTURE(self)) + + +ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql'] + + +def backend_specific(*dialects): + """Decorator to skip backend specific tests on inappropriate engines. + + ::dialects: list of dialects names under which the test will be launched. + """ + def wrap(f): + @functools.wraps(f) + def ins_wrap(self): + if not set(dialects).issubset(ALLOWED_DIALECTS): + raise ValueError( + "Please use allowed dialects: %s" % ALLOWED_DIALECTS) + if self.engine.name not in dialects: + msg = ('The test "%s" can be run ' + 'only on %s. Current engine is %s.') + args = (f.__name__, ' '.join(dialects), self.engine.name) + self.skip(msg % args) + else: + return f(self) + return ins_wrap + return wrap + + +@six.add_metaclass(abc.ABCMeta) +class OpportunisticFixture(DbFixture): + """Base fixture to use default CI databases. + + The databases exist in OpenStack CI infrastructure. But for the + correct functioning in local environment the databases must be + created manually. + """ + + DRIVER = abc.abstractproperty(lambda: None) + DBNAME = PASSWORD = USERNAME = 'openstack_citest' + + def _get_uri(self): + return utils.get_connect_string(backend=self.DRIVER, + user=self.USERNAME, + passwd=self.PASSWORD, + database=self.DBNAME) + + +@six.add_metaclass(abc.ABCMeta) +class OpportunisticTestCase(DbTestCase): + """Base test case to use default CI databases. + + The subclasses of the test case are running only when openstack_citest + database is available otherwise a tests will be skipped. + """ + + FIXTURE = abc.abstractproperty(lambda: None) + + def setUp(self): + credentials = { + 'backend': self.FIXTURE.DRIVER, + 'user': self.FIXTURE.USERNAME, + 'passwd': self.FIXTURE.PASSWORD, + 'database': self.FIXTURE.DBNAME} + + if self.FIXTURE.DRIVER and not utils.is_backend_avail(**credentials): + msg = '%s backend is not available.' % self.FIXTURE.DRIVER + return self.skip(msg) + + super(OpportunisticTestCase, self).setUp() + + +class MySQLOpportunisticFixture(OpportunisticFixture): + DRIVER = 'mysql' + + +class PostgreSQLOpportunisticFixture(OpportunisticFixture): + DRIVER = 'postgresql' + + +class MySQLOpportunisticTestCase(OpportunisticTestCase): + FIXTURE = MySQLOpportunisticFixture + + +class PostgreSQLOpportunisticTestCase(OpportunisticTestCase): + FIXTURE = PostgreSQLOpportunisticFixture diff --git a/designate/openstack/common/db/sqlalchemy/test_migrations.py b/designate/openstack/common/db/sqlalchemy/test_migrations.py new file mode 100644 index 000000000..5e2e278fa --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/test_migrations.py @@ -0,0 +1,269 @@ +# Copyright 2010-2011 OpenStack Foundation +# Copyright 2012-2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import logging +import os +import subprocess + +import lockfile +from six import moves +from six.moves.urllib import parse +import sqlalchemy +import sqlalchemy.exc + +from designate.openstack.common.db.sqlalchemy import utils +from designate.openstack.common.gettextutils import _ +from designate.openstack.common import test + +LOG = logging.getLogger(__name__) + + +def _have_mysql(user, passwd, database): + present = os.environ.get('TEST_MYSQL_PRESENT') + if present is None: + return utils.is_backend_avail(backend='mysql', + user=user, + passwd=passwd, + database=database) + return present.lower() in ('', 'true') + + +def _have_postgresql(user, passwd, database): + present = os.environ.get('TEST_POSTGRESQL_PRESENT') + if present is None: + return utils.is_backend_avail(backend='postgres', + user=user, + passwd=passwd, + database=database) + return present.lower() in ('', 'true') + + +def _set_db_lock(lock_path=None, lock_prefix=None): + def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + path = lock_path or os.environ.get("DESIGNATE_LOCK_PATH") + lock = lockfile.FileLock(os.path.join(path, lock_prefix)) + with lock: + LOG.debug(_('Got lock "%s"') % f.__name__) + return f(*args, **kwargs) + finally: + LOG.debug(_('Lock released "%s"') % f.__name__) + return wrapper + return decorator + + +class BaseMigrationTestCase(test.BaseTestCase): + """Base class fort testing of migration utils.""" + + def __init__(self, *args, **kwargs): + super(BaseMigrationTestCase, self).__init__(*args, **kwargs) + + self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__), + 'test_migrations.conf') + # Test machines can set the TEST_MIGRATIONS_CONF variable + # to override the location of the config file for migration testing + self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF', + self.DEFAULT_CONFIG_FILE) + self.test_databases = {} + self.migration_api = None + + def setUp(self): + super(BaseMigrationTestCase, self).setUp() + + # Load test databases from the config file. Only do this + # once. No need to re-run this on each test... + LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH) + if os.path.exists(self.CONFIG_FILE_PATH): + cp = moves.configparser.RawConfigParser() + try: + cp.read(self.CONFIG_FILE_PATH) + defaults = cp.defaults() + for key, value in defaults.items(): + self.test_databases[key] = value + except moves.configparser.ParsingError as e: + self.fail("Failed to read test_migrations.conf config " + "file. Got error: %s" % e) + else: + self.fail("Failed to find test_migrations.conf config " + "file.") + + self.engines = {} + for key, value in self.test_databases.items(): + self.engines[key] = sqlalchemy.create_engine(value) + + # We start each test case with a completely blank slate. + self._reset_databases() + + def tearDown(self): + # We destroy the test data store between each test case, + # and recreate it, which ensures that we have no side-effects + # from the tests + self._reset_databases() + super(BaseMigrationTestCase, self).tearDown() + + def execute_cmd(self, cmd=None): + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + output = process.communicate()[0] + LOG.debug(output) + self.assertEqual(0, process.returncode, + "Failed to run: %s\n%s" % (cmd, output)) + + def _reset_pg(self, conn_pieces): + (user, + password, + database, + host) = utils.get_db_connection_info(conn_pieces) + os.environ['PGPASSWORD'] = password + os.environ['PGUSER'] = user + # note(boris-42): We must create and drop database, we can't + # drop database which we have connected to, so for such + # operations there is a special database template1. + sqlcmd = ("psql -w -U %(user)s -h %(host)s -c" + " '%(sql)s' -d template1") + + sql = ("drop database if exists %s;") % database + droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql} + self.execute_cmd(droptable) + + sql = ("create database %s;") % database + createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql} + self.execute_cmd(createtable) + + os.unsetenv('PGPASSWORD') + os.unsetenv('PGUSER') + + @_set_db_lock(lock_prefix='migration_tests-') + def _reset_databases(self): + for key, engine in self.engines.items(): + conn_string = self.test_databases[key] + conn_pieces = parse.urlparse(conn_string) + engine.dispose() + if conn_string.startswith('sqlite'): + # We can just delete the SQLite database, which is + # the easiest and cleanest solution + db_path = conn_pieces.path.strip('/') + if os.path.exists(db_path): + os.unlink(db_path) + # No need to recreate the SQLite DB. SQLite will + # create it for us if it's not there... + elif conn_string.startswith('mysql'): + # We can execute the MySQL client to destroy and re-create + # the MYSQL database, which is easier and less error-prone + # than using SQLAlchemy to do this via MetaData...trust me. + (user, password, database, host) = \ + utils.get_db_connection_info(conn_pieces) + sql = ("drop database if exists %(db)s; " + "create database %(db)s;") % {'db': database} + cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s " + "-e \"%(sql)s\"") % {'user': user, 'password': password, + 'host': host, 'sql': sql} + self.execute_cmd(cmd) + elif conn_string.startswith('postgresql'): + self._reset_pg(conn_pieces) + + +class WalkVersionsMixin(object): + def _walk_versions(self, engine=None, snake_walk=False, downgrade=True): + # Determine latest version script from the repo, then + # upgrade from 1 through to the latest, with no data + # in the databases. This just checks that the schema itself + # upgrades successfully. + + # Place the database under version control + self.migration_api.version_control(engine, self.REPOSITORY, + self.INIT_VERSION) + self.assertEqual(self.INIT_VERSION, + self.migration_api.db_version(engine, + self.REPOSITORY)) + + LOG.debug('latest version is %s' % self.REPOSITORY.latest) + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + for version in versions: + # upgrade -> downgrade -> upgrade + self._migrate_up(engine, version, with_data=True) + if snake_walk: + downgraded = self._migrate_down( + engine, version - 1, with_data=True) + if downgraded: + self._migrate_up(engine, version) + + if downgrade: + # Now walk it back down to 0 from the latest, testing + # the downgrade paths. + for version in reversed(versions): + # downgrade -> upgrade -> downgrade + downgraded = self._migrate_down(engine, version - 1) + + if snake_walk and downgraded: + self._migrate_up(engine, version) + self._migrate_down(engine, version - 1) + + def _migrate_down(self, engine, version, with_data=False): + try: + self.migration_api.downgrade(engine, self.REPOSITORY, version) + except NotImplementedError: + # NOTE(sirp): some migrations, namely release-level + # migrations, don't support a downgrade. + return False + + self.assertEqual( + version, self.migration_api.db_version(engine, self.REPOSITORY)) + + # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target' + # version). So if we have any downgrade checks, they need to be run for + # the previous (higher numbered) migration. + if with_data: + post_downgrade = getattr( + self, "_post_downgrade_%03d" % (version + 1), None) + if post_downgrade: + post_downgrade(engine) + + return True + + def _migrate_up(self, engine, version, with_data=False): + """migrate up to a new version of the db. + + We allow for data insertion and post checks at every + migration version with special _pre_upgrade_### and + _check_### functions in the main test. + """ + # NOTE(sdague): try block is here because it's impossible to debug + # where a failed data migration happens otherwise + try: + if with_data: + data = None + pre_upgrade = getattr( + self, "_pre_upgrade_%03d" % version, None) + if pre_upgrade: + data = pre_upgrade(engine) + + self.migration_api.upgrade(engine, self.REPOSITORY, version) + self.assertEqual(version, + self.migration_api.db_version(engine, + self.REPOSITORY)) + if with_data: + check = getattr(self, "_check_%03d" % version, None) + if check: + check(engine, data) + except Exception: + LOG.error("Failed to migrate to version %s on engine %s" % + (version, engine)) + raise diff --git a/designate/openstack/common/db/sqlalchemy/utils.py b/designate/openstack/common/db/sqlalchemy/utils.py new file mode 100644 index 000000000..563d5bab8 --- /dev/null +++ b/designate/openstack/common/db/sqlalchemy/utils.py @@ -0,0 +1,547 @@ +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010-2011 OpenStack Foundation. +# Copyright 2012 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import re + +from migrate.changeset import UniqueConstraint +import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy.engine import reflection +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import UpdateBase +from sqlalchemy.sql import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.types import NullType + +from designate.openstack.common.gettextutils import _ +from designate.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + +_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") + + +def sanitize_db_url(url): + match = _DBURL_REGEX.match(url) + if match: + return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):]) + return url + + +class InvalidSortKey(Exception): + message = _("Sort key supplied was not valid.") + + +# copy from glance/db/sqlalchemy/api.py +def paginate_query(query, model, limit, sort_keys, marker=None, + sort_dir=None, sort_dirs=None): + """Returns a query with sorting / pagination criteria added. + + Pagination works by requiring a unique sort_key, specified by sort_keys. + (If sort_keys is not unique, then we risk looping through values.) + We use the last row in the previous page as the 'marker' for pagination. + So we must return values that follow the passed marker in the order. + With a single-valued sort_key, this would be easy: sort_key > X. + With a compound-values sort_key, (k1, k2, k3) we must do this to repeat + the lexicographical ordering: + (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3) + + We also have to cope with different sort_directions. + + Typically, the id of the last row is used as the client-facing pagination + marker, then the actual marker object must be fetched from the db and + passed in to us as marker. + + :param query: the query object to which we should add paging/sorting + :param model: the ORM model class + :param limit: maximum number of items to return + :param sort_keys: array of attributes by which results should be sorted + :param marker: the last item of the previous page; we returns the next + results after this value. + :param sort_dir: direction in which results should be sorted (asc, desc) + :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys + + :rtype: sqlalchemy.orm.query.Query + :return: The query with sorting/pagination added. + """ + + if 'id' not in sort_keys: + # TODO(justinsb): If this ever gives a false-positive, check + # the actual primary key, rather than assuming its id + LOG.warning(_('Id not in sort_keys; is sort_keys unique?')) + + assert(not (sort_dir and sort_dirs)) + + # Default the sort direction to ascending + if sort_dirs is None and sort_dir is None: + sort_dir = 'asc' + + # Ensure a per-column sort direction + if sort_dirs is None: + sort_dirs = [sort_dir for _sort_key in sort_keys] + + assert(len(sort_dirs) == len(sort_keys)) + + # Add sorting + for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): + try: + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + except KeyError: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) + try: + sort_key_attr = getattr(model, current_sort_key) + except AttributeError: + raise InvalidSortKey() + query = query.order_by(sort_dir_func(sort_key_attr)) + + # Add pagination + if marker is not None: + marker_values = [] + for sort_key in sort_keys: + v = getattr(marker, sort_key) + marker_values.append(v) + + # Build up an array of sort criteria as in the docstring + criteria_list = [] + for i in range(len(sort_keys)): + crit_attrs = [] + for j in range(i): + model_attr = getattr(model, sort_keys[j]) + crit_attrs.append((model_attr == marker_values[j])) + + model_attr = getattr(model, sort_keys[i]) + if sort_dirs[i] == 'desc': + crit_attrs.append((model_attr < marker_values[i])) + else: + crit_attrs.append((model_attr > marker_values[i])) + + criteria = sqlalchemy.sql.and_(*crit_attrs) + criteria_list.append(criteria) + + f = sqlalchemy.sql.or_(*criteria_list) + query = query.filter(f) + + if limit is not None: + query = query.limit(limit) + + return query + + +def get_table(engine, name): + """Returns an sqlalchemy table dynamically from db. + + Needed because the models don't work for us in migrations + as models will be far out of sync with the current data. + """ + metadata = MetaData() + metadata.bind = engine + return Table(name, metadata, autoload=True) + + +class InsertFromSelect(UpdateBase): + """Form the base for `INSERT INTO table (SELECT ... )` statement.""" + def __init__(self, table, select): + self.table = table + self.select = select + + +@compiles(InsertFromSelect) +def visit_insert_from_select(element, compiler, **kw): + """Form the `INSERT INTO table (SELECT ... )` statement.""" + return "INSERT INTO %s %s" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select)) + + +class ColumnError(Exception): + """Error raised when no column or an invalid column is found.""" + + +def _get_not_supported_column(col_name_col_instance, column_name): + try: + column = col_name_col_instance[column_name] + except KeyError: + msg = _("Please specify column %s in col_name_col_instance " + "param. It is required because column has unsupported " + "type by sqlite).") + raise ColumnError(msg % column_name) + + if not isinstance(column, Column): + msg = _("col_name_col_instance param has wrong type of " + "column instance for column %s It should be instance " + "of sqlalchemy.Column.") + raise ColumnError(msg % column_name) + return column + + +def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns, + **col_name_col_instance): + """Drop unique constraint from table. + + This method drops UC from table and works for mysql, postgresql and sqlite. + In mysql and postgresql we are able to use "alter table" construction. + Sqlalchemy doesn't support some sqlite column types and replaces their + type with NullType in metadata. We process these columns and replace + NullType with the correct column type. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of table that contains uniq constraint. + :param uc_name: name of uniq constraint that will be dropped. + :param columns: columns that are in uniq constraint. + :param col_name_col_instance: contains pair column_name=column_instance. + column_instance is instance of Column. These params + are required only for columns that have unsupported + types by sqlite. For example BigInteger. + """ + + meta = MetaData() + meta.bind = migrate_engine + t = Table(table_name, meta, autoload=True) + + if migrate_engine.name == "sqlite": + override_cols = [ + _get_not_supported_column(col_name_col_instance, col.name) + for col in t.columns + if isinstance(col.type, NullType) + ] + for col in override_cols: + t.columns.replace(col) + + uc = UniqueConstraint(*columns, table=t, name=uc_name) + uc.drop() + + +def drop_old_duplicate_entries_from_table(migrate_engine, table_name, + use_soft_delete, *uc_column_names): + """Drop all old rows having the same values for columns in uc_columns. + + This method drop (or mark ad `deleted` if use_soft_delete is True) old + duplicate rows form table with name `table_name`. + + :param migrate_engine: Sqlalchemy engine + :param table_name: Table with duplicates + :param use_soft_delete: If True - values will be marked as `deleted`, + if False - values will be removed from table + :param uc_column_names: Unique constraint columns + """ + meta = MetaData() + meta.bind = migrate_engine + + table = Table(table_name, meta, autoload=True) + columns_for_group_by = [table.c[name] for name in uc_column_names] + + columns_for_select = [func.max(table.c.id)] + columns_for_select.extend(columns_for_group_by) + + duplicated_rows_select = select(columns_for_select, + group_by=columns_for_group_by, + having=func.count(table.c.id) > 1) + + for row in migrate_engine.execute(duplicated_rows_select): + # NOTE(boris-42): Do not remove row that has the biggest ID. + delete_condition = table.c.id != row[0] + is_none = None # workaround for pyflakes + delete_condition &= table.c.deleted_at == is_none + for name in uc_column_names: + delete_condition &= table.c[name] == row[name] + + rows_to_delete_select = select([table.c.id]).where(delete_condition) + for row in migrate_engine.execute(rows_to_delete_select).fetchall(): + LOG.info(_("Deleting duplicated row with id: %(id)s from table: " + "%(table)s") % dict(id=row[0], table=table_name)) + + if use_soft_delete: + delete_statement = table.update().\ + where(delete_condition).\ + values({ + 'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow() + }) + else: + delete_statement = table.delete().where(delete_condition) + migrate_engine.execute(delete_statement) + + +def _get_default_deleted_value(table): + if isinstance(table.c.id.type, Integer): + return 0 + if isinstance(table.c.id.type, String): + return "" + raise ColumnError(_("Unsupported id columns type")) + + +def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): + table = get_table(migrate_engine, table_name) + + insp = reflection.Inspector.from_engine(migrate_engine) + real_indexes = insp.get_indexes(table_name) + existing_index_names = dict( + [(index['name'], index['column_names']) for index in real_indexes]) + + # NOTE(boris-42): Restore indexes on `deleted` column + for index in indexes: + if 'deleted' not in index['column_names']: + continue + name = index['name'] + if name in existing_index_names: + column_names = [table.c[c] for c in existing_index_names[name]] + old_index = Index(name, *column_names, unique=index["unique"]) + old_index.drop(migrate_engine) + + column_names = [table.c[c] for c in index['column_names']] + new_index = Index(index["name"], *column_names, unique=index["unique"]) + new_index.create(migrate_engine) + + +def change_deleted_column_type_to_boolean(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_boolean_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + old_deleted = Column('old_deleted', Boolean, default=False) + old_deleted.create(table, populate_default=False) + + table.update().\ + where(table.c.deleted == table.c.id).\ + values(old_deleted=True).\ + execute() + + table.c.deleted.drop() + table.c.old_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name, + **col_name_col_instance): + insp = reflection.Inspector.from_engine(migrate_engine) + table = get_table(migrate_engine, table_name) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', Boolean, default=0) + columns.append(column_copy) + + constraints = [constraint.copy() for constraint in table.constraints] + + meta = table.metadata + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + c_select = [] + for c in table.c: + if c.name != "deleted": + c_select.append(c) + else: + c_select.append(table.c.deleted == table.c.id) + + ins = InsertFromSelect(new_table, select(c_select)) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + new_table.update().\ + where(new_table.c.deleted == new_table.c.id).\ + values(deleted=True).\ + execute() + + +def change_deleted_column_type_to_id_type(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_id_type_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + new_deleted = Column('new_deleted', table.c.id.type, + default=_get_default_deleted_value(table)) + new_deleted.create(table, populate_default=True) + + deleted = True # workaround for pyflakes + table.update().\ + where(table.c.deleted == deleted).\ + values(new_deleted=table.c.id).\ + execute() + table.c.deleted.drop() + table.c.new_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name, + **col_name_col_instance): + # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check + # constraints in sqlite DB and our `deleted` column has + # 2 check constraints. So there is only one way to remove + # these constraints: + # 1) Create new table with the same columns, constraints + # and indexes. (except deleted column). + # 2) Copy all data from old to new table. + # 3) Drop old table. + # 4) Rename new table to old table name. + insp = reflection.Inspector.from_engine(migrate_engine) + meta = MetaData(bind=migrate_engine) + table = Table(table_name, meta, autoload=True) + default_deleted_value = _get_default_deleted_value(table) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', table.c.id.type, + default=default_deleted_value) + columns.append(column_copy) + + def is_deleted_column_constraint(constraint): + # NOTE(boris-42): There is no other way to check is CheckConstraint + # associated with deleted column. + if not isinstance(constraint, CheckConstraint): + return False + sqltext = str(constraint.sqltext) + return (sqltext.endswith("deleted in (0, 1)") or + sqltext.endswith("deleted IN (:deleted_1, :deleted_2)")) + + constraints = [] + for constraint in table.constraints: + if not is_deleted_column_constraint(constraint): + constraints.append(constraint.copy()) + + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + ins = InsertFromSelect(new_table, table.select()) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + deleted = True # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=new_table.c.id).\ + execute() + + # NOTE(boris-42): Fix value of deleted column: False -> "" or 0. + deleted = False # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=default_deleted_value).\ + execute() + + +def get_connect_string(backend, database, user=None, passwd=None): + """Get database connection + + Try to get a connection with a very specific set of values, if we get + these then we'll run the tests, otherwise they are skipped + """ + args = {'backend': backend, + 'user': user, + 'passwd': passwd, + 'database': database} + if backend == 'sqlite': + template = '%(backend)s:///%(database)s' + else: + template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" + return template % args + + +def is_backend_avail(backend, database, user=None, passwd=None): + try: + connect_uri = get_connect_string(backend=backend, + database=database, + user=user, + passwd=passwd) + engine = sqlalchemy.create_engine(connect_uri) + connection = engine.connect() + except Exception: + # intentionally catch all to handle exceptions even if we don't + # have any backend code loaded. + return False + else: + connection.close() + engine.dispose() + return True + + +def get_db_connection_info(conn_pieces): + database = conn_pieces.path.strip('/') + loc_pieces = conn_pieces.netloc.split('@') + host = loc_pieces[1] + + auth_pieces = loc_pieces[0].split(':') + user = auth_pieces[0] + password = "" + if len(auth_pieces) > 1: + password = auth_pieces[1].strip() + + return (user, password, database, host) diff --git a/designate/storage/api.py b/designate/storage/api.py index 083c6c5dc..3d82eac8d 100644 --- a/designate/storage/api.py +++ b/designate/storage/api.py @@ -55,14 +55,16 @@ class StorageAPI(object): """ return self.storage.get_quota(context, quota_id) - def find_quotas(self, context, criterion=None): + def find_quotas(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find Quotas :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_quotas(context, criterion) + return self.storage.find_quotas( + context, criterion, marker, limit, sort_key, sort_dir) def find_quota(self, context, criterion): """ @@ -140,14 +142,16 @@ class StorageAPI(object): """ return self.storage.get_server(context, server_id) - def find_servers(self, context, criterion=None): + def find_servers(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find Servers :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_servers(context, criterion) + return self.storage.find_servers( + context, criterion, marker, limit, sort_key, sort_dir) def find_server(self, context, criterion): """ @@ -225,14 +229,16 @@ class StorageAPI(object): """ return self.storage.get_tld(context, tld_id) - def find_tlds(self, context, criterion=None): + def find_tlds(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find TLDs :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_tlds(context, criterion) + return self.storage.find_tlds( + context, criterion, marker, limit, sort_key, sort_dir) def find_tld(self, context, criterion): """ @@ -309,14 +315,16 @@ class StorageAPI(object): """ return self.storage.get_tsigkey(context, tsigkey_id) - def find_tsigkeys(self, context, criterion=None): + def find_tsigkeys(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find Tsigkey :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_tsigkeys(context, criterion) + return self.storage.find_tsigkeys( + context, criterion, marker, limit, sort_key, sort_dir) def find_tsigkey(self, context, criterion): """ @@ -419,14 +427,16 @@ class StorageAPI(object): """ return self.storage.get_domain(context, domain_id) - def find_domains(self, context, criterion=None): + def find_domains(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find Domains :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_domains(context, criterion) + return self.storage.find_domains( + context, criterion, marker, limit, sort_key, sort_dir) def find_domain(self, context, criterion): """ @@ -515,14 +525,16 @@ class StorageAPI(object): """ return self.storage.get_recordset(context, recordset_id) - def find_recordsets(self, context, criterion=None): + def find_recordsets(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find RecordSets. :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_recordsets(context, criterion) + return self.storage.find_recordsets( + context, criterion, marker, limit, sort_key, sort_dir) def find_recordset(self, context, criterion=None): """ @@ -612,14 +624,16 @@ class StorageAPI(object): """ return self.storage.get_record(context, record_id) - def find_records(self, context, criterion=None): + def find_records(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find Records. :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_records(context, criterion) + return self.storage.find_records( + context, criterion, marker, limit, sort_key, sort_dir) def find_record(self, context, criterion=None): """ @@ -705,14 +719,16 @@ class StorageAPI(object): """ return self.storage.get_blacklist(context, blacklist_id) - def find_blacklists(self, context, criterion=None): + def find_blacklists(self, context, criterion=None, marker=None, limit=None, + sort_key=None, sort_dir=None): """ Find all Blacklisted Domains :param context: RPC Context. :param criterion: Criteria to filter by. """ - return self.storage.find_blacklists(context, criterion) + return self.storage.find_blacklists( + context, criterion, marker, limit, sort_key, sort_dir) def find_blacklist(self, context, criterion): """ diff --git a/designate/storage/base.py b/designate/storage/base.py index 11f650d8c..dcb1ddbc5 100644 --- a/designate/storage/base.py +++ b/designate/storage/base.py @@ -18,6 +18,7 @@ from designate.plugin import DriverPlugin class Storage(DriverPlugin): + """ Base class for storage plugins """ __metaclass__ = abc.ABCMeta __plugin_ns__ = 'designate.storage' @@ -42,12 +43,19 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_quotas(self, context, criterion): + def find_quotas(self, context, criterion=None, marker=None, + limit=None, sort_key=None, sort_dir=None): """ Find Quotas :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -88,12 +96,19 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_servers(self, context, criterion=None): + def find_servers(self, context, criterion=None, marker=None, + limit=None, sort_key=None, sort_dir=None): """ Find Servers. :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -143,21 +158,34 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_tlds(self, context, criterion=None): + def find_tlds(self, context, criterion=None, marker=None, + limit=None, sort_key=None, sort_dir=None): """ Find TLDs :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod - def find_tld(self, context, criterion): + def find_tld(self, context, criterion=None): """ Find a single TLD. :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -188,12 +216,19 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_tsigkeys(self, context, criterion=None): + def find_tsigkeys(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): """ Find TSIG Keys. :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -268,12 +303,19 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_domains(self, context, criterion=None): + def find_domains(self, context, criterion=None, marker=None, + limit=None, sort_key=None, sort_dir=None): """ Find Domains :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -333,13 +375,20 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_recordsets(self, context, domain_id, criterion=None): + def find_recordsets(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): """ Find RecordSets. :param context: RPC Context. :param domain_id: Domain ID where the recordsets reside. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -399,12 +448,19 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_records(self, context, criterion=None): + def find_records(self, context, criterion=None, marker=None, + limit=None, sort_key=None, sort_dir=None): """ Find Records. :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod @@ -462,12 +518,19 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def find_blacklists(self, context, criterion): + def find_blacklists(self, context, criterion=None, marker=None, + limit=None, sort_key=None, sort_dir=None): """ Find Blacklists :param context: RPC Context. :param criterion: Criteria to filter by. + :param marker: Resource ID from which after the requested page will + start after + :param limit: Integer limit of objects of the page size after the + marker + :param sort_key: Key from which to sort after. + :param sort_dir: Direction to sort after using sort_key. """ @abc.abstractmethod diff --git a/designate/storage/impl_sqlalchemy/__init__.py b/designate/storage/impl_sqlalchemy/__init__.py index d49f6112d..a8a86e3c2 100644 --- a/designate/storage/impl_sqlalchemy/__init__.py +++ b/designate/storage/impl_sqlalchemy/__init__.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import exc from sqlalchemy import distinct, func from oslo.config import cfg from designate.openstack.common import log as logging +from designate.openstack.common.db.sqlalchemy.utils import paginate_query from designate import exceptions from designate.storage import base from designate.storage.impl_sqlalchemy import models @@ -26,6 +27,7 @@ from designate.sqlalchemy.session import get_session from designate.sqlalchemy.session import get_engine from designate.sqlalchemy.session import SQLOPTS + LOG = logging.getLogger(__name__) cfg.CONF.register_group(cfg.OptGroup( @@ -92,7 +94,8 @@ class SQLAlchemyStorage(base.Storage): return query - def _find(self, model, context, criterion, one=False): + def _find(self, model, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): """ Base "finder" method @@ -112,7 +115,23 @@ class SQLAlchemyStorage(base.Storage): except (exc.NoResultFound, exc.MultipleResultsFound): raise exceptions.NotFound() else: + # If marker is not none and basestring we query it. # Othwewise, return all matching records + if marker is not None: + try: + marker = self._find(model, context, {'id': marker}, + one=True) + except exceptions.NotFound: + raise exceptions.MarkerNotFound( + 'Marker %s could not be found' % marker) + sort_key = sort_key or 'created_at' + sort_dir = sort_dir or 'asc' + + query = paginate_query( + query, model, limit, + [sort_key, 'id', 'created_at'], marker=marker, + sort_dir=sort_dir) + return query.all() ## CRUD for our resources (quota, server, tsigkey, tenant, domain & record) @@ -126,9 +145,12 @@ class SQLAlchemyStorage(base.Storage): ## # Quota Methods - def _find_quotas(self, context, criterion, one=False): + def _find_quotas(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.Quota, context, criterion, one) + return self._find(models.Quota, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.QuotaNotFound() @@ -149,8 +171,11 @@ class SQLAlchemyStorage(base.Storage): return dict(quota) - def find_quotas(self, context, criterion=None): - quotas = self._find_quotas(context, criterion) + def find_quotas(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + quotas = self._find_quotas(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(q) for q in quotas] @@ -177,9 +202,12 @@ class SQLAlchemyStorage(base.Storage): quota.delete(self.session) # Server Methods - def _find_servers(self, context, criterion, one=False): + def _find_servers(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.Server, context, criterion, one) + return self._find(models.Server, context, criterion, one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.ServerNotFound() @@ -195,9 +223,11 @@ class SQLAlchemyStorage(base.Storage): return dict(server) - def find_servers(self, context, criterion=None): - servers = self._find_servers(context, criterion) - + def find_servers(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + servers = self._find_servers(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(s) for s in servers] def get_server(self, context, server_id): @@ -222,9 +252,12 @@ class SQLAlchemyStorage(base.Storage): server.delete(self.session) # TLD Methods - def _find_tlds(self, context, criterion, one=False): + def _find_tlds(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.Tld, context, criterion, one) + return self._find(models.Tld, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.TLDNotFound() @@ -239,8 +272,10 @@ class SQLAlchemyStorage(base.Storage): return dict(tld) - def find_tlds(self, context, criterion=None): - tlds = self._find_tlds(context, criterion) + def find_tlds(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + tlds = self._find_tlds(context, criterion, marker=marker, limit=limit, + sort_key=sort_key, sort_dir=sort_dir) return [dict(s) for s in tlds] def find_tld(self, context, criterion=None): @@ -267,9 +302,12 @@ class SQLAlchemyStorage(base.Storage): tld.delete(self.session) # TSIG Key Methods - def _find_tsigkeys(self, context, criterion, one=False): + def _find_tsigkeys(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.TsigKey, context, criterion, one) + return self._find(models.TsigKey, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.TsigKeyNotFound() @@ -285,8 +323,11 @@ class SQLAlchemyStorage(base.Storage): return dict(tsigkey) - def find_tsigkeys(self, context, criterion=None): - tsigkeys = self._find_tsigkeys(context, criterion) + def find_tsigkeys(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + tsigkeys = self._find_tsigkeys(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(t) for t in tsigkeys] @@ -352,9 +393,12 @@ class SQLAlchemyStorage(base.Storage): ## ## Domain Methods ## - def _find_domains(self, context, criterion, one=False): + def _find_domains(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.Domain, context, criterion, one) + return self._find(models.Domain, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.DomainNotFound() @@ -375,8 +419,11 @@ class SQLAlchemyStorage(base.Storage): return dict(domain) - def find_domains(self, context, criterion=None): - domains = self._find_domains(context, criterion) + def find_domains(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + domains = self._find_domains(context, criterion, marker=marker, + limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(d) for d in domains] @@ -412,9 +459,13 @@ class SQLAlchemyStorage(base.Storage): return query.count() # RecordSet Methods - def _find_recordsets(self, context, criterion, one=False): + def _find_recordsets(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, + sort_dir=None): try: - return self._find(models.RecordSet, context, criterion, one) + return self._find(models.RecordSet, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.RecordSetNotFound() @@ -441,8 +492,11 @@ class SQLAlchemyStorage(base.Storage): return dict(recordset) - def find_recordsets(self, context, criterion=None): - recordsets = self._find_recordsets(context, criterion) + def find_recordsets(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + recordsets = self._find_recordsets( + context, criterion, marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(r) for r in recordsets] @@ -479,9 +533,12 @@ class SQLAlchemyStorage(base.Storage): return query.count() # Record Methods - def _find_records(self, context, criterion, one=False): + def _find_records(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.Record, context, criterion, one) + return self._find(models.Record, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.RecordNotFound() @@ -506,8 +563,11 @@ class SQLAlchemyStorage(base.Storage): return dict(record) - def find_records(self, context, criterion=None): - records = self._find_records(context, criterion) + def find_records(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + records = self._find_records( + context, criterion, marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(r) for r in records] @@ -549,9 +609,12 @@ class SQLAlchemyStorage(base.Storage): # # Blacklist Methods # - def _find_blacklist(self, context, criterion, one=False): + def _find_blacklist(self, context, criterion, one=False, + marker=None, limit=None, sort_key=None, sort_dir=None): try: - return self._find(models.Blacklists, context, criterion, one) + return self._find(models.Blacklists, context, criterion, one=one, + marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) except exceptions.NotFound: raise exceptions.BlacklistNotFound() @@ -567,8 +630,11 @@ class SQLAlchemyStorage(base.Storage): return dict(blacklist) - def find_blacklists(self, context, criterion=None): - blacklists = self._find_blacklist(context, criterion) + def find_blacklists(self, context, criterion=None, + marker=None, limit=None, sort_key=None, sort_dir=None): + blacklists = self._find_blacklist( + context, criterion, marker=marker, limit=limit, sort_key=sort_key, + sort_dir=sort_dir) return [dict(b) for b in blacklists] diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index cff711921..41f4ec074 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -968,8 +968,8 @@ class CentralServiceTest(CentralTestCase): self.admin_context, criterion) self.assertEqual(len(recordsets), 2) - self.assertEqual(recordsets[0]['name'], 'mail.%s' % domain['name']) - self.assertEqual(recordsets[1]['name'], 'www.%s' % domain['name']) + self.assertEqual(recordsets[0]['name'], 'www.%s' % domain['name']) + self.assertEqual(recordsets[1]['name'], 'mail.%s' % domain['name']) def test_find_recordset(self): domain = self.create_domain() diff --git a/designate/tests/test_storage/__init__.py b/designate/tests/test_storage/__init__.py index 2668f6e4e..076e49fc4 100644 --- a/designate/tests/test_storage/__init__.py +++ b/designate/tests/test_storage/__init__.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. import testtools +import uuid from designate.openstack.common import log as logging from designate import exceptions @@ -77,6 +78,26 @@ class StorageTestCase(object): return fixture, self.storage.create_record( context, domain['id'], recordset['id'], fixture) + def _ensure_paging(self, data, method): + """ + Given an array of created items we iterate through them making sure + they match up to things returned by paged results. + """ + found = method(self.admin_context, limit=5) + x = 0 + for i in xrange(0, len(data)): + self.assertEqual(data[i]['id'], found[x]['id']) + x += 1 + if x == len(found): + x = 0 + found = method( + self.admin_context, limit=5, marker=found[-1:][0]['id']) + + def test_paging_marker_not_found(self): + with testtools.ExpectedException(exceptions.MarkerNotFound): + self.storage.find_servers( + self.admin_context, marker=str(uuid.uuid4()), limit=5) + # Quota Tests def test_create_quota(self): values = self.get_quota_fixture() @@ -270,20 +291,20 @@ class StorageTestCase(object): self.assertEqual(actual, []) # Create a single server - _, server_one = self.create_server() + _, server = self.create_server() actual = self.storage.find_servers(self.admin_context) self.assertEqual(len(actual), 1) + self.assertEqual(str(actual[0]['name']), str(server['name'])) - self.assertEqual(str(actual[0]['name']), str(server_one['name'])) + # Order of found items later will be reverse of the order they are + # created + created = [self.create_server( + values={'name': 'ns%s.example.org.' % i})[1] + for i in xrange(10, 20)] + created.insert(0, server) - # Create a second server - _, server_two = self.create_server(fixture=1) - - actual = self.storage.find_servers(self.admin_context) - self.assertEqual(len(actual), 2) - - self.assertEqual(str(actual[1]['name']), str(server_two['name'])) + self._ensure_paging(created, self.storage.find_servers) def test_find_servers_criterion(self): _, server_one = self.create_server(0) @@ -388,24 +409,22 @@ class StorageTestCase(object): self.assertEqual(actual, []) # Create a single tsigkey - _, tsigkey_one = self.create_tsigkey() + _, tsig = self.create_tsigkey() actual = self.storage.find_tsigkeys(self.admin_context) self.assertEqual(len(actual), 1) - self.assertEqual(actual[0]['name'], tsigkey_one['name']) - self.assertEqual(actual[0]['algorithm'], tsigkey_one['algorithm']) - self.assertEqual(actual[0]['secret'], tsigkey_one['secret']) + self.assertEqual(actual[0]['name'], tsig['name']) + self.assertEqual(actual[0]['algorithm'], tsig['algorithm']) + self.assertEqual(actual[0]['secret'], tsig['secret']) - # Create a second tsigkey - _, tsigkey_two = self.create_tsigkey(fixture=1) + # Order of found items later will be reverse of the order they are + # created + created = [self.create_tsigkey(values={'name': 'tsig%s.' % i})[1] + for i in xrange(10, 20)] + created.insert(0, tsig) - actual = self.storage.find_tsigkeys(self.admin_context) - self.assertEqual(len(actual), 2) - - self.assertEqual(actual[1]['name'], tsigkey_two['name']) - self.assertEqual(actual[1]['algorithm'], tsigkey_two['algorithm']) - self.assertEqual(actual[1]['secret'], tsigkey_two['secret']) + self._ensure_paging(created, self.storage.find_tsigkeys) def test_find_tsigkeys_criterion(self): _, tsigkey_one = self.create_tsigkey(fixture=0) @@ -582,19 +601,21 @@ class StorageTestCase(object): self.assertEqual(actual, []) # Create a single domain - fixture_one, domain_one = self.create_domain() + fixture_one, domain = self.create_domain() actual = self.storage.find_domains(self.admin_context) self.assertEqual(len(actual), 1) - self.assertEqual(actual[0]['name'], domain_one['name']) - self.assertEqual(actual[0]['email'], domain_one['email']) + self.assertEqual(actual[0]['name'], domain['name']) + self.assertEqual(actual[0]['email'], domain['email']) - # Create a second domain - self.create_domain(fixture=1) + # Order of found items later will be reverse of the order they are + # created + created = [self.create_domain(values={'name': 'x%s.org.' % i})[1] + for i in xrange(10, 20)] + created.insert(0, domain) - actual = self.storage.find_domains(self.admin_context) - self.assertEqual(len(actual), 2) + self._ensure_paging(created, self.storage.find_domains) def test_find_domains_criterion(self): _, domain_one = self.create_domain(0) @@ -812,14 +833,14 @@ class StorageTestCase(object): self.assertEqual(actual[0]['name'], recordset_one['name']) self.assertEqual(actual[0]['type'], recordset_one['type']) - # Create a second recordset - _, recordset_two = self.create_recordset(domain, fixture=1) + # Order of found items later will be reverse of the order they are + # created + created = [self.create_recordset( + domain, values={'name': 'test%s' % i + '.%s'})[1] + for i in xrange(10, 20)] + created.insert(0, recordset_one) - actual = self.storage.find_recordsets(self.admin_context, criterion) - self.assertEqual(len(actual), 2) - - self.assertEqual(actual[1]['name'], recordset_two['name']) - self.assertEqual(actual[1]['type'], recordset_two['type']) + self._ensure_paging(created, self.storage.find_recordsets) def test_find_recordsets_criterion(self): _, domain = self.create_domain() @@ -1016,22 +1037,23 @@ class StorageTestCase(object): self.assertEqual(actual, []) # Create a single record - _, record_one = self.create_record(domain, recordset, fixture=0) + _, record = self.create_record(domain, recordset, fixture=0) actual = self.storage.find_records(self.admin_context, criterion) self.assertEqual(len(actual), 1) - self.assertEqual(actual[0]['data'], record_one['data']) - self.assertIn('status', record_one) + self.assertEqual(actual[0]['data'], record['data']) + self.assertIn('status', record) - # Create a second record - _, record_two = self.create_record(domain, recordset, fixture=1) + # Order of found items later will be reverse of the order they are + # created + created = [self.create_record( + domain, recordset, + values={'data': '192.0.0.%s' % i})[1] + for i in xrange(10, 20)] + created.insert(0, record) - actual = self.storage.find_records(self.admin_context, criterion) - self.assertEqual(len(actual), 2) - - self.assertEqual(actual[1]['data'], record_two['data']) - self.assertIn('status', record_two) + self._ensure_paging(created, self.storage.find_records) def test_find_records_criterion(self): _, domain = self.create_domain() diff --git a/designate/tests/test_storage/test_api.py b/designate/tests/test_storage/test_api.py index 9062852ee..03b55e8bb 100644 --- a/designate/tests/test_storage/test_api.py +++ b/designate/tests/test_storage/test_api.py @@ -90,12 +90,20 @@ class StorageAPITest(TestCase): def test_find_quotas(self): context = mock.sentinel.context criterion = mock.sentinel.criterion + marker = mock.sentinel.marker + limit = mock.sentinel.limit + sort_key = mock.sentinel.sort_key + sort_dir = mock.sentinel.sort_dir quota = mock.sentinel.quota self._set_side_effect('find_quotas', [[quota]]) - result = self.storage_api.find_quotas(context, criterion) - self._assert_called_with('find_quotas', context, criterion) + result = self.storage_api.find_quotas( + context, criterion, + marker, limit, sort_key, sort_dir) + self._assert_called_with( + 'find_quotas', context, criterion, + marker, limit, sort_key, sort_dir) self.assertEqual([quota], result) def test_find_quota(self): @@ -198,12 +206,21 @@ class StorageAPITest(TestCase): def test_find_servers(self): context = mock.sentinel.context criterion = mock.sentinel.criterion + marker = mock.sentinel.marker + limit = mock.sentinel.limit + sort_key = mock.sentinel.sort_key + sort_dir = mock.sentinel.sort_dir + server = mock.sentinel.server self._set_side_effect('find_servers', [[server]]) - result = self.storage_api.find_servers(context, criterion) - self._assert_called_with('find_servers', context, criterion) + result = self.storage_api.find_servers( + context, criterion, + marker, limit, sort_key, sort_dir) + self._assert_called_with( + 'find_servers', context, criterion, + marker, limit, sort_key, sort_dir) self.assertEqual([server], result) def test_find_server(self): @@ -306,12 +323,19 @@ class StorageAPITest(TestCase): def test_find_tsigkeys(self): context = mock.sentinel.context criterion = mock.sentinel.criterion + marker = mock.sentinel.marker + limit = mock.sentinel.limit + sort_key = mock.sentinel.sort_key + sort_dir = mock.sentinel.sort_dir tsigkey = mock.sentinel.tsigkey self._set_side_effect('find_tsigkeys', [[tsigkey]]) - result = self.storage_api.find_tsigkeys(context, criterion) - self._assert_called_with('find_tsigkeys', context, criterion) + result = self.storage_api.find_tsigkeys( + context, criterion, marker, limit, sort_key, sort_dir) + self._assert_called_with( + 'find_tsigkeys', context, criterion, + marker, limit, sort_key, sort_dir) self.assertEqual([tsigkey], result) def test_find_tsigkey(self): @@ -444,12 +468,20 @@ class StorageAPITest(TestCase): def test_find_domains(self): context = mock.sentinel.context criterion = mock.sentinel.criterion + marker = mock.sentinel.marker + limit = mock.sentinel.limit + sort_key = mock.sentinel.sort_key + sort_dir = mock.sentinel.sort_dir domain = mock.sentinel.domain self._set_side_effect('find_domains', [[domain]]) - result = self.storage_api.find_domains(context, criterion) - self._assert_called_with('find_domains', context, criterion) + result = self.storage_api.find_domains( + context, criterion, + marker, limit, sort_key, sort_dir) + self._assert_called_with( + 'find_domains', context, criterion, + marker, limit, sort_key, sort_dir) self.assertEqual([domain], result) def test_find_domain(self): @@ -552,12 +584,20 @@ class StorageAPITest(TestCase): def test_find_recordsets(self): context = mock.sentinel.context criterion = mock.sentinel.criterion + marker = mock.sentinel.marker + limit = mock.sentinel.limit + sort_key = mock.sentinel.sort_key + sort_dir = mock.sentinel.sort_dir recordset = mock.sentinel.recordset self._set_side_effect('find_recordsets', [[recordset]]) - result = self.storage_api.find_recordsets(context, criterion) - self._assert_called_with('find_recordsets', context, criterion) + result = self.storage_api.find_recordsets( + context, criterion, + marker, limit, sort_key, sort_dir) + self._assert_called_with( + 'find_recordsets', context, criterion, + marker, limit, sort_key, sort_dir) self.assertEqual([recordset], result) def test_find_recordset(self): @@ -660,12 +700,20 @@ class StorageAPITest(TestCase): def test_find_records(self): context = mock.sentinel.context criterion = mock.sentinel.criterion + marker = mock.sentinel.marker + limit = mock.sentinel.limit + sort_key = mock.sentinel.sort_key + sort_dir = mock.sentinel.sort_dir record = mock.sentinel.record self._set_side_effect('find_records', [[record]]) - result = self.storage_api.find_records(context, criterion) - self._assert_called_with('find_records', context, criterion) + result = self.storage_api.find_records( + context, criterion, + marker, limit, sort_key, sort_dir) + self._assert_called_with( + 'find_records', context, criterion, + marker, limit, sort_key, sort_dir) self.assertEqual([record], result) diff --git a/openstack-common.conf b/openstack-common.conf index 818fa05f5..9edf292bb 100644 --- a/openstack-common.conf +++ b/openstack-common.conf @@ -2,6 +2,8 @@ # The list of modules to copy from oslo-incubator.git module=context +module=db +module=db.sqlalchemy module=excutils module=fixture module=gettextutils diff --git a/requirements.txt b/requirements.txt index 66200c250..ad853368e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ Flask>=0.10,<1.0 iso8601>=0.1.8 jsonschema>=2.0.0,<3.0.0 kombu>=2.4.8 +lockfile>=0.8 netaddr>=0.7.6 oslo.config>=1.2.0 Paste