trove/trove/db/sqlalchemy/api.py

171 lines
4.2 KiB
Python

# Copyright 2011 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import sqlalchemy.exc
from trove.common import exception
from trove.db.sqlalchemy import migration
from trove.db.sqlalchemy import session
def list(query_func, *args, **kwargs):
query = query_func(*args, **kwargs)
res = query.all()
query.session.commit()
return res
def count(query, *args, **kwargs):
query = query(*args, **kwargs)
res = query.count()
query.session.commit()
return res
def first(query, *args, **kwargs):
query = query(*args, **kwargs)
res = query.first()
query.session.commit()
return res
def join(query, model, *args):
query = query(model)
res = query.join(*args)
query.session.commit()
return res
def find_all(model, **conditions):
return _query_by(model, **conditions)
def find_all_by_limit(query_func, model, conditions, limit, marker=None,
marker_column=None):
query = _limits(query_func, model, conditions, limit, marker,
marker_column)
res = query.all()
query.session.commit()
return res
def find_by(model, **kwargs):
query = _query_by(model, **kwargs)
res = query.first()
query.session.commit()
return res
def find_by_filter(model, **kwargs):
filters = kwargs.pop('filters', [])
return _query_by_filter(model, *filters, **kwargs)
def save(model):
try:
db_session = session.get_session()
with db_session.begin():
model = db_session.merge(model)
db_session.flush()
return model
except sqlalchemy.exc.IntegrityError as error:
raise exception.DBConstraintError(model_name=model.__class__.__name__,
error=str(error.orig))
def delete(model):
db_session = session.get_session()
with db_session.begin():
model = db_session.merge(model)
db_session.delete(model)
db_session.flush()
def delete_all(query_func, model, **conditions):
query = query_func(model, **conditions)
query.delete()
query.session.commit()
def update(model, **values):
for k, v in values.items():
model[k] = v
def update_all(query_func, model, conditions, values):
query = query_func(model, **conditions)
query.update()
query.session.commit()
def configure_db(options, *plugins):
session.configure_db(options)
configure_db_for_plugins(options, *plugins)
def configure_db_for_plugins(options, *plugins):
for plugin in plugins:
session.configure_db(options, models_mapper=plugin.mapper)
def drop_db(options):
session.drop_db(options)
def clean_db():
session.clean_db()
def db_sync(options, version=None, repo_path=None):
migration.db_sync(options, version, repo_path)
def db_upgrade(options, version=None, repo_path=None):
migration.upgrade(options, version, repo_path)
def db_reset(options, *plugins):
drop_db(options)
db_sync(options)
configure_db(options)
def _base_query(cls):
db_session = session.get_session()
query = db_session.query(cls)
return query
def _query_by(cls, **conditions):
query = _base_query(cls)
if conditions:
query = query.filter_by(**conditions)
return query
def _query_by_filter(cls, *filters, **conditions):
query = _query_by(cls, **conditions)
if filters:
query = query.filter(*filters)
return query
def _limits(query_func, model, conditions, limit, marker, marker_column=None):
query = query_func(model, **conditions)
marker_column = marker_column or model.id
if marker:
query = query.filter(marker_column > marker)
return query.order_by(marker_column).limit(limit)