Sync db, db.sqlalchemy, gettextutils from oslo-incubator 6ba44fd

This change sync's oslo-incubator's db module from commit hash
6ba44fd7f9d39a7930defb4e14c37b8b1046cbcb

 $ python update.py --nodeps --base keystone \
    --dest-dir ../keystone \
    --modules db,db.sqlalchemy,gettextutils

- Config options were moved from db.sqlalchemy.session to db.options
- db.sqlalchemy.session doesn't provide get_session, get_engine, or
  cleanup functions.
- db.sqlalchemy.migration.db_version() requires an engine parameter

Closes-Bug: #1227321

Change-Id: I742cef9dab68d9eed977df0039736cfe67ca493c
This commit is contained in:
Brant Knudson 2014-02-21 16:22:06 -06:00 committed by Dolph Mathews
parent 5d65f0f040
commit 8f7b87b2a7
33 changed files with 1186 additions and 568 deletions

View File

@ -303,17 +303,6 @@
#keystone_ec2_url=http://localhost:5000/v2.0/ec2tokens
#
# Options defined in keystone.openstack.common.db.sqlalchemy.session
#
# The file name to use with SQLite (string value)
#sqlite_db=keystone.sqlite
# If True, SQLite uses synchronous mode (boolean value)
#sqlite_synchronous=true
#
# Options defined in keystone.openstack.common.eventlet_backdoor
#
@ -572,28 +561,30 @@
[database]
#
# Options defined in keystone.openstack.common.db.api
# Options defined in keystone.openstack.common.db.options
#
# The file name to use with SQLite (string value)
#sqlite_db=keystone.sqlite
# If True, SQLite uses synchronous mode (boolean value)
#sqlite_synchronous=true
# The backend to use for db (string value)
# Deprecated group/name - [DEFAULT]/db_backend
#backend=sqlalchemy
#
# Options defined in keystone.openstack.common.db.sqlalchemy.session
#
# The SQLAlchemy connection string used to connect to the
# database (string value)
# Deprecated group/name - [DEFAULT]/sql_connection
# Deprecated group/name - [DATABASE]/sql_connection
# Deprecated group/name - [sql]/connection
#connection=sqlite:////keystone/openstack/common/db/$sqlite_db
#connection=<None>
# The SQLAlchemy connection string used to connect to the
# slave database (string value)
#slave_connection=
# The SQL mode to be used for MySQL sessions (default is
# empty, meaning do not override any server-side SQL mode
# setting) (string value)
#mysql_sql_mode=<None>
# Timeout before idle sql connections are reaped (integer
# value)
@ -647,6 +638,25 @@
# Deprecated group/name - [DATABASE]/sqlalchemy_pool_timeout
#pool_timeout=<None>
# Enable the experimental use of database reconnect on
# connection lost (boolean value)
#use_db_reconnect=false
# seconds between db connection retries (integer value)
#db_retry_interval=1
# Whether to increase interval between db connection retries,
# up to db_max_retry_interval (boolean value)
#db_inc_retry_interval=true
# max seconds between db connection retries, if
# db_inc_retry_interval is enabled (integer value)
#db_max_retry_interval=10
# maximum db connection retries before error is raised.
# (setting -1 implies an infinite retry count) (integer value)
#db_max_retries=20
[ec2]

View File

@ -21,7 +21,6 @@ from keystone.common.sql import migration_helpers
from keystone import config
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
CONF = config.CONF
@ -39,7 +38,8 @@ class Assignment(assignment.Driver):
# Internal interface to manage the database
def db_sync(self, version=None):
migration.db_sync(
migration_helpers.find_migrate_repo(), version=version)
sql.get_engine(), migration_helpers.find_migrate_repo(),
version=version)
def _get_project(self, session, project_id):
project_ref = session.query(Project).get(project_id)
@ -81,7 +81,7 @@ class Assignment(assignment.Driver):
# We aren't given a session when called by the manager directly.
if session is None:
session = db_session.get_session()
session = sql.get_session()
q = session.query(RoleAssignment)
q = q.filter_by(actor_id=user_id or group_id)
@ -296,7 +296,7 @@ class Assignment(assignment.Driver):
Role.id == RoleAssignment.role_id,
RoleAssignment.actor_id.in_(group_ids))
session = db_session.get_session()
session = sql.get_session()
with session.begin():
query = session.query(Role).filter(
sql_constraints).distinct()
@ -313,7 +313,7 @@ class Assignment(assignment.Driver):
entity.id == RoleAssignment.target_id,
RoleAssignment.actor_id.in_(group_ids))
session = db_session.get_session()
session = sql.get_session()
with session.begin():
query = session.query(entity).filter(
group_sql_conditions)

View File

@ -22,7 +22,6 @@ from keystone.common.sql import migration_helpers
from keystone import config
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
CONF = config.CONF
@ -82,11 +81,12 @@ class Endpoint(sql.ModelBase, sql.DictBase):
class Catalog(catalog.Driver):
def db_sync(self, version=None):
migration.db_sync(
migration_helpers.find_migrate_repo(), version=version)
sql.get_engine(), migration_helpers.find_migrate_repo(),
version=version)
# Regions
def list_regions(self):
session = db_session.get_session()
session = sql.get_session()
regions = session.query(Region).all()
return [s.to_dict() for s in list(regions)]
@ -120,11 +120,11 @@ class Catalog(catalog.Driver):
self._get_region(session, parent_region_id)
def get_region(self, region_id):
session = db_session.get_session()
session = sql.get_session()
return self._get_region(session, region_id).to_dict()
def delete_region(self, region_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_region(session, region_id)
self._delete_child_regions(session, region_id)
@ -133,7 +133,7 @@ class Catalog(catalog.Driver):
session.flush()
def create_region(self, region_ref):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
self._check_parent_region(session, region_ref)
region = Region.from_dict(region_ref)
@ -142,7 +142,7 @@ class Catalog(catalog.Driver):
return region.to_dict()
def update_region(self, region_id, region_ref):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
self._check_parent_region(session, region_ref)
ref = self._get_region(session, region_id)
@ -158,7 +158,7 @@ class Catalog(catalog.Driver):
# Services
@sql.truncated
def list_services(self, hints):
session = db_session.get_session()
session = sql.get_session()
services = session.query(Service)
services = sql.filter_limit_query(Service, services, hints)
return [s.to_dict() for s in list(services)]
@ -170,25 +170,25 @@ class Catalog(catalog.Driver):
return ref
def get_service(self, service_id):
session = db_session.get_session()
session = sql.get_session()
return self._get_service(session, service_id).to_dict()
def delete_service(self, service_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_service(session, service_id)
session.query(Endpoint).filter_by(service_id=service_id).delete()
session.delete(ref)
def create_service(self, service_id, service_ref):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
service = Service.from_dict(service_ref)
session.add(service)
return service.to_dict()
def update_service(self, service_id, service_ref):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_service(session, service_id)
old_dict = ref.to_dict()
@ -202,7 +202,7 @@ class Catalog(catalog.Driver):
# Endpoints
def create_endpoint(self, endpoint_id, endpoint_ref):
session = db_session.get_session()
session = sql.get_session()
self.get_service(endpoint_ref['service_id'])
new_endpoint = Endpoint.from_dict(endpoint_ref)
with session.begin():
@ -210,7 +210,7 @@ class Catalog(catalog.Driver):
return new_endpoint.to_dict()
def delete_endpoint(self, endpoint_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_endpoint(session, endpoint_id)
session.delete(ref)
@ -222,18 +222,18 @@ class Catalog(catalog.Driver):
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
def get_endpoint(self, endpoint_id):
session = db_session.get_session()
session = sql.get_session()
return self._get_endpoint(session, endpoint_id).to_dict()
@sql.truncated
def list_endpoints(self, hints):
session = db_session.get_session()
session = sql.get_session()
endpoints = session.query(Endpoint)
endpoints = sql.filter_limit_query(Endpoint, endpoints, hints)
return [e.to_dict() for e in list(endpoints)]
def update_endpoint(self, endpoint_id, endpoint_ref):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_endpoint(session, endpoint_id)
old_dict = ref.to_dict()
@ -250,7 +250,7 @@ class Catalog(catalog.Driver):
d.update({'tenant_id': tenant_id,
'user_id': user_id})
session = db_session.get_session()
session = sql.get_session()
endpoints = (session.query(Endpoint).
options(sql.joinedload(Endpoint.service)).
all())
@ -278,7 +278,7 @@ class Catalog(catalog.Driver):
d.update({'tenant_id': tenant_id,
'user_id': user_id})
session = db_session.get_session()
session = sql.get_session()
services = (session.query(Service).
options(sql.joinedload(Service.endpoints)).
all())

View File

@ -21,6 +21,7 @@ CONF() because it sets up configuration options.
import contextlib
import functools
from oslo.config import cfg
import six
import sqlalchemy as sql
from sqlalchemy.ext import declarative
@ -30,11 +31,14 @@ from sqlalchemy import types as sql_types
from keystone.common import utils
from keystone import exception
from keystone.openstack.common.db import exception as db_exception
from keystone.openstack.common.db import options as db_options
from keystone.openstack.common.db.sqlalchemy import models
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common import jsonutils
CONF = cfg.CONF
ModelBase = declarative.declarative_base()
@ -64,7 +68,7 @@ and_ = sql.and_
def initialize():
"""Initialize the module."""
db_session.set_defaults(
db_options.set_defaults(
sql_connection="sqlite:///keystone.db",
sqlite_db="keystone.db")
@ -162,10 +166,38 @@ class ModelDictMixin(object):
return dict((name, getattr(self, name)) for name in names)
_engine_facade = None
def _get_engine_facade():
global _engine_facade
if not _engine_facade:
_engine_facade = db_session.EngineFacade(
CONF.database.connection,
**dict(six.iteritems(CONF.database)))
return _engine_facade
def cleanup():
global _engine_facade
_engine_facade = None
def get_engine():
return _get_engine_facade().get_engine()
def get_session(expire_on_commit=False):
return _get_engine_facade().get_session(expire_on_commit=expire_on_commit)
@contextlib.contextmanager
def transaction(expire_on_commit=False):
"""Return a SQLAlchemy session in a scoped transaction."""
session = db_session.get_session(expire_on_commit=expire_on_commit)
session = get_session(expire_on_commit=expire_on_commit)
with session.begin():
yield session

View File

@ -126,7 +126,7 @@ def sync_database_to_version(extension=None, version=None):
try:
abs_path = find_migrate_repo(package)
try:
migration.db_version_control(abs_path)
migration.db_version_control(sql.get_engine(), abs_path)
# Register the repo with the version control API
# If it already knows about the repo, it will throw
# an exception that we can safely ignore
@ -135,7 +135,7 @@ def sync_database_to_version(extension=None, version=None):
except exception.MigrationNotProvided as e:
print(e)
sys.exit(1)
migration.db_sync(abs_path, version=version)
migration.db_sync(sql.get_engine(), abs_path, version=version)
def print_db_version(extension=None):

View File

@ -17,7 +17,6 @@ from keystone.common.sql import migration_helpers
from keystone.contrib import endpoint_filter
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
class ProjectEndpoint(sql.ModelBase, sql.DictBase):
@ -37,11 +36,11 @@ class EndpointFilter(object):
def db_sync(self, version=None):
abs_path = migration_helpers.find_migrate_repo(endpoint_filter)
migration.db_sync(abs_path, version=version)
migration.db_sync(sql.get_engine(), abs_path, version=version)
@sql.handle_conflicts(conflict_type='project_endpoint')
def add_endpoint_to_project(self, endpoint_id, project_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id,
project_id=project_id)
@ -58,25 +57,25 @@ class EndpointFilter(object):
return endpoint_filter_ref
def check_endpoint_in_project(self, endpoint_id, project_id):
session = db_session.get_session()
session = sql.get_session()
self._get_project_endpoint_ref(session, endpoint_id, project_id)
def remove_endpoint_from_project(self, endpoint_id, project_id):
session = db_session.get_session()
session = sql.get_session()
endpoint_filter_ref = self._get_project_endpoint_ref(
session, endpoint_id, project_id)
with session.begin():
session.delete(endpoint_filter_ref)
def list_endpoints_for_project(self, project_id):
session = db_session.get_session()
session = sql.get_session()
query = session.query(ProjectEndpoint)
query = query.filter_by(project_id=project_id)
endpoint_filter_refs = query.all()
return endpoint_filter_refs
def list_projects_for_endpoint(self, endpoint_id):
session = db_session.get_session()
session = sql.get_session()
query = session.query(ProjectEndpoint)
query = query.filter_by(endpoint_id=endpoint_id)
endpoint_filter_refs = query.all()

View File

@ -18,7 +18,6 @@ from keystone.contrib import federation
from keystone.contrib.federation import core
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common import jsonutils
@ -92,12 +91,12 @@ class Federation(core.Driver):
def db_sync(self):
abs_path = migration_helpers.find_migrate_repo(federation)
migration.db_sync(abs_path)
migration.db_sync(sql.get_engine(), abs_path)
# Identity Provider CRUD
@sql.handle_conflicts(conflict_type='identity_provider')
def create_idp(self, idp_id, idp):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
idp['id'] = idp_id
idp_ref = IdentityProviderModel.from_dict(idp)
@ -105,7 +104,7 @@ class Federation(core.Driver):
return idp_ref.to_dict()
def delete_idp(self, idp_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
idp_ref = self._get_idp(session, idp_id)
q = session.query(IdentityProviderModel)
@ -120,19 +119,19 @@ class Federation(core.Driver):
return idp_ref
def list_idps(self):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
idps = session.query(IdentityProviderModel)
idps_list = [idp.to_dict() for idp in idps]
return idps_list
def get_idp(self, idp_id):
session = db_session.get_session()
session = sql.get_session()
idp_ref = self._get_idp(session, idp_id)
return idp_ref.to_dict()
def update_idp(self, idp_id, idp):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
idp_ref = self._get_idp(session, idp_id)
old_idp = idp_ref.to_dict()
@ -155,7 +154,7 @@ class Federation(core.Driver):
@sql.handle_conflicts(conflict_type='federation_protocol')
def create_protocol(self, idp_id, protocol_id, protocol):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
self._get_idp(session, idp_id)
protocol['id'] = protocol_id
@ -165,7 +164,7 @@ class Federation(core.Driver):
return protocol_ref.to_dict()
def update_protocol(self, idp_id, protocol_id, protocol):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
proto_ref = self._get_protocol(session, idp_id, protocol_id)
old_proto = proto_ref.to_dict()
@ -176,19 +175,19 @@ class Federation(core.Driver):
return proto_ref.to_dict()
def get_protocol(self, idp_id, protocol_id):
session = db_session.get_session()
session = sql.get_session()
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
return protocol_ref.to_dict()
def list_protocols(self, idp_id):
session = db_session.get_session()
session = sql.get_session()
q = session.query(FederationProtocolModel)
q = q.filter_by(idp_id=idp_id)
protocols = [protocol.to_dict() for protocol in q]
return protocols
def delete_protocol(self, idp_id, protocol_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
key_ref = self._get_protocol(session, idp_id, protocol_id)
q = session.query(FederationProtocolModel)
@ -205,7 +204,7 @@ class Federation(core.Driver):
@sql.handle_conflicts(conflict_type='mapping')
def create_mapping(self, mapping_id, mapping):
session = db_session.get_session()
session = sql.get_session()
ref = {}
ref['id'] = mapping_id
ref['rules'] = jsonutils.dumps(mapping.get('rules'))
@ -215,19 +214,19 @@ class Federation(core.Driver):
return mapping_ref.to_dict()
def delete_mapping(self, mapping_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
mapping_ref = self._get_mapping(session, mapping_id)
session.delete(mapping_ref)
def list_mappings(self):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
mappings = session.query(MappingModel)
return [x.to_dict() for x in mappings]
def get_mapping(self, mapping_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
mapping_ref = self._get_mapping(session, mapping_id)
return mapping_ref.to_dict()
@ -237,7 +236,7 @@ class Federation(core.Driver):
ref = {}
ref['id'] = mapping_id
ref['rules'] = jsonutils.dumps(mapping.get('rules'))
session = db_session.get_session()
session = sql.get_session()
with session.begin():
mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict()
@ -248,7 +247,7 @@ class Federation(core.Driver):
return mapping_ref.to_dict()
def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
protocol_ref = self._get_protocol(session, idp_id, protocol_id)
mapping_id = protocol_ref.mapping_id

View File

@ -24,7 +24,6 @@ from keystone.contrib import oauth1
from keystone.contrib.oauth1 import core
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common import jsonutils
from keystone.openstack.common import timeutils
@ -86,7 +85,8 @@ class AccessToken(sql.ModelBase, sql.DictBase):
class OAuth1(object):
def db_sync(self):
migration.db_sync(migration_helpers.find_migrate_repo(oauth1))
migration.db_sync(sql.get_engine(),
migration_helpers.find_migrate_repo(oauth1))
def _get_consumer(self, session, consumer_id):
consumer_ref = session.query(Consumer).get(consumer_id)
@ -95,7 +95,7 @@ class OAuth1(object):
return consumer_ref
def get_consumer_with_secret(self, consumer_id):
session = db_session.get_session()
session = sql.get_session()
consumer_ref = self._get_consumer(session, consumer_id)
return consumer_ref.to_dict()
@ -107,7 +107,7 @@ class OAuth1(object):
consumer['secret'] = uuid.uuid4().hex
if not consumer.get('description'):
consumer['description'] = None
session = db_session.get_session()
session = sql.get_session()
with session.begin():
consumer_ref = Consumer.from_dict(consumer)
session.add(consumer_ref)
@ -143,19 +143,19 @@ class OAuth1(object):
session.delete(token_ref)
def delete_consumer(self, consumer_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
self._delete_request_tokens(session, consumer_id)
self._delete_access_tokens(session, consumer_id)
self._delete_consumer(session, consumer_id)
def list_consumers(self):
session = db_session.get_session()
session = sql.get_session()
cons = session.query(Consumer)
return [core.filter_consumer(x.to_dict()) for x in cons]
def update_consumer(self, consumer_id, consumer):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
consumer_ref = self._get_consumer(session, consumer_id)
old_consumer_dict = consumer_ref.to_dict()
@ -186,7 +186,7 @@ class OAuth1(object):
ref['role_ids'] = None
ref['consumer_id'] = consumer_id
ref['expires_at'] = expiry_date
session = db_session.get_session()
session = sql.get_session()
with session.begin():
token_ref = RequestToken.from_dict(ref)
session.add(token_ref)
@ -199,13 +199,13 @@ class OAuth1(object):
return token_ref
def get_request_token(self, request_token_id):
session = db_session.get_session()
session = sql.get_session()
token_ref = self._get_request_token(session, request_token_id)
return token_ref.to_dict()
def authorize_request_token(self, request_token_id, user_id,
role_ids):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
token_ref = self._get_request_token(session, request_token_id)
token_dict = token_ref.to_dict()
@ -227,7 +227,7 @@ class OAuth1(object):
access_token_id = uuid.uuid4().hex
if access_token_secret is None:
access_token_secret = uuid.uuid4().hex
session = db_session.get_session()
session = sql.get_session()
with session.begin():
req_token_ref = self._get_request_token(session, request_token_id)
token_dict = req_token_ref.to_dict()
@ -265,18 +265,18 @@ class OAuth1(object):
return token_ref
def get_access_token(self, access_token_id):
session = db_session.get_session()
session = sql.get_session()
token_ref = self._get_access_token(session, access_token_id)
return token_ref.to_dict()
def list_access_tokens(self, user_id):
session = db_session.get_session()
session = sql.get_session()
q = session.query(AccessToken)
user_auths = q.filter_by(authorizing_user_id=user_id)
return [core.filter_token(x.to_dict()) for x in user_auths]
def delete_access_token(self, user_id, access_token_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
token_ref = self._get_access_token(session, access_token_id)
token_dict = token_ref.to_dict()

View File

@ -17,8 +17,6 @@ from keystone.common import sql
from keystone.contrib import revoke
from keystone.contrib.revoke import model
from keystone.openstack.common.db.sqlalchemy import session as db_session
CONF = config.CONF
@ -64,7 +62,7 @@ class Revoke(revoke.Driver):
def _prune_expired_events(self):
oldest = revoke.revoked_before_cutoff_time()
session = db_session.get_session()
session = sql.get_session()
dialect = session.bind.dialect.name
batch_size = self._flush_batch_size(dialect)
if batch_size > 0:
@ -86,7 +84,7 @@ class Revoke(revoke.Driver):
def get_events(self, last_fetch=None):
self._prune_expired_events()
session = db_session.get_session()
session = sql.get_session()
query = session.query(RevocationEvent).order_by(
RevocationEvent.revoked_at)
@ -108,6 +106,6 @@ class Revoke(revoke.Driver):
kwargs[attr] = getattr(event, attr)
kwargs['id'] = uuid.uuid4().hex
record = RevocationEvent(**kwargs)
session = db_session.get_session()
session = sql.get_session()
with session.begin():
session.add(record)

View File

@ -17,7 +17,6 @@ from keystone.common.sql import migration_helpers
from keystone import credential
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
class CredentialModel(sql.ModelBase, sql.DictBase):
@ -36,20 +35,21 @@ class Credential(credential.Driver):
# Internal interface to manage the database
def db_sync(self, version=None):
migration.db_sync(
migration_helpers.find_migrate_repo(), version=version)
sql.get_engine(), migration_helpers.find_migrate_repo(),
version=version)
# credential crud
@sql.handle_conflicts(conflict_type='credential')
def create_credential(self, credential_id, credential):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = CredentialModel.from_dict(credential)
session.add(ref)
return ref.to_dict()
def list_credentials(self, **filters):
session = db_session.get_session()
session = sql.get_session()
query = session.query(CredentialModel)
if 'user_id' in filters:
query = query.filter_by(user_id=filters.get('user_id'))
@ -63,12 +63,12 @@ class Credential(credential.Driver):
return ref
def get_credential(self, credential_id):
session = db_session.get_session()
session = sql.get_session()
return self._get_credential(session, credential_id).to_dict()
@sql.handle_conflicts(conflict_type='credential')
def update_credential(self, credential_id, credential):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_credential(session, credential_id)
old_dict = ref.to_dict()
@ -82,14 +82,14 @@ class Credential(credential.Driver):
return ref.to_dict()
def delete_credential(self, credential_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_credential(session, credential_id)
session.delete(ref)
def delete_credentials_for_project(self, project_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
query = session.query(CredentialModel)
@ -97,7 +97,7 @@ class Credential(credential.Driver):
query.delete()
def delete_credentials_for_user(self, user_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
query = session.query(CredentialModel)

View File

@ -19,7 +19,6 @@ from keystone.common import utils
from keystone import exception
from keystone import identity
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
# Import assignment sql to ensure that the models defined in there are
# available for the reference from User and Group to Domain.id.
@ -82,7 +81,8 @@ class Identity(identity.Driver):
# Internal interface to manage the database
def db_sync(self, version=None):
migration.db_sync(
migration_helpers.find_migrate_repo(), version=version)
sql.get_engine(), migration_helpers.find_migrate_repo(),
version=version)
def _check_password(self, password, user_ref):
"""Check the specified password against the data store.
@ -103,7 +103,7 @@ class Identity(identity.Driver):
# Identity interface
def authenticate(self, user_id, password):
session = db_session.get_session()
session = sql.get_session()
user_ref = None
try:
user_ref = self._get_user(session, user_id)
@ -118,7 +118,7 @@ class Identity(identity.Driver):
@sql.handle_conflicts(conflict_type='user')
def create_user(self, user_id, user):
user = utils.hash_user_password(user)
session = db_session.get_session()
session = sql.get_session()
with session.begin():
user_ref = User.from_dict(user)
session.add(user_ref)
@ -126,7 +126,7 @@ class Identity(identity.Driver):
@sql.truncated
def list_users(self, hints):
session = db_session.get_session()
session = sql.get_session()
query = session.query(User)
user_refs = sql.filter_limit_query(User, query, hints)
return [identity.filter_user(x.to_dict()) for x in user_refs]
@ -138,11 +138,11 @@ class Identity(identity.Driver):
return user_ref
def get_user(self, user_id):
session = db_session.get_session()
session = sql.get_session()
return identity.filter_user(self._get_user(session, user_id).to_dict())
def get_user_by_name(self, user_name, domain_id):
session = db_session.get_session()
session = sql.get_session()
query = session.query(User)
query = query.filter_by(name=user_name)
query = query.filter_by(domain_id=domain_id)
@ -154,7 +154,7 @@ class Identity(identity.Driver):
@sql.handle_conflicts(conflict_type='user')
def update_user(self, user_id, user):
session = db_session.get_session()
session = sql.get_session()
if 'id' in user and user_id != user['id']:
raise exception.ValidationError(_('Cannot change user ID'))
@ -172,7 +172,7 @@ class Identity(identity.Driver):
return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
def add_user_to_group(self, user_id, group_id):
session = db_session.get_session()
session = sql.get_session()
self.get_group(group_id)
self.get_user(user_id)
query = session.query(UserGroupMembership)
@ -187,7 +187,7 @@ class Identity(identity.Driver):
group_id=group_id))
def check_user_in_group(self, user_id, group_id):
session = db_session.get_session()
session = sql.get_session()
self.get_group(group_id)
self.get_user(user_id)
query = session.query(UserGroupMembership)
@ -197,7 +197,7 @@ class Identity(identity.Driver):
raise exception.NotFound(_('User not found in group'))
def remove_user_from_group(self, user_id, group_id):
session = db_session.get_session()
session = sql.get_session()
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
query = session.query(UserGroupMembership)
@ -215,7 +215,7 @@ class Identity(identity.Driver):
# occurrence to filter on more than the user_id already being used
# here, this is left as future enhancement and until then we leave
# it for the controller to do for us.
session = db_session.get_session()
session = sql.get_session()
self.get_user(user_id)
query = session.query(Group).join(UserGroupMembership)
query = query.filter(UserGroupMembership.user_id == user_id)
@ -227,7 +227,7 @@ class Identity(identity.Driver):
# occurrence to filter on more than the group_id already being used
# here, this is left as future enhancement and until then we leave
# it for the controller to do for us.
session = db_session.get_session()
session = sql.get_session()
self.get_group(group_id)
query = session.query(User).join(UserGroupMembership)
query = query.filter(UserGroupMembership.group_id == group_id)
@ -235,7 +235,7 @@ class Identity(identity.Driver):
return [identity.filter_user(u.to_dict()) for u in query]
def delete_user(self, user_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_user(session, user_id)
@ -251,7 +251,7 @@ class Identity(identity.Driver):
@sql.handle_conflicts(conflict_type='group')
def create_group(self, group_id, group):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = Group.from_dict(group)
session.add(ref)
@ -259,7 +259,7 @@ class Identity(identity.Driver):
@sql.truncated
def list_groups(self, hints):
session = db_session.get_session()
session = sql.get_session()
query = session.query(Group)
refs = sql.filter_limit_query(Group, query, hints)
return [ref.to_dict() for ref in refs]
@ -271,12 +271,12 @@ class Identity(identity.Driver):
return ref
def get_group(self, group_id):
session = db_session.get_session()
session = sql.get_session()
return self._get_group(session, group_id).to_dict()
@sql.handle_conflicts(conflict_type='group')
def update_group(self, group_id, group):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_group(session, group_id)
@ -291,7 +291,7 @@ class Identity(identity.Driver):
return ref.to_dict()
def delete_group(self, group_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_group(session, group_id)

View File

@ -0,0 +1,2 @@
import six
six.add_move(six.MovedModule('mox', 'mox', 'mox3.mox'))

View File

@ -15,43 +15,148 @@
"""Multiple DB API backend support.
Supported configuration options:
The following two parameters are in the 'database' group:
`backend`: DB backend name or full module path to DB backend module.
A DB backend module should implement a method named 'get_backend' which
takes no arguments. The method can return any object that implements DB
API methods.
"""
from oslo.config import cfg
import functools
import logging
import threading
import time
from keystone.openstack.common.db import exception
from keystone.openstack.common.gettextutils import _LE
from keystone.openstack.common import importutils
db_opts = [
cfg.StrOpt('backend',
default='sqlalchemy',
deprecated_name='db_backend',
deprecated_group='DEFAULT',
help='The backend to use for db'),
]
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
CONF.register_opts(db_opts, 'database')
def safe_for_db_retry(f):
"""Enable db-retry for decorated function, if config option enabled."""
f.__dict__['enable_retry'] = True
return f
class wrap_db_retry(object):
"""Retry db.api methods, if DBConnectionError() raised
Retry decorated db.api methods. If we enabled `use_db_reconnect`
in config, this decorator will be applied to all db.api functions,
marked with @safe_for_db_retry decorator.
Decorator catchs DBConnectionError() and retries function in a
loop until it succeeds, or until maximum retries count will be reached.
"""
def __init__(self, retry_interval, max_retries, inc_retry_interval,
max_retry_interval):
super(wrap_db_retry, self).__init__()
self.retry_interval = retry_interval
self.max_retries = max_retries
self.inc_retry_interval = inc_retry_interval
self.max_retry_interval = max_retry_interval
def __call__(self, f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
next_interval = self.retry_interval
remaining = self.max_retries
while True:
try:
return f(*args, **kwargs)
except exception.DBConnectionError as e:
if remaining == 0:
LOG.exception(_LE('DB exceeded retry limit.'))
raise exception.DBError(e)
if remaining != -1:
remaining -= 1
LOG.exception(_LE('DB connection error.'))
# NOTE(vsergeyev): We are using patched time module, so
# this effectively yields the execution
# context to another green thread.
time.sleep(next_interval)
if self.inc_retry_interval:
next_interval = min(
next_interval * 2,
self.max_retry_interval
)
return wrapper
class DBAPI(object):
def __init__(self, backend_mapping=None):
if backend_mapping is None:
backend_mapping = {}
backend_name = CONF.database.backend
# Import the untranslated name if we don't have a
# mapping.
backend_path = backend_mapping.get(backend_name, backend_name)
backend_mod = importutils.import_module(backend_path)
self.__backend = backend_mod.get_backend()
def __init__(self, backend_name, backend_mapping=None, lazy=False,
**kwargs):
"""Initialize the chosen DB API backend.
:param backend_name: name of the backend to load
:type backend_name: str
:param backend_mapping: backend name -> module/class to load mapping
:type backend_mapping: dict
:param lazy: load the DB backend lazily on the first DB API method call
:type lazy: bool
Keyword arguments:
:keyword use_db_reconnect: retry DB transactions on disconnect or not
:type use_db_reconnect: bool
:keyword retry_interval: seconds between transaction retries
:type retry_interval: int
:keyword inc_retry_interval: increase retry interval or not
:type inc_retry_interval: bool
:keyword max_retry_interval: max interval value between retries
:type max_retry_interval: int
:keyword max_retries: max number of retries before an error is raised
:type max_retries: int
"""
self._backend = None
self._backend_name = backend_name
self._backend_mapping = backend_mapping or {}
self._lock = threading.Lock()
if not lazy:
self._load_backend()
self.use_db_reconnect = kwargs.get('use_db_reconnect', False)
self.retry_interval = kwargs.get('retry_interval', 1)
self.inc_retry_interval = kwargs.get('inc_retry_interval', True)
self.max_retry_interval = kwargs.get('max_retry_interval', 10)
self.max_retries = kwargs.get('max_retries', 20)
def _load_backend(self):
with self._lock:
if not self._backend:
# Import the untranslated name if we don't have a mapping
backend_path = self._backend_mapping.get(self._backend_name,
self._backend_name)
backend_mod = importutils.import_module(backend_path)
self._backend = backend_mod.get_backend()
def __getattr__(self, key):
return getattr(self.__backend, key)
if not self._backend:
self._load_backend()
attr = getattr(self._backend, key)
if not hasattr(attr, '__call__'):
return attr
# NOTE(vsergeyev): If `use_db_reconnect` option is set to True, retry
# DB API methods, decorated with @safe_for_db_retry
# on disconnect.
if self.use_db_reconnect and hasattr(attr, 'enable_retry'):
attr = wrap_db_retry(
retry_interval=self.retry_interval,
max_retries=self.max_retries,
inc_retry_interval=self.inc_retry_interval,
max_retry_interval=self.max_retry_interval)(attr)
return attr

View File

@ -0,0 +1,168 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
from oslo.config import cfg
database_opts = [
cfg.StrOpt('sqlite_db',
deprecated_group='DEFAULT',
default='keystone.sqlite',
help='The file name to use with SQLite'),
cfg.BoolOpt('sqlite_synchronous',
deprecated_group='DEFAULT',
default=True,
help='If True, SQLite uses synchronous mode'),
cfg.StrOpt('backend',
default='sqlalchemy',
deprecated_name='db_backend',
deprecated_group='DEFAULT',
help='The backend to use for db'),
cfg.StrOpt('connection',
help='The SQLAlchemy connection string used to connect to the '
'database',
secret=True,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_connection',
group='DATABASE'),
cfg.DeprecatedOpt('connection',
group='sql'), ]),
cfg.StrOpt('mysql_sql_mode',
help='The SQL mode to be used for MySQL sessions '
'(default is empty, meaning do not override '
'any server-side SQL mode setting)'),
cfg.IntOpt('idle_timeout',
default=3600,
deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_idle_timeout',
group='DATABASE'),
cfg.DeprecatedOpt('idle_timeout',
group='sql')],
help='Timeout before idle sql connections are reaped'),
cfg.IntOpt('min_pool_size',
default=1,
deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_min_pool_size',
group='DATABASE')],
help='Minimum number of SQL connections to keep open in a '
'pool'),
cfg.IntOpt('max_pool_size',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_pool_size',
group='DATABASE')],
help='Maximum number of SQL connections to keep open in a '
'pool'),
cfg.IntOpt('max_retries',
default=10,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_retries',
group='DATABASE')],
help='Maximum db connection retries during startup. '
'(setting -1 implies an infinite retry count)'),
cfg.IntOpt('retry_interval',
default=10,
deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval',
group='DEFAULT'),
cfg.DeprecatedOpt('reconnect_interval',
group='DATABASE')],
help='Interval between retries of opening a sql connection'),
cfg.IntOpt('max_overflow',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow',
group='DEFAULT'),
cfg.DeprecatedOpt('sqlalchemy_max_overflow',
group='DATABASE')],
help='If set, use this value for max_overflow with sqlalchemy'),
cfg.IntOpt('connection_debug',
default=0,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug',
group='DEFAULT')],
help='Verbosity of SQL debugging information. 0=None, '
'100=Everything'),
cfg.BoolOpt('connection_trace',
default=False,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace',
group='DEFAULT')],
help='Add python stack traces to SQL as comment strings'),
cfg.IntOpt('pool_timeout',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout',
group='DATABASE')],
help='If set, use this value for pool_timeout with sqlalchemy'),
cfg.BoolOpt('use_db_reconnect',
default=False,
help='Enable the experimental use of database reconnect '
'on connection lost'),
cfg.IntOpt('db_retry_interval',
default=1,
help='seconds between db connection retries'),
cfg.BoolOpt('db_inc_retry_interval',
default=True,
help='Whether to increase interval between db connection '
'retries, up to db_max_retry_interval'),
cfg.IntOpt('db_max_retry_interval',
default=10,
help='max seconds between db connection retries, if '
'db_inc_retry_interval is enabled'),
cfg.IntOpt('db_max_retries',
default=20,
help='maximum db connection retries before error is raised. '
'(setting -1 implies an infinite retry count)'),
]
CONF = cfg.CONF
CONF.register_opts(database_opts, 'database')
def set_defaults(sql_connection, sqlite_db, max_pool_size=None,
max_overflow=None, pool_timeout=None):
"""Set defaults for configuration variables."""
cfg.set_defaults(database_opts,
connection=sql_connection,
sqlite_db=sqlite_db)
# Update the QueuePool defaults
if max_pool_size is not None:
cfg.set_defaults(database_opts,
max_pool_size=max_pool_size)
if max_overflow is not None:
cfg.set_defaults(database_opts,
max_overflow=max_overflow)
if pool_timeout is not None:
cfg.set_defaults(database_opts,
pool_timeout=pool_timeout)
def list_opts():
"""Returns a list of oslo.config options available in the library.
The returned list includes all oslo.config options which may be registered
at runtime by the library.
Each element of the list is a tuple. The first element is the name of the
group under which the list of elements in the second element will be
registered. A group name of None corresponds to the [DEFAULT] group in
config files.
The purpose of this is to allow tools like the Oslo sample config file
generator to discover the options exposed to users by this library.
:returns: a list of (group_name, opts) tuples
"""
return [('database', copy.deepcopy(database_opts))]

View File

@ -51,13 +51,9 @@ import sqlalchemy
from sqlalchemy.schema import UniqueConstraint
from keystone.openstack.common.db import exception
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common.gettextutils import _
get_engine = db_session.get_engine
def _get_unique_constraints(self, table):
"""Retrieve information about existing unique constraints of the table
@ -172,11 +168,12 @@ def patch_migrate():
sqlite.SQLiteConstraintGenerator)
def db_sync(abs_path, version=None, init_version=0):
def db_sync(engine, abs_path, version=None, init_version=0):
"""Upgrade or downgrade a database.
Function runs the upgrade() or downgrade() functions in change scripts.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository.
:param version: Database will upgrade/downgrade until this version.
If None - database will update to the latest
@ -190,32 +187,54 @@ def db_sync(abs_path, version=None, init_version=0):
raise exception.DbMigrationError(
message=_("version should be an integer"))
current_version = db_version(abs_path, init_version)
current_version = db_version(engine, abs_path, init_version)
repository = _find_migrate_repo(abs_path)
_db_schema_sanity_check(engine)
if version is None or version > current_version:
return versioning_api.upgrade(get_engine(), repository, version)
return versioning_api.upgrade(engine, repository, version)
else:
return versioning_api.downgrade(get_engine(), repository,
return versioning_api.downgrade(engine, repository,
version)
def db_version(abs_path, init_version):
def _db_schema_sanity_check(engine):
"""Ensure all database tables were created with required parameters.
:param engine: SQLAlchemy engine instance for a given database
"""
if engine.name == 'mysql':
onlyutf8_sql = ('SELECT TABLE_NAME,TABLE_COLLATION '
'from information_schema.TABLES '
'where TABLE_SCHEMA=%s and '
'TABLE_COLLATION NOT LIKE "%%utf8%%"')
table_names = [res[0] for res in engine.execute(onlyutf8_sql,
engine.url.database)]
if len(table_names) > 0:
raise ValueError(_('Tables "%s" have non utf8 collation, '
'please make sure all tables are CHARSET=utf8'
) % ','.join(table_names))
def db_version(engine, abs_path, init_version):
"""Show the current version of the repository.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
try:
return versioning_api.db_version(get_engine(), repository)
return versioning_api.db_version(engine, repository)
except versioning_exceptions.DatabaseNotControlledError:
meta = sqlalchemy.MetaData()
engine = get_engine()
meta.reflect(bind=engine)
tables = meta.tables
if len(tables) == 0:
db_version_control(abs_path, init_version)
return versioning_api.db_version(get_engine(), repository)
if len(tables) == 0 or 'alembic_version' in tables:
db_version_control(engine, abs_path, version=init_version)
return versioning_api.db_version(engine, repository)
else:
raise exception.DbMigrationError(
message=_(
@ -224,17 +243,18 @@ def db_version(abs_path, init_version):
"manually."))
def db_version_control(abs_path, version=None):
def db_version_control(engine, abs_path, version=None):
"""Mark a database as under this repository's version control.
Once a database is under version control, schema changes should
only be done via change scripts in this repository.
:param engine: SQLAlchemy engine instance for a given database
:param abs_path: Absolute path to migrate repository
:param version: Initial database version
"""
repository = _find_migrate_repo(abs_path)
versioning_api.version_control(get_engine(), repository, version)
versioning_api.version_control(engine, repository, version)
return version

View File

@ -26,7 +26,6 @@ from sqlalchemy import Column, Integer
from sqlalchemy import DateTime
from sqlalchemy.orm import object_mapper
from keystone.openstack.common.db.sqlalchemy import session as sa
from keystone.openstack.common import timeutils
@ -34,10 +33,9 @@ class ModelBase(object):
"""Base class for models."""
__table_initialized__ = False
def save(self, session=None):
def save(self, session):
"""Save this object."""
if not session:
session = sa.get_session()
# NOTE(boris-42): This part of code should be look like:
# session.add(self)
# session.flush()
@ -102,15 +100,15 @@ class ModelBase(object):
class TimestampMixin(object):
created_at = Column(DateTime, default=timeutils.utcnow)
updated_at = Column(DateTime, onupdate=timeutils.utcnow)
created_at = Column(DateTime, default=lambda: timeutils.utcnow())
updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow())
class SoftDeleteMixin(object):
deleted_at = Column(DateTime)
deleted = Column(Integer, default=0)
def soft_delete(self, session=None):
def soft_delete(self, session):
"""Mark this object as deleted."""
self.deleted = self.id
self.deleted_at = timeutils.utcnow()

View File

@ -1,5 +1,3 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 Mirantis.inc
# All Rights Reserved.
#

View File

@ -16,19 +16,6 @@
"""Session Handling for SQLAlchemy backend.
Initializing:
* Call `set_defaults()` with the minimal of the following kwargs:
``sql_connection``, ``sqlite_db``
Example:
.. code:: python
session.set_defaults(
sql_connection="sqlite:///var/lib/keystone/sqlite.db",
sqlite_db="/var/lib/keystone/sqlite.db")
Recommended ways to use sessions within this framework:
* Don't use them explicitly; this is like running with ``AUTOCOMMIT=1``.
@ -87,7 +74,7 @@ Recommended ways to use sessions within this framework:
.. code:: python
def create_many_foo(context, foos):
session = get_session()
session = sessionmaker()
with session.begin():
for foo in foos:
foo_ref = models.Foo()
@ -95,7 +82,7 @@ Recommended ways to use sessions within this framework:
session.add(foo_ref)
def update_bar(context, foo_id, newbar):
session = get_session()
session = sessionmaker()
with session.begin():
foo_ref = (model_query(context, models.Foo, session).
filter_by(id=foo_id).
@ -124,7 +111,9 @@ Recommended ways to use sessions within this framework:
filter_by(id=subq.as_scalar()).
update({'bar': newbar}))
For reference, this emits approximately the following SQL statement::
For reference, this emits approximately the following SQL statement:
.. code:: sql
UPDATE bar SET bar = ${newbar}
WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1);
@ -140,7 +129,7 @@ Recommended ways to use sessions within this framework:
foo1 = models.Foo()
foo2 = models.Foo()
foo1.id = foo2.id = 1
session = get_session()
session = sessionmaker()
try:
with session.begin():
session.add(foo1)
@ -166,7 +155,7 @@ Recommended ways to use sessions within this framework:
.. code:: python
def myfunc(foo):
session = get_session()
session = sessionmaker()
with session.begin():
# do some database things
bar = _private_func(foo, session)
@ -174,7 +163,7 @@ Recommended ways to use sessions within this framework:
def _private_func(foo, session=None):
if not session:
session = get_session()
session = sessionmaker()
with session.begin(subtransaction=True):
# do some other database things
return bar
@ -238,7 +227,7 @@ Efficient use of soft deletes:
def complex_soft_delete_with_synchronization_bar(session=None):
if session is None:
session = get_session()
session = sessionmaker()
with session.begin(subtransactions=True):
count = (model_query(BarModel).
find(some_condition).
@ -255,7 +244,7 @@ Efficient use of soft deletes:
.. code:: python
def soft_delete_bar_model():
session = get_session()
session = sessionmaker()
with session.begin():
bar_ref = model_query(BarModel).find(some_condition).first()
# Work with bar_ref
@ -267,7 +256,7 @@ Efficient use of soft deletes:
.. code:: python
def soft_delete_multi_models():
session = get_session()
session = sessionmaker()
with session.begin():
query = (model_query(BarModel, session=session).
find(some_condition))
@ -291,11 +280,9 @@ Efficient use of soft deletes:
import functools
import logging
import os.path
import re
import time
from oslo.config import cfg
import six
from sqlalchemy import exc as sqla_exc
from sqlalchemy.interfaces import PoolListener
@ -304,150 +291,12 @@ from sqlalchemy.pool import NullPool, StaticPool
from sqlalchemy.sql.expression import literal_column
from keystone.openstack.common.db import exception
from keystone.openstack.common.gettextutils import _
from keystone.openstack.common.gettextutils import _LE, _LW, _LI
from keystone.openstack.common import timeutils
sqlite_db_opts = [
cfg.StrOpt('sqlite_db',
default='keystone.sqlite',
help='The file name to use with SQLite'),
cfg.BoolOpt('sqlite_synchronous',
default=True,
help='If True, SQLite uses synchronous mode'),
]
database_opts = [
cfg.StrOpt('connection',
default='sqlite:///' +
os.path.abspath(os.path.join(os.path.dirname(__file__),
'../', '$sqlite_db')),
help='The SQLAlchemy connection string used to connect to the '
'database',
secret=True,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_connection',
group='DATABASE'),
cfg.DeprecatedOpt('connection',
group='sql'), ]),
cfg.StrOpt('slave_connection',
default='',
secret=True,
help='The SQLAlchemy connection string used to connect to the '
'slave database'),
cfg.IntOpt('idle_timeout',
default=3600,
deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_idle_timeout',
group='DATABASE'),
cfg.DeprecatedOpt('idle_timeout',
group='sql')],
help='Timeout before idle sql connections are reaped'),
cfg.IntOpt('min_pool_size',
default=1,
deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_min_pool_size',
group='DATABASE')],
help='Minimum number of SQL connections to keep open in a '
'pool'),
cfg.IntOpt('max_pool_size',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_pool_size',
group='DATABASE')],
help='Maximum number of SQL connections to keep open in a '
'pool'),
cfg.IntOpt('max_retries',
default=10,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_max_retries',
group='DATABASE')],
help='Maximum db connection retries during startup. '
'(setting -1 implies an infinite retry count)'),
cfg.IntOpt('retry_interval',
default=10,
deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval',
group='DEFAULT'),
cfg.DeprecatedOpt('reconnect_interval',
group='DATABASE')],
help='Interval between retries of opening a sql connection'),
cfg.IntOpt('max_overflow',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow',
group='DEFAULT'),
cfg.DeprecatedOpt('sqlalchemy_max_overflow',
group='DATABASE')],
help='If set, use this value for max_overflow with sqlalchemy'),
cfg.IntOpt('connection_debug',
default=0,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug',
group='DEFAULT')],
help='Verbosity of SQL debugging information. 0=None, '
'100=Everything'),
cfg.BoolOpt('connection_trace',
default=False,
deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace',
group='DEFAULT')],
help='Add python stack traces to SQL as comment strings'),
cfg.IntOpt('pool_timeout',
default=None,
deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout',
group='DATABASE')],
help='If set, use this value for pool_timeout with sqlalchemy'),
]
CONF = cfg.CONF
CONF.register_opts(sqlite_db_opts)
CONF.register_opts(database_opts, 'database')
LOG = logging.getLogger(__name__)
_ENGINE = None
_MAKER = None
_SLAVE_ENGINE = None
_SLAVE_MAKER = None
def set_defaults(sql_connection, sqlite_db, max_pool_size=None,
max_overflow=None, pool_timeout=None):
"""Set defaults for configuration variables."""
cfg.set_defaults(database_opts,
connection=sql_connection)
cfg.set_defaults(sqlite_db_opts,
sqlite_db=sqlite_db)
# Update the QueuePool defaults
if max_pool_size is not None:
cfg.set_defaults(database_opts,
max_pool_size=max_pool_size)
if max_overflow is not None:
cfg.set_defaults(database_opts,
max_overflow=max_overflow)
if pool_timeout is not None:
cfg.set_defaults(database_opts,
pool_timeout=pool_timeout)
def cleanup():
global _ENGINE, _MAKER
global _SLAVE_ENGINE, _SLAVE_MAKER
if _MAKER:
_MAKER.close_all()
_MAKER = None
if _ENGINE:
_ENGINE.dispose()
_ENGINE = None
if _SLAVE_MAKER:
_SLAVE_MAKER.close_all()
_SLAVE_MAKER = None
if _SLAVE_ENGINE:
_SLAVE_ENGINE.dispose()
_SLAVE_ENGINE = None
class SqliteForeignKeysListener(PoolListener):
"""Ensures that the foreign key constraints are enforced in SQLite.
@ -460,30 +309,6 @@ class SqliteForeignKeysListener(PoolListener):
dbapi_con.execute('pragma foreign_keys=ON')
def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False,
slave_session=False, mysql_traditional_mode=False):
"""Return a SQLAlchemy session."""
global _MAKER
global _SLAVE_MAKER
maker = _MAKER
if slave_session:
maker = _SLAVE_MAKER
if maker is None:
engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session,
mysql_traditional_mode=mysql_traditional_mode)
maker = get_maker(engine, autocommit, expire_on_commit)
if slave_session:
_SLAVE_MAKER = maker
else:
_MAKER = maker
session = maker()
return session
# note(boris-42): In current versions of DB backends unique constraint
# violation messages follow the structure:
#
@ -492,9 +317,9 @@ def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False,
# N columns - (IntegrityError) column c1, c2, ..., N are not unique
#
# sqlite since 3.7.16:
# 1 column - (IntegrityError) UNIQUE constraint failed: k1
# 1 column - (IntegrityError) UNIQUE constraint failed: tbl.k1
#
# N columns - (IntegrityError) UNIQUE constraint failed: k1, k2
# N columns - (IntegrityError) UNIQUE constraint failed: tbl.k1, tbl.k2
#
# postgres:
# 1 column - (IntegrityError) duplicate key value violates unique
@ -507,11 +332,20 @@ def get_session(autocommit=True, expire_on_commit=False, sqlite_fk=False,
# 'c1'")
# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined
# with -' for key 'name_of_our_constraint'")
#
# ibm_db_sa:
# N columns - (IntegrityError) SQL0803N One or more values in the INSERT
# statement, UPDATE statement, or foreign key update caused by a
# DELETE statement are not valid because the primary key, unique
# constraint or unique index identified by "2" constrains table
# "NOVA.KEY_PAIRS" from having duplicate values for the index
# key.
_DUP_KEY_RE_DB = {
"sqlite": (re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"),
re.compile(r"^.*UNIQUE\s+constraint\s+failed:\s+(.+)$")),
"postgresql": (re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"),),
"mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),)
"mysql": (re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$"),),
"ibm_db_sa": (re.compile(r"^.*SQL0803N.*$"),),
}
@ -533,7 +367,7 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name):
return [columns]
return columns[len(uniqbase):].split("0")[1:]
if engine_name not in ["mysql", "sqlite", "postgresql"]:
if engine_name not in ["ibm_db_sa", "mysql", "sqlite", "postgresql"]:
return
# FIXME(johannes): The usage of the .message attribute has been
@ -548,10 +382,15 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name):
else:
return
columns = match.group(1)
# NOTE(mriedem): The ibm_db_sa integrity error message doesn't provide the
# columns so we have to omit that from the DBDuplicateEntry error.
columns = ''
if engine_name != 'ibm_db_sa':
columns = match.group(1)
if engine_name == "sqlite":
columns = columns.strip().split(", ")
columns = [c.split('.')[-1] for c in columns.strip().split(", ")]
else:
columns = get_columns_from_uniq_cons_or_name(columns)
raise exception.DBDuplicateEntry(columns, integrity_error)
@ -589,57 +428,39 @@ def _raise_if_deadlock_error(operational_error, engine_name):
def _wrap_db_error(f):
#TODO(rpodolyaka): in a subsequent commit make this a class decorator to
# ensure it can only applied to Session subclasses instances (as we use
# Session instance bind attribute below)
@functools.wraps(f)
def _wrap(*args, **kwargs):
def _wrap(self, *args, **kwargs):
try:
return f(*args, **kwargs)
return f(self, *args, **kwargs)
except UnicodeEncodeError:
raise exception.DBInvalidUnicodeParameter()
# note(boris-42): We should catch unique constraint violation and
# wrap it by our own DBDuplicateEntry exception. Unique constraint
# violation is wrapped by IntegrityError.
except sqla_exc.OperationalError as e:
_raise_if_deadlock_error(e, get_engine().name)
_raise_if_db_connection_lost(e, self.bind)
_raise_if_deadlock_error(e, self.bind.dialect.name)
# NOTE(comstud): A lot of code is checking for OperationalError
# so let's not wrap it for now.
raise
# note(boris-42): We should catch unique constraint violation and
# wrap it by our own DBDuplicateEntry exception. Unique constraint
# violation is wrapped by IntegrityError.
except sqla_exc.IntegrityError as e:
# note(boris-42): SqlAlchemy doesn't unify errors from different
# DBs so we must do this. Also in some tables (for example
# instance_types) there are more than one unique constraint. This
# means we should get names of columns, which values violate
# unique constraint, from error message.
_raise_if_duplicate_entry_error(e, get_engine().name)
_raise_if_duplicate_entry_error(e, self.bind.dialect.name)
raise exception.DBError(e)
except Exception as e:
LOG.exception(_('DB exception wrapped.'))
LOG.exception(_LE('DB exception wrapped.'))
raise exception.DBError(e)
return _wrap
def get_engine(sqlite_fk=False, slave_engine=False,
mysql_traditional_mode=False):
"""Return a SQLAlchemy engine."""
global _ENGINE
global _SLAVE_ENGINE
engine = _ENGINE
db_uri = CONF.database.connection
if slave_engine:
engine = _SLAVE_ENGINE
db_uri = CONF.database.slave_connection
if engine is None:
engine = create_engine(db_uri, sqlite_fk=sqlite_fk,
mysql_traditional_mode=mysql_traditional_mode)
if slave_engine:
_SLAVE_ENGINE = engine
else:
_ENGINE = engine
return engine
def _synchronous_switch_listener(dbapi_conn, connection_rec):
"""Switch sqlite connections to non-synchronous mode."""
dbapi_conn.execute("PRAGMA synchronous = OFF")
@ -681,7 +502,7 @@ def _ping_listener(engine, dbapi_conn, connection_rec, connection_proxy):
cursor.execute(ping_sql)
except Exception as ex:
if engine.dialect.is_disconnect(ex, dbapi_conn, cursor):
msg = _('Database server has gone away: %s') % ex
msg = _LW('Database server has gone away: %s') % ex
LOG.warning(msg)
raise sqla_exc.DisconnectionError(msg)
else:
@ -696,7 +517,44 @@ def _set_mode_traditional(dbapi_con, connection_rec, connection_proxy):
than a declared field just with warning. That is fraught with data
corruption.
"""
dbapi_con.cursor().execute("SET SESSION sql_mode = TRADITIONAL;")
_set_session_sql_mode(dbapi_con, connection_rec,
connection_proxy, 'TRADITIONAL')
def _set_session_sql_mode(dbapi_con, connection_rec,
connection_proxy, sql_mode=None):
"""Set the sql_mode session variable.
MySQL supports several server modes. The default is None, but sessions
may choose to enable server modes like TRADITIONAL, ANSI,
several STRICT_* modes and others.
Note: passing in '' (empty string) for sql_mode clears
the SQL mode for the session, overriding a potentially set
server default. Passing in None (the default) makes this
a no-op, meaning if a server-side SQL mode is set, it still applies.
"""
cursor = dbapi_con.cursor()
if sql_mode is not None:
cursor.execute("SET SESSION sql_mode = %s", [sql_mode])
# Check against the real effective SQL mode. Even when unset by
# our own config, the server may still be operating in a specific
# SQL mode as set by the server configuration
cursor.execute("SHOW VARIABLES LIKE 'sql_mode'")
row = cursor.fetchone()
if row is None:
LOG.warning(_LW('Unable to detect effective SQL mode'))
return
realmode = row[1]
LOG.info(_LI('MySQL server mode set to %s') % realmode)
# 'TRADITIONAL' mode enables several other modes, so
# we need a substring match here
if not ('TRADITIONAL' in realmode.upper() or
'STRICT_ALL_TABLES' in realmode.upper()):
LOG.warning(_LW("MySQL SQL mode is '%s', "
"consider enabling TRADITIONAL or STRICT_ALL_TABLES")
% realmode)
def _is_db_connection_error(args):
@ -711,69 +569,79 @@ def _is_db_connection_error(args):
return False
def create_engine(sql_connection, sqlite_fk=False,
mysql_traditional_mode=False):
def _raise_if_db_connection_lost(error, engine):
# NOTE(vsergeyev): Function is_disconnect(e, connection, cursor)
# requires connection and cursor in incoming parameters,
# but we have no possibility to create connection if DB
# is not available, so in such case reconnect fails.
# But is_disconnect() ignores these parameters, so it
# makes sense to pass to function None as placeholder
# instead of connection and cursor.
if engine.dialect.is_disconnect(error, None, None):
raise exception.DBConnectionError(error)
def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
mysql_traditional_mode=False, idle_timeout=3600,
connection_debug=0, max_pool_size=None, max_overflow=None,
pool_timeout=None, sqlite_synchronous=True,
connection_trace=False, max_retries=10, retry_interval=10):
"""Return a new SQLAlchemy engine."""
# NOTE(geekinutah): At this point we could be connecting to the normal
# db handle or the slave db handle. Things like
# _wrap_db_error aren't going to work well if their
# backends don't match. Let's check.
_assert_matching_drivers()
connection_dict = sqlalchemy.engine.url.make_url(sql_connection)
engine_args = {
"pool_recycle": CONF.database.idle_timeout,
"echo": False,
"pool_recycle": idle_timeout,
'convert_unicode': True,
}
# Map our SQL debug level to SQLAlchemy's options
if CONF.database.connection_debug >= 100:
engine_args['echo'] = 'debug'
elif CONF.database.connection_debug >= 50:
engine_args['echo'] = True
logger = logging.getLogger('sqlalchemy.engine')
# Map SQL debug level to Python log level
if connection_debug >= 100:
logger.setLevel(logging.DEBUG)
elif connection_debug >= 50:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.WARNING)
if "sqlite" in connection_dict.drivername:
if sqlite_fk:
engine_args["listeners"] = [SqliteForeignKeysListener()]
engine_args["poolclass"] = NullPool
if CONF.database.connection == "sqlite://":
if sql_connection == "sqlite://":
engine_args["poolclass"] = StaticPool
engine_args["connect_args"] = {'check_same_thread': False}
else:
if CONF.database.max_pool_size is not None:
engine_args['pool_size'] = CONF.database.max_pool_size
if CONF.database.max_overflow is not None:
engine_args['max_overflow'] = CONF.database.max_overflow
if CONF.database.pool_timeout is not None:
engine_args['pool_timeout'] = CONF.database.pool_timeout
if max_pool_size is not None:
engine_args['pool_size'] = max_pool_size
if max_overflow is not None:
engine_args['max_overflow'] = max_overflow
if pool_timeout is not None:
engine_args['pool_timeout'] = pool_timeout
engine = sqlalchemy.create_engine(sql_connection, **engine_args)
sqlalchemy.event.listen(engine, 'checkin', _thread_yield)
if engine.name in ['mysql', 'ibm_db_sa']:
callback = functools.partial(_ping_listener, engine)
sqlalchemy.event.listen(engine, 'checkout', callback)
ping_callback = functools.partial(_ping_listener, engine)
sqlalchemy.event.listen(engine, 'checkout', ping_callback)
if engine.name == 'mysql':
if mysql_traditional_mode:
sqlalchemy.event.listen(engine, 'checkout',
_set_mode_traditional)
else:
LOG.warning(_("This application has not enabled MySQL "
"traditional mode, which means silent "
"data corruption may occur. "
"Please encourage the application "
"developers to enable this mode."))
mysql_sql_mode = 'TRADITIONAL'
if mysql_sql_mode:
mode_callback = functools.partial(_set_session_sql_mode,
sql_mode=mysql_sql_mode)
sqlalchemy.event.listen(engine, 'checkout', mode_callback)
elif 'sqlite' in connection_dict.drivername:
if not CONF.sqlite_synchronous:
if not sqlite_synchronous:
sqlalchemy.event.listen(engine, 'connect',
_synchronous_switch_listener)
sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener)
if (CONF.database.connection_trace and
engine.dialect.dbapi.__name__ == 'MySQLdb'):
if connection_trace and engine.dialect.dbapi.__name__ == 'MySQLdb':
_patch_mysqldb_with_stacktrace_comments()
try:
@ -782,15 +650,15 @@ def create_engine(sql_connection, sqlite_fk=False,
if not _is_db_connection_error(e.args[0]):
raise
remaining = CONF.database.max_retries
remaining = max_retries
if remaining == -1:
remaining = 'infinite'
while True:
msg = _('SQL connection failed. %s attempts left.')
msg = _LW('SQL connection failed. %s attempts left.')
LOG.warning(msg % remaining)
if remaining != 'infinite':
remaining -= 1
time.sleep(CONF.database.retry_interval)
time.sleep(retry_interval)
try:
engine.connect()
break
@ -877,13 +745,117 @@ def _patch_mysqldb_with_stacktrace_comments():
setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query)
def _assert_matching_drivers():
"""Make sure slave handle and normal handle have the same driver."""
# NOTE(geekinutah): There's no use case for writing to one backend and
# reading from another. Who knows what the future holds?
if CONF.database.slave_connection == '':
return
class EngineFacade(object):
"""A helper class for removing of global engine instances from keystone.db.
normal = sqlalchemy.engine.url.make_url(CONF.database.connection)
slave = sqlalchemy.engine.url.make_url(CONF.database.slave_connection)
assert normal.drivername == slave.drivername
As a library, keystone.db can't decide where to store/when to create engine
and sessionmaker instances, so this must be left for a target application.
On the other hand, in order to simplify the adoption of keystone.db changes,
we'll provide a helper class, which creates engine and sessionmaker
on its instantiation and provides get_engine()/get_session() methods
that are compatible with corresponding utility functions that currently
exist in target projects, e.g. in Nova.
engine/sessionmaker instances will still be global (and they are meant to
be global), but they will be stored in the app context, rather that in the
keystone.db context.
Note: using of this helper is completely optional and you are encouraged to
integrate engine/sessionmaker instances into your apps any way you like
(e.g. one might want to bind a session to a request context). Two important
things to remember:
1. An Engine instance is effectively a pool of DB connections, so it's
meant to be shared (and it's thread-safe).
2. A Session instance is not meant to be shared and represents a DB
transactional context (i.e. it's not thread-safe). sessionmaker is
a factory of sessions.
"""
def __init__(self, sql_connection,
sqlite_fk=False, mysql_sql_mode=None,
autocommit=True, expire_on_commit=False, **kwargs):
"""Initialize engine and sessionmaker instances.
:param sqlite_fk: enable foreign keys in SQLite
:type sqlite_fk: bool
:param mysql_sql_mode: set SQL mode in MySQL
:type mysql_sql_mode: string
:param autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:param expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
Keyword arguments:
:keyword idle_timeout: timeout before idle sql connections are reaped
(defaults to 3600)
:keyword connection_debug: verbosity of SQL debugging information.
0=None, 100=Everything (defaults to 0)
:keyword max_pool_size: maximum number of SQL connections to keep open
in a pool (defaults to SQLAlchemy settings)
:keyword max_overflow: if set, use this value for max_overflow with
sqlalchemy (defaults to SQLAlchemy settings)
:keyword pool_timeout: if set, use this value for pool_timeout with
sqlalchemy (defaults to SQLAlchemy settings)
:keyword sqlite_synchronous: if True, SQLite uses synchronous mode
(defaults to True)
:keyword connection_trace: add python stack traces to SQL as comment
strings (defaults to False)
:keyword max_retries: maximum db connection retries during startup.
(setting -1 implies an infinite retry count)
(defaults to 10)
:keyword retry_interval: interval between retries of opening a sql
connection (defaults to 10)
"""
super(EngineFacade, self).__init__()
self._engine = create_engine(
sql_connection=sql_connection,
sqlite_fk=sqlite_fk,
mysql_sql_mode=mysql_sql_mode,
idle_timeout=kwargs.get('idle_timeout', 3600),
connection_debug=kwargs.get('connection_debug', 0),
max_pool_size=kwargs.get('max_pool_size'),
max_overflow=kwargs.get('max_overflow'),
pool_timeout=kwargs.get('pool_timeout'),
sqlite_synchronous=kwargs.get('sqlite_synchronous', True),
connection_trace=kwargs.get('connection_trace', False),
max_retries=kwargs.get('max_retries', 10),
retry_interval=kwargs.get('retry_interval', 10))
self._session_maker = get_maker(
engine=self._engine,
autocommit=autocommit,
expire_on_commit=expire_on_commit)
def get_engine(self):
"""Get the engine instance (note, that it's shared)."""
return self._engine
def get_session(self, **kwargs):
"""Get a Session instance.
If passed, keyword arguments values override the ones used when the
sessionmaker instance was created.
:keyword autocommit: use autocommit mode for created Session instances
:type autocommit: bool
:keyword expire_on_commit: expire session objects on commit
:type expire_on_commit: bool
"""
for arg in kwargs:
if arg not in ('autocommit', 'expire_on_commit'):
del kwargs[arg]
return self._session_maker(**kwargs)

View File

@ -0,0 +1,149 @@
# Copyright (c) 2013 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import abc
import functools
import os
import fixtures
import six
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common.db.sqlalchemy import utils
from keystone.openstack.common import test
class DbFixture(fixtures.Fixture):
"""Basic database fixture.
Allows to run tests on various db backends, such as SQLite, MySQL and
PostgreSQL. By default use sqlite backend. To override default backend
uri set env variable OS_TEST_DBAPI_CONNECTION with database admin
credentials for specific backend.
"""
def _get_uri(self):
return os.getenv('OS_TEST_DBAPI_CONNECTION', 'sqlite://')
def __init__(self, test):
super(DbFixture, self).__init__()
self.test = test
def setUp(self):
super(DbFixture, self).setUp()
self.test.engine = session.create_engine(self._get_uri())
self.test.sessionmaker = session.get_maker(self.test.engine)
self.addCleanup(self.test.engine.dispose)
class DbTestCase(test.BaseTestCase):
"""Base class for testing of DB code.
Using `DbFixture`. Intended to be the main database test case to use all
the tests on a given backend with user defined uri. Backend specific
tests should be decorated with `backend_specific` decorator.
"""
FIXTURE = DbFixture
def setUp(self):
super(DbTestCase, self).setUp()
self.useFixture(self.FIXTURE(self))
ALLOWED_DIALECTS = ['sqlite', 'mysql', 'postgresql']
def backend_specific(*dialects):
"""Decorator to skip backend specific tests on inappropriate engines.
::dialects: list of dialects names under which the test will be launched.
"""
def wrap(f):
@functools.wraps(f)
def ins_wrap(self):
if not set(dialects).issubset(ALLOWED_DIALECTS):
raise ValueError(
"Please use allowed dialects: %s" % ALLOWED_DIALECTS)
if self.engine.name not in dialects:
msg = ('The test "%s" can be run '
'only on %s. Current engine is %s.')
args = (f.__name__, ' '.join(dialects), self.engine.name)
self.skip(msg % args)
else:
return f(self)
return ins_wrap
return wrap
@six.add_metaclass(abc.ABCMeta)
class OpportunisticFixture(DbFixture):
"""Base fixture to use default CI databases.
The databases exist in OpenStack CI infrastructure. But for the
correct functioning in local environment the databases must be
created manually.
"""
DRIVER = abc.abstractproperty(lambda: None)
DBNAME = PASSWORD = USERNAME = 'openstack_citest'
def _get_uri(self):
return utils.get_connect_string(backend=self.DRIVER,
user=self.USERNAME,
passwd=self.PASSWORD,
database=self.DBNAME)
@six.add_metaclass(abc.ABCMeta)
class OpportunisticTestCase(DbTestCase):
"""Base test case to use default CI databases.
The subclasses of the test case are running only when openstack_citest
database is available otherwise a tests will be skipped.
"""
FIXTURE = abc.abstractproperty(lambda: None)
def setUp(self):
credentials = {
'backend': self.FIXTURE.DRIVER,
'user': self.FIXTURE.USERNAME,
'passwd': self.FIXTURE.PASSWORD,
'database': self.FIXTURE.DBNAME}
if self.FIXTURE.DRIVER and not utils.is_backend_avail(**credentials):
msg = '%s backend is not available.' % self.FIXTURE.DRIVER
return self.skip(msg)
super(OpportunisticTestCase, self).setUp()
class MySQLOpportunisticFixture(OpportunisticFixture):
DRIVER = 'mysql'
class PostgreSQLOpportunisticFixture(OpportunisticFixture):
DRIVER = 'postgresql'
class MySQLOpportunisticTestCase(OpportunisticTestCase):
FIXTURE = MySQLOpportunisticFixture
class PostgreSQLOpportunisticTestCase(OpportunisticTestCase):
FIXTURE = PostgreSQLOpportunisticFixture

View File

@ -15,83 +15,43 @@
# under the License.
import functools
import logging
import os
import subprocess
import lockfile
from six import moves
from six.moves.urllib import parse
import sqlalchemy
import sqlalchemy.exc
from keystone.openstack.common.gettextutils import _
from keystone.openstack.common import log as logging
from keystone.openstack.common.py3kcompat import urlutils
from keystone.openstack.common.db.sqlalchemy import utils
from keystone.openstack.common.gettextutils import _LE
from keystone.openstack.common import test
LOG = logging.getLogger(__name__)
def _get_connect_string(backend, user, passwd, database):
"""Get database connection
Try to get a connection with a very specific set of values, if we get
these then we'll run the tests, otherwise they are skipped
"""
if backend == "postgres":
backend = "postgresql+psycopg2"
elif backend == "mysql":
backend = "mysql+mysqldb"
else:
raise Exception("Unrecognized backend: '%s'" % backend)
return ("%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"
% {'backend': backend, 'user': user, 'passwd': passwd,
'database': database})
def _is_backend_avail(backend, user, passwd, database):
try:
connect_uri = _get_connect_string(backend, user, passwd, database)
engine = sqlalchemy.create_engine(connect_uri)
connection = engine.connect()
except Exception:
# intentionally catch all to handle exceptions even if we don't
# have any backend code loaded.
return False
else:
connection.close()
engine.dispose()
return True
def _have_mysql(user, passwd, database):
present = os.environ.get('TEST_MYSQL_PRESENT')
if present is None:
return _is_backend_avail('mysql', user, passwd, database)
return utils.is_backend_avail(backend='mysql',
user=user,
passwd=passwd,
database=database)
return present.lower() in ('', 'true')
def _have_postgresql(user, passwd, database):
present = os.environ.get('TEST_POSTGRESQL_PRESENT')
if present is None:
return _is_backend_avail('postgres', user, passwd, database)
return utils.is_backend_avail(backend='postgres',
user=user,
passwd=passwd,
database=database)
return present.lower() in ('', 'true')
def get_db_connection_info(conn_pieces):
database = conn_pieces.path.strip('/')
loc_pieces = conn_pieces.netloc.split('@')
host = loc_pieces[1]
auth_pieces = loc_pieces[0].split(':')
user = auth_pieces[0]
password = ""
if len(auth_pieces) > 1:
password = auth_pieces[1].strip()
return (user, password, database, host)
def _set_db_lock(lock_path=None, lock_prefix=None):
def decorator(f):
@functools.wraps(f)
@ -100,10 +60,10 @@ def _set_db_lock(lock_path=None, lock_prefix=None):
path = lock_path or os.environ.get("KEYSTONE_LOCK_PATH")
lock = lockfile.FileLock(os.path.join(path, lock_prefix))
with lock:
LOG.debug(_('Got lock "%s"') % f.__name__)
LOG.debug('Got lock "%s"' % f.__name__)
return f(*args, **kwargs)
finally:
LOG.debug(_('Lock released "%s"') % f.__name__)
LOG.debug('Lock released "%s"' % f.__name__)
return wrapper
return decorator
@ -166,7 +126,10 @@ class BaseMigrationTestCase(test.BaseTestCase):
"Failed to run: %s\n%s" % (cmd, output))
def _reset_pg(self, conn_pieces):
(user, password, database, host) = get_db_connection_info(conn_pieces)
(user,
password,
database,
host) = utils.get_db_connection_info(conn_pieces)
os.environ['PGPASSWORD'] = password
os.environ['PGUSER'] = user
# note(boris-42): We must create and drop database, we can't
@ -190,7 +153,7 @@ class BaseMigrationTestCase(test.BaseTestCase):
def _reset_databases(self):
for key, engine in self.engines.items():
conn_string = self.test_databases[key]
conn_pieces = urlutils.urlparse(conn_string)
conn_pieces = parse.urlparse(conn_string)
engine.dispose()
if conn_string.startswith('sqlite'):
# We can just delete the SQLite database, which is
@ -205,7 +168,7 @@ class BaseMigrationTestCase(test.BaseTestCase):
# the MYSQL database, which is easier and less error-prone
# than using SQLAlchemy to do this via MetaData...trust me.
(user, password, database, host) = \
get_db_connection_info(conn_pieces)
utils.get_db_connection_info(conn_pieces)
sql = ("drop database if exists %(db)s; "
"create database %(db)s;") % {'db': database}
cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s "
@ -301,6 +264,6 @@ class WalkVersionsMixin(object):
if check:
check(engine, data)
except Exception:
LOG.error("Failed to migrate to version %s on engine %s" %
LOG.error(_LE("Failed to migrate to version %s on engine %s") %
(version, engine))
raise

View File

@ -16,6 +16,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import logging
import re
from migrate.changeset import UniqueConstraint
@ -29,6 +30,7 @@ from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql.expression import UpdateBase
from sqlalchemy.sql import select
@ -36,9 +38,9 @@ from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.types import NullType
from keystone.openstack.common.gettextutils import _
from keystone.openstack.common import log as logging
from keystone.openstack.common import context as request_context
from keystone.openstack.common.db.sqlalchemy import models
from keystone.openstack.common.gettextutils import _, _LI, _LW
from keystone.openstack.common import timeutils
@ -94,7 +96,7 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
if 'id' not in sort_keys:
# TODO(justinsb): If this ever gives a false-positive, check
# the actual primary key, rather than assuming its id
LOG.warning(_('Id not in sort_keys; is sort_keys unique?'))
LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?'))
assert(not (sort_dir and sort_dirs))
@ -157,6 +159,98 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
return query
def _read_deleted_filter(query, db_model, read_deleted):
if 'deleted' not in db_model.__table__.columns:
raise ValueError(_("There is no `deleted` column in `%s` table. "
"Project doesn't use soft-deleted feature.")
% db_model.__name__)
default_deleted_value = db_model.__table__.c.deleted.default.arg
if read_deleted == 'no':
query = query.filter(db_model.deleted == default_deleted_value)
elif read_deleted == 'yes':
pass # omit the filter to include deleted and active
elif read_deleted == 'only':
query = query.filter(db_model.deleted != default_deleted_value)
else:
raise ValueError(_("Unrecognized read_deleted value '%s'")
% read_deleted)
return query
def _project_filter(query, db_model, context, project_only):
if project_only and 'project_id' not in db_model.__table__.columns:
raise ValueError(_("There is no `project_id` column in `%s` table.")
% db_model.__name__)
if request_context.is_user_context(context) and project_only:
if project_only == 'allow_none':
is_none = None
query = query.filter(or_(db_model.project_id == context.project_id,
db_model.project_id == is_none))
else:
query = query.filter(db_model.project_id == context.project_id)
return query
def model_query(context, model, session, args=None, project_only=False,
read_deleted=None):
"""Query helper that accounts for context's `read_deleted` field.
:param context: context to query under
:param model: Model to query. Must be a subclass of ModelBase.
:type model: models.ModelBase
:param session: The session to use.
:type session: sqlalchemy.orm.session.Session
:param args: Arguments to query. If None - model is used.
:type args: tuple
:param project_only: If present and context is user-type, then restrict
query to match the context's project_id. If set to
'allow_none', restriction includes project_id = None.
:type project_only: bool
:param read_deleted: If present, overrides context's read_deleted field.
:type read_deleted: bool
Usage:
..code:: python
result = (utils.model_query(context, models.Instance, session=session)
.filter_by(uuid=instance_uuid)
.all())
query = utils.model_query(
context, Node,
session=session,
args=(func.count(Node.id), func.sum(Node.ram))
).filter_by(project_id=project_id)
"""
if not read_deleted:
if hasattr(context, 'read_deleted'):
# NOTE(viktors): some projects use `read_deleted` attribute in
# their contexts instead of `show_deleted`.
read_deleted = context.read_deleted
else:
read_deleted = context.show_deleted
if not issubclass(model, models.ModelBase):
raise TypeError(_("model should be a subclass of ModelBase"))
query = session.query(model) if not args else session.query(*args)
query = _read_deleted_filter(query, model, read_deleted)
query = _project_filter(query, model, context, project_only)
return query
def get_table(engine, name):
"""Returns an sqlalchemy table dynamically from db.
@ -277,8 +371,8 @@ def drop_old_duplicate_entries_from_table(migrate_engine, table_name,
rows_to_delete_select = select([table.c.id]).where(delete_condition)
for row in migrate_engine.execute(rows_to_delete_select).fetchall():
LOG.info(_("Deleting duplicated row with id: %(id)s from table: "
"%(table)s") % dict(id=row[0], table=table_name))
LOG.info(_LI("Deleting duplicated row with id: %(id)s from table: "
"%(table)s") % dict(id=row[0], table=table_name))
if use_soft_delete:
delete_statement = table.update().\
@ -497,3 +591,52 @@ def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name,
where(new_table.c.deleted == deleted).\
values(deleted=default_deleted_value).\
execute()
def get_connect_string(backend, database, user=None, passwd=None):
"""Get database connection
Try to get a connection with a very specific set of values, if we get
these then we'll run the tests, otherwise they are skipped
"""
args = {'backend': backend,
'user': user,
'passwd': passwd,
'database': database}
if backend == 'sqlite':
template = '%(backend)s:///%(database)s'
else:
template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"
return template % args
def is_backend_avail(backend, database, user=None, passwd=None):
try:
connect_uri = get_connect_string(backend=backend,
database=database,
user=user,
passwd=passwd)
engine = sqlalchemy.create_engine(connect_uri)
connection = engine.connect()
except Exception:
# intentionally catch all to handle exceptions even if we don't
# have any backend code loaded.
return False
else:
connection.close()
engine.dispose()
return True
def get_db_connection_info(conn_pieces):
database = conn_pieces.path.strip('/')
loc_pieces = conn_pieces.netloc.split('@')
host = loc_pieces[1]
auth_pieces = loc_pieces[0].split(':')
user = auth_pieces[0]
password = ""
if len(auth_pieces) > 1:
password = auth_pieces[1].strip()
return (user, password, database, host)

View File

@ -23,6 +23,7 @@ Usual usage in an openstack.common module:
"""
import copy
import functools
import gettext
import locale
from logging import handlers
@ -35,6 +36,17 @@ import six
_localedir = os.environ.get('keystone'.upper() + '_LOCALEDIR')
_t = gettext.translation('keystone', localedir=_localedir, fallback=True)
# We use separate translation catalogs for each log level, so set up a
# mapping between the log level name and the translator. The domain
# for the log level is project_name + "-log-" + log_level so messages
# for each level end up in their own catalog.
_t_log_levels = dict(
(level, gettext.translation('keystone' + '-log-' + level,
localedir=_localedir,
fallback=True))
for level in ['info', 'warning', 'error', 'critical']
)
_AVAILABLE_LANGUAGES = {}
USE_LAZY = False
@ -60,6 +72,28 @@ def _(msg):
return _t.ugettext(msg)
def _log_translation(msg, level):
"""Build a single translation of a log message
"""
if USE_LAZY:
return Message(msg, domain='keystone' + '-log-' + level)
else:
translator = _t_log_levels[level]
if six.PY3:
return translator.gettext(msg)
return translator.ugettext(msg)
# Translators for log levels.
#
# The abbreviated names are meant to reflect the usual use of a short
# name like '_'. The "L" is for "log" and the other letter comes from
# the level.
_LI = functools.partial(_log_translation, level='info')
_LW = functools.partial(_log_translation, level='warning')
_LE = functools.partial(_log_translation, level='error')
_LC = functools.partial(_log_translation, level='critical')
def install(domain, lazy=False):
"""Install a _() function using the given translation domain.
@ -118,7 +152,8 @@ class Message(six.text_type):
and can be treated as such.
"""
def __new__(cls, msgid, msgtext=None, params=None, domain='keystone', *args):
def __new__(cls, msgid, msgtext=None, params=None,
domain='keystone', *args):
"""Create a new Message object.
In order for translation to work gettext requires a message ID, this
@ -193,10 +228,11 @@ class Message(six.text_type):
# When we mod a Message we want the actual operation to be performed
# by the parent class (i.e. unicode()), the only thing we do here is
# save the original msgid and the parameters in case of a translation
unicode_mod = super(Message, self).__mod__(other)
params = self._sanitize_mod_params(other)
unicode_mod = super(Message, self).__mod__(params)
modded = Message(self.msgid,
msgtext=unicode_mod,
params=self._sanitize_mod_params(other),
params=params,
domain=self.domain)
return modded
@ -235,8 +271,17 @@ class Message(six.text_type):
params = self._copy_param(dict_param)
else:
params = {}
# Save our existing parameters as defaults to protect
# ourselves from losing values if we are called through an
# (erroneous) chain that builds a valid Message with
# arguments, and then does something like "msg % kwds"
# where kwds is an empty dictionary.
src = {}
if isinstance(self.params, dict):
src.update(self.params)
src.update(dict_param)
for key in keys:
params[key] = self._copy_param(dict_param[key])
params[key] = self._copy_param(src[key])
return params
@ -287,9 +332,27 @@ def get_available_languages(domain):
list_identifiers = (getattr(localedata, 'list', None) or
getattr(localedata, 'locale_identifiers'))
locale_identifiers = list_identifiers()
for i in locale_identifiers:
if find(i) is not None:
language_list.append(i)
# NOTE(luisg): Babel>=1.0,<1.3 has a bug where some OpenStack supported
# locales (e.g. 'zh_CN', and 'zh_TW') aren't supported even though they
# are perfectly legitimate locales:
# https://github.com/mitsuhiko/babel/issues/37
# In Babel 1.3 they fixed the bug and they support these locales, but
# they are still not explicitly "listed" by locale_identifiers().
# That is why we add the locales here explicitly if necessary so that
# they are listed as supported.
aliases = {'zh': 'zh_CN',
'zh_Hant_HK': 'zh_HK',
'zh_Hant': 'zh_TW',
'fil': 'tl_PH'}
for (locale, alias) in six.iteritems(aliases):
if locale in language_list and alias not in language_list:
language_list.append(alias)
_AVAILABLE_LANGUAGES[domain] = language_list
return copy.copy(language_list)

View File

@ -16,7 +16,6 @@ from keystone.common import sql
from keystone.common.sql import migration_helpers
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.policy.backends import rules
@ -33,11 +32,12 @@ class Policy(rules.Policy):
# Internal interface to manage the database
def db_sync(self, version=None):
migration.db_sync(
migration_helpers.find_migrate_repo(), version=version)
sql.get_engine(), migration_helpers.find_migrate_repo(),
version=version)
@sql.handle_conflicts(conflict_type='policy')
def create_policy(self, policy_id, policy):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = PolicyModel.from_dict(policy)
@ -46,7 +46,7 @@ class Policy(rules.Policy):
return ref.to_dict()
def list_policies(self):
session = db_session.get_session()
session = sql.get_session()
refs = session.query(PolicyModel).all()
return [ref.to_dict() for ref in refs]
@ -59,13 +59,13 @@ class Policy(rules.Policy):
return ref
def get_policy(self, policy_id):
session = db_session.get_session()
session = sql.get_session()
return self._get_policy(session, policy_id).to_dict()
@sql.handle_conflicts(conflict_type='policy')
def update_policy(self, policy_id, policy):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_policy(session, policy_id)
@ -79,7 +79,7 @@ class Policy(rules.Policy):
return ref.to_dict()
def delete_policy(self, policy_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = self._get_policy(session, policy_id)

View File

@ -32,7 +32,6 @@ import testtools
from testtools import testcase
import webob
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.fixture import mockpatch
from keystone.openstack.common import gettextutils
@ -61,7 +60,8 @@ from keystone.common import utils as common_utils
from keystone import config
from keystone import exception
from keystone import notifications
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common.db import options as db_options
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.fixture import config as config_fixture
from keystone.openstack.common import log
from keystone.openstack.common import timeutils
@ -110,7 +110,7 @@ class dirs:
# keystone.common.sql.initialize() for testing.
def _initialize_sql_session():
db_file = dirs.tmp('test.db')
session.set_defaults(
db_options.set_defaults(
sql_connection="sqlite:///" + db_file,
sqlite_db=db_file)
@ -157,7 +157,8 @@ def setup_database():
if os.path.exists(db):
os.unlink(db)
if not os.path.exists(pristine):
migration.db_sync((migration_helpers.find_migrate_repo()))
migration.db_sync(sql.get_engine(),
migration_helpers.find_migrate_repo())
migration_helpers.sync_database_to_version(extension='revoke')
shutil.copyfile(db, pristine)
else:
@ -165,7 +166,7 @@ def setup_database():
def teardown_database():
session.cleanup()
sql.cleanup()
@atexit.register
@ -405,8 +406,8 @@ class TestCase(testtools.TestCase):
# The credential backend only supports SQL, so we always have to load
# the tables.
self.engine = session.get_engine()
self.addCleanup(session.cleanup)
self.engine = sql.get_engine()
self.addCleanup(sql.cleanup)
sql.ModelBase.metadata.create_all(bind=self.engine)
self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine)

View File

@ -14,6 +14,7 @@
import uuid
from keystone.common import sql
from keystone.common.sql import migration_helpers
from keystone import contrib
from keystone.openstack.common.db.sqlalchemy import migration
@ -38,8 +39,8 @@ class TestExtensionCase(test_v3.RestfulTestCase):
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
package = importutils.import_module(package_name)
abs_path = migration_helpers.find_migrate_repo(package)
migration.db_version_control(abs_path)
migration.db_sync(abs_path)
migration.db_version_control(sql.get_engine(), abs_path)
migration.db_sync(sql.get_engine(), abs_path)
def setUp(self):
super(TestExtensionCase, self).setUp()

View File

@ -25,7 +25,6 @@ from keystone.common import sql
from keystone import config
from keystone import exception
from keystone import identity
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common.fixture import moxstubout
from keystone import tests
from keystone.tests import default_fixtures
@ -1204,8 +1203,8 @@ class LdapIdentitySqlAssignment(tests.TestCase, BaseLDAPIdentity):
self.clear_database()
self.load_backends()
cache.configure_cache_region(cache.REGION)
self.engine = session.get_engine()
self.addCleanup(session.cleanup)
self.engine = sql.get_engine()
self.addCleanup(sql.cleanup)
sql.ModelBase.metadata.create_all(bind=self.engine)
self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine)
@ -1278,8 +1277,8 @@ class MultiLDAPandSQLIdentity(tests.TestCase, BaseLDAPIdentity):
self._set_config()
self.load_backends()
self.engine = session.get_engine()
self.addCleanup(session.cleanup)
self.engine = sql.get_engine()
self.addCleanup(sql.cleanup)
sql.ModelBase.metadata.create_all(bind=self.engine)
self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine)

View File

@ -23,7 +23,6 @@ from keystone import config
from keystone import exception
from keystone.identity.backends import sql as identity_sql
from keystone.openstack.common.db import exception as db_exception
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common.fixture import moxstubout
from keystone import tests
from keystone.tests import default_fixtures
@ -125,7 +124,7 @@ class SqlModels(SqlTests):
class SqlIdentity(SqlTests, test_backend.IdentityTests):
def test_password_hashed(self):
session = db_session.get_session()
session = sql.get_session()
user_ref = self.identity_api._get_user(session, self.user_foo['id'])
self.assertNotEqual(user_ref['password'], self.user_foo['password'])
@ -314,7 +313,7 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests):
'password': uuid.uuid4().hex}
self.identity_api.create_user(user_id, user)
session = db_session.get_session()
session = sql.get_session()
query = session.query(identity_sql.User)
query = query.filter_by(id=user_id)
raw_user_ref = query.one()
@ -336,14 +335,14 @@ class SqlToken(SqlTests, test_backend.TokenTests):
fixture = self.useFixture(moxstubout.MoxStubout())
self.mox = fixture.mox
tok = token_sql.Token()
session = db_session.get_session()
session = sql.get_session()
q = session.query(token_sql.TokenModel.id,
token_sql.TokenModel.expires)
self.mox.StubOutWithMock(session, 'query')
session.query(token_sql.TokenModel.id,
token_sql.TokenModel.expires).AndReturn(q)
self.mox.StubOutWithMock(db_session, 'get_session')
db_session.get_session().AndReturn(session)
self.mox.StubOutWithMock(sql, 'get_session')
sql.get_session().AndReturn(session)
self.mox.ReplayAll()
tok.list_revoked_tokens()

View File

@ -18,7 +18,6 @@ import webob
from keystone.common import sql
from keystone import config
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common import jsonutils
from keystone.openstack.common import timeutils
from keystone import tests
@ -43,8 +42,8 @@ class CompatTestCase(tests.NoModule, tests.TestCase):
# credential api makes some very SQL specific assumptions that should
# be addressed allowing for non-SQL based testing to occur.
self.load_backends()
self.engine = session.get_engine()
self.addCleanup(session.cleanup)
self.engine = sql.get_engine()
self.addCleanup(sql.cleanup)
self.addCleanup(sql.ModelBase.metadata.drop_all,
bind=self.engine)
sql.ModelBase.metadata.create_all(bind=self.engine)

View File

@ -44,7 +44,7 @@ from keystone import config
from keystone import credential
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import migration
from keystone.openstack.common.db.sqlalchemy import session
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone import tests
from keystone.tests import default_fixtures
@ -75,8 +75,8 @@ class SqlMigrateBase(tests.TestCase):
self.config(self.config_files())
# create and share a single sqlalchemy engine for testing
self.engine = session.get_engine()
self.Session = session.get_maker(self.engine, autocommit=False)
self.engine = sql.get_engine()
self.Session = db_session.get_maker(self.engine, autocommit=False)
self.initialize_sql()
self.repo_path = migration_helpers.find_migrate_repo(
@ -94,7 +94,7 @@ class SqlMigrateBase(tests.TestCase):
autoload=True)
self.downgrade(0)
table.drop(self.engine, checkfirst=True)
session.cleanup()
sql.cleanup()
super(SqlMigrateBase, self).tearDown()
def select_table(self, name):
@ -158,7 +158,7 @@ class SqlUpgradeTests(SqlMigrateBase):
self.assertTableDoesNotExist('user')
def test_start_version_0(self):
version = migration.db_version(self.repo_path, 0)
version = migration.db_version(sql.get_engine(), self.repo_path, 0)
self.assertEqual(version, 0, "DB is not at version 0")
def test_two_steps_forward_one_step_back(self):

View File

@ -14,6 +14,7 @@ import random
import uuid
from keystone.auth import controllers as auth_controllers
from keystone.common import sql
from keystone.common.sql import migration_helpers
from keystone import config
from keystone import contrib
@ -45,8 +46,8 @@ class FederationTests(test_v3.RestfulTestCase):
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
package = importutils.import_module(package_name)
abs_path = migration_helpers.find_migrate_repo(package)
migration.db_version_control(abs_path)
migration.db_sync(abs_path)
migration.db_version_control(sql.get_engine(), abs_path)
migration.db_sync(sql.get_engine(), abs_path)
class FederatedIdentityProviderTests(FederationTests):

View File

@ -17,6 +17,7 @@ import uuid
from six.moves import urllib
from keystone.common import sql
from keystone.common.sql import migration_helpers
from keystone import config
from keystone import contrib
@ -41,8 +42,8 @@ class OAuth1Tests(test_v3.RestfulTestCase):
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
package = importutils.import_module(package_name)
abs_path = migration_helpers.find_migrate_repo(package)
migration.db_version_control(abs_path)
migration.db_sync(abs_path)
migration.db_version_control(sql.get_engine(), abs_path)
migration.db_sync(sql.get_engine(), abs_path)
def setUp(self):
super(OAuth1Tests, self).setUp()

View File

@ -17,7 +17,6 @@ import copy
from keystone.common import sql
from keystone import config
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common import timeutils
from keystone import token
@ -45,7 +44,7 @@ class Token(token.Driver):
def get_token(self, token_id):
if token_id is None:
raise exception.TokenNotFound(token_id=token_id)
session = db_session.get_session()
session = sql.get_session()
token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id)
@ -60,13 +59,13 @@ class Token(token.Driver):
token_ref = TokenModel.from_dict(data_copy)
token_ref.valid = True
session = db_session.get_session()
session = sql.get_session()
with session.begin():
session.add(token_ref)
return token_ref.to_dict()
def delete_token(self, token_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid:
@ -83,7 +82,7 @@ class Token(token.Driver):
or the trustor's user ID, so will use trust_id to query the tokens.
"""
session = db_session.get_session()
session = sql.get_session()
with session.begin():
now = timeutils.utcnow()
query = session.query(TokenModel)
@ -122,7 +121,7 @@ class Token(token.Driver):
return False
def _list_tokens_for_trust(self, trust_id):
session = db_session.get_session()
session = sql.get_session()
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel)
@ -136,7 +135,7 @@ class Token(token.Driver):
return tokens
def _list_tokens_for_user(self, user_id, tenant_id=None):
session = db_session.get_session()
session = sql.get_session()
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel)
@ -152,7 +151,7 @@ class Token(token.Driver):
def _list_tokens_for_consumer(self, user_id, consumer_id):
tokens = []
session = db_session.get_session()
session = sql.get_session()
with session.begin():
now = timeutils.utcnow()
query = session.query(TokenModel)
@ -178,7 +177,7 @@ class Token(token.Driver):
return self._list_tokens_for_user(user_id, tenant_id)
def list_revoked_tokens(self):
session = db_session.get_session()
session = sql.get_session()
tokens = []
now = timeutils.utcnow()
query = session.query(TokenModel.id, TokenModel.expires)
@ -211,7 +210,7 @@ class Token(token.Driver):
return batch_size
def flush_expired_tokens(self):
session = db_session.get_session()
session = sql.get_session()
dialect = session.bind.dialect.name
batch_size = self.token_flush_batch_size(dialect)
if batch_size > 0:

View File

@ -14,7 +14,6 @@
from keystone.common import sql
from keystone import exception
from keystone.openstack.common.db.sqlalchemy import session as db_session
from keystone.openstack.common import timeutils
from keystone import trust
@ -47,7 +46,7 @@ class TrustRole(sql.ModelBase):
class Trust(trust.Driver):
@sql.handle_conflicts(conflict_type='trust')
def create_trust(self, trust_id, trust, roles):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = TrustModel.from_dict(trust)
ref['id'] = trust_id
@ -73,7 +72,7 @@ class Trust(trust.Driver):
@sql.handle_conflicts(conflict_type='trust')
def consume_use(self, trust_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
ref = (session.query(TrustModel).
with_lockmode('update').
@ -90,7 +89,7 @@ class Trust(trust.Driver):
raise exception.TrustUseLimitReached(trust_id=trust_id)
def get_trust(self, trust_id):
session = db_session.get_session()
session = sql.get_session()
ref = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(id=trust_id).first())
@ -111,13 +110,13 @@ class Trust(trust.Driver):
@sql.handle_conflicts(conflict_type='trust')
def list_trusts(self):
session = db_session.get_session()
session = sql.get_session()
trusts = session.query(TrustModel).filter_by(deleted_at=None)
return [trust_ref.to_dict() for trust_ref in trusts]
@sql.handle_conflicts(conflict_type='trust')
def list_trusts_for_trustee(self, trustee_user_id):
session = db_session.get_session()
session = sql.get_session()
trusts = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(trustee_user_id=trustee_user_id))
@ -125,7 +124,7 @@ class Trust(trust.Driver):
@sql.handle_conflicts(conflict_type='trust')
def list_trusts_for_trustor(self, trustor_user_id):
session = db_session.get_session()
session = sql.get_session()
trusts = (session.query(TrustModel).
filter_by(deleted_at=None).
filter_by(trustor_user_id=trustor_user_id))
@ -133,7 +132,7 @@ class Trust(trust.Driver):
@sql.handle_conflicts(conflict_type='trust')
def delete_trust(self, trust_id):
session = db_session.get_session()
session = sql.get_session()
with session.begin():
trust_ref = session.query(TrustModel).get(trust_id)
if not trust_ref: